├── LICENSE ├── README.md ├── config ├── simple_dp3.yaml └── task │ ├── _real_world_task_template.yaml │ ├── rlbench │ ├── close_jar.yaml │ ├── insert_onto_square_peg.yaml │ ├── light_bulb_in.yaml │ ├── meat_off_grill.yaml │ ├── place_cups.yaml │ ├── place_shape_in_shape_sorter.yaml │ ├── place_wine_at_rack_location.yaml │ ├── put_groceries_in_cupboard.yaml │ ├── put_money_in_safe.yaml │ ├── reach_and_drag.yaml │ ├── stack_blocks.yaml │ ├── stack_cups.yaml │ └── turn_tap.yaml │ └── rlbench_multi.yaml ├── diffusion_policy_3d ├── LICENSE ├── common │ ├── checkpoint_util.py │ ├── logger_util.py │ ├── model_util.py │ ├── pytorch_util.py │ ├── replay_buffer.py │ ├── sampler.py │ └── sampler_multitask.py ├── dataset │ ├── __init__.py │ ├── base_dataset.py │ ├── rlbench_base_dataset.py │ ├── rlbench_dataset.py │ └── rlbench_dataset_list.py ├── env_runner │ └── base_runner.py ├── model │ ├── clip │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── clip.py │ │ └── simple_tokenizer.py │ ├── common │ │ ├── dict_of_tensor_mixin.py │ │ ├── geodesic_loss.py │ │ ├── lr_scheduler.py │ │ ├── module_attr_mixin.py │ │ ├── normalizer.py │ │ ├── shape_util.py │ │ └── tensor_util.py │ ├── diffusion │ │ ├── conditional_unet1d.py │ │ ├── conv1d_components.py │ │ ├── ema_model.py │ │ ├── mask_generator.py │ │ ├── positional_embedding.py │ │ ├── simple_conditional_unet1d.py │ │ └── simple_conditional_unet1d_progress.py │ └── vision │ │ └── pointnet_extractor.py ├── policy │ ├── base_policy.py │ ├── dp3.py │ └── simple_dp3.py └── setup.py ├── dp3_requirements.txt ├── env_real ├── data │ ├── collect_zarr_real.py │ ├── prepare_mask.py │ ├── prepare_pose.py │ └── prepare_rgbd.py └── utils │ └── realworld_objects.py ├── env_rlbench ├── policy │ ├── dp3_policy.py │ └── subgoal_policy.py ├── runner │ ├── rl_bench_camera.py │ ├── rl_bench_dataset.py │ ├── rl_bench_env.py │ └── rlbench_runner.py ├── setup.py └── utils │ └── rlbench_utils.py ├── env_rlbench_peract ├── data │ ├── collect_zarr_rlbench_peract.py │ └── collect_zarr_rlbench_peract_multi_stage.py ├── setup.py └── utils │ ├── rlbench_objects.py │ └── rlbench_utils.py ├── foundation_pose ├── LICENSE ├── Utils.py ├── build_all.sh ├── build_all_conda.sh ├── datareader.py ├── estimater.py ├── learning │ ├── datasets │ │ ├── h5_dataset.py │ │ └── pose_dataset.py │ ├── models │ │ ├── network_modules.py │ │ ├── refine_network.py │ │ └── score_network.py │ └── training │ │ ├── predict_pose_refine.py │ │ ├── predict_score.py │ │ └── training_config.py ├── mycpp │ ├── CMakeLists.txt │ ├── include │ │ └── Utils.h │ └── src │ │ ├── Utils.cpp │ │ └── app │ │ └── pybind_api.cpp ├── setup.py └── wrapper.py ├── fp_requirements.txt ├── scripts ├── eval_policy.sh ├── eval_policy_multi.sh ├── gen_demonstration_real.sh ├── gen_demonstration_rlbench.sh └── train_policy.sh ├── tools ├── eval_dp3.py └── train_dp3.py └── utils ├── collect_utils.py ├── collect_utils_rlbench.py ├── io_utils.py ├── logger_utils.py ├── mask_utils.py ├── pose_utils.py ├── transform_utils.py ├── vis_o3d_utils.py └── vis_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024-Present, NVIDIA Corporation & affiliates. All rights reserved. 2 | 3 | 4 | ======================================================================= 5 | 6 | 1. Definitions 7 | 8 | "Licensor" means any person or entity that distributes its Work. 9 | 10 | "Software" means the original work of authorship made available under 11 | this License. 12 | 13 | "Work" means the Software and any additions to or derivative works of 14 | the Software that are made available under this License. 15 | 16 | The terms "reproduce," "reproduction," "derivative works," and 17 | "distribution" have the meaning as provided under U.S. copyright law; 18 | provided, however, that for the purposes of this License, derivative 19 | works shall not include works that remain separable from, or merely 20 | link (or bind by name) to the interfaces of, the Work. 21 | 22 | Works, including the Software, are "made available" under this License 23 | by including in or with the Work either (a) a copyright notice 24 | referencing the applicability of this License to the Work, or (b) a 25 | copy of this License. 26 | 27 | 2. License Grants 28 | 29 | 2.1 Copyright Grant. Subject to the terms and conditions of this 30 | License, each Licensor grants to you a perpetual, worldwide, 31 | non-exclusive, royalty-free, copyright license to reproduce, 32 | prepare derivative works of, publicly display, publicly perform, 33 | sublicense and distribute its Work and any resulting derivative 34 | works in any form. 35 | 36 | 3. Limitations 37 | 38 | 3.1 Redistribution. You may reproduce or distribute the Work only 39 | if (a) you do so under this License, (b) you include a complete 40 | copy of this License with your distribution, and (c) you retain 41 | without modification any copyright, patent, trademark, or 42 | attribution notices that are present in the Work. 43 | 44 | 3.2 Derivative Works. You may specify that additional or different 45 | terms apply to the use, reproduction, and distribution of your 46 | derivative works of the Work ("Your Terms") only if (a) Your Terms 47 | provide that the use limitation in Section 3.3 applies to your 48 | derivative works, and (b) you identify the specific derivative 49 | works that are subject to Your Terms. Notwithstanding Your Terms, 50 | this License (including the redistribution requirements in Section 51 | 3.1) will continue to apply to the Work itself. 52 | 53 | 3.3 Use Limitation. The Work and any derivative works thereof only 54 | may be used or intended for use non-commercially. Notwithstanding 55 | the foregoing, NVIDIA and its affiliates may use the Work and any 56 | derivative works commercially. As used herein, "non-commercially" 57 | means for research or evaluation purposes only. 58 | 59 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 60 | against any Licensor (including any claim, cross-claim or 61 | counterclaim in a lawsuit) to enforce any patents that you allege 62 | are infringed by any Work, then your rights under this License from 63 | such Licensor (including the grant in Section 2.1) will terminate 64 | immediately. 65 | 66 | 3.5 Trademarks. This License does not grant any rights to use any 67 | Licensor�s or its affiliates� names, logos, or trademarks, except 68 | as necessary to reproduce the notices described in this License. 69 | 70 | 3.6 Termination. If you violate any term of this License, then your 71 | rights under this License (including the grant in Section 2.1) will 72 | terminate immediately. 73 | 74 | 4. Disclaimer of Warranty. 75 | 76 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 77 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 78 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 79 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 80 | THIS LICENSE. 81 | 82 | 5. Limitation of Liability. 83 | 84 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 85 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 86 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 87 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 88 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 89 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 90 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 91 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 92 | THE POSSIBILITY OF SUCH DAMAGES. 93 | 94 | ======================================================================= -------------------------------------------------------------------------------- /config/simple_dp3.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task: adroit_hammer 3 | 4 | name: train_simple_dp3 5 | 6 | 7 | task_name: ${task.name} 8 | shape_meta: ${task.shape_meta} 9 | exp_name: "debug" 10 | 11 | horizon: 4 #8 #4 12 | n_obs_steps: 1 #4 #2 13 | n_action_steps: 4 14 | n_latency_steps: 0 15 | dataset_obs_steps: ${n_obs_steps} 16 | keypoint_visible_rate: 1.0 17 | obs_as_global_cond: True 18 | 19 | pose_estimation: 20 | mesh_dir: "/tmp/mesh/${task.task_name}/" 21 | 22 | policy: 23 | _target_: diffusion_policy_3d.policy.simple_dp3.SimpleDP3 24 | use_point_crop: true 25 | condition_type: film 26 | use_down_condition: true 27 | use_mid_condition: true 28 | use_up_condition: true 29 | use_lang_emb: true 30 | use_stage_emb: true 31 | use_progress: true 32 | 33 | diffusion_step_embed_dim: 128 34 | down_dims: 35 | - 128 36 | - 256 37 | - 384 38 | crop_shape: 39 | - 80 40 | - 80 41 | encoder_output_dim: 64 42 | horizon: ${horizon} 43 | kernel_size: 5 44 | n_action_steps: ${n_action_steps} 45 | n_groups: 8 46 | n_obs_steps: ${n_obs_steps} 47 | 48 | noise_scheduler: 49 | _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler 50 | num_train_timesteps: 100 51 | beta_start: 0.0001 52 | beta_end: 0.02 53 | beta_schedule: squaredcos_cap_v2 54 | clip_sample: True 55 | set_alpha_to_one: True 56 | steps_offset: 0 57 | prediction_type: sample 58 | 59 | 60 | num_inference_steps: 10 61 | obs_as_global_cond: true 62 | shape_meta: ${shape_meta} 63 | 64 | 65 | 66 | # whether to use point cloud's color 67 | use_pc_color: false 68 | # type of pointnet 69 | pointnet_type: "pointnet" 70 | 71 | 72 | pointcloud_encoder_cfg: 73 | in_channels: 3 74 | out_channels: ${policy.encoder_output_dim} 75 | use_layernorm: true 76 | final_norm: layernorm # layernorm, none 77 | normal_channel: false 78 | 79 | 80 | ema: 81 | _target_: diffusion_policy_3d.model.diffusion.ema_model.EMAModel 82 | update_after_step: 0 83 | inv_gamma: 1.0 84 | power: 0.75 85 | min_value: 0.0 86 | max_value: 0.9999 87 | 88 | dataloader: 89 | batch_size: 128 90 | num_workers: 8 91 | shuffle: True 92 | pin_memory: True 93 | persistent_workers: False 94 | 95 | val_dataloader: 96 | batch_size: 128 97 | num_workers: 8 98 | shuffle: False 99 | pin_memory: True 100 | persistent_workers: False 101 | 102 | optimizer: 103 | _target_: torch.optim.AdamW 104 | lr: 1.0e-4 105 | betas: [0.95, 0.999] 106 | eps: 1.0e-8 107 | weight_decay: 1.0e-6 108 | 109 | training: 110 | device: "cuda:0" 111 | seed: 42 112 | debug: False 113 | resume: True 114 | lr_scheduler: cosine 115 | lr_warmup_steps: 500 116 | num_epochs: 10001 #3000 117 | gradient_accumulate_every: 1 118 | use_ema: True 119 | rollout_every: 200 120 | checkpoint_every: 1000 #200 121 | val_every: 100 #1 122 | sample_every: 5 123 | max_train_steps: null 124 | max_val_steps: null 125 | tqdm_interval_sec: 1.0 126 | 127 | logging: 128 | group: ${exp_name} 129 | id: null 130 | mode: online 131 | name: ${task.task_name}_${training.seed} 132 | project: dp3-${task.name} 133 | resume: true 134 | tags: 135 | - dp3 136 | 137 | 138 | checkpoint: 139 | save_ckpt: True #False # if True, save checkpoint every checkpoint_every 140 | topk: 141 | monitor_key: test_mean_score 142 | mode: max 143 | k: 1 144 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 145 | save_last_ckpt: True # this only saves when save_ckpt is True 146 | save_last_snapshot: False 147 | 148 | multi_run: 149 | run_dir: outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 150 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 151 | 152 | hydra: 153 | job: 154 | override_dirname: ${name} 155 | run: 156 | dir: outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 157 | sweep: 158 | dir: outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 159 | subdir: ${hydra.job.num} 160 | 161 | 162 | evaluation: 163 | eval_epoch: null -------------------------------------------------------------------------------- /config/task/_real_world_task_template.yaml: -------------------------------------------------------------------------------- 1 | name: real_world 2 | task_name: pour_water 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: null # !! implementing customized class for real robot deployment 18 | task_name: ${task.task_name} 19 | 20 | has_lang_emb: False 21 | has_stage_emb: False 22 | enable_stage: False 23 | pose_estimation_wrapper: null 24 | 25 | dataset: 26 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 27 | root_dir: /tmp/pour_water/ 28 | 29 | horizon: ${horizon} 30 | pad_before: ${eval:'${n_obs_steps}'} 31 | pad_after: ${eval:'${n_action_steps}'} 32 | seed: 42 33 | val_ratio: 0.1 34 | max_train_episodes: 100 35 | has_lang_emb: False 36 | has_stage_emb: False 37 | random_aug: True 38 | symmetric_axis: null 39 | symmetric_theta_start: null 40 | symmetric_theta_end: null 41 | symmetric_theta_step: null 42 | -------------------------------------------------------------------------------- /config/task/rlbench/close_jar.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: close_jar 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: False 22 | has_stage_emb: False 23 | enable_stage: False 24 | use_fp: True 25 | fp_cam_name: cam_wrist 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: False 39 | has_stage_emb: False 40 | random_aug: True 41 | symmetric_axis: null 42 | symmetric_theta_start: null 43 | symmetric_theta_end: null 44 | symmetric_theta_step: null 45 | -------------------------------------------------------------------------------- /config/task/rlbench/insert_onto_square_peg.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: insert_onto_square_peg 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: False 22 | has_stage_emb: False 23 | enable_stage: False 24 | use_fp: True 25 | fp_cam_name: cam_wrist 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: False 39 | has_stage_emb: False 40 | random_aug: True 41 | symmetric_axis: [[0., 0., 1.], [1., 0., 0.], [0., 1., 0.]] 42 | symmetric_theta_start: [-180., -180, -180] 43 | symmetric_theta_end: [180., 180., 180.] 44 | symmetric_theta_step: [90, 180, 180] 45 | -------------------------------------------------------------------------------- /config/task/rlbench/light_bulb_in.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: light_bulb_in 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: False 22 | has_stage_emb: False 23 | enable_stage: False 24 | use_fp: True 25 | fp_cam_name: custom_cam_front_light_bulb_in 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: False 39 | has_stage_emb: False 40 | random_aug: True 41 | symmetric_axis: null 42 | symmetric_theta_start: null 43 | symmetric_theta_end: null 44 | symmetric_theta_step: null 45 | -------------------------------------------------------------------------------- /config/task/rlbench/meat_off_grill.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: meat_off_grill 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: True 22 | has_stage_emb: False 23 | enable_stage: False 24 | use_fp: True 25 | fp_cam_name: cam_wrist 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: True 39 | has_stage_emb: False 40 | random_aug: True 41 | symmetric_axis: null 42 | symmetric_theta_start: null 43 | symmetric_theta_end: null 44 | symmetric_theta_step: null -------------------------------------------------------------------------------- /config/task/rlbench/place_cups.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: place_cups 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: True 22 | has_stage_emb: True 23 | enable_stage: True 24 | use_fp: True 25 | fp_cam_name: cam_wrist 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: True 39 | has_stage_emb: True 40 | random_aug: True 41 | symmetric_axis: null 42 | symmetric_theta_start: null 43 | symmetric_theta_end: null 44 | symmetric_theta_step: null 45 | -------------------------------------------------------------------------------- /config/task/rlbench/place_shape_in_shape_sorter.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: place_shape_in_shape_sorter 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: True 22 | has_stage_emb: False 23 | enable_stage: False 24 | use_fp: True 25 | fp_cam_name: cam_wrist 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: True 39 | has_stage_emb: False 40 | random_aug: True 41 | # symmetric_axis: [0., 0., 1.] 42 | # symmetric_theta_start: -180. 43 | # symmetric_theta_end: 180. 44 | # symmetric_theta_step: 90 45 | symmetric_axis: null 46 | symmetric_theta_start: null 47 | symmetric_theta_end: null 48 | symmetric_theta_step: null 49 | -------------------------------------------------------------------------------- /config/task/rlbench/place_wine_at_rack_location.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: place_wine_at_rack_location 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: True 22 | has_stage_emb: False 23 | enable_stage: False 24 | use_fp: True 25 | fp_cam_name: cam_front 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: True 39 | has_stage_emb: False 40 | random_aug: True 41 | symmetric_axis: [0., 0., 1.] 42 | symmetric_theta_start: -90. 43 | symmetric_theta_end: 90. 44 | symmetric_theta_step: 15. -------------------------------------------------------------------------------- /config/task/rlbench/put_groceries_in_cupboard.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: put_groceries_in_cupboard 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: True 22 | has_stage_emb: False 23 | enable_stage: False 24 | use_fp: True 25 | fp_cam_name: cam_wrist 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: True 39 | has_stage_emb: False 40 | random_aug: True 41 | symmetric_axis: null 42 | symmetric_theta_start: null 43 | symmetric_theta_end: null 44 | symmetric_theta_step: null 45 | -------------------------------------------------------------------------------- /config/task/rlbench/put_money_in_safe.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: put_money_in_safe 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: True 22 | has_stage_emb: False 23 | enable_stage: False 24 | use_fp: True 25 | fp_cam_name: cam_wrist 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: True 39 | has_stage_emb: False 40 | random_aug: True 41 | symmetric_axis: null 42 | symmetric_theta_start: null 43 | symmetric_theta_end: null 44 | symmetric_theta_step: null 45 | -------------------------------------------------------------------------------- /config/task/rlbench/reach_and_drag.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: reach_and_drag 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: False 22 | has_stage_emb: False 23 | enable_stage: False 24 | use_fp: True 25 | fp_cam_name: cam_front 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: False 39 | has_stage_emb: False 40 | random_aug: True 41 | symmetric_axis: [[0., 0., 1.], [1., 0., 0.], [0., 1., 0.]] 42 | symmetric_theta_start: [-180., -180, -180] 43 | symmetric_theta_end: [180., 180., 180.] 44 | symmetric_theta_step: [90, 180, 180] -------------------------------------------------------------------------------- /config/task/rlbench/stack_blocks.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: stack_blocks 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: False 22 | has_stage_emb: False 23 | enable_stage: True 24 | use_fp: True 25 | fp_cam_name: cam_wrist 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: False 39 | has_stage_emb: False 40 | random_aug: True 41 | symmetric_axis: [[0., 0., 1.], [1., 0., 0.], [0., 1., 0.]] 42 | symmetric_theta_start: [-180., -180, -180] 43 | symmetric_theta_end: [180., 180., 180.] 44 | symmetric_theta_step: [90, 90, 90] 45 | 46 | -------------------------------------------------------------------------------- /config/task/rlbench/stack_cups.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: stack_cups 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: False 22 | has_stage_emb: False 23 | enable_stage: True 24 | use_fp: True 25 | fp_cam_name: cam_wrist 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: False 39 | has_stage_emb: False 40 | random_aug: True 41 | symmetric_axis: [0., 0., 1.] 42 | symmetric_theta_start: -180. 43 | symmetric_theta_end: 180. 44 | symmetric_theta_step: 15. 45 | -------------------------------------------------------------------------------- /config/task/rlbench/turn_tap.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: turn_tap 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: 17 | _target_: env_rlbench.runner.rlbench_runner.RLBenchRunner 18 | root_dir: "/tmp/rlbench_zarr/test/${task.task_name}/all_variations/" 19 | task_name: ${task.task_name} 20 | 21 | has_lang_emb: True 22 | has_stage_emb: False 23 | enable_stage: False 24 | use_fp: True 25 | fp_cam_name: cam_front 26 | pose_estimation_wrapper: null 27 | 28 | dataset: 29 | _target_: diffusion_policy_3d.dataset.rlbench_dataset.RLBenchDataset 30 | root_dir: "/tmp/rlbench_zarr/train/${task.task_name}/all_variations/" 31 | 32 | horizon: 4 33 | pad_before: 1 34 | pad_after: 4 35 | seed: 42 36 | val_ratio: 0.1 37 | max_train_episodes: 100 38 | has_lang_emb: True 39 | has_stage_emb: False 40 | random_aug: True 41 | symmetric_axis: null 42 | symmetric_theta_start: null 43 | symmetric_theta_end: null 44 | symmetric_theta_step: null 45 | -------------------------------------------------------------------------------- /config/task/rlbench_multi.yaml: -------------------------------------------------------------------------------- 1 | name: rlbench 2 | task_name: multitask 3 | 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | point_cloud: 8 | shape: [1024, 3] 9 | type: point_cloud 10 | agent_pos: 11 | shape: [7] 12 | type: low_dimx 13 | action: 14 | shape: [8] 15 | 16 | env_runner: null 17 | 18 | dataset: 19 | _target_: diffusion_policy_3d.dataset.rlbench_dataset_list.RLBenchDatasetList 20 | root_dir: "/tmp/rlbench_zarr/train" # overrides all single-task configs -------------------------------------------------------------------------------- /diffusion_policy_3d/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yanjie Ze 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /diffusion_policy_3d/common/checkpoint_util.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | import os 3 | 4 | class TopKCheckpointManager: 5 | def __init__(self, 6 | save_dir, 7 | monitor_key: str, 8 | mode='min', 9 | k=1, 10 | format_str='epoch={epoch:03d}-train_loss={train_loss:.3f}.ckpt' 11 | ): 12 | assert mode in ['max', 'min'] 13 | assert k >= 0 14 | 15 | self.save_dir = save_dir 16 | self.monitor_key = monitor_key 17 | self.mode = mode 18 | self.k = k 19 | self.format_str = format_str 20 | self.path_value_map = dict() 21 | 22 | def get_ckpt_path(self, data: Dict[str, float]) -> Optional[str]: 23 | if self.k == 0: 24 | return None 25 | 26 | value = data[self.monitor_key] 27 | ckpt_path = os.path.join( 28 | self.save_dir, self.format_str.format(**data)) 29 | 30 | if len(self.path_value_map) < self.k: 31 | # under-capacity 32 | self.path_value_map[ckpt_path] = value 33 | return ckpt_path 34 | 35 | # at capacity 36 | sorted_map = sorted(self.path_value_map.items(), key=lambda x: x[1]) 37 | min_path, min_value = sorted_map[0] 38 | max_path, max_value = sorted_map[-1] 39 | 40 | delete_path = None 41 | if self.mode == 'max': 42 | if value > min_value: 43 | delete_path = min_path 44 | else: 45 | if value < max_value: 46 | delete_path = max_path 47 | 48 | if delete_path is None: 49 | return None 50 | else: 51 | del self.path_value_map[delete_path] 52 | self.path_value_map[ckpt_path] = value 53 | 54 | if not os.path.exists(self.save_dir): 55 | os.mkdir(self.save_dir) 56 | 57 | if os.path.exists(delete_path): 58 | os.remove(delete_path) 59 | return ckpt_path 60 | -------------------------------------------------------------------------------- /diffusion_policy_3d/common/logger_util.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | 3 | class LargestKRecorder: 4 | def __init__(self, K): 5 | """ 6 | Initialize the EfficientScalarRecorder. 7 | 8 | Parameters: 9 | - K: Number of largest scalars to consider when computing the average. 10 | """ 11 | self.scalars = [] 12 | self.K = K 13 | 14 | def record(self, scalar): 15 | """ 16 | Record a scalar value. 17 | 18 | Parameters: 19 | - scalar: The scalar value to be recorded. 20 | """ 21 | if len(self.scalars) < self.K: 22 | heapq.heappush(self.scalars, scalar) 23 | else: 24 | # Compare the new scalar with the smallest value in the heap 25 | if scalar > self.scalars[0]: 26 | heapq.heappushpop(self.scalars, scalar) 27 | 28 | def average_of_largest_K(self): 29 | """ 30 | Compute the average of the largest K scalar values recorded. 31 | 32 | Returns: 33 | - avg: Average of the largest K scalars. 34 | """ 35 | if len(self.scalars) == 0: 36 | raise ValueError("No scalars have been recorded yet.") 37 | 38 | return sum(self.scalars) / len(self.scalars) 39 | 40 | -------------------------------------------------------------------------------- /diffusion_policy_3d/common/model_util.py: -------------------------------------------------------------------------------- 1 | from termcolor import cprint 2 | 3 | def print_params(model): 4 | """ 5 | Print the number of parameters in each part of the model. 6 | """ 7 | params_dict = {} 8 | 9 | all_num_param = sum(p.numel() for p in model.parameters()) 10 | 11 | for name, param in model.named_parameters(): 12 | part_name = name.split('.')[0] 13 | if part_name not in params_dict: 14 | params_dict[part_name] = 0 15 | params_dict[part_name] += param.numel() 16 | 17 | cprint(f'----------------------------------', 'cyan') 18 | cprint(f'Class name: {model.__class__.__name__}', 'cyan') 19 | cprint(f' Number of parameters: {all_num_param / 1e6:.4f}M', 'cyan') 20 | for part_name, num_params in params_dict.items(): 21 | cprint(f' {part_name}: {num_params / 1e6:.4f}M ({num_params / all_num_param:.2%})', 'cyan') 22 | cprint(f'----------------------------------', 'cyan') -------------------------------------------------------------------------------- /diffusion_policy_3d/common/pytorch_util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable, List 2 | import collections 3 | import torch 4 | import torch.nn as nn 5 | 6 | def dict_apply( 7 | x: Dict[str, torch.Tensor], 8 | func: Callable[[torch.Tensor], torch.Tensor] 9 | ) -> Dict[str, torch.Tensor]: 10 | result = dict() 11 | for key, value in x.items(): 12 | if isinstance(value, dict): 13 | result[key] = dict_apply(value, func) 14 | else: 15 | result[key] = func(value) 16 | return result 17 | 18 | def pad_remaining_dims(x, target): 19 | assert x.shape == target.shape[:len(x.shape)] 20 | return x.reshape(x.shape + (1,)*(len(target.shape) - len(x.shape))) 21 | 22 | def dict_apply_split( 23 | x: Dict[str, torch.Tensor], 24 | split_func: Callable[[torch.Tensor], Dict[str, torch.Tensor]] 25 | ) -> Dict[str, torch.Tensor]: 26 | results = collections.defaultdict(dict) 27 | for key, value in x.items(): 28 | result = split_func(value) 29 | for k, v in result.items(): 30 | results[k][key] = v 31 | return results 32 | 33 | def dict_apply_reduce( 34 | x: List[Dict[str, torch.Tensor]], 35 | reduce_func: Callable[[List[torch.Tensor]], torch.Tensor] 36 | ) -> Dict[str, torch.Tensor]: 37 | result = dict() 38 | for key in x[0].keys(): 39 | result[key] = reduce_func([x_[key] for x_ in x]) 40 | return result 41 | 42 | 43 | def optimizer_to(optimizer, device): 44 | for state in optimizer.state.values(): 45 | for k, v in state.items(): 46 | if isinstance(v, torch.Tensor): 47 | state[k] = v.to(device=device) 48 | return optimizer 49 | -------------------------------------------------------------------------------- /diffusion_policy_3d/common/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | import numba 4 | from diffusion_policy_3d.common.replay_buffer import ReplayBuffer 5 | 6 | 7 | @numba.jit(nopython=True) 8 | def create_indices( 9 | episode_ends:np.ndarray, sequence_length:int, 10 | episode_mask: np.ndarray, 11 | pad_before: int=0, pad_after: int=0, 12 | debug:bool=True) -> np.ndarray: 13 | episode_mask.shape == episode_ends.shape 14 | pad_before = min(max(pad_before, 0), sequence_length-1) 15 | pad_after = min(max(pad_after, 0), sequence_length-1) 16 | 17 | indices = list() 18 | for i in range(len(episode_ends)): 19 | if not episode_mask[i]: 20 | # skip episode 21 | continue 22 | start_idx = 0 23 | if i > 0: 24 | start_idx = episode_ends[i-1] 25 | end_idx = episode_ends[i] 26 | episode_length = end_idx - start_idx 27 | 28 | min_start = -pad_before 29 | max_start = episode_length - sequence_length + pad_after 30 | 31 | # range stops one idx before end 32 | for idx in range(min_start, max_start+1): 33 | buffer_start_idx = max(idx, 0) + start_idx 34 | buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx 35 | start_offset = buffer_start_idx - (idx+start_idx) 36 | end_offset = (idx+sequence_length+start_idx) - buffer_end_idx 37 | sample_start_idx = 0 + start_offset 38 | sample_end_idx = sequence_length - end_offset 39 | if debug: 40 | assert(start_offset >= 0) 41 | assert(end_offset >= 0) 42 | assert (sample_end_idx - sample_start_idx) == (buffer_end_idx - buffer_start_idx) 43 | indices.append([ 44 | buffer_start_idx, buffer_end_idx, 45 | sample_start_idx, sample_end_idx]) 46 | indices = np.array(indices) 47 | return indices 48 | 49 | 50 | def get_val_mask(n_episodes, val_ratio, seed=0): 51 | val_mask = np.zeros(n_episodes, dtype=bool) 52 | if val_ratio <= 0: 53 | return val_mask 54 | 55 | # have at least 1 episode for validation, and at least 1 episode for train 56 | n_val = min(max(1, round(n_episodes * val_ratio)), n_episodes-1) 57 | rng = np.random.default_rng(seed=seed) 58 | val_idxs = rng.choice(n_episodes, size=n_val, replace=False) 59 | val_mask[val_idxs] = True 60 | return val_mask 61 | 62 | 63 | def downsample_mask(mask, max_n, seed=0): 64 | # subsample training data 65 | train_mask = mask 66 | if (max_n is not None) and (np.sum(train_mask) > max_n): 67 | n_train = int(max_n) 68 | curr_train_idxs = np.nonzero(train_mask)[0] 69 | rng = np.random.default_rng(seed=seed) 70 | train_idxs_idx = rng.choice(len(curr_train_idxs), size=n_train, replace=False) 71 | train_idxs = curr_train_idxs[train_idxs_idx] 72 | train_mask = np.zeros_like(train_mask) 73 | train_mask[train_idxs] = True 74 | assert np.sum(train_mask) == n_train 75 | return train_mask 76 | 77 | class SequenceSampler: 78 | def __init__(self, 79 | replay_buffer: ReplayBuffer, 80 | sequence_length:int, 81 | pad_before:int=0, 82 | pad_after:int=0, 83 | keys=None, 84 | key_first_k=dict(), 85 | episode_mask: Optional[np.ndarray]=None, 86 | n_aug:int=-1, 87 | ): 88 | """ 89 | key_first_k: dict str: int 90 | Only take first k data from these keys (to improve perf) 91 | """ 92 | 93 | super().__init__() 94 | assert(sequence_length >= 1) 95 | if keys is None: 96 | keys = list(replay_buffer.keys()) 97 | 98 | self.n_aug = n_aug 99 | if n_aug < 0: 100 | episode_ends = replay_buffer.episode_ends[:] 101 | if episode_mask is None: 102 | episode_mask = np.ones(episode_ends.shape, dtype=bool) 103 | 104 | if np.any(episode_mask): 105 | indices = create_indices(episode_ends, 106 | sequence_length=sequence_length, 107 | pad_before=pad_before, 108 | pad_after=pad_after, 109 | episode_mask=episode_mask 110 | ) 111 | else: 112 | indices = np.zeros((0,4), dtype=np.int64) 113 | else: 114 | indices = [] 115 | for aug_idx in range(n_aug): 116 | episode_ends = replay_buffer.episode_ends[:] 117 | if episode_mask is None: 118 | episode_mask = np.ones(episode_ends.shape, dtype=bool) 119 | 120 | if np.any(episode_mask): 121 | cur_indices = create_indices(episode_ends, 122 | sequence_length=sequence_length, 123 | pad_before=pad_before, 124 | pad_after=pad_after, 125 | episode_mask=episode_mask 126 | ) 127 | else: 128 | cur_indices = np.zeros((0,4), dtype=np.int64) 129 | 130 | aug_idx_list = np.full((len(cur_indices), 1), fill_value=aug_idx) 131 | cur_indices = np.concatenate([aug_idx_list, cur_indices], axis=1) 132 | indices.append(cur_indices) 133 | indices = np.concatenate(indices, axis=0) 134 | 135 | # (buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx) 136 | self.indices = indices 137 | self.keys = list(keys) # prevent OmegaConf list performance problem 138 | self.sequence_length = sequence_length 139 | self.replay_buffer = replay_buffer 140 | self.key_first_k = key_first_k 141 | 142 | def __len__(self): 143 | return len(self.indices) 144 | 145 | def sample_sequence(self, idx): 146 | if self.n_aug < 0: 147 | aug_idx = None 148 | buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx \ 149 | = self.indices[idx] 150 | else: 151 | aug_idx, buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx \ 152 | = self.indices[idx] 153 | # print("aug_idx", aug_idx) 154 | # print("buffer_start_idx", buffer_start_idx, "buffer_end_idx", buffer_end_idx) 155 | result = dict() 156 | for key in self.keys: 157 | input_arr = self.replay_buffer[key] 158 | # performance optimization, avoid small allocation if possible 159 | if key not in self.key_first_k: 160 | sample = input_arr[buffer_start_idx:buffer_end_idx] 161 | else: 162 | # performance optimization, only load used obs steps 163 | n_data = buffer_end_idx - buffer_start_idx 164 | k_data = min(self.key_first_k[key], n_data) 165 | # fill value with Nan to catch bugs 166 | # the non-loaded region should never be used 167 | sample = np.full((n_data,) + input_arr.shape[1:], 168 | fill_value=np.nan, dtype=input_arr.dtype) 169 | try: 170 | sample[:k_data] = input_arr[buffer_start_idx:buffer_start_idx+k_data] 171 | except Exception as e: 172 | import pdb; pdb.set_trace() 173 | data = sample 174 | if (sample_start_idx > 0) or (sample_end_idx < self.sequence_length): 175 | data = np.zeros( 176 | shape=(self.sequence_length,) + input_arr.shape[1:], 177 | dtype=input_arr.dtype) 178 | if sample_start_idx > 0: 179 | data[:sample_start_idx] = sample[0] 180 | if sample_end_idx < self.sequence_length: 181 | data[sample_end_idx:] = sample[-1] 182 | data[sample_start_idx:sample_end_idx] = sample 183 | result[key] = data 184 | return result, aug_idx 185 | -------------------------------------------------------------------------------- /diffusion_policy_3d/common/sampler_multitask.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | import numpy as np 3 | import numba 4 | from diffusion_policy_3d.common.replay_buffer import ReplayBuffer 5 | 6 | 7 | @numba.jit(nopython=True) 8 | def create_indices( 9 | episode_ends:np.ndarray, sequence_length:int, 10 | episode_mask: np.ndarray, 11 | pad_before: int=0, pad_after: int=0, 12 | debug:bool=True) -> np.ndarray: 13 | episode_mask.shape == episode_ends.shape 14 | pad_before = min(max(pad_before, 0), sequence_length-1) 15 | pad_after = min(max(pad_after, 0), sequence_length-1) 16 | 17 | indices = list() 18 | for i in range(len(episode_ends)): 19 | if not episode_mask[i]: 20 | # skip episode 21 | continue 22 | start_idx = 0 23 | if i > 0: 24 | start_idx = episode_ends[i-1] 25 | end_idx = episode_ends[i] 26 | episode_length = end_idx - start_idx 27 | 28 | min_start = -pad_before 29 | max_start = episode_length - sequence_length + pad_after 30 | 31 | # range stops one idx before end 32 | for idx in range(min_start, max_start+1): 33 | buffer_start_idx = max(idx, 0) + start_idx 34 | buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx 35 | start_offset = buffer_start_idx - (idx+start_idx) 36 | end_offset = (idx+sequence_length+start_idx) - buffer_end_idx 37 | sample_start_idx = 0 + start_offset 38 | sample_end_idx = sequence_length - end_offset 39 | if debug: 40 | assert(start_offset >= 0) 41 | assert(end_offset >= 0) 42 | assert (sample_end_idx - sample_start_idx) == (buffer_end_idx - buffer_start_idx) 43 | indices.append([ 44 | buffer_start_idx, buffer_end_idx, 45 | sample_start_idx, sample_end_idx]) 46 | indices = np.array(indices) 47 | return indices 48 | 49 | 50 | class MultiTaskSequenceSampler: 51 | def __init__(self, 52 | replay_buffer_list: List[ReplayBuffer], 53 | sequence_length:int, 54 | pad_before:int=0, 55 | pad_after:int=0, 56 | keys=None, 57 | key_first_k=dict(), 58 | # episode_mask: Optional[np.ndarray]=None, 59 | ): 60 | """ 61 | key_first_k: dict str: int 62 | Only take first k data from these keys (to improve perf) 63 | """ 64 | 65 | super().__init__() 66 | assert(sequence_length >= 1) 67 | if keys is None: 68 | keys = list(replay_buffer_list[0].keys()) 69 | 70 | indices = [] 71 | for task_idx, replay_buffer in enumerate(replay_buffer_list): 72 | episode_ends = replay_buffer.episode_ends[:] 73 | episode_mask = np.ones(episode_ends.shape, dtype=bool) 74 | 75 | if np.any(episode_mask): 76 | cur_indices = create_indices(episode_ends, 77 | sequence_length=sequence_length, 78 | pad_before=pad_before, 79 | pad_after=pad_after, 80 | episode_mask=episode_mask 81 | ) 82 | else: 83 | cur_indices = np.zeros((0,4), dtype=np.int64) 84 | 85 | task_idx_list = np.full((len(cur_indices), 1), fill_value=task_idx) 86 | cur_indices = np.concatenate([task_idx_list, cur_indices], axis=1) 87 | indices.append(cur_indices) 88 | indices = np.concatenate(indices, axis=0) 89 | 90 | # (task_idx, buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx) 91 | self.indices = indices 92 | self.keys = list(keys) # prevent OmegaConf list performance problem 93 | self.sequence_length = sequence_length 94 | self.replay_buffer_list = replay_buffer_list 95 | self.key_first_k = key_first_k 96 | 97 | def __len__(self): 98 | return len(self.indices) 99 | 100 | def sample_sequence(self, idx): 101 | task_idx, buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx \ 102 | = self.indices[idx] 103 | # print("buffer_start_idx", buffer_start_idx, "buffer_end_idx", buffer_end_idx) 104 | result = dict() 105 | 106 | cur_replay_buffer = self.replay_buffer_list[task_idx] 107 | for key in self.keys: 108 | input_arr = cur_replay_buffer[key] 109 | # performance optimization, avoid small allocation if possible 110 | if key not in self.key_first_k: 111 | sample = input_arr[buffer_start_idx:buffer_end_idx] 112 | else: 113 | # performance optimization, only load used obs steps 114 | n_data = buffer_end_idx - buffer_start_idx 115 | k_data = min(self.key_first_k[key], n_data) 116 | # fill value with Nan to catch bugs 117 | # the non-loaded region should never be used 118 | sample = np.full((n_data,) + input_arr.shape[1:], 119 | fill_value=np.nan, dtype=input_arr.dtype) 120 | try: 121 | sample[:k_data] = input_arr[buffer_start_idx:buffer_start_idx+k_data] 122 | except Exception as e: 123 | import pdb; pdb.set_trace() 124 | data = sample 125 | if (sample_start_idx > 0) or (sample_end_idx < self.sequence_length): 126 | data = np.zeros( 127 | shape=(self.sequence_length,) + input_arr.shape[1:], 128 | dtype=input_arr.dtype) 129 | if sample_start_idx > 0: 130 | data[:sample_start_idx] = sample[0] 131 | if sample_end_idx < self.sequence_length: 132 | data[sample_end_idx:] = sample[-1] 133 | data[sample_start_idx:sample_end_idx] = sample 134 | result[key] = data 135 | return result, task_idx 136 | -------------------------------------------------------------------------------- /diffusion_policy_3d/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/object_centric_diffusion/5ac86be2370b582c03e85f22d24bdcb825dfa13b/diffusion_policy_3d/dataset/__init__.py -------------------------------------------------------------------------------- /diffusion_policy_3d/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn 5 | from diffusion_policy_3d.model.common.normalizer import LinearNormalizer 6 | 7 | 8 | class BaseDataset(torch.utils.data.Dataset): 9 | def get_validation_dataset(self) -> 'BaseDataset': 10 | # return an empty dataset by default 11 | return BaseDataset() 12 | 13 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 14 | raise NotImplementedError() 15 | 16 | def get_all_actions(self) -> torch.Tensor: 17 | raise NotImplementedError() 18 | 19 | def __len__(self) -> int: 20 | return 0 21 | 22 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 23 | """ 24 | output: 25 | obs: 26 | key: T, * 27 | action: T, Da 28 | """ 29 | raise NotImplementedError() 30 | -------------------------------------------------------------------------------- /diffusion_policy_3d/dataset/rlbench_dataset.py: -------------------------------------------------------------------------------- 1 | from diffusion_policy_3d.dataset.rlbench_base_dataset import RLBenchBaseDataset 2 | 3 | 4 | class RLBenchDataset(RLBenchBaseDataset): 5 | def __init__(self, *args, **kwargs): 6 | super().__init__(*args, **kwargs) -------------------------------------------------------------------------------- /diffusion_policy_3d/dataset/rlbench_dataset_list.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import numpy as np 3 | import torch 4 | import copy 5 | from typing import Dict 6 | import hydra 7 | 8 | from diffusion_policy_3d.model.common.normalizer import LinearNormalizer 9 | from diffusion_policy_3d.dataset.base_dataset import BaseDataset 10 | 11 | 12 | class RLBenchDatasetList(BaseDataset): 13 | def __init__(self, root_dir=None): 14 | self.task_list = [ 15 | 'meat_off_grill', 16 | 'put_money_in_safe', 17 | 'place_wine_at_rack_location', 18 | 'reach_and_drag', 19 | 'stack_blocks', 20 | 'close_jar', 21 | 'light_bulb_in', 22 | 'put_groceries_in_cupboard', 23 | 'place_shape_in_shape_sorter', 24 | 'insert_onto_square_peg', 25 | 'stack_cups', 26 | 'place_cups', 27 | 'turn_tap', 28 | ] 29 | self.task_id_to_name = {} 30 | self.task_name_to_dataset = {} 31 | 32 | self.dataset_list = [] 33 | self.global_idx_to_task_id_local_idx = {} 34 | 35 | start = 0 36 | for task_idx, task_name in enumerate(self.task_list): 37 | 38 | # read yaml 39 | with open(f"config/task/rlbench/{task_name}.yaml", 'r') as stream: 40 | dataset_cfg = yaml.safe_load(stream)["dataset"] 41 | 42 | if root_dir is not None: 43 | dataset_cfg["root_dir"] = f"{root_dir}/{task_name}/all_variations" # overrides all single-task configs 44 | else: 45 | dataset_cfg["root_dir"] = dataset_cfg["root_dir"].replace("${task.task_name}", task_name) 46 | 47 | dataset = hydra.utils.instantiate(dataset_cfg) 48 | self.dataset_list.append(dataset) 49 | 50 | self.task_id_to_name[task_idx] = task_name 51 | self.task_name_to_dataset[task_name] = dataset 52 | 53 | # global index to (task_id, local index) 54 | dataset_length = len(dataset) 55 | for sample_idx in range(dataset_length): 56 | self.global_idx_to_task_id_local_idx[start+sample_idx] = (task_idx, sample_idx) 57 | start += dataset_length 58 | 59 | def get_validation_dataset(self): 60 | val_set_list = copy.copy(self) 61 | return val_set_list 62 | 63 | def get_normalizer(self, mode='limits', **kwargs): 64 | action_all = [] 65 | agent_pos_all = [] 66 | for dataset in self.dataset_list: 67 | print(dataset.replay_buffer['action'][:, :3].shape) 68 | action_all.append(dataset.replay_buffer['action'][:, :3]) 69 | agent_pos_all.append(dataset.replay_buffer['state'][...,:][:, :3]) 70 | 71 | print(len(action_all)) 72 | 73 | action_all = np.concatenate(action_all, axis=0) 74 | agent_pos_all = np.concatenate(agent_pos_all, axis=0) 75 | 76 | data = { 77 | 'action': action_all, 78 | 'agent_pos': agent_pos_all, 79 | # 'point_cloud': self.replay_buffer['point_cloud'], 80 | } 81 | normalizer = LinearNormalizer() 82 | normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs) 83 | return normalizer 84 | 85 | def __len__(self) -> int: 86 | return len(self.global_idx_to_task_id_local_idx) 87 | 88 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 89 | task_idx, sample_idx = self.global_idx_to_task_id_local_idx[idx] 90 | 91 | select_dataset = self.dataset_list[task_idx] 92 | return select_dataset.__getitem__(sample_idx) -------------------------------------------------------------------------------- /diffusion_policy_3d/env_runner/base_runner.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from diffusion_policy_3d.policy.base_policy import BasePolicy 3 | 4 | 5 | class BaseRunner: 6 | def __init__(self, output_dir): 7 | self.output_dir = output_dir 8 | 9 | def run(self, policy: BasePolicy) -> Dict: 10 | raise NotImplementedError() 11 | -------------------------------------------------------------------------------- /diffusion_policy_3d/model/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/object_centric_diffusion/5ac86be2370b582c03e85f22d24bdcb825dfa13b/diffusion_policy_3d/model/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /diffusion_policy_3d/model/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text -------------------------------------------------------------------------------- /diffusion_policy_3d/model/common/dict_of_tensor_mixin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DictOfTensorMixin(nn.Module): 5 | def __init__(self, params_dict=None): 6 | super().__init__() 7 | if params_dict is None: 8 | params_dict = nn.ParameterDict() 9 | self.params_dict = params_dict 10 | 11 | @property 12 | def device(self): 13 | return next(iter(self.parameters())).device 14 | 15 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 16 | def dfs_add(dest, keys, value: torch.Tensor): 17 | if len(keys) == 1: 18 | dest[keys[0]] = value 19 | return 20 | 21 | if keys[0] not in dest: 22 | dest[keys[0]] = nn.ParameterDict() 23 | dfs_add(dest[keys[0]], keys[1:], value) 24 | 25 | def load_dict(state_dict, prefix): 26 | out_dict = nn.ParameterDict() 27 | for key, value in state_dict.items(): 28 | value: torch.Tensor 29 | if key.startswith(prefix): 30 | param_keys = key[len(prefix):].split('.')[1:] 31 | # if len(param_keys) == 0: 32 | # import pdb; pdb.set_trace() 33 | dfs_add(out_dict, param_keys, value.clone()) 34 | return out_dict 35 | 36 | self.params_dict = load_dict(state_dict, prefix + 'params_dict') 37 | self.params_dict.requires_grad_(False) 38 | return 39 | -------------------------------------------------------------------------------- /diffusion_policy_3d/model/common/geodesic_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch import Tensor 5 | 6 | 7 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: 8 | """ 9 | Convert rotations given as quaternions to rotation matrices. 10 | 11 | Args: 12 | quaternions: quaternions with real part first, 13 | as tensor of shape (..., 4). 14 | 15 | Returns: 16 | Rotation matrices as tensor of shape (..., 3, 3). 17 | """ 18 | r, i, j, k = torch.unbind(quaternions, -1) 19 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. 20 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 21 | 22 | o = torch.stack( 23 | ( 24 | 1 - two_s * (j * j + k * k), 25 | two_s * (i * j - k * r), 26 | two_s * (i * k + j * r), 27 | two_s * (i * j + k * r), 28 | 1 - two_s * (i * i + k * k), 29 | two_s * (j * k - i * r), 30 | two_s * (i * k - j * r), 31 | two_s * (j * k + i * r), 32 | 1 - two_s * (i * i + j * j), 33 | ), 34 | -1, 35 | ) 36 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 37 | 38 | 39 | class GeodesicLoss(nn.Module): 40 | r"""Creates a criterion that measures the distance between rotation matrices, which is 41 | useful for pose estimation problems. 42 | The distance ranges from 0 to :math:`pi`. 43 | See: http://www.boris-belousov.net/2016/12/01/quat-dist/#using-rotation-matrices and: 44 | "Metrics for 3D Rotations: Comparison and Analysis" (https://link.springer.com/article/10.1007/s10851-009-0161-2). 45 | 46 | Both `input` and `target` consist of rotation matrices, i.e., they have to be Tensors 47 | of size :math:`(minibatch, 3, 3)`. 48 | 49 | The loss can be described as: 50 | 51 | .. math:: 52 | \text{loss}(R_{S}, R_{T}) = \arccos\left(\frac{\text{tr} (R_{S} R_{T}^{T}) - 1}{2}\right) 53 | 54 | Args: 55 | eps (float, optional): term to improve numerical stability (default: 1e-7). See: 56 | https://github.com/pytorch/pytorch/issues/8069. 57 | 58 | reduction (string, optional): Specifies the reduction to apply to the output: 59 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will 60 | be applied, ``'mean'``: the weighted mean of the output is taken, 61 | ``'sum'``: the output will be summed. Default: ``'mean'`` 62 | 63 | Shape: 64 | - Input: Shape :math:`(N, 3, 3)`. 65 | - Target: Shape :math:`(N, 3, 3)`. 66 | - Output: If :attr:`reduction` is ``'none'``, then :math:`(N)`. Otherwise, scalar. 67 | """ 68 | 69 | def __init__(self, eps: float = 1e-7, reduction: str = "mean") -> None: 70 | super().__init__() 71 | self.eps = eps 72 | self.reduction = reduction 73 | 74 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 75 | R_diffs = input @ target.permute(0, 2, 1) 76 | # See: https://github.com/pytorch/pytorch/issues/7500#issuecomment-502122839. 77 | traces = R_diffs.diagonal(dim1=-2, dim2=-1).sum(-1) 78 | dists = torch.acos(torch.clamp((traces - 1) / 2, -1 + self.eps, 1 - self.eps)) 79 | if self.reduction == "none": 80 | return dists 81 | elif self.reduction == "mean": 82 | return dists.mean() 83 | elif self.reduction == "sum": 84 | return dists.sum() -------------------------------------------------------------------------------- /diffusion_policy_3d/model/common/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from diffusers.optimization import ( 2 | Union, SchedulerType, Optional, 3 | Optimizer, TYPE_TO_SCHEDULER_FUNCTION 4 | ) 5 | 6 | def get_scheduler( 7 | name: Union[str, SchedulerType], 8 | optimizer: Optimizer, 9 | num_warmup_steps: Optional[int] = None, 10 | num_training_steps: Optional[int] = None, 11 | **kwargs 12 | ): 13 | """ 14 | Added kwargs vs diffuser's original implementation 15 | 16 | Unified API to get any scheduler from its name. 17 | 18 | Args: 19 | name (`str` or `SchedulerType`): 20 | The name of the scheduler to use. 21 | optimizer (`torch.optim.Optimizer`): 22 | The optimizer that will be used during training. 23 | num_warmup_steps (`int`, *optional*): 24 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being 25 | optional), the function will raise an error if it's unset and the scheduler type requires it. 26 | num_training_steps (`int``, *optional*): 27 | The number of training steps to do. This is not required by all schedulers (hence the argument being 28 | optional), the function will raise an error if it's unset and the scheduler type requires it. 29 | """ 30 | name = SchedulerType(name) 31 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] 32 | if name == SchedulerType.CONSTANT: 33 | return schedule_func(optimizer, **kwargs) 34 | 35 | # All other schedulers require `num_warmup_steps` 36 | if num_warmup_steps is None: 37 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") 38 | 39 | if name == SchedulerType.CONSTANT_WITH_WARMUP: 40 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs) 41 | 42 | # All other schedulers require `num_training_steps` 43 | if num_training_steps is None: 44 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") 45 | 46 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs) 47 | -------------------------------------------------------------------------------- /diffusion_policy_3d/model/common/module_attr_mixin.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ModuleAttrMixin(nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | self._dummy_variable = nn.Parameter() 7 | 8 | @property 9 | def device(self): 10 | return next(iter(self.parameters())).device 11 | 12 | @property 13 | def dtype(self): 14 | return next(iter(self.parameters())).dtype 15 | -------------------------------------------------------------------------------- /diffusion_policy_3d/model/common/shape_util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Callable 2 | import torch 3 | import torch.nn as nn 4 | 5 | def get_module_device(m: nn.Module): 6 | device = torch.device('cpu') 7 | try: 8 | param = next(iter(m.parameters())) 9 | device = param.device 10 | except StopIteration: 11 | pass 12 | return device 13 | 14 | @torch.no_grad() 15 | def get_output_shape( 16 | input_shape: Tuple[int], 17 | net: Callable[[torch.Tensor], torch.Tensor] 18 | ): 19 | device = get_module_device(net) 20 | test_input = torch.zeros((1,)+tuple(input_shape), device=device) 21 | test_output = net(test_input) 22 | output_shape = tuple(test_output.shape[1:]) 23 | return output_shape 24 | -------------------------------------------------------------------------------- /diffusion_policy_3d/model/diffusion/conv1d_components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from einops.layers.torch import Rearrange 5 | 6 | 7 | class Downsample1d(nn.Module): 8 | def __init__(self, dim): 9 | super().__init__() 10 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 11 | 12 | def forward(self, x): 13 | return self.conv(x) 14 | 15 | class Upsample1d(nn.Module): 16 | def __init__(self, dim): 17 | super().__init__() 18 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 19 | 20 | def forward(self, x): 21 | return self.conv(x) 22 | 23 | class Conv1dBlock(nn.Module): 24 | ''' 25 | Conv1d --> GroupNorm --> Mish 26 | ''' 27 | 28 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 29 | super().__init__() 30 | 31 | self.block = nn.Sequential( 32 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 33 | # Rearrange('batch channels horizon -> batch channels 1 horizon'), 34 | nn.GroupNorm(n_groups, out_channels), 35 | # Rearrange('batch channels 1 horizon -> batch channels horizon'), 36 | nn.Mish(), 37 | ) 38 | 39 | def forward(self, x): 40 | return self.block(x) 41 | 42 | 43 | def test(): 44 | cb = Conv1dBlock(256, 128, kernel_size=3) 45 | x = torch.zeros((1,256,16)) 46 | o = cb(x) 47 | -------------------------------------------------------------------------------- /diffusion_policy_3d/model/diffusion/ema_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch.nn.modules.batchnorm import _BatchNorm 4 | 5 | class EMAModel: 6 | """ 7 | Exponential Moving Average of models weights 8 | """ 9 | 10 | def __init__( 11 | self, 12 | model, 13 | update_after_step=0, 14 | inv_gamma=1.0, 15 | power=2 / 3, 16 | min_value=0.0, 17 | max_value=0.9999 18 | ): 19 | """ 20 | @crowsonkb's notes on EMA Warmup: 21 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan 22 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), 23 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 24 | at 215.4k steps). 25 | Args: 26 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 27 | power (float): Exponential factor of EMA warmup. Default: 2/3. 28 | min_value (float): The minimum EMA decay rate. Default: 0. 29 | """ 30 | 31 | self.averaged_model = model 32 | self.averaged_model.eval() 33 | self.averaged_model.requires_grad_(False) 34 | 35 | self.update_after_step = update_after_step 36 | self.inv_gamma = inv_gamma 37 | self.power = power 38 | self.min_value = min_value 39 | self.max_value = max_value 40 | 41 | self.decay = 0.0 42 | self.optimization_step = 0 43 | 44 | def get_decay(self, optimization_step): 45 | """ 46 | Compute the decay factor for the exponential moving average. 47 | """ 48 | step = max(0, optimization_step - self.update_after_step - 1) 49 | value = 1 - (1 + step / self.inv_gamma) ** -self.power 50 | 51 | if step <= 0: 52 | return 0.0 53 | 54 | return max(self.min_value, min(value, self.max_value)) 55 | 56 | @torch.no_grad() 57 | def step(self, new_model): 58 | self.decay = self.get_decay(self.optimization_step) 59 | 60 | # old_all_dataptrs = set() 61 | # for param in new_model.parameters(): 62 | # data_ptr = param.data_ptr() 63 | # if data_ptr != 0: 64 | # old_all_dataptrs.add(data_ptr) 65 | 66 | all_dataptrs = set() 67 | for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()): 68 | for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)): 69 | # iterative over immediate parameters only. 70 | if isinstance(param, dict): 71 | raise RuntimeError('Dict parameter not supported') 72 | 73 | # data_ptr = param.data_ptr() 74 | # if data_ptr != 0: 75 | # all_dataptrs.add(data_ptr) 76 | 77 | if isinstance(module, _BatchNorm): 78 | # skip batchnorms 79 | ema_param.copy_(param.to(dtype=ema_param.dtype).data) 80 | elif not param.requires_grad: 81 | ema_param.copy_(param.to(dtype=ema_param.dtype).data) 82 | else: 83 | ema_param.mul_(self.decay) 84 | ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) 85 | 86 | # verify that iterating over module and then parameters is identical to parameters recursively. 87 | # assert old_all_dataptrs == all_dataptrs 88 | self.optimization_step += 1 89 | -------------------------------------------------------------------------------- /diffusion_policy_3d/model/diffusion/mask_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Optional 2 | import torch 3 | from torch import nn 4 | from diffusion_policy_3d.model.common.module_attr_mixin import ModuleAttrMixin 5 | 6 | 7 | def get_intersection_slice_mask( 8 | shape: tuple, 9 | dim_slices: Sequence[slice], 10 | device: Optional[torch.device]=None 11 | ): 12 | assert(len(shape) == len(dim_slices)) 13 | mask = torch.zeros(size=shape, dtype=torch.bool, device=device) 14 | mask[dim_slices] = True 15 | return mask 16 | 17 | 18 | def get_union_slice_mask( 19 | shape: tuple, 20 | dim_slices: Sequence[slice], 21 | device: Optional[torch.device]=None 22 | ): 23 | assert(len(shape) == len(dim_slices)) 24 | mask = torch.zeros(size=shape, dtype=torch.bool, device=device) 25 | for i in range(len(dim_slices)): 26 | this_slices = [slice(None)] * len(shape) 27 | this_slices[i] = dim_slices[i] 28 | mask[this_slices] = True 29 | return mask 30 | 31 | 32 | class DummyMaskGenerator(ModuleAttrMixin): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | @torch.no_grad() 37 | def forward(self, shape): 38 | device = self.device 39 | mask = torch.ones(size=shape, dtype=torch.bool, device=device) 40 | return mask 41 | 42 | 43 | class LowdimMaskGenerator(ModuleAttrMixin): 44 | def __init__(self, 45 | action_dim, obs_dim, 46 | # obs mask setup 47 | max_n_obs_steps=2, 48 | fix_obs_steps=True, 49 | # action mask 50 | action_visible=False 51 | ): 52 | super().__init__() 53 | self.action_dim = action_dim 54 | self.obs_dim = obs_dim 55 | self.max_n_obs_steps = max_n_obs_steps 56 | self.fix_obs_steps = fix_obs_steps 57 | self.action_visible = action_visible 58 | 59 | @torch.no_grad() 60 | def forward(self, shape, seed=None): 61 | device = self.device 62 | B, T, D = shape 63 | assert D == (self.action_dim + self.obs_dim) 64 | 65 | # create all tensors on this device 66 | rng = torch.Generator(device=device) 67 | if seed is not None: 68 | rng = rng.manual_seed(seed) 69 | 70 | # generate dim mask 71 | dim_mask = torch.zeros(size=shape, 72 | dtype=torch.bool, device=device) 73 | is_action_dim = dim_mask.clone() 74 | is_action_dim[...,:self.action_dim] = True 75 | is_obs_dim = ~is_action_dim 76 | 77 | # generate obs mask 78 | if self.fix_obs_steps: 79 | obs_steps = torch.full((B,), 80 | fill_value=self.max_n_obs_steps, device=device) 81 | else: 82 | obs_steps = torch.randint( 83 | low=1, high=self.max_n_obs_steps+1, 84 | size=(B,), generator=rng, device=device) 85 | 86 | steps = torch.arange(0, T, device=device).reshape(1,T).expand(B,T) 87 | obs_mask = (steps.T < obs_steps).T.reshape(B,T,1).expand(B,T,D) 88 | obs_mask = obs_mask & is_obs_dim 89 | 90 | # generate action mask 91 | if self.action_visible: 92 | action_steps = torch.maximum( 93 | obs_steps - 1, 94 | torch.tensor(0, 95 | dtype=obs_steps.dtype, 96 | device=obs_steps.device)) 97 | action_mask = (steps.T < action_steps).T.reshape(B,T,1).expand(B,T,D) 98 | action_mask = action_mask & is_action_dim 99 | 100 | mask = obs_mask 101 | if self.action_visible: 102 | mask = mask | action_mask 103 | 104 | return mask 105 | 106 | 107 | 108 | 109 | class KeypointMaskGenerator(ModuleAttrMixin): 110 | def __init__(self, 111 | # dimensions 112 | action_dim, keypoint_dim, 113 | # obs mask setup 114 | max_n_obs_steps=2, fix_obs_steps=True, 115 | # keypoint mask setup 116 | keypoint_visible_rate=0.7, time_independent=False, 117 | # action mask 118 | action_visible=False, 119 | context_dim=0, # dim for context 120 | n_context_steps=1 121 | ): 122 | super().__init__() 123 | self.action_dim = action_dim 124 | self.keypoint_dim = keypoint_dim 125 | self.context_dim = context_dim 126 | self.max_n_obs_steps = max_n_obs_steps 127 | self.fix_obs_steps = fix_obs_steps 128 | self.keypoint_visible_rate = keypoint_visible_rate 129 | self.time_independent = time_independent 130 | self.action_visible = action_visible 131 | self.n_context_steps = n_context_steps 132 | 133 | @torch.no_grad() 134 | def forward(self, shape, seed=None): 135 | device = self.device 136 | B, T, D = shape 137 | all_keypoint_dims = D - self.action_dim - self.context_dim 138 | n_keypoints = all_keypoint_dims // self.keypoint_dim 139 | 140 | # create all tensors on this device 141 | rng = torch.Generator(device=device) 142 | if seed is not None: 143 | rng = rng.manual_seed(seed) 144 | 145 | # generate dim mask 146 | dim_mask = torch.zeros(size=shape, 147 | dtype=torch.bool, device=device) 148 | is_action_dim = dim_mask.clone() 149 | is_action_dim[...,:self.action_dim] = True 150 | is_context_dim = dim_mask.clone() 151 | if self.context_dim > 0: 152 | is_context_dim[...,-self.context_dim:] = True 153 | is_obs_dim = ~(is_action_dim | is_context_dim) 154 | # assumption trajectory=cat([action, keypoints, context], dim=-1) 155 | 156 | # generate obs mask 157 | if self.fix_obs_steps: 158 | obs_steps = torch.full((B,), 159 | fill_value=self.max_n_obs_steps, device=device) 160 | else: 161 | obs_steps = torch.randint( 162 | low=1, high=self.max_n_obs_steps+1, 163 | size=(B,), generator=rng, device=device) 164 | 165 | steps = torch.arange(0, T, device=device).reshape(1,T).expand(B,T) 166 | obs_mask = (steps.T < obs_steps).T.reshape(B,T,1).expand(B,T,D) 167 | obs_mask = obs_mask & is_obs_dim 168 | 169 | # generate action mask 170 | if self.action_visible: 171 | action_steps = torch.maximum( 172 | obs_steps - 1, 173 | torch.tensor(0, 174 | dtype=obs_steps.dtype, 175 | device=obs_steps.device)) 176 | action_mask = (steps.T < action_steps).T.reshape(B,T,1).expand(B,T,D) 177 | action_mask = action_mask & is_action_dim 178 | 179 | # generate keypoint mask 180 | if self.time_independent: 181 | visible_kps = torch.rand(size=(B, T, n_keypoints), 182 | generator=rng, device=device) < self.keypoint_visible_rate 183 | visible_dims = torch.repeat_interleave(visible_kps, repeats=self.keypoint_dim, dim=-1) 184 | visible_dims_mask = torch.cat([ 185 | torch.ones((B, T, self.action_dim), 186 | dtype=torch.bool, device=device), 187 | visible_dims, 188 | torch.ones((B, T, self.context_dim), 189 | dtype=torch.bool, device=device), 190 | ], axis=-1) 191 | keypoint_mask = visible_dims_mask 192 | else: 193 | visible_kps = torch.rand(size=(B,n_keypoints), 194 | generator=rng, device=device) < self.keypoint_visible_rate 195 | visible_dims = torch.repeat_interleave(visible_kps, repeats=self.keypoint_dim, dim=-1) 196 | visible_dims_mask = torch.cat([ 197 | torch.ones((B, self.action_dim), 198 | dtype=torch.bool, device=device), 199 | visible_dims, 200 | torch.ones((B, self.context_dim), 201 | dtype=torch.bool, device=device), 202 | ], axis=-1) 203 | keypoint_mask = visible_dims_mask.reshape(B,1,D).expand(B,T,D) 204 | keypoint_mask = keypoint_mask & is_obs_dim 205 | 206 | # generate context mask 207 | context_mask = is_context_dim.clone() 208 | context_mask[:,self.n_context_steps:,:] = False 209 | 210 | mask = obs_mask & keypoint_mask 211 | if self.action_visible: 212 | mask = mask | action_mask 213 | if self.context_dim > 0: 214 | mask = mask | context_mask 215 | 216 | return mask 217 | 218 | 219 | def test(): 220 | # kmg = KeypointMaskGenerator(2,2, random_obs_steps=True) 221 | # self = KeypointMaskGenerator(2,2,context_dim=2, action_visible=True) 222 | # self = KeypointMaskGenerator(2,2,context_dim=0, action_visible=True) 223 | self = LowdimMaskGenerator(2,20, max_n_obs_steps=3, action_visible=True) 224 | -------------------------------------------------------------------------------- /diffusion_policy_3d/model/diffusion/positional_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class SinusoidalPosEmb(nn.Module): 6 | def __init__(self, dim): 7 | super().__init__() 8 | self.dim = dim 9 | 10 | def forward(self, x): 11 | device = x.device 12 | half_dim = self.dim // 2 13 | emb = math.log(10000) / (half_dim - 1) 14 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 15 | emb = x[:, None] * emb[None, :] 16 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 17 | return emb 18 | -------------------------------------------------------------------------------- /diffusion_policy_3d/policy/base_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | from diffusion_policy_3d.model.common.module_attr_mixin import ModuleAttrMixin 5 | from diffusion_policy_3d.model.common.normalizer import LinearNormalizer 6 | 7 | class BasePolicy(ModuleAttrMixin): 8 | # init accepts keyword argument shape_meta, see config/task/*_image.yaml 9 | 10 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 11 | """ 12 | obs_dict: 13 | str: B,To,* 14 | return: B,Ta,Da 15 | """ 16 | raise NotImplementedError() 17 | 18 | # reset state for stateful policies 19 | def reset(self): 20 | pass 21 | 22 | # ========== training =========== 23 | # no standard training interface except setting normalizer 24 | def set_normalizer(self, normalizer: LinearNormalizer): 25 | raise NotImplementedError() 26 | -------------------------------------------------------------------------------- /diffusion_policy_3d/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='diffusion_policy_3d', 5 | packages=find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /dp3_requirements.txt: -------------------------------------------------------------------------------- 1 | zarr==2.12.0 2 | wandb 3 | omegaconf 4 | hydra-core==1.2.0 5 | dill==0.3.5.1 6 | einops==0.4.1 7 | diffusers==0.11.1 8 | huggingface-hub==0.23.3 9 | numba==0.58.1 10 | moviepy 11 | imageio 12 | av 13 | matplotlib 14 | termcolor 15 | 16 | # CLIP 17 | ftfy==6.2.0 -------------------------------------------------------------------------------- /env_real/data/collect_zarr_real.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import os 11 | import sys 12 | sys.path.insert(0, os.getcwd()) 13 | 14 | import glob 15 | from pathlib import Path 16 | import shutil 17 | import copy 18 | import zarr 19 | import pickle 20 | from PIL import Image 21 | import numpy as np 22 | from utils.collect_utils import collect_narr_function, save_zarr 23 | from utils.pose_utils import get_rel_pose, euler_from_quaternion 24 | 25 | 26 | # create 27 | data_dict = [] 28 | cur_data_dict = { 29 | "total_count": 0, 30 | # "img_arrays": [], 31 | # "point_cloud_arrays": [], 32 | # "depth_arrays": [], 33 | "state_arrays": [], 34 | "state_in_world_arrays": [], 35 | "state_next_arrays": [], 36 | "state_next_in_world_arrays": [], 37 | "goal_arrays": [], 38 | "action_arrays": [], 39 | "progress_arrays": [], 40 | "progress_binary_arrays": [], 41 | "task_stage_arrays": [], 42 | "variation_arrays": [], 43 | "episode_ends_arrays": [], 44 | "lang_list": [], # debug only 45 | } 46 | 47 | import argparse 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('task_name', type=str) 50 | parser.add_argument('dataset_path', default="/tmp/record3d/[task_name]/r3d/", type=str) 51 | 52 | args = parser.parse_args() 53 | 54 | task_name = args.task_name 55 | print(task_name) 56 | dataset_path = args.dataset_path 57 | print(dataset_path) 58 | r3d_path_list = glob.glob(os.path.join(dataset_path, '*.r3d')) 59 | r3d_path_list.sort() 60 | 61 | for r3d_file in r3d_path_list: 62 | print(r3d_file) 63 | 64 | extract_path = os.path.join(dataset_path, Path(r3d_file).stem) 65 | grasp_obj_poses = np.load(os.path.join(extract_path, f"grasp_object_poses.npy")) 66 | target_object_poses = np.load(os.path.join(extract_path, f"target_object_poses.npy")) 67 | target_pose = target_object_poses[0] # !! do we need to update target pose? 68 | 69 | 70 | episode_keypoints = list(range(len(grasp_obj_poses))) 71 | 72 | # create poses_dict 73 | poses_dict = { 74 | "task_stage": [], 75 | "grasp_obj_pose": [], 76 | "grasp_obj_pose_relative_to_target": [], 77 | "target_obj_pose": [], 78 | "keypoint": [], 79 | } 80 | stage_idx = 0 81 | variation = 0 # !! only one variation 82 | count = 0 83 | for idx, obj_pose in enumerate(grasp_obj_poses): 84 | obj_pose_relative_to_target = get_rel_pose(target_pose, obj_pose) # !! be careful about the order (target, grasp_object) 85 | 86 | poses_dict["grasp_obj_pose"].append(obj_pose) 87 | poses_dict["grasp_obj_pose_relative_to_target"].append(obj_pose_relative_to_target) 88 | poses_dict["target_obj_pose"].append(target_pose) 89 | poses_dict["task_stage"].append(stage_idx) 90 | if idx in episode_keypoints: 91 | poses_dict["keypoint"].append(count) 92 | count += 1 93 | 94 | # 95 | arrays_sub_dict_list, total_count_sub_list = collect_narr_function([poses_dict], None) 96 | assert len(arrays_sub_dict_list) == len(total_count_sub_list) == 1 97 | 98 | for arrays_sub_dict, total_count_sub in zip(arrays_sub_dict_list, total_count_sub_list): 99 | cur_data_dict["total_count"] += total_count_sub 100 | cur_data_dict["episode_ends_arrays"].append(copy.deepcopy(cur_data_dict["total_count"])) # the index of the last step of the episode 101 | # img_arrays.extend(copy.deepcopy(img_arrays_sub)) 102 | # point_cloud_arrays.extend(copy.deepcopy(point_cloud_arrays_sub)) 103 | # depth_arrays.extend(copy.deepcopy(depth_arrays_sub)) 104 | cur_data_dict["state_arrays"].extend(copy.deepcopy(arrays_sub_dict["state"])) 105 | cur_data_dict["state_in_world_arrays"].extend(copy.deepcopy(arrays_sub_dict["state_in_world"])) 106 | cur_data_dict["state_next_arrays"].extend(copy.deepcopy(arrays_sub_dict["state_next"])) 107 | cur_data_dict["state_next_in_world_arrays"].extend(copy.deepcopy(arrays_sub_dict["state_next_in_world"])) 108 | cur_data_dict["goal_arrays"].extend(copy.deepcopy(arrays_sub_dict["goal"])) 109 | cur_data_dict["action_arrays"].extend(copy.deepcopy(arrays_sub_dict["action"])) 110 | cur_data_dict["progress_arrays"].extend(copy.deepcopy(arrays_sub_dict["progress"])) 111 | cur_data_dict["progress_binary_arrays"].extend(copy.deepcopy(arrays_sub_dict["progress_binary"])) 112 | cur_data_dict["task_stage_arrays"].extend(np.array(arrays_sub_dict["task_stage"]).reshape(-1, 1)) 113 | cur_data_dict["variation_arrays"].extend(np.full((len(arrays_sub_dict["action"]), 1), fill_value=variation)) 114 | 115 | 116 | data_dict.append(cur_data_dict) 117 | assert len(data_dict) == 1 118 | 119 | # collect from all process 120 | state_arrays = [] 121 | state_in_world_arrays = [] 122 | state_next_arrays = [] 123 | state_next_in_world_arrays = [] 124 | goal_arrays = [] 125 | action_arrays = [] 126 | progress_arrays = [] 127 | progress_binary_arrays = [] 128 | task_stage_arrays = [] 129 | variation_arrays = [] 130 | episode_ends_arrays = [] 131 | for i in range(1): 132 | state_arrays.extend(data_dict[i]["state_arrays"]) 133 | state_in_world_arrays.extend(data_dict[i]["state_in_world_arrays"]) 134 | state_next_arrays.extend(data_dict[i]["state_next_arrays"]) 135 | state_next_in_world_arrays.extend(data_dict[i]["state_next_in_world_arrays"]) 136 | goal_arrays.extend(data_dict[i]["goal_arrays"]) 137 | action_arrays.extend(data_dict[i]["action_arrays"]) 138 | progress_arrays.extend(data_dict[i]["progress_arrays"]) 139 | progress_binary_arrays.extend(data_dict[i]["progress_binary_arrays"]) 140 | task_stage_arrays.extend(data_dict[i]["task_stage_arrays"]) 141 | variation_arrays.extend(data_dict[i]["variation_arrays"]) 142 | episode_ends_arrays.extend(data_dict[i]["episode_ends_arrays"]) 143 | 144 | 145 | # save zarr 146 | save_dir = f"/tmp/record3d/{task_name}/zarr/" 147 | os.makedirs(save_dir, exist_ok=True) 148 | save_zarr( 149 | save_dir, 150 | state_arrays, 151 | state_in_world_arrays, 152 | state_next_arrays, 153 | state_next_in_world_arrays, 154 | goal_arrays, 155 | action_arrays, 156 | progress_arrays, 157 | progress_binary_arrays, 158 | task_stage_arrays, 159 | variation_arrays, 160 | episode_ends_arrays, 161 | ) -------------------------------------------------------------------------------- /env_real/data/prepare_mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import os 11 | import sys 12 | sys.path.insert(0, "/tmp/SPOT/") 13 | 14 | # Copyright (c) Tencent Inc. All rights reserved. 15 | import os 16 | import cv2 17 | import argparse 18 | import os.path as osp 19 | 20 | import torch 21 | from mmengine.config import Config, DictAction 22 | from mmengine.runner.amp import autocast 23 | from mmengine.dataset import Compose 24 | from mmengine.utils import ProgressBar 25 | from mmdet.apis import init_detector 26 | from mmdet.utils import get_test_pipeline_cfg 27 | 28 | import supervision as sv 29 | 30 | BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator(thickness=1) 31 | MASK_ANNOTATOR = sv.MaskAnnotator() 32 | 33 | 34 | class LabelAnnotator(sv.LabelAnnotator): 35 | 36 | @staticmethod 37 | def resolve_text_background_xyxy( 38 | center_coordinates, 39 | text_wh, 40 | position, 41 | ): 42 | center_x, center_y = center_coordinates 43 | text_w, text_h = text_wh 44 | return center_x, center_y, center_x + text_w, center_y + text_h 45 | 46 | 47 | LABEL_ANNOTATOR = LabelAnnotator(text_padding=4, 48 | text_scale=0.5, 49 | text_thickness=1) 50 | 51 | 52 | def inference_detector(model, 53 | image_path, 54 | texts, 55 | test_pipeline, 56 | max_dets=100, 57 | score_thr=0.3, 58 | use_amp=False, 59 | show=False, 60 | annotation=False): 61 | data_info = dict(img_id=0, img_path=image_path, texts=texts) 62 | data_info = test_pipeline(data_info) 63 | data_batch = dict(inputs=data_info['inputs'].unsqueeze(0), 64 | data_samples=[data_info['data_samples']]) 65 | 66 | with autocast(enabled=use_amp), torch.no_grad(): 67 | output = model.test_step(data_batch)[0] 68 | pred_instances = output.pred_instances 69 | pred_instances = pred_instances[pred_instances.scores.float() > 70 | score_thr] 71 | 72 | if len(pred_instances.scores) > max_dets: 73 | indices = pred_instances.scores.float().topk(max_dets)[1] 74 | pred_instances = pred_instances[indices] 75 | 76 | pred_instances = pred_instances.cpu().numpy() 77 | 78 | if 'masks' in pred_instances: 79 | masks = pred_instances['masks'] 80 | else: 81 | masks = None 82 | 83 | detections = sv.Detections(xyxy=pred_instances['bboxes'], 84 | class_id=pred_instances['labels'], 85 | confidence=pred_instances['scores'], 86 | mask=masks) 87 | 88 | labels = [ 89 | f"{texts[class_id][0]} {confidence:0.2f}" for class_id, confidence in 90 | zip(detections.class_id, detections.confidence) 91 | ] 92 | 93 | # label images 94 | image = cv2.imread(image_path) 95 | anno_image = image.copy() 96 | image = BOUNDING_BOX_ANNOTATOR.annotate(image, detections) 97 | image = LABEL_ANNOTATOR.annotate(image, detections, labels=labels) 98 | if masks is not None: 99 | image = MASK_ANNOTATOR.annotate(image, detections) 100 | 101 | if annotation: 102 | images_dict = {} 103 | annotations_dict = {} 104 | 105 | images_dict[osp.basename(image_path)] = anno_image 106 | annotations_dict[osp.basename(image_path)] = detections 107 | 108 | ANNOTATIONS_DIRECTORY = os.makedirs(r"./annotations", exist_ok=True) 109 | 110 | MIN_IMAGE_AREA_PERCENTAGE = 0.002 111 | MAX_IMAGE_AREA_PERCENTAGE = 0.80 112 | APPROXIMATION_PERCENTAGE = 0.75 113 | 114 | sv.DetectionDataset( 115 | classes=texts, images=images_dict, 116 | annotations=annotations_dict).as_yolo( 117 | annotations_directory_path=ANNOTATIONS_DIRECTORY, 118 | min_image_area_percentage=MIN_IMAGE_AREA_PERCENTAGE, 119 | max_image_area_percentage=MAX_IMAGE_AREA_PERCENTAGE, 120 | approximation_percentage=APPROXIMATION_PERCENTAGE) 121 | 122 | if show: 123 | cv2.imshow('Image', image) # Provide window name 124 | k = cv2.waitKey(0) 125 | if k == 27: 126 | # wait for ESC key to exit 127 | cv2.destroyAllWindows() 128 | return pred_instances['bboxes'], image 129 | 130 | 131 | if __name__ == '__main__': 132 | from pathlib import Path 133 | from zipfile import ZipFile 134 | import glob 135 | import shutil 136 | import numpy as np 137 | from env_real.utils.realworld_objects import real_task_object_dict 138 | 139 | # setup yolo world 140 | yolo_world_dir = "/tmp/YOLO-World/" 141 | config_path = f"{yolo_world_dir}/configs/pretrain/yolo_world_v2_x_vlpan_bn_2e-3_100e_4x8gpus_obj365v1_goldg_cc3mlite_train_lvis_minival.py" 142 | weight_path = f"{yolo_world_dir}/weights/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36.pth" 143 | topk = 1 144 | threshold = 0.005 145 | amp = False 146 | show = False 147 | annotation = False 148 | 149 | # load config 150 | cfg = Config.fromfile(config_path) 151 | cfg.work_dir = osp.join('./work_dirs', 152 | osp.splitext(osp.basename(config_path))[0]) 153 | # init model 154 | cfg.load_from = weight_path 155 | model = init_detector(cfg, checkpoint=weight_path, device='cuda:0') 156 | 157 | # init test pipeline 158 | test_pipeline_cfg = get_test_pipeline_cfg(cfg=cfg) 159 | # test_pipeline[0].type = 'mmdet.LoadImageFromNDArray' 160 | test_pipeline = Compose(test_pipeline_cfg) 161 | 162 | # load data 163 | import argparse 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument('task_name', type=str) 166 | parser.add_argument('dataset_path', default="/tmp/record3d/[task_name]/r3d/", type=str) 167 | args = parser.parse_args() 168 | 169 | task_name = args.task_name 170 | print(task_name) 171 | dataset_path = args.dataset_path 172 | print(dataset_path) 173 | r3d_path_list = glob.glob(os.path.join(dataset_path, '*.r3d')) 174 | r3d_path_list.sort() 175 | 176 | for obj_type in ['grasp', 'target']: 177 | 178 | # setup text keyword 179 | prompts = real_task_object_dict[task_name][f"{obj_type}_object_prompt"] 180 | texts = [[t] for t in prompts] + [[' ']] 181 | 182 | for r3d_file in r3d_path_list: 183 | print(r3d_file) 184 | 185 | extract_path = os.path.join(dataset_path, Path(r3d_file).stem) 186 | rgb_dir = os.path.join(extract_path, 'rgb') 187 | 188 | # only the first frame 189 | frame_idx = 0 190 | rgb_path = os.path.join(rgb_dir, f'{frame_idx}.jpg') 191 | fname = Path(rgb_path).stem 192 | 193 | # run inference 194 | vis_dir = os.path.join(extract_path, f'vis') 195 | os.makedirs(vis_dir, exist_ok=True) 196 | model.reparameterize(texts) 197 | bboxes, vis_image = inference_detector(model, 198 | rgb_path, 199 | texts, 200 | test_pipeline, 201 | topk, 202 | threshold, 203 | use_amp=amp, 204 | show=show, 205 | annotation=annotation) 206 | cv2.imwrite((os.path.join(vis_dir, f'bbox_{obj_type}.png')), vis_image) 207 | 208 | # get mask 209 | bbox = bboxes[0] # !! assume only one box 210 | rgb = cv2.imread(rgb_path) 211 | mask = np.zeros(rgb.shape[:-1]) 212 | x1, y1, x2, y2 = bbox.astype(np.int) 213 | mask[y1:y2, x1:x2] = 1. 214 | # print(np.unique(mask, return_counts=True)) 215 | 216 | # save the mask 217 | mask_dir = os.path.join(extract_path, f'mask_{obj_type}') 218 | os.makedirs(mask_dir, exist_ok=True) 219 | 220 | pallet = np.array([ 221 | [0, 0, 0], 222 | [255, 255, 255] 223 | ]).astype(np.uint8) 224 | 225 | from PIL import Image 226 | mask_image = Image.fromarray(mask.astype(np.uint8)).convert('P') 227 | mask_image.putpalette(pallet) 228 | mask_image.save(os.path.join(mask_dir, fname+'.png')) 229 | 230 | # read the mask 231 | # mask_image = Image.open(os.path.join(mask_dir, fname+'.png')).convert('P') 232 | # mask = np.array(mask_image) 233 | # print(np.unique(mask, return_counts=True)) 234 | # mask_image.close() -------------------------------------------------------------------------------- /env_real/data/prepare_pose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import os 11 | import sys 12 | sys.path.insert(0, os.getcwd()) 13 | 14 | from foundation_pose.wrapper import FoundationPoseWrapper 15 | from pathlib import Path 16 | import glob 17 | import numpy as np 18 | import os 19 | import imageio 20 | import cv2 21 | from PIL import Image 22 | import trimesh 23 | from foundation_pose.Utils import depth2xyzmap, draw_posed_3d_box, draw_xyz_axis, toOpen3dCloud, trimesh_add_pure_colored_texture 24 | from scipy.spatial.transform import Rotation 25 | from tqdm import tqdm 26 | import open3d as o3d 27 | from env_real.utils.realworld_objects import real_task_object_dict 28 | 29 | 30 | def get_vis_pose(pose, color, K, mesh): 31 | to_origin, extents = trimesh.bounds.oriented_bounds(mesh) 32 | bbox = np.stack([-extents/2, extents/2], axis=0).reshape(2,3) 33 | 34 | center_pose = pose @ np.linalg.inv(to_origin) 35 | 36 | vis = draw_posed_3d_box(K, img=color, ob_in_cam=center_pose, bbox=bbox) 37 | vis = draw_xyz_axis(color, ob_in_cam=center_pose, scale=0.1, K=K, thickness=3, transparency=0, is_input_rgb=True) 38 | return vis 39 | 40 | if __name__=='__main__': 41 | import argparse 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('task_name', type=str) 44 | args = parser.parse_args() 45 | task_name = args.task_name 46 | print(task_name) 47 | dataset_path = f"/tmp/record3d/{task_name}/r3d/" 48 | r3d_path_list = glob.glob(os.path.join(dataset_path, '*.r3d')) 49 | r3d_path_list.sort() 50 | 51 | 52 | for obj_type in ['grasp', 'target']: 53 | 54 | # get mesh 55 | obj_name = real_task_object_dict[task_name][f"{obj_type}_object_name"] 56 | mesh_path = f"/tmp/record3d/mesh/{obj_name}/{obj_name}.obj" 57 | mesh = trimesh.load(mesh_path) 58 | mesh.vertices = mesh.vertices - np.mean(mesh.vertices, axis=0) 59 | # mesh.show() 60 | 61 | for r3d_file in r3d_path_list: 62 | print(r3d_file) 63 | 64 | # if not "2024-09-02--19-49-49" in r3d_file: 65 | # continue 66 | 67 | # get FP 68 | pose_estimation_wrapper = FoundationPoseWrapper(mesh_dir=None) 69 | pose_estimation_wrapper.mesh = mesh 70 | pose_estimator = pose_estimation_wrapper.create_estimator(debug_level=0) 71 | 72 | # process data 73 | extract_path = os.path.join(dataset_path, Path(r3d_file).stem) 74 | rgb_path_list = glob.glob(os.path.join(extract_path, 'rgb', '*.jpg')) 75 | vis_pose_list = [] 76 | object_poses = [] 77 | 78 | # get K 79 | K_path = os.path.join(extract_path, 'K.txt') 80 | K = np.loadtxt(K_path).reshape(3,3) 81 | 82 | # get HW 83 | H,W = cv2.imread(os.path.join(extract_path, 'rgb', f'0.jpg')).shape[:2] 84 | 85 | for frame_idx in tqdm(range(len(rgb_path_list))): 86 | 87 | # get rgb 88 | rgb_path = os.path.join(extract_path, 'rgb', f'{frame_idx}.jpg') 89 | fname = Path(rgb_path).stem 90 | rgb = imageio.imread(rgb_path)[...,:3] 91 | # rgb = cv2.resize(rgb, (192,256), interpolation=cv2.INTER_NEAREST) 92 | 93 | # get depth 94 | depth_path = os.path.join(extract_path, 'depth', f'{frame_idx}.png') 95 | depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH) / 1000. 96 | depth = cv2.resize(depth, (W, H)) 97 | 98 | if frame_idx == 0: 99 | # get mask 100 | mask_path = os.path.join(extract_path, f'mask_{obj_type}', f'{frame_idx}.png') 101 | mask_image = Image.open(mask_path).convert('P') 102 | mask = np.array(mask_image) 103 | # mask = cv2.resize(mask, (192,256), interpolation=cv2.INTER_NEAREST) 104 | mask = (mask == 1).astype(bool).astype(np.uint8) 105 | 106 | fp_mat = pose_estimator.register(K=K, rgb=rgb, depth=depth, ob_mask=mask, iteration=5) 107 | else: 108 | fp_mat = pose_estimator.track_one(rgb=rgb, depth=depth, K=K, iteration=2) 109 | 110 | # get pose in world frame 111 | extrinsic_path = os.path.join(extract_path, 'cam_in_ob', fname + '.txt') 112 | extrinsic = np.loadtxt(extrinsic_path).reshape(4, 4) 113 | pose_in_world = np.matmul(extrinsic, fp_mat) 114 | pose_in_world_quat = np.concatenate([ 115 | pose_in_world[:3, 3], 116 | Rotation.from_matrix(pose_in_world[:3, :3]).as_quat(), 117 | ]) 118 | object_poses.append(pose_in_world_quat) 119 | 120 | # visualization 121 | vis_pose = get_vis_pose( 122 | pose=fp_mat, 123 | color=rgb, 124 | K=K, 125 | mesh=pose_estimation_wrapper.mesh 126 | ) 127 | vis_pose_list.append(vis_pose) 128 | 129 | # save object pose in world 130 | assert len(object_poses) == len(rgb_path_list), f"{len(object_poses)} {len(rgb_path_list)}" 131 | np.save(os.path.join(extract_path, f"{obj_type}_object_poses.npy"), np.array(object_poses)) 132 | object_poses = np.load(os.path.join(extract_path, f"{obj_type}_object_poses.npy")) 133 | 134 | # save video 135 | video_path = os.path.join(extract_path, 'vis', f'pose_track_{obj_type}.mp4') 136 | video_writer = imageio.get_writer(video_path, fps=40) 137 | for img in vis_pose_list: 138 | video_writer.append_data(img) 139 | video_writer.close() 140 | 141 | # save point cloud in target's frame 142 | for r3d_file in r3d_path_list: 143 | print(r3d_file) 144 | # save scene_in_target pointcloud 145 | pcd = o3d.io.read_point_cloud(f'{extract_path}/pc_original_in_world.ply') 146 | points = np.asarray(pcd.points) 147 | print(points.shape) 148 | 149 | from utils.pose_utils import get_rel_pose, euler_from_quaternion, compute_rel_transform 150 | import utils.transform_utils as T 151 | grasp_obj_pose = np.load(os.path.join(extract_path, f"grasp_object_poses.npy"))[0] 152 | target_obj_pose = np.load(os.path.join(extract_path, f"target_object_poses.npy"))[0] 153 | points = np.vstack([ 154 | compute_rel_transform(target_obj_pose[:3], T.quat2mat(target_obj_pose[3:]), p[:3], T.quat2mat(grasp_obj_pose[3:]))[0] 155 | for p in points 156 | ] 157 | ) 158 | pcd.points = o3d.utility.Vector3dVector(points.reshape(-1, 3)) 159 | o3d.io.write_point_cloud(f"{extract_path}/scene_in_target.ply", pcd) -------------------------------------------------------------------------------- /env_real/data/prepare_rgbd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import numpy as np 11 | import cv2 12 | import os 13 | import argparse 14 | import liblzfse 15 | import open3d as o3d 16 | import json 17 | import imageio 18 | 19 | 20 | def load_depth(filepath, H, W): 21 | with open(filepath, 'rb') as depth_fh: 22 | raw_bytes = depth_fh.read() 23 | decompressed_bytes = liblzfse.decompress(raw_bytes) 24 | depth_img = np.frombuffer(decompressed_bytes, dtype=np.float32) 25 | depth_img = depth_img.reshape((H, W)) 26 | return depth_img 27 | 28 | def load_conf(filepath, H, W): 29 | with open(filepath, 'rb') as depth_fh: 30 | raw_bytes = depth_fh.read() 31 | decompressed_bytes = liblzfse.decompress(raw_bytes) 32 | conf = np.frombuffer(decompressed_bytes, dtype=np.int8) 33 | conf = conf.reshape((H, W)) 34 | return np.float32(conf) 35 | 36 | def create_point_cloud_depth(depth, rgb, fx, fy, cx, cy): 37 | depth_shape = depth.shape 38 | [x_d, y_d] = np.meshgrid(range(0, depth_shape[1]), range(0, depth_shape[0])) 39 | x3 = np.divide(np.multiply((x_d-cx), depth), fx) 40 | y3 = np.divide(np.multiply((y_d-cy), depth), fy) 41 | z3 = depth 42 | 43 | coord = np.stack((x3, y3, z3), axis=2) 44 | 45 | rgb_norm = rgb/255 46 | 47 | return np.concatenate((coord, rgb_norm), axis=2) 48 | 49 | if __name__ == '__main__': 50 | from pathlib import Path 51 | from zipfile import ZipFile 52 | import glob 53 | import shutil 54 | 55 | from scipy.spatial.transform import Rotation 56 | 57 | import argparse 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('task_name', type=str) 60 | parser.add_argument('dataset_path', default="/tmp/record3d/[task_name]/r3d/", type=str) 61 | args = parser.parse_args() 62 | 63 | task_name = args.task_name 64 | print(task_name) 65 | dataset_path = args.dataset_path 66 | print(dataset_path) 67 | r3d_path_list = glob.glob(os.path.join(dataset_path, '*.r3d')) 68 | r3d_path_list.sort() 69 | 70 | for r3d_file in r3d_path_list: 71 | print(r3d_file) 72 | 73 | # unzip file 74 | extract_path = os.path.join(dataset_path, Path(r3d_file).stem) 75 | ZipFile(r3d_file).extractall(extract_path) 76 | 77 | # get metadata 78 | with open(os.path.join(extract_path, 'metadata'), "rb") as f: 79 | metadata = json.loads(f.read()) 80 | poses = np.asarray(metadata['poses']) # (N, 7) [x, y, z, qx, qy, qz, qw] 81 | 82 | # get intrinsics 83 | K = np.asarray(metadata['K']).reshape(3, 3).T 84 | np.savetxt(os.path.join(extract_path, 'K.txt'), K) 85 | fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] 86 | 87 | # process data 88 | rgb_path_list = glob.glob(os.path.join(extract_path, 'rgbd', '*.jpg')) 89 | assert len(poses) == len(rgb_path_list), f"{len(poses)} {len(rgb_path_list)}" 90 | 91 | # get HW 92 | downscale_factor = 1. # 1. for front and 3.75 for rear camera set by ios 93 | H,W = cv2.imread(os.path.join(extract_path, 'rgbd', f'0.jpg')).shape[:2] 94 | H_dc, W_dc = int(H/downscale_factor), int(W/downscale_factor) 95 | 96 | writer = imageio.get_writer(os.path.join(extract_path, 'video.mp4'), fps=30) 97 | for frame_idx in range(len(rgb_path_list)): 98 | 99 | rgb_path = os.path.join(extract_path, 'rgbd', f'{frame_idx}.jpg') 100 | fname = Path(rgb_path).stem 101 | writer.append_data(imageio.imread(rgb_path)) 102 | 103 | # copy all rgb to dir "/rgb" 104 | rgb_dir = os.path.join(extract_path, 'rgb') 105 | os.makedirs(rgb_dir, exist_ok=True) 106 | shutil.copy(rgb_path, os.path.join(rgb_dir, fname + '.jpg')) 107 | 108 | # save depth 109 | depth_path = rgb_path.replace('.jpg', '.depth') 110 | depth = load_depth(str(depth_path), H_dc, W_dc) 111 | 112 | depth_dir = os.path.join(extract_path, 'depth') 113 | os.makedirs(depth_dir, exist_ok=True) 114 | cv2.imwrite(os.path.join(depth_dir, fname + '.png'), (depth * 1000.).astype(np.uint16)) 115 | # print(np.max(np.nan_to_num(depth)), np.min(np.nan_to_num(depth))) 116 | # depth = cv2.imread(os.path.join(rgb_dir, fname + '.png'), cv2.IMREAD_ANYDEPTH) / 1000. 117 | # print(np.max(depth), np.min(depth)) 118 | 119 | # save extrinsic 120 | pose = poses[frame_idx] # cam2world 121 | pose_mat = np.eye(4) 122 | pose_mat[:3, :3] = Rotation.from_quat(pose[3:]).as_matrix() 123 | pose_mat[3, :3] = pose[:3] 124 | extrinsic = np.linalg.inv(pose_mat) 125 | 126 | pose_dir = os.path.join(extract_path, 'cam_in_ob') 127 | os.makedirs(pose_dir, exist_ok=True) 128 | np.savetxt(os.path.join(pose_dir, fname + '.txt'), extrinsic) 129 | 130 | if frame_idx == 0: 131 | # load rgb 132 | rgb_path = os.path.join(extract_path, 'rgbd', fname+'.jpg') 133 | rgb = cv2.imread(rgb_path) 134 | rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) 135 | 136 | # load confidence 137 | conf_path = os.path.join(extract_path, 'rgbd', fname+'.conf') 138 | if os.path.exists(conf_path): 139 | conf = load_conf(conf_path, H_dc, W_dc) 140 | else: 141 | conf = None 142 | 143 | # get pointcloud (interpolated) 144 | depth_resized = cv2.resize(depth, (W, H)) 145 | pc = create_point_cloud_depth(depth_resized, rgb, fx, fy, cx, cy).reshape(-1, 6) 146 | if conf is not None: 147 | conf_resized = cv2.resize(conf, (W, H), cv2.INTER_NEAREST_EXACT) 148 | pc = pc[conf_resized.reshape(-1) >= 2] 149 | 150 | pcd = o3d.geometry.PointCloud() 151 | pcd.points = o3d.utility.Vector3dVector(pc[:, :3]) 152 | pcd.colors = o3d.utility.Vector3dVector(pc[:, 3:]) 153 | # o3d.visualization.draw_geometries([pcd]) 154 | o3d.io.write_point_cloud(f'{extract_path}/pc_interpolated.ply', pcd) 155 | 156 | pcd.transform(extrinsic) 157 | o3d.io.write_point_cloud(f'{extract_path}/pc_interpolated_in_world.ply', pcd) 158 | 159 | # get pointcloud (original resolution) 160 | rgb_resized = cv2.resize(rgb, (W_dc, H_dc)) 161 | pc = create_point_cloud_depth(depth, rgb_resized, fx / downscale_factor, fy / downscale_factor, cx / downscale_factor, cy / downscale_factor).reshape(-1, 6) 162 | if conf is not None: 163 | pc = pc[conf.reshape(-1) >= 2] 164 | 165 | pcd2 = o3d.geometry.PointCloud() 166 | pcd2.points = o3d.utility.Vector3dVector(pc[:, :3]) 167 | pcd2.colors = o3d.utility.Vector3dVector(pc[:, 3:]) 168 | # o3d.visualization.draw_geometries([pcd, pcd2]) 169 | o3d.io.write_point_cloud(f'{extract_path}/pc_original.ply', pcd2) 170 | 171 | pcd2.transform(extrinsic) 172 | o3d.io.write_point_cloud(f'{extract_path}/pc_original_in_world.ply', pcd2) 173 | writer.close() -------------------------------------------------------------------------------- /env_real/utils/realworld_objects.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | real_task_object_dict = { 11 | "pour_water": { 12 | "grasp_object_name": "blue_kettle", 13 | "grasp_object_prompt": ["blue kettle"], 14 | "target_object_name": "starbucks_mug", 15 | "target_object_prompt": ["orange mug"], 16 | }, 17 | } -------------------------------------------------------------------------------- /env_rlbench/policy/subgoal_policy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import numpy as np 11 | from diffusion_policy_3d.common.replay_buffer import ReplayBuffer 12 | from utils.pose_utils import get_rel_pose, relative_to_target_to_world 13 | 14 | 15 | class RLBenchSubGoalPolicy: 16 | def __init__(self, env): 17 | pass 18 | 19 | def _open_gripper(self): 20 | gripper = self.env._rlbench_env._scene.robot.gripper 21 | scene = self.env._rlbench_env._scene 22 | gripper.release() # remove object from the list gripper._grasped_objects 23 | 24 | done = False 25 | i = 0 26 | vel = 0.04 27 | open_amount = 1.0 #if gripper_name == 'Robotiq85Gripper' else 0.8 28 | while not done: 29 | done = gripper.actuate(open_amount, velocity=vel) 30 | scene.step() 31 | i += 1 32 | if i > 1000: 33 | self.fail('Took too many steps to open') 34 | return None 35 | 36 | def _move_away(self, axis='z', dist=0.2): 37 | gripper_pose = self.env._task._robot.arm.get_tip().get_pose() 38 | action = np.zeros(8) 39 | action[7] = 1 # open gripper 40 | action[:7] = gripper_pose 41 | if 'x' in axis: 42 | action[0] += dist[0] if isinstance(dist, list) else dist 43 | if 'y' in axis: 44 | action[1] += dist[1] if isinstance(dist, list) else dist 45 | if 'z' in axis: 46 | action[2] += dist[2] if isinstance(dist, list) else dist 47 | return action 48 | 49 | def _move_to(self, goal_pose, close_gripper=True, object_centric=False): 50 | action = np.zeros(8) 51 | action[7] = 0 # close gripper 52 | action[:7] = goal_pose # assume the action mode is EndEffectorPoseViaIK or EndEffectorPoseViaPlanning 53 | return action 54 | 55 | def _subgoal_relative_to_target_to_subgoal_gripper(self, subgoal_relative_to_target, target_obj_pose): 56 | subgoal = relative_to_target_to_world(subgoal_relative_to_target, target_obj_pose) 57 | subgoal_gripper = relative_to_target_to_world(self.T_obj_to_gripper, subgoal) 58 | return subgoal_gripper -------------------------------------------------------------------------------- /env_rlbench/runner/rl_bench_camera.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import numpy as np 11 | from pyrep.const import RenderMode 12 | from pyrep.objects import VisionSensor 13 | 14 | def get_camera(name): 15 | if name == "custom_cam_front_light_bulb_in": 16 | cam = VisionSensor.create([1024, 576]) 17 | cam.set_explicit_handling(True) 18 | # cam.set_position([1.0, 0., 1.3]) 19 | cam.set_position([0.9, 0., 1.3]) 20 | cam.set_orientation([-np.pi, -0.4*np.pi, 0.5*np.pi]) 21 | cam.set_render_mode(RenderMode.OPENGL) 22 | else: 23 | raise NotImplementedError 24 | return cam -------------------------------------------------------------------------------- /env_rlbench/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='env_rlbench', 5 | packages=find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /env_rlbench/utils/rlbench_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | """ 11 | A wrapper to collect object poses from demos 12 | """ 13 | from functools import partial 14 | from os.path import exists, dirname, abspath, join 15 | from typing import List, Callable 16 | 17 | import numpy as np 18 | from pyrep import PyRep 19 | from pyrep.objects.vision_sensor import VisionSensor 20 | 21 | from rlbench.backend.const import * 22 | from rlbench.const import SUPPORTED_ROBOTS 23 | from rlbench.backend.robot import Robot 24 | from rlbench.backend.scene import Scene 25 | from rlbench.environment import Environment, DIR_PATH 26 | 27 | from utils.pose_utils import get_rel_pose, euler_from_quaternion 28 | 29 | 30 | class MyScene(Scene): 31 | def __init__(self, *args, **kwargs): 32 | super().__init__(*args, **kwargs) 33 | 34 | def _get_misc(self): 35 | def _get_cam_data(cam: VisionSensor, name: str): 36 | d = {} 37 | if cam.still_exists(): 38 | d = { 39 | '%s_extrinsics' % name: cam.get_matrix(), 40 | '%s_intrinsics' % name: cam.get_intrinsic_matrix(), 41 | '%s_near' % name: cam.get_near_clipping_plane(), 42 | '%s_far' % name: cam.get_far_clipping_plane(), 43 | } 44 | return d 45 | misc = _get_cam_data(self._cam_over_shoulder_left, 'left_shoulder_camera') 46 | misc.update(_get_cam_data(self._cam_over_shoulder_right, 'right_shoulder_camera')) 47 | misc.update(_get_cam_data(self._cam_overhead, 'overhead_camera')) 48 | misc.update(_get_cam_data(self._cam_front, 'front_camera')) 49 | misc.update(_get_cam_data(self._cam_wrist, 'wrist_camera')) 50 | misc.update({"variation_index": self._variation_index}) 51 | 52 | if self._joint_position_action is not None: 53 | # Store the actual requested joint positions during demo collection 54 | misc.update({"joint_position_action": self._joint_position_action}) 55 | joint_poses = [j.get_pose() for j in self.robot.arm.joints] 56 | misc.update({'joint_poses': joint_poses}) 57 | 58 | # add object pose 59 | grasp_obj_pose = self.task._cups[0].get_pose() 60 | misc.update({'grasp_obj_pose': grasp_obj_pose}) 61 | target_obj_pose = self.task._spokes[0].get_pose() 62 | misc.update({'target_obj_pose': target_obj_pose}) 63 | 64 | 65 | obj_pose_relative_to_target = get_rel_pose(target_obj_pose, grasp_obj_pose) # !! careful about the order (target, grasp_object) 66 | misc.update({'obj_pose_relative_to_target': obj_pose_relative_to_target}) 67 | target_pose_relative_to_target = get_rel_pose(target_obj_pose, target_obj_pose) 68 | misc.update({'target_pose_relative_to_target': target_pose_relative_to_target}) 69 | 70 | # convert from quaternion to euler 71 | obj_pose_relative_to_target_euler = np.zeros(6) 72 | obj_pose_relative_to_target_euler[:3] = obj_pose_relative_to_target[:3] 73 | obj_pose_relative_to_target_euler[3:] = euler_from_quaternion(*obj_pose_relative_to_target[3:]) 74 | misc.update({'obj_pose_relative_to_target_euler': obj_pose_relative_to_target_euler}) 75 | target_pose_relative_to_target_euler = np.zeros(6) 76 | target_pose_relative_to_target_euler[3:] = target_pose_relative_to_target[:3] 77 | target_pose_relative_to_target_euler[3:] = euler_from_quaternion(*target_pose_relative_to_target[3:]) 78 | misc.update({'target_pose_relative_to_target_euler': target_pose_relative_to_target_euler}) 79 | 80 | return misc 81 | 82 | 83 | class MyEnvironment(Environment): 84 | def __init__(self, *args, **kwargs): 85 | super().__init__(*args, **kwargs) 86 | 87 | def launch(self): 88 | if self._pyrep is not None: 89 | raise RuntimeError('Already called launch!') 90 | self._pyrep = PyRep() 91 | self._pyrep.launch(join(DIR_PATH, TTT_FILE), headless=self._headless) 92 | 93 | arm_class, gripper_class, _ = SUPPORTED_ROBOTS[ 94 | self._robot_setup] 95 | arm_class = partial( 96 | arm_class, 97 | max_velocity=self._arm_max_velocity, 98 | max_acceleration=self._arm_max_acceleration) 99 | 100 | # We assume the panda is already loaded in the scene. 101 | if self._robot_setup != 'panda': 102 | # Remove the panda from the scene 103 | panda_arm = Panda() 104 | panda_pos = panda_arm.get_position() 105 | panda_arm.remove() 106 | arm_path = join(DIR_PATH, 'robot_ttms', self._robot_setup + '.ttm') 107 | self._pyrep.import_model(arm_path) 108 | arm, gripper = arm_class(), gripper_class() 109 | arm.set_position(panda_pos) 110 | else: 111 | arm, gripper = arm_class(), gripper_class() 112 | 113 | self._robot = Robot(arm, gripper) 114 | if self._randomize_every is None: 115 | self._scene = MyScene( 116 | self._pyrep, self._robot, self._obs_config, self._robot_setup) 117 | else: 118 | raise NotImplementedError 119 | # self._scene = MyDomainRandomizationScene( 120 | # self._pyrep, self._robot, self._obs_config, self._robot_setup, 121 | # self._randomize_every, self._frequency, 122 | # self._visual_randomization_config, 123 | # self._dynamics_randomization_config) 124 | 125 | self._action_mode.arm_action_mode.set_control_mode(self._robot) 126 | 127 | 128 | from rlbench.noise_model import NoiseModel 129 | from rlbench.backend.utils import image_to_float_array, rgb_handles_to_mask 130 | 131 | 132 | def get_rgb_depth(sensor: VisionSensor, get_rgb: bool, get_depth: bool, 133 | get_pcd: bool, rgb_noise: NoiseModel, 134 | depth_noise: NoiseModel, depth_in_meters: bool): 135 | rgb = depth = pcd = None 136 | if sensor is not None and (get_rgb or get_depth): 137 | sensor.handle_explicitly() 138 | if get_rgb: 139 | rgb = sensor.capture_rgb() 140 | if rgb_noise is not None: 141 | rgb = rgb_noise.apply(rgb) 142 | rgb = np.clip((rgb * 255.).astype(np.uint8), 0, 255) 143 | if get_depth or get_pcd: 144 | depth = sensor.capture_depth(depth_in_meters) 145 | if depth_noise is not None: 146 | depth = depth_noise.apply(depth) 147 | if get_pcd: 148 | depth_m = depth 149 | if not depth_in_meters: 150 | near = sensor.get_near_clipping_plane() 151 | far = sensor.get_far_clipping_plane() 152 | depth_m = near + depth * (far - near) 153 | pcd = sensor.pointcloud_from_depth(depth_m) 154 | if not get_depth: 155 | depth = None 156 | return rgb, depth, pcd 157 | 158 | 159 | def get_mask(sensor: VisionSensor, masks_as_one_channel=True): 160 | masks_as_one_channel = True 161 | mask_fn = rgb_handles_to_mask if masks_as_one_channel else lambda x: x 162 | 163 | mask = None 164 | if sensor is not None: 165 | sensor.handle_explicitly() 166 | mask = mask_fn(sensor.capture_rgb()) 167 | return mask 168 | 169 | 170 | def get_seg_mask(obj_list, mask): 171 | mask_id_list = np.unique(mask) 172 | 173 | select_id_list = [] 174 | for obj in obj_list: 175 | id = obj.get_handle() 176 | name = obj.get_object_name(obj.get_handle()) 177 | if id not in mask_id_list: 178 | id = id + 1 # !! this solves the inconsisencies between object id and mask id 179 | select_id_list.append(id) 180 | 181 | seg_mask = np.zeros_like(mask) 182 | for seg_id, obj_id in enumerate(select_id_list): 183 | seg_mask[mask == obj_id] = seg_id + 1 # start from 1 184 | return seg_mask 185 | -------------------------------------------------------------------------------- /env_rlbench_peract/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='env_rlbench_peract', 5 | packages=find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /env_rlbench_peract/utils/rlbench_objects.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | from rlbench.const import colors 11 | from rlbench.tasks.put_groceries_in_cupboard import GROCERY_NAMES 12 | from rlbench.tasks.place_shape_in_shape_sorter import SHAPE_NAMES 13 | 14 | dustpan_sizes = ['tall', 'short'] 15 | 16 | task_object_dict = { 17 | "meat_off_grill": { 18 | "grasp_object_name": { 19 | 0: "chicken", 20 | 1: "steak" 21 | }, 22 | "target_object_name": "grill", 23 | }, 24 | "turn_tap": { 25 | "grasp_object_name": { 26 | 0: "tap_left", 27 | 1: "tap_right" 28 | }, 29 | "target_object_name": "tap_main", 30 | }, 31 | "close_jar": { 32 | "grasp_object_name": "jar_lid0", #{i: f"jar_lid{i % 2}" for i in range(len(colors))}, 33 | "target_object_name": {i: f"jar{i % 2}" for i in range(len(colors))}, 34 | }, 35 | "reach_and_drag": { 36 | "grasp_object_name": "stick", 37 | "target_object_name": "target0", 38 | }, 39 | "stack_blocks": { # multiple stage 40 | "grasp_object_name": {i: f"stack_blocks_target{i}" for i in range(4)}, 41 | "target_object_name": { 42 | 0: "stack_blocks_target_plane", 43 | 1: f"stack_blocks_target{0}", 44 | 2: f"stack_blocks_target{1}", 45 | 3: f"stack_blocks_target{2}", 46 | }, 47 | }, 48 | "light_bulb_in": { 49 | "grasp_object_name": {i: f"light_bulb{i % 2}" for i in range(len(colors))}, 50 | "target_object_name": "lamp_base", 51 | }, 52 | "put_money_in_safe": { 53 | "grasp_object_name": "dollar_stack", 54 | "target_object_name": "safe_body", 55 | }, 56 | "place_wine_at_rack_location": { 57 | "grasp_object_name": "wine_bottle", 58 | "target_object_name": "rack_top", 59 | }, 60 | "put_groceries_in_cupboard":{ 61 | "grasp_object_name": {i: GROCERY_NAMES[i].replace(' ', '_') for i in range(len(GROCERY_NAMES))}, 62 | "target_object_name": "cupboard", 63 | }, 64 | "place_shape_in_shape_sorter":{ 65 | "grasp_object_name": {i: SHAPE_NAMES[i].replace(' ', '_') for i in range(len(SHAPE_NAMES))}, 66 | "target_object_name": "shape_sorter", 67 | }, 68 | "insert_onto_square_peg":{ 69 | "grasp_object_name": "square_ring", 70 | "target_object_name": "__NONE__", # handled on run time 71 | }, 72 | "stack_cups": { # multiple stage 73 | "grasp_object_name": { 74 | 0: "cup1", 75 | 1: "cup3", 76 | }, 77 | "target_object_name": { 78 | 0: "cup2", 79 | 1: "cup1", 80 | }, 81 | }, 82 | "place_cups": { # multiple stage 83 | "grasp_object_name": {i: f"mug{i}" for i in range(3)}, 84 | "target_object_name": "place_cups_holder_base", 85 | }, 86 | } -------------------------------------------------------------------------------- /foundation_pose/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022-2023, NVIDIA Corporation & affiliates. All rights reserved. 2 | 3 | 4 | ======================================================================= 5 | 6 | 1. Definitions 7 | 8 | "Licensor" means any person or entity that distributes its Work. 9 | 10 | "Software" means the original work of authorship made available under 11 | this License. 12 | 13 | "Work" means the Software and any additions to or derivative works of 14 | the Software that are made available under this License. 15 | 16 | The terms "reproduce," "reproduction," "derivative works," and 17 | "distribution" have the meaning as provided under U.S. copyright law; 18 | provided, however, that for the purposes of this License, derivative 19 | works shall not include works that remain separable from, or merely 20 | link (or bind by name) to the interfaces of, the Work. 21 | 22 | Works, including the Software, are "made available" under this License 23 | by including in or with the Work either (a) a copyright notice 24 | referencing the applicability of this License to the Work, or (b) a 25 | copy of this License. 26 | 27 | 2. License Grants 28 | 29 | 2.1 Copyright Grant. Subject to the terms and conditions of this 30 | License, each Licensor grants to you a perpetual, worldwide, 31 | non-exclusive, royalty-free, copyright license to reproduce, 32 | prepare derivative works of, publicly display, publicly perform, 33 | sublicense and distribute its Work and any resulting derivative 34 | works in any form. 35 | 36 | 3. Limitations 37 | 38 | 3.1 Redistribution. You may reproduce or distribute the Work only 39 | if (a) you do so under this License, (b) you include a complete 40 | copy of this License with your distribution, and (c) you retain 41 | without modification any copyright, patent, trademark, or 42 | attribution notices that are present in the Work. 43 | 44 | 3.2 Derivative Works. You may specify that additional or different 45 | terms apply to the use, reproduction, and distribution of your 46 | derivative works of the Work ("Your Terms") only if (a) Your Terms 47 | provide that the use limitation in Section 3.3 applies to your 48 | derivative works, and (b) you identify the specific derivative 49 | works that are subject to Your Terms. Notwithstanding Your Terms, 50 | this License (including the redistribution requirements in Section 51 | 3.1) will continue to apply to the Work itself. 52 | 53 | 3.3 Use Limitation. The Work and any derivative works thereof only 54 | may be used or intended for use non-commercially. Notwithstanding 55 | the foregoing, NVIDIA and its affiliates may use the Work and any 56 | derivative works commercially. As used herein, "non-commercially" 57 | means for research or evaluation purposes only. 58 | 59 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 60 | against any Licensor (including any claim, cross-claim or 61 | counterclaim in a lawsuit) to enforce any patents that you allege 62 | are infringed by any Work, then your rights under this License from 63 | such Licensor (including the grant in Section 2.1) will terminate 64 | immediately. 65 | 66 | 3.5 Trademarks. This License does not grant any rights to use any 67 | Licensor�s or its affiliates� names, logos, or trademarks, except 68 | as necessary to reproduce the notices described in this License. 69 | 70 | 3.6 Termination. If you violate any term of this License, then your 71 | rights under this License (including the grant in Section 2.1) will 72 | terminate immediately. 73 | 74 | 4. Disclaimer of Warranty. 75 | 76 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 77 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 78 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 79 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 80 | THIS LICENSE. 81 | 82 | 5. Limitation of Liability. 83 | 84 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 85 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 86 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 87 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 88 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 89 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 90 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 91 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 92 | THE POSSIBILITY OF SUCH DAMAGES. 93 | 94 | ======================================================================= -------------------------------------------------------------------------------- /foundation_pose/build_all.sh: -------------------------------------------------------------------------------- 1 | DIR=$(pwd) 2 | 3 | cd $DIR/mycpp/ && mkdir -p build && cd build && cmake .. && make -j11 4 | 5 | cd ${DIR} -------------------------------------------------------------------------------- /foundation_pose/build_all_conda.sh: -------------------------------------------------------------------------------- 1 | PROJ_ROOT=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 2 | 3 | # Install mycpp 4 | cd ${PROJ_ROOT}/mycpp/ && \ 5 | rm -rf build && mkdir -p build && cd build && \ 6 | cmake .. && \ 7 | make -j$(nproc) 8 | 9 | # Install mycuda 10 | # cd ${PROJ_ROOT}/bundlesdf/mycuda && \ 11 | # rm -rf build *egg* *.so && \ 12 | # python -m pip install -e . 13 | 14 | cd ${PROJ_ROOT} -------------------------------------------------------------------------------- /foundation_pose/estimater.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | from foundation_pose.Utils import * 11 | from foundation_pose.Utils import erode_depth 12 | from foundation_pose.datareader import * 13 | import itertools 14 | from foundation_pose.learning.training.predict_score import * 15 | from foundation_pose.learning.training.predict_pose_refine import * 16 | import yaml 17 | 18 | 19 | class FoundationPose: 20 | def __init__(self, model_pts, model_normals, symmetry_tfs=None, mesh=None, scorer:ScorePredictor=None, refiner:PoseRefinePredictor=None, cfg=None, glctx=None, debug=0, debug_dir='/home/bowen/debug/novel_pose_debug/'): 21 | self.gt_pose = None 22 | self.ignore_normal_flip = True 23 | self.debug = debug 24 | self.debug_dir = debug_dir 25 | self.cfg = cfg 26 | os.makedirs(debug_dir, exist_ok=True) 27 | 28 | self.reset_object(model_pts, model_normals, symmetry_tfs=symmetry_tfs, mesh=mesh) 29 | self.make_rotation_grid(min_n_views=40, inplane_step=60) 30 | 31 | self.glctx = glctx 32 | 33 | if scorer is not None: 34 | self.scorer = scorer 35 | else: 36 | self.scorer = ScorePredictor() 37 | 38 | if refiner is not None: 39 | self.refiner = refiner 40 | else: 41 | self.refiner = PoseRefinePredictor() 42 | 43 | self.pose_last = None # Used for tracking; per the centered mesh 44 | 45 | 46 | def reset_object(self, model_pts, model_normals, symmetry_tfs=None, mesh=None): 47 | max_xyz = mesh.vertices.max(axis=0) 48 | min_xyz = mesh.vertices.min(axis=0) 49 | self.model_center = (min_xyz+max_xyz)/2 50 | if mesh is not None: 51 | self.mesh_ori = mesh.copy() 52 | mesh = mesh.copy() 53 | mesh.vertices = mesh.vertices - self.model_center.reshape(1,3) 54 | 55 | model_pts = mesh.vertices 56 | self.diameter = compute_mesh_diameter(model_pts=mesh.vertices, n_sample=10000) 57 | self.vox_size = max(self.diameter/20.0, 0.003) 58 | # logging.info(f'self.diameter:{self.diameter}, vox_size:{self.vox_size}') 59 | self.dist_bin = self.vox_size/2 60 | self.angle_bin = 20 # Deg 61 | pcd = toOpen3dCloud(model_pts, normals=model_normals) 62 | pcd = pcd.voxel_down_sample(self.vox_size) 63 | self.max_xyz = np.asarray(pcd.points).max(axis=0) 64 | self.min_xyz = np.asarray(pcd.points).min(axis=0) 65 | self.pts = torch.tensor(np.asarray(pcd.points), dtype=torch.float32, device='cuda') 66 | self.normals = F.normalize(torch.tensor(np.asarray(pcd.normals), dtype=torch.float32, device='cuda'), dim=-1) 67 | # logging.info(f'self.pts:{self.pts.shape}') 68 | self.mesh_path = None 69 | self.mesh = mesh 70 | if self.mesh is not None: 71 | self.mesh_path = f'/tmp/{uuid.uuid4()}.obj' 72 | self.mesh.export(self.mesh_path) 73 | self.mesh_tensors = make_mesh_tensors(self.mesh) 74 | 75 | if symmetry_tfs is None: 76 | self.symmetry_tfs = torch.eye(4).float().cuda()[None] 77 | else: 78 | self.symmetry_tfs = torch.as_tensor(symmetry_tfs, device='cuda', dtype=torch.float) 79 | 80 | # logging.info("reset done") 81 | 82 | 83 | 84 | def get_tf_to_centered_mesh(self): 85 | tf_to_center = torch.eye(4, dtype=torch.float, device='cuda') 86 | tf_to_center[:3,3] = -torch.as_tensor(self.model_center, device='cuda', dtype=torch.float) 87 | return tf_to_center 88 | 89 | 90 | def to_device(self, s='cuda:0'): 91 | for k in self.__dict__: 92 | self.__dict__[k] = self.__dict__[k] 93 | if torch.is_tensor(self.__dict__[k]) or isinstance(self.__dict__[k], nn.Module): 94 | logging.info(f"Moving {k} to device {s}") 95 | self.__dict__[k] = self.__dict__[k].to(s) 96 | for k in self.mesh_tensors: 97 | logging.info(f"Moving {k} to device {s}") 98 | self.mesh_tensors[k] = self.mesh_tensors[k].to(s) 99 | if self.refiner is not None: 100 | self.refiner.model.to(s) 101 | if self.scorer is not None: 102 | self.scorer.model.to(s) 103 | if self.glctx is not None: 104 | self.glctx = dr.RasterizeCudaContext(s) 105 | 106 | 107 | 108 | def make_rotation_grid(self, min_n_views=40, inplane_step=60): 109 | cam_in_obs = sample_views_icosphere(n_views=min_n_views) 110 | # logging.info(f'cam_in_obs:{cam_in_obs.shape}') 111 | rot_grid = [] 112 | for i in range(len(cam_in_obs)): 113 | for inplane_rot in np.deg2rad(np.arange(0, 360, inplane_step)): 114 | cam_in_ob = cam_in_obs[i] 115 | R_inplane = euler_matrix(0,0,inplane_rot) 116 | cam_in_ob = cam_in_ob@R_inplane 117 | ob_in_cam = np.linalg.inv(cam_in_ob) 118 | rot_grid.append(ob_in_cam) 119 | 120 | rot_grid = np.asarray(rot_grid) 121 | # logging.info(f"rot_grid:{rot_grid.shape}") 122 | rot_grid = mycpp.cluster_poses(30, 99999, rot_grid, self.symmetry_tfs.data.cpu().numpy()) 123 | rot_grid = np.asarray(rot_grid) 124 | # logging.info(f"after cluster, rot_grid:{rot_grid.shape}") 125 | self.rot_grid = torch.as_tensor(rot_grid, device='cuda', dtype=torch.float) 126 | # logging.info(f"self.rot_grid: {self.rot_grid.shape}") 127 | 128 | 129 | def generate_random_pose_hypo(self, K, rgb, depth, mask, scene_pts=None): 130 | ''' 131 | @scene_pts: torch tensor (N,3) 132 | ''' 133 | ob_in_cams = self.rot_grid.clone() 134 | center = self.guess_translation(depth=depth, mask=mask, K=K) 135 | ob_in_cams[:,:3,3] = torch.tensor(center, device='cuda', dtype=torch.float).reshape(1,3) 136 | return ob_in_cams 137 | 138 | 139 | def guess_translation(self, depth, mask, K): 140 | vs,us = np.where(mask>0) 141 | if len(us)==0: 142 | logging.info(f'mask is all zero') 143 | return np.zeros((3)) 144 | uc = (us.min()+us.max())/2.0 145 | vc = (vs.min()+vs.max())/2.0 146 | valid = mask.astype(bool) & (depth>=MIN_DEPTH) 147 | if not valid.any(): 148 | logging.info(f"valid is empty") 149 | return np.zeros((3)) 150 | 151 | zc = np.median(depth[valid]) 152 | center = (np.linalg.inv(K)@np.asarray([uc,vc,1]).reshape(3,1))*zc 153 | 154 | if self.debug>=2: 155 | pcd = toOpen3dCloud(center.reshape(1,3)) 156 | o3d.io.write_point_cloud(f'{self.debug_dir}/init_center.ply', pcd) 157 | 158 | return center.reshape(3) 159 | 160 | 161 | def register(self, K, rgb, depth, ob_mask, ob_id=None, glctx=None, iteration=5): 162 | '''Copmute pose from given pts to self.pcd 163 | @pts: (N,3) np array, downsampled scene points 164 | ''' 165 | set_seed(0) 166 | 167 | if self.glctx is None: 168 | if glctx is None: 169 | self.glctx = dr.RasterizeCudaContext() 170 | # self.glctx = dr.RasterizeGLContext() 171 | else: 172 | self.glctx = glctx 173 | 174 | depth = erode_depth(depth, radius=2, device='cuda') 175 | depth = bilateral_filter_depth(depth, radius=2, device='cuda') 176 | 177 | if self.debug>=2: 178 | xyz_map = depth2xyzmap(depth, K) 179 | valid = xyz_map[...,2]>=MIN_DEPTH 180 | pcd = toOpen3dCloud(xyz_map[valid], rgb[valid]) 181 | o3d.io.write_point_cloud(f'{self.debug_dir}/scene_raw.ply',pcd) 182 | cv2.imwrite(f'{self.debug_dir}/ob_mask.png', (ob_mask*255.0).clip(0,255)) 183 | 184 | normal_map = None 185 | valid = (depth>=MIN_DEPTH) & (ob_mask>0) 186 | if valid.sum()<4: 187 | logging.info(f'valid too small, return') 188 | pose = np.eye(4) 189 | pose[:3,3] = self.guess_translation(depth=depth, mask=ob_mask, K=K) 190 | return pose 191 | 192 | if self.debug>=2: 193 | imageio.imwrite(f'{self.debug_dir}/color.png', rgb) 194 | cv2.imwrite(f'{self.debug_dir}/depth.png', (depth*1000).astype(np.uint16)) 195 | valid = xyz_map[...,2]>=MIN_DEPTH 196 | pcd = toOpen3dCloud(xyz_map[valid], rgb[valid]) 197 | o3d.io.write_point_cloud(f'{self.debug_dir}/scene_complete.ply',pcd) 198 | 199 | self.H, self.W = depth.shape[:2] 200 | self.K = K 201 | self.ob_id = ob_id 202 | self.ob_mask = ob_mask 203 | 204 | poses = self.generate_random_pose_hypo(K=K, rgb=rgb, depth=depth, mask=ob_mask, scene_pts=None) 205 | poses = poses.data.cpu().numpy() 206 | center = self.guess_translation(depth=depth, mask=ob_mask, K=K) 207 | 208 | poses = torch.as_tensor(poses, device='cuda', dtype=torch.float) 209 | poses[:,:3,3] = torch.as_tensor(center.reshape(1,3), device='cuda') 210 | 211 | add_errs = self.compute_add_err_to_gt_pose(poses) 212 | 213 | xyz_map = depth2xyzmap(depth, K) 214 | poses, vis = self.refiner.predict(mesh=self.mesh, mesh_tensors=self.mesh_tensors, rgb=rgb, depth=depth, K=K, ob_in_cams=poses.data.cpu().numpy(), normal_map=normal_map, xyz_map=xyz_map, glctx=self.glctx, mesh_diameter=self.diameter, iteration=iteration, get_vis=self.debug>=2) 215 | if vis is not None: 216 | imageio.imwrite(f'{self.debug_dir}/vis_refiner.png', vis) 217 | 218 | scores, vis = self.scorer.predict(mesh=self.mesh, rgb=rgb, depth=depth, K=K, ob_in_cams=poses.data.cpu().numpy(), xyz_map=xyz_map, normal_map=normal_map, mesh_tensors=self.mesh_tensors, glctx=self.glctx, mesh_diameter=self.diameter, get_vis=self.debug>=2) 219 | if vis is not None: 220 | imageio.imwrite(f'{self.debug_dir}/vis_score.png', vis) 221 | 222 | add_errs = self.compute_add_err_to_gt_pose(poses) 223 | 224 | ids = torch.as_tensor(scores).argsort(descending=True) 225 | scores = scores[ids] 226 | poses = poses[ids] 227 | 228 | 229 | best_pose = poses[0]@self.get_tf_to_centered_mesh() 230 | self.pose_last = poses[0] 231 | self.best_id = ids[0] 232 | 233 | self.poses = poses 234 | self.scores = scores 235 | 236 | return best_pose.data.cpu().numpy() 237 | 238 | 239 | def compute_add_err_to_gt_pose(self, poses): 240 | ''' 241 | @poses: wrt. the centered mesh 242 | ''' 243 | return -torch.ones(len(poses), device='cuda', dtype=torch.float) 244 | 245 | 246 | def track_one(self, rgb, depth, K, iteration, extra={}): 247 | if self.pose_last is None: 248 | logging.info("Please init pose by register first") 249 | raise RuntimeError 250 | 251 | depth = torch.as_tensor(depth, device='cuda', dtype=torch.float) 252 | depth = erode_depth(depth, radius=2, device='cuda') 253 | depth = bilateral_filter_depth(depth, radius=2, device='cuda') 254 | 255 | xyz_map = depth2xyzmap_batch(depth[None], torch.as_tensor(K, dtype=torch.float, device='cuda')[None], zfar=np.inf)[0] 256 | 257 | pose, vis = self.refiner.predict(mesh=self.mesh, mesh_tensors=self.mesh_tensors, rgb=rgb, depth=depth, K=K, ob_in_cams=self.pose_last.reshape(1,4,4).data.cpu().numpy(), normal_map=None, xyz_map=xyz_map, mesh_diameter=self.diameter, glctx=self.glctx, iteration=iteration, get_vis=self.debug>=2) 258 | if self.debug>=2: 259 | extra['vis'] = vis 260 | self.pose_last = pose 261 | return (pose@self.get_tf_to_centered_mesh()).data.cpu().numpy().reshape(4,4) 262 | 263 | 264 | -------------------------------------------------------------------------------- /foundation_pose/learning/datasets/h5_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | 11 | import os,sys,h5py,bisect,io,json 12 | code_dir = os.path.dirname(os.path.realpath(__file__)) 13 | sys.path.append(f'{code_dir}/../../../../') 14 | from Utils import * 15 | from learning.datasets.pose_dataset import * 16 | 17 | 18 | 19 | 20 | class PairH5Dataset(torch.utils.data.Dataset): 21 | def __init__(self, cfg, h5_file, mode='train', max_num_key=None, cache_data=None): 22 | self.cfg = cfg 23 | self.h5_file = h5_file 24 | self.mode = mode 25 | 26 | # logging.info(f"self.h5_file:{self.h5_file}") 27 | self.n_perturb = None 28 | self.H_ori = None 29 | self.W_ori = None 30 | self.cache_data = cache_data 31 | 32 | if self.mode=='test': 33 | pass 34 | else: 35 | self.object_keys = [] 36 | key_file = h5_file.replace('.h5','_keys.pkl') 37 | if os.path.exists(key_file): 38 | with open(key_file, 'rb') as ff: 39 | self.object_keys = pickle.load(ff) 40 | logging.info(f'object_keys loaded#:{len(self.object_keys)} from {key_file}') 41 | if max_num_key is not None: 42 | self.object_keys = self.object_keys[:max_num_key] 43 | else: 44 | with h5py.File(h5_file, 'r', libver='latest') as hf: 45 | for k in hf: 46 | self.object_keys.append(k) 47 | if max_num_key is not None and len(self.object_keys)>=max_num_key: 48 | logging.info("break due to max_num_key") 49 | break 50 | 51 | logging.info(f'self.object_keys#:{len(self.object_keys)}, max_num_key:{max_num_key}') 52 | 53 | with h5py.File(h5_file, 'r', libver='latest') as hf: 54 | group = hf[self.object_keys[0]] 55 | cnt = 0 56 | for k_perturb in group: 57 | if 'i_perturb' in k_perturb: 58 | cnt += 1 59 | if 'crop_ratio' in group[k_perturb]: 60 | self.cfg['crop_ratio'] = float(group[k_perturb]['crop_ratio'][()]) 61 | if self.H_ori is None: 62 | if 'H_ori' in group[k_perturb]: 63 | self.H_ori = int(group[k_perturb]['H_ori'][()]) 64 | self.W_ori = int(group[k_perturb]['W_ori'][()]) 65 | else: 66 | self.H_ori = 540 67 | self.W_ori = 720 68 | self.n_perturb = cnt 69 | logging.info(f'self.n_perturb:{self.n_perturb}') 70 | 71 | 72 | def __len__(self): 73 | if self.mode=='test': 74 | return 1 75 | return len(self.object_keys) 76 | 77 | 78 | 79 | def transform_depth_to_xyzmap(self, batch:BatchPoseData, H_ori, W_ori, bound=1): 80 | bs = len(batch.rgbAs) 81 | H,W = batch.rgbAs.shape[-2:] 82 | mesh_radius = batch.mesh_diameters.cuda()/2 83 | tf_to_crops = batch.tf_to_crops.cuda() 84 | crop_to_oris = batch.tf_to_crops.inverse().cuda() #(B,3,3) 85 | batch.poseA = batch.poseA.cuda() 86 | batch.Ks = batch.Ks.cuda() 87 | 88 | if batch.xyz_mapAs is None: 89 | depthAs_ori = kornia.geometry.transform.warp_perspective(batch.depthAs.cuda().expand(bs,-1,-1,-1), crop_to_oris, dsize=(H_ori, W_ori), mode='nearest', align_corners=False) 90 | batch.xyz_mapAs = depth2xyzmap_batch(depthAs_ori[:,0], batch.Ks, zfar=np.inf).permute(0,3,1,2) #(B,3,H,W) 91 | batch.xyz_mapAs = kornia.geometry.transform.warp_perspective(batch.xyz_mapAs, tf_to_crops, dsize=(H,W), mode='nearest', align_corners=False) 92 | batch.xyz_mapAs = batch.xyz_mapAs.cuda() 93 | if self.cfg['normalize_xyz']: 94 | invalid = batch.xyz_mapAs[:,2:3]=2) 99 | batch.xyz_mapAs[invalid.expand(bs,3,-1,-1)] = 0 100 | 101 | if batch.xyz_mapBs is None: 102 | depthBs_ori = kornia.geometry.transform.warp_perspective(batch.depthBs.cuda().expand(bs,-1,-1,-1), crop_to_oris, dsize=(H_ori, W_ori), mode='nearest', align_corners=False) 103 | batch.xyz_mapBs = depth2xyzmap_batch(depthBs_ori[:,0], batch.Ks, zfar=np.inf).permute(0,3,1,2) #(B,3,H,W) 104 | batch.xyz_mapBs = kornia.geometry.transform.warp_perspective(batch.xyz_mapBs, tf_to_crops, dsize=(H,W), mode='nearest', align_corners=False) 105 | batch.xyz_mapBs = batch.xyz_mapBs.cuda() 106 | if self.cfg['normalize_xyz']: 107 | invalid = batch.xyz_mapBs[:,2:3]=2) 112 | batch.xyz_mapBs[invalid.expand(bs,3,-1,-1)] = 0 113 | 114 | return batch 115 | 116 | 117 | 118 | def transform_batch(self, batch:BatchPoseData, H_ori, W_ori, bound=1): 119 | '''Transform the batch before feeding to the network 120 | !NOTE the H_ori, W_ori could be different at test time from the training data, and needs to be set 121 | ''' 122 | bs = len(batch.rgbAs) 123 | batch.rgbAs = batch.rgbAs.cuda().float()/255.0 124 | batch.rgbBs = batch.rgbBs.cuda().float()/255.0 125 | 126 | batch = self.transform_depth_to_xyzmap(batch, H_ori, W_ori, bound=bound) 127 | return batch 128 | 129 | 130 | 131 | 132 | class TripletH5Dataset(PairH5Dataset): 133 | def __init__(self, cfg, h5_file, mode, max_num_key=None, cache_data=None): 134 | super().__init__(cfg, h5_file, mode, max_num_key, cache_data=cache_data) 135 | 136 | 137 | def transform_depth_to_xyzmap(self, batch:BatchPoseData, H_ori, W_ori, bound=1): 138 | bs = len(batch.rgbAs) 139 | H,W = batch.rgbAs.shape[-2:] 140 | mesh_radius = batch.mesh_diameters.cuda()/2 141 | tf_to_crops = batch.tf_to_crops.cuda() 142 | crop_to_oris = batch.tf_to_crops.inverse().cuda() #(B,3,3) 143 | batch.poseA = batch.poseA.cuda() 144 | batch.Ks = batch.Ks.cuda() 145 | 146 | if batch.xyz_mapAs is None: 147 | depthAs_ori = kornia.geometry.transform.warp_perspective(batch.depthAs.cuda().expand(bs,-1,-1,-1), crop_to_oris, dsize=(H_ori, W_ori), mode='nearest', align_corners=False) 148 | batch.xyz_mapAs = depth2xyzmap_batch(depthAs_ori[:,0], batch.Ks, zfar=np.inf).permute(0,3,1,2) #(B,3,H,W) 149 | batch.xyz_mapAs = kornia.geometry.transform.warp_perspective(batch.xyz_mapAs, tf_to_crops, dsize=(H,W), mode='nearest', align_corners=False) 150 | batch.xyz_mapAs = batch.xyz_mapAs.cuda() 151 | invalid = batch.xyz_mapAs[:,2:3]<0.1 152 | batch.xyz_mapAs = (batch.xyz_mapAs-batch.poseA[:,:3,3].reshape(bs,3,1,1)) 153 | if self.cfg['normalize_xyz']: 154 | batch.xyz_mapAs *= 1/mesh_radius.reshape(bs,1,1,1) 155 | invalid = invalid.expand(bs,3,-1,-1) | (torch.abs(batch.xyz_mapAs)>=2) 156 | batch.xyz_mapAs[invalid.expand(bs,3,-1,-1)] = 0 157 | 158 | if batch.xyz_mapBs is None: 159 | depthBs_ori = kornia.geometry.transform.warp_perspective(batch.depthBs.cuda().expand(bs,-1,-1,-1), crop_to_oris, dsize=(H_ori, W_ori), mode='nearest', align_corners=False) 160 | batch.xyz_mapBs = depth2xyzmap_batch(depthBs_ori[:,0], batch.Ks, zfar=np.inf).permute(0,3,1,2) #(B,3,H,W) 161 | batch.xyz_mapBs = kornia.geometry.transform.warp_perspective(batch.xyz_mapBs, tf_to_crops, dsize=(H,W), mode='nearest', align_corners=False) 162 | batch.xyz_mapBs = batch.xyz_mapBs.cuda() 163 | invalid = batch.xyz_mapBs[:,2:3]<0.1 164 | batch.xyz_mapBs = (batch.xyz_mapBs-batch.poseA[:,:3,3].reshape(bs,3,1,1)) 165 | if self.cfg['normalize_xyz']: 166 | batch.xyz_mapBs *= 1/mesh_radius.reshape(bs,1,1,1) 167 | invalid = invalid.expand(bs,3,-1,-1) | (torch.abs(batch.xyz_mapBs)>=2) 168 | batch.xyz_mapBs[invalid.expand(bs,3,-1,-1)] = 0 169 | 170 | return batch 171 | 172 | 173 | def transform_batch(self, batch:BatchPoseData, H_ori, W_ori, bound=1): 174 | bs = len(batch.rgbAs) 175 | batch.rgbAs = batch.rgbAs.cuda().float()/255.0 176 | batch.rgbBs = batch.rgbBs.cuda().float()/255.0 177 | 178 | batch = self.transform_depth_to_xyzmap(batch, H_ori, W_ori, bound=bound) 179 | return batch 180 | 181 | 182 | 183 | class ScoreMultiPairH5Dataset(TripletH5Dataset): 184 | def __init__(self, cfg, h5_file, mode, max_num_key=None, cache_data=None): 185 | super().__init__(cfg, h5_file, mode, max_num_key, cache_data=cache_data) 186 | if mode in ['train', 'val']: 187 | self.cfg['train_num_pair'] = self.n_perturb 188 | 189 | 190 | class PoseRefinePairH5Dataset(PairH5Dataset): 191 | def __init__(self, cfg, h5_file, mode='train', max_num_key=None, cache_data=None): 192 | super().__init__(cfg=cfg, h5_file=h5_file, mode=mode, max_num_key=max_num_key, cache_data=cache_data) 193 | 194 | if mode!='test': 195 | with h5py.File(h5_file, 'r', libver='latest') as hf: 196 | group = hf[self.object_keys[0]] 197 | for key_perturb in group: 198 | depthA = imageio.imread(group[key_perturb]['depthA'][()]) 199 | depthB = imageio.imread(group[key_perturb]['depthB'][()]) 200 | self.cfg['n_view'] = min(self.cfg['n_view'], depthA.shape[1]//depthB.shape[1]) 201 | logging.info(f'n_view:{self.cfg["n_view"]}') 202 | self.trans_normalizer = group[key_perturb]['trans_normalizer'][()] 203 | if isinstance(self.trans_normalizer, np.ndarray): 204 | self.trans_normalizer = self.trans_normalizer.tolist() 205 | self.rot_normalizer = group[key_perturb]['rot_normalizer'][()]/180.0*np.pi 206 | logging.info(f'self.trans_normalizer:{self.trans_normalizer}, self.rot_normalizer:{self.rot_normalizer}') 207 | break 208 | 209 | 210 | def transform_batch(self, batch:BatchPoseData, H_ori, W_ori, bound=1): 211 | '''Transform the batch before feeding to the network 212 | !NOTE the H_ori, W_ori could be different at test time from the training data, and needs to be set 213 | ''' 214 | bs = len(batch.rgbAs) 215 | batch.rgbAs = batch.rgbAs.cuda().float()/255.0 216 | batch.rgbBs = batch.rgbBs.cuda().float()/255.0 217 | 218 | batch = self.transform_depth_to_xyzmap(batch, H_ori, W_ori, bound=bound) 219 | return batch 220 | 221 | -------------------------------------------------------------------------------- /foundation_pose/learning/datasets/pose_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import os,sys 11 | from dataclasses import dataclass 12 | from typing import Iterator, List, Optional, Set, Union 13 | import numpy as np 14 | import torch 15 | code_dir = os.path.dirname(os.path.realpath(__file__)) 16 | sys.path.append(f'{code_dir}/../../../../') 17 | from Utils import * 18 | 19 | 20 | @dataclass 21 | class PoseData: 22 | """ 23 | rgb: (h, w, 3) uint8 24 | depth: (bsz, h, w) float32 25 | bbox: (4, ) int 26 | K: (3, 3) float32 27 | """ 28 | rgb: np.ndarray = None 29 | bbox: np.ndarray = None 30 | K: np.ndarray = None 31 | depth: Optional[np.ndarray] = None 32 | object_data = None 33 | mesh_diameter: float = None 34 | rgbA: np.ndarray = None 35 | rgbB: np.ndarray = None 36 | depthA: np.ndarray = None 37 | depthB: np.ndarray = None 38 | maskA = None 39 | maskB = None 40 | poseA: np.ndarray = None #(4,4) 41 | target: float = None 42 | 43 | def __init__(self, rgbA=None, rgbB=None, depthA=None, depthB=None, maskA=None, maskB=None, normalA=None, normalB=None, xyz_mapA=None, xyz_mapB=None, poseA=None, poseB=None, K=None, target=None, mesh_diameter=None, tf_to_crop=None, crop_mask=None, model_pts=None, label=None, model_scale=None): 44 | self.rgbA = rgbA #(H,W,3) or (H,W*n_view,3) when multiview 45 | self.rgbB = rgbB 46 | self.depthA = depthA 47 | self.depthB = depthB 48 | self.poseA = poseA 49 | self.poseB = poseB 50 | self.maskA = maskA 51 | self.maskB = maskB 52 | self.crop_mask = crop_mask 53 | self.normalA = normalA 54 | self.normalB = normalB 55 | self.xyz_mapA = xyz_mapA 56 | self.xyz_mapB = xyz_mapB 57 | self.target = target 58 | self.K = K 59 | self.mesh_diameter = mesh_diameter 60 | self.tf_to_crop = tf_to_crop 61 | self.model_pts = model_pts 62 | self.label = label 63 | self.model_scale = model_scale 64 | 65 | 66 | @dataclass 67 | class BatchPoseData: 68 | """ 69 | rgbs: (bsz, 3, h, w) torch tensor uint8 70 | depths: (bsz, h, w) float32 71 | bboxes: (bsz, 4) int 72 | K: (bsz, 3, 3) float32 73 | """ 74 | 75 | rgbs: torch.Tensor = None 76 | object_datas = None 77 | bboxes: torch.Tensor = None 78 | K: torch.Tensor = None 79 | depths: Optional[torch.Tensor] = None 80 | rgbAs = None 81 | rgbBs = None 82 | depthAs = None 83 | depthBs = None 84 | normalAs = None 85 | normalBs = None 86 | poseA = None #(B,4,4) 87 | poseB = None 88 | targets = None # Score targets, torch tensor (B) 89 | 90 | def __init__(self, rgbAs=None, rgbBs=None, depthAs=None, depthBs=None, normalAs=None, normalBs=None, maskAs=None, maskBs=None, poseA=None, poseB=None, xyz_mapAs=None, xyz_mapBs=None, tf_to_crops=None, Ks=None, crop_masks=None, model_pts=None, mesh_diameters=None, labels=None): 91 | self.rgbAs = rgbAs 92 | self.rgbBs = rgbBs 93 | self.depthAs = depthAs 94 | self.depthBs = depthBs 95 | self.normalAs = normalAs 96 | self.normalBs = normalBs 97 | self.poseA = poseA 98 | self.poseB = poseB 99 | self.maskAs = maskAs 100 | self.maskBs = maskBs 101 | self.xyz_mapAs = xyz_mapAs 102 | self.xyz_mapBs = xyz_mapBs 103 | self.tf_to_crops = tf_to_crops 104 | self.crop_masks = crop_masks 105 | self.Ks = Ks 106 | self.model_pts = model_pts 107 | self.mesh_diameters = mesh_diameters 108 | self.labels = labels 109 | 110 | 111 | def pin_memory(self) -> "BatchPoseData": 112 | for k in self.__dict__: 113 | if self.__dict__[k] is not None: 114 | try: 115 | self.__dict__[k] = self.__dict__[k].pin_memory() 116 | except Exception as e: 117 | pass 118 | return self 119 | 120 | def cuda(self): 121 | for k in self.__dict__: 122 | if self.__dict__[k] is not None: 123 | try: 124 | self.__dict__[k] = self.__dict__[k].cuda() 125 | except: 126 | pass 127 | return self 128 | 129 | def select_by_indices(self, ids): 130 | out = BatchPoseData() 131 | for k in self.__dict__: 132 | if self.__dict__[k] is not None: 133 | out.__dict__[k] = self.__dict__[k][ids.to(self.__dict__[k].device)] 134 | return out 135 | 136 | -------------------------------------------------------------------------------- /foundation_pose/learning/models/network_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import os,sys,copy,math,tqdm 11 | import numpy as np 12 | dir_path = os.path.dirname(os.path.realpath(__file__)) 13 | sys.path.append(dir_path) 14 | import torch.nn.functional as F 15 | import torch 16 | import torch.nn as nn 17 | import time 18 | import cv2 19 | sys.path.append(f'{dir_path}/../../../../') 20 | from Utils import * 21 | 22 | 23 | 24 | class ConvBN(nn.Module): 25 | def __init__(self, C_in, C_out, kernel_size=3, stride=1, groups=1, bias=True,dilation=1,): 26 | super().__init__() 27 | padding = (kernel_size - 1) // 2 28 | self.net = nn.Sequential( 29 | nn.Conv2d(C_in, C_out, kernel_size, stride, padding, groups=groups, bias=bias,dilation=dilation), 30 | nn.BatchNorm2d(C_out), 31 | ) 32 | 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | 37 | class ConvBNReLU(nn.Module): 38 | def __init__(self, C_in, C_out, kernel_size=3, stride=1, groups=1, bias=True,dilation=1, norm_layer=nn.BatchNorm2d): 39 | super().__init__() 40 | padding = (kernel_size - 1) // 2 41 | layers = [ 42 | nn.Conv2d(C_in, C_out, kernel_size, stride, padding, groups=groups, bias=bias,dilation=dilation), 43 | ] 44 | if norm_layer is not None: 45 | layers.append(norm_layer(C_out)) 46 | layers.append(nn.ReLU(inplace=True)) 47 | self.net = nn.Sequential(*layers) 48 | 49 | def forward(self, x): 50 | return self.net(x) 51 | 52 | 53 | class ConvPadding(nn.Module): 54 | def __init__(self,C_in, C_out, kernel_size=3, stride=1, groups=1, bias=True,dilation=1): 55 | super(ConvPadding, self).__init__() 56 | padding = (kernel_size - 1) // 2 57 | self.conv = nn.Conv2d(C_in, C_out, kernel_size, stride, padding, groups=groups, bias=bias,dilation=dilation) 58 | 59 | def forward(self,x): 60 | return self.conv(x) 61 | 62 | 63 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, bias=False): 64 | """3x3 convolution with padding""" 65 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 66 | padding=dilation, groups=groups, bias=bias, dilation=dilation) 67 | 68 | 69 | def conv1x1(in_planes, out_planes, stride=1, bias=False): 70 | """1x1 convolution""" 71 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) 72 | 73 | class ResnetBasicBlock(nn.Module): 74 | __constants__ = ['downsample'] 75 | 76 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=nn.BatchNorm2d, bias=False): 77 | super().__init__() 78 | self.norm_layer = norm_layer 79 | if groups != 1 or base_width != 64: 80 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 81 | if dilation > 1: 82 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 83 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 84 | self.conv1 = conv3x3(inplanes, planes, stride,bias=bias) 85 | if self.norm_layer is not None: 86 | self.bn1 = norm_layer(planes) 87 | self.relu = nn.ReLU(inplace=True) 88 | self.conv2 = conv3x3(planes, planes,bias=bias) 89 | if self.norm_layer is not None: 90 | self.bn2 = norm_layer(planes) 91 | self.downsample = downsample 92 | self.stride = stride 93 | 94 | def forward(self, x): 95 | identity = x 96 | 97 | out = self.conv1(x) 98 | if self.norm_layer is not None: 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | if self.norm_layer is not None: 104 | out = self.bn2(out) 105 | 106 | if self.downsample is not None: 107 | identity = self.downsample(x) 108 | out += identity 109 | out = self.relu(out) 110 | 111 | return out 112 | 113 | 114 | 115 | class PositionalEmbedding(nn.Module): 116 | def __init__(self, d_model, max_len=512): 117 | super().__init__() 118 | 119 | # Compute the positional encodings once in log space. 120 | pe = torch.zeros(max_len, d_model).float() 121 | pe.require_grad = False 122 | 123 | position = torch.arange(0, max_len).float().unsqueeze(1) #(N,1) 124 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()[None] 125 | 126 | pe[:, 0::2] = torch.sin(position * div_term) #(N, d_model/2) 127 | pe[:, 1::2] = torch.cos(position * div_term) 128 | 129 | pe = pe.unsqueeze(0) 130 | self.register_buffer('pe', pe) #(1, max_len, D) 131 | 132 | 133 | def forward(self, x): 134 | ''' 135 | @x: (B,N,D) 136 | ''' 137 | return x + self.pe[:, :x.size(1)] 138 | 139 | -------------------------------------------------------------------------------- /foundation_pose/learning/models/refine_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import os,sys 11 | import numpy as np 12 | code_dir = os.path.dirname(os.path.realpath(__file__)) 13 | sys.path.append(code_dir) 14 | sys.path.append(f'{code_dir}/../../../../') 15 | from Utils import * 16 | import torch.nn.functional as F 17 | import torch 18 | import torch.nn as nn 19 | import cv2 20 | from functools import partial 21 | from network_modules import * 22 | from Utils import * 23 | 24 | 25 | 26 | class RefineNet(nn.Module): 27 | def __init__(self, cfg=None, c_in=4, n_view=1): 28 | super().__init__() 29 | self.cfg = cfg 30 | if self.cfg.use_BN: 31 | norm_layer = nn.BatchNorm2d 32 | norm_layer1d = nn.BatchNorm1d 33 | else: 34 | norm_layer = None 35 | norm_layer1d = None 36 | 37 | self.encodeA = nn.Sequential( 38 | ConvBNReLU(C_in=c_in,C_out=64,kernel_size=7,stride=2, norm_layer=norm_layer), 39 | ConvBNReLU(C_in=64,C_out=128,kernel_size=3,stride=2, norm_layer=norm_layer), 40 | ResnetBasicBlock(128,128,bias=True, norm_layer=norm_layer), 41 | ResnetBasicBlock(128,128,bias=True, norm_layer=norm_layer), 42 | ) 43 | 44 | self.encodeAB = nn.Sequential( 45 | ResnetBasicBlock(256,256,bias=True, norm_layer=norm_layer), 46 | ResnetBasicBlock(256,256,bias=True, norm_layer=norm_layer), 47 | ConvBNReLU(256,512,kernel_size=3,stride=2, norm_layer=norm_layer), 48 | ResnetBasicBlock(512,512,bias=True, norm_layer=norm_layer), 49 | ResnetBasicBlock(512,512,bias=True, norm_layer=norm_layer), 50 | ) 51 | 52 | embed_dim = 512 53 | num_heads = 4 54 | self.pos_embed = PositionalEmbedding(d_model=embed_dim, max_len=400) 55 | 56 | self.trans_head = nn.Sequential( 57 | nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=512, batch_first=True), 58 | nn.Linear(512, 3), 59 | ) 60 | 61 | if self.cfg['rot_rep']=='axis_angle': 62 | rot_out_dim = 3 63 | elif self.cfg['rot_rep']=='6d': 64 | rot_out_dim = 6 65 | else: 66 | raise RuntimeError 67 | self.rot_head = nn.Sequential( 68 | nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=512, batch_first=True), 69 | nn.Linear(512, rot_out_dim), 70 | ) 71 | 72 | 73 | def forward(self, A, B): 74 | """ 75 | @A: (B,C,H,W) 76 | """ 77 | bs = len(A) 78 | output = {} 79 | 80 | x = torch.cat([A,B], dim=0) 81 | x = self.encodeA(x) 82 | a = x[:bs] 83 | b = x[bs:] 84 | 85 | ab = torch.cat((a,b),1).contiguous() 86 | ab = self.encodeAB(ab) #(B,C,H,W) 87 | 88 | ab = self.pos_embed(ab.reshape(bs, ab.shape[1], -1).permute(0,2,1)) 89 | 90 | output['trans'] = self.trans_head(ab).mean(dim=1) 91 | output['rot'] = self.rot_head(ab).mean(dim=1) 92 | 93 | return output 94 | -------------------------------------------------------------------------------- /foundation_pose/learning/models/score_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import os,sys 11 | import numpy as np 12 | code_dir = os.path.dirname(os.path.realpath(__file__)) 13 | sys.path.append(code_dir) 14 | sys.path.append(f'{code_dir}/../../../../') 15 | from Utils import * 16 | from functools import partial 17 | import torch.nn.functional as F 18 | import torch 19 | import torch.nn as nn 20 | import cv2 21 | from network_modules import * 22 | from Utils import * 23 | 24 | 25 | 26 | 27 | class ScoreNetMultiPair(nn.Module): 28 | def __init__(self, cfg=None, c_in=4): 29 | super().__init__() 30 | self.cfg = cfg 31 | if self.cfg.use_BN: 32 | norm_layer = nn.BatchNorm2d 33 | else: 34 | norm_layer = None 35 | 36 | self.encoderA = nn.Sequential( 37 | ConvBNReLU(C_in=c_in,C_out=64,kernel_size=7,stride=2, norm_layer=norm_layer), 38 | ConvBNReLU(C_in=64,C_out=128,kernel_size=3,stride=2, norm_layer=norm_layer), 39 | ResnetBasicBlock(128,128,bias=True, norm_layer=norm_layer), 40 | ResnetBasicBlock(128,128,bias=True, norm_layer=norm_layer), 41 | ) 42 | 43 | self.encoderAB = nn.Sequential( 44 | ResnetBasicBlock(256,256,bias=True, norm_layer=norm_layer), 45 | ResnetBasicBlock(256,256,bias=True, norm_layer=norm_layer), 46 | ConvBNReLU(256,512,kernel_size=3,stride=2, norm_layer=norm_layer), 47 | ResnetBasicBlock(512,512,bias=True, norm_layer=norm_layer), 48 | ResnetBasicBlock(512,512,bias=True, norm_layer=norm_layer), 49 | ) 50 | 51 | embed_dim = 512 52 | num_heads = 4 53 | self.att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, bias=True, batch_first=True) 54 | self.att_cross = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, bias=True, batch_first=True) 55 | 56 | self.pos_embed = PositionalEmbedding(d_model=embed_dim, max_len=400) 57 | self.linear = nn.Linear(embed_dim, 1) 58 | 59 | 60 | def extract_feat(self, A, B): 61 | """ 62 | @A: (B*L,C,H,W) L is num of pairs 63 | """ 64 | bs = A.shape[0] # B*L 65 | 66 | x = torch.cat([A,B], dim=0) 67 | x = self.encoderA(x) 68 | a = x[:bs] 69 | b = x[bs:] 70 | ab = torch.cat((a,b), dim=1) 71 | ab = self.encoderAB(ab) 72 | ab = self.pos_embed(ab.reshape(bs, ab.shape[1], -1).permute(0,2,1)) 73 | ab, _ = self.att(ab, ab, ab) 74 | return ab.mean(dim=1).reshape(bs,-1) 75 | 76 | 77 | def forward(self, A, B, L): 78 | """ 79 | @A: (B*L,C,H,W) L is num of pairs 80 | @L: num of pairs 81 | """ 82 | output = {} 83 | bs = A.shape[0]//L 84 | feats = self.extract_feat(A, B) #(B*L, C) 85 | x = feats.reshape(bs,L,-1) 86 | x, _ = self.att_cross(x, x, x) 87 | 88 | output['score_logit'] = self.linear(x).reshape(bs,L) # (B,L) 89 | 90 | return output 91 | -------------------------------------------------------------------------------- /foundation_pose/learning/training/training_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import os,sys 11 | from dataclasses import dataclass, field 12 | from typing import List, Optional, Tuple,Union 13 | import numpy as np 14 | import omegaconf 15 | import torch 16 | 17 | 18 | @dataclass 19 | class TrainingConfig(omegaconf.dictconfig.DictConfig): 20 | input_resize: tuple = (160, 160) 21 | normalize_xyz:Optional[bool] = True 22 | use_mask:Optional[bool] = False 23 | crop_ratio:Optional[float] = None 24 | split_objects_across_gpus: bool = True 25 | max_num_key: Optional[int] = None 26 | use_normal:bool = False 27 | n_view:int = 1 28 | zfar:float = np.inf 29 | c_in:int = 6 30 | train_num_pair:Optional[int] = None 31 | make_pair_online:Optional[bool] = False 32 | render_backend:Optional[str] = 'nvdiffrast' 33 | 34 | # Run management 35 | run_id: Optional[str] = None 36 | exp_name:Optional[str] = None 37 | resume_run_id: Optional[str] = None 38 | save_dir: Optional[str] = None 39 | batch_size: int = 64 40 | epoch_size: int = 115200 41 | val_size: int = 1280 42 | n_epochs: int = 25 43 | save_epoch_interval: int = 100 44 | n_dataloader_workers: int = 20 45 | n_rendering_workers: int = 1 46 | gradient_max_norm:float = np.inf 47 | max_step_per_epoch: Optional[int] = 25000 48 | 49 | # Network 50 | use_BN:bool = True 51 | loss_type:Optional[str] = 'pairwise_valid' 52 | 53 | # Optimizer 54 | optimizer: str = "adam" 55 | weight_decay: float = 0.0 56 | clip_grad_norm: float = np.inf 57 | lr: float = 0.0001 58 | warmup_step: int = -1 # -1 means disable 59 | n_epochs_warmup: int = 1 60 | 61 | # Visualization 62 | vis_interval: Optional[int] = 1000 63 | 64 | debug: Optional[bool] = None 65 | 66 | 67 | 68 | @dataclass 69 | class TrainRefinerConfig: 70 | # Datasets 71 | input_resize: tuple = (160, 160) #(W,H) 72 | crop_ratio:Optional[float] = None 73 | max_num_key: Optional[int] = None 74 | use_normal:bool = False 75 | use_mask:Optional[bool] = False 76 | normal_uint8:bool = False 77 | normalize_xyz:Optional[bool] = True 78 | trans_normalizer:Optional[list] = None 79 | rot_normalizer:Optional[float] = None 80 | c_in:int = 6 81 | n_view:int = 1 82 | zfar:float = np.inf 83 | trans_rep:str = 'tracknet' # tracknet/deepim 84 | rot_rep:Optional[str] = 'axis_angle' # 6d/axis_angle 85 | save_dir: Optional[str] = None 86 | 87 | # Run management 88 | run_id: Optional[str] = None 89 | exp_name:Optional[str] = None 90 | batch_size: int = 64 91 | use_BN:bool = True 92 | optimizer: str = "adam" 93 | weight_decay: float = 0.0 94 | clip_grad_norm: float = np.inf 95 | lr: float = 0.0001 96 | warmup_step: int = -1 97 | loss_type:str = 'l2' # l1/l2/add 98 | 99 | vis_interval: Optional[int] = 1000 100 | debug: Optional[bool] = None 101 | 102 | 103 | -------------------------------------------------------------------------------- /foundation_pose/mycpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.15) 2 | project(mycpp) 3 | 4 | 5 | set(CMAKE_BUILD_TYPE Release) 6 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 7 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -fopenmp -g3 -O3") 8 | 9 | 10 | find_package(Boost REQUIRED COMPONENTS system program_options) 11 | find_package(OpenMP REQUIRED) 12 | find_package(Eigen3 REQUIRED) 13 | find_package(pybind11 REQUIRED) 14 | 15 | include_directories( 16 | include 17 | ${BLAS_INCLUDE_DIR} 18 | ) 19 | 20 | file(GLOB MY_SRC ${PROJECT_SOURCE_DIR}/src/*.cpp) 21 | 22 | set(PYBIND11_CPP_STANDARD -std=c++14) 23 | 24 | pybind11_add_module(mycpp src/app/pybind_api.cpp ${MY_SRC}) 25 | target_link_libraries(mycpp PRIVATE ${Boost_LIBRARIES} ${OpenMP_CXX_FLAGS} Eigen3::Eigen) -------------------------------------------------------------------------------- /foundation_pose/mycpp/include/Utils.h: -------------------------------------------------------------------------------- 1 | /*Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | */ 9 | 10 | 11 | #pragma once 12 | 13 | // STL 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | 38 | 39 | using vectorMatrix4f = std::vector>; 40 | 41 | 42 | namespace Utils 43 | { 44 | float rotationGeodesicDistance(const Eigen::Matrix3f &R1, const Eigen::Matrix3f &R2); 45 | 46 | } // namespace Utils 47 | -------------------------------------------------------------------------------- /foundation_pose/mycpp/src/Utils.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # NVIDIA CORPORATION and its licensors retain all intellectual property 5 | # and proprietary rights in and to this software, related documentation 6 | # and any modifications thereto. Any use, reproduction, disclosure or 7 | # distribution of this software and related documentation without an express 8 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 9 | */ 10 | 11 | 12 | #include "Utils.h" 13 | 14 | 15 | 16 | namespace Utils 17 | { 18 | 19 | 20 | // Difference angle in radian 21 | float rotationGeodesicDistance(const Eigen::Matrix3f &R1, const Eigen::Matrix3f &R2) 22 | { 23 | float cos = ((R1 * R2.transpose()).trace()-1) / 2.0; 24 | cos = std::max(std::min(cos, 1.0f), -1.0f); 25 | return std::acos(cos); 26 | } 27 | 28 | 29 | } // namespace Utils 30 | -------------------------------------------------------------------------------- /foundation_pose/mycpp/src/app/pybind_api.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # NVIDIA CORPORATION and its licensors retain all intellectual property 5 | # and proprietary rights in and to this software, related documentation 6 | # and any modifications thereto. Any use, reproduction, disclosure or 7 | # distribution of this software and related documentation without an express 8 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 9 | */ 10 | 11 | 12 | #include "Utils.h" 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | namespace py = pybind11; 19 | 20 | 21 | 22 | //@angle_diff: unit is degree 23 | //@dist_diff: unit is meter 24 | vectorMatrix4f cluster_poses(float angle_diff, float dist_diff, const vectorMatrix4f &poses_in, const vectorMatrix4f &symmetry_tfs) 25 | { 26 | printf("num original candidates = %d\n",poses_in.size()); 27 | vectorMatrix4f poses_out; 28 | poses_out.push_back(poses_in[0]); // now it becomes the pose clusters 29 | 30 | const float radian_thres = angle_diff/180.0*M_PI; 31 | 32 | for (int i=1;i=dist_diff) 42 | { 43 | continue; 44 | } 45 | 46 | /////////// Remove symmetry 47 | for (const auto &tf: symmetry_tfs) 48 | { 49 | Eigen::Matrix4f cur_pose_tmp = cur_pose*tf; 50 | float rot_diff = Utils::rotationGeodesicDistance(cur_pose_tmp.block(0,0,3,3), cluster.block(0,0,3,3)); 51 | if (rot_diff < radian_thres) 52 | { 53 | isnew = false; 54 | break; 55 | } 56 | } 57 | 58 | if (!isnew) break; 59 | } 60 | 61 | if (isnew) 62 | { 63 | poses_out.push_back(poses_in[i]); 64 | } 65 | } 66 | 67 | printf("num of pose after clustering: %d\n",poses_out.size()); 68 | return poses_out; 69 | } 70 | 71 | 72 | 73 | 74 | 75 | PYBIND11_MODULE(mycpp, m) 76 | { 77 | m.def("cluster_poses", &cluster_poses, py::call_guard()); 78 | } -------------------------------------------------------------------------------- /foundation_pose/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='foundation_pose', 5 | packages=find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /foundation_pose/wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | import numpy as np 5 | import trimesh 6 | from foundation_pose.estimater import ScorePredictor, PoseRefinePredictor, FoundationPose 7 | from foundation_pose.Utils import draw_posed_3d_box, draw_xyz_axis, trimesh_add_pure_colored_texture 8 | 9 | 10 | class FoundationPoseWrapper: 11 | def __init__(self, mesh_dir, debug_dir=None) -> None: 12 | # load object mesh 13 | self.debug_dir = "./debug" #debug_dir 14 | self.mesh_dir = mesh_dir 15 | self.mesh = None 16 | 17 | self.grasp_obj_name = None 18 | self.cur_grasp_obj_name = None 19 | 20 | def update_grasp_obj_name(self, obj_name): 21 | self.grasp_obj_name = obj_name 22 | 23 | def load_mesh(self): 24 | assert self.grasp_obj_name is not None 25 | mesh_path = os.path.join(self.mesh_dir, self.grasp_obj_name + ".obj") 26 | print(mesh_path) 27 | 28 | mesh = trimesh.load(mesh_path) 29 | if isinstance(mesh, trimesh.Scene): 30 | mesh = trimesh.load(mesh_path, force='mesh', skip_materials=False) 31 | 32 | # solve the material issue (wrongly recognize vertex color as material image) 33 | if mesh.visual.material is not None: 34 | if mesh.visual.material.image is None: # no texture 35 | mesh = trimesh.load(mesh_path, force='mesh', skip_materials=True) # use vertex color 36 | # mesh.show() 37 | 38 | if "light_bulb_in" in mesh_path: 39 | mesh = trimesh.load(mesh_path, force='mesh', skip_materials=True) # use vertex color 40 | 41 | mesh.vertices = mesh.vertices - np.mean(mesh.vertices, axis=0) 42 | 43 | self.mesh = mesh 44 | self.cur_grasp_obj_name = self.grasp_obj_name 45 | 46 | def create_estimator(self, debug_level=-1): 47 | # load mesh if mesh have not been loaded or grasp_obj_name changed 48 | if (self.mesh is None) or not (self.cur_grasp_obj_name == self.grasp_obj_name): 49 | self.load_mesh() 50 | 51 | debug_level = 0 if (self.debug_dir is None) or (debug_level < 0) else debug_level 52 | 53 | scorer = ScorePredictor() 54 | refiner = PoseRefinePredictor() 55 | return FoundationPose( 56 | model_pts=self.mesh.vertices, model_normals=self.mesh.vertex_normals, mesh=self.mesh, 57 | scorer=scorer, refiner=refiner, 58 | debug_dir=self.debug_dir, debug=debug_level, 59 | ) 60 | 61 | 62 | class FoundationPoseWrapperReal: 63 | def __init__(self) -> None: 64 | # load object mesh 65 | self.debug_dir = "./debug" #debug_dir 66 | self.mesh = None 67 | 68 | def create_estimator(self, debug_level=-1): 69 | assert self.mesh is not None 70 | 71 | debug_level = 0 if (self.debug_dir is None) or (debug_level < 0) else debug_level 72 | 73 | scorer = ScorePredictor() 74 | refiner = PoseRefinePredictor() 75 | return FoundationPose( 76 | model_pts=self.mesh.vertices, model_normals=self.mesh.vertex_normals, mesh=self.mesh, 77 | scorer=scorer, refiner=refiner, 78 | debug_dir=self.debug_dir, debug=debug_level, 79 | ) -------------------------------------------------------------------------------- /fp_requirements.txt: -------------------------------------------------------------------------------- 1 | # JupyterLab 2 | # jupyterlab==4.1.5 3 | # ipywidgets==8.1.2 4 | 5 | 6 | # Common Libraries 7 | numpy==1.23.5 #1.26.4 8 | scipy==1.10.1 #1.12.0 9 | # scikit-learn==1.3.2 #1.4.1.post1 10 | # pyyaml==6.0.1 11 | ruamel.yaml==0.18.6 12 | ninja==1.11.1.1 13 | h5py==3.11.0 #3.10.0 14 | numba==0.58.1 #0.56.4 #0.59.1 15 | pybind11==2.12.0 16 | 17 | # CV Libraries 18 | imageio==2.34.2 #2.34.0 19 | imageio-ffmpeg==0.5.1 20 | #opencv-python==4.9.0.80 21 | opencv-python-headless==4.10.0.84 22 | # opencv-contrib-python==4.9.0.80 23 | # plotly==5.22.0 #5.20.0 24 | open3d==0.18.0 25 | # pyglet==1.5.28 26 | # pysdf==0.1.9 27 | 28 | # Trimesh 29 | trimesh==4.2.2 30 | # xatlas==0.0.9 31 | # rtree==1.2.0 32 | 33 | # PyRender 34 | # pyrender==0.1.45 35 | # pyOpenGL>=3.1.0 36 | # pyOpenGL_accelerate>=3.1.0 37 | 38 | meshcat==0.3.2 39 | webdataset==0.2.86 40 | omegaconf==2.3.0 41 | # pypng==0.20220715.0 42 | # Panda3D==1.10.14 43 | # simplejson==3.19.2 44 | # bokeh==3.4.0 45 | # roma==1.4.4 46 | seaborn==0.13.2 47 | # pin==2.7.0 48 | # openpyxl==3.1.2 49 | # torchnet==0.0.4 50 | wandb==0.17.4 #0.16.5 51 | # colorama==0.4.6 52 | # GPUtil==1.4.0 53 | # imgaug==0.4.0 54 | # xlsxwriter==3.2.0 55 | # timm==0.9.16 56 | # albumentations==1.4.2 57 | # xatlas==0.0.9 58 | # nodejs==0.1.1 59 | # jupyterlab==4.1.5 60 | # objaverse==0.1.7 61 | # g4f==0.2.7.1 62 | # ultralytics==8.0.120 63 | # pycocotools==2.0.7 64 | # py-spy==0.3.14 65 | # pybullet==3.2.6 66 | # videoio==0.2.8 67 | kornia==0.7.3 #0.7.2 68 | einops==0.4.1 #0.7.0 69 | transformations==2022.9.26 #2024.6.1 70 | joblib==1.4.2 #1.3.2 71 | warp-lang==1.2.2 #1.0.2 72 | 73 | # For PyTorch3D 74 | fvcore==0.1.5.post20221221 -------------------------------------------------------------------------------- /scripts/eval_policy.sh: -------------------------------------------------------------------------------- 1 | DEBUG=False 2 | seed=0 3 | 4 | alg_name=simple_dp3 5 | task_name=${1} 6 | config_name=${alg_name} 7 | seed=${seed} 8 | exp_name=${task_name} 9 | run_dir="data/outputs/${exp_name}_seed${seed}" 10 | 11 | gpu_id=0 12 | use_fp=True # if false use ground-truth object pose instead 13 | eval_epoch=1000 14 | 15 | 16 | export HYDRA_FULL_ERROR=1 17 | export CUDA_VISIBLE_DEVICES=${gpu_id} 18 | python tools/eval_dp3.py --config-name=${config_name}.yaml \ 19 | task=${task_name} \ 20 | hydra.run.dir=${run_dir} \ 21 | training.debug=$DEBUG \ 22 | training.seed=${seed} \ 23 | training.device="cuda:0" \ 24 | exp_name=${exp_name} \ 25 | logging.mode=${wandb_mode} \ 26 | checkpoint.save_ckpt=${save_ckpt} \ 27 | task.env_runner.use_fp=${use_fp} \ 28 | evaluation.eval_epoch=${eval_epoch} 29 | -------------------------------------------------------------------------------- /scripts/eval_policy_multi.sh: -------------------------------------------------------------------------------- 1 | for rlbench_task_name in \ 2 | meat_off_grill \ 3 | insert_onto_square_peg \ 4 | close_jar \ 5 | light_bulb_in \ 6 | place_wine_at_rack_location \ 7 | put_groceries_in_cupboard \ 8 | put_money_in_safe \ 9 | reach_and_drag \ 10 | place_shape_in_shape_sorter \ 11 | stack_cups \ 12 | stack_blocks \ 13 | place_cups \ 14 | turn_tap 15 | do 16 | DEBUG=False 17 | seed=0 18 | 19 | alg_name=simple_dp3 20 | task_name=${rlbench_task_name} 21 | config_name=${alg_name} 22 | seed=${seed} 23 | exp_name="rlbench_multi" 24 | run_dir="data/outputs/${exp_name}_seed${seed}" 25 | 26 | gpu_id=0 27 | use_fp=True # if false use ground-truth object pose instead 28 | eval_epoch=1000 29 | 30 | 31 | export HYDRA_FULL_ERROR=1 32 | export CUDA_VISIBLE_DEVICES=${gpu_id} 33 | python tools/eval_dp3.py --config-name=${config_name}.yaml \ 34 | task=rlbench/${task_name} \ 35 | hydra.run.dir=${run_dir} \ 36 | training.debug=$DEBUG \ 37 | training.seed=${seed} \ 38 | training.device="cuda:0" \ 39 | exp_name=${exp_name} \ 40 | logging.mode=${wandb_mode} \ 41 | checkpoint.save_ckpt=${save_ckpt} \ 42 | task.env_runner.use_fp=${use_fp} \ 43 | evaluation.eval_epoch=${eval_epoch} 44 | done 45 | -------------------------------------------------------------------------------- /scripts/gen_demonstration_real.sh: -------------------------------------------------------------------------------- 1 | BASE_DIR=$(pwd) 2 | TASK_DATASET_PATH=/tmp/pour_water/r3d/ 3 | YOLO_WORLD_DIR=/tmp/YOLO-World/ 4 | 5 | for task_name in \ 6 | pour_water \ 7 | # task2 8 | # task3 9 | do 10 | # extract RGBD frame from .r3d file (recorded video by Record3D app) 11 | conda run -n spot --no-capture-output python env_real/data/prepare_rgbd.py $task_name ${TASK_DATASET_PATH} 12 | 13 | # get first frame segmentation from YOLO-World in "yolo_world" conda envoriment 14 | # *must use absolute path for some reasons 15 | cd ${YOLO_WORLD_DIR} 16 | conda run -n yolo_world --no-capture-output PYTHONPATH=${YOLO_WORLD_DIR} python ${BASE_DIR}/env_real/data/prepare_mask.py $task_name ${TASK_DATASET_PATH} 17 | 18 | # get grasp/trarget object pose throughout the sequence 19 | cd ${BASE_DIR} 20 | conda run -n spot --no-capture-output python env_real/data/prepare_pose.py $task_name ${TASK_DATASET_PATH} 21 | # convert the data to zarr for training 22 | cd ${BASE_DIR} 23 | conda run -n spot --no-capture-output python env_real/data/collect_zarr_real.py $task_name ${TASK_DATASET_PATH} 24 | done -------------------------------------------------------------------------------- /scripts/gen_demonstration_rlbench.sh: -------------------------------------------------------------------------------- 1 | # peract setting 2 | for task_name in \ 3 | meat_off_grill \ 4 | place_wine_at_rack_location \ 5 | insert_onto_square_peg \ 6 | put_groceries_in_cupboard \ 7 | place_shape_in_shape_sorter \ 8 | reach_and_drag \ 9 | put_money_in_safe \ 10 | turn_tap \ 11 | light_bulb_in \ 12 | close_jar 13 | do 14 | # training set 15 | python env_rlbench_peract/data/collect_zarr_rlbench_peract.py \ 16 | --peract_demo_dir=/tmp/peract/raw/ \ 17 | --save_path=/tmp/rlbench_zarr/ \ 18 | --tasks=$task_name --variations=-1 --processes=1 --split=train --episodes_per_task=100 19 | 20 | # testing set 21 | python env_rlbench_peract/data/collect_zarr_rlbench_peract.py \ 22 | --peract_demo_dir=/tmp/peract/raw/ \ 23 | --save_path=/tmp/rlbench_zarr/ \ 24 | --tasks=$task_name --variations=-1 --processes=1 --split=test --episodes_per_task=25 25 | done 26 | 27 | 28 | # peract setting (multi-object task) 29 | for task_name in \ 30 | stack_cups \ 31 | stack_blocks \ 32 | place_cups 33 | do 34 | # training set 35 | python env_rlbench_peract/data/collect_zarr_rlbench_peract_multi_stage.py \ 36 | --peract_demo_dir=/tmp/peract/raw/ \ 37 | --save_path=/tmp/rlbench_zarr/ \ 38 | --tasks=$task_name --variations=-1 --processes=1 --split=train --episodes_per_task=100 39 | 40 | # testing set 41 | python env_rlbench_peract/data/collect_zarr_rlbench_peract_multi_stage.py \ 42 | --peract_demo_dir=/tmp/peract/raw/ \ 43 | --save_path=/tmp/rlbench_zarr/ \ 44 | --tasks=$task_name --variations=-1 --processes=1 --split=test --episodes_per_task=25 45 | done -------------------------------------------------------------------------------- /scripts/train_policy.sh: -------------------------------------------------------------------------------- 1 | DEBUG=False 2 | save_ckpt=True 3 | seed=0 4 | 5 | alg_name=simple_dp3 6 | task_name=${1} 7 | config_name=${alg_name} 8 | seed=${seed} 9 | exp_name=${task_name} 10 | run_dir="data/outputs/${exp_name}_seed${seed}" 11 | 12 | 13 | # gpu_id=$(bash scripts/find_gpu.sh) 14 | gpu_id=0 15 | echo -e "\033[33mgpu id (to use): ${gpu_id}\033[0m" 16 | 17 | 18 | if [ $DEBUG = True ]; then 19 | wandb_mode=offline 20 | # wandb_mode=online 21 | echo -e "\033[33mDebug mode!\033[0m" 22 | echo -e "\033[33mDebug mode!\033[0m" 23 | echo -e "\033[33mDebug mode!\033[0m" 24 | else 25 | wandb_mode=online 26 | echo -e "\033[33mTrain mode\033[0m" 27 | fi 28 | 29 | 30 | export HYDRA_FULL_ERROR=1 31 | export CUDA_VISIBLE_DEVICES=${gpu_id} 32 | python tools/train_dp3.py --config-name=${config_name}.yaml \ 33 | task=${task_name} \ 34 | hydra.run.dir=${run_dir} \ 35 | training.debug=$DEBUG \ 36 | training.seed=${seed} \ 37 | training.device="cuda:0" \ 38 | exp_name=${exp_name} \ 39 | logging.mode=${wandb_mode} \ 40 | checkpoint.save_ckpt=${save_ckpt} 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /tools/eval_dp3.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).resolve().parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | os.chdir(ROOT_DIR) 9 | 10 | import os 11 | import hydra 12 | from omegaconf import OmegaConf 13 | import pathlib 14 | from train_dp3 import TrainDP3Workspace 15 | 16 | OmegaConf.register_new_resolver("eval", eval, replace=True) 17 | 18 | 19 | @hydra.main( 20 | version_base=None, 21 | config_path=str(pathlib.Path(__file__).resolve().parent.parent.joinpath('config')) 22 | ) 23 | def main(cfg): 24 | workspace = TrainDP3Workspace(cfg) 25 | workspace.eval() 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /utils/collect_utils_rlbench.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import logging 11 | from typing import List 12 | 13 | import numpy as np 14 | from rlbench.demo import Demo 15 | 16 | 17 | def _is_stopped(demo, i, obs, stopped_buffer, delta=0.1): 18 | next_is_not_final = i == (len(demo) - 2) 19 | gripper_state_no_change = ( 20 | i < (len(demo) - 2) and 21 | (obs.gripper_open == demo[i + 1].gripper_open and 22 | obs.gripper_open == demo[i - 1].gripper_open and 23 | demo[i - 2].gripper_open == demo[i - 1].gripper_open)) 24 | small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) 25 | stopped = (stopped_buffer <= 0 and small_delta and 26 | (not next_is_not_final) and gripper_state_no_change) 27 | return stopped 28 | 29 | 30 | def keypoint_discovery(demo: Demo, 31 | stopping_delta=0.1, 32 | method='heuristic') -> List[int]: 33 | episode_keypoints = [] 34 | if method == 'heuristic': 35 | prev_gripper_open = demo[0].gripper_open 36 | stopped_buffer = 0 37 | for i, obs in enumerate(demo): 38 | stopped = _is_stopped(demo, i, obs, stopped_buffer, stopping_delta) 39 | stopped_buffer = 4 if stopped else stopped_buffer - 1 40 | # If change in gripper, or end of episode. 41 | last = i == (len(demo) - 1) 42 | if i != 0 and (obs.gripper_open != prev_gripper_open or 43 | last or stopped): 44 | episode_keypoints.append(i) 45 | prev_gripper_open = obs.gripper_open 46 | if len(episode_keypoints) > 1 and (episode_keypoints[-1] - 1) == \ 47 | episode_keypoints[-2]: 48 | episode_keypoints.pop(-2) 49 | logging.debug('Found %d keypoints.' % len(episode_keypoints), 50 | episode_keypoints) 51 | return episode_keypoints 52 | 53 | elif method == 'random': 54 | # Randomly select keypoints. 55 | episode_keypoints = np.random.choice( 56 | range(len(demo)), 57 | size=20, 58 | replace=False) 59 | episode_keypoints.sort() 60 | return episode_keypoints 61 | 62 | elif method == 'fixed_interval': 63 | # Fixed interval. 64 | episode_keypoints = [] 65 | segment_length = len(demo) // 20 66 | for i in range(0, len(demo), segment_length): 67 | episode_keypoints.append(i) 68 | return episode_keypoints 69 | 70 | else: 71 | raise NotImplementedError -------------------------------------------------------------------------------- /utils/io_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import json 11 | import numpy as np 12 | 13 | 14 | def save_np_dict_to_json(np_dict, save_path): 15 | json_string = json.dumps({k: v.tolist() for k, v in np_dict.items()}, indent=4) 16 | with open(save_path, "w") as f: 17 | f.write(json_string) 18 | 19 | def load_np_dict_from_json(read_path): 20 | with open(read_path, "r") as f: 21 | np_dict = {k: np.array(v) for k, v in json.load(f).items()} 22 | return np_dict 23 | 24 | 25 | if __name__ == "__main__": 26 | # Test data 27 | d = { 28 | 'chicken': np.random.randn(5), 29 | 'banana': np.random.randn(5), 30 | 'carrots': np.random.randn(5) 31 | } 32 | 33 | save_path = "test.json" 34 | save_np_dict_to_json(d, save_path) 35 | a = load_np_dict_from_json(save_path) 36 | 37 | print (a) 38 | print (d) -------------------------------------------------------------------------------- /utils/logger_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import numpy as np 11 | import imageio 12 | from PIL import Image 13 | from utils.mask_utils import palette_ADE20K 14 | 15 | 16 | def _save_rgb(data, path, type='video'): 17 | assert type in ['images', 'gif', 'video'] 18 | 19 | if type == 'gif': 20 | # save gif 21 | pil_images = [Image.fromarray(img) for img in data] 22 | pil_images[0].save(path, save_all=True, append_images=pil_images[1:], duration=50, loop=0) # duration is the number of milliseconds between frames; this is 40 frames per second 23 | elif type == 'video': 24 | # save video 25 | video_writer = imageio.get_writer(path, fps=40) 26 | for img in data: 27 | video_writer.append_data(img) 28 | video_writer.close() 29 | else: 30 | raise NotImplementedError 31 | 32 | def _save_combined_rgb(data_list, path, type='video'): 33 | assert type in ['gif', 'video'] 34 | # TODO: check if length is the same 35 | # TODO: check shape 36 | 37 | n_frame = len(data_list[0]) 38 | for data in data_list: 39 | assert len(data) == n_frame 40 | 41 | data_combined = [] 42 | for i in range(n_frame): 43 | data_combined.append( 44 | np.concatenate([frames[i] for frames in data_list], axis=1) 45 | ) 46 | 47 | if type == 'gif': 48 | # save gif 49 | pil_images = [Image.fromarray(img) for img in data_combined] 50 | pil_images[0].save(path, save_all=True, append_images=pil_images[1:], duration=50, loop=0) # duration is the number of milliseconds between frames; this is 40 frames per second 51 | elif type == 'video': 52 | # save video 53 | video_writer = imageio.get_writer(path, fps=40) 54 | for img in data_combined: 55 | video_writer.append_data(img) 56 | video_writer.close() 57 | else: 58 | raise NotImplementedError 59 | 60 | def _save_mask(data, path, palette=None): 61 | 62 | print("The function _save_mask currently only saves the first image.") 63 | # TODO: support multiple images 64 | mask_image = data[0] 65 | 66 | if palette is None: 67 | palette = palette_ADE20K 68 | mask_image = Image.fromarray(mask_image.astype(np.uint8)).convert('P') 69 | mask_image.putpalette(palette) 70 | mask_image.save(path) 71 | 72 | 73 | class EnvLogger: 74 | def __init__(self): 75 | self.data = {} 76 | self.data_type = {} 77 | 78 | def clear(self): 79 | for name in self.data: 80 | self.data[name].clear() 81 | 82 | def add_data_type(self, name, type): 83 | self.data[name] = [] 84 | self.data_type[name] = type 85 | 86 | def add_data(self, name, new_data): 87 | assert name in self.data 88 | self.data[name].append(new_data) 89 | 90 | def get_data(self, name): 91 | return self.data[name] 92 | 93 | def save_data(self, name, path, output_fn=None, **kwargs): 94 | if isinstance(name, list): 95 | # TODO: check if type is the same 96 | data_list = [] 97 | data_type = self.data_type[name[0]] 98 | for k in name: 99 | data_list.append(self.data[k]) 100 | output_fn = output_fn if output_fn is not None else self._get_combined_output_fn(data_type) 101 | output_fn(data_list, path, **kwargs) 102 | else: 103 | data = self.data[name] 104 | data_type = self.data_type[name] 105 | output_fn = output_fn if output_fn is not None else self._get_output_fn(data_type) 106 | output_fn(data, path, **kwargs) 107 | 108 | def _get_output_fn(self, type): 109 | assert type in ['list', 'array', 'rgb', 'depth', 'mask'] 110 | if type == 'list': 111 | raise NotImplementedError 112 | elif type == 'array': 113 | raise NotImplementedError 114 | elif type == 'rgb': 115 | return _save_rgb 116 | elif type == 'depth': 117 | raise NotImplementedError 118 | elif type == 'mask': 119 | return _save_mask 120 | else: 121 | raise NotImplementedError 122 | 123 | def _get_combined_output_fn(self, type): 124 | assert type in ['list', 'array', 'rgb', 'depth', 'mask'] 125 | if type == 'list': 126 | raise NotImplementedError 127 | elif type == 'array': 128 | raise NotImplementedError 129 | elif type == 'rgb': 130 | return _save_combined_rgb 131 | elif type == 'depth': 132 | raise NotImplementedError 133 | else: 134 | raise NotImplementedError -------------------------------------------------------------------------------- /utils/mask_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | palette_ADE20K = np.array([[0, 0, 0], 5 | [120, 120, 120], 6 | [180, 120, 120], 7 | [6, 230, 230], 8 | [80, 50, 50], 9 | [4, 200, 3], 10 | [120, 120, 80], 11 | [140, 140, 140], 12 | [204, 5, 255], 13 | [184, 255, 0], 14 | [4, 250, 7], 15 | [224, 5, 255], 16 | [235, 255, 7], 17 | [150, 5, 61], 18 | [120, 120, 70], 19 | [8, 255, 51], 20 | [255, 6, 82], 21 | [143, 255, 140], 22 | [204, 255, 4], 23 | [255, 51, 7], 24 | [204, 70, 3], 25 | [0, 102, 200], 26 | [61, 230, 250], 27 | [255, 6, 51], 28 | [11, 102, 255], 29 | [255, 7, 71], 30 | [255, 9, 224], 31 | [9, 7, 230], 32 | [220, 220, 220], 33 | [255, 9, 92], 34 | [112, 9, 255], 35 | [8, 255, 214], 36 | [7, 255, 224], 37 | [255, 184, 6], 38 | [10, 255, 71], 39 | [255, 41, 10], 40 | [7, 255, 255], 41 | [224, 255, 8], 42 | [102, 8, 255], 43 | [255, 61, 6], 44 | [255, 194, 7], 45 | [255, 122, 8], 46 | [0, 255, 20], 47 | [255, 8, 41], 48 | [255, 5, 153], 49 | [6, 51, 255], 50 | [235, 12, 255], 51 | [160, 150, 20], 52 | [0, 163, 255], 53 | [140, 140, 140], 54 | [250, 10, 15], 55 | [20, 255, 0], 56 | [31, 255, 0], 57 | [255, 31, 0], 58 | [255, 224, 0], 59 | [153, 255, 0], 60 | [0, 0, 255], 61 | [255, 71, 0], 62 | [0, 235, 255], 63 | [0, 173, 255], 64 | [31, 0, 255], 65 | [11, 200, 200], 66 | [255, 82, 0], 67 | [0, 255, 245], 68 | [0, 61, 255], 69 | [0, 255, 112], 70 | [0, 255, 133], 71 | [255, 0, 0], 72 | [255, 163, 0], 73 | [255, 102, 0], 74 | [194, 255, 0], 75 | [0, 143, 255], 76 | [51, 255, 0], 77 | [0, 82, 255], 78 | [0, 255, 41], 79 | [0, 255, 173], 80 | [10, 0, 255], 81 | [173, 255, 0], 82 | [0, 255, 153], 83 | [255, 92, 0], 84 | [255, 0, 255], 85 | [255, 0, 245], 86 | [255, 0, 102], 87 | [255, 173, 0], 88 | [255, 0, 20], 89 | [255, 184, 184], 90 | [0, 31, 255], 91 | [0, 255, 61], 92 | [0, 71, 255], 93 | [255, 0, 204], 94 | [0, 255, 194], 95 | [0, 255, 82], 96 | [0, 10, 255], 97 | [0, 112, 255], 98 | [51, 0, 255], 99 | [0, 194, 255], 100 | [0, 122, 255], 101 | [0, 255, 163], 102 | [255, 153, 0], 103 | [0, 255, 10], 104 | [255, 112, 0], 105 | [143, 255, 0], 106 | [82, 0, 255], 107 | [163, 255, 0], 108 | [255, 235, 0], 109 | [8, 184, 170], 110 | [133, 0, 255], 111 | [0, 255, 92], 112 | [184, 0, 255], 113 | [255, 0, 31], 114 | [0, 184, 255], 115 | [0, 214, 255], 116 | [255, 0, 112], 117 | [92, 255, 0], 118 | [0, 224, 255], 119 | [112, 224, 255], 120 | [70, 184, 160], 121 | [163, 0, 255], 122 | [153, 0, 255], 123 | [71, 255, 0], 124 | [255, 0, 163], 125 | [255, 204, 0], 126 | [255, 0, 143], 127 | [0, 255, 235], 128 | [133, 255, 0], 129 | [255, 0, 235], 130 | [245, 0, 255], 131 | [255, 0, 122], 132 | [255, 245, 0], 133 | [10, 190, 212], 134 | [214, 255, 0], 135 | [0, 204, 255], 136 | [20, 0, 255], 137 | [255, 255, 0], 138 | [0, 153, 255], 139 | [0, 41, 255], 140 | [0, 255, 204], 141 | [41, 0, 255], 142 | [41, 255, 0], 143 | [173, 0, 255], 144 | [0, 245, 255], 145 | [71, 0, 255], 146 | [122, 0, 255], 147 | [0, 255, 184], 148 | [0, 92, 255], 149 | [230, 230, 230], 150 | [0, 133, 255], 151 | [255, 214, 0], 152 | [25, 194, 194], 153 | [102, 255, 0], 154 | [92, 0, 255]], np.uint8) -------------------------------------------------------------------------------- /utils/pose_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import os 11 | import sys 12 | sys.path.insert(0, os.getcwd()) 13 | 14 | import numba 15 | import numpy as np 16 | import utils.transform_utils as T 17 | from scipy.spatial.transform import Rotation as R 18 | 19 | from pyquaternion import Quaternion 20 | 21 | def pose_euler_to_pose_quaternion(pose_euler): 22 | pose_quaternion = np.zeros(6) 23 | pose_quaternion[:3] = pose_euler[:3] 24 | pose_quaternion[3:] = quaternion_from_euler(*pose_euler[3:]) 25 | return pose_quaternion 26 | 27 | def pose_quaternion_to_pose_euler(pose_quaternion): 28 | pose_euler = np.zeros(6) 29 | pose_euler[:3] = pose_quaternion[:3] 30 | pose_euler[3:] = euler_from_quaternion(*pose_quaternion[3:]) 31 | return pose_euler 32 | 33 | # ref: https://github.com/stepjam/RLBench/blob/master/rlbench/action_modes/arm_action_modes.py#L30 34 | def calculate_goal_pose(current_pose: np.ndarray, action: np.ndarray): 35 | a_x, a_y, a_z, a_qx, a_qy, a_qz, a_qw = action 36 | x, y, z, qx, qy, qz, qw = current_pose 37 | 38 | Qp = Quaternion(a_qw, a_qx, a_qy, a_qz) 39 | Qch = Quaternion(qw, qx, qy, qz) 40 | QW = Qp * Qch 41 | 42 | # new_rot = Quaternion( 43 | # a_qw, a_qx, a_qy, a_qz) * Quaternion(qw, qx, qy, qz) 44 | qw, qx, qy, qz = list(QW) 45 | pose = [a_x + x, a_y + y, a_z + z] + [qx, qy, qz, qw] 46 | return np.array(pose) 47 | 48 | # ref: https://math.stackexchange.com/questions/2124361/quaternions-multiplication-order-to-rotate-unrotate 49 | def calculate_action(current_pose: np.ndarray, goal_pose: np.ndarray): 50 | x1, y1, z1, qx1, qy1, qz1, qw1 = current_pose 51 | x2, y2, z2, qx2, qy2, qz2, qw2 = goal_pose 52 | 53 | #QW == Qp * Qch 54 | QW = Quaternion(qw2, qx2, qy2, qz2) 55 | Qch = Quaternion(qw1, qx1, qy1, qz1) 56 | # Qp == QW * Qch.Inversed 57 | Qp = QW * Qch.inverse 58 | 59 | # QW = R.from_quat([qx2, qy2, qz2, qw2]) 60 | # Qch = R.from_quat([qx1, qy1, qz1, qw1]) 61 | # Qp = QW * Qch.inv 62 | # Qp = Qp.as_quat() 63 | 64 | a_qw, a_qx, a_qy, a_qz = list(Qp) 65 | a_x, a_y, a_z = x2-x1, y2-y1, z2-z1 66 | 67 | action = [a_x, a_y, a_z] + [a_qx, a_qy, a_qz, a_qw] 68 | return np.array(action) 69 | 70 | def compute_rel_transform(A_pos, A_mat, B_pos, B_mat): 71 | T_WA = np.vstack((np.hstack((A_mat, A_pos[:, None])), [0, 0, 0, 1])) 72 | T_WB = np.vstack((np.hstack((B_mat, B_pos[:, None])), [0, 0, 0, 1])) 73 | 74 | T_AB = np.matmul(np.linalg.inv(T_WA), T_WB) 75 | 76 | return T_AB[:3, 3], T_AB[:3, :3] 77 | 78 | # target_obj_pose, grasp_obj_pose 79 | def get_rel_pose(pose1, pose2): 80 | pos1 = np.array(pose1[:3]) 81 | quat1 = np.array(pose1[3:]) 82 | mat1 = T.quat2mat(quat1) 83 | pos2 = np.array(pose2[:3]) 84 | quat2 = np.array(pose2[3:]) 85 | mat2 = T.quat2mat(quat2) 86 | 87 | pos, mat = compute_rel_transform(pos1, mat1, pos2, mat2) 88 | quat = T.mat2quat(mat) 89 | return np.concatenate([pos, quat]) 90 | 91 | 92 | # def realtive_to_target_to_world(env, obj_pose_relative_to_cab, cabinet): 93 | def relative_to_target_to_world(subgoal_relative_to_target, target_obj_pose): 94 | pos1 = subgoal_relative_to_target[:3] 95 | mat1 = T.quat2mat(subgoal_relative_to_target[3:]) 96 | pos2 = target_obj_pose[:3] 97 | mat2 = T.quat2mat(target_obj_pose[3:]) 98 | 99 | # T_WA = T_WB @ T_BA 100 | T_BA = np.vstack((np.hstack((mat1, pos1[:, None])), [0, 0, 0, 1])) 101 | T_WB = np.vstack((np.hstack((mat2, pos2[:, None])), [0, 0, 0, 1])) 102 | T_WA = np.matmul(T_WB, T_BA) 103 | 104 | pos = T_WA[:3, 3] 105 | mat = T_WA[:3, :3] 106 | quat = T.mat2quat(mat) 107 | return np.concatenate([pos, quat]) 108 | 109 | def euler_from_quaternion(x, y, z, w): 110 | # import math 111 | # t0 = +2.0 * (w * x + y * z) 112 | # t1 = +1.0 - 2.0 * (x * x + y * y) 113 | # roll_x = math.atan2(t0, t1) 114 | 115 | # t2 = +2.0 * (w * y - z * x) 116 | # t2 = +1.0 if t2 > +1.0 else t2 117 | # t2 = -1.0 if t2 < -1.0 else t2 118 | # pitch_y = math.asin(t2) 119 | 120 | # t3 = +2.0 * (w * z + x * y) 121 | # t4 = +1.0 - 2.0 * (y * y + z * z) 122 | # yaw_z = math.atan2(t3, t4) 123 | 124 | # return roll_x, pitch_y, yaw_z 125 | 126 | rot = R.from_quat([x, y, z, w]) # (x, y, z, w) 127 | euler = rot.as_euler('xyz') 128 | return euler 129 | 130 | # def matrix_from_quaternion(x, y, z, w): 131 | # rot = R.from_quat([x, y, z, w]) # (x, y, z, w) 132 | # mat = rot.as_matrix() 133 | # return mat 134 | 135 | def quaternion_from_euler(euler): 136 | rot = R.from_euler('xyz', euler) 137 | quat = rot.as_quat() # (x, y, z, w) 138 | return quat 139 | 140 | 141 | def rodrigues(r, calculate_jacobian=True): 142 | """Computes the Rodrigues transform and its derivative 143 | 144 | :param r: either a 3-vector representing the rotation parameter, or a full rotation matrix 145 | :param calculate_jacobian: indicates if the Jacobian of the transform is also required 146 | :returns: If `calculate_jacobian` is `True`, the Jacobian is given as the second element of the returned tuple. 147 | """ 148 | 149 | r = np.array(r, dtype=np.double) 150 | eps = np.finfo(np.double).eps 151 | 152 | if np.all(r.shape == (3, 1)) or np.all(r.shape == (1, 3)) or np.all(r.shape == (3,)): 153 | r = r.flatten() 154 | theta = np.linalg.norm(r) 155 | if theta < eps: 156 | r_out = np.eye(3) 157 | if calculate_jacobian: 158 | jac = np.zeros((3, 9)) 159 | jac[0, 5] = jac[1, 6] = jac[2, 1] = -1 160 | jac[0, 7] = jac[1, 2] = jac[2, 3] = 1 161 | 162 | else: 163 | c = np.cos(theta) 164 | s = np.sin(theta) 165 | c1 = 1. - c 166 | itheta = 1.0 if theta == 0.0 else 1.0 / theta 167 | r *= itheta 168 | I = np.eye(3) 169 | rrt = np.array([r * r[0], r * r[1], r * r[2]]) 170 | _r_x_ = np.array([[0, -r[2], r[1]], [r[2], 0, -r[0]], [-r[1], r[0], 0]]) 171 | r_out = c * I + c1 * rrt + s * _r_x_ 172 | if calculate_jacobian: 173 | drrt = np.array([[r[0] + r[0], r[1], r[2], r[1], 0, 0, r[2], 0, 0], 174 | [0, r[0], 0, r[0], r[1] + r[1], r[2], 0, r[2], 0], 175 | [0, 0, r[0], 0, 0, r[1], r[0], r[1], r[2] + r[2]]]) 176 | d_r_x_ = np.array([[0, 0, 0, 0, 0, -1, 0, 1, 0], 177 | [0, 0, 1, 0, 0, 0, -1, 0, 0], 178 | [0, -1, 0, 1, 0, 0, 0, 0, 0]]) 179 | I = np.array([I.flatten(), I.flatten(), I.flatten()]) 180 | ri = np.array([[r[0]], [r[1]], [r[2]]]) 181 | a0 = -s * ri 182 | a1 = (s - 2 * c1 * itheta) * ri 183 | a2 = np.ones((3, 1)) * c1 * itheta 184 | a3 = (c - s * itheta) * ri 185 | a4 = np.ones((3, 1)) * s * itheta 186 | jac = a0 * I + a1 * rrt.flatten() + a2 * drrt + a3 * _r_x_.flatten() + a4 * d_r_x_ 187 | elif np.all(r.shape == (3, 3)): 188 | u, d, v = np.linalg.svd(r) 189 | r = np.dot(u, v) 190 | rx = r[2, 1] - r[1, 2] 191 | ry = r[0, 2] - r[2, 0] 192 | rz = r[1, 0] - r[0, 1] 193 | s = np.linalg.norm(np.array([rx, ry, rz])) * np.sqrt(0.25) 194 | c = np.clip((np.sum(np.diag(r)) - 1) * 0.5, -1, 1) 195 | theta = np.arccos(c) 196 | if s < 1e-5: 197 | if c > 0: 198 | r_out = np.zeros((3, 1)) 199 | else: 200 | rx, ry, rz = np.clip(np.sqrt((np.diag(r) + 1) * 0.5), 0, np.inf) 201 | if r[0, 1] < 0: 202 | ry = -ry 203 | if r[0, 2] < 0: 204 | rz = -rz 205 | if np.abs(rx) < np.abs(ry) and np.abs(rx) < np.abs(rz) and ((r[1, 2] > 0) != (ry * rz > 0)): 206 | rz = -rz 207 | 208 | r_out = np.array([[rx, ry, rz]]).T 209 | theta /= np.linalg.norm(r_out) 210 | r_out *= theta 211 | if calculate_jacobian: 212 | jac = np.zeros((9, 3)) 213 | if c > 0: 214 | jac[1, 2] = jac[5, 0] = jac[6, 1] = -0.5 215 | jac[2, 1] = jac[3, 2] = jac[7, 0] = 0.5 216 | else: 217 | vth = 1.0 / (2.0 * s) 218 | if calculate_jacobian: 219 | dtheta_dtr = -1. / s 220 | dvth_dtheta = -vth * c / s 221 | d1 = 0.5 * dvth_dtheta * dtheta_dtr 222 | d2 = 0.5 * dtheta_dtr 223 | dvardR = np.array([ 224 | [0, 0, 0, 0, 0, 1, 0, -1, 0], 225 | [0, 0, -1, 0, 0, 0, 1, 0, 0], 226 | [0, 1, 0, -1, 0, 0, 0, 0, 0], 227 | [d1, 0, 0, 0, d1, 0, 0, 0, d1], 228 | [d2, 0, 0, 0, d2, 0, 0, 0, d2]]) 229 | dvar2dvar = np.array([ 230 | [vth, 0, 0, rx, 0], 231 | [0, vth, 0, ry, 0], 232 | [0, 0, vth, rz, 0], 233 | [0, 0, 0, 0, 1]]) 234 | domegadvar2 = np.array([ 235 | [theta, 0, 0, rx * vth], 236 | [0, theta, 0, ry * vth], 237 | [0, 0, theta, rz * vth]]) 238 | jac = np.dot(np.dot(domegadvar2, dvar2dvar), dvardR) 239 | for ii in range(3): 240 | jac[ii] = jac[ii].reshape((3, 3)).T.flatten() 241 | jac = jac.T 242 | vth *= theta 243 | r_out = np.array([[rx, ry, rz]]).T * vth 244 | else: 245 | raise Exception("rodrigues: input matrix must be 1x3, 3x1 or 3x3.") 246 | if calculate_jacobian: 247 | return r_out, jac 248 | else: 249 | return r_out 250 | 251 | 252 | def rodrigues2rotmat(r): 253 | # R = np.zeros((3, 3)) 254 | r_skew = np.array([[0, -r[2], r[1]], [r[2], 0, -r[0]], [-r[1], r[0], 0]]) 255 | theta = np.linalg.norm(r) 256 | return np.identity(3) + np.sin(theta) * r_skew + (1 - np.cos(theta)) * r_skew.dot(r_skew) 257 | 258 | 259 | if __name__ == "__main__": 260 | target_obj_pose = np.array([ 0.32260922, -0.10839751, 0.96184993, 0.57075304, 0.05685034, -0.81511486, 0.08122107]) 261 | grasp_obj_pose = np.array([ 1.99161470e-01, 3.34810495e-01, 7.91927218e-01, -1.26936930e-05, 2.39512883e-06, 9.92952466e-01, -1.18513443e-01]) 262 | # obj_pose_relative_to_target = np.array([ 0.17114308, -0.45885825, -0.02656152, -0.01118924, -0.57345784, 0.0159555, 0.81900305]) 263 | 264 | np.set_printoptions(precision=3) 265 | obj_pose_relative_to_target = get_rel_pose(target_obj_pose, grasp_obj_pose) 266 | grasp_obj_pose_2 = relative_to_target_to_world(obj_pose_relative_to_target, target_obj_pose) 267 | print("grasp_obj_pose", grasp_obj_pose) 268 | print("grasp_obj_pose_2", grasp_obj_pose_2) 269 | 270 | obj_pose_relative_to_target = get_rel_pose(target_obj_pose, grasp_obj_pose) 271 | grasp_obj_pose_2 = relative_to_target_to_world(obj_pose_relative_to_target, target_obj_pose) 272 | print("grasp_obj_pose", grasp_obj_pose) 273 | print("grasp_obj_pose_2", grasp_obj_pose_2) -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import os 11 | import sys 12 | sys.path.insert(0, os.getcwd()) 13 | 14 | import torch 15 | import numpy as np 16 | from utils.vis_o3d_utils import O3DVisualizer 17 | from utils.pose_utils import calculate_goal_pose 18 | 19 | 20 | def dp3_visualize(agent_pos, pred=None, target=None, visualize=True, predict_type='relative'): 21 | if visualize: 22 | vis = O3DVisualizer() 23 | 24 | if isinstance(agent_pos, torch.Tensor): 25 | agent_pos_cpu = agent_pos.clone().detach().cpu().numpy() 26 | if pred is not None: 27 | pred_cpu = pred.clone().detach().cpu().numpy() 28 | if target is not None: 29 | target_cpu = target.clone().detach().cpu().numpy() 30 | else: 31 | agent_pos_cpu = agent_pos 32 | if pred is not None: 33 | pred_cpu = pred 34 | if target is not None: 35 | target_cpu = target 36 | 37 | if len(agent_pos_cpu.shape) == 2: # not a batch 38 | agent_pos_cpu = agent_pos_cpu[None, ...] 39 | if pred is not None: 40 | pred_cpu = pred_cpu[None, ...] 41 | if target is not None: 42 | target_cpu = target_cpu[None, ...] 43 | print(agent_pos_cpu.shape) 44 | 45 | 46 | input_count = 0 47 | pred_count = 0 48 | gt_count = 0 49 | for batch_idx in range(len(agent_pos_cpu)): 50 | 51 | cur_pose = agent_pos_cpu[batch_idx] 52 | print(cur_pose.shape) 53 | if target is not None: 54 | gt_pose = target_cpu[batch_idx] 55 | print(gt_pose.shape) 56 | 57 | 58 | for pose_idx in range(cur_pose.shape[0]): 59 | np.set_printoptions(precision=3) 60 | print("batch %d" % batch_idx) 61 | # print("input", cur_pose[pose_idx]) 62 | if visualize: 63 | vis.add_pose_from_traj(cur_pose[pose_idx].reshape(-1, 7), pos_only=False, paint_color=[1., 0., 0.]) 64 | input_count += 1 65 | 66 | if target is not None: 67 | if predict_type == 'relative': 68 | # euler 69 | # cur_gt_pose = cur_pose[pose_idx] + gt_pose[pose_idx] 70 | # quaternion 71 | cur_gt_pose = calculate_goal_pose(cur_pose[pose_idx][:7], gt_pose[pose_idx][:7]) 72 | else: 73 | cur_gt_pose = gt_pose[pose_idx] 74 | # print("action (gt)", cur_gt_pose) 75 | if visualize: 76 | vis.add_pose_from_traj(cur_gt_pose.reshape(-1, 7), pos_only=False, paint_color=[0., 1., 0.]) 77 | gt_count += 1 78 | if pred is not None: 79 | pred_pose = pred_cpu[batch_idx] 80 | print(pred_pose.shape) 81 | 82 | if predict_type == 'relative': 83 | iterated_pose = cur_pose[0] 84 | for pose_idx in range(pred_pose.shape[0]): 85 | if predict_type == 'relative': 86 | # euler 87 | # cur_pred_pose = cur_pose[pose_idx] + pred_pose[pose_idx] 88 | # quaternion 89 | # cur_pred_pose = calculate_goal_pose(cur_pose[pose_idx], pred_pose[pose_idx]) 90 | cur_pred_pose = calculate_goal_pose(iterated_pose, pred_pose[pose_idx]) 91 | else: 92 | cur_pred_pose = pred_pose[pose_idx] 93 | # print("action (pred)", cur_pred_pose) 94 | if visualize: 95 | vis.add_pose_from_traj(cur_pred_pose.reshape(-1, 7), pos_only=False, paint_color=[0., 0., 1.]) 96 | 97 | if predict_type == 'relative': 98 | # print("iterated_pose", iterated_pose) 99 | # print("cur_pose", cur_pose[pose_idx+1]) 100 | if pose_idx+1 < cur_pose.shape[0]: 101 | print("difference", iterated_pose - cur_pose[pose_idx+1]) 102 | iterated_pose = cur_pred_pose 103 | pred_count += 1 104 | 105 | print(input_count, gt_count, pred_count) 106 | if visualize: 107 | vis.draw() 108 | 109 | 110 | import trimesh 111 | from foundation_pose.Utils import draw_posed_3d_box, draw_xyz_axis 112 | 113 | def get_vis_pose(pose, color, K, mesh): 114 | to_origin, extents = trimesh.bounds.oriented_bounds(mesh) 115 | bbox = np.stack([-extents/2, extents/2], axis=0).reshape(2,3) 116 | 117 | center_pose = pose @ np.linalg.inv(to_origin) 118 | 119 | vis = draw_posed_3d_box(K, img=color, ob_in_cam=center_pose, bbox=bbox) 120 | vis = draw_xyz_axis(color, ob_in_cam=center_pose, scale=0.1, K=K, thickness=3, transparency=0, is_input_rgb=True) 121 | return vis --------------------------------------------------------------------------------