├── README.md ├── SAM-6D ├── Data │ ├── Example │ │ ├── camera.json │ │ ├── depth.png │ │ ├── obj_000005.ply │ │ ├── outputs │ │ │ └── sam6d_results │ │ │ │ ├── vis_ism.png │ │ │ │ └── vis_pem.png │ │ └── rgb.png │ └── README.md ├── Instance_Segmentation_Model │ ├── LICENSE │ ├── README.md │ ├── configs │ │ ├── callback │ │ │ ├── base.yaml │ │ │ ├── checkpoint │ │ │ │ └── base.yaml │ │ │ └── lr │ │ │ │ └── base.yaml │ │ ├── data │ │ │ └── bop.yaml │ │ ├── download.yaml │ │ ├── hydra │ │ │ ├── hydra_logging │ │ │ │ ├── console.yaml │ │ │ │ ├── custom.yaml │ │ │ │ └── rich.yaml │ │ │ └── job_logging │ │ │ │ ├── console.yaml │ │ │ │ ├── custom.yaml │ │ │ │ └── rich.yaml │ │ ├── machine │ │ │ ├── local.yaml │ │ │ ├── slurm.yaml │ │ │ └── trainer │ │ │ │ ├── local.yaml │ │ │ │ └── slurm.yaml │ │ ├── model │ │ │ ├── ISM_fastsam.yaml │ │ │ ├── ISM_sam.yaml │ │ │ ├── descriptor_model │ │ │ │ └── dinov2.yaml │ │ │ └── segmentor_model │ │ │ │ ├── fast_sam.yaml │ │ │ │ └── sam.yaml │ │ ├── run_inference.yaml │ │ └── user │ │ │ └── default.yaml │ ├── download_dinov2.py │ ├── download_fastsam.py │ ├── download_sam.py │ ├── environment.yml │ ├── exp.sh │ ├── model │ │ ├── __init__.py │ │ ├── detector.py │ │ ├── dinov2.py │ │ ├── fast_sam.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── block.py │ │ │ ├── dino_head.py │ │ │ ├── drop_path.py │ │ │ ├── layer_scale.py │ │ │ ├── mlp.py │ │ │ ├── patch_embed.py │ │ │ └── swiglu_ffn.py │ │ ├── loss.py │ │ ├── sam.py │ │ ├── utils.py │ │ └── vision_transformer.py │ ├── provider │ │ ├── base_bop.py │ │ ├── bop.py │ │ └── bop_pbr.py │ ├── run_inference.py │ ├── run_inference_custom.py │ ├── segment_anything │ │ ├── __init__.py │ │ ├── automatic_mask_generator.py │ │ ├── build_sam.py │ │ ├── modeling │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── image_encoder.py │ │ │ ├── mask_decoder.py │ │ │ ├── prompt_encoder.py │ │ │ ├── sam.py │ │ │ └── transformer.py │ │ ├── predictor.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── amg.py │ │ │ ├── onnx.py │ │ │ └── transforms.py │ └── utils │ │ ├── bbox_utils.py │ │ ├── inout.py │ │ ├── logging.py │ │ ├── poses │ │ ├── create_template_poses.py │ │ ├── find_neighbors.py │ │ ├── fps.py │ │ ├── pose_utils.py │ │ ├── predefined_poses │ │ │ ├── cam_poses_level0.npy │ │ │ ├── cam_poses_level1.npy │ │ │ ├── cam_poses_level2.npy │ │ │ ├── idx_all_level0_in_level2.npy │ │ │ ├── idx_all_level1_in_level2.npy │ │ │ ├── idx_all_level2_in_level2.npy │ │ │ ├── obj_poses_level0.npy │ │ │ ├── obj_poses_level1.npy │ │ │ └── obj_poses_level2.npy │ │ └── pyrender.py │ │ ├── trimesh_utils.py │ │ └── weight.py ├── Pose_Estimation_Model │ ├── README.md │ ├── config │ │ └── base.yaml │ ├── dependencies.sh │ ├── download_sam6d-pem.py │ ├── model │ │ ├── coarse_point_matching.py │ │ ├── feature_extraction.py │ │ ├── fine_point_matching.py │ │ ├── pointnet2 │ │ │ ├── _ext_src │ │ │ │ ├── include │ │ │ │ │ ├── ball_query.h │ │ │ │ │ ├── cuda_utils.h │ │ │ │ │ ├── group_points.h │ │ │ │ │ ├── interpolate.h │ │ │ │ │ ├── sampling.h │ │ │ │ │ └── utils.h │ │ │ │ └── src │ │ │ │ │ ├── ball_query.cpp │ │ │ │ │ ├── ball_query_gpu.cu │ │ │ │ │ ├── bindings.cpp │ │ │ │ │ ├── group_points.cpp │ │ │ │ │ ├── group_points_gpu.cu │ │ │ │ │ ├── interpolate.cpp │ │ │ │ │ ├── interpolate_gpu.cu │ │ │ │ │ ├── sampling.cpp │ │ │ │ │ └── sampling_gpu.cu │ │ │ ├── pointnet2_modules.py │ │ │ ├── pointnet2_test.py │ │ │ ├── pointnet2_utils.py │ │ │ ├── pytorch_utils.py │ │ │ └── setup.py │ │ ├── pose_estimation_model.py │ │ └── transformer.py │ ├── provider │ │ ├── bop_test_dataset.py │ │ └── training_dataset.py │ ├── run_inference_custom.py │ ├── test_bop.py │ ├── train.py │ └── utils │ │ ├── bop_object_utils.py │ │ ├── data_utils.py │ │ ├── draw_utils.py │ │ ├── loss_utils.py │ │ ├── model_utils.py │ │ └── solver.py ├── Render │ ├── render_bop_templates.py │ ├── render_custom_templates.py │ ├── render_gso_templates.py │ └── render_shapenet_templates.py ├── demo.sh ├── environment.yaml └── prepare.sh └── pics ├── overview_pem.png ├── overview_sam_6d.png └── vis.gif /README.md: -------------------------------------------------------------------------------- 1 | #

SAM-6D: Segment Anything Model Meets Zero-Shot 6D Object Pose Estimation

2 | 3 | ####

[Jiehong Lin](https://jiehonglin.github.io/), [Lihua Liu](https://github.com/foollh), [Dekun Lu](https://github.com/WuTanKun), [Kui Jia](http://kuijia.site/)

4 | ####

CVPR 2024

5 | ####

[[Paper]](https://arxiv.org/abs/2311.15707)

6 | 7 |

8 | 9 |

10 | 11 | 12 | ## News 13 | - [2024/03/07] We publish an updated version of our paper on [ArXiv](https://arxiv.org/abs/2311.15707). 14 | - [2024/02/29] Our paper is accepted by CVPR2024! 15 | 16 | 17 | ## Update Log 18 | - [2024/03/05] We update the demo to support [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), you can do this by specifying `SEGMENTOR_MODEL=fastsam` in demo.sh. 19 | - [2024/03/03] We upload a [docker image](https://hub.docker.com/r/lihualiu/sam-6d/tags) for running custom data. 20 | - [2024/03/01] We update the released [model](https://drive.google.com/file/d/1joW9IvwsaRJYxoUmGo68dBVg-HcFNyI7/view?usp=sharing) of PEM. For the new model, a larger batchsize of 32 is set, while that of the old is 12. 21 | 22 | ## Overview 23 | In this work, we employ Segment Anything Model as an advanced starting point for **zero-shot 6D object pose estimation** from RGB-D images, and propose a novel framework, named **SAM-6D**, which utilizes the following two dedicated sub-networks to realize the focused task: 24 | - [x] [Instance Segmentation Model](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Instance_Segmentation_Model) 25 | - [x] [Pose Estimation Model](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Pose_Estimation_Model) 26 | 27 | 28 |

29 | 30 |

31 | 32 | 33 | ## Getting Started 34 | 35 | ### 1. Preparation 36 | Please clone the repository locally: 37 | ``` 38 | git clone https://github.com/JiehongLin/SAM-6D.git 39 | ``` 40 | Install the environment and download the model checkpoints: 41 | ``` 42 | cd SAM-6D 43 | sh prepare.sh 44 | ``` 45 | We also provide a [docker image](https://hub.docker.com/r/lihualiu/sam-6d/tags) for convenience. 46 | 47 | ### 2. Evaluation on the custom data 48 | ``` 49 | # set the paths 50 | export CAD_PATH=Data/Example/obj_000005.ply # path to a given cad model(mm) 51 | export RGB_PATH=Data/Example/rgb.png # path to a given RGB image 52 | export DEPTH_PATH=Data/Example/depth.png # path to a given depth map(mm) 53 | export CAMERA_PATH=Data/Example/camera.json # path to given camera intrinsics 54 | export OUTPUT_DIR=Data/Example/outputs # path to a pre-defined file for saving results 55 | 56 | # run inference 57 | cd SAM-6D 58 | sh demo.sh 59 | ``` 60 | 61 | 62 | 63 | ## Citation 64 | If you find our work useful in your research, please consider citing: 65 | 66 | @article{lin2023sam, 67 | title={SAM-6D: Segment Anything Model Meets Zero-Shot 6D Object Pose Estimation}, 68 | author={Lin, Jiehong and Liu, Lihua and Lu, Dekun and Jia, Kui}, 69 | journal={arXiv preprint arXiv:2311.15707}, 70 | year={2023} 71 | } 72 | 73 | 74 | ## Contact 75 | 76 | If you have any questions, please feel free to contact the authors. 77 | 78 | Jiehong Lin: [mortimer.jh.lin@gmail.com](mailto:mortimer.jh.lin@gmail.com) 79 | 80 | Lihua Liu: [lihualiu.scut@gmail.com](mailto:lihualiu.scut@gmail.com) 81 | 82 | Dekun Lu: [derkunlu@gmail.com](mailto:derkunlu@gmail.com) 83 | 84 | Kui Jia: [kuijia@gmail.com](kuijia@gmail.com) 85 | 86 | -------------------------------------------------------------------------------- /SAM-6D/Data/Example/camera.json: -------------------------------------------------------------------------------- 1 | {"cam_K": [572.4114, 0.0, 325.2611, 0.0, 573.57043, 242.04899, 0.0, 0.0, 1.0], "depth_scale": 1.0} -------------------------------------------------------------------------------- /SAM-6D/Data/Example/depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Data/Example/depth.png -------------------------------------------------------------------------------- /SAM-6D/Data/Example/outputs/sam6d_results/vis_ism.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Data/Example/outputs/sam6d_results/vis_ism.png -------------------------------------------------------------------------------- /SAM-6D/Data/Example/outputs/sam6d_results/vis_pem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Data/Example/outputs/sam6d_results/vis_pem.png -------------------------------------------------------------------------------- /SAM-6D/Data/Example/rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Data/Example/rgb.png -------------------------------------------------------------------------------- /SAM-6D/Data/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Data Structure 3 | Our data structure in `Data` folder is constructed as follows: 4 | ``` 5 | Data 6 | ├── MegaPose-Training-Data 7 | ├── MegaPose-GSO 8 | ├──google_scanned_objects 9 | ├──templates 10 | └──train_pbr_web 11 | ├── MegaPose-ShapeNetCore 12 | ├──shapenetcorev2 13 | ├──templates 14 | └──train_pbr_web 15 | ├── BOP # https://bop.felk.cvut.cz/datasets/ 16 | ├──tudl 17 | ├──lmo 18 | ├──ycbv 19 | ├──icbin 20 | ├──hb 21 | ├──itodd 22 | └──tless 23 | └── BOP-Templates 24 | ├──tudl 25 | ├──lmo 26 | ├──ycbv 27 | ├──icbin 28 | ├──hb 29 | ├──itodd 30 | └──tless 31 | ``` 32 | 33 | 34 | ## Data Download 35 | 36 | ### Training Datasets 37 | For training the Pose Estimation Model, you may download the rendered images of [c](https://github.com/thodan/bop_toolkit/blob/master/docs/bop_challenge_2023_training_datasets.md) provided by BOP official in the respective `MegaPose-Training-Data/MegaPose-GSO/train_pbr_web` and `MegaPose-Training-Data/MegaPose-ShapeNetCore/train_pbr_web` folders. 38 | 39 | We use the [pre-processed object models](https://www.paris.inria.fr/archive_ylabbeprojectsdata/megapose/tars/) of the two datasets, provided by [MegePose](https://github.com/megapose6d/megapose6d), and download them in the `MegaPose-Training-Data/MegaPose-GSO/google_scanned_objects` and `MegaPose-Training-Data/MegaPose-ShapeNetCore/shapenetcorev2` folders, respectively. 40 | 41 | 42 | ### BOP Datasets 43 | To evaluate our SAM-6D on BOP datasets, you may download the test data and the object CAD models of the seven core datasets from [BOP official](https://bop.felk.cvut.cz/datasets/). For each dataset, the structure could be constructed as follows: 44 | 45 | ``` 46 | BOP 47 | ├── lmo 48 | ├──models # object CAD models 49 | ├──test # bop19 test set 50 | ├──(train_pbr) # maybe used in instance segmentation 51 | ... 52 | ... 53 | ``` 54 | 55 | You may also download the `train_pbr` data of the datasets for template selection in the Instance Segmentation Model following [CNOS](https://github.com/nv-nguyen/cnos?tab=readme-ov-file). 56 | 57 | 58 | 59 | ## Template Rendering 60 | 61 | ### Requirements 62 | 63 | * blenderproc 64 | * trimesh 65 | * numpy 66 | * cv2 67 | 68 | 69 | ### Template Rendering of Training Objects 70 | 71 | We generate two-view templates for each training object via [Blenderproc](https://github.com/DLR-RM/BlenderProc). You may run the following commands to render the templates for `MegaPose-GSO` dataset: 72 | 73 | ``` 74 | cd ../Render/ 75 | blenderproc run render_gso_templates.py 76 | ``` 77 | and the commands for `shapenetcorev2` dataset: 78 | 79 | ``` 80 | cd ../Render/ 81 | blenderproc run render_shapenet_templates.py 82 | ``` 83 | 84 | ### Template Rendering of Test Objects 85 | We generate 42 templates for each test object following the [CNOS](https://github.com/nv-nguyen/cnos?tab=readme-ov-file) via [Blenderproc](https://github.com/DLR-RM/BlenderProc). 86 | 87 | 88 | #### Custom Objects 89 | 90 | For a custom object, you can run the following commands to render the templates: 91 | ``` 92 | cd ../Render/ 93 | blenderproc run render_custom_templates.py --cad_path $CAD_PATH --output_dir $OUTPUT_DIR 94 | ``` 95 | The string "CAD_PATH" is the path to your CAD and the string "OUTPUT_DIR" is the path to save templates. 96 | 97 | 98 | #### BOP Datasets 99 | You may run the following commands to render the tempates for the objects in BOP datasets: 100 | ``` 101 | cd ../Render/ 102 | blenderproc run render_bop_templates.py --dataset_name $DATASET 103 | ``` 104 | The string "DATASET" could be set as `lmo`, `icbin`, `itodd`, `hb`, `tless`, `tudl` or `ycbv`. We also offer downloadable rendered templates [[link](https://drive.google.com/drive/folders/1fXt5Z6YDPZTJICZcywBUhu5rWnPvYAPI?usp=sharing)]. 105 | 106 | 107 | ## Acknowledgements 108 | - [MegaPose](https://github.com/megapose6d/megapose6d) 109 | - [CNOS](https://github.com/nv-nguyen/cnos?tab=readme-ov-file) 110 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Van Nguyen Nguyen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/README.md: -------------------------------------------------------------------------------- 1 | # Instance Segmentation Model (ISM) for SAM-6D 2 | 3 | 4 | ## Requirements 5 | The code has been tested with 6 | - python 3.9.6 7 | - pytorch 2.0.0 8 | - CUDA 11.3 9 | 10 | Create conda environment: 11 | 12 | ``` 13 | conda env create -f environment.yml 14 | conda activate sam6d-ism 15 | 16 | # for using SAM 17 | pip install git+https://github.com/facebookresearch/segment-anything.git 18 | 19 | # for using fastSAM 20 | pip install ultralytics==8.0.135 21 | ``` 22 | 23 | 24 | ## Data Preparation 25 | 26 | Please refer to [[link](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Data)] for more details. 27 | 28 | 29 | ## Foundation Model Download 30 | 31 | Download model weights of [Segmenting Anything](https://github.com/facebookresearch/segment-anything): 32 | ``` 33 | python download_sam.py 34 | ``` 35 | 36 | Download model weights of [Fast Segmenting Anything](https://github.com/CASIA-IVA-Lab/FastSAM): 37 | ``` 38 | python download_fastsam.py 39 | ``` 40 | 41 | Download model weights of ViT pre-trained by [DINOv2](https://github.com/facebookresearch/dinov2): 42 | ``` 43 | python download_dinov2.py 44 | ``` 45 | 46 | 47 | ## Evaluation on BOP Datasets 48 | 49 | To evaluate the model on BOP datasets, please run the following commands: 50 | 51 | ``` 52 | # Specify a specific GPU 53 | export CUDA_VISIBLE_DEVICES=0 54 | 55 | # with sam 56 | python run_inference.py dataset_name=$DATASET 57 | 58 | # with fastsam 59 | python run_inference.py dataset_name=$DATASET model=ISM_fastsam 60 | ``` 61 | 62 | The string "DATASET" could be set as `lmo`, `icbin`, `itodd`, `hb`, `tless`, `tudl` or `ycbv`. 63 | 64 | 65 | ## Acknowledgements 66 | 67 | - [CNOS](https://github.com/nv-nguyen/cnos) 68 | - [SAM](https://github.com/facebookresearch/segment-anything) 69 | - [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) 70 | - [DINOv2](https://github.com/facebookresearch/dinov2) 71 | 72 | 73 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/callback/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - checkpoint: base 3 | - lr: base 4 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/callback/checkpoint/base.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 2 | 3 | dirpath: ${save_dir}/checkpoints 4 | save_last: true 5 | verbose: true 6 | save_top_k: -1 7 | every_n_train_steps: 10000 # number of backprop step -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/callback/lr/base.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 2 | 3 | logging_interval: step -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/data/bop.yaml: -------------------------------------------------------------------------------- 1 | root_dir: ${machine.root_dir}/BOP/ 2 | source_url: https://bop.felk.cvut.cz/media/data/bop_datasets/ 3 | unzip_mode: unzip 4 | 5 | reference_dataloader: 6 | _target_: provider.bop.BOPTemplate 7 | obj_ids: 8 | template_dir: ${machine.root_dir}/BOP/ 9 | level_templates: 0 10 | pose_distribution: all 11 | processing_config: 12 | image_size: ${model.descriptor_model.image_size} 13 | max_num_scenes: 10 # config for reference frames selection 14 | max_num_frames: 500 15 | min_visib_fract: 0.8 16 | num_references: 200 17 | use_visible_mask: True 18 | 19 | query_dataloader: 20 | _target_: provider.bop.BaseBOPTest 21 | root_dir: ${machine.root_dir}/BOP/ 22 | split: 23 | reset_metaData: True 24 | processing_config: 25 | image_size: ${model.descriptor_model.image_size} 26 | 27 | train_datasets: 28 | megapose-gso: 29 | identifier: bop23_datasets/megapose-gso/gso_models.json 30 | mapping_image_key: /bop23_datasets/megapose-gso/train_pbr_web/key_to_shard.json 31 | prefix: bop23_datasets/megapose-gso/train_pbr_web/ 32 | shard_ids: [0, 1039] 33 | megapose-shapenet: 34 | identifier: bop23_datasets/megapose-shapenet/shapenet_models.json 35 | mapping_image_key: bop23_datasets/megapose-shapenet/train_pbr_web/key_to_shard.json 36 | prefix: bop23_datasets/megapose-shapenet/train_pbr_web 37 | shard_ids: [0, 1039] 38 | 39 | datasets: 40 | lm: 41 | cad: lm_models.zip 42 | test: lm_test_bop19.zip 43 | pbr_train: lm_train_pbr.zip 44 | obj_names: [001_ape, 002_benchvise, 003_bowl, 004_camera, 005_can, 006_cat, 007_cup, 008_driller, 009_duck, 010_eggbox, 011_glue, 012_holepuncher, 013_iron, 014_lamp, 015_phone] 45 | lmo: 46 | cad: lmo_models.zip 47 | test: lmo_test_bop19.zip 48 | pbr_train: lm_train_pbr.zip 49 | obj_names: [001_ape, 005_can, 006_cat, 008_driller, 009_duck, 010_eggbox, 011_glue, 012_holepuncher] 50 | tless: 51 | cad: tless_models.zip 52 | test: tless_test_primesense_bop19.zip 53 | pbr_train: tless_train_pbr.zip 54 | obj_names: [001_obj, 002_obj, 003_obj, 004_obj, 005_obj, 006_obj, 007_obj, 008_obj, 009_obj, 010_obj, 011_obj, 012_obj, 013_obj, 014_obj, 015_obj, 016_obj, 017_obj, 018_obj, 019_obj, 020_obj, 021_obj, 022_obj, 023_obj, 024_obj, 025_obj, 026_obj, 027_obj, 028_obj, 029_obj, 030_obj] 55 | itodd: 56 | cad: itodd_models.zip 57 | test: itodd_test_bop19.zip 58 | pbr_train: itodd_train_pbr.zip 59 | obj_names: [001_obj, 002_obj, 003_obj, 004_obj, 005_obj, 006_obj, 007_obj, 008_obj, 009_obj, 010_obj, 011_obj, 012_obj, 013_obj, 014_obj, 015_obj, 016_obj, 017_obj, 018_obj, 019_obj, 020_obj, 021_obj, 022_obj, 023_obj, 024_obj, 025_obj, 026_obj, 027_obj, 028_obj] 60 | hb: 61 | cad: hb_models.zip 62 | test: hb_test_primesense_bop19.zip 63 | pbr_train: hb_train_pbr.zip 64 | obj_names: [001_red_teddy, 002_bench_wise, 003_car, 004_white_cow, 005_white_pig, 006_white_cup, 007_driller, 008_green_rabbit, 009_holepuncher, 010_brown_unknown, 011_brown_unknown, 012_black_unknown, 013_black_unknown, 014_white_painter, 015_small_unkown, 016_small_unkown, 017_small_unkown, 018_cake_box, 019_minion, 020_colored_dog, 021_phone, 022_animal, 023_yellow_dog, 024_cassette_player, 025_red_racing_car, 026_motobike, 027_heels, 028_dinosaur, 029_tea_box, 030_animal, 031_japanese_toy, 032_white_racing_car, 033_yellow_rabbit] 65 | hope: 66 | cad: hope_models.zip 67 | test: hope_test_bop19.zip 68 | obj_names: [001_alphabet_soup, 002_bbq_sauce, 003_butter, 004_cherries, 005_chocolate_pudding, 006_cookies, 007_corn, 008_cream_cheese, 009_granola_bar, 010_green_bean, 011_tomato_sauce, 012_macaroni_cheese, 013_mayo, 014_milk, 015_mushroom, 016_mustard, 017_orange_juice, 018_parmesa_cheese, 019_peaches, 020_peaches_and_carrot, 021_pineapple, 022_popcorn, 023_raisins, 024_ranch_dressing, 025_spaghetti, 026_tomato_sauce, 027_tuna, 028_yogurt] 69 | ycbv: 70 | cad: ycbv_models.zip 71 | test: ycbv_test_bop19.zip 72 | pbr_train: ycbv_train_pbr.zip 73 | obj_names: [002_master_chef_can, 003_cracker_box, 004_sugar_box, 005_tomato_soup_can, 006_mustard_bottle, 007_tuna_fish_can, 008_pudding_box, 009_gelatin_box, 010_potted_meat_can, 011_banana, 019_pitcher_base, 021_bleach_cleanser, 024_bowl, 025_mug, 035_power_drill, 036_wood_block, 037_scissors, 040_large_marker, 051_large_clamp, 052_extra_large_clamp, 061_foam_brick] 74 | ruapc: 75 | cad: ruapc_models.zip 76 | test: ruapc_test_bop19.zip 77 | obj_names: [001_red_copper_box, 002_red_cheezit_box, 003_crayon_box, 004_white_glue, 005_expo_box, 006_greenies, 007_straw_cup, 008_stick_box, 009_highland_sticker, 010_red_tennis_ball, 011_yellow_duck, 012_blue_oreo, 013_pen_box, 014_yellow_standley] 78 | icbin: 79 | cad: icbin_models.zip 80 | test: icbin_test_bop19.zip 81 | pbr_train: icbin_train_pbr.zip 82 | obj_names: [001_blue_cup, 002_blue_box] 83 | icmi: 84 | cad: icmi_models.zip 85 | test: icmi_test_bop19.zip 86 | obj_names: [001_obj, 002_obj, 003_obj, 004_obj, 005_obj, 006_obj] 87 | tudl: 88 | cad: tudl_models.zip 89 | test: tudl_test_bop19.zip 90 | pbr_train: tudl_train_pbr.zip 91 | obj_names: [001_dinosaur, 002_white_ape, 003_white_can] 92 | tyol: 93 | cad: tyol_models.zip 94 | test: tyol_test_bop19.zip 95 | obj_names: [001_obj, 002_obj, 003_obj, 004_obj, 005_obj, 006_obj, 007_obj, 008_obj, 009_obj, 010_obj, 011_obj, 012_obj, 013_obj, 014_obj, 015_obj, 016_obj, 017_obj, 018_obj, 019_obj, 020_obj, 021_obj] -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/download.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - user: default 3 | - machine: local 4 | - data: bop 5 | - _self_ 6 | - override hydra/hydra_logging: disabled 7 | - override hydra/job_logging: disabled 8 | 9 | hydra: 10 | output_subdir: null 11 | run: 12 | dir: . 13 | 14 | level: 2 15 | disable_output: true 16 | num_workers: 4 17 | gpus: "0,1,2,3" -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/hydra/hydra_logging/console.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | filters: 4 | onlyimportant: 5 | (): utils.logging.LevelsFilter 6 | levels: 7 | - CRITICAL 8 | - ERROR 9 | - WARNING 10 | noimportant: 11 | (): utils.logging.LevelsFilter 12 | levels: 13 | - INFO 14 | - DEBUG 15 | - NOTSET 16 | 17 | formatters: 18 | simple: 19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 20 | datefmt: '%d/%m/%y %H:%M:%S' 21 | 22 | colorlog: 23 | (): colorlog.ColoredFormatter 24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] 25 | - %(message)s' 26 | datefmt: '%d/%m/%y %H:%M:%S' 27 | 28 | log_colors: 29 | DEBUG: purple 30 | INFO: green 31 | WARNING: yellow 32 | ERROR: red 33 | CRITICAL: red 34 | 35 | handlers: 36 | console: 37 | class: logging.StreamHandler 38 | formatter: colorlog 39 | stream: ext://sys.stdout 40 | 41 | root: 42 | level: ${logger_level} 43 | handlers: 44 | - console 45 | 46 | disable_existing_loggers: false -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/hydra/hydra_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | formatters: 4 | colorlog: 5 | (): colorlog.ColoredFormatter 6 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s' 7 | datefmt: '%d/%m/%y %H:%M:%S' 8 | 9 | handlers: 10 | console: 11 | class: logging.StreamHandler 12 | formatter: colorlog 13 | stream: ext://sys.stdout 14 | 15 | root: 16 | level: INFO 17 | handlers: 18 | - console 19 | 20 | disable_existing_loggers: false -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/hydra/hydra_logging/rich.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | formatters: 4 | colorlog: 5 | (): colorlog.ColoredFormatter 6 | format: '[%(cyan)s%(asctime)s%(reset)s][%(purple)sHYDRA%(reset)s] %(message)s' 7 | datefmt: '%d/%m/%y %H:%M:%S' 8 | 9 | handlers: 10 | console: 11 | class: rich.logging.RichHandler # logging.StreamHandler 12 | formatter: colorlog 13 | 14 | root: 15 | level: INFO 16 | handlers: 17 | - console 18 | 19 | disable_existing_loggers: false 20 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/hydra/job_logging/console.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | filters: 4 | onlyimportant: 5 | (): utils.logging.LevelsFilter 6 | levels: 7 | - CRITICAL 8 | - ERROR 9 | - WARNING 10 | noimportant: 11 | (): utils.logging.LevelsFilter 12 | levels: 13 | - INFO 14 | - DEBUG 15 | - NOTSET 16 | 17 | formatters: 18 | simple: 19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 20 | datefmt: '%d/%m/%y %H:%M:%S' 21 | 22 | colorlog: 23 | (): colorlog.ColoredFormatter 24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] 25 | - %(message)s' 26 | datefmt: '%d/%m/%y %H:%M:%S' 27 | 28 | log_colors: 29 | DEBUG: purple 30 | INFO: green 31 | WARNING: yellow 32 | ERROR: red 33 | CRITICAL: red 34 | 35 | handlers: 36 | console: 37 | class: logging.StreamHandler 38 | formatter: colorlog 39 | stream: ext://sys.stdout 40 | 41 | root: 42 | level: ${logger_level} 43 | handlers: 44 | - console 45 | 46 | disable_existing_loggers: false 47 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/hydra/job_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | filters: 4 | onlyimportant: 5 | (): utils.logging.LevelsFilter 6 | levels: 7 | - CRITICAL 8 | - ERROR 9 | - WARNING 10 | noimportant: 11 | (): utils.logging.LevelsFilter 12 | levels: 13 | - INFO 14 | - DEBUG 15 | - NOTSET 16 | 17 | formatters: 18 | simple: 19 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 20 | datefmt: '%d/%m/%y %H:%M:%S' 21 | 22 | colorlog: 23 | (): colorlog.ColoredFormatter 24 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] 25 | - %(message)s' 26 | datefmt: '%d/%m/%y %H:%M:%S' 27 | 28 | log_colors: 29 | DEBUG: purple 30 | INFO: green 31 | WARNING: yellow 32 | ERROR: red 33 | CRITICAL: red 34 | 35 | handlers: 36 | console: 37 | class: logging.StreamHandler 38 | formatter: colorlog 39 | stream: ext://sys.stdout 40 | 41 | file_out: 42 | class: logging.FileHandler 43 | formatter: simple 44 | filename: logs.out 45 | filters: 46 | - noimportant 47 | 48 | file_err: 49 | class: logging.FileHandler 50 | formatter: simple 51 | filename: logs.err 52 | filters: 53 | - onlyimportant 54 | 55 | root: 56 | level: ${logger_level} 57 | handlers: 58 | - console 59 | - file_out 60 | - file_err 61 | 62 | disable_existing_loggers: false -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/hydra/job_logging/rich.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | filters: 4 | onlyimportant: 5 | (): utils.logging.LevelsFilter 6 | levels: 7 | - CRITICAL 8 | - ERROR 9 | - WARNING 10 | noimportant: 11 | (): utils.logging.LevelsFilter 12 | levels: 13 | - INFO 14 | - DEBUG 15 | - NOTSET 16 | 17 | formatters: 18 | verysimple: 19 | format: '%(message)s' 20 | simple: 21 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 22 | datefmt: '%d/%m/%y %H:%M:%S' 23 | 24 | colorlog: 25 | (): colorlog.ColoredFormatter 26 | format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)s%(name)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] 27 | - %(message)s' 28 | datefmt: '%d/%m/%y %H:%M:%S' 29 | 30 | log_colors: 31 | DEBUG: purple 32 | INFO: green 33 | WARNING: yellow 34 | ERROR: red 35 | CRITICAL: red 36 | 37 | handlers: 38 | console: 39 | class: rich.logging.RichHandler # logging.StreamHandler 40 | formatter: verysimple # colorlog 41 | rich_tracebacks: true 42 | 43 | file_out: 44 | class: logging.FileHandler 45 | formatter: simple 46 | filename: logs.out 47 | filters: 48 | - noimportant 49 | 50 | file_err: 51 | class: logging.FileHandler 52 | formatter: simple 53 | filename: logs.err 54 | filters: 55 | - onlyimportant 56 | 57 | root: 58 | level: ${logger_level} 59 | handlers: 60 | - console 61 | - file_out 62 | - file_err 63 | 64 | disable_existing_loggers: false -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/machine/local.yaml: -------------------------------------------------------------------------------- 1 | name: local 2 | 3 | # specific attributes to this machine 4 | batch_size: 16 5 | num_workers: 10 6 | 7 | 8 | # Define trainer, datasets 9 | defaults: 10 | - trainer: local 11 | 12 | dryrun: True 13 | root_dir: ${user.local_root_dir} -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/machine/slurm.yaml: -------------------------------------------------------------------------------- 1 | name: slurm 2 | 3 | # specific attributes to this machine 4 | batch_size: 16 5 | num_workers: 10 6 | 7 | # Define trainer 8 | defaults: 9 | - trainer: slurm 10 | 11 | dryrun: True 12 | root_dir: ${user.slurm_root_dir} -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/machine/trainer/local.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | max_epochs: 1000 4 | accelerator: gpu 5 | deterministic: False 6 | detect_anomaly: False 7 | enable_progress_bar: True 8 | strategy: ddp 9 | precision: 16 10 | accumulate_grad_batches: 11 | val_check_interval: 2000 12 | log_every_n_steps: 1 13 | num_sanity_val_steps: 2 14 | limit_val_batches: 20 15 | limit_test_batches: 16 | 17 | callbacks: "${oc.dict.values: callback}" -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/machine/trainer/slurm.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | max_epochs: 1000 4 | accelerator: gpu 5 | deterministic: False 6 | detect_anomaly: False 7 | enable_progress_bar: True 8 | strategy: ddp 9 | precision: 16 10 | accumulate_grad_batches: 11 | val_check_interval: 1000 12 | log_every_n_steps: 1 13 | num_sanity_val_steps: 2 14 | limit_val_batches: 20 15 | limit_test_batches: 1.0 16 | 17 | callbacks: "${oc.dict.values: callback}" -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/model/ISM_fastsam.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.detector.Instance_Segmentation_Model 2 | log_interval: 5 3 | log_dir: ${save_dir} 4 | segmentor_width_size: 640 # make it stable 5 | descriptor_width_size: 640 6 | visible_thred: 0.5 7 | pointcloud_sample_num: 2048 8 | 9 | defaults: 10 | - segmentor_model: fast_sam 11 | - descriptor_model: dinov2 12 | 13 | post_processing_config: 14 | mask_post_processing: 15 | min_box_size: 0.05 # relative to image size 16 | min_mask_size: 3e-4 # relative to image size 17 | nms_thresh: 0.25 18 | 19 | matching_config: 20 | metric: 21 | _target_: model.loss.PairwiseSimilarity 22 | metric: cosine 23 | chunk_size: 16 24 | aggregation_function: avg_5 25 | confidence_thresh: 0.2 26 | 27 | onboarding_config: 28 | rendering_type: pbr 29 | reset_descriptors: False 30 | level_templates: 0 # 0 is coarse, 1 is medium, 2 is dense 31 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/model/ISM_sam.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.detector.Instance_Segmentation_Model 2 | log_interval: 5 3 | log_dir: ${save_dir} 4 | segmentor_width_size: 640 # make it stable 5 | descriptor_width_size: 640 6 | visible_thred: 0.5 7 | pointcloud_sample_num: 2048 8 | 9 | defaults: 10 | - segmentor_model: sam 11 | - descriptor_model: dinov2 12 | 13 | post_processing_config: 14 | mask_post_processing: 15 | min_box_size: 0.05 # relative to image size 16 | min_mask_size: 3e-4 # relative to image size 17 | nms_thresh: 0.25 18 | 19 | matching_config: 20 | metric: 21 | _target_: model.loss.PairwiseSimilarity 22 | metric: cosine 23 | chunk_size: 16 24 | aggregation_function: avg_5 25 | confidence_thresh: 0.2 26 | 27 | onboarding_config: 28 | rendering_type: pbr 29 | reset_descriptors: False 30 | level_templates: 0 # 0 is coarse, 1 is medium, 2 is dense 31 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/model/descriptor_model/dinov2.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.dinov2.CustomDINOv2 2 | model_name: dinov2_vitl14 3 | # model: 4 | # _target_: torch.hub.load 5 | # repo_or_dir: facebookresearch/dinov2 6 | # model: ${model.descriptor_model.model_name} 7 | checkpoint_dir: ./checkpoints/dinov2/ 8 | token_name: x_norm_clstoken 9 | descriptor_width_size: ${model.descriptor_width_size} 10 | image_size: 224 11 | chunk_size: 16 12 | validpatch_thresh: 0.5 -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/model/segmentor_model/fast_sam.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.fast_sam.FastSAM 2 | checkpoint_path: ./checkpoints/FastSAM/FastSAM-x.pt 3 | segmentor_width_size: ${model.segmentor_width_size} 4 | config: 5 | iou_threshold: 0.9 6 | conf_threshold: 0.05 7 | max_det: 200 -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/model/segmentor_model/sam.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.sam.CustomSamAutomaticMaskGenerator 2 | points_per_batch: 64 3 | stability_score_thresh: 0.85 4 | box_nms_thresh: 0.7 5 | min_mask_region_area: 0 6 | crop_overlap_ratio: 7 | pred_iou_thresh: 0.88 8 | segmentor_width_size: ${model.segmentor_width_size} 9 | sam: 10 | _target_: model.sam.load_sam 11 | model_type: vit_h 12 | checkpoint_dir: ./checkpoints/segment-anything/ -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/run_inference.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - user: default 3 | - machine: local 4 | - callback: base 5 | - data: bop 6 | - model: ISM_sam 7 | - _self_ 8 | - override hydra/hydra_logging: disabled 9 | - override hydra/job_logging: disabled 10 | 11 | hydra: 12 | output_subdir: null 13 | run: 14 | dir: . 15 | 16 | save_dir: ./log/${name_exp} 17 | name_exp: sam 18 | dataset_name: -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/configs/user/default.yaml: -------------------------------------------------------------------------------- 1 | local_root_dir: ../Data 2 | slurm_root_dir: ../Data -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/download_dinov2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import os, sys 4 | import os.path as osp 5 | from utils.inout import get_root_project 6 | 7 | # set level logging 8 | logging.basicConfig(level=logging.INFO) 9 | import logging 10 | import hydra 11 | from omegaconf import DictConfig, OmegaConf 12 | model_dict = { 13 | "dinov2_vits14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth", 14 | "dinov2_vitb14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth", 15 | "dinov2_vitl14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", 16 | "dinov2_vitg14": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth", 17 | } 18 | 19 | def download_model(url, output_path): 20 | import os 21 | 22 | command = f"wget -O {output_path}/{url.split('/')[-1]} {url} --no-check-certificate" 23 | os.system(command) 24 | 25 | 26 | @hydra.main( 27 | version_base=None, 28 | config_path="./configs", 29 | config_name="download", 30 | ) 31 | def download(cfg: DictConfig) -> None: 32 | model_name = "dinov2_vitl14" # default segmentation model used in CNOS 33 | save_dir = osp.join(get_root_project(), "checkpoints/dinov2") 34 | os.makedirs(save_dir, exist_ok=True) 35 | download_model(model_dict[model_name], save_dir) 36 | 37 | if __name__ == "__main__": 38 | download() 39 | 40 | 41 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/download_fastsam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import os, sys 4 | import os.path as osp 5 | from utils.inout import get_root_project 6 | 7 | # set level logging 8 | logging.basicConfig(level=logging.INFO) 9 | import logging 10 | import hydra 11 | from omegaconf import DictConfig, OmegaConf 12 | 13 | def download_model(output_path): 14 | import os 15 | command = f"gdown --no-cookies --no-check-certificate -O '{output_path}/FastSAM-x.pt' 1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv" 16 | os.system(command) 17 | 18 | 19 | @hydra.main( 20 | version_base=None, 21 | config_path="./configs", 22 | config_name="download", 23 | ) 24 | def download(cfg: DictConfig) -> None: 25 | save_dir = osp.join(get_root_project(), "checkpoints/FastSAM") 26 | os.makedirs(save_dir, exist_ok=True) 27 | download_model(save_dir) 28 | 29 | if __name__ == "__main__": 30 | download() 31 | 32 | 33 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/download_sam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import os, sys 4 | import os.path as osp 5 | from utils.inout import get_root_project 6 | 7 | # set level logging 8 | logging.basicConfig(level=logging.INFO) 9 | import logging 10 | import hydra 11 | from omegaconf import DictConfig, OmegaConf 12 | model_dict = { 13 | "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", # 2560 MB 14 | "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", # 1250 MB 15 | "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 16 | } # 375 GB 17 | 18 | def download_model(url, output_path): 19 | import os 20 | 21 | command = f"wget -O {output_path}/{url.split('/')[-1]} {url} --no-check-certificate" 22 | os.system(command) 23 | 24 | 25 | @hydra.main( 26 | version_base=None, 27 | config_path="./configs", 28 | config_name="download", 29 | ) 30 | def download(cfg: DictConfig) -> None: 31 | model_name = "vit_h" # default segmentation model used in CNOS 32 | save_dir = osp.join(get_root_project(), "checkpoints/segment-anything") 33 | os.makedirs(save_dir, exist_ok=True) 34 | download_model(model_dict[model_name], save_dir) 35 | 36 | if __name__ == "__main__": 37 | download() 38 | 39 | 40 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/environment.yml: -------------------------------------------------------------------------------- 1 | name: sam6d-ism 2 | channels: 3 | - xformers 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - pip 10 | - python=3.9.6 11 | - pip: 12 | - torch 13 | - torchvision 14 | - omegaconf 15 | - torchmetrics==0.10.3 16 | - fvcore 17 | - iopath 18 | - xformers==0.0.18 19 | - opencv-python 20 | - pycocotools 21 | - matplotlib 22 | - onnxruntime 23 | - onnx 24 | - scipy 25 | - ffmpeg 26 | - hydra-colorlog 27 | - hydra-core 28 | - gdown 29 | - pytorch-lightning==1.8.1 30 | - pandas 31 | - ruamel.yaml 32 | - pyrender 33 | - wandb 34 | - distinctipy 35 | - imageio 36 | # bop https://github.com/thodan/bop_toolkit.git -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/exp.sh: -------------------------------------------------------------------------------- 1 | # with sam 2 | python run_inference.py dataset_name=icbin 3 | 4 | # with fastsam 5 | python run_inference.py dataset_name=icbin model=ISM_fastsam 6 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/model/__init__.py -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/fast_sam.py: -------------------------------------------------------------------------------- 1 | from ultralytics import YOLO 2 | from pathlib import Path 3 | from typing import Union 4 | import numpy as np 5 | import cv2 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from segment_anything.utils.amg import MaskData 10 | import logging 11 | import os.path as osp 12 | from typing import Any, Dict, List, Optional, Tuple 13 | import pytorch_lightning as pl 14 | from ultralytics import yolo # noqa 15 | from ultralytics.nn.autobackend import AutoBackend 16 | 17 | 18 | class CustomYOLO(YOLO): 19 | def __init__( 20 | self, 21 | model, 22 | iou, 23 | conf, 24 | max_det, 25 | segmentor_width_size, 26 | selected_device="cpu", 27 | verbose=False, 28 | ): 29 | YOLO.__init__( 30 | self, 31 | model, 32 | ) 33 | self.overrides["iou"] = iou 34 | self.overrides["conf"] = conf 35 | self.overrides["max_det"] = max_det 36 | self.overrides["verbose"] = verbose 37 | self.overrides["imgsz"] = segmentor_width_size 38 | 39 | self.overrides["conf"] = 0.25 40 | self.overrides["mode"] = "predict" 41 | self.overrides["save"] = False 42 | 43 | self.predictor = yolo.v8.segment.SegmentationPredictor( 44 | overrides=self.overrides, _callbacks=self.callbacks 45 | ) 46 | 47 | self.not_setup = True 48 | self.selected_device = selected_device 49 | logging.info(f"Init CustomYOLO done!") 50 | 51 | def setup_model(self, device, verbose=False): 52 | """Initialize YOLO model with given parameters and set it to evaluation mode.""" 53 | model = self.predictor.model or self.predictor.args.model 54 | self.predictor.args.half &= ( 55 | device.type != "cpu" 56 | ) # half precision only supported on CUDA 57 | self.predictor.model = AutoBackend( 58 | model, 59 | device=device, 60 | dnn=self.predictor.args.dnn, 61 | data=self.predictor.args.data, 62 | fp16=self.predictor.args.half, 63 | fuse=True, 64 | verbose=verbose, 65 | ) 66 | self.predictor.device = device 67 | self.predictor.model.eval() 68 | logging.info(f"Setup model at device {device} done!") 69 | 70 | def __call__(self, source=None, stream=False): 71 | return self.predictor(source=source, stream=stream) 72 | 73 | 74 | class FastSAM(object): 75 | def __init__( 76 | self, 77 | checkpoint_path: Union[str, Path], 78 | config: dict = None, 79 | segmentor_width_size=None, 80 | device=None, 81 | ): 82 | self.model = CustomYOLO( 83 | model=checkpoint_path, 84 | iou=config.iou_threshold, 85 | conf=config.conf_threshold, 86 | max_det=config.max_det, 87 | selected_device=device, 88 | segmentor_width_size=segmentor_width_size, 89 | ) 90 | self.segmentor_width_size = segmentor_width_size 91 | self.current_device = device 92 | logging.info(f"Init FastSAM done!") 93 | 94 | def postprocess_resize(self, detections, orig_size, update_boxes=False): 95 | detections["masks"] = F.interpolate( 96 | detections["masks"].unsqueeze(1).float(), 97 | size=(orig_size[0], orig_size[1]), 98 | mode="bilinear", 99 | align_corners=False, 100 | )[:, 0, :, :] 101 | if update_boxes: 102 | scale = orig_size[1] / self.segmentor_width_size 103 | detections["boxes"] = detections["boxes"].float() * scale 104 | detections["boxes"][:, [0, 2]] = torch.clamp( 105 | detections["boxes"][:, [0, 2]], 0, orig_size[1] - 1 106 | ) 107 | detections["boxes"][:, [1, 3]] = torch.clamp( 108 | detections["boxes"][:, [1, 3]], 0, orig_size[0] - 1 109 | ) 110 | return detections 111 | 112 | @torch.no_grad() 113 | def generate_masks(self, image) -> List[Dict[str, Any]]: 114 | if self.segmentor_width_size is not None: 115 | orig_size = image.shape[:2] 116 | detections = self.model(image) 117 | 118 | masks = detections[0].masks.data 119 | boxes = detections[0].boxes.data[:, :4] # two lasts: confidence and class 120 | 121 | # define class data 122 | mask_data = { 123 | "masks": masks.to(self.current_device), 124 | "boxes": boxes.to(self.current_device), 125 | } 126 | if self.segmentor_width_size is not None: 127 | mask_data = self.postprocess_resize(mask_data, orig_size) 128 | return mask_data 129 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_head import DINOHead 8 | from .mlp import Mlp 9 | from .patch_embed import PatchEmbed 10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 11 | from .block import NestedTensorBlock 12 | from .attention import MemEffAttention 13 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class MemEffAttention(Attention): 66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 67 | if not XFORMERS_AVAILABLE: 68 | assert attn_bias is None, "xFormers is required for nested tensors usage" 69 | return super().forward(x) 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 73 | 74 | q, k, v = unbind(qkv, 2) 75 | 76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 77 | x = x.reshape([B, N, C]) 78 | 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from utils.poses.pose_utils import load_rotation_transform, convert_openCV_to_openGL_torch 4 | import torch.nn.functional as F 5 | from model.utils import BatchedData 6 | 7 | 8 | class Similarity(nn.Module): 9 | def __init__(self, metric="cosine", chunk_size=64): 10 | super(Similarity, self).__init__() 11 | self.metric = metric 12 | self.chunk_size = chunk_size 13 | 14 | def forward(self, query, reference): 15 | query = F.normalize(query, dim=-1) 16 | reference = F.normalize(reference, dim=-1) 17 | similarity = F.cosine_similarity(query, reference, dim=-1) 18 | return similarity.clamp(min=0.0, max=1.0) 19 | 20 | 21 | class PairwiseSimilarity(nn.Module): 22 | def __init__(self, metric="cosine", chunk_size=64): 23 | super(PairwiseSimilarity, self).__init__() 24 | self.metric = metric 25 | self.chunk_size = chunk_size 26 | 27 | def forward(self, query, reference): 28 | N_query = query.shape[0] 29 | N_objects, N_templates = reference.shape[0], reference.shape[1] 30 | references = reference.clone().unsqueeze(0).repeat(N_query, 1, 1, 1) 31 | queries = query.clone().unsqueeze(1).repeat(1, N_templates, 1) 32 | queries = F.normalize(queries, dim=-1) 33 | references = F.normalize(references, dim=-1) 34 | 35 | similarity = BatchedData(batch_size=None) 36 | for idx_obj in range(N_objects): 37 | sim = F.cosine_similarity( 38 | queries, references[:, idx_obj], dim=-1 39 | ) # N_query x N_templates 40 | similarity.append(sim) 41 | similarity.stack() 42 | similarity = similarity.data 43 | similarity = similarity.permute(1, 0, 2) # N_query x N_objects x N_templates 44 | return similarity.clamp(min=0.0, max=1.0) 45 | 46 | class MaskedPatch_MatrixSimilarity(nn.Module): 47 | def __init__(self, metric="cosine", chunk_size=64): 48 | super(MaskedPatch_MatrixSimilarity, self).__init__() 49 | self.metric = metric 50 | self.chunk_size = chunk_size 51 | 52 | def compute_straight(self, query, reference): 53 | (N_query, N_patch, N_features) = query.shape 54 | sim_matrix = torch.matmul(query, reference.permute(0, 2, 1)) # N_query x N_query_mask x N_refer_mask 55 | 56 | # N2_ref score max 57 | max_ref_patch_score = torch.max(sim_matrix, dim=-1).values 58 | # N1_query score average 59 | factor = torch.count_nonzero(query.sum(dim=-1), dim=-1) + 1e-6 60 | scores = torch.sum(max_ref_patch_score, dim=-1) / factor # N_query x N_objects x N_templates 61 | 62 | return scores.clamp(min=0.0, max=1.0) 63 | 64 | def compute_visible_ratio(self, query, reference, thred=0.5): 65 | 66 | sim_matrix = torch.matmul(query, reference.permute(0, 2, 1)) # N_query x N_query_mask x N_refer_mask 67 | sim_matrix = sim_matrix.max(1)[0] # N_query x N_refer_mask 68 | valid_patches = torch.count_nonzero(sim_matrix, dim=(1, )) + 1e-6 69 | 70 | # fliter correspendence by thred 71 | flitered_matrix = sim_matrix * (sim_matrix > thred) 72 | sim_patches = torch.count_nonzero(flitered_matrix, dim=(1,)) 73 | 74 | visible_ratio = sim_patches / valid_patches 75 | 76 | return visible_ratio 77 | 78 | def compute_similarity(self, query, reference): 79 | # all template computation 80 | N_query = query.shape[0] 81 | N_objects, N_templates = reference.shape[0], reference.shape[1] 82 | references = reference.unsqueeze(0).repeat(N_query, 1, 1, 1, 1) 83 | queries = query.unsqueeze(1).repeat(1, N_templates, 1, 1) 84 | 85 | similarity = BatchedData(batch_size=None) 86 | for idx_obj in range(N_objects): 87 | sim_matrix = torch.matmul(queries, references[:, idx_obj].permute(0, 1, 3, 2)) # N_query x N_templates x N_query_mask x N_refer_mask 88 | similarity.append(sim_matrix) 89 | similarity.stack() 90 | similarity = similarity.data 91 | similarity = similarity.permute(1, 0, 2, 3, 4) # N_query x N_objects x N_templates x N1_query x N2_ref 92 | 93 | # N2_ref score max 94 | max_ref_patch_score = torch.max(similarity, dim=-1).values 95 | # N1_query score average 96 | factor = torch.count_nonzero(query.sum(dim=-1), dim=-1)[:, None, None] 97 | scores = torch.sum(max_ref_patch_score, dim=-1) / factor # N_query x N_objects x N_templates 98 | 99 | return scores.clamp(min=0.0, max=1.0) 100 | 101 | def forward_by_chunk(self, query, reference): 102 | # divide by N_query 103 | batch_query = BatchedData(batch_size=self.chunk_size, data=query) 104 | del query 105 | scores = BatchedData(batch_size=self.chunk_size) 106 | for idx_batch in range(len(batch_query)): 107 | score = self.compute_similarity(batch_query[idx_batch], reference) 108 | scores.cat(score) 109 | return scores.data 110 | 111 | def forward(self, qurey, reference): 112 | if qurey.shape[0] > self.chunk_size: 113 | scores = self.forward_by_chunk(qurey, reference) 114 | else: 115 | scores = self.compute_similarity(qurey, reference) 116 | return scores 117 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/sam.py: -------------------------------------------------------------------------------- 1 | from segment_anything import ( 2 | sam_model_registry, 3 | SamPredictor, 4 | SamAutomaticMaskGenerator, 5 | ) 6 | from segment_anything.modeling import Sam 7 | from segment_anything.utils.amg import MaskData, generate_crop_boxes, rle_to_mask 8 | import logging 9 | import numpy as np 10 | import torch 11 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 12 | import os.path as osp 13 | from typing import Any, Dict, List, Optional, Tuple 14 | import cv2 15 | import torch.nn.functional as F 16 | 17 | pretrained_weight_dict = { 18 | "vit_l": "sam_vit_l_0b3195.pth", # 1250MB 19 | "vit_b": "sam_vit_b_01ec64.pth", # 375MB 20 | "vit_h": "sam_vit_h_4b8939.pth", # 2500MB 21 | } 22 | 23 | 24 | def load_sam(model_type, checkpoint_dir): 25 | logging.info(f"Loading SAM model from {checkpoint_dir}") 26 | sam = sam_model_registry[model_type]( 27 | checkpoint=osp.join(checkpoint_dir, pretrained_weight_dict[model_type]) 28 | ) 29 | return sam 30 | 31 | 32 | def load_sam_predictor(model_type, checkpoint_dir, device): 33 | logging.info(f"Loading SAM model from {checkpoint_dir}") 34 | sam = sam_model_registry[model_type]( 35 | checkpoint=osp.join(checkpoint_dir, pretrained_weight_dict[model_type]) 36 | ) 37 | sam.to(device=device) 38 | predictor = SamPredictor(sam) 39 | return predictor 40 | 41 | 42 | def load_sam_mask_generator(model_type, checkpoint_dir, device): 43 | logging.info(f"Loading SAM model from {checkpoint_dir}") 44 | sam = sam_model_registry[model_type]( 45 | checkpoint=osp.join(checkpoint_dir, pretrained_weight_dict[model_type]) 46 | ) 47 | sam.to(device=device) 48 | mask_generator = SamAutomaticMaskGenerator(sam, output_mode="coco_rle") 49 | return mask_generator 50 | 51 | 52 | class CustomSamAutomaticMaskGenerator(SamAutomaticMaskGenerator): 53 | def __init__( 54 | self, 55 | sam: Sam, 56 | min_mask_region_area: int = 0, 57 | points_per_batch: int = 64, 58 | stability_score_thresh: float = 0.85, 59 | box_nms_thresh: float = 0.7, 60 | crop_overlap_ratio: float = 512 / 1500, 61 | segmentor_width_size=None, 62 | pred_iou_thresh: float = 0.88, 63 | ): 64 | SamAutomaticMaskGenerator.__init__( 65 | self, 66 | sam, 67 | min_mask_region_area=min_mask_region_area, 68 | points_per_batch=points_per_batch, 69 | stability_score_thresh=stability_score_thresh, 70 | box_nms_thresh=box_nms_thresh, 71 | crop_overlap_ratio=crop_overlap_ratio, 72 | pred_iou_thresh=pred_iou_thresh 73 | ) 74 | self.segmentor_width_size = segmentor_width_size 75 | logging.info(f"Init CustomSamAutomaticMaskGenerator done!") 76 | 77 | def preprocess_resize(self, image: np.ndarray): 78 | orig_size = image.shape[:2] 79 | height_size = int(self.segmentor_width_size * orig_size[0] / orig_size[1]) 80 | resized_image = cv2.resize( 81 | image.copy(), (self.segmentor_width_size, height_size) # (width, height) 82 | ) 83 | return resized_image 84 | 85 | def postprocess_resize(self, detections, orig_size): 86 | detections["masks"] = F.interpolate( 87 | detections["masks"].unsqueeze(1).float(), 88 | size=(orig_size[0], orig_size[1]), 89 | mode="bilinear", 90 | align_corners=False, 91 | )[:, 0, :, :] 92 | scale = orig_size[1] / self.segmentor_width_size 93 | detections["boxes"] = detections["boxes"].float() * scale 94 | detections["boxes"][:, [0, 2]] = torch.clamp( 95 | detections["boxes"][:, [0, 2]], 0, orig_size[1] - 1 96 | ) 97 | detections["boxes"][:, [1, 3]] = torch.clamp( 98 | detections["boxes"][:, [1, 3]], 0, orig_size[0] - 1 99 | ) 100 | return detections 101 | 102 | @torch.no_grad() 103 | def generate_masks(self, image: np.ndarray) -> List[Dict[str, Any]]: 104 | if self.segmentor_width_size is not None: 105 | orig_size = image.shape[:2] 106 | image = self.preprocess_resize(image) 107 | # Generate masks 108 | mask_data = self._generate_masks(image) 109 | 110 | # Filter small disconnected regions and holes in masks 111 | if self.min_mask_region_area > 0: 112 | mask_data = self.postprocess_small_regions( 113 | mask_data, 114 | self.min_mask_region_area, 115 | max(self.box_nms_thresh, self.crop_nms_thresh), 116 | ) 117 | if self.segmentor_width_size is not None: 118 | mask_data = self.postprocess_resize(mask_data, orig_size) 119 | return mask_data 120 | 121 | def _generate_masks(self, image: np.ndarray) -> MaskData: 122 | orig_size = image.shape[:2] 123 | crop_boxes, layer_idxs = generate_crop_boxes( 124 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 125 | ) 126 | 127 | # Iterate over image crops 128 | data = MaskData() 129 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 130 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 131 | data.cat(crop_data) 132 | 133 | # Remove duplicate masks between crops 134 | if len(crop_boxes) > 1: 135 | # Prefer masks from smaller crops 136 | scores = 1 / box_area(data["crop_boxes"]) 137 | scores = scores.to(data["boxes"].device) 138 | keep_by_nms = batched_nms( 139 | data["boxes"].float(), 140 | scores, 141 | torch.zeros_like(data["boxes"][:, 0]), # categories 142 | iou_threshold=self.crop_nms_thresh, 143 | ) 144 | data.filter(keep_by_nms) 145 | 146 | data["masks"] = [torch.from_numpy(rle_to_mask(rle)) for rle in data["rles"]] 147 | data["masks"] = torch.stack(data["masks"]) 148 | return {"masks": data["masks"].to(data["boxes"].device), "boxes": data["boxes"]} 149 | 150 | def remove_small_detections(self, mask_data: MaskData, img_size: List) -> MaskData: 151 | # calculate area and number of pixels in each mask 152 | area = box_area(mask_data["boxes"]) / (img_size[0] * img_size[1]) 153 | idx_selected = area >= self.mask_post_processing.min_box_size 154 | mask_data.filter(idx_selected) 155 | return mask_data -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision 4 | from torchvision.ops.boxes import batched_nms, box_area 5 | import logging 6 | from utils.inout import save_json, load_json, save_npz 7 | from utils.bbox_utils import xyxy_to_xywh, xywh_to_xyxy, force_binary_mask 8 | import time 9 | from PIL import Image 10 | 11 | lmo_object_ids = np.array( 12 | [ 13 | 1, 14 | 5, 15 | 6, 16 | 8, 17 | 9, 18 | 10, 19 | 11, 20 | 12, 21 | ] 22 | ) # object ID of occlusionLINEMOD is different 23 | 24 | 25 | def mask_to_rle(binary_mask): 26 | rle = {"counts": [], "size": list(binary_mask.shape)} 27 | counts = rle.get("counts") 28 | 29 | last_elem = 0 30 | running_length = 0 31 | 32 | for i, elem in enumerate(binary_mask.ravel(order="F")): 33 | if elem == last_elem: 34 | pass 35 | else: 36 | counts.append(running_length) 37 | running_length = 0 38 | last_elem = elem 39 | running_length += 1 40 | 41 | counts.append(running_length) 42 | 43 | return rle 44 | 45 | 46 | class BatchedData: 47 | """ 48 | A structure for storing data in batched format. 49 | Implements basic filtering and concatenation. 50 | """ 51 | 52 | def __init__(self, batch_size, data=None, **kwargs) -> None: 53 | self.batch_size = batch_size 54 | if data is not None: 55 | self.data = data 56 | else: 57 | self.data = [] 58 | 59 | def __len__(self): 60 | assert self.batch_size is not None, "batch_size is not defined" 61 | return np.ceil(len(self.data) / self.batch_size).astype(int) 62 | 63 | def __getitem__(self, idx): 64 | assert self.batch_size is not None, "batch_size is not defined" 65 | return self.data[idx * self.batch_size : (idx + 1) * self.batch_size] 66 | 67 | def cat(self, data, dim=0): 68 | if len(self.data) == 0: 69 | self.data = data 70 | else: 71 | self.data = torch.cat([self.data, data], dim=dim) 72 | 73 | def append(self, data): 74 | self.data.append(data) 75 | 76 | def stack(self, dim=0): 77 | self.data = torch.stack(self.data, dim=dim) 78 | 79 | 80 | class Detections: 81 | """ 82 | A structure for storing detections. 83 | """ 84 | 85 | def __init__(self, data) -> None: 86 | if isinstance(data, str): 87 | data = self.load_from_file(data) 88 | for key, value in data.items(): 89 | setattr(self, key, value) 90 | self.keys = list(data.keys()) 91 | if "boxes" in self.keys: 92 | if isinstance(self.boxes, np.ndarray): 93 | self.to_torch() 94 | self.boxes = self.boxes.long() 95 | 96 | def remove_very_small_detections(self, config): 97 | img_area = self.masks.shape[1] * self.masks.shape[2] 98 | box_areas = box_area(self.boxes) / img_area 99 | mask_areas = self.masks.sum(dim=(1, 2)) / img_area 100 | keep_idxs = torch.logical_and( 101 | box_areas > config.min_box_size**2, mask_areas > config.min_mask_size 102 | ) 103 | # logging.info(f"Removing {len(keep_idxs) - keep_idxs.sum()} detections") 104 | for key in self.keys: 105 | setattr(self, key, getattr(self, key)[keep_idxs]) 106 | 107 | def apply_nms_per_object_id(self, nms_thresh=0.5): 108 | keep_idxs = BatchedData(None) 109 | all_indexes = torch.arange(len(self.object_ids), device=self.boxes.device) 110 | for object_id in torch.unique(self.object_ids): 111 | idx = self.object_ids == object_id 112 | idx_object_id = all_indexes[idx] 113 | keep_idx = torchvision.ops.nms( 114 | self.boxes[idx].float(), self.scores[idx].float(), nms_thresh 115 | ) 116 | keep_idxs.cat(idx_object_id[keep_idx]) 117 | keep_idxs = keep_idxs.data 118 | for key in self.keys: 119 | setattr(self, key, getattr(self, key)[keep_idxs]) 120 | 121 | def apply_nms(self, nms_thresh=0.5): 122 | keep_idx = torchvision.ops.nms( 123 | self.boxes.float(), self.scores.float(), nms_thresh 124 | ) 125 | for key in self.keys: 126 | setattr(self, key, getattr(self, key)[keep_idx]) 127 | 128 | def add_attribute(self, key, value): 129 | setattr(self, key, value) 130 | self.keys.append(key) 131 | 132 | def __len__(self): 133 | return len(self.boxes) 134 | 135 | def check_size(self): 136 | mask_size = len(self.masks) 137 | box_size = len(self.boxes) 138 | score_size = len(self.scores) 139 | object_id_size = len(self.object_ids) 140 | assert ( 141 | mask_size == box_size == score_size == object_id_size 142 | ), f"Size mismatch {mask_size} {box_size} {score_size} {object_id_size}" 143 | 144 | def to_numpy(self): 145 | for key in self.keys: 146 | setattr(self, key, getattr(self, key).cpu().numpy()) 147 | 148 | def to_torch(self): 149 | for key in self.keys: 150 | a = getattr(self, key) 151 | setattr(self, key, torch.from_numpy(getattr(self, key))) 152 | 153 | def save_to_file( 154 | self, scene_id, frame_id, runtime, file_path, dataset_name, return_results=False 155 | ): 156 | """ 157 | scene_id, image_id, category_id, bbox, time 158 | """ 159 | boxes = xyxy_to_xywh(self.boxes) 160 | results = { 161 | "scene_id": scene_id, 162 | "image_id": frame_id, 163 | "category_id": self.object_ids + 1 164 | if dataset_name != "lmo" 165 | else lmo_object_ids[self.object_ids], 166 | "score": self.scores, 167 | "bbox": boxes, 168 | "time": runtime, 169 | "segmentation": self.masks, 170 | } 171 | save_npz(file_path, results) 172 | if return_results: 173 | return results 174 | 175 | def load_from_file(self, file_path): 176 | data = np.load(file_path) 177 | masks = data["segmentation"] 178 | boxes = xywh_to_xyxy(np.array(data["bbox"])) 179 | data = { 180 | "object_ids": data["category_id"] - 1, 181 | "bbox": boxes, 182 | "scores": data["score"], 183 | "masks": masks, 184 | } 185 | logging.info(f"Loaded {file_path}") 186 | return data 187 | 188 | def filter(self, idxs): 189 | for key in self.keys: 190 | setattr(self, key, getattr(self, key)[idxs]) 191 | 192 | def clone(self): 193 | """ 194 | Clone the current object 195 | """ 196 | return Detections(self.__dict__.copy()) 197 | 198 | 199 | def convert_npz_to_json(idx, list_npz_paths): 200 | npz_path = list_npz_paths[idx] 201 | detections = np.load(npz_path) 202 | results = [] 203 | for idx_det in range(len(detections["bbox"])): 204 | result = { 205 | "scene_id": int(detections["scene_id"]), 206 | "image_id": int(detections["image_id"]), 207 | "category_id": int(detections["category_id"][idx_det]), 208 | "bbox": detections["bbox"][idx_det].tolist(), 209 | "score": float(detections["score"][idx_det]), 210 | "time": float(detections["time"]), 211 | "segmentation": mask_to_rle( 212 | force_binary_mask(detections["segmentation"][idx_det]) 213 | ), 214 | } 215 | results.append(result) 216 | return results 217 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/provider/bop.py: -------------------------------------------------------------------------------- 1 | import logging, os 2 | import os.path as osp 3 | from tqdm import tqdm 4 | import time 5 | import numpy as np 6 | import torchvision.transforms as T 7 | from pathlib import Path 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | import os.path as osp 11 | from utils.poses.pose_utils import load_index_level_in_level2 12 | import torch 13 | from utils.bbox_utils import CropResizePad 14 | import pytorch_lightning as pl 15 | from provider.base_bop import BaseBOP 16 | import imageio.v2 as imageio 17 | from utils.inout import load_json 18 | 19 | pl.seed_everything(2023) 20 | 21 | 22 | class BOPTemplate(Dataset): 23 | def __init__( 24 | self, 25 | template_dir, 26 | obj_ids, 27 | processing_config, 28 | level_templates, 29 | pose_distribution, 30 | **kwargs, 31 | ): 32 | self.template_dir = template_dir 33 | if obj_ids is None: 34 | obj_ids = [ 35 | int(obj_id[4:]) 36 | for obj_id in os.listdir(template_dir) 37 | if osp.isdir(osp.join(template_dir, obj_id)) 38 | ] 39 | obj_ids = sorted(obj_ids) 40 | logging.info(f"Found {obj_ids} objects in {self.template_dir}") 41 | self.obj_ids = obj_ids 42 | self.processing_config = processing_config 43 | self.rgb_transform = T.Compose( 44 | [ 45 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 46 | ] 47 | ) 48 | self.proposal_processor = CropResizePad(self.processing_config.image_size) 49 | self.load_template_poses(level_templates, pose_distribution) 50 | 51 | def __len__(self): 52 | return len(self.obj_ids) 53 | 54 | def load_template_poses(self, level_templates, pose_distribution): 55 | if pose_distribution == "all": 56 | self.index_templates = load_index_level_in_level2(level_templates, "all") 57 | else: 58 | raise NotImplementedError 59 | 60 | def __getitem__(self, idx): 61 | templates, masks, boxes = [], [], [] 62 | for id_template in self.index_templates: 63 | image = Image.open( 64 | f"{self.template_dir}/obj_{self.obj_ids[idx]:06d}/{id_template:06d}.png" 65 | ) 66 | boxes.append(image.getbbox()) 67 | 68 | mask = image.getchannel("A") 69 | mask = torch.from_numpy(np.array(mask) / 255).float() 70 | masks.append(mask.unsqueeze(-1)) 71 | 72 | image = torch.from_numpy(np.array(image.convert("RGB")) / 255).float() 73 | templates.append(image) 74 | 75 | templates = torch.stack(templates).permute(0, 3, 1, 2) 76 | masks = torch.stack(masks).permute(0, 3, 1, 2) 77 | boxes = torch.tensor(np.array(boxes)) 78 | templates_croped = self.proposal_processor(images=templates, boxes=boxes) 79 | masks_cropped = self.proposal_processor(images=masks, boxes=boxes) 80 | return { 81 | "templates": self.rgb_transform(templates_croped), 82 | "template_masks": masks_cropped[:, 0, :, :], 83 | } 84 | 85 | 86 | class BaseBOPTest(BaseBOP): 87 | def __init__( 88 | self, 89 | root_dir, 90 | split, 91 | **kwargs, 92 | ): 93 | self.root_dir = root_dir 94 | self.split = split 95 | self.load_list_scene(split=split) 96 | self.load_metaData(reset_metaData=True) 97 | # shuffle metadata 98 | self.metaData = self.metaData.sample(frac=1, random_state=2021).reset_index() 99 | self.camemra_params = {} 100 | for scenes in self.list_scenes: 101 | scenes_id = scenes.split("/")[-1] 102 | self.camemra_params.setdefault(scenes_id, load_json(osp.join(scenes, "scene_camera.json"))) 103 | self.rgb_transform = T.Compose( 104 | [ 105 | T.ToTensor(), 106 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 107 | ] 108 | ) 109 | 110 | def load_depth_img(self, idx): 111 | depth_path = self.metaData.iloc[idx]["depth_path"] 112 | if depth_path is None: 113 | if self.root_dir.split("/")[-1] == "itodd": 114 | depth_path = self.metaData.iloc[idx]["rgb_path"].replace("gray", "depth") 115 | else: 116 | depth_path = self.metaData.iloc[idx]["rgb_path"].replace("rgb", "depth") 117 | depth = np.array(imageio.imread(depth_path)) 118 | return depth 119 | 120 | def __getitem__(self, idx): 121 | rgb_path = self.metaData.iloc[idx]["rgb_path"] 122 | scene_id = self.metaData.iloc[idx]["scene_id"] 123 | frame_id = self.metaData.iloc[idx]["frame_id"] 124 | cam_intrinsic = self.metaData.iloc[idx]["intrinsic"] 125 | image = Image.open(rgb_path) 126 | image = self.rgb_transform(image.convert("RGB")) 127 | depth = self.load_depth_img(idx) 128 | cam_intrinsic = np.array(cam_intrinsic).reshape((3, 3)) 129 | depth_scale = self.camemra_params[scene_id][f"{frame_id}"]["depth_scale"] 130 | 131 | return dict( 132 | image=image, 133 | scene_id=scene_id, 134 | frame_id=frame_id, 135 | depth=depth.astype(np.int32), 136 | cam_intrinsic=cam_intrinsic, 137 | depth_scale=depth_scale, 138 | ) 139 | 140 | if __name__ == "__main__": 141 | logging.basicConfig(level=logging.INFO) 142 | from omegaconf import DictConfig, OmegaConf 143 | from torchvision.utils import make_grid, save_image 144 | 145 | processing_config = OmegaConf.create( 146 | { 147 | "image_size": 224, 148 | } 149 | ) 150 | inv_rgb_transform = T.Compose( 151 | [ 152 | T.Normalize( 153 | mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], 154 | std=[1 / 0.229, 1 / 0.224, 1 / 0.225], 155 | ), 156 | ] 157 | ) 158 | dataset = BOPTemplate( 159 | template_dir="/home/nguyen/Documents/datasets/bop23/datasets/templates_pyrender/lmo", 160 | obj_ids=None, 161 | level_templates=0, 162 | pose_distribution="all", 163 | processing_config=processing_config, 164 | ) 165 | for idx in tqdm(range(len(dataset))): 166 | sample = dataset[idx] 167 | sample["templates"] = inv_rgb_transform(sample["templates"]) 168 | save_image(sample["templates"], f"./tmp/lm_{idx}.png", nrow=7) -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/run_inference.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | from hydra.utils import instantiate 6 | from torch.utils.data import DataLoader 7 | 8 | 9 | @hydra.main(version_base=None, config_path="configs", config_name="run_inference") 10 | def run_inference(cfg: DictConfig): 11 | OmegaConf.set_struct(cfg, False) 12 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 13 | output_path = hydra_cfg["runtime"]["output_dir"] 14 | logging.info( 15 | f"Training script. The outputs of hydra will be stored in: {output_path}" 16 | ) 17 | logging.info("Initializing logger, callbacks and trainer") 18 | 19 | if cfg.machine.name == "slurm": 20 | num_gpus = int(os.environ["SLURM_GPUS_ON_NODE"]) 21 | num_nodes = int(os.environ["SLURM_NNODES"]) 22 | cfg.machine.trainer.devices = num_gpus 23 | cfg.machine.trainer.num_nodes = num_nodes 24 | logging.info(f"Slurm config: {num_gpus} gpus, {num_nodes} nodes") 25 | trainer = instantiate(cfg.machine.trainer) 26 | 27 | default_ref_dataloader_config = cfg.data.reference_dataloader 28 | default_query_dataloader_config = cfg.data.query_dataloader 29 | 30 | query_dataloader_config = default_query_dataloader_config.copy() 31 | ref_dataloader_config = default_ref_dataloader_config.copy() 32 | 33 | if cfg.dataset_name in ["hb", "tless"]: 34 | query_dataloader_config.split = "test_primesense" 35 | else: 36 | query_dataloader_config.split = "test" 37 | query_dataloader_config.root_dir += f"{cfg.dataset_name}" 38 | query_dataset = instantiate(query_dataloader_config) 39 | 40 | logging.info("Initializing model") 41 | model = instantiate(cfg.model) 42 | 43 | model.ref_obj_names = cfg.data.datasets[cfg.dataset_name].obj_names 44 | model.dataset_name = cfg.dataset_name 45 | 46 | query_dataloader = DataLoader( 47 | query_dataset, 48 | batch_size=1, # only support a single image for now 49 | num_workers=cfg.machine.num_workers, 50 | shuffle=False, 51 | ) 52 | if cfg.model.onboarding_config.rendering_type == "pyrender": 53 | ref_dataloader_config.template_dir += f"templates_pyrender/{cfg.dataset_name}" 54 | ref_dataset = instantiate(ref_dataloader_config) 55 | elif cfg.model.onboarding_config.rendering_type == "pbr": 56 | logging.info("Using BlenderProc for reference images") 57 | ref_dataloader_config._target_ = "provider.bop_pbr.BOPTemplatePBR" 58 | ref_dataloader_config.root_dir = f"{query_dataloader_config.root_dir}" 59 | ref_dataloader_config.template_dir += f"templates_pyrender/{cfg.dataset_name}" 60 | if not os.path.exists(ref_dataloader_config.template_dir): 61 | os.makedirs(ref_dataloader_config.template_dir) 62 | ref_dataset = instantiate(ref_dataloader_config) 63 | ref_dataset.load_processed_metaData(reset_metaData=True) 64 | else: 65 | raise NotImplementedError 66 | model.ref_dataset = ref_dataset 67 | 68 | segmentation_name = cfg.model.segmentor_model._target_.split(".")[-1] 69 | agg_function = cfg.model.matching_config.aggregation_function 70 | rendering_type = cfg.model.onboarding_config.rendering_type 71 | level_template = cfg.model.onboarding_config.level_templates 72 | model.name_prediction_file = f"result_{cfg.dataset_name}" 73 | logging.info(f"Loading dataloader for {cfg.dataset_name} done!") 74 | trainer.test( 75 | model, 76 | dataloaders=query_dataloader, 77 | ) 78 | logging.info(f"---" * 20) 79 | 80 | 81 | if __name__ == "__main__": 82 | logging.basicConfig(level=logging.INFO) 83 | run_inference() 84 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/inout.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import shutil 4 | import numpy as np 5 | import json 6 | import sys 7 | 8 | try: 9 | import ruamel_yaml as yaml 10 | except ModuleNotFoundError: 11 | import ruamel.yaml as yaml 12 | 13 | 14 | def create_folder(path): 15 | try: 16 | os.mkdir(path) 17 | except OSError as exc: 18 | if exc.errno != errno.EEXIST: 19 | raise 20 | pass 21 | 22 | 23 | def del_folder(path): 24 | try: 25 | shutil.rmtree(path) 26 | except OSError as exc: 27 | pass 28 | 29 | 30 | def write_txt(path, list_files): 31 | with open(path, "w") as f: 32 | for idx in list_files: 33 | f.write(idx + "\n") 34 | f.close() 35 | 36 | 37 | def open_txt(path): 38 | with open(path, "r") as f: 39 | lines = f.readlines() 40 | lines = [line.strip() for line in lines] 41 | return lines 42 | 43 | 44 | def load_json(path): 45 | with open(path, "r") as f: 46 | # info = yaml.load(f, Loader=yaml.CLoader) 47 | info = json.load(f) 48 | return info 49 | 50 | 51 | def save_json(path, info): 52 | # save to json without sorting keys or changing format 53 | with open(path, "w") as f: 54 | json.dump(info, f, indent=4) 55 | 56 | 57 | def save_json_bop23(path, info): 58 | # save to json without sorting keys or changing format 59 | with open(path, "w") as f: 60 | json.dump(info, f) 61 | 62 | 63 | def save_npz(path, info): 64 | np.savez_compressed(path, **info) 65 | 66 | 67 | def casting_format_to_save_json(data): 68 | # casting for every keys in dict to list so that it can be saved as json 69 | for key in data.keys(): 70 | if ( 71 | isinstance(data[key][0], np.ndarray) 72 | or isinstance(data[key][0], np.float32) 73 | or isinstance(data[key][0], np.float64) 74 | or isinstance(data[key][0], np.int32) 75 | or isinstance(data[key][0], np.int64) 76 | ): 77 | data[key] = np.array(data[key]).tolist() 78 | return data 79 | 80 | 81 | def get_root_project(): 82 | return os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 83 | 84 | 85 | def append_lib(path): 86 | sys.path.append(os.path.join(path, "src")) -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class LevelsFilter(logging.Filter): 5 | def __init__(self, levels): 6 | self.levels = [getattr(logging, level) for level in levels] 7 | 8 | def filter(self, record): 9 | return record.levelno in self.levels 10 | 11 | 12 | class StreamToLogger(object): 13 | """ 14 | Fake file-like stream object that redirects writes to a logger instance. 15 | """ 16 | 17 | def __init__(self, logger, level): 18 | self.logger = logger 19 | self.level = level 20 | self.linebuf = "" 21 | 22 | def write(self, buf): 23 | for line in buf.rstrip().splitlines(): 24 | self.logger.log(self.level, line.rstrip()) 25 | 26 | def flush(self): 27 | pass 28 | 29 | 30 | class TqdmLoggingHandler(logging.Handler): 31 | def __init__(self, level=logging.NOTSET): 32 | super().__init__(level) 33 | 34 | def emit(self, record): 35 | import tqdm 36 | 37 | try: 38 | msg = self.format(record) 39 | tqdm.tqdm.write(msg) 40 | self.flush() 41 | except Exception: 42 | self.handleError(record) 43 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/create_template_poses.py: -------------------------------------------------------------------------------- 1 | # blenderproc run poses/create_template_poses.py 2 | import blenderproc 3 | import bpy 4 | import bmesh 5 | import math 6 | import numpy as np 7 | import os 8 | 9 | 10 | def get_camera_positions(nSubDiv): 11 | """ 12 | * Construct an icosphere 13 | * subdived 14 | """ 15 | 16 | bpy.ops.mesh.primitive_ico_sphere_add(location=(0, 0, 0), enter_editmode=True) 17 | # bpy.ops.export_mesh.ply(filepath='./sphere.ply') 18 | icos = bpy.context.object 19 | me = icos.data 20 | 21 | # -- cut away lower part 22 | bm = bmesh.from_edit_mesh(me) 23 | sel = [v for v in bm.verts if v.co[2] < 0] 24 | 25 | bmesh.ops.delete(bm, geom=sel, context="FACES") 26 | bmesh.update_edit_mesh(me) 27 | 28 | # -- subdivide and move new vertices out to the surface of the sphere 29 | # nSubDiv = 3 30 | for i in range(nSubDiv): 31 | bpy.ops.mesh.subdivide() 32 | 33 | bm = bmesh.from_edit_mesh(me) 34 | for v in bm.verts: 35 | l = math.sqrt(v.co[0] ** 2 + v.co[1] ** 2 + v.co[2] ** 2) 36 | v.co[0] /= l 37 | v.co[1] /= l 38 | v.co[2] /= l 39 | bmesh.update_edit_mesh(me) 40 | 41 | # -- cut away zero elevation 42 | bm = bmesh.from_edit_mesh(me) 43 | sel = [v for v in bm.verts if v.co[2] <= 0] 44 | bmesh.ops.delete(bm, geom=sel, context="FACES") 45 | bmesh.update_edit_mesh(me) 46 | 47 | # convert vertex positions to az,el 48 | positions = [] 49 | angles = [] 50 | bm = bmesh.from_edit_mesh(me) 51 | for v in bm.verts: 52 | x = v.co[0] 53 | y = v.co[1] 54 | z = v.co[2] 55 | az = math.atan2(x, y) # *180./math.pi 56 | el = math.atan2(z, math.sqrt(x**2 + y**2)) # *180./math.pi 57 | # positions.append((az,el)) 58 | angles.append((el, az)) 59 | positions.append((x, y, z)) 60 | 61 | bpy.ops.object.editmode_toggle() 62 | 63 | # sort positions, first by az and el 64 | data = zip(angles, positions) 65 | positions = sorted(data) 66 | positions = [y for x, y in positions] 67 | angles = sorted(angles) 68 | return angles, positions 69 | 70 | 71 | def normalize(vec): 72 | return vec / (np.linalg.norm(vec, axis=-1, keepdims=True)) 73 | 74 | 75 | def look_at(cam_location, point): 76 | # Cam points in positive z direction 77 | forward = point - cam_location 78 | forward = normalize(forward) 79 | 80 | tmp = np.array([0.0, 0.0, -1.0]) 81 | # print warning when camera location is parallel to tmp 82 | norm = min( 83 | np.linalg.norm(cam_location - tmp, axis=-1), 84 | np.linalg.norm(cam_location + tmp, axis=-1), 85 | ) 86 | if norm < 1e-3: 87 | print("Warning: camera location is parallel to tmp") 88 | tmp = np.array([0.0, -1.0, 0.0]) 89 | 90 | right = np.cross(tmp, forward) 91 | right = normalize(right) 92 | 93 | up = np.cross(forward, right) 94 | up = normalize(up) 95 | 96 | mat = np.stack((right, up, forward, cam_location), axis=-1) 97 | 98 | hom_vec = np.array([[0.0, 0.0, 0.0, 1.0]]) 99 | 100 | if len(mat.shape) > 2: 101 | hom_vec = np.tile(hom_vec, [mat.shape[0], 1, 1]) 102 | 103 | mat = np.concatenate((mat, hom_vec), axis=-2) 104 | return mat 105 | 106 | 107 | def convert_location_to_rotation(locations): 108 | obj_poses = np.zeros((len(locations), 4, 4)) 109 | for idx, pt in enumerate(locations): 110 | obj_poses[idx] = look_at(pt, np.array([0, 0, 0])) 111 | return obj_poses 112 | 113 | 114 | def inverse_transform(poses): 115 | new_poses = np.zeros_like(poses) 116 | for idx_pose in range(len(poses)): 117 | rot = poses[idx_pose, :3, :3] 118 | t = poses[idx_pose, :3, 3] 119 | rot = np.transpose(rot) 120 | t = -np.matmul(rot, t) 121 | new_poses[idx_pose][3][3] = 1 122 | new_poses[idx_pose][:3, :3] = rot 123 | new_poses[idx_pose][:3, 3] = t 124 | return new_poses 125 | 126 | 127 | save_dir = "utils/poses/predefined_poses" 128 | if not os.path.exists(save_dir): 129 | os.makedirs(save_dir) 130 | 131 | for level in [0, 1, 2]: 132 | position_icosphere = np.asarray(get_camera_positions(level)[1]) 133 | cam_poses = convert_location_to_rotation(position_icosphere) 134 | cam_poses[:, :3, 3] *= 1000. 135 | np.save(f"{save_dir}/cam_poses_level{level}.npy", cam_poses) 136 | obj_poses = inverse_transform(cam_poses) 137 | np.save(f"{save_dir}/obj_poses_level{level}.npy", obj_poses) 138 | 139 | print("Output saved to: " + save_dir) 140 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/find_neighbors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import open3d as o3d 3 | import os 4 | from utils.poses.pose_utils import ( 5 | get_obj_poses_from_template_level, 6 | get_root_project, 7 | NearestTemplateFinder, 8 | ) 9 | import os.path as osp 10 | # from utils.vis_3d_utils import convert_numpy_to_open3d, draw_camera 11 | 12 | if __name__ == "__main__": 13 | for template_level in range(2): 14 | templates_poses_level0 = get_obj_poses_from_template_level( 15 | template_level, "all", return_cam=True 16 | ) 17 | finder = NearestTemplateFinder( 18 | level_templates=2, 19 | pose_distribution="all", 20 | return_inplane=True, 21 | ) 22 | obj_poses_level = get_obj_poses_from_template_level(template_level, "all", return_cam=False) 23 | idx_templates, inplanes = finder.search_nearest_template(obj_poses_level) 24 | print(len(obj_poses_level), len(idx_templates)) 25 | root_repo = get_root_project() 26 | save_path = os.path.join(root_repo, f"utils/poses/predefined_poses/idx_all_level{template_level}_in_level2.npy") 27 | np.save(save_path, idx_templates) 28 | 29 | # level 2 in level 2 is just itself 30 | obj_poses_level = get_obj_poses_from_template_level(2, "all", return_cam=False) 31 | print(len(obj_poses_level)) 32 | save_path = os.path.join(root_repo, "utils/poses/predefined_poses/idx_all_level2_in_level2.npy") 33 | np.save(save_path, np.arange(len(obj_poses_level))) -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/fps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # credit: https://github.com/ziruiw-dev/farthest-point-sampling/blob/master/fps_v1.py 3 | 4 | class FPS: 5 | def __init__(self, pcd_xyz, n_samples): 6 | self.n_samples = n_samples 7 | self.pcd_xyz = pcd_xyz 8 | self.n_pts = pcd_xyz.shape[0] 9 | self.dim = pcd_xyz.shape[1] 10 | self.selected_pts = None 11 | self.selected_pts_expanded = np.zeros(shape=(n_samples, 1, self.dim)) 12 | self.remaining_pts = np.copy(pcd_xyz) 13 | 14 | self.grouping_radius = None 15 | self.dist_pts_to_selected = ( 16 | None # Iteratively updated in step(). Finally re-used in group() 17 | ) 18 | self.labels = None 19 | 20 | # Random pick a start 21 | self.start_idx = np.random.randint(low=0, high=self.n_pts - 1) 22 | self.selected_pts_expanded[0] = self.remaining_pts[self.start_idx] 23 | self.n_selected_pts = 1 24 | self.idx_selected = [self.start_idx] 25 | 26 | def get_selected_pts(self): 27 | self.selected_pts = np.squeeze(self.selected_pts_expanded, axis=1) 28 | return self.selected_pts 29 | 30 | def step(self): 31 | if self.n_selected_pts < self.n_samples: 32 | self.dist_pts_to_selected = self.__distance__( 33 | self.remaining_pts, self.selected_pts_expanded[: self.n_selected_pts] 34 | ).T 35 | dist_pts_to_selected_min = np.min( 36 | self.dist_pts_to_selected, axis=1, keepdims=True 37 | ) 38 | res_selected_idx = np.argmax(dist_pts_to_selected_min) 39 | self.selected_pts_expanded[self.n_selected_pts] = self.remaining_pts[ 40 | res_selected_idx 41 | ] 42 | 43 | self.n_selected_pts += 1 44 | 45 | # add to idx_selected 46 | self.idx_selected.append(res_selected_idx) 47 | else: 48 | print("Got enough number samples") 49 | 50 | def fit(self): 51 | for _ in range(1, self.n_samples): 52 | self.step() 53 | return self.get_selected_pts(), self.idx_selected 54 | 55 | def group(self, radius): 56 | self.grouping_radius = radius # the grouping radius is not actually used 57 | dists = self.dist_pts_to_selected 58 | 59 | # Ignore the "points"-"selected" relations if it's larger than the radius 60 | dists = np.where(dists > radius, dists + 1000000 * radius, dists) 61 | 62 | # Find the relation with the smallest distance. 63 | # NOTE: the smallest distance may still larger than the radius. 64 | self.labels = np.argmin(dists, axis=1) 65 | return self.labels 66 | 67 | @staticmethod 68 | def __distance__(a, b): 69 | return np.linalg.norm(a - b, ord=2, axis=2) 70 | 71 | 72 | if __name__ == "__main__": 73 | points = np.random.rand(1000, 3) 74 | sampled_points, idx_selected = FPS(points, 100).fit() 75 | print(sampled_points.shape, len(idx_selected)) 76 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level0.npy -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level1.npy -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level2.npy -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level0_in_level2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level0_in_level2.npy -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level1_in_level2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level1_in_level2.npy -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level2_in_level2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/idx_all_level2_in_level2.npy -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level0.npy -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level1.npy -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/SAM-6D/Instance_Segmentation_Model/utils/poses/predefined_poses/obj_poses_level2.npy -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/poses/pyrender.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyrender 3 | import trimesh 4 | import os 5 | from PIL import Image 6 | import numpy as np 7 | import os.path as osp 8 | from tqdm import tqdm 9 | import argparse 10 | from utils.trimesh_utils import as_mesh 11 | from utils.trimesh_utils import get_obj_diameter 12 | os.environ["DISPLAY"] = ":1" 13 | os.environ["PYOPENGL_PLATFORM"] = "egl" 14 | 15 | 16 | def render( 17 | mesh, 18 | output_dir, 19 | obj_poses, 20 | img_size, 21 | intrinsic, 22 | light_itensity=0.6, 23 | is_tless=False, 24 | ): 25 | # camera pose is fixed as np.eye(4) 26 | cam_pose = np.eye(4) 27 | # convert openCV camera 28 | cam_pose[1, 1] = -1 29 | cam_pose[2, 2] = -1 30 | # create scene config 31 | ambient_light = np.array([0.02, 0.02, 0.02, 1.0]) # np.array([1.0, 1.0, 1.0, 1.0]) 32 | if light_itensity != 0.6: 33 | ambient_light = np.array([1.0, 1.0, 1.0, 1.0]) 34 | scene = pyrender.Scene( 35 | bg_color=np.array([0.0, 0.0, 0.0, 0.0]), ambient_light=ambient_light 36 | ) 37 | light = pyrender.SpotLight( 38 | color=np.ones(3), 39 | intensity=light_itensity, 40 | innerConeAngle=np.pi / 16.0, 41 | outerConeAngle=np.pi / 6.0, 42 | ) 43 | scene.add(light, pose=cam_pose) 44 | 45 | # create camera and render engine 46 | fx, fy, cx, cy = intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2] 47 | camera = pyrender.IntrinsicsCamera( 48 | fx=fx, fy=fy, cx=cx, cy=cy, znear=0.05, zfar=100000 49 | ) 50 | scene.add(camera, pose=cam_pose) 51 | render_engine = pyrender.OffscreenRenderer(img_size[1], img_size[0]) 52 | cad_node = scene.add(mesh, pose=np.eye(4), name="cad") 53 | 54 | for idx_frame in range(obj_poses.shape[0]): 55 | scene.set_pose(cad_node, obj_poses[idx_frame]) 56 | rgb, depth = render_engine.render(scene, pyrender.constants.RenderFlags.RGBA) 57 | rgb = Image.fromarray(np.uint8(rgb)) 58 | rgb.save(osp.join(output_dir, f"{idx_frame:06d}.png")) 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("cad_path", nargs="?", help="Path to the model file") 64 | parser.add_argument("obj_pose", nargs="?", help="Path to the model file") 65 | parser.add_argument( 66 | "output_dir", nargs="?", help="Path to where the final files will be saved" 67 | ) 68 | parser.add_argument("gpus_devices", nargs="?", help="GPU devices") 69 | parser.add_argument("disable_output", nargs="?", help="Disable output of blender") 70 | parser.add_argument("light_itensity", nargs="?", type=float, default=0.6, help="Light itensity") 71 | parser.add_argument("radius", nargs="?", type=float, default=1, help="Distance from camera to object") 72 | args = parser.parse_args() 73 | print(args) 74 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus_devices 75 | poses = np.load(args.obj_pose) 76 | # we can increase high energy for lightning but it's simpler to change just scale of the object to meter 77 | # poses[:, :3, :3] = poses[:, :3, :3] / 1000.0 78 | poses[:, :3, 3] = poses[:, :3, 3] / 1000.0 79 | if args.radius != 1: 80 | poses[:, :3, 3] = poses[:, :3, 3] * args.radius 81 | if "tless" in args.output_dir: 82 | intrinsic = np.asarray( 83 | [1075.65091572, 0.0, 360, 0.0, 1073.90347929, 270, 0.0, 0.0, 1.0] 84 | ).reshape(3, 3) 85 | img_size = [540, 720] 86 | is_tless = True 87 | else: 88 | intrinsic = np.array( 89 | [[572.4114, 0.0, 325.2611], [0.0, 573.57043, 242.04899], [0.0, 0.0, 1.0]] 90 | ) 91 | img_size = [480, 640] 92 | is_tless = False 93 | 94 | # load mesh to meter 95 | mesh = trimesh.load_mesh(args.cad_path) 96 | diameter = get_obj_diameter(mesh) 97 | if diameter > 100: # object is in mm 98 | mesh.apply_scale(0.001) 99 | if is_tless: 100 | # setting uniform colors for mesh 101 | color = 0.4 102 | mesh.visual.face_colors = np.ones((len(mesh.faces), 3)) * color 103 | mesh.visual.vertex_colors = np.ones((len(mesh.vertices), 3)) * color 104 | mesh = pyrender.Mesh.from_trimesh(mesh, smooth=False) 105 | else: 106 | mesh = pyrender.Mesh.from_trimesh(as_mesh(mesh)) 107 | os.makedirs(args.output_dir, exist_ok=True) 108 | render( 109 | output_dir=args.output_dir, 110 | mesh=mesh, 111 | obj_poses=poses, 112 | intrinsic=intrinsic, 113 | img_size=(480, 640), 114 | light_itensity=args.light_itensity, 115 | ) 116 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/trimesh_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | import torch 4 | 5 | 6 | def load_mesh(path, ORIGIN_GEOMETRY="BOUNDS"): 7 | mesh = as_mesh(trimesh.load(path)) 8 | if ORIGIN_GEOMETRY == "BOUNDS": 9 | AABB = mesh.bounds 10 | center = np.mean(AABB, axis=0) 11 | mesh.vertices -= center 12 | return mesh 13 | 14 | 15 | def get_bbox_from_mesh(mesh): 16 | AABB = mesh.bounds 17 | OBB = AABB_to_OBB(AABB) 18 | return OBB 19 | 20 | 21 | def get_obj_diameter(mesh_path): 22 | mesh = load_mesh(mesh_path) 23 | extents = mesh.extents * 2 24 | return np.linalg.norm(extents) 25 | 26 | 27 | def as_mesh(scene_or_mesh): 28 | if isinstance(scene_or_mesh, trimesh.Scene): 29 | result = trimesh.util.concatenate( 30 | [ 31 | trimesh.Trimesh(vertices=m.vertices, faces=m.faces) 32 | for m in scene_or_mesh.geometry.values() 33 | ] 34 | ) 35 | else: 36 | result = scene_or_mesh 37 | return result 38 | 39 | 40 | def AABB_to_OBB(AABB): 41 | """ 42 | AABB bbox to oriented bounding box 43 | """ 44 | minx, miny, minz, maxx, maxy, maxz = np.arange(6) 45 | corner_index = np.array( 46 | [ 47 | minx, 48 | miny, 49 | minz, 50 | maxx, 51 | miny, 52 | minz, 53 | maxx, 54 | maxy, 55 | minz, 56 | minx, 57 | maxy, 58 | minz, 59 | minx, 60 | miny, 61 | maxz, 62 | maxx, 63 | miny, 64 | maxz, 65 | maxx, 66 | maxy, 67 | maxz, 68 | minx, 69 | maxy, 70 | maxz, 71 | ] 72 | ).reshape((-1, 3)) 73 | 74 | corners = AABB.reshape(-1)[corner_index] 75 | return corners 76 | 77 | def depth_image_to_pointcloud_translate_torch(depth, scale, K): 78 | u = torch.arange(0, depth.shape[2]) 79 | v = torch.arange(0, depth.shape[1]) 80 | 81 | u, v = torch.meshgrid(u, v, indexing="xy") 82 | u = u.to(depth.device) 83 | v = v.to(depth.device) 84 | 85 | # depth metric is mm, depth_scale metric is m 86 | # K metric is m 87 | Z = depth * scale / 1000 88 | X = (u - K[0, 2]) * Z / K[0, 0] 89 | Y = (v - K[1, 2]) * Z / K[1, 1] 90 | 91 | valid = Z > 0 92 | 93 | X = X * valid 94 | Y = Y * valid 95 | Z = Z * valid 96 | 97 | # average should run on valid point 98 | valid_num = torch.count_nonzero(valid, axis=(1, 2)) + 1e-8 99 | avg_X = torch.sum(X, axis=(1, 2)) / valid_num 100 | avg_Y = torch.sum(Y, axis=(1, 2)) / valid_num 101 | avg_Z = torch.sum(Z, axis=(1, 2)) / valid_num 102 | 103 | translate = torch.vstack((avg_X, avg_Y, avg_Z)).permute(1, 0) 104 | 105 | return translate 106 | 107 | 108 | if __name__ == "__main__": 109 | mesh_path = ( 110 | "/media/nguyen/Data/dataset/ShapeNet/ShapeNetCore.v2/" 111 | "03001627/1016f4debe988507589aae130c1f06fb/models/model_normalized.obj" 112 | ) 113 | mesh = load_mesh(mesh_path) 114 | bbox = get_bbox_from_mesh(mesh) 115 | # create a visualization scene with rays, hits, and mesh 116 | scene = trimesh.Scene([mesh, trimesh.points.PointCloud(bbox)]) 117 | # display the scene 118 | scene.show() 119 | -------------------------------------------------------------------------------- /SAM-6D/Instance_Segmentation_Model/utils/weight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | 5 | 6 | def load_checkpoint(model, checkpoint_path, checkpoint_key=None, prefix=""): 7 | checkpoint = torch.load(checkpoint_path) 8 | if checkpoint_key is not None: 9 | pretrained_dict = checkpoint[checkpoint_key] # "state_dict" 10 | else: 11 | pretrained_dict = checkpoint 12 | pretrained_dict = {k.replace(prefix, ""): v for k, v in pretrained_dict.items()} 13 | model_dict = model.state_dict() 14 | # compare keys and update value 15 | pretrained_dict_can_load = { 16 | k: v 17 | for k, v in pretrained_dict.items() 18 | if k in model_dict and v.shape == model_dict[k].shape 19 | } 20 | pretrained_dict_cannot_load = [ 21 | k for k, v in pretrained_dict.items() if k not in model_dict 22 | ] 23 | model_dict_not_update = [ 24 | k for k, v in model_dict.items() if k not in pretrained_dict 25 | ] 26 | module_cannot_load = np.unique( 27 | [k.split(".")[0] for k in pretrained_dict_cannot_load] # 28 | ) 29 | module_not_update = np.unique([k.split(".")[0] for k in model_dict_not_update]) # 30 | logging.info(f"Cannot load: {module_cannot_load}") 31 | logging.info(f"Not update: {module_not_update}") 32 | logging.info( 33 | f"Pretrained: {len(pretrained_dict)}/ Loaded: {len(pretrained_dict_can_load)}/ Cannot loaded: {len(pretrained_dict_cannot_load)} VS Current model: {len(model_dict)}" 34 | ) 35 | model_dict.update(pretrained_dict_can_load) 36 | model.load_state_dict(model_dict) 37 | logging.info("Load pretrained done!") 38 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/README.md: -------------------------------------------------------------------------------- 1 | # Pose Estimation Model (PEM) for SAM-6D 2 | 3 | 4 | 5 | ![image](https://github.com/JiehongLin/SAM-6D/blob/main/pics/overview_pem.png) 6 | 7 | ## Requirements 8 | The code has been tested with 9 | - python 3.9.6 10 | - pytorch 2.0.0 11 | - CUDA 11.3 12 | 13 | Other dependencies: 14 | 15 | ``` 16 | sh dependencies.sh 17 | ``` 18 | 19 | ## Data Preparation 20 | 21 | Please refer to [[link](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Data)] for more details. 22 | 23 | 24 | ## Model Download 25 | Our trained model is provided [[here](https://drive.google.com/file/d/1joW9IvwsaRJYxoUmGo68dBVg-HcFNyI7/view?usp=sharing)], and could be downloaded via the command: 26 | ``` 27 | python download_sam6d-pem.py 28 | ``` 29 | 30 | ## Training on MegaPose Training Set 31 | 32 | To train the Pose Estimation Model of SAM-6D, please prepare the training data and run the folowing command: 33 | ``` 34 | python train.py --gpus 0,1,2,3 --model pose_estimation_model --config config/base.yaml 35 | ``` 36 | By default, we use four GPUs of 3090ti to train the model with batchsize set as 28. 37 | 38 | 39 | ## Evaluation on BOP Datasets 40 | 41 | To evaluate the model on BOP datasets, please run the following command: 42 | ``` 43 | python test_bop.py --gpus 0 --model pose_estimation_model --config config/base.yaml --dataset $DATASET --view 42 44 | ``` 45 | The string "DATASET" could be set as `lmo`, `icbin`, `itodd`, `hb`, `tless`, `tudl`, `ycbv`, or `all`. Before evaluation, please refer to [[link](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Data)] for rendering the object templates of BOP datasets, or download our [rendered templates](https://drive.google.com/drive/folders/1fXt5Z6YDPZTJICZcywBUhu5rWnPvYAPI?usp=drive_link). Besides, the instance segmentation should be done following [[link](https://github.com/JiehongLin/SAM-6D/tree/main/SAM-6D/Instance_Segmentation_Model)]; to test on your own segmentation results, you could change the "detection_paths" in the `test_bop.py` file. 46 | 47 | One could also download our trained model for evaluation: 48 | ``` 49 | python test_bop.py --gpus 0 --model pose_estimation_model --config config/base.yaml --checkpoint_path checkpoints/sam-6d-pem-base.pth --dataset $DATASET --view 42 50 | ``` 51 | 52 | 53 | ## Acknowledgements 54 | - [MegaPose](https://github.com/megapose6d/megapose6d) 55 | - [GDRNPP](https://github.com/shanice-l/gdrnpp_bop2022) 56 | - [GeoTransformer](https://github.com/qinzheng93/GeoTransformer) 57 | - [Flatten Transformer](https://github.com/LeapLabTHU/FLatten-Transformer) 58 | 59 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/config/base.yaml: -------------------------------------------------------------------------------- 1 | NAME_PROJECT: SAM-6D 2 | 3 | optimizer: 4 | type : Adam 5 | lr : 0.0001 6 | betas: [0.5, 0.999] 7 | eps : 0.000001 8 | weight_decay: 0.0 9 | 10 | lr_scheduler: 11 | type: WarmupCosineLR 12 | max_iters: 600000 13 | warmup_factor: 0.001 14 | warmup_iters: 1000 15 | 16 | model: 17 | coarse_npoint: 196 18 | fine_npoint: 2048 19 | feature_extraction: 20 | vit_type: vit_base 21 | up_type: linear 22 | embed_dim: 768 23 | out_dim: 256 24 | use_pyramid_feat: True 25 | pretrained: True 26 | geo_embedding: 27 | sigma_d: 0.2 28 | sigma_a: 15 29 | angle_k: 3 30 | reduction_a: max 31 | hidden_dim: 256 32 | coarse_point_matching: 33 | nblock: 3 34 | input_dim: 256 35 | hidden_dim: 256 36 | out_dim: 256 37 | temp: 0.1 38 | sim_type: cosine 39 | normalize_feat: True 40 | loss_dis_thres: 0.15 41 | nproposal1: 6000 42 | nproposal2: 300 43 | fine_point_matching: 44 | nblock: 3 45 | input_dim: 256 46 | hidden_dim: 256 47 | out_dim: 256 48 | pe_radius1: 0.1 49 | pe_radius2: 0.2 50 | focusing_factor: 3 51 | temp: 0.1 52 | sim_type: cosine 53 | normalize_feat: True 54 | loss_dis_thres: 0.15 55 | 56 | 57 | 58 | train_dataset: 59 | name: training_dataset 60 | data_dir: ../Data/MegaPose-Training-Data 61 | img_size: 224 62 | n_sample_observed_point: 2048 63 | n_sample_model_point: 2048 64 | n_sample_template_point: 5000 65 | min_visib_fract: 0.1 66 | min_px_count_visib: 512 67 | shift_range: 0.01 68 | rgb_mask_flag: True 69 | dilate_mask: True 70 | 71 | train_dataloader: 72 | bs : 28 73 | num_workers : 24 74 | shuffle : True 75 | drop_last : True 76 | pin_memory : False 77 | 78 | 79 | 80 | test_dataset: 81 | name: bop_test_dataset 82 | data_dir: ../Data/BOP 83 | template_dir: ../Data/BOP-Templates 84 | img_size: 224 85 | n_sample_observed_point: 2048 86 | n_sample_model_point: 1024 87 | n_sample_template_point: 5000 88 | minimum_n_point: 8 89 | rgb_mask_flag: True 90 | seg_filter_score: 0.25 91 | n_template_view: 42 92 | 93 | 94 | test_dataloader: 95 | bs : 16 96 | num_workers : 16 97 | shuffle : False 98 | drop_last : False 99 | pin_memory : False 100 | 101 | 102 | rd_seed: 1 103 | training_epoch: 15 104 | iters_to_print: 50 105 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/dependencies.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | pip install timm 5 | pip install gorilla-core==0.2.7.8 6 | pip install trimesh==3.22.1 7 | pip install imgaug 8 | pip install opencv-python 9 | pip install gpustat==1.0.0 10 | pip install einops 11 | 12 | cd /model/pointnet2 13 | python setup.py install 14 | cd .. 15 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/download_sam6d-pem.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | def download_model(output_path): 5 | import os 6 | command = f"gdown --no-cookies --no-check-certificate -O '{output_path}/sam-6d-pem-base.pth' 1joW9IvwsaRJYxoUmGo68dBVg-HcFNyI7" 7 | os.system(command) 8 | 9 | def download() -> None: 10 | root_dir = os.path.dirname((os.path.abspath(__file__))) 11 | save_dir = osp.join(root_dir, "checkpoints") 12 | os.makedirs(save_dir, exist_ok=True) 13 | download_model(save_dir) 14 | 15 | if __name__ == "__main__": 16 | download() 17 | 18 | 19 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/coarse_point_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformer import GeometricTransformer 5 | from model_utils import ( 6 | compute_feature_similarity, 7 | aug_pose_noise, 8 | compute_coarse_Rt, 9 | ) 10 | from loss_utils import compute_correspondence_loss 11 | 12 | 13 | 14 | class CoarsePointMatching(nn.Module): 15 | def __init__(self, cfg, return_feat=False): 16 | super(CoarsePointMatching, self).__init__() 17 | self.cfg = cfg 18 | self.return_feat = return_feat 19 | self.nblock = self.cfg.nblock 20 | 21 | self.in_proj = nn.Linear(cfg.input_dim, cfg.hidden_dim) 22 | self.out_proj = nn.Linear(cfg.hidden_dim, cfg.out_dim) 23 | 24 | self.bg_token = nn.Parameter(torch.randn(1, 1, cfg.hidden_dim) * .02) 25 | 26 | self.transformers = [] 27 | for _ in range(self.nblock): 28 | self.transformers.append(GeometricTransformer( 29 | blocks=['self', 'cross'], 30 | d_model = cfg.hidden_dim, 31 | num_heads = 4, 32 | dropout=None, 33 | activation_fn='ReLU', 34 | return_attention_scores=False, 35 | )) 36 | self.transformers = nn.ModuleList(self.transformers) 37 | 38 | def forward(self, p1, f1, geo1, p2, f2, geo2, radius, end_points): 39 | B = f1.size(0) 40 | 41 | f1 = self.in_proj(f1) 42 | f1 = torch.cat([self.bg_token.repeat(B,1,1), f1], dim=1) # adding bg 43 | f2 = self.in_proj(f2) 44 | f2 = torch.cat([self.bg_token.repeat(B,1,1), f2], dim=1) # adding bg 45 | 46 | atten_list = [] 47 | for idx in range(self.nblock): 48 | f1, f2 = self.transformers[idx](f1, geo1, f2, geo2) 49 | 50 | if self.training or idx==self.nblock-1: 51 | atten_list.append(compute_feature_similarity( 52 | self.out_proj(f1), 53 | self.out_proj(f2), 54 | self.cfg.sim_type, 55 | self.cfg.temp, 56 | self.cfg.normalize_feat 57 | )) 58 | 59 | if self.training: 60 | gt_R = end_points['rotation_label'] 61 | gt_t = end_points['translation_label'] / (radius.reshape(-1, 1)+1e-6) 62 | init_R, init_t = aug_pose_noise(gt_R, gt_t) 63 | 64 | end_points = compute_correspondence_loss( 65 | end_points, atten_list, p1, p2, gt_R, gt_t, 66 | dis_thres=self.cfg.loss_dis_thres, 67 | loss_str='coarse' 68 | ) 69 | else: 70 | init_R, init_t = compute_coarse_Rt( 71 | atten_list[-1], p1, p2, 72 | end_points['model'] / (radius.reshape(-1, 1, 1) + 1e-6), 73 | self.cfg.nproposal1, self.cfg.nproposal2, 74 | ) 75 | end_points['init_R'] = init_R 76 | end_points['init_t'] = init_t 77 | 78 | if self.return_feat: 79 | return end_points, self.out_proj(f1), self.out_proj(f2) 80 | else: 81 | return end_points 82 | 83 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/fine_point_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from transformer import SparseToDenseTransformer 6 | from model_utils import compute_feature_similarity, compute_fine_Rt 7 | from loss_utils import compute_correspondence_loss 8 | from pointnet2_utils import QueryAndGroup 9 | from pytorch_utils import SharedMLP, Conv1d 10 | 11 | 12 | class FinePointMatching(nn.Module): 13 | def __init__(self, cfg, return_feat=False): 14 | super(FinePointMatching, self).__init__() 15 | self.cfg = cfg 16 | self.return_feat = return_feat 17 | self.nblock = self.cfg.nblock 18 | 19 | self.in_proj = nn.Linear(cfg.input_dim, cfg.hidden_dim) 20 | self.out_proj = nn.Linear(cfg.hidden_dim, cfg.out_dim) 21 | 22 | self.bg_token = nn.Parameter(torch.randn(1, 1, cfg.hidden_dim) * .02) 23 | self.PE = PositionalEncoding(cfg.hidden_dim, r1=cfg.pe_radius1, r2=cfg.pe_radius2) 24 | 25 | self.transformers = [] 26 | for _ in range(self.nblock): 27 | self.transformers.append(SparseToDenseTransformer( 28 | cfg.hidden_dim, 29 | num_heads=4, 30 | sparse_blocks=['self', 'cross'], 31 | dropout=None, 32 | activation_fn='ReLU', 33 | focusing_factor=cfg.focusing_factor, 34 | with_bg_token=True, 35 | replace_bg_token=True 36 | )) 37 | self.transformers = nn.ModuleList(self.transformers) 38 | 39 | def forward(self, p1, f1, geo1, fps_idx1, p2, f2, geo2, fps_idx2, radius, end_points): 40 | B = p1.size(0) 41 | 42 | init_R = end_points['init_R'] 43 | init_t = end_points['init_t'] 44 | p1_ = (p1 - init_t.unsqueeze(1)) @ init_R 45 | 46 | f1 = self.in_proj(f1) + self.PE(p1_) 47 | f1 = torch.cat([self.bg_token.repeat(B,1,1), f1], dim=1) # adding bg 48 | 49 | f2 = self.in_proj(f2) + self.PE(p2) 50 | f2 = torch.cat([self.bg_token.repeat(B,1,1), f2], dim=1) # adding bg 51 | 52 | atten_list = [] 53 | for idx in range(self.nblock): 54 | f1, f2 = self.transformers[idx](f1, geo1, fps_idx1, f2, geo2, fps_idx2) 55 | 56 | if self.training or idx==self.nblock-1: 57 | atten_list.append(compute_feature_similarity( 58 | self.out_proj(f1), 59 | self.out_proj(f2), 60 | self.cfg.sim_type, 61 | self.cfg.temp, 62 | self.cfg.normalize_feat 63 | )) 64 | 65 | if self.training: 66 | gt_R = end_points['rotation_label'] 67 | gt_t = end_points['translation_label'] / (radius.reshape(-1, 1)+1e-6) 68 | 69 | end_points = compute_correspondence_loss( 70 | end_points, atten_list, p1, p2, gt_R, gt_t, 71 | dis_thres=self.cfg.loss_dis_thres, 72 | loss_str='fine' 73 | ) 74 | else: 75 | pred_R, pred_t, pred_pose_score = compute_fine_Rt( 76 | atten_list[-1], p1, p2, 77 | end_points['model'] / (radius.reshape(-1, 1, 1) + 1e-6), 78 | ) 79 | end_points['pred_R'] = pred_R 80 | end_points['pred_t'] = pred_t * (radius.reshape(-1, 1)+1e-6) 81 | end_points['pred_pose_score'] = pred_pose_score 82 | 83 | if self.return_feat: 84 | return end_points, self.out_proj(f1), self.out_proj(f2) 85 | else: 86 | return end_points 87 | 88 | 89 | 90 | class PositionalEncoding(nn.Module): 91 | def __init__(self, out_dim, r1=0.1, r2=0.2, nsample1=32, nsample2=64, use_xyz=True, bn=True): 92 | super(PositionalEncoding, self).__init__() 93 | self.group1 = QueryAndGroup(r1, nsample1, use_xyz=use_xyz) 94 | self.group2 = QueryAndGroup(r2, nsample2, use_xyz=use_xyz) 95 | input_dim = 6 if use_xyz else 3 96 | 97 | self.mlp1 = SharedMLP([input_dim, 32, 64, 128], bn=bn) 98 | self.mlp2 = SharedMLP([input_dim, 32, 64, 128], bn=bn) 99 | self.mlp3 = Conv1d(256, out_dim, 1, activation=None, bn=None) 100 | 101 | def forward(self, pts1, pts2=None): 102 | if pts2 is None: 103 | pts2 = pts1 104 | 105 | # scale1 106 | feat1 = self.group1( 107 | pts1.contiguous(), pts2.contiguous(), pts1.transpose(1,2).contiguous() 108 | ) 109 | feat1 = self.mlp1(feat1) 110 | feat1 = F.max_pool2d( 111 | feat1, kernel_size=[1, feat1.size(3)] 112 | ) 113 | 114 | # scale2 115 | feat2 = self.group2( 116 | pts1.contiguous(), pts2.contiguous(), pts1.transpose(1,2).contiguous() 117 | ) 118 | feat2 = self.mlp2(feat2) 119 | feat2 = F.max_pool2d( 120 | feat2, kernel_size=[1, feat2.size(3)] 121 | ) 122 | 123 | feat = torch.cat([feat1, feat2], dim=1).squeeze(-1) 124 | feat = self.mlp3(feat).transpose(1,2) 125 | return feat 126 | 127 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 10 | const int nsample); 11 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 13 | at::Tensor weight); 14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 15 | at::Tensor weight, const int m); 16 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 12 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "utils.h" 8 | 9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 10 | int nsample, const float *new_xyz, 11 | const float *xyz, int *idx); 12 | 13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 14 | const int nsample) { 15 | CHECK_CONTIGUOUS(new_xyz); 16 | CHECK_CONTIGUOUS(xyz); 17 | CHECK_IS_FLOAT(new_xyz); 18 | CHECK_IS_FLOAT(xyz); 19 | 20 | if (new_xyz.type().is_cuda()) { 21 | CHECK_CUDA(xyz); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, nsample, new_xyz.data(), 31 | xyz.data(), idx.data()); 32 | } else { 33 | TORCH_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 13 | // output: idx(b, m, nsample) 14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 15 | int nsample, 16 | const float *__restrict__ new_xyz, 17 | const float *__restrict__ xyz, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | xyz += batch_index * n * 3; 21 | new_xyz += batch_index * m * 3; 22 | idx += m * nsample * batch_index; 23 | 24 | int index = threadIdx.x; 25 | int stride = blockDim.x; 26 | 27 | float radius2 = radius * radius; 28 | for (int j = index; j < m; j += stride) { 29 | float new_x = new_xyz[j * 3 + 0]; 30 | float new_y = new_xyz[j * 3 + 1]; 31 | float new_z = new_xyz[j * 3 + 2]; 32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 33 | float x = xyz[k * 3 + 0]; 34 | float y = xyz[k * 3 + 1]; 35 | float z = xyz[k * 3 + 2]; 36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 37 | (new_z - z) * (new_z - z); 38 | if (d2 < radius2) { 39 | if (cnt == 0) { 40 | for (int l = 0; l < nsample; ++l) { 41 | idx[j * nsample + l] = k; 42 | } 43 | } 44 | idx[j * nsample + cnt] = k; 45 | ++cnt; 46 | } 47 | } 48 | } 49 | } 50 | 51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 52 | int nsample, const float *new_xyz, 53 | const float *xyz, int *idx) { 54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 55 | query_ball_point_kernel<<>>( 56 | b, n, m, radius, nsample, new_xyz, xyz, idx); 57 | 58 | CUDA_CHECK_ERRORS(); 59 | } 60 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "group_points.h" 8 | #include "interpolate.h" 9 | #include "sampling.h" 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("gather_points", &gather_points); 13 | m.def("gather_points_grad", &gather_points_grad); 14 | m.def("furthest_point_sampling", &furthest_point_sampling); 15 | 16 | m.def("three_nn", &three_nn); 17 | m.def("three_interpolate", &three_interpolate); 18 | m.def("three_interpolate_grad", &three_interpolate_grad); 19 | 20 | m.def("ball_query", &ball_query); 21 | 22 | m.def("group_points", &group_points); 23 | m.def("group_points_grad", &group_points_grad); 24 | } 25 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "group_points.h" 7 | #include "utils.h" 8 | 9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 10 | const float *points, const int *idx, 11 | float *out); 12 | 13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 14 | int nsample, const float *grad_out, 15 | const int *idx, float *grad_points); 16 | 17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 18 | CHECK_CONTIGUOUS(points); 19 | CHECK_CONTIGUOUS(idx); 20 | CHECK_IS_FLOAT(points); 21 | CHECK_IS_INT(idx); 22 | 23 | if (points.type().is_cuda()) { 24 | CHECK_CUDA(idx); 25 | } 26 | 27 | at::Tensor output = 28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 29 | at::device(points.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (points.type().is_cuda()) { 32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 33 | idx.size(1), idx.size(2), points.data(), 34 | idx.data(), output.data()); 35 | } else { 36 | TORCH_CHECK(false, "CPU not supported"); 37 | } 38 | 39 | return output; 40 | } 41 | 42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 43 | CHECK_CONTIGUOUS(grad_out); 44 | CHECK_CONTIGUOUS(idx); 45 | CHECK_IS_FLOAT(grad_out); 46 | CHECK_IS_INT(idx); 47 | 48 | if (grad_out.type().is_cuda()) { 49 | CHECK_CUDA(idx); 50 | } 51 | 52 | at::Tensor output = 53 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 54 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 55 | 56 | if (grad_out.type().is_cuda()) { 57 | group_points_grad_kernel_wrapper( 58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 59 | grad_out.data(), idx.data(), output.data()); 60 | } else { 61 | TORCH_CHECK(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, npoints, nsample) 12 | // output: out(b, c, npoints, nsample) 13 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 14 | int nsample, 15 | const float *__restrict__ points, 16 | const int *__restrict__ idx, 17 | float *__restrict__ out) { 18 | int batch_index = blockIdx.x; 19 | points += batch_index * n * c; 20 | idx += batch_index * npoints * nsample; 21 | out += batch_index * npoints * nsample * c; 22 | 23 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 24 | const int stride = blockDim.y * blockDim.x; 25 | for (int i = index; i < c * npoints; i += stride) { 26 | const int l = i / npoints; 27 | const int j = i % npoints; 28 | for (int k = 0; k < nsample; ++k) { 29 | int ii = idx[j * nsample + k]; 30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 31 | } 32 | } 33 | } 34 | 35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 36 | const float *points, const int *idx, 37 | float *out) { 38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 39 | 40 | group_points_kernel<<>>( 41 | b, c, n, npoints, nsample, points, idx, out); 42 | 43 | CUDA_CHECK_ERRORS(); 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points) { 74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 75 | 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | CUDA_CHECK_ERRORS(); 80 | } 81 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "interpolate.h" 7 | #include "utils.h" 8 | 9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 10 | const float *known, float *dist2, int *idx); 11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 12 | const float *points, const int *idx, 13 | const float *weight, float *out); 14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 15 | const float *grad_out, 16 | const int *idx, const float *weight, 17 | float *grad_points); 18 | 19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 20 | CHECK_CONTIGUOUS(unknowns); 21 | CHECK_CONTIGUOUS(knows); 22 | CHECK_IS_FLOAT(unknowns); 23 | CHECK_IS_FLOAT(knows); 24 | 25 | if (unknowns.type().is_cuda()) { 26 | CHECK_CUDA(knows); 27 | } 28 | 29 | at::Tensor idx = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 32 | at::Tensor dist2 = 33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 34 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 35 | 36 | if (unknowns.type().is_cuda()) { 37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 38 | unknowns.data(), knows.data(), 39 | dist2.data(), idx.data()); 40 | } else { 41 | TORCH_CHECK(false, "CPU not supported"); 42 | } 43 | 44 | return {dist2, idx}; 45 | } 46 | 47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 48 | at::Tensor weight) { 49 | CHECK_CONTIGUOUS(points); 50 | CHECK_CONTIGUOUS(idx); 51 | CHECK_CONTIGUOUS(weight); 52 | CHECK_IS_FLOAT(points); 53 | CHECK_IS_INT(idx); 54 | CHECK_IS_FLOAT(weight); 55 | 56 | if (points.type().is_cuda()) { 57 | CHECK_CUDA(idx); 58 | CHECK_CUDA(weight); 59 | } 60 | 61 | at::Tensor output = 62 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 63 | at::device(points.device()).dtype(at::ScalarType::Float)); 64 | 65 | if (points.type().is_cuda()) { 66 | three_interpolate_kernel_wrapper( 67 | points.size(0), points.size(1), points.size(2), idx.size(1), 68 | points.data(), idx.data(), weight.data(), 69 | output.data()); 70 | } else { 71 | TORCH_CHECK(false, "CPU not supported"); 72 | } 73 | 74 | return output; 75 | } 76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 77 | at::Tensor weight, const int m) { 78 | CHECK_CONTIGUOUS(grad_out); 79 | CHECK_CONTIGUOUS(idx); 80 | CHECK_CONTIGUOUS(weight); 81 | CHECK_IS_FLOAT(grad_out); 82 | CHECK_IS_INT(idx); 83 | CHECK_IS_FLOAT(weight); 84 | 85 | if (grad_out.type().is_cuda()) { 86 | CHECK_CUDA(idx); 87 | CHECK_CUDA(weight); 88 | } 89 | 90 | at::Tensor output = 91 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 92 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 93 | 94 | if (grad_out.type().is_cuda()) { 95 | three_interpolate_grad_kernel_wrapper( 96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 97 | grad_out.data(), idx.data(), weight.data(), 98 | output.data()); 99 | } else { 100 | TORCH_CHECK(false, "CPU not supported"); 101 | } 102 | 103 | return output; 104 | } 105 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: unknown(b, n, 3) known(b, m, 3) 13 | // output: dist2(b, n, 3), idx(b, n, 3) 14 | __global__ void three_nn_kernel(int b, int n, int m, 15 | const float *__restrict__ unknown, 16 | const float *__restrict__ known, 17 | float *__restrict__ dist2, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | unknown += batch_index * n * 3; 21 | known += batch_index * m * 3; 22 | dist2 += batch_index * n * 3; 23 | idx += batch_index * n * 3; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | for (int j = index; j < n; j += stride) { 28 | float ux = unknown[j * 3 + 0]; 29 | float uy = unknown[j * 3 + 1]; 30 | float uz = unknown[j * 3 + 2]; 31 | 32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 33 | int besti1 = 0, besti2 = 0, besti3 = 0; 34 | for (int k = 0; k < m; ++k) { 35 | float x = known[k * 3 + 0]; 36 | float y = known[k * 3 + 1]; 37 | float z = known[k * 3 + 2]; 38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 39 | if (d < best1) { 40 | best3 = best2; 41 | besti3 = besti2; 42 | best2 = best1; 43 | besti2 = besti1; 44 | best1 = d; 45 | besti1 = k; 46 | } else if (d < best2) { 47 | best3 = best2; 48 | besti3 = besti2; 49 | best2 = d; 50 | besti2 = k; 51 | } else if (d < best3) { 52 | best3 = d; 53 | besti3 = k; 54 | } 55 | } 56 | dist2[j * 3 + 0] = best1; 57 | dist2[j * 3 + 1] = best2; 58 | dist2[j * 3 + 2] = best3; 59 | 60 | idx[j * 3 + 0] = besti1; 61 | idx[j * 3 + 1] = besti2; 62 | idx[j * 3 + 2] = besti3; 63 | } 64 | } 65 | 66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 67 | const float *known, float *dist2, int *idx) { 68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 69 | three_nn_kernel<<>>(b, n, m, unknown, known, 70 | dist2, idx); 71 | 72 | CUDA_CHECK_ERRORS(); 73 | } 74 | 75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 76 | // output: out(b, c, n) 77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 78 | const float *__restrict__ points, 79 | const int *__restrict__ idx, 80 | const float *__restrict__ weight, 81 | float *__restrict__ out) { 82 | int batch_index = blockIdx.x; 83 | points += batch_index * m * c; 84 | 85 | idx += batch_index * n * 3; 86 | weight += batch_index * n * 3; 87 | 88 | out += batch_index * n * c; 89 | 90 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 91 | const int stride = blockDim.y * blockDim.x; 92 | for (int i = index; i < c * n; i += stride) { 93 | const int l = i / n; 94 | const int j = i % n; 95 | float w1 = weight[j * 3 + 0]; 96 | float w2 = weight[j * 3 + 1]; 97 | float w3 = weight[j * 3 + 2]; 98 | 99 | int i1 = idx[j * 3 + 0]; 100 | int i2 = idx[j * 3 + 1]; 101 | int i3 = idx[j * 3 + 2]; 102 | 103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 104 | points[l * m + i3] * w3; 105 | } 106 | } 107 | 108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 109 | const float *points, const int *idx, 110 | const float *weight, float *out) { 111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 112 | three_interpolate_kernel<<>>( 113 | b, c, m, n, points, idx, weight, out); 114 | 115 | CUDA_CHECK_ERRORS(); 116 | } 117 | 118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 119 | // output: grad_points(b, c, m) 120 | 121 | __global__ void three_interpolate_grad_kernel( 122 | int b, int c, int n, int m, const float *__restrict__ grad_out, 123 | const int *__restrict__ idx, const float *__restrict__ weight, 124 | float *__restrict__ grad_points) { 125 | int batch_index = blockIdx.x; 126 | grad_out += batch_index * n * c; 127 | idx += batch_index * n * 3; 128 | weight += batch_index * n * 3; 129 | grad_points += batch_index * m * c; 130 | 131 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 132 | const int stride = blockDim.y * blockDim.x; 133 | for (int i = index; i < c * n; i += stride) { 134 | const int l = i / n; 135 | const int j = i % n; 136 | float w1 = weight[j * 3 + 0]; 137 | float w2 = weight[j * 3 + 1]; 138 | float w3 = weight[j * 3 + 2]; 139 | 140 | int i1 = idx[j * 3 + 0]; 141 | int i2 = idx[j * 3 + 1]; 142 | int i3 = idx[j * 3 + 2]; 143 | 144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 147 | } 148 | } 149 | 150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 151 | const float *grad_out, 152 | const int *idx, const float *weight, 153 | float *grad_points) { 154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 155 | three_interpolate_grad_kernel<<>>( 156 | b, c, n, m, grad_out, idx, weight, grad_points); 157 | 158 | CUDA_CHECK_ERRORS(); 159 | } 160 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "sampling.h" 7 | #include "utils.h" 8 | 9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 10 | const float *points, const int *idx, 11 | float *out); 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs); 19 | 20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 21 | CHECK_CONTIGUOUS(points); 22 | CHECK_CONTIGUOUS(idx); 23 | CHECK_IS_FLOAT(points); 24 | CHECK_IS_INT(idx); 25 | 26 | if (points.type().is_cuda()) { 27 | CHECK_CUDA(idx); 28 | } 29 | 30 | at::Tensor output = 31 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 32 | at::device(points.device()).dtype(at::ScalarType::Float)); 33 | 34 | if (points.type().is_cuda()) { 35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 36 | idx.size(1), points.data(), 37 | idx.data(), output.data()); 38 | } else { 39 | TORCH_CHECK(false, "CPU not supported"); 40 | } 41 | 42 | return output; 43 | } 44 | 45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 46 | const int n) { 47 | CHECK_CONTIGUOUS(grad_out); 48 | CHECK_CONTIGUOUS(idx); 49 | CHECK_IS_FLOAT(grad_out); 50 | CHECK_IS_INT(idx); 51 | 52 | if (grad_out.type().is_cuda()) { 53 | CHECK_CUDA(idx); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 58 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (grad_out.type().is_cuda()) { 61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 62 | idx.size(1), grad_out.data(), 63 | idx.data(), output.data()); 64 | } else { 65 | TORCH_CHECK(false, "CPU not supported"); 66 | } 67 | 68 | return output; 69 | } 70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 71 | CHECK_CONTIGUOUS(points); 72 | CHECK_IS_FLOAT(points); 73 | 74 | at::Tensor output = 75 | torch::zeros({points.size(0), nsamples}, 76 | at::device(points.device()).dtype(at::ScalarType::Int)); 77 | 78 | at::Tensor tmp = 79 | torch::full({points.size(0), points.size(1)}, 1e10, 80 | at::device(points.device()).dtype(at::ScalarType::Float)); 81 | 82 | if (points.type().is_cuda()) { 83 | furthest_point_sampling_kernel_wrapper( 84 | points.size(0), points.size(1), nsamples, points.data(), 85 | tmp.data(), output.data()); 86 | } else { 87 | TORCH_CHECK(false, "CPU not supported"); 88 | } 89 | 90 | return output; 91 | } 92 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/pointnet2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Testing customized ops. ''' 7 | 8 | import torch 9 | from torch.autograd import gradcheck 10 | import numpy as np 11 | 12 | import os 13 | import sys 14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append(BASE_DIR) 16 | import pointnet2_utils 17 | 18 | def test_interpolation_grad(): 19 | batch_size = 1 20 | feat_dim = 2 21 | m = 4 22 | feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda() 23 | 24 | def interpolate_func(inputs): 25 | idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda() 26 | weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda() 27 | interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight) 28 | return interpolated_feats 29 | 30 | assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1)) 31 | 32 | if __name__=='__main__': 33 | test_interpolation_grad() 34 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import os 6 | from setuptools import setup, find_packages 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | 10 | _ext_src_root = "_ext_src" 11 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 12 | "{}/src/*.cu".format(_ext_src_root) 13 | ) 14 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 15 | 16 | setup( 17 | name='pointnet2', 18 | packages = find_packages(), 19 | ext_modules=[ 20 | CUDAExtension( 21 | name='pointnet2._ext', 22 | sources=_ext_sources, 23 | include_dirs = [os.path.join(_ext_src_root, "include")], 24 | extra_compile_args={ 25 | # "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 26 | # "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 27 | "cxx": [], 28 | "nvcc": ["-O3", 29 | "-DCUDA_HAS_FP16=1", 30 | "-D__CUDA_NO_HALF_OPERATORS__", 31 | "-D__CUDA_NO_HALF_CONVERSIONS__", 32 | "-D__CUDA_NO_HALF2_OPERATORS__", 33 | ]},) 34 | ], 35 | cmdclass={'build_ext': BuildExtension.with_options(use_ninja=True)} 36 | ) 37 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/model/pose_estimation_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from feature_extraction import ViTEncoder 5 | from coarse_point_matching import CoarsePointMatching 6 | from fine_point_matching import FinePointMatching 7 | from transformer import GeometricStructureEmbedding 8 | from model_utils import sample_pts_feats 9 | 10 | 11 | class Net(nn.Module): 12 | def __init__(self, cfg): 13 | super(Net, self).__init__() 14 | self.cfg = cfg 15 | self.coarse_npoint = cfg.coarse_npoint 16 | self.fine_npoint = cfg.fine_npoint 17 | 18 | self.feature_extraction = ViTEncoder(cfg.feature_extraction, self.fine_npoint) 19 | self.geo_embedding = GeometricStructureEmbedding(cfg.geo_embedding) 20 | self.coarse_point_matching = CoarsePointMatching(cfg.coarse_point_matching) 21 | self.fine_point_matching = FinePointMatching(cfg.fine_point_matching) 22 | 23 | def forward(self, end_points): 24 | dense_pm, dense_fm, dense_po, dense_fo, radius = self.feature_extraction(end_points) 25 | 26 | # pre-compute geometric embeddings for geometric transformer 27 | bg_point = torch.ones(dense_pm.size(0),1,3).float().to(dense_pm.device) * 100 28 | 29 | sparse_pm, sparse_fm, fps_idx_m = sample_pts_feats( 30 | dense_pm, dense_fm, self.coarse_npoint, return_index=True 31 | ) 32 | geo_embedding_m = self.geo_embedding(torch.cat([bg_point, sparse_pm], dim=1)) 33 | 34 | sparse_po, sparse_fo, fps_idx_o = sample_pts_feats( 35 | dense_po, dense_fo, self.coarse_npoint, return_index=True 36 | ) 37 | geo_embedding_o = self.geo_embedding(torch.cat([bg_point, sparse_po], dim=1)) 38 | 39 | # coarse_point_matching 40 | end_points = self.coarse_point_matching( 41 | sparse_pm, sparse_fm, geo_embedding_m, 42 | sparse_po, sparse_fo, geo_embedding_o, 43 | radius, end_points, 44 | ) 45 | 46 | # fine_point_matching 47 | end_points = self.fine_point_matching( 48 | dense_pm, dense_fm, geo_embedding_m, fps_idx_m, 49 | dense_po, dense_fo, geo_embedding_o, fps_idx_o, 50 | radius, end_points 51 | ) 52 | 53 | return end_points 54 | 55 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/train.py: -------------------------------------------------------------------------------- 1 | 2 | import gorilla 3 | from tqdm import tqdm 4 | import argparse 5 | import os 6 | import sys 7 | import os.path as osp 8 | import time 9 | import logging 10 | import numpy as np 11 | import random 12 | import importlib 13 | 14 | import torch 15 | from torch.autograd import Variable 16 | import torch.optim as optim 17 | 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | sys.path.append(os.path.join(BASE_DIR, 'provider')) 20 | sys.path.append(os.path.join(BASE_DIR, 'utils')) 21 | sys.path.append(os.path.join(BASE_DIR, 'model')) 22 | sys.path.append(os.path.join(BASE_DIR, 'model', 'pointnet2')) 23 | 24 | from solver import Solver, get_logger 25 | from loss_utils import Loss 26 | 27 | def get_parser(): 28 | parser = argparse.ArgumentParser( 29 | description="Pose Estimation") 30 | 31 | parser.add_argument("--gpus", 32 | type=str, 33 | default="0", 34 | help="index of gpu") 35 | parser.add_argument("--model", 36 | type=str, 37 | default="pose_estimation_model", 38 | help="name of model") 39 | parser.add_argument("--config", 40 | type=str, 41 | default="config/base.yaml", 42 | help="path to config file") 43 | parser.add_argument("--exp_id", 44 | type=int, 45 | default=0, 46 | help="experiment id") 47 | parser.add_argument("--checkpoint_iter", 48 | type=int, 49 | default=-1, 50 | help="iter num. of checkpoint") 51 | args_cfg = parser.parse_args() 52 | 53 | return args_cfg 54 | 55 | 56 | def init(): 57 | args = get_parser() 58 | exp_name = args.model + '_' + \ 59 | osp.splitext(args.config.split("/")[-1])[0] + '_id' + str(args.exp_id) 60 | log_dir = osp.join("log", exp_name) 61 | 62 | cfg = gorilla.Config.fromfile(args.config) 63 | cfg.exp_name = exp_name 64 | cfg.gpus = args.gpus 65 | cfg.model_name = args.model 66 | cfg.log_dir = log_dir 67 | cfg.checkpoint_iter = args.checkpoint_iter 68 | 69 | if not os.path.isdir(log_dir): 70 | os.makedirs(log_dir) 71 | logger = get_logger( 72 | level_print=logging.INFO, level_save=logging.WARNING, path_file=log_dir+"/training_logger.log") 73 | gorilla.utils.set_cuda_visible_devices(gpu_ids=cfg.gpus) 74 | 75 | return logger, cfg 76 | 77 | 78 | if __name__ == "__main__": 79 | logger, cfg = init() 80 | 81 | logger.warning( 82 | "************************ Start Logging ************************") 83 | logger.info(cfg) 84 | logger.info("using gpu: {}".format(cfg.gpus)) 85 | 86 | random.seed(cfg.rd_seed) 87 | torch.manual_seed(cfg.rd_seed) 88 | 89 | # model 90 | logger.info("=> creating model ...") 91 | MODEL = importlib.import_module(cfg.model_name) 92 | model = MODEL.Net(cfg.model) 93 | if hasattr(cfg, 'pretrain_dir') and cfg.pretrain_dir is not None: 94 | logger.info('loading pretrained backbone from {}'.format(cfg.pretrain_dir)) 95 | key1, key2 = model.load_state_dict(torch.load(cfg.pretrain_dir)['model'], strict=False) 96 | if len(cfg.gpus) > 1: 97 | model = torch.nn.DataParallel(model, range(len(cfg.gpus.split(",")))) 98 | model = model.cuda() 99 | 100 | loss = Loss().cuda() 101 | count_parameters = sum(gorilla.parameter_count(model).values()) 102 | logger.warning("#Total parameters : {}".format(count_parameters)) 103 | 104 | # dataloader 105 | batchsize = cfg.train_dataloader.bs 106 | num_epoch = cfg.training_epoch 107 | 108 | if cfg.lr_scheduler.type == 'WarmupCosineLR': 109 | num_iter = cfg.lr_scheduler.max_iters 110 | if hasattr(cfg, 'warmup_iter') and cfg.warmup_iter >0: 111 | num_iter = num_iter + cfg.warmup_iter 112 | iters_per_epoch = int(np.floor(num_iter / num_epoch)) 113 | elif cfg.lr_scheduler.type == 'CyclicLR': 114 | iters_per_epoch = cfg.lr_scheduler.step_size_up+cfg.lr_scheduler.step_size_down 115 | train_dataset = importlib.import_module(cfg.train_dataset.name) 116 | train_dataset = train_dataset.Dataset(cfg.train_dataset, iters_per_epoch*batchsize) 117 | 118 | 119 | train_dataloader = torch.utils.data.DataLoader( 120 | train_dataset, 121 | batch_size=cfg.train_dataloader.bs, 122 | num_workers=cfg.train_dataloader.num_workers, 123 | shuffle=cfg.train_dataloader.shuffle, 124 | sampler=None, 125 | drop_last=cfg.train_dataloader.drop_last, 126 | pin_memory=cfg.train_dataloader.pin_memory, 127 | ) 128 | 129 | dataloaders = { 130 | "train": train_dataloader, 131 | } 132 | 133 | # solver 134 | Trainer = Solver(model=model, loss=loss, 135 | dataloaders=dataloaders, 136 | logger=logger, 137 | cfg=cfg) 138 | Trainer.solve() 139 | 140 | logger.info('\nFinish!\n') 141 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/utils/bop_object_utils.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | Modified from https://github.com/rasmushaugaard/surfemb/blob/master/surfemb/data/obj.py 4 | ''' 5 | 6 | import os 7 | import glob 8 | import json 9 | import numpy as np 10 | import trimesh 11 | from tqdm import tqdm 12 | 13 | from data_utils import ( 14 | load_im, 15 | ) 16 | 17 | class Obj: 18 | def __init__( 19 | self, obj_id, 20 | mesh: trimesh.Trimesh, 21 | model_points, 22 | diameter: float, 23 | symmetry_flag: int, 24 | template_path: str, 25 | n_template_view: int, 26 | ): 27 | self.obj_id = obj_id 28 | self.mesh = mesh 29 | self.model_points = model_points 30 | self.diameter = diameter 31 | self.symmetry_flag = symmetry_flag 32 | self._get_template(template_path, n_template_view) 33 | 34 | def get_item(self, return_color=False, sample_num=2048): 35 | if return_color: 36 | model_points, _, model_colors = trimesh.sample.sample_surface(self.mesh, sample_num, sample_color=True) 37 | model_points = model_points.astype(np.float32) / 1000.0 38 | return (model_points, model_colors, self.symmetry_flag) 39 | else: 40 | return (self.model_points, self.symmetry_flag) 41 | 42 | def _get_template(self, path, nView): 43 | if nView > 0: 44 | total_nView = len(glob.glob(os.path.join(path, 'rgb_*.png'))) 45 | 46 | self.template = [] 47 | self.template_mask = [] 48 | self.template_pts = [] 49 | 50 | for v in range(nView): 51 | i = int(total_nView / nView * v) 52 | rgb_path = os.path.join(path, 'rgb_'+str(i)+'.png') 53 | xyz_path = os.path.join(path, 'xyz_'+str(i)+'.npy') 54 | mask_path = os.path.join(path, 'mask_'+str(i)+'.png') 55 | 56 | rgb = load_im(rgb_path).astype(np.uint8) 57 | xyz = np.load(xyz_path).astype(np.float32) / 1000.0 58 | mask = load_im(mask_path).astype(np.uint8) == 255 59 | 60 | self.template.append(rgb) 61 | self.template_mask.append(mask) 62 | self.template_pts.append(xyz) 63 | else: 64 | self.template = None 65 | self.template_choose = None 66 | self.template_pts = None 67 | 68 | def get_template(self, view_idx): 69 | return self.template[view_idx], self.template_mask[view_idx], self.template_pts[view_idx] 70 | 71 | 72 | def load_obj( 73 | model_path, obj_id: int, sample_num: int, 74 | template_path: str, 75 | n_template_view: int, 76 | ): 77 | models_info = json.load(open(os.path.join(model_path, 'models_info.json'))) 78 | mesh = trimesh.load_mesh(os.path.join(model_path, f'obj_{obj_id:06d}.ply')) 79 | model_points = mesh.sample(sample_num).astype(np.float32) / 1000.0 80 | diameter = models_info[str(obj_id)]['diameter'] / 1000.0 81 | if 'symmetries_continuous' in models_info[str(obj_id)]: 82 | symmetry_flag = 1 83 | elif 'symmetries_discrete' in models_info[str(obj_id)]: 84 | symmetry_flag = 1 85 | else: 86 | symmetry_flag = 0 87 | return Obj( 88 | obj_id, mesh, model_points, diameter, symmetry_flag, 89 | template_path, n_template_view 90 | ) 91 | 92 | 93 | def load_objs( 94 | model_path='models', 95 | template_path='render_imgs', 96 | sample_num=512, 97 | n_template_view=0, 98 | show_progressbar=True 99 | ): 100 | objs = [] 101 | obj_ids = sorted([int(p.split('/')[-1][4:10]) for p in glob.glob(os.path.join(model_path, '*.ply'))]) 102 | 103 | if n_template_view>0: 104 | template_paths = sorted(glob.glob(os.path.join(template_path, '*'))) 105 | assert len(template_paths) == len(obj_ids), '{} template_paths, {} obj_ids'.format(len(template_paths), len(obj_ids)) 106 | else: 107 | template_paths = [None for _ in range(len(obj_ids))] 108 | 109 | cnt = 0 110 | for obj_id in tqdm(obj_ids, 'loading objects') if show_progressbar else obj_ids: 111 | objs.append( 112 | load_obj(model_path, obj_id, sample_num, 113 | template_paths[cnt], n_template_view) 114 | ) 115 | cnt+=1 116 | return objs, obj_ids 117 | 118 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/utils/draw_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | 5 | def calculate_2d_projections(coordinates_3d, intrinsics): 6 | """ 7 | Input: 8 | coordinates: [3, N] 9 | intrinsics: [3, 3] 10 | Return 11 | projected_coordinates: [N, 2] 12 | """ 13 | projected_coordinates = intrinsics @ coordinates_3d 14 | projected_coordinates = projected_coordinates[:2, :] / projected_coordinates[2, :] 15 | projected_coordinates = projected_coordinates.transpose() 16 | projected_coordinates = np.array(projected_coordinates, dtype=np.int32) 17 | 18 | return projected_coordinates 19 | 20 | def get_3d_bbox(scale, shift = 0): 21 | """ 22 | Input: 23 | scale: [3] or scalar 24 | shift: [3] or scalar 25 | Return 26 | bbox_3d: [3, N] 27 | 28 | """ 29 | if hasattr(scale, "__iter__"): 30 | bbox_3d = np.array([[scale[0] / 2, +scale[1] / 2, scale[2] / 2], 31 | [scale[0] / 2, +scale[1] / 2, -scale[2] / 2], 32 | [-scale[0] / 2, +scale[1] / 2, scale[2] / 2], 33 | [-scale[0] / 2, +scale[1] / 2, -scale[2] / 2], 34 | [+scale[0] / 2, -scale[1] / 2, scale[2] / 2], 35 | [+scale[0] / 2, -scale[1] / 2, -scale[2] / 2], 36 | [-scale[0] / 2, -scale[1] / 2, scale[2] / 2], 37 | [-scale[0] / 2, -scale[1] / 2, -scale[2] / 2]]) + shift 38 | else: 39 | bbox_3d = np.array([[scale / 2, +scale / 2, scale / 2], 40 | [scale / 2, +scale / 2, -scale / 2], 41 | [-scale / 2, +scale / 2, scale / 2], 42 | [-scale / 2, +scale / 2, -scale / 2], 43 | [+scale / 2, -scale / 2, scale / 2], 44 | [+scale / 2, -scale / 2, -scale / 2], 45 | [-scale / 2, -scale / 2, scale / 2], 46 | [-scale / 2, -scale / 2, -scale / 2]]) +shift 47 | 48 | bbox_3d = bbox_3d.transpose() 49 | return bbox_3d 50 | 51 | def draw_3d_bbox(img, imgpts, color, size=3): 52 | imgpts = np.int32(imgpts).reshape(-1, 2) 53 | 54 | # draw ground layer in darker color 55 | color_ground = (int(color[0] * 0.3), int(color[1] * 0.3), int(color[2] * 0.3)) 56 | for i, j in zip([4, 5, 6, 7],[5, 7, 4, 6]): 57 | img = cv2.line(img, tuple(imgpts[i]), tuple(imgpts[j]), color_ground, size) 58 | 59 | # draw pillars in blue color 60 | color_pillar = (int(color[0]*0.6), int(color[1]*0.6), int(color[2]*0.6)) 61 | for i, j in zip(range(4),range(4,8)): 62 | img = cv2.line(img, tuple(imgpts[i]), tuple(imgpts[j]), color_pillar, size) 63 | 64 | # finally, draw top layer in color 65 | for i, j in zip([0, 1, 2, 3],[1, 3, 0, 2]): 66 | img = cv2.line(img, tuple(imgpts[i]), tuple(imgpts[j]), color, size) 67 | return img 68 | 69 | def draw_3d_pts(img, imgpts, color, size=1): 70 | imgpts = np.int32(imgpts).reshape(-1, 2) 71 | for point in imgpts: 72 | img = cv2.circle(img, (point[0], point[1]), size, color, -1) 73 | return img 74 | 75 | def draw_detections(image, pred_rots, pred_trans, model_points, intrinsics, color=(255, 0, 0)): 76 | num_pred_instances = len(pred_rots) 77 | draw_image_bbox = image.copy() 78 | # 3d bbox 79 | scale = (np.max(model_points, axis=0) - np.min(model_points, axis=0)) 80 | shift = np.mean(model_points, axis=0) 81 | bbox_3d = get_3d_bbox(scale, shift) 82 | 83 | # 3d point 84 | choose = np.random.choice(np.arange(len(model_points)), 512) 85 | pts_3d = model_points[choose].T 86 | 87 | for ind in range(num_pred_instances): 88 | # draw 3d bounding box 89 | transformed_bbox_3d = pred_rots[ind]@bbox_3d + pred_trans[ind][:,np.newaxis] 90 | projected_bbox = calculate_2d_projections(transformed_bbox_3d, intrinsics[ind]) 91 | draw_image_bbox = draw_3d_bbox(draw_image_bbox, projected_bbox, color) 92 | # draw point cloud 93 | transformed_pts_3d = pred_rots[ind]@pts_3d + pred_trans[ind][:,np.newaxis] 94 | projected_pts = calculate_2d_projections(transformed_pts_3d, intrinsics[ind]) 95 | draw_image_bbox = draw_3d_pts(draw_image_bbox, projected_pts, color) 96 | 97 | return draw_image_bbox 98 | -------------------------------------------------------------------------------- /SAM-6D/Pose_Estimation_Model/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | from model_utils import pairwise_distance 6 | 7 | def compute_correspondence_loss( 8 | end_points, 9 | atten_list, 10 | pts1, 11 | pts2, 12 | gt_r, 13 | gt_t, 14 | dis_thres=0.15, 15 | loss_str='coarse' 16 | ): 17 | CE = nn.CrossEntropyLoss(reduction ='none') 18 | 19 | gt_pts = (pts1-gt_t.unsqueeze(1))@gt_r 20 | dis_mat = torch.sqrt(pairwise_distance(gt_pts, pts2)) 21 | 22 | dis1, label1 = dis_mat.min(2) 23 | fg_label1 = (dis1<=dis_thres).float() 24 | label1 = (fg_label1 * (label1.float()+1.0)).long() 25 | 26 | dis2, label2 = dis_mat.min(1) 27 | fg_label2 = (dis2<=dis_thres).float() 28 | label2 = (fg_label2 * (label2.float()+1.0)).long() 29 | 30 | # loss 31 | for idx, atten in enumerate(atten_list): 32 | l1 = CE(atten.transpose(1,2)[:,:,1:].contiguous(), label1).mean(1) 33 | l2 = CE(atten[:,:,1:].contiguous(), label2).mean(1) 34 | end_points[loss_str + '_loss' + str(idx)] = 0.5 * (l1 + l2) 35 | 36 | # acc 37 | pred_label = torch.max(atten_list[-1][:,1:,:], dim=2)[1] 38 | end_points[loss_str + '_acc'] = (pred_label==label1).float().mean(1) 39 | 40 | # pred foreground num 41 | fg_mask = (pred_label > 0).float() 42 | end_points[loss_str + '_fg_num'] = fg_mask.sum(1) 43 | 44 | # foreground point dis 45 | fg_label = fg_mask * (pred_label - 1) 46 | fg_label = fg_label.long() 47 | pred_pts = torch.gather(pts2, 1, fg_label.unsqueeze(2).repeat(1,1,3)) 48 | pred_dis = torch.norm(pred_pts-gt_pts, dim=2) 49 | pred_dis = (pred_dis * fg_mask).sum(1) / (fg_mask.sum(1)+1e-8) 50 | end_points[loss_str + '_dis'] = pred_dis 51 | 52 | return end_points 53 | 54 | 55 | 56 | class Loss(nn.Module): 57 | def __init__(self): 58 | super(Loss, self).__init__() 59 | 60 | def forward(self, end_points): 61 | out_dicts = {'loss': 0} 62 | for key in end_points.keys(): 63 | if 'coarse_' in key or 'fine_' in key: 64 | out_dicts[key] = end_points[key].mean() 65 | if 'loss' in key: 66 | out_dicts['loss'] = out_dicts['loss'] + end_points[key] 67 | out_dicts['loss'] = torch.clamp(out_dicts['loss'], max=100.0).mean() 68 | return out_dicts 69 | 70 | -------------------------------------------------------------------------------- /SAM-6D/Render/render_bop_templates.py: -------------------------------------------------------------------------------- 1 | import blenderproc as bproc 2 | 3 | import os 4 | import argparse 5 | import json 6 | import cv2 7 | import numpy as np 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--dataset_name', help="The name of bop datasets") 11 | args = parser.parse_args() 12 | 13 | # set relative path of Data folder 14 | render_dir = os.path.dirname(os.path.abspath(__file__)) 15 | bop_path = os.path.join(render_dir, '../Data/BOP') 16 | output_dir = os.path.join(render_dir, '../Data/BOP-Templates') 17 | cnos_cam_fpath = os.path.join(render_dir, '../Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level0.npy') 18 | 19 | bproc.init() 20 | 21 | if args.dataset_name == 'tless': 22 | model_folder = 'models_cad' 23 | else: 24 | model_folder = 'models' 25 | 26 | model_path = os.path.join(bop_path, args.dataset_name, model_folder) 27 | models_info = json.load(open(os.path.join(model_path, 'models_info.json'))) 28 | for obj_id in models_info.keys(): 29 | diameter = models_info[obj_id]['diameter'] 30 | scale = 1 / diameter 31 | obj_fpath = os.path.join(model_path, f'obj_{int(obj_id):06d}.ply') 32 | 33 | cam_poses = np.load(cnos_cam_fpath) 34 | for idx, cam_pose in enumerate(cam_poses[:]): 35 | 36 | bproc.clean_up() 37 | 38 | # load object 39 | obj = bproc.loader.load_obj(obj_fpath)[0] 40 | obj.set_scale([scale, scale, scale]) 41 | obj.set_cp("category_id", obj_id) 42 | 43 | if args.dataset_name == 'tless': 44 | color = [0.4, 0.4, 0.4, 0.] 45 | material = bproc.material.create('obj') 46 | material.set_principled_shader_value('Base Color', color) 47 | obj.set_material(0, material) 48 | 49 | # convert cnos camera poses to blender camera poses 50 | cam_pose[:3, 1:3] = -cam_pose[:3, 1:3] 51 | cam_pose[:3, -1] = cam_pose[:3, -1] * 0.001 * 2 52 | bproc.camera.add_camera_pose(cam_pose) 53 | 54 | # set light 55 | light_energy = 1000 56 | light_scale = 2.5 57 | light1 = bproc.types.Light() 58 | light1.set_type("POINT") 59 | light1.set_location([light_scale*cam_pose[:3, -1][0], light_scale*cam_pose[:3, -1][1], light_scale*cam_pose[:3, -1][2]]) 60 | light1.set_energy(light_energy) 61 | 62 | bproc.renderer.set_max_amount_of_samples(1) 63 | # render the whole pipeline 64 | data = bproc.renderer.render() 65 | # render nocs 66 | data.update(bproc.renderer.render_nocs()) 67 | 68 | # check save folder 69 | save_fpath = os.path.join(output_dir, args.dataset_name, f'obj_{int(obj_id):06d}') 70 | if not os.path.exists(save_fpath): 71 | os.makedirs(save_fpath) 72 | 73 | # save rgb image 74 | color_bgr_0 = data["colors"][0] 75 | color_bgr_0[..., :3] = color_bgr_0[..., :3][..., ::-1] 76 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(idx)+'.png'), color_bgr_0) 77 | 78 | # save mask 79 | mask_0 = data["nocs"][0][..., -1] 80 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(idx)+'.png'), mask_0*255) 81 | 82 | # save nocs 83 | xyz_0 = 2*(data["nocs"][0][..., :3] - 0.5) 84 | np.save(os.path.join(save_fpath,'xyz_'+str(idx)+'.npy'), xyz_0.astype(np.float16)) 85 | -------------------------------------------------------------------------------- /SAM-6D/Render/render_custom_templates.py: -------------------------------------------------------------------------------- 1 | import blenderproc as bproc 2 | 3 | import os 4 | import argparse 5 | import cv2 6 | import numpy as np 7 | import trimesh 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--cad_path', help="The path of CAD model") 11 | parser.add_argument('--output_dir', help="The path to save CAD templates") 12 | parser.add_argument('--normalize', default=True, help="Whether to normalize CAD model or not") 13 | parser.add_argument('--colorize', default=False, help="Whether to colorize CAD model or not") 14 | parser.add_argument('--base_color', default=0.05, help="The base color used in CAD model") 15 | args = parser.parse_args() 16 | 17 | # set the cnos camera path 18 | render_dir = os.path.dirname(os.path.abspath(__file__)) 19 | cnos_cam_fpath = os.path.join(render_dir, '../Instance_Segmentation_Model/utils/poses/predefined_poses/cam_poses_level0.npy') 20 | 21 | bproc.init() 22 | 23 | def get_norm_info(mesh_path): 24 | mesh = trimesh.load(mesh_path, force='mesh') 25 | 26 | model_points = trimesh.sample.sample_surface(mesh, 1024)[0] 27 | model_points = model_points.astype(np.float32) 28 | 29 | min_value = np.min(model_points, axis=0) 30 | max_value = np.max(model_points, axis=0) 31 | 32 | radius = max(np.linalg.norm(max_value), np.linalg.norm(min_value)) 33 | 34 | return 1/(2*radius) 35 | 36 | 37 | # load cnos camera pose 38 | cam_poses = np.load(cnos_cam_fpath) 39 | 40 | # calculating the scale of CAD model 41 | if args.normalize: 42 | scale = get_norm_info(args.cad_path) 43 | else: 44 | scale = 1 45 | 46 | for idx, cam_pose in enumerate(cam_poses): 47 | 48 | bproc.clean_up() 49 | 50 | # load object 51 | obj = bproc.loader.load_obj(args.cad_path)[0] 52 | obj.set_scale([scale, scale, scale]) 53 | obj.set_cp("category_id", 1) 54 | 55 | # assigning material colors to untextured objects 56 | if args.colorize: 57 | color = [args.base_color, args.base_color, args.base_color, 0.] 58 | material = bproc.material.create('obj') 59 | material.set_principled_shader_value('Base Color', color) 60 | obj.set_material(0, material) 61 | 62 | # convert cnos camera poses to blender camera poses 63 | cam_pose[:3, 1:3] = -cam_pose[:3, 1:3] 64 | cam_pose[:3, -1] = cam_pose[:3, -1] * 0.001 * 2 65 | bproc.camera.add_camera_pose(cam_pose) 66 | 67 | # set light 68 | light_scale = 2.5 69 | light_energy = 1000 70 | light1 = bproc.types.Light() 71 | light1.set_type("POINT") 72 | light1.set_location([light_scale*cam_pose[:3, -1][0], light_scale*cam_pose[:3, -1][1], light_scale*cam_pose[:3, -1][2]]) 73 | light1.set_energy(light_energy) 74 | 75 | bproc.renderer.set_max_amount_of_samples(50) 76 | # render the whole pipeline 77 | data = bproc.renderer.render() 78 | # render nocs 79 | data.update(bproc.renderer.render_nocs()) 80 | 81 | # check save folder 82 | save_fpath = os.path.join(args.output_dir, "templates") 83 | if not os.path.exists(save_fpath): 84 | os.makedirs(save_fpath) 85 | 86 | # save rgb image 87 | color_bgr_0 = data["colors"][0] 88 | color_bgr_0[..., :3] = color_bgr_0[..., :3][..., ::-1] 89 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(idx)+'.png'), color_bgr_0) 90 | 91 | # save mask 92 | mask_0 = data["nocs"][0][..., -1] 93 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(idx)+'.png'), mask_0*255) 94 | 95 | # save nocs 96 | xyz_0 = 2*(data["nocs"][0][..., :3] - 0.5) 97 | np.save(os.path.join(save_fpath,'xyz_'+str(idx)+'.npy'), xyz_0.astype(np.float16)) -------------------------------------------------------------------------------- /SAM-6D/Render/render_gso_templates.py: -------------------------------------------------------------------------------- 1 | import blenderproc as bproc 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import trimesh 7 | 8 | # set relative path of Data folder 9 | render_dir = os.path.dirname(os.path.abspath(__file__)) 10 | gso_path = os.path.join(render_dir, '../Data/MegaPose-Training-Data/MegaPose-GSO/google_scanned_objects') 11 | gso_norm_path = os.path.join(gso_path, 'models_normalized') 12 | output_dir = os.path.join(render_dir, '../Data/MegaPose-Training-Data/MegaPose-GSO/templates') 13 | 14 | bproc.init() 15 | 16 | def get_norm_info(mesh_path): 17 | mesh = trimesh.load(mesh_path, force='mesh') 18 | 19 | model_points = trimesh.sample.sample_surface(mesh, 1024)[0] 20 | model_points = model_points.astype(np.float32) 21 | 22 | min_value = np.min(model_points, axis=0) 23 | max_value = np.max(model_points, axis=0) 24 | 25 | radius = max(np.linalg.norm(max_value), np.linalg.norm(min_value)) 26 | 27 | return 1/(2*radius) 28 | 29 | 30 | for idx, model_name in enumerate(os.listdir(gso_norm_path)): 31 | if not os.path.isdir(os.path.join(gso_norm_path, model_name)) or '.' in model_name: 32 | continue 33 | print('---------------------------'+str(model_name)+'-------------------------------------') 34 | 35 | save_fpath = os.path.join(output_dir, model_name) 36 | if not os.path.exists(save_fpath): 37 | os.makedirs(save_fpath) 38 | 39 | obj_fpath = os.path.join(gso_norm_path, model_name, 'meshes', 'model.obj') 40 | if not os.path.exists(obj_fpath): 41 | continue 42 | 43 | scale = get_norm_info(obj_fpath) 44 | 45 | bproc.clean_up() 46 | 47 | obj = bproc.loader.load_obj(obj_fpath)[0] 48 | obj.set_scale([scale, scale, scale]) 49 | obj.set_cp("category_id", idx) 50 | 51 | # set light 52 | light1 = bproc.types.Light() 53 | light1.set_type("POINT") 54 | light1.set_location([-3, -3, -3]) 55 | light1.set_energy(1000) 56 | 57 | light2 = bproc.types.Light() 58 | light2.set_type("POINT") 59 | light2.set_location([3, 3, 3]) 60 | light2.set_energy(1000) 61 | 62 | location = [[-1, -1, -1], [1, 1, 1]] 63 | # set camera locations around the object 64 | for loc in location: 65 | # compute rotation based on vector going from location towards the location of object 66 | rotation_matrix = bproc.camera.rotation_from_forward_vec(obj.get_location() - loc) 67 | # add homog cam pose based on location and rotation 68 | cam2world_matrix = bproc.math.build_transformation_mat(loc, rotation_matrix) 69 | bproc.camera.add_camera_pose(cam2world_matrix) 70 | 71 | bproc.renderer.set_max_amount_of_samples(50) 72 | # render the whole pipeline 73 | data = bproc.renderer.render() 74 | # render nocs 75 | data.update(bproc.renderer.render_nocs()) 76 | 77 | # save rgb images 78 | color_bgr_0 = data["colors"][0] 79 | color_bgr_0[..., :3] = color_bgr_0[..., :3][..., ::-1] 80 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(0)+'.png'), color_bgr_0) 81 | color_bgr_1 = data["colors"][1] 82 | color_bgr_1[..., :3] = color_bgr_1[..., :3][..., ::-1] 83 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(1)+'.png'), color_bgr_1) 84 | 85 | # save masks 86 | mask_0 = data["nocs"][0][..., -1] 87 | mask_1 = data["nocs"][1][..., -1] 88 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(0)+'.png'), mask_0*255) 89 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(1)+'.png'), mask_1*255) 90 | 91 | # save nocs 92 | xyz_0 = 2*(data["nocs"][0][..., :3] - 0.5) 93 | xyz_1 = 2*(data["nocs"][1][..., :3] - 0.5) 94 | np.save(os.path.join(save_fpath,'xyz_'+str(0)+'.npy'), xyz_0.astype(np.float16)) 95 | np.save(os.path.join(save_fpath,'xyz_'+str(1)+'.npy'), xyz_1.astype(np.float16)) 96 | 97 | -------------------------------------------------------------------------------- /SAM-6D/Render/render_shapenet_templates.py: -------------------------------------------------------------------------------- 1 | import blenderproc as bproc 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import trimesh 7 | 8 | # set relative path of Data folder 9 | render_dir = os.path.dirname(os.path.abspath(__file__)) 10 | shapenet_path = os.path.join(render_dir, '../Data/MegaPose-Training-Data/MegaPose-ShapeNetCore/shapenetcorev2') 11 | shapenet_orig_path = os.path.join(shapenet_path, 'models_orig') 12 | output_dir = os.path.join(render_dir, '../Data/MegaPose-Training-Data/MegaPose-ShapeNetCore/templates') 13 | 14 | bproc.init() 15 | 16 | def get_norm_info(mesh_path): 17 | mesh = trimesh.load(mesh_path, force='mesh') 18 | 19 | model_points = trimesh.sample.sample_surface(mesh, 1024)[0] 20 | model_points = model_points.astype(np.float32) 21 | 22 | min_value = np.min(model_points, axis=0) 23 | max_value = np.max(model_points, axis=0) 24 | 25 | radius = max(np.linalg.norm(max_value), np.linalg.norm(min_value)) 26 | 27 | return 1/(2*radius) 28 | 29 | 30 | for synset_id in os.listdir(shapenet_orig_path): 31 | synset_fpath = os.path.join(shapenet_orig_path, synset_id) 32 | if not os.path.isdir(synset_fpath) or '.' in synset_id: 33 | continue 34 | print('---------------------------'+str(synset_id)+'-------------------------------------') 35 | for idx, source_id in enumerate(os.listdir(synset_fpath)): 36 | print('---------------------------'+str(source_id)+'-------------------------------------') 37 | save_synset_folder = os.path.join(output_dir, synset_id) 38 | if not os.path.exists(save_synset_folder): 39 | os.makedirs(save_synset_folder) 40 | 41 | save_fpath = os.path.join(save_synset_folder, source_id) 42 | if not os.path.exists(save_fpath): 43 | os.mkdir(save_fpath) 44 | else: 45 | continue 46 | 47 | cad_path = os.path.join(shapenet_orig_path, synset_id, source_id) 48 | obj_fpath = os.path.join(cad_path, 'models', 'model_normalized.obj') 49 | 50 | if not os.path.exists(obj_fpath): 51 | continue 52 | 53 | scale = get_norm_info(obj_fpath) 54 | 55 | bproc.clean_up() 56 | 57 | obj = bproc.loader.load_shapenet(shapenet_orig_path, synset_id, source_id, move_object_origin=False) 58 | obj.set_scale([scale, scale, scale]) 59 | obj.set_cp("category_id", idx) 60 | 61 | # set light 62 | light1 = bproc.types.Light() 63 | light1.set_type("POINT") 64 | light1.set_location([-3, -3, -3]) 65 | light1.set_energy(1000) 66 | 67 | light2 = bproc.types.Light() 68 | light2.set_type("POINT") 69 | light2.set_location([3, 3, 3]) 70 | light2.set_energy(1000) 71 | 72 | location = [[-1, -1, -1], [1, 1, 1]] 73 | # set camera locations around the object 74 | for loc in location: 75 | # compute rotation based on vector going from location towards the location of object 76 | rotation_matrix = bproc.camera.rotation_from_forward_vec(obj.get_location() - loc) 77 | # add homog cam pose based on location and rotation 78 | cam2world_matrix = bproc.math.build_transformation_mat(loc, rotation_matrix) 79 | bproc.camera.add_camera_pose(cam2world_matrix) 80 | 81 | bproc.renderer.set_max_amount_of_samples(1) 82 | # render the whole pipeline 83 | data = bproc.renderer.render() 84 | # render nocs 85 | data.update(bproc.renderer.render_nocs()) 86 | 87 | # save rgb images 88 | color_bgr_0 = data["colors"][0] 89 | color_bgr_0[..., :3] = color_bgr_0[..., :3][..., ::-1] 90 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(0)+'.png'), color_bgr_0) 91 | color_bgr_1 = data["colors"][1] 92 | color_bgr_1[..., :3] = color_bgr_1[..., :3][..., ::-1] 93 | cv2.imwrite(os.path.join(save_fpath,'rgb_'+str(1)+'.png'), color_bgr_1) 94 | 95 | # save masks 96 | mask_0 = data["nocs"][0][..., -1] 97 | mask_1 = data["nocs"][1][..., -1] 98 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(0)+'.png'), mask_0*255) 99 | cv2.imwrite(os.path.join(save_fpath,'mask_'+str(1)+'.png'), mask_1*255) 100 | 101 | # save nocs 102 | xyz_0 = 2*(data["nocs"][0][..., :3] - 0.5) 103 | xyz_1 = 2*(data["nocs"][1][..., :3] - 0.5) 104 | # xyz need to rotate 90 degree to match CAD 105 | rot90 = np.array([[1, 0, 0], 106 | [0, 0, 1], 107 | [0, -1, 0]]) 108 | h, w = xyz_0.shape[0], xyz_0.shape[1] 109 | 110 | xyz_0 = ((rot90 @ xyz_0.reshape(-1, 3).T).T).reshape(h, w, 3) 111 | xyz_1 = ((rot90 @ xyz_1.reshape(-1, 3).T).T).reshape(h, w, 3) 112 | np.save(os.path.join(save_fpath,'xyz_'+str(0)+'.npy'), xyz_0.astype(np.float16)) 113 | np.save(os.path.join(save_fpath,'xyz_'+str(1)+'.npy'), xyz_1.astype(np.float16)) 114 | 115 | -------------------------------------------------------------------------------- /SAM-6D/demo.sh: -------------------------------------------------------------------------------- 1 | # Render CAD templates 2 | cd Render 3 | blenderproc run render_custom_templates.py --output_dir $OUTPUT_DIR --cad_path $CAD_PATH #--colorize True 4 | 5 | 6 | # Run instance segmentation model 7 | export SEGMENTOR_MODEL=sam 8 | 9 | cd ../Instance_Segmentation_Model 10 | python run_inference_custom.py --segmentor_model $SEGMENTOR_MODEL --output_dir $OUTPUT_DIR --cad_path $CAD_PATH --rgb_path $RGB_PATH --depth_path $DEPTH_PATH --cam_path $CAMERA_PATH 11 | 12 | 13 | # Run pose estimation model 14 | export SEG_PATH=$OUTPUT_DIR/sam6d_results/detection_ism.json 15 | 16 | cd ../Pose_Estimation_Model 17 | python run_inference_custom.py --output_dir $OUTPUT_DIR --cad_path $CAD_PATH --rgb_path $RGB_PATH --depth_path $DEPTH_PATH --cam_path $CAMERA_PATH --seg_path $SEG_PATH 18 | 19 | -------------------------------------------------------------------------------- /SAM-6D/environment.yaml: -------------------------------------------------------------------------------- 1 | name: sam6d 2 | channels: 3 | - xformers 4 | - conda-forge 5 | - pytorch 6 | - nvidia 7 | - defaults 8 | dependencies: 9 | - pip 10 | - python=3.9.6 11 | - pip: 12 | - torch==2.0.0 13 | - torchvision==0.15.1 14 | - fvcore 15 | - xformers==0.0.18 16 | - torchmetrics==0.10.3 17 | - blenderproc==2.6.1 18 | - opencv-python 19 | # ISM 20 | - omegaconf 21 | - ruamel.yaml 22 | - hydra-colorlog 23 | - hydra-core 24 | - gdown 25 | - pandas 26 | - imageio 27 | - pyrender 28 | - pytorch-lightning==1.8.1 29 | - pycocotools 30 | - distinctipy 31 | - git+https://github.com/facebookresearch/segment-anything.git # SAM 32 | - ultralytics==8.0.135 # FastSAM 33 | # PEM 34 | - timm 35 | - gorilla-core==0.2.7.8 36 | - trimesh==4.0.8 37 | - gpustat==1.0.0 38 | - imgaug 39 | - einops -------------------------------------------------------------------------------- /SAM-6D/prepare.sh: -------------------------------------------------------------------------------- 1 | ### Create conda environment 2 | conda env create -f environment.yaml 3 | conda activate sam6d 4 | 5 | ### Install pointnet2 6 | cd Pose_Estimation_Model/model/pointnet2 7 | python setup.py install 8 | cd ../../../ 9 | 10 | ### Download ISM pretrained model 11 | cd Instance_Segmentation_Model 12 | python download_sam.py 13 | python download_fastsam.py 14 | python download_dinov2.py 15 | cd ../ 16 | 17 | ### Download PEM pretrained model 18 | cd Pose_Estimation_Model 19 | python download_sam6d-pem.py 20 | -------------------------------------------------------------------------------- /pics/overview_pem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/pics/overview_pem.png -------------------------------------------------------------------------------- /pics/overview_sam_6d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/pics/overview_sam_6d.png -------------------------------------------------------------------------------- /pics/vis.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiehongLin/SAM-6D/1c2543b3b6faa1f1d81b3c7291f8b371d71e50c2/pics/vis.gif --------------------------------------------------------------------------------