├── README.md
├── SAM-6D
├── Data
│ ├── Example
│ │ ├── camera.json
│ │ ├── depth.png
│ │ ├── obj_000005.ply
│ │ ├── outputs
│ │ │ └── sam6d_results
│ │ │ │ ├── vis_ism.png
│ │ │ │ └── vis_pem.png
│ │ └── rgb.png
│ └── README.md
├── Instance_Segmentation_Model
│ ├── LICENSE
│ ├── README.md
│ ├── configs
│ │ ├── callback
│ │ │ ├── base.yaml
│ │ │ ├── checkpoint
│ │ │ │ └── base.yaml
│ │ │ └── lr
│ │ │ │ └── base.yaml
│ │ ├── data
│ │ │ └── bop.yaml
│ │ ├── download.yaml
│ │ ├── hydra
│ │ │ ├── hydra_logging
│ │ │ │ ├── console.yaml
│ │ │ │ ├── custom.yaml
│ │ │ │ └── rich.yaml
│ │ │ └── job_logging
│ │ │ │ ├── console.yaml
│ │ │ │ ├── custom.yaml
│ │ │ │ └── rich.yaml
│ │ ├── machine
│ │ │ ├── local.yaml
│ │ │ ├── slurm.yaml
│ │ │ └── trainer
│ │ │ │ ├── local.yaml
│ │ │ │ └── slurm.yaml
│ │ ├── model
│ │ │ ├── ISM_fastsam.yaml
│ │ │ ├── ISM_sam.yaml
│ │ │ ├── descriptor_model
│ │ │ │ └── dinov2.yaml
│ │ │ └── segmentor_model
│ │ │ │ ├── fast_sam.yaml
│ │ │ │ └── sam.yaml
│ │ ├── run_inference.yaml
│ │ └── user
│ │ │ └── default.yaml
│ ├── download_dinov2.py
│ ├── download_fastsam.py
│ ├── download_sam.py
│ ├── environment.yml
│ ├── exp.sh
│ ├── model
│ │ ├── __init__.py
│ │ ├── detector.py
│ │ ├── dinov2.py
│ │ ├── fast_sam.py
│ │ ├── layers
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── block.py
│ │ │ ├── dino_head.py
│ │ │ ├── drop_path.py
│ │ │ ├── layer_scale.py
│ │ │ ├── mlp.py
│ │ │ ├── patch_embed.py
│ │ │ └── swiglu_ffn.py
│ │ ├── loss.py
│ │ ├── sam.py
│ │ ├── utils.py
│ │ └── vision_transformer.py
│ ├── provider
│ │ ├── base_bop.py
│ │ ├── bop.py
│ │ └── bop_pbr.py
│ ├── run_inference.py
│ ├── run_inference_custom.py
│ ├── segment_anything
│ │ ├── __init__.py
│ │ ├── automatic_mask_generator.py
│ │ ├── build_sam.py
│ │ ├── modeling
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── image_encoder.py
│ │ │ ├── mask_decoder.py
│ │ │ ├── prompt_encoder.py
│ │ │ ├── sam.py
│ │ │ └── transformer.py
│ │ ├── predictor.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── amg.py
│ │ │ ├── onnx.py
│ │ │ └── transforms.py
│ └── utils
│ │ ├── bbox_utils.py
│ │ ├── inout.py
│ │ ├── logging.py
│ │ ├── poses
│ │ ├── create_template_poses.py
│ │ ├── find_neighbors.py
│ │ ├── fps.py
│ │ ├── pose_utils.py
│ │ ├── predefined_poses
│ │ │ ├── cam_poses_level0.npy
│ │ │ ├── cam_poses_level1.npy
│ │ │ ├── cam_poses_level2.npy
│ │ │ ├── idx_all_level0_in_level2.npy
│ │ │ ├── idx_all_level1_in_level2.npy
│ │ │ ├── idx_all_level2_in_level2.npy
│ │ │ ├── obj_poses_level0.npy
│ │ │ ├── obj_poses_level1.npy
│ │ │ └── obj_poses_level2.npy
│ │ └── pyrender.py
│ │ ├── trimesh_utils.py
│ │ └── weight.py
├── Pose_Estimation_Model
│ ├── README.md
│ ├── config
│ │ └── base.yaml
│ ├── dependencies.sh
│ ├── download_sam6d-pem.py
│ ├── model
│ │ ├── coarse_point_matching.py
│ │ ├── feature_extraction.py
│ │ ├── fine_point_matching.py
│ │ ├── pointnet2
│ │ │ ├── _ext_src
│ │ │ │ ├── include
│ │ │ │ │ ├── ball_query.h
│ │ │ │ │ ├── cuda_utils.h
│ │ │ │ │ ├── group_points.h
│ │ │ │ │ ├── interpolate.h
│ │ │ │ │ ├── sampling.h
│ │ │ │ │ └── utils.h
│ │ │ │ └── src
│ │ │ │ │ ├── ball_query.cpp
│ │ │ │ │ ├── ball_query_gpu.cu
│ │ │ │ │ ├── bindings.cpp
│ │ │ │ │ ├── group_points.cpp
│ │ │ │ │ ├── group_points_gpu.cu
│ │ │ │ │ ├── interpolate.cpp
│ │ │ │ │ ├── interpolate_gpu.cu
│ │ │ │ │ ├── sampling.cpp
│ │ │ │ │ └── sampling_gpu.cu
│ │ │ ├── pointnet2_modules.py
│ │ │ ├── pointnet2_test.py
│ │ │ ├── pointnet2_utils.py
│ │ │ ├── pytorch_utils.py
│ │ │ └── setup.py
│ │ ├── pose_estimation_model.py
│ │ └── transformer.py
│ ├── provider
│ │ ├── bop_test_dataset.py
│ │ └── training_dataset.py
│ ├── run_inference_custom.py
│ ├── test_bop.py
│ ├── train.py
│ └── utils
│ │ ├── bop_object_utils.py
│ │ ├── data_utils.py
│ │ ├── draw_utils.py
│ │ ├── loss_utils.py
│ │ ├── model_utils.py
│ │ └── solver.py
├── Render
│ ├── render_bop_templates.py
│ ├── render_custom_templates.py
│ ├── render_gso_templates.py
│ └── render_shapenet_templates.py
├── demo.sh
├── environment.yaml
└── prepare.sh
└── pics
├── overview_pem.png
├── overview_sam_6d.png
└── vis.gif
/README.md:
--------------------------------------------------------------------------------
1 | #
SAM-6D: Segment Anything Model Meets Zero-Shot 6D Object Pose Estimation
2 |
3 | #### [Jiehong Lin](https://jiehonglin.github.io/), [Lihua Liu](https://github.com/foollh), [Dekun Lu](https://github.com/WuTanKun), [Kui Jia](http://kuijia.site/)
4 | #### CVPR 2024
5 | #### [[Paper]](https://arxiv.org/abs/2311.15707)
6 |
7 |
8 |
9 |
10 |
11 |
12 | ## News
13 | - [2024/03/07] We publish an updated version of our paper on [ArXiv](https://arxiv.org/abs/2311.15707).
14 | - [2024/02/29] Our paper is accepted by CVPR2024!
15 |
16 |
17 | ## Update Log
18 | - [2024/03/05] We update the demo to support [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), you can do this by specifying `SEGMENTOR_MODEL=fastsam` in demo.sh.
19 | - [2024/03/03] We upload a [docker image](https://hub.docker.com/r/lihualiu/sam-6d/tags) for running custom data.
20 | - [2024/03/01] We update the released [model](https://drive.google.com/file/d/1joW9IvwsaRJYxoUmGo68dBVg-HcFNyI7/view?usp=sharing) of PEM. For the new model, a larger batchsize of 32 is set, while that of the old is 12.
21 |
22 | ## Overview
23 | In this work, we employ Segment Anything Model as an advanced starting point for **zero-shot 6D object pose estimation** from RGB-D images, and propose a novel framework, named **SAM-6D**, which utilizes the following two dedicated sub-networks to realize the focused task:
24 | - [x] [Instance Segmentation Model](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Instance_Segmentation_Model)
25 | - [x] [Pose Estimation Model](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Pose_Estimation_Model)
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 | ## Getting Started
34 |
35 | ### 1. Preparation
36 | Please clone the repository locally:
37 | ```
38 | git clone https://github.com/JiehongLin/SAM-6D.git
39 | ```
40 | Install the environment and download the model checkpoints:
41 | ```
42 | cd SAM-6D
43 | sh prepare.sh
44 | ```
45 | We also provide a [docker image](https://hub.docker.com/r/lihualiu/sam-6d/tags) for convenience.
46 |
47 | ### 2. Evaluation on the custom data
48 | ```
49 | # set the paths
50 | export CAD_PATH=Data/Example/obj_000005.ply # path to a given cad model(mm)
51 | export RGB_PATH=Data/Example/rgb.png # path to a given RGB image
52 | export DEPTH_PATH=Data/Example/depth.png # path to a given depth map(mm)
53 | export CAMERA_PATH=Data/Example/camera.json # path to given camera intrinsics
54 | export OUTPUT_DIR=Data/Example/outputs # path to a pre-defined file for saving results
55 |
56 | # run inference
57 | cd SAM-6D
58 | sh demo.sh
59 | ```
60 |
61 |
62 |
63 | ## Citation
64 | If you find our work useful in your research, please consider citing:
65 |
66 | @article{lin2023sam,
67 | title={SAM-6D: Segment Anything Model Meets Zero-Shot 6D Object Pose Estimation},
68 | author={Lin, Jiehong and Liu, Lihua and Lu, Dekun and Jia, Kui},
69 | journal={arXiv preprint arXiv:2311.15707},
70 | year={2023}
71 | }
72 |
73 |
74 | ## Contact
75 |
76 | If you have any questions, please feel free to contact the authors.
77 |
78 | Jiehong Lin: [mortimer.jh.lin@gmail.com](mailto:mortimer.jh.lin@gmail.com)
79 |
80 | Lihua Liu: [lihualiu.scut@gmail.com](mailto:lihualiu.scut@gmail.com)
81 |
82 | Dekun Lu: [derkunlu@gmail.com](mailto:derkunlu@gmail.com)
83 |
84 | Kui Jia: [kuijia@gmail.com](kuijia@gmail.com)
85 |
86 |
--------------------------------------------------------------------------------
/SAM-6D/Data/Example/camera.json:
--------------------------------------------------------------------------------
1 | {"cam_K": [572.4114, 0.0, 325.2611, 0.0, 573.57043, 242.04899, 0.0, 0.0, 1.0], "depth_scale": 1.0}
--------------------------------------------------------------------------------
/SAM-6D/Data/Example/depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Data/Example/depth.png
--------------------------------------------------------------------------------
/SAM-6D/Data/Example/outputs/sam6d_results/vis_ism.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Data/Example/outputs/sam6d_results/vis_ism.png
--------------------------------------------------------------------------------
/SAM-6D/Data/Example/outputs/sam6d_results/vis_pem.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Data/Example/outputs/sam6d_results/vis_pem.png
--------------------------------------------------------------------------------
/SAM-6D/Data/Example/rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Data/Example/rgb.png
--------------------------------------------------------------------------------
/SAM-6D/Data/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Data Structure
3 | Our data structure in `Data` folder is constructed as follows:
4 | ```
5 | Data
6 | ├── MegaPose-Training-Data
7 | ├── MegaPose-GSO
8 | ├──google_scanned_objects
9 | ├──templates
10 | └──train_pbr_web
11 | ├── MegaPose-ShapeNetCore
12 | ├──shapenetcorev2
13 | ├──templates
14 | └──train_pbr_web
15 | ├── BOP # https://bop.felk.cvut.cz/datasets/
16 | ├──tudl
17 | ├──lmo
18 | ├──ycbv
19 | ├──icbin
20 | ├──hb
21 | ├──itodd
22 | └──tless
23 | └── BOP-Templates
24 | ├──tudl
25 | ├──lmo
26 | ├──ycbv
27 | ├──icbin
28 | ├──hb
29 | ├──itodd
30 | └──tless
31 | ```
32 |
33 |
34 | ## Data Download
35 |
36 | ### Training Datasets
37 | For training the Pose Estimation Model, you may download the rendered images of [c](https://github.com/thodan/bop_toolkit/blob/master/docs/bop_challenge_2023_training_datasets.md) provided by BOP official in the respective `MegaPose-Training-Data/MegaPose-GSO/train_pbr_web` and `MegaPose-Training-Data/MegaPose-ShapeNetCore/train_pbr_web` folders.
38 |
39 | We use the [pre-processed object models](https://www.paris.inria.fr/archive_ylabbeprojectsdata/megapose/tars/) of the two datasets, provided by [MegePose](https://github.com/megapose6d/megapose6d), and download them in the `MegaPose-Training-Data/MegaPose-GSO/google_scanned_objects` and `MegaPose-Training-Data/MegaPose-ShapeNetCore/shapenetcorev2` folders, respectively.
40 |
41 |
42 | ### BOP Datasets
43 | To evaluate our SAM-6D on BOP datasets, you may download the test data and the object CAD models of the seven core datasets from [BOP official](https://bop.felk.cvut.cz/datasets/). For each dataset, the structure could be constructed as follows:
44 |
45 | ```
46 | BOP
47 | ├── lmo
48 | ├──models # object CAD models
49 | ├──test # bop19 test set
50 | ├──(train_pbr) # maybe used in instance segmentation
51 | ...
52 | ...
53 | ```
54 |
55 | You may also download the `train_pbr` data of the datasets for template selection in the Instance Segmentation Model following [CNOS](https://github.com/nv-nguyen/cnos?tab=readme-ov-file).
56 |
57 |
58 |
59 | ## Template Rendering
60 |
61 | ### Requirements
62 |
63 | * blenderproc
64 | * trimesh
65 | * numpy
66 | * cv2
67 |
68 |
69 | ### Template Rendering of Training Objects
70 |
71 | We generate two-view templates for each training object via [Blenderproc](https://github.com/DLR-RM/BlenderProc). You may run the following commands to render the templates for `MegaPose-GSO` dataset:
72 |
73 | ```
74 | cd ../Render/
75 | blenderproc run render_gso_templates.py
76 | ```
77 | and the commands for `shapenetcorev2` dataset:
78 |
79 | ```
80 | cd ../Render/
81 | blenderproc run render_shapenet_templates.py
82 | ```
83 |
84 | ### Template Rendering of Test Objects
85 | We generate 42 templates for each test object following the [CNOS](https://github.com/nv-nguyen/cnos?tab=readme-ov-file) via [Blenderproc](https://github.com/DLR-RM/BlenderProc).
86 |
87 |
88 | #### Custom Objects
89 |
90 | For a custom object, you can run the following commands to render the templates:
91 | ```
92 | cd ../Render/
93 | blenderproc run render_custom_templates.py --cad_path $CAD_PATH --output_dir $OUTPUT_DIR
94 | ```
95 | The string "CAD_PATH" is the path to your CAD and the string "OUTPUT_DIR" is the path to save templates.
96 |
97 |
98 | #### BOP Datasets
99 | You may run the following commands to render the tempates for the objects in BOP datasets:
100 | ```
101 | cd ../Render/
102 | blenderproc run render_bop_templates.py --dataset_name $DATASET
103 | ```
104 | The string "DATASET" could be set as `lmo`, `icbin`, `itodd`, `hb`, `tless`, `tudl` or `ycbv`. We also offer downloadable rendered templates [[link](https://drive.google.com/drive/folders/1fXt5Z6YDPZTJICZcywBUhu5rWnPvYAPI?usp=sharing)].
105 |
106 |
107 | ## Acknowledgements
108 | - [MegaPose](https://github.com/megapose6d/megapose6d)
109 | - [CNOS](https://github.com/nv-nguyen/cnos?tab=readme-ov-file)
110 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Van Nguyen Nguyen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/README.md:
--------------------------------------------------------------------------------
1 | # Instance Segmentation Model (ISM) for SAM-6D
2 |
3 |
4 | ## Requirements
5 | The code has been tested with
6 | - python 3.9.6
7 | - pytorch 2.0.0
8 | - CUDA 11.3
9 |
10 | Create conda environment:
11 |
12 | ```
13 | conda env create -f environment.yml
14 | conda activate sam6d-ism
15 |
16 | # for using SAM
17 | pip install git+https://github.com/facebookresearch/segment-anything.git
18 |
19 | # for using fastSAM
20 | pip install ultralytics==8.0.135
21 | ```
22 |
23 |
24 | ## Data Preparation
25 |
26 | Please refer to [[link](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Data)] for more details.
27 |
28 |
29 | ## Foundation Model Download
30 |
31 | Download model weights of [Segmenting Anything](https://github.com/facebookresearch/segment-anything):
32 | ```
33 | python download_sam.py
34 | ```
35 |
36 | Download model weights of [Fast Segmenting Anything](https://github.com/CASIA-IVA-Lab/FastSAM):
37 | ```
38 | python download_fastsam.py
39 | ```
40 |
41 | Download model weights of ViT pre-trained by [DINOv2](https://github.com/facebookresearch/dinov2):
42 | ```
43 | python download_dinov2.py
44 | ```
45 |
46 |
47 | ## Evaluation on BOP Datasets
48 |
49 | To evaluate the model on BOP datasets, please run the following commands:
50 |
51 | ```
52 | # Specify a specific GPU
53 | export CUDA_VISIBLE_DEVICES=0
54 |
55 | # with sam
56 | python run_inference.py dataset_name=$DATASET
57 |
58 | # with fastsam
59 | python run_inference.py dataset_name=$DATASET model=ISM_fastsam
60 | ```
61 |
62 | The string "DATASET" could be set as `lmo`, `icbin`, `itodd`, `hb`, `tless`, `tudl` or `ycbv`.
63 |
64 |
65 | ## Acknowledgements
66 |
67 | - [CNOS](https://github.com/nv-nguyen/cnos)
68 | - [SAM](https://github.com/facebookresearch/segment-anything)
69 | - [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
70 | - [DINOv2](https://github.com/facebookresearch/dinov2)
71 |
72 |
73 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/callback/base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - checkpoint: base
3 | - lr: base
4 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/callback/checkpoint/base.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
2 |
3 | dirpath: ${save_dir}/checkpoints
4 | save_last: true
5 | verbose: true
6 | save_top_k: -1
7 | every_n_train_steps: 10000 # number of backprop step
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/callback/lr/base.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.callbacks.LearningRateMonitor
2 |
3 | logging_interval: step
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/data/bop.yaml:
--------------------------------------------------------------------------------
1 | root_dir: ${machine.root_dir}/BOP/
2 | source_url: https://bop.felk.cvut.cz/media/data/bop_datasets/
3 | unzip_mode: unzip
4 |
5 | reference_dataloader:
6 | _target_: provider.bop.BOPTemplate
7 | obj_ids:
8 | template_dir: ${machine.root_dir}/BOP/
9 | level_templates: 0
10 | pose_distribution: all
11 | processing_config:
12 | image_size: ${model.descriptor_model.image_size}
13 | max_num_scenes: 10 # config for reference frames selection
14 | max_num_frames: 500
15 | min_visib_fract: 0.8
16 | num_references: 200
17 | use_visible_mask: True
18 |
19 | query_dataloader:
20 | _target_: provider.bop.BaseBOPTest
21 | root_dir: ${machine.root_dir}/BOP/
22 | split:
23 | reset_metaData: True
24 | processing_config:
25 | image_size: ${model.descriptor_model.image_size}
26 |
27 | train_datasets:
28 | megapose-gso:
29 | identifier: bop23_datasets/megapose-gso/gso_models.json
30 | mapping_image_key: /bop23_datasets/megapose-gso/train_pbr_web/key_to_shard.json
31 | prefix: bop23_datasets/megapose-gso/train_pbr_web/
32 | shard_ids: [0, 1039]
33 | megapose-shapenet:
34 | identifier: bop23_datasets/megapose-shapenet/shapenet_models.json
35 | mapping_image_key: bop23_datasets/megapose-shapenet/train_pbr_web/key_to_shard.json
36 | prefix: bop23_datasets/megapose-shapenet/train_pbr_web
37 | shard_ids: [0, 1039]
38 |
39 | datasets:
40 | lm:
41 | cad: lm_models.zip
42 | test: lm_test_bop19.zip
43 | pbr_train: lm_train_pbr.zip
44 | obj_names: [001_ape, 002_benchvise, 003_bowl, 004_camera, 005_can, 006_cat, 007_cup, 008_driller, 009_duck, 010_eggbox, 011_glue, 012_holepuncher, 013_iron, 014_lamp, 015_phone]
45 | lmo:
46 | cad: lmo_models.zip
47 | test: lmo_test_bop19.zip
48 | pbr_train: lm_train_pbr.zip
49 | obj_names: [001_ape, 005_can, 006_cat, 008_driller, 009_duck, 010_eggbox, 011_glue, 012_holepuncher]
50 | tless:
51 | cad: tless_models.zip
52 | test: tless_test_primesense_bop19.zip
53 | pbr_train: tless_train_pbr.zip
54 | obj_names: [001_obj, 002_obj, 003_obj, 004_obj, 005_obj, 006_obj, 007_obj, 008_obj, 009_obj, 010_obj, 011_obj, 012_obj, 013_obj, 014_obj, 015_obj, 016_obj, 017_obj, 018_obj, 019_obj, 020_obj, 021_obj, 022_obj, 023_obj, 024_obj, 025_obj, 026_obj, 027_obj, 028_obj, 029_obj, 030_obj]
55 | itodd:
56 | cad: itodd_models.zip
57 | test: itodd_test_bop19.zip
58 | pbr_train: itodd_train_pbr.zip
59 | obj_names: [001_obj, 002_obj, 003_obj, 004_obj, 005_obj, 006_obj, 007_obj, 008_obj, 009_obj, 010_obj, 011_obj, 012_obj, 013_obj, 014_obj, 015_obj, 016_obj, 017_obj, 018_obj, 019_obj, 020_obj, 021_obj, 022_obj, 023_obj, 024_obj, 025_obj, 026_obj, 027_obj, 028_obj]
60 | hb:
61 | cad: hb_models.zip
62 | test: hb_test_primesense_bop19.zip
63 | pbr_train: hb_train_pbr.zip
64 | obj_names: [001_red_teddy, 002_bench_wise, 003_car, 004_white_cow, 005_white_pig, 006_white_cup, 007_driller, 008_green_rabbit, 009_holepuncher, 010_brown_unknown, 011_brown_unknown, 012_black_unknown, 013_black_unknown, 014_white_painter, 015_small_unkown, 016_small_unkown, 017_small_unkown, 018_cake_box, 019_minion, 020_colored_dog, 021_phone, 022_animal, 023_yellow_dog, 024_cassette_player, 025_red_racing_car, 026_motobike, 027_heels, 028_dinosaur, 029_tea_box, 030_animal, 031_japanese_toy, 032_white_racing_car, 033_yellow_rabbit]
65 | hope:
66 | cad: hope_models.zip
67 | test: hope_test_bop19.zip
68 | obj_names: [001_alphabet_soup, 002_bbq_sauce, 003_butter, 004_cherries, 005_chocolate_pudding, 006_cookies, 007_corn, 008_cream_cheese, 009_granola_bar, 010_green_bean, 011_tomato_sauce, 012_macaroni_cheese, 013_mayo, 014_milk, 015_mushroom, 016_mustard, 017_orange_juice, 018_parmesa_cheese, 019_peaches, 020_peaches_and_carrot, 021_pineapple, 022_popcorn, 023_raisins, 024_ranch_dressing, 025_spaghetti, 026_tomato_sauce, 027_tuna, 028_yogurt]
69 | ycbv:
70 | cad: ycbv_models.zip
71 | test: ycbv_test_bop19.zip
72 | pbr_train: ycbv_train_pbr.zip
73 | obj_names: [002_master_chef_can, 003_cracker_box, 004_sugar_box, 005_tomato_soup_can, 006_mustard_bottle, 007_tuna_fish_can, 008_pudding_box, 009_gelatin_box, 010_potted_meat_can, 011_banana, 019_pitcher_base, 021_bleach_cleanser, 024_bowl, 025_mug, 035_power_drill, 036_wood_block, 037_scissors, 040_large_marker, 051_large_clamp, 052_extra_large_clamp, 061_foam_brick]
74 | ruapc:
75 | cad: ruapc_models.zip
76 | test: ruapc_test_bop19.zip
77 | obj_names: [001_red_copper_box, 002_red_cheezit_box, 003_crayon_box, 004_white_glue, 005_expo_box, 006_greenies, 007_straw_cup, 008_stick_box, 009_highland_sticker, 010_red_tennis_ball, 011_yellow_duck, 012_blue_oreo, 013_pen_box, 014_yellow_standley]
78 | icbin:
79 | cad: icbin_models.zip
80 | test: icbin_test_bop19.zip
81 | pbr_train: icbin_train_pbr.zip
82 | obj_names: [001_blue_cup, 002_blue_box]
83 | icmi:
84 | cad: icmi_models.zip
85 | test: icmi_test_bop19.zip
86 | obj_names: [001_obj, 002_obj, 003_obj, 004_obj, 005_obj, 006_obj]
87 | tudl:
88 | cad: tudl_models.zip
89 | test: tudl_test_bop19.zip
90 | pbr_train: tudl_train_pbr.zip
91 | obj_names: [001_dinosaur, 002_white_ape, 003_white_can]
92 | tyol:
93 | cad: tyol_models.zip
94 | test: tyol_test_bop19.zip
95 | obj_names: [001_obj, 002_obj, 003_obj, 004_obj, 005_obj, 006_obj, 007_obj, 008_obj, 009_obj, 010_obj, 011_obj, 012_obj, 013_obj, 014_obj, 015_obj, 016_obj, 017_obj, 018_obj, 019_obj, 020_obj, 021_obj]
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/download.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - user: default
3 | - machine: local
4 | - data: bop
5 | - _self_
6 | - override hydra/hydra_logging: disabled
7 | - override hydra/job_logging: disabled
8 |
9 | hydra:
10 | output_subdir: null
11 | run:
12 | dir: .
13 |
14 | level: 2
15 | disable_output: true
16 | num_workers: 4
17 | gpus: "0,1,2,3"
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/hydra/hydra_logging/console.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | filters:
4 | onlyimportant:
5 | (): utils.logging.LevelsFilter
6 | levels:
7 | - CRITICAL
8 | - ERROR
9 | - WARNING
10 | noimportant:
11 | (): utils.logging.LevelsFilter
12 | levels:
13 | - INFO
14 | - DEBUG
15 | - NOTSET
16 |
17 | formatters:
18 | simple:
19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
20 | datefmt: '%d/%m/%y %H:%M:%S'
21 |
22 | colorlog:
23 | (): colorlog.ColoredFormatter
24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s]
25 | - %(message)s'
26 | datefmt: '%d/%m/%y %H:%M:%S'
27 |
28 | log_colors:
29 | DEBUG: purple
30 | INFO: green
31 | WARNING: yellow
32 | ERROR: red
33 | CRITICAL: red
34 |
35 | handlers:
36 | console:
37 | class: logging.StreamHandler
38 | formatter: colorlog
39 | stream: ext://sys.stdout
40 |
41 | root:
42 | level: ${logger_level}
43 | handlers:
44 | - console
45 |
46 | disable_existing_loggers: false
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/hydra/hydra_logging/custom.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | formatters:
4 | colorlog:
5 | (): colorlog.ColoredFormatter
6 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s'
7 | datefmt: '%d/%m/%y %H:%M:%S'
8 |
9 | handlers:
10 | console:
11 | class: logging.StreamHandler
12 | formatter: colorlog
13 | stream: ext://sys.stdout
14 |
15 | root:
16 | level: INFO
17 | handlers:
18 | - console
19 |
20 | disable_existing_loggers: false
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/hydra/hydra_logging/rich.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | formatters:
4 | colorlog:
5 | (): colorlog.ColoredFormatter
6 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s'
7 | datefmt: '%d/%m/%y %H:%M:%S'
8 |
9 | handlers:
10 | console:
11 | class: rich.logging.RichHandler # logging.StreamHandler
12 | formatter: colorlog
13 |
14 | root:
15 | level: INFO
16 | handlers:
17 | - console
18 |
19 | disable_existing_loggers: false
20 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/hydra/job_logging/console.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | filters:
4 | onlyimportant:
5 | (): utils.logging.LevelsFilter
6 | levels:
7 | - CRITICAL
8 | - ERROR
9 | - WARNING
10 | noimportant:
11 | (): utils.logging.LevelsFilter
12 | levels:
13 | - INFO
14 | - DEBUG
15 | - NOTSET
16 |
17 | formatters:
18 | simple:
19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
20 | datefmt: '%d/%m/%y %H:%M:%S'
21 |
22 | colorlog:
23 | (): colorlog.ColoredFormatter
24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s]
25 | - %(message)s'
26 | datefmt: '%d/%m/%y %H:%M:%S'
27 |
28 | log_colors:
29 | DEBUG: purple
30 | INFO: green
31 | WARNING: yellow
32 | ERROR: red
33 | CRITICAL: red
34 |
35 | handlers:
36 | console:
37 | class: logging.StreamHandler
38 | formatter: colorlog
39 | stream: ext://sys.stdout
40 |
41 | root:
42 | level: ${logger_level}
43 | handlers:
44 | - console
45 |
46 | disable_existing_loggers: false
47 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/hydra/job_logging/custom.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | filters:
4 | onlyimportant:
5 | (): utils.logging.LevelsFilter
6 | levels:
7 | - CRITICAL
8 | - ERROR
9 | - WARNING
10 | noimportant:
11 | (): utils.logging.LevelsFilter
12 | levels:
13 | - INFO
14 | - DEBUG
15 | - NOTSET
16 |
17 | formatters:
18 | simple:
19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
20 | datefmt: '%d/%m/%y %H:%M:%S'
21 |
22 | colorlog:
23 | (): colorlog.ColoredFormatter
24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s]
25 | - %(message)s'
26 | datefmt: '%d/%m/%y %H:%M:%S'
27 |
28 | log_colors:
29 | DEBUG: purple
30 | INFO: green
31 | WARNING: yellow
32 | ERROR: red
33 | CRITICAL: red
34 |
35 | handlers:
36 | console:
37 | class: logging.StreamHandler
38 | formatter: colorlog
39 | stream: ext://sys.stdout
40 |
41 | file_out:
42 | class: logging.FileHandler
43 | formatter: simple
44 | filename: logs.out
45 | filters:
46 | - noimportant
47 |
48 | file_err:
49 | class: logging.FileHandler
50 | formatter: simple
51 | filename: logs.err
52 | filters:
53 | - onlyimportant
54 |
55 | root:
56 | level: ${logger_level}
57 | handlers:
58 | - console
59 | - file_out
60 | - file_err
61 |
62 | disable_existing_loggers: false
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/hydra/job_logging/rich.yaml:
--------------------------------------------------------------------------------
1 | version: 1
2 |
3 | filters:
4 | onlyimportant:
5 | (): utils.logging.LevelsFilter
6 | levels:
7 | - CRITICAL
8 | - ERROR
9 | - WARNING
10 | noimportant:
11 | (): utils.logging.LevelsFilter
12 | levels:
13 | - INFO
14 | - DEBUG
15 | - NOTSET
16 |
17 | formatters:
18 | verysimple:
19 | format: '%(message)s'
20 | simple:
21 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
22 | datefmt: '%d/%m/%y %H:%M:%S'
23 |
24 | colorlog:
25 | (): colorlog.ColoredFormatter
26 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s]
27 | - %(message)s'
28 | datefmt: '%d/%m/%y %H:%M:%S'
29 |
30 | log_colors:
31 | DEBUG: purple
32 | INFO: green
33 | WARNING: yellow
34 | ERROR: red
35 | CRITICAL: red
36 |
37 | handlers:
38 | console:
39 | class: rich.logging.RichHandler # logging.StreamHandler
40 | formatter: verysimple # colorlog
41 | rich_tracebacks: true
42 |
43 | file_out:
44 | class: logging.FileHandler
45 | formatter: simple
46 | filename: logs.out
47 | filters:
48 | - noimportant
49 |
50 | file_err:
51 | class: logging.FileHandler
52 | formatter: simple
53 | filename: logs.err
54 | filters:
55 | - onlyimportant
56 |
57 | root:
58 | level: ${logger_level}
59 | handlers:
60 | - console
61 | - file_out
62 | - file_err
63 |
64 | disable_existing_loggers: false
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/machine/local.yaml:
--------------------------------------------------------------------------------
1 | name: local
2 |
3 | # specific attributes to this machine
4 | batch_size: 16
5 | num_workers: 10
6 |
7 |
8 | # Define trainer, datasets
9 | defaults:
10 | - trainer: local
11 |
12 | dryrun: True
13 | root_dir: ${user.local_root_dir}
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/machine/slurm.yaml:
--------------------------------------------------------------------------------
1 | name: slurm
2 |
3 | # specific attributes to this machine
4 | batch_size: 16
5 | num_workers: 10
6 |
7 | # Define trainer
8 | defaults:
9 | - trainer: slurm
10 |
11 | dryrun: True
12 | root_dir: ${user.slurm_root_dir}
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/machine/trainer/local.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 |
3 | max_epochs: 1000
4 | accelerator: gpu
5 | deterministic: False
6 | detect_anomaly: False
7 | enable_progress_bar: True
8 | strategy: ddp
9 | precision: 16
10 | accumulate_grad_batches:
11 | val_check_interval: 2000
12 | log_every_n_steps: 1
13 | num_sanity_val_steps: 2
14 | limit_val_batches: 20
15 | limit_test_batches:
16 |
17 | callbacks: "${oc.dict.values: callback}"
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/machine/trainer/slurm.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 |
3 | max_epochs: 1000
4 | accelerator: gpu
5 | deterministic: False
6 | detect_anomaly: False
7 | enable_progress_bar: True
8 | strategy: ddp
9 | precision: 16
10 | accumulate_grad_batches:
11 | val_check_interval: 1000
12 | log_every_n_steps: 1
13 | num_sanity_val_steps: 2
14 | limit_val_batches: 20
15 | limit_test_batches: 1.0
16 |
17 | callbacks: "${oc.dict.values: callback}"
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/model/ISM_fastsam.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.detector.Instance_Segmentation_Model
2 | log_interval: 5
3 | log_dir: ${save_dir}
4 | segmentor_width_size: 640 # make it stable
5 | descriptor_width_size: 640
6 | visible_thred: 0.5
7 | pointcloud_sample_num: 2048
8 |
9 | defaults:
10 | - segmentor_model: fast_sam
11 | - descriptor_model: dinov2
12 |
13 | post_processing_config:
14 | mask_post_processing:
15 | min_box_size: 0.05 # relative to image size
16 | min_mask_size: 3e-4 # relative to image size
17 | nms_thresh: 0.25
18 |
19 | matching_config:
20 | metric:
21 | _target_: model.loss.PairwiseSimilarity
22 | metric: cosine
23 | chunk_size: 16
24 | aggregation_function: avg_5
25 | confidence_thresh: 0.2
26 |
27 | onboarding_config:
28 | rendering_type: pbr
29 | reset_descriptors: False
30 | level_templates: 0 # 0 is coarse, 1 is medium, 2 is dense
31 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/model/ISM_sam.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.detector.Instance_Segmentation_Model
2 | log_interval: 5
3 | log_dir: ${save_dir}
4 | segmentor_width_size: 640 # make it stable
5 | descriptor_width_size: 640
6 | visible_thred: 0.5
7 | pointcloud_sample_num: 2048
8 |
9 | defaults:
10 | - segmentor_model: sam
11 | - descriptor_model: dinov2
12 |
13 | post_processing_config:
14 | mask_post_processing:
15 | min_box_size: 0.05 # relative to image size
16 | min_mask_size: 3e-4 # relative to image size
17 | nms_thresh: 0.25
18 |
19 | matching_config:
20 | metric:
21 | _target_: model.loss.PairwiseSimilarity
22 | metric: cosine
23 | chunk_size: 16
24 | aggregation_function: avg_5
25 | confidence_thresh: 0.2
26 |
27 | onboarding_config:
28 | rendering_type: pbr
29 | reset_descriptors: False
30 | level_templates: 0 # 0 is coarse, 1 is medium, 2 is dense
31 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/model/descriptor_model/dinov2.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.dinov2.CustomDINOv2
2 | model_name: dinov2_vitl14
3 | # model:
4 | # _target_: torch.hub.load
5 | # repo_or_dir: facebookresearch/dinov2
6 | # model: ${model.descriptor_model.model_name}
7 | checkpoint_dir: ./checkpoints/dinov2/
8 | token_name: x_norm_clstoken
9 | descriptor_width_size: ${model.descriptor_width_size}
10 | image_size: 224
11 | chunk_size: 16
12 | validpatch_thresh: 0.5
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/model/segmentor_model/fast_sam.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.fast_sam.FastSAM
2 | checkpoint_path: ./checkpoints/FastSAM/FastSAM-x.pt
3 | segmentor_width_size: ${model.segmentor_width_size}
4 | config:
5 | iou_threshold: 0.9
6 | conf_threshold: 0.05
7 | max_det: 200
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/model/segmentor_model/sam.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.sam.CustomSamAutomaticMaskGenerator
2 | points_per_batch: 64
3 | stability_score_thresh: 0.85
4 | box_nms_thresh: 0.7
5 | min_mask_region_area: 0
6 | crop_overlap_ratio:
7 | pred_iou_thresh: 0.88
8 | segmentor_width_size: ${model.segmentor_width_size}
9 | sam:
10 | _target_: model.sam.load_sam
11 | model_type: vit_h
12 | checkpoint_dir: ./checkpoints/segment-anything/
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/run_inference.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - user: default
3 | - machine: local
4 | - callback: base
5 | - data: bop
6 | - model: ISM_sam
7 | - _self_
8 | - override hydra/hydra_logging: disabled
9 | - override hydra/job_logging: disabled
10 |
11 | hydra:
12 | output_subdir: null
13 | run:
14 | dir: .
15 |
16 | save_dir: ./log/${name_exp}
17 | name_exp: sam
18 | dataset_name:
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/configs/user/default.yaml:
--------------------------------------------------------------------------------
1 | local_root_dir: ../Data
2 | slurm_root_dir: ../Data
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/download_dinov2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import os, sys
4 | import os.path as osp
5 | from utils.inout import get_root_project
6 |
7 | # set level logging
8 | logging.basicConfig(level=logging.INFO)
9 | import logging
10 | import hydra
11 | from omegaconf import DictConfig, OmegaConf
12 | model_dict = {
13 | "dinov2_vits14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth",
14 | "dinov2_vitb14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth",
15 | "dinov2_vitl14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
16 | "dinov2_vitg14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth",
17 | }
18 |
19 | def download_model(url, output_path):
20 | import os
21 |
22 | command = f"wget -O {output_path}/{url.split('/')[-1]} {url} --no-check-certificate"
23 | os.system(command)
24 |
25 |
26 | @hydra.main(
27 | version_base=None,
28 | config_path="./configs",
29 | config_name="download",
30 | )
31 | def download(cfg: DictConfig) -> None:
32 | model_name = "dinov2_vitl14" # default segmentation model used in CNOS
33 | save_dir = osp.join(get_root_project(), "checkpoints/dinov2")
34 | os.makedirs(save_dir, exist_ok=True)
35 | download_model(model_dict[model_name], save_dir)
36 |
37 | if __name__ == "__main__":
38 | download()
39 |
40 |
41 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/download_fastsam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import os, sys
4 | import os.path as osp
5 | from utils.inout import get_root_project
6 |
7 | # set level logging
8 | logging.basicConfig(level=logging.INFO)
9 | import logging
10 | import hydra
11 | from omegaconf import DictConfig, OmegaConf
12 |
13 | def download_model(output_path):
14 | import os
15 | command = f"gdown --no-cookies --no-check-certificate -O '{output_path}/FastSAM-x.pt' 1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv"
16 | os.system(command)
17 |
18 |
19 | @hydra.main(
20 | version_base=None,
21 | config_path="./configs",
22 | config_name="download",
23 | )
24 | def download(cfg: DictConfig) -> None:
25 | save_dir = osp.join(get_root_project(), "checkpoints/FastSAM")
26 | os.makedirs(save_dir, exist_ok=True)
27 | download_model(save_dir)
28 |
29 | if __name__ == "__main__":
30 | download()
31 |
32 |
33 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/download_sam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import os, sys
4 | import os.path as osp
5 | from utils.inout import get_root_project
6 |
7 | # set level logging
8 | logging.basicConfig(level=logging.INFO)
9 | import logging
10 | import hydra
11 | from omegaconf import DictConfig, OmegaConf
12 | model_dict = {
13 | "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", # 2560 MB
14 | "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", # 1250 MB
15 | "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
16 | } # 375 GB
17 |
18 | def download_model(url, output_path):
19 | import os
20 |
21 | command = f"wget -O {output_path}/{url.split('/')[-1]} {url} --no-check-certificate"
22 | os.system(command)
23 |
24 |
25 | @hydra.main(
26 | version_base=None,
27 | config_path="./configs",
28 | config_name="download",
29 | )
30 | def download(cfg: DictConfig) -> None:
31 | model_name = "vit_h" # default segmentation model used in CNOS
32 | save_dir = osp.join(get_root_project(), "checkpoints/segment-anything")
33 | os.makedirs(save_dir, exist_ok=True)
34 | download_model(model_dict[model_name], save_dir)
35 |
36 | if __name__ == "__main__":
37 | download()
38 |
39 |
40 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/environment.yml:
--------------------------------------------------------------------------------
1 | name: sam6d-ism
2 | channels:
3 | - xformers
4 | - pytorch
5 | - nvidia
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - pip
10 | - python=3.9.6
11 | - pip:
12 | - torch
13 | - torchvision
14 | - omegaconf
15 | - torchmetrics==0.10.3
16 | - fvcore
17 | - iopath
18 | - xformers==0.0.18
19 | - opencv-python
20 | - pycocotools
21 | - matplotlib
22 | - onnxruntime
23 | - onnx
24 | - scipy
25 | - ffmpeg
26 | - hydra-colorlog
27 | - hydra-core
28 | - gdown
29 | - pytorch-lightning==1.8.1
30 | - pandas
31 | - ruamel.yaml
32 | - pyrender
33 | - wandb
34 | - distinctipy
35 | - imageio
36 | # bop https://github.com/thodan/bop_toolkit.git
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/exp.sh:
--------------------------------------------------------------------------------
1 | # with sam
2 | python run_inference.py dataset_name=icbin
3 |
4 | # with fastsam
5 | python run_inference.py dataset_name=icbin model=ISM_fastsam
6 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/model/__init__.py
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/fast_sam.py:
--------------------------------------------------------------------------------
1 | from ultralytics import YOLO
2 | from pathlib import Path
3 | from typing import Union
4 | import numpy as np
5 | import cv2
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from segment_anything.utils.amg import MaskData
10 | import logging
11 | import os.path as osp
12 | from typing import Any, Dict, List, Optional, Tuple
13 | import pytorch_lightning as pl
14 | from ultralytics import yolo # noqa
15 | from ultralytics.nn.autobackend import AutoBackend
16 |
17 |
18 | class CustomYOLO(YOLO):
19 | def __init__(
20 | self,
21 | model,
22 | iou,
23 | conf,
24 | max_det,
25 | segmentor_width_size,
26 | selected_device="cpu",
27 | verbose=False,
28 | ):
29 | YOLO.__init__(
30 | self,
31 | model,
32 | )
33 | self.overrides["iou"] = iou
34 | self.overrides["conf"] = conf
35 | self.overrides["max_det"] = max_det
36 | self.overrides["verbose"] = verbose
37 | self.overrides["imgsz"] = segmentor_width_size
38 |
39 | self.overrides["conf"] = 0.25
40 | self.overrides["mode"] = "predict"
41 | self.overrides["save"] = False
42 |
43 | self.predictor = yolo.v8.segment.SegmentationPredictor(
44 | overrides=self.overrides, _callbacks=self.callbacks
45 | )
46 |
47 | self.not_setup = True
48 | self.selected_device = selected_device
49 | logging.info(f"Init CustomYOLO done!")
50 |
51 | def setup_model(self, device, verbose=False):
52 | """Initialize YOLO model with given parameters and set it to evaluation mode."""
53 | model = self.predictor.model or self.predictor.args.model
54 | self.predictor.args.half &= (
55 | device.type != "cpu"
56 | ) # half precision only supported on CUDA
57 | self.predictor.model = AutoBackend(
58 | model,
59 | device=device,
60 | dnn=self.predictor.args.dnn,
61 | data=self.predictor.args.data,
62 | fp16=self.predictor.args.half,
63 | fuse=True,
64 | verbose=verbose,
65 | )
66 | self.predictor.device = device
67 | self.predictor.model.eval()
68 | logging.info(f"Setup model at device {device} done!")
69 |
70 | def __call__(self, source=None, stream=False):
71 | return self.predictor(source=source, stream=stream)
72 |
73 |
74 | class FastSAM(object):
75 | def __init__(
76 | self,
77 | checkpoint_path: Union[str, Path],
78 | config: dict = None,
79 | segmentor_width_size=None,
80 | device=None,
81 | ):
82 | self.model = CustomYOLO(
83 | model=checkpoint_path,
84 | iou=config.iou_threshold,
85 | conf=config.conf_threshold,
86 | max_det=config.max_det,
87 | selected_device=device,
88 | segmentor_width_size=segmentor_width_size,
89 | )
90 | self.segmentor_width_size = segmentor_width_size
91 | self.current_device = device
92 | logging.info(f"Init FastSAM done!")
93 |
94 | def postprocess_resize(self, detections, orig_size, update_boxes=False):
95 | detections["masks"] = F.interpolate(
96 | detections["masks"].unsqueeze(1).float(),
97 | size=(orig_size[0], orig_size[1]),
98 | mode="bilinear",
99 | align_corners=False,
100 | )[:, 0, :, :]
101 | if update_boxes:
102 | scale = orig_size[1] / self.segmentor_width_size
103 | detections["boxes"] = detections["boxes"].float() * scale
104 | detections["boxes"][:, [0, 2]] = torch.clamp(
105 | detections["boxes"][:, [0, 2]], 0, orig_size[1] - 1
106 | )
107 | detections["boxes"][:, [1, 3]] = torch.clamp(
108 | detections["boxes"][:, [1, 3]], 0, orig_size[0] - 1
109 | )
110 | return detections
111 |
112 | @torch.no_grad()
113 | def generate_masks(self, image) -> List[Dict[str, Any]]:
114 | if self.segmentor_width_size is not None:
115 | orig_size = image.shape[:2]
116 | detections = self.model(image)
117 |
118 | masks = detections[0].masks.data
119 | boxes = detections[0].boxes.data[:, :4] # two lasts: confidence and class
120 |
121 | # define class data
122 | mask_data = {
123 | "masks": masks.to(self.current_device),
124 | "boxes": boxes.to(self.current_device),
125 | }
126 | if self.segmentor_width_size is not None:
127 | mask_data = self.postprocess_resize(mask_data, orig_size)
128 | return mask_data
129 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .dino_head import DINOHead
8 | from .mlp import Mlp
9 | from .patch_embed import PatchEmbed
10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
11 | from .block import NestedTensorBlock
12 | from .attention import MemEffAttention
13 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | import logging
12 |
13 | from torch import Tensor
14 | from torch import nn
15 |
16 |
17 | logger = logging.getLogger("dinov2")
18 |
19 |
20 | try:
21 | from xformers.ops import memory_efficient_attention, unbind, fmha
22 |
23 | XFORMERS_AVAILABLE = True
24 | except ImportError:
25 | logger.warning("xFormers not available")
26 | XFORMERS_AVAILABLE = False
27 |
28 |
29 | class Attention(nn.Module):
30 | def __init__(
31 | self,
32 | dim: int,
33 | num_heads: int = 8,
34 | qkv_bias: bool = False,
35 | proj_bias: bool = True,
36 | attn_drop: float = 0.0,
37 | proj_drop: float = 0.0,
38 | ) -> None:
39 | super().__init__()
40 | self.num_heads = num_heads
41 | head_dim = dim // num_heads
42 | self.scale = head_dim**-0.5
43 |
44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45 | self.attn_drop = nn.Dropout(attn_drop)
46 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
47 | self.proj_drop = nn.Dropout(proj_drop)
48 |
49 | def forward(self, x: Tensor) -> Tensor:
50 | B, N, C = x.shape
51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52 |
53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54 | attn = q @ k.transpose(-2, -1)
55 |
56 | attn = attn.softmax(dim=-1)
57 | attn = self.attn_drop(attn)
58 |
59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60 | x = self.proj(x)
61 | x = self.proj_drop(x)
62 | return x
63 |
64 |
65 | class MemEffAttention(Attention):
66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67 | if not XFORMERS_AVAILABLE:
68 | assert attn_bias is None, "xFormers is required for nested tensors usage"
69 | return super().forward(x)
70 |
71 | B, N, C = x.shape
72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73 |
74 | q, k, v = unbind(qkv, 2)
75 |
76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77 | x = x.reshape([B, N, C])
78 |
79 | x = self.proj(x)
80 | x = self.proj_drop(x)
81 | return x
82 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/layers/dino_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.init import trunc_normal_
10 | from torch.nn.utils import weight_norm
11 |
12 |
13 | class DINOHead(nn.Module):
14 | def __init__(
15 | self,
16 | in_dim,
17 | out_dim,
18 | use_bn=False,
19 | nlayers=3,
20 | hidden_dim=2048,
21 | bottleneck_dim=256,
22 | mlp_bias=True,
23 | ):
24 | super().__init__()
25 | nlayers = max(nlayers, 1)
26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
27 | self.apply(self._init_weights)
28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
29 | self.last_layer.weight_g.data.fill_(1)
30 |
31 | def _init_weights(self, m):
32 | if isinstance(m, nn.Linear):
33 | trunc_normal_(m.weight, std=0.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 |
37 | def forward(self, x):
38 | x = self.mlp(x)
39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12
40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
41 | x = self.last_layer(x)
42 | return x
43 |
44 |
45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
46 | if nlayers == 1:
47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias)
48 | else:
49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
50 | if use_bn:
51 | layers.append(nn.BatchNorm1d(hidden_dim))
52 | layers.append(nn.GELU())
53 | for _ in range(nlayers - 2):
54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
55 | if use_bn:
56 | layers.append(nn.BatchNorm1d(hidden_dim))
57 | layers.append(nn.GELU())
58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
59 | return nn.Sequential(*layers)
60 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10 |
11 |
12 | from torch import nn
13 |
14 |
15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16 | if drop_prob == 0.0 or not training:
17 | return x
18 | keep_prob = 1 - drop_prob
19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21 | if keep_prob > 0.0:
22 | random_tensor.div_(keep_prob)
23 | output = x * random_tensor
24 | return output
25 |
26 |
27 | class DropPath(nn.Module):
28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29 |
30 | def __init__(self, drop_prob=None):
31 | super(DropPath, self).__init__()
32 | self.drop_prob = drop_prob
33 |
34 | def forward(self, x):
35 | return drop_path(x, self.drop_prob, self.training)
36 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8 |
9 | from typing import Union
10 |
11 | import torch
12 | from torch import Tensor
13 | from torch import nn
14 |
15 |
16 | class LayerScale(nn.Module):
17 | def __init__(
18 | self,
19 | dim: int,
20 | init_values: Union[float, Tensor] = 1e-5,
21 | inplace: bool = False,
22 | ) -> None:
23 | super().__init__()
24 | self.inplace = inplace
25 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
29 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10 |
11 |
12 | from typing import Callable, Optional
13 |
14 | from torch import Tensor, nn
15 |
16 |
17 | class Mlp(nn.Module):
18 | def __init__(
19 | self,
20 | in_features: int,
21 | hidden_features: Optional[int] = None,
22 | out_features: Optional[int] = None,
23 | act_layer: Callable[..., nn.Module] = nn.GELU,
24 | drop: float = 0.0,
25 | bias: bool = True,
26 | ) -> None:
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | from typing import Callable, Optional, Tuple, Union
12 |
13 | from torch import Tensor
14 | import torch.nn as nn
15 |
16 |
17 | def make_2tuple(x):
18 | if isinstance(x, tuple):
19 | assert len(x) == 2
20 | return x
21 |
22 | assert isinstance(x, int)
23 | return (x, x)
24 |
25 |
26 | class PatchEmbed(nn.Module):
27 | """
28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29 |
30 | Args:
31 | img_size: Image size.
32 | patch_size: Patch token size.
33 | in_chans: Number of input image channels.
34 | embed_dim: Number of linear projection output channels.
35 | norm_layer: Normalization layer.
36 | """
37 |
38 | def __init__(
39 | self,
40 | img_size: Union[int, Tuple[int, int]] = 224,
41 | patch_size: Union[int, Tuple[int, int]] = 16,
42 | in_chans: int = 3,
43 | embed_dim: int = 768,
44 | norm_layer: Optional[Callable] = None,
45 | flatten_embedding: bool = True,
46 | ) -> None:
47 | super().__init__()
48 |
49 | image_HW = make_2tuple(img_size)
50 | patch_HW = make_2tuple(patch_size)
51 | patch_grid_size = (
52 | image_HW[0] // patch_HW[0],
53 | image_HW[1] // patch_HW[1],
54 | )
55 |
56 | self.img_size = image_HW
57 | self.patch_size = patch_HW
58 | self.patches_resolution = patch_grid_size
59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60 |
61 | self.in_chans = in_chans
62 | self.embed_dim = embed_dim
63 |
64 | self.flatten_embedding = flatten_embedding
65 |
66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | _, _, H, W = x.shape
71 | patch_H, patch_W = self.patch_size
72 |
73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75 |
76 | x = self.proj(x) # B C H W
77 | H, W = x.size(2), x.size(3)
78 | x = x.flatten(2).transpose(1, 2) # B HW C
79 | x = self.norm(x)
80 | if not self.flatten_embedding:
81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82 | return x
83 |
84 | def flops(self) -> float:
85 | Ho, Wo = self.patches_resolution
86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87 | if self.norm is not None:
88 | flops += Ho * Wo * self.embed_dim
89 | return flops
90 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Callable, Optional
8 |
9 | from torch import Tensor, nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class SwiGLUFFN(nn.Module):
14 | def __init__(
15 | self,
16 | in_features: int,
17 | hidden_features: Optional[int] = None,
18 | out_features: Optional[int] = None,
19 | act_layer: Callable[..., nn.Module] = None,
20 | drop: float = 0.0,
21 | bias: bool = True,
22 | ) -> None:
23 | super().__init__()
24 | out_features = out_features or in_features
25 | hidden_features = hidden_features or in_features
26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | x12 = self.w12(x)
31 | x1, x2 = x12.chunk(2, dim=-1)
32 | hidden = F.silu(x1) * x2
33 | return self.w3(hidden)
34 |
35 |
36 | try:
37 | from xformers.ops import SwiGLU
38 |
39 | XFORMERS_AVAILABLE = True
40 | except ImportError:
41 | SwiGLU = SwiGLUFFN
42 | XFORMERS_AVAILABLE = False
43 |
44 |
45 | class SwiGLUFFNFused(SwiGLU):
46 | def __init__(
47 | self,
48 | in_features: int,
49 | hidden_features: Optional[int] = None,
50 | out_features: Optional[int] = None,
51 | act_layer: Callable[..., nn.Module] = None,
52 | drop: float = 0.0,
53 | bias: bool = True,
54 | ) -> None:
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58 | super().__init__(
59 | in_features=in_features,
60 | hidden_features=hidden_features,
61 | out_features=out_features,
62 | bias=bias,
63 | )
64 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/loss.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from utils.poses.pose_utils import load_rotation_transform, convert_openCV_to_openGL_torch
4 | import torch.nn.functional as F
5 | from model.utils import BatchedData
6 |
7 |
8 | class Similarity(nn.Module):
9 | def __init__(self, metric="cosine", chunk_size=64):
10 | super(Similarity, self).__init__()
11 | self.metric = metric
12 | self.chunk_size = chunk_size
13 |
14 | def forward(self, query, reference):
15 | query = F.normalize(query, dim=-1)
16 | reference = F.normalize(reference, dim=-1)
17 | similarity = F.cosine_similarity(query, reference, dim=-1)
18 | return similarity.clamp(min=0.0, max=1.0)
19 |
20 |
21 | class PairwiseSimilarity(nn.Module):
22 | def __init__(self, metric="cosine", chunk_size=64):
23 | super(PairwiseSimilarity, self).__init__()
24 | self.metric = metric
25 | self.chunk_size = chunk_size
26 |
27 | def forward(self, query, reference):
28 | N_query = query.shape[0]
29 | N_objects, N_templates = reference.shape[0], reference.shape[1]
30 | references = reference.clone().unsqueeze(0).repeat(N_query, 1, 1, 1)
31 | queries = query.clone().unsqueeze(1).repeat(1, N_templates, 1)
32 | queries = F.normalize(queries, dim=-1)
33 | references = F.normalize(references, dim=-1)
34 |
35 | similarity = BatchedData(batch_size=None)
36 | for idx_obj in range(N_objects):
37 | sim = F.cosine_similarity(
38 | queries, references[:, idx_obj], dim=-1
39 | ) # N_query x N_templates
40 | similarity.append(sim)
41 | similarity.stack()
42 | similarity = similarity.data
43 | similarity = similarity.permute(1, 0, 2) # N_query x N_objects x N_templates
44 | return similarity.clamp(min=0.0, max=1.0)
45 |
46 | class MaskedPatch_MatrixSimilarity(nn.Module):
47 | def __init__(self, metric="cosine", chunk_size=64):
48 | super(MaskedPatch_MatrixSimilarity, self).__init__()
49 | self.metric = metric
50 | self.chunk_size = chunk_size
51 |
52 | def compute_straight(self, query, reference):
53 | (N_query, N_patch, N_features) = query.shape
54 | sim_matrix = torch.matmul(query, reference.permute(0, 2, 1)) # N_query x N_query_mask x N_refer_mask
55 |
56 | # N2_ref score max
57 | max_ref_patch_score = torch.max(sim_matrix, dim=-1).values
58 | # N1_query score average
59 | factor = torch.count_nonzero(query.sum(dim=-1), dim=-1) + 1e-6
60 | scores = torch.sum(max_ref_patch_score, dim=-1) / factor # N_query x N_objects x N_templates
61 |
62 | return scores.clamp(min=0.0, max=1.0)
63 |
64 | def compute_visible_ratio(self, query, reference, thred=0.5):
65 |
66 | sim_matrix = torch.matmul(query, reference.permute(0, 2, 1)) # N_query x N_query_mask x N_refer_mask
67 | sim_matrix = sim_matrix.max(1)[0] # N_query x N_refer_mask
68 | valid_patches = torch.count_nonzero(sim_matrix, dim=(1, )) + 1e-6
69 |
70 | # fliter correspendence by thred
71 | flitered_matrix = sim_matrix * (sim_matrix > thred)
72 | sim_patches = torch.count_nonzero(flitered_matrix, dim=(1,))
73 |
74 | visible_ratio = sim_patches / valid_patches
75 |
76 | return visible_ratio
77 |
78 | def compute_similarity(self, query, reference):
79 | # all template computation
80 | N_query = query.shape[0]
81 | N_objects, N_templates = reference.shape[0], reference.shape[1]
82 | references = reference.unsqueeze(0).repeat(N_query, 1, 1, 1, 1)
83 | queries = query.unsqueeze(1).repeat(1, N_templates, 1, 1)
84 |
85 | similarity = BatchedData(batch_size=None)
86 | for idx_obj in range(N_objects):
87 | sim_matrix = torch.matmul(queries, references[:, idx_obj].permute(0, 1, 3, 2)) # N_query x N_templates x N_query_mask x N_refer_mask
88 | similarity.append(sim_matrix)
89 | similarity.stack()
90 | similarity = similarity.data
91 | similarity = similarity.permute(1, 0, 2, 3, 4) # N_query x N_objects x N_templates x N1_query x N2_ref
92 |
93 | # N2_ref score max
94 | max_ref_patch_score = torch.max(similarity, dim=-1).values
95 | # N1_query score average
96 | factor = torch.count_nonzero(query.sum(dim=-1), dim=-1)[:, None, None]
97 | scores = torch.sum(max_ref_patch_score, dim=-1) / factor # N_query x N_objects x N_templates
98 |
99 | return scores.clamp(min=0.0, max=1.0)
100 |
101 | def forward_by_chunk(self, query, reference):
102 | # divide by N_query
103 | batch_query = BatchedData(batch_size=self.chunk_size, data=query)
104 | del query
105 | scores = BatchedData(batch_size=self.chunk_size)
106 | for idx_batch in range(len(batch_query)):
107 | score = self.compute_similarity(batch_query[idx_batch], reference)
108 | scores.cat(score)
109 | return scores.data
110 |
111 | def forward(self, qurey, reference):
112 | if qurey.shape[0] > self.chunk_size:
113 | scores = self.forward_by_chunk(qurey, reference)
114 | else:
115 | scores = self.compute_similarity(qurey, reference)
116 | return scores
117 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/sam.py:
--------------------------------------------------------------------------------
1 | from segment_anything import (
2 | sam_model_registry,
3 | SamPredictor,
4 | SamAutomaticMaskGenerator,
5 | )
6 | from segment_anything.modeling import Sam
7 | from segment_anything.utils.amg import MaskData, generate_crop_boxes, rle_to_mask
8 | import logging
9 | import numpy as np
10 | import torch
11 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore
12 | import os.path as osp
13 | from typing import Any, Dict, List, Optional, Tuple
14 | import cv2
15 | import torch.nn.functional as F
16 |
17 | pretrained_weight_dict = {
18 | "vit_l": "sam_vit_l_0b3195.pth", # 1250MB
19 | "vit_b": "sam_vit_b_01ec64.pth", # 375MB
20 | "vit_h": "sam_vit_h_4b8939.pth", # 2500MB
21 | }
22 |
23 |
24 | def load_sam(model_type, checkpoint_dir):
25 | logging.info(f"Loading SAM model from {checkpoint_dir}")
26 | sam = sam_model_registry[model_type](
27 | checkpoint=osp.join(checkpoint_dir, pretrained_weight_dict[model_type])
28 | )
29 | return sam
30 |
31 |
32 | def load_sam_predictor(model_type, checkpoint_dir, device):
33 | logging.info(f"Loading SAM model from {checkpoint_dir}")
34 | sam = sam_model_registry[model_type](
35 | checkpoint=osp.join(checkpoint_dir, pretrained_weight_dict[model_type])
36 | )
37 | sam.to(device=device)
38 | predictor = SamPredictor(sam)
39 | return predictor
40 |
41 |
42 | def load_sam_mask_generator(model_type, checkpoint_dir, device):
43 | logging.info(f"Loading SAM model from {checkpoint_dir}")
44 | sam = sam_model_registry[model_type](
45 | checkpoint=osp.join(checkpoint_dir, pretrained_weight_dict[model_type])
46 | )
47 | sam.to(device=device)
48 | mask_generator = SamAutomaticMaskGenerator(sam, output_mode="coco_rle")
49 | return mask_generator
50 |
51 |
52 | class CustomSamAutomaticMaskGenerator(SamAutomaticMaskGenerator):
53 | def __init__(
54 | self,
55 | sam: Sam,
56 | min_mask_region_area: int = 0,
57 | points_per_batch: int = 64,
58 | stability_score_thresh: float = 0.85,
59 | box_nms_thresh: float = 0.7,
60 | crop_overlap_ratio: float = 512 / 1500,
61 | segmentor_width_size=None,
62 | pred_iou_thresh: float = 0.88,
63 | ):
64 | SamAutomaticMaskGenerator.__init__(
65 | self,
66 | sam,
67 | min_mask_region_area=min_mask_region_area,
68 | points_per_batch=points_per_batch,
69 | stability_score_thresh=stability_score_thresh,
70 | box_nms_thresh=box_nms_thresh,
71 | crop_overlap_ratio=crop_overlap_ratio,
72 | pred_iou_thresh=pred_iou_thresh
73 | )
74 | self.segmentor_width_size = segmentor_width_size
75 | logging.info(f"Init CustomSamAutomaticMaskGenerator done!")
76 |
77 | def preprocess_resize(self, image: np.ndarray):
78 | orig_size = image.shape[:2]
79 | height_size = int(self.segmentor_width_size * orig_size[0] / orig_size[1])
80 | resized_image = cv2.resize(
81 | image.copy(), (self.segmentor_width_size, height_size) # (width, height)
82 | )
83 | return resized_image
84 |
85 | def postprocess_resize(self, detections, orig_size):
86 | detections["masks"] = F.interpolate(
87 | detections["masks"].unsqueeze(1).float(),
88 | size=(orig_size[0], orig_size[1]),
89 | mode="bilinear",
90 | align_corners=False,
91 | )[:, 0, :, :]
92 | scale = orig_size[1] / self.segmentor_width_size
93 | detections["boxes"] = detections["boxes"].float() * scale
94 | detections["boxes"][:, [0, 2]] = torch.clamp(
95 | detections["boxes"][:, [0, 2]], 0, orig_size[1] - 1
96 | )
97 | detections["boxes"][:, [1, 3]] = torch.clamp(
98 | detections["boxes"][:, [1, 3]], 0, orig_size[0] - 1
99 | )
100 | return detections
101 |
102 | @torch.no_grad()
103 | def generate_masks(self, image: np.ndarray) -> List[Dict[str, Any]]:
104 | if self.segmentor_width_size is not None:
105 | orig_size = image.shape[:2]
106 | image = self.preprocess_resize(image)
107 | # Generate masks
108 | mask_data = self._generate_masks(image)
109 |
110 | # Filter small disconnected regions and holes in masks
111 | if self.min_mask_region_area > 0:
112 | mask_data = self.postprocess_small_regions(
113 | mask_data,
114 | self.min_mask_region_area,
115 | max(self.box_nms_thresh, self.crop_nms_thresh),
116 | )
117 | if self.segmentor_width_size is not None:
118 | mask_data = self.postprocess_resize(mask_data, orig_size)
119 | return mask_data
120 |
121 | def _generate_masks(self, image: np.ndarray) -> MaskData:
122 | orig_size = image.shape[:2]
123 | crop_boxes, layer_idxs = generate_crop_boxes(
124 | orig_size, self.crop_n_layers, self.crop_overlap_ratio
125 | )
126 |
127 | # Iterate over image crops
128 | data = MaskData()
129 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
130 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
131 | data.cat(crop_data)
132 |
133 | # Remove duplicate masks between crops
134 | if len(crop_boxes) > 1:
135 | # Prefer masks from smaller crops
136 | scores = 1 / box_area(data["crop_boxes"])
137 | scores = scores.to(data["boxes"].device)
138 | keep_by_nms = batched_nms(
139 | data["boxes"].float(),
140 | scores,
141 | torch.zeros_like(data["boxes"][:, 0]), # categories
142 | iou_threshold=self.crop_nms_thresh,
143 | )
144 | data.filter(keep_by_nms)
145 |
146 | data["masks"] = [torch.from_numpy(rle_to_mask(rle)) for rle in data["rles"]]
147 | data["masks"] = torch.stack(data["masks"])
148 | return {"masks": data["masks"].to(data["boxes"].device), "boxes": data["boxes"]}
149 |
150 | def remove_small_detections(self, mask_data: MaskData, img_size: List) -> MaskData:
151 | # calculate area and number of pixels in each mask
152 | area = box_area(mask_data["boxes"]) / (img_size[0] * img_size[1])
153 | idx_selected = area >= self.mask_post_processing.min_box_size
154 | mask_data.filter(idx_selected)
155 | return mask_data
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/model/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torchvision
4 | from torchvision.ops.boxes import batched_nms, box_area
5 | import logging
6 | from utils.inout import save_json, load_json, save_npz
7 | from utils.bbox_utils import xyxy_to_xywh, xywh_to_xyxy, force_binary_mask
8 | import time
9 | from PIL import Image
10 |
11 | lmo_object_ids = np.array(
12 | [
13 | 1,
14 | 5,
15 | 6,
16 | 8,
17 | 9,
18 | 10,
19 | 11,
20 | 12,
21 | ]
22 | ) # object ID of occlusionLINEMOD is different
23 |
24 |
25 | def mask_to_rle(binary_mask):
26 | rle = {"counts": [], "size": list(binary_mask.shape)}
27 | counts = rle.get("counts")
28 |
29 | last_elem = 0
30 | running_length = 0
31 |
32 | for i, elem in enumerate(binary_mask.ravel(order="F")):
33 | if elem == last_elem:
34 | pass
35 | else:
36 | counts.append(running_length)
37 | running_length = 0
38 | last_elem = elem
39 | running_length += 1
40 |
41 | counts.append(running_length)
42 |
43 | return rle
44 |
45 |
46 | class BatchedData:
47 | """
48 | A structure for storing data in batched format.
49 | Implements basic filtering and concatenation.
50 | """
51 |
52 | def __init__(self, batch_size, data=None, **kwargs) -> None:
53 | self.batch_size = batch_size
54 | if data is not None:
55 | self.data = data
56 | else:
57 | self.data = []
58 |
59 | def __len__(self):
60 | assert self.batch_size is not None, "batch_size is not defined"
61 | return np.ceil(len(self.data) / self.batch_size).astype(int)
62 |
63 | def __getitem__(self, idx):
64 | assert self.batch_size is not None, "batch_size is not defined"
65 | return self.data[idx * self.batch_size : (idx + 1) * self.batch_size]
66 |
67 | def cat(self, data, dim=0):
68 | if len(self.data) == 0:
69 | self.data = data
70 | else:
71 | self.data = torch.cat([self.data, data], dim=dim)
72 |
73 | def append(self, data):
74 | self.data.append(data)
75 |
76 | def stack(self, dim=0):
77 | self.data = torch.stack(self.data, dim=dim)
78 |
79 |
80 | class Detections:
81 | """
82 | A structure for storing detections.
83 | """
84 |
85 | def __init__(self, data) -> None:
86 | if isinstance(data, str):
87 | data = self.load_from_file(data)
88 | for key, value in data.items():
89 | setattr(self, key, value)
90 | self.keys = list(data.keys())
91 | if "boxes" in self.keys:
92 | if isinstance(self.boxes, np.ndarray):
93 | self.to_torch()
94 | self.boxes = self.boxes.long()
95 |
96 | def remove_very_small_detections(self, config):
97 | img_area = self.masks.shape[1] * self.masks.shape[2]
98 | box_areas = box_area(self.boxes) / img_area
99 | mask_areas = self.masks.sum(dim=(1, 2)) / img_area
100 | keep_idxs = torch.logical_and(
101 | box_areas > config.min_box_size**2, mask_areas > config.min_mask_size
102 | )
103 | # logging.info(f"Removing {len(keep_idxs) - keep_idxs.sum()} detections")
104 | for key in self.keys:
105 | setattr(self, key, getattr(self, key)[keep_idxs])
106 |
107 | def apply_nms_per_object_id(self, nms_thresh=0.5):
108 | keep_idxs = BatchedData(None)
109 | all_indexes = torch.arange(len(self.object_ids), device=self.boxes.device)
110 | for object_id in torch.unique(self.object_ids):
111 | idx = self.object_ids == object_id
112 | idx_object_id = all_indexes[idx]
113 | keep_idx = torchvision.ops.nms(
114 | self.boxes[idx].float(), self.scores[idx].float(), nms_thresh
115 | )
116 | keep_idxs.cat(idx_object_id[keep_idx])
117 | keep_idxs = keep_idxs.data
118 | for key in self.keys:
119 | setattr(self, key, getattr(self, key)[keep_idxs])
120 |
121 | def apply_nms(self, nms_thresh=0.5):
122 | keep_idx = torchvision.ops.nms(
123 | self.boxes.float(), self.scores.float(), nms_thresh
124 | )
125 | for key in self.keys:
126 | setattr(self, key, getattr(self, key)[keep_idx])
127 |
128 | def add_attribute(self, key, value):
129 | setattr(self, key, value)
130 | self.keys.append(key)
131 |
132 | def __len__(self):
133 | return len(self.boxes)
134 |
135 | def check_size(self):
136 | mask_size = len(self.masks)
137 | box_size = len(self.boxes)
138 | score_size = len(self.scores)
139 | object_id_size = len(self.object_ids)
140 | assert (
141 | mask_size == box_size == score_size == object_id_size
142 | ), f"Size mismatch {mask_size} {box_size} {score_size} {object_id_size}"
143 |
144 | def to_numpy(self):
145 | for key in self.keys:
146 | setattr(self, key, getattr(self, key).cpu().numpy())
147 |
148 | def to_torch(self):
149 | for key in self.keys:
150 | a = getattr(self, key)
151 | setattr(self, key, torch.from_numpy(getattr(self, key)))
152 |
153 | def save_to_file(
154 | self, scene_id, frame_id, runtime, file_path, dataset_name, return_results=False
155 | ):
156 | """
157 | scene_id, image_id, category_id, bbox, time
158 | """
159 | boxes = xyxy_to_xywh(self.boxes)
160 | results = {
161 | "scene_id": scene_id,
162 | "image_id": frame_id,
163 | "category_id": self.object_ids + 1
164 | if dataset_name != "lmo"
165 | else lmo_object_ids[self.object_ids],
166 | "score": self.scores,
167 | "bbox": boxes,
168 | "time": runtime,
169 | "segmentation": self.masks,
170 | }
171 | save_npz(file_path, results)
172 | if return_results:
173 | return results
174 |
175 | def load_from_file(self, file_path):
176 | data = np.load(file_path)
177 | masks = data["segmentation"]
178 | boxes = xywh_to_xyxy(np.array(data["bbox"]))
179 | data = {
180 | "object_ids": data["category_id"] - 1,
181 | "bbox": boxes,
182 | "scores": data["score"],
183 | "masks": masks,
184 | }
185 | logging.info(f"Loaded {file_path}")
186 | return data
187 |
188 | def filter(self, idxs):
189 | for key in self.keys:
190 | setattr(self, key, getattr(self, key)[idxs])
191 |
192 | def clone(self):
193 | """
194 | Clone the current object
195 | """
196 | return Detections(self.__dict__.copy())
197 |
198 |
199 | def convert_npz_to_json(idx, list_npz_paths):
200 | npz_path = list_npz_paths[idx]
201 | detections = np.load(npz_path)
202 | results = []
203 | for idx_det in range(len(detections["bbox"])):
204 | result = {
205 | "scene_id": int(detections["scene_id"]),
206 | "image_id": int(detections["image_id"]),
207 | "category_id": int(detections["category_id"][idx_det]),
208 | "bbox": detections["bbox"][idx_det].tolist(),
209 | "score": float(detections["score"][idx_det]),
210 | "time": float(detections["time"]),
211 | "segmentation": mask_to_rle(
212 | force_binary_mask(detections["segmentation"][idx_det])
213 | ),
214 | }
215 | results.append(result)
216 | return results
217 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/provider/bop.py:
--------------------------------------------------------------------------------
1 | import logging, os
2 | import os.path as osp
3 | from tqdm import tqdm
4 | import time
5 | import numpy as np
6 | import torchvision.transforms as T
7 | from pathlib import Path
8 | from PIL import Image
9 | from torch.utils.data import Dataset
10 | import os.path as osp
11 | from utils.poses.pose_utils import load_index_level_in_level2
12 | import torch
13 | from utils.bbox_utils import CropResizePad
14 | import pytorch_lightning as pl
15 | from provider.base_bop import BaseBOP
16 | import imageio.v2 as imageio
17 | from utils.inout import load_json
18 |
19 | pl.seed_everything(2023)
20 |
21 |
22 | class BOPTemplate(Dataset):
23 | def __init__(
24 | self,
25 | template_dir,
26 | obj_ids,
27 | processing_config,
28 | level_templates,
29 | pose_distribution,
30 | **kwargs,
31 | ):
32 | self.template_dir = template_dir
33 | if obj_ids is None:
34 | obj_ids = [
35 | int(obj_id[4:])
36 | for obj_id in os.listdir(template_dir)
37 | if osp.isdir(osp.join(template_dir, obj_id))
38 | ]
39 | obj_ids = sorted(obj_ids)
40 | logging.info(f"Found {obj_ids} objects in {self.template_dir}")
41 | self.obj_ids = obj_ids
42 | self.processing_config = processing_config
43 | self.rgb_transform = T.Compose(
44 | [
45 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
46 | ]
47 | )
48 | self.proposal_processor = CropResizePad(self.processing_config.image_size)
49 | self.load_template_poses(level_templates, pose_distribution)
50 |
51 | def __len__(self):
52 | return len(self.obj_ids)
53 |
54 | def load_template_poses(self, level_templates, pose_distribution):
55 | if pose_distribution == "all":
56 | self.index_templates = load_index_level_in_level2(level_templates, "all")
57 | else:
58 | raise NotImplementedError
59 |
60 | def __getitem__(self, idx):
61 | templates, masks, boxes = [], [], []
62 | for id_template in self.index_templates:
63 | image = Image.open(
64 | f"{self.template_dir}/obj_{self.obj_ids[idx]:06d}/{id_template:06d}.png"
65 | )
66 | boxes.append(image.getbbox())
67 |
68 | mask = image.getchannel("A")
69 | mask = torch.from_numpy(np.array(mask) / 255).float()
70 | masks.append(mask.unsqueeze(-1))
71 |
72 | image = torch.from_numpy(np.array(image.convert("RGB")) / 255).float()
73 | templates.append(image)
74 |
75 | templates = torch.stack(templates).permute(0, 3, 1, 2)
76 | masks = torch.stack(masks).permute(0, 3, 1, 2)
77 | boxes = torch.tensor(np.array(boxes))
78 | templates_croped = self.proposal_processor(images=templates, boxes=boxes)
79 | masks_cropped = self.proposal_processor(images=masks, boxes=boxes)
80 | return {
81 | "templates": self.rgb_transform(templates_croped),
82 | "template_masks": masks_cropped[:, 0, :, :],
83 | }
84 |
85 |
86 | class BaseBOPTest(BaseBOP):
87 | def __init__(
88 | self,
89 | root_dir,
90 | split,
91 | **kwargs,
92 | ):
93 | self.root_dir = root_dir
94 | self.split = split
95 | self.load_list_scene(split=split)
96 | self.load_metaData(reset_metaData=True)
97 | # shuffle metadata
98 | self.metaData = self.metaData.sample(frac=1, random_state=2021).reset_index()
99 | self.camemra_params = {}
100 | for scenes in self.list_scenes:
101 | scenes_id = scenes.split("/")[-1]
102 | self.camemra_params.setdefault(scenes_id, load_json(osp.join(scenes, "scene_camera.json")))
103 | self.rgb_transform = T.Compose(
104 | [
105 | T.ToTensor(),
106 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
107 | ]
108 | )
109 |
110 | def load_depth_img(self, idx):
111 | depth_path = self.metaData.iloc[idx]["depth_path"]
112 | if depth_path is None:
113 | if self.root_dir.split("/")[-1] == "itodd":
114 | depth_path = self.metaData.iloc[idx]["rgb_path"].replace("gray", "depth")
115 | else:
116 | depth_path = self.metaData.iloc[idx]["rgb_path"].replace("rgb", "depth")
117 | depth = np.array(imageio.imread(depth_path))
118 | return depth
119 |
120 | def __getitem__(self, idx):
121 | rgb_path = self.metaData.iloc[idx]["rgb_path"]
122 | scene_id = self.metaData.iloc[idx]["scene_id"]
123 | frame_id = self.metaData.iloc[idx]["frame_id"]
124 | cam_intrinsic = self.metaData.iloc[idx]["intrinsic"]
125 | image = Image.open(rgb_path)
126 | image = self.rgb_transform(image.convert("RGB"))
127 | depth = self.load_depth_img(idx)
128 | cam_intrinsic = np.array(cam_intrinsic).reshape((3, 3))
129 | depth_scale = self.camemra_params[scene_id][f"{frame_id}"]["depth_scale"]
130 |
131 | return dict(
132 | image=image,
133 | scene_id=scene_id,
134 | frame_id=frame_id,
135 | depth=depth.astype(np.int32),
136 | cam_intrinsic=cam_intrinsic,
137 | depth_scale=depth_scale,
138 | )
139 |
140 | if __name__ == "__main__":
141 | logging.basicConfig(level=logging.INFO)
142 | from omegaconf import DictConfig, OmegaConf
143 | from torchvision.utils import make_grid, save_image
144 |
145 | processing_config = OmegaConf.create(
146 | {
147 | "image_size": 224,
148 | }
149 | )
150 | inv_rgb_transform = T.Compose(
151 | [
152 | T.Normalize(
153 | mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
154 | std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
155 | ),
156 | ]
157 | )
158 | dataset = BOPTemplate(
159 | template_dir="/home/nguyen/Documents/datasets/bop23/datasets/templates_pyrender/lmo",
160 | obj_ids=None,
161 | level_templates=0,
162 | pose_distribution="all",
163 | processing_config=processing_config,
164 | )
165 | for idx in tqdm(range(len(dataset))):
166 | sample = dataset[idx]
167 | sample["templates"] = inv_rgb_transform(sample["templates"])
168 | save_image(sample["templates"], f"./tmp/lm_{idx}.png", nrow=7)
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/run_inference.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import hydra
4 | from omegaconf import DictConfig, OmegaConf
5 | from hydra.utils import instantiate
6 | from torch.utils.data import DataLoader
7 |
8 |
9 | @hydra.main(version_base=None, config_path="configs", config_name="run_inference")
10 | def run_inference(cfg: DictConfig):
11 | OmegaConf.set_struct(cfg, False)
12 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
13 | output_path = hydra_cfg["runtime"]["output_dir"]
14 | logging.info(
15 | f"Training script. The outputs of hydra will be stored in: {output_path}"
16 | )
17 | logging.info("Initializing logger, callbacks and trainer")
18 |
19 | if cfg.machine.name == "slurm":
20 | num_gpus = int(os.environ["SLURM_GPUS_ON_NODE"])
21 | num_nodes = int(os.environ["SLURM_NNODES"])
22 | cfg.machine.trainer.devices = num_gpus
23 | cfg.machine.trainer.num_nodes = num_nodes
24 | logging.info(f"Slurm config: {num_gpus} gpus, {num_nodes} nodes")
25 | trainer = instantiate(cfg.machine.trainer)
26 |
27 | default_ref_dataloader_config = cfg.data.reference_dataloader
28 | default_query_dataloader_config = cfg.data.query_dataloader
29 |
30 | query_dataloader_config = default_query_dataloader_config.copy()
31 | ref_dataloader_config = default_ref_dataloader_config.copy()
32 |
33 | if cfg.dataset_name in ["hb", "tless"]:
34 | query_dataloader_config.split = "test_primesense"
35 | else:
36 | query_dataloader_config.split = "test"
37 | query_dataloader_config.root_dir += f"{cfg.dataset_name}"
38 | query_dataset = instantiate(query_dataloader_config)
39 |
40 | logging.info("Initializing model")
41 | model = instantiate(cfg.model)
42 |
43 | model.ref_obj_names = cfg.data.datasets[cfg.dataset_name].obj_names
44 | model.dataset_name = cfg.dataset_name
45 |
46 | query_dataloader = DataLoader(
47 | query_dataset,
48 | batch_size=1, # only support a single image for now
49 | num_workers=cfg.machine.num_workers,
50 | shuffle=False,
51 | )
52 | if cfg.model.onboarding_config.rendering_type == "pyrender":
53 | ref_dataloader_config.template_dir += f"templates_pyrender/{cfg.dataset_name}"
54 | ref_dataset = instantiate(ref_dataloader_config)
55 | elif cfg.model.onboarding_config.rendering_type == "pbr":
56 | logging.info("Using BlenderProc for reference images")
57 | ref_dataloader_config._target_ = "provider.bop_pbr.BOPTemplatePBR"
58 | ref_dataloader_config.root_dir = f"{query_dataloader_config.root_dir}"
59 | ref_dataloader_config.template_dir += f"templates_pyrender/{cfg.dataset_name}"
60 | if not os.path.exists(ref_dataloader_config.template_dir):
61 | os.makedirs(ref_dataloader_config.template_dir)
62 | ref_dataset = instantiate(ref_dataloader_config)
63 | ref_dataset.load_processed_metaData(reset_metaData=True)
64 | else:
65 | raise NotImplementedError
66 | model.ref_dataset = ref_dataset
67 |
68 | segmentation_name = cfg.model.segmentor_model._target_.split(".")[-1]
69 | agg_function = cfg.model.matching_config.aggregation_function
70 | rendering_type = cfg.model.onboarding_config.rendering_type
71 | level_template = cfg.model.onboarding_config.level_templates
72 | model.name_prediction_file = f"result_{cfg.dataset_name}"
73 | logging.info(f"Loading dataloader for {cfg.dataset_name} done!")
74 | trainer.test(
75 | model,
76 | dataloaders=query_dataloader,
77 | )
78 | logging.info(f"---" * 20)
79 |
80 |
81 | if __name__ == "__main__":
82 | logging.basicConfig(level=logging.INFO)
83 | run_inference()
84 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/segment_anything/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .build_sam import (
8 | build_sam,
9 | build_sam_vit_h,
10 | build_sam_vit_l,
11 | build_sam_vit_b,
12 | sam_model_registry,
13 | )
14 | from .predictor import SamPredictor
15 | from .automatic_mask_generator import SamAutomaticMaskGenerator
16 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/segment_anything/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from functools import partial
10 |
11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12 |
13 |
14 | def build_sam_vit_h(checkpoint=None):
15 | return _build_sam(
16 | encoder_embed_dim=1280,
17 | encoder_depth=32,
18 | encoder_num_heads=16,
19 | encoder_global_attn_indexes=[7, 15, 23, 31],
20 | checkpoint=checkpoint,
21 | )
22 |
23 |
24 | build_sam = build_sam_vit_h
25 |
26 |
27 | def build_sam_vit_l(checkpoint=None):
28 | return _build_sam(
29 | encoder_embed_dim=1024,
30 | encoder_depth=24,
31 | encoder_num_heads=16,
32 | encoder_global_attn_indexes=[5, 11, 17, 23],
33 | checkpoint=checkpoint,
34 | )
35 |
36 |
37 | def build_sam_vit_b(checkpoint=None):
38 | return _build_sam(
39 | encoder_embed_dim=768,
40 | encoder_depth=12,
41 | encoder_num_heads=12,
42 | encoder_global_attn_indexes=[2, 5, 8, 11],
43 | checkpoint=checkpoint,
44 | )
45 |
46 |
47 | sam_model_registry = {
48 | "default": build_sam_vit_h,
49 | "vit_h": build_sam_vit_h,
50 | "vit_l": build_sam_vit_l,
51 | "vit_b": build_sam_vit_b,
52 | }
53 |
54 |
55 | def _build_sam(
56 | encoder_embed_dim,
57 | encoder_depth,
58 | encoder_num_heads,
59 | encoder_global_attn_indexes,
60 | checkpoint=None,
61 | ):
62 | prompt_embed_dim = 256
63 | image_size = 1024
64 | vit_patch_size = 16
65 | image_embedding_size = image_size // vit_patch_size
66 | sam = Sam(
67 | image_encoder=ImageEncoderViT(
68 | depth=encoder_depth,
69 | embed_dim=encoder_embed_dim,
70 | img_size=image_size,
71 | mlp_ratio=4,
72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73 | num_heads=encoder_num_heads,
74 | patch_size=vit_patch_size,
75 | qkv_bias=True,
76 | use_rel_pos=True,
77 | global_attn_indexes=encoder_global_attn_indexes,
78 | window_size=14,
79 | out_chans=prompt_embed_dim,
80 | ),
81 | prompt_encoder=PromptEncoder(
82 | embed_dim=prompt_embed_dim,
83 | image_embedding_size=(image_embedding_size, image_embedding_size),
84 | input_image_size=(image_size, image_size),
85 | mask_in_chans=16,
86 | ),
87 | mask_decoder=MaskDecoder(
88 | num_multimask_outputs=3,
89 | transformer=TwoWayTransformer(
90 | depth=2,
91 | embedding_dim=prompt_embed_dim,
92 | mlp_dim=2048,
93 | num_heads=8,
94 | ),
95 | transformer_dim=prompt_embed_dim,
96 | iou_head_depth=3,
97 | iou_head_hidden_dim=256,
98 | ),
99 | pixel_mean=[123.675, 116.28, 103.53],
100 | pixel_std=[58.395, 57.12, 57.375],
101 | )
102 | sam.eval()
103 | if checkpoint is not None:
104 | with open(checkpoint, "rb") as f:
105 | state_dict = torch.load(f)
106 | sam.load_state_dict(state_dict)
107 | return sam
108 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/segment_anything/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder import MaskDecoder
10 | from .prompt_encoder import PromptEncoder
11 | from .transformer import TwoWayTransformer
12 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/segment_anything/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/segment_anything/modeling/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import List, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class MaskDecoder(nn.Module):
17 | def __init__(
18 | self,
19 | *,
20 | transformer_dim: int,
21 | transformer: nn.Module,
22 | num_multimask_outputs: int = 3,
23 | activation: Type[nn.Module] = nn.GELU,
24 | iou_head_depth: int = 3,
25 | iou_head_hidden_dim: int = 256,
26 | ) -> None:
27 | """
28 | Predicts masks given an image and prompt embeddings, using a
29 | transformer architecture.
30 |
31 | Arguments:
32 | transformer_dim (int): the channel dimension of the transformer
33 | transformer (nn.Module): the transformer used to predict masks
34 | num_multimask_outputs (int): the number of masks to predict
35 | when disambiguating masks
36 | activation (nn.Module): the type of activation to use when
37 | upscaling masks
38 | iou_head_depth (int): the depth of the MLP used to predict
39 | mask quality
40 | iou_head_hidden_dim (int): the hidden dimension of the MLP
41 | used to predict mask quality
42 | """
43 | super().__init__()
44 | self.transformer_dim = transformer_dim
45 | self.transformer = transformer
46 |
47 | self.num_multimask_outputs = num_multimask_outputs
48 |
49 | self.iou_token = nn.Embedding(1, transformer_dim)
50 | self.num_mask_tokens = num_multimask_outputs + 1
51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52 |
53 | self.output_upscaling = nn.Sequential(
54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55 | LayerNorm2d(transformer_dim // 4),
56 | activation(),
57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58 | activation(),
59 | )
60 | self.output_hypernetworks_mlps = nn.ModuleList(
61 | [
62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63 | for i in range(self.num_mask_tokens)
64 | ]
65 | )
66 |
67 | self.iou_prediction_head = MLP(
68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69 | )
70 |
71 | def forward(
72 | self,
73 | image_embeddings: torch.Tensor,
74 | image_pe: torch.Tensor,
75 | sparse_prompt_embeddings: torch.Tensor,
76 | dense_prompt_embeddings: torch.Tensor,
77 | multimask_output: bool,
78 | ) -> Tuple[torch.Tensor, torch.Tensor]:
79 | """
80 | Predict masks given image and prompt embeddings.
81 |
82 | Arguments:
83 | image_embeddings (torch.Tensor): the embeddings from the image encoder
84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
87 | multimask_output (bool): Whether to return multiple masks or a single
88 | mask.
89 |
90 | Returns:
91 | torch.Tensor: batched predicted masks
92 | torch.Tensor: batched predictions of mask quality
93 | """
94 | masks, iou_pred = self.predict_masks(
95 | image_embeddings=image_embeddings,
96 | image_pe=image_pe,
97 | sparse_prompt_embeddings=sparse_prompt_embeddings,
98 | dense_prompt_embeddings=dense_prompt_embeddings,
99 | )
100 |
101 | # Select the correct mask or masks for output
102 | if multimask_output:
103 | mask_slice = slice(1, None)
104 | else:
105 | mask_slice = slice(0, 1)
106 | masks = masks[:, mask_slice, :, :]
107 | iou_pred = iou_pred[:, mask_slice]
108 |
109 | # Prepare output
110 | return masks, iou_pred
111 |
112 | def predict_masks(
113 | self,
114 | image_embeddings: torch.Tensor,
115 | image_pe: torch.Tensor,
116 | sparse_prompt_embeddings: torch.Tensor,
117 | dense_prompt_embeddings: torch.Tensor,
118 | ) -> Tuple[torch.Tensor, torch.Tensor]:
119 | """Predicts masks. See 'forward' for more details."""
120 | # Concatenate output tokens
121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
124 |
125 | # Expand per-image data in batch direction to be per-mask
126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
127 | src = src + dense_prompt_embeddings
128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
129 | b, c, h, w = src.shape
130 |
131 | # Run the transformer
132 | hs, src = self.transformer(src, pos_src, tokens)
133 | iou_token_out = hs[:, 0, :]
134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
135 |
136 | # Upscale mask embeddings and predict masks using the mask tokens
137 | src = src.transpose(1, 2).view(b, c, h, w)
138 | upscaled_embedding = self.output_upscaling(src)
139 | hyper_in_list: List[torch.Tensor] = []
140 | for i in range(self.num_mask_tokens):
141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
142 | hyper_in = torch.stack(hyper_in_list, dim=1)
143 | b, c, h, w = upscaled_embedding.shape
144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
145 |
146 | # Generate mask quality predictions
147 | iou_pred = self.iou_prediction_head(iou_token_out)
148 |
149 | return masks, iou_pred
150 |
151 |
152 | # Lightly adapted from
153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
154 | class MLP(nn.Module):
155 | def __init__(
156 | self,
157 | input_dim: int,
158 | hidden_dim: int,
159 | output_dim: int,
160 | num_layers: int,
161 | sigmoid_output: bool = False,
162 | ) -> None:
163 | super().__init__()
164 | self.num_layers = num_layers
165 | h = [hidden_dim] * (num_layers - 1)
166 | self.layers = nn.ModuleList(
167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
168 | )
169 | self.sigmoid_output = sigmoid_output
170 |
171 | def forward(self, x):
172 | for i, layer in enumerate(self.layers):
173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
174 | if self.sigmoid_output:
175 | x = F.sigmoid(x)
176 | return x
177 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/segment_anything/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/segment_anything/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Tuple
12 |
13 | from ..modeling import Sam
14 | from .amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(
43 | input_image_size: torch.Tensor, longest_side: int
44 | ) -> torch.Tensor:
45 | input_image_size = input_image_size.to(torch.float32)
46 | scale = longest_side / torch.max(input_image_size)
47 | transformed_size = scale * input_image_size
48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49 | return transformed_size
50 |
51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52 | point_coords = point_coords + 0.5
53 | point_coords = point_coords / self.img_size
54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56 |
57 | point_embedding = point_embedding * (point_labels != -1)
58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59 | point_labels == -1
60 | )
61 |
62 | for i in range(self.model.prompt_encoder.num_point_embeddings):
63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64 | i
65 | ].weight * (point_labels == i)
66 |
67 | return point_embedding
68 |
69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71 | mask_embedding = mask_embedding + (
72 | 1 - has_mask_input
73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74 | return mask_embedding
75 |
76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77 | masks = F.interpolate(
78 | masks,
79 | size=(self.img_size, self.img_size),
80 | mode="bilinear",
81 | align_corners=False,
82 | )
83 |
84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
86 |
87 | orig_im_size = orig_im_size.to(torch.int64)
88 | h, w = orig_im_size[0], orig_im_size[1]
89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90 | return masks
91 |
92 | def select_masks(
93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94 | ) -> Tuple[torch.Tensor, torch.Tensor]:
95 | # Determine if we should return the multiclick mask or not from the number of points.
96 | # The reweighting is used to avoid control flow.
97 | score_reweight = torch.tensor(
98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99 | ).to(iou_preds.device)
100 | score = iou_preds + (num_points - 2.5) * score_reweight
101 | best_idx = torch.argmax(score, dim=1)
102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104 |
105 | return masks, iou_preds
106 |
107 | @torch.no_grad()
108 | def forward(
109 | self,
110 | image_embeddings: torch.Tensor,
111 | point_coords: torch.Tensor,
112 | point_labels: torch.Tensor,
113 | mask_input: torch.Tensor,
114 | has_mask_input: torch.Tensor,
115 | orig_im_size: torch.Tensor,
116 | ):
117 | sparse_embedding = self._embed_points(point_coords, point_labels)
118 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
119 |
120 | masks, scores = self.model.mask_decoder.predict_masks(
121 | image_embeddings=image_embeddings,
122 | image_pe=self.model.prompt_encoder.get_dense_pe(),
123 | sparse_prompt_embeddings=sparse_embedding,
124 | dense_prompt_embeddings=dense_embedding,
125 | )
126 |
127 | if self.use_stability_score:
128 | scores = calculate_stability_score(
129 | masks, self.model.mask_threshold, self.stability_score_offset
130 | )
131 |
132 | if self.return_single_mask:
133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134 |
135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136 |
137 | if self.return_extra_metrics:
138 | stability_scores = calculate_stability_score(
139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140 | )
141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142 | return upscaled_masks, scores, stability_scores, areas, masks
143 |
144 | return upscaled_masks, scores, masks
145 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/segment_anything/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to the longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34 | """
35 | Expects a numpy array of length 2 in the final dimension. Requires the
36 | original image size in (H, W) format.
37 | """
38 | old_h, old_w = original_size
39 | new_h, new_w = self.get_preprocess_shape(
40 | original_size[0], original_size[1], self.target_length
41 | )
42 | coords = deepcopy(coords).astype(float)
43 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
44 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
45 | return coords
46 |
47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48 | """
49 | Expects a numpy array shape Bx4. Requires the original image size
50 | in (H, W) format.
51 | """
52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53 | return boxes.reshape(-1, 4)
54 |
55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56 | """
57 | Expects batched images with shape BxCxHxW and float format. This
58 | transformation may not exactly match apply_image. apply_image is
59 | the transformation expected by the model.
60 | """
61 | # Expects an image in BCHW format. May not exactly match apply_image.
62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
63 | return F.interpolate(
64 | image, target_size, mode="bilinear", align_corners=False, antialias=True
65 | )
66 |
67 | def apply_coords_torch(
68 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
69 | ) -> torch.Tensor:
70 | """
71 | Expects a torch tensor with length 2 in the last dimension. Requires the
72 | original image size in (H, W) format.
73 | """
74 | old_h, old_w = original_size
75 | new_h, new_w = self.get_preprocess_shape(
76 | original_size[0], original_size[1], self.target_length
77 | )
78 | coords = deepcopy(coords).to(torch.float)
79 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
80 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
81 | return coords
82 |
83 | def apply_boxes_torch(
84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85 | ) -> torch.Tensor:
86 | """
87 | Expects a torch tensor with shape Bx4. Requires the original image
88 | size in (H, W) format.
89 | """
90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91 | return boxes.reshape(-1, 4)
92 |
93 | @staticmethod
94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95 | """
96 | Compute the output size given input size and target long side length.
97 | """
98 | scale = long_side_length * 1.0 / max(oldh, oldw)
99 | newh, neww = oldh * scale, oldw * scale
100 | neww = int(neww + 0.5)
101 | newh = int(newh + 0.5)
102 | return (newh, neww)
103 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/inout.py:
--------------------------------------------------------------------------------
1 | import os
2 | import errno
3 | import shutil
4 | import numpy as np
5 | import json
6 | import sys
7 |
8 | try:
9 | import ruamel_yaml as yaml
10 | except ModuleNotFoundError:
11 | import ruamel.yaml as yaml
12 |
13 |
14 | def create_folder(path):
15 | try:
16 | os.mkdir(path)
17 | except OSError as exc:
18 | if exc.errno != errno.EEXIST:
19 | raise
20 | pass
21 |
22 |
23 | def del_folder(path):
24 | try:
25 | shutil.rmtree(path)
26 | except OSError as exc:
27 | pass
28 |
29 |
30 | def write_txt(path, list_files):
31 | with open(path, "w") as f:
32 | for idx in list_files:
33 | f.write(idx + "\n")
34 | f.close()
35 |
36 |
37 | def open_txt(path):
38 | with open(path, "r") as f:
39 | lines = f.readlines()
40 | lines = [line.strip() for line in lines]
41 | return lines
42 |
43 |
44 | def load_json(path):
45 | with open(path, "r") as f:
46 | # info = yaml.load(f, Loader=yaml.CLoader)
47 | info = json.load(f)
48 | return info
49 |
50 |
51 | def save_json(path, info):
52 | # save to json without sorting keys or changing format
53 | with open(path, "w") as f:
54 | json.dump(info, f, indent=4)
55 |
56 |
57 | def save_json_bop23(path, info):
58 | # save to json without sorting keys or changing format
59 | with open(path, "w") as f:
60 | json.dump(info, f)
61 |
62 |
63 | def save_npz(path, info):
64 | np.savez_compressed(path, **info)
65 |
66 |
67 | def casting_format_to_save_json(data):
68 | # casting for every keys in dict to list so that it can be saved as json
69 | for key in data.keys():
70 | if (
71 | isinstance(data[key][0], np.ndarray)
72 | or isinstance(data[key][0], np.float32)
73 | or isinstance(data[key][0], np.float64)
74 | or isinstance(data[key][0], np.int32)
75 | or isinstance(data[key][0], np.int64)
76 | ):
77 | data[key] = np.array(data[key]).tolist()
78 | return data
79 |
80 |
81 | def get_root_project():
82 | return os.path.dirname(os.path.dirname((os.path.abspath(__file__))))
83 |
84 |
85 | def append_lib(path):
86 | sys.path.append(os.path.join(path, "src"))
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | class LevelsFilter(logging.Filter):
5 | def __init__(self, levels):
6 | self.levels = [getattr(logging, level) for level in levels]
7 |
8 | def filter(self, record):
9 | return record.levelno in self.levels
10 |
11 |
12 | class StreamToLogger(object):
13 | """
14 | Fake file-like stream object that redirects writes to a logger instance.
15 | """
16 |
17 | def __init__(self, logger, level):
18 | self.logger = logger
19 | self.level = level
20 | self.linebuf = ""
21 |
22 | def write(self, buf):
23 | for line in buf.rstrip().splitlines():
24 | self.logger.log(self.level, line.rstrip())
25 |
26 | def flush(self):
27 | pass
28 |
29 |
30 | class TqdmLoggingHandler(logging.Handler):
31 | def __init__(self, level=logging.NOTSET):
32 | super().__init__(level)
33 |
34 | def emit(self, record):
35 | import tqdm
36 |
37 | try:
38 | msg = self.format(record)
39 | tqdm.tqdm.write(msg)
40 | self.flush()
41 | except Exception:
42 | self.handleError(record)
43 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/create_template_poses.py:
--------------------------------------------------------------------------------
1 | # blenderproc run poses/create_template_poses.py
2 | import blenderproc
3 | import bpy
4 | import bmesh
5 | import math
6 | import numpy as np
7 | import os
8 |
9 |
10 | def get_camera_positions(nSubDiv):
11 | """
12 | * Construct an icosphere
13 | * subdived
14 | """
15 |
16 | bpy.ops.mesh.primitive_ico_sphere_add(location=(0, 0, 0), enter_editmode=True)
17 | # bpy.ops.export_mesh.ply(filepath='./sphere.ply')
18 | icos = bpy.context.object
19 | me = icos.data
20 |
21 | # -- cut away lower part
22 | bm = bmesh.from_edit_mesh(me)
23 | sel = [v for v in bm.verts if v.co[2] < 0]
24 |
25 | bmesh.ops.delete(bm, geom=sel, context="FACES")
26 | bmesh.update_edit_mesh(me)
27 |
28 | # -- subdivide and move new vertices out to the surface of the sphere
29 | # nSubDiv = 3
30 | for i in range(nSubDiv):
31 | bpy.ops.mesh.subdivide()
32 |
33 | bm = bmesh.from_edit_mesh(me)
34 | for v in bm.verts:
35 | l = math.sqrt(v.co[0] ** 2 + v.co[1] ** 2 + v.co[2] ** 2)
36 | v.co[0] /= l
37 | v.co[1] /= l
38 | v.co[2] /= l
39 | bmesh.update_edit_mesh(me)
40 |
41 | # -- cut away zero elevation
42 | bm = bmesh.from_edit_mesh(me)
43 | sel = [v for v in bm.verts if v.co[2] <= 0]
44 | bmesh.ops.delete(bm, geom=sel, context="FACES")
45 | bmesh.update_edit_mesh(me)
46 |
47 | # convert vertex positions to az,el
48 | positions = []
49 | angles = []
50 | bm = bmesh.from_edit_mesh(me)
51 | for v in bm.verts:
52 | x = v.co[0]
53 | y = v.co[1]
54 | z = v.co[2]
55 | az = math.atan2(x, y) # *180./math.pi
56 | el = math.atan2(z, math.sqrt(x**2 + y**2)) # *180./math.pi
57 | # positions.append((az,el))
58 | angles.append((el, az))
59 | positions.append((x, y, z))
60 |
61 | bpy.ops.object.editmode_toggle()
62 |
63 | # sort positions, first by az and el
64 | data = zip(angles, positions)
65 | positions = sorted(data)
66 | positions = [y for x, y in positions]
67 | angles = sorted(angles)
68 | return angles, positions
69 |
70 |
71 | def normalize(vec):
72 | return vec / (np.linalg.norm(vec, axis=-1, keepdims=True))
73 |
74 |
75 | def look_at(cam_location, point):
76 | # Cam points in positive z direction
77 | forward = point - cam_location
78 | forward = normalize(forward)
79 |
80 | tmp = np.array([0.0, 0.0, -1.0])
81 | # print warning when camera location is parallel to tmp
82 | norm = min(
83 | np.linalg.norm(cam_location - tmp, axis=-1),
84 | np.linalg.norm(cam_location + tmp, axis=-1),
85 | )
86 | if norm < 1e-3:
87 | print("Warning: camera location is parallel to tmp")
88 | tmp = np.array([0.0, -1.0, 0.0])
89 |
90 | right = np.cross(tmp, forward)
91 | right = normalize(right)
92 |
93 | up = np.cross(forward, right)
94 | up = normalize(up)
95 |
96 | mat = np.stack((right, up, forward, cam_location), axis=-1)
97 |
98 | hom_vec = np.array([[0.0, 0.0, 0.0, 1.0]])
99 |
100 | if len(mat.shape) > 2:
101 | hom_vec = np.tile(hom_vec, [mat.shape[0], 1, 1])
102 |
103 | mat = np.concatenate((mat, hom_vec), axis=-2)
104 | return mat
105 |
106 |
107 | def convert_location_to_rotation(locations):
108 | obj_poses = np.zeros((len(locations), 4, 4))
109 | for idx, pt in enumerate(locations):
110 | obj_poses[idx] = look_at(pt, np.array([0, 0, 0]))
111 | return obj_poses
112 |
113 |
114 | def inverse_transform(poses):
115 | new_poses = np.zeros_like(poses)
116 | for idx_pose in range(len(poses)):
117 | rot = poses[idx_pose, :3, :3]
118 | t = poses[idx_pose, :3, 3]
119 | rot = np.transpose(rot)
120 | t = -np.matmul(rot, t)
121 | new_poses[idx_pose][3][3] = 1
122 | new_poses[idx_pose][:3, :3] = rot
123 | new_poses[idx_pose][:3, 3] = t
124 | return new_poses
125 |
126 |
127 | save_dir = "utils/poses/predefined_poses"
128 | if not os.path.exists(save_dir):
129 | os.makedirs(save_dir)
130 |
131 | for level in [0, 1, 2]:
132 | position_icosphere = np.asarray(get_camera_positions(level)[1])
133 | cam_poses = convert_location_to_rotation(position_icosphere)
134 | cam_poses[:, :3, 3] *= 1000.
135 | np.save(f"{save_dir}/cam_poses_level{level}.npy", cam_poses)
136 | obj_poses = inverse_transform(cam_poses)
137 | np.save(f"{save_dir}/obj_poses_level{level}.npy", obj_poses)
138 |
139 | print("Output saved to: " + save_dir)
140 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/find_neighbors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # import open3d as o3d
3 | import os
4 | from utils.poses.pose_utils import (
5 | get_obj_poses_from_template_level,
6 | get_root_project,
7 | NearestTemplateFinder,
8 | )
9 | import os.path as osp
10 | # from utils.vis_3d_utils import convert_numpy_to_open3d, draw_camera
11 |
12 | if __name__ == "__main__":
13 | for template_level in range(2):
14 | templates_poses_level0 = get_obj_poses_from_template_level(
15 | template_level, "all", return_cam=True
16 | )
17 | finder = NearestTemplateFinder(
18 | level_templates=2,
19 | pose_distribution="all",
20 | return_inplane=True,
21 | )
22 | obj_poses_level = get_obj_poses_from_template_level(template_level, "all", return_cam=False)
23 | idx_templates, inplanes = finder.search_nearest_template(obj_poses_level)
24 | print(len(obj_poses_level), len(idx_templates))
25 | root_repo = get_root_project()
26 | save_path = os.path.join(root_repo, f"utils/poses/predefined_poses/idx_all_level{template_level}_in_level2.npy")
27 | np.save(save_path, idx_templates)
28 |
29 | # level 2 in level 2 is just itself
30 | obj_poses_level = get_obj_poses_from_template_level(2, "all", return_cam=False)
31 | print(len(obj_poses_level))
32 | save_path = os.path.join(root_repo, "utils/poses/predefined_poses/idx_all_level2_in_level2.npy")
33 | np.save(save_path, np.arange(len(obj_poses_level)))
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/fps.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # credit: https://github.com/ziruiw-dev/farthest-point-sampling/blob/master/fps_v1.py
3 |
4 | class FPS:
5 | def __init__(self, pcd_xyz, n_samples):
6 | self.n_samples = n_samples
7 | self.pcd_xyz = pcd_xyz
8 | self.n_pts = pcd_xyz.shape[0]
9 | self.dim = pcd_xyz.shape[1]
10 | self.selected_pts = None
11 | self.selected_pts_expanded = np.zeros(shape=(n_samples, 1, self.dim))
12 | self.remaining_pts = np.copy(pcd_xyz)
13 |
14 | self.grouping_radius = None
15 | self.dist_pts_to_selected = (
16 | None # Iteratively updated in step(). Finally re-used in group()
17 | )
18 | self.labels = None
19 |
20 | # Random pick a start
21 | self.start_idx = np.random.randint(low=0, high=self.n_pts - 1)
22 | self.selected_pts_expanded[0] = self.remaining_pts[self.start_idx]
23 | self.n_selected_pts = 1
24 | self.idx_selected = [self.start_idx]
25 |
26 | def get_selected_pts(self):
27 | self.selected_pts = np.squeeze(self.selected_pts_expanded, axis=1)
28 | return self.selected_pts
29 |
30 | def step(self):
31 | if self.n_selected_pts < self.n_samples:
32 | self.dist_pts_to_selected = self.__distance__(
33 | self.remaining_pts, self.selected_pts_expanded[: self.n_selected_pts]
34 | ).T
35 | dist_pts_to_selected_min = np.min(
36 | self.dist_pts_to_selected, axis=1, keepdims=True
37 | )
38 | res_selected_idx = np.argmax(dist_pts_to_selected_min)
39 | self.selected_pts_expanded[self.n_selected_pts] = self.remaining_pts[
40 | res_selected_idx
41 | ]
42 |
43 | self.n_selected_pts += 1
44 |
45 | # add to idx_selected
46 | self.idx_selected.append(res_selected_idx)
47 | else:
48 | print("Got enough number samples")
49 |
50 | def fit(self):
51 | for _ in range(1, self.n_samples):
52 | self.step()
53 | return self.get_selected_pts(), self.idx_selected
54 |
55 | def group(self, radius):
56 | self.grouping_radius = radius # the grouping radius is not actually used
57 | dists = self.dist_pts_to_selected
58 |
59 | # Ignore the "points"-"selected" relations if it's larger than the radius
60 | dists = np.where(dists > radius, dists + 1000000 * radius, dists)
61 |
62 | # Find the relation with the smallest distance.
63 | # NOTE: the smallest distance may still larger than the radius.
64 | self.labels = np.argmin(dists, axis=1)
65 | return self.labels
66 |
67 | @staticmethod
68 | def __distance__(a, b):
69 | return np.linalg.norm(a - b, ord=2, axis=2)
70 |
71 |
72 | if __name__ == "__main__":
73 | points = np.random.rand(1000, 3)
74 | sampled_points, idx_selected = FPS(points, 100).fit()
75 | print(sampled_points.shape, len(idx_selected))
76 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level0.npy
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level1.npy
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level2.npy
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level0_in_level2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level0_in_level2.npy
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level1_in_level2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level1_in_level2.npy
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level2_in_level2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level2_in_level2.npy
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level0.npy
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level1.npy
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level2.npy
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/poses/pyrender.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pyrender
3 | import trimesh
4 | import os
5 | from PIL import Image
6 | import numpy as np
7 | import os.path as osp
8 | from tqdm import tqdm
9 | import argparse
10 | from utils.trimesh_utils import as_mesh
11 | from utils.trimesh_utils import get_obj_diameter
12 | os.environ["DISPLAY"] = ":1"
13 | os.environ["PYOPENGL_PLATFORM"] = "egl"
14 |
15 |
16 | def render(
17 | mesh,
18 | output_dir,
19 | obj_poses,
20 | img_size,
21 | intrinsic,
22 | light_itensity=0.6,
23 | is_tless=False,
24 | ):
25 | # camera pose is fixed as np.eye(4)
26 | cam_pose = np.eye(4)
27 | # convert openCV camera
28 | cam_pose[1, 1] = -1
29 | cam_pose[2, 2] = -1
30 | # create scene config
31 | ambient_light = np.array([0.02, 0.02, 0.02, 1.0]) # np.array([1.0, 1.0, 1.0, 1.0])
32 | if light_itensity != 0.6:
33 | ambient_light = np.array([1.0, 1.0, 1.0, 1.0])
34 | scene = pyrender.Scene(
35 | bg_color=np.array([0.0, 0.0, 0.0, 0.0]), ambient_light=ambient_light
36 | )
37 | light = pyrender.SpotLight(
38 | color=np.ones(3),
39 | intensity=light_itensity,
40 | innerConeAngle=np.pi / 16.0,
41 | outerConeAngle=np.pi / 6.0,
42 | )
43 | scene.add(light, pose=cam_pose)
44 |
45 | # create camera and render engine
46 | fx, fy, cx, cy = intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2]
47 | camera = pyrender.IntrinsicsCamera(
48 | fx=fx, fy=fy, cx=cx, cy=cy, znear=0.05, zfar=100000
49 | )
50 | scene.add(camera, pose=cam_pose)
51 | render_engine = pyrender.OffscreenRenderer(img_size[1], img_size[0])
52 | cad_node = scene.add(mesh, pose=np.eye(4), name="cad")
53 |
54 | for idx_frame in range(obj_poses.shape[0]):
55 | scene.set_pose(cad_node, obj_poses[idx_frame])
56 | rgb, depth = render_engine.render(scene, pyrender.constants.RenderFlags.RGBA)
57 | rgb = Image.fromarray(np.uint8(rgb))
58 | rgb.save(osp.join(output_dir, f"{idx_frame:06d}.png"))
59 |
60 |
61 | if __name__ == "__main__":
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument("cad_path", nargs="?", help="Path to the model file")
64 | parser.add_argument("obj_pose", nargs="?", help="Path to the model file")
65 | parser.add_argument(
66 | "output_dir", nargs="?", help="Path to where the final files will be saved"
67 | )
68 | parser.add_argument("gpus_devices", nargs="?", help="GPU devices")
69 | parser.add_argument("disable_output", nargs="?", help="Disable output of blender")
70 | parser.add_argument("light_itensity", nargs="?", type=float, default=0.6, help="Light itensity")
71 | parser.add_argument("radius", nargs="?", type=float, default=1, help="Distance from camera to object")
72 | args = parser.parse_args()
73 | print(args)
74 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus_devices
75 | poses = np.load(args.obj_pose)
76 | # we can increase high energy for lightning but it's simpler to change just scale of the object to meter
77 | # poses[:, :3, :3] = poses[:, :3, :3] / 1000.0
78 | poses[:, :3, 3] = poses[:, :3, 3] / 1000.0
79 | if args.radius != 1:
80 | poses[:, :3, 3] = poses[:, :3, 3] * args.radius
81 | if "tless" in args.output_dir:
82 | intrinsic = np.asarray(
83 | [1075.65091572, 0.0, 360, 0.0, 1073.90347929, 270, 0.0, 0.0, 1.0]
84 | ).reshape(3, 3)
85 | img_size = [540, 720]
86 | is_tless = True
87 | else:
88 | intrinsic = np.array(
89 | [[572.4114, 0.0, 325.2611], [0.0, 573.57043, 242.04899], [0.0, 0.0, 1.0]]
90 | )
91 | img_size = [480, 640]
92 | is_tless = False
93 |
94 | # load mesh to meter
95 | mesh = trimesh.load_mesh(args.cad_path)
96 | diameter = get_obj_diameter(mesh)
97 | if diameter > 100: # object is in mm
98 | mesh.apply_scale(0.001)
99 | if is_tless:
100 | # setting uniform colors for mesh
101 | color = 0.4
102 | mesh.visual.face_colors = np.ones((len(mesh.faces), 3)) * color
103 | mesh.visual.vertex_colors = np.ones((len(mesh.vertices), 3)) * color
104 | mesh = pyrender.Mesh.from_trimesh(mesh, smooth=False)
105 | else:
106 | mesh = pyrender.Mesh.from_trimesh(as_mesh(mesh))
107 | os.makedirs(args.output_dir, exist_ok=True)
108 | render(
109 | output_dir=args.output_dir,
110 | mesh=mesh,
111 | obj_poses=poses,
112 | intrinsic=intrinsic,
113 | img_size=(480, 640),
114 | light_itensity=args.light_itensity,
115 | )
116 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/trimesh_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import trimesh
3 | import torch
4 |
5 |
6 | def load_mesh(path, ORIGIN_GEOMETRY="BOUNDS"):
7 | mesh = as_mesh(trimesh.load(path))
8 | if ORIGIN_GEOMETRY == "BOUNDS":
9 | AABB = mesh.bounds
10 | center = np.mean(AABB, axis=0)
11 | mesh.vertices -= center
12 | return mesh
13 |
14 |
15 | def get_bbox_from_mesh(mesh):
16 | AABB = mesh.bounds
17 | OBB = AABB_to_OBB(AABB)
18 | return OBB
19 |
20 |
21 | def get_obj_diameter(mesh_path):
22 | mesh = load_mesh(mesh_path)
23 | extents = mesh.extents * 2
24 | return np.linalg.norm(extents)
25 |
26 |
27 | def as_mesh(scene_or_mesh):
28 | if isinstance(scene_or_mesh, trimesh.Scene):
29 | result = trimesh.util.concatenate(
30 | [
31 | trimesh.Trimesh(vertices=m.vertices, faces=m.faces)
32 | for m in scene_or_mesh.geometry.values()
33 | ]
34 | )
35 | else:
36 | result = scene_or_mesh
37 | return result
38 |
39 |
40 | def AABB_to_OBB(AABB):
41 | """
42 | AABB bbox to oriented bounding box
43 | """
44 | minx, miny, minz, maxx, maxy, maxz = np.arange(6)
45 | corner_index = np.array(
46 | [
47 | minx,
48 | miny,
49 | minz,
50 | maxx,
51 | miny,
52 | minz,
53 | maxx,
54 | maxy,
55 | minz,
56 | minx,
57 | maxy,
58 | minz,
59 | minx,
60 | miny,
61 | maxz,
62 | maxx,
63 | miny,
64 | maxz,
65 | maxx,
66 | maxy,
67 | maxz,
68 | minx,
69 | maxy,
70 | maxz,
71 | ]
72 | ).reshape((-1, 3))
73 |
74 | corners = AABB.reshape(-1)[corner_index]
75 | return corners
76 |
77 | def depth_image_to_pointcloud_translate_torch(depth, scale, K):
78 | u = torch.arange(0, depth.shape[2])
79 | v = torch.arange(0, depth.shape[1])
80 |
81 | u, v = torch.meshgrid(u, v, indexing="xy")
82 | u = u.to(depth.device)
83 | v = v.to(depth.device)
84 |
85 | # depth metric is mm, depth_scale metric is m
86 | # K metric is m
87 | Z = depth * scale / 1000
88 | X = (u - K[0, 2]) * Z / K[0, 0]
89 | Y = (v - K[1, 2]) * Z / K[1, 1]
90 |
91 | valid = Z > 0
92 |
93 | X = X * valid
94 | Y = Y * valid
95 | Z = Z * valid
96 |
97 | # average should run on valid point
98 | valid_num = torch.count_nonzero(valid, axis=(1, 2)) + 1e-8
99 | avg_X = torch.sum(X, axis=(1, 2)) / valid_num
100 | avg_Y = torch.sum(Y, axis=(1, 2)) / valid_num
101 | avg_Z = torch.sum(Z, axis=(1, 2)) / valid_num
102 |
103 | translate = torch.vstack((avg_X, avg_Y, avg_Z)).permute(1, 0)
104 |
105 | return translate
106 |
107 |
108 | if __name__ == "__main__":
109 | mesh_path = (
110 | "/media/nguyen/Data/dataset/ShapeNet/ShapeNetCore.v2/"
111 | "03001627/1016f4debe988507589aae130c1f06fb/models/model_normalized.obj"
112 | )
113 | mesh = load_mesh(mesh_path)
114 | bbox = get_bbox_from_mesh(mesh)
115 | # create a visualization scene with rays, hits, and mesh
116 | scene = trimesh.Scene([mesh, trimesh.points.PointCloud(bbox)])
117 | # display the scene
118 | scene.show()
119 |
--------------------------------------------------------------------------------
/SAM-6D/Instance_Segmentation_Model/utils/weight.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import numpy as np
4 |
5 |
6 | def load_checkpoint(model, checkpoint_path, checkpoint_key=None, prefix=""):
7 | checkpoint = torch.load(checkpoint_path)
8 | if checkpoint_key is not None:
9 | pretrained_dict = checkpoint[checkpoint_key] # "state_dict"
10 | else:
11 | pretrained_dict = checkpoint
12 | pretrained_dict = {k.replace(prefix, ""): v for k, v in pretrained_dict.items()}
13 | model_dict = model.state_dict()
14 | # compare keys and update value
15 | pretrained_dict_can_load = {
16 | k: v
17 | for k, v in pretrained_dict.items()
18 | if k in model_dict and v.shape == model_dict[k].shape
19 | }
20 | pretrained_dict_cannot_load = [
21 | k for k, v in pretrained_dict.items() if k not in model_dict
22 | ]
23 | model_dict_not_update = [
24 | k for k, v in model_dict.items() if k not in pretrained_dict
25 | ]
26 | module_cannot_load = np.unique(
27 | [k.split(".")[0] for k in pretrained_dict_cannot_load] #
28 | )
29 | module_not_update = np.unique([k.split(".")[0] for k in model_dict_not_update]) #
30 | logging.info(f"Cannot load: {module_cannot_load}")
31 | logging.info(f"Not update: {module_not_update}")
32 | logging.info(
33 | f"Pretrained: {len(pretrained_dict)}/ Loaded: {len(pretrained_dict_can_load)}/ Cannot loaded: {len(pretrained_dict_cannot_load)} VS Current model: {len(model_dict)}"
34 | )
35 | model_dict.update(pretrained_dict_can_load)
36 | model.load_state_dict(model_dict)
37 | logging.info("Load pretrained done!")
38 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/README.md:
--------------------------------------------------------------------------------
1 | # Pose Estimation Model (PEM) for SAM-6D
2 |
3 |
4 |
5 | 
6 |
7 | ## Requirements
8 | The code has been tested with
9 | - python 3.9.6
10 | - pytorch 2.0.0
11 | - CUDA 11.3
12 |
13 | Other dependencies:
14 |
15 | ```
16 | sh dependencies.sh
17 | ```
18 |
19 | ## Data Preparation
20 |
21 | Please refer to [[link](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Data)] for more details.
22 |
23 |
24 | ## Model Download
25 | Our trained model is provided [[here](https://drive.google.com/file/d/1joW9IvwsaRJYxoUmGo68dBVg-HcFNyI7/view?usp=sharing)], and could be downloaded via the command:
26 | ```
27 | python download_sam6d-pem.py
28 | ```
29 |
30 | ## Training on MegaPose Training Set
31 |
32 | To train the Pose Estimation Model of SAM-6D, please prepare the training data and run the folowing command:
33 | ```
34 | python train.py --gpus 0,1,2,3 --model pose_estimation_model --config config/base.yaml
35 | ```
36 | By default, we use four GPUs of 3090ti to train the model with batchsize set as 28.
37 |
38 |
39 | ## Evaluation on BOP Datasets
40 |
41 | To evaluate the model on BOP datasets, please run the following command:
42 | ```
43 | python test_bop.py --gpus 0 --model pose_estimation_model --config config/base.yaml --dataset $DATASET --view 42
44 | ```
45 | The string "DATASET" could be set as `lmo`, `icbin`, `itodd`, `hb`, `tless`, `tudl`, `ycbv`, or `all`. Before evaluation, please refer to [[link](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Data)] for rendering the object templates of BOP datasets, or download our [rendered templates](https://drive.google.com/drive/folders/1fXt5Z6YDPZTJICZcywBUhu5rWnPvYAPI?usp=drive_link). Besides, the instance segmentation should be done following [[link](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Instance_Segmentation_Model)]; to test on your own segmentation results, you could change the "detection_paths" in the `test_bop.py` file.
46 |
47 | One could also download our trained model for evaluation:
48 | ```
49 | python test_bop.py --gpus 0 --model pose_estimation_model --config config/base.yaml --checkpoint_path checkpoints/sam-6d-pem-base.pth --dataset $DATASET --view 42
50 | ```
51 |
52 |
53 | ## Acknowledgements
54 | - [MegaPose](https://github.com/megapose6d/megapose6d)
55 | - [GDRNPP](https://github.com/shanice-l/gdrnpp_bop2022)
56 | - [GeoTransformer](https://github.com/qinzheng93/GeoTransformer)
57 | - [Flatten Transformer](https://github.com/LeapLabTHU/FLatten-Transformer)
58 |
59 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/config/base.yaml:
--------------------------------------------------------------------------------
1 | NAME_PROJECT: SAM-6D
2 |
3 | optimizer:
4 | type : Adam
5 | lr : 0.0001
6 | betas: [0.5, 0.999]
7 | eps : 0.000001
8 | weight_decay: 0.0
9 |
10 | lr_scheduler:
11 | type: WarmupCosineLR
12 | max_iters: 600000
13 | warmup_factor: 0.001
14 | warmup_iters: 1000
15 |
16 | model:
17 | coarse_npoint: 196
18 | fine_npoint: 2048
19 | feature_extraction:
20 | vit_type: vit_base
21 | up_type: linear
22 | embed_dim: 768
23 | out_dim: 256
24 | use_pyramid_feat: True
25 | pretrained: True
26 | geo_embedding:
27 | sigma_d: 0.2
28 | sigma_a: 15
29 | angle_k: 3
30 | reduction_a: max
31 | hidden_dim: 256
32 | coarse_point_matching:
33 | nblock: 3
34 | input_dim: 256
35 | hidden_dim: 256
36 | out_dim: 256
37 | temp: 0.1
38 | sim_type: cosine
39 | normalize_feat: True
40 | loss_dis_thres: 0.15
41 | nproposal1: 6000
42 | nproposal2: 300
43 | fine_point_matching:
44 | nblock: 3
45 | input_dim: 256
46 | hidden_dim: 256
47 | out_dim: 256
48 | pe_radius1: 0.1
49 | pe_radius2: 0.2
50 | focusing_factor: 3
51 | temp: 0.1
52 | sim_type: cosine
53 | normalize_feat: True
54 | loss_dis_thres: 0.15
55 |
56 |
57 |
58 | train_dataset:
59 | name: training_dataset
60 | data_dir: ../Data/MegaPose-Training-Data
61 | img_size: 224
62 | n_sample_observed_point: 2048
63 | n_sample_model_point: 2048
64 | n_sample_template_point: 5000
65 | min_visib_fract: 0.1
66 | min_px_count_visib: 512
67 | shift_range: 0.01
68 | rgb_mask_flag: True
69 | dilate_mask: True
70 |
71 | train_dataloader:
72 | bs : 28
73 | num_workers : 24
74 | shuffle : True
75 | drop_last : True
76 | pin_memory : False
77 |
78 |
79 |
80 | test_dataset:
81 | name: bop_test_dataset
82 | data_dir: ../Data/BOP
83 | template_dir: ../Data/BOP-Templates
84 | img_size: 224
85 | n_sample_observed_point: 2048
86 | n_sample_model_point: 1024
87 | n_sample_template_point: 5000
88 | minimum_n_point: 8
89 | rgb_mask_flag: True
90 | seg_filter_score: 0.25
91 | n_template_view: 42
92 |
93 |
94 | test_dataloader:
95 | bs : 16
96 | num_workers : 16
97 | shuffle : False
98 | drop_last : False
99 | pin_memory : False
100 |
101 |
102 | rd_seed: 1
103 | training_epoch: 15
104 | iters_to_print: 50
105 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/dependencies.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | pip install timm
5 | pip install gorilla-core==0.2.7.8
6 | pip install trimesh==3.22.1
7 | pip install imgaug
8 | pip install opencv-python
9 | pip install gpustat==1.0.0
10 | pip install einops
11 |
12 | cd /model/pointnet2
13 | python setup.py install
14 | cd ..
15 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/download_sam6d-pem.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | def download_model(output_path):
5 | import os
6 | command = f"gdown --no-cookies --no-check-certificate -O '{output_path}/sam-6d-pem-base.pth' 1joW9IvwsaRJYxoUmGo68dBVg-HcFNyI7"
7 | os.system(command)
8 |
9 | def download() -> None:
10 | root_dir = os.path.dirname((os.path.abspath(__file__)))
11 | save_dir = osp.join(root_dir, "checkpoints")
12 | os.makedirs(save_dir, exist_ok=True)
13 | download_model(save_dir)
14 |
15 | if __name__ == "__main__":
16 | download()
17 |
18 |
19 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/coarse_point_matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformer import GeometricTransformer
5 | from model_utils import (
6 | compute_feature_similarity,
7 | aug_pose_noise,
8 | compute_coarse_Rt,
9 | )
10 | from loss_utils import compute_correspondence_loss
11 |
12 |
13 |
14 | class CoarsePointMatching(nn.Module):
15 | def __init__(self, cfg, return_feat=False):
16 | super(CoarsePointMatching, self).__init__()
17 | self.cfg = cfg
18 | self.return_feat = return_feat
19 | self.nblock = self.cfg.nblock
20 |
21 | self.in_proj = nn.Linear(cfg.input_dim, cfg.hidden_dim)
22 | self.out_proj = nn.Linear(cfg.hidden_dim, cfg.out_dim)
23 |
24 | self.bg_token = nn.Parameter(torch.randn(1, 1, cfg.hidden_dim) * .02)
25 |
26 | self.transformers = []
27 | for _ in range(self.nblock):
28 | self.transformers.append(GeometricTransformer(
29 | blocks=['self', 'cross'],
30 | d_model = cfg.hidden_dim,
31 | num_heads = 4,
32 | dropout=None,
33 | activation_fn='ReLU',
34 | return_attention_scores=False,
35 | ))
36 | self.transformers = nn.ModuleList(self.transformers)
37 |
38 | def forward(self, p1, f1, geo1, p2, f2, geo2, radius, end_points):
39 | B = f1.size(0)
40 |
41 | f1 = self.in_proj(f1)
42 | f1 = torch.cat([self.bg_token.repeat(B,1,1), f1], dim=1) # adding bg
43 | f2 = self.in_proj(f2)
44 | f2 = torch.cat([self.bg_token.repeat(B,1,1), f2], dim=1) # adding bg
45 |
46 | atten_list = []
47 | for idx in range(self.nblock):
48 | f1, f2 = self.transformers[idx](f1, geo1, f2, geo2)
49 |
50 | if self.training or idx==self.nblock-1:
51 | atten_list.append(compute_feature_similarity(
52 | self.out_proj(f1),
53 | self.out_proj(f2),
54 | self.cfg.sim_type,
55 | self.cfg.temp,
56 | self.cfg.normalize_feat
57 | ))
58 |
59 | if self.training:
60 | gt_R = end_points['rotation_label']
61 | gt_t = end_points['translation_label'] / (radius.reshape(-1, 1)+1e-6)
62 | init_R, init_t = aug_pose_noise(gt_R, gt_t)
63 |
64 | end_points = compute_correspondence_loss(
65 | end_points, atten_list, p1, p2, gt_R, gt_t,
66 | dis_thres=self.cfg.loss_dis_thres,
67 | loss_str='coarse'
68 | )
69 | else:
70 | init_R, init_t = compute_coarse_Rt(
71 | atten_list[-1], p1, p2,
72 | end_points['model'] / (radius.reshape(-1, 1, 1) + 1e-6),
73 | self.cfg.nproposal1, self.cfg.nproposal2,
74 | )
75 | end_points['init_R'] = init_R
76 | end_points['init_t'] = init_t
77 |
78 | if self.return_feat:
79 | return end_points, self.out_proj(f1), self.out_proj(f2)
80 | else:
81 | return end_points
82 |
83 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/fine_point_matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from transformer import SparseToDenseTransformer
6 | from model_utils import compute_feature_similarity, compute_fine_Rt
7 | from loss_utils import compute_correspondence_loss
8 | from pointnet2_utils import QueryAndGroup
9 | from pytorch_utils import SharedMLP, Conv1d
10 |
11 |
12 | class FinePointMatching(nn.Module):
13 | def __init__(self, cfg, return_feat=False):
14 | super(FinePointMatching, self).__init__()
15 | self.cfg = cfg
16 | self.return_feat = return_feat
17 | self.nblock = self.cfg.nblock
18 |
19 | self.in_proj = nn.Linear(cfg.input_dim, cfg.hidden_dim)
20 | self.out_proj = nn.Linear(cfg.hidden_dim, cfg.out_dim)
21 |
22 | self.bg_token = nn.Parameter(torch.randn(1, 1, cfg.hidden_dim) * .02)
23 | self.PE = PositionalEncoding(cfg.hidden_dim, r1=cfg.pe_radius1, r2=cfg.pe_radius2)
24 |
25 | self.transformers = []
26 | for _ in range(self.nblock):
27 | self.transformers.append(SparseToDenseTransformer(
28 | cfg.hidden_dim,
29 | num_heads=4,
30 | sparse_blocks=['self', 'cross'],
31 | dropout=None,
32 | activation_fn='ReLU',
33 | focusing_factor=cfg.focusing_factor,
34 | with_bg_token=True,
35 | replace_bg_token=True
36 | ))
37 | self.transformers = nn.ModuleList(self.transformers)
38 |
39 | def forward(self, p1, f1, geo1, fps_idx1, p2, f2, geo2, fps_idx2, radius, end_points):
40 | B = p1.size(0)
41 |
42 | init_R = end_points['init_R']
43 | init_t = end_points['init_t']
44 | p1_ = (p1 - init_t.unsqueeze(1)) @ init_R
45 |
46 | f1 = self.in_proj(f1) + self.PE(p1_)
47 | f1 = torch.cat([self.bg_token.repeat(B,1,1), f1], dim=1) # adding bg
48 |
49 | f2 = self.in_proj(f2) + self.PE(p2)
50 | f2 = torch.cat([self.bg_token.repeat(B,1,1), f2], dim=1) # adding bg
51 |
52 | atten_list = []
53 | for idx in range(self.nblock):
54 | f1, f2 = self.transformers[idx](f1, geo1, fps_idx1, f2, geo2, fps_idx2)
55 |
56 | if self.training or idx==self.nblock-1:
57 | atten_list.append(compute_feature_similarity(
58 | self.out_proj(f1),
59 | self.out_proj(f2),
60 | self.cfg.sim_type,
61 | self.cfg.temp,
62 | self.cfg.normalize_feat
63 | ))
64 |
65 | if self.training:
66 | gt_R = end_points['rotation_label']
67 | gt_t = end_points['translation_label'] / (radius.reshape(-1, 1)+1e-6)
68 |
69 | end_points = compute_correspondence_loss(
70 | end_points, atten_list, p1, p2, gt_R, gt_t,
71 | dis_thres=self.cfg.loss_dis_thres,
72 | loss_str='fine'
73 | )
74 | else:
75 | pred_R, pred_t, pred_pose_score = compute_fine_Rt(
76 | atten_list[-1], p1, p2,
77 | end_points['model'] / (radius.reshape(-1, 1, 1) + 1e-6),
78 | )
79 | end_points['pred_R'] = pred_R
80 | end_points['pred_t'] = pred_t * (radius.reshape(-1, 1)+1e-6)
81 | end_points['pred_pose_score'] = pred_pose_score
82 |
83 | if self.return_feat:
84 | return end_points, self.out_proj(f1), self.out_proj(f2)
85 | else:
86 | return end_points
87 |
88 |
89 |
90 | class PositionalEncoding(nn.Module):
91 | def __init__(self, out_dim, r1=0.1, r2=0.2, nsample1=32, nsample2=64, use_xyz=True, bn=True):
92 | super(PositionalEncoding, self).__init__()
93 | self.group1 = QueryAndGroup(r1, nsample1, use_xyz=use_xyz)
94 | self.group2 = QueryAndGroup(r2, nsample2, use_xyz=use_xyz)
95 | input_dim = 6 if use_xyz else 3
96 |
97 | self.mlp1 = SharedMLP([input_dim, 32, 64, 128], bn=bn)
98 | self.mlp2 = SharedMLP([input_dim, 32, 64, 128], bn=bn)
99 | self.mlp3 = Conv1d(256, out_dim, 1, activation=None, bn=None)
100 |
101 | def forward(self, pts1, pts2=None):
102 | if pts2 is None:
103 | pts2 = pts1
104 |
105 | # scale1
106 | feat1 = self.group1(
107 | pts1.contiguous(), pts2.contiguous(), pts1.transpose(1,2).contiguous()
108 | )
109 | feat1 = self.mlp1(feat1)
110 | feat1 = F.max_pool2d(
111 | feat1, kernel_size=[1, feat1.size(3)]
112 | )
113 |
114 | # scale2
115 | feat2 = self.group2(
116 | pts1.contiguous(), pts2.contiguous(), pts1.transpose(1,2).contiguous()
117 | )
118 | feat2 = self.mlp2(feat2)
119 | feat2 = F.max_pool2d(
120 | feat2, kernel_size=[1, feat2.size(3)]
121 | )
122 |
123 | feat = torch.cat([feat1, feat2], dim=1).squeeze(-1)
124 | feat = self.mlp3(feat).transpose(1,2)
125 | return feat
126 |
127 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/ball_query.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 |
9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
10 | const int nsample);
11 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/cuda_utils.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #ifndef _CUDA_UTILS_H
7 | #define _CUDA_UTILS_H
8 |
9 | #include
10 | #include
11 | #include
12 |
13 | #include
14 | #include
15 |
16 | #include
17 |
18 | #define TOTAL_THREADS 512
19 |
20 | inline int opt_n_threads(int work_size) {
21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
22 |
23 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
24 | }
25 |
26 | inline dim3 opt_block_config(int x, int y) {
27 | const int x_threads = opt_n_threads(x);
28 | const int y_threads =
29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
30 | dim3 block_config(x_threads, y_threads, 1);
31 |
32 | return block_config;
33 | }
34 |
35 | #define CUDA_CHECK_ERRORS() \
36 | do { \
37 | cudaError_t err = cudaGetLastError(); \
38 | if (cudaSuccess != err) { \
39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
41 | __FILE__); \
42 | exit(-1); \
43 | } \
44 | } while (0)
45 |
46 | #endif
47 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/group_points.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 |
9 | at::Tensor group_points(at::Tensor points, at::Tensor idx);
10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
11 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/interpolate.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 |
8 | #include
9 | #include
10 |
11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows);
12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
13 | at::Tensor weight);
14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
15 | at::Tensor weight, const int m);
16 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/sampling.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 |
9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx);
10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
12 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/utils.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 | #include
9 |
10 | #define CHECK_CUDA(x) \
11 | do { \
12 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \
13 | } while (0)
14 |
15 | #define CHECK_CONTIGUOUS(x) \
16 | do { \
17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \
18 | } while (0)
19 |
20 | #define CHECK_IS_INT(x) \
21 | do { \
22 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
23 | #x " must be an int tensor"); \
24 | } while (0)
25 |
26 | #define CHECK_IS_FLOAT(x) \
27 | do { \
28 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \
29 | #x " must be a float tensor"); \
30 | } while (0)
31 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/ball_query.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "ball_query.h"
7 | #include "utils.h"
8 |
9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
10 | int nsample, const float *new_xyz,
11 | const float *xyz, int *idx);
12 |
13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
14 | const int nsample) {
15 | CHECK_CONTIGUOUS(new_xyz);
16 | CHECK_CONTIGUOUS(xyz);
17 | CHECK_IS_FLOAT(new_xyz);
18 | CHECK_IS_FLOAT(xyz);
19 |
20 | if (new_xyz.type().is_cuda()) {
21 | CHECK_CUDA(xyz);
22 | }
23 |
24 | at::Tensor idx =
25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int));
27 |
28 | if (new_xyz.type().is_cuda()) {
29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
30 | radius, nsample, new_xyz.data(),
31 | xyz.data(), idx.data());
32 | } else {
33 | TORCH_CHECK(false, "CPU not supported");
34 | }
35 |
36 | return idx;
37 | }
38 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/ball_query_gpu.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include
7 | #include
8 | #include
9 |
10 | #include "cuda_utils.h"
11 |
12 | // input: new_xyz(b, m, 3) xyz(b, n, 3)
13 | // output: idx(b, m, nsample)
14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
15 | int nsample,
16 | const float *__restrict__ new_xyz,
17 | const float *__restrict__ xyz,
18 | int *__restrict__ idx) {
19 | int batch_index = blockIdx.x;
20 | xyz += batch_index * n * 3;
21 | new_xyz += batch_index * m * 3;
22 | idx += m * nsample * batch_index;
23 |
24 | int index = threadIdx.x;
25 | int stride = blockDim.x;
26 |
27 | float radius2 = radius * radius;
28 | for (int j = index; j < m; j += stride) {
29 | float new_x = new_xyz[j * 3 + 0];
30 | float new_y = new_xyz[j * 3 + 1];
31 | float new_z = new_xyz[j * 3 + 2];
32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
33 | float x = xyz[k * 3 + 0];
34 | float y = xyz[k * 3 + 1];
35 | float z = xyz[k * 3 + 2];
36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
37 | (new_z - z) * (new_z - z);
38 | if (d2 < radius2) {
39 | if (cnt == 0) {
40 | for (int l = 0; l < nsample; ++l) {
41 | idx[j * nsample + l] = k;
42 | }
43 | }
44 | idx[j * nsample + cnt] = k;
45 | ++cnt;
46 | }
47 | }
48 | }
49 | }
50 |
51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
52 | int nsample, const float *new_xyz,
53 | const float *xyz, int *idx) {
54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
55 | query_ball_point_kernel<<>>(
56 | b, n, m, radius, nsample, new_xyz, xyz, idx);
57 |
58 | CUDA_CHECK_ERRORS();
59 | }
60 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "ball_query.h"
7 | #include "group_points.h"
8 | #include "interpolate.h"
9 | #include "sampling.h"
10 |
11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
12 | m.def("gather_points", &gather_points);
13 | m.def("gather_points_grad", &gather_points_grad);
14 | m.def("furthest_point_sampling", &furthest_point_sampling);
15 |
16 | m.def("three_nn", &three_nn);
17 | m.def("three_interpolate", &three_interpolate);
18 | m.def("three_interpolate_grad", &three_interpolate_grad);
19 |
20 | m.def("ball_query", &ball_query);
21 |
22 | m.def("group_points", &group_points);
23 | m.def("group_points_grad", &group_points_grad);
24 | }
25 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/group_points.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "group_points.h"
7 | #include "utils.h"
8 |
9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
10 | const float *points, const int *idx,
11 | float *out);
12 |
13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
14 | int nsample, const float *grad_out,
15 | const int *idx, float *grad_points);
16 |
17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) {
18 | CHECK_CONTIGUOUS(points);
19 | CHECK_CONTIGUOUS(idx);
20 | CHECK_IS_FLOAT(points);
21 | CHECK_IS_INT(idx);
22 |
23 | if (points.type().is_cuda()) {
24 | CHECK_CUDA(idx);
25 | }
26 |
27 | at::Tensor output =
28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)},
29 | at::device(points.device()).dtype(at::ScalarType::Float));
30 |
31 | if (points.type().is_cuda()) {
32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
33 | idx.size(1), idx.size(2), points.data(),
34 | idx.data(), output.data());
35 | } else {
36 | TORCH_CHECK(false, "CPU not supported");
37 | }
38 |
39 | return output;
40 | }
41 |
42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
43 | CHECK_CONTIGUOUS(grad_out);
44 | CHECK_CONTIGUOUS(idx);
45 | CHECK_IS_FLOAT(grad_out);
46 | CHECK_IS_INT(idx);
47 |
48 | if (grad_out.type().is_cuda()) {
49 | CHECK_CUDA(idx);
50 | }
51 |
52 | at::Tensor output =
53 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
54 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
55 |
56 | if (grad_out.type().is_cuda()) {
57 | group_points_grad_kernel_wrapper(
58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
59 | grad_out.data(), idx.data(), output.data());
60 | } else {
61 | TORCH_CHECK(false, "CPU not supported");
62 | }
63 |
64 | return output;
65 | }
66 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/group_points_gpu.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include
7 | #include
8 |
9 | #include "cuda_utils.h"
10 |
11 | // input: points(b, c, n) idx(b, npoints, nsample)
12 | // output: out(b, c, npoints, nsample)
13 | __global__ void group_points_kernel(int b, int c, int n, int npoints,
14 | int nsample,
15 | const float *__restrict__ points,
16 | const int *__restrict__ idx,
17 | float *__restrict__ out) {
18 | int batch_index = blockIdx.x;
19 | points += batch_index * n * c;
20 | idx += batch_index * npoints * nsample;
21 | out += batch_index * npoints * nsample * c;
22 |
23 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
24 | const int stride = blockDim.y * blockDim.x;
25 | for (int i = index; i < c * npoints; i += stride) {
26 | const int l = i / npoints;
27 | const int j = i % npoints;
28 | for (int k = 0; k < nsample; ++k) {
29 | int ii = idx[j * nsample + k];
30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii];
31 | }
32 | }
33 | }
34 |
35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
36 | const float *points, const int *idx,
37 | float *out) {
38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
39 |
40 | group_points_kernel<<>>(
41 | b, c, n, npoints, nsample, points, idx, out);
42 |
43 | CUDA_CHECK_ERRORS();
44 | }
45 |
46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
47 | // output: grad_points(b, c, n)
48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
49 | int nsample,
50 | const float *__restrict__ grad_out,
51 | const int *__restrict__ idx,
52 | float *__restrict__ grad_points) {
53 | int batch_index = blockIdx.x;
54 | grad_out += batch_index * npoints * nsample * c;
55 | idx += batch_index * npoints * nsample;
56 | grad_points += batch_index * n * c;
57 |
58 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
59 | const int stride = blockDim.y * blockDim.x;
60 | for (int i = index; i < c * npoints; i += stride) {
61 | const int l = i / npoints;
62 | const int j = i % npoints;
63 | for (int k = 0; k < nsample; ++k) {
64 | int ii = idx[j * nsample + k];
65 | atomicAdd(grad_points + l * n + ii,
66 | grad_out[(l * npoints + j) * nsample + k]);
67 | }
68 | }
69 | }
70 |
71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
72 | int nsample, const float *grad_out,
73 | const int *idx, float *grad_points) {
74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
75 |
76 | group_points_grad_kernel<<>>(
77 | b, c, n, npoints, nsample, grad_out, idx, grad_points);
78 |
79 | CUDA_CHECK_ERRORS();
80 | }
81 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/interpolate.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "interpolate.h"
7 | #include "utils.h"
8 |
9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
10 | const float *known, float *dist2, int *idx);
11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
12 | const float *points, const int *idx,
13 | const float *weight, float *out);
14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
15 | const float *grad_out,
16 | const int *idx, const float *weight,
17 | float *grad_points);
18 |
19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) {
20 | CHECK_CONTIGUOUS(unknowns);
21 | CHECK_CONTIGUOUS(knows);
22 | CHECK_IS_FLOAT(unknowns);
23 | CHECK_IS_FLOAT(knows);
24 |
25 | if (unknowns.type().is_cuda()) {
26 | CHECK_CUDA(knows);
27 | }
28 |
29 | at::Tensor idx =
30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
31 | at::device(unknowns.device()).dtype(at::ScalarType::Int));
32 | at::Tensor dist2 =
33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
34 | at::device(unknowns.device()).dtype(at::ScalarType::Float));
35 |
36 | if (unknowns.type().is_cuda()) {
37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
38 | unknowns.data(), knows.data(),
39 | dist2.data(), idx.data());
40 | } else {
41 | TORCH_CHECK(false, "CPU not supported");
42 | }
43 |
44 | return {dist2, idx};
45 | }
46 |
47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
48 | at::Tensor weight) {
49 | CHECK_CONTIGUOUS(points);
50 | CHECK_CONTIGUOUS(idx);
51 | CHECK_CONTIGUOUS(weight);
52 | CHECK_IS_FLOAT(points);
53 | CHECK_IS_INT(idx);
54 | CHECK_IS_FLOAT(weight);
55 |
56 | if (points.type().is_cuda()) {
57 | CHECK_CUDA(idx);
58 | CHECK_CUDA(weight);
59 | }
60 |
61 | at::Tensor output =
62 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
63 | at::device(points.device()).dtype(at::ScalarType::Float));
64 |
65 | if (points.type().is_cuda()) {
66 | three_interpolate_kernel_wrapper(
67 | points.size(0), points.size(1), points.size(2), idx.size(1),
68 | points.data(), idx.data(), weight.data(),
69 | output.data());
70 | } else {
71 | TORCH_CHECK(false, "CPU not supported");
72 | }
73 |
74 | return output;
75 | }
76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
77 | at::Tensor weight, const int m) {
78 | CHECK_CONTIGUOUS(grad_out);
79 | CHECK_CONTIGUOUS(idx);
80 | CHECK_CONTIGUOUS(weight);
81 | CHECK_IS_FLOAT(grad_out);
82 | CHECK_IS_INT(idx);
83 | CHECK_IS_FLOAT(weight);
84 |
85 | if (grad_out.type().is_cuda()) {
86 | CHECK_CUDA(idx);
87 | CHECK_CUDA(weight);
88 | }
89 |
90 | at::Tensor output =
91 | torch::zeros({grad_out.size(0), grad_out.size(1), m},
92 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
93 |
94 | if (grad_out.type().is_cuda()) {
95 | three_interpolate_grad_kernel_wrapper(
96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
97 | grad_out.data(), idx.data(), weight.data(),
98 | output.data());
99 | } else {
100 | TORCH_CHECK(false, "CPU not supported");
101 | }
102 |
103 | return output;
104 | }
105 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/interpolate_gpu.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include
7 | #include
8 | #include
9 |
10 | #include "cuda_utils.h"
11 |
12 | // input: unknown(b, n, 3) known(b, m, 3)
13 | // output: dist2(b, n, 3), idx(b, n, 3)
14 | __global__ void three_nn_kernel(int b, int n, int m,
15 | const float *__restrict__ unknown,
16 | const float *__restrict__ known,
17 | float *__restrict__ dist2,
18 | int *__restrict__ idx) {
19 | int batch_index = blockIdx.x;
20 | unknown += batch_index * n * 3;
21 | known += batch_index * m * 3;
22 | dist2 += batch_index * n * 3;
23 | idx += batch_index * n * 3;
24 |
25 | int index = threadIdx.x;
26 | int stride = blockDim.x;
27 | for (int j = index; j < n; j += stride) {
28 | float ux = unknown[j * 3 + 0];
29 | float uy = unknown[j * 3 + 1];
30 | float uz = unknown[j * 3 + 2];
31 |
32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40;
33 | int besti1 = 0, besti2 = 0, besti3 = 0;
34 | for (int k = 0; k < m; ++k) {
35 | float x = known[k * 3 + 0];
36 | float y = known[k * 3 + 1];
37 | float z = known[k * 3 + 2];
38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
39 | if (d < best1) {
40 | best3 = best2;
41 | besti3 = besti2;
42 | best2 = best1;
43 | besti2 = besti1;
44 | best1 = d;
45 | besti1 = k;
46 | } else if (d < best2) {
47 | best3 = best2;
48 | besti3 = besti2;
49 | best2 = d;
50 | besti2 = k;
51 | } else if (d < best3) {
52 | best3 = d;
53 | besti3 = k;
54 | }
55 | }
56 | dist2[j * 3 + 0] = best1;
57 | dist2[j * 3 + 1] = best2;
58 | dist2[j * 3 + 2] = best3;
59 |
60 | idx[j * 3 + 0] = besti1;
61 | idx[j * 3 + 1] = besti2;
62 | idx[j * 3 + 2] = besti3;
63 | }
64 | }
65 |
66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
67 | const float *known, float *dist2, int *idx) {
68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
69 | three_nn_kernel<<>>(b, n, m, unknown, known,
70 | dist2, idx);
71 |
72 | CUDA_CHECK_ERRORS();
73 | }
74 |
75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
76 | // output: out(b, c, n)
77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n,
78 | const float *__restrict__ points,
79 | const int *__restrict__ idx,
80 | const float *__restrict__ weight,
81 | float *__restrict__ out) {
82 | int batch_index = blockIdx.x;
83 | points += batch_index * m * c;
84 |
85 | idx += batch_index * n * 3;
86 | weight += batch_index * n * 3;
87 |
88 | out += batch_index * n * c;
89 |
90 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
91 | const int stride = blockDim.y * blockDim.x;
92 | for (int i = index; i < c * n; i += stride) {
93 | const int l = i / n;
94 | const int j = i % n;
95 | float w1 = weight[j * 3 + 0];
96 | float w2 = weight[j * 3 + 1];
97 | float w3 = weight[j * 3 + 2];
98 |
99 | int i1 = idx[j * 3 + 0];
100 | int i2 = idx[j * 3 + 1];
101 | int i3 = idx[j * 3 + 2];
102 |
103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
104 | points[l * m + i3] * w3;
105 | }
106 | }
107 |
108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
109 | const float *points, const int *idx,
110 | const float *weight, float *out) {
111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
112 | three_interpolate_kernel<<>>(
113 | b, c, m, n, points, idx, weight, out);
114 |
115 | CUDA_CHECK_ERRORS();
116 | }
117 |
118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
119 | // output: grad_points(b, c, m)
120 |
121 | __global__ void three_interpolate_grad_kernel(
122 | int b, int c, int n, int m, const float *__restrict__ grad_out,
123 | const int *__restrict__ idx, const float *__restrict__ weight,
124 | float *__restrict__ grad_points) {
125 | int batch_index = blockIdx.x;
126 | grad_out += batch_index * n * c;
127 | idx += batch_index * n * 3;
128 | weight += batch_index * n * 3;
129 | grad_points += batch_index * m * c;
130 |
131 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
132 | const int stride = blockDim.y * blockDim.x;
133 | for (int i = index; i < c * n; i += stride) {
134 | const int l = i / n;
135 | const int j = i % n;
136 | float w1 = weight[j * 3 + 0];
137 | float w2 = weight[j * 3 + 1];
138 | float w3 = weight[j * 3 + 2];
139 |
140 | int i1 = idx[j * 3 + 0];
141 | int i2 = idx[j * 3 + 1];
142 | int i3 = idx[j * 3 + 2];
143 |
144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
147 | }
148 | }
149 |
150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
151 | const float *grad_out,
152 | const int *idx, const float *weight,
153 | float *grad_points) {
154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
155 | three_interpolate_grad_kernel<<>>(
156 | b, c, n, m, grad_out, idx, weight, grad_points);
157 |
158 | CUDA_CHECK_ERRORS();
159 | }
160 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/sampling.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "sampling.h"
7 | #include "utils.h"
8 |
9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
10 | const float *points, const int *idx,
11 | float *out);
12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
13 | const float *grad_out, const int *idx,
14 | float *grad_points);
15 |
16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
17 | const float *dataset, float *temp,
18 | int *idxs);
19 |
20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) {
21 | CHECK_CONTIGUOUS(points);
22 | CHECK_CONTIGUOUS(idx);
23 | CHECK_IS_FLOAT(points);
24 | CHECK_IS_INT(idx);
25 |
26 | if (points.type().is_cuda()) {
27 | CHECK_CUDA(idx);
28 | }
29 |
30 | at::Tensor output =
31 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
32 | at::device(points.device()).dtype(at::ScalarType::Float));
33 |
34 | if (points.type().is_cuda()) {
35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
36 | idx.size(1), points.data(),
37 | idx.data(), output.data());
38 | } else {
39 | TORCH_CHECK(false, "CPU not supported");
40 | }
41 |
42 | return output;
43 | }
44 |
45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx,
46 | const int n) {
47 | CHECK_CONTIGUOUS(grad_out);
48 | CHECK_CONTIGUOUS(idx);
49 | CHECK_IS_FLOAT(grad_out);
50 | CHECK_IS_INT(idx);
51 |
52 | if (grad_out.type().is_cuda()) {
53 | CHECK_CUDA(idx);
54 | }
55 |
56 | at::Tensor output =
57 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
58 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
59 |
60 | if (grad_out.type().is_cuda()) {
61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n,
62 | idx.size(1), grad_out.data(),
63 | idx.data(), output.data());
64 | } else {
65 | TORCH_CHECK(false, "CPU not supported");
66 | }
67 |
68 | return output;
69 | }
70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) {
71 | CHECK_CONTIGUOUS(points);
72 | CHECK_IS_FLOAT(points);
73 |
74 | at::Tensor output =
75 | torch::zeros({points.size(0), nsamples},
76 | at::device(points.device()).dtype(at::ScalarType::Int));
77 |
78 | at::Tensor tmp =
79 | torch::full({points.size(0), points.size(1)}, 1e10,
80 | at::device(points.device()).dtype(at::ScalarType::Float));
81 |
82 | if (points.type().is_cuda()) {
83 | furthest_point_sampling_kernel_wrapper(
84 | points.size(0), points.size(1), nsamples, points.data(),
85 | tmp.data(), output.data());
86 | } else {
87 | TORCH_CHECK(false, "CPU not supported");
88 | }
89 |
90 | return output;
91 | }
92 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/pointnet2_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | ''' Testing customized ops. '''
7 |
8 | import torch
9 | from torch.autograd import gradcheck
10 | import numpy as np
11 |
12 | import os
13 | import sys
14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
15 | sys.path.append(BASE_DIR)
16 | import pointnet2_utils
17 |
18 | def test_interpolation_grad():
19 | batch_size = 1
20 | feat_dim = 2
21 | m = 4
22 | feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda()
23 |
24 | def interpolate_func(inputs):
25 | idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda()
26 | weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda()
27 | interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight)
28 | return interpolated_feats
29 |
30 | assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1))
31 |
32 | if __name__=='__main__':
33 | test_interpolation_grad()
34 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pointnet2/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import os
6 | from setuptools import setup, find_packages
7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
8 | import glob
9 |
10 | _ext_src_root = "_ext_src"
11 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob(
12 | "{}/src/*.cu".format(_ext_src_root)
13 | )
14 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root))
15 |
16 | setup(
17 | name='pointnet2',
18 | packages = find_packages(),
19 | ext_modules=[
20 | CUDAExtension(
21 | name='pointnet2._ext',
22 | sources=_ext_sources,
23 | include_dirs = [os.path.join(_ext_src_root, "include")],
24 | extra_compile_args={
25 | # "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))],
26 | # "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))],
27 | "cxx": [],
28 | "nvcc": ["-O3",
29 | "-DCUDA_HAS_FP16=1",
30 | "-D__CUDA_NO_HALF_OPERATORS__",
31 | "-D__CUDA_NO_HALF_CONVERSIONS__",
32 | "-D__CUDA_NO_HALF2_OPERATORS__",
33 | ]},)
34 | ],
35 | cmdclass={'build_ext': BuildExtension.with_options(use_ninja=True)}
36 | )
37 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/model/pose_estimation_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from feature_extraction import ViTEncoder
5 | from coarse_point_matching import CoarsePointMatching
6 | from fine_point_matching import FinePointMatching
7 | from transformer import GeometricStructureEmbedding
8 | from model_utils import sample_pts_feats
9 |
10 |
11 | class Net(nn.Module):
12 | def __init__(self, cfg):
13 | super(Net, self).__init__()
14 | self.cfg = cfg
15 | self.coarse_npoint = cfg.coarse_npoint
16 | self.fine_npoint = cfg.fine_npoint
17 |
18 | self.feature_extraction = ViTEncoder(cfg.feature_extraction, self.fine_npoint)
19 | self.geo_embedding = GeometricStructureEmbedding(cfg.geo_embedding)
20 | self.coarse_point_matching = CoarsePointMatching(cfg.coarse_point_matching)
21 | self.fine_point_matching = FinePointMatching(cfg.fine_point_matching)
22 |
23 | def forward(self, end_points):
24 | dense_pm, dense_fm, dense_po, dense_fo, radius = self.feature_extraction(end_points)
25 |
26 | # pre-compute geometric embeddings for geometric transformer
27 | bg_point = torch.ones(dense_pm.size(0),1,3).float().to(dense_pm.device) * 100
28 |
29 | sparse_pm, sparse_fm, fps_idx_m = sample_pts_feats(
30 | dense_pm, dense_fm, self.coarse_npoint, return_index=True
31 | )
32 | geo_embedding_m = self.geo_embedding(torch.cat([bg_point, sparse_pm], dim=1))
33 |
34 | sparse_po, sparse_fo, fps_idx_o = sample_pts_feats(
35 | dense_po, dense_fo, self.coarse_npoint, return_index=True
36 | )
37 | geo_embedding_o = self.geo_embedding(torch.cat([bg_point, sparse_po], dim=1))
38 |
39 | # coarse_point_matching
40 | end_points = self.coarse_point_matching(
41 | sparse_pm, sparse_fm, geo_embedding_m,
42 | sparse_po, sparse_fo, geo_embedding_o,
43 | radius, end_points,
44 | )
45 |
46 | # fine_point_matching
47 | end_points = self.fine_point_matching(
48 | dense_pm, dense_fm, geo_embedding_m, fps_idx_m,
49 | dense_po, dense_fo, geo_embedding_o, fps_idx_o,
50 | radius, end_points
51 | )
52 |
53 | return end_points
54 |
55 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/train.py:
--------------------------------------------------------------------------------
1 |
2 | import gorilla
3 | from tqdm import tqdm
4 | import argparse
5 | import os
6 | import sys
7 | import os.path as osp
8 | import time
9 | import logging
10 | import numpy as np
11 | import random
12 | import importlib
13 |
14 | import torch
15 | from torch.autograd import Variable
16 | import torch.optim as optim
17 |
18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19 | sys.path.append(os.path.join(BASE_DIR, 'provider'))
20 | sys.path.append(os.path.join(BASE_DIR, 'utils'))
21 | sys.path.append(os.path.join(BASE_DIR, 'model'))
22 | sys.path.append(os.path.join(BASE_DIR, 'model', 'pointnet2'))
23 |
24 | from solver import Solver, get_logger
25 | from loss_utils import Loss
26 |
27 | def get_parser():
28 | parser = argparse.ArgumentParser(
29 | description="Pose Estimation")
30 |
31 | parser.add_argument("--gpus",
32 | type=str,
33 | default="0",
34 | help="index of gpu")
35 | parser.add_argument("--model",
36 | type=str,
37 | default="pose_estimation_model",
38 | help="name of model")
39 | parser.add_argument("--config",
40 | type=str,
41 | default="config/base.yaml",
42 | help="path to config file")
43 | parser.add_argument("--exp_id",
44 | type=int,
45 | default=0,
46 | help="experiment id")
47 | parser.add_argument("--checkpoint_iter",
48 | type=int,
49 | default=-1,
50 | help="iter num. of checkpoint")
51 | args_cfg = parser.parse_args()
52 |
53 | return args_cfg
54 |
55 |
56 | def init():
57 | args = get_parser()
58 | exp_name = args.model + '_' + \
59 | osp.splitext(args.config.split("/")[-1])[0] + '_id' + str(args.exp_id)
60 | log_dir = osp.join("log", exp_name)
61 |
62 | cfg = gorilla.Config.fromfile(args.config)
63 | cfg.exp_name = exp_name
64 | cfg.gpus = args.gpus
65 | cfg.model_name = args.model
66 | cfg.log_dir = log_dir
67 | cfg.checkpoint_iter = args.checkpoint_iter
68 |
69 | if not os.path.isdir(log_dir):
70 | os.makedirs(log_dir)
71 | logger = get_logger(
72 | level_print=logging.INFO, level_save=logging.WARNING, path_file=log_dir+"/training_logger.log")
73 | gorilla.utils.set_cuda_visible_devices(gpu_ids=cfg.gpus)
74 |
75 | return logger, cfg
76 |
77 |
78 | if __name__ == "__main__":
79 | logger, cfg = init()
80 |
81 | logger.warning(
82 | "************************ Start Logging ************************")
83 | logger.info(cfg)
84 | logger.info("using gpu: {}".format(cfg.gpus))
85 |
86 | random.seed(cfg.rd_seed)
87 | torch.manual_seed(cfg.rd_seed)
88 |
89 | # model
90 | logger.info("=> creating model ...")
91 | MODEL = importlib.import_module(cfg.model_name)
92 | model = MODEL.Net(cfg.model)
93 | if hasattr(cfg, 'pretrain_dir') and cfg.pretrain_dir is not None:
94 | logger.info('loading pretrained backbone from {}'.format(cfg.pretrain_dir))
95 | key1, key2 = model.load_state_dict(torch.load(cfg.pretrain_dir)['model'], strict=False)
96 | if len(cfg.gpus) > 1:
97 | model = torch.nn.DataParallel(model, range(len(cfg.gpus.split(","))))
98 | model = model.cuda()
99 |
100 | loss = Loss().cuda()
101 | count_parameters = sum(gorilla.parameter_count(model).values())
102 | logger.warning("#Total parameters : {}".format(count_parameters))
103 |
104 | # dataloader
105 | batchsize = cfg.train_dataloader.bs
106 | num_epoch = cfg.training_epoch
107 |
108 | if cfg.lr_scheduler.type == 'WarmupCosineLR':
109 | num_iter = cfg.lr_scheduler.max_iters
110 | if hasattr(cfg, 'warmup_iter') and cfg.warmup_iter >0:
111 | num_iter = num_iter + cfg.warmup_iter
112 | iters_per_epoch = int(np.floor(num_iter / num_epoch))
113 | elif cfg.lr_scheduler.type == 'CyclicLR':
114 | iters_per_epoch = cfg.lr_scheduler.step_size_up+cfg.lr_scheduler.step_size_down
115 | train_dataset = importlib.import_module(cfg.train_dataset.name)
116 | train_dataset = train_dataset.Dataset(cfg.train_dataset, iters_per_epoch*batchsize)
117 |
118 |
119 | train_dataloader = torch.utils.data.DataLoader(
120 | train_dataset,
121 | batch_size=cfg.train_dataloader.bs,
122 | num_workers=cfg.train_dataloader.num_workers,
123 | shuffle=cfg.train_dataloader.shuffle,
124 | sampler=None,
125 | drop_last=cfg.train_dataloader.drop_last,
126 | pin_memory=cfg.train_dataloader.pin_memory,
127 | )
128 |
129 | dataloaders = {
130 | "train": train_dataloader,
131 | }
132 |
133 | # solver
134 | Trainer = Solver(model=model, loss=loss,
135 | dataloaders=dataloaders,
136 | logger=logger,
137 | cfg=cfg)
138 | Trainer.solve()
139 |
140 | logger.info('\nFinish!\n')
141 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/utils/bop_object_utils.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | Modified from https://github.com/rasmushaugaard/surfemb/blob/master/surfemb/data/obj.py
4 | '''
5 |
6 | import os
7 | import glob
8 | import json
9 | import numpy as np
10 | import trimesh
11 | from tqdm import tqdm
12 |
13 | from data_utils import (
14 | load_im,
15 | )
16 |
17 | class Obj:
18 | def __init__(
19 | self, obj_id,
20 | mesh: trimesh.Trimesh,
21 | model_points,
22 | diameter: float,
23 | symmetry_flag: int,
24 | template_path: str,
25 | n_template_view: int,
26 | ):
27 | self.obj_id = obj_id
28 | self.mesh = mesh
29 | self.model_points = model_points
30 | self.diameter = diameter
31 | self.symmetry_flag = symmetry_flag
32 | self._get_template(template_path, n_template_view)
33 |
34 | def get_item(self, return_color=False, sample_num=2048):
35 | if return_color:
36 | model_points, _, model_colors = trimesh.sample.sample_surface(self.mesh, sample_num, sample_color=True)
37 | model_points = model_points.astype(np.float32) / 1000.0
38 | return (model_points, model_colors, self.symmetry_flag)
39 | else:
40 | return (self.model_points, self.symmetry_flag)
41 |
42 | def _get_template(self, path, nView):
43 | if nView > 0:
44 | total_nView = len(glob.glob(os.path.join(path, 'rgb_*.png')))
45 |
46 | self.template = []
47 | self.template_mask = []
48 | self.template_pts = []
49 |
50 | for v in range(nView):
51 | i = int(total_nView / nView * v)
52 | rgb_path = os.path.join(path, 'rgb_'+str(i)+'.png')
53 | xyz_path = os.path.join(path, 'xyz_'+str(i)+'.npy')
54 | mask_path = os.path.join(path, 'mask_'+str(i)+'.png')
55 |
56 | rgb = load_im(rgb_path).astype(np.uint8)
57 | xyz = np.load(xyz_path).astype(np.float32) / 1000.0
58 | mask = load_im(mask_path).astype(np.uint8) == 255
59 |
60 | self.template.append(rgb)
61 | self.template_mask.append(mask)
62 | self.template_pts.append(xyz)
63 | else:
64 | self.template = None
65 | self.template_choose = None
66 | self.template_pts = None
67 |
68 | def get_template(self, view_idx):
69 | return self.template[view_idx], self.template_mask[view_idx], self.template_pts[view_idx]
70 |
71 |
72 | def load_obj(
73 | model_path, obj_id: int, sample_num: int,
74 | template_path: str,
75 | n_template_view: int,
76 | ):
77 | models_info = json.load(open(os.path.join(model_path, 'models_info.json')))
78 | mesh = trimesh.load_mesh(os.path.join(model_path, f'obj_{obj_id:06d}.ply'))
79 | model_points = mesh.sample(sample_num).astype(np.float32) / 1000.0
80 | diameter = models_info[str(obj_id)]['diameter'] / 1000.0
81 | if 'symmetries_continuous' in models_info[str(obj_id)]:
82 | symmetry_flag = 1
83 | elif 'symmetries_discrete' in models_info[str(obj_id)]:
84 | symmetry_flag = 1
85 | else:
86 | symmetry_flag = 0
87 | return Obj(
88 | obj_id, mesh, model_points, diameter, symmetry_flag,
89 | template_path, n_template_view
90 | )
91 |
92 |
93 | def load_objs(
94 | model_path='models',
95 | template_path='render_imgs',
96 | sample_num=512,
97 | n_template_view=0,
98 | show_progressbar=True
99 | ):
100 | objs = []
101 | obj_ids = sorted([int(p.split('/')[-1][4:10]) for p in glob.glob(os.path.join(model_path, '*.ply'))])
102 |
103 | if n_template_view>0:
104 | template_paths = sorted(glob.glob(os.path.join(template_path, '*')))
105 | assert len(template_paths) == len(obj_ids), '{} template_paths, {} obj_ids'.format(len(template_paths), len(obj_ids))
106 | else:
107 | template_paths = [None for _ in range(len(obj_ids))]
108 |
109 | cnt = 0
110 | for obj_id in tqdm(obj_ids, 'loading objects') if show_progressbar else obj_ids:
111 | objs.append(
112 | load_obj(model_path, obj_id, sample_num,
113 | template_paths[cnt], n_template_view)
114 | )
115 | cnt+=1
116 | return objs, obj_ids
117 |
118 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/utils/draw_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 |
5 | def calculate_2d_projections(coordinates_3d, intrinsics):
6 | """
7 | Input:
8 | coordinates: [3, N]
9 | intrinsics: [3, 3]
10 | Return
11 | projected_coordinates: [N, 2]
12 | """
13 | projected_coordinates = intrinsics @ coordinates_3d
14 | projected_coordinates = projected_coordinates[:2, :] / projected_coordinates[2, :]
15 | projected_coordinates = projected_coordinates.transpose()
16 | projected_coordinates = np.array(projected_coordinates, dtype=np.int32)
17 |
18 | return projected_coordinates
19 |
20 | def get_3d_bbox(scale, shift = 0):
21 | """
22 | Input:
23 | scale: [3] or scalar
24 | shift: [3] or scalar
25 | Return
26 | bbox_3d: [3, N]
27 |
28 | """
29 | if hasattr(scale, "__iter__"):
30 | bbox_3d = np.array([[scale[0] / 2, +scale[1] / 2, scale[2] / 2],
31 | [scale[0] / 2, +scale[1] / 2, -scale[2] / 2],
32 | [-scale[0] / 2, +scale[1] / 2, scale[2] / 2],
33 | [-scale[0] / 2, +scale[1] / 2, -scale[2] / 2],
34 | [+scale[0] / 2, -scale[1] / 2, scale[2] / 2],
35 | [+scale[0] / 2, -scale[1] / 2, -scale[2] / 2],
36 | [-scale[0] / 2, -scale[1] / 2, scale[2] / 2],
37 | [-scale[0] / 2, -scale[1] / 2, -scale[2] / 2]]) + shift
38 | else:
39 | bbox_3d = np.array([[scale / 2, +scale / 2, scale / 2],
40 | [scale / 2, +scale / 2, -scale / 2],
41 | [-scale / 2, +scale / 2, scale / 2],
42 | [-scale / 2, +scale / 2, -scale / 2],
43 | [+scale / 2, -scale / 2, scale / 2],
44 | [+scale / 2, -scale / 2, -scale / 2],
45 | [-scale / 2, -scale / 2, scale / 2],
46 | [-scale / 2, -scale / 2, -scale / 2]]) +shift
47 |
48 | bbox_3d = bbox_3d.transpose()
49 | return bbox_3d
50 |
51 | def draw_3d_bbox(img, imgpts, color, size=3):
52 | imgpts = np.int32(imgpts).reshape(-1, 2)
53 |
54 | # draw ground layer in darker color
55 | color_ground = (int(color[0] * 0.3), int(color[1] * 0.3), int(color[2] * 0.3))
56 | for i, j in zip([4, 5, 6, 7],[5, 7, 4, 6]):
57 | img = cv2.line(img, tuple(imgpts[i]), tuple(imgpts[j]), color_ground, size)
58 |
59 | # draw pillars in blue color
60 | color_pillar = (int(color[0]*0.6), int(color[1]*0.6), int(color[2]*0.6))
61 | for i, j in zip(range(4),range(4,8)):
62 | img = cv2.line(img, tuple(imgpts[i]), tuple(imgpts[j]), color_pillar, size)
63 |
64 | # finally, draw top layer in color
65 | for i, j in zip([0, 1, 2, 3],[1, 3, 0, 2]):
66 | img = cv2.line(img, tuple(imgpts[i]), tuple(imgpts[j]), color, size)
67 | return img
68 |
69 | def draw_3d_pts(img, imgpts, color, size=1):
70 | imgpts = np.int32(imgpts).reshape(-1, 2)
71 | for point in imgpts:
72 | img = cv2.circle(img, (point[0], point[1]), size, color, -1)
73 | return img
74 |
75 | def draw_detections(image, pred_rots, pred_trans, model_points, intrinsics, color=(255, 0, 0)):
76 | num_pred_instances = len(pred_rots)
77 | draw_image_bbox = image.copy()
78 | # 3d bbox
79 | scale = (np.max(model_points, axis=0) - np.min(model_points, axis=0))
80 | shift = np.mean(model_points, axis=0)
81 | bbox_3d = get_3d_bbox(scale, shift)
82 |
83 | # 3d point
84 | choose = np.random.choice(np.arange(len(model_points)), 512)
85 | pts_3d = model_points[choose].T
86 |
87 | for ind in range(num_pred_instances):
88 | # draw 3d bounding box
89 | transformed_bbox_3d = pred_rots[ind]@bbox_3d + pred_trans[ind][:,np.newaxis]
90 | projected_bbox = calculate_2d_projections(transformed_bbox_3d, intrinsics[ind])
91 | draw_image_bbox = draw_3d_bbox(draw_image_bbox, projected_bbox, color)
92 | # draw point cloud
93 | transformed_pts_3d = pred_rots[ind]@pts_3d + pred_trans[ind][:,np.newaxis]
94 | projected_pts = calculate_2d_projections(transformed_pts_3d, intrinsics[ind])
95 | draw_image_bbox = draw_3d_pts(draw_image_bbox, projected_pts, color)
96 |
97 | return draw_image_bbox
98 |
--------------------------------------------------------------------------------
/SAM-6D/Pose_Estimation_Model/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 |
5 | from model_utils import pairwise_distance
6 |
7 | def compute_correspondence_loss(
8 | end_points,
9 | atten_list,
10 | pts1,
11 | pts2,
12 | gt_r,
13 | gt_t,
14 | dis_thres=0.15,
15 | loss_str='coarse'
16 | ):
17 | CE = nn.CrossEntropyLoss(reduction ='none')
18 |
19 | gt_pts = (pts1-gt_t.unsqueeze(1))@gt_r
20 | dis_mat = torch.sqrt(pairwise_distance(gt_pts, pts2))
21 |
22 | dis1, label1 = dis_mat.min(2)
23 | fg_label1 = (dis1<=dis_thres).float()
24 | label1 = (fg_label1 * (label1.float()+1.0)).long()
25 |
26 | dis2, label2 = dis_mat.min(1)
27 | fg_label2 = (dis2<=dis_thres).float()
28 | label2 = (fg_label2 * (label2.float()+1.0)).long()
29 |
30 | # loss
31 | for idx, atten in enumerate(atten_list):
32 | l1 = CE(atten.transpose(1,2)[:,:,1:].contiguous(), label1).mean(1)
33 | l2 = CE(atten[:,:,1:].contiguous(), label2).mean(1)
34 | end_points[loss_str + '_loss' + str(idx)] = 0.5 * (l1 + l2)
35 |
36 | # acc
37 | pred_label = torch.max(atten_list[-1][:,1:,:], dim=2)[1]
38 | end_points[loss_str + '_acc'] = (pred_label==label1).float().mean(1)
39 |
40 | # pred foreground num
41 | fg_mask = (pred_label > 0).float()
42 | end_points[loss_str + '_fg_num'] = fg_mask.sum(1)
43 |
44 | # foreground point dis
45 | fg_label = fg_mask * (pred_label - 1)
46 | fg_label = fg_label.long()
47 | pred_pts = torch.gather(pts2, 1, fg_label.unsqueeze(2).repeat(1,1,3))
48 | pred_dis = torch.norm(pred_pts-gt_pts, dim=2)
49 | pred_dis = (pred_dis * fg_mask).sum(1) / (fg_mask.sum(1)+1e-8)
50 | end_points[loss_str + '_dis'] = pred_dis
51 |
52 | return end_points
53 |
54 |
55 |
56 | class Loss(nn.Module):
57 | def __init__(self):
58 | super(Loss, self).__init__()
59 |
60 | def forward(self, end_points):
61 | out_dicts = {'loss': 0}
62 | for key in end_points.keys():
63 | if 'coarse_' in key or 'fine_' in key:
64 | out_dicts[key] = end_points[key].mean()
65 | if 'loss' in key:
66 | out_dicts['loss'] = out_dicts['loss'] + end_points[key]
67 | out_dicts['loss'] = torch.clamp(out_dicts['loss'], max=100.0).mean()
68 | return out_dicts
69 |
70 |
--------------------------------------------------------------------------------
/SAM-6D/Render/render_bop_templates.py:
--------------------------------------------------------------------------------
1 | import blenderproc as bproc
2 |
3 | import os
4 | import argparse
5 | import json
6 | import cv2
7 | import numpy as np
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--dataset_name', help="The name of bop datasets")
11 | args = parser.parse_args()
12 |
13 | # set relative path of Data folder
14 | render_dir = os.path.dirname(os.path.abspath(__file__))
15 | bop_path = os.path.join(render_dir, '../Data/BOP')
16 | output_dir = os.path.join(render_dir, '../Data/BOP-Templates')
17 | cnos_cam_fpath = os.path.join(render_dir, '../Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level0.npy')
18 |
19 | bproc.init()
20 |
21 | if args.dataset_name == 'tless':
22 | model_folder = 'models_cad'
23 | else:
24 | model_folder = 'models'
25 |
26 | model_path = os.path.join(bop_path, args.dataset_name, model_folder)
27 | models_info = json.load(open(os.path.join(model_path, 'models_info.json')))
28 | for obj_id in models_info.keys():
29 | diameter = models_info[obj_id]['diameter']
30 | scale = 1 / diameter
31 | obj_fpath = os.path.join(model_path, f'obj_{int(obj_id):06d}.ply')
32 |
33 | cam_poses = np.load(cnos_cam_fpath)
34 | for idx, cam_pose in enumerate(cam_poses[:]):
35 |
36 | bproc.clean_up()
37 |
38 | # load object
39 | obj = bproc.loader.load_obj(obj_fpath)[0]
40 | obj.set_scale([scale, scale, scale])
41 | obj.set_cp("category_id", obj_id)
42 |
43 | if args.dataset_name == 'tless':
44 | color = [0.4, 0.4, 0.4, 0.]
45 | material = bproc.material.create('obj')
46 | material.set_principled_shader_value('Base Color', color)
47 | obj.set_material(0, material)
48 |
49 | # convert cnos camera poses to blender camera poses
50 | cam_pose[:3, 1:3] = -cam_pose[:3, 1:3]
51 | cam_pose[:3, -1] = cam_pose[:3, -1] * 0.001 * 2
52 | bproc.camera.add_camera_pose(cam_pose)
53 |
54 | # set light
55 | light_energy = 1000
56 | light_scale = 2.5
57 | light1 = bproc.types.Light()
58 | light1.set_type("POINT")
59 | light1.set_location([light_scale*cam_pose[:3, -1][0], light_scale*cam_pose[:3, -1][1], light_scale*cam_pose[:3, -1][2]])
60 | light1.set_energy(light_energy)
61 |
62 | bproc.renderer.set_max_amount_of_samples(1)
63 | # render the whole pipeline
64 | data = bproc.renderer.render()
65 | # render nocs
66 | data.update(bproc.renderer.render_nocs())
67 |
68 | # check save folder
69 | save_fpath = os.path.join(output_dir, args.dataset_name, f'obj_{int(obj_id):06d}')
70 | if not os.path.exists(save_fpath):
71 | os.makedirs(save_fpath)
72 |
73 | # save rgb image
74 | color_bgr_0 = data["colors"][0]
75 | color_bgr_0[..., :3] = color_bgr_0[..., :3][..., ::-1]
76 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(idx)+'.png'), color_bgr_0)
77 |
78 | # save mask
79 | mask_0 = data["nocs"][0][..., -1]
80 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(idx)+'.png'), mask_0*255)
81 |
82 | # save nocs
83 | xyz_0 = 2*(data["nocs"][0][..., :3] - 0.5)
84 | np.save(os.path.join(save_fpath,'xyz_'+str(idx)+'.npy'), xyz_0.astype(np.float16))
85 |
--------------------------------------------------------------------------------
/SAM-6D/Render/render_custom_templates.py:
--------------------------------------------------------------------------------
1 | import blenderproc as bproc
2 |
3 | import os
4 | import argparse
5 | import cv2
6 | import numpy as np
7 | import trimesh
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--cad_path', help="The path of CAD model")
11 | parser.add_argument('--output_dir', help="The path to save CAD templates")
12 | parser.add_argument('--normalize', default=True, help="Whether to normalize CAD model or not")
13 | parser.add_argument('--colorize', default=False, help="Whether to colorize CAD model or not")
14 | parser.add_argument('--base_color', default=0.05, help="The base color used in CAD model")
15 | args = parser.parse_args()
16 |
17 | # set the cnos camera path
18 | render_dir = os.path.dirname(os.path.abspath(__file__))
19 | cnos_cam_fpath = os.path.join(render_dir, '../Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level0.npy')
20 |
21 | bproc.init()
22 |
23 | def get_norm_info(mesh_path):
24 | mesh = trimesh.load(mesh_path, force='mesh')
25 |
26 | model_points = trimesh.sample.sample_surface(mesh, 1024)[0]
27 | model_points = model_points.astype(np.float32)
28 |
29 | min_value = np.min(model_points, axis=0)
30 | max_value = np.max(model_points, axis=0)
31 |
32 | radius = max(np.linalg.norm(max_value), np.linalg.norm(min_value))
33 |
34 | return 1/(2*radius)
35 |
36 |
37 | # load cnos camera pose
38 | cam_poses = np.load(cnos_cam_fpath)
39 |
40 | # calculating the scale of CAD model
41 | if args.normalize:
42 | scale = get_norm_info(args.cad_path)
43 | else:
44 | scale = 1
45 |
46 | for idx, cam_pose in enumerate(cam_poses):
47 |
48 | bproc.clean_up()
49 |
50 | # load object
51 | obj = bproc.loader.load_obj(args.cad_path)[0]
52 | obj.set_scale([scale, scale, scale])
53 | obj.set_cp("category_id", 1)
54 |
55 | # assigning material colors to untextured objects
56 | if args.colorize:
57 | color = [args.base_color, args.base_color, args.base_color, 0.]
58 | material = bproc.material.create('obj')
59 | material.set_principled_shader_value('Base Color', color)
60 | obj.set_material(0, material)
61 |
62 | # convert cnos camera poses to blender camera poses
63 | cam_pose[:3, 1:3] = -cam_pose[:3, 1:3]
64 | cam_pose[:3, -1] = cam_pose[:3, -1] * 0.001 * 2
65 | bproc.camera.add_camera_pose(cam_pose)
66 |
67 | # set light
68 | light_scale = 2.5
69 | light_energy = 1000
70 | light1 = bproc.types.Light()
71 | light1.set_type("POINT")
72 | light1.set_location([light_scale*cam_pose[:3, -1][0], light_scale*cam_pose[:3, -1][1], light_scale*cam_pose[:3, -1][2]])
73 | light1.set_energy(light_energy)
74 |
75 | bproc.renderer.set_max_amount_of_samples(50)
76 | # render the whole pipeline
77 | data = bproc.renderer.render()
78 | # render nocs
79 | data.update(bproc.renderer.render_nocs())
80 |
81 | # check save folder
82 | save_fpath = os.path.join(args.output_dir, "templates")
83 | if not os.path.exists(save_fpath):
84 | os.makedirs(save_fpath)
85 |
86 | # save rgb image
87 | color_bgr_0 = data["colors"][0]
88 | color_bgr_0[..., :3] = color_bgr_0[..., :3][..., ::-1]
89 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(idx)+'.png'), color_bgr_0)
90 |
91 | # save mask
92 | mask_0 = data["nocs"][0][..., -1]
93 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(idx)+'.png'), mask_0*255)
94 |
95 | # save nocs
96 | xyz_0 = 2*(data["nocs"][0][..., :3] - 0.5)
97 | np.save(os.path.join(save_fpath,'xyz_'+str(idx)+'.npy'), xyz_0.astype(np.float16))
--------------------------------------------------------------------------------
/SAM-6D/Render/render_gso_templates.py:
--------------------------------------------------------------------------------
1 | import blenderproc as bproc
2 |
3 | import os
4 | import cv2
5 | import numpy as np
6 | import trimesh
7 |
8 | # set relative path of Data folder
9 | render_dir = os.path.dirname(os.path.abspath(__file__))
10 | gso_path = os.path.join(render_dir, '../Data/MegaPose-Training-Data/MegaPose-GSO/google_scanned_objects')
11 | gso_norm_path = os.path.join(gso_path, 'models_normalized')
12 | output_dir = os.path.join(render_dir, '../Data/MegaPose-Training-Data/MegaPose-GSO/templates')
13 |
14 | bproc.init()
15 |
16 | def get_norm_info(mesh_path):
17 | mesh = trimesh.load(mesh_path, force='mesh')
18 |
19 | model_points = trimesh.sample.sample_surface(mesh, 1024)[0]
20 | model_points = model_points.astype(np.float32)
21 |
22 | min_value = np.min(model_points, axis=0)
23 | max_value = np.max(model_points, axis=0)
24 |
25 | radius = max(np.linalg.norm(max_value), np.linalg.norm(min_value))
26 |
27 | return 1/(2*radius)
28 |
29 |
30 | for idx, model_name in enumerate(os.listdir(gso_norm_path)):
31 | if not os.path.isdir(os.path.join(gso_norm_path, model_name)) or '.' in model_name:
32 | continue
33 | print('---------------------------'+str(model_name)+'-------------------------------------')
34 |
35 | save_fpath = os.path.join(output_dir, model_name)
36 | if not os.path.exists(save_fpath):
37 | os.makedirs(save_fpath)
38 |
39 | obj_fpath = os.path.join(gso_norm_path, model_name, 'meshes', 'model.obj')
40 | if not os.path.exists(obj_fpath):
41 | continue
42 |
43 | scale = get_norm_info(obj_fpath)
44 |
45 | bproc.clean_up()
46 |
47 | obj = bproc.loader.load_obj(obj_fpath)[0]
48 | obj.set_scale([scale, scale, scale])
49 | obj.set_cp("category_id", idx)
50 |
51 | # set light
52 | light1 = bproc.types.Light()
53 | light1.set_type("POINT")
54 | light1.set_location([-3, -3, -3])
55 | light1.set_energy(1000)
56 |
57 | light2 = bproc.types.Light()
58 | light2.set_type("POINT")
59 | light2.set_location([3, 3, 3])
60 | light2.set_energy(1000)
61 |
62 | location = [[-1, -1, -1], [1, 1, 1]]
63 | # set camera locations around the object
64 | for loc in location:
65 | # compute rotation based on vector going from location towards the location of object
66 | rotation_matrix = bproc.camera.rotation_from_forward_vec(obj.get_location() - loc)
67 | # add homog cam pose based on location and rotation
68 | cam2world_matrix = bproc.math.build_transformation_mat(loc, rotation_matrix)
69 | bproc.camera.add_camera_pose(cam2world_matrix)
70 |
71 | bproc.renderer.set_max_amount_of_samples(50)
72 | # render the whole pipeline
73 | data = bproc.renderer.render()
74 | # render nocs
75 | data.update(bproc.renderer.render_nocs())
76 |
77 | # save rgb images
78 | color_bgr_0 = data["colors"][0]
79 | color_bgr_0[..., :3] = color_bgr_0[..., :3][..., ::-1]
80 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(0)+'.png'), color_bgr_0)
81 | color_bgr_1 = data["colors"][1]
82 | color_bgr_1[..., :3] = color_bgr_1[..., :3][..., ::-1]
83 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(1)+'.png'), color_bgr_1)
84 |
85 | # save masks
86 | mask_0 = data["nocs"][0][..., -1]
87 | mask_1 = data["nocs"][1][..., -1]
88 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(0)+'.png'), mask_0*255)
89 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(1)+'.png'), mask_1*255)
90 |
91 | # save nocs
92 | xyz_0 = 2*(data["nocs"][0][..., :3] - 0.5)
93 | xyz_1 = 2*(data["nocs"][1][..., :3] - 0.5)
94 | np.save(os.path.join(save_fpath,'xyz_'+str(0)+'.npy'), xyz_0.astype(np.float16))
95 | np.save(os.path.join(save_fpath,'xyz_'+str(1)+'.npy'), xyz_1.astype(np.float16))
96 |
97 |
--------------------------------------------------------------------------------
/SAM-6D/Render/render_shapenet_templates.py:
--------------------------------------------------------------------------------
1 | import blenderproc as bproc
2 |
3 | import os
4 | import cv2
5 | import numpy as np
6 | import trimesh
7 |
8 | # set relative path of Data folder
9 | render_dir = os.path.dirname(os.path.abspath(__file__))
10 | shapenet_path = os.path.join(render_dir, '../Data/MegaPose-Training-Data/MegaPose-ShapeNetCore/shapenetcorev2')
11 | shapenet_orig_path = os.path.join(shapenet_path, 'models_orig')
12 | output_dir = os.path.join(render_dir, '../Data/MegaPose-Training-Data/MegaPose-ShapeNetCore/templates')
13 |
14 | bproc.init()
15 |
16 | def get_norm_info(mesh_path):
17 | mesh = trimesh.load(mesh_path, force='mesh')
18 |
19 | model_points = trimesh.sample.sample_surface(mesh, 1024)[0]
20 | model_points = model_points.astype(np.float32)
21 |
22 | min_value = np.min(model_points, axis=0)
23 | max_value = np.max(model_points, axis=0)
24 |
25 | radius = max(np.linalg.norm(max_value), np.linalg.norm(min_value))
26 |
27 | return 1/(2*radius)
28 |
29 |
30 | for synset_id in os.listdir(shapenet_orig_path):
31 | synset_fpath = os.path.join(shapenet_orig_path, synset_id)
32 | if not os.path.isdir(synset_fpath) or '.' in synset_id:
33 | continue
34 | print('---------------------------'+str(synset_id)+'-------------------------------------')
35 | for idx, source_id in enumerate(os.listdir(synset_fpath)):
36 | print('---------------------------'+str(source_id)+'-------------------------------------')
37 | save_synset_folder = os.path.join(output_dir, synset_id)
38 | if not os.path.exists(save_synset_folder):
39 | os.makedirs(save_synset_folder)
40 |
41 | save_fpath = os.path.join(save_synset_folder, source_id)
42 | if not os.path.exists(save_fpath):
43 | os.mkdir(save_fpath)
44 | else:
45 | continue
46 |
47 | cad_path = os.path.join(shapenet_orig_path, synset_id, source_id)
48 | obj_fpath = os.path.join(cad_path, 'models', 'model_normalized.obj')
49 |
50 | if not os.path.exists(obj_fpath):
51 | continue
52 |
53 | scale = get_norm_info(obj_fpath)
54 |
55 | bproc.clean_up()
56 |
57 | obj = bproc.loader.load_shapenet(shapenet_orig_path, synset_id, source_id, move_object_origin=False)
58 | obj.set_scale([scale, scale, scale])
59 | obj.set_cp("category_id", idx)
60 |
61 | # set light
62 | light1 = bproc.types.Light()
63 | light1.set_type("POINT")
64 | light1.set_location([-3, -3, -3])
65 | light1.set_energy(1000)
66 |
67 | light2 = bproc.types.Light()
68 | light2.set_type("POINT")
69 | light2.set_location([3, 3, 3])
70 | light2.set_energy(1000)
71 |
72 | location = [[-1, -1, -1], [1, 1, 1]]
73 | # set camera locations around the object
74 | for loc in location:
75 | # compute rotation based on vector going from location towards the location of object
76 | rotation_matrix = bproc.camera.rotation_from_forward_vec(obj.get_location() - loc)
77 | # add homog cam pose based on location and rotation
78 | cam2world_matrix = bproc.math.build_transformation_mat(loc, rotation_matrix)
79 | bproc.camera.add_camera_pose(cam2world_matrix)
80 |
81 | bproc.renderer.set_max_amount_of_samples(1)
82 | # render the whole pipeline
83 | data = bproc.renderer.render()
84 | # render nocs
85 | data.update(bproc.renderer.render_nocs())
86 |
87 | # save rgb images
88 | color_bgr_0 = data["colors"][0]
89 | color_bgr_0[..., :3] = color_bgr_0[..., :3][..., ::-1]
90 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(0)+'.png'), color_bgr_0)
91 | color_bgr_1 = data["colors"][1]
92 | color_bgr_1[..., :3] = color_bgr_1[..., :3][..., ::-1]
93 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(1)+'.png'), color_bgr_1)
94 |
95 | # save masks
96 | mask_0 = data["nocs"][0][..., -1]
97 | mask_1 = data["nocs"][1][..., -1]
98 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(0)+'.png'), mask_0*255)
99 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(1)+'.png'), mask_1*255)
100 |
101 | # save nocs
102 | xyz_0 = 2*(data["nocs"][0][..., :3] - 0.5)
103 | xyz_1 = 2*(data["nocs"][1][..., :3] - 0.5)
104 | # xyz need to rotate 90 degree to match CAD
105 | rot90 = np.array([[1, 0, 0],
106 | [0, 0, 1],
107 | [0, -1, 0]])
108 | h, w = xyz_0.shape[0], xyz_0.shape[1]
109 |
110 | xyz_0 = ((rot90 @ xyz_0.reshape(-1, 3).T).T).reshape(h, w, 3)
111 | xyz_1 = ((rot90 @ xyz_1.reshape(-1, 3).T).T).reshape(h, w, 3)
112 | np.save(os.path.join(save_fpath,'xyz_'+str(0)+'.npy'), xyz_0.astype(np.float16))
113 | np.save(os.path.join(save_fpath,'xyz_'+str(1)+'.npy'), xyz_1.astype(np.float16))
114 |
115 |
--------------------------------------------------------------------------------
/SAM-6D/demo.sh:
--------------------------------------------------------------------------------
1 | # Render CAD templates
2 | cd Render
3 | blenderproc run render_custom_templates.py --output_dir $OUTPUT_DIR --cad_path $CAD_PATH #--colorize True
4 |
5 |
6 | # Run instance segmentation model
7 | export SEGMENTOR_MODEL=sam
8 |
9 | cd ../Instance_Segmentation_Model
10 | python run_inference_custom.py --segmentor_model $SEGMENTOR_MODEL --output_dir $OUTPUT_DIR --cad_path $CAD_PATH --rgb_path $RGB_PATH --depth_path $DEPTH_PATH --cam_path $CAMERA_PATH
11 |
12 |
13 | # Run pose estimation model
14 | export SEG_PATH=$OUTPUT_DIR/sam6d_results/detection_ism.json
15 |
16 | cd ../Pose_Estimation_Model
17 | python run_inference_custom.py --output_dir $OUTPUT_DIR --cad_path $CAD_PATH --rgb_path $RGB_PATH --depth_path $DEPTH_PATH --cam_path $CAMERA_PATH --seg_path $SEG_PATH
18 |
19 |
--------------------------------------------------------------------------------
/SAM-6D/environment.yaml:
--------------------------------------------------------------------------------
1 | name: sam6d
2 | channels:
3 | - xformers
4 | - conda-forge
5 | - pytorch
6 | - nvidia
7 | - defaults
8 | dependencies:
9 | - pip
10 | - python=3.9.6
11 | - pip:
12 | - torch==2.0.0
13 | - torchvision==0.15.1
14 | - fvcore
15 | - xformers==0.0.18
16 | - torchmetrics==0.10.3
17 | - blenderproc==2.6.1
18 | - opencv-python
19 | # ISM
20 | - omegaconf
21 | - ruamel.yaml
22 | - hydra-colorlog
23 | - hydra-core
24 | - gdown
25 | - pandas
26 | - imageio
27 | - pyrender
28 | - pytorch-lightning==1.8.1
29 | - pycocotools
30 | - distinctipy
31 | - git+https://github.com/facebookresearch/segment-anything.git # SAM
32 | - ultralytics==8.0.135 # FastSAM
33 | # PEM
34 | - timm
35 | - gorilla-core==0.2.7.8
36 | - trimesh==4.0.8
37 | - gpustat==1.0.0
38 | - imgaug
39 | - einops
--------------------------------------------------------------------------------
/SAM-6D/prepare.sh:
--------------------------------------------------------------------------------
1 | ### Create conda environment
2 | conda env create -f environment.yaml
3 | conda activate sam6d
4 |
5 | ### Install pointnet2
6 | cd Pose_Estimation_Model/model/pointnet2
7 | python setup.py install
8 | cd ../../../
9 |
10 | ### Download ISM pretrained model
11 | cd Instance_Segmentation_Model
12 | python download_sam.py
13 | python download_fastsam.py
14 | python download_dinov2.py
15 | cd ../
16 |
17 | ### Download PEM pretrained model
18 | cd Pose_Estimation_Model
19 | python download_sam6d-pem.py
20 |
--------------------------------------------------------------------------------
/pics/overview_pem.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/pics/overview_pem.png
--------------------------------------------------------------------------------
/pics/overview_sam_6d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/pics/overview_sam_6d.png
--------------------------------------------------------------------------------
/pics/vis.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/pics/vis.gif
--------------------------------------------------------------------------------