├── .gitignore ├── LICENSE ├── README.md ├── assets └── realword_vis.png ├── config ├── idgc.yaml ├── infer_idgc_test.yaml ├── infer_idgc_train.yaml ├── infer_qgc_test.yaml ├── qgc.yaml └── test_base.yaml ├── datasets ├── __init__.py ├── refine_datasets.py └── task_dex_datasets.py ├── model ├── __init__.py ├── backbone │ ├── __init__.py │ ├── clip_sd.py │ ├── pointnet.py │ ├── pointnet2.py │ ├── pointnet2_utils.py │ └── resnet.py ├── decoder │ ├── __init__.py │ └── unet.py ├── irf.py ├── ldgd.py ├── loss │ ├── __init__.py │ ├── grasp_loss_pose.py │ └── matcher.py └── utils │ ├── __init__.py │ ├── diffusion_utils.py │ ├── hand_model.py │ ├── helpers.py │ └── position_embedding.py ├── requirements.txt ├── test.py ├── thirdparty ├── pointnet2 │ ├── _ext_src │ │ ├── include │ │ │ ├── ball_query.h │ │ │ ├── cuda_utils.h │ │ │ ├── cylinder_query.h │ │ │ ├── group_points.h │ │ │ ├── interpolate.h │ │ │ ├── sampling.h │ │ │ └── utils.h │ │ └── src │ │ │ ├── ball_query.cpp │ │ │ ├── ball_query_gpu.cu │ │ │ ├── bindings.cpp │ │ │ ├── cylinder_query.cpp │ │ │ ├── cylinder_query_gpu.cu │ │ │ ├── group_points.cpp │ │ │ ├── group_points_gpu.cu │ │ │ ├── interpolate.cpp │ │ │ ├── interpolate_gpu.cu │ │ │ ├── sampling.cpp │ │ │ └── sampling_gpu.cu │ ├── pointnet2_modules.py │ ├── pointnet2_utils.py │ ├── pytorch_utils.py │ └── setup.py └── pytorch_kinematics │ ├── .gitignore │ ├── README.md │ ├── pytorch_kinematics │ ├── __init__.py │ ├── chain.py │ ├── frame.py │ ├── jacobian.py │ ├── mjcf.py │ ├── mjcf_parser │ │ ├── __init__.py │ │ ├── attribute.py │ │ ├── base.py │ │ ├── constants.py │ │ ├── copier.py │ │ ├── debugging.py │ │ ├── element.py │ │ ├── io.py │ │ ├── namescope.py │ │ ├── parser.py │ │ ├── schema.py │ │ ├── schema.xml │ │ └── util.py │ ├── sdf.py │ ├── transforms │ │ ├── __init__.py │ │ ├── math.py │ │ ├── rotation_conversions.py │ │ ├── so3.py │ │ └── transform3d.py │ ├── urdf.py │ └── urdf_parser_py │ │ ├── __init__.py │ │ ├── sdf.py │ │ ├── urdf.py │ │ └── xml_reflection │ │ ├── __init__.py │ │ ├── basics.py │ │ └── core.py │ ├── setup.py │ └── tests │ ├── __init__.py │ ├── ant.xml │ ├── humanoid.xml │ ├── kuka_iiwa.urdf │ ├── prismatic_robot.urdf │ ├── simple_arm.sdf │ ├── test_jacobian.py │ ├── test_kinematics.py │ └── test_transform.py ├── train.py └── utils ├── __init__.py ├── calc_diversity.py ├── config_utils.py ├── eval_utils.py ├── grasp_init.py ├── pc_util.py └── rotation_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | #pretrained weight 2 | pretrained/ 3 | 4 | # cache 5 | **/__pycache__/ 6 | 7 | # vscode 8 | .vscode/ 9 | 10 | # exp 11 | Experiments/ 12 | 13 | # datasets 14 | data 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 iSEE 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

Grasp as You Say: Language-guided Dexterous Grasp Generation

2 | 3 | ###

*Yi-Lin Wei, Jian-Jian Jiang, Chengyi Xing, Xiantuo Tan, Xiao-Ming Wu, Hao Li,
Mark Cutkosky, Wei-Shi Zheng*

4 | 5 | ####

[[Paper]](https://arxiv.org/abs/2405.19291)     [[Project]](https://isee-laboratory.github.io/DexGYS/)

6 | 7 | ![-](assets/realword_vis.png) 8 | ### (NeurIPS 2024) Official repository of paper "Grasp as You Say: Language-guided Dexterous Grasp Generation" 9 | 10 | 11 | ## Install 12 | - Create a new `conda` environemnt and activate it. 13 | ``` 14 | conda create -n dexgys python=3.8 15 | conda activate dexgys 16 | ``` 17 | - Install the dependencies. 18 | ``` 19 | conda install -y pytorch==1.10.0 torchvision==0.11.0 cudatoolkit=11.3 -c pytorch -c conda-forge 20 | pip install -r requirements.txt 21 | ``` 22 | - Build the pakage. 23 | > **Note**: The CUDA enviroment should be consistent in the phase of building and running (Recommendation: cuda11 or higher). 24 | ``` 25 | cd thirdparty/pytorch_kinematics 26 | pip install -e . 27 | 28 | cd ../pointnet2 29 | python setup.py install 30 | 31 | cd ../ 32 | git clone https://github.com/wrc042/CSDF.git 33 | cd CSDF 34 | pip install -e . 35 | cd ../../ 36 | ``` 37 | 38 | ## Data Preparation 39 | 1. Download dexterous grap label and language label of DexGYS from here ["coming soon"], and put in the "dexgys" in the path of "./data". 40 | 41 | 2. Download ShadowHand model mjcf from [here](https://mirrors.pku.edu.cn/dl-release/UniDexGrasp_CVPR2023/), and put the "mjcf" in the path of "./data". 42 | 43 | 3. Download 3D mesh of object from [here](https://oakink.net/), and put the "oakink" in the path of "./data". 44 | 45 | 4. Finally, the directory should as follow: 46 | ``` 47 | .data/ 48 | ├── dexgys/ 49 | │ ├── train_with_guide_v2.1.json 50 | │ ├── test_with_guide_v2.1.json 51 | ├── oakink/ 52 | │ ├── shape/ 53 | └── mjcf/ 54 | ``` 55 | 56 | ## Usage 57 | ### Train 58 | 1. Train Intention and Diversity Grasp Component (IDGC) 59 | ``` 60 | python train.py -t "./config/idgc.yaml" 61 | ``` 62 | 2. Infer IDGC on train and test set to obatin training and testing pairs for QGC. 63 | ``` 64 | python ./test.py \ 65 | --train_cfg ./config/idgc.yaml \ 66 | --test_cfg ./config/infer_idgc_train.yaml \ 67 | --override model.checkpoint_path \"\" 68 | ``` 69 | ``` 70 | python ./test.py \ 71 | --train_cfg ./config/idgc.yaml \ 72 | --test_cfg ./config/infer_idgc_test.yaml \ 73 | --override model.checkpoint_path \"\" 74 | ``` 75 | 76 | 77 | 3. Train Quality Grasp Component (QGC). 78 | 79 | - Set the "data.train.pose_path" and "data.test.pose_path" of "./config/qgc.yaml" to the of the outcome of step2. 80 | - For example: 81 | ``` 82 | data: 83 | name: refinement 84 | train: 85 | data_root: &data_root "./data/oakink" 86 | pose_path: ./Experiments/idgc/test_results/epoch__train/matched_results.json 87 | ... 88 | val: 89 | data_root: *data_root 90 | pose_path: ./Experiments/idgc/test_results/epoch__test/matched_results.json 91 | ``` 92 | - Then run: 93 | ``` 94 | python train.py -t "./config/qgc.yaml" 95 | ``` 96 | 97 | ### Test 98 | - Infer QGC to refine the coarse outcome of IDGC. 99 | - Set "data.test.pose_path" of "./config/infer_qgc_test.yaml" to the of the outcome of LDGC. 100 | ``` 101 | data: 102 | name: refinement 103 | train: 104 | data_root: &data_root "./data/oakink" 105 | pose_path: ./Experiments/idgc/test_results/epoch__train/matched_results.json 106 | sample_in_pose: &sample_in_pose True 107 | ``` 108 | - Then run: 109 | ``` 110 | python ./test.py \ 111 | --train_cfg ./config/qgc.yaml \ 112 | --test_cfg ./config/infer_qgc_test.yaml \ 113 | --override model.checkpoint_path \"\" 114 | ``` 115 | 116 | ## TODO 117 | - [ ] Release the datasets of GraspGYSNet 118 | - [ ] Release the visualization code of GraspGYS framework 119 | - [ ] Release the evaluation code of GraspGYS framework 120 | - [x] Release the training code of GraspGYS framework 121 | - [x] Release the inference code of GraspGYS framework 122 | 123 | ## Acknowledgements 124 | 125 | The code of this repository is based on the following repositories. We would like to thank the authors for sharing their works. 126 | 127 | - [UniDexGrasp](https://github.com/PKU-EPIC/UniDexGrasp) 128 | 129 | - [Scene-Diffuser](https://github.com/scenediffuser/Scene-Diffuser) 130 | 131 | - [DGTR](https://github.com/iSEE-Laboratory/DGTR) 132 | 133 | ## Contact 134 | - Email: weiylin5@mail2.sysu.edu.cn 135 | 136 | ## Citation 137 | Please cite it if you find this work useful. 138 | ``` 139 | @article{wei2024grasp, 140 | title={Grasp as you say: language-guided dexterous grasp generation}, 141 | author={Wei, Yi-Lin and Jiang, Jian-Jian and Xing, Chengyi and Tan, Xian-Tuo and Wu, Xiao-Ming and Li, Hao and Cutkosky, Mark and Zheng, Wei-Shi}, 142 | journal={arXiv preprint arXiv:2405.19291}, 143 | year={2024} 144 | } 145 | ``` 146 | -------------------------------------------------------------------------------- /assets/realword_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/Grasp-as-You-Say/4694fa9369523d09d0c2cea6ec7bcddd2cb206d8/assets/realword_vis.png -------------------------------------------------------------------------------- /config/idgc.yaml: -------------------------------------------------------------------------------- 1 | device: &device cuda:0 2 | rotation_type: &rotation_type euler 3 | rotation_dim: &rotation_dim 3 4 | 5 | ncols: 120 6 | epochs: &epochs 200 7 | print_freq: 500 8 | validate_freq: 1 9 | save_root: ./Experiments/idgc 10 | save_top_n: 20 11 | log_dir: logs 12 | seed: 3407 13 | norm_type: &norm_type minmax11 14 | guidence_type: &guidence_type "fine" # fine 15 | frozen_clip: True 16 | 17 | data: 18 | name: task_pose 19 | train: 20 | data_root: &data_root "./data/oakink" 21 | pose_path: "./data/dexgys/train_with_guide_v2.1.json" 22 | num_rotation_aug: &num_rotation_aug 1 23 | num_equivalent_aug: &num_equivalent_aug 1 24 | sample_in_pose: &sample_in_pose True 25 | guidence_type: *guidence_type 26 | rotation_type: *rotation_type 27 | norm_type: *norm_type 28 | batch_size: 48 29 | num_workers: 4 30 | 31 | val: 32 | data_root: *data_root 33 | pose_path: "./data/dexgys/test_with_guide_v2.1.json" 34 | num_rotation_aug: 1 35 | sample_in_pose: False 36 | guidence_type: *guidence_type 37 | rotation_type: *rotation_type 38 | norm_type: *norm_type 39 | batch_size: 48 40 | num_workers: 4 41 | 42 | model: 43 | name: LDGD 44 | steps: 100 45 | schedule_cfg: 46 | beta: [0.0001, 0.01] 47 | beta_schedule: 'linear' 48 | s: 0.008 49 | rand_t_type: 'half' # 'half' or 'all' 50 | loss_type: 'l2' # 'l1' or 'l2' 51 | out_sigmoid: False 52 | pred_x0: True 53 | device: *device 54 | rotation_type: *rotation_type 55 | 56 | decoder: 57 | name: unet 58 | use_guidence: True 59 | use_obj: True 60 | plus_condition_type: ("") 61 | trans_condition_type: "txt_obj" 62 | language_encoder: 63 | name: clip_sd 64 | version: 'ViT-L/14' 65 | use_pre: False 66 | use_adapter: True 67 | dim_in: 768 68 | dim_out: 512 69 | reduction: 1 70 | device: *device 71 | backbone: 72 | name: pointnet2 73 | use_pooling: False 74 | layer1: 75 | npoint: 1024 76 | radius_list: [0.02] 77 | nsample_list: [64] 78 | mlp_list: [0, 64, 128] 79 | layer2: 80 | npoint: 256 81 | radius_list: [0.05] 82 | nsample_list: [32] 83 | mlp_list: [128, 256, 256] 84 | layer3: 85 | npoint: 64 86 | radius_list: [0.1] 87 | nsample_list: [16] 88 | mlp_list: 89 | - 256 90 | - 512 91 | - &encoder_out 512 92 | use_xyz: true 93 | normalize_xyz: true 94 | d_x: 28 # placeholder 95 | d_model: 512 96 | time_embed_mult: 2 97 | nblocks: 4 98 | resblock_dropout: 0.0 99 | transformer_num_heads: 8 100 | transformer_dim_head: 64 101 | transformer_dropout: 0.1 102 | transformer_depth: 1 103 | transformer_mult_ff: 2 104 | context_dim: 512 105 | use_position_embedding: false # for input x 106 | 107 | criterion: 108 | hand_model: 109 | mjcf_path: ./data/mjcf/shadow_hand.xml 110 | mesh_path: ./data/mjcf/meshes 111 | n_surface_points: 1024 112 | contact_points_path: ./data/mjcf/contact_points.json 113 | penetration_points_path: ./data/mjcf/penetration_points.json 114 | fingertip_points_path: ./data/mjcf/fingertip.json 115 | loss_weights: 116 | hand_chamfer: 1.0 117 | para: 10.0 118 | obj_penetration: 50.0 119 | self_penetration: 10.0 120 | 121 | cost_weights: 122 | hand_mesh: 0.0 123 | qpos: 1.0 124 | translation: 2.0 125 | rotation: 2.0 126 | device: *device 127 | rotation_type: *rotation_type 128 | norm_type: *norm_type 129 | 130 | 131 | optimizer: 132 | name: adam 133 | lr: 1.0e-4 134 | weight_decay: 1.0e-4 135 | 136 | scheduler: 137 | name: cosine 138 | t_max: *epochs 139 | min_lr: 1.0e-5 140 | -------------------------------------------------------------------------------- /config/infer_idgc_test.yaml: -------------------------------------------------------------------------------- 1 | set: "test" 2 | data: 3 | test: 4 | data_root: "./data/oakink" 5 | pose_path: "./data/dexgys/test_with_guide_v2.1.json" 6 | num_rotation_aug: 1 7 | num_equivalent_aug: 1 8 | sample_in_pose: False 9 | guidence_type: "fine" 10 | rotation_type: "euler" 11 | norm_type: "minmax11" 12 | batch_size: 128 13 | num_workers: 8 14 | 15 | hand_model: 16 | mjcf_path: ./data/mjcf/shadow_hand.xml 17 | mesh_path: ./data/mjcf/meshes 18 | n_surface_points: 1024 19 | contact_points_path: ./data/mjcf/contact_points.json 20 | penetration_points_path: ./data/mjcf/penetration_points.json 21 | model: 22 | checkpoint_path: None 23 | 24 | q1: 25 | lambda_torque: 10 26 | m: 8 27 | mu: 1 28 | nms: true 29 | thres_contact: 0.01 30 | thres_pen: 0.005 31 | thres_tpen: 0.01 32 | -------------------------------------------------------------------------------- /config/infer_idgc_train.yaml: -------------------------------------------------------------------------------- 1 | set: "train" 2 | data: 3 | test: 4 | data_root: "./data/oakink" 5 | pose_path: "./data/dexgys/train_with_guide_v2.1.json" 6 | num_rotation_aug: 1 7 | num_equivalent_aug: 1 8 | sample_in_pose: False 9 | guidence_type: "fine" 10 | rotation_type: "euler" 11 | norm_type: "minmax11" 12 | batch_size: 128 13 | num_workers: 8 14 | 15 | hand_model: 16 | mjcf_path: ./data/mjcf/shadow_hand.xml 17 | mesh_path: ./data/mjcf/meshes 18 | n_surface_points: 1024 19 | contact_points_path: ./data/mjcf/contact_points.json 20 | penetration_points_path: ./data/mjcf/penetration_points.json 21 | model: 22 | checkpoint_path: None 23 | 24 | q1: 25 | lambda_torque: 10 26 | m: 8 27 | mu: 1 28 | nms: true 29 | thres_contact: 0.01 30 | thres_pen: 0.005 31 | thres_tpen: 0.01 32 | -------------------------------------------------------------------------------- /config/infer_qgc_test.yaml: -------------------------------------------------------------------------------- 1 | set: "test" 2 | data: 3 | test: 4 | data_root: ./data/oakink 5 | pose_path: ./Experiments/idgc/test_results/epoch_1_test/matched_results.json 6 | sample_in_pose: true 7 | guidence_type: fine 8 | rotation_type: euler 9 | norm_type: minmax11 10 | batch_size: 16 11 | num_workers: 4 12 | 13 | hand_model: 14 | mjcf_path: ./data/mjcf/shadow_hand.xml 15 | mesh_path: ./data/mjcf/meshes 16 | n_surface_points: 1024 17 | contact_points_path: ./data/mjcf/contact_points.json 18 | penetration_points_path: ./data/mjcf/penetration_points.json 19 | model: 20 | checkpoint_path: None 21 | 22 | q1: 23 | lambda_torque: 10 24 | m: 8 25 | mu: 1 26 | nms: true 27 | thres_contact: 0.01 28 | thres_pen: 0.005 29 | thres_tpen: 0.01 30 | -------------------------------------------------------------------------------- /config/qgc.yaml: -------------------------------------------------------------------------------- 1 | device: &device cuda:0 2 | rotation_type: &rotation_type euler 3 | rotation_dim: &rotation_dim 3 4 | 5 | ncols: 120 6 | epochs: &epochs 100 7 | print_freq: 500 8 | validate_freq: 1 9 | save_root: ./Experiments/qgc 10 | save_top_n: 20 11 | log_dir: logs 12 | seed: 3407 13 | norm_type: &norm_type minmax11 14 | guidence_type: &guidence_type "fine" # fine 15 | frozen_clip: False 16 | 17 | data: 18 | name: refinement 19 | train: 20 | data_root: &data_root "./data/oakink" 21 | pose_path: ./Experiments/idgc/test_results/epoch_1_train/matched_results.json 22 | sample_in_pose: &sample_in_pose True 23 | guidence_type: *guidence_type 24 | rotation_type: *rotation_type 25 | norm_type: *norm_type 26 | batch_size: 32 27 | num_workers: 8 28 | 29 | val: 30 | data_root: *data_root 31 | pose_path: ./Experiments/idgc/test_results/epoch_1_test/matched_results.json 32 | sample_in_pose: True 33 | guidence_type: *guidence_type 34 | rotation_type: *rotation_type 35 | norm_type: *norm_type 36 | batch_size: 32 37 | num_workers: 8 38 | 39 | model: 40 | name: IRF 41 | steps: 100 42 | schedule_cfg: 43 | beta: [0.0001, 0.01] 44 | beta_schedule: 'linear' 45 | s: 0.008 46 | rand_t_type: 'half' # 'half' or 'all' 47 | loss_type: 'l2' # 'l1' or 'l2' 48 | out_sigmoid: False 49 | pred_abs: False 50 | device: *device 51 | rotation_type: *rotation_type 52 | 53 | decoder: 54 | name: unet 55 | task_num: 0 56 | cls_dim: 0 57 | use_guidence: False 58 | use_obj: True 59 | use_hand: True 60 | plus_condition_type: () 61 | trans_condition_type: "obj_hand" 62 | backbone: 63 | name: pointnet2 64 | use_pooling: False 65 | layer1: 66 | npoint: 1024 67 | radius_list: [0.02] 68 | nsample_list: [64] 69 | mlp_list: [0, 64, 128] 70 | layer2: 71 | npoint: 128 72 | radius_list: [0.04] 73 | nsample_list: [16] 74 | mlp_list: [128, 256, 256] 75 | layer3: 76 | npoint: 16 77 | radius_list: [0.08] 78 | nsample_list: [4] 79 | mlp_list: 80 | - 256 81 | - 512 82 | - &encoder_out 512 83 | use_xyz: true 84 | normalize_xyz: true 85 | 86 | hand_backbone: 87 | name: pointnet2 88 | use_pooling: False 89 | layer1: 90 | npoint: 1024 91 | radius_list: [0.02] 92 | nsample_list: [64] 93 | mlp_list: [0, 64, 128] 94 | layer2: 95 | npoint: 128 96 | radius_list: [0.04] 97 | nsample_list: [16] 98 | mlp_list: [128, 256, 256] 99 | layer3: 100 | npoint: 16 101 | radius_list: [0.08] 102 | nsample_list: [4] 103 | mlp_list: 104 | - 256 105 | - 512 106 | - *encoder_out 107 | use_xyz: true 108 | normalize_xyz: true 109 | 110 | d_x: 28 # placeholder 111 | d_model: 512 112 | time_embed_mult: 2 113 | nblocks: 4 114 | resblock_dropout: 0.0 115 | transformer_num_heads: 8 116 | transformer_dim_head: 64 117 | transformer_dropout: 0.1 118 | transformer_depth: 1 119 | transformer_mult_ff: 2 120 | context_dim: 512 121 | use_position_embedding: false # for input x 122 | 123 | criterion: 124 | hand_model: 125 | mjcf_path: ./data/mjcf/shadow_hand.xml 126 | mesh_path: ./data/mjcf/meshes 127 | n_surface_points: 1024 128 | contact_points_path: ./data/mjcf/contact_points.json 129 | penetration_points_path: ./data/mjcf/penetration_points.json 130 | fingertip_points_path: ./data/mjcf/fingertip.json 131 | loss_weights: 132 | hand_chamfer: 1.0 133 | para: 10.0 134 | obj_penetration: 100.0 135 | self_penetration: 10.0 136 | cmap: 10.0 137 | 138 | cost_weights: 139 | hand_mesh: 0.0 140 | qpos: 1.0 141 | translation: 2.0 142 | rotation: 2.0 143 | device: *device 144 | rotation_type: *rotation_type 145 | norm_type: *norm_type 146 | 147 | 148 | optimizer: 149 | name: adam 150 | lr: 1.0e-4 151 | weight_decay: 1.0e-4 152 | 153 | scheduler: 154 | name: cosine 155 | t_max: *epochs 156 | min_lr: 1.0e-5 157 | -------------------------------------------------------------------------------- /config/test_base.yaml: -------------------------------------------------------------------------------- 1 | set: "test" 2 | data: 3 | test: 4 | data_root: "./data/oakink" 5 | pose_path: "./data/dexgys/test_with_guide_v2.1.json" 6 | num_rotation_aug: 1 7 | num_equivalent_aug: 1 8 | sample_in_pose: False 9 | guidence_type: "fine" 10 | rotation_type: "euler" 11 | norm_type: "minmax11" 12 | batch_size: 128 13 | num_workers: 8 14 | 15 | hand_model: 16 | mjcf_path: ./data/mjcf/shadow_hand.xml 17 | mesh_path: ./data/mjcf/meshes 18 | n_surface_points: 1024 19 | contact_points_path: ./data/mjcf/contact_points.json 20 | penetration_points_path: ./data/mjcf/penetration_points.json 21 | model: 22 | checkpoint_path: None 23 | 24 | q1: 25 | lambda_torque: 10 26 | m: 8 27 | mu: 1 28 | nms: true 29 | thres_contact: 0.01 30 | thres_pen: 0.005 31 | thres_tpen: 0.01 32 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from torch.utils.data import Dataset 4 | from .task_dex_datasets import TaskDataset_Pose 5 | from .refine_datasets import RefineDataset 6 | 7 | 8 | def build_datasets(data_cfg) -> Tuple[Dataset, Dataset, Optional[Dataset]]: 9 | 10 | if data_cfg.name.lower() == "task_pose": 11 | if not hasattr(data_cfg, "test"): 12 | train_set = TaskDataset_Pose( 13 | data_root=data_cfg.train.data_root, 14 | pose_path=data_cfg.train.pose_path, 15 | rotation_type=data_cfg.train.rotation_type, 16 | sample_in_pose = data_cfg.train.sample_in_pose, 17 | norm_type=data_cfg.train.norm_type, 18 | guidence_type=data_cfg.train.guidence_type, 19 | is_train=True, 20 | ) 21 | val_set = TaskDataset_Pose( 22 | data_root=data_cfg.val.data_root, 23 | pose_path=data_cfg.val.pose_path, 24 | rotation_type=data_cfg.val.rotation_type, 25 | sample_in_pose = data_cfg.val.sample_in_pose, 26 | norm_type=data_cfg.val.norm_type, 27 | guidence_type=data_cfg.val.guidence_type, 28 | is_train=False, 29 | ) 30 | test_set = None 31 | elif hasattr(data_cfg, "test"): 32 | train_set = None 33 | val_set = None 34 | test_set = TaskDataset_Pose( 35 | data_root=data_cfg.test.data_root, 36 | pose_path=data_cfg.test.pose_path, 37 | rotation_type=data_cfg.test.rotation_type, 38 | sample_in_pose = data_cfg.test.sample_in_pose, 39 | norm_type=data_cfg.test.norm_type, 40 | guidence_type=data_cfg.test.guidence_type, 41 | is_train=False, 42 | ) 43 | 44 | else: 45 | raise Exception("1") 46 | return train_set, val_set, test_set 47 | elif data_cfg.name.lower() == "refinement": 48 | if not hasattr(data_cfg, "test"): 49 | train_set = RefineDataset( 50 | data_root=data_cfg.train.data_root, 51 | pose_path=data_cfg.train.pose_path, 52 | rotation_type=data_cfg.train.rotation_type, 53 | sample_in_pose = data_cfg.train.sample_in_pose, 54 | norm_type=data_cfg.train.norm_type, 55 | guidence_type=data_cfg.train.guidence_type, 56 | is_train=True, 57 | ) 58 | val_set = RefineDataset( 59 | data_root=data_cfg.val.data_root, 60 | pose_path=data_cfg.val.pose_path, 61 | rotation_type=data_cfg.val.rotation_type, 62 | sample_in_pose = data_cfg.val.sample_in_pose, 63 | norm_type=data_cfg.val.norm_type, 64 | guidence_type=data_cfg.val.guidence_type, 65 | is_train=False, 66 | ) 67 | test_set = None 68 | elif hasattr(data_cfg, "test"): 69 | train_set = None 70 | val_set = None 71 | test_set = RefineDataset( 72 | data_root=data_cfg.test.data_root, 73 | pose_path=data_cfg.test.pose_path, 74 | rotation_type=data_cfg.test.rotation_type, 75 | sample_in_pose = data_cfg.test.sample_in_pose, 76 | norm_type=data_cfg.test.norm_type, 77 | guidence_type=data_cfg.test.guidence_type, 78 | is_train=False, 79 | ) 80 | return train_set, val_set, test_set 81 | else: 82 | raise NotImplementedError(f"Unable to build {data_cfg.name} dataset") 83 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .ldgd import DDPM 3 | from .irf import RenfinmentTransformer 4 | def build_model(cfg): 5 | if cfg.name.lower() == "ldgd": 6 | return DDPM(cfg) 7 | elif cfg.name.lower() == "irf": 8 | return RenfinmentTransformer(cfg) 9 | 10 | else: 11 | raise Exception("1") 12 | -------------------------------------------------------------------------------- /model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .pointnet2 import Pointnet2Backbone 2 | from .resnet import build_resnet_backbone 3 | from .pointnet import PointNetEncoder 4 | from .clip_sd import ClipCustom 5 | 6 | def build_backbone(backbone_cfg): 7 | if backbone_cfg.name.lower() == "resnet": 8 | return build_resnet_backbone(backbone_cfg) 9 | elif backbone_cfg.name.lower() == "pointnet2": 10 | return Pointnet2Backbone(backbone_cfg) 11 | elif backbone_cfg.name.lower() == "pointnet": 12 | return PointNetEncoder(backbone_cfg) 13 | elif backbone_cfg.name.lower() == "clip_sd": 14 | return ClipCustom(backbone_cfg) 15 | else: 16 | raise NotImplementedError(f"No such backbone: {backbone_cfg.name}") 17 | -------------------------------------------------------------------------------- /model/backbone/clip_sd.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch.nn as nn 3 | 4 | class Adapter(nn.Module): 5 | def __init__(self, c_in, c_out, reduction=4): 6 | super().__init__() 7 | self.fc = nn.Sequential( 8 | nn.Linear(c_in, c_in // reduction, bias=False), 9 | nn.ReLU(inplace=True), 10 | nn.Linear(c_in // reduction, c_in // reduction, bias=False), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(c_in // reduction, c_out, bias=False), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | def forward(self, x): 17 | x = self.fc(x) 18 | return x 19 | 20 | class ClipCustom(nn.Module): 21 | """ 22 | Uses the CLIP transformer encoder for text. 23 | """ 24 | def __init__(self, cfg): 25 | super().__init__() 26 | self.device = cfg.device 27 | 28 | self.model, _ = clip.load(cfg.version, jit=False, device="cpu") 29 | if not hasattr(cfg, "use_adapter") or cfg.use_adapter: 30 | self.adapter = Adapter(cfg.dim_in, cfg.dim_out, cfg.reduction) 31 | self.use_adapter = True 32 | else: 33 | self.use_adapter = False 34 | 35 | def freeze(self): 36 | self.model = self.model.eval() 37 | for param in self.model.parameters(): 38 | param.requires_grad = False 39 | 40 | def forward(self, text, pre_data=None): 41 | 42 | tokens = clip.tokenize(text).to(self.device) 43 | 44 | x = self.model.token_embedding(tokens).type(self.model.dtype) # [batch_size, n_ctx, d_model] 45 | 46 | x = x + self.model.positional_embedding.type(self.model.dtype) 47 | x = x.permute(1, 0, 2) # NLD -> LND 48 | x = self.model.transformer(x) 49 | x = x.permute(1, 0, 2) # LND -> NLD 50 | x = self.model.ln_final(x).type(self.model.dtype) 51 | 52 | if self.use_adapter: 53 | x = self.adapter(x) 54 | return x, tokens 55 | 56 | 57 | if __name__ == "__main__": 58 | model = ClipCustom().to("cuda") 59 | cls_token, txt_token = model(["abcd","acv ooo"]) -------------------------------------------------------------------------------- /model/backbone/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | def size_splits(tensor, split_sizes, dim=0): 9 | """Splits the tensor according to chunks of split_sizes. 10 | 11 | Arguments: 12 | tensor (Tensor): tensor to split. 13 | split_sizes (list(int)): sizes of chunks 14 | dim (int): dimension along which to split the tensor. 15 | """ 16 | if dim < 0: 17 | dim += tensor.dim() 18 | 19 | dim_size = tensor.size(dim) 20 | if dim_size != torch.sum(torch.Tensor(split_sizes)): 21 | raise KeyError("Sum of split sizes exceeds tensor dim") 22 | 23 | splits = torch.cumsum(torch.Tensor([0] + split_sizes), dim=0)[:-1] 24 | 25 | return tuple(tensor.narrow(int(dim), int(start), int(length)) 26 | for start, length in zip(splits, split_sizes)) 27 | 28 | 29 | class STN3d(nn.Module): 30 | def __init__(self, channel): 31 | super(STN3d, self).__init__() 32 | self.conv1 = torch.nn.Conv1d(channel, 64, 1) 33 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 34 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 35 | self.fc1 = nn.Linear(1024, 512) 36 | self.fc2 = nn.Linear(512, 256) 37 | self.fc3 = nn.Linear(256, 9) 38 | self.relu = nn.ReLU() 39 | 40 | self.bn1 = nn.BatchNorm1d(64) 41 | self.bn2 = nn.BatchNorm1d(128) 42 | self.bn3 = nn.BatchNorm1d(1024) 43 | self.bn4 = nn.BatchNorm1d(512) 44 | self.bn5 = nn.BatchNorm1d(256) 45 | 46 | def forward(self, x): 47 | batchsize = x.size()[0] 48 | x = F.relu(self.bn1(self.conv1(x))) 49 | x = F.relu(self.bn2(self.conv2(x))) 50 | x = F.relu(self.bn3(self.conv3(x))) 51 | x = torch.max(x, 2, keepdim=True)[0] 52 | x = x.view(-1, 1024) 53 | 54 | x = F.relu(self.bn4(self.fc1(x))) 55 | x = F.relu(self.bn5(self.fc2(x))) 56 | x = self.fc3(x) 57 | 58 | #iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat( 59 | # batchsize, 1) 60 | iden = torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32)).view(1, 9).repeat(batchsize, 1) 61 | if x.is_cuda: 62 | iden = iden.to(x.device) 63 | x = x + iden 64 | x = x.view(-1, 3, 3) 65 | return x 66 | 67 | class STNkd(nn.Module): 68 | def __init__(self, k=64): 69 | super(STNkd, self).__init__() 70 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 71 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 72 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 73 | self.fc1 = nn.Linear(1024, 512) 74 | self.fc2 = nn.Linear(512, 256) 75 | self.fc3 = nn.Linear(256, k * k) 76 | self.relu = nn.ReLU() 77 | 78 | self.bn1 = nn.BatchNorm1d(64) 79 | self.bn2 = nn.BatchNorm1d(128) 80 | self.bn3 = nn.BatchNorm1d(1024) 81 | self.bn4 = nn.BatchNorm1d(512) 82 | self.bn5 = nn.BatchNorm1d(256) 83 | 84 | self.k = k 85 | 86 | def forward(self, x): 87 | batchsize = x.size()[0] 88 | x = F.relu(self.bn1(self.conv1(x))) 89 | x = F.relu(self.bn2(self.conv2(x))) 90 | x = F.relu(self.bn3(self.conv3(x))) 91 | x = torch.max(x, 2, keepdim=True)[0] 92 | x = x.view(-1, 1024) 93 | 94 | x = F.relu(self.bn4(self.fc1(x))) 95 | x = F.relu(self.bn5(self.fc2(x))) 96 | x = self.fc3(x) 97 | 98 | iden = torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)).view(1, self.k * self.k).repeat(batchsize, 1) 99 | if x.is_cuda: 100 | iden = iden.cuda() 101 | x = x + iden 102 | x = x.view(-1, self.k, self.k) 103 | return x 104 | 105 | class PointNetEncoderLight(nn.Module): 106 | def __init__(self, cfg=None): 107 | super().__init__() 108 | # channel = cfg.channel 109 | channel = 3 110 | self.conv1 = torch.nn.Conv1d(channel, 32, 1) 111 | self.conv2 = torch.nn.Conv1d(32, 64, 1) 112 | self.conv3 = torch.nn.Conv1d(64, 64, 1) 113 | self.conv4 = torch.nn.Conv1d(128, 32, 1) 114 | 115 | self.bn1 = nn.BatchNorm1d(32) 116 | self.bn2 = nn.BatchNorm1d(64) 117 | self.bn3 = nn.BatchNorm1d(64) 118 | self.bn4 = nn.BatchNorm1d(32) 119 | 120 | def forward(self, x): 121 | B, D, N = x.size() 122 | assert D == 3 123 | 124 | x = F.relu(self.bn1(self.conv1(x))) 125 | x = F.relu(self.bn2(self.conv2(x))) 126 | g = F.relu(self.bn3(self.conv3(x))) 127 | g = torch.max(x, 2, keepdim=True)[0] 128 | g = g.view(-1, 64, 1).repeat(1, 1, N) 129 | x = torch.cat([x, g], 1) 130 | x = F.relu(self.bn4(self.conv4(x))) 131 | 132 | return x 133 | 134 | class PointNetEncoder(nn.Module): 135 | def __init__(self, cfg): 136 | super(PointNetEncoder, self).__init__() 137 | global_feat=cfg.global_feat 138 | feature_transform=cfg.feature_transform 139 | channel=cfg.channel 140 | use_stn=cfg.use_stn 141 | self.stn = STN3d(channel) if use_stn else None 142 | self.conv1 = torch.nn.Conv1d(channel, 64, 1) 143 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 144 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 145 | self.bn1 = nn.BatchNorm1d(64) 146 | self.bn2 = nn.BatchNorm1d(128) 147 | self.bn3 = nn.BatchNorm1d(1024) 148 | self.global_feat = global_feat 149 | self.feature_transform = feature_transform 150 | if self.feature_transform: 151 | self.fstn = STNkd(k=64) 152 | 153 | def forward(self, x): 154 | B, D, N = x.size() 155 | assert D == 3 156 | if self.stn: 157 | trans = self.stn(x) 158 | x = x.transpose(2, 1) 159 | if self.stn: 160 | x = torch.bmm(x, trans) 161 | x = x.transpose(2, 1) 162 | x = F.relu(self.bn1(self.conv1(x))) 163 | 164 | if self.feature_transform: 165 | trans_feat = self.fstn(x) 166 | x = x.transpose(2, 1) 167 | x = torch.bmm(x, trans_feat) 168 | x = x.transpose(2, 1) 169 | else: 170 | trans_feat = None 171 | 172 | pointfeat = x 173 | x = F.relu(self.bn2(self.conv2(x))) 174 | x = self.bn3(self.conv3(x)) 175 | x = torch.max(x, 2, keepdim=True)[0] 176 | x = x.view(-1, 1024) 177 | if self.global_feat: 178 | return x, pointfeat, trans_feat 179 | else: 180 | raise NotImplementedError() 181 | x = x.view(-1, 1024, 1).repeat(1, 1, N) 182 | return torch.cat([x, pointfeat], 1), trans, trans_feat 183 | 184 | if __name__ == '__main__': 185 | points = torch.randn(2, 3, 778) 186 | print(points.size()) 187 | pointnet = PointNetEncoderLight() 188 | x = pointnet(points) 189 | print(x.size()) -------------------------------------------------------------------------------- /model/backbone/pointnet2.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | sys.path.append("./thirdparty/pointnet2/") 4 | 5 | import torch.nn as nn 6 | from pointnet2_modules import PointnetSAModuleVotes 7 | from torch.functional import Tensor 8 | import torch 9 | 10 | class Pointnet2Backbone(nn.Module): 11 | """ 12 | Backbone network for point cloud feature learning. 13 | Based on Pointnet++ single-scale grouping network. 14 | 15 | Parameters 16 | ---------- 17 | input_feature_dim: int 18 | Number of input channels in the feature descriptor for each point. 19 | e.g. 3 for RGB. 20 | """ 21 | 22 | def __init__(self, cfg): 23 | super().__init__() 24 | 25 | self.sa1 = PointnetSAModuleVotes( 26 | npoint=cfg.layer1.npoint, 27 | radius=cfg.layer1.radius_list[0], 28 | nsample=cfg.layer1.nsample_list[0], 29 | mlp=cfg.layer1.mlp_list, 30 | use_xyz=cfg.use_xyz, 31 | normalize_xyz=cfg.normalize_xyz 32 | ) 33 | 34 | self.sa2 = PointnetSAModuleVotes( 35 | npoint=cfg.layer2.npoint, 36 | radius=cfg.layer2.radius_list[0], 37 | nsample=cfg.layer2.nsample_list[0], 38 | mlp=cfg.layer2.mlp_list, 39 | use_xyz=cfg.use_xyz, 40 | normalize_xyz=cfg.normalize_xyz 41 | ) 42 | 43 | self.sa3 = PointnetSAModuleVotes( 44 | npoint=cfg.layer3.npoint, 45 | radius=cfg.layer3.radius_list[0], 46 | nsample=cfg.layer3.nsample_list[0], 47 | mlp=cfg.layer3.mlp_list, 48 | use_xyz=cfg.use_xyz, 49 | normalize_xyz=cfg.normalize_xyz 50 | ) 51 | if cfg.use_pooling: 52 | self.gap = torch.nn.AdaptiveAvgPool1d(1) 53 | self.use_pooling = cfg.use_pooling 54 | 55 | def _break_up_pc(self, pc): 56 | xyz = pc[..., 0:3].contiguous() 57 | features = ( 58 | pc[..., 3:].transpose(1, 2).contiguous() 59 | if pc.size(-1) > 3 else None 60 | ) 61 | return xyz, features 62 | 63 | def forward(self, pointcloud: Tensor): 64 | """ 65 | Forward pass of the network 66 | 67 | Parameters 68 | ---------- 69 | pointcloud: Variable(Tensor) 70 | (B, N, 3 + input_feature_dim) tensor 71 | Point cloud to run predicts on 72 | Each point in the point-cloud MUST 73 | be formated as (x, y, z, features...) 74 | 75 | Returns 76 | ---------- 77 | xyz: float32 Tensor of shape (B, K, 3) 78 | features: float32 Tensor of shape (B, D, K) 79 | inds: int64 Tensor of shape (B, K) values in [0, N-1] 80 | """ 81 | xyz, features = self._break_up_pc(pointcloud) 82 | xyz, features, fps_inds = self.sa1(xyz, features) 83 | xyz, features, fps_inds = self.sa2(xyz, features) 84 | xyz, features, fps_inds = self.sa3(xyz, features) 85 | if self.use_pooling: 86 | features = self.gap(features) 87 | return xyz, features 88 | -------------------------------------------------------------------------------- /model/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | import torch.nn as nn 7 | import math 8 | import torch.utils.model_zoo as model_zoo 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = nn.BatchNorm2d(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = nn.BatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 73 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, num_classes=1000): 104 | self.inplanes = 64 105 | super(ResNet, self).__init__() 106 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 107 | bias=False) 108 | self.bn1 = nn.BatchNorm2d(64) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 111 | self.layer1 = self._make_layer(block, 64, layers[0]) 112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 115 | self.avgpool = nn.AvgPool2d(7, stride=1) 116 | self.fc1 = nn.Linear(512 * block.expansion, 1024) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 121 | elif isinstance(m, nn.BatchNorm2d): 122 | nn.init.constant_(m.weight, 1) 123 | nn.init.constant_(m.bias, 0) 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | nn.Conv2d(self.inplanes, planes * block.expansion, 130 | kernel_size=1, stride=stride, bias=False), 131 | nn.BatchNorm2d(planes * block.expansion), 132 | ) 133 | 134 | layers = [] 135 | layers.append(block(self.inplanes, planes, stride, downsample)) 136 | self.inplanes = planes * block.expansion 137 | for i in range(1, blocks): 138 | layers.append(block(self.inplanes, planes)) 139 | 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, x): 143 | x = self.conv1(x) 144 | x = self.bn1(x) 145 | x = self.relu(x) 146 | x = self.maxpool(x) 147 | 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | 153 | x = self.avgpool(x) 154 | x = x.view(x.size(0), -1) 155 | x = self.relu(x) 156 | 157 | x = self.fc1(x) 158 | x = self.relu(x) 159 | 160 | return x 161 | 162 | 163 | def resnet18(pretrained=False, **kwargs): 164 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 165 | if pretrained: 166 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet18']) 167 | model.load_state_dict(pretrained_state_dict, strict=False) 168 | return model 169 | 170 | 171 | def resnet34(pretrained=False, **kwargs): 172 | """Constructs a ResNet-34 model. 173 | 174 | Args: 175 | pretrained (bool): If True, returns a model pre-trained on ImageNet 176 | """ 177 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 178 | if pretrained: 179 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet34']) 180 | model.load_state_dict(pretrained_state_dict, strict=False) 181 | return model 182 | 183 | 184 | def resnet50(pretrained=False, **kwargs): 185 | """Constructs a ResNet-50 model. 186 | 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | """ 190 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 191 | if pretrained: 192 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet50']) 193 | model.load_state_dict(pretrained_state_dict, strict=False) 194 | return model 195 | 196 | 197 | def resnet101(pretrained=False, **kwargs): 198 | """Constructs a ResNet-101 model. 199 | 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 204 | if pretrained: 205 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet101']) 206 | model.load_state_dict(pretrained_state_dict, strict=False) 207 | return model 208 | 209 | 210 | def resnet152(pretrained=False, **kwargs): 211 | """Constructs a ResNet-152 model. 212 | 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 217 | if pretrained: 218 | pretrained_state_dict = model_zoo.load_url(model_urls['resnet152']) 219 | model.load_state_dict(pretrained_state_dict, strict=False) 220 | return model 221 | 222 | 223 | model_functions = { 224 | 18: resnet18, 225 | 34: resnet34, 226 | 50: resnet50, 227 | 101: resnet101, 228 | 152: resnet152 229 | } 230 | 231 | def build_resnet_backbone(depth, **kwargs): 232 | return model_functions[depth](**kwargs) -------------------------------------------------------------------------------- /model/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import UNetModel 2 | 3 | def build_decoder(decoder_cfg): 4 | if decoder_cfg.name.lower() == "unet": 5 | return UNetModel(decoder_cfg) 6 | else: 7 | raise NotImplementedError(f"No such decode: {decoder_cfg.name}") 8 | -------------------------------------------------------------------------------- /model/decoder/unet.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from einops import rearrange 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from model.utils.diffusion_utils import timestep_embedding, ResBlock, SpatialTransformer 8 | from model.backbone import build_backbone 9 | 10 | class UNetModel(nn.Module): 11 | def __init__(self, cfg) -> None: 12 | super(UNetModel, self).__init__() 13 | 14 | self.d_x = cfg.d_x 15 | self.d_model = cfg.d_model 16 | self.nblocks = cfg.nblocks 17 | self.resblock_dropout = cfg.resblock_dropout 18 | self.transformer_num_heads = cfg.transformer_num_heads 19 | self.transformer_dim_head = cfg.transformer_dim_head 20 | self.transformer_dropout = cfg.transformer_dropout 21 | self.transformer_depth = cfg.transformer_depth 22 | self.transformer_mult_ff = cfg.transformer_mult_ff 23 | self.context_dim = cfg.context_dim 24 | self.use_position_embedding = cfg.use_position_embedding # for input sequence x 25 | self.plus_condition_type = cfg.plus_condition_type 26 | self.trans_condition_type = cfg.trans_condition_type 27 | 28 | ## create scene model from config 29 | self.scene_model = build_backbone(cfg.backbone) 30 | if hasattr(cfg, "use_hand") and cfg.use_hand: 31 | self.hand_model = build_backbone(cfg.hand_backbone) 32 | if hasattr(cfg, "use_guidence") and cfg.use_guidence: 33 | self.language_model = build_backbone(cfg.language_encoder) 34 | 35 | time_embed_dim = self.d_model * cfg.time_embed_mult 36 | self.time_embed = nn.Sequential( 37 | nn.Linear(self.d_model, time_embed_dim), 38 | nn.SiLU(), 39 | nn.Linear(time_embed_dim, time_embed_dim), 40 | ) 41 | 42 | self.in_layers = nn.Sequential( 43 | nn.Conv1d(self.d_x, self.d_model, 1) 44 | ) 45 | 46 | self.layers = nn.ModuleList() 47 | for i in range(self.nblocks): 48 | self.layers.append( 49 | ResBlock( 50 | self.d_model, 51 | time_embed_dim, 52 | self.resblock_dropout, 53 | self.plus_condition_type, 54 | self.d_model, 55 | ) 56 | ) 57 | self.layers.append( 58 | SpatialTransformer( 59 | self.d_model, 60 | self.transformer_num_heads, 61 | self.transformer_dim_head, 62 | depth=self.transformer_depth, 63 | dropout=self.transformer_dropout, 64 | mult_ff=self.transformer_mult_ff, 65 | context_dim=self.context_dim, 66 | ) 67 | ) 68 | 69 | self.out_layers = nn.Sequential( 70 | nn.GroupNorm(32, self.d_model), 71 | nn.SiLU(), 72 | nn.Conv1d(self.d_model, self.d_x, 1), 73 | ) 74 | 75 | def condition_mask(self, obj_cond, txt_cond, text_vector): 76 | cond = torch.cat([obj_cond, txt_cond], dim=1) 77 | batch_size, seq_length, embedding_size = cond.size() 78 | 79 | txt_index = text_vector.argmax(dim=-1) + cond.shape[0] + 2 80 | 81 | range_matrix = torch.arange(seq_length).expand(batch_size, seq_length).to(obj_cond.device) 82 | 83 | attention_mask = range_matrix < txt_index.unsqueeze(1) 84 | attention_mask[:,cond.shape[0]+1] = False 85 | return cond, attention_mask 86 | 87 | def forward(self, x_t: torch.Tensor, ts: torch.Tensor, data:Dict) -> torch.Tensor: 88 | """ Apply the model to an input batch 89 | 90 | Args: 91 | x_t: the input data, or 92 | ts: timestep, 1-D batch of timesteps 93 | cond: condition feature 94 | 95 | Return: 96 | the denoised target data, i.e., $x_{t-1}$ 97 | """ 98 | 99 | if self.trans_condition_type == "txt_obj": 100 | cond = torch.cat([data["cond_obj"], data["cond_txt"]], dim=1) 101 | atten_mask = None 102 | elif self.trans_condition_type == "obj_hand": 103 | cond = torch.cat([data["cond_obj"], data["cond_hand"]], dim=1) 104 | atten_mask = None 105 | else: 106 | raise Exception("no valid trans condition") 107 | 108 | in_shape = len(x_t.shape) 109 | if in_shape == 2: 110 | x_t = x_t.unsqueeze(1) 111 | assert len(x_t.shape) == 3 112 | 113 | ## time embedding 114 | if ts != None: 115 | t_emb = timestep_embedding(ts, self.d_model) 116 | t_emb = self.time_embed(t_emb) 117 | else: 118 | t_emb = None 119 | 120 | h = rearrange(x_t, 'b l c -> b c l') 121 | h = self.in_layers(h) # 122 | 123 | ## prepare position embedding for input x 124 | if self.use_position_embedding: 125 | B, DX, TX = h.shape 126 | pos_Q = torch.arange(TX, dtype=h.dtype, device=h.device) 127 | pos_embedding_Q = timestep_embedding(pos_Q, DX) # 128 | h = h + pos_embedding_Q.permute(1, 0) # 129 | 130 | for i in range(self.nblocks): 131 | h = self.layers[i * 2 + 0](h, t_emb, data) 132 | h = self.layers[i * 2 + 1](h, context=cond, mask=atten_mask) 133 | h = self.out_layers(h) 134 | h = rearrange(h, 'b c l -> b l c') 135 | 136 | ## reverse to original shape 137 | if in_shape == 2: 138 | h = h.squeeze(1) 139 | 140 | return h 141 | 142 | def condition_obj(self, data: Dict) -> torch.Tensor: 143 | """ Obtain scene feature with scene model 144 | 145 | Args: 146 | data: dataloader-provided data 147 | 148 | Return: 149 | Condition feature 150 | """ 151 | 152 | b = data['obj_pc'].shape[0] 153 | obj_pc = data['obj_pc'].to(torch.float32) 154 | _, obj_feat = self.scene_model(obj_pc) 155 | obj_feat = obj_feat.permute(0,2,1).contiguous() 156 | # (B, C, N) 157 | return obj_feat 158 | 159 | def condition_language(self, data: Dict) -> torch.Tensor: 160 | 161 | cond_txt, text_vector = self.language_model(data["guidence"], data["clip_data"] if "clip_data" in data else None) 162 | cond_txt_cls = cond_txt[torch.arange(cond_txt.shape[0]), text_vector.argmax(dim=-1)] 163 | 164 | return cond_txt_cls, cond_txt 165 | 166 | def condition_hand(self, data: Dict) -> torch.Tensor: 167 | b = data['hand_pc'].shape[0] 168 | hand_pc = data['hand_pc'] 169 | _, hand_feat = self.hand_model(hand_pc) 170 | hand_feat = hand_feat.permute(0,2,1).contiguous() 171 | # (B, C, N) 172 | return hand_feat 173 | 174 | -------------------------------------------------------------------------------- /model/irf.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.functional import Tensor 6 | from .decoder import build_decoder 7 | from .loss import GraspLossPose 8 | from .utils.diffusion_utils import make_schedule_ddpm 9 | 10 | class RenfinmentTransformer(nn.Module): 11 | def __init__(self, cfg) -> None: 12 | super(RenfinmentTransformer, self).__init__() 13 | 14 | self.eps_model = build_decoder(cfg.decoder) 15 | self.criterion = GraspLossPose(cfg.criterion) 16 | self.loss_weights = cfg.criterion.loss_weights 17 | self.timesteps = cfg.steps 18 | self.schedule_cfg = cfg.schedule_cfg 19 | self.rand_t_type = cfg.rand_t_type 20 | self.pred_abs = cfg.pred_abs 21 | 22 | for k, v in make_schedule_ddpm(self.timesteps, **self.schedule_cfg).items(): 23 | self.register_buffer(k, v) 24 | 25 | @property 26 | def device(self): 27 | return self.betas.device 28 | 29 | 30 | def forward(self, data: Dict): 31 | """ Reverse diffusion process, sampling with the given data containing condition 32 | 33 | Args: 34 | data: test data, data['norm_pose'] gives the target data, 35 | data['obj_pc'] gives the condition 36 | 37 | Return: 38 | Computed loss 39 | """ 40 | B = data['coarse_norm_pose'].shape[0] 41 | 42 | ## predict noise 43 | data['hand_pc'] = self.criterion.hand_model(data['coarse_pose'], with_surface_points=True)["surface_points"] 44 | data['cond_hand'] = self.eps_model.condition_hand(data) 45 | 46 | data["cond_obj"] = self.eps_model.condition_obj(data) 47 | 48 | output = self.eps_model(data['coarse_norm_pose'], None, data) 49 | 50 | if not self.pred_abs: 51 | output += data['coarse_norm_pose'] 52 | if self.training: 53 | loss_dict = self.criterion({"pred_pose_norm": output}, data) 54 | else: 55 | loss_dict, _e = self.criterion({"pred_pose_norm": output}, data) 56 | 57 | loss = 0 58 | for k, v in loss_dict.items(): 59 | if k in self.loss_weights: 60 | loss += v * self.loss_weights[k] 61 | 62 | return loss, loss_dict, None 63 | 64 | def forward_test(self, data: Dict): 65 | """ Reverse diffusion process, sampling with the given data containing condition 66 | 67 | Args: 68 | data: test data, data['norm_pose'] gives the target data, 69 | data['obj_pc'] gives the condition 70 | 71 | Return: 72 | Computed loss 73 | """ 74 | B = data['coarse_norm_pose'].shape[0] 75 | 76 | ## predict noise 77 | data['hand_pc'] = self.criterion.hand_model(data['coarse_pose'], with_surface_points=True)["surface_points"] 78 | data['cond_hand'] = self.eps_model.condition_hand(data) 79 | 80 | data["cond_obj"] = self.eps_model.condition_obj(data) 81 | 82 | output = self.eps_model(data['coarse_norm_pose'], None, data) 83 | 84 | if not self.pred_abs: 85 | output += data['coarse_norm_pose'] 86 | if self.training: 87 | loss_dict = self.criterion({"pred_pose_norm": output}, data) 88 | else: 89 | loss_dict, preditions = self.criterion({"pred_pose_norm": output}, data) 90 | 91 | loss = 0 92 | for k, v in loss_dict.items(): 93 | if k in self.loss_weights: 94 | loss += v * self.loss_weights[k] 95 | 96 | return loss, loss_dict, preditions 97 | 98 | def forward_infer(self, data: Dict, k=4): 99 | """ Reverse diffusion process, sampling with the given data containing condition 100 | 101 | Args: 102 | data: test data, data['norm_pose'] gives the target data, 103 | data['obj_pc'] gives the condition 104 | 105 | Return: 106 | Computed loss 107 | """ 108 | B = data['coarse_norm_pose'].shape[0] 109 | 110 | ## predict noise 111 | data['hand_pc'] = self.criterion.hand_model(data['coarse_pose'], with_surface_points=True)["surface_points"] 112 | data['cond_hand'] = self.eps_model.condition_hand(data) 113 | 114 | data["cond_obj"] = self.eps_model.condition_obj(data) 115 | 116 | output = self.eps_model(data['coarse_norm_pose'], None, data) 117 | 118 | if not self.pred_abs: 119 | output += data['coarse_norm_pose'] 120 | 121 | preds_hand, targets_hand = self.criterion.infer_norm_process_dict({"pred_pose_norm": output}, data) 122 | 123 | return preds_hand, targets_hand 124 | 125 | 126 | @torch.no_grad() 127 | def sample(self, data: Dict, k: int=1) -> torch.Tensor: 128 | """ Reverse diffusion process, sampling with the given data containing condition 129 | In this method, the sampled results are unnormalized and converted to absolute representation. 130 | 131 | Args: 132 | data: test data, data['norm_pose'] gives the target data shape 133 | k: the number of sampled data 134 | 135 | Return: 136 | Sampled results, the shape is 137 | """ 138 | ## TODO ddim sample function 139 | ksamples = [] 140 | for _ in range(k): 141 | ksamples.append(self.p_sample_loop(data)) 142 | ksamples = torch.stack(ksamples, dim=1) 143 | return ksamples 144 | 145 | -------------------------------------------------------------------------------- /model/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .grasp_loss_pose import GraspLossPose 2 | -------------------------------------------------------------------------------- /model/loss/matcher.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable, Dict, List 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from scipy.optimize import linear_sum_assignment 8 | from torch.functional import Tensor 9 | 10 | 11 | class Matcher(nn.Module): 12 | def __init__(self, weight_dict: Dict[str, float]): 13 | super().__init__() 14 | # exclude weights equaling 0 15 | self.weight_dict = {k: v for k, v in weight_dict.items() if v > 0} 16 | 17 | @torch.no_grad() 18 | def forward(self, preds: Dict, targets: Dict): 19 | _example = preds["rotation"] 20 | device = _example.device 21 | batch_size, nqueries = _example.shape[:2] 22 | rotation_type = targets["rotation_type"][0] 23 | 24 | cost_matrices = [] 25 | for name, weight in self.weight_dict.items(): 26 | m = getattr(self, f"get_{name}_cost_mat") 27 | if name == "rotation": 28 | cost_mat = m(preds, targets, weight=weight, rotation_type=rotation_type) 29 | cost_mat = m(preds, targets, weight=weight) 30 | cost_matrices.append(cost_mat) 31 | final_cost = [sum(x).detach().cpu().numpy() for x in zip(*cost_matrices)] 32 | 33 | assignments = [] 34 | # auxiliary variables useful for batched loss computation 35 | per_query_gt_inds = torch.zeros( 36 | [batch_size, nqueries], dtype=torch.int64, device=device 37 | ) 38 | query_matched_mask = torch.zeros( 39 | [batch_size, nqueries], dtype=torch.float32, device=device 40 | ) 41 | for b in range(batch_size): 42 | assign = linear_sum_assignment(final_cost[b]) 43 | assign = [ 44 | torch.from_numpy(x).long().to(device) 45 | for x in assign 46 | ] 47 | per_query_gt_inds[b, assign[0]] = assign[1] 48 | query_matched_mask[b, assign[0]] = 1 49 | assignments.append(assign) 50 | return { 51 | "final_cost": final_cost, 52 | "assignments": assignments, 53 | "per_query_gt_inds": per_query_gt_inds, 54 | "query_matched_mask": query_matched_mask, 55 | } 56 | 57 | def get_hand_mesh_cost_mat( 58 | self, 59 | prediction: Tensor, 60 | targets: List[Tensor], 61 | weight: float = 1.0, 62 | ) -> List[Tensor]: 63 | # TODO: implement chamfer loss for hand mesh cost 64 | raise NotImplementedError("Unable to calculate hand mesh cost matrix yet. Please help me to implement it ^_^") 65 | 66 | def get_qpos_cost_mat( 67 | self, 68 | prediction: Tensor, 69 | targets: List[Tensor], 70 | weight: float = 1.0, 71 | ) -> List[Tensor]: 72 | pred_qpos = prediction["qpos_norm"] 73 | target_qpos = [x[..., 3:25] for x in targets["norm_pose"]] 74 | return self._get_cost_mat_by_elementwise(pred_qpos, target_qpos, weight=weight) 75 | 76 | def get_translation_cost_mat( 77 | self, 78 | prediction: Tensor, 79 | targets: List[Tensor], 80 | weight: float = 1.0, 81 | ) -> List[Tensor]: 82 | pred_translation = prediction["translation_norm"] 83 | target_translation = [x[..., :3] for x in targets["norm_pose"]] 84 | return self._get_cost_mat_by_elementwise(pred_translation, target_translation, weight=weight) 85 | 86 | def get_rotation_cost_mat( 87 | self, 88 | prediction: Tensor, 89 | targets: List[Tensor], 90 | weight: float = 1.0, 91 | rotation_type: str = "quaternion", 92 | ) -> List[Tensor]: 93 | if hasattr(self, f"_get_{rotation_type}_cost_mat"): 94 | m = getattr(self, f"_get_{rotation_type}_cost_mat") 95 | pred_rotation = prediction["rotation"] # (num_queries, D) 96 | target_rotation = [x[..., 25:] for x in targets["norm_pose"]] # [(ngt1, D), (ngt2, D), ...] 97 | return m(pred_rotation, target_rotation, weight) 98 | else: 99 | raise NotImplementedError(f"Unable to get {rotation_type} cost matrix") 100 | 101 | def _get_cost_mat_by_elementwise( 102 | self, 103 | prediction: Tensor, 104 | targets: List[Tensor], 105 | weight: float = 1.0, 106 | element_wise_func: Callable[[Tensor, Tensor], Tensor] = partial(F.l1_loss, reduction="none"), 107 | ) -> List[Tensor]: 108 | """ 109 | calculate cost matrix by element-wise operations 110 | 111 | Params: 112 | prediction: B, nqueries, D 113 | targets: [(ngt1, D), (ngt2, D), ...] 114 | weight: a float number for current cost matrix 115 | element_wise_func: an element-wise function for two tensors. Default is l1_loss 116 | 117 | return: 118 | cost_mat: [(nqueries, ngt1), (nqueries, ngt2), ...] 119 | """ 120 | B = prediction.size(0) 121 | assert B == len(targets), f"batch size and len(targets) should be the same" 122 | nqueries = prediction.size(1) 123 | cost_mat = [] 124 | for i in range(B): 125 | rot, gt = prediction[i], targets[i] 126 | ngt = gt.size(0) 127 | rot = rot.unsqueeze(1).expand(-1, ngt, -1) # (nqueries, ngt, D) 128 | gt = gt.unsqueeze(0).expand(nqueries, -1, -1) # (nqueries, ngt, D) 129 | cost = element_wise_func(rot, gt).sum(-1) 130 | cost_mat.append(weight * cost) 131 | return cost_mat 132 | 133 | def _get_quaternion_cost_mat(self, prediction: Tensor, targets: List[Tensor], weight: float = 1.0) -> List[Tensor]: 134 | B = prediction.size(0) 135 | assert B == len(targets), f"batch size and len(targets) should be the same" 136 | cost_mat = [] 137 | for i in range(B): 138 | rot, gt = prediction[i], targets[i] 139 | cost = 1 - (rot @ gt.T).abs().detach() 140 | cost_mat.append(weight * cost) 141 | return cost_mat 142 | 143 | def _get_rotation_6d_cost_mat(self, prediction: Tensor, targets: List[Tensor], weight: float = 1.0) -> List[Tensor]: 144 | cost_mat = self._get_cost_mat_by_elementwise( 145 | prediction, 146 | targets, 147 | weight=weight, 148 | ) 149 | return cost_mat 150 | 151 | def _get_euler_cost_mat(self, prediction: Tensor, targets: List[Tensor], weight: float = 1.0) -> List[Tensor]: 152 | """ 153 | specially-designed l1 loss for euler angles 154 | """ 155 | B = prediction.size(0) 156 | assert B == len(targets), f"batch size and len(targets) should be the same" 157 | cost_mat = [] 158 | for i in range(B): 159 | rot, gt = prediction[i].unsqueeze(1), targets[i].unsqueeze(0) 160 | error = (rot - gt).abs().sum(-1) 161 | cost = torch.where(error < 0.5, error, 1 - error) 162 | cost_mat.append(weight * cost) 163 | return cost_mat 164 | -------------------------------------------------------------------------------- /model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/Grasp-as-You-Say/4694fa9369523d09d0c2cea6ec7bcddd2cb206d8/model/utils/__init__.py -------------------------------------------------------------------------------- /model/utils/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch.nn as nn 3 | from functools import partial 4 | import copy 5 | 6 | 7 | class BatchNormDim1Swap(nn.BatchNorm1d): 8 | """ 9 | Used for nn.Transformer that uses a HW x N x C rep 10 | """ 11 | 12 | def forward(self, x): 13 | """ 14 | x: HW x N x C 15 | permute to N x C x HW 16 | Apply BN on C 17 | permute back 18 | """ 19 | hw, n, c = x.shape 20 | x = x.permute(1, 2, 0) 21 | x = super(BatchNormDim1Swap, self).forward(x) 22 | # x: n x c x hw -> hw x n x c 23 | x = x.permute(2, 0, 1) 24 | return x 25 | 26 | 27 | NORM_DICT = { 28 | "bn": BatchNormDim1Swap, 29 | "bn1d": nn.BatchNorm1d, 30 | "id": nn.Identity, 31 | "ln": nn.LayerNorm, 32 | } 33 | 34 | ACTIVATION_DICT = { 35 | "relu": nn.ReLU, 36 | "gelu": nn.GELU, 37 | "leakyrelu": partial(nn.LeakyReLU, negative_slope=0.1), 38 | } 39 | 40 | WEIGHT_INIT_DICT = { 41 | "xavier_uniform": nn.init.xavier_uniform_, 42 | } 43 | 44 | 45 | class GenericMLP(nn.Module): 46 | def __init__( 47 | self, 48 | input_dim, 49 | hidden_dims, 50 | output_dim, 51 | norm_fn_name=None, 52 | activation="relu", 53 | use_conv=False, 54 | dropout=None, 55 | hidden_use_bias=False, 56 | output_use_bias=True, 57 | output_use_activation=False, 58 | output_use_norm=False, 59 | weight_init_name=None, 60 | ): 61 | super().__init__() 62 | activation = ACTIVATION_DICT[activation] 63 | norm = None 64 | if norm_fn_name is not None: 65 | norm = NORM_DICT[norm_fn_name] 66 | if norm_fn_name == "ln" and use_conv: 67 | norm = lambda x: nn.GroupNorm(1, x) # easier way to use LayerNorm 68 | 69 | if dropout is not None: 70 | if not isinstance(dropout, list): 71 | dropout = [dropout for _ in range(len(hidden_dims))] 72 | 73 | layers = [] 74 | prev_dim = input_dim 75 | for idx, x in enumerate(hidden_dims): 76 | if use_conv: 77 | layer = nn.Conv1d(prev_dim, x, 1, bias=hidden_use_bias) 78 | else: 79 | layer = nn.Linear(prev_dim, x, bias=hidden_use_bias) 80 | layers.append(layer) 81 | if norm: 82 | layers.append(norm(x)) 83 | layers.append(activation()) 84 | if dropout is not None: 85 | layers.append(nn.Dropout(p=dropout[idx])) 86 | prev_dim = x 87 | if use_conv: 88 | layer = nn.Conv1d(prev_dim, output_dim, 1, bias=output_use_bias) 89 | else: 90 | layer = nn.Linear(prev_dim, output_dim, bias=output_use_bias) 91 | layers.append(layer) 92 | 93 | if output_use_norm: 94 | layers.append(norm(output_dim)) 95 | 96 | if output_use_activation: 97 | layers.append(activation()) 98 | 99 | self.layers = nn.Sequential(*layers) 100 | 101 | if weight_init_name is not None: 102 | self.do_weight_init(weight_init_name) 103 | 104 | def do_weight_init(self, weight_init_name): 105 | func = WEIGHT_INIT_DICT[weight_init_name] 106 | for (_, param) in self.named_parameters(): 107 | if param.dim() > 1: # skips batchnorm/layernorm 108 | func(param) 109 | 110 | def forward(self, x): 111 | output = self.layers(x) 112 | return output 113 | 114 | 115 | def get_clones(module, N): 116 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 117 | -------------------------------------------------------------------------------- /model/utils/position_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | from utils.pc_util import shift_scale_points 11 | 12 | 13 | class PositionEmbeddingCoordsSine(nn.Module): 14 | def __init__( 15 | self, 16 | temperature=10000, 17 | normalize=False, 18 | scale=None, 19 | pos_type="fourier", 20 | d_pos=None, 21 | d_in=3, 22 | gauss_scale=1.0, 23 | ): 24 | super().__init__() 25 | self.temperature = temperature 26 | self.normalize = normalize 27 | if scale is not None and normalize is False: 28 | raise ValueError("normalize should be True if scale is passed") 29 | if scale is None: 30 | scale = 2 * math.pi 31 | assert pos_type in ["sine", "fourier"] 32 | self.pos_type = pos_type 33 | self.scale = scale 34 | if pos_type == "fourier": 35 | assert d_pos is not None 36 | assert d_pos % 2 == 0 37 | # define a gaussian matrix input_ch -> output_ch 38 | B = torch.empty((d_in, d_pos // 2)).normal_() 39 | B *= gauss_scale 40 | self.register_buffer("gauss_B", B) 41 | self.d_pos = d_pos 42 | 43 | def get_sine_embeddings(self, xyz, num_channels, input_range): 44 | # clone coords so that shift/scale operations do not affect original tensor 45 | orig_xyz = xyz 46 | xyz = orig_xyz.clone() 47 | 48 | ncoords = xyz.shape[1] 49 | if self.normalize: 50 | xyz = shift_scale_points(xyz, src_range=input_range) 51 | 52 | ndim = num_channels // xyz.shape[2] 53 | if ndim % 2 != 0: 54 | ndim -= 1 55 | # automatically handle remainder by assiging it to the first dim 56 | rems = num_channels - (ndim * xyz.shape[2]) 57 | 58 | assert ( 59 | ndim % 2 == 0 60 | ), f"Cannot handle odd sized ndim={ndim} where num_channels={num_channels} and xyz={xyz.shape}" 61 | 62 | final_embeds = [] 63 | prev_dim = 0 64 | 65 | for d in range(xyz.shape[2]): 66 | cdim = ndim 67 | if rems > 0: 68 | # add remainder in increments of two to maintain even size 69 | cdim += 2 70 | rems -= 2 71 | 72 | if cdim != prev_dim: 73 | dim_t = torch.arange(cdim, dtype=torch.float32, device=xyz.device) 74 | dim_t = self.temperature ** (2 * (dim_t // 2) / cdim) 75 | 76 | # create batch x cdim x nccords embedding 77 | raw_pos = xyz[:, :, d] 78 | if self.scale: 79 | raw_pos *= self.scale 80 | pos = raw_pos[:, :, None] / dim_t 81 | pos = torch.stack( 82 | (pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=3 83 | ).flatten(2) 84 | final_embeds.append(pos) 85 | prev_dim = cdim 86 | 87 | final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1) 88 | return final_embeds 89 | 90 | def get_fourier_embeddings(self, xyz, num_channels=None, input_range=None): 91 | # Follows - https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html 92 | 93 | if num_channels is None: 94 | num_channels = self.gauss_B.shape[1] * 2 95 | 96 | bsize, npoints = xyz.shape[0], xyz.shape[1] 97 | assert num_channels > 0 and num_channels % 2 == 0 98 | d_in, max_d_out = self.gauss_B.shape[0], self.gauss_B.shape[1] 99 | d_out = num_channels // 2 100 | assert d_out <= max_d_out 101 | assert d_in == xyz.shape[-1] 102 | 103 | # clone coords so that shift/scale operations do not affect original tensor 104 | orig_xyz = xyz 105 | xyz = orig_xyz.clone() 106 | 107 | ncoords = xyz.shape[1] 108 | if self.normalize: 109 | xyz = shift_scale_points(xyz, src_range=input_range) 110 | 111 | xyz *= 2 * np.pi 112 | xyz_proj = torch.mm(xyz.view(-1, d_in), self.gauss_B[:, :d_out]).view( 113 | bsize, npoints, d_out 114 | ) 115 | final_embeds = [xyz_proj.sin(), xyz_proj.cos()] 116 | 117 | # return batch x d_pos x npoints embedding 118 | final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1).contiguous() 119 | return final_embeds 120 | 121 | def forward(self, xyz, num_channels=None, input_range=None): 122 | assert isinstance(xyz, torch.Tensor) 123 | assert xyz.ndim == 3 124 | # xyz is batch x npoints x 3 125 | if self.pos_type == "sine": 126 | with torch.no_grad(): 127 | return self.get_sine_embeddings(xyz, num_channels, input_range) 128 | elif self.pos_type == "fourier": 129 | with torch.no_grad(): 130 | return self.get_fourier_embeddings(xyz, num_channels, input_range) 131 | else: 132 | raise ValueError(f"Unknown {self.pos_type}") 133 | 134 | def extra_repr(self): 135 | st = f"type={self.pos_type}, scale={self.scale}, normalize={self.normalize}" 136 | if hasattr(self, "gauss_B"): 137 | st += ( 138 | f", gaussB={self.gauss_B.shape}, gaussBsum={self.gauss_B.sum().item()}" 139 | ) 140 | return st 141 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.4.1 2 | hydra-core==1.3.2 3 | transforms3d==0.4.1 4 | lxml==4.9.2 5 | trimesh==4.0.5 6 | scipy==1.10.0 7 | UMNN==1.68 8 | healpy==1.16.2 9 | omegaconf==2.2.2 10 | plotly==5.14.1 11 | tensorboard==2.11.2 12 | Pillow==10.0.0 13 | multimethod==1.9.1 14 | ftfy 15 | regex 16 | tqdm 17 | git+https://github.com/openai/CLIP.git 18 | git+https://github.com/facebookresearch/pytorch3d.git@v0.7.2 19 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 10 | const int nsample); 11 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/include/cylinder_query.h: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #pragma once 4 | #include 5 | 6 | at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax, 7 | const int nsample); 8 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 13 | at::Tensor weight); 14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 15 | at::Tensor weight, const int m); 16 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 12 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "utils.h" 8 | 9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 10 | int nsample, const float *new_xyz, 11 | const float *xyz, int *idx); 12 | 13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 14 | const int nsample) { 15 | CHECK_CONTIGUOUS(new_xyz); 16 | CHECK_CONTIGUOUS(xyz); 17 | CHECK_IS_FLOAT(new_xyz); 18 | CHECK_IS_FLOAT(xyz); 19 | 20 | if (new_xyz.type().is_cuda()) { 21 | CHECK_CUDA(xyz); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, nsample, new_xyz.data(), 31 | xyz.data(), idx.data()); 32 | } else { 33 | TORCH_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 13 | // output: idx(b, m, nsample) 14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 15 | int nsample, 16 | const float *__restrict__ new_xyz, 17 | const float *__restrict__ xyz, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | xyz += batch_index * n * 3; 21 | new_xyz += batch_index * m * 3; 22 | idx += m * nsample * batch_index; 23 | 24 | int index = threadIdx.x; 25 | int stride = blockDim.x; 26 | 27 | float radius2 = radius * radius; 28 | for (int j = index; j < m; j += stride) { 29 | float new_x = new_xyz[j * 3 + 0]; 30 | float new_y = new_xyz[j * 3 + 1]; 31 | float new_z = new_xyz[j * 3 + 2]; 32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 33 | float x = xyz[k * 3 + 0]; 34 | float y = xyz[k * 3 + 1]; 35 | float z = xyz[k * 3 + 2]; 36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 37 | (new_z - z) * (new_z - z); 38 | if (d2 < radius2) { 39 | if (cnt == 0) { 40 | for (int l = 0; l < nsample; ++l) { 41 | idx[j * nsample + l] = k; 42 | } 43 | } 44 | idx[j * nsample + cnt] = k; 45 | ++cnt; 46 | } 47 | } 48 | } 49 | } 50 | 51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 52 | int nsample, const float *new_xyz, 53 | const float *xyz, int *idx) { 54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 55 | query_ball_point_kernel<<>>( 56 | b, n, m, radius, nsample, new_xyz, xyz, idx); 57 | 58 | CUDA_CHECK_ERRORS(); 59 | } 60 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "group_points.h" 8 | #include "interpolate.h" 9 | #include "sampling.h" 10 | #include "cylinder_query.h" 11 | 12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 13 | m.def("gather_points", &gather_points); 14 | m.def("gather_points_grad", &gather_points_grad); 15 | m.def("furthest_point_sampling", &furthest_point_sampling); 16 | 17 | m.def("three_nn", &three_nn); 18 | m.def("three_interpolate", &three_interpolate); 19 | m.def("three_interpolate_grad", &three_interpolate_grad); 20 | 21 | m.def("ball_query", &ball_query); 22 | 23 | m.def("group_points", &group_points); 24 | m.def("group_points_grad", &group_points_grad); 25 | 26 | m.def("cylinder_query", &cylinder_query); 27 | } 28 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/src/cylinder_query.cpp: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #include "cylinder_query.h" 4 | #include "utils.h" 5 | 6 | void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax, 7 | int nsample, const float *new_xyz, 8 | const float *xyz, const float *rot, int *idx); 9 | 10 | at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax, 11 | const int nsample) { 12 | CHECK_CONTIGUOUS(new_xyz); 13 | CHECK_CONTIGUOUS(xyz); 14 | CHECK_CONTIGUOUS(rot); 15 | CHECK_IS_FLOAT(new_xyz); 16 | CHECK_IS_FLOAT(xyz); 17 | CHECK_IS_FLOAT(rot); 18 | 19 | if (new_xyz.type().is_cuda()) { 20 | CHECK_CUDA(xyz); 21 | CHECK_CUDA(rot); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_cylinder_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, hmin, hmax, nsample, new_xyz.data(), 31 | xyz.data(), rot.data(), idx.data()); 32 | } else { 33 | TORCH_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/src/cylinder_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "cuda_utils.h" 8 | 9 | __global__ void query_cylinder_point_kernel(int b, int n, int m, float radius, float hmin, float hmax, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | const float *__restrict__ rot, 14 | int *__restrict__ idx) { 15 | int batch_index = blockIdx.x; 16 | xyz += batch_index * n * 3; 17 | new_xyz += batch_index * m * 3; 18 | rot += batch_index * m * 9; 19 | idx += m * nsample * batch_index; 20 | 21 | int index = threadIdx.x; 22 | int stride = blockDim.x; 23 | 24 | float radius2 = radius * radius; 25 | for (int j = index; j < m; j += stride) { 26 | float new_x = new_xyz[j * 3 + 0]; 27 | float new_y = new_xyz[j * 3 + 1]; 28 | float new_z = new_xyz[j * 3 + 2]; 29 | float r0 = rot[j * 9 + 0]; 30 | float r1 = rot[j * 9 + 1]; 31 | float r2 = rot[j * 9 + 2]; 32 | float r3 = rot[j * 9 + 3]; 33 | float r4 = rot[j * 9 + 4]; 34 | float r5 = rot[j * 9 + 5]; 35 | float r6 = rot[j * 9 + 6]; 36 | float r7 = rot[j * 9 + 7]; 37 | float r8 = rot[j * 9 + 8]; 38 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 39 | float x = xyz[k * 3 + 0] - new_x; 40 | float y = xyz[k * 3 + 1] - new_y; 41 | float z = xyz[k * 3 + 2] - new_z; 42 | float x_rot = r0 * x + r3 * y + r6 * z; 43 | float y_rot = r1 * x + r4 * y + r7 * z; 44 | float z_rot = r2 * x + r5 * y + r8 * z; 45 | float d2 = y_rot * y_rot + z_rot * z_rot; 46 | if (d2 < radius2 && x_rot > hmin && x_rot < hmax) { 47 | if (cnt == 0) { 48 | for (int l = 0; l < nsample; ++l) { 49 | idx[j * nsample + l] = k; 50 | } 51 | } 52 | idx[j * nsample + cnt] = k; 53 | ++cnt; 54 | } 55 | } 56 | } 57 | } 58 | 59 | void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax, 60 | int nsample, const float *new_xyz, 61 | const float *xyz, const float *rot, int *idx) { 62 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 63 | query_cylinder_point_kernel<<>>( 64 | b, n, m, radius, hmin, hmax, nsample, new_xyz, xyz, rot, idx); 65 | 66 | CUDA_CHECK_ERRORS(); 67 | } 68 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "group_points.h" 7 | #include "utils.h" 8 | 9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 10 | const float *points, const int *idx, 11 | float *out); 12 | 13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 14 | int nsample, const float *grad_out, 15 | const int *idx, float *grad_points); 16 | 17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 18 | CHECK_CONTIGUOUS(points); 19 | CHECK_CONTIGUOUS(idx); 20 | CHECK_IS_FLOAT(points); 21 | CHECK_IS_INT(idx); 22 | 23 | if (points.type().is_cuda()) { 24 | CHECK_CUDA(idx); 25 | } 26 | 27 | at::Tensor output = 28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 29 | at::device(points.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (points.type().is_cuda()) { 32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 33 | idx.size(1), idx.size(2), points.data(), 34 | idx.data(), output.data()); 35 | } else { 36 | TORCH_CHECK(false, "CPU not supported"); 37 | } 38 | 39 | return output; 40 | } 41 | 42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 43 | CHECK_CONTIGUOUS(grad_out); 44 | CHECK_CONTIGUOUS(idx); 45 | CHECK_IS_FLOAT(grad_out); 46 | CHECK_IS_INT(idx); 47 | 48 | if (grad_out.type().is_cuda()) { 49 | CHECK_CUDA(idx); 50 | } 51 | 52 | at::Tensor output = 53 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 54 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 55 | 56 | if (grad_out.type().is_cuda()) { 57 | group_points_grad_kernel_wrapper( 58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 59 | grad_out.data(), idx.data(), output.data()); 60 | } else { 61 | TORCH_CHECK(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, npoints, nsample) 12 | // output: out(b, c, npoints, nsample) 13 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 14 | int nsample, 15 | const float *__restrict__ points, 16 | const int *__restrict__ idx, 17 | float *__restrict__ out) { 18 | int batch_index = blockIdx.x; 19 | points += batch_index * n * c; 20 | idx += batch_index * npoints * nsample; 21 | out += batch_index * npoints * nsample * c; 22 | 23 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 24 | const int stride = blockDim.y * blockDim.x; 25 | for (int i = index; i < c * npoints; i += stride) { 26 | const int l = i / npoints; 27 | const int j = i % npoints; 28 | for (int k = 0; k < nsample; ++k) { 29 | int ii = idx[j * nsample + k]; 30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 31 | } 32 | } 33 | } 34 | 35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 36 | const float *points, const int *idx, 37 | float *out) { 38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 39 | 40 | group_points_kernel<<>>( 41 | b, c, n, npoints, nsample, points, idx, out); 42 | 43 | CUDA_CHECK_ERRORS(); 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points) { 74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 75 | 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | CUDA_CHECK_ERRORS(); 80 | } 81 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "interpolate.h" 7 | #include "utils.h" 8 | 9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 10 | const float *known, float *dist2, int *idx); 11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 12 | const float *points, const int *idx, 13 | const float *weight, float *out); 14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 15 | const float *grad_out, 16 | const int *idx, const float *weight, 17 | float *grad_points); 18 | 19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 20 | CHECK_CONTIGUOUS(unknowns); 21 | CHECK_CONTIGUOUS(knows); 22 | CHECK_IS_FLOAT(unknowns); 23 | CHECK_IS_FLOAT(knows); 24 | 25 | if (unknowns.type().is_cuda()) { 26 | CHECK_CUDA(knows); 27 | } 28 | 29 | at::Tensor idx = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 32 | at::Tensor dist2 = 33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 34 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 35 | 36 | if (unknowns.type().is_cuda()) { 37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 38 | unknowns.data(), knows.data(), 39 | dist2.data(), idx.data()); 40 | } else { 41 | TORCH_CHECK(false, "CPU not supported"); 42 | } 43 | 44 | return {dist2, idx}; 45 | } 46 | 47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 48 | at::Tensor weight) { 49 | CHECK_CONTIGUOUS(points); 50 | CHECK_CONTIGUOUS(idx); 51 | CHECK_CONTIGUOUS(weight); 52 | CHECK_IS_FLOAT(points); 53 | CHECK_IS_INT(idx); 54 | CHECK_IS_FLOAT(weight); 55 | 56 | if (points.type().is_cuda()) { 57 | CHECK_CUDA(idx); 58 | CHECK_CUDA(weight); 59 | } 60 | 61 | at::Tensor output = 62 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 63 | at::device(points.device()).dtype(at::ScalarType::Float)); 64 | 65 | if (points.type().is_cuda()) { 66 | three_interpolate_kernel_wrapper( 67 | points.size(0), points.size(1), points.size(2), idx.size(1), 68 | points.data(), idx.data(), weight.data(), 69 | output.data()); 70 | } else { 71 | TORCH_CHECK(false, "CPU not supported"); 72 | } 73 | 74 | return output; 75 | } 76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 77 | at::Tensor weight, const int m) { 78 | CHECK_CONTIGUOUS(grad_out); 79 | CHECK_CONTIGUOUS(idx); 80 | CHECK_CONTIGUOUS(weight); 81 | CHECK_IS_FLOAT(grad_out); 82 | CHECK_IS_INT(idx); 83 | CHECK_IS_FLOAT(weight); 84 | 85 | if (grad_out.type().is_cuda()) { 86 | CHECK_CUDA(idx); 87 | CHECK_CUDA(weight); 88 | } 89 | 90 | at::Tensor output = 91 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 92 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 93 | 94 | if (grad_out.type().is_cuda()) { 95 | three_interpolate_grad_kernel_wrapper( 96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 97 | grad_out.data(), idx.data(), weight.data(), 98 | output.data()); 99 | } else { 100 | TORCH_CHECK(false, "CPU not supported"); 101 | } 102 | 103 | return output; 104 | } 105 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: unknown(b, n, 3) known(b, m, 3) 13 | // output: dist2(b, n, 3), idx(b, n, 3) 14 | __global__ void three_nn_kernel(int b, int n, int m, 15 | const float *__restrict__ unknown, 16 | const float *__restrict__ known, 17 | float *__restrict__ dist2, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | unknown += batch_index * n * 3; 21 | known += batch_index * m * 3; 22 | dist2 += batch_index * n * 3; 23 | idx += batch_index * n * 3; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | for (int j = index; j < n; j += stride) { 28 | float ux = unknown[j * 3 + 0]; 29 | float uy = unknown[j * 3 + 1]; 30 | float uz = unknown[j * 3 + 2]; 31 | 32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 33 | int besti1 = 0, besti2 = 0, besti3 = 0; 34 | for (int k = 0; k < m; ++k) { 35 | float x = known[k * 3 + 0]; 36 | float y = known[k * 3 + 1]; 37 | float z = known[k * 3 + 2]; 38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 39 | if (d < best1) { 40 | best3 = best2; 41 | besti3 = besti2; 42 | best2 = best1; 43 | besti2 = besti1; 44 | best1 = d; 45 | besti1 = k; 46 | } else if (d < best2) { 47 | best3 = best2; 48 | besti3 = besti2; 49 | best2 = d; 50 | besti2 = k; 51 | } else if (d < best3) { 52 | best3 = d; 53 | besti3 = k; 54 | } 55 | } 56 | dist2[j * 3 + 0] = best1; 57 | dist2[j * 3 + 1] = best2; 58 | dist2[j * 3 + 2] = best3; 59 | 60 | idx[j * 3 + 0] = besti1; 61 | idx[j * 3 + 1] = besti2; 62 | idx[j * 3 + 2] = besti3; 63 | } 64 | } 65 | 66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 67 | const float *known, float *dist2, int *idx) { 68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 69 | three_nn_kernel<<>>(b, n, m, unknown, known, 70 | dist2, idx); 71 | 72 | CUDA_CHECK_ERRORS(); 73 | } 74 | 75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 76 | // output: out(b, c, n) 77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 78 | const float *__restrict__ points, 79 | const int *__restrict__ idx, 80 | const float *__restrict__ weight, 81 | float *__restrict__ out) { 82 | int batch_index = blockIdx.x; 83 | points += batch_index * m * c; 84 | 85 | idx += batch_index * n * 3; 86 | weight += batch_index * n * 3; 87 | 88 | out += batch_index * n * c; 89 | 90 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 91 | const int stride = blockDim.y * blockDim.x; 92 | for (int i = index; i < c * n; i += stride) { 93 | const int l = i / n; 94 | const int j = i % n; 95 | float w1 = weight[j * 3 + 0]; 96 | float w2 = weight[j * 3 + 1]; 97 | float w3 = weight[j * 3 + 2]; 98 | 99 | int i1 = idx[j * 3 + 0]; 100 | int i2 = idx[j * 3 + 1]; 101 | int i3 = idx[j * 3 + 2]; 102 | 103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 104 | points[l * m + i3] * w3; 105 | } 106 | } 107 | 108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 109 | const float *points, const int *idx, 110 | const float *weight, float *out) { 111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 112 | three_interpolate_kernel<<>>( 113 | b, c, m, n, points, idx, weight, out); 114 | 115 | CUDA_CHECK_ERRORS(); 116 | } 117 | 118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 119 | // output: grad_points(b, c, m) 120 | 121 | __global__ void three_interpolate_grad_kernel( 122 | int b, int c, int n, int m, const float *__restrict__ grad_out, 123 | const int *__restrict__ idx, const float *__restrict__ weight, 124 | float *__restrict__ grad_points) { 125 | int batch_index = blockIdx.x; 126 | grad_out += batch_index * n * c; 127 | idx += batch_index * n * 3; 128 | weight += batch_index * n * 3; 129 | grad_points += batch_index * m * c; 130 | 131 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 132 | const int stride = blockDim.y * blockDim.x; 133 | for (int i = index; i < c * n; i += stride) { 134 | const int l = i / n; 135 | const int j = i % n; 136 | float w1 = weight[j * 3 + 0]; 137 | float w2 = weight[j * 3 + 1]; 138 | float w3 = weight[j * 3 + 2]; 139 | 140 | int i1 = idx[j * 3 + 0]; 141 | int i2 = idx[j * 3 + 1]; 142 | int i3 = idx[j * 3 + 2]; 143 | 144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 147 | } 148 | } 149 | 150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 151 | const float *grad_out, 152 | const int *idx, const float *weight, 153 | float *grad_points) { 154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 155 | three_interpolate_grad_kernel<<>>( 156 | b, c, n, m, grad_out, idx, weight, grad_points); 157 | 158 | CUDA_CHECK_ERRORS(); 159 | } 160 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "sampling.h" 7 | #include "utils.h" 8 | 9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 10 | const float *points, const int *idx, 11 | float *out); 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs); 19 | 20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 21 | CHECK_CONTIGUOUS(points); 22 | CHECK_CONTIGUOUS(idx); 23 | CHECK_IS_FLOAT(points); 24 | CHECK_IS_INT(idx); 25 | 26 | if (points.type().is_cuda()) { 27 | CHECK_CUDA(idx); 28 | } 29 | 30 | at::Tensor output = 31 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 32 | at::device(points.device()).dtype(at::ScalarType::Float)); 33 | 34 | if (points.type().is_cuda()) { 35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 36 | idx.size(1), points.data(), 37 | idx.data(), output.data()); 38 | } else { 39 | TORCH_CHECK(false, "CPU not supported"); 40 | } 41 | 42 | return output; 43 | } 44 | 45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 46 | const int n) { 47 | CHECK_CONTIGUOUS(grad_out); 48 | CHECK_CONTIGUOUS(idx); 49 | CHECK_IS_FLOAT(grad_out); 50 | CHECK_IS_INT(idx); 51 | 52 | if (grad_out.type().is_cuda()) { 53 | CHECK_CUDA(idx); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 58 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (grad_out.type().is_cuda()) { 61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 62 | idx.size(1), grad_out.data(), 63 | idx.data(), output.data()); 64 | } else { 65 | TORCH_CHECK(false, "CPU not supported"); 66 | } 67 | 68 | return output; 69 | } 70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 71 | CHECK_CONTIGUOUS(points); 72 | CHECK_IS_FLOAT(points); 73 | 74 | at::Tensor output = 75 | torch::zeros({points.size(0), nsamples}, 76 | at::device(points.device()).dtype(at::ScalarType::Int)); 77 | 78 | at::Tensor tmp = 79 | torch::full({points.size(0), points.size(1)}, 1e10, 80 | at::device(points.device()).dtype(at::ScalarType::Float)); 81 | 82 | if (points.type().is_cuda()) { 83 | furthest_point_sampling_kernel_wrapper( 84 | points.size(0), points.size(1), nsamples, points.data(), 85 | tmp.data(), output.data()); 86 | } else { 87 | TORCH_CHECK(false, "CPU not supported"); 88 | } 89 | 90 | return output; 91 | } 92 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/_ext_src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, m) 12 | // output: out(b, c, m) 13 | __global__ void gather_points_kernel(int b, int c, int n, int m, 14 | const float *__restrict__ points, 15 | const int *__restrict__ idx, 16 | float *__restrict__ out) { 17 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 18 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 19 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 20 | int a = idx[i * m + j]; 21 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 22 | } 23 | } 24 | } 25 | } 26 | 27 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 28 | const float *points, const int *idx, 29 | float *out) { 30 | gather_points_kernel<<>>(b, c, n, npoints, 32 | points, idx, out); 33 | 34 | CUDA_CHECK_ERRORS(); 35 | } 36 | 37 | // input: grad_out(b, c, m) idx(b, m) 38 | // output: grad_points(b, c, n) 39 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 40 | const float *__restrict__ grad_out, 41 | const int *__restrict__ idx, 42 | float *__restrict__ grad_points) { 43 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 44 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 45 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 46 | int a = idx[i * m + j]; 47 | atomicAdd(grad_points + (i * c + l) * n + a, 48 | grad_out[(i * c + l) * m + j]); 49 | } 50 | } 51 | } 52 | } 53 | 54 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 55 | const float *grad_out, const int *idx, 56 | float *grad_points) { 57 | gather_points_grad_kernel<<>>( 59 | b, c, n, npoints, grad_out, idx, grad_points); 60 | 61 | CUDA_CHECK_ERRORS(); 62 | } 63 | 64 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 65 | int idx1, int idx2) { 66 | const float v1 = dists[idx1], v2 = dists[idx2]; 67 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 68 | dists[idx1] = max(v1, v2); 69 | dists_i[idx1] = v2 > v1 ? i2 : i1; 70 | } 71 | 72 | // Input dataset: (b, n, 3), tmp: (b, n) 73 | // Ouput idxs (b, m) 74 | template 75 | __global__ void furthest_point_sampling_kernel( 76 | int b, int n, int m, const float *__restrict__ dataset, 77 | float *__restrict__ temp, int *__restrict__ idxs) { 78 | if (m <= 0) return; 79 | __shared__ float dists[block_size]; 80 | __shared__ int dists_i[block_size]; 81 | 82 | int batch_index = blockIdx.x; 83 | dataset += batch_index * n * 3; 84 | temp += batch_index * n; 85 | idxs += batch_index * m; 86 | 87 | int tid = threadIdx.x; 88 | const int stride = block_size; 89 | 90 | int old = 0; 91 | if (threadIdx.x == 0) idxs[0] = old; 92 | 93 | __syncthreads(); 94 | for (int j = 1; j < m; j++) { 95 | int besti = 0; 96 | float best = -1; 97 | float x1 = dataset[old * 3 + 0]; 98 | float y1 = dataset[old * 3 + 1]; 99 | float z1 = dataset[old * 3 + 2]; 100 | for (int k = tid; k < n; k += stride) { 101 | float x2, y2, z2; 102 | x2 = dataset[k * 3 + 0]; 103 | y2 = dataset[k * 3 + 1]; 104 | z2 = dataset[k * 3 + 2]; 105 | // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 106 | // if (mag <= 1e-3) continue; 107 | 108 | float d = 109 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 110 | 111 | float d2 = min(d, temp[k]); 112 | temp[k] = d2; 113 | besti = d2 > best ? k : besti; 114 | best = d2 > best ? d2 : best; 115 | } 116 | dists[tid] = best; 117 | dists_i[tid] = besti; 118 | __syncthreads(); 119 | 120 | if (block_size >= 512) { 121 | if (tid < 256) { 122 | __update(dists, dists_i, tid, tid + 256); 123 | } 124 | __syncthreads(); 125 | } 126 | if (block_size >= 256) { 127 | if (tid < 128) { 128 | __update(dists, dists_i, tid, tid + 128); 129 | } 130 | __syncthreads(); 131 | } 132 | if (block_size >= 128) { 133 | if (tid < 64) { 134 | __update(dists, dists_i, tid, tid + 64); 135 | } 136 | __syncthreads(); 137 | } 138 | if (block_size >= 64) { 139 | if (tid < 32) { 140 | __update(dists, dists_i, tid, tid + 32); 141 | } 142 | __syncthreads(); 143 | } 144 | if (block_size >= 32) { 145 | if (tid < 16) { 146 | __update(dists, dists_i, tid, tid + 16); 147 | } 148 | __syncthreads(); 149 | } 150 | if (block_size >= 16) { 151 | if (tid < 8) { 152 | __update(dists, dists_i, tid, tid + 8); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 8) { 157 | if (tid < 4) { 158 | __update(dists, dists_i, tid, tid + 4); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 4) { 163 | if (tid < 2) { 164 | __update(dists, dists_i, tid, tid + 2); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 2) { 169 | if (tid < 1) { 170 | __update(dists, dists_i, tid, tid + 1); 171 | } 172 | __syncthreads(); 173 | } 174 | 175 | old = dists_i[0]; 176 | if (tid == 0) idxs[j] = old; 177 | } 178 | } 179 | 180 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 181 | const float *dataset, float *temp, 182 | int *idxs) { 183 | unsigned int n_threads = opt_n_threads(n); 184 | 185 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 186 | 187 | switch (n_threads) { 188 | case 512: 189 | furthest_point_sampling_kernel<512> 190 | <<>>(b, n, m, dataset, temp, idxs); 191 | break; 192 | case 256: 193 | furthest_point_sampling_kernel<256> 194 | <<>>(b, n, m, dataset, temp, idxs); 195 | break; 196 | case 128: 197 | furthest_point_sampling_kernel<128> 198 | <<>>(b, n, m, dataset, temp, idxs); 199 | break; 200 | case 64: 201 | furthest_point_sampling_kernel<64> 202 | <<>>(b, n, m, dataset, temp, idxs); 203 | break; 204 | case 32: 205 | furthest_point_sampling_kernel<32> 206 | <<>>(b, n, m, dataset, temp, idxs); 207 | break; 208 | case 16: 209 | furthest_point_sampling_kernel<16> 210 | <<>>(b, n, m, dataset, temp, idxs); 211 | break; 212 | case 8: 213 | furthest_point_sampling_kernel<8> 214 | <<>>(b, n, m, dataset, temp, idxs); 215 | break; 216 | case 4: 217 | furthest_point_sampling_kernel<4> 218 | <<>>(b, n, m, dataset, temp, idxs); 219 | break; 220 | case 2: 221 | furthest_point_sampling_kernel<2> 222 | <<>>(b, n, m, dataset, temp, idxs); 223 | break; 224 | case 1: 225 | furthest_point_sampling_kernel<1> 226 | <<>>(b, n, m, dataset, temp, idxs); 227 | break; 228 | default: 229 | furthest_point_sampling_kernel<512> 230 | <<>>(b, n, m, dataset, temp, idxs); 231 | } 232 | 233 | CUDA_CHECK_ERRORS(); 234 | } 235 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | import torch 8 | import torch.nn as nn 9 | from typing import List, Tuple 10 | 11 | class SharedMLP(nn.Sequential): 12 | 13 | def __init__( 14 | self, 15 | args: List[int], 16 | *, 17 | bn: bool = False, 18 | activation=nn.ReLU(inplace=True), 19 | preact: bool = False, 20 | first: bool = False, 21 | name: str = "" 22 | ): 23 | super().__init__() 24 | 25 | for i in range(len(args) - 1): 26 | self.add_module( 27 | name + 'layer{}'.format(i), 28 | Conv2d( 29 | args[i], 30 | args[i + 1], 31 | bn=(not first or not preact or (i != 0)) and bn, 32 | activation=activation 33 | if (not first or not preact or (i != 0)) else None, 34 | preact=preact 35 | ) 36 | ) 37 | 38 | 39 | class _BNBase(nn.Sequential): 40 | 41 | def __init__(self, in_size, batch_norm=None, name=""): 42 | super().__init__() 43 | self.add_module(name + "bn", batch_norm(in_size)) 44 | 45 | nn.init.constant_(self[0].weight, 1.0) 46 | nn.init.constant_(self[0].bias, 0) 47 | 48 | 49 | class BatchNorm1d(_BNBase): 50 | 51 | def __init__(self, in_size: int, *, name: str = ""): 52 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 53 | 54 | 55 | class BatchNorm2d(_BNBase): 56 | 57 | def __init__(self, in_size: int, name: str = ""): 58 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 59 | 60 | 61 | class BatchNorm3d(_BNBase): 62 | 63 | def __init__(self, in_size: int, name: str = ""): 64 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 65 | 66 | 67 | class _ConvBase(nn.Sequential): 68 | 69 | def __init__( 70 | self, 71 | in_size, 72 | out_size, 73 | kernel_size, 74 | stride, 75 | padding, 76 | activation, 77 | bn, 78 | init, 79 | conv=None, 80 | batch_norm=None, 81 | bias=True, 82 | preact=False, 83 | name="" 84 | ): 85 | super().__init__() 86 | 87 | bias = bias and (not bn) 88 | conv_unit = conv( 89 | in_size, 90 | out_size, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | padding=padding, 94 | bias=bias 95 | ) 96 | init(conv_unit.weight) 97 | if bias: 98 | nn.init.constant_(conv_unit.bias, 0) 99 | 100 | if bn: 101 | if not preact: 102 | bn_unit = batch_norm(out_size) 103 | else: 104 | bn_unit = batch_norm(in_size) 105 | 106 | if preact: 107 | if bn: 108 | self.add_module(name + 'bn', bn_unit) 109 | 110 | if activation is not None: 111 | self.add_module(name + 'activation', activation) 112 | 113 | self.add_module(name + 'conv', conv_unit) 114 | 115 | if not preact: 116 | if bn: 117 | self.add_module(name + 'bn', bn_unit) 118 | 119 | if activation is not None: 120 | self.add_module(name + 'activation', activation) 121 | 122 | 123 | class Conv1d(_ConvBase): 124 | 125 | def __init__( 126 | self, 127 | in_size: int, 128 | out_size: int, 129 | *, 130 | kernel_size: int = 1, 131 | stride: int = 1, 132 | padding: int = 0, 133 | activation=nn.ReLU(inplace=True), 134 | bn: bool = False, 135 | init=nn.init.kaiming_normal_, 136 | bias: bool = True, 137 | preact: bool = False, 138 | name: str = "" 139 | ): 140 | super().__init__( 141 | in_size, 142 | out_size, 143 | kernel_size, 144 | stride, 145 | padding, 146 | activation, 147 | bn, 148 | init, 149 | conv=nn.Conv1d, 150 | batch_norm=BatchNorm1d, 151 | bias=bias, 152 | preact=preact, 153 | name=name 154 | ) 155 | 156 | 157 | class Conv2d(_ConvBase): 158 | 159 | def __init__( 160 | self, 161 | in_size: int, 162 | out_size: int, 163 | *, 164 | kernel_size: Tuple[int, int] = (1, 1), 165 | stride: Tuple[int, int] = (1, 1), 166 | padding: Tuple[int, int] = (0, 0), 167 | activation=nn.ReLU(inplace=True), 168 | bn: bool = False, 169 | init=nn.init.kaiming_normal_, 170 | bias: bool = True, 171 | preact: bool = False, 172 | name: str = "" 173 | ): 174 | super().__init__( 175 | in_size, 176 | out_size, 177 | kernel_size, 178 | stride, 179 | padding, 180 | activation, 181 | bn, 182 | init, 183 | conv=nn.Conv2d, 184 | batch_norm=BatchNorm2d, 185 | bias=bias, 186 | preact=preact, 187 | name=name 188 | ) 189 | 190 | 191 | class Conv3d(_ConvBase): 192 | 193 | def __init__( 194 | self, 195 | in_size: int, 196 | out_size: int, 197 | *, 198 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 199 | stride: Tuple[int, int, int] = (1, 1, 1), 200 | padding: Tuple[int, int, int] = (0, 0, 0), 201 | activation=nn.ReLU(inplace=True), 202 | bn: bool = False, 203 | init=nn.init.kaiming_normal_, 204 | bias: bool = True, 205 | preact: bool = False, 206 | name: str = "" 207 | ): 208 | super().__init__( 209 | in_size, 210 | out_size, 211 | kernel_size, 212 | stride, 213 | padding, 214 | activation, 215 | bn, 216 | init, 217 | conv=nn.Conv3d, 218 | batch_norm=BatchNorm3d, 219 | bias=bias, 220 | preact=preact, 221 | name=name 222 | ) 223 | 224 | 225 | class FC(nn.Sequential): 226 | 227 | def __init__( 228 | self, 229 | in_size: int, 230 | out_size: int, 231 | *, 232 | activation=nn.ReLU(inplace=True), 233 | bn: bool = False, 234 | init=None, 235 | preact: bool = False, 236 | name: str = "" 237 | ): 238 | super().__init__() 239 | 240 | fc = nn.Linear(in_size, out_size, bias=not bn) 241 | if init is not None: 242 | init(fc.weight) 243 | if not bn: 244 | nn.init.constant_(fc.bias, 0) 245 | 246 | if preact: 247 | if bn: 248 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 249 | 250 | if activation is not None: 251 | self.add_module(name + 'activation', activation) 252 | 253 | self.add_module(name + 'fc', fc) 254 | 255 | if not preact: 256 | if bn: 257 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 258 | 259 | if activation is not None: 260 | self.add_module(name + 'activation', activation) 261 | 262 | def set_bn_momentum_default(bn_momentum): 263 | 264 | def fn(m): 265 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 266 | m.momentum = bn_momentum 267 | 268 | return fn 269 | 270 | 271 | class BNMomentumScheduler(object): 272 | 273 | def __init__( 274 | self, model, bn_lambda, last_epoch=-1, 275 | setter=set_bn_momentum_default 276 | ): 277 | if not isinstance(model, nn.Module): 278 | raise RuntimeError( 279 | "Class '{}' is not a PyTorch nn Module".format( 280 | type(model).__name__ 281 | ) 282 | ) 283 | 284 | self.model = model 285 | self.setter = setter 286 | self.lmbd = bn_lambda 287 | 288 | self.step(last_epoch + 1) 289 | self.last_epoch = last_epoch 290 | 291 | def step(self, epoch=None): 292 | if epoch is None: 293 | epoch = self.last_epoch + 1 294 | 295 | self.last_epoch = epoch 296 | self.model.apply(self.setter(self.lmbd(epoch))) 297 | 298 | 299 | -------------------------------------------------------------------------------- /thirdparty/pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | import os 10 | ROOT = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | _ext_src_root = "_ext_src" 13 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 14 | "{}/src/*.cu".format(_ext_src_root) 15 | ) 16 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 17 | 18 | setup( 19 | name='pointnet2', 20 | ext_modules=[ 21 | CUDAExtension( 22 | name='pointnet2._ext', 23 | sources=_ext_sources, 24 | extra_compile_args={ 25 | "cxx": ["-O2", "-I{}".format("{}/{}/include".format(ROOT, _ext_src_root))], 26 | "nvcc": ["-O2", "-I{}".format("{}/{}/include".format(ROOT, _ext_src_root))], 27 | }, 28 | ) 29 | ], 30 | cmdclass={ 31 | 'build_ext': BuildExtension 32 | } 33 | ) 34 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.egg-info 3 | __pycache__ 4 | temp* -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Robot Kinematics 2 | - Parallel and differentiable forward kinematics (FK) and Jacobian calculation 3 | - Load robot description from URDF, SDF, and MJCF formats 4 | 5 | # Usage 6 | Clone repository somewhere, then `pip3 install -e .` to install in editable mode. 7 | 8 | See `tests` for code samples; some are also shown here. 9 | 10 | ## Forward Kinematics (FK) 11 | ```python 12 | import math 13 | import pytorch_kinematics as pk 14 | 15 | # load robot description from URDF and specify end effector link 16 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 17 | # prints out the (nested) tree of links 18 | print(chain) 19 | # prints out list of joint names 20 | print(chain.get_joint_parameter_names()) 21 | 22 | # specify joint values (can do so in many forms) 23 | th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0] 24 | # do forward kinematics and get transform objects; end_only=False gives a dictionary of transforms for all links 25 | ret = chain.forward_kinematics(th, end_only=False) 26 | # look up the transform for a specific link 27 | tg = ret['lbr_iiwa_link_7'] 28 | # get transform matrix (1,4,4), then convert to separate position and unit quaternion 29 | m = tg.get_matrix() 30 | pos = m[:, :3, 3] 31 | rot = pk.matrix_to_quaternion(m[:, :3, :3]) 32 | ``` 33 | 34 | We can parallelize FK by passing in 2D joint values, and also use CUDA if available 35 | ```python 36 | import torch 37 | import pytorch_kinematics as pk 38 | 39 | d = "cuda" if torch.cuda.is_available() else "cpu" 40 | dtype = torch.float64 41 | 42 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 43 | chain = chain.to(dtype=dtype, device=d) 44 | 45 | N = 1000 46 | th_batch = torch.rand(N, len(chain.get_joint_parameter_names()), dtype=dtype, device=d) 47 | 48 | # order of magnitudes faster when doing FK in parallel 49 | # elapsed 0.008678913116455078s for N=1000 when parallel 50 | # (N,4,4) transform matrix; only the one for the end effector is returned since end_only=True by default 51 | tg_batch = chain.forward_kinematics(th_batch) 52 | 53 | # elapsed 8.44686508178711s for N=1000 when serial 54 | for i in range(N): 55 | tg = chain.forward_kinematics(th_batch[i]) 56 | ``` 57 | 58 | We can compute gradients through the FK 59 | ```python 60 | import torch 61 | import math 62 | import pytorch_kinematics as pk 63 | 64 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 65 | 66 | # require gradient through the input joint values 67 | th = torch.tensor([0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0], requires_grad=True) 68 | tg = chain.forward_kinematics(th) 69 | m = tg.get_matrix() 70 | pos = m[:, :3, 3] 71 | pos.norm().backward() 72 | # now th.grad is populated 73 | ``` 74 | 75 | We can load SDF and MJCF descriptions too, and pass in joint values via a dictionary (unspecified joints get th=0) for non-serial chains 76 | ```python 77 | import math 78 | import torch 79 | import pytorch_kinematics as pk 80 | 81 | chain = pk.build_chain_from_sdf(open("simple_arm.sdf").read()) 82 | ret = chain.forward_kinematics({'arm_elbow_pan_joint': math.pi / 2.0, 'arm_wrist_lift_joint': -0.5}) 83 | # recall that we specify joint values and get link transforms 84 | tg = ret['arm_wrist_roll'] 85 | 86 | # can also do this in parallel 87 | N = 100 88 | ret = chain.forward_kinematics({'arm_elbow_pan_joint': torch.rand(N, 1), 'arm_wrist_lift_joint': torch.rand(N, 1)}) 89 | # (N, 4, 4) transform object 90 | tg = ret['arm_wrist_roll'] 91 | 92 | # building the robot from a MJCF file 93 | chain = pk.build_chain_from_mjcf(open("ant.xml").read()) 94 | print(chain) 95 | print(chain.get_joint_parameter_names()) 96 | th = {'hip_1': 1.0, 'ankle_1': 1} 97 | ret = chain.forward_kinematics(th) 98 | 99 | chain = pk.build_chain_from_mjcf(open("humanoid.xml").read()) 100 | print(chain) 101 | print(chain.get_joint_parameter_names()) 102 | th = {'left_knee': 0.0, 'right_knee': 0.0} 103 | ret = chain.forward_kinematics(th) 104 | ``` 105 | 106 | ## Jacobian calculation 107 | The Jacobian (in the kinematics context) is a matrix describing how the end effector changes with respect to joint value changes 108 | (where ![dx](https://latex.codecogs.com/png.latex?%5Cinline%20%5Cdot%7Bx%7D) is the twist, or stacked velocity and angular velocity): 109 | ![jacobian](https://latex.codecogs.com/png.latex?%5Cinline%20%5Cdot%7Bx%7D%3DJ%5Cdot%7Bq%7D) 110 | 111 | For `SerialChain` we provide a differentiable and parallelizable method for computing the Jacobian with respect to the base frame. 112 | ```python 113 | import math 114 | import torch 115 | import pytorch_kinematics as pk 116 | 117 | # can convert Chain to SerialChain by choosing end effector frame 118 | chain = pk.build_chain_from_sdf(open("simple_arm.sdf").read()) 119 | # print(chain) to see the available links for use as end effector 120 | # note that any link can be chosen; it doesn't have to be a link with no children 121 | chain = pk.SerialChain(chain, "arm_wrist_roll_frame") 122 | 123 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 124 | th = torch.tensor([0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]) 125 | # (1,6,7) tensor, with 7 corresponding to the DOF of the robot 126 | J = chain.jacobian(th) 127 | 128 | # get Jacobian in parallel and use CUDA if available 129 | N = 1000 130 | d = "cuda" if torch.cuda.is_available() else "cpu" 131 | dtype = torch.float64 132 | 133 | chain = chain.to(dtype=dtype, device=d) 134 | # Jacobian calculation is differentiable 135 | th = torch.rand(N, 7, dtype=dtype, device=d, requires_grad=True) 136 | # (N,6,7) 137 | J = chain.jacobian(th) 138 | 139 | # can get Jacobian at a point offset from the end effector (location is specified in EE link frame) 140 | # by default location is at the origin of the EE frame 141 | loc = torch.rand(N, 3, dtype=dtype, device=d) 142 | J = chain.jacobian(th, locations=loc) 143 | ``` 144 | 145 | The Jacobian can be used to do inverse kinematics. See [IK survey](https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf) 146 | for a survey of ways to do so. Note that IK may be better performed through other means (but doing it through the Jacobian can give an end-to-end differentiable method). 147 | 148 | # Credits 149 | - `pytorch_kinematics/transforms` is extracted from [pytorch3d](https://github.com/facebookresearch/pytorch3d) with minor extensions. 150 | This was done instead of including `pytorch3d` as a dependency because it is hard to install and most of its code is unrelated. 151 | An important difference is that we use left hand multiplied transforms as is convention in robotics (T * pt) instead of their 152 | right hand multiplied transforms. 153 | - `pytorch_kinematics/urdf_parser_py`, and `pytorch_kinematics/mjcf_parser` is extracted from [kinpy](https://github.com/neka-nat/kinpy), as well as the FK logic. 154 | This repository ports the logic to pytorch, parallelizes it, and provides some extensions. 155 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_kinematics.sdf import * 2 | from pytorch_kinematics.urdf import * 3 | from pytorch_kinematics.mjcf import * 4 | from pytorch_kinematics.transforms import * 5 | from pytorch_kinematics.chain import * 6 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/chain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import jacobian 3 | import pytorch_kinematics.transforms as tf 4 | 5 | 6 | def ensure_2d_tensor(th, dtype, device): 7 | if not torch.is_tensor(th): 8 | th = torch.tensor(th, dtype=dtype, device=device) 9 | if len(th.shape) == 0: 10 | N = 1 11 | th = th.view(1,1) 12 | elif len(th.shape) == 1: 13 | N = len(th) 14 | th = th.view(-1, 1) 15 | else: 16 | N = th.shape[0] 17 | return th, N 18 | 19 | 20 | class Chain(object): 21 | def __init__(self, root_frame, dtype=torch.float32, device="cpu"): 22 | self._root = root_frame 23 | self.dtype = dtype 24 | self.device = device 25 | 26 | def to(self, dtype=None, device=None): 27 | if dtype is not None: 28 | self.dtype = dtype 29 | if device is not None: 30 | self.device = device 31 | self._root = self._root.to(dtype=self.dtype, device=self.device) 32 | return self 33 | 34 | def __str__(self): 35 | return str(self._root) 36 | 37 | @staticmethod 38 | def _find_frame_recursive(name, frame): 39 | for child in frame.children: 40 | if child.name == name: 41 | return child 42 | ret = Chain._find_frame_recursive(name, child) 43 | if not ret is None: 44 | return ret 45 | return None 46 | 47 | def find_frame(self, name): 48 | if self._root.name == name: 49 | return self._root 50 | return self._find_frame_recursive(name, self._root) 51 | 52 | @staticmethod 53 | def _find_link_recursive(name, frame): 54 | for child in frame.children: 55 | if child.link.name == name: 56 | return child.link 57 | ret = Chain._find_link_recursive(name, child) 58 | if not ret is None: 59 | return ret 60 | return None 61 | 62 | def find_link(self, name): 63 | if self._root.link.name == name: 64 | return self._root.link 65 | return self._find_link_recursive(name, self._root) 66 | 67 | @staticmethod 68 | def _get_joint_parameter_names(frame, exclude_fixed=True): 69 | joint_names = [] 70 | if not (exclude_fixed and frame.joint.joint_type == "fixed"): 71 | joint_names.append(frame.joint.name) 72 | for child in frame.children: 73 | joint_names.extend(Chain._get_joint_parameter_names(child, exclude_fixed)) 74 | return joint_names 75 | 76 | def get_joint_parameter_names(self, exclude_fixed=True): 77 | names = self._get_joint_parameter_names(self._root, exclude_fixed) 78 | return sorted(set(names), key=names.index) 79 | 80 | def add_frame(self, frame, parent_name): 81 | frame = self.find_frame(parent_name) 82 | if not frame is None: 83 | frame.add_child(frame) 84 | 85 | @staticmethod 86 | def _forward_kinematics(root, th_dict, world=tf.Transform3d()): 87 | link_transforms = {} 88 | 89 | th, N = ensure_2d_tensor(th_dict.get(root.joint.name, 0.0), world.dtype, world.device) 90 | 91 | trans = world.compose(root.get_transform(th.view(N, 1))) 92 | link_transforms[root.link.name] = trans.compose(root.link.offset) 93 | for child in root.children: 94 | link_transforms.update(Chain._forward_kinematics(child, th_dict, trans)) 95 | return link_transforms 96 | 97 | def forward_kinematics(self, th, world=tf.Transform3d()): 98 | if not isinstance(th, dict): 99 | jn = self.get_joint_parameter_names() 100 | assert len(jn) == th.shape[1] 101 | th_dict = dict((j, th[:,i]) for i, j in enumerate(jn)) 102 | else: 103 | th_dict = th 104 | if world.dtype != self.dtype or world.device != self.device: 105 | world = world.to(dtype=self.dtype, device=self.device, copy=True) 106 | return self._forward_kinematics(self._root, th_dict, world) 107 | 108 | 109 | class SerialChain(Chain): 110 | def __init__(self, chain, end_frame_name, root_frame_name="", **kwargs): 111 | if root_frame_name == "": 112 | super(SerialChain, self).__init__(chain._root, **kwargs) 113 | else: 114 | super(SerialChain, self).__init__(chain.find_frame(root_frame_name), **kwargs) 115 | if self._root is None: 116 | raise ValueError("Invalid root frame name %s." % root_frame_name) 117 | self._serial_frames = self._generate_serial_chain_recurse(self._root, end_frame_name) 118 | if self._serial_frames is None: 119 | raise ValueError("Invalid end frame name %s." % end_frame_name) 120 | 121 | @staticmethod 122 | def _generate_serial_chain_recurse(root_frame, end_frame_name): 123 | for child in root_frame.children: 124 | if child.name == end_frame_name: 125 | return [child] 126 | else: 127 | frames = SerialChain._generate_serial_chain_recurse(child, end_frame_name) 128 | if not frames is None: 129 | return [child] + frames 130 | return None 131 | 132 | def get_joint_parameter_names(self, exclude_fixed=True): 133 | names = [] 134 | for f in self._serial_frames: 135 | if exclude_fixed and f.joint.joint_type == 'fixed': 136 | continue 137 | names.append(f.joint.name) 138 | return names 139 | 140 | def forward_kinematics(self, th, world=tf.Transform3d(), end_only=True): 141 | if world.dtype != self.dtype or world.device != self.device: 142 | world = world.to(dtype=self.dtype, device=self.device, copy=True) 143 | th, N = ensure_2d_tensor(th, self.dtype, self.device) 144 | 145 | cnt = 0 146 | link_transforms = {} 147 | trans = tf.Transform3d(matrix=world.get_matrix().repeat(N, 1, 1)) 148 | for f in self._serial_frames: 149 | trans = trans.compose(f.get_transform(th[:, cnt].view(N, 1))) 150 | link_transforms[f.link.name] = trans.compose(f.link.offset) 151 | if f.joint.joint_type != "fixed": 152 | cnt += 1 153 | return link_transforms[self._serial_frames[-1].link.name] if end_only else link_transforms 154 | 155 | def jacobian(self, th, locations=None): 156 | if locations is not None: 157 | locations = tf.Transform3d(pos=locations) 158 | return jacobian.calc_jacobian(self, th, tool=locations) 159 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/frame.py: -------------------------------------------------------------------------------- 1 | import pytorch_kinematics.transforms as tf 2 | import torch 3 | 4 | 5 | class Visual(object): 6 | TYPES = ['box', 'cylinder', 'sphere', 'capsule', 'mesh'] 7 | 8 | def __init__(self, offset=tf.Transform3d(), 9 | geom_type=None, geom_param=None): 10 | self.offset = offset 11 | self.geom_type = geom_type 12 | self.geom_param = geom_param 13 | 14 | def __repr__(self): 15 | return "Visual(offset={0}, geom_type='{1}', geom_param={2})".format(self.offset, 16 | self.geom_type, 17 | self.geom_param) 18 | 19 | 20 | class Link(object): 21 | def __init__(self, name=None, offset=tf.Transform3d(), 22 | visuals=()): 23 | self.name = name 24 | self.offset = offset 25 | self.visuals = visuals 26 | 27 | def to(self, *args, **kwargs): 28 | self.offset = self.offset.to(*args, **kwargs) 29 | return self 30 | 31 | def __repr__(self): 32 | return "Link(name='{0}', offset={1}, visuals={2})".format(self.name, 33 | self.offset, 34 | self.visuals) 35 | 36 | 37 | class Joint(object): 38 | TYPES = ['fixed', 'revolute', 'prismatic'] 39 | 40 | def __init__(self, name=None, offset=tf.Transform3d(), joint_type='fixed', axis=(0.0, 0.0, 1.0), jrange=None, 41 | dtype=torch.float32, device="cpu"): 42 | self.name = name 43 | self.range = jrange 44 | self.offset = offset 45 | if joint_type not in self.TYPES: 46 | raise RuntimeError("joint specified as {} type not, but we only support {}".format(joint_type, self.TYPES)) 47 | self.joint_type = joint_type 48 | if axis is None: 49 | self.axis = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device) 50 | else: 51 | if torch.is_tensor(axis): 52 | self.axis = axis.clone().detach().to(dtype=dtype, device=device) 53 | else: 54 | self.axis = torch.tensor(axis, dtype=dtype, device=device) 55 | # normalize axis to have norm 1 (needed for correct representation scaling with theta) 56 | self.axis = self.axis / self.axis.norm() 57 | 58 | def to(self, *args, **kwargs): 59 | self.axis = self.axis.to(*args, **kwargs) 60 | self.offset = self.offset.to(*args, **kwargs) 61 | return self 62 | 63 | def __repr__(self): 64 | return "Joint(name='{0}', offset={1}, joint_type='{2}', axis={3}, range={4})".format(self.name, 65 | self.offset, 66 | self.joint_type, 67 | self.axis, 68 | self.range) 69 | 70 | 71 | class Frame(object): 72 | def __init__(self, name=None, link=Link(), 73 | joint=Joint(), children=()): 74 | self.name = 'None' if name is None else name 75 | self.link = link 76 | self.joint = joint 77 | self.children = children 78 | 79 | def __str__(self, level=0): 80 | ret = " \t" * level + self.name + "\n" 81 | for child in self.children: 82 | ret += child.__str__(level + 1) 83 | return ret 84 | 85 | def to(self, *args, **kwargs): 86 | self.joint = self.joint.to(*args, **kwargs) 87 | self.link = self.link.to(*args, **kwargs) 88 | self.children = [c.to(*args, **kwargs) for c in self.children] 89 | return self 90 | 91 | def add_child(self, child): 92 | self.children.append(child) 93 | 94 | def is_end(self): 95 | return (len(self.children) == 0) 96 | 97 | def get_transform(self, theta): 98 | dtype = self.joint.axis.dtype 99 | d = self.joint.axis.device 100 | if self.joint.joint_type == 'revolute': 101 | t = tf.Transform3d(rot=tf.axis_angle_to_quaternion(theta * self.joint.axis), dtype=dtype, device=d) 102 | elif self.joint.joint_type == 'prismatic': 103 | t = tf.Transform3d(pos=theta * self.joint.axis, dtype=dtype, device=d) 104 | elif self.joint.joint_type == 'fixed': 105 | t = tf.Transform3d(default_batch_size=theta.shape[0], dtype=dtype, device=d) 106 | else: 107 | raise ValueError("Unsupported joint type %s." % self.joint.joint_type) 108 | return self.joint.offset.compose(t) 109 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/jacobian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_kinematics import transforms 3 | 4 | 5 | def calc_jacobian(serial_chain, th, tool=None): 6 | """ 7 | Return robot Jacobian J in base frame (N,6,DOF) where dot{x} = J dot{q} 8 | The first 3 rows relate the translational velocities and the 9 | last 3 rows relate the angular velocities. 10 | 11 | tool is the transformation wrt the end effector; default is identity. If specified, will have to 12 | specify for each of the N inputs 13 | """ 14 | if not torch.is_tensor(th): 15 | th = torch.tensor(th, dtype=serial_chain.dtype, device=serial_chain.device) 16 | if len(th.shape) <= 1: 17 | N = 1 18 | th = th.view(1, -1) 19 | else: 20 | N = th.shape[0] 21 | ndof = th.shape[1] 22 | 23 | j_fl = torch.zeros((N, 6, ndof), dtype=serial_chain.dtype, device=serial_chain.device) 24 | 25 | if tool is None: 26 | cur_transform = transforms.Transform3d(device=serial_chain.device, 27 | dtype=serial_chain.dtype).get_matrix().repeat(N, 1, 1) 28 | else: 29 | if tool.dtype != serial_chain.dtype or tool.device != serial_chain.device: 30 | tool = tool.to(device=serial_chain.device, copy=True, dtype=serial_chain.dtype) 31 | cur_transform = tool.get_matrix() 32 | 33 | cnt = 0 34 | for f in reversed(serial_chain._serial_frames): 35 | if f.joint.joint_type == "revolute": 36 | cnt += 1 37 | d = torch.stack([-cur_transform[:, 0, 0] * cur_transform[:, 1, 3] 38 | + cur_transform[:, 1, 0] * cur_transform[:, 0, 3], 39 | -cur_transform[:, 0, 1] * cur_transform[:, 1, 3] 40 | + cur_transform[:, 1, 1] * cur_transform[:, 0, 3], 41 | -cur_transform[:, 0, 2] * cur_transform[:, 1, 3] 42 | + cur_transform[:, 1, 2] * cur_transform[:, 0, 3]]).transpose(0, 1) 43 | delta = cur_transform[:, 2, 0:3] 44 | j_fl[:, :, -cnt] = torch.cat((d, delta), dim=-1) 45 | elif f.joint.joint_type == "prismatic": 46 | cnt += 1 47 | j_fl[:, :3, -cnt] = f.joint.axis.repeat(N, 1) @ cur_transform[:, :3, :3] 48 | cur_frame_transform = f.get_transform(th[:, -cnt].view(N, 1)).get_matrix() 49 | cur_transform = cur_frame_transform @ cur_transform 50 | 51 | # currently j_fl is Jacobian in flange (end-effector) frame, convert to base/world frame 52 | pose = serial_chain.forward_kinematics(th).get_matrix() 53 | rotation = pose[:, :3, :3] 54 | j_tr = torch.zeros((N, 6, 6), dtype=serial_chain.dtype, device=serial_chain.device) 55 | j_tr[:, :3, :3] = rotation 56 | j_tr[:, 3:, 3:] = rotation 57 | j_w = j_tr @ j_fl 58 | return j_w 59 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/mjcf.py: -------------------------------------------------------------------------------- 1 | from . import frame 2 | from . import chain 3 | from . import mjcf_parser 4 | import pytorch_kinematics.transforms as tf 5 | 6 | JOINT_TYPE_MAP = {'hinge': 'revolute', None: 'revolute'} 7 | 8 | 9 | def geoms_to_visuals(geom, base=tf.Transform3d()): 10 | visuals = [] 11 | for g in geom: 12 | if g.type == 'capsule': 13 | param = g.size 14 | elif g.type == 'sphere': 15 | param = g.size[0] 16 | elif g.type == 'box': 17 | param = g.size 18 | elif g.type == 'mesh': 19 | param = (g.mesh.name, g.mesh.scale) 20 | else: 21 | # print(g.name) 22 | param = (g.mesh.name, g.mesh.scale) 23 | # raise ValueError('Invalid geometry type %s.' % g.type) 24 | visuals.append(frame.Visual(offset=base.compose(tf.Transform3d(rot=g.quat, pos=g.pos)), 25 | geom_type=g.type, 26 | geom_param=param)) 27 | return visuals 28 | 29 | 30 | def body_to_link(body, base=tf.Transform3d()): 31 | return frame.Link(body.name, 32 | offset=base.compose(tf.Transform3d(rot=body.quat, pos=body.pos))) 33 | 34 | 35 | def joint_to_joint(joint, base=tf.Transform3d()): 36 | return frame.Joint(joint.name, 37 | offset=base.compose(tf.Transform3d(pos=joint.pos)), 38 | joint_type=JOINT_TYPE_MAP[joint.type], 39 | jrange=joint.range, 40 | axis=joint.axis) 41 | 42 | 43 | def add_composite_joint(root_frame, joints, base=tf.Transform3d()): 44 | if len(joints) > 0: 45 | root_frame.children = root_frame.children + (frame.Frame(link=frame.Link(name=root_frame.link.name + '_child'), 46 | joint=joint_to_joint(joints[0], base)),) 47 | ret, offset = add_composite_joint(root_frame.children[-1], joints[1:]) 48 | return ret, root_frame.joint.offset.compose(offset) 49 | else: 50 | return root_frame, root_frame.joint.offset 51 | 52 | 53 | def _build_chain_recurse(root_frame, root_body): 54 | base = root_frame.link.offset 55 | cur_frame, cur_base = add_composite_joint(root_frame, root_body.joint, base) 56 | jbase = cur_base.inverse().compose(base) 57 | if len(root_body.joint) > 0: 58 | cur_frame.link.visuals = geoms_to_visuals(root_body.geom, jbase) 59 | else: 60 | cur_frame.link.visuals = geoms_to_visuals(root_body.geom) 61 | for b in root_body.body: 62 | cur_frame.children = cur_frame.children + (frame.Frame(),) 63 | next_frame = cur_frame.children[-1] 64 | next_frame.name = b.name + "_frame" 65 | next_frame.link = body_to_link(b, jbase) 66 | _build_chain_recurse(next_frame, b) 67 | 68 | 69 | def build_chain_from_mjcf(data): 70 | """ 71 | Build a Chain object from MJCF data. 72 | 73 | Parameters 74 | ---------- 75 | data : str 76 | MJCF string data. 77 | 78 | Returns 79 | ------- 80 | chain.Chain 81 | Chain object created from MJCF. 82 | """ 83 | model = mjcf_parser.from_xml_string(data) 84 | root_body = model.worldbody.body[0] 85 | root_frame = frame.Frame(root_body.name + "_frame", 86 | link=body_to_link(root_body), 87 | joint=frame.Joint()) 88 | _build_chain_recurse(root_frame, root_body) 89 | return chain.Chain(root_frame) 90 | 91 | 92 | def build_serial_chain_from_mjcf(data, end_link_name, root_link_name=""): 93 | """ 94 | Build a SerialChain object from MJCF data. 95 | 96 | Parameters 97 | ---------- 98 | data : str 99 | MJCF string data. 100 | end_link_name : str 101 | The name of the link that is the end effector. 102 | root_link_name : str, optional 103 | The name of the root link. 104 | 105 | Returns 106 | ------- 107 | chain.SerialChain 108 | SerialChain object created from MJCF. 109 | """ 110 | mjcf_chain = build_chain_from_mjcf(data) 111 | return chain.SerialChain(mjcf_chain, end_link_name + "_frame", 112 | "" if root_link_name == "" else root_link_name + "_frame") 113 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/mjcf_parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import * -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/mjcf_parser/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Base class for all MJCF elements in the object model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import abc 23 | 24 | import six 25 | 26 | 27 | @six.add_metaclass(abc.ABCMeta) 28 | class Element(object): 29 | """Abstract base class for an MJCF element. 30 | 31 | This class is provided so that `isinstance(foo, Element)` is `True` for all 32 | Element-like objects. We do not implement the actual element here because 33 | the actual object returned from traversing the object hierarchy is a 34 | weakproxy-like proxy to an actual element. This is because we do not allow 35 | orphaned non-root elements, so when a particular element is removed from the 36 | tree, all references held automatically become invalid. 37 | """ 38 | __slots__ = [] 39 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/mjcf_parser/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Magic constants used within `dm_control.mjcf`.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | PREFIX_SEPARATOR = '/' 23 | PREFIX_SEPARATOR_ESCAPE = '\\' 24 | 25 | # Used to disambiguate namespaces between attachment frames. 26 | NAMESPACE_SEPARATOR = '@' 27 | 28 | # Magic attribute names 29 | BASEPATH = 'basepath' 30 | CHILDCLASS = 'childclass' 31 | CLASS = 'class' 32 | DEFAULT = 'default' 33 | DCLASS = 'dclass' 34 | 35 | # Magic tags 36 | ACTUATOR = 'actuator' 37 | BODY = 'body' 38 | DEFAULT = 'default' 39 | MESH = 'mesh' 40 | SITE = 'site' 41 | TENDON = 'tendon' 42 | WORLDBODY = 'worldbody' 43 | 44 | MJDATA_TRIGGERS_DIRTY = [ 45 | 'qpos', 'qvel', 'act', 'ctrl', 'qfrc_applied', 'xfrc_applied'] 46 | MJMODEL_DOESNT_TRIGGER_DIRTY = [ 47 | 'rgba', 'matid', 'emission', 'specular', 'shininess', 'reflectance'] 48 | 49 | # When writing into `model.{body,geom,site}_{pos,quat}` we must ensure that the 50 | # corresponding rows in `model.{body,geom,site}_sameframe` are set to zero, 51 | # otherwise MuJoCo will use the body or inertial frame instead of our modified 52 | # pos/quat values. We must do the same for `body_{ipos,iquat}` and 53 | # `body_simple`. 54 | MJMODEL_DISABLE_ON_WRITE = { 55 | # Field name in MjModel: (attribute names of Binding instance to be zeroed) 56 | 'body_pos': ('sameframe',), 57 | 'body_quat': ('sameframe',), 58 | 'geom_pos': ('sameframe',), 59 | 'geom_quat': ('sameframe',), 60 | 'site_pos': ('sameframe',), 61 | 'site_quat': ('sameframe',), 62 | 'body_ipos': ('simple', 'sameframe'), 63 | 'body_iquat': ('simple', 'sameframe'), 64 | } 65 | 66 | # This is the actual upper limit on VFS filename length, despite what it says 67 | # in the header file (100) or the error message (99). 68 | MAX_VFS_FILENAME_LENGTH = 98 69 | 70 | # The prefix used in the schema to denote reference_namespace that are defined 71 | # via another attribute. 72 | INDIRECT_REFERENCE_NAMESPACE_PREFIX = 'attrib:' 73 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/mjcf_parser/copier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Helper object for keeping track of new elements created when copying MJCF.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from . import constants 23 | 24 | 25 | class Copier(object): 26 | """Helper for keeping track of new elements created when copying MJCF.""" 27 | 28 | def __init__(self, source): 29 | if source._attachments: # pylint: disable=protected-access 30 | raise NotImplementedError('Cannot copy from elements with attachments') 31 | self._source = source 32 | 33 | def copy_into(self, destination, override_attributes=False): 34 | """Copies this copier's element into a destination MJCF element.""" 35 | newly_created_elements = {} 36 | destination._check_valid_attachment(self._source) # pylint: disable=protected-access 37 | if override_attributes: 38 | destination.set_attributes(**self._source.get_attributes()) 39 | else: 40 | destination._sync_attributes(self._source, copying=True) # pylint: disable=protected-access 41 | for source_child in self._source.all_children(): 42 | dest_child = None 43 | # First, if source_child has an identifier, we look for an existing child 44 | # element of self with the same identifier to override. 45 | if source_child.spec.identifier and override_attributes: 46 | identifier_attr = source_child.spec.identifier 47 | if identifier_attr == constants.CLASS: 48 | identifier_attr = constants.DCLASS 49 | identifier = getattr(source_child, identifier_attr) 50 | if identifier: 51 | dest_child = destination.find(source_child.spec.namespace, identifier) 52 | if dest_child is not None and dest_child.parent is not destination: 53 | raise ValueError( 54 | '<{}> with identifier {!r} is already a child of another element' 55 | .format(source_child.spec.namespace, identifier)) 56 | # Next, we cover the case where either the child is not a repeated element 57 | # or if source_child has an identifier attribute but it isn't set. 58 | if not source_child.spec.repeated and dest_child is None: 59 | dest_child = destination.get_children(source_child.tag) 60 | 61 | # Add a new element if dest_child doesn't exist, either because it is 62 | # supposed to be a repeated child, or because it's an uncreated on-demand. 63 | if dest_child is None: 64 | dest_child = destination.add( 65 | source_child.tag, **source_child.get_attributes()) 66 | newly_created_elements[source_child] = dest_child 67 | override_child_attributes = True 68 | else: 69 | override_child_attributes = override_attributes 70 | 71 | # Finally, copy attributes into dest_child. 72 | child_copier = Copier(source_child) 73 | newly_created_elements.update( 74 | child_copier.copy_into(dest_child, override_child_attributes)) 75 | return newly_created_elements 76 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/mjcf_parser/io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """IO functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | def GetResource(name, mode='rb'): 24 | with open(name, mode=mode) as f: 25 | return f.read() 26 | 27 | 28 | def GetResourceFilename(name, mode='rb'): 29 | del mode # Unused. 30 | return name 31 | 32 | 33 | GetResourceAsFile = open # pylint: disable=invalid-name 34 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/mjcf_parser/namescope.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """An object to manage the scoping of identifiers in MJCF models.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from . import constants 25 | import six 26 | 27 | 28 | class NameScope(object): 29 | """A name scoping context for an MJCF model. 30 | 31 | This object maintains the uniqueness of identifiers within each MJCF 32 | namespace. Examples of MJCF namespaces include 'body', 'joint', and 'geom'. 33 | Each namescope also carries a name, and can have a parent namescope. 34 | When MJCF models are merged, all identifiers gain a hierarchical prefix 35 | separated by '/', which is the concatenation of all scope names up to 36 | the root namescope. 37 | """ 38 | 39 | def __init__(self, name, mjcf_model, model_dir='', assets=None): 40 | """Initializes a scope with the given name. 41 | 42 | Args: 43 | name: The scope's name 44 | mjcf_model: The RootElement of the MJCF model associated with this scope. 45 | model_dir: (optional) Path to the directory containing the model XML file. 46 | This is used to prefix the paths of all asset files. 47 | assets: (optional) A dictionary of pre-loaded assets, of the form 48 | `{filename: bytestring}`. If present, PyMJCF will search for assets in 49 | this dictionary before attempting to load them from the filesystem. 50 | """ 51 | self._parent = None 52 | self._name = name 53 | self._mjcf_model = mjcf_model 54 | self._namespaces = collections.defaultdict(dict) 55 | self._model_dir = model_dir 56 | self._files = set() 57 | self._assets = assets or {} 58 | self._revision = 0 59 | 60 | @property 61 | def revision(self): 62 | return self._revision 63 | 64 | def increment_revision(self): 65 | self._revision += 1 66 | for namescope in six.itervalues(self._namespaces['namescope']): 67 | namescope.increment_revision() 68 | 69 | @property 70 | def name(self): 71 | """This scope's name.""" 72 | return self._name 73 | 74 | @property 75 | def files(self): 76 | """A set containing the `File` attributes registered in this scope.""" 77 | return self._files 78 | 79 | @property 80 | def assets(self): 81 | """A dictionary containing pre-loaded assets.""" 82 | return self._assets 83 | 84 | @property 85 | def model_dir(self): 86 | """Path to the directory containing the model XML file.""" 87 | return self._model_dir 88 | 89 | @name.setter 90 | def name(self, new_name): 91 | if self._parent: 92 | self._parent.add('namescope', new_name, self) 93 | self._parent.remove('namescope', self._name) 94 | self._name = new_name 95 | self.increment_revision() 96 | 97 | @property 98 | def mjcf_model(self): 99 | return self._mjcf_model 100 | 101 | @property 102 | def parent(self): 103 | """This parent `NameScope`, or `None` if this is a root scope.""" 104 | return self._parent 105 | 106 | @parent.setter 107 | def parent(self, new_parent): 108 | if self._parent: 109 | self._parent.remove('namescope', self._name) 110 | self._parent = new_parent 111 | if self._parent: 112 | self._parent.add('namescope', self._name, self) 113 | self.increment_revision() 114 | 115 | @property 116 | def root(self): 117 | if self._parent is None: 118 | return self 119 | else: 120 | return self._parent.root 121 | 122 | def full_prefix(self, prefix_root=None, as_list=False): 123 | """The prefix for identifiers belonging to this scope. 124 | 125 | Args: 126 | prefix_root: (optional) A `NameScope` object to be treated as root 127 | for the purpose of calculating the prefix. If `None` then no prefix 128 | is produced. 129 | as_list: (optional) A boolean, if `True` return the list of prefix 130 | components. If `False`, return the full prefix string separated by 131 | `mjcf.constants.PREFIX_SEPARATOR`. 132 | 133 | Returns: 134 | The prefix string. 135 | """ 136 | prefix_root = prefix_root or self 137 | if prefix_root != self and self._parent: 138 | prefix_list = self._parent.full_prefix(prefix_root, as_list=True) 139 | prefix_list.append(self._name) 140 | else: 141 | prefix_list = [] 142 | if as_list: 143 | return prefix_list 144 | else: 145 | if prefix_list: 146 | prefix_list.append('') 147 | return constants.PREFIX_SEPARATOR.join(prefix_list) 148 | 149 | def _assign(self, namespace, identifier, obj): 150 | """Checks a proposed identifier's validity before assigning to an object.""" 151 | namespace_dict = self._namespaces[namespace] 152 | if not isinstance(identifier, str): 153 | raise ValueError( 154 | 'Identifier must be a string: got {}'.format(type(identifier))) 155 | elif constants.PREFIX_SEPARATOR in identifier: 156 | raise ValueError( 157 | 'Identifier cannot contain {!r}: got {}' 158 | .format(constants.PREFIX_SEPARATOR, identifier)) 159 | else: 160 | namespace_dict[identifier] = obj 161 | 162 | def add(self, namespace, identifier, obj): 163 | """Add an identifier to this name scope. 164 | 165 | Args: 166 | namespace: A string specifying the namespace to which the 167 | identifier belongs. 168 | identifier: The identifier string. 169 | obj: The object referred to by the identifier. 170 | 171 | Raises: 172 | ValueError: If `identifier` not valid. 173 | """ 174 | namespace_dict = self._namespaces[namespace] 175 | if identifier in namespace_dict: 176 | raise ValueError('Duplicated identifier {!r} in namespace <{}>' 177 | .format(identifier, namespace)) 178 | else: 179 | self._assign(namespace, identifier, obj) 180 | self.increment_revision() 181 | 182 | def replace(self, namespace, identifier, obj): 183 | """Reassociates an identifier with a different object. 184 | 185 | Args: 186 | namespace: A string specifying the namespace to which the 187 | identifier belongs. 188 | identifier: The identifier string. 189 | obj: The object referred to by the identifier. 190 | 191 | Raises: 192 | ValueError: If `identifier` not valid. 193 | """ 194 | self._assign(namespace, identifier, obj) 195 | self.increment_revision() 196 | 197 | def remove(self, namespace, identifier): 198 | """Removes an identifier from this name scope. 199 | 200 | Args: 201 | namespace: A string specifying the namespace to which the 202 | identifier belongs. 203 | identifier: The identifier string. 204 | 205 | Raises: 206 | KeyError: If `identifier` does not exist in this scope. 207 | """ 208 | del self._namespaces[namespace][identifier] 209 | self.increment_revision() 210 | 211 | def rename(self, namespace, old_identifier, new_identifier): 212 | obj = self.get(namespace, old_identifier) 213 | self.add(namespace, new_identifier, obj) 214 | self.remove(namespace, old_identifier) 215 | 216 | def get(self, namespace, identifier): 217 | return self._namespaces[namespace][identifier] 218 | 219 | def has_identifier(self, namespace, identifier): 220 | return identifier in self._namespaces[namespace] 221 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/mjcf_parser/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Various helper functions and classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import sys 23 | import six 24 | 25 | DEFAULT_ENCODING = sys.getdefaultencoding() 26 | 27 | 28 | def to_binary_string(s): 29 | """Convert text string to binary.""" 30 | if isinstance(s, six.binary_type): 31 | return s 32 | return s.encode(DEFAULT_ENCODING) 33 | 34 | 35 | def to_native_string(s): 36 | """Convert a text or binary string to the native string format.""" 37 | if six.PY3 and isinstance(s, six.binary_type): 38 | return s.decode(DEFAULT_ENCODING) 39 | elif six.PY2 and isinstance(s, six.text_type): 40 | return s.encode(DEFAULT_ENCODING) 41 | else: 42 | return s 43 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/sdf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from .urdf_parser_py.sdf import SDF, Mesh, Cylinder, Box, Sphere 4 | from . import frame 5 | from . import chain 6 | import pytorch_kinematics.transforms as tf 7 | 8 | JOINT_TYPE_MAP = {'revolute': 'revolute', 9 | 'prismatic': 'prismatic', 10 | 'fixed': 'fixed'} 11 | 12 | 13 | def _convert_transform(pose): 14 | if pose is None: 15 | return tf.Transform3d() 16 | else: 17 | return tf.Transform3d(rot=tf.euler_angles_to_matrix(torch.tensor(pose[3:]), "ZYX"), pos=pose[:3]) 18 | 19 | 20 | def _convert_visuals(visuals): 21 | vlist = [] 22 | for v in visuals: 23 | v_tf = _convert_transform(v.pose) 24 | if isinstance(v.geometry, Mesh): 25 | g_type = "mesh" 26 | g_param = v.geometry.filename 27 | elif isinstance(v.geometry, Cylinder): 28 | g_type = "cylinder" 29 | v_tf = v_tf.compose( 30 | tf.Transform3d(rot=tf.euler_angles_to_matrix(torch.tensor([0.5 * math.pi, 0, 0]), "ZYX"))) 31 | g_param = (v.geometry.radius, v.geometry.length) 32 | elif isinstance(v.geometry, Box): 33 | g_type = "box" 34 | g_param = v.geometry.size 35 | elif isinstance(v.geometry, Sphere): 36 | g_type = "sphere" 37 | g_param = v.geometry.radius 38 | else: 39 | g_type = None 40 | g_param = None 41 | vlist.append(frame.Visual(v_tf, g_type, g_param)) 42 | return vlist 43 | 44 | 45 | def _build_chain_recurse(root_frame, lmap, joints): 46 | children = [] 47 | for j in joints: 48 | if j.parent == root_frame.link.name: 49 | child_frame = frame.Frame(j.child + "_frame") 50 | link_p = lmap[j.parent] 51 | link_c = lmap[j.child] 52 | t_p = _convert_transform(link_p.pose) 53 | t_c = _convert_transform(link_c.pose) 54 | child_frame.joint = frame.Joint(j.name, offset=t_p.inverse().compose(t_c), 55 | joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis.xyz) 56 | child_frame.link = frame.Link(link_c.name, offset=tf.Transform3d(), 57 | visuals=_convert_visuals(link_c.visuals)) 58 | child_frame.children = _build_chain_recurse(child_frame, lmap, joints) 59 | children.append(child_frame) 60 | return children 61 | 62 | 63 | def build_chain_from_sdf(data): 64 | """ 65 | Build a Chain object from SDF data. 66 | 67 | Parameters 68 | ---------- 69 | data : str 70 | SDF string data. 71 | 72 | Returns 73 | ------- 74 | chain.Chain 75 | Chain object created from SDF. 76 | """ 77 | sdf = SDF.from_xml_string(data) 78 | robot = sdf.model 79 | lmap = robot.link_map 80 | joints = robot.joints 81 | n_joints = len(joints) 82 | has_root = [True for _ in range(len(joints))] 83 | for i in range(n_joints): 84 | for j in range(i + 1, n_joints): 85 | if joints[i].parent == joints[j].child: 86 | has_root[i] = False 87 | elif joints[j].parent == joints[i].child: 88 | has_root[j] = False 89 | for i in range(n_joints): 90 | if has_root[i]: 91 | root_link = lmap[joints[i].parent] 92 | break 93 | root_frame = frame.Frame(root_link.name + "_frame") 94 | root_frame.joint = frame.Joint(offset=_convert_transform(root_link.pose)) 95 | root_frame.link = frame.Link(root_link.name, tf.Transform3d(), 96 | _convert_visuals(root_link.visuals)) 97 | root_frame.children = _build_chain_recurse(root_frame, lmap, joints) 98 | return chain.Chain(root_frame) 99 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from .rotation_conversions import ( 4 | axis_angle_to_matrix, 5 | axis_angle_to_quaternion, 6 | euler_angles_to_matrix, 7 | matrix_to_euler_angles, 8 | matrix_to_quaternion, 9 | matrix_to_rotation_6d, 10 | quaternion_apply, 11 | quaternion_invert, 12 | quaternion_multiply, 13 | quaternion_raw_multiply, 14 | quaternion_to_matrix, 15 | random_quaternions, 16 | random_rotation, 17 | random_rotations, 18 | rotation_6d_to_matrix, 19 | standardize_quaternion, 20 | xyzw_to_wxyz, 21 | wxyz_to_xyzw 22 | ) 23 | from .so3 import ( 24 | so3_exp_map, 25 | so3_log_map, 26 | so3_relative_angle, 27 | so3_rotation_angle, 28 | ) 29 | from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate 30 | 31 | 32 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 33 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/transforms/math.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple, Union 9 | 10 | import torch 11 | 12 | 13 | def acos_linear_extrapolation( 14 | x: torch.Tensor, 15 | bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4, 16 | ) -> torch.Tensor: 17 | """ 18 | Implements `arccos(x)` which is linearly extrapolated outside `x`'s original 19 | domain of `(-1, 1)`. This allows for stable backpropagation in case `x` 20 | is not guaranteed to be strictly within `(-1, 1)`. 21 | More specifically: 22 | ``` 23 | if -bound <= x <= bound: 24 | acos_linear_extrapolation(x) = acos(x) 25 | elif x <= -bound: # 1st order Taylor approximation 26 | acos_linear_extrapolation(x) = acos(-bound) + dacos/dx(-bound) * (x - (-bound)) 27 | else: # x >= bound 28 | acos_linear_extrapolation(x) = acos(bound) + dacos/dx(bound) * (x - bound) 29 | ``` 30 | Note that `bound` can be made more specific with setting 31 | `bound=[lower_bound, upper_bound]` as detailed below. 32 | Args: 33 | x: Input `Tensor`. 34 | bound: A float constant or a float 2-tuple defining the region for the 35 | linear extrapolation of `acos`. 36 | If `bound` is a float scalar, linearly interpolates acos for 37 | `x <= -bound` or `bound <= x`. 38 | If `bound` is a 2-tuple, the first/second element of `bound` 39 | describes the lower/upper bound that defines the lower/upper 40 | extrapolation region, i.e. the region where 41 | `x <= bound[0]`/`bound[1] <= x`. 42 | Note that all elements of `bound` have to be within (-1, 1). 43 | Returns: 44 | acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`. 45 | """ 46 | 47 | if isinstance(bound, float): 48 | upper_bound = bound 49 | lower_bound = -bound 50 | else: 51 | lower_bound, upper_bound = bound 52 | 53 | if lower_bound > upper_bound: 54 | raise ValueError("lower bound has to be smaller or equal to upper bound.") 55 | 56 | if lower_bound <= -1.0 or upper_bound >= 1.0: 57 | raise ValueError("Both lower bound and upper bound have to be within (-1, 1).") 58 | 59 | # init an empty tensor and define the domain sets 60 | acos_extrap = torch.empty_like(x) 61 | x_upper = x >= upper_bound 62 | x_lower = x <= lower_bound 63 | x_mid = (~x_upper) & (~x_lower) 64 | 65 | # acos calculation for upper_bound < x < lower_bound 66 | acos_extrap[x_mid] = torch.acos(x[x_mid]) 67 | # the linear extrapolation for x >= upper_bound 68 | acos_extrap[x_upper] = _acos_linear_approximation(x[x_upper], upper_bound) 69 | # the linear extrapolation for x <= lower_bound 70 | acos_extrap[x_lower] = _acos_linear_approximation(x[x_lower], lower_bound) 71 | 72 | return acos_extrap 73 | 74 | 75 | def _acos_linear_approximation(x: torch.Tensor, x0: float) -> torch.Tensor: 76 | """ 77 | Calculates the 1st order Taylor expansion of `arccos(x)` around `x0`. 78 | """ 79 | return (x - x0) * _dacos_dx(x0) + math.acos(x0) 80 | 81 | 82 | def _dacos_dx(x: float) -> float: 83 | """ 84 | Calculates the derivative of `arccos(x)` w.r.t. `x`. 85 | """ 86 | return (-1.0) / math.sqrt(1.0 - x * x) 87 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/urdf.py: -------------------------------------------------------------------------------- 1 | from .urdf_parser_py.urdf import URDF, Mesh, Cylinder, Box, Sphere 2 | from . import frame 3 | from . import chain 4 | import torch 5 | import pytorch_kinematics.transforms as tf 6 | # has better RPY to quaternion transformation 7 | import transformations as tf2 8 | 9 | JOINT_TYPE_MAP = {'revolute': 'revolute', 10 | 'continuous': 'revolute', 11 | 'prismatic': 'prismatic', 12 | 'fixed': 'fixed'} 13 | 14 | 15 | def _convert_transform(origin): 16 | if origin is None: 17 | return tf.Transform3d() 18 | else: 19 | return tf.Transform3d(rot=torch.tensor(tf2.quaternion_from_euler(*origin.rpy, "sxyz"), dtype=torch.float32), 20 | pos=origin.xyz) 21 | 22 | 23 | def _convert_visual(visual): 24 | if visual is None or visual.geometry is None: 25 | return frame.Visual() 26 | else: 27 | v_tf = _convert_transform(visual.origin) 28 | if isinstance(visual.geometry, Mesh): 29 | g_type = "mesh" 30 | g_param = visual.geometry.filename 31 | elif isinstance(visual.geometry, Cylinder): 32 | g_type = "cylinder" 33 | g_param = (visual.geometry.radius, visual.geometry.length) 34 | elif isinstance(visual.geometry, Box): 35 | g_type = "box" 36 | g_param = visual.geometry.size 37 | elif isinstance(visual.geometry, Sphere): 38 | g_type = "sphere" 39 | g_param = visual.geometry.radius 40 | else: 41 | g_type = None 42 | g_param = None 43 | return frame.Visual(v_tf, g_type, g_param) 44 | 45 | 46 | def _build_chain_recurse(root_frame, lmap, joints): 47 | children = [] 48 | for j in joints: 49 | if j.parent == root_frame.link.name: 50 | child_frame = frame.Frame(j.child + "_frame") 51 | child_frame.joint = frame.Joint(j.name, offset=_convert_transform(j.origin), 52 | joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis) 53 | link = lmap[j.child] 54 | child_frame.link = frame.Link(link.name, offset=_convert_transform(link.origin), 55 | visuals=[_convert_visual(link.visual)]) 56 | child_frame.children = _build_chain_recurse(child_frame, lmap, joints) 57 | children.append(child_frame) 58 | return children 59 | 60 | 61 | def build_chain_from_urdf(data): 62 | """ 63 | Build a Chain object from URDF data. 64 | 65 | Parameters 66 | ---------- 67 | data : str 68 | URDF string data. 69 | 70 | Returns 71 | ------- 72 | chain.Chain 73 | Chain object created from URDF. 74 | 75 | Example 76 | ------- 77 | >>> import pytorch_kinematics as pk 78 | >>> data = ''' 79 | ... 80 | ... 81 | ... 82 | ... 83 | ... 84 | ... 85 | ... ''' 86 | >>> chain = pk.build_chain_from_urdf(data) 87 | >>> print(chain) 88 | link1_frame 89 | link2_frame 90 | 91 | """ 92 | robot = URDF.from_xml_string(data) 93 | lmap = robot.link_map 94 | joints = robot.joints 95 | n_joints = len(joints) 96 | has_root = [True for _ in range(len(joints))] 97 | for i in range(n_joints): 98 | for j in range(i + 1, n_joints): 99 | if joints[i].parent == joints[j].child: 100 | has_root[i] = False 101 | elif joints[j].parent == joints[i].child: 102 | has_root[j] = False 103 | for i in range(n_joints): 104 | if has_root[i]: 105 | root_link = lmap[joints[i].parent] 106 | break 107 | root_frame = frame.Frame(root_link.name + "_frame") 108 | root_frame.joint = frame.Joint() 109 | root_frame.link = frame.Link(root_link.name, _convert_transform(root_link.origin), 110 | [_convert_visual(root_link.visual)]) 111 | root_frame.children = _build_chain_recurse(root_frame, lmap, joints) 112 | return chain.Chain(root_frame) 113 | 114 | 115 | def build_serial_chain_from_urdf(data, end_link_name, root_link_name=""): 116 | """ 117 | Build a SerialChain object from urdf data. 118 | 119 | Parameters 120 | ---------- 121 | data : str 122 | URDF string data. 123 | end_link_name : str 124 | The name of the link that is the end effector. 125 | root_link_name : str, optional 126 | The name of the root link. 127 | 128 | Returns 129 | ------- 130 | chain.SerialChain 131 | SerialChain object created from URDF. 132 | """ 133 | urdf_chain = build_chain_from_urdf(data) 134 | return chain.SerialChain(urdf_chain, end_link_name + "_frame", 135 | "" if root_link_name == "" else root_link_name + "_frame") 136 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/urdf_parser_py/__init__.py: -------------------------------------------------------------------------------- 1 | from . import urdf 2 | from . import sdf 3 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/urdf_parser_py/xml_reflection/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/pytorch_kinematics/urdf_parser_py/xml_reflection/basics.py: -------------------------------------------------------------------------------- 1 | import string 2 | import yaml 3 | import collections 4 | from lxml import etree 5 | 6 | 7 | def xml_string(rootXml, addHeader=True): 8 | # Meh 9 | xmlString = etree.tostring(rootXml, pretty_print=True, encoding='unicode') 10 | if addHeader: 11 | xmlString = '\n' + xmlString 12 | return xmlString 13 | 14 | 15 | def dict_sub(obj, keys): 16 | return dict((key, obj[key]) for key in keys) 17 | 18 | 19 | def node_add(doc, sub): 20 | if sub is None: 21 | return None 22 | if type(sub) == str: 23 | return etree.SubElement(doc, sub) 24 | elif isinstance(sub, etree._Element): 25 | doc.append(sub) # This screws up the rest of the tree for prettyprint 26 | return sub 27 | else: 28 | raise Exception('Invalid sub value') 29 | 30 | 31 | def pfloat(x): 32 | return str(x).rstrip('.') 33 | 34 | 35 | def xml_children(node): 36 | children = node.getchildren() 37 | 38 | def predicate(node): 39 | return not isinstance(node, etree._Comment) 40 | 41 | return list(filter(predicate, children)) 42 | 43 | 44 | def isstring(obj): 45 | try: 46 | return isinstance(obj, basestring) 47 | except NameError: 48 | return isinstance(obj, str) 49 | 50 | 51 | def to_yaml(obj): 52 | """ Simplify yaml representation for pretty printing """ 53 | # Is there a better way to do this by adding a representation with 54 | # yaml.Dumper? 55 | # Ordered dict: http://pyyaml.org/ticket/29#comment:11 56 | if obj is None or isstring(obj): 57 | out = str(obj) 58 | elif type(obj) in [int, float, bool]: 59 | return obj 60 | elif hasattr(obj, 'to_yaml'): 61 | out = obj.to_yaml() 62 | elif isinstance(obj, etree._Element): 63 | out = etree.tostring(obj, pretty_print=True) 64 | elif type(obj) == dict: 65 | out = {} 66 | for (var, value) in obj.items(): 67 | out[str(var)] = to_yaml(value) 68 | elif hasattr(obj, 'tolist'): 69 | # For numpy objects 70 | out = to_yaml(obj.tolist()) 71 | elif isinstance(obj, collections.Iterable): 72 | out = [to_yaml(item) for item in obj] 73 | else: 74 | out = str(obj) 75 | return out 76 | 77 | 78 | class SelectiveReflection(object): 79 | def get_refl_vars(self): 80 | return list(vars(self).keys()) 81 | 82 | 83 | class YamlReflection(SelectiveReflection): 84 | def to_yaml(self): 85 | raw = dict((var, getattr(self, var)) for var in self.get_refl_vars()) 86 | return to_yaml(raw) 87 | 88 | def __str__(self): 89 | # Good idea? Will it remove other important things? 90 | return yaml.dump(self.to_yaml()).rstrip() 91 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='pytorch_kinematics', 5 | version='0.3.0', 6 | packages=['pytorch_kinematics'], 7 | url='https://github.com/UM-ARM-Lab/pytorch_kinematics', 8 | license='MIT', 9 | author='zhsh', 10 | author_email='zhsh@umich.edu', 11 | description='Robot kinematics implemented in pytorch', 12 | install_requires=[ 13 | 'torch', 14 | 'numpy', 15 | 'transformations', 16 | 'absl-py' 17 | ], 18 | tests_require=[ 19 | 'pytest' 20 | ] 21 | ) 22 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iSEE-Laboratory/Grasp-as-You-Say/4694fa9369523d09d0c2cea6ec7bcddd2cb206d8/thirdparty/pytorch_kinematics/tests/__init__.py -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/tests/ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 72 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/tests/prismatic_robot.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/tests/test_jacobian.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import pytorch_kinematics as pk 4 | 5 | 6 | def test_correctness(): 7 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 8 | th = torch.tensor([0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]) 9 | J = chain.jacobian(th) 10 | 11 | assert torch.allclose(J, torch.tensor([[[0, 1.41421356e-02, 0, 2.82842712e-01, 0, 0, 0], 12 | [-6.60827561e-01, 0, -4.57275649e-01, 0, 5.72756493e-02, 0, 0], 13 | [0, 6.60827561e-01, 0, -3.63842712e-01, 0, 8.10000000e-02, 0], 14 | [0, 0, -7.07106781e-01, 0, -7.07106781e-01, 0, -1], 15 | [0, 1, 0, -1, 0, 1, 0], 16 | [1, 0, 7.07106781e-01, 0, -7.07106781e-01, 0, 0]]])) 17 | 18 | chain = pk.build_chain_from_sdf(open("simple_arm.sdf").read()) 19 | chain = pk.SerialChain(chain, "arm_wrist_roll_frame") 20 | th = torch.tensor([0.8, 0.2, -0.5, -0.3]) 21 | J = chain.jacobian(th) 22 | torch.allclose(J, torch.tensor([[[0., -1.51017878, -0.46280904, 0.], 23 | [0., 0.37144033, 0.29716627, 0.], 24 | [0., 0., 0., 0.], 25 | [0., 0., 0., 0.], 26 | [0., 0., 0., 0.], 27 | [0., 1., 1., 1.]]])) 28 | 29 | 30 | def test_jacobian_at_different_loc_than_ee(): 31 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 32 | th = torch.tensor([0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]) 33 | loc = torch.tensor([0.1, 0, 0]) 34 | J = chain.jacobian(th, locations=loc) 35 | J_c1 = torch.tensor([[[-0., 0.11414214, -0., 0.18284271, 0., 0.1, 0.], 36 | [-0.66082756, -0., -0.38656497, -0., 0.12798633, -0., 0.1], 37 | [-0., 0.66082756, -0., -0.36384271, 0., 0.081, -0.], 38 | [-0., -0., -0.70710678, -0., -0.70710678, 0., -1.], 39 | [0., 1., 0., -1., 0., 1., 0.], 40 | [1., 0., 0.70710678, 0., -0.70710678, -0., 0.]]]) 41 | 42 | assert torch.allclose(J, J_c1) 43 | 44 | loc = torch.tensor([-0.1, 0.05, 0]) 45 | J = chain.jacobian(th, locations=loc) 46 | J_c2 = torch.tensor([[[-0.05, -0.08585786, -0.03535534, 0.38284271, 0.03535534, -0.1, -0.], 47 | [-0.66082756, -0., -0.52798633, -0., -0.01343503, 0., -0.1], 48 | [-0., 0.66082756, -0.03535534, -0.36384271, -0.03535534, 0.081, -0.05], 49 | [-0., -0., -0.70710678, -0., -0.70710678, 0., -1.], 50 | [0., 1., 0., -1., 0., 1., 0.], 51 | [1., 0., 0.70710678, 0., -0.70710678, -0., 0.]]]) 52 | 53 | assert torch.allclose(J, J_c2) 54 | 55 | # check that batching the location is fine 56 | th = th.repeat(2, 1) 57 | loc = torch.tensor([[0.1, 0, 0], [-0.1, 0.05, 0]]) 58 | J = chain.jacobian(th, locations=loc) 59 | assert torch.allclose(J, torch.cat((J_c1, J_c2))) 60 | 61 | 62 | def test_parallel(): 63 | N = 100 64 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 65 | th = torch.cat( 66 | (torch.tensor([[0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]]), torch.rand(N, 7))) 67 | J = chain.jacobian(th) 68 | for i in range(N): 69 | J_i = chain.jacobian(th[i]) 70 | assert torch.allclose(J[i], J_i) 71 | 72 | 73 | def test_dtype_device(): 74 | N = 1000 75 | d = "cuda" if torch.cuda.is_available() else "cpu" 76 | dtype = torch.float64 77 | 78 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 79 | chain = chain.to(dtype=dtype, device=d) 80 | th = torch.rand(N, 7, dtype=dtype, device=d) 81 | J = chain.jacobian(th) 82 | assert J.dtype is dtype 83 | 84 | 85 | def test_gradient(): 86 | N = 10 87 | d = "cuda" if torch.cuda.is_available() else "cpu" 88 | dtype = torch.float64 89 | 90 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 91 | chain = chain.to(dtype=dtype, device=d) 92 | th = torch.rand(N, 7, dtype=dtype, device=d, requires_grad=True) 93 | J = chain.jacobian(th) 94 | assert th.grad is None 95 | J.norm().backward() 96 | assert th.grad is not None 97 | 98 | 99 | def test_jacobian_prismatic(): 100 | chain = pk.build_serial_chain_from_urdf(open("prismatic_robot.urdf").read(), "link4") 101 | th = torch.zeros(3) 102 | tg = chain.forward_kinematics(th) 103 | m = tg.get_matrix() 104 | pos = m[0, :3, 3] 105 | assert torch.allclose(pos, torch.tensor([0, 0, 1.])) 106 | th = torch.tensor([0, 0.1, 0]) 107 | tg = chain.forward_kinematics(th) 108 | m = tg.get_matrix() 109 | pos = m[0, :3, 3] 110 | assert torch.allclose(pos, torch.tensor([0, -0.1, 1.])) 111 | th = torch.tensor([0.1, 0.1, 0]) 112 | tg = chain.forward_kinematics(th) 113 | m = tg.get_matrix() 114 | pos = m[0, :3, 3] 115 | assert torch.allclose(pos, torch.tensor([0, -0.1, 1.1])) 116 | th = torch.tensor([0.1, 0.1, 0.1]) 117 | tg = chain.forward_kinematics(th) 118 | m = tg.get_matrix() 119 | pos = m[0, :3, 3] 120 | assert torch.allclose(pos, torch.tensor([0.1, -0.1, 1.1])) 121 | 122 | J = chain.jacobian(th) 123 | assert torch.allclose(J, torch.tensor([[[0., 0., 1.], 124 | [0., -1., 0.], 125 | [1., 0., 0.], 126 | [0., 0., 0.], 127 | [0., 0., 0.], 128 | [0., 0., 0.]]])) 129 | 130 | 131 | if __name__ == "__main__": 132 | test_correctness() 133 | test_parallel() 134 | test_dtype_device() 135 | test_gradient() 136 | test_jacobian_prismatic() 137 | test_jacobian_at_different_loc_than_ee() 138 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/tests/test_kinematics.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import pytorch_kinematics as pk 5 | 6 | 7 | def quat_pos_from_transform3d(tg): 8 | m = tg.get_matrix() 9 | pos = m[:, :3, 3] 10 | rot = pk.matrix_to_quaternion(m[:, :3, :3]) 11 | return pos, rot 12 | 13 | 14 | def quaternion_equality(a, b): 15 | # negative of a quaternion is the same rotation 16 | return torch.allclose(a, b) or torch.allclose(a, -b) 17 | 18 | 19 | def test_fkik(): 20 | data = '' \ 21 | '' \ 22 | '' \ 23 | '' \ 24 | '' \ 25 | '' \ 26 | '' \ 27 | '' \ 28 | '' \ 29 | '' \ 30 | '' \ 31 | '' \ 32 | '' \ 33 | '' \ 34 | '' 35 | chain = pk.build_serial_chain_from_urdf(data, 'link3') 36 | th1 = torch.tensor([0.42553542, 0.17529176]) 37 | tg = chain.forward_kinematics(th1) 38 | pos, rot = quat_pos_from_transform3d(tg) 39 | assert torch.allclose(pos, torch.tensor([[1.91081784, 0.41280851, 0.0000]])) 40 | assert quaternion_equality(rot, torch.tensor([[0.95521418, 0.0000, 0.0000, 0.2959153]])) 41 | print(tg) 42 | # TODO implement and test inverse kinematics 43 | # th2 = chain.inverse_kinematics(tg) 44 | # self.assertTrue(np.allclose(th1, th2, atol=1.0e-6)) 45 | # test batch kinematics 46 | N = 20 47 | th_batch = torch.rand(N, 2) 48 | tg_batch = chain.forward_kinematics(th_batch) 49 | m = tg_batch.get_matrix() 50 | for i in range(N): 51 | tg = chain.forward_kinematics(th_batch[i]) 52 | assert torch.allclose(tg.get_matrix().view(4, 4), m[i]) 53 | 54 | # check that gradients are passed through 55 | th2 = torch.tensor([0.42553542, 0.17529176], requires_grad=True) 56 | tg = chain.forward_kinematics(th2) 57 | pos, rot = quat_pos_from_transform3d(tg) 58 | # note that since we are using existing operations we are not checking grad calculation correctness 59 | assert th2.grad is None 60 | pos.norm().backward() 61 | assert th2.grad is not None 62 | 63 | 64 | def test_urdf(): 65 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 66 | print(chain) 67 | print(chain.get_joint_parameter_names()) 68 | th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0] 69 | ret = chain.forward_kinematics(th, end_only=False) 70 | tg = ret['lbr_iiwa_link_7'] 71 | pos, rot = quat_pos_from_transform3d(tg) 72 | assert quaternion_equality(rot, torch.tensor([7.07106781e-01, 0, -7.07106781e-01, 0])) 73 | assert torch.allclose(pos, torch.tensor([-6.60827561e-01, 0, 3.74142136e-01])) 74 | 75 | N = 1000 76 | d = "cuda" if torch.cuda.is_available() else "cpu" 77 | dtype = torch.float64 78 | 79 | th_batch = torch.rand(N, len(chain.get_joint_parameter_names()), dtype=dtype, device=d) 80 | chain = chain.to(dtype=dtype, device=d) 81 | 82 | import time 83 | start = time.time() 84 | tg_batch = chain.forward_kinematics(th_batch) 85 | m = tg_batch.get_matrix() 86 | elapsed = time.time() - start 87 | print("elapsed {}s for N={} when parallel".format(elapsed, N)) 88 | 89 | start = time.time() 90 | elapsed = 0 91 | for i in range(N): 92 | tg = chain.forward_kinematics(th_batch[i]) 93 | elapsed += time.time() - start 94 | start = time.time() 95 | assert torch.allclose(tg.get_matrix().view(4, 4), m[i]) 96 | print("elapsed {}s for N={} when serial".format(elapsed, N)) 97 | 98 | 99 | # test robot with prismatic and fixed joints 100 | def test_fk_simple_arm(): 101 | chain = pk.build_chain_from_sdf(open("simple_arm.sdf").read()) 102 | # print(chain) 103 | # print(chain.get_joint_parameter_names()) 104 | ret = chain.forward_kinematics({'arm_elbow_pan_joint': math.pi / 2.0, 'arm_wrist_lift_joint': -0.5}) 105 | tg = ret['arm_wrist_roll'] 106 | pos, rot = quat_pos_from_transform3d(tg) 107 | assert quaternion_equality(rot, torch.tensor([0.70710678, 0., 0., 0.70710678])) 108 | assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5])) 109 | 110 | N = 100 111 | ret = chain.forward_kinematics({'arm_elbow_pan_joint': torch.rand(N, 1), 'arm_wrist_lift_joint': torch.rand(N, 1)}) 112 | tg = ret['arm_wrist_roll'] 113 | assert list(tg.get_matrix().shape) == [N, 4, 4] 114 | 115 | 116 | def test_cuda(): 117 | if torch.cuda.is_available(): 118 | d = "cuda" 119 | dtype = torch.float64 120 | chain = pk.build_chain_from_sdf(open("simple_arm.sdf").read()) 121 | chain = chain.to(dtype=dtype, device=d) 122 | 123 | ret = chain.forward_kinematics({'arm_elbow_pan_joint': math.pi / 2.0, 'arm_wrist_lift_joint': -0.5}) 124 | tg = ret['arm_wrist_roll'] 125 | pos, rot = quat_pos_from_transform3d(tg) 126 | assert quaternion_equality(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=dtype, device=d)) 127 | assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=dtype, device=d)) 128 | 129 | data = '' \ 130 | '' \ 131 | '' \ 132 | '' \ 133 | '' \ 134 | '' \ 135 | '' \ 136 | '' \ 137 | '' \ 138 | '' \ 139 | '' \ 140 | '' \ 141 | '' \ 142 | '' \ 143 | '' 144 | chain = pk.build_serial_chain_from_urdf(data, 'link3') 145 | chain = chain.to(dtype=dtype, device=d) 146 | N = 20 147 | th_batch = torch.rand(N, 2).to(device=d, dtype=dtype) 148 | tg_batch = chain.forward_kinematics(th_batch) 149 | m = tg_batch.get_matrix() 150 | for i in range(N): 151 | tg = chain.forward_kinematics(th_batch[i]) 152 | assert torch.allclose(tg.get_matrix().view(4, 4), m[i]) 153 | 154 | 155 | # test more complex robot and the MJCF parser 156 | def test_fk_mjcf(): 157 | chain = pk.build_chain_from_mjcf(open("ant.xml").read()) 158 | print(chain) 159 | print(chain.get_joint_parameter_names()) 160 | th = {'hip_1': 1.0, 'ankle_1': 1} 161 | ret = chain.forward_kinematics(th) 162 | tg = ret['aux_1_child'] 163 | pos, rot = quat_pos_from_transform3d(tg) 164 | assert quaternion_equality(rot, torch.tensor([0.87758256, 0., 0., 0.47942554])) 165 | assert torch.allclose(pos, torch.tensor([0.2, 0.2, 0.75])) 166 | tg = ret['front_left_foot_child'] 167 | pos, rot = quat_pos_from_transform3d(tg) 168 | assert quaternion_equality(rot, torch.tensor([0.77015115, -0.4600326, 0.13497724, 0.42073549])) 169 | assert torch.allclose(pos, torch.tensor([0.13976626, 0.47635466, 0.75])) 170 | print(ret) 171 | 172 | 173 | def test_fk_mjcf_humanoid(): 174 | chain = pk.build_chain_from_mjcf(open("humanoid.xml").read()) 175 | print(chain) 176 | print(chain.get_joint_parameter_names()) 177 | th = {'left_knee': 0.0, 'right_knee': 0.0} 178 | ret = chain.forward_kinematics(th) 179 | print(ret) 180 | 181 | 182 | if __name__ == "__main__": 183 | test_fkik() 184 | test_fk_simple_arm() 185 | test_fk_mjcf() 186 | test_cuda() 187 | test_urdf() 188 | # test_fk_mjcf_humanoid() 189 | -------------------------------------------------------------------------------- /thirdparty/pytorch_kinematics/tests/test_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_kinematics.transforms as tf 3 | 4 | 5 | def test_transform(): 6 | N = 20 7 | mats = tf.random_rotations(N, dtype=torch.float64, device="cpu", requires_grad=True) 8 | assert list(mats.shape) == [N, 3, 3] 9 | # test batch conversions 10 | quat = tf.matrix_to_quaternion(mats) 11 | assert list(quat.shape) == [N, 4] 12 | mats_recovered = tf.quaternion_to_matrix(quat) 13 | assert torch.allclose(mats, mats_recovered) 14 | 15 | quat_identity = tf.quaternion_multiply(quat, tf.quaternion_invert(quat)) 16 | assert torch.allclose(tf.quaternion_to_matrix(quat_identity), torch.eye(3, dtype=torch.float64).repeat(N, 1, 1)) 17 | 18 | 19 | def test_translations(): 20 | t = tf.Translate(1, 2, 3) 21 | points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( 22 | 1, 3, 3 23 | ) 24 | points_out = t.transform_points(points) 25 | points_out_expected = torch.tensor( 26 | [[2.0, 2.0, 3.0], [1.0, 3.0, 3.0], [1.5, 2.5, 3.0]] 27 | ).view(1, 3, 3) 28 | assert torch.allclose(points_out, points_out_expected) 29 | 30 | N = 20 31 | points = torch.randn((N, N, 3)) 32 | translation = torch.randn((N, 3)) 33 | transforms = tf.Transform3d(pos=translation) 34 | translated_points = transforms.transform_points(points) 35 | assert torch.allclose(translated_points, translation.repeat(N, 1, 1).transpose(0, 1) + points) 36 | returned_points = transforms.inverse().transform_points(translated_points) 37 | assert torch.allclose(returned_points, points, atol=1e-6) 38 | 39 | 40 | def test_rotate_axis_angle(): 41 | t = tf.Transform3d().rotate_axis_angle(90.0, axis="Z") 42 | points = torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 1.0]]).view( 43 | 1, 3, 3 44 | ) 45 | normals = torch.tensor( 46 | [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]] 47 | ).view(1, 3, 3) 48 | points_out = t.transform_points(points) 49 | normals_out = t.transform_normals(normals) 50 | points_out_expected = torch.tensor( 51 | [[0.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 1.0]] 52 | ).view(1, 3, 3) 53 | normals_out_expected = torch.tensor( 54 | [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] 55 | ).view(1, 3, 3) 56 | assert torch.allclose(points_out, points_out_expected) 57 | assert torch.allclose(normals_out, normals_out_expected) 58 | 59 | 60 | def test_rotate(): 61 | R = tf.so3_exp_map(torch.randn((1, 3))) 62 | t = tf.Transform3d().rotate(R) 63 | points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( 64 | 1, 3, 3 65 | ) 66 | normals = torch.tensor( 67 | [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]] 68 | ).view(1, 3, 3) 69 | points_out = t.transform_points(points) 70 | normals_out = t.transform_normals(normals) 71 | points_out_expected = torch.bmm(points, R.transpose(-1, -2)) 72 | normals_out_expected = torch.bmm(normals, R.transpose(-1, -2)) 73 | assert torch.allclose(points_out, points_out_expected) 74 | assert torch.allclose(normals_out, normals_out_expected) 75 | for i in range(3): 76 | assert torch.allclose(points_out[0, i], R @ points[0, i]) 77 | assert torch.allclose(normals_out[0, i], R @ normals[0, i]) 78 | 79 | 80 | def test_transform_combined(): 81 | R = tf.so3_exp_map(torch.randn((1, 3))) 82 | tr = torch.randn((1, 3)) 83 | t = tf.Transform3d(rot=R, pos=tr) 84 | N = 10 85 | points = torch.randn((N, 3)) 86 | normals = torch.randn((N, 3)) 87 | points_out = t.transform_points(points) 88 | normals_out = t.transform_normals(normals) 89 | for i in range(N): 90 | assert torch.allclose(points_out[i], R @ points[i] + tr) 91 | assert torch.allclose(normals_out[i], R @ normals[i]) 92 | 93 | 94 | def test_euler(): 95 | euler_angles = torch.tensor([1, 0, 0.5]) 96 | t = tf.Transform3d(rot=euler_angles) 97 | sxyz_matrix = torch.tensor([[0.87758256, -0.47942554, 0., 0., ], 98 | [0.25903472, 0.47415988, -0.84147098, 0.], 99 | [0.40342268, 0.73846026, 0.54030231, 0.], 100 | [0., 0., 0., 1.]]) 101 | # from tf.transformations import euler_matrix 102 | # print(euler_matrix(*euler_angles, "rxyz")) 103 | # print(t.get_matrix()) 104 | assert torch.allclose(sxyz_matrix, t.get_matrix()) 105 | 106 | 107 | def test_quaternions(): 108 | n = 10 109 | q = tf.random_quaternions(n) 110 | q_tf = tf.wxyz_to_xyzw(q) 111 | assert torch.allclose(q, tf.xyzw_to_wxyz(q_tf)) 112 | 113 | 114 | if __name__ == "__main__": 115 | test_transform() 116 | test_translations() 117 | test_rotate_axis_angle() 118 | test_rotate() 119 | test_euler() 120 | test_quaternions() 121 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config_utils import EasyConfig 2 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os 4 | from ast import literal_eval 5 | from typing import Any, Dict, List, Tuple, Union 6 | from multimethod import multimethod 7 | import yaml 8 | 9 | 10 | class EasyConfig(dict): 11 | 12 | def __getattr__(self, key: str) -> Any: 13 | if key not in self: 14 | raise AttributeError(key) 15 | return self[key] 16 | 17 | def __setattr__(self, key: str, value: Any) -> None: 18 | self[key] = value 19 | 20 | def __delattr__(self, key: str) -> None: 21 | del self[key] 22 | 23 | def load(self, fpath: str, *, recursive: bool = False) -> None: 24 | """load cfg from yaml 25 | 26 | Args: 27 | fpath (str): path to the yaml file 28 | recursive (bool, optional): recursily load its parent defaul yaml files. Defaults to False. 29 | """ 30 | if not os.path.exists(fpath): 31 | raise FileNotFoundError(fpath) 32 | fpaths = [fpath] 33 | if recursive: 34 | extension = os.path.splitext(fpath)[1] 35 | while os.path.dirname(fpath) != fpath: 36 | fpath = os.path.dirname(fpath) 37 | fpaths.append(os.path.join(fpath, 'default' + extension)) 38 | for fpath in reversed(fpaths): 39 | if os.path.exists(fpath): 40 | with open(fpath) as f: 41 | cfg_dict = yaml.safe_load(f) 42 | self.update(cfg_dict) 43 | 44 | def reload(self, fpath: str, *, recursive: bool = False) -> None: 45 | self.clear() 46 | self.load(fpath, recursive=recursive) 47 | 48 | # mutimethod makes python supports function overloading 49 | @multimethod 50 | def update(self, other: Dict) -> None: 51 | for key, value in other.items(): 52 | if isinstance(value, dict): 53 | if key not in self or not isinstance(self[key], EasyConfig): 54 | self[key] = EasyConfig() 55 | # recursively update 56 | self[key].update(value) 57 | else: 58 | self[key] = value 59 | 60 | @multimethod 61 | def update(self, opts: Union[List, Tuple]) -> None: 62 | index = 0 63 | while index < len(opts): 64 | opt = opts[index] 65 | if opt.startswith('--'): 66 | opt = opt[2:] 67 | if '=' in opt: 68 | key, value = opt.split('=', 1) 69 | index += 1 70 | else: 71 | key, value = opt, opts[index + 1] 72 | index += 2 73 | current = self 74 | subkeys = key.split('.') 75 | try: 76 | value = literal_eval(value) 77 | except: 78 | pass 79 | for subkey in subkeys[:-1]: 80 | current = current.setdefault(subkey, EasyConfig()) 81 | current[subkeys[-1]] = value 82 | 83 | def dict(self) -> Dict[str, Any]: 84 | configs = dict() 85 | for key, value in self.items(): 86 | if isinstance(value, EasyConfig): 87 | value = value.dict() 88 | configs[key] = value 89 | return configs 90 | 91 | def hash(self) -> str: 92 | buffer = json.dumps(self.dict(), sort_keys=True) 93 | return hashlib.sha256(buffer.encode()).hexdigest() 94 | 95 | def __str__(self) -> str: 96 | texts = [] 97 | for key, value in self.items(): 98 | if isinstance(value, EasyConfig): 99 | seperator = '\n' 100 | else: 101 | seperator = ' ' 102 | text = key + ':' + seperator + str(value) 103 | lines = text.split('\n') 104 | for k, line in enumerate(lines[1:]): 105 | lines[k + 1] = (' ' * 2) + line 106 | texts.extend(lines) 107 | return '\n'.join(texts) 108 | -------------------------------------------------------------------------------- /utils/rotation_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | import pytorch3d.transforms as T 5 | from torch.functional import Tensor 6 | 7 | 8 | class EulerConverter: 9 | """ 10 | A class for converting euler angles with axes = 'sxyz' (as defined in transforms3d) 11 | to other rotation representations. 12 | 13 | Support batch operations. 14 | 15 | Expect tensor of shape (..., 3) as input. 16 | 17 | All types of outputs can be transformed to rotation matrices, which 18 | are identical to the results of transforms3d.euler.euler2mat(euler, axes='sxyz'), 19 | by functions like pytorch3d.transforms.XXX_to_matrix(). 20 | """ 21 | 22 | def to_euler(self, euler): 23 | return euler 24 | 25 | def to_matrix(self, euler): 26 | return T.axis_angle_to_matrix( 27 | self.to_axisangle(euler) 28 | ) 29 | 30 | def to_rotation_6d(self, euler): 31 | return T.matrix_to_rotation_6d( 32 | self.to_matrix(euler), 33 | ) 34 | 35 | def to_quaternion(self, euler): 36 | return T.axis_angle_to_quaternion( 37 | self.to_axisangle(euler) 38 | ) 39 | 40 | def to_axisangle(self, euler): 41 | return torch.flip( 42 | T.matrix_to_axis_angle(T.euler_angles_to_matrix(euler, "ZYX")), 43 | dims=[-1] 44 | ) 45 | 46 | 47 | class RotNorm: 48 | """ 49 | A class for normalizing rotation representations 50 | """ 51 | @staticmethod 52 | def norm_euler(euler: Tensor) -> Tensor: 53 | """ 54 | euler: A tensor of size: (B, 3, N) 55 | """ 56 | lower_bounds = torch.ones_like(euler) * math.pi * -1.0 57 | upper_bounds = torch.ones_like(euler) * math.pi 58 | return (euler - lower_bounds) / (upper_bounds - lower_bounds) 59 | 60 | @staticmethod 61 | def norm_quaternion(quaternion: Tensor) -> Tensor: 62 | """ 63 | quaternion: A tensor of size: (B, 4, N) 64 | """ 65 | return F.normalize(quaternion) 66 | 67 | @staticmethod 68 | def norm_rotation_6d(rot6d: Tensor) -> Tensor: 69 | """ 70 | rot6d: A tensor of size: (B, 6, N) 71 | """ 72 | vector_1 = F.normalize(rot6d[:, :3, :]) 73 | vector_2 = F.normalize(rot6d[:, 3:6, :]) 74 | return torch.cat([vector_1, vector_2], dim=1) 75 | 76 | def norm_other(tensor) -> Tensor: 77 | return tensor 78 | 79 | 80 | class Rot2Axisangle: 81 | """ 82 | A class for converting rotation representations to axisangle 83 | """ 84 | @staticmethod 85 | def euler2axisangle(euler): 86 | return torch.flip( 87 | T.matrix_to_axis_angle(T.euler_angles_to_matrix(euler, "ZYX")), 88 | dims=[-1] 89 | ) 90 | 91 | @staticmethod 92 | def quaternion2axisangle(quaternion): 93 | return T.quaternion_to_axis_angle(quaternion) 94 | 95 | @staticmethod 96 | def rotation_6d2axisangle(rot6d): 97 | B, N = rot6d.shape[:2] 98 | mat = robust_compute_rotation_matrix_from_ortho6d(rot6d.reshape(B * N, 6)) 99 | return T.matrix_to_axis_angle(mat.reshape(B, N, 3, 3)) 100 | 101 | @staticmethod 102 | def matrix2axisangle(mat): 103 | return T.matrix_to_axis_angle(mat) 104 | 105 | @staticmethod 106 | def axisangle2axisangle(axisangle): 107 | return axisangle 108 | 109 | 110 | # Codes borrowed from dexgraspnet 111 | def robust_compute_rotation_matrix_from_ortho6d(poses): 112 | """ 113 | Instead of making 2nd vector orthogonal to first 114 | create a base that takes into account the two predicted 115 | directions equally 116 | """ 117 | x_raw = poses[:, 0:3] # batch*3 118 | y_raw = poses[:, 3:6] # batch*3 119 | 120 | x = normalize_vector(x_raw) # batch*3 121 | y = normalize_vector(y_raw) # batch*3 122 | middle = normalize_vector(x + y) 123 | orthmid = normalize_vector(x - y) 124 | x = normalize_vector(middle + orthmid) 125 | y = normalize_vector(middle - orthmid) 126 | # Their scalar product should be small ! 127 | # assert torch.einsum("ij,ij->i", [x, y]).abs().max() < 0.00001 128 | z = normalize_vector(cross_product(x, y)) 129 | 130 | x = x.view(-1, 3, 1) 131 | y = y.view(-1, 3, 1) 132 | z = z.view(-1, 3, 1) 133 | matrix = torch.cat((x, y, z), 2) # batch*3*3 134 | # Check for reflection in matrix ! If found, flip last vector TODO 135 | # assert (torch.stack([torch.det(mat) for mat in matrix ])< 0).sum() == 0 136 | return matrix 137 | 138 | 139 | def normalize_vector(v): 140 | batch = v.shape[0] 141 | v_mag = torch.sqrt(v.pow(2).sum(1)) # batch 142 | v_mag = torch.max(v_mag, v.new([1e-8])) 143 | v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1]) 144 | v = v/v_mag 145 | return v 146 | 147 | 148 | def cross_product(u, v): 149 | batch = u.shape[0] 150 | i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1] 151 | j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2] 152 | k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0] 153 | 154 | out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) 155 | 156 | return out 157 | --------------------------------------------------------------------------------