├── .gitignore
├── README.md
├── _config.yml
├── _layouts
└── default.html
├── config
├── config.yaml
├── net_config_blender_multiview_2view_eval.txt
└── net_config_blender_multiview_2view_train.txt
├── dataloaders
├── dataset_TODD.py
├── dataset_blender.py
├── group_multiview.py
├── template.py
└── tools.py
├── index.md
├── inference.ipynb
├── media
└── main_image.jpg
├── model.jpg
├── model
├── hybrid_depth_decoder.py
├── multiview_net.py
├── networks
│ ├── __init__.py
│ ├── layers_op.py
│ ├── psm_submodule.py
│ ├── resnet_encoder.py
│ ├── senet.py
│ └── senet_submodule.py
└── transformer
│ ├── __init__.py
│ └── epipolar_transformer.py
├── net_train_multiview.py
├── requirements.txt
├── runner.sh
├── src
├── LICENSE.md
└── lib
│ ├── camera.py
│ ├── color_stuff.py
│ ├── datapoint.py
│ ├── net
│ ├── common.py
│ ├── dataset.py
│ ├── functions
│ │ └── learning_rate.py
│ ├── init
│ │ └── default_init.py
│ ├── losses.py
│ ├── models
│ │ ├── panoptic_net.py
│ │ └── simplenet.py
│ ├── panoptic_trainer.py
│ └── post_processing
│ │ ├── depth_outputs.py
│ │ ├── epnp.py
│ │ ├── eval3d.py
│ │ ├── nms.py
│ │ ├── obb_outputs.py
│ │ ├── pose_outputs.py
│ │ └── segmentation_outputs.py
│ ├── occlusions.py
│ └── transform.py
└── utils
├── homo_utils.py
└── rotation_SVD.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## MVTrans: Multi-view Perception to See Transparent Objects (ICRA2023)
2 |
3 | [**Paper**](https://arxiv.org/abs/2302.11683) | [**Project**](https://ac-rad.github.io/MVTrans/) | [**Video**](https://youtu.be/8Qdc_xWVp-k)
4 |
5 | This repo contains the official implementation of the paper "MVTrans: Multi-view Perception to See Transparent Objects".
6 |
7 | ## Introduction
8 | Transparent object perception is a crucial skill for applications such as robot manipulation in household and laboratory settings. Existing methods utilize RGB-D or stereo inputs to handle a subset of perception tasks including depth and pose estimation. However transparent object perception remains to be an open problem. In this paper, we forgo the unreliable depth map from RGB-D sensors and extend the stereo based method. Our proposed method, MVTrans, is an end-to-end multi-view architecture with multiple perception capabilities, including depth estimation, segmentation, and pose estimation. Additionally, we establish a novel procedural photo-realistic dataset generation pipeline and create a large-scale transparent object detection dataset, Syn-TODD, which is suitable for training networks with all three modalities, RGB-D, stereo and multi-view RGB.
9 |
10 |
11 |
12 | ## Installation
13 | Setup a conda environment, install required packages, and download the repo:
14 | ```
15 | conda create -y --prefix ./env python=3.8
16 | ./env/bin/python -m pip install -r requirements.txt
17 | git clone https://github.com/ac-rad/MVTrans.git
18 | ```
19 | Weights & Biases (wandb) is used to log and visualize training results. Please follow the [instruction](https://docs.wandb.ai/) to setup wandb. To appropriately log results to cloud, insert your wandb login key in `net_train_multiview.py`. Otherwise, to log results locally, run the following command and access results at localhost:
20 | ```
21 | wandb offline
22 | ```
23 |
24 | ## Dataset
25 | Our synthetic transparent object detection dataset (Syn-TODD) can be downloaded at [here](https://borealisdata.ca/dataset.xhtml?persistentId=doi:10.5683/SP3/LQKTXE).
26 |
27 | ## Pre-trained Model
28 |
29 | We provide pre-trained model weight for MVTrans trained on Syn-TODD dataset.
30 |
31 | | Model views | Link |
32 | |-------------|------|
33 | | 2 views | [here](https://borealisdata.ca/api/access/datafile/632196) |
34 | | 3 views | [here](https://borealisdata.ca/api/access/datafile/632197) |
35 | | 5 views | [here](https://borealisdata.ca/api/access/datafile/632195) |
36 |
37 | ## Training
38 | To train MVTrans from scratch, modify the data path and output directory in configuration files under `config/`, and then run:
39 | ```
40 | ./runner.sh net_train_multiview.py @config/net_config_blender_multiview_{NUM_OF_VIEW}_train.txt
41 | ```
42 |
43 | ## Evaluation
44 | To run the evaluation, need to change modify the data path and output directory in configuration files under `config/`, and then run:
45 | ```
46 | ./runner.sh net_train_multiview.py @config/net_config_blender_multiview_{NUM_OF_VIEW}_eval.txt
47 | ```
48 | ## Inference
49 | To run the inference, launch jupyter notebook and run `inference.ipynb`.
50 | ## Citation
51 | Please cite our paper:
52 | ```
53 | @misc{wang2023mvtrans,
54 | title={MVTrans: Multi-View Perception of Transparent Objects},
55 | author={Yi Ru Wang and Yuchi Zhao and Haoping Xu and Saggi Eppel and Alan Aspuru-Guzik and Florian Shkurti and Animesh Garg},
56 | year={2023},
57 | eprint={2302.11683},
58 | archivePrefix={arXiv},
59 | primaryClass={cs.RO}
60 | }
61 | ```
62 |
63 | ## Reference
64 | Our MVTrans architecture is built based on [SimNet](https://github.com/ToyotaResearchInstitute/simnet) and [ESTDepth](https://github.com/xxlong0/ESTDepth).
65 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | remote_theme: pages-themes/cayman@v0.2.0
2 | plugins:
3 | - jekyll-remote-theme # add this line to the plugins list if you already have one
4 | title: "MVTrans: Multi-View Perception of Transparent Objects"
5 | repository_url: "https://github.com/ac-rad/MVTrans"
6 | show_arxiv: true
7 | arxiv_url: "https://arxiv.org/abs/2302.11683"
8 | show_dataset: true
9 | dataset_url: "https://borealisdata.ca/dataset.xhtml?persistentId=doi:10.5683/SP3/LQKTXE"
10 | description: Yi Ru Wang, Yuchi Zhao, Haoping Xu, Sagi Eppel, Alan Aspuru-Guzik, Florian Shkurti, Animesh Garg
11 |
--------------------------------------------------------------------------------
/_layouts/default.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
10 | 11 |
12 | -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "cce97a6b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import cv2\n", 12 | "import torch\n", 13 | "import os\n", 14 | "import argparse\n", 15 | "from torch.nn import functional as F\n", 16 | "from torch.utils.data import ConcatDataset, DataLoader\n", 17 | "from matplotlib import pyplot as plt\n", 18 | "from PIL import Image\n", 19 | "from importlib.machinery import SourceFileLoader\n", 20 | "\n", 21 | "from src.lib import datapoint, camera, transform\n", 22 | "from src.lib.net import common\n", 23 | "from src.lib.net.init.default_init import default_init\n", 24 | "from src.lib.net.dataset import extract_left_numpy_img\n", 25 | "from src.lib.net.post_processing import epnp\n", 26 | "from src.lib.net.post_processing import nms\n", 27 | "from src.lib.net.post_processing import pose_outputs as poseOut\n", 28 | "from src.lib.net.post_processing import eval3d\n", 29 | "from src.lib.net.post_processing.eval3d import measure_3d_iou, EvalMetrics, measure_ADD\n", 30 | "from src.lib.net.post_processing.segmentation_outputs import draw_segmentation_mask_gt\n", 31 | "from src.lib.net.post_processing.epnp import optimize_for_9D" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "698d14b6", 37 | "metadata": {}, 38 | "source": [ 39 | "# Define Helper Functions" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "aa87c45b", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "class Detection:\n", 50 | " def __init__(self, camera_T_object=None, scale_matrix=None, box=None, obj_CAD=None):\n", 51 | " self.camera_T_object=camera_T_object #np.ndarray\n", 52 | " self.scale_matrix= scale_matrix # np.ndarray\n", 53 | " self.size_label=\"small\"\n", 54 | " self.ignore = False\n", 55 | " self.box = box \n", 56 | " self.obj_CAD=0\n", 57 | " \n", 58 | "def get_obj_pose_and_bbox(heatmap_output, vertex_output, z_centroid_output, cov_matrices, camera_model):\n", 59 | " peaks = poseOut.extract_peaks_from_centroid(np.copy(heatmap_output), max_peaks=np.inf)\n", 60 | " bboxes_ext = poseOut.extract_vertices_from_peaks(np.copy(peaks), np.copy(vertex_output), np.copy(heatmap_output)) # Shape: List(np.array([8,2])) --> y,x order\n", 61 | " z_centroids = poseOut.extract_z_centroid_from_peaks(np.copy(peaks), np.copy(z_centroid_output))\n", 62 | " cov_matrices = poseOut.extract_cov_matrices_from_peaks(np.copy(peaks), np.copy(cov_matrices))\n", 63 | " poses = []\n", 64 | " for bbox_ext, z_centroid, cov_matrix, peak in zip(bboxes_ext, z_centroids, cov_matrices, peaks):\n", 65 | " bbox_ext_flipped = bbox_ext[:, ::-1] # Switch from yx to xy\n", 66 | " # Solve for pose up to a scale factor\n", 67 | " error, camera_T_object, scale_matrix = optimize_for_9D(bbox_ext_flipped.T, camera_model, solve_for_transforms=True) \n", 68 | " abs_camera_T_object, abs_object_scale = epnp.find_absolute_scale(\n", 69 | " -1.0 * z_centroid, camera_T_object, scale_matrix\n", 70 | " )\n", 71 | " poses.append(transform.Pose(camera_T_object=abs_camera_T_object, scale_matrix=abs_object_scale))\n", 72 | " return poses, bboxes_ext\n", 73 | "\n", 74 | "def get_obj_name(scene):\n", 75 | " return scene[0].split(\"/\")[-3]\n", 76 | "\n", 77 | "def prune_state_dict(state_dict):\n", 78 | " for key in list(state_dict.keys()):\n", 79 | " state_dict[key[6:]] = state_dict.pop(key)\n", 80 | " return state_dict" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "15648098", 86 | "metadata": {}, 87 | "source": [ 88 | "# Model Setup" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "d335c955", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "ckpt_path = 'PATH_TO_CHECKPOINT.ckpt'\n", 99 | "model_file = 'ABSOLUTE_PATH_OF_REPO/models/multiview_net.py'\n", 100 | "model_name = 'res_fpn'\n", 101 | "hparam_file= 'ABSOLUTE_PATH_OF_REPO/config/net_config_blender_multiview_2view_eval.txt'\n", 102 | "model_path = (model_file)\n", 103 | "\n", 104 | "parser = argparse.ArgumentParser(fromfile_prefix_chars='@')\n", 105 | "common.add_train_args(parser)\n", 106 | "hparams = parser.parse_args(['@config/net_config_blender_multiview_2view_eval.txt'])\n", 107 | "\n", 108 | "print('Using model class from:', model_path)\n", 109 | "net_module = SourceFileLoader(model_name, str(model_path)).load_module()\n", 110 | "net_attr = getattr(net_module, model_name)\n", 111 | "model = net_attr(hparams)\n", 112 | "model.apply(default_init)\n", 113 | "\n", 114 | "print('Restoring from checkpoint:', ckpt_path)\n", 115 | "state_dict = torch.load(ckpt_path, map_location='cuda:0')['state_dict']\n", 116 | "state_dict = prune_state_dict(state_dict)\n", 117 | "model.load_state_dict(state_dict)\n", 118 | "\n", 119 | "model.cuda()\n", 120 | "model.eval()" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "a974fa5d", 126 | "metadata": {}, 127 | "source": [ 128 | "# Run inference" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "d81598ac", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "# load data from dataloader\n", 139 | "val_ds = datapoint.make_dataset(hparams.val_path, dataset = 'blender', multiview = True, num_multiview = hparams.num_multiview, num_samples = hparams.num_samples)\n", 140 | "data_loader = common.get_loader(hparams, \"val\", datapoint_dataset=val_ds)\n", 141 | "data = next(iter(data_loader)) \n", 142 | "step = 1\n", 143 | "obj_name = None\n", 144 | "step_model = 0\n", 145 | "\n", 146 | "# inference\n", 147 | "if hparams.network_type == 'simnet':\n", 148 | " image, seg_target, depth_target, pose_targets, box_targets, keypoint_targets, _, scene_name = data\n", 149 | " seg_output, depth_output, small_depth_output, pose_outputs, box_outputs, keypoint_outputs = model.forward(\n", 150 | " image.cuda(), step = step_model\n", 151 | " )\n", 152 | " step_model +=1\n", 153 | "elif hparams.network_type == 'multiview':\n", 154 | " image, camera_poses, camera_intrinsic, seg_target, depth_target, pose_targets, _, scene_name = data\n", 155 | " camera_intrinsic=[item.cuda() for item in camera_intrinsic]\n", 156 | "\n", 157 | " assert image.shape[1] == camera_poses.shape[1], f'dimension mismatch: num of imgs {image.shape} not equal to num of camera poses {camera_poses.shape}'\n", 158 | "\n", 159 | " seg_output, depth_output, small_depth_output, pose_outputs, box_outputs, keypoint_outputs = model.forward(\n", 160 | " imgs = image.cuda(), cam_poses = camera_poses.cuda(), cam_intr = camera_intrinsic, mode = 'val' \n", 161 | " )\n", 162 | "else:\n", 163 | " raise ValueError(f'Network type not supported: {hparams.network_type}')" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "id": "4de8d0e5", 169 | "metadata": {}, 170 | "source": [ 171 | "# Visualization" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "9fce2888", 178 | "metadata": { 179 | "scrolled": false 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "camera_model = camera.BlenderCamera()\n", 184 | "with torch.no_grad():\n", 185 | " left_image_np = extract_left_numpy_img(image[0], mode = 'multiview')\n", 186 | " depth_vis = depth_output.get_visualization_img(np.copy(left_image_np))\n", 187 | "\n", 188 | " depth_target[0].depth_pred=np.expand_dims(depth_target[0].depth_pred, axis=0)\n", 189 | " depth_target[0].convert_to_torch_from_numpy()\n", 190 | " gt_depth_vis = depth_target[0].get_visualization_img(np.copy(left_image_np))\n", 191 | "\n", 192 | " seg_vis = seg_output.get_visualization_img(np.copy(left_image_np)) \n", 193 | " seg_target[0].convert_to_numpy_from_torch()\n", 194 | " gt_seg_vis = draw_segmentation_mask_gt(np.copy(left_image_np), seg_target[0].seg_pred)\n", 195 | "\n", 196 | " c_img = cv2.cvtColor(np.array(left_image_np), cv2.COLOR_BGR2RGB)\n", 197 | " pose_vis = pose_outputs.get_visualization_img(np.copy(left_image_np), camera_model=camera_model)\n", 198 | " gt_pose_vis = pose_targets[0].get_visualization_img_gt(np.copy(left_image_np), camera_model=camera_model)\n", 199 | "\n", 200 | " # plotting \n", 201 | " rows = 2\n", 202 | " columns = 3\n", 203 | " fig = plt.figure(figsize=(15, 15))\n", 204 | "\n", 205 | " fig.add_subplot(rows, columns, 1)\n", 206 | " plt.imshow(gt_seg_vis)\n", 207 | " plt.axis('off')\n", 208 | " plt.title(\"gt_seg map\")\n", 209 | "\n", 210 | " fig.add_subplot(rows, columns, 2)\n", 211 | " plt.imshow(gt_depth_vis)\n", 212 | " plt.axis('off')\n", 213 | " plt.title(\"gt depth map\")\n", 214 | "\n", 215 | " fig.add_subplot(rows, columns, 3)\n", 216 | " plt.imshow(gt_pose_vis.astype(int))\n", 217 | " plt.axis('off')\n", 218 | " plt.title(\"gt pose vis\")\n", 219 | "\n", 220 | " fig.add_subplot(rows, columns, 4)\n", 221 | " plt.imshow(seg_vis)\n", 222 | " plt.axis('off')\n", 223 | " plt.title(\"seg map\")\n", 224 | "\n", 225 | " fig.add_subplot(rows, columns, 5)\n", 226 | " plt.imshow(depth_vis)\n", 227 | " plt.axis('off')\n", 228 | " plt.title(\"depth map\") \n", 229 | "\n", 230 | " fig.add_subplot(rows, columns, 6)\n", 231 | " plt.imshow(pose_vis.astype(int))\n", 232 | " plt.axis('off')\n", 233 | " plt.title(\"pose vis\")" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "id": "a69c5d19", 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [] 243 | } 244 | ], 245 | "metadata": { 246 | "kernelspec": { 247 | "display_name": "Python 3 (ipykernel)", 248 | "language": "python", 249 | "name": "python3" 250 | }, 251 | "language_info": { 252 | "codemirror_mode": { 253 | "name": "ipython", 254 | "version": 3 255 | }, 256 | "file_extension": ".py", 257 | "mimetype": "text/x-python", 258 | "name": "python", 259 | "nbconvert_exporter": "python", 260 | "pygments_lexer": "ipython3", 261 | "version": "3.8.13" 262 | } 263 | }, 264 | "nbformat": 4, 265 | "nbformat_minor": 5 266 | } 267 | -------------------------------------------------------------------------------- /media/main_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ac-rad/MVTrans/8698a19cff8ee6ab317af88dd408e0e098c35504/media/main_image.jpg -------------------------------------------------------------------------------- /model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ac-rad/MVTrans/8698a19cff8ee6ab317af88dd408e0e098c35504/model.jpg -------------------------------------------------------------------------------- /model/hybrid_depth_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from utils.homo_utils import * 7 | from model.transformer.epipolar_transformer import EpipolarTransformer 8 | from model.networks.layers_op import convbn, convbnrelu, convbn_3d, convbnrelu_3d, convbntanh_3d 9 | 10 | 11 | def upsample(x): 12 | """Upsample input tensor by a factor of 2 13 | """ 14 | return F.interpolate(x, scale_factor=2, mode="nearest") 15 | 16 | 17 | class ConvBlock(nn.Module): 18 | """Layer to perform a convolution followed by ELU 19 | """ 20 | 21 | def __init__(self, in_channels, out_channels): 22 | super(ConvBlock, self).__init__() 23 | 24 | self.conv = convbn(in_channels, out_channels, 3, 1, 1, 1) 25 | self.nonlin = nn.ReLU(inplace=True) 26 | 27 | def forward(self, x): 28 | out = self.conv(x) 29 | out = self.nonlin(out) 30 | return out 31 | 32 | 33 | def depthlayer(logits, depth_values): 34 | prob_volume = torch.nn.functional.softmax(logits, dim=1) 35 | depth = torch.sum(prob_volume * depth_values, dim=1, keepdim=True) 36 | prob, _ = torch.max(prob_volume, dim=1, keepdim=True) 37 | 38 | return depth, prob 39 | 40 | 41 | class DepthHybridDecoder(nn.Module): 42 | def __init__(self, num_ch_enc, num_output_channels=1, use_skips=True, 43 | ndepths=64, depth_max=10.0, IF_EST_transformer=True): 44 | super(DepthHybridDecoder, self).__init__() 45 | 46 | self.num_output_channels = num_output_channels 47 | self.use_skips = use_skips 48 | self.IF_EST_transformer = IF_EST_transformer 49 | self.upsample_mode = 'nearest' 50 | 51 | self.num_ch_enc = num_ch_enc # [64, 64, 128, 256, 512] 52 | self.num_ch_dec = np.array([16, 32, ndepths, 128, 256]) 53 | 54 | self.ndepths = ndepths 55 | self.depth_max = depth_max 56 | 57 | self.pixel_grid = None 58 | 59 | # decoder 60 | self.upconv_4_0 = ConvBlock(self.num_ch_enc[-1], self.num_ch_dec[4]) 61 | self.upconv_4_1 = ConvBlock(self.num_ch_dec[4] + self.num_ch_enc[3], self.num_ch_dec[4]) 62 | 63 | self.upconv_3_0 = ConvBlock(self.num_ch_dec[4], self.num_ch_dec[3]) 64 | self.upconv_3_1 = ConvBlock(self.num_ch_dec[3] + self.num_ch_enc[2], self.num_ch_dec[3]) 65 | 66 | self.upconv_2_0 = ConvBlock(self.num_ch_dec[3], self.num_ch_dec[2]) 67 | self.upconv_2_1 = ConvBlock(self.num_ch_dec[2] + self.num_ch_enc[1], self.ndepths) 68 | 69 | self.upconv_1_0 = ConvBlock(self.num_ch_dec[2] + self.ndepths, self.num_ch_dec[1]) 70 | self.upconv_1_1 = ConvBlock(self.num_ch_dec[1] + self.num_ch_enc[0], self.num_ch_dec[1]) 71 | self.dispconv_1 = nn.Conv2d(self.num_ch_dec[1], self.num_output_channels, 3, 1, 1, 1, bias=True) 72 | 73 | self.upconv_0_0 = ConvBlock(self.num_ch_dec[1], self.num_ch_dec[0]) 74 | self.upconv_0_1 = ConvBlock(self.num_ch_dec[0], self.num_ch_dec[0]) 75 | self.dispconv_0 = nn.Conv2d(self.num_ch_dec[0], self.num_output_channels, 3, 1, 1, 1, bias=True) 76 | 77 | self.sigmoid = nn.Sigmoid() 78 | self.relu = nn.ReLU(inplace=True) 79 | 80 | base_channels = 32 81 | if self.IF_EST_transformer: 82 | self.epipolar_transformer = EpipolarTransformer(base_channels // 2, base_channels // 2, 3) 83 | 84 | self.dres0 = nn.Sequential(convbnrelu_3d(base_channels, base_channels, 3, 1, 1), 85 | convbnrelu_3d(base_channels, base_channels, 3, 1, 1)) 86 | 87 | self.dres1 = nn.Sequential(convbnrelu_3d(base_channels, base_channels, 3, 1, 1), 88 | convbnrelu_3d(base_channels, base_channels, 3, 1, 1)) 89 | 90 | self.dres2 = nn.Sequential(convbnrelu_3d(base_channels + 1, base_channels + 1, 3, 1, 1)) 91 | 92 | self.key_layer = nn.Sequential(convbnrelu_3d(base_channels + 1, base_channels // 2, 3, 1, 1)) 93 | self.value_layer = nn.Sequential(convbntanh_3d(base_channels + 1, base_channels // 2, 3, 1, 1)) 94 | 95 | self.stereo_head0 = nn.Sequential( 96 | convbnrelu_3d(base_channels // 2, base_channels // 2, 3, 1, 1), 97 | nn.Conv3d(base_channels // 2, 1, kernel_size=1, padding=0, stride=1, bias=True) 98 | ) 99 | 100 | self.stereo_head1 = nn.Sequential( 101 | convbnrelu_3d(base_channels // 2, base_channels // 2, 3, 1, 1), 102 | nn.Conv3d(base_channels // 2, 1, kernel_size=1, padding=0, stride=1, bias=True) 103 | ) 104 | 105 | def scale_cam_intr(self, cam_intr, scale): 106 | cam_intr_new = cam_intr.clone() 107 | cam_intr_new[:, :2, :] *= scale 108 | 109 | return cam_intr_new 110 | 111 | def collapse_num(self, x): 112 | if len(x.shape) == 5: 113 | B, NUM, C, H, W = x.shape 114 | x = x.view(B * NUM, C, H, W) 115 | elif len(x.shape) == 6: 116 | B, NUM, C, D, H, W = x.shape 117 | x = x.view(B * NUM, C, D, H, W) 118 | return x 119 | 120 | def expand_num(self, x, NUM): 121 | if len(x.shape) == 4: 122 | B_NUM, C, H, W = x.shape 123 | x = x.view(-1, NUM, C, H, W) 124 | elif len(x.shape) == 5: 125 | B_NUM, C, D, H, W = x.shape 126 | x = x.view(-1, NUM, C, D, H, W) 127 | return x 128 | 129 | def forward_transformer(self, costvolumes, semantic_features, cam_poses, cam_intr, 130 | depth_values, depth_min, depth_interval, 131 | pre_costs=None, pre_cam_poses=None): 132 | """ 133 | try to make it faster 134 | :param costvolumes: list of [N,C,D,H,W] 135 | :param cam_poses: list of [N,4,4] 136 | :param cam_intr: [N,3,3] 137 | :param depth_values: [N, ndepths, 1, 1] 138 | :return: 139 | """ 140 | num = len(costvolumes) 141 | 142 | B, C, D, H, W = costvolumes[0].shape 143 | 144 | depth_values_lowres = depth_values.repeat(1, 1, H, W) 145 | depth_values_highres = depth_values.repeat(1, 1, 4 * H, 4 * W) 146 | 147 | outputs = {} 148 | 149 | if self.pixel_grid is None: 150 | self.pixel_grid = set_id_grid(H, W).to(costvolumes[0].dtype).to(costvolumes[0].device) # [1, 3, H, W] 151 | self.pixel_grid = self.pixel_grid.view(1, 3, 1, H * W).repeat(B, 1, D, 1) # [B, 3, D, H*W] 152 | 153 | # scale 4 154 | x = self.upconv_4_0(semantic_features[4]) 155 | x = [upsample(x)] 156 | if self.use_skips: 157 | x += [semantic_features[3]] 158 | x = torch.cat(x, 1) 159 | x = self.upconv_4_1(x) 160 | 161 | # scale 3 162 | x = self.upconv_3_0(x) 163 | x = [upsample(x)] 164 | if self.use_skips: 165 | x += [semantic_features[2]] 166 | x = torch.cat(x, 1) 167 | x = self.upconv_3_1(x) 168 | 169 | # scale 2 170 | x = self.upconv_2_0(x) 171 | x = [upsample(x)] 172 | if self.use_skips: 173 | x += [semantic_features[1]] 174 | x = torch.cat(x, 1) 175 | semantic_vs = self.upconv_2_1(x) # after relu, [B*num, C, H, W] 176 | 177 | # stack cost volumes together 178 | costvolumes = torch.stack(costvolumes, dim=1) 179 | costvolumes = self.collapse_num(costvolumes) 180 | # 3D matching guidance features 181 | matching_x = self.dres0(costvolumes) 182 | matching_x = self.dres1(matching_x) 183 | x = torch.cat([semantic_vs.unsqueeze(1), matching_x], dim=1) # [B*num,33,D,H,W] 184 | x = self.dres2(x) 185 | 186 | value = self.value_layer(x) 187 | key = self.key_layer(x) 188 | init_logits_ = self.stereo_head0(value).squeeze(1) # [B*num,D,H,W] 189 | 190 | init_logits = F.interpolate(init_logits_, scale_factor=4) 191 | 192 | pred_depth_s3, pred_prob_s3 = depthlayer(init_logits, depth_values_highres) 193 | pred_depth_s3 = self.expand_num(pred_depth_s3, num) # [B, num,1,H,W] 194 | pred_prob_s3 = self.expand_num(pred_prob_s3, num) 195 | for img_idx in range(num): 196 | outputs[("depth", img_idx, 3)] = pred_depth_s3[:, img_idx, :, :, :] 197 | outputs[("init_prob", img_idx)] = pred_prob_s3[:, img_idx, :, :, :] 198 | 199 | value = self.expand_num(value, num) 200 | key = self.expand_num(key, num) 201 | values = [value[:, img_idx, :, :, :, :] for img_idx in range(num)] 202 | keys = [key[:, img_idx, :, :, :, :] for img_idx in range(num)] 203 | detached_values = [value.detach() for value in values] 204 | detached_keys = [key.detach() for key in keys] 205 | 206 | ###################################################################### 207 | # transformer 208 | if pre_costs is not None: 209 | cam_poses += pre_cam_poses 210 | values += pre_costs["values"] 211 | keys += pre_costs["keys"] 212 | pre_num = len(pre_cam_poses) 213 | else: 214 | pre_num = 0 215 | 216 | all_fused_logits = [] 217 | for i in range(num): 218 | ref_cam_pose = cam_poses[i] 219 | warped_keys = [] 220 | warped_values = [] 221 | for j in range(num + pre_num): 222 | if i != j: 223 | rel_pose = torch.matmul(cam_poses[j], torch.inverse(ref_cam_pose)) 224 | 225 | warped_key_ = warp_volume(keys[j], depth_values_lowres.view(B, 1, D, H * W), 226 | rel_pose, cam_intr, 227 | self.pixel_grid, depth_min, depth_interval) # [B,C,D,H,W] 228 | 229 | warped_value_ = warp_volume(values[j], depth_values_lowres.view(B, 1, D, H * W), 230 | rel_pose, cam_intr, 231 | self.pixel_grid, depth_min, depth_interval) # [B,C,D,H,W] 232 | 233 | warped_keys.append(warped_key_) 234 | warped_values.append(warped_value_) 235 | 236 | fused_cost = self.epipolar_transformer( 237 | target_key=keys[i], warped_keys=warped_keys, 238 | target_value=values[i], warped_values=warped_values 239 | ) 240 | 241 | values[i] = fused_cost 242 | detached_values[i] = fused_cost.detach() 243 | 244 | fused_logits_ = self.stereo_head1(fused_cost).squeeze(1) 245 | all_fused_logits.append(fused_logits_) 246 | 247 | fused_logits = F.interpolate(fused_logits_, scale_factor=4) 248 | outputs[("depth", i, 2)], outputs[("fused_prob", i)] = depthlayer(fused_logits, depth_values_highres) 249 | 250 | ###################################################################### 251 | # depth refinement 252 | all_fused_logits = torch.stack(all_fused_logits, dim=1) # [B, NUM, D, H, W] 253 | all_fused_logits = self.collapse_num(all_fused_logits) # [B*NUM, D, H, W] 254 | 255 | # scale 1 256 | x = self.upconv_1_0(torch.cat([semantic_vs, self.relu(all_fused_logits)], dim=1)) 257 | x = [upsample(x)] 258 | if self.use_skips: 259 | x += [semantic_features[0]] 260 | x = torch.cat(x, 1) 261 | x = self.upconv_1_1(x) 262 | 263 | pred_depth_s1 = F.interpolate(self.depth_max * self.sigmoid(self.dispconv_1(x)), 264 | scale_factor=2) 265 | pred_depth_s1 = self.expand_num(pred_depth_s1, num) # [B, num,1,H,W] 266 | for img_idx in range(num): 267 | outputs[("depth", img_idx, 1)] = pred_depth_s1[:, img_idx, :, :, :] 268 | 269 | # scale 0 270 | x = self.upconv_0_0(x) 271 | x = [upsample(x)] 272 | x = torch.cat(x, 1) 273 | x = self.upconv_0_1(x) 274 | 275 | pred_depth_s0 = self.depth_max * self.sigmoid(self.dispconv_0(x)) 276 | pred_depth_s0 = self.expand_num(pred_depth_s0, num) # [B, num,1,H,W] 277 | for img_idx in range(num): 278 | outputs[("depth", img_idx, 0)] = pred_depth_s0[:, img_idx, :, :, :] 279 | 280 | return outputs, {"keys": detached_keys[-1:], "values": detached_values[-1:]}, cam_poses[-1:] 281 | 282 | def forward_notransformer(self, costvolumes, semantic_features, cam_poses, cam_intr, 283 | depth_values, depth_min, depth_interval, 284 | pre_costs=None, pre_cam_poses=None, if_trans_weight=True): 285 | """ 286 | :param costvolumes: list of [N,C,D,H,W] 287 | :param cam_poses: list of [N,4,4] 288 | :param cam_intr: [N,3,3] 289 | :param depth_values: [N, ndepths, H, W] 290 | :return: 291 | """ 292 | num = len(costvolumes) 293 | B, C, D, H, W = costvolumes[0].shape 294 | depth_values_lowres = depth_values.repeat(1, 1, H, W) 295 | depth_values_highres = depth_values.repeat(1, 1, H, W) 296 | outputs = {} 297 | 298 | if self.pixel_grid is None: 299 | self.pixel_grid = set_id_grid(H, W).to(costvolumes[0].dtype).to(costvolumes[0].device) # [1, 3, H, W] 300 | self.pixel_grid = self.pixel_grid.view(1, 3, 1, H * W).repeat(B, 1, D, 1) # [B, 3, D, H*W] 301 | 302 | # scale 4 303 | x = self.upconv_4_0(semantic_features[4]) 304 | x = [upsample(x)] 305 | if self.use_skips: 306 | x += [semantic_features[3]] 307 | x = torch.cat(x, 1) 308 | x = self.upconv_4_1(x) 309 | # scale 3 310 | x = self.upconv_3_0(x) 311 | x = [upsample(x)] 312 | if self.use_skips: 313 | x += [semantic_features[2]] 314 | x = torch.cat(x, 1) 315 | x = self.upconv_3_1(x) 316 | 317 | # scale 2 318 | x = self.upconv_2_0(x) 319 | x = [upsample(x)] 320 | if self.use_skips: 321 | x += [semantic_features[1]] 322 | x = torch.cat(x, 1) 323 | semantic_vs = self.upconv_2_1(x) # after relu, [B*num, C, H, W] 324 | 325 | 326 | # stack cost volumes together 327 | costvolumes = torch.stack(costvolumes, dim=1) 328 | costvolumes = self.collapse_num(costvolumes) 329 | # 3D matching guidance features 330 | matching_x = self.dres0(costvolumes) 331 | matching_x = self.dres1(matching_x) 332 | 333 | x = torch.cat([semantic_vs.unsqueeze(1), matching_x], dim=1) # [B*num,33,D,H,W] 334 | x = self.dres2(x) 335 | 336 | value = self.value_layer(x) 337 | key = self.key_layer(x) 338 | init_logits_ = self.stereo_head0(value).squeeze(1) # [B*num,D,H,W] 339 | init_logits = init_logits_ 340 | 341 | pred_depth_s3, pred_prob_s3 = depthlayer(init_logits, depth_values_highres) 342 | pred_depth_s3 = self.expand_num(pred_depth_s3, num) # [B, num,1,H,W] 343 | pred_prob_s3 = self.expand_num(pred_prob_s3, num) 344 | for img_idx in range(num): 345 | outputs[("depth", img_idx, 3)] = pred_depth_s3[:, img_idx, :, :, :] 346 | outputs[("init_prob", img_idx)] = pred_prob_s3[:, img_idx, :, :, :] 347 | 348 | value_expand = self.expand_num(value, num) 349 | key_expand = self.expand_num(key, num) 350 | values = [value_expand[:, img_idx, :, :, :, :] for img_idx in range(num)] 351 | keys = [key_expand[:, img_idx, :, :, :, :] for img_idx in range(num)] 352 | detached_values = [value.detach() for value in values] 353 | detached_keys = [key.detach() for key in keys] 354 | 355 | ###################################################################### 356 | 357 | all_fused_logits = self.stereo_head1(value).squeeze(1) 358 | fused_logits = all_fused_logits 359 | pred_depth_s2, pred_prob_s2 = depthlayer(fused_logits, depth_values_highres) 360 | pred_depth_s2 = self.expand_num(pred_depth_s2, num) # [B, num,1,H,W] 361 | pred_prob_s2 = self.expand_num(pred_prob_s2, num) 362 | for img_idx in range(num): 363 | outputs[("depth", img_idx, 2)] = pred_depth_s2[:, img_idx, :, :, :] 364 | outputs[("fused_prob", img_idx)] = pred_prob_s2[:, img_idx, :, :, :] 365 | 366 | return outputs, {"keys": detached_keys[-1:], "values": detached_values[-1:]}, cam_poses[-1:] 367 | 368 | def forward(self, costvolumes, semantic_features, cam_poses, cam_intr, 369 | depth_values, depth_min, depth_interval, 370 | pre_costs=None, pre_cam_poses=None, mode="train"): 371 | 372 | flag = self.IF_EST_transformer & (pre_costs is not None or mode == "train") 373 | 374 | if flag: 375 | return self.forward_transformer(costvolumes, semantic_features, cam_poses, cam_intr, 376 | depth_values, depth_min, depth_interval, 377 | pre_costs, pre_cam_poses) 378 | else: 379 | return self.forward_notransformer(costvolumes, semantic_features, cam_poses, cam_intr, 380 | depth_values, depth_min, depth_interval, 381 | pre_costs, pre_cam_poses) 382 | -------------------------------------------------------------------------------- /model/multiview_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from src.lib.net.models import simplenet 5 | from src.lib.net.post_processing import depth_outputs 6 | from src.lib.net.models.panoptic_net import DepthHead, OBBHead, SegmentationHead, ShapeSpec 7 | from model.networks.psm_submodule import psm_feature_extraction 8 | from model.networks.resnet_encoder import ResnetEncoder 9 | from model.hybrid_depth_decoder import DepthHybridDecoder 10 | from model.networks.layers_op import convbn_3d, convbnrelu_3d 11 | from utils.homo_utils import homo_warping 12 | 13 | 14 | class MultiviewBackbone(nn.Module): 15 | 16 | def __init__(self, hparams, in_channels=3): 17 | super().__init__() 18 | 19 | def make_rgb_stem(): 20 | net = simplenet.NetFactory() 21 | x = net.input(in_dim=3, stride=1, activated=True) 22 | x = net.downscale(x, 32) 23 | x = net.downscale(x, 32) 24 | return net.bake() 25 | 26 | def make_disp_features(): 27 | net = simplenet.NetFactory() 28 | x = net.input(in_dim=1, stride=1, activated=False) 29 | x = net.layer(x, 32, rate=5) 30 | return net.bake() 31 | 32 | self.disp_features = make_disp_features() 33 | 34 | def make_rgbd_backbone(num_channels=64, out_dim=64): 35 | net = simplenet.NetFactory() 36 | x = net.input(in_dim=64, activated=True, stride=4) 37 | x = net._lateral(x, out_dim=num_channels) 38 | x4 = x = net.block(x, '111') 39 | x = net.downscale(x, num_channels * 2) 40 | x8 = x = net.block(x, '1111') 41 | x = net.downscale(x, num_channels * 4) 42 | x = net.block(x, '12591259') 43 | net.tag(net.output(x, out_dim), 'p4') 44 | x = net.upsample(x, x8, out_dim) 45 | net.tag(x, 'p3') 46 | x = net.upsample(x, x4, out_dim) 47 | net.tag(x, 'p2') 48 | return net.bake() 49 | 50 | self.rgbd_backbone = make_rgbd_backbone() 51 | self.reduce_channel = torch.nn.Conv2d(256, 32, 1) 52 | def forward(self, img_features, small_disp, robot_joint_angles=None): 53 | small_disp = small_disp #self.stereo_stem.forward(stacked_img[:, 0:3], stacked_img[:, 3:6]) 54 | left_rgb_features = self.reduce_channel(img_features) 55 | disp_features = self.disp_features(small_disp) 56 | rgbd_features = torch.cat((disp_features, left_rgb_features), axis=1) 57 | outputs = self.rgbd_backbone.forward(rgbd_features) 58 | outputs['small_disp'] = small_disp 59 | return outputs 60 | 61 | 62 | class MultiviewNet(nn.Module): 63 | 64 | def __init__(self, hparams, ndepths=64, depth_min=0.01, depth_max=10.0, resnet=50): 65 | super().__init__() 66 | self.hparams = hparams 67 | self.ndepths = ndepths 68 | self.depth_min = depth_min 69 | self.depth_max = depth_max 70 | self.depth_interval = (depth_max - depth_min) / (ndepths - 1) 71 | 72 | # the to.(torch.float32) is required, if not will be all zeros 73 | self.depth_cands = torch.arange(0, ndepths, requires_grad=False).reshape(1, -1).to( 74 | torch.float32) * self.depth_interval + self.depth_min 75 | self.matchingFeature = psm_feature_extraction() 76 | self.semanticFeature = ResnetEncoder(resnet, "pretrained") # the features after bn and relu 77 | self.multiviewBackbone = MultiviewBackbone(hparams) 78 | self.stage_infos = { 79 | "stage1": { 80 | "scale": 4.0, 81 | }, 82 | "stage2": { 83 | "scale": 2.0, 84 | }, 85 | "stage3": { 86 | "scale": 1.0, 87 | } 88 | } 89 | 90 | self.pre0 = convbn_3d(64, 32, 1, 1, 0) 91 | self.pre1 = convbnrelu_3d(32, 32, 3, 1, 1) 92 | self.pre2 = convbn_3d(32, 32, 3, 1, 1) 93 | 94 | self.CostRegNet = DepthHybridDecoder(self.semanticFeature.num_ch_enc, 95 | num_output_channels=1, use_skips=True, 96 | ndepths=self.ndepths, depth_max=self.depth_max, 97 | IF_EST_transformer=False) 98 | 99 | 100 | # self.backbone = MultiviewBackbone(hparams) 101 | # ResFPN used p2,p3,p4,p5 (64 channels) 102 | # DRN uses only p2,p3,p4 (no need for p5 since dilation increases striding naturally) 103 | backbone_output_shape_4x = { 104 | #'p0': ShapeSpec(channels=64, height=None, width=None, stride=1), 105 | #'p1': ShapeSpec(channels=64, height=None, width=None, stride=2), 106 | 'p2': ShapeSpec(channels=64, height=None, width=None, stride=4), 107 | 'p3': ShapeSpec(channels=64, height=None, width=None, stride=8), 108 | 'p4': ShapeSpec(channels=64, height=None, width=None, stride=16), 109 | #'p5': ShapeSpec(channels=64, height=None, width=None, stride=32), 110 | } 111 | 112 | backbone_output_shape_8x = { 113 | #'p0': ShapeSpec(channels=64, height=None, width=None, stride=1), 114 | #'p1': ShapeSpec(channels=64, height=None, width=None, stride=2), 115 | #'p2': ShapeSpec(channels=64, height=None, width=None, stride=4), 116 | 'p3': ShapeSpec(channels=64, height=None, width=None, stride=8), 117 | 'p4': ShapeSpec(channels=64, height=None, width=None, stride=16), 118 | #'p5': ShapeSpec(channels=64, height=None, width=None, stride=32), 119 | } 120 | 121 | # Add depth head. 122 | self.depth_head = DepthHead(backbone_output_shape_4x, backbone_output_shape_8x, hparams) 123 | # Add segmentation head. 124 | self.seg_head = SegmentationHead(backbone_output_shape_4x, backbone_output_shape_8x, 3, hparams) 125 | # Add pose heads. 126 | self.pose_head = OBBHead(backbone_output_shape_4x, backbone_output_shape_8x, hparams) 127 | 128 | def get_costvolume(self, features, cam_poses, cam_intr, depth_values): 129 | """ 130 | return cost volume, [ref_feature, warped_feature] concat 131 | :param features: middle one is ref feature, others are source features 132 | :param cam_poses: 133 | :param cam_intr: 134 | :param depth_values: 135 | :return: 136 | """ 137 | num_views = len(features) 138 | ref_feature = features[0] 139 | ref_cam_pose = cam_poses[:, 0, :, :] 140 | 141 | ref_extrinsic = torch.inverse(ref_cam_pose) 142 | # step 2. differentiable homograph, build cost volume 143 | ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, self.ndepths, 1, 1) 144 | costvolume = torch.zeros_like(ref_volume).to(ref_volume.dtype).to(ref_volume.device) 145 | for view_i in range(num_views): 146 | if view_i == 0: 147 | continue 148 | src_fea = features[view_i] 149 | src_cam_pose = cam_poses[:, view_i, :, :] 150 | src_extrinsic = torch.inverse(src_cam_pose) 151 | # warpped features 152 | src_proj_new = src_extrinsic.clone() 153 | ref_proj_new = ref_extrinsic.clone() 154 | 155 | ref_proj_new = ref_proj_new.to('cpu') 156 | cam_intr = cam_intr.to('cpu') 157 | ref_extrinsic = ref_extrinsic.to('cpu') 158 | src_extrinsic = src_extrinsic.to('cpu') 159 | 160 | src_proj_new[:, :3, :4] = torch.matmul(cam_intr, src_extrinsic[:, :3, :4]) 161 | ref_proj_new[:, :3, :4] = (cam_intr @ ref_extrinsic[:, :3, :4]).clone() 162 | ref_proj_new = ref_proj_new.to('cuda') 163 | cam_intr = cam_intr.to('cuda') 164 | ref_extrinsic = ref_extrinsic.to('cuda') 165 | src_extrinsic = src_extrinsic.to('cuda') 166 | 167 | warped_volume = homo_warping(src_fea, src_proj_new, ref_proj_new, depth_values) 168 | 169 | # it seems that ref_volume - warped_volume not good 170 | x = torch.cat([ref_volume, warped_volume], dim=1) 171 | x = self.pre0(x) 172 | x = x + self.pre2(self.pre1(x)) 173 | 174 | costvolume = costvolume + x 175 | # aggregate multiple feature volumes by variance 176 | costvolume = costvolume / (num_views - 1) 177 | del warped_volume 178 | del x 179 | return costvolume 180 | 181 | def scale_cam_intr(self, cam_intr, scale): 182 | cam_intr_new = cam_intr[0].clone() 183 | cam_intr_new[:, :2, :] *= scale 184 | cam_intr_new[:, :2, :] *= scale 185 | 186 | 187 | return cam_intr_new 188 | def forward(self, imgs, cam_poses, cam_intr, pre_costs=None, pre_cam_poses=None, mode='train'): 189 | """ 190 | input seqs (0,1,2,3,4) target view will be (1,2,3) or input three views 191 | :param imgs: 192 | :param cam_poses: 193 | :param cam_intr: 194 | :param sample: 195 | :return: 196 | """ 197 | imgs = 2 * (imgs / 255.) - 1. 198 | assert len(imgs.shape) == 5, 'expected imgs to be BxVxCxHxW' 199 | batch_size, views_num, _, height_img, width_img = imgs.shape 200 | 201 | height = height_img // 4 202 | width = width_img // 4 203 | 204 | assert views_num >= 2, f'View number should be greater 1, but is {views_num}' # the views_num should be larger than 2 205 | 206 | target_num = 0 207 | 208 | # Convert list of tensors to tensor 209 | assert len(cam_poses.shape) == 4, f'expected shape to be len 4, got {cam_poses.shape}' 210 | 211 | matching_features = self.matchingFeature(imgs.view(batch_size * views_num, 3, height_img, width_img)) 212 | 213 | matching_features = matching_features.view(batch_size, views_num, -1, height, width) 214 | 215 | matching_features = matching_features.permute(1, 0, 2, 3, 4).contiguous() 216 | 217 | semantic_features = self.semanticFeature( 218 | imgs[:, 0].view(batch_size, -1, height_img, width_img)) 219 | 220 | cam_intr_stage1 = self.scale_cam_intr(cam_intr, scale=1. / self.stage_infos["stage1"]["scale"]) 221 | 222 | depth_values = self.depth_cands.view(1, self.ndepths, 1, 1 223 | ).repeat(batch_size, 1, 1, 1).to(imgs.dtype).to(imgs.device) 224 | 225 | target_cam_poses = [] 226 | 227 | # Get the cost volume 228 | cost_volume = self.get_costvolume(matching_features, 229 | cam_poses[:, :, :, :], #bs x views x 4 x 4 230 | cam_intr_stage1, 231 | depth_values) 232 | 233 | outputs, cur_costs, cur_cam_poses = self.CostRegNet(costvolumes = [cost_volume], 234 | semantic_features = semantic_features, 235 | cam_poses = target_cam_poses, 236 | cam_intr = cam_intr_stage1, 237 | depth_values = depth_values, 238 | depth_min = self.depth_min, 239 | depth_interval = self.depth_interval, 240 | pre_costs = pre_costs, 241 | pre_cam_poses = pre_cam_poses, 242 | mode = mode) 243 | 244 | # Convert depth to desired shape 245 | 246 | # Output a small displacement output (H/4, W/4) 247 | small_disp_output = outputs[("depth", 0, 2)] 248 | 249 | assert len(small_disp_output.shape) == 4, f'Expecting depth to be Nx1xHxW, but got {small_disp_output.shape}' 250 | 251 | # Get RGB Features 252 | features = self.multiviewBackbone(img_features = semantic_features[1], 253 | small_disp = small_disp_output) 254 | 255 | small_disp_output = small_disp_output.squeeze(dim=1) 256 | if self.hparams.frozen_stereo_checkpoint is not None: 257 | small_disp_output = small_disp_output.detach() 258 | assert False 259 | small_depth_output = depth_outputs.DepthOutput(small_disp_output, self.hparams.loss_depth_mult) 260 | seg_output = self.seg_head.forward(features) 261 | depth_output = self.depth_head.forward(features) 262 | pose_output = self.pose_head.forward(features) 263 | box_output = None 264 | keypoint_output = None 265 | return seg_output, depth_output, small_depth_output, pose_output, box_output, keypoint_output 266 | 267 | def res_fpn(hparams): 268 | return MultiviewNet(hparams) 269 | 270 | if __name__ == '__main__': 271 | model = MultiviewNet() -------------------------------------------------------------------------------- /model/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ac-rad/MVTrans/8698a19cff8ee6ab317af88dd408e0e098c35504/model/networks/__init__.py -------------------------------------------------------------------------------- /model/networks/layers_op.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import math 8 | 9 | 10 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation): 11 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 12 | padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False), 13 | nn.BatchNorm2d(out_planes)) 14 | 15 | 16 | def convbn_3d(in_planes, out_planes, kernel_size, stride, pad): 17 | return nn.Sequential( 18 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride, bias=False), 19 | nn.BatchNorm3d(out_planes)) 20 | 21 | 22 | def convbnrelu(in_planes, out_planes, kernel_size, stride, pad, dilation): 23 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 24 | padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False), 25 | nn.BatchNorm2d(out_planes), 26 | nn.ReLU(inplace=True)) 27 | 28 | 29 | def convbnrelu_3d(in_planes, out_planes, kernel_size, stride, pad): 30 | return nn.Sequential( 31 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride, bias=False), 32 | nn.BatchNorm3d(out_planes), 33 | nn.ReLU(inplace=True)) 34 | 35 | def convbntanh_3d(in_planes, out_planes, kernel_size, stride, pad): 36 | return nn.Sequential( 37 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride, bias=False), 38 | nn.BatchNorm3d(out_planes), 39 | nn.Tanh()) 40 | 41 | -------------------------------------------------------------------------------- /model/networks/psm_submodule.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import math 8 | from utils.homo_utils import * 9 | from model.networks.layers_op import convbn 10 | torch.backends.cudnn.benchmark = False 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 17 | super(BasicBlock, self).__init__() 18 | 19 | self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation), 20 | nn.ReLU(inplace=True)) 21 | 22 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation) 23 | 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | out = self.conv1(x) 29 | out = self.conv2(out) 30 | 31 | if self.downsample is not None: 32 | x = self.downsample(x) 33 | 34 | out += x 35 | 36 | return out 37 | 38 | 39 | class psm_feature_extraction(nn.Module): 40 | def __init__(self): 41 | super(psm_feature_extraction, self).__init__() 42 | self.inplanes = 32 43 | self.firstconv = nn.Sequential(convbn(3, 32, 3, 2, 1, 1), 44 | nn.ReLU(inplace=True), 45 | convbn(32, 32, 3, 1, 1, 1), 46 | nn.ReLU(inplace=True), 47 | convbn(32, 32, 3, 1, 1, 1), 48 | nn.ReLU(inplace=True)) 49 | 50 | self.layer1 = self._make_layer(block = BasicBlock, planes = 32, blocks = 3, stride = 1, pad = 1, dilation = 1) 51 | self.layer2 = self._make_layer(block = BasicBlock, planes = 64, blocks = 16, stride = 2, pad = 1, dilation = 1) 52 | self.layer3 = self._make_layer(BasicBlock, 128, 3, 1, 1, 1) 53 | self.layer4 = self._make_layer(BasicBlock, 128, 3, 1, 1, 2) 54 | 55 | self.branch1 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32, 32)), 56 | convbn(in_planes = 128, out_planes = 32, kernel_size = 1, stride = 1, pad = 0, dilation = 1), 57 | nn.ReLU(inplace=True)) 58 | 59 | self.branch2 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16, 16)), 60 | convbn(128, 32, 1, 1, 0, 1), 61 | nn.ReLU(inplace=True)) 62 | 63 | self.branch3 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8, 8)), 64 | convbn(128, 32, 1, 1, 0, 1), 65 | nn.ReLU(inplace=True)) 66 | 67 | self.branch4 = nn.Sequential(nn.AvgPool2d((4, 4), stride=(4, 4)), 68 | convbn(128, 32, 1, 1, 0, 1), 69 | nn.ReLU(inplace=True)) 70 | 71 | self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride=1, bias=False)) 74 | # nn.Conv2d(128, 32, kernel_size=1, padding=0, stride=1, bias=False)) 75 | self.out_channels = [32] 76 | 77 | def _make_layer(self, block, planes, blocks, stride, pad, dilation): 78 | downsample = None 79 | if stride != 1 or self.inplanes != planes * block.expansion: 80 | downsample = nn.Sequential( 81 | nn.Conv2d(self.inplanes, planes * block.expansion, 82 | kernel_size=1, stride=stride, bias=False), 83 | nn.BatchNorm2d(planes * block.expansion), ) 84 | 85 | layers = [] 86 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation)) 87 | self.inplanes = planes * block.expansion 88 | for i in range(1, blocks): 89 | layers.append(block(self.inplanes, planes, 1, None, pad, dilation)) 90 | 91 | return nn.Sequential(*layers) 92 | 93 | def forward(self, x): 94 | output = self.firstconv(x) 95 | output_layer1 = self.layer1(output) 96 | output_raw = self.layer2(output_layer1) 97 | output = self.layer3(output_raw) 98 | output_skip = self.layer4(output) 99 | 100 | output_branch1 = self.branch1(output_skip) 101 | output_branch1 = F.upsample(output_branch1, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 102 | 103 | output_branch2 = self.branch2(output_skip) 104 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 105 | 106 | output_branch3 = self.branch3(output_skip) 107 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 108 | 109 | output_branch4 = self.branch4(output_skip) 110 | output_branch4 = F.upsample(output_branch4, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 111 | 112 | output_feature = torch.cat( 113 | (output_raw, output_skip, output_branch4, output_branch3, output_branch2, output_branch1), 1) 114 | output_feature = self.lastconv(output_feature) 115 | 116 | return output_feature # [1/4 scale], the feature map is not after bn and relu 117 | -------------------------------------------------------------------------------- /model/networks/resnet_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.models as models 14 | import torch.utils.model_zoo as model_zoo 15 | 16 | 17 | class ResnetEncoder(nn.Module): 18 | """Pytorch module for a resnet encoder 19 | """ 20 | 21 | def __init__(self, num_layers, pretrained, num_input_images=1): 22 | super(ResnetEncoder, self).__init__() 23 | 24 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 25 | 26 | resnets = {18: models.resnet18, 27 | 34: models.resnet34, 28 | 50: models.resnet50, 29 | 101: models.resnet101, 30 | 152: models.resnet152} 31 | 32 | if num_layers not in resnets: 33 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 34 | 35 | self.encoder = resnets[num_layers](pretrained) 36 | 37 | if num_layers > 34: 38 | self.num_ch_enc[1:] *= 4 39 | 40 | def forward(self, x): 41 | self.features = [] 42 | 43 | x = self.encoder.conv1(x) 44 | x = self.encoder.bn1(x) 45 | self.features.append(self.encoder.relu(x)) 46 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) 47 | self.features.append(self.encoder.layer2(self.features[-1])) 48 | self.features.append(self.encoder.layer3(self.features[-1])) 49 | self.features.append(self.encoder.layer4(self.features[-1])) 50 | 51 | return self.features # the feature maps is activated by relu 52 | -------------------------------------------------------------------------------- /model/networks/senet_submodule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | from networks.psm_submodule import convbn 5 | import torch.nn.functional as F 6 | from .senet import SEModule, Bottleneck 7 | 8 | 9 | class SEBottleneck(Bottleneck): 10 | """ 11 | Bottleneck for SENet154. 12 | """ 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 16 | downsample=None): 17 | super(SEBottleneck, self).__init__() 18 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes * 2) 20 | self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, 21 | stride=stride, padding=1, groups=groups, 22 | bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes * 2) 24 | self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, 25 | bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * 4) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.se_module = SEModule(planes * 4, reduction=reduction) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | 33 | class se_feature_extraction(nn.Module): 34 | def __init__(self): 35 | super(se_feature_extraction, self).__init__() 36 | self.inplanes = 32 37 | self.firstconv = nn.Sequential(convbn(3, 32, 3, 2, 1, 1), 38 | nn.ReLU(inplace=True), 39 | convbn(32, 32, 3, 1, 1, 1), 40 | nn.ReLU(inplace=True), 41 | convbn(32, 32, 3, 1, 1, 1), 42 | nn.ReLU(inplace=True)) 43 | 44 | self.layer1 = self._make_layer(SEBottleneck, 45 | planes=32, 46 | blocks=3, 47 | stride=1, 48 | groups=32, 49 | reduction=16, 50 | downsample_kernel_size=1, 51 | downsample_padding=0) 52 | self.layer2 = self._make_layer(SEBottleneck, 53 | planes=32, 54 | blocks=3, 55 | stride=2, 56 | groups=32, 57 | reduction=16, 58 | downsample_kernel_size=3, 59 | downsample_padding=1) 60 | self.layer3 = self._make_layer(SEBottleneck, 61 | planes=32, 62 | blocks=3, 63 | stride=1, 64 | groups=32, 65 | reduction=16, 66 | downsample_kernel_size=1, 67 | downsample_padding=0) 68 | self.layer4 = self._make_layer(SEBottleneck, 69 | planes=32, 70 | blocks=3, 71 | stride=1, 72 | groups=32, 73 | reduction=16, 74 | downsample_kernel_size=1, 75 | downsample_padding=0) 76 | 77 | self.branch1 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32, 32)), 78 | convbn(128, 32, 1, 1, 0, 1), 79 | nn.ReLU(inplace=True)) 80 | 81 | self.branch2 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16, 16)), 82 | convbn(128, 32, 1, 1, 0, 1), 83 | nn.ReLU(inplace=True)) 84 | 85 | self.branch3 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8, 8)), 86 | convbn(128, 32, 1, 1, 0, 1), 87 | nn.ReLU(inplace=True)) 88 | 89 | self.branch4 = nn.Sequential(nn.AvgPool2d((4, 4), stride=(4, 4)), 90 | convbn(128, 32, 1, 1, 0, 1), 91 | nn.ReLU(inplace=True)) 92 | 93 | self.lastconv = nn.Sequential(convbn(384, 128, 3, 1, 1, 1), 94 | nn.ReLU(inplace=True), 95 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride=1, bias=False)) 96 | 97 | self.out_channels = [128, 32] 98 | 99 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 100 | downsample_kernel_size=1, downsample_padding=0): 101 | downsample = None 102 | if stride != 1 or self.inplanes != planes * block.expansion: 103 | downsample = nn.Sequential( 104 | nn.Conv2d(self.inplanes, planes * block.expansion, 105 | kernel_size=downsample_kernel_size, stride=stride, 106 | padding=downsample_padding, bias=False), 107 | nn.BatchNorm2d(planes * block.expansion), 108 | ) 109 | 110 | layers = [] 111 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 112 | downsample)) 113 | self.inplanes = planes * block.expansion 114 | for i in range(1, blocks): 115 | layers.append(block(self.inplanes, planes, groups, reduction)) 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | output = self.firstconv(x) 121 | output_layer1 = self.layer1(output) 122 | output_raw = self.layer2(output_layer1) 123 | output = self.layer3(output_raw) 124 | output_skip = self.layer4(output) 125 | 126 | output_branch1 = self.branch1(output_skip) 127 | output_branch1 = F.upsample(output_branch1, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 128 | 129 | output_branch2 = self.branch2(output_skip) 130 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 131 | 132 | output_branch3 = self.branch3(output_skip) 133 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 134 | 135 | output_branch4 = self.branch4(output_skip) 136 | output_branch4 = F.upsample(output_branch4, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 137 | 138 | output_feature = torch.cat( 139 | (output_raw, output_skip, output_branch4, output_branch3, output_branch2, output_branch1), 1) 140 | output_feature = self.lastconv(output_feature) 141 | 142 | return output_layer1, output_feature # [1/2. 1/4 scale] 143 | -------------------------------------------------------------------------------- /model/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ac-rad/MVTrans/8698a19cff8ee6ab317af88dd408e0e098c35504/model/transformer/__init__.py -------------------------------------------------------------------------------- /model/transformer/epipolar_transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import math 8 | import numpy as np 9 | 10 | class EpipolarTransformer(nn.Module): 11 | # combine GRU and transformer 12 | def __init__(self, input_channel, output_channel, kernel_size): 13 | super(EpipolarTransformer, self).__init__() 14 | self.softmax = torch.nn.Softmax(dim=-1) 15 | # self.gamma = torch.nn.Parameter(torch.zeros(1)) 16 | 17 | # filters used for gates 18 | gru_input_channel = input_channel + output_channel 19 | self.output_channel = output_channel 20 | 21 | self.gate_conv = nn.Conv3d(gru_input_channel, output_channel * 2, kernel_size, padding=1) 22 | self.reset_gate_norm = nn.GroupNorm(1, output_channel, 1e-5, True) 23 | self.update_gate_norm = nn.GroupNorm(1, output_channel, 1e-5, True) 24 | 25 | # filters used for outputs 26 | self.output_conv = nn.Conv3d(gru_input_channel, output_channel, kernel_size, padding=1) 27 | self.output_norm = nn.GroupNorm(1, output_channel, 1e-5, True) 28 | 29 | self.activation = nn.Tanh() 30 | 31 | def gates(self, x, h): 32 | # x = N x C x D x H x W 33 | # h = N x C x D x H x W 34 | 35 | # c = N x C*2 x D x H x W 36 | c = torch.cat((x, h), dim=1) 37 | f = self.gate_conv(c) 38 | 39 | # r = reset gate, u = update gate 40 | # both are N x O x D x H x W 41 | C = f.shape[1] 42 | r, u = torch.split(f, C // 2, 1) 43 | 44 | rn = self.reset_gate_norm(r) 45 | un = self.update_gate_norm(u) 46 | rns = torch.nn.functional.sigmoid(rn) 47 | uns = torch.nn.functional.sigmoid(un) 48 | return rns, uns 49 | 50 | def output(self, x, h, r, u): 51 | f = torch.cat((x, r * h), dim=1) 52 | o = self.output_conv(f) 53 | on = self.output_norm(o) 54 | return on 55 | 56 | def forward(self, target_key, target_value, warped_values=None, warped_keys=None): 57 | """ 58 | return the fused volume of target_volume 59 | """ 60 | B, C, D, H, W = target_value.shape 61 | 62 | if warped_values is not None: 63 | correlations = [] 64 | for key in warped_keys: 65 | correlation = torch.sum(target_key * key, dim=1, keepdim=True) # [B,1,D,H,W] 66 | correlations.append(correlation) 67 | 68 | correlations = torch.stack(correlations, dim=-1) # [B,1,D,H,W, N] 69 | attention_maps = self.softmax(correlations) # [B,1,D,H,W, N] 70 | 71 | values = torch.stack(warped_values, dim=-1) # [B,C,D,H,W, N] 72 | 73 | h = torch.mean(values * attention_maps.repeat(1, C, 1, 1, 1, 1), dim=-1, keepdim=False) 74 | else: 75 | h = None 76 | 77 | HC = self.output_channel 78 | if (h is None): 79 | h = torch.zeros((B, HC, D, H, W), dtype=torch.float, device=target_value.device) 80 | r, u = self.gates(target_value, h) 81 | o = self.output(target_value, h, r, u) 82 | y = self.activation(o) 83 | return u * h + (1 - u) * y 84 | 85 | 86 | -------------------------------------------------------------------------------- /net_train_multiview.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYTHONHASHSEED'] = str(1) 3 | import argparse 4 | import random 5 | random.seed(12345) 6 | import numpy as np 7 | np.random.seed(12345) 8 | import torch 9 | torch.manual_seed(12345) 10 | import wandb 11 | 12 | import pytorch_lightning as pl 13 | from pytorch_lightning.callbacks import ModelCheckpoint 14 | from pytorch_lightning import loggers 15 | 16 | from src.lib.net import common 17 | from src.lib import datapoint, camera 18 | from src.lib.net.post_processing.eval3d import Eval3d, extract_objects_from_detections 19 | from src.lib.net.panoptic_trainer import PanopticModel 20 | 21 | _GPU_TO_USE = [0] 22 | _NUM_NODES=2 23 | 24 | 25 | class EvalMethod(): 26 | 27 | def __init__(self, mode = None): 28 | self.eval_3d = Eval3d() 29 | if mode == 'blender': 30 | self.camera_model = camera.BlenderCamera() 31 | else: 32 | raise ValueError 33 | 34 | def process_sample(self, pose_outputs, box_outputs, seg_outputs, detections_gt, scene_name): 35 | detections = pose_outputs.get_detections(self.camera_model) 36 | if scene_name != 'sim': 37 | table_detection, detections_gt, detections = extract_objects_from_detections( 38 | detections_gt, detections 39 | ) 40 | self.eval_3d.process_sample(detections, detections_gt, scene_name) 41 | return True 42 | 43 | def process_all_dataset(self, log): 44 | log['all 3Dmap'] = self.eval_3d.process_all_3D_dataset() 45 | 46 | def draw_detections( 47 | self, pose_outputs, box_outputs, seg_outputs, keypoint_outputs, left_image_np, llog, prefix 48 | ): 49 | pose_vis = pose_outputs.get_visualization_img( 50 | np.copy(left_image_np), camera_model=self.camera_model 51 | ) 52 | llog[f'{prefix}/pose'] = wandb.Image(pose_vis, caption=prefix) 53 | seg_vis = seg_outputs.get_visualization_img(np.copy(left_image_np)) 54 | llog[f'{prefix}/seg'] = wandb.Image(seg_vis, caption=prefix) 55 | 56 | def reset(self): 57 | self.eval_3d = Eval3d() 58 | 59 | def GetLatestCheckpoint(out_folder): 60 | if len(os.listdir(out_folder)) == 0: 61 | return False 62 | else: 63 | max_mtime = 0 64 | for dirname,subdirs,files in os.walk(out_folder): 65 | print(files) 66 | for fname in files: 67 | if not fname.endswith('.ckpt'): 68 | continue 69 | full_path = os.path.join(dirname, fname) 70 | mtime = os.stat(full_path).st_mtime 71 | if mtime > max_mtime: 72 | max_mtime = mtime 73 | max_dir = dirname 74 | max_file = fname 75 | try: 76 | return os.path.join(max_dir,max_file) 77 | except: 78 | return False 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser(fromfile_prefix_chars='@') 83 | common.add_train_args(parser) 84 | hparams = parser.parse_args() 85 | 86 | # Get the mode of training 87 | if 'simnet' in hparams.train_path: 88 | training_mode = 'simnet' 89 | elif 'blender' or 'synthetic' in hparams.train_path: 90 | training_mode = 'blender' 91 | 92 | print(f'Making the {training_mode} dataset') 93 | if training_mode == 'simnet': 94 | train_ds = datapoint.make_dataset(hparams.train_path) 95 | val_ds = datapoint.make_dataset(hparams.val_path) 96 | elif training_mode == 'blender': 97 | if hparams.network_type == 'multiview': 98 | train_ds = datapoint.make_dataset(hparams.train_path, dataset ='blender', multiview = True, num_multiview = hparams.num_multiview, num_samples = hparams.num_samples) 99 | val_ds = datapoint.make_dataset(hparams.val_path, dataset ='blender', multiview = True, num_multiview = hparams.num_multiview, num_samples = hparams.num_samples) 100 | elif hparams.network_type == 'simnet': 101 | train_ds = datapoint.make_dataset(hparams.train_path, dataset ='blender') 102 | val_ds = datapoint.make_dataset(hparams.val_path, dataset ='blender') 103 | else: 104 | raise ValueError 105 | 106 | samples_per_epoch = len(train_ds.list()) 107 | samples_per_step = hparams.train_batch_size 108 | steps = hparams.max_steps 109 | steps_per_epoch = samples_per_epoch // samples_per_step 110 | epochs = int(np.ceil(steps / steps_per_epoch)) 111 | actual_steps = epochs * steps_per_epoch 112 | print('Samples per epoch', samples_per_epoch) 113 | print('Steps per epoch', steps_per_epoch) 114 | print('Target steps:', steps) 115 | print('Actual steps:', actual_steps) 116 | print('Epochs:', epochs) 117 | 118 | # Login to wandb 119 | wandb.login(key='YOUR_KEY') 120 | 121 | model = PanopticModel(hparams = hparams, 122 | epochs = epochs, 123 | train_dataset = train_ds, 124 | eval_metric = EvalMethod(mode = training_mode), 125 | val_dataset = val_ds) 126 | 127 | model_checkpoint = ModelCheckpoint(dirpath=hparams.output, save_top_k=-1, mode='min', save_last = True, monitor = 'val_loss') 128 | wandb_logger = loggers.WandbLogger(name=hparams.wandb_name, project='simnet') 129 | 130 | # Make output folder if doesn't exist 131 | if not os.path.exists(hparams.output): 132 | os.mkdir(hparams.output) 133 | latest_ckpt = GetLatestCheckpoint(out_folder = hparams.output) 134 | if not latest_ckpt: 135 | trainer = pl.Trainer( 136 | accelerator="gpu", 137 | max_epochs=epochs, 138 | gpus=_GPU_TO_USE, 139 | checkpoint_callback=model_checkpoint, 140 | default_root_dir = hparams.output, 141 | #val_check_interval=0.7, 142 | check_val_every_n_epoch=1, 143 | logger=wandb_logger, 144 | strategy='ddp', 145 | detect_anomaly=True, 146 | ) 147 | else: 148 | print('TRAINING FROM CHECKPOINT!!!!!!!!!!!!!!') 149 | trainer = pl.Trainer( 150 | accelerator="gpu", 151 | max_epochs=epochs, 152 | gpus=_GPU_TO_USE, 153 | checkpoint_callback=model_checkpoint, 154 | default_root_dir = hparams.output, 155 | #val_check_interval=0.7, 156 | check_val_every_n_epoch=1, 157 | logger=wandb_logger, 158 | resume_from_checkpoint = latest_ckpt, 159 | strategy='ddp', 160 | detect_anomaly=True, 161 | ) 162 | 163 | trainer.fit(model) 164 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | actionlib==1.12.1 3 | addict==2.4.0 4 | aiohttp==3.8.1 5 | aiosignal==1.2.0 6 | angles==1.9.12 7 | anyio==3.6.1 8 | argon2-cffi==21.3.0 9 | argon2-cffi-bindings==21.2.0 10 | async-timeout==4.0.2 11 | attrs==21.4.0 12 | Babel==2.10.3 13 | backcall==0.2.0 14 | beautifulsoup4==4.11.1 15 | bleach==5.0.0 16 | bondpy==1.8.5 17 | boto3==1.18.1 18 | botocore==1.21.1 19 | cachetools==4.2.2 20 | camera_calibration==1.15.0 21 | camera_calibration_parsers==1.11.13 22 | catkin==0.7.29 23 | certifi==2021.5.30 24 | cffi==1.15.0 25 | charset-normalizer==2.0.3 26 | click==8.0.1 27 | colour==0.1.5 28 | configparser==5.0.2 29 | controller_manager==0.18.4 30 | controller_manager_msgs==0.18.4 31 | cv_bridge==1.13.0 32 | cycler==0.10.0 33 | dataclasses==0.6 34 | debugpy==1.6.0 35 | decorator==4.4.2 36 | defusedxml==0.7.1 37 | deprecation==2.1.0 38 | diagnostic_analysis==1.9.7 39 | diagnostic_common_diagnostics==1.9.7 40 | diagnostic_updater==1.9.7 41 | docker-pycreds==0.4.0 42 | dynamic_reconfigure==1.6.4 43 | entrypoints==0.4 44 | fastjsonschema==2.15.3 45 | frozenlist==1.3.0 46 | fsspec==2022.5.0 47 | future==0.18.2 48 | gazebo_plugins==2.8.7 49 | gazebo_ros==2.8.7 50 | gencpp==0.6.5 51 | geneus==2.2.6 52 | genlisp==0.4.16 53 | genmsg==0.5.16 54 | gennodejs==2.0.1 55 | genpy==0.6.16 56 | gitdb==4.0.7 57 | GitPython==3.1.18 58 | google-auth==1.33.0 59 | google-auth-oauthlib==0.4.4 60 | greenlet==1.1.0 61 | grpcio==1.38.1 62 | idna==3.2 63 | image_geometry==1.13.0 64 | imageio==2.9.0 65 | importlib-metadata==4.12.0 66 | importlib-resources==5.7.1 67 | interactive-markers==1.11.5 68 | ipykernel==6.13.1 69 | ipython==7.25.0 70 | ipython-genutils==0.2.0 71 | ipywidgets==7.7.0 72 | jedi==0.18.0 73 | Jinja2==3.1.2 74 | jmespath==0.10.0 75 | joblib==1.1.0 76 | joint_state_publisher==1.12.15 77 | joint_state_publisher_gui==1.12.15 78 | json5==0.9.9 79 | jsonschema==4.6.0 80 | jupyter==1.0.0 81 | jupyter-client==7.3.4 82 | jupyter-console==6.4.3 83 | jupyter-core==4.10.0 84 | jupyter-packaging==0.12.2 85 | jupyter-server==1.18.1 86 | jupyterlab==3.4.4 87 | jupyterlab-pygments==0.2.2 88 | jupyterlab-server==2.15.0 89 | jupyterlab-widgets==1.1.0 90 | kdl_parser_py==1.13.1 91 | kiwisolver==1.3.1 92 | laser_geometry==1.6.7 93 | Markdown==3.3.4 94 | MarkupSafe==2.1.1 95 | matplotlib==3.4.2 96 | matplotlib-inline==0.1.2 97 | message_filters==1.14.13 98 | mistune==0.8.4 99 | moveit-core==1.0.10 100 | moveit_commander==1.0.10 101 | moveit_ros_planning_interface==1.0.10 102 | moveit_ros_visualization==1.0.10 103 | msgpack==1.0.2 104 | multidict==6.0.2 105 | nbclassic==0.4.3 106 | nbclient==0.6.4 107 | nbconvert==6.5.0 108 | nbformat==5.4.0 109 | nest-asyncio==1.5.5 110 | networkx==2.5.1 111 | notebook==6.4.12 112 | notebook-shim==0.1.0 113 | numpy==1.21.0 114 | oauthlib==3.1.1 115 | open3d==0.15.2 116 | opencv-python==4.5.3.56 117 | OpenEXR==1.3.8 118 | packaging==21.3 119 | pandas==1.4.3 120 | pandocfilters==1.5.0 121 | parso==0.8.2 122 | pathtools==0.1.2 123 | pexpect==4.8.0 124 | pickleshare==0.7.5 125 | Pillow==9.2.0 126 | prometheus-client==0.14.1 127 | promise==2.3 128 | prompt-toolkit==3.0.19 129 | protobuf==3.17.3 130 | psutil==5.8.0 131 | ptyprocess==0.7.0 132 | pyasn1==0.4.8 133 | pyasn1-modules==0.2.8 134 | pycparser==2.21 135 | pyDeprecate==0.3.2 136 | PyEXR==0.3.10 137 | Pygments==2.9.0 138 | pynvim==0.4.3 139 | pyparsing==2.4.7 140 | pyquaternion==0.9.9 141 | pyrsistent==0.18.1 142 | python-dateutil==2.8.2 143 | python_qt_binding==0.4.4 144 | pytorch-lightning==1.6.0 145 | pytz==2022.1 146 | PyVISA==1.12.0 147 | PyWavelets==1.1.1 148 | PyYAML==5.4.1 149 | pyzmq==23.1.0 150 | qt-dotgraph==0.4.2 151 | qt-gui==0.4.2 152 | qt-gui-cpp==0.4.2 153 | qt-gui-py-common==0.4.2 154 | qtconsole==5.3.1 155 | QtPy==2.1.0 156 | requests==2.26.0 157 | requests-oauthlib==1.3.0 158 | resource_retriever==1.12.7 159 | rosbag==1.14.13 160 | rosboost-cfg==1.14.9 161 | rosclean==1.14.9 162 | roscreate==1.14.9 163 | rosgraph==1.14.13 164 | roslaunch==1.14.13 165 | roslib==1.14.9 166 | roslint==0.11.2 167 | roslz4==1.14.13 168 | rosmake==1.14.9 169 | rosmaster==1.14.13 170 | rosmsg==1.14.13 171 | rosnode==1.14.13 172 | rosparam==1.14.13 173 | rospy==1.14.13 174 | rosservice==1.14.13 175 | rostest==1.14.13 176 | rostopic==1.14.13 177 | rosunit==1.14.9 178 | roswtf==1.14.13 179 | rqt-moveit==0.5.10 180 | rqt-reconfigure==0.5.4 181 | rqt-robot-monitor==0.5.13 182 | rqt-rviz==0.7.0 183 | rqt_action==0.4.9 184 | rqt_bag==0.5.1 185 | rqt_bag_plugins==0.5.1 186 | rqt_console==0.4.9 187 | rqt_dep==0.4.9 188 | rqt_graph==0.4.11 189 | rqt_gui==0.5.2 190 | rqt_gui_py==0.5.2 191 | rqt_image_view==0.4.16 192 | rqt_launch==0.4.8 193 | rqt_logger_level==0.4.8 194 | rqt_msg==0.4.8 195 | rqt_nav_view==0.5.7 196 | rqt_plot==0.4.13 197 | rqt_pose_view==0.5.8 198 | rqt_publisher==0.4.8 199 | rqt_py_common==0.5.2 200 | rqt_py_console==0.4.8 201 | rqt_robot_dashboard==0.5.7 202 | rqt_robot_steering==0.5.10 203 | rqt_runtime_monitor==0.5.7 204 | rqt_service_caller==0.4.8 205 | rqt_shell==0.4.9 206 | rqt_srv==0.4.8 207 | rqt_tf_tree==0.6.0 208 | rqt_top==0.4.8 209 | rqt_topic==0.4.11 210 | rqt_web==0.4.8 211 | rsa==4.7.2 212 | rviz==1.13.24 213 | s3transfer==0.5.0 214 | scikit-image==0.18.2 215 | scikit-learn==1.1.2 216 | scipy==1.7.0 217 | Send2Trash==1.8.0 218 | sensor-msgs==1.12.8 219 | sentry-sdk==1.3.0 220 | setproctitle==1.2.3 221 | shortuuid==1.0.1 222 | six==1.16.0 223 | sklearn==0.0 224 | smach==2.0.1 225 | smach_ros==2.0.1 226 | smclib==1.8.5 227 | smmap==4.0.0 228 | sniffio==1.2.0 229 | soupsieve==2.3.2.post1 230 | srdfdom==0.5.2 231 | subprocess32==3.5.4 232 | tensorboard==2.5.0 233 | tensorboard-data-server==0.6.1 234 | tensorboard-plugin-wit==1.8.0 235 | terminado==0.15.0 236 | tf==1.12.1 237 | tf2_geometry_msgs==0.6.5 238 | tf2_kdl==0.6.5 239 | tf2_py==0.6.5 240 | tf2_ros==0.6.5 241 | tf_conversions==1.12.1 242 | threadpoolctl==3.1.0 243 | tifffile==2021.7.2 244 | tinycss2==1.1.1 245 | tomlkit==0.11.2 246 | topic_tools==1.14.13 247 | torch==1.8.1+cu111 248 | torchaudio==0.8.1 249 | torchmetrics==0.9.3 250 | torchvision==0.9.1+cu111 251 | tornado==6.1 252 | tqdm==4.61.2 253 | traitlets==5.2.2.post1 254 | typing_extensions==4.2.0 255 | urdfdom-py==0.4.6 256 | urllib3==1.26.6 257 | visa==1.0.0 258 | wandb==0.12.18 259 | wcwidth==0.2.5 260 | webencodings==0.5.1 261 | websocket-client==1.3.3 262 | Werkzeug==2.0.1 263 | widgetsnbextension==3.6.0 264 | xacro==1.13.17 265 | yarl==1.7.2 266 | ydiff==1.2 267 | zipp==3.8.0 268 | zstandard==0.15.2 269 | -------------------------------------------------------------------------------- /runner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -Eeuo pipefail 4 | 5 | SCRIPT_DIR=$(dirname $(readlink -f $0)) 6 | export PYTHONPATH=$(readlink -f "${SCRIPT_DIR}") 7 | export OPENBLAS_NUM_THREADS=1 8 | 9 | /home/USER/envs/ENVS_NAME/bin/python $@ 10 | -------------------------------------------------------------------------------- /src/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Toyota Research Institute (TRI) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/lib/camera.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from src.lib import transform 4 | 5 | CLIPPING_PLANE_NEAR = 0.4 6 | SCALE_FACTOR = 4 7 | 8 | 9 | class BlenderCamera: 10 | 11 | def __init__( 12 | self, 13 | hfov_deg=100., 14 | vfov_deg=80., 15 | height=480, 16 | width=640, 17 | stereo_baseline=0.06499999761581421, 18 | enable_noise=False, 19 | override_intrinsics= np.array([[613.9624633789062, 0.0, 320.0, 0.0], 20 | [0.0, 613.9624633789062, 240.0, 0.0], 21 | [0.0, 0.0, 1.0, 0.0], 22 | [0.0, 0.0, 0.0, 1.0]]) 23 | ): 24 | """ The default camera model to match the Basler's rectified implementation at TOT """ 25 | # This is to go from mmt to pyrender frame 26 | self.RT_matrix = transform.Transform.from_aa(axis=transform.X_AXIS, angle_deg=180.0).matrix 27 | if override_intrinsics is not None: 28 | self._set_intrinsics(override_intrinsics) 29 | return 30 | height = height // SCALE_FACTOR 31 | width = width // SCALE_FACTOR 32 | assert height % 64 == 0 33 | assert width % 64 == 0 34 | 35 | self.height = height 36 | self.width = width 37 | 38 | # self.stereo_baseline = stereo_baseline 39 | self.is_left = True 40 | 41 | hfov = np.deg2rad(hfov_deg) 42 | vfov = np.deg2rad(vfov_deg) 43 | focal_length_x = 0.5 * width / np.tan(0.5 * hfov) 44 | focal_length_y = 0.5 * height / np.tan(0.5 * vfov) 45 | 46 | focal_length = focal_length_x 47 | focal_length_ar = focal_length_y / focal_length_x 48 | 49 | self._set_intrinsics( 50 | np.array([ 51 | [focal_length, 0., width / 2., 0.], 52 | [0., focal_length * focal_length_ar, height / 2., 0.], 53 | [0., 0., 1., 0.], 54 | [0., 0., 0., 1.], 55 | ]) 56 | ) 57 | 58 | def add_camera_noise(self, img): 59 | return camera_noise.add(img) 60 | 61 | def make_datapoint(self): 62 | k_matrix = self.K_matrix[:3, :3] 63 | if params.ENABLE_STEREO: 64 | assert self.stereo_baseline is not None 65 | return datapoint.StereoCameraDataPoint( 66 | k_matrix=k_matrix, 67 | baseline=self.stereo_baseline, 68 | ) 69 | return datapoint.CameraDataPoint(k_matrix=k_matrix,) 70 | 71 | def _set_intrinsics(self, intrinsics_matrix): 72 | assert intrinsics_matrix.shape[0] == 4 73 | assert intrinsics_matrix.shape[1] == 4 74 | 75 | self.K_matrix = intrinsics_matrix 76 | self.proj_matrix = self.K_matrix @ self.RT_matrix 77 | 78 | def project(self, points): 79 | """Project 4d homogenous points (4xN) to 4d homogenous pixels (4xN)""" 80 | assert len(points.shape) == 2 81 | assert points.shape[0] == 4 82 | return self.proj_matrix @ points 83 | 84 | def deproject(self, pixels): 85 | """Deproject 4d homogenous pixels (4xN) to 4d homogenous points (4xN)""" 86 | assert len(pixels.shape) == 2 87 | assert pixels.shape[0] == 4 88 | return np.linalg.inv(self.proj_matrix) @ pixels 89 | 90 | def splat_points(self, hpoints_camera): 91 | """Project 4d homogenous points (4xN) to 4d homogenous points (4xN)""" 92 | assert len(hpoints_camera.shape) == 2 93 | assert hpoints_camera.shape[0] == 4 94 | hpixels = self.project(hpoints_camera) 95 | pixels = convert_homopixels_to_pixels(hpixels) 96 | depths_camera = convert_homopoints_to_points(hpoints_camera)[2, :] 97 | image = np.zeros((self.height, self.width)) 98 | pixel_cols = np.clip(np.round(pixels[0, :]).astype(np.int32), 0, self.width - 1) 99 | pixel_rows = np.clip(np.round(pixels[1, :]).astype(np.int32), 0, self.height - 1) 100 | image[pixel_rows, pixel_cols] = depths_camera < CLIPPING_PLANE_NEAR 101 | return image 102 | 103 | def deproject_depth_image(self, depth_image): 104 | assert depth_image.shape == (self.height, self.width) 105 | v, u = np.indices(depth_image.shape).astype(np.float32) 106 | z = depth_image.reshape((1, -1)) 107 | pixels = np.stack([u.flatten(), v.flatten()], axis=0) 108 | hpixels = convert_pixels_to_homopixels(pixels, z) 109 | hpoints = self.deproject(hpixels) 110 | return hpoints 111 | 112 | 113 | def convert_homopixels_to_pixels(pixels): 114 | """Project 4d homogenous pixels (4xN) to 2d pixels (2xN)""" 115 | assert len(pixels.shape) == 2 116 | assert pixels.shape[0] == 4 117 | pixels_3d = pixels[:3, :] / pixels[3:4, :] 118 | pixels_2d = pixels_3d[:2, :] / pixels_3d[2:3, :] 119 | assert pixels_2d.shape[1] == pixels.shape[1] 120 | assert pixels_2d.shape[0] == 2 121 | return pixels_2d 122 | 123 | 124 | def convert_pixels_to_homopixels(pixels, depths): 125 | """Project 2d pixels (2xN) and depths (meters, 1xN) to 4d pixels (4xN)""" 126 | assert len(pixels.shape) == 2 127 | assert pixels.shape[0] == 2 128 | assert len(depths.shape) == 2 129 | assert depths.shape[1] == pixels.shape[1] 130 | assert depths.shape[0] == 1 131 | pixels_4d = np.concatenate([ 132 | depths * pixels, 133 | depths, 134 | np.ones_like(depths), 135 | ], axis=0) 136 | assert pixels_4d.shape[0] == 4 137 | assert pixels_4d.shape[1] == pixels.shape[1] 138 | return pixels_4d 139 | 140 | 141 | def convert_points_to_homopoints(points): 142 | """Project 3d points (3xN) to 4d homogenous points (4xN)""" 143 | assert len(points.shape) == 2 144 | assert points.shape[0] == 3 145 | points_4d = np.concatenate([ 146 | points, 147 | np.ones((1, points.shape[1])), 148 | ], axis=0) 149 | assert points_4d.shape[1] == points.shape[1] 150 | assert points_4d.shape[0] == 4 151 | return points_4d 152 | 153 | 154 | def convert_homopoints_to_points(points_4d): 155 | """Project 4d homogenous points (4xN) to 3d points (3xN)""" 156 | assert len(points_4d.shape) == 2 157 | assert points_4d.shape[0] == 4 158 | points_3d = points_4d[:3, :] / points_4d[3:4, :] 159 | assert points_3d.shape[1] == points_3d.shape[1] 160 | assert points_3d.shape[0] == 3 161 | return points_3d 162 | -------------------------------------------------------------------------------- /src/lib/color_stuff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import colour 3 | 4 | 5 | def get_colors(num_colors): 6 | assert num_colors > 0 7 | 8 | colors = list(colour.Color("purple").range_to(colour.Color("green"), num_colors)) 9 | color_rgb = 255 * np.array([np.array(a.get_rgb()) for a in colors]) 10 | color_rgb = [a.astype(np.int) for a in color_rgb] 11 | return color_rgb 12 | 13 | 14 | def get_panoptic_colors(): 15 | colors = [ 16 | colour.Color("yellow"), 17 | colour.Color("blue"), 18 | colour.Color("green"), 19 | colour.Color("red"), 20 | colour.Color("purple") 21 | ] 22 | color_rgb = 255 * np.array([np.array(a.get_rgb()) for a in colors]) 23 | color_rgb = [a.astype(np.int) for a in color_rgb] 24 | return color_rgb 25 | 26 | 27 | def get_unique_colors(num_colors): 28 | ''' 29 | Gives a the specified number of unique colors 30 | Args: 31 | num_colors: an int specifying the number of colors 32 | Returs: 33 | A list of rgb colors in the range of (0,255) 34 | ''' 35 | color_rgb = get_colors(num_colors) 36 | 37 | if (num_colors != len(np.unique(color_rgb, axis=0))): 38 | raise ValueError('Colors returned are not unique.') 39 | 40 | return color_rgb 41 | -------------------------------------------------------------------------------- /src/lib/datapoint.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import IPython 3 | import copy 4 | import sys 5 | import io 6 | import operator 7 | import pathlib 8 | import pickle 9 | import shortuuid 10 | 11 | from PIL import Image 12 | import cv2 13 | import boto3 14 | import numpy as np 15 | import zstandard as zstd 16 | 17 | from dataloaders.dataset_blender import BlenderLocalDataset 18 | from dataloaders.dataset_TODD import ToddLocalDataset 19 | 20 | def get_uid(): 21 | return shortuuid.uuid() 22 | 23 | 24 | # Struct for Pose Prediction 25 | @dataclasses.dataclass 26 | class Pose: 27 | heat_map: np.ndarray 28 | vertex_target: np.ndarray 29 | z_centroid: np.ndarray 30 | 31 | 32 | # Struct for Keypoint Prediction 33 | @dataclasses.dataclass 34 | class Keypoint: 35 | heat_map: np.ndarray 36 | 37 | 38 | # Struct for Oriented Bounding Box Prediction 39 | @dataclasses.dataclass 40 | class OBB: 41 | heat_map: np.ndarray 42 | vertex_target: np.ndarray 43 | z_centroid: np.ndarray 44 | cov_matrices: np.ndarray 45 | compressed: bool = False 46 | 47 | def compress(self): 48 | if self.compressed: 49 | return 50 | # Heat map scale by 4x and quantize 51 | height, width = self.heat_map.shape 52 | self.heat_map = cv2.resize( 53 | self.heat_map, (width // 4, height // 4), interpolation=cv2.INTER_CUBIC 54 | ).astype(np.float16) 55 | 56 | # Vertex field, quantize and transpose to make vertex field smooth in memory order (makes 57 | # downstream compression 50x more effective) 58 | self.vertex_target = self.vertex_target.transpose(2, 0, 1).astype(np.float16) 59 | 60 | self.compressed = True 61 | 62 | def decompress(self): 63 | if not self.compressed: 64 | return 65 | 66 | # Heat map scale by 4x and quantize 67 | height, width = self.heat_map.shape 68 | self.heat_map = cv2.resize( 69 | self.heat_map.astype(np.float32), (width * 4, height * 4), interpolation=cv2.INTER_CUBIC 70 | ) 71 | 72 | # Vertex field, quantize and transpose to make vertex field smooth in memory order (makes 73 | # downstream compression 50x more effective) 74 | self.vertex_target = self.vertex_target.astype(np.float32).transpose(1, 2, 0) 75 | self.compressed = False 76 | 77 | 78 | def compress_color_image(img, quality=90): 79 | with io.BytesIO() as buf: 80 | img = Image.fromarray(img) 81 | img.save(buf, format='jpeg', quality=quality) 82 | return buf.getvalue() 83 | 84 | 85 | def decompress_color_image(img_bytes): 86 | with io.BytesIO(img_bytes) as buf: 87 | img = Image.open(buf) 88 | return np.array(img) 89 | 90 | 91 | #Struct for Stereo Representation 92 | @dataclasses.dataclass 93 | class Stereo: 94 | left_color: np.ndarray 95 | right_color: np.ndarray 96 | compressed: bool = False 97 | 98 | def compress(self): 99 | if self.compressed: 100 | return 101 | self.left_color = compress_color_image(self.left_color) 102 | self.right_color = compress_color_image(self.right_color) 103 | self.compressed = True 104 | 105 | def decompress(self): 106 | if not self.compressed: 107 | return 108 | 109 | self.left_color = decompress_color_image(self.left_color) 110 | self.right_color = decompress_color_image(self.right_color) 111 | self.compressed = False 112 | 113 | 114 | # Application Specific Datapoints Should be specified here. 115 | @dataclasses.dataclass 116 | class Panoptic: 117 | stereo: Stereo 118 | depth: np.ndarray 119 | segmentation: np.ndarray 120 | object_poses: list 121 | boxes: list 122 | detections: list 123 | keypoints: list = dataclasses.field(default_factory=list) 124 | instance_mask: np.ndarray = None 125 | scene_name: str = 'sim' 126 | uid: str = dataclasses.field(default_factory=get_uid) 127 | compressed: bool = False 128 | 129 | def compress(self): 130 | self.stereo.compress() 131 | for object_pose in self.object_poses: 132 | object_pose.compress() 133 | 134 | if self.compressed: 135 | return 136 | 137 | # Depth scale by 4x and quantize 138 | height, width = self.depth.shape 139 | self.depth = cv2.resize(self.depth, (width // 4, height // 4), 140 | interpolation=cv2.INTER_CUBIC).astype(np.float16) 141 | 142 | self.compressed = True 143 | 144 | def decompress(self): 145 | self.stereo.decompress() 146 | for object_pose in self.object_poses: 147 | object_pose.decompress() 148 | 149 | if not self.compressed: 150 | return 151 | 152 | # Depth scale by 4x and quantize 153 | height, width = self.depth.shape 154 | self.depth = cv2.resize( 155 | self.depth.astype(np.float32), (width * 4, height * 4), interpolation=cv2.INTER_CUBIC 156 | ) 157 | 158 | self.compressed = False 159 | 160 | 161 | # Application Specific Datapoints Should be specified here. 162 | @dataclasses.dataclass 163 | class RobotMask: 164 | stereo: Stereo 165 | depth: np.ndarray 166 | segmentation: np.ndarray 167 | uid: str = dataclasses.field(default_factory=get_uid) 168 | 169 | 170 | # End Applications Here 171 | def compress_datapoint(x): 172 | x = copy.deepcopy(x) 173 | x.compress() 174 | buf = pickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) 175 | cctx = zstd.ZstdCompressor() 176 | cbuf = cctx.compress(buf) 177 | return cbuf 178 | 179 | 180 | def decompress_datapoint(cbuf, disable_final_decompression=False): 181 | cctx = zstd.ZstdDecompressor() 182 | buf = cctx.decompress(cbuf) 183 | x = pickle.loads(buf) 184 | if not disable_final_decompression: 185 | x.decompress() 186 | 187 | for i in range(len(x.object_poses)): 188 | print(f'Pose: {i}') 189 | # print('vertex_target: ', x.object_poses[i].vertex_target.shape, np.unique(x.object_poses[i].vertex_target)) 190 | print('z_centroid: ', x.object_poses[i].z_centroid.shape, np.unique(x.object_poses[i].z_centroid)) 191 | print('cov_matrices: ', x.object_poses[i].cov_matrices.shape, np.unique(x.object_poses[i].cov_matrices)) 192 | print('heat_map: ', x.object_poses[i].heat_map.shape,np.unique(x.object_poses[i].heat_map)) 193 | if np.unique(x.object_poses[i].cov_matrices).shape[0] > 1: 194 | # Produce segmentation based on covariance matrix values 195 | unique_vectors = np.unique(x.object_poses[i].cov_matrices, axis = 2) 196 | cov_matrix = x.object_poses[i].cov_matrices.copy() 197 | z_cent = x.object_poses[i].z_centroid.copy() 198 | for v_idx, (vector, z) in enumerate(zip(unique_vectors,np.unique(x.object_poses[i].cov_matrices))): 199 | cov_matrix[cov_matrix == vector] = v_idx / 4 200 | z_cent[z_cent == z] = v_idx/4 201 | 202 | # Plot the covariance matrix segmentation 203 | import matplotlib.pyplot as plt 204 | plt.plot() 205 | plt.imshow(cov_matrix[:,:,0], cmap='Greys', interpolation='nearest') 206 | plt.savefig('/h/helen/transparent-perception/tests/blkwht.png') 207 | plt.close() 208 | plt.plot() 209 | plt.imshow(z_cent, cmap='Greys', interpolation='nearest') 210 | plt.savefig('/h/helen/transparent-perception/tests/blkwht_z.png') 211 | plt.close() 212 | plt.plot() 213 | plt.imshow(x.object_poses[i].heat_map, interpolation='nearest') 214 | plt.savefig('/h/helen/transparent-perception/tests/blkwht_heat.png') 215 | plt.close() 216 | plt.plot() 217 | plt.imshow(x.stereo.left_color, interpolation='nearest') 218 | plt.savefig('/h/helen/transparent-perception/tests/blkwht_color.png') 219 | plt.close() 220 | raise ValueError 221 | return x 222 | 223 | def get_segmentation(path): 224 | return 0 225 | 226 | def make_dataset(uri, dataset = 'simnet', multiview = False, num_multiview = 2, num_samples = 43): 227 | if dataset == 'simnet': 228 | if ',' in uri: 229 | datasets = [] 230 | for uri in uri.split(','): 231 | datasets.append(make_one_dataset(uri)) 232 | return ConcatDataset(datasets) 233 | return make_one_dataset(uri) 234 | elif dataset == 'blender': 235 | uri, _, raw_params = uri.partition('?') 236 | path = uri.partition('file://')[2] 237 | dataset_path = pathlib.Path(path) 238 | return BlenderLocalDataset(dataset_path, multiview = multiview, num_views = num_multiview, num_samples = num_samples) 239 | elif dataset == 'TODD': 240 | uri, _, raw_params = uri.partition('?') 241 | path = uri.partition('file://')[2] 242 | dataset_path = pathlib.Path(path) 243 | return ToddLocalDataset(dataset_path, multiview = multiview, num_views = num_multiview, num_samples = num_samples) 244 | else: 245 | raise ValueError(f'Expected dataset to be simnet, blender, but got {dataset}') 246 | 247 | 248 | def make_one_dataset(uri): 249 | # parse parameters 250 | uri, _, raw_params = uri.partition('?') 251 | dataset = make_one_simple_dataset(uri) 252 | if not raw_params: 253 | return dataset 254 | 255 | params = {} 256 | for raw_param in raw_params.split('&'): 257 | k, _, v = raw_param.partition('=') 258 | assert k and v 259 | assert k not in params 260 | params[k] = v 261 | return FilterDataset(dataset, params) 262 | 263 | 264 | def make_one_simple_dataset(uri): 265 | if uri.startswith('s3://'): 266 | path = uri.partition('s3://')[2] 267 | bucket, _, dataset_path = path.partition('/') 268 | return RemoteDataset(bucket, dataset_path) 269 | 270 | if uri.startswith('file://'): 271 | path = uri.partition('file://')[2] 272 | dataset_path = pathlib.Path(path) 273 | return LocalDataset(dataset_path) 274 | raise ValueError(f'uri must start with `s3://` or `file://`. uri={uri}') 275 | 276 | 277 | def _datapoint_path(dataset_path, uid): 278 | return f'{dataset_path}/{uid}.pickle.zstd' 279 | 280 | 281 | class FilterDataset: 282 | 283 | def __init__(self, dataset, params): 284 | self.dataset = dataset 285 | self.params = params 286 | self.samples = None 287 | for key in params: 288 | if key == 'samples': 289 | self.samples = int(params[key]) 290 | else: 291 | raise ValueError(f'Unknown param in dataset args: {key}') 292 | 293 | def list(self): 294 | handles = self.dataset.list() 295 | if self.samples is not None: 296 | handles = handles * (self.samples // len(handles) + 1) 297 | handles = handles[:self.samples] 298 | return handles 299 | 300 | def write(self, datapoint): 301 | raise ValueError('Cannot write to concat dataset') 302 | 303 | 304 | class ConcatDataset: 305 | 306 | def __init__(self, datasets): 307 | self.datasets = datasets 308 | 309 | def list(self): 310 | handles = [] 311 | for dataset in self.datasets: 312 | handles.extend(dataset.list()) 313 | return handles 314 | 315 | def write(self, datapoint): 316 | raise ValueError('Cannot write to concat dataset') 317 | 318 | 319 | class RemoteDataset: 320 | 321 | def __init__(self, bucket, path): 322 | self.s3 = boto3.resource('s3') 323 | self.bucket = bucket 324 | self.dataset_path = path 325 | assert not path.endswith('/') 326 | self._cache_list = None 327 | 328 | def list(self): 329 | if self._cache_list is not None: 330 | return self._cache_list 331 | bucket = self.s3.Bucket(self.bucket) 332 | handles = [] 333 | for obj in bucket.objects.filter(Prefix=self.dataset_path + '/'): 334 | path = obj.key 335 | if not path.endswith('.pickle.zstd'): 336 | continue 337 | uid = path.rpartition('/')[2].partition('.pickle.zstd')[0] 338 | handles.append(RemoteReadHandle(self.bucket, self.dataset_path, uid)) 339 | x = sorted(handles, key=operator.attrgetter('uid')) 340 | self._cache_list = x 341 | return x 342 | 343 | def write(self, datapoint): 344 | buf = compress_datapoint(datapoint) 345 | path = _datapoint_path(self.dataset_path, datapoint.uid) 346 | self.s3.Bucket(self.bucket).put_object(Key=path, Body=buf) 347 | 348 | 349 | class LocalDataset: 350 | 351 | def __init__(self, dataset_path): 352 | if not dataset_path.exists(): 353 | print('New dataset directory:', dataset_path) 354 | dataset_path.mkdir(parents=True) 355 | assert dataset_path.is_dir() 356 | self.dataset_path = dataset_path 357 | 358 | def list(self): 359 | handles = [] 360 | for path in self.dataset_path.glob('*.pickle.zstd'): 361 | uid = path.name.partition('.')[0] 362 | handles.append(LocalReadHandle(self.dataset_path, uid)) 363 | return sorted(handles, key=operator.attrgetter('uid')) 364 | 365 | def write(self, datapoint): 366 | path = _datapoint_path(self.dataset_path, datapoint.uid) 367 | buf = compress_datapoint(datapoint) 368 | with open(path, 'wb') as fh: 369 | fh.write(buf) 370 | 371 | 372 | class LocalReadHandle: 373 | 374 | def __init__(self, dataset_path, uid): 375 | self.dataset_path = dataset_path 376 | self.uid = uid 377 | 378 | def read(self, disable_final_decompression=False): 379 | path = _datapoint_path(self.dataset_path, self.uid) 380 | with open(path, 'rb') as fh: 381 | dp = decompress_datapoint(fh.read(), disable_final_decompression=disable_final_decompression) 382 | if not hasattr(dp, 'uid'): 383 | dp.uid = self.uid 384 | assert dp.uid == self.uid 385 | return dp 386 | 387 | 388 | if __name__ == '__main__': 389 | ds = make_dataset(sys.argv[1]) 390 | for dph in ds.list(): 391 | dp = dph.read() 392 | IPython.embed() 393 | break 394 | -------------------------------------------------------------------------------- /src/lib/net/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import ConcatDataset, DataLoader 3 | 4 | from src.lib.net.init.default_init import default_init 5 | from src.lib.net.dataset import Dataset 6 | 7 | 8 | def add_dataset_args(parser, prefix): 9 | group = parser.add_argument_group("{}_dataset".format(prefix)) 10 | group.add_argument("--{}_path".format(prefix), type=str, required=True) 11 | group.add_argument("--{}_batch_size".format(prefix), default=16, type=int) 12 | group.add_argument("--{}_num_workers".format(prefix), default=7, type=int) 13 | 14 | 15 | def add_train_args(parser): 16 | parser.add_argument("--max_steps", type=int, required=True) 17 | parser.add_argument("--output", type=str, required=True) 18 | parser.add_argument("--network_type", type=str, required=True) # multiview, simnet 19 | parser.add_argument("--num_multiview", type=int, required=True) # 2,3,5 20 | parser.add_argument("--num_samples", type=int, required=True) # 43, 56 21 | add_dataset_args(parser, "train") 22 | add_dataset_args(parser, "val") 23 | 24 | optim_group = parser.add_argument_group("optim") 25 | optim_group.add_argument("--optim_type", default='sgd', type=str) 26 | optim_group.add_argument("--optim_learning_rate", default=0.02, type=float) 27 | optim_group.add_argument("--optim_momentum", default=0.9, type=float) 28 | optim_group.add_argument("--optim_weight_decay", default=1e-4, type=float) 29 | optim_group.add_argument("--optim_poly_exp", default=0.9, type=float) 30 | optim_group.add_argument("--optim_warmup_epochs", default=None, type=int) 31 | parser.add_argument("--model_file", type=str, required=True) 32 | parser.add_argument("--model_name", type=str, required=True) 33 | parser.add_argument("--checkpoint", default=None, type=str) 34 | parser.add_argument("--wandb_name", type=str, required=True) 35 | # Ignore Mask Search. 36 | parser.add_argument("--min_height", default=0.0, type=float) 37 | parser.add_argument("--min_occlusion", default=0.0, type=float) 38 | parser.add_argument("--min_truncation", default=0.0, type=float) 39 | # Backbone configs 40 | parser.add_argument("--model_norm", default='BN', type=str) 41 | parser.add_argument("--num_filters_scale", default=4, type=int) 42 | 43 | # Loss weights 44 | parser.add_argument("--frozen_stereo_checkpoint", default=None, type=str) 45 | parser.add_argument("--loss_seg_mult", default=1.0, type=float) 46 | parser.add_argument("--loss_depth_mult", default=1.0, type=float) 47 | parser.add_argument("--loss_depth_refine_mult", default=1.0, type=float) 48 | parser.add_argument("--loss_heatmap_mult", default=100.0, type=float) 49 | parser.add_argument("--loss_vertex_mult", default=0.1, type=float) 50 | parser.add_argument("--loss_z_centroid_mult", default=0.1, type=float) 51 | parser.add_argument("--loss_rotation_mult", default=0.1, type=float) 52 | parser.add_argument("--loss_keypoint_mult", default=0.1, type=float) 53 | # Stereo Stem Args 54 | parser.add_argument( 55 | "--loss_disparity_stdmean_scaled", 56 | action="store_true", 57 | help="If true, the loss will be scaled based on the standard deviation and mean of the " 58 | "ground truth disparities" 59 | ) 60 | parser.add_argument("--cost_volume_downsample_factor", default=4, type=int) 61 | parser.add_argument("--max_disparity", default=90, type=int) 62 | parser.add_argument( 63 | "--fe_features", 64 | default=16, 65 | type=int, 66 | help="Number of output features in feature extraction stage" 67 | ) 68 | parser.add_argument( 69 | "--fe_internal_features", 70 | default=32, 71 | type=int, 72 | help="Number of features in the first block of the feature extraction" 73 | ) 74 | # keypoint head args 75 | parser.add_argument("--num_keypoints", default=1, type=int) 76 | 77 | 78 | def get_config_value(hparams, prefix, key): 79 | full_key = "{}_{}".format(prefix, key) 80 | if hasattr(hparams, full_key): 81 | return getattr(hparams, full_key) 82 | else: 83 | return None 84 | 85 | 86 | def get_loader(hparams, prefix, preprocess_func=None, datapoint_dataset=None): 87 | datasets = [] 88 | path = get_config_value(hparams, prefix, 'path') 89 | datasets.append( 90 | Dataset( 91 | path, hparams, preprocess_image_func=preprocess_func, datapoint_dataset=datapoint_dataset 92 | ) 93 | ) 94 | batch_size = get_config_value(hparams, prefix, "batch_size") 95 | if hparams.network_type == 'simnet': 96 | collate_fn = simnet_collate 97 | elif hparams.network_type == 'multiview': 98 | collate_fn = multiview_collate 99 | return DataLoader( 100 | ConcatDataset(datasets), 101 | batch_size=batch_size, 102 | collate_fn=collate_fn, 103 | num_workers=get_config_value(hparams, prefix, "num_workers"), 104 | pin_memory=True, 105 | shuffle=True, 106 | drop_last=True 107 | ) 108 | 109 | 110 | def simnet_collate(batch): 111 | # list of elements per patch 112 | # Each element is a tuple of (stereo,imgs) 113 | targets = [] 114 | for ii in range(len(batch[0])): 115 | targets.append([batch_element[ii] for batch_element in batch]) 116 | 117 | stacked_images = torch.stack(targets[0]) 118 | 119 | return stacked_images, targets[1], targets[2], targets[3], targets[4], targets[5], targets[ 120 | 6], targets[7] 121 | 122 | def multiview_collate(batch): 123 | # list of elements per patch 124 | # Each element is a tuple of (stereo,imgs) 125 | targets = [] 126 | for ii in range(len(batch[0])): 127 | targets.append([batch_element[ii] for batch_element in batch]) 128 | 129 | stacked_images = torch.stack(targets[0]) 130 | stacked_camera_poses = torch.stack(targets[1]) 131 | stacked_intrinsics = torch.stack(targets[2]) 132 | 133 | return stacked_images, stacked_camera_poses, targets[2], targets[3], targets[4], targets[5], targets[ 134 | 6], targets[9] 135 | 136 | 137 | def prune_state_dict(state_dict): 138 | for key in list(state_dict.keys()): 139 | state_dict[key[6:]] = state_dict.pop(key) 140 | return state_dict 141 | 142 | 143 | def keep_only_stereo_weights(state_dict): 144 | pruned_state_dict = {} 145 | for key in list(state_dict.keys()): 146 | if 'stereo' in key: 147 | pruned_state_dict[key] = state_dict[key] 148 | return pruned_state_dict 149 | 150 | 151 | def get_model(hparams): 152 | if hparams.network_type == 'simnet': 153 | from simnet.lib.net.models import panoptic_net 154 | net_module = panoptic_net 155 | # net_module = SourceFileLoader(hparams.model_name, str(model_path)).load_module() 156 | net_attr = getattr(net_module, hparams.model_name) 157 | model = net_attr(hparams) 158 | elif hparams.network_type == 'multiview': 159 | from model import multiview_net 160 | net_module = multiview_net 161 | net_attr = getattr(net_module, hparams.model_name) 162 | model = net_attr(hparams) 163 | 164 | model.apply(default_init) 165 | 166 | # For large models use imagenet weights. 167 | # This speeds up training and can give a +2 mAP score on car detections 168 | if hparams.num_filters_scale == 1: 169 | model.load_imagenet_weights() 170 | 171 | if hparams.frozen_stereo_checkpoint is not None: 172 | print('Restoring stereo weights from checkpoint:', hparams.frozen_stereo_checkpoint) 173 | state_dict = torch.load(hparams.frozen_stereo_checkpoint, map_location='cpu')['state_dict'] 174 | state_dict = prune_state_dict(state_dict) 175 | state_dict = keep_only_stereo_weights(state_dict) 176 | model.load_state_dict(state_dict, strict=False) 177 | 178 | if hparams.checkpoint is not None: 179 | print('Restoring from checkpoint:', hparams.checkpoint) 180 | state_dict = torch.load(hparams.checkpoint, map_location='cpu')['state_dict'] 181 | state_dict = prune_state_dict(state_dict) 182 | model.load_state_dict(state_dict, strict=False) 183 | return model 184 | -------------------------------------------------------------------------------- /src/lib/net/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Toyota Research Institute. All rights reserved. 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from src.lib import datapoint 9 | from src.lib.net.post_processing import obb_outputs, depth_outputs, segmentation_outputs 10 | 11 | 12 | def extract_left_numpy_img(anaglyph, mode = None): 13 | if mode == 'simnet': 14 | anaglyph_np = np.ascontiguousarray(anaglyph.cpu().numpy()) 15 | anaglyph_np = anaglyph_np.transpose((1, 2, 0)) 16 | left_img = anaglyph_np[..., 0:3] * 255.0 17 | elif mode == 'multiview': 18 | anaglyph_np = np.ascontiguousarray(anaglyph[0].cpu().numpy()) 19 | anaglyph_np = anaglyph_np.transpose((1, 2, 0)) 20 | left_img = anaglyph_np[..., 0:3] * 255.0 21 | return left_img 22 | 23 | 24 | def extract_right_numpy_img(anaglyph): 25 | anaglyph_np = np.ascontiguousarray(anaglyph.cpu().numpy()) 26 | anaglyph_np = anaglyph_np.transpose((1, 2, 0)) 27 | left_img = anaglyph_np[..., 3:6] * 255.0 28 | return left_img 29 | 30 | 31 | def create_anaglyph(stereo_dp): 32 | height, width, _ = stereo_dp.left_color.shape 33 | image = np.zeros([height, width, 6], dtype=np.uint8) 34 | cv2.normalize(stereo_dp.left_color, stereo_dp.left_color, 0, 255, cv2.NORM_MINMAX) 35 | cv2.normalize(stereo_dp.right_color, stereo_dp.right_color, 0, 255, cv2.NORM_MINMAX) 36 | image[..., 0:3] = stereo_dp.left_color 37 | image[..., 3:6] = stereo_dp.right_color 38 | image = image * 1. / 255.0 39 | image = image.transpose((2, 0, 1)) # 3xHxW 40 | return torch.from_numpy(np.ascontiguousarray(image)).float() 41 | 42 | def CombineMultiview(stereo_dps): 43 | height, width, _ = stereo_dps[0].left_color.shape 44 | image = np.zeros([height, width, 3*len(stereo_dps)], dtype=np.uint8) 45 | images_combined = [] 46 | for index, stereo_dp in enumerate(stereo_dps): 47 | cv2.normalize(stereo_dp.left_color, stereo_dp.left_color, 0, 255, cv2.NORM_MINMAX) 48 | image = stereo_dp.left_color.transpose((2,0,1)) 49 | image = image * 1. / 255.0 50 | images_combined.append(image) 51 | 52 | return torch.from_numpy(np.ascontiguousarray(images_combined)).float() 53 | 54 | class Dataset(Dataset): 55 | 56 | def __init__(self, dataset_uri, hparams, preprocess_image_func=None, datapoint_dataset=None): 57 | super().__init__() 58 | 59 | if datapoint_dataset is None: 60 | datapoint_dataset = datapoint.make_dataset(dataset_uri) 61 | self.datapoint_handles = datapoint_dataset.list() 62 | print(len(self.datapoint_handles)) 63 | # No need to shuffle, already shufled based on random uids 64 | self.hparams = hparams 65 | 66 | if preprocess_image_func is None: 67 | self.preprocces_image_func = create_anaglyph 68 | else: 69 | assert False 70 | self.preprocces_image_func = preprocess_image_func 71 | 72 | def __len__(self): 73 | return len(self.datapoint_handles) 74 | 75 | def getMultiviewSample(self, idx): 76 | 77 | try: 78 | dp_list = [dp.read() for dp in self.datapoint_handles[idx]] 79 | except: 80 | dp_list = [dp for dp in self.datapoint_handles[idx]] 81 | 82 | # Get anaglyph 83 | stereo_list = [dp.stereo for dp in dp_list] 84 | anaglyph = CombineMultiview(stereo_list) 85 | 86 | # Get segmentation 87 | segmentation_target = segmentation_outputs.SegmentationOutput(dp_list[0].segmentation, self.hparams) 88 | segmentation_target.convert_to_torch_from_numpy() 89 | 90 | scene_name = [dp.uid for dp in dp_list] 91 | # Check for nans, infs and large depth replace 92 | if np.isnan(dp_list[0].depth).any(): 93 | depth_mask_nan = np.isnan(dp_list[0].depth) 94 | dp_list[0].depth[depth_mask_nan] = 3.0 95 | 96 | if np.isinf(dp_list[0].depth).any(): 97 | depth_mask_inf = np.isinf(dp_list[0].depth) 98 | dp_list[0].depth[depth_mask_inf] = 3.0 99 | 100 | if (dp_list[0].depth > 3).any(): 101 | dp_list[0].depth[dp_list[0].depth > 3] = 3.0 102 | 103 | # Check for nans in covariance 104 | for pose in dp_list[0].object_poses: 105 | if np.isnan(pose.cov_matrices).any(): 106 | covariance_mask_nan = np.isnan(pose.cov_matrices) 107 | pose.cov_matrices[covariance_mask_nan] = 0.0001 108 | if np.isinf(pose.cov_matrices).any(): 109 | covariance_mask_nan = np.isnan(pose.cov_matrices) 110 | pose.cov_matrices[covariance_mask_nan] = 0.0001 111 | if np.isnan(pose.vertex_target).any(): 112 | mask = np.isnan(pose.vertex_target) 113 | pose.vertex_target[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), pose.vertex_target[~mask]) 114 | if np.isinf(pose.vertex_target).any(): 115 | mask = np.isinf(pose.vertex_target) 116 | pose.vertex_target[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), pose.vertex_target[~mask]) 117 | if np.isnan(pose.z_centroid).any(): 118 | mask = np.isnan(pose.z_centroid) 119 | pose.z_centroid[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), pose.z_centroid[~mask]) 120 | if np.isinf(pose.z_centroid).any(): 121 | mask = np.isinf(pose.z_centroid) 122 | pose.z_centroid[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), pose.z_centroid[~mask]) 123 | if np.isnan(pose.heat_map).any(): 124 | mask = np.isnan(pose.heat_map) 125 | pose.heat_map[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), pose.heat_map[~mask]) 126 | if np.isinf(pose.heat_map).any(): 127 | mask = np.isinf(pose.heat_map) 128 | pose.heat_map[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), pose.heat_map[~mask]) 129 | 130 | depth_target = depth_outputs.DepthOutput(dp_list[0].depth, self.hparams) 131 | depth_target.convert_to_torch_from_numpy() 132 | pose_target = None 133 | for pose_dp in dp_list[0].object_poses: 134 | pose_target = obb_outputs.OBBOutput( 135 | pose_dp.heat_map, pose_dp.vertex_target, pose_dp.z_centroid, pose_dp.cov_matrices, 136 | self.hparams 137 | ) 138 | pose_target.convert_to_torch_from_numpy() 139 | 140 | box_target = None 141 | kp_target = None 142 | 143 | # Final check to make sure there are no bad depths 144 | assert not np.isnan(dp_list[0].depth).any(), 'Depth should not have nan!!' 145 | assert not isnan(depth_target), 'Depth should not have nan!!' 146 | 147 | 148 | # Get the camera pose 149 | camera_poses = np.array([np.array(dp.camera_params['camera_extrinsic']) for dp in dp_list]) 150 | camera_poses = torch.from_numpy(np.ascontiguousarray(camera_poses)).float() 151 | 152 | # Get camera intrinsics 153 | camera_intrinsic = np.array(dp_list[0].camera_params['camera_intrinsic']).reshape((1,3,3)) 154 | camera_intrinsic = torch.from_numpy(np.ascontiguousarray(camera_intrinsic)).float() 155 | 156 | assert camera_poses.shape[0] == anaglyph.shape[0], f'Number of camera poses {camera_poses.shape} does not match number of images {anaglyph.shape}' 157 | return anaglyph, camera_poses, camera_intrinsic, segmentation_target, depth_target, pose_target, box_target, kp_target, dp_list[0].detections, scene_name 158 | 159 | def getStereoSample(self,idx): 160 | 161 | try: 162 | dp = self.datapoint_handles[idx].read() 163 | except: 164 | dp = self.datapoint_handles[idx] 165 | 166 | anaglyph = self.preprocces_image_func(dp.stereo) 167 | 168 | segmentation_target = segmentation_outputs.SegmentationOutput(dp.segmentation, self.hparams) 169 | segmentation_target.convert_to_torch_from_numpy() 170 | scene_name = dp.uid 171 | 172 | # Check for nans, infs and large depth replace 173 | if np.isnan(dp.depth).any(): 174 | depth_mask_nan = np.isnan(dp.depth) 175 | dp.depth[depth_mask_nan] = 3.0 176 | 177 | if np.isinf(dp.depth).any(): 178 | depth_mask_inf = np.isinf(dp.depth) 179 | dp.depth[depth_mask_inf] = 3.0 180 | 181 | if (dp.depth > 3).any(): 182 | dp.depth[dp.depth > 3] = 3.0 183 | 184 | # Check for nans in covariance 185 | for pose in dp.object_poses: 186 | if np.isnan(pose.cov_matrices).any(): 187 | covariance_mask_nan = np.isnan(pose.cov_matrices) 188 | pose.cov_matrices[covariance_mask_nan] = 0.0001 189 | 190 | depth_target = depth_outputs.DepthOutput(dp.depth, self.hparams) 191 | depth_target.convert_to_torch_from_numpy() 192 | pose_target = None 193 | for pose_dp in dp.object_poses: 194 | pose_target = obb_outputs.OBBOutput( 195 | pose_dp.heat_map, pose_dp.vertex_target, pose_dp.z_centroid, pose_dp.cov_matrices, 196 | self.hparams 197 | ) 198 | pose_target.convert_to_torch_from_numpy() 199 | 200 | box_target = None 201 | kp_target = None 202 | 203 | # Final check to make sure there are no bad depths 204 | assert not np.isnan(dp.depth).any(), 'Depth should not have nan!!' 205 | assert not isnan(depth_target), 'Depth should not have nan!!' 206 | return anaglyph, segmentation_target, depth_target, pose_target, box_target, kp_target, dp.detections, scene_name 207 | def __getitem__(self, idx): 208 | if self.hparams.network_type != 'multiview': 209 | return self.getStereoSample(idx) 210 | else: 211 | return self.getMultiviewSample(idx) 212 | 213 | def plot(data,index): 214 | from matplotlib import pyplot as plt 215 | plt.imshow((((data-np.amin(data))/(np.amax(data) - np.amin(data)))*256).astype(np.uint8), interpolation='nearest') 216 | plt.savefig(f'/h/helen/transparent-perception/tests/{index}_depth.png') 217 | plt.close() 218 | 219 | def isnan(x): 220 | return x != x 221 | -------------------------------------------------------------------------------- /src/lib/net/functions/learning_rate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Toyota Research Institute. All rights reserved. 2 | # 3 | # Originally from Koichiro Yamaguchi's pixwislab repo mirrored at: 4 | # https://github.awsinternal.tri.global/driving/pixwislab 5 | 6 | 7 | def lambda_learning_rate_poly(max_epochs, exponent): 8 | """Make a function for computing learning rate by "poly" policy. 9 | 10 | This policy does a polynomial decay of the learning rate over the epochs 11 | of training. 12 | 13 | Args: 14 | max_epochs (int): max numbers of epochs 15 | exponent (float): exponent value 16 | """ 17 | return lambda epoch: pow((1.0 - epoch / max_epochs), exponent) 18 | 19 | 20 | def lambda_warmup(warmup_period, warmup_factor, wrapped_lambda): 21 | 22 | def warmup(epoch, warmup_period, warmup_factor): 23 | if epoch > warmup_period: 24 | return 1.0 25 | else: 26 | return warmup_factor + (1.0 - warmup_factor) * (epoch / warmup_period) 27 | 28 | return lambda epoch: warmup(epoch, warmup_period, warmup_factor) * wrapped_lambda(epoch) 29 | -------------------------------------------------------------------------------- /src/lib/net/init/default_init.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def default_init(module): 5 | """Initialize parameters of the module. 6 | 7 | For convolution, weights are initialized by Kaiming method and 8 | biases are initialized to zero. 9 | For batch normalization, scales and biases are set to 1 and 0, 10 | respectively. 11 | """ 12 | if isinstance(module, nn.Conv2d): 13 | nn.init.kaiming_normal_(module.weight.data) 14 | if module.bias is not None: 15 | module.bias.data.zero_() 16 | elif isinstance(module, nn.Conv3d): 17 | nn.init.kaiming_normal_(module.weight.data) 18 | if module.bias is not None: 19 | module.bias.data.zero_() 20 | elif isinstance(module, nn.BatchNorm2d): 21 | module.weight.data.fill_(1) 22 | module.bias.data.zero_() 23 | elif isinstance(module, nn.BatchNorm3d): 24 | module.weight.data.fill_(1) 25 | module.bias.data.zero_() 26 | -------------------------------------------------------------------------------- /src/lib/net/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Toyota Research Institute. All rights reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class MaskedL1Loss(nn.Module): 8 | 9 | def __init__(self, centroid_threshold=0.3, downscale_factor=8): 10 | super().__init__() 11 | self.loss = nn.L1Loss(reduction='none') 12 | self.centroid_threshold = centroid_threshold 13 | self.downscale_factor = downscale_factor 14 | 15 | def forward(self, output, target, valid_mask): 16 | ''' 17 | output: [N,16,H,W] 18 | target: [N,16,H,W] 19 | valid_mask: [N,H,W] 20 | ''' 21 | valid_count = torch.sum( 22 | valid_mask[:, ::self.downscale_factor, ::self.downscale_factor] > self.centroid_threshold 23 | ) 24 | loss = self.loss(output, target) 25 | if len(output.shape) == 4: 26 | loss = torch.sum(loss, dim=1) 27 | loss[valid_mask[:, ::self.downscale_factor, ::self.downscale_factor] < self.centroid_threshold 28 | ] = 0.0 29 | if valid_count == 0: 30 | return torch.sum(loss) 31 | return torch.sum(loss) / valid_count 32 | 33 | 34 | class MSELoss(nn.Module): 35 | 36 | def __init__(self): 37 | super().__init__() 38 | self.loss = nn.MSELoss(reduction='none') 39 | 40 | def forward(self, output, target): 41 | ''' 42 | output: [N,H,W] 43 | target: [N,H,W] 44 | ignore_mask: [N,H,W] 45 | ''' 46 | loss = self.loss(output, target) 47 | return torch.mean(loss) 48 | 49 | 50 | class MaskedMSELoss(nn.Module): 51 | 52 | def __init__(self): 53 | super().__init__() 54 | self.loss = nn.MSELoss(reduction='none') 55 | 56 | def forward(self, output, target, ignore_mask): 57 | ''' 58 | output: [N,H,W] 59 | target: [N,H,W] 60 | ignore_mask: [N,H,W] 61 | ''' 62 | valid_sum = torch.sum(torch.logical_not(ignore_mask)) 63 | loss = self.loss(output, target) 64 | loss[ignore_mask > 0] = 0.0 65 | return torch.sum(loss) / valid_sum 66 | -------------------------------------------------------------------------------- /src/lib/net/models/simplenet.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | @dataclasses.dataclass 9 | class Node: 10 | inp: any 11 | module: nn.Module 12 | activated: bool 13 | stride: int 14 | dim: int 15 | 16 | def __hash__(self): 17 | return hash(self.module) 18 | 19 | 20 | class NetFactory(nn.Module): 21 | 22 | def __init__(self): 23 | super().__init__() 24 | self.nodes = [] 25 | self.skips = {} 26 | self.tags = {} 27 | 28 | def tag(self, node, name): 29 | self.tags[node] = name 30 | 31 | def input(self, in_dim=3, stride=1, activated=True): 32 | assert not self.nodes 33 | n = Node(inp=None, module=None, activated=activated, stride=stride, dim=in_dim) 34 | self.nodes.append(n) 35 | return n 36 | 37 | def _add(self, node): 38 | self.nodes.append(node) 39 | return node 40 | 41 | def _activate(self, node): 42 | if node.activated: 43 | return node 44 | return self._add( 45 | dataclasses.replace( 46 | node, 47 | inp=node, 48 | module=nn.Sequential(nn.BatchNorm2d(node.dim), nn.LeakyReLU()), 49 | activated=True 50 | ) 51 | ) 52 | 53 | def _conv(self, node, out_dim=None, stride=1, rate=1, kernel=3): 54 | node = self._activate(node) 55 | if out_dim is None: 56 | out_dim = node.dim 57 | padding = (kernel - 1) // 2 * rate 58 | return self._add( 59 | dataclasses.replace( 60 | node, 61 | inp=node, 62 | module=nn.Conv2d( 63 | node.dim, out_dim, kernel, stride=stride, dilation=rate, padding=padding 64 | ), 65 | activated=False, 66 | dim=out_dim, 67 | stride=node.stride * stride 68 | ) 69 | ) 70 | 71 | def _interp(self, node): 72 | return self._add( 73 | dataclasses.replace( 74 | node, 75 | inp=node, 76 | module=nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 77 | stride=node.stride // 2 78 | ) 79 | ) 80 | 81 | def _lateral(self, node, out_dim=None): 82 | if out_dim is None: 83 | out_dim = node.dim 84 | if out_dim == node.dim: 85 | return node 86 | return self._conv(node, out_dim=out_dim, kernel=1) 87 | 88 | def output(self, node, out_dim): 89 | return self._conv(node, out_dim=out_dim, kernel=1) 90 | 91 | def downscale(self, node, out_dim): 92 | return self._conv(node, out_dim, stride=2) 93 | 94 | def upsample(self, node, skip, out_dim): 95 | skip = self._lateral(skip, out_dim=out_dim) 96 | node = self._lateral(node, out_dim=out_dim) 97 | node = self._interp(node) 98 | self.skips[node] = skip 99 | return node 100 | 101 | def layer(self, node, out_dim=None, rate=1): 102 | if out_dim is None: 103 | out_dim = node.dim 104 | skip = self._lateral(node, out_dim=out_dim) 105 | node = self._conv(node, rate=rate) 106 | node = self._conv(node, rate=rate) 107 | self.skips[node] = skip 108 | return node 109 | 110 | def block(self, node, rates): 111 | for r in [int(r) for r in rates]: 112 | node = self.layer(node, rate=r) 113 | return node 114 | 115 | def bake(self): 116 | self.modules = nn.ModuleList(n.module for n in self.nodes if n.module is not None) 117 | return self 118 | 119 | def forward(self, x): 120 | outputs = {} 121 | tag_outputs = {} 122 | for node in self.nodes: 123 | if node.module is None: # initial input 124 | pass 125 | else: 126 | if node in self.skips: 127 | x = outputs[self.skips[node]] + node.module(outputs[node.inp]) 128 | else: 129 | x = node.module(outputs[node.inp]) 130 | outputs[node] = x 131 | last = x 132 | if node in self.tags: 133 | tag_outputs[self.tags[node]] = x 134 | if tag_outputs: 135 | return tag_outputs 136 | return last 137 | 138 | 139 | def hdrn_alpha_base(num_channels): 140 | net = NetFactory() 141 | x = net.input() 142 | x = net.downscale(x, num_channels) 143 | x = net.downscale(x, num_channels) 144 | x4 = x = net.block(x, '111') 145 | x = net.downscale(x, num_channels * 2) 146 | x8 = x = net.block(x, '1111') 147 | x = net.downscale(x, num_channels * 4) 148 | x = net.block(x, '12591259') 149 | x = net.upsample(x, x8, num_channels // 2) 150 | x = net.upsample(x, x4, num_channels // 2) 151 | return net.bake() 152 | 153 | 154 | def make_process_cost_volume(num_disparities): 155 | net = NetFactory() 156 | x = net.input(in_dim=num_disparities, stride=4, activated=True) 157 | x = net.block(x, '1259') 158 | x = net.output(x, out_dim=num_disparities) 159 | return net.bake() 160 | 161 | 162 | class HdrnAlphaStereo(nn.Module): 163 | 164 | def __init__(self, hparams): 165 | super().__init__() 166 | 167 | self.num_disparities = hparams.max_disparity 168 | self.internal_scale = hparams.cost_volume_downsample_factor 169 | self.internal_num_disparities = self.num_disparities // self.internal_scale 170 | assert self.internal_scale in [4, 8, 16] 171 | 172 | self.feature_extractor = hdrn_alpha_base(hparams.fe_internal_features) 173 | self.cost_volume = DotProductCostVolume(self.internal_num_disparities) 174 | self.process_cost_volume = make_process_cost_volume(self.internal_num_disparities) 175 | 176 | self.soft_argmin = SoftArgmin() 177 | 178 | def forward(self, left_image, right_image): 179 | left_score = self.feature_extractor(left_image) 180 | right_score = self.feature_extractor(right_image) 181 | 182 | cost_volume = self.cost_volume(left_score, right_score) 183 | cost_volume = self.process_cost_volume(cost_volume) 184 | 185 | disparity_small = self.soft_argmin(cost_volume) 186 | 187 | return disparity_small 188 | 189 | class StereoBackbone(nn.Module): 190 | 191 | def __init__(self, hparams, in_channels=3): 192 | super().__init__() 193 | 194 | def make_rgb_stem(): 195 | net = NetFactory() 196 | x = net.input(in_dim=3, stride=1, activated=True) 197 | x = net.downscale(x, 32) 198 | x = net.downscale(x, 32) 199 | return net.bake() 200 | 201 | def make_disp_features(): 202 | net = NetFactory() 203 | x = net.input(in_dim=1, stride=1, activated=False) 204 | x = net.layer(x, 32, rate=5) 205 | return net.bake() 206 | 207 | self.rgb_stem = make_rgb_stem() 208 | self.stereo_stem = HdrnAlphaStereo(hparams) 209 | self.disp_features = make_disp_features() 210 | 211 | def make_rgbd_backbone(num_channels=64, out_dim=64): 212 | net = NetFactory() 213 | x = net.input(in_dim=64, activated=True, stride=4) 214 | x = net._lateral(x, out_dim=num_channels) 215 | x4 = x = net.block(x, '111') 216 | x = net.downscale(x, num_channels * 2) 217 | x8 = x = net.block(x, '1111') 218 | x = net.downscale(x, num_channels * 4) 219 | x = net.block(x, '12591259') 220 | net.tag(net.output(x, out_dim), 'p4') 221 | x = net.upsample(x, x8, out_dim) 222 | net.tag(x, 'p3') 223 | x = net.upsample(x, x4, out_dim) 224 | net.tag(x, 'p2') 225 | return net.bake() 226 | 227 | self.rgbd_backbone = make_rgbd_backbone() 228 | 229 | def forward(self, stacked_img, step, robot_joint_angles=None): 230 | small_disp = self.stereo_stem.forward(stacked_img[:, 0:3], stacked_img[:, 3:6]) 231 | left_rgb_features = self.rgb_stem.forward(stacked_img[:, 0:3]) 232 | disp_features = self.disp_features(small_disp) 233 | rgbd_features = torch.cat((disp_features, left_rgb_features), axis=1) 234 | outputs = self.rgbd_backbone.forward(rgbd_features) 235 | outputs['small_disp'] = small_disp 236 | return outputs 237 | 238 | @property 239 | def out_channels(self): 240 | return 32 241 | 242 | @property 243 | def stride(self): 244 | return 4 # = stride 2 conv -> stride 2 max pool 245 | 246 | @property 247 | def out_channels(self): 248 | return 32 249 | 250 | @property 251 | def stride(self): 252 | return 4 # = stride 2 conv -> stride 2 max pool 253 | 254 | @torch.jit.script 255 | def cost_volume(left, right, num_disparities: int, is_right: bool): 256 | batch_size, channels, height, width = left.shape 257 | 258 | output = torch.zeros((batch_size, channels, num_disparities, height, width), 259 | dtype=left.dtype, 260 | device=left.device) 261 | 262 | for i in range(num_disparities): 263 | if not is_right: 264 | output[:, :, i, :, i:] = left[:, :, :, i:] * right[:, :, :, :width - i] 265 | else: 266 | output[:, :, i, :, :width - i] = left[:, :, :, i:] * right[:, :, :, :width - i] 267 | 268 | return output 269 | 270 | 271 | class CostVolume(nn.Module): 272 | """Compute cost volume using cross correlation of left and right feature maps""" 273 | 274 | def __init__(self, num_disparities, is_right=False): 275 | super().__init__() 276 | self.num_disparities = num_disparities 277 | self.is_right = is_right 278 | 279 | def forward(self, left, right): 280 | if torch.jit.is_scripting(): 281 | return cost_volume(left, right, self.num_disparities, self.is_right) 282 | else: 283 | return self.forward_with_amp(left, right) 284 | 285 | @torch.jit.unused 286 | def forward_with_amp(self, left, right): 287 | """This operation is unstable at float16, so compute at float32 even when using mixed precision""" 288 | with torch.cuda.amp.autocast(enabled=False): 289 | left = left.to(torch.float32) 290 | right = right.to(torch.float32) 291 | output = cost_volume(left, right, self.num_disparities, self.is_right) 292 | output = torch.clamp(output, -1e3, 1e3) 293 | return output 294 | 295 | 296 | @torch.jit.script 297 | def dot_product_cost_volume(left, right, num_disparities: int, is_right: bool): 298 | batch_size, channels, height, width = left.shape 299 | 300 | output = torch.zeros((batch_size, num_disparities, height, width), 301 | dtype=left.dtype, 302 | device=left.device) 303 | 304 | for i in range(num_disparities): 305 | if not is_right: 306 | output[:, i, :, i:] = (left[:, :, :, i:] * right[:, :, :, :width - i]).mean(dim=1) 307 | else: 308 | output[:, i, :, width - i] = (left[:, :, :, i:] * right[:, :, :, :width - i]).mean(dim=1) 309 | 310 | return output 311 | 312 | class DotProductCostVolume(nn.Module): 313 | """Compute cost volume using dot product of left and right feature maps""" 314 | 315 | def __init__(self, num_disparities, is_right=False): 316 | super().__init__() 317 | self.num_disparities = num_disparities 318 | self.is_right = is_right 319 | 320 | def forward(self, left, right): 321 | return dot_product_cost_volume(left, right, self.num_disparities, self.is_right) 322 | 323 | @torch.jit.unused 324 | def forward_with_amp(self, left, right): 325 | """This operation is unstable at float16, so compute at float32 even when using mixed precision""" 326 | with torch.cuda.amp.autocast(enabled=False): 327 | left = left.to(torch.float32) 328 | right = right.to(torch.float32) 329 | output = dot_product_cost_volume(left, right, self.num_disparities, self.is_right) 330 | output = torch.clamp(output, -1e3, 1e3) 331 | return output 332 | 333 | 334 | @torch.jit.script 335 | def soft_argmin(input): 336 | _, channels, _, _ = input.shape 337 | 338 | softmin = F.softmin(input, dim=1) 339 | index_tensor = torch.arange(0, channels, dtype=softmin.dtype, 340 | device=softmin.device).view(1, channels, 1, 1) 341 | output = torch.sum(softmin * index_tensor, dim=1, keepdim=True) 342 | return output 343 | 344 | 345 | class SoftArgmin(nn.Module): 346 | """Compute soft argmin operation for given cost volume""" 347 | 348 | def forward(self, input): 349 | return soft_argmin(input) 350 | 351 | 352 | @torch.jit.script 353 | def matchability(input): 354 | softmin = F.softmin(input, dim=1) 355 | log_softmin = F.log_softmax(-input, dim=1) 356 | output = torch.sum(softmin * log_softmin, dim=1, keepdim=True) 357 | return output 358 | 359 | 360 | class Matchability(nn.Module): 361 | """Compute disparity matchability value from https://arxiv.org/abs/2008.04800""" 362 | 363 | def forward(self, input): 364 | if torch.jit.is_scripting(): 365 | # Torchscript generation can't handle mixed precision, so always compute at float32. 366 | return matchability(input) 367 | else: 368 | return self.forward_with_amp(input) 369 | 370 | @torch.jit.unused 371 | def forward_with_amp(self, input): 372 | """This operation is unstable at float16, so compute at float32 even when using mixed precision""" 373 | with torch.cuda.amp.autocast(enabled=False): 374 | input = input.to(torch.float32) 375 | return matchability(input) 376 | 377 | 378 | def main(): 379 | num_channels = 32 380 | net = NetFactory() 381 | x = net.input() 382 | x = net.downscale(x, num_channels) 383 | x = net.downscale(x, num_channels) 384 | x4 = x = net.block(x, '111') 385 | x = net.downscale(x, num_channels * 2) 386 | x8 = x = net.block(x, '1111') 387 | x = net.downscale(x, num_channels * 4) 388 | x = net.block(x, '12591259') 389 | x = net.upsample(x, x8, num_channels // 2) 390 | x = net.upsample(x, x4, num_channels // 2) 391 | net.bake() 392 | 393 | x = torch.randn(5, 3, 512, 640) 394 | y = net(x) 395 | 396 | import torch._C as _C 397 | TrainingMode = _C._onnx.TrainingMode 398 | torch.onnx.export( 399 | net, 400 | x, 401 | "test_net.onnx", 402 | do_constant_folding=False, 403 | verbose=True, 404 | training=TrainingMode.TRAINING, 405 | opset_version=13 406 | ) 407 | import onnx 408 | from onnx import shape_inference 409 | onnx.save(shape_inference.infer_shapes(onnx.load('test_net.onnx')), 'test_net_shapes.onnx') 410 | 411 | 412 | if __name__ == '__main__': 413 | main() 414 | -------------------------------------------------------------------------------- /src/lib/net/post_processing/depth_outputs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from matplotlib import cm 5 | 6 | from torch.nn import functional as F 7 | from src.lib.net import losses 8 | 9 | _mse_loss = losses.MSELoss() 10 | _MAX_DISP = 3.0 11 | 12 | 13 | class DepthOutput: 14 | 15 | def __init__(self, depth_pred, loss_multiplier): 16 | self.depth_pred = depth_pred 17 | self.is_numpy = False 18 | self.loss = nn.SmoothL1Loss() 19 | self.loss_multiplier = loss_multiplier 20 | 21 | # Converters for torch to numpy 22 | def convert_to_numpy_from_torch(self): 23 | self.depth_pred = np.ascontiguousarray(self.depth_pred.cpu().numpy()) 24 | self.is_numpy = True 25 | 26 | def convert_to_torch_from_numpy(self): 27 | self.depth_pred[self.depth_pred > _MAX_DISP] = _MAX_DISP #- 1 28 | self.depth_pred = torch.from_numpy(np.ascontiguousarray(self.depth_pred)).float() 29 | self.is_numpy = False 30 | 31 | def get_visualization_img(self, left_img_np, corner_scale=1, raw_disp=True): 32 | if not self.is_numpy: 33 | self.convert_to_numpy_from_torch() 34 | disp = self.depth_pred[0] 35 | 36 | if raw_disp: 37 | return disp_map_visualize(disp) 38 | disp_scaled = disp[::corner_scale, ::corner_scale] 39 | left_img_np[:disp_scaled.shape[0], -disp_scaled.shape[1]:] = disp_map_visualize(disp_scaled) 40 | return left_img_np 41 | 42 | def get_visualization_img_gt(self, left_img_np, corner_scale=1, raw_disp=True): 43 | if not self.is_numpy: 44 | self.convert_to_numpy_from_torch() 45 | disp = self.depth_pred 46 | 47 | if raw_disp: 48 | return disp_map_visualize(disp) 49 | disp_scaled = disp[::corner_scale, ::corner_scale] 50 | left_img_np[:disp_scaled.shape[0], -disp_scaled.shape[1]:] = disp_map_visualize(disp_scaled) 51 | return left_img_np 52 | 53 | def compute_loss(self, depth_targets, log, name): 54 | if self.is_numpy: 55 | raise ValueError("Output is not in torch mode") 56 | depth_target_stacked = [] 57 | for depth_target in depth_targets: 58 | depth_target_stacked.append(depth_target.depth_pred.cuda()) 59 | depth_target_batch = torch.stack(depth_target_stacked) 60 | scale_factor = self.depth_pred.shape[2] / depth_target_batch.shape[2] 61 | 62 | if scale_factor != 1.0: 63 | depth_target_batch = F.interpolate( 64 | depth_target_batch[:, None, :, :], scale_factor=scale_factor 65 | )[:, 0, :, :] 66 | # scale down disparity by same factor as spatial resize 67 | depth_target_batch = depth_target_batch * scale_factor 68 | 69 | depth_loss = self.loss(self.depth_pred, depth_target_batch) / scale_factor 70 | log[name] = depth_loss 71 | return self.loss_multiplier * depth_loss 72 | 73 | def compute_metrics(self, depth_targets, log = {}, masks=None, mode = 'Train'): 74 | RMSE = [] 75 | MAE = [] 76 | REL = [] 77 | depth_target_stacked = [] 78 | for depth_target in depth_targets: 79 | depth_target_stacked.append(depth_target.depth_pred) 80 | depth_target_batch = torch.stack(depth_target_stacked) 81 | scale_factor = self.depth_pred.shape[2] / depth_target_batch.shape[2] 82 | if scale_factor != 1.0: 83 | depth_target_batch = F.interpolate( 84 | depth_target_batch[:, None, :, :], scale_factor=scale_factor 85 | )[:, 0, :, :] 86 | if masks is not None: 87 | depth_target_batch *= masks.cpu() 88 | self.depth_pred = self.depth_pred.clone()* masks.cuda() 89 | for i in range(len(depth_targets)): 90 | pred= self.depth_pred[i].cpu() 91 | gt = depth_target_batch[i].cpu() 92 | RMSE.append(self.computeRMSE(pred, gt).detach().numpy()) 93 | MAE.append(self.computeMAE(pred, gt).detach().numpy()) 94 | REL.append(self.computeREL(pred,gt).detach().numpy()) 95 | MAE = np.array(MAE) 96 | RMSE = np.array(RMSE) 97 | REL = np.array(REL) 98 | log[f'{mode}_depth_MAE_mean'] = MAE.mean() 99 | log[f'{mode}_depth_RMSE_mean'] = RMSE.mean() 100 | log[f'{mode}_depth_REL_mean'] = REL.mean() 101 | log[f'{mode}_depth_MAE_median'] = np.median(MAE) 102 | log[f'{mode}_depth_RMSE_median'] = np.median(RMSE) 103 | log[f'{mode}_depth_REL_median'] = np.median(REL) 104 | return MAE.mean(), np.median(MAE), RMSE.mean(), np.median(RMSE), REL.mean(), np.median(REL) 105 | 106 | def computeRMSE(self, pred, gt): 107 | eps=1e-5 108 | img1 = torch.zeros_like(pred) 109 | img2 = torch.zeros_like(gt) 110 | 111 | img1 = img1.copy_(pred) 112 | img2 = img2.copy_(gt) 113 | 114 | mask = gt > eps 115 | img1[~mask] = 0. 116 | img2[~mask] = 0. 117 | non_zero_count = torch.sum((pred>0).int()) 118 | return torch.sqrt(nn.MSELoss(reduction='sum')(img1, img2)/non_zero_count) 119 | 120 | def computeMAE(self, pred, gt): 121 | eps=1e-5 122 | img1 = torch.zeros_like(pred) 123 | img2 = torch.zeros_like(gt) 124 | 125 | img1 = img1.copy_(pred) 126 | img2 = img2.copy_(gt) 127 | 128 | mask = gt > eps 129 | img1[~mask] = 0. 130 | img2[~mask] = 0. 131 | non_zero_count = torch.sum((pred>0).int()) 132 | return nn.L1Loss(reduction='sum')(img1, img2)/non_zero_count 133 | 134 | def computeREL(self, pred, gt): 135 | mask = gt > 1e-5 136 | diff = torch.abs(gt[mask] - pred[mask]) / gt[mask] 137 | return diff.mean() 138 | 139 | def turbo_vis(heatmap, normalize=False, uint8_output=False): 140 | assert len(heatmap.shape) == 2 141 | if normalize: 142 | heatmap = heatmap.astype(np.float32) 143 | heatmap -= np.min(heatmap) 144 | heatmap /= np.max(heatmap) 145 | 146 | assert heatmap.dtype != np.uint8 147 | x = heatmap 148 | x = x.clip(0, 1) 149 | a = (x * 255).astype(int) 150 | b = (a + 1).clip(max=255) 151 | f = x * 255.0 - a 152 | turbo_map = np.array(cm.turbo.colors)[::-1] 153 | pseudo_color = (turbo_map[a] + (turbo_map[b] - turbo_map[a]) * f[..., np.newaxis]) 154 | pseudo_color[heatmap < 0.0] = 0.0 155 | pseudo_color[heatmap > 1.0] = 1.0 156 | if uint8_output: 157 | pseudo_color = (pseudo_color * 255).astype(np.uint8) 158 | return pseudo_color 159 | 160 | 161 | def disp_map_visualize(x, max_disp=_MAX_DISP): 162 | assert len(x.shape) == 2 163 | x = x.astype(np.float64) 164 | valid = ((x < max_disp) & np.isfinite(x)) 165 | if valid.sum() == 0: 166 | return np.zeros_like(x).astype(np.uint8) 167 | x -= np.min(x[valid]) 168 | x /= np.max(x[valid]) 169 | x = 1. - x 170 | x[~valid] = 0. 171 | 172 | x[np.isnan(x)] = 0. 173 | try: 174 | x = turbo_vis(x) 175 | except: 176 | print(np.unique(x)) 177 | raise ValueError 178 | x = (x * 255).astype(np.uint8) 179 | return x[:, :, ::-1] 180 | -------------------------------------------------------------------------------- /src/lib/net/post_processing/epnp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from src.lib import camera 3 | 4 | # Definition of unit cube centered at the orign 5 | x_width = 1.0 6 | y_depth = 1.0 7 | z_height = 1.0 8 | 9 | _WORLD_T_POINTS = np.array([ 10 | [0, 0, 0], #0 11 | [0, 0, z_height], #1 12 | [0, y_depth, z_height], #2 13 | [0, y_depth, 0], #3 14 | [x_width, 0, 0], #4 15 | [x_width, 0, z_height], #5 16 | [x_width, y_depth, z_height], #6 17 | [x_width, y_depth, 0], #7 18 | ]) - 0.5 19 | 20 | def get_2d_bbox_of_9D_box(camera_T_object, scale_matrix, camera_model): 21 | unit_box_homopoints = camera.convert_points_to_homopoints(_WORLD_T_POINTS.T) 22 | morphed_homopoints = camera_T_object @ (scale_matrix @ unit_box_homopoints) 23 | morphed_pixels = camera.convert_homopixels_to_pixels(camera_model.K_matrix @ morphed_homopoints).T 24 | bbox = [ 25 | np.array([np.min(morphed_pixels[:, 0]), 26 | np.min(morphed_pixels[:, 1])]), 27 | np.array([np.max(morphed_pixels[:, 0]), 28 | np.max(morphed_pixels[:, 1])]) 29 | ] 30 | return bbox 31 | 32 | 33 | def project_pose_onto_image(pose, camera_model): 34 | unit_box_homopoints = camera.convert_points_to_homopoints(_WORLD_T_POINTS.T) 35 | morphed_homopoints = pose.camera_T_object @ (pose.scale_matrix @ unit_box_homopoints) 36 | morphed_pixels = camera.convert_homopixels_to_pixels(camera_model.project(morphed_homopoints)).T 37 | morphed_pixels = morphed_pixels[:, ::-1] 38 | return morphed_pixels 39 | 40 | 41 | def get_2d_bbox_of_projection(bbox_ext): 42 | bbox = [ 43 | np.array([np.min(bbox_ext[:, 0]), np.min(bbox_ext[:, 1])]), 44 | np.array([np.max(bbox_ext[:, 0]), np.max(bbox_ext[:, 1])]) 45 | ] 46 | return bbox 47 | 48 | 49 | def define_control_points(): 50 | return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0]]) 51 | 52 | 53 | def compute_alphas(Xw, Cw): 54 | X = np.concatenate((Xw, np.array([np.ones((8))])), axis=0) # 4x8 55 | C = Cw.T # 4x3 --> 3x4 56 | C = np.concatenate((C, np.array([np.ones((4))])), axis=0) #4x4 57 | Alpha = np.matmul(np.linalg.inv(C), X) 58 | return Alpha.T 59 | 60 | 61 | def construct_M_matrix(bbox_pixels, alphas, K_matrix): 62 | ''' 63 | More detailed ePnP explanation: https://en.wikipedia.org/wiki/Perspective-n-Point 64 | ''' 65 | M = np.zeros([16, 12]) # 16 is for bounding box verticies, 12 is for ? control pts? 66 | f_x = K_matrix[0, 0] 67 | f_y = K_matrix[1, 1] 68 | c_x = K_matrix[0, 2] 69 | c_y = K_matrix[1, 2] 70 | for ii in range(8): 71 | u = bbox_pixels[0, ii] 72 | v = bbox_pixels[1, ii] 73 | for jj in range(4): 74 | alpha = alphas[ii, jj] 75 | M[ii * 2, jj * 3] = f_x * alpha 76 | M[ii * 2, jj * 3 + 2] = (c_x - u) * alpha 77 | M[ii * 2 + 1, jj * 3 + 1] = f_y * alpha 78 | M[ii * 2 + 1, jj * 3 + 2] = (c_y - v) * alpha 79 | return M 80 | 81 | 82 | def convert_control_to_box_vertices(control_points, alphas): 83 | bbox_vertices = np.zeros([8, 3]) 84 | for i in range(8): 85 | for j in range(4): 86 | alpha = alphas[i, j] 87 | bbox_vertices[i] = bbox_vertices[i] + alpha * control_points[j] 88 | 89 | return bbox_vertices 90 | 91 | 92 | def solve_for_control_points(M): 93 | e_vals, e_vecs = np.linalg.eig(M.T @ M) 94 | control_points = e_vecs[:, np.argmin(e_vals)] 95 | control_points = control_points.reshape([4, 3]) 96 | return control_points 97 | 98 | 99 | def compute_homopoints_from_control_points(camera_control_points, alphas, K_matrix): 100 | camera_points = convert_control_to_box_vertices(camera_control_points, alphas) 101 | camera_homopoints = camera.convert_points_to_homopoints(camera_points.T) 102 | return camera_homopoints 103 | unit_box_homopoints = camera.convert_points_to_homopoints(_WORLD_T_POINTS.T) 104 | 105 | 106 | def optimize_for_9D(bbox_pixels, camera_model, solve_for_transforms=False): 107 | K_matrix = camera_model.K_matrix 108 | Cw = define_control_points() 109 | Xw = _WORLD_T_POINTS # 8x3 110 | alphas = compute_alphas(Xw.T, Cw) 111 | M = construct_M_matrix(bbox_pixels, alphas, np.copy(K_matrix)) 112 | camera_control_points = solve_for_control_points(M) 113 | camera_points = convert_control_to_box_vertices(camera_control_points, alphas) 114 | camera_homopoints = camera.convert_points_to_homopoints(camera_points.T) 115 | if solve_for_transforms: 116 | unit_box_homopoints = camera.convert_points_to_homopoints(_WORLD_T_POINTS.T) 117 | # Test both the negative and positive solutions of the control points and pick the best one. Taken from the Google MediaPipe Code base. 118 | error_one, camera_T_object_one, scale_matrix_one = estimateSimilarityUmeyama( 119 | unit_box_homopoints, camera_homopoints 120 | ) 121 | 122 | camera_homopoints = compute_homopoints_from_control_points( 123 | -1 * camera_control_points, alphas, K_matrix 124 | ) 125 | error_two, camera_T_object_two, scale_matrix_two = estimateSimilarityUmeyama( 126 | unit_box_homopoints, camera_homopoints 127 | ) 128 | if error_one < error_two: 129 | camera_T_object = camera_T_object_one 130 | scale_matrix = scale_matrix_one 131 | else: 132 | camera_T_object = camera_T_object_two 133 | scale_matrix = scale_matrix_two 134 | 135 | # Compute Fit to original pixles: 136 | morphed_points = camera_T_object @ (scale_matrix @ unit_box_homopoints) 137 | morphed_pixels = points_to_camera(morphed_points, K_matrix) 138 | confidence = np.linalg.norm(bbox_pixels - morphed_pixels) 139 | return confidence, camera_T_object, scale_matrix 140 | camera_homopixels = K_matrix @ camera_homopoints 141 | return camera.convert_homopixels_to_pixels(camera_homopixels).T 142 | 143 | 144 | def estimateSimilarityUmeyama(source_hom, TargetHom): 145 | # Copy of original paper is at: http://web.stanford.edu/class/cs273/refs/umeyama.pdf 146 | assert source_hom.shape[0] == 4 147 | assert TargetHom.shape[0] == 4 148 | SourceCentroid = np.mean(source_hom[:3, :], axis=1) 149 | TargetCentroid = np.mean(TargetHom[:3, :], axis=1) 150 | nPoints = source_hom.shape[1] 151 | 152 | CenteredSource = source_hom[:3, :] - np.tile(SourceCentroid, (nPoints, 1)).transpose() 153 | CenteredTarget = TargetHom[:3, :] - np.tile(TargetCentroid, (nPoints, 1)).transpose() 154 | 155 | CovMatrix = np.matmul(CenteredTarget, np.transpose(CenteredSource)) / nPoints 156 | 157 | if np.isnan(CovMatrix).any(): 158 | print('nPoints:', nPoints) 159 | print('source_hom', source_hom.shape) 160 | print('TargetHom', TargetHom.shape) 161 | raise RuntimeError('There are NANs in the input.') 162 | 163 | U, D, Vh = np.linalg.svd(CovMatrix, full_matrices=True) 164 | d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0 165 | if d: 166 | D[-1] = -D[-1] 167 | U[:, -1] = -U[:, -1] 168 | 169 | Rotation = np.matmul(U, Vh) 170 | var_source = np.std(CenteredSource[:3, :], axis=1) 171 | var_target_aligned = np.std(np.linalg.inv(Rotation) @ CenteredTarget[:3, :], axis=1) 172 | ScaleMatrix = np.diag(var_target_aligned / var_source) 173 | 174 | Translation = TargetHom[:3, :].mean(axis=1) - source_hom[:3, :].mean(axis=1).dot( 175 | ScaleMatrix @ Rotation.T 176 | ) 177 | 178 | source_T_target = np.identity(4) 179 | source_T_target[:3, :3] = Rotation 180 | source_T_target[:3, 3] = Translation 181 | scale_matrix = np.eye(4) 182 | scale_matrix[0:3, 0:3] = ScaleMatrix 183 | # Measure error fit 184 | morphed_points = source_T_target @ (scale_matrix @ source_hom) 185 | error = np.linalg.norm(morphed_points - TargetHom) 186 | return error, source_T_target, scale_matrix 187 | 188 | 189 | def points_to_camera(world_T_homopoints, K_matrix): 190 | homopixels = K_matrix @ world_T_homopoints 191 | return camera.convert_homopixels_to_pixels(homopixels) 192 | 193 | 194 | def find_absolute_scale(new_z, camera_T_object, object_scale, debug=True): 195 | old_z = camera_T_object[2, 3] 196 | abs_camera_T_object = np.copy(camera_T_object) 197 | abs_camera_T_object[0:3, 3] = (new_z / old_z) * abs_camera_T_object[0:3, 3] 198 | abs_object_scale = np.eye(4) 199 | abs_object_scale[0:3, 0:3] = (new_z / old_z) * object_scale[0:3, 0:3] 200 | return abs_camera_T_object, abs_object_scale 201 | -------------------------------------------------------------------------------- /src/lib/net/post_processing/nms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def run(detections, overlap_thresh=0.75, order_mode='confidence'): 5 | # initialize the list of picked detections 6 | pruned_detections = [] 7 | 8 | # sort the indexes 9 | if order_mode == 'lower_y': 10 | idxs = create_order_by_lower_y(detections) 11 | elif order_mode == 'confidence': 12 | idxs = create_order_by_score(detections) 13 | 14 | overlap_function = get_2d_one_way_iou 15 | 16 | # keep looping while some indexes still remain in the indexes list 17 | while len(idxs) > 0: 18 | # grab the last index in the indexes list and add the index value 19 | # to the list of picked indexes 20 | last = len(idxs) - 1 21 | ii = idxs[last] 22 | indices_to_suppress = [] 23 | for index, index_of_index in zip(idxs[:last], range(last)): 24 | detection_proposal = detections[index] 25 | overlap = overlap_function(detections[ii], detection_proposal) 26 | if overlap > overlap_thresh: 27 | indices_to_suppress.append(index_of_index) 28 | # Add the the pruned_detections. 29 | pruned_detections.append(detections[ii]) 30 | indices_to_suppress.append(last) 31 | idxs = np.delete(idxs, indices_to_suppress) 32 | 33 | # return only the bounding boxes that were picked 34 | return prune_by_min_height(pruned_detections) 35 | 36 | 37 | def prune_by_min_height(detections): 38 | pruned_detections = [] 39 | for detection in detections: 40 | if detection.bbox[1][0] - detection.bbox[0][0] < 12: 41 | continue 42 | pruned_detections.append(detection) 43 | return pruned_detections 44 | 45 | 46 | def create_order_by_lower_y(detections): 47 | idxs = [] 48 | for detection in detections: 49 | idxs.append(detection.bbox[1][1]) 50 | idxs = np.argsort(idxs) 51 | return idxs 52 | 53 | 54 | def create_order_by_score(detections): 55 | idxs = [] 56 | for detection in detections: 57 | idxs.append(detection.score) 58 | idxs = np.argsort(idxs) 59 | return idxs 60 | 61 | 62 | def get_2d_one_way_iou(detection_one, detection_two): 63 | box_one = np.array([ 64 | detection_one.bbox[0][0], detection_one.bbox[0][1], detection_one.bbox[1][0], 65 | detection_one.bbox[1][1] 66 | ]) 67 | box_two = np.array([ 68 | detection_two.bbox[0][0], detection_two.bbox[0][1], detection_two.bbox[1][0], 69 | detection_two.bbox[1][1] 70 | ]) 71 | # determine the (x, y)-coordinates of the intersection rectangle 72 | xA = max(box_one[0], box_two[0]) 73 | yA = max(box_one[1], box_two[1]) 74 | xB = min(box_one[2], box_two[2]) 75 | yB = min(box_one[3], box_two[3]) 76 | # compute the area of intersection rectangle 77 | inter_area = max(0, xB - xA + 1) * max(0, yB - yA + 1) 78 | # compute the area of both the prediction and ground-truth 79 | # rectangles 80 | box_one_area = (box_one[2] - box_one[0] + 1) * (box_one[3] - box_one[1] + 1) 81 | box_two_area = (box_two[2] - box_two[0] + 1) * (box_two[3] - box_two[1] + 1) 82 | # compute the intersection over union by taking the intersection 83 | # area and dividing it by the sum of prediction + ground-truth 84 | # areas - the interesection area 85 | if float(box_one_area) == 0.0: 86 | return 0 87 | return inter_area / float(box_one_area + box_two_area - inter_area) 88 | -------------------------------------------------------------------------------- /src/lib/net/post_processing/obb_outputs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from src.lib import transform 5 | from src.lib.net.post_processing.epnp import optimize_for_9D 6 | from src.lib.net.post_processing import epnp, eval3d, nms, pose_outputs 7 | from src.lib.net import losses 8 | from src.lib.net.post_processing.eval3d import measure_3d_iou, EvalMetrics, measure_ADD 9 | import copy 10 | 11 | _mask_l1_loss = losses.MaskedL1Loss() 12 | _mse_loss = losses.MSELoss() 13 | 14 | 15 | class OBBOutput: 16 | 17 | def __init__(self, heatmap, vertex_field, z_centroid_field, cov_field, hparams, names = []): 18 | self.heatmap = heatmap 19 | self.vertex_field = vertex_field 20 | self.z_centroid_field = z_centroid_field 21 | self.cov_field = cov_field 22 | self.is_numpy = False 23 | self.hparams = hparams 24 | self.names = names 25 | 26 | # Converters for torch to numpy 27 | def convert_to_numpy_from_torch(self): 28 | 29 | self.heatmap = np.ascontiguousarray(self.heatmap.cpu().numpy()) 30 | if len(self.heatmap.shape) == 2: 31 | self.heatmap = self.heatmap.reshape((1, self.heatmap.shape[0], self.heatmap.shape[1])) 32 | self.vertex_field = np.ascontiguousarray(self.vertex_field.cpu().numpy()) 33 | if len(self.vertex_field.shape) == 3: 34 | self.vertex_field = self.vertex_field.reshape((1,self.vertex_field.shape[0], self.vertex_field.shape[1], self.vertex_field.shape[2])) 35 | self.vertex_field = self.vertex_field.transpose((0, 2, 3, 1)) 36 | self.vertex_field = self.vertex_field / 100.0 37 | self.cov_field = np.ascontiguousarray(self.cov_field.cpu().numpy()) 38 | if len(self.cov_field.shape) == 3: 39 | self.cov_field = self.cov_field.reshape((1,self.cov_field.shape[0], self.cov_field.shape[1], self.cov_field.shape[2])) 40 | self.cov_field = self.cov_field.transpose((0, 2, 3, 1)) 41 | self.cov_field = self.cov_field / 1000.0 42 | self.z_centroid_field = np.ascontiguousarray(self.z_centroid_field.cpu().numpy()) 43 | if len(self.z_centroid_field.shape) ==2: 44 | self.z_centroid_field = self.z_centroid_field.reshape((1, self.z_centroid_field.shape[0], self.z_centroid_field.shape[1])) 45 | self.z_centroid_field = self.z_centroid_field / 100.0 + 1.0 46 | self.is_numpy = True 47 | 48 | def convert_to_numpy_from_torch_gt(self): 49 | self.heatmap = np.ascontiguousarray(self.heatmap.cpu().numpy()) 50 | self.vertex_field = np.ascontiguousarray(self.vertex_field.cpu().numpy()) 51 | self.vertex_field = self.vertex_field.transpose((1, 2, 0)) 52 | self.vertex_field = self.vertex_field / 100.0 53 | self.cov_field = np.ascontiguousarray(self.cov_field.cpu().numpy()) 54 | self.cov_field = self.cov_field.transpose((1, 2, 0)) 55 | self.cov_field = self.cov_field / 1000.0 56 | self.z_centroid_field = np.ascontiguousarray(self.z_centroid_field.cpu().numpy()) 57 | self.z_centroid_field = self.z_centroid_field / 100.0 + 1.0 58 | self.is_numpy = True 59 | 60 | def convert_to_torch_from_numpy(self): 61 | self.vertex_field = self.vertex_field.transpose((2, 0, 1)) 62 | self.vertex_field = 100.0 * self.vertex_field 63 | self.vertex_field = torch.from_numpy(np.ascontiguousarray(self.vertex_field)).float() 64 | self.cov_field = self.cov_field.transpose((2, 0, 1)) 65 | self.cov_field = 1000.0 * self.cov_field 66 | self.cov_field = torch.from_numpy(np.ascontiguousarray(self.cov_field)).float() 67 | self.heatmap = torch.from_numpy(np.ascontiguousarray(self.heatmap)).float() 68 | # Normalize z_centroid by 1. 69 | self.z_centroid_field = 100.0 * (self.z_centroid_field - 1.0) 70 | self.z_centroid_field = torch.from_numpy(np.ascontiguousarray(self.z_centroid_field)).float() 71 | self.is_numpy = False 72 | 73 | def get_detections(self, camera_model): 74 | if not self.is_numpy: 75 | self.convert_to_numpy_from_torch() 76 | 77 | poses, scores = compute_oriented_bounding_boxes( 78 | np.copy(self.heatmap[0]), 79 | np.copy(self.vertex_field[0]), 80 | np.copy(self.z_centroid_field[0]), 81 | np.copy(self.cov_field[0]), 82 | camera_model=camera_model 83 | ) 84 | detections = [] 85 | for pose, score in zip(poses, scores): 86 | bbox = epnp.get_2d_bbox_of_9D_box(pose.camera_T_object, pose.scale_matrix, camera_model) 87 | detections.append( 88 | eval3d.Detection( 89 | camera_T_object=pose.camera_T_object, 90 | bbox=bbox, 91 | score=score, 92 | scale_matrix=pose.scale_matrix 93 | ) 94 | ) 95 | detections = nms.run(detections) 96 | return detections 97 | 98 | def get_visualization_img(self, left_img, camera_model=None): 99 | if not self.is_numpy: 100 | self.convert_to_numpy_from_torch() 101 | return draw_oriented_bounding_box_from_outputs( 102 | self.heatmap[0], 103 | self.vertex_field[0], 104 | self.cov_field[0], 105 | self.z_centroid_field[0], 106 | left_img, 107 | camera_model=camera_model 108 | ) 109 | 110 | def get_visualization_img_gt(self, left_img, camera_model=None): 111 | if not self.is_numpy: 112 | self.convert_to_numpy_from_torch_gt() 113 | return draw_oriented_bounding_box_from_outputs( 114 | self.heatmap, 115 | self.vertex_field, 116 | self.cov_field, 117 | self.z_centroid_field, 118 | left_img, 119 | camera_model=camera_model 120 | ) 121 | 122 | def compute_metrics(self, obb_targets, log, camera_model, cad_list): 123 | td_iou_list=[] 124 | IoU_list=[] 125 | mAP_list=[] 126 | num_sample=0 127 | ap_values = [] 128 | ADD_list=[] 129 | ADD_s_list=[] 130 | AUC_list=[] 131 | less2cm_list=[] 132 | AUC_adds_list=[] 133 | less2cm_adds_list=[] 134 | poses=[] 135 | bbox_ext=[] 136 | pose_item=None 137 | 138 | with torch.no_grad(): 139 | # 3D bbox eval 140 | detection_outputs=self.get_detections(camera_model) 141 | vertex_target = torch.stack([obb_target.vertex_field for obb_target in obb_targets]) 142 | z_centroid_field_target = torch.stack([ 143 | obb_target.z_centroid_field for obb_target in obb_targets 144 | ]) 145 | heatmap_target = torch.stack([obb_target.heatmap for obb_target in obb_targets]) 146 | cov_target = torch.stack([obb_target.cov_field for obb_target in obb_targets]) 147 | 148 | obb_target = copy.deepcopy(obb_targets[0]) 149 | obb_target.vertex_field = vertex_target 150 | obb_target.z_centroid_field = z_centroid_field_target 151 | obb_target.heatmap = heatmap_target 152 | obb_target.cov_field = cov_target 153 | pose_target = obb_target 154 | gt_detections = pose_target.get_detections(camera_model) 155 | 156 | # 3D IOU 157 | true_matches, pred_matches, pred_scores, class_labels, ignore_labels, sorted_detections, overlaps = measure_3d_iou(copy.deepcopy(detection_outputs), copy.deepcopy(gt_detections)) 158 | if len(overlaps) == 0: 159 | flag=True 160 | 161 | overlaps = np.array(overlaps) 162 | for pred_match in pred_matches: 163 | for i, index in enumerate(pred_match): 164 | if index == -1: 165 | continue 166 | td_iou_list.append(overlaps[i][int(index)]) 167 | 168 | # Obj Pose eval 169 | ADD, ADD_s, AUC, less2cm, AUC_adds, less2cm_adds=measure_ADD(detection_outputs, gt_detections, CAD_list=cad_list) 170 | 171 | # 3D mAP 172 | td_mAP=EvalMetrics() 173 | td_mAP.process_sample(true_matches=true_matches, pred_matches=pred_matches, pred_scores=pred_scores) 174 | ap_values.append(td_mAP.process_dataset()) 175 | 176 | return [np.array(td_iou_list).mean(), np.array(ap_values).mean(), 177 | np.array(ADD).mean(), np.array(ADD_s).mean(), 178 | np.array(AUC).mean(), np.array(less2cm).mean(), 179 | np.array(AUC_adds).mean(), np.array(less2cm_adds).mean()] 180 | 181 | def compute_loss(self, obb_targets, log): 182 | if self.is_numpy: 183 | raise ValueError("Output is not in torch mode") 184 | vertex_target = torch.stack([obb_target.vertex_field for obb_target in obb_targets]) 185 | z_centroid_field_target = torch.stack([ 186 | obb_target.z_centroid_field for obb_target in obb_targets 187 | ]) 188 | heatmap_target = torch.stack([obb_target.heatmap for obb_target in obb_targets]) 189 | cov_target = torch.stack([obb_target.cov_field for obb_target in obb_targets]) 190 | 191 | heatmap_target = heatmap_target.cuda() 192 | vertex_target = vertex_target.cuda() 193 | z_centroid_field_target = z_centroid_field_target.cuda() 194 | cov_target = cov_target.cuda() 195 | 196 | cov_loss = _mask_l1_loss(cov_target, self.cov_field, heatmap_target) 197 | log['cov_loss'] = cov_loss 198 | vertex_loss = _mask_l1_loss(vertex_target, self.vertex_field, heatmap_target) 199 | log['vertex_loss'] = vertex_loss 200 | z_centroid_loss = _mask_l1_loss(z_centroid_field_target, self.z_centroid_field, heatmap_target) 201 | log['z_centroid'] = z_centroid_loss 202 | 203 | heatmap_loss = _mse_loss(heatmap_target, self.heatmap) 204 | log['heatmap'] = heatmap_loss 205 | return self.hparams.loss_vertex_mult * vertex_loss + self.hparams.loss_heatmap_mult * heatmap_loss + self.hparams.loss_z_centroid_mult * z_centroid_loss + self.hparams.loss_rotation_mult * cov_loss 206 | 207 | 208 | def extract_cov_matrices_from_peaks(peaks, cov_matrices_output, scale_factor=8): 209 | assert peaks.shape[1] == 2 210 | cov_matrices = [] 211 | for ii in range(peaks.shape[0]): 212 | index = np.zeros([2]) 213 | index[0] = int(peaks[ii, 0] / scale_factor) 214 | index[1] = int(peaks[ii, 1] / scale_factor) 215 | index = index.astype(np.int) 216 | cov_mat_values = cov_matrices_output[index[0], index[1], :] 217 | cov_matrix = np.array([[cov_mat_values[0], cov_mat_values[3], cov_mat_values[4]], 218 | [cov_mat_values[3], cov_mat_values[1], cov_mat_values[5]], 219 | [cov_mat_values[4], cov_mat_values[5], cov_mat_values[2]]]) 220 | cov_matrices.append(cov_matrix) 221 | return cov_matrices 222 | 223 | 224 | def draw_oriented_bounding_box_from_outputs( 225 | heatmap_output, vertex_output, rotation_output, z_centroid_output, c_img, camera_model=None 226 | ): 227 | poses, _ = compute_oriented_bounding_boxes( 228 | np.copy(heatmap_output), 229 | np.copy(vertex_output), 230 | np.copy(z_centroid_output), 231 | np.copy(rotation_output), 232 | camera_model=camera_model, 233 | max_detections=100, 234 | ) 235 | return pose_outputs.draw_9dof_cv2_boxes(c_img, poses, camera_model=camera_model) 236 | 237 | 238 | def solve_for_rotation_from_cov_matrix(cov_matrix): 239 | assert cov_matrix.shape[0] == 3 240 | assert cov_matrix.shape[1] == 3 241 | U, D, Vh = np.linalg.svd(cov_matrix, full_matrices=True) 242 | d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0 243 | if d: 244 | D[-1] = -D[-1] 245 | U[:, -1] = -U[:, -1] 246 | # Rotation from world to points. 247 | rotation = np.eye(4) 248 | rotation[0:3, 0:3] = U 249 | return rotation 250 | 251 | 252 | def compute_oriented_bounding_boxes( 253 | heatmap_output, 254 | vertex_output, 255 | z_centroid_output, 256 | cov_matrices, 257 | camera_model, 258 | ground_truth_peaks=None, 259 | max_detections=np.inf, 260 | ): 261 | peaks = pose_outputs.extract_peaks_from_centroid( 262 | np.copy(heatmap_output), max_peaks=max_detections 263 | ) 264 | bboxes_ext = pose_outputs.extract_vertices_from_peaks( 265 | np.copy(peaks), np.copy(vertex_output), np.copy(heatmap_output) 266 | ) # Shape: List(np.array([8,2])) --> y,x order 267 | z_centroids = pose_outputs.extract_z_centroid_from_peaks( 268 | np.copy(peaks), np.copy(z_centroid_output) 269 | ) 270 | cov_matrices = pose_outputs.extract_cov_matrices_from_peaks(np.copy(peaks), np.copy(cov_matrices)) 271 | poses = [] 272 | scores = [] 273 | for bbox_ext, z_centroid, cov_matrix, peak in zip(bboxes_ext, z_centroids, cov_matrices, peaks): 274 | bbox_ext_flipped = bbox_ext[:, ::-1] # Switch from yx to xy 275 | # Solve for pose up to a scale factor 276 | error, camera_T_object, scale_matrix = optimize_for_9D( 277 | bbox_ext_flipped.T, camera_model, solve_for_transforms=True 278 | ) 279 | abs_camera_T_object, abs_object_scale = epnp.find_absolute_scale( 280 | -1.0 * z_centroid, camera_T_object, scale_matrix 281 | ) 282 | 283 | poses.append(transform.Pose(camera_T_object=abs_camera_T_object, scale_matrix=abs_object_scale)) 284 | scores.append(heatmap_output[peak[0], peak[1]]) 285 | 286 | return poses, scores 287 | -------------------------------------------------------------------------------- /src/lib/net/post_processing/pose_outputs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | 5 | from skimage.feature import peak_local_max 6 | 7 | from src.lib import transform, color_stuff, camera 8 | from src.lib.net.post_processing.epnp import optimize_for_9D 9 | from src.lib.net.post_processing import epnp, eval3d, nms 10 | from src.lib.net.post_processing.eval3d import measure_3d_iou, EvalMetrics 11 | from src.lib.net import losses 12 | import copy 13 | 14 | _mask_l1_loss = losses.MaskedL1Loss() 15 | _mse_loss = losses.MSELoss() 16 | 17 | 18 | class PoseOutput: 19 | 20 | def __init__(self, heatmap, vertex_field, z_centroid_field, hparams): 21 | self.heatmap = heatmap 22 | self.vertex_field = vertex_field 23 | self.z_centroid_field = z_centroid_field 24 | self.is_numpy = False 25 | self.hparams = hparams 26 | 27 | # Converters for torch to numpy 28 | def convert_to_numpy_from_torch(self): 29 | self.heatmap = np.ascontiguousarray(self.heatmap.cpu().numpy()) 30 | self.vertex_field = np.ascontiguousarray(self.vertex_field.cpu().numpy()) 31 | self.vertex_field = self.vertex_field.transpose((0, 2, 3, 1)) 32 | self.vertex_field = self.vertex_field / 100.0 33 | self.z_centroid_field = np.ascontiguousarray(self.z_centroid_field.cpu().numpy()) 34 | self.z_centroid_field = self.z_centroid_field / 100.0 + 1.0 35 | self.is_numpy = True 36 | 37 | def convert_to_torch_from_numpy(self): 38 | self.vertex_field = self.vertex_field.transpose((2, 0, 1)) 39 | self.vertex_field = 100.0 * self.vertex_field 40 | self.vertex_field = torch.from_numpy(np.ascontiguousarray(self.vertex_field)).float() 41 | self.heatmap = torch.from_numpy(np.ascontiguousarray(self.heatmap)).float() 42 | # Normalize z_centroid by 1. 43 | self.z_centroid_field = 100.0 * (self.z_centroid_field - 1.0) 44 | self.z_centroid_field = torch.from_numpy(np.ascontiguousarray(self.z_centroid_field)).float() 45 | self.is_numpy = False 46 | 47 | def get_detections(self): 48 | if not self.is_numpy: 49 | self.convert_to_numpy_from_torch() 50 | 51 | poses, scores = compute_9D_poses( 52 | np.copy(self.heatmap[0]), np.copy(self.vertex_field[0]), np.copy(self.z_centroid_field[0]) 53 | ) 54 | 55 | detections = [] 56 | for pose, score in zip(poses, scores): 57 | bbox = epnp.get_2d_bbox_of_9D_box(pose.camera_T_object, pose.scale_matrix) 58 | detections.append( 59 | eval3d.Detection( 60 | camera_T_object=pose.camera_T_object, 61 | bbox=bbox, 62 | score=score, 63 | scale_matrix=pose.scale_matrix 64 | ) 65 | ) 66 | 67 | detections = nms.run(detections) 68 | 69 | return detections 70 | 71 | def get_visualization_img(self, left_img): 72 | if not self.is_numpy: 73 | self.convert_to_numpy_from_torch() 74 | return draw_pose_from_outputs( 75 | self.heatmap[0], 76 | self.vertex_field[0], 77 | self.z_centroid_field[0], 78 | left_img, 79 | max_detections=100, 80 | ) 81 | 82 | def compute_metrics(self, pose_targets, log): 83 | 84 | td_iou_list=[] 85 | IoU_list=[] 86 | mAP_list=[] 87 | num_sample=0 88 | ap_values = [] 89 | ADD_list=[] 90 | ADD_s_list=[] 91 | AUC_list=[] 92 | less2cm_list=[] 93 | AUC_adds_list=[] 94 | less2cm_adds_list=[] 95 | 96 | # 3D bbox eval 97 | detection_outputs=self.get_detections() 98 | 99 | poses=[] 100 | bbox_ext=[] 101 | pose_item=None 102 | for pose_target in pose_targets: 103 | gt_detections = pose_target.get_detections() 104 | # 3D IOU 105 | true_matches, pred_matches, pred_scores, class_labels, ignore_labels, sorted_detections, overlaps = measure_3d_iou(copy.deepcopy(detection_outputs), copy.deepcopy(gt_detections)) 106 | if len(overlaps) == 0: 107 | flag=True 108 | break 109 | td_iou_list.append(overlaps[0][0]) 110 | 111 | # 3D mAP 112 | td_mAP=EvalMetrics() 113 | td_mAP.process_sample(true_matches=true_matches, pred_matches=pred_matches, pred_scores=pred_scores) 114 | ap_values.append(td_mAP.process_dataset()[0]) 115 | 116 | def get_obj_pose_and_bbox(heatmap_output, vertex_output, z_centroid_output, cov_matrices, camera_model): 117 | peaks = self.extract_peaks_from_centroid(np.copy(heatmap_output), max_peaks=np.inf) 118 | bboxes_ext = self.extract_vertices_from_peaks(np.copy(peaks), np.copy(vertex_output), np.copy(heatmap_output)) # Shape: List(np.array([8,2])) --> y,x order 119 | z_centroids = self.extract_z_centroid_from_peaks(np.copy(peaks), np.copy(z_centroid_output)) 120 | cov_matrices = self.extract_cov_matrices_from_peaks(np.copy(peaks), np.copy(cov_matrices)) 121 | poses = [] 122 | for bbox_ext, z_centroid, cov_matrix, peak in zip(bboxes_ext, z_centroids, cov_matrices, peaks): 123 | bbox_ext_flipped = bbox_ext[:, ::-1] # Switch from yx to xy 124 | # Solve for pose up to a scale factor 125 | error, camera_T_object, scale_matrix = optimize_for_9D(bbox_ext_flipped.T, camera_model, solve_for_transforms=True) 126 | abs_camera_T_object, abs_object_scale = epnp.find_absolute_scale( 127 | -1.0 * z_centroid, camera_T_object, scale_matrix 128 | ) 129 | poses.append(transform.Pose(camera_T_object=abs_camera_T_object, scale_matrix=abs_object_scale)) 130 | return poses, bboxes_ext 131 | def compute_loss(self, pose_targets, log): 132 | if self.is_numpy: 133 | raise ValueError("Output is not in torch mode") 134 | vertex_target = torch.stack([pose_target.vertex_field for pose_target in pose_targets]) 135 | z_centroid_field_target = torch.stack([ 136 | pose_target.z_centroid_field for pose_target in pose_targets 137 | ]) 138 | heatmap_target = torch.stack([pose_target.heatmap for pose_target in pose_targets]) 139 | 140 | # Move to GPU 141 | heatmap_target = heatmap_target.cuda() 142 | vertex_target = vertex_target.cuda() 143 | z_centroid_field_target = z_centroid_field_target.cuda() 144 | 145 | vertex_loss = _mask_l1_loss(vertex_target, self.vertex_field, heatmap_target) 146 | log['vertex_loss'] = vertex_loss 147 | z_centroid_loss = _mask_l1_loss(z_centroid_field_target, self.z_centroid_field, heatmap_target) 148 | log['z_centroid'] = z_centroid_loss 149 | 150 | heatmap_loss = _mse_loss(heatmap_target, self.heatmap) 151 | log['heatmap'] = heatmap_loss 152 | return self.hparams.loss_vertex_mult * vertex_loss + self.hparams.loss_heatmap_mult * heatmap_loss + self.hparams.loss_z_centroid_mult * z_centroid_loss 153 | 154 | 155 | def extract_peaks_from_centroid( 156 | centroid_heatmap, min_distance=5, min_confidence=0.3, max_peaks=np.inf 157 | ): 158 | peaks = peak_local_max( 159 | centroid_heatmap, 160 | min_distance=min_distance, 161 | threshold_abs=min_confidence, 162 | num_peaks=max_peaks 163 | ) 164 | peaks_old = peak_local_max( 165 | centroid_heatmap, min_distance=min_distance, threshold_abs=min_confidence 166 | ) 167 | 168 | return peaks 169 | 170 | 171 | def extract_vertices_from_peaks(peaks, vertex_fields, c_img, scale_factor=8): 172 | ''' 173 | peaks: np.array (n_objs, 2) 174 | vertex_fields: np.array (h, w, 16) 175 | ''' 176 | assert peaks.shape[1] == 2 177 | assert vertex_fields.shape[2] == 16 178 | height, width = c_img.shape[0:2] 179 | vertex_fields = vertex_fields 180 | vertex_fields[:, :, ::2] = (1.0 - vertex_fields[:, :, ::2]) * (2 * height) - height 181 | vertex_fields[:, :, 1::2] = (1.0 - vertex_fields[:, :, 1::2]) * (2 * width) - width 182 | bboxes = [] 183 | for ii in range(peaks.shape[0]): 184 | bbox = get_bbox_from_vertex(vertex_fields, peaks[ii, :], scale_factor=scale_factor) 185 | bboxes.append(bbox) 186 | return bboxes # Shape: List(np.array([8,2])) 187 | 188 | 189 | def extract_z_centroid_from_peaks(peaks, z_centroid_output, scale_factor=8): 190 | assert peaks.shape[1] == 2 191 | z_centroids = [] 192 | for ii in range(peaks.shape[0]): 193 | index = np.zeros([2]) 194 | index[0] = int(peaks[ii, 0] / scale_factor) 195 | index[1] = int(peaks[ii, 1] / scale_factor) 196 | index = index.astype(np.int) 197 | z_centroids.append(z_centroid_output[index[0], index[1]]) 198 | return z_centroids 199 | 200 | 201 | def extract_cov_matrices_from_peaks(peaks, cov_matrices_output, scale_factor=8): 202 | assert peaks.shape[1] == 2 203 | cov_matrices = [] 204 | for ii in range(peaks.shape[0]): 205 | index = np.zeros([2]) 206 | index[0] = int(peaks[ii, 0] / scale_factor) 207 | index[1] = int(peaks[ii, 1] / scale_factor) 208 | index = index.astype(np.int) 209 | cov_mat_values = cov_matrices_output[index[0], index[1], :] 210 | cov_matrix = np.array([[cov_mat_values[0], cov_mat_values[3], cov_mat_values[4]], 211 | [cov_mat_values[3], cov_mat_values[1], cov_mat_values[5]], 212 | [cov_mat_values[4], cov_mat_values[5], cov_mat_values[2]]]) 213 | cov_matrices.append(cov_matrix) 214 | return cov_matrices 215 | 216 | 217 | def get_bbox_from_vertex(vertex_fields, index, scale_factor=8): 218 | ''' 219 | index: where the vertex is located; order is y, x 220 | vertex_fields: (h,w,16) 221 | ''' 222 | assert index.shape[0] == 2 223 | index[0] = int(index[0] / scale_factor) 224 | index[1] = int(index[1] / scale_factor) 225 | bbox = vertex_fields[index[0], index[1], :] 226 | bbox = bbox.reshape([8, 2]) # y, x order 227 | bbox = scale_factor * (index) - bbox 228 | return bbox 229 | 230 | 231 | def draw_peaks(centroid_target, peaks): 232 | centroid_target = np.clip(centroid_target, 0.0, 1.0) * 255.0 233 | color = (0, 0, 255) 234 | height, width = centroid_target.shape 235 | # Make a 3 Channel image. 236 | c_img = np.zeros([centroid_target.shape[0], centroid_target.shape[1], 3]) 237 | c_img[:, :, 1] = centroid_target 238 | for ii in range(peaks.shape[0]): 239 | point = (int(peaks[ii, 1]), int(peaks[ii, 0])) 240 | c_img = cv2.circle(c_img, point, 8, color, -1) 241 | return cv2.resize(c_img, (width, height)) 242 | 243 | 244 | def draw_pose_from_outputs( 245 | heatmap_output, vertex_output, z_centroid_output, c_img, max_detections=np.inf 246 | ): 247 | poses, _ = compute_9D_poses( 248 | np.copy(heatmap_output), 249 | np.copy(vertex_output), 250 | np.copy(z_centroid_output), 251 | max_detections=max_detections, 252 | ) 253 | return draw_9dof_cv2_boxes(c_img, poses) 254 | 255 | 256 | def draw_pose_9D_from_detections(detections, c_img): 257 | successes = [] 258 | poses = [] 259 | for detection in detections: 260 | poses.append( 261 | transform.Pose( 262 | camera_T_object=detection.camera_T_object, scale_matrix=detection.scale_matrix 263 | ) 264 | ) 265 | successes.append(detection.success) 266 | return draw_9dof_cv2_boxes(c_img, poses, successes=successes) 267 | 268 | 269 | def solve_for_rotation_from_cov_matrix(cov_matrix): 270 | assert cov_matrix.shape[0] == 3 271 | assert cov_matrix.shape[1] == 3 272 | U, D, Vh = np.linalg.svd(cov_matrix, full_matrices=True) 273 | d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0 274 | if d: 275 | D[-1] = -D[-1] 276 | U[:, -1] = -U[:, -1] 277 | # Rotation from world to points. 278 | rotation = np.eye(4) 279 | rotation[0:3, 0:3] = U 280 | return rotation 281 | 282 | 283 | def compute_9D_poses(heatmap_output, vertex_output, z_centroid_output, max_detections=np.inf): 284 | peaks = extract_peaks_from_centroid(np.copy(heatmap_output), max_peaks=max_detections) 285 | bboxes_ext = extract_vertices_from_peaks( 286 | np.copy(peaks), np.copy(vertex_output), np.copy(heatmap_output) 287 | ) 288 | z_centroids = extract_z_centroid_from_peaks(np.copy(peaks), np.copy(z_centroid_output)) 289 | poses = [] 290 | scores = [] 291 | for bbox_ext, z_centroid, peak in zip(bboxes_ext, z_centroids, peaks): 292 | bbox_ext_flipped = bbox_ext[:, ::-1] 293 | # Solve for pose up to a scale factor 294 | error, camera_T_object, scale_matrix = optimize_for_9D( 295 | bbox_ext_flipped.T, solve_for_transforms=True 296 | ) 297 | # Assign correct depth factor 298 | abs_camera_T_object, abs_scale_matrix = epnp.find_absolute_scale( 299 | z_centroid, camera_T_object, scale_matrix 300 | ) 301 | poses.append(transform.Pose(camera_T_object=abs_camera_T_object, scale_matrix=abs_scale_matrix)) 302 | scores.append(heatmap_output[peak[0], peak[1]]) 303 | return poses, scores 304 | 305 | 306 | def draw_9dof_cv2_boxes(c_img, poses, camera_model=None, successes=None): 307 | boxes = [] 308 | for pose in poses: 309 | # Compute the bounds of the boxes current size and location 310 | unit_box_homopoints = camera.convert_points_to_homopoints(epnp._WORLD_T_POINTS.T) 311 | morphed_homopoints = pose.camera_T_object @ (pose.scale_matrix @ unit_box_homopoints) 312 | if camera_model == None: 313 | camera_model = camera.HSRCamera() 314 | else: 315 | camera_model = camera_model 316 | morphed_pixels = camera.convert_homopixels_to_pixels( 317 | camera_model.K_matrix @ morphed_homopoints 318 | ).T 319 | boxes.append(morphed_pixels[:, ::-1]) 320 | return draw_9dof_box(c_img, boxes, successes=successes) 321 | 322 | 323 | def draw_9dof_box(c_img, boxes, successes=None): 324 | if len(boxes) == 0: 325 | return c_img 326 | if successes is None: 327 | colors = color_stuff.get_colors(len(boxes)) 328 | else: 329 | colors = [] 330 | for success in successes: 331 | #TODO(michael.laskey): Move to Enum Structure 332 | if success == 1: 333 | colors.append(np.array([0, 255, 0]).astype(np.uint8)) 334 | elif success == -1: 335 | colors.append(np.array([255, 255, 0]).astype(np.uint8)) 336 | elif success == -2: 337 | colors.append(np.array([0, 0, 255]).astype(np.uint8)) 338 | else: 339 | colors.append(np.array([255, 0, 0]).astype(np.uint8)) 340 | c_img = cv2.cvtColor(np.array(c_img), cv2.COLOR_BGR2RGB) 341 | for vertices, color in zip(boxes, colors): 342 | vertices = vertices.astype(np.int) 343 | points = [] 344 | vertex_colors = (255, 0, 0) 345 | line_color = (int(color[0]), int(color[1]), int(color[2])) 346 | circle_colors = color_stuff.get_colors(8) 347 | for i, circle_color in zip(range(vertices.shape[0]), circle_colors): 348 | color = vertex_colors 349 | point = (int(vertices[i, 1]), int(vertices[i, 0])) 350 | c_img = cv2.circle(c_img, point, 1, (0, 255, 0), -1) 351 | points.append(point) 352 | 353 | # Draw the lines 354 | thickness = 1 355 | 356 | c_img = cv2.line(c_img, points[0], points[3], line_color, thickness) 357 | c_img = cv2.line(c_img, points[0], points[4], line_color, thickness) 358 | c_img = cv2.line(c_img, points[0], points[1], line_color, thickness) 359 | 360 | c_img = cv2.line(c_img, points[5], points[1], line_color, thickness) 361 | c_img = cv2.line(c_img, points[5], points[4], line_color, thickness) 362 | c_img = cv2.line(c_img, points[5], points[6], line_color, thickness) 363 | 364 | c_img = cv2.line(c_img, points[2], points[1], line_color, thickness) 365 | c_img = cv2.line(c_img, points[2], points[3], line_color, thickness) 366 | c_img = cv2.line(c_img, points[2], points[6], line_color, thickness) 367 | 368 | c_img = cv2.line(c_img, points[7], points[6], line_color, thickness) #9 369 | c_img = cv2.line(c_img, points[7], points[4], line_color, thickness) #6 370 | c_img = cv2.line(c_img, points[7], points[3], line_color, thickness) #12 371 | return c_img 372 | -------------------------------------------------------------------------------- /src/lib/net/post_processing/segmentation_outputs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | from torch.nn import functional as F 5 | from src.lib import color_stuff 6 | 7 | 8 | class SegmentationOutput: 9 | 10 | def __init__(self, seg_pred, hparams): 11 | self.seg_pred = seg_pred 12 | self.is_numpy = False 13 | self.hparams = hparams 14 | 15 | # Converters for torch to numpy 16 | def convert_to_numpy_from_torch(self): 17 | if not self.is_numpy: 18 | self.seg_pred = np.ascontiguousarray(self.seg_pred.detach().cpu().numpy()) 19 | self.is_numpy = True 20 | 21 | def convert_to_torch_from_numpy(self): 22 | self.seg_pred = torch.from_numpy(np.ascontiguousarray(self.seg_pred)).long() 23 | self.is_numpy = False 24 | 25 | def get_visualization_img(self, left_image): 26 | if not self.is_numpy: 27 | self.convert_to_numpy_from_torch() 28 | return draw_segmentation_mask(left_image, self.seg_pred[0]) 29 | 30 | def get_visualization_img_gt(self, left_image): 31 | if not self.is_numpy: 32 | self.convert_to_numpy_from_torch() 33 | return draw_segmentation_mask_gt(left_image, self.seg_pred, num_classes=np.unique(self.seg_pred).shape[0]) 34 | 35 | def get_prediction(self): 36 | if not self.is_numpy: 37 | self.convert_to_numpy_from_torch() 38 | return self.seg_pred[0] 39 | 40 | def compute_loss(self, seg_targets, log): 41 | if self.is_numpy: 42 | raise ValueError("Output is not in torch mode") 43 | seg_target_stacked = [] 44 | for seg_target in seg_targets: 45 | seg_target_stacked.append(seg_target.seg_pred) 46 | seg_target_batch = torch.stack(seg_target_stacked) 47 | seg_target_batch=seg_target_batch.cuda() 48 | seg_loss = F.cross_entropy(self.seg_pred, seg_target_batch, reduction="mean", ignore_index=-100) 49 | log['segmentation'] = seg_loss 50 | return self.hparams.loss_seg_mult * seg_loss 51 | 52 | def compute_metrics(self, seg_targets, log, threshold=0.5, mode = 'Train'): 53 | smooth = 1 54 | IoUs = [] 55 | mAPs = [] 56 | num_class=0 57 | 58 | for i, seg_target in enumerate(seg_targets): 59 | num_class=seg_target.seg_pred.max()+1 60 | seg_target.convert_to_numpy_from_torch() 61 | self.convert_to_numpy_from_torch() 62 | self.seg_pred[i]=self.seg_pred[i].astype(int) 63 | seg_target.seg_pred=seg_target.seg_pred.astype(int) 64 | 65 | seg_out_pred=np.argmax(self.seg_pred[i], axis=0) 66 | 67 | IoU = [[self.computeIoU(seg_out_pred.astype(int) == k, seg_target.seg_pred.astype(int) == j) 68 | for k in range(seg_target.seg_pred.astype(int).max()+1)] 69 | for j in range(seg_target.seg_pred.astype(int).max()+1)] 70 | 71 | IoUs.append([IoU[j][j] for j in range(seg_target.seg_pred.astype(int).max()+1)]) 72 | 73 | positive = [np.array(IoU[j]) > threshold for j in range(seg_target.seg_pred.astype(int).max()+1)] 74 | TP = np.array([positive[j][j] for j in range(seg_target.seg_pred.astype(int).max()+1)]).sum() 75 | 76 | num_iou=seg_target.seg_pred.astype(int).max()+1 77 | mAPs.append(TP/num_iou) 78 | 79 | IoU = np.array(IoUs).sum() / num_class / (i+1) 80 | mAP = np.array(mAPs).sum() / (i+1) 81 | 82 | log[f'{mode}_seg_IoU'] = IoU 83 | log[f'{mode}_seg_mAP'] = mAP 84 | return IoU, mAP 85 | 86 | def computeIoU(self,pred, ground): 87 | assert (pred.max() <=1) and (ground.max()<=1) ,f'Incorrect behaviour {pred.max()}, {ground.max()}' 88 | assert (pred.min() <=1) and (ground.min()<=1) ,f'Incorrect behaviour {pred.min()}, {ground.min()}' 89 | smooth = 1 90 | intersection = (pred * ground).sum() 91 | union = (pred + ground).sum() 92 | return (intersection + smooth)/(union + smooth) 93 | 94 | 95 | def draw_segmentation_mask_gt(color_img, seg_mask, num_classes=5): 96 | assert len(seg_mask.shape) == 2 97 | seg_mask = seg_mask.astype(np.uint8) 98 | colors = color_stuff.get_panoptic_colors() 99 | color_img = color_img_to_gray(color_img) 100 | for ii, color in zip(range(num_classes), colors): 101 | colored_mask = np.zeros([seg_mask.shape[0], seg_mask.shape[1], 3]) 102 | colored_mask[seg_mask == ii, :] = color 103 | color_img = cv2.addWeighted( 104 | color_img.astype(np.uint8), 0.9, colored_mask.astype(np.uint8), 0.4, 0 105 | ) 106 | return cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) 107 | 108 | 109 | def color_img_to_gray(image): 110 | gray_scale_img = np.zeros(image.shape) 111 | img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 112 | for i in range(3): 113 | gray_scale_img[:, :, i] = img 114 | gray_scale_img[:, :, i] = img 115 | return gray_scale_img 116 | 117 | 118 | def draw_segmentation_mask(color_img, seg_mask): 119 | assert len(seg_mask.shape) == 3 120 | num_classes = seg_mask.shape[0] 121 | # Convert to predictions 122 | seg_mask_predictions = np.argmax(seg_mask, axis=0) 123 | return draw_segmentation_mask_gt(color_img, seg_mask_predictions, num_classes=num_classes) 124 | -------------------------------------------------------------------------------- /src/lib/occlusions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from src.lib import camera 4 | from src.lib.net.post_processing import epnp 5 | 6 | 7 | ### 3D Occlusions 8 | def object_is_outside_image(detection, camera_model): 9 | bbox = epnp.get_2d_bbox_of_9D_box(detection.camera_T_object, detection.scale_matrix, camera_model) 10 | width = camera_model.width 11 | height = camera_model.height 12 | 13 | if bbox[0][0] < 0 or bbox[0][0] < 0: 14 | return True 15 | if bbox[1][0] > width or bbox[1][1] > height: 16 | return True 17 | return False 18 | 19 | 20 | def get_bbox_image(detection, camera_model): 21 | bbox = epnp.get_2d_bbox_of_9D_box(detection.camera_T_object, detection.scale_matrix, camera_model) 22 | width = camera_model.width 23 | height = camera_model.height 24 | img = np.zeros([height, width]) 25 | img[int(bbox[0][1]):int(bbox[1][1]), int(bbox[0][0]):int(bbox[1][0])] = 1.0 26 | return img 27 | 28 | 29 | def mark_occlusions_in_detections( 30 | detections, occlusion_score=0.5, camera_model=None, allow_outside_of_image=False 31 | ): 32 | if camera_model is None: 33 | camera_model = camera.FMKCamera() 34 | for ii in range(len(detections)): 35 | if object_is_outside_image(detections[ii], camera_model): 36 | detections[ii].ignore = True 37 | continue 38 | bbox_unocc = get_bbox_image(detections[ii], camera_model) 39 | bbox_occ = np.copy(bbox_unocc) 40 | bbox_prop = np.copy(bbox_occ) 41 | for detection_proposal in detections: 42 | # Check if the object is behind the target object. 43 | if detection_proposal.camera_T_object[2, 3] >= detections[ii].camera_T_object[2, 3]: 44 | continue 45 | bbox_proposal = get_bbox_image(detection_proposal, camera_model) 46 | bbox_occ = bbox_occ - bbox_proposal 47 | bbox_prop = bbox_prop + bbox_proposal 48 | occlusion_level = np.sum(bbox_occ > 0) / np.sum(bbox_unocc > 0) 49 | if occlusion_level < occlusion_score: 50 | detections[ii].ignore = True 51 | -------------------------------------------------------------------------------- /src/lib/transform.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import numpy as np 4 | 5 | X_AXIS = np.array([1., 0., 0.]) 6 | Y_AXIS = np.array([0., 1., 0.]) 7 | Z_AXIS = np.array([0., 0., 1.]) 8 | 9 | 10 | @dataclasses.dataclass 11 | class Pose: 12 | camera_T_object: np.ndarray 13 | scale_matrix: np.ndarray = np.eye(4) 14 | 15 | 16 | class Transform: 17 | 18 | def __init__(self, matrix=None): 19 | if matrix is None: 20 | self.matrix = np.eye(4) 21 | else: 22 | self.matrix = matrix 23 | self.is_concrete = True 24 | 25 | def apply_transform(self, transform): 26 | assert self.is_concrete 27 | assert isinstance(transform, Transform) 28 | self.matrix = self.matrix @ transform.matrix 29 | 30 | def inverse(self): 31 | assert self.is_concrete 32 | return Transform(matrix=np.linalg.inv(self.matrix)) 33 | 34 | def __repr__(self): 35 | assert self.matrix.shape == (4, 4) 36 | if self.is_SE3(): 37 | return f'Transform(translate={self.translation})' 38 | else: 39 | return f'Transform(IS_NOT_SE3,matrix={self.matrix})' 40 | 41 | def is_SE3(self): 42 | return matrixIsSE3(self.matrix) 43 | 44 | @property 45 | def translation(self): 46 | return self.matrix[:3, 3] 47 | 48 | @translation.setter 49 | def translation(self, value): 50 | assert value.shape == (3,) 51 | self.matrix[:3, 3] = value 52 | 53 | @property 54 | def rotation(self): 55 | return self.matrix[:3, :3] 56 | 57 | @rotation.setter 58 | def rotation(self, value): 59 | assert value.shape == (3, 3) 60 | self.matrix[:3, :3] = value 61 | 62 | @classmethod 63 | def from_aa(cls, axis=X_AXIS, angle_deg=0., translation=None): 64 | assert axis.shape == (3,) 65 | matrix = np.eye(4) 66 | if angle_deg != 0.: 67 | matrix[:3, :3] = axis_angle_to_rotation_matrix(axis, np.deg2rad(angle_deg)) 68 | if translation is not None: 69 | translation = np.array(translation) 70 | assert translation.shape == (3,) 71 | matrix[:3, 3] = translation 72 | return cls(matrix=matrix) 73 | 74 | 75 | def matrixIsSE3(matrix): 76 | if not np.allclose(matrix[3, :], np.array([0., 0., 0., 1.])): 77 | return False 78 | rot = matrix[:3, :3] 79 | if not np.allclose(rot @ rot.T, np.eye(3)): 80 | return False 81 | if not np.isclose(np.linalg.det(rot), 1.): 82 | return False 83 | return True 84 | 85 | 86 | def find_closest_SE3(matrix): 87 | matrix = np.copy(matrix) 88 | assert np.allclose(matrix[3, :], np.array([0., 0., 0., 1.])) 89 | rotation = matrix[:3, :3] 90 | u, s, vh = np.linalg.svd(rotation) 91 | matrix[:3, :3] = u @ vh 92 | assert matrixIsSE3(matrix) 93 | return matrix 94 | 95 | 96 | def axis_angle_to_rotation_matrix(axis, theta): 97 | """Return the rotation matrix associated with counterclockwise rotation about 98 | the given axis by theta radians. 99 | 100 | Args: 101 | axis: a list which specifies a unit axis 102 | theta: an angle in radians, for which to rotate around by 103 | Returns: 104 | A 3x3 rotation matrix 105 | """ 106 | axis = np.asarray(axis) 107 | axis = axis / np.sqrt(np.dot(axis, axis)) 108 | a = np.cos(theta / 2.0) 109 | b, c, d = -axis * np.sin(theta / 2.0) 110 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 111 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 112 | return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 113 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 114 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) 115 | -------------------------------------------------------------------------------- /utils/rotation_SVD.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def find_covariance(scene_pd: "np.array['3,N', float]") -> "np.array['3,3', float]": 5 | scene_mean = scene_pd.mean(axis=-1).reshape((3,1)) 6 | scene_hat = scene_pd - scene_mean 7 | scene_hat = scene_hat[:, np.random.choice(np.arange(scene_hat.shape[1]), size=min(1000, scene_hat.shape[1]//10), replace=False)] 8 | covariance = np.cov(scene_hat) 9 | return covariance 10 | 11 | 12 | def covariance_obj(obj:"open3d.geometry.TriangleMesh", rotation:"np.array['3,3', float]") -> "np.array['3,3', float]": 13 | import open3d 14 | pcd = open3d.geometry.TriangleMesh.sample_points_uniformly(obj, number_of_points=10000) 15 | points = np.asarray(pcd.points).T 16 | points_t = rotation @ points 17 | C = find_covariance(points_t) 18 | return C 19 | 20 | 21 | def covariance_mesh(obj: "bpy.types.Object", rotation:"np.array['3,3', float]") -> "np.array['3,3', float]": 22 | import bpy 23 | coords = np.array([(obj.matrix_world @ v.co) for v in obj.data.vertices]) # Nx3 24 | coords = coords.T # 3xN 25 | points_t = rotation @ coords# 3xN 26 | C = find_covariance(points_t) 27 | return C 28 | 29 | 30 | def covariance2rotation(covariance: "np.array['3,3', float]") -> "np.array['3,3', float]": 31 | U, s, V = np.linalg.svd(covariance, full_matrices=True) 32 | d = (np.linalg.det(U) * np.linalg.det(V)) < 0.0 33 | if d: 34 | s[-1] = -s[-1] 35 | U[:, -1] = -U[:, -1] 36 | return U 37 | 38 | 39 | def rand_rotation_matrix(deflection=1.0): 40 | theta, phi, z = np.random.uniform(size=(3,)) 41 | theta = theta * 2.0 * deflection * np.pi # Rotation about the pole (Z). 42 | phi = phi * 2.0 * np.pi # For direction of pole deflection. 43 | z = z * 2.0 * deflection # For magnitude of pole deflection. 44 | r = np.sqrt(z) 45 | V = (np.sin(phi) * r, 46 | np.cos(phi) * r, 47 | np.sqrt(2.0 - z)) 48 | st = np.sin(theta) 49 | ct = np.cos(theta) 50 | R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1))) 51 | M = (np.outer(V, V) - np.eye(3)).dot(R) 52 | return M 53 | 54 | 55 | def test_r_svd(): 56 | import open3d 57 | r = rand_rotation_matrix() 58 | obj = open3d.io.read_triangle_mesh("../KeyPose/Model&Keypoint/mug_0.obj") 59 | pcd = open3d.geometry.TriangleMesh.sample_points_uniformly(obj, number_of_points=10000) 60 | points = np.asarray(pcd.points).T 61 | points_t = r @ points + np.random.rand(3, 1) 62 | C = find_covariance(points_t) 63 | r_hat = covariance2rotation(C) 64 | print(r - r_hat) 65 | 66 | 67 | if __name__ == "__main__": 68 | test_r_svd() 69 | --------------------------------------------------------------------------------