├── .gitignore ├── README.md ├── configs ├── __init__.py ├── config.yaml ├── data │ └── default.yaml ├── datasets │ └── default.yaml ├── eval │ └── default.yaml ├── lifelong │ ├── agem.yaml │ ├── base.yaml │ ├── er.yaml │ ├── ewc.yaml │ ├── groot_single_task.yaml │ ├── multitask.yaml │ ├── packnet.yaml │ ├── single_task.yaml │ └── vos_3d_single_task_low_precision.yaml ├── policy │ ├── bc_mae_policy.yaml │ ├── bc_mae_policy.yaml~ │ ├── bc_peract_policy.yaml~ │ ├── bc_real_robot_rnn_rgbd_policy.yaml │ ├── bc_rnn_policy.yaml │ ├── bc_rnn_rgbd_policy.yaml │ ├── bc_transformer_policy.yaml │ ├── bc_vilt_policy.yaml │ ├── bc_viola_policy.yaml │ ├── data_augmentation │ │ ├── batch_wise_img_color_jitter_group_aug.yaml │ │ ├── identity_aug.yaml │ │ ├── img_color_jitter_group_aug.yaml │ │ ├── se3_augmentation.yaml │ │ ├── se3_augmentation_2.yaml │ │ ├── se3_augmentation_3.yaml │ │ ├── translation_aug.yaml │ │ └── translation_aug_group.yaml │ ├── groot_real_robot_no_wrist_transformer_policy.yaml │ ├── image_encoder │ │ ├── patch_encoder.yaml │ │ └── resnet_encoder.yaml │ ├── language_encoder │ │ ├── clip_encoder.yaml │ │ ├── identity_encoder.yaml │ │ ├── mlp_encoder.yaml │ │ └── rnn_encoder.yaml │ ├── pcd_encoder │ │ ├── masked_pointnet_encoder.yaml │ │ └── pointnet_encoder.yaml │ ├── policy_head │ │ └── gmm_head.yaml │ ├── position_encoding │ │ └── sinusoidal_position_encoding.yaml │ ├── vos_3d_masked_transformer_policy.yaml │ ├── vos_3d_object_transformer_policy.yaml │ ├── vos_3d_real_robot_masked_no_wrist_transformer_policy.yaml │ ├── vos_3d_real_robot_object_transformer_policy.yaml │ ├── vos_3d_robot_embeddings_masked_transformer_policy.yaml │ ├── vos_3d_robot_masked_ee_transformer_policy.yaml │ ├── vos_3d_robot_masked_no_wrist_transformer_policy.yaml │ ├── vos_3d_robot_masked_transformer_policy.yaml │ ├── vos_3d_transformer_policy.yaml │ ├── vos_3d_wrist_transformer_policy.yaml │ ├── vos_3d_wrist_transformer_policy_v2.yaml │ └── wrist_transformer_policy.yaml ├── real_robot_experiments │ └── example_exp.yaml ├── templates │ ├── augmentation_template.yaml │ ├── grouping_template.yaml │ └── masked_transformer.yaml └── train │ ├── default.yaml │ ├── optimizer │ └── adam_w.yaml │ └── scheduler │ ├── cosine_annealing.yaml │ └── cosine_annealing_warm_restarts.yaml ├── docs └── scribble_demo.gif ├── download_example_data.sh ├── groot_imitation ├── __init__.py ├── groot_algo │ ├── __init__.py │ ├── dataset_preprocessing │ │ ├── __init__.py │ │ ├── pcd_generation.py │ │ └── vos_annotation.py │ ├── dino_features.py │ ├── env_wrapper.py │ ├── eval_utils.py │ ├── groot_transformer │ │ ├── __init__.py │ │ ├── groot_real_robot_no_wrist_transformer_policy.py │ │ └── groot_transformer_algo.py │ ├── input_utils.py │ ├── misc_utils.py │ ├── modules.py │ ├── o3d_modules.py │ ├── point_mae_modules.py │ ├── sam_operator.py │ └── xmem_tracker.py ├── segmentation_correspondence_model │ ├── init_path.py │ ├── sam_amg.py │ └── scm.py └── vision_model_configs │ ├── sam_config.yaml │ └── xmem_config.yaml ├── real_robot_scripts ├── __init__.py ├── create_dataset.py ├── deoxys_data_collection.py ├── deoxys_reset_joints.py ├── evaluation_configs │ ├── eval_camera.yaml │ ├── eval_canonical.yaml │ ├── eval_distractions.yaml │ └── eval_new_instances.yaml ├── groot_img_utils.py ├── init_path.py ├── real_robot_observation_cfg_example.yml ├── real_robot_sam_result.py ├── real_robot_utils.py ├── verify_camera_observations.ipynb └── verify_camera_observations.py ├── requirements.txt ├── scripts ├── aug_demo_generation.py ├── dataset_configs │ ├── config.yaml │ └── datasets │ │ └── default.yaml ├── init_path.py ├── interactive_demo_example.py ├── interactive_demo_from_datasets.py ├── prepare_training_set.py ├── real_robot_eval_checkpoint.py ├── single_task_real_robot_training.py └── single_task_training.py ├── setup.py ├── setup_vision_models.sh ├── third_party └── XMem │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── dataset │ ├── __init__.py │ ├── range_transform.py │ ├── reseed.py │ ├── static_dataset.py │ ├── tps.py │ ├── util.py │ └── vos_dataset.py │ ├── docs │ ├── DEMO.md │ ├── ECCV-logo.png │ ├── FAILURE_CASES.md │ ├── GETTING_STARTED.md │ ├── INFERENCE.md │ ├── PALETTE.md │ ├── RESULTS.md │ ├── TRAINING.md │ ├── icon.png │ ├── index.html │ └── style.css │ ├── eval.py │ ├── inference │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── mask_mapper.py │ │ ├── test_datasets.py │ │ └── video_reader.py │ ├── inference_core.py │ ├── interact │ │ ├── __init__.py │ │ ├── fbrs │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ ├── controller.py │ │ │ ├── inference │ │ │ │ ├── __init__.py │ │ │ │ ├── clicker.py │ │ │ │ ├── evaluation.py │ │ │ │ ├── predictors │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base.py │ │ │ │ │ ├── brs.py │ │ │ │ │ ├── brs_functors.py │ │ │ │ │ └── brs_losses.py │ │ │ │ ├── transforms │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base.py │ │ │ │ │ ├── crops.py │ │ │ │ │ ├── flip.py │ │ │ │ │ ├── limit_longest_side.py │ │ │ │ │ └── zoom_in.py │ │ │ │ └── utils.py │ │ │ ├── model │ │ │ │ ├── __init__.py │ │ │ │ ├── initializer.py │ │ │ │ ├── is_deeplab_model.py │ │ │ │ ├── is_hrnet_model.py │ │ │ │ ├── losses.py │ │ │ │ ├── metrics.py │ │ │ │ ├── modeling │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── basic_blocks.py │ │ │ │ │ ├── deeplab_v3.py │ │ │ │ │ ├── hrnet_ocr.py │ │ │ │ │ ├── ocr.py │ │ │ │ │ ├── resnet.py │ │ │ │ │ └── resnetv1b.py │ │ │ │ ├── ops.py │ │ │ │ └── syncbn │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── README.md │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── modules │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── functional │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── _csrc.py │ │ │ │ │ ├── csrc │ │ │ │ │ │ ├── bn.h │ │ │ │ │ │ ├── cuda │ │ │ │ │ │ │ ├── bn_cuda.cu │ │ │ │ │ │ │ ├── common.h │ │ │ │ │ │ │ └── ext_lib.h │ │ │ │ │ │ └── ext_lib.cpp │ │ │ │ │ └── syncbn.py │ │ │ │ │ └── nn │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── syncbn.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── cython │ │ │ │ ├── __init__.py │ │ │ │ ├── _get_dist_maps.pyx │ │ │ │ ├── _get_dist_maps.pyxbld │ │ │ │ └── dist_maps.py │ │ │ │ ├── misc.py │ │ │ │ └── vis.py │ │ ├── fbrs_controller.py │ │ ├── gui.py │ │ ├── gui_utils.py │ │ ├── interaction.py │ │ ├── interactive_utils.py │ │ ├── resource_manager.py │ │ ├── s2m │ │ │ ├── __init__.py │ │ │ ├── _deeplab.py │ │ │ ├── s2m_network.py │ │ │ ├── s2m_resnet.py │ │ │ └── utils.py │ │ ├── s2m_controller.py │ │ └── timer.py │ ├── kv_memory_store.py │ └── memory_manager.py │ ├── interactive_demo.py │ ├── merge_multi_scale.py │ ├── model │ ├── __init__.py │ ├── aggregate.py │ ├── cbam.py │ ├── group_modules.py │ ├── losses.py │ ├── memory_util.py │ ├── modules.py │ ├── network.py │ ├── resnet.py │ └── trainer.py │ ├── requirements.txt │ ├── requirements_demo.txt │ ├── scripts │ ├── __init__.py │ ├── download_bl30k.py │ ├── download_datasets.py │ ├── download_models.sh │ ├── download_models_demo.sh │ ├── expand_long_vid.py │ └── resize_youtube.py │ ├── train.py │ └── util │ ├── __init__.py │ ├── configuration.py │ ├── davis_subset.txt │ ├── image_saver.py │ ├── load_subset.py │ ├── log_integrator.py │ ├── logger.py │ ├── palette.py │ ├── tensor_util.py │ └── yv_subset.txt └── walkthrough_example.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | /finalize_result.py~ 3 | /result_annotation.py~ 4 | *.ply 5 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/configs/__init__.py -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: default 6 | - datasets: default 7 | - policy: groot_real_robot_no_wrist_transformer_policy 8 | - train: default 9 | - eval: default 10 | - lifelong: groot_single_task 11 | - real_robot_experiments: null 12 | - test: null 13 | 14 | seed: 10000 15 | use_wandb: false 16 | wandb_project: "lifelong learning" 17 | folder: null # use default path 18 | bddl_folder: null # use default path 19 | init_states_folder: null # use default path 20 | load_previous_model: false 21 | device: "cuda" 22 | task_embedding_format: "bert" 23 | task_embedding_one_hot_offset: 1 24 | pretrain: false 25 | pretrain_model_path: "" 26 | benchmark_name: "GROOT_Real_Robot_Benchmark" 27 | task_id: 0 28 | 29 | 30 | dataset_path: null 31 | save_video: true 32 | 33 | vos_annotation: true 34 | object_pcd: true 35 | pcd_aug: true 36 | pcd_grouping: true 37 | delete_intermediate_files: true 38 | -------------------------------------------------------------------------------- /configs/data/default.yaml: -------------------------------------------------------------------------------- 1 | # observation related 2 | data_modality: 3 | - "image" 4 | - "proprio" 5 | seq_len: 10 6 | frame_stack: 1 7 | use_eye_in_hand: true 8 | use_gripper: true 9 | use_joint: true 10 | use_ee: false 11 | 12 | max_word_len: 25 13 | 14 | state_dim: null 15 | num_kp: 64 16 | img_h: 128 17 | img_w: 128 18 | 19 | task_group_size: 1 20 | task_order_index: 0 21 | shuffle_task: false 22 | 23 | obs: 24 | modality: 25 | rgb: [] # ["agentview_rgb", "eye_in_hand_rgb"] 26 | depth: [] 27 | low_dim: ["gripper_states", "joint_states"] 28 | pcd: ["xyz"] 29 | wrist_depth: [] 30 | 31 | # mapping from obs.modality keys to robosuite environment observation keys 32 | obs_key_mapping: 33 | agentview_rgb: agentview_image 34 | eye_in_hand_rgb: robot0_eye_in_hand_image 35 | eye_in_hand_depth: robot0_eye_in_hand_depth 36 | gripper_states: robot0_gripper_qpos 37 | joint_states: robot0_joint_pos 38 | ee_states: ee_states 39 | xyz: xyz 40 | # This is point cloud rgb, not images 41 | rgb: rgb 42 | neighborhood_10_64: neighborhood_10_64 43 | centers_10_64: centers_10_64 44 | 45 | # action related 46 | affine_translate: 4 47 | action_scale: 1.0 48 | train_dataset_ratio: 0.8 49 | -------------------------------------------------------------------------------- /configs/datasets/default.yaml: -------------------------------------------------------------------------------- 1 | max_points: 512 2 | num_group: 10 3 | group_size: 64 4 | 5 | # dataset point cloud normalization range 6 | max_array: [0.69943695, 0.5, 0.45784091] 7 | min_array: [0.0, -0.5, 0.0] 8 | 9 | # configuration for augmenting depth point clouds 10 | aug: 11 | workspace_center: [0.6, 0.0, 0.0] 12 | rotations: [30] 13 | 14 | erode: false -------------------------------------------------------------------------------- /configs/eval/default.yaml: -------------------------------------------------------------------------------- 1 | load_path: "" # only used when separately evaluating a pretrained model 2 | eval: true 3 | batch_size: 64 4 | num_workers: 4 5 | n_eval: 20 6 | eval_every: 9 7 | max_steps: 600 8 | use_mp: true 9 | num_procs: 20 10 | save_sim_states: false 11 | -------------------------------------------------------------------------------- /configs/lifelong/agem.yaml: -------------------------------------------------------------------------------- 1 | algo: AGEM 2 | n_memories: 1000 3 | -------------------------------------------------------------------------------- /configs/lifelong/base.yaml: -------------------------------------------------------------------------------- 1 | algo: Sequential 2 | -------------------------------------------------------------------------------- /configs/lifelong/er.yaml: -------------------------------------------------------------------------------- 1 | algo: ER 2 | n_memories: 1000 3 | -------------------------------------------------------------------------------- /configs/lifelong/ewc.yaml: -------------------------------------------------------------------------------- 1 | algo: EWC 2 | e_lambda: 50000 3 | gamma: 0.9 4 | -------------------------------------------------------------------------------- /configs/lifelong/groot_single_task.yaml: -------------------------------------------------------------------------------- 1 | algo: GROOTSingleTask 2 | 3 | -------------------------------------------------------------------------------- /configs/lifelong/multitask.yaml: -------------------------------------------------------------------------------- 1 | algo: Multitask 2 | eval_in_train: false 3 | -------------------------------------------------------------------------------- /configs/lifelong/packnet.yaml: -------------------------------------------------------------------------------- 1 | algo: PackNet 2 | prune_perc: 0.75 3 | post_prune_epochs: 50 4 | post_eval_every: 5 5 | -------------------------------------------------------------------------------- /configs/lifelong/single_task.yaml: -------------------------------------------------------------------------------- 1 | algo: SingleTask 2 | -------------------------------------------------------------------------------- /configs/lifelong/vos_3d_single_task_low_precision.yaml: -------------------------------------------------------------------------------- 1 | algo: VOS3DSingleTaskLowPrecision 2 | 3 | -------------------------------------------------------------------------------- /configs/policy/bc_mae_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: BCMAEPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 128 5 | 6 | spatial_transformer_input_size: null 7 | spatial_transformer_num_layers: 7 8 | spatial_transformer_num_heads: 8 9 | spatial_transformer_head_output_size: 120 10 | spatial_transformer_mlp_hidden_size: 256 11 | spatial_transformer_dropout: 0.1 12 | 13 | spatial_down_sample: true 14 | spatial_down_sample_embed_size: 64 15 | 16 | transformer_input_size: null 17 | transformer_num_layers: 4 18 | transformer_num_heads: 6 19 | transformer_head_output_size: 64 20 | transformer_mlp_hidden_size: 256 21 | transformer_dropout: 0.1 22 | transformer_max_seq_len: 10 23 | 24 | defaults: 25 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 26 | - data_augmentation@translation_aug: translation_aug.yaml 27 | - image_encoder: patch_encoder.yaml 28 | - language_encoder: mlp_encoder.yaml 29 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 30 | - policy_head: gmm_head.yaml 31 | -------------------------------------------------------------------------------- /configs/policy/bc_mae_policy.yaml~: -------------------------------------------------------------------------------- 1 | policy_type: BCPeractPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/bc_peract_policy.yaml~: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/configs/policy/bc_peract_policy.yaml~ -------------------------------------------------------------------------------- /configs/policy/bc_real_robot_rnn_rgbd_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: BCRealRobotRNNRGBDPolicy 2 | image_embed_size: 64 3 | text_embed_size: 32 4 | 5 | rnn_hidden_size: 1024 6 | rnn_num_layers: 2 7 | rnn_dropout: 0.0 8 | rnn_bidirectional: false 9 | 10 | defaults: 11 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 12 | - data_augmentation@translation_aug: translation_aug.yaml 13 | - image_encoder: resnet_encoder 14 | - language_encoder: mlp_encoder 15 | - policy_head: gmm_head 16 | -------------------------------------------------------------------------------- /configs/policy/bc_rnn_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: BCRNNPolicy 2 | image_embed_size: 64 3 | text_embed_size: 32 4 | 5 | rnn_hidden_size: 1024 6 | rnn_num_layers: 2 7 | rnn_dropout: 0.0 8 | rnn_bidirectional: false 9 | 10 | defaults: 11 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 12 | - data_augmentation@translation_aug: translation_aug.yaml 13 | - image_encoder: resnet_encoder 14 | - language_encoder: mlp_encoder 15 | - policy_head: gmm_head 16 | -------------------------------------------------------------------------------- /configs/policy/bc_rnn_rgbd_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: BCRNNRGBDPolicy 2 | image_embed_size: 64 3 | text_embed_size: 32 4 | 5 | rnn_hidden_size: 1024 6 | rnn_num_layers: 2 7 | rnn_dropout: 0.0 8 | rnn_bidirectional: false 9 | 10 | defaults: 11 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 12 | - data_augmentation@translation_aug: translation_aug.yaml 13 | - image_encoder: resnet_encoder 14 | - language_encoder: mlp_encoder 15 | - policy_head: gmm_head 16 | -------------------------------------------------------------------------------- /configs/policy/bc_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: BCTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - image_encoder: resnet_encoder.yaml 18 | - language_encoder: mlp_encoder.yaml 19 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 20 | - policy_head: gmm_head.yaml 21 | -------------------------------------------------------------------------------- /configs/policy/bc_vilt_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: BCViLTPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 128 5 | 6 | spatial_transformer_input_size: null 7 | spatial_transformer_num_layers: 7 8 | spatial_transformer_num_heads: 8 9 | spatial_transformer_head_output_size: 120 10 | spatial_transformer_mlp_hidden_size: 256 11 | spatial_transformer_dropout: 0.1 12 | 13 | spatial_down_sample: true 14 | spatial_down_sample_embed_size: 64 15 | 16 | transformer_input_size: null 17 | transformer_num_layers: 4 18 | transformer_num_heads: 6 19 | transformer_head_output_size: 64 20 | transformer_mlp_hidden_size: 256 21 | transformer_dropout: 0.1 22 | transformer_max_seq_len: 10 23 | 24 | defaults: 25 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 26 | - data_augmentation@translation_aug: translation_aug.yaml 27 | - image_encoder: patch_encoder.yaml 28 | - language_encoder: mlp_encoder.yaml 29 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 30 | - policy_head: gmm_head.yaml 31 | -------------------------------------------------------------------------------- /configs/policy/bc_viola_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: BCViolaPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/data_augmentation/batch_wise_img_color_jitter_group_aug.yaml: -------------------------------------------------------------------------------- 1 | network: BatchWiseImgColorJitterAug 2 | 3 | network_kwargs: 4 | input_shape: null 5 | brightness: 0.3 6 | contrast: 0.3 7 | saturation: 0.3 8 | hue: 0.3 9 | epsilon: 0.1 -------------------------------------------------------------------------------- /configs/policy/data_augmentation/identity_aug.yaml: -------------------------------------------------------------------------------- 1 | network: IdentityAug 2 | 3 | network_kwargs: 4 | input_shape: null -------------------------------------------------------------------------------- /configs/policy/data_augmentation/img_color_jitter_group_aug.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/configs/policy/data_augmentation/img_color_jitter_group_aug.yaml -------------------------------------------------------------------------------- /configs/policy/data_augmentation/se3_augmentation.yaml: -------------------------------------------------------------------------------- 1 | network: SE3Augmentation 2 | 3 | network_kwargs: 4 | mean: 0.0 5 | std: 0.02 6 | enabled: True 7 | rot_range: [0.1, 0.1, 0.1] -------------------------------------------------------------------------------- /configs/policy/data_augmentation/se3_augmentation_2.yaml: -------------------------------------------------------------------------------- 1 | network: SE3Augmentation2 2 | 3 | network_kwargs: 4 | mean: 0.0 5 | std: 0.02 6 | enabled: True 7 | use_position: True 8 | use_rotation: True 9 | location_range: [1., 1., 1.] 10 | rot_range: [0.1, 0.1, 0.1] -------------------------------------------------------------------------------- /configs/policy/data_augmentation/se3_augmentation_3.yaml: -------------------------------------------------------------------------------- 1 | network: SE3Augmentation3 2 | 3 | network_kwargs: 4 | mean: 0.0 5 | std: 0.002 6 | enabled: True 7 | use_position: True 8 | use_rotation: True 9 | location_range: [1., 1., 1.] 10 | rot_range: [0.1, 0.1, 0.1] -------------------------------------------------------------------------------- /configs/policy/data_augmentation/translation_aug.yaml: -------------------------------------------------------------------------------- 1 | network: TranslationAug 2 | 3 | network_kwargs: 4 | input_shape: null 5 | translation: 8 -------------------------------------------------------------------------------- /configs/policy/data_augmentation/translation_aug_group.yaml: -------------------------------------------------------------------------------- 1 | network: TranslationAugGroup 2 | 3 | network_kwargs: 4 | input_shapes: {} 5 | translation: 8 -------------------------------------------------------------------------------- /configs/policy/groot_real_robot_no_wrist_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: GROOTRealRobotNoWristTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: masked_pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/image_encoder/patch_encoder.yaml: -------------------------------------------------------------------------------- 1 | network: PatchEncoder 2 | network_kwargs: 3 | patch_size: [8, 8] 4 | no_patch_embed_bias: false 5 | -------------------------------------------------------------------------------- /configs/policy/image_encoder/resnet_encoder.yaml: -------------------------------------------------------------------------------- 1 | network: ResnetEncoder 2 | network_kwargs: 3 | pretrained: false 4 | freeze: false 5 | remove_layer_num: 4 6 | no_stride: false 7 | language_fusion: 'film' 8 | -------------------------------------------------------------------------------- /configs/policy/language_encoder/clip_encoder.yaml: -------------------------------------------------------------------------------- 1 | network: CLIPEncoder 2 | network_kwargs: 3 | model_type: "ViT-B/32" 4 | hidden_size: 128 5 | output_size: 128 6 | num_layers: 1 7 | download_path: "./clip" 8 | -------------------------------------------------------------------------------- /configs/policy/language_encoder/identity_encoder.yaml: -------------------------------------------------------------------------------- 1 | network: IdentityEncoder 2 | network_kwargs: 3 | dummy: true 4 | -------------------------------------------------------------------------------- /configs/policy/language_encoder/mlp_encoder.yaml: -------------------------------------------------------------------------------- 1 | network: MLPEncoder 2 | network_kwargs: 3 | input_size: 768 4 | hidden_size: 128 5 | output_size: 128 6 | num_layers: 1 7 | -------------------------------------------------------------------------------- /configs/policy/language_encoder/rnn_encoder.yaml: -------------------------------------------------------------------------------- 1 | network: RNNEncoder 2 | network_kwargs: 3 | input_size: 768 4 | hidden_size: 128 5 | output_size: 16 6 | num_layers: 1 7 | -------------------------------------------------------------------------------- /configs/policy/pcd_encoder/masked_pointnet_encoder.yaml: -------------------------------------------------------------------------------- 1 | network: MaskedPointNetEncoder 2 | network_kwargs: 3 | group_cfg: 4 | num_group: 10 5 | group_size: 64 6 | masked_encoder_cfg: 7 | mask_ratio: 0.6 8 | mask_type: "rand" 9 | embed_dim: null 10 | output_size: null -------------------------------------------------------------------------------- /configs/policy/pcd_encoder/pointnet_encoder.yaml: -------------------------------------------------------------------------------- 1 | network: PointNetEncoder 2 | network_kwargs: 3 | input_shape: null 4 | output_size: 32 5 | language_dim: null 6 | global_feat: True 7 | -------------------------------------------------------------------------------- /configs/policy/policy_head/gmm_head.yaml: -------------------------------------------------------------------------------- 1 | network: GMMHead 2 | 3 | network_kwargs: 4 | hidden_size: 1024 5 | num_layers: 2 6 | min_std: 0.0001 7 | num_modes: 5 8 | low_eval_noise: false 9 | activation: "softplus" 10 | 11 | loss_kwargs: 12 | loss_coef: 1.0 13 | -------------------------------------------------------------------------------- /configs/policy/position_encoding/sinusoidal_position_encoding.yaml: -------------------------------------------------------------------------------- 1 | network: SinusoidalPositionEncoding 2 | network_kwargs: 3 | input_size: null 4 | inv_freq_factor: 10 5 | factor_ratio: null 6 | 7 | -------------------------------------------------------------------------------- /configs/policy/vos_3d_masked_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: VOS3DMaskedTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: masked_pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/vos_3d_object_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: VOS3DObjectTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/vos_3d_real_robot_masked_no_wrist_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: VOS3DRealRobotMaskedNoWristTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: masked_pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/vos_3d_real_robot_object_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: VOS3DRealRobotObjectTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/vos_3d_robot_embeddings_masked_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: VOS3DRobotEmbeddingsMaskedTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: masked_pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/vos_3d_robot_masked_ee_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: VOS3DRobotMaskedEETransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: masked_pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/vos_3d_robot_masked_no_wrist_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: VOS3DRobotMaskedNoWristTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: masked_pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/vos_3d_robot_masked_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: VOS3DRobotMaskedTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: masked_pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/vos_3d_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: VOS3DTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation.yaml 18 | # - image_encoder: resnet_encoder.yaml 19 | - pcd_encoder: pointnet_encoder.yaml 20 | - language_encoder: mlp_encoder.yaml 21 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 22 | - policy_head: gmm_head.yaml 23 | -------------------------------------------------------------------------------- /configs/policy/vos_3d_wrist_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: VOS3DWristTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/vos_3d_wrist_transformer_policy_v2.yaml: -------------------------------------------------------------------------------- 1 | policy_type: VOS3DWristTransformerPolicyV2 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - data_augmentation@pcd_aug: se3_augmentation_3.yaml 18 | - image_encoder@wrist_encoder: resnet_encoder.yaml 19 | # - image_encoder: resnet_encoder.yaml 20 | - pcd_encoder: pointnet_encoder.yaml 21 | - language_encoder: mlp_encoder.yaml 22 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 23 | - policy_head: gmm_head.yaml 24 | -------------------------------------------------------------------------------- /configs/policy/wrist_transformer_policy.yaml: -------------------------------------------------------------------------------- 1 | policy_type: WristTransformerPolicy 2 | extra_num_layers: 0 3 | extra_hidden_size: 128 4 | embed_size: 64 5 | 6 | transformer_input_size: null 7 | transformer_num_layers: 4 8 | transformer_num_heads: 6 9 | transformer_head_output_size: 64 10 | transformer_mlp_hidden_size: 256 11 | transformer_dropout: 0.1 12 | transformer_max_seq_len: 10 13 | 14 | defaults: 15 | - data_augmentation@color_aug: batch_wise_img_color_jitter_group_aug.yaml 16 | - data_augmentation@translation_aug: translation_aug.yaml 17 | - image_encoder@wrist_encoder: resnet_encoder.yaml 18 | - language_encoder: mlp_encoder.yaml 19 | - position_encoding@temporal_position_encoding: sinusoidal_position_encoding.yaml 20 | - policy_head: gmm_head.yaml 21 | -------------------------------------------------------------------------------- /configs/real_robot_experiments/example_exp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /policy: groot_real_robot_no_wrist_transformer_policy 5 | - override /policy/data_augmentation@policy.pcd_aug: se3_augmentation_3 6 | - override /train/scheduler: cosine_annealing_warm_restarts 7 | 8 | 9 | task_id: 3 10 | 11 | benchmark_name: GROOT_Real_Robot_Benchmark 12 | 13 | policy: 14 | pcd_aug: 15 | network_kwargs: 16 | mean: 0.0 17 | std: 0.001 18 | use_position: True 19 | enabled: True 20 | use_rotation: False 21 | rot_range: [0.0, 0.0, 0.0] 22 | # embed_size: 32 23 | # transformer_head_output_size: 32 24 | # transformer_mlp_hidden_size: 128 25 | pcd_encoder: 26 | network_kwargs: 27 | masked_encoder_cfg: 28 | mask_ratio: 0.75 29 | train: 30 | grad_clip: 100. 31 | n_epochs: 100 32 | # batch_size: 64 33 | 34 | data: 35 | obs: 36 | modality: 37 | pcd: ["xyz"] 38 | # normalized_pcd: ["xyz"] 39 | wrist_depth: [] 40 | low_dim: ["gripper_states", "joint_states"] 41 | grouped_pcd: ["neighborhood_10_64", "centers_10_64"] 42 | 43 | use_joint: true 44 | 45 | eval: 46 | eval_every: 5 47 | 48 | datasets: 49 | # dataset point cloud normalization range 50 | max_array: [0.69943695, 0.5, 0.45784091] 51 | min_array: [0.0, -0.5, 0.0] 52 | 53 | experiment_description: "normalized point clouds, canonical, no aug dataset" 54 | -------------------------------------------------------------------------------- /configs/templates/augmentation_template.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /policy: GROOT_robot_masked_transformer_policy 5 | - override /policy/data_augmentation@policy.pcd_aug: se3_augmentation_3 6 | - override /train/scheduler: cosine_annealing_warm_restarts 7 | 8 | benchmark_name: GROOT_Real_Robot_Benchmark 9 | 10 | task_id: 0 11 | 12 | policy: 13 | pcd_aug: 14 | network_kwargs: 15 | mean: 0.0 16 | std: 0.001 17 | use_position: True 18 | enabled: True 19 | use_rotation: False 20 | rot_range: [0.0, 0.0, 0.0] 21 | pcd_encoder: 22 | network_kwargs: 23 | group_cfg: 24 | num_group: 10 25 | group_size: 64 26 | 27 | train: 28 | grad_clip: 100. 29 | n_epochs: 100 30 | # batch_size: 64 31 | 32 | data: 33 | obs: 34 | modality: 35 | pcd: [] 36 | normalized_pcd: ["xyz"] 37 | wrist_depth: ["eye_in_hand_depth"] 38 | low_dim: ["gripper_states", "joint_states"] 39 | grouped_pcd: ["neighborhood_10_64", "centers_10_64"] 40 | 41 | use_joint: true 42 | 43 | experiment_description: "" -------------------------------------------------------------------------------- /configs/templates/grouping_template.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /policy: groot_robot_masked_transformer_policy 5 | - override /policy/data_augmentation@policy.pcd_aug: se3_augmentation_3 6 | - override /train/scheduler: cosine_annealing_warm_restarts 7 | 8 | benchmark_name: GROOT_Real_Robot_Benchmark 9 | 10 | task_id: 1 11 | 12 | policy: 13 | pcd_aug: 14 | network_kwargs: 15 | mean: 0.0 16 | std: 0.001 17 | use_position: True 18 | enabled: True 19 | use_rotation: False 20 | rot_range: [0.0, 0.0, 0.0] 21 | pcd_encoder: 22 | network_kwargs: 23 | group_cfg: 24 | num_group: 10 25 | group_size: 64 26 | 27 | train: 28 | grad_clip: 100. 29 | n_epochs: 100 30 | # batch_size: 64 31 | 32 | data: 33 | obs: 34 | modality: 35 | pcd: [] 36 | normalized_pcd: ["xyz"] 37 | wrist_depth: ["eye_in_hand_depth"] 38 | low_dim: ["gripper_states", "joint_states"] 39 | grouped_pcd: ["neighborhood_5_32", "centers_5_32"] 40 | 41 | use_joint: true 42 | 43 | experiment_description: "" -------------------------------------------------------------------------------- /configs/templates/masked_transformer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /policy: groot_masked_transformer_policy 5 | - override /policy/data_augmentation@policy.pcd_aug: se3_augmentation_3 6 | - override /train/scheduler: cosine_annealing_warm_restarts 7 | 8 | 9 | task_id: 2 10 | 11 | policy: 12 | pcd_aug: 13 | network_kwargs: 14 | mean: 0.0 15 | std: 0.002 16 | use_position: True 17 | enabled: True 18 | use_rotation: False 19 | rot_range: [0.0, 0.0, 0.0] 20 | 21 | train: 22 | grad_clip: 100. 23 | n_epochs: 100 24 | 25 | data: 26 | obs: 27 | modality: 28 | pcd: [] 29 | normalized_pcd: ["xyz"] 30 | wrist_depth: ["eye_in_hand_depth"] 31 | low_dim: ["gripper_states", "joint_states"] 32 | 33 | use_joint: true -------------------------------------------------------------------------------- /configs/train/default.yaml: -------------------------------------------------------------------------------- 1 | # training 2 | n_epochs: 100 3 | batch_size: 16 # 32 4 | num_workers: 4 5 | grad_clip: 100. 6 | loss_scale: 1.0 7 | 8 | # resume training 9 | resume: false 10 | resume_path: "" 11 | debug: false 12 | 13 | use_augmentation: true 14 | 15 | defaults: 16 | - optimizer@optimizer: adam_w.yaml 17 | - scheduler@scheduler: cosine_annealing.yaml 18 | -------------------------------------------------------------------------------- /configs/train/optimizer/adam_w.yaml: -------------------------------------------------------------------------------- 1 | name: torch.optim.AdamW 2 | 3 | kwargs: 4 | lr: 0.0001 5 | betas: [0.9, 0.999] 6 | weight_decay: 0.0001 7 | -------------------------------------------------------------------------------- /configs/train/scheduler/cosine_annealing.yaml: -------------------------------------------------------------------------------- 1 | name: torch.optim.lr_scheduler.CosineAnnealingLR 2 | 3 | kwargs: 4 | eta_min: 1e-5 5 | last_epoch: -1 6 | -------------------------------------------------------------------------------- /configs/train/scheduler/cosine_annealing_warm_restarts.yaml: -------------------------------------------------------------------------------- 1 | name: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts 2 | 3 | kwargs: 4 | T_0: 1 5 | T_mult: 1 6 | eta_min: 1e-6 7 | last_epoch: -1 -------------------------------------------------------------------------------- /docs/scribble_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/docs/scribble_demo.gif -------------------------------------------------------------------------------- /download_example_data.sh: -------------------------------------------------------------------------------- 1 | wget -O frame.jpg https://utexas.box.com/shared/static/zammwzfn7b5utqv6y94ra9nx0zc729fy.jpg 2 | 3 | wget -O frame_annotation.png https://utexas.box.com/shared/static/18wy8v2fanqjzakl4b27y351626o6e0h.png 4 | 5 | wget -O example_demo.hdf5 https://utexas.box.com/shared/static/v68ixuug2j4edg537dgunml8patv4x1e.hdf5 6 | 7 | wget -O example_new_object.jpg https://utexas.box.com/shared/static/2eky71626yqr71mgfg0z1w40hfzfekj9.jpg 8 | 9 | mkdir datasets 10 | mkdir -p datasets/annotations/example_demo 11 | 12 | mv frame.jpg datasets/annotations/example_demo 13 | mv frame_annotation.png datasets/annotations/example_demo 14 | mv example_demo.hdf5 datasets/ 15 | mv example_new_object.jpg datasets/annotations/example_demo 16 | -------------------------------------------------------------------------------- /groot_imitation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/groot_imitation/__init__.py -------------------------------------------------------------------------------- /groot_imitation/groot_algo/dataset_preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/groot_imitation/groot_algo/dataset_preprocessing/__init__.py -------------------------------------------------------------------------------- /groot_imitation/groot_algo/dataset_preprocessing/vos_annotation.py: -------------------------------------------------------------------------------- 1 | """Get all the masks for the demonstration dataset""" 2 | import h5py 3 | import cv2 4 | import os 5 | import argparse 6 | import numpy as np 7 | from PIL import Image 8 | from tqdm import tqdm 9 | 10 | from kaede_utils.visualization_utils.video_utils import KaedeVideoWriter 11 | 12 | import init_path 13 | from groot_imitation.groot_algo.xmem_tracker import XMemTracker 14 | from groot_imitation.groot_algo.misc_utils import get_annotation_path, get_first_frame_annotation, VideoWriter 15 | from groot_imitation.groot_algo.o3d_modules import convert_convention 16 | 17 | # def parse_args(): 18 | # args = argparse.ArgumentParser(description='Get all the masks for the demonstration dataset') 19 | # args.add_argument('--dataset', type=str, default='data/demonstration_dataset.hdf5', help='path to the demonstration dataset') 20 | # # visualize video for debugging 21 | # args.add_argument('--save-video', action='store_true', help='visualize the video') 22 | # # see if it's real data 23 | # args.add_argument('--real', action='store_true', help='real data') 24 | # args.add_argument('--multi-instance', action='store_true', help='multi-instance case') 25 | # args.add_argument('--verbose', action='store_true', help='verbose') 26 | # return args.parse_args() 27 | 28 | def dataset_vos_annotation(cfg, 29 | dataset_name, 30 | mask_dataset_name, 31 | xmem_tracker, 32 | annotation_folder, 33 | save_video=True, 34 | is_real_robot=True, 35 | verbose=False, 36 | ): 37 | """This is the case where we only focus on manipulation one specific-instance.""" 38 | 39 | if save_video: 40 | video_writer = VideoWriter(annotation_folder, "annotation_video.mp4", fps=40.0) 41 | 42 | # first_frame = cv2.imread(os.path.join(annotation_folder, "frame.jpg")) 43 | # first_frame = first_frame[:, :, ::-1] 44 | # first_frame_annotation = np.array(Image.open((os.path.join(annotation_folder, "frame_annotation.png")))) 45 | first_frame, first_frame_annotation = get_first_frame_annotation(annotation_folder) 46 | with h5py.File(dataset_name, 'r') as dataset, h5py.File(mask_dataset_name, 'w') as new_dataset: 47 | 48 | # TODO: Speciify if it's a multi-instance case 49 | count = 0 50 | new_dataset.create_group("data") 51 | for demo in tqdm(dataset["data"].keys()): 52 | xmem_tracker.clear_memory() 53 | if verbose: 54 | print("processing demo: ", demo) 55 | images = dataset[f"data/{demo}/obs/agentview_rgb"][()] 56 | image_list = [first_frame] 57 | for image in images: 58 | image_list.append(convert_convention(image, real_robot=is_real_robot)) 59 | image_list = [cv2.resize(image, (first_frame_annotation.shape[1], first_frame_annotation.shape[0]), interpolation=cv2.INTER_AREA) for image in image_list] 60 | masks = xmem_tracker.track_video(image_list, first_frame_annotation) 61 | 62 | if verbose: 63 | print(len(image_list), len(masks)) 64 | 65 | new_dataset.create_group(f"data/{demo}/obs") 66 | new_dataset[f"data/{demo}/obs"].create_dataset("agentview_masks", data=np.stack(masks[1:], axis=0)) 67 | assert(len(masks[1:]) == len(images)) 68 | if save_video: 69 | overlay_images = [] 70 | for rgb_img, mask in zip(image_list, masks): 71 | colored_mask = Image.fromarray(mask) 72 | colored_mask.putpalette(xmem_tracker.palette) 73 | colored_mask = np.array(colored_mask.convert("RGB")) 74 | overlay_img = cv2.addWeighted(rgb_img, 0.7, colored_mask, 0.3, 0) 75 | 76 | overlay_images.append(overlay_img) 77 | video_writer.append_image(overlay_img) 78 | 79 | if save_video: 80 | video_writer.save(flip=True, bgr=False) 81 | -------------------------------------------------------------------------------- /groot_imitation/groot_algo/eval_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import robomimic.utils.obs_utils as ObsUtils 4 | import robomimic.utils.tensor_utils as TensorUtils 5 | from libero.lifelong.utils import * 6 | 7 | 8 | def raw_obs_to_tensor_obs(obs, task_emb, cfg): 9 | """ 10 | Prepare the tensor observations as input for the algorithm. 11 | """ 12 | 13 | data = { 14 | "obs": {}, 15 | "task_emb": task_emb 16 | } 17 | 18 | all_obs_keys = [] 19 | for modality_name, modality_list in cfg.data.obs.modality.items(): 20 | for obs_name in modality_list: 21 | data["obs"][obs_name] = [] 22 | all_obs_keys += modality_list 23 | 24 | for obs_name in all_obs_keys: 25 | if obs_name in cfg.data.obs_key_mapping: 26 | mapped_name = cfg.data.obs_key_mapping[obs_name] 27 | else: 28 | mapped_name = obs_name 29 | if "neighborhood" in mapped_name or "centers" in mapped_name: 30 | continue 31 | data["obs"][obs_name] = torch.from_numpy(ObsUtils.process_obs( 32 | obs[mapped_name], 33 | obs_key=obs_name)).float().unsqueeze(0) 34 | 35 | data = TensorUtils.map_tensor(data, 36 | lambda x: safe_device(x, device=cfg.device)) 37 | return data 38 | 39 | 40 | def raw_real_obs_to_tensor_obs(obs, task_emb, cfg): 41 | """ 42 | Prepare the tensor observations as input for the algorithm. 43 | """ 44 | 45 | data = { 46 | "obs": {}, 47 | "task_emb": task_emb 48 | } 49 | 50 | all_obs_keys = [] 51 | for modality_name, modality_list in cfg.data.obs.modality.items(): 52 | for obs_name in modality_list: 53 | data["obs"][obs_name] = [] 54 | all_obs_keys += modality_list 55 | 56 | for obs_name in all_obs_keys: 57 | mapped_name = obs_name 58 | if "neighborhood" in mapped_name or "centers" in mapped_name: 59 | continue 60 | data["obs"][obs_name] = torch.from_numpy(ObsUtils.process_obs( 61 | obs[mapped_name], 62 | obs_key=obs_name)).float().unsqueeze(0) 63 | 64 | data = TensorUtils.map_tensor(data, 65 | lambda x: safe_device(x, device=cfg.device)) 66 | return data 67 | -------------------------------------------------------------------------------- /groot_imitation/groot_algo/groot_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .groot_transformer_algo import GROOTSingleTask 2 | from .groot_real_robot_no_wrist_transformer_policy import GROOTRealRobotNoWristTransformerPolicy 3 | -------------------------------------------------------------------------------- /groot_imitation/groot_algo/input_utils.py: -------------------------------------------------------------------------------- 1 | """Process inputs""" 2 | import os 3 | 4 | import numpy as np 5 | from PIL import Image 6 | -------------------------------------------------------------------------------- /groot_imitation/segmentation_correspondence_model/init_path.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | path = os.path.dirname(os.path.realpath(__file__)) 4 | sys.path.insert(0, os.path.join(path, '../')) 5 | 6 | # For XMem 7 | sys.path.append("./third_party/XMem") 8 | sys.path.append("./third_party/XMem/model") 9 | sys.path.append("./third_party/XMem/util") 10 | sys.path.append("./third_party/XMem/inference") 11 | 12 | -------------------------------------------------------------------------------- /groot_imitation/vision_model_configs/sam_config.yaml: -------------------------------------------------------------------------------- 1 | points_per_side: 32 # type: int or null 2 | points_per_batch: 64 # type: int 3 | pred_iou_thresh: 0.88 # type: float 4 | stability_score_thresh: 0.95 # type: float 5 | stability_score_offset: 1.0 # type: float 6 | box_nms_thresh: 0.7 # type: float 7 | crop_n_layers: 0 # type: int 8 | crop_nms_thresh: 0.7 # type: float 9 | crop_overlap_ratio: 0.3413333333333333 # (512/1500) type: float 10 | crop_n_points_downscale_factor: 1 # type: int 11 | point_grids: null # type: List[np.ndarray] or null 12 | min_mask_region_area: 0 # type: int 13 | -------------------------------------------------------------------------------- /groot_imitation/vision_model_configs/xmem_config.yaml: -------------------------------------------------------------------------------- 1 | benchmark: False 2 | disable_long_term: False 3 | enable_long_term: True 4 | enable_long_term_count_usage: True 5 | max_mid_term_frames: 10 6 | min_mid_term_frames: 5 7 | max_long_term_elements: 10000 8 | num_prototypes: 128 9 | buffer_size: 100 10 | top_k: 30 11 | mem_every: 5 12 | deep_update_every: -1 13 | save_scores: False 14 | flip: False 15 | size: 480 16 | enable_long_term: True 17 | enable_long_term_count_usage: True 18 | 19 | model: './xmem_checkpoints/XMem.pth' 20 | s2m_model: 'xmem_checkpoints/s2m.pth' 21 | fbrs_model: 'xmem_checkpoins/fbrs.pth' -------------------------------------------------------------------------------- /real_robot_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/real_robot_scripts/__init__.py -------------------------------------------------------------------------------- /real_robot_scripts/deoxys_reset_joints.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from deoxys import config_root 4 | from deoxys.franka_interface import FrankaInterface 5 | from deoxys.utils import YamlConfig 6 | from deoxys.utils.log_utils import get_deoxys_example_logger 7 | 8 | logger = get_deoxys_example_logger() 9 | 10 | 11 | class DeoxysResetJoints: 12 | def __init__(self, 13 | interface_cfg="charmander.yml", controller_cfg="joint-position-controller.yml"): 14 | self.interface_cfg = interface_cfg 15 | self.controller_cfg = controller_cfg 16 | self.controller_type = "JOINT_POSITION" 17 | self.reset_joint_positions = [ 18 | 0.09162008114028396, 19 | -0.19826458111314524, 20 | -0.01990020486871322, 21 | -2.4732269941140346, 22 | -0.01307073642274261, 23 | 2.30396583422025, 24 | 0.8480939705504309, 25 | ] 26 | 27 | def run(self): 28 | robot_interface = FrankaInterface( 29 | config_root + f"/{self.interface_cfg}", use_visualizer=False 30 | ) 31 | controller_cfg = YamlConfig(config_root + f"/{self.controller_cfg}").as_easydict() 32 | 33 | self.reset_joint_positions = [ 34 | e + np.clip(np.random.randn() * 0.005, -0.005, 0.005) 35 | for e in self.reset_joint_positions 36 | ] 37 | action = self.reset_joint_positions + [-1.0] 38 | 39 | while True: 40 | if len(robot_interface._state_buffer) > 0: 41 | logger.info(f"Current Robot joint: {np.round(robot_interface.last_q, 3)}") 42 | logger.info(f"Desired Robot joint: {np.round(robot_interface.last_q_d, 3)}") 43 | 44 | if ( 45 | np.max( 46 | np.abs( 47 | np.array(robot_interface._state_buffer[-1].q) 48 | - np.array(self.reset_joint_positions) 49 | ) 50 | ) 51 | < 1e-3 52 | ): 53 | break 54 | robot_interface.control( 55 | controller_type=self.controller_type, 56 | action=action, 57 | controller_cfg=controller_cfg, 58 | ) 59 | robot_interface.close() 60 | 61 | 62 | if __name__ == "__main__": 63 | reset_joint = DeoxysResetJoints() 64 | 65 | reset_joint.run() 66 | -------------------------------------------------------------------------------- /real_robot_scripts/evaluation_configs/eval_camera.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: "camera" 2 | new_instance_idx: -1 -------------------------------------------------------------------------------- /real_robot_scripts/evaluation_configs/eval_canonical.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: "canonical" 2 | new_instance_idx: -1 -------------------------------------------------------------------------------- /real_robot_scripts/evaluation_configs/eval_distractions.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: "distractions" 2 | new_instance_idx: -1 -------------------------------------------------------------------------------- /real_robot_scripts/evaluation_configs/eval_new_instances.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: "new_instances" 2 | new_instance_idx: -1 -------------------------------------------------------------------------------- /real_robot_scripts/groot_img_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | class ImageProcessor(): 5 | def __init__(self): 6 | pass 7 | 8 | def get_fx_fy_dict(self, img_size=224): 9 | if img_size == 224: 10 | fx_fy_dict = { 11 | "k4a": { 12 | 0: {"fx": 0.35, "fy": 0.35}, 13 | 1: {"fx": 0.35, "fy": 0.35}, 14 | 2: {"fx": 0.4, "fy": 0.6} 15 | }, 16 | "rs": { 17 | 0: {"fx": 0.49, "fy": 0.49}, 18 | 1: {"fx": 0.49, "fy": 0.49}, 19 | } 20 | } 21 | # elif img_size == 128: 22 | # fx_fy_dict = {0: {"fx": 0.2, "fy": 0.2}, 1: {"fx": 0.2, "fy": 0.2}, 2: {"fx": 0.2, "fy": 0.3}} 23 | # elif img_size == 84: 24 | # fx_fy_dict = {0: {"fx": 0.13, "fy": 0.13}, 1: {"fx": 0.13, "fy": 0.13}, 2: {"fx": 0.15, "fy": 0.225}} 25 | return fx_fy_dict 26 | 27 | def resize_img( 28 | self, 29 | img: np.ndarray, 30 | camera_type: str, 31 | img_w: int=224, 32 | img_h: int=224, 33 | offset_w: int=0, 34 | offset_h: int=0, 35 | fx: float=None, 36 | fy: float=None) -> np.ndarray: 37 | if camera_type == "k4a": 38 | if fx is None: 39 | fx = 0.2 40 | if fy is None: 41 | fy = 0.2 42 | resized_img = cv2.resize(img, (0, 0), fx=fx, fy=fy, interpolation = cv2.INTER_NEAREST) 43 | w = resized_img.shape[0] 44 | h = resized_img.shape[1] 45 | 46 | if camera_type == "rs": 47 | if fx is None: 48 | fx = 0.2 49 | if fy is None: 50 | fy = 0.3 51 | resized_img = cv2.resize(img, (0, 0), fx=fx, fy=fy, interpolation = cv2.INTER_NEAREST) 52 | w = resized_img.shape[0] 53 | h = resized_img.shape[1] 54 | 55 | resized_img = resized_img[w//2-img_w//2:w//2+img_w//2, h//2-img_h//2:h//2+img_h//2, ...] 56 | return resized_img 57 | 58 | def resize_intrinsics( 59 | self, 60 | original_image_size: np.ndarray, 61 | intrinsic_matrix: np.ndarray, 62 | camera_type: str, 63 | img_w: int=224, 64 | img_h: int=224, 65 | fx: float=None, 66 | fy: float=None) -> np.ndarray: 67 | if camera_type == "k4a": 68 | if fx is None: 69 | fx = 0.2 70 | if fy is None: 71 | fy = 0.2 72 | elif camera_type == "rs": 73 | if fx is None: 74 | fx = 0.2 75 | if fy is None: 76 | fy = 0.3 77 | 78 | fake_image = np.zeros((original_image_size[0], original_image_size[1], 3)) 79 | 80 | resized_img = cv2.resize(fake_image, (0, 0), fx=fx, fy=fy) 81 | new_intrinsic_matrix = intrinsic_matrix.copy() 82 | w, h = resized_img.shape[0], resized_img.shape[1] 83 | new_intrinsic_matrix[0, 0] = intrinsic_matrix[0, 0] * fx 84 | new_intrinsic_matrix[1, 1] = intrinsic_matrix[1, 1] * fy 85 | new_intrinsic_matrix[0, 2] = intrinsic_matrix[0, 2] * fx 86 | new_intrinsic_matrix[1, 2] = intrinsic_matrix[1, 2] * fy 87 | new_intrinsic_matrix[0, 2] = new_intrinsic_matrix[0, 2] - (w//2-img_w//2) 88 | new_intrinsic_matrix[1, 2] = new_intrinsic_matrix[1, 2] - (h//2-img_h//2) 89 | return new_intrinsic_matrix 90 | -------------------------------------------------------------------------------- /real_robot_scripts/init_path.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | path = os.path.dirname(os.path.realpath(__file__)) 4 | sys.path.insert(0, os.path.join(path, '../')) 5 | 6 | # For XMem 7 | sys.path.append("./third_party/XMem") 8 | sys.path.append("./third_party/XMem/model") 9 | sys.path.append("./third_party/XMem/util") 10 | sys.path.append("./third_party/XMem/inference") 11 | # sys.path.append("./third_party/XMem/inference/interact") 12 | 13 | # import robosuite.macros as macros 14 | # macros.IMAGE_CONVENTION = "opencv" 15 | -------------------------------------------------------------------------------- /real_robot_scripts/real_robot_observation_cfg_example.yml: -------------------------------------------------------------------------------- 1 | # camera_ids: [0, 2] 2 | 3 | camera_refs : ["rs_0"] 4 | camera_types: 5 | "camera_rs_0": "rs" 6 | "camera_rs_1": "rs" 7 | 8 | img_h: 224 9 | img_w: 224 10 | 11 | camera_name_conversion: 12 | "camera_rs_0": "agentview" 13 | "camera_rs_1": "agentview" 14 | "camera_k4a_0": "agentview" 15 | "camera_k4a_1": "agentview" 16 | 17 | obs_key_mapping: 18 | # camera_0 and camera_1 are workspace cameras 19 | camera_0_color: agentview_rgb 20 | camera_0_depth: agentview_depth 21 | camera_1_color: agnetview_rgb 22 | camera_1_depth: agentview_depth 23 | 24 | # camera_2 is the eye-in-hand camera 25 | camera_2_color: eye_in_hand_rgb 26 | camera_2_depth: eye_in_hand_depth 27 | gripper_states: gripper_states 28 | joint_states: joint-states 29 | ee_states: ee_states 30 | xyz: xyz 31 | # This is point cloud rgb, not images 32 | rgb: rgb 33 | neighborhood_10_64: neighborhood_10_64 34 | centers_10_64: centers_10_64 35 | 36 | 37 | datasets: 38 | max_points: 512 39 | num_group: 10 40 | group_size: 64 41 | 42 | # dataset point cloud normalization range 43 | max_array: [0.69943695, 0.5, 0.45784091] 44 | min_array: [0.0, -0.5, 0.0] 45 | 46 | # configuration for augmenting depth point clouds 47 | aug: 48 | workspace_center: [0.6, 0.0, 0.0] 49 | rotations: [30] 50 | 51 | erode: false -------------------------------------------------------------------------------- /real_robot_scripts/verify_camera_observations.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a script to verify the camera observations of the real robot. We assume that you are using deoxys_vision for capturing images. If you are using a different vision pipeline, please modify the code accordingly. 3 | """ 4 | 5 | import plotly.graph_objs as go 6 | 7 | from deoxys import config_root 8 | from deoxys.franka_interface import FrankaInterface 9 | from deoxys.utils import YamlConfig 10 | from deoxys.utils.input_utils import input2action 11 | from deoxys.utils.io_devices import SpaceMouse 12 | from deoxys.utils.log_utils import get_deoxys_example_logger 13 | 14 | from deoxys_vision.networking.camera_redis_interface import \ 15 | CameraRedisSubInterface 16 | from deoxys_vision.utils.calibration_utils import load_default_extrinsics, load_default_intrinsics 17 | from deoxys_vision.utils.camera_utils import assert_camera_ref_convention, get_camera_info 18 | 19 | import init_path 20 | from groot_imitation.groot_algo.dataset_preprocessing.pcd_generation import scene_pcd_fn 21 | 22 | from real_robot_scripts.groot_img_utils import ImageProcessor 23 | from real_robot_scripts.real_robot_utils import RealRobotObsProcessor 24 | 25 | def main(): 26 | # Make sure that you've launched camera nodes somewhere else 27 | observation_cfg = YamlConfig("real_robot_scripts/real_robot_observation_cfg_example.yml").as_easydict() 28 | 29 | observation_cfg.cameras = [] 30 | for camera_ref in observation_cfg.camera_refs: 31 | assert_camera_ref_convention(camera_ref) 32 | camera_info = get_camera_info(camera_ref) 33 | 34 | observation_cfg.cameras.append(camera_info) 35 | 36 | # cr_interfaces = {} 37 | # for camera_info in observation_cfg.cameras: 38 | # cr_interface = CameraRedisSubInterface(camera_info=camera_info) 39 | # cr_interface.start() 40 | # cr_interfaces[camera_info.camera_name] = cr_interface 41 | 42 | # # type_fn = lambda x: observation_cfg.camera_types[f"camera_{x}"] 43 | 44 | # color_images = [] 45 | # depth_images = [] 46 | # for camera_name in cr_interfaces.keys(): 47 | # camera_type = observation_cfg.cameras[camera_name].camera_type 48 | # camera_id = observation_cfg.cameras[camera_name].camera_id 49 | # extrinsics = load_default_extrinsics(camera_id, camera_type, calibration_method="tsai", fmt="dict") 50 | # intrinsics = load_default_intrinsics(camera_id, camera_type, fmt="dict") 51 | 52 | # imgs = cr_interfaces[camera_id].get_img() 53 | # img_info = cr_interfaces[camera_id].get_img_info() 54 | # color_images.append(imgs['color']) 55 | # depth_images.append(imgs['depth']) 56 | 57 | 58 | obs_processor = RealRobotObsProcessor(observation_cfg, 59 | processor_name="ImageProcessor") 60 | obs_processor.load_intrinsic_matrix(resize=False) 61 | obs_processor.load_extrinsic_matrix() 62 | extrinsic_matrix = obs_processor.get_extrinsic_matrix("agentview") 63 | intrinsic_matrix = obs_processor.get_intrinsic_matrix("agentview") 64 | 65 | color_imgs, depth_imgs = obs_processor.get_original_imgs() 66 | print(color_imgs[0].shape) 67 | 68 | pcd_points, pcd_colors = scene_pcd_fn( 69 | observation_cfg, 70 | rgb_img_input=color_imgs[0], 71 | depth_img_input=depth_imgs[0], 72 | extrinsic_matrix=extrinsic_matrix, 73 | intrinsic_matrix=intrinsic_matrix, 74 | ) 75 | 76 | # visualize point clouds using plotly 77 | color_str = ['rgb('+str(r)+','+str(g)+','+str(b)+')' for r,g,b in pcd_colors] 78 | 79 | # Extract x, y, and z columns from the point cloud 80 | x_vals = pcd_points[:, 0] 81 | y_vals = pcd_points[:, 1] 82 | z_vals = pcd_points[:, 2] 83 | 84 | # Create the scatter3d plot 85 | rgbd_scatter = go.Scatter3d( 86 | x=x_vals, 87 | y=y_vals, 88 | z=z_vals, 89 | mode='markers', 90 | marker=dict(size=3, color=color_str, opacity=0.8) 91 | ) 92 | 93 | # Set the layout for the plot 94 | layout = go.Layout( 95 | margin=dict(l=0, r=0, b=0, t=0) 96 | ) 97 | 98 | fig = go.Figure(data=[rgbd_scatter], layout=layout) 99 | 100 | # Show the figure 101 | fig.show() 102 | 103 | # Get the camera info 104 | 105 | 106 | 107 | 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.2.0 2 | numpy==1.22.4 3 | wandb==0.13.1 4 | easydict==1.9 5 | transformers==4.21.1 6 | opencv-python==4.6.0.66 7 | robomimic==0.2.0 8 | einops==0.4.1 9 | thop==0.1.1-2209072238 10 | robosuite==1.4.0 11 | bddl==1.0.1 12 | future==0.18.2 13 | matplotlib==3.5.3 14 | cloudpickle==2.1.0 15 | gym==0.25.2 16 | open3d 17 | ninja 18 | -------------------------------------------------------------------------------- /scripts/dataset_configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | 6 | dataset_path: null 7 | save_video: true 8 | 9 | vos_annotation: true 10 | object_pcd: true 11 | pcd_aug: true 12 | pcd_grouping: true 13 | delete_intermediate_files: true 14 | 15 | datasets: 16 | max_points: 512 17 | num_group: 10 18 | group_size: 64 19 | 20 | # dataset point cloud normalization range 21 | max_array: [0.69943695, 0.5, 0.45784091] 22 | min_array: [0.0, -0.5, 0.0] 23 | 24 | # configuration for augmenting depth point clouds 25 | aug: 26 | workspace_center: [0.6, 0.0, 0.0] 27 | rotations: [30] 28 | 29 | erode: false -------------------------------------------------------------------------------- /scripts/dataset_configs/datasets/default.yaml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /scripts/init_path.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | path = os.path.dirname(os.path.realpath(__file__)) 4 | sys.path.insert(0, os.path.join(path, '../')) 5 | 6 | # For XMem 7 | sys.path.append("./third_party/XMem") 8 | sys.path.append("./third_party/XMem/model") 9 | sys.path.append("./third_party/XMem/util") 10 | sys.path.append("./third_party/XMem/inference") 11 | # sys.path.append("./third_party/XMem/inference/interact") 12 | 13 | # import robosuite.macros as macros 14 | # macros.IMAGE_CONVENTION = "opencv" 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # read the contents of your README file 2 | from os import path 3 | 4 | from setuptools import find_packages, setup 5 | 6 | this_directory = path.abspath(path.dirname(__file__)) 7 | with open(path.join(this_directory, "./README.md"), encoding="utf-8") as f: 8 | lines = f.readlines() 9 | 10 | # remove images from README 11 | lines = [x for x in lines if ".png" not in x] 12 | long_description = "".join(lines) 13 | 14 | setup( 15 | name="groot", 16 | packages=[package for package in find_packages() if package.startswith("groot_imitation")], 17 | install_requires=[], 18 | eager_resources=["*"], 19 | include_package_data=True, 20 | python_requires=">=3", 21 | description="GROOT at UT-Austin RPL", 22 | author="Yifeng Zhu", 23 | # url="https://github.com/ARISE-Initiative/robosuite", 24 | author_email="yifengz@cs.utexas.edu", 25 | version="0.1.0", 26 | long_description=long_description, 27 | long_description_content_type="text/markdown", 28 | ) -------------------------------------------------------------------------------- /setup_vision_models.sh: -------------------------------------------------------------------------------- 1 | 2 | cd third_party 3 | 4 | # git clone git@github.com:hkchengrex/XMem.git 5 | 6 | cd XMem 7 | 8 | # inside third_party/XMem/ 9 | bash scripts/download_models_demo.sh 10 | 11 | # back to third_party/ 12 | cd .. 13 | 14 | mkdir xmem_checkpoints 15 | 16 | mv XMem/saves/*pth xmem_checkpoints/ 17 | 18 | # dinov2 19 | git clone git@github.com:facebookresearch/dinov2.git 20 | cd dinov2 21 | git checkout fc49f49d734c767272a4ea0e18ff2ab8e60fc92d 22 | pip install -r requirements.txt 23 | pip install -e . 24 | 25 | cd .. 26 | 27 | # SAM 28 | git clone https://github.com/facebookresearch/segment-anything 29 | cd segment-anything 30 | git checkout 6fdee8f2727f4506cfbbe553e23b895e27956588 31 | pip install -e . 32 | 33 | cd .. 34 | mkdir sam_checkpoints 35 | cd sam_checkpoints 36 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 37 | 38 | -------------------------------------------------------------------------------- /third_party/XMem/.gitignore: -------------------------------------------------------------------------------- 1 | log/ 2 | saves 3 | saves/ 4 | output/ 5 | .vscode/ 6 | workspace/ 7 | run*.sh 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /third_party/XMem/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/dataset/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/dataset/range_transform.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | im_mean = (124, 116, 104) 4 | 5 | im_normalization = transforms.Normalize( 6 | mean=[0.485, 0.456, 0.406], 7 | std=[0.229, 0.224, 0.225] 8 | ) 9 | 10 | inv_im_trans = transforms.Normalize( 11 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 12 | std=[1/0.229, 1/0.224, 1/0.225]) 13 | -------------------------------------------------------------------------------- /third_party/XMem/dataset/reseed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | def reseed(seed): 5 | random.seed(seed) 6 | torch.manual_seed(seed) -------------------------------------------------------------------------------- /third_party/XMem/dataset/tps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import cv2 4 | import thinplate as tps 5 | 6 | cv2.setNumThreads(0) 7 | 8 | def pick_random_points(h, w, n_samples): 9 | y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False) 10 | x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False) 11 | return y_idx/h, x_idx/w 12 | 13 | 14 | def warp_dual_cv(img, mask, c_src, c_dst): 15 | dshape = img.shape 16 | theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True) 17 | grid = tps.tps_grid(theta, c_dst, dshape) 18 | mapx, mapy = tps.tps_grid_to_remap(grid, img.shape) 19 | return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST) 20 | 21 | 22 | def random_tps_warp(img, mask, scale, n_ctrl_pts=12): 23 | """ 24 | Apply a random TPS warp of the input image and mask 25 | Uses randomness from numpy 26 | """ 27 | img = np.asarray(img) 28 | mask = np.asarray(mask) 29 | 30 | h, w = mask.shape 31 | points = pick_random_points(h, w, n_ctrl_pts) 32 | c_src = np.stack(points, 1) 33 | c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape) 34 | warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst) 35 | 36 | return Image.fromarray(warp_im), Image.fromarray(warp_gt) 37 | 38 | -------------------------------------------------------------------------------- /third_party/XMem/dataset/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def all_to_onehot(masks, labels): 5 | if len(masks.shape) == 3: 6 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) 7 | else: 8 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) 9 | 10 | for ni, l in enumerate(labels): 11 | Ms[ni] = (masks == l).astype(np.uint8) 12 | 13 | return Ms 14 | -------------------------------------------------------------------------------- /third_party/XMem/docs/ECCV-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/docs/ECCV-logo.png -------------------------------------------------------------------------------- /third_party/XMem/docs/FAILURE_CASES.md: -------------------------------------------------------------------------------- 1 | # Failure Cases 2 | 3 | Like all methods, XMem can fail. Here, we try to show some illustrative and frankly consistent failure modes that we noticed. We slowed down all videos for visualization. 4 | 5 | ## Fast motion, similar objects 6 | 7 | The first one is fast motion with similarly-looking objects that do not provide sufficient appearance clues for XMem to track. Below is an example from the YouTubeVOS validation set (0e8a6b63bb): 8 | 9 | https://user-images.githubusercontent.com/7107196/179459162-80b65a6c-439d-4239-819f-68804d9412e9.mp4 10 | 11 | And the source video: 12 | 13 | https://user-images.githubusercontent.com/7107196/181700094-356284bc-e8a4-4757-ab84-1e9009fddd4b.mp4 14 | 15 | Technically it can be solved by using more positional and motion clues. XMem is not sufficiently proficient at those. 16 | 17 | ## Shot changes; saliency shift 18 | 19 | Ever wondered why I did not include the final scene of Chika Dance when the roach flies off? Because it failed there. 20 | 21 | XMem seems to be attracted to any new salient object in the scene when the (true) target object is missing. By new I mean an object that did not appear (or had a different appearance) earlier in the video -- as XMem could not have a memory representation for that object. This happens a lot if the camera shot changes. 22 | 23 | https://user-images.githubusercontent.com/7107196/179459190-d736937a-6925-4472-b46e-dcf94e1cafc0.mp4 24 | 25 | Note that the first shot change is not as problematic. 26 | -------------------------------------------------------------------------------- /third_party/XMem/docs/GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | Our code is tested on Ubuntu. I have briefly tested the GUI on Windows (with a PyQt5 fix in the heading of interactive_demo.py). 4 | 5 | ## Requirements 6 | 7 | * Python 3.8+ 8 | * PyTorch 1.11+ (See [PyTorch](https://pytorch.org/) for installation instructions) 9 | * `torchvision` corresponding to the PyTorch version 10 | * OpenCV (try `pip install opencv-python`) 11 | * Others: `pip install -r requirements.txt` 12 | 13 | ## Dataset 14 | 15 | I recommend either softlinking (`ln -s`) existing data or use the provided `scripts/download_datasets.py` to structure the datasets as our format. 16 | 17 | `python -m scripts.download_dataset` 18 | 19 | The structure is the same as the one in STCN -- you can place XMem in the same folder as STCN and it will work. 20 | The script uses Google Drive and sometimes fails when certain files are blocked from automatic download. You would have to do some manual work in that case. 21 | It does not download BL30K because it is huge and we don't want to crash your harddisks. 22 | 23 | ```bash 24 | ├── XMem 25 | ├── BL30K 26 | ├── DAVIS 27 | │ ├── 2016 28 | │ │ ├── Annotations 29 | │ │ └── ... 30 | │ └── 2017 31 | │ ├── test-dev 32 | │ │ ├── Annotations 33 | │ │ └── ... 34 | │ └── trainval 35 | │ ├── Annotations 36 | │ └── ... 37 | ├── static 38 | │ ├── BIG_small 39 | │ └── ... 40 | ├── long_video_set 41 | │ ├── long_video 42 | │ ├── long_video_x3 43 | │ ├── long_video_davis 44 | │ └── ... 45 | ├── YouTube 46 | │ ├── all_frames 47 | │ │ └── valid_all_frames 48 | │ ├── train 49 | │ ├── train_480p 50 | │ └── valid 51 | └── YouTube2018 52 | ├── all_frames 53 | │ └── valid_all_frames 54 | └── valid 55 | ``` 56 | 57 | ## Long-Time Video 58 | 59 | It comes from [AFB-URR](https://github.com/xmlyqing00/AFB-URR). Please following their license when using this data. We release our extended version (X3) and corresponding `_davis` versions such that the DAVIS evaluation can be used directly. They can be downloaded [[here]](TODO). The script above would also attempt to download it. 60 | 61 | ### BL30K 62 | 63 | You can either use the automatic script `download_bl30k.py` or download it manually from [MiVOS](https://github.com/hkchengrex/MiVOS/#bl30k). Note that each segment is about 115GB in size -- 700GB in total. You are going to need ~1TB of free disk space to run the script (including extraction buffer). 64 | The script uses Google Drive and sometimes fails when certain files are blocked from automatic download. You would have to do some manual work in that case. 65 | -------------------------------------------------------------------------------- /third_party/XMem/docs/PALETTE.md: -------------------------------------------------------------------------------- 1 | # Palette 2 | 3 | > Some image formats, such as GIF or PNG, can use a palette, which is a table of (usually) 256 colors to allow for better compression. Basically, instead of representing each pixel with its full color triplet, which takes 24bits (plus eventual 8 more for transparency), they use a 8 bit index that represent the position inside the palette, and thus the color. 4 | -- https://docs.geoserver.org/2.22.x/en/user/tutorials/palettedimage/palettedimage.html 5 | 6 | So those mask files that look like color images are single-channel, `uint8` arrays under the hood. When `PIL` reads them, it (correctly) gives you a two-dimensional array (`opencv` does not work AFAIK). If what you get is instead of three-dimensional, `H*W*3` array, then your mask is not actually a paletted mask, but just a colored image. Reading and saving a paletted mask through `opencv` or MS Paint would destroy the palette. 7 | 8 | Our code, when asked to generate multi-object segmentation (e.g., DAVIS 2017/YouTubeVOS), always reads and writes single-channel mask. If there is a palette in the input, we will use it in the output. The code does not care whether a palette is actually used -- we can read grayscale images just fine. 9 | 10 | Importantly, we use `np.unique` to determine the number of objects in the mask. This would fail if: 11 | 12 | 1. Colored images, instead of paletted masks are used. 13 | 2. The masks have "smooth" edges, produced by feathering/downsizing/compression. For example, when you draw the mask in a painting software, make sure you set the brush hardness to maximum. 14 | -------------------------------------------------------------------------------- /third_party/XMem/docs/RESULTS.md: -------------------------------------------------------------------------------- 1 | # Results 2 | 3 | ## Preamble 4 | 5 | Our code, by default, uses automatic mixed precision (AMP). Its effect on the output is negligible. 6 | All speeds reported in the paper are recorded with AMP turned off (`--benchmark`). 7 | Due to refactoring, there might be slight differences between the outputs produced by this code base with the precomputed results/results reported in the paper. This difference rarely leads to a change of the least significant figure (i.e., 0.1). 8 | 9 | **For most complete results, please see the paper (and the appendix)!** 10 | 11 | All available precomputed results can be found [[here]](https://drive.google.com/drive/folders/1UxHPXJbQLHjF5zYVn3XZCXfi_NYL81Bf?usp=sharing). 12 | 13 | ## Pretrained models 14 | 15 | We provide four pretrained models for download: 16 | 17 | 1. XMem.pth (Default) 18 | 2. XMem-s012.pth (Trained with BL30K) 19 | 3. XMem-s2.pth (No pretraining on static images) 20 | 4. XMem-no-sensory (No sensory memory) 21 | 22 | The model without pretraining is for reference. The model without sensory memory might be more suitable for tasks without spatial continuity, like mask tracking in a multi-camera 3D reconstruction setting, though I would encourage you to try the base model as well. 23 | 24 | Download them from [[GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing). 25 | 26 | ## Long-Time Video 27 | 28 | [[Precomputed Results]](https://drive.google.com/drive/folders/1NADcetigH6d83mUvyb2rH4VVjwFA76Lh?usp=sharing) 29 | 30 | ### Long-Time Video (1X) 31 | 32 | | Model | J&F | J | F | 33 | | --- | :--:|:--:|:---:| 34 | | XMem | 89.8±0.2 | 88.0±0.2 | 91.6±0.2 | 35 | 36 | ### Long-Time Video (3X) 37 | 38 | | Model | J&F | J | F | 39 | | --- | :--:|:--:|:---:| 40 | | XMem | 90.0±0.4 | 88.2±0.3 | 91.8±0.4 | 41 | 42 | ## DAVIS 43 | 44 | [[Precomputed Results]](https://drive.google.com/drive/folders/1XTOGevTedRSjHnFVsZyTdxJG-iHjO0Re?usp=sharing) 45 | 46 | ### DAVIS 2016 47 | 48 | | Model | J&F | J | F | FPS | FPS (AMP) | 49 | | --- | :--:|:--:|:---:|:---:|:---:| 50 | | XMem | 91.5 | 90.4 | 92.7 | 29.6 | 40.3 | 51 | | XMem-s012 | 92.0 | 90.7 | 93.2 | 29.6 | 40.3 | 52 | | XMem-s2 | 90.8 | 89.6 | 91.9 | 29.6 | 40.3 | 53 | 54 | ### DAVIS 2017 validation 55 | 56 | | Model | J&F | J | F | FPS | FPS (AMP) | 57 | | --- | :--:|:--:|:---:|:---:|:---:| 58 | | XMem | 86.2 | 82.9 | 89.5 | 22.6 | 33.9 | 59 | | XMem-s012 | 87.7 | 84.0 | 91.4 | 22.6 | 33.9 | 60 | | XMem-s2 | 84.5 | 81.4 | 87.6 | 22.6 | 33.9 | 61 | | XMem-no-sensory | 85.1 | - | - | 23.1 | - | 62 | 63 | ### DAVIS 2017 test-dev 64 | 65 | | Model | J&F | J | F | 66 | | --- | :--:|:--:|:---:| 67 | | XMem | 81.0 | 77.4 | 84.5 | 68 | | XMem-s012 | 81.2 | 77.6 | 84.7 | 69 | | XMem-s2 | 79.8 | 61.4 | 68.1 | 70 | | XMem-s012 (600p) | 82.5 | 79.1 | 85.8 | 71 | 72 | ## YouTubeVOS 73 | 74 | We use all available frames in YouTubeVOS by default. 75 | See [INFERENCE.md](./INFERENCE.md) if you want to evaluate with sparse frames for some reason. 76 | 77 | [[Precomputed Results]](https://drive.google.com/drive/folders/1P_BmOdcG6OP5mWGqWzCZrhQJ7AaLME4E?usp=sharing) 78 | 79 | [[Precomputed Results (sparse)]](https://drive.google.com/drive/folders/1IRV1fHepufUXM45EEbtl9D4pkoh9POSZ?usp=sharing) 80 | 81 | ### YouTubeVOS 2018 validation 82 | 83 | | Model | G | J-Seen | F-Seen | J-Unseen | F-Unseen | FPS | FPS (AMP) | 84 | | --- | :--:|:--:|:---:|:---:|:---:|:---:|:---:| 85 | | XMem | 85.7 | 84.6 | 89.3 | 80.2 | 88.7 | 22.6 | 31.7 | 86 | | XMem-s012 | 86.1 | 85.1 | 89.8 | 80.3 | 89.2 | 22.6 | 31.7 | 87 | | XMem-s2 | 84.3 | 83.9 | 88.8 | 77.7 | 86.7 | 22.6 | 31.7 | 88 | | XMem-no-sensory | 84.4 | - | - | - | - | 23.1 | - | 89 | 90 | ### YouTubeVOS 2019 validation 91 | 92 | | Model | G | J-Seen | F-Seen | J-Unseen | F-Unseen | 93 | | --- | :--:|:--:|:---:|:---:|:---:| 94 | | XMem | 85.5 | 84.3 | 88.6 | 80.3 | 88.6 | 95 | | XMem-s012 | 85.8 | 84.8 | 89.2 | 80.3 | 88.8 | 96 | | XMem-s2 | 84.2 | 83.8 | 88.3 | 78.1 | 86.7 | 97 | 98 | ## Multi-scale evaluation 99 | 100 | Please see the appendix for quantitative results. 101 | 102 | [[DAVIS-MS Precomputed Results]](https://drive.google.com/drive/folders/1H3VHKDO09izp6KR3sE-LzWbjyM-jpftn?usp=sharing) 103 | 104 | [[YouTubeVOS-MS Precomputed Results]](https://drive.google.com/drive/folders/1ww5HVRbMKXraLd2dy1rtk6kLjEawW9Kn?usp=sharing) 105 | -------------------------------------------------------------------------------- /third_party/XMem/docs/TRAINING.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | First, set up the datasets following [GETTING STARTED.md](./GETTING_STARTED.md). 4 | 5 | The model is trained progressively with different stages (0: static images; 1: BL30K; 2: longer main training; 3: shorter main training). After each stage finishes, we start the next stage by loading the latest trained weight. 6 | For example, the base model is pretrained with static images followed by the shorter main training (s03). 7 | 8 | To train the base model on two GPUs, you can use: 9 | 10 | ```bash 11 | python -m torch.distributed.run --master_port 25763 --nproc_per_node=2 train.py --exp_id retrain --stage 03 12 | ``` 13 | (**NOTE**: Unexplained accuracy decrease might occur if you are not using two GPUs to train. See https://github.com/hkchengrex/XMem/issues/71.) 14 | 15 | `master_port` needs to point to an unused port. 16 | `nproc_per_node` refers to the number of GPUs to be used (specify `CUDA_VISIBLE_DEVICES` to select which GPUs to use). 17 | `exp_id` is an identifier you give to this training job. 18 | 19 | See other available command line arguments in `util/configuration.py`. 20 | **Unlike the training code of STCN, batch sizes are effective. You don't have to adjust the batch size when you use more/fewer GPUs.** 21 | 22 | We implemented automatic staging in this code base. You don't have to train different stages by yourself like in STCN (but that is still supported). 23 | `stage` is a string that we split to determine the training stages. Examples include `0` (static images only), `03` (base training), `012` (with BL30K), `2` (main training only). 24 | 25 | You can use `tensorboard` to visualize the training process. 26 | 27 | ## Outputs 28 | 29 | The model files and checkpoints will be saved in `./saves/[name containing datetime and exp_id]`. 30 | 31 | `.pth` files with `_checkpoint` store the network weights, optimizer states, etc. and can be used to resume training (with `--load_checkpoint`). 32 | 33 | Other `.pth` files store the network weights only and can be used for inference. We note that there are variations in performance across different training runs and across the last few saved models. For the base model, we most often note that main training at 107K iterations leads to the best result (full training is 110K). 34 | 35 | We measure the median and std scores across five training runs of the base model: 36 | 37 | | Dataset | median | std | 38 | | --- | :--:|:--:| 39 | | DAVIS J&F | 86.2 | 0.23 | 40 | | YouTubeVOS 2018 G | 85.6 | 0.21 41 | 42 | ## Pretrained models 43 | 44 | You can start training from scratch, or use any of our pretrained models for fine-tuning. For example, you can load our stage 0 model to skip main training: 45 | 46 | ```bash 47 | python -m torch.distributed.launch --master_port 25763 --nproc_per_node=2 train.py --exp_id retrain_stage3_only --stage 3 --load_network saves/XMem-s0.pth 48 | ``` 49 | 50 | Download them from [[GitHub]](https://github.com/hkchengrex/XMem/releases/tag/v1.0) or [[Google Drive]](https://drive.google.com/drive/folders/1QYsog7zNzcxGXTGBzEhMUg8QVJwZB6D1?usp=sharing). 51 | -------------------------------------------------------------------------------- /third_party/XMem/docs/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/docs/icon.png -------------------------------------------------------------------------------- /third_party/XMem/docs/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Roboto', sans-serif; 3 | font-size:18px; 4 | margin-left: auto; 5 | margin-right: auto; 6 | font-weight: 300; 7 | height: 100%; 8 | max-width: 1000px; 9 | } 10 | 11 | .light { 12 | font-weight: 100; 13 | } 14 | 15 | .heavy { 16 | font-weight: 400; 17 | } 18 | 19 | .column { 20 | float: left; 21 | } 22 | 23 | .metric_table { 24 | border-collapse: collapse; 25 | margin-left: 15px; 26 | margin-right: auto; 27 | } 28 | 29 | .metric_table th{ 30 | border-bottom: 1px solid #555; 31 | padding-left: 15px; 32 | padding-right: 15px; 33 | } 34 | 35 | .metric_table td{ 36 | padding-left: 15px; 37 | padding-right: 15px; 38 | } 39 | 40 | .metric_table .left_align{ 41 | text-align: left; 42 | } 43 | 44 | a:link,a:visited 45 | { 46 | color: #05538f; 47 | text-decoration: none; 48 | } 49 | 50 | a:hover { 51 | color: #63cbdd; 52 | } 53 | 54 | hr 55 | { 56 | border: 0; 57 | height: 1px; 58 | background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0)); 59 | } 60 | -------------------------------------------------------------------------------- /third_party/XMem/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/inference/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/inference/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/inference/data/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/inference/data/mask_mapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from dataset.util import all_to_onehot 5 | 6 | 7 | class MaskMapper: 8 | """ 9 | This class is used to convert a indexed-mask to a one-hot representation. 10 | It also takes care of remapping non-continuous indices 11 | It has two modes: 12 | 1. Default. Only masks with new indices are supposed to go into the remapper. 13 | This is also the case for YouTubeVOS. 14 | i.e., regions with index 0 are not "background", but "don't care". 15 | 16 | 2. Exhaustive. Regions with index 0 are considered "background". 17 | Every single pixel is considered to be "labeled". 18 | """ 19 | def __init__(self): 20 | self.labels = [] 21 | self.remappings = {} 22 | 23 | # if coherent, no mapping is required 24 | self.coherent = True 25 | 26 | def convert_mask(self, mask, exhaustive=False): 27 | # mask is in index representation, H*W numpy array 28 | labels = np.unique(mask).astype(np.uint8) 29 | labels = labels[labels!=0].tolist() 30 | 31 | new_labels = list(set(labels) - set(self.labels)) 32 | if not exhaustive: 33 | assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' 34 | 35 | # add new remappings 36 | for i, l in enumerate(new_labels): 37 | self.remappings[l] = i+len(self.labels)+1 38 | if self.coherent and i+len(self.labels)+1 != l: 39 | self.coherent = False 40 | 41 | if exhaustive: 42 | new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1) 43 | else: 44 | if self.coherent: 45 | new_mapped_labels = new_labels 46 | else: 47 | new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1) 48 | 49 | self.labels.extend(new_labels) 50 | mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float() 51 | 52 | # mask num_objects*H*W 53 | return mask, new_mapped_labels 54 | 55 | 56 | def remap_index_mask(self, mask): 57 | # mask is in index representation, H*W numpy array 58 | if self.coherent: 59 | return mask 60 | 61 | new_mask = np.zeros_like(mask) 62 | for l, i in self.remappings.items(): 63 | new_mask[mask==i] = l 64 | return new_mask 65 | 66 | def clear_labels(self): 67 | self.labels = [] 68 | self.remappings = {} 69 | -------------------------------------------------------------------------------- /third_party/XMem/inference/data/test_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | import json 4 | 5 | from inference.data.video_reader import VideoReader 6 | 7 | 8 | class LongTestDataset: 9 | def __init__(self, data_root, size=-1): 10 | self.image_dir = path.join(data_root, 'JPEGImages') 11 | self.mask_dir = path.join(data_root, 'Annotations') 12 | self.size = size 13 | 14 | self.vid_list = sorted(os.listdir(self.image_dir)) 15 | 16 | def get_datasets(self): 17 | for video in self.vid_list: 18 | yield VideoReader(video, 19 | path.join(self.image_dir, video), 20 | path.join(self.mask_dir, video), 21 | to_save = [ 22 | name[:-4] for name in os.listdir(path.join(self.mask_dir, video)) 23 | ], 24 | size=self.size, 25 | ) 26 | 27 | def __len__(self): 28 | return len(self.vid_list) 29 | 30 | 31 | class DAVISTestDataset: 32 | def __init__(self, data_root, imset='2017/val.txt', size=-1): 33 | if size != 480: 34 | self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution') 35 | self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution') 36 | if not path.exists(self.image_dir): 37 | print(f'{self.image_dir} not found. Look at other options.') 38 | self.image_dir = path.join(data_root, 'JPEGImages', '1080p') 39 | self.mask_dir = path.join(data_root, 'Annotations', '1080p') 40 | assert path.exists(self.image_dir), 'path not found' 41 | else: 42 | self.image_dir = path.join(data_root, 'JPEGImages', '480p') 43 | self.mask_dir = path.join(data_root, 'Annotations', '480p') 44 | self.size_dir = path.join(data_root, 'JPEGImages', '480p') 45 | self.size = size 46 | 47 | with open(path.join(data_root, 'ImageSets', imset)) as f: 48 | self.vid_list = sorted([line.strip() for line in f]) 49 | 50 | def get_datasets(self): 51 | for video in self.vid_list: 52 | yield VideoReader(video, 53 | path.join(self.image_dir, video), 54 | path.join(self.mask_dir, video), 55 | size=self.size, 56 | size_dir=path.join(self.size_dir, video), 57 | ) 58 | 59 | def __len__(self): 60 | return len(self.vid_list) 61 | 62 | 63 | class YouTubeVOSTestDataset: 64 | def __init__(self, data_root, split, size=480): 65 | self.image_dir = path.join(data_root, 'all_frames', split+'_all_frames', 'JPEGImages') 66 | self.mask_dir = path.join(data_root, split, 'Annotations') 67 | self.size = size 68 | 69 | self.vid_list = sorted(os.listdir(self.image_dir)) 70 | self.req_frame_list = {} 71 | 72 | with open(path.join(data_root, split, 'meta.json')) as f: 73 | # read meta.json to know which frame is required for evaluation 74 | meta = json.load(f)['videos'] 75 | 76 | for vid in self.vid_list: 77 | req_frames = [] 78 | objects = meta[vid]['objects'] 79 | for value in objects.values(): 80 | req_frames.extend(value['frames']) 81 | 82 | req_frames = list(set(req_frames)) 83 | self.req_frame_list[vid] = req_frames 84 | 85 | def get_datasets(self): 86 | for video in self.vid_list: 87 | yield VideoReader(video, 88 | path.join(self.image_dir, video), 89 | path.join(self.mask_dir, video), 90 | size=self.size, 91 | to_save=self.req_frame_list[video], 92 | use_all_mask=True 93 | ) 94 | 95 | def __len__(self): 96 | return len(self.vid_list) 97 | -------------------------------------------------------------------------------- /third_party/XMem/inference/data/video_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | from torch.utils.data.dataset import Dataset 5 | from torchvision import transforms 6 | from torchvision.transforms import InterpolationMode 7 | import torch.nn.functional as F 8 | from PIL import Image 9 | import numpy as np 10 | 11 | from dataset.range_transform import im_normalization 12 | 13 | 14 | class VideoReader(Dataset): 15 | """ 16 | This class is used to read a video, one frame at a time 17 | """ 18 | def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None): 19 | """ 20 | image_dir - points to a directory of jpg images 21 | mask_dir - points to a directory of png masks 22 | size - resize min. side to size. Does nothing if <0. 23 | to_save - optionally contains a list of file names without extensions 24 | where the segmentation mask is required 25 | use_all_mask - when true, read all available mask in mask_dir. 26 | Default false. Set to true for YouTubeVOS validation. 27 | """ 28 | self.vid_name = vid_name 29 | self.image_dir = image_dir 30 | self.mask_dir = mask_dir 31 | self.to_save = to_save 32 | self.use_all_mask = use_all_mask 33 | if size_dir is None: 34 | self.size_dir = self.image_dir 35 | else: 36 | self.size_dir = size_dir 37 | 38 | self.frames = sorted(os.listdir(self.image_dir)) 39 | self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette() 40 | self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0]) 41 | 42 | if size < 0: 43 | self.im_transform = transforms.Compose([ 44 | transforms.ToTensor(), 45 | im_normalization, 46 | ]) 47 | else: 48 | self.im_transform = transforms.Compose([ 49 | transforms.ToTensor(), 50 | im_normalization, 51 | transforms.Resize(size, interpolation=InterpolationMode.BILINEAR), 52 | ]) 53 | self.size = size 54 | 55 | 56 | def __getitem__(self, idx): 57 | frame = self.frames[idx] 58 | info = {} 59 | data = {} 60 | info['frame'] = frame 61 | info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save) 62 | 63 | im_path = path.join(self.image_dir, frame) 64 | img = Image.open(im_path).convert('RGB') 65 | 66 | if self.image_dir == self.size_dir: 67 | shape = np.array(img).shape[:2] 68 | else: 69 | size_path = path.join(self.size_dir, frame) 70 | size_im = Image.open(size_path).convert('RGB') 71 | shape = np.array(size_im).shape[:2] 72 | 73 | gt_path = path.join(self.mask_dir, frame[:-4]+'.png') 74 | img = self.im_transform(img) 75 | 76 | load_mask = self.use_all_mask or (gt_path == self.first_gt_path) 77 | if load_mask and path.exists(gt_path): 78 | mask = Image.open(gt_path).convert('P') 79 | mask = np.array(mask, dtype=np.uint8) 80 | data['mask'] = mask 81 | 82 | info['shape'] = shape 83 | info['need_resize'] = not (self.size < 0) 84 | data['rgb'] = img 85 | data['info'] = info 86 | 87 | return data 88 | 89 | def resize_mask(self, mask): 90 | # mask transform is applied AFTER mapper, so we need to post-process it in eval.py 91 | h, w = mask.shape[-2:] 92 | min_hw = min(h, w) 93 | return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), 94 | mode='nearest') 95 | 96 | def get_palette(self): 97 | return self.palette 98 | 99 | def __len__(self): 100 | return len(self.frames) -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/inference/interact/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/inference/interact/fbrs/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/controller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..fbrs.inference import clicker 4 | from ..fbrs.inference.predictors import get_predictor 5 | 6 | 7 | class InteractiveController: 8 | def __init__(self, net, device, predictor_params, prob_thresh=0.5): 9 | self.net = net.to(device) 10 | self.prob_thresh = prob_thresh 11 | self.clicker = clicker.Clicker() 12 | self.states = [] 13 | self.probs_history = [] 14 | self.object_count = 0 15 | self._result_mask = None 16 | 17 | self.image = None 18 | self.predictor = None 19 | self.device = device 20 | self.predictor_params = predictor_params 21 | self.reset_predictor() 22 | 23 | def set_image(self, image): 24 | self.image = image 25 | self._result_mask = torch.zeros(image.shape[-2:], dtype=torch.uint8) 26 | self.object_count = 0 27 | self.reset_last_object() 28 | 29 | def add_click(self, x, y, is_positive): 30 | self.states.append({ 31 | 'clicker': self.clicker.get_state(), 32 | 'predictor': self.predictor.get_states() 33 | }) 34 | 35 | click = clicker.Click(is_positive=is_positive, coords=(y, x)) 36 | self.clicker.add_click(click) 37 | pred = self.predictor.get_prediction(self.clicker) 38 | torch.cuda.empty_cache() 39 | 40 | if self.probs_history: 41 | self.probs_history.append((self.probs_history[-1][0], pred)) 42 | else: 43 | self.probs_history.append((torch.zeros_like(pred), pred)) 44 | 45 | def undo_click(self): 46 | if not self.states: 47 | return 48 | 49 | prev_state = self.states.pop() 50 | self.clicker.set_state(prev_state['clicker']) 51 | self.predictor.set_states(prev_state['predictor']) 52 | self.probs_history.pop() 53 | 54 | def partially_finish_object(self): 55 | object_prob = self.current_object_prob 56 | if object_prob is None: 57 | return 58 | 59 | self.probs_history.append((object_prob, torch.zeros_like(object_prob))) 60 | self.states.append(self.states[-1]) 61 | 62 | self.clicker.reset_clicks() 63 | self.reset_predictor() 64 | 65 | def finish_object(self): 66 | object_prob = self.current_object_prob 67 | if object_prob is None: 68 | return 69 | 70 | self.object_count += 1 71 | object_mask = object_prob > self.prob_thresh 72 | self._result_mask[object_mask] = self.object_count 73 | self.reset_last_object() 74 | 75 | def reset_last_object(self): 76 | self.states = [] 77 | self.probs_history = [] 78 | self.clicker.reset_clicks() 79 | self.reset_predictor() 80 | 81 | def reset_predictor(self, predictor_params=None): 82 | if predictor_params is not None: 83 | self.predictor_params = predictor_params 84 | self.predictor = get_predictor(self.net, device=self.device, 85 | **self.predictor_params) 86 | if self.image is not None: 87 | self.predictor.set_input_image(self.image) 88 | 89 | @property 90 | def current_object_prob(self): 91 | if self.probs_history: 92 | current_prob_total, current_prob_additive = self.probs_history[-1] 93 | return torch.maximum(current_prob_total, current_prob_additive) 94 | else: 95 | return None 96 | 97 | @property 98 | def is_incomplete_mask(self): 99 | return len(self.probs_history) > 0 100 | 101 | @property 102 | def result_mask(self): 103 | return self._result_mask.clone() 104 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/inference/interact/fbrs/inference/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/clicker.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import numpy as np 4 | from copy import deepcopy 5 | from scipy.ndimage import distance_transform_edt 6 | 7 | Click = namedtuple('Click', ['is_positive', 'coords']) 8 | 9 | 10 | class Clicker(object): 11 | def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1): 12 | if gt_mask is not None: 13 | self.gt_mask = gt_mask == 1 14 | self.not_ignore_mask = gt_mask != ignore_label 15 | else: 16 | self.gt_mask = None 17 | 18 | self.reset_clicks() 19 | 20 | if init_clicks is not None: 21 | for click in init_clicks: 22 | self.add_click(click) 23 | 24 | def make_next_click(self, pred_mask): 25 | assert self.gt_mask is not None 26 | click = self._get_click(pred_mask) 27 | self.add_click(click) 28 | 29 | def get_clicks(self, clicks_limit=None): 30 | return self.clicks_list[:clicks_limit] 31 | 32 | def _get_click(self, pred_mask, padding=True): 33 | fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask) 34 | fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask) 35 | 36 | if padding: 37 | fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant') 38 | fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant') 39 | 40 | fn_mask_dt = distance_transform_edt(fn_mask) 41 | fp_mask_dt = distance_transform_edt(fp_mask) 42 | 43 | if padding: 44 | fn_mask_dt = fn_mask_dt[1:-1, 1:-1] 45 | fp_mask_dt = fp_mask_dt[1:-1, 1:-1] 46 | 47 | fn_mask_dt = fn_mask_dt * self.not_clicked_map 48 | fp_mask_dt = fp_mask_dt * self.not_clicked_map 49 | 50 | fn_max_dist = np.max(fn_mask_dt) 51 | fp_max_dist = np.max(fp_mask_dt) 52 | 53 | is_positive = fn_max_dist > fp_max_dist 54 | if is_positive: 55 | coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x] 56 | else: 57 | coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x] 58 | 59 | return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0])) 60 | 61 | def add_click(self, click): 62 | coords = click.coords 63 | 64 | if click.is_positive: 65 | self.num_pos_clicks += 1 66 | else: 67 | self.num_neg_clicks += 1 68 | 69 | self.clicks_list.append(click) 70 | if self.gt_mask is not None: 71 | self.not_clicked_map[coords[0], coords[1]] = False 72 | 73 | def _remove_last_click(self): 74 | click = self.clicks_list.pop() 75 | coords = click.coords 76 | 77 | if click.is_positive: 78 | self.num_pos_clicks -= 1 79 | else: 80 | self.num_neg_clicks -= 1 81 | 82 | if self.gt_mask is not None: 83 | self.not_clicked_map[coords[0], coords[1]] = True 84 | 85 | def reset_clicks(self): 86 | if self.gt_mask is not None: 87 | self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool) 88 | 89 | self.num_pos_clicks = 0 90 | self.num_neg_clicks = 0 91 | 92 | self.clicks_list = [] 93 | 94 | def get_state(self): 95 | return deepcopy(self.clicks_list) 96 | 97 | def set_state(self, state): 98 | self.reset_clicks() 99 | for click in state: 100 | self.add_click(click) 101 | 102 | def __len__(self): 103 | return len(self.clicks_list) 104 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/evaluation.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from ..inference import utils 7 | from ..inference.clicker import Clicker 8 | 9 | try: 10 | get_ipython() 11 | from tqdm import tqdm_notebook as tqdm 12 | except NameError: 13 | from tqdm import tqdm 14 | 15 | 16 | def evaluate_dataset(dataset, predictor, oracle_eval=False, **kwargs): 17 | all_ious = [] 18 | 19 | start_time = time() 20 | for index in tqdm(range(len(dataset)), leave=False): 21 | sample = dataset.get_sample(index) 22 | item = dataset[index] 23 | 24 | if oracle_eval: 25 | gt_mask = torch.tensor(sample['instances_mask'], dtype=torch.float32) 26 | gt_mask = gt_mask.unsqueeze(0).unsqueeze(0) 27 | predictor.opt_functor.mask_loss.set_gt_mask(gt_mask) 28 | _, sample_ious, _ = evaluate_sample(item['images'], sample['instances_mask'], predictor, **kwargs) 29 | all_ious.append(sample_ious) 30 | end_time = time() 31 | elapsed_time = end_time - start_time 32 | 33 | return all_ious, elapsed_time 34 | 35 | 36 | def evaluate_sample(image_nd, instances_mask, predictor, max_iou_thr, 37 | pred_thr=0.49, max_clicks=20): 38 | clicker = Clicker(gt_mask=instances_mask) 39 | pred_mask = np.zeros_like(instances_mask) 40 | ious_list = [] 41 | 42 | with torch.no_grad(): 43 | predictor.set_input_image(image_nd) 44 | 45 | for click_number in range(max_clicks): 46 | clicker.make_next_click(pred_mask) 47 | pred_probs = predictor.get_prediction(clicker) 48 | pred_mask = pred_probs > pred_thr 49 | 50 | iou = utils.get_iou(instances_mask, pred_mask) 51 | ious_list.append(iou) 52 | 53 | if iou >= max_iou_thr: 54 | break 55 | 56 | return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs 57 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BasePredictor 2 | from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor 3 | from .brs_functors import InputOptimizer, ScaleBiasOptimizer 4 | from ..transforms import ZoomIn 5 | from ...model.is_hrnet_model import DistMapsHRNetModel 6 | 7 | 8 | def get_predictor(net, brs_mode, device, 9 | prob_thresh=0.49, 10 | with_flip=True, 11 | zoom_in_params=dict(), 12 | predictor_params=None, 13 | brs_opt_func_params=None, 14 | lbfgs_params=None): 15 | lbfgs_params_ = { 16 | 'm': 20, 17 | 'factr': 0, 18 | 'pgtol': 1e-8, 19 | 'maxfun': 20, 20 | } 21 | 22 | predictor_params_ = { 23 | 'optimize_after_n_clicks': 1 24 | } 25 | 26 | if zoom_in_params is not None: 27 | zoom_in = ZoomIn(**zoom_in_params) 28 | else: 29 | zoom_in = None 30 | 31 | if lbfgs_params is not None: 32 | lbfgs_params_.update(lbfgs_params) 33 | lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun'] 34 | 35 | if brs_opt_func_params is None: 36 | brs_opt_func_params = dict() 37 | 38 | if brs_mode == 'NoBRS': 39 | if predictor_params is not None: 40 | predictor_params_.update(predictor_params) 41 | predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_) 42 | elif brs_mode.startswith('f-BRS'): 43 | predictor_params_.update({ 44 | 'net_clicks_limit': 8, 45 | }) 46 | if predictor_params is not None: 47 | predictor_params_.update(predictor_params) 48 | 49 | insertion_mode = { 50 | 'f-BRS-A': 'after_c4', 51 | 'f-BRS-B': 'after_aspp', 52 | 'f-BRS-C': 'after_deeplab' 53 | }[brs_mode] 54 | 55 | opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh, 56 | with_flip=with_flip, 57 | optimizer_params=lbfgs_params_, 58 | **brs_opt_func_params) 59 | 60 | if isinstance(net, DistMapsHRNetModel): 61 | FeaturePredictor = HRNetFeatureBRSPredictor 62 | insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode] 63 | else: 64 | FeaturePredictor = FeatureBRSPredictor 65 | 66 | predictor = FeaturePredictor(net, device, 67 | opt_functor=opt_functor, 68 | with_flip=with_flip, 69 | insertion_mode=insertion_mode, 70 | zoom_in=zoom_in, 71 | **predictor_params_) 72 | elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS': 73 | use_dmaps = brs_mode == 'DistMap-BRS' 74 | 75 | predictor_params_.update({ 76 | 'net_clicks_limit': 5, 77 | }) 78 | if predictor_params is not None: 79 | predictor_params_.update(predictor_params) 80 | 81 | opt_functor = InputOptimizer(prob_thresh=prob_thresh, 82 | with_flip=with_flip, 83 | optimizer_params=lbfgs_params_, 84 | **brs_opt_func_params) 85 | 86 | predictor = InputBRSPredictor(net, device, 87 | optimize_target='dmaps' if use_dmaps else 'rgb', 88 | opt_functor=opt_functor, 89 | with_flip=with_flip, 90 | zoom_in=zoom_in, 91 | **predictor_params_) 92 | else: 93 | raise NotImplementedError 94 | 95 | return predictor 96 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/predictors/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from ..transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide 5 | 6 | 7 | class BasePredictor(object): 8 | def __init__(self, net, device, 9 | net_clicks_limit=None, 10 | with_flip=False, 11 | zoom_in=None, 12 | max_size=None, 13 | **kwargs): 14 | self.net = net 15 | self.with_flip = with_flip 16 | self.net_clicks_limit = net_clicks_limit 17 | self.original_image = None 18 | self.device = device 19 | self.zoom_in = zoom_in 20 | 21 | self.transforms = [zoom_in] if zoom_in is not None else [] 22 | if max_size is not None: 23 | self.transforms.append(LimitLongestSide(max_size=max_size)) 24 | self.transforms.append(SigmoidForPred()) 25 | if with_flip: 26 | self.transforms.append(AddHorizontalFlip()) 27 | 28 | def set_input_image(self, image_nd): 29 | for transform in self.transforms: 30 | transform.reset() 31 | self.original_image = image_nd.to(self.device) 32 | if len(self.original_image.shape) == 3: 33 | self.original_image = self.original_image.unsqueeze(0) 34 | 35 | def get_prediction(self, clicker): 36 | clicks_list = clicker.get_clicks() 37 | 38 | image_nd, clicks_lists, is_image_changed = self.apply_transforms( 39 | self.original_image, [clicks_list] 40 | ) 41 | 42 | pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed) 43 | prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, 44 | size=image_nd.size()[2:]) 45 | 46 | for t in reversed(self.transforms): 47 | prediction = t.inv_transform(prediction) 48 | 49 | if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(): 50 | print('zooming') 51 | return self.get_prediction(clicker) 52 | 53 | # return prediction.cpu().numpy()[0, 0] 54 | return prediction 55 | 56 | def _get_prediction(self, image_nd, clicks_lists, is_image_changed): 57 | points_nd = self.get_points_nd(clicks_lists) 58 | return self.net(image_nd, points_nd)['instances'] 59 | 60 | def _get_transform_states(self): 61 | return [x.get_state() for x in self.transforms] 62 | 63 | def _set_transform_states(self, states): 64 | assert len(states) == len(self.transforms) 65 | for state, transform in zip(states, self.transforms): 66 | transform.set_state(state) 67 | 68 | def apply_transforms(self, image_nd, clicks_lists): 69 | is_image_changed = False 70 | for t in self.transforms: 71 | image_nd, clicks_lists = t.transform(image_nd, clicks_lists) 72 | is_image_changed |= t.image_changed 73 | 74 | return image_nd, clicks_lists, is_image_changed 75 | 76 | def get_points_nd(self, clicks_lists): 77 | total_clicks = [] 78 | num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] 79 | num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] 80 | num_max_points = max(num_pos_clicks + num_neg_clicks) 81 | if self.net_clicks_limit is not None: 82 | num_max_points = min(self.net_clicks_limit, num_max_points) 83 | num_max_points = max(1, num_max_points) 84 | 85 | for clicks_list in clicks_lists: 86 | clicks_list = clicks_list[:self.net_clicks_limit] 87 | pos_clicks = [click.coords for click in clicks_list if click.is_positive] 88 | pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1)] 89 | 90 | neg_clicks = [click.coords for click in clicks_list if not click.is_positive] 91 | neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1)] 92 | total_clicks.append(pos_clicks + neg_clicks) 93 | 94 | return torch.tensor(total_clicks, device=self.device) 95 | 96 | def get_states(self): 97 | return {'transform_states': self._get_transform_states()} 98 | 99 | def set_states(self, states): 100 | self._set_transform_states(states['transform_states']) 101 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/predictors/brs_functors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from ...model.metrics import _compute_iou 5 | from .brs_losses import BRSMaskLoss 6 | 7 | 8 | class BaseOptimizer: 9 | def __init__(self, optimizer_params, 10 | prob_thresh=0.49, 11 | reg_weight=1e-3, 12 | min_iou_diff=0.01, 13 | brs_loss=BRSMaskLoss(), 14 | with_flip=False, 15 | flip_average=False, 16 | **kwargs): 17 | self.brs_loss = brs_loss 18 | self.optimizer_params = optimizer_params 19 | self.prob_thresh = prob_thresh 20 | self.reg_weight = reg_weight 21 | self.min_iou_diff = min_iou_diff 22 | self.with_flip = with_flip 23 | self.flip_average = flip_average 24 | 25 | self.best_prediction = None 26 | self._get_prediction_logits = None 27 | self._opt_shape = None 28 | self._best_loss = None 29 | self._click_masks = None 30 | self._last_mask = None 31 | self.device = None 32 | 33 | def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None): 34 | self.best_prediction = None 35 | self._get_prediction_logits = get_prediction_logits 36 | self._click_masks = (pos_mask, neg_mask) 37 | self._opt_shape = shape 38 | self._last_mask = None 39 | self.device = device 40 | 41 | def __call__(self, x): 42 | opt_params = torch.from_numpy(x).float().to(self.device) 43 | opt_params.requires_grad_(True) 44 | 45 | with torch.enable_grad(): 46 | opt_vars, reg_loss = self.unpack_opt_params(opt_params) 47 | result_before_sigmoid = self._get_prediction_logits(*opt_vars) 48 | result = torch.sigmoid(result_before_sigmoid) 49 | 50 | pos_mask, neg_mask = self._click_masks 51 | if self.with_flip and self.flip_average: 52 | result, result_flipped = torch.chunk(result, 2, dim=0) 53 | result = 0.5 * (result + torch.flip(result_flipped, dims=[3])) 54 | pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]] 55 | 56 | loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask) 57 | loss = loss + reg_loss 58 | 59 | f_val = loss.detach().cpu().numpy() 60 | if self.best_prediction is None or f_val < self._best_loss: 61 | self.best_prediction = result_before_sigmoid.detach() 62 | self._best_loss = f_val 63 | 64 | if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh: 65 | return [f_val, np.zeros_like(x)] 66 | 67 | current_mask = result > self.prob_thresh 68 | if self._last_mask is not None and self.min_iou_diff > 0: 69 | diff_iou = _compute_iou(current_mask, self._last_mask) 70 | if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff: 71 | return [f_val, np.zeros_like(x)] 72 | self._last_mask = current_mask 73 | 74 | loss.backward() 75 | f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float32) 76 | 77 | return [f_val, f_grad] 78 | 79 | def unpack_opt_params(self, opt_params): 80 | raise NotImplementedError 81 | 82 | 83 | class InputOptimizer(BaseOptimizer): 84 | def unpack_opt_params(self, opt_params): 85 | opt_params = opt_params.view(self._opt_shape) 86 | if self.with_flip: 87 | opt_params_flipped = torch.flip(opt_params, dims=[3]) 88 | opt_params = torch.cat([opt_params, opt_params_flipped], dim=0) 89 | reg_loss = self.reg_weight * torch.sum(opt_params**2) 90 | 91 | return (opt_params,), reg_loss 92 | 93 | 94 | class ScaleBiasOptimizer(BaseOptimizer): 95 | def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs): 96 | super().__init__(*args, **kwargs) 97 | self.scale_act = scale_act 98 | self.reg_bias_weight = reg_bias_weight 99 | 100 | def unpack_opt_params(self, opt_params): 101 | scale, bias = torch.chunk(opt_params, 2, dim=0) 102 | reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2)) 103 | 104 | if self.scale_act == 'tanh': 105 | scale = torch.tanh(scale) 106 | elif self.scale_act == 'sin': 107 | scale = torch.sin(scale) 108 | 109 | return (1 + scale, bias), reg_loss 110 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/predictors/brs_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ...model.losses import SigmoidBinaryCrossEntropyLoss 4 | 5 | 6 | class BRSMaskLoss(torch.nn.Module): 7 | def __init__(self, eps=1e-5): 8 | super().__init__() 9 | self._eps = eps 10 | 11 | def forward(self, result, pos_mask, neg_mask): 12 | pos_diff = (1 - result) * pos_mask 13 | pos_target = torch.sum(pos_diff ** 2) 14 | pos_target = pos_target / (torch.sum(pos_mask) + self._eps) 15 | 16 | neg_diff = result * neg_mask 17 | neg_target = torch.sum(neg_diff ** 2) 18 | neg_target = neg_target / (torch.sum(neg_mask) + self._eps) 19 | 20 | loss = pos_target + neg_target 21 | 22 | with torch.no_grad(): 23 | f_max_pos = torch.max(torch.abs(pos_diff)).item() 24 | f_max_neg = torch.max(torch.abs(neg_diff)).item() 25 | 26 | return loss, f_max_pos, f_max_neg 27 | 28 | 29 | class OracleMaskLoss(torch.nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | self.gt_mask = None 33 | self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True) 34 | self.predictor = None 35 | self.history = [] 36 | 37 | def set_gt_mask(self, gt_mask): 38 | self.gt_mask = gt_mask 39 | self.history = [] 40 | 41 | def forward(self, result, pos_mask, neg_mask): 42 | gt_mask = self.gt_mask.to(result.device) 43 | if self.predictor.object_roi is not None: 44 | r1, r2, c1, c2 = self.predictor.object_roi[:4] 45 | gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1] 46 | gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True) 47 | 48 | if result.shape[0] == 2: 49 | gt_mask_flipped = torch.flip(gt_mask, dims=[3]) 50 | gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0) 51 | 52 | loss = self.loss(result, gt_mask) 53 | self.history.append(loss.detach().cpu().numpy()[0]) 54 | 55 | if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5: 56 | return 0, 0, 0 57 | 58 | return loss, 1.0, 1.0 59 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SigmoidForPred 2 | from .flip import AddHorizontalFlip 3 | from .zoom_in import ZoomIn 4 | from .limit_longest_side import LimitLongestSide 5 | from .crops import Crops 6 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/transforms/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseTransform(object): 5 | def __init__(self): 6 | self.image_changed = False 7 | 8 | def transform(self, image_nd, clicks_lists): 9 | raise NotImplementedError 10 | 11 | def inv_transform(self, prob_map): 12 | raise NotImplementedError 13 | 14 | def reset(self): 15 | raise NotImplementedError 16 | 17 | def get_state(self): 18 | raise NotImplementedError 19 | 20 | def set_state(self, state): 21 | raise NotImplementedError 22 | 23 | 24 | class SigmoidForPred(BaseTransform): 25 | def transform(self, image_nd, clicks_lists): 26 | return image_nd, clicks_lists 27 | 28 | def inv_transform(self, prob_map): 29 | return torch.sigmoid(prob_map) 30 | 31 | def reset(self): 32 | pass 33 | 34 | def get_state(self): 35 | return None 36 | 37 | def set_state(self, state): 38 | pass 39 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/transforms/crops.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from ...inference.clicker import Click 7 | from .base import BaseTransform 8 | 9 | 10 | class Crops(BaseTransform): 11 | def __init__(self, crop_size=(320, 480), min_overlap=0.2): 12 | super().__init__() 13 | self.crop_height, self.crop_width = crop_size 14 | self.min_overlap = min_overlap 15 | 16 | self.x_offsets = None 17 | self.y_offsets = None 18 | self._counts = None 19 | 20 | def transform(self, image_nd, clicks_lists): 21 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 22 | image_height, image_width = image_nd.shape[2:4] 23 | self._counts = None 24 | 25 | if image_height < self.crop_height or image_width < self.crop_width: 26 | return image_nd, clicks_lists 27 | 28 | self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap) 29 | self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap) 30 | self._counts = np.zeros((image_height, image_width)) 31 | 32 | image_crops = [] 33 | for dy in self.y_offsets: 34 | for dx in self.x_offsets: 35 | self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1 36 | image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width] 37 | image_crops.append(image_crop) 38 | image_crops = torch.cat(image_crops, dim=0) 39 | self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32) 40 | 41 | clicks_list = clicks_lists[0] 42 | clicks_lists = [] 43 | for dy in self.y_offsets: 44 | for dx in self.x_offsets: 45 | crop_clicks = [Click(is_positive=x.is_positive, coords=(x.coords[0] - dy, x.coords[1] - dx)) 46 | for x in clicks_list] 47 | clicks_lists.append(crop_clicks) 48 | 49 | return image_crops, clicks_lists 50 | 51 | def inv_transform(self, prob_map): 52 | if self._counts is None: 53 | return prob_map 54 | 55 | new_prob_map = torch.zeros((1, 1, *self._counts.shape), 56 | dtype=prob_map.dtype, device=prob_map.device) 57 | 58 | crop_indx = 0 59 | for dy in self.y_offsets: 60 | for dx in self.x_offsets: 61 | new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0] 62 | crop_indx += 1 63 | new_prob_map = torch.div(new_prob_map, self._counts) 64 | 65 | return new_prob_map 66 | 67 | def get_state(self): 68 | return self.x_offsets, self.y_offsets, self._counts 69 | 70 | def set_state(self, state): 71 | self.x_offsets, self.y_offsets, self._counts = state 72 | 73 | def reset(self): 74 | self.x_offsets = None 75 | self.y_offsets = None 76 | self._counts = None 77 | 78 | 79 | def get_offsets(length, crop_size, min_overlap_ratio=0.2): 80 | if length == crop_size: 81 | return [0] 82 | 83 | N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio) 84 | N = math.ceil(N) 85 | 86 | overlap_ratio = (N - length / crop_size) / (N - 1) 87 | overlap_width = int(crop_size * overlap_ratio) 88 | 89 | offsets = [0] 90 | for i in range(1, N): 91 | new_offset = offsets[-1] + crop_size - overlap_width 92 | if new_offset + crop_size > length: 93 | new_offset = length - crop_size 94 | 95 | offsets.append(new_offset) 96 | 97 | return offsets 98 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/transforms/flip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..clicker import Click 4 | from .base import BaseTransform 5 | 6 | 7 | class AddHorizontalFlip(BaseTransform): 8 | def transform(self, image_nd, clicks_lists): 9 | assert len(image_nd.shape) == 4 10 | image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0) 11 | 12 | image_width = image_nd.shape[3] 13 | clicks_lists_flipped = [] 14 | for clicks_list in clicks_lists: 15 | clicks_list_flipped = [Click(is_positive=click.is_positive, 16 | coords=(click.coords[0], image_width - click.coords[1] - 1)) 17 | for click in clicks_list] 18 | clicks_lists_flipped.append(clicks_list_flipped) 19 | clicks_lists = clicks_lists + clicks_lists_flipped 20 | 21 | return image_nd, clicks_lists 22 | 23 | def inv_transform(self, prob_map): 24 | assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0 25 | num_maps = prob_map.shape[0] // 2 26 | prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:] 27 | 28 | return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3])) 29 | 30 | def get_state(self): 31 | return None 32 | 33 | def set_state(self, state): 34 | pass 35 | 36 | def reset(self): 37 | pass 38 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/inference/transforms/limit_longest_side.py: -------------------------------------------------------------------------------- 1 | from .zoom_in import ZoomIn, get_roi_image_nd 2 | 3 | 4 | class LimitLongestSide(ZoomIn): 5 | def __init__(self, max_size=800): 6 | super().__init__(target_size=max_size, skip_clicks=0) 7 | 8 | def transform(self, image_nd, clicks_lists): 9 | assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 10 | image_max_size = max(image_nd.shape[2:4]) 11 | self.image_changed = False 12 | 13 | if image_max_size <= self.target_size: 14 | return image_nd, clicks_lists 15 | self._input_image = image_nd 16 | 17 | self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1) 18 | self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) 19 | self.image_changed = True 20 | 21 | tclicks_lists = [self._transform_clicks(clicks_lists[0])] 22 | return self._roi_image, tclicks_lists 23 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/inference/interact/fbrs/model/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/initializer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class Initializer(object): 7 | def __init__(self, local_init=True, gamma=None): 8 | self.local_init = local_init 9 | self.gamma = gamma 10 | 11 | def __call__(self, m): 12 | if getattr(m, '__initialized', False): 13 | return 14 | 15 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, 16 | nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, 17 | nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__: 18 | if m.weight is not None: 19 | self._init_gamma(m.weight.data) 20 | if m.bias is not None: 21 | self._init_beta(m.bias.data) 22 | else: 23 | if getattr(m, 'weight', None) is not None: 24 | self._init_weight(m.weight.data) 25 | if getattr(m, 'bias', None) is not None: 26 | self._init_bias(m.bias.data) 27 | 28 | if self.local_init: 29 | object.__setattr__(m, '__initialized', True) 30 | 31 | def _init_weight(self, data): 32 | nn.init.uniform_(data, -0.07, 0.07) 33 | 34 | def _init_bias(self, data): 35 | nn.init.constant_(data, 0) 36 | 37 | def _init_gamma(self, data): 38 | if self.gamma is None: 39 | nn.init.constant_(data, 1.0) 40 | else: 41 | nn.init.normal_(data, 1.0, self.gamma) 42 | 43 | def _init_beta(self, data): 44 | nn.init.constant_(data, 0) 45 | 46 | 47 | class Bilinear(Initializer): 48 | def __init__(self, scale, groups, in_channels, **kwargs): 49 | super().__init__(**kwargs) 50 | self.scale = scale 51 | self.groups = groups 52 | self.in_channels = in_channels 53 | 54 | def _init_weight(self, data): 55 | """Reset the weight and bias.""" 56 | bilinear_kernel = self.get_bilinear_kernel(self.scale) 57 | weight = torch.zeros_like(data) 58 | for i in range(self.in_channels): 59 | if self.groups == 1: 60 | j = i 61 | else: 62 | j = 0 63 | weight[i, j] = bilinear_kernel 64 | data[:] = weight 65 | 66 | @staticmethod 67 | def get_bilinear_kernel(scale): 68 | """Generate a bilinear upsampling kernel.""" 69 | kernel_size = 2 * scale - scale % 2 70 | scale = (kernel_size + 1) // 2 71 | center = scale - 0.5 * (1 + kernel_size % 2) 72 | 73 | og = np.ogrid[:kernel_size, :kernel_size] 74 | kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale) 75 | 76 | return torch.tensor(kernel, dtype=torch.float32) 77 | 78 | 79 | class XavierGluon(Initializer): 80 | def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs): 81 | super().__init__(**kwargs) 82 | 83 | self.rnd_type = rnd_type 84 | self.factor_type = factor_type 85 | self.magnitude = float(magnitude) 86 | 87 | def _init_weight(self, arr): 88 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr) 89 | 90 | if self.factor_type == 'avg': 91 | factor = (fan_in + fan_out) / 2.0 92 | elif self.factor_type == 'in': 93 | factor = fan_in 94 | elif self.factor_type == 'out': 95 | factor = fan_out 96 | else: 97 | raise ValueError('Incorrect factor type') 98 | scale = np.sqrt(self.magnitude / factor) 99 | 100 | if self.rnd_type == 'uniform': 101 | nn.init.uniform_(arr, -scale, scale) 102 | elif self.rnd_type == 'gaussian': 103 | nn.init.normal_(arr, 0, scale) 104 | else: 105 | raise ValueError('Unknown random type') 106 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/is_deeplab_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .ops import DistMaps 5 | from .modeling.deeplab_v3 import DeepLabV3Plus 6 | from .modeling.basic_blocks import SepConvHead 7 | 8 | 9 | def get_deeplab_model(backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5, 10 | norm_layer=nn.BatchNorm2d, backbone_norm_layer=None, 11 | use_rgb_conv=True, cpu_dist_maps=False, 12 | norm_radius=260): 13 | model = DistMapsModel( 14 | feature_extractor=DeepLabV3Plus(backbone=backbone, 15 | ch=deeplab_ch, 16 | project_dropout=aspp_dropout, 17 | norm_layer=norm_layer, 18 | backbone_norm_layer=backbone_norm_layer), 19 | head=SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2, 20 | num_layers=2, norm_layer=norm_layer), 21 | use_rgb_conv=use_rgb_conv, 22 | norm_layer=norm_layer, 23 | norm_radius=norm_radius, 24 | cpu_dist_maps=cpu_dist_maps 25 | ) 26 | 27 | return model 28 | 29 | 30 | class DistMapsModel(nn.Module): 31 | def __init__(self, feature_extractor, head, norm_layer=nn.BatchNorm2d, use_rgb_conv=True, 32 | cpu_dist_maps=False, norm_radius=260): 33 | super(DistMapsModel, self).__init__() 34 | 35 | if use_rgb_conv: 36 | self.rgb_conv = nn.Sequential( 37 | nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1), 38 | nn.LeakyReLU(negative_slope=0.2), 39 | norm_layer(8), 40 | nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1), 41 | ) 42 | else: 43 | self.rgb_conv = None 44 | 45 | self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, 46 | cpu_mode=cpu_dist_maps) 47 | self.feature_extractor = feature_extractor 48 | self.head = head 49 | 50 | def forward(self, image, points): 51 | coord_features = self.dist_maps(image, points) 52 | 53 | if self.rgb_conv is not None: 54 | x = self.rgb_conv(torch.cat((image, coord_features), dim=1)) 55 | else: 56 | c1, c2 = torch.chunk(coord_features, 2, dim=1) 57 | c3 = torch.ones_like(c1) 58 | coord_features = torch.cat((c1, c2, c3), dim=1) 59 | x = 0.8 * image * coord_features + 0.2 * image 60 | 61 | backbone_features = self.feature_extractor(x) 62 | instance_out = self.head(backbone_features[0]) 63 | instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:], 64 | mode='bilinear', align_corners=True) 65 | 66 | return {'instances': instance_out} 67 | 68 | def load_weights(self, path_to_weights): 69 | current_state_dict = self.state_dict() 70 | new_state_dict = torch.load(path_to_weights, map_location='cpu') 71 | current_state_dict.update(new_state_dict) 72 | self.load_state_dict(current_state_dict) 73 | 74 | def get_trainable_params(self): 75 | backbone_params = nn.ParameterList() 76 | other_params = nn.ParameterList() 77 | 78 | for name, param in self.named_parameters(): 79 | if param.requires_grad: 80 | if 'backbone' in name: 81 | backbone_params.append(param) 82 | else: 83 | other_params.append(param) 84 | return backbone_params, other_params 85 | 86 | 87 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/is_hrnet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .ops import DistMaps 5 | from .modeling.hrnet_ocr import HighResolutionNet 6 | 7 | 8 | def get_hrnet_model(width=48, ocr_width=256, small=False, norm_radius=260, 9 | use_rgb_conv=True, with_aux_output=False, cpu_dist_maps=False, 10 | norm_layer=nn.BatchNorm2d): 11 | model = DistMapsHRNetModel( 12 | feature_extractor=HighResolutionNet(width=width, ocr_width=ocr_width, small=small, 13 | num_classes=1, norm_layer=norm_layer), 14 | use_rgb_conv=use_rgb_conv, 15 | with_aux_output=with_aux_output, 16 | norm_layer=norm_layer, 17 | norm_radius=norm_radius, 18 | cpu_dist_maps=cpu_dist_maps 19 | ) 20 | 21 | return model 22 | 23 | 24 | class DistMapsHRNetModel(nn.Module): 25 | def __init__(self, feature_extractor, use_rgb_conv=True, with_aux_output=False, 26 | norm_layer=nn.BatchNorm2d, norm_radius=260, cpu_dist_maps=False): 27 | super(DistMapsHRNetModel, self).__init__() 28 | self.with_aux_output = with_aux_output 29 | 30 | if use_rgb_conv: 31 | self.rgb_conv = nn.Sequential( 32 | nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1), 33 | nn.LeakyReLU(negative_slope=0.2), 34 | norm_layer(8), 35 | nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1), 36 | ) 37 | else: 38 | self.rgb_conv = None 39 | 40 | self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps) 41 | self.feature_extractor = feature_extractor 42 | 43 | def forward(self, image, points): 44 | coord_features = self.dist_maps(image, points) 45 | 46 | if self.rgb_conv is not None: 47 | x = self.rgb_conv(torch.cat((image, coord_features), dim=1)) 48 | else: 49 | c1, c2 = torch.chunk(coord_features, 2, dim=1) 50 | c3 = torch.ones_like(c1) 51 | coord_features = torch.cat((c1, c2, c3), dim=1) 52 | x = 0.8 * image * coord_features + 0.2 * image 53 | 54 | feature_extractor_out = self.feature_extractor(x) 55 | instance_out = feature_extractor_out[0] 56 | instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:], 57 | mode='bilinear', align_corners=True) 58 | outputs = {'instances': instance_out} 59 | if self.with_aux_output: 60 | instance_aux_out = feature_extractor_out[1] 61 | instance_aux_out = nn.functional.interpolate(instance_aux_out, size=image.size()[2:], 62 | mode='bilinear', align_corners=True) 63 | outputs['instances_aux'] = instance_aux_out 64 | 65 | return outputs 66 | 67 | def load_weights(self, path_to_weights): 68 | current_state_dict = self.state_dict() 69 | new_state_dict = torch.load(path_to_weights) 70 | current_state_dict.update(new_state_dict) 71 | self.load_state_dict(current_state_dict) 72 | 73 | def get_trainable_params(self): 74 | backbone_params = nn.ParameterList() 75 | other_params = nn.ParameterList() 76 | other_params_keys = [] 77 | nonbackbone_keywords = ['rgb_conv', 'aux_head', 'cls_head', 'conv3x3_ocr', 'ocr_distri_head'] 78 | 79 | for name, param in self.named_parameters(): 80 | if param.requires_grad: 81 | if any(x in name for x in nonbackbone_keywords): 82 | other_params.append(param) 83 | other_params_keys.append(name) 84 | else: 85 | backbone_params.append(param) 86 | print('Nonbackbone params:', sorted(other_params_keys)) 87 | return backbone_params, other_params 88 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from ..utils import misc 5 | 6 | 7 | class TrainMetric(object): 8 | def __init__(self, pred_outputs, gt_outputs): 9 | self.pred_outputs = pred_outputs 10 | self.gt_outputs = gt_outputs 11 | 12 | def update(self, *args, **kwargs): 13 | raise NotImplementedError 14 | 15 | def get_epoch_value(self): 16 | raise NotImplementedError 17 | 18 | def reset_epoch_stats(self): 19 | raise NotImplementedError 20 | 21 | def log_states(self, sw, tag_prefix, global_step): 22 | pass 23 | 24 | @property 25 | def name(self): 26 | return type(self).__name__ 27 | 28 | 29 | class AdaptiveIoU(TrainMetric): 30 | def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, 31 | ignore_label=-1, from_logits=True, 32 | pred_output='instances', gt_output='instances'): 33 | super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) 34 | self._ignore_label = ignore_label 35 | self._from_logits = from_logits 36 | self._iou_thresh = init_thresh 37 | self._thresh_step = thresh_step 38 | self._thresh_beta = thresh_beta 39 | self._iou_beta = iou_beta 40 | self._ema_iou = 0.0 41 | self._epoch_iou_sum = 0.0 42 | self._epoch_batch_count = 0 43 | 44 | def update(self, pred, gt): 45 | gt_mask = gt > 0 46 | if self._from_logits: 47 | pred = torch.sigmoid(pred) 48 | 49 | gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy() 50 | if np.all(gt_mask_area == 0): 51 | return 52 | 53 | ignore_mask = gt == self._ignore_label 54 | max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean() 55 | best_thresh = self._iou_thresh 56 | for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]: 57 | temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean() 58 | if temp_iou > max_iou: 59 | max_iou = temp_iou 60 | best_thresh = t 61 | 62 | self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh 63 | self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou 64 | self._epoch_iou_sum += max_iou 65 | self._epoch_batch_count += 1 66 | 67 | def get_epoch_value(self): 68 | if self._epoch_batch_count > 0: 69 | return self._epoch_iou_sum / self._epoch_batch_count 70 | else: 71 | return 0.0 72 | 73 | def reset_epoch_stats(self): 74 | self._epoch_iou_sum = 0.0 75 | self._epoch_batch_count = 0 76 | 77 | def log_states(self, sw, tag_prefix, global_step): 78 | sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step) 79 | sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step) 80 | 81 | @property 82 | def iou_thresh(self): 83 | return self._iou_thresh 84 | 85 | 86 | def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): 87 | if ignore_mask is not None: 88 | pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) 89 | 90 | reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) 91 | union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() 92 | intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() 93 | nonzero = union > 0 94 | 95 | iou = intersection[nonzero] / union[nonzero] 96 | if not keep_ignore: 97 | return iou 98 | else: 99 | result = np.full_like(intersection, -1) 100 | result[nonzero] = iou 101 | return result 102 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/inference/interact/fbrs/model/modeling/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/modeling/basic_blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ...model import ops 4 | 5 | 6 | class ConvHead(nn.Module): 7 | def __init__(self, out_channels, in_channels=32, num_layers=1, 8 | kernel_size=3, padding=1, 9 | norm_layer=nn.BatchNorm2d): 10 | super(ConvHead, self).__init__() 11 | convhead = [] 12 | 13 | for i in range(num_layers): 14 | convhead.extend([ 15 | nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), 16 | nn.ReLU(), 17 | norm_layer(in_channels) if norm_layer is not None else nn.Identity() 18 | ]) 19 | convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0)) 20 | 21 | self.convhead = nn.Sequential(*convhead) 22 | 23 | def forward(self, *inputs): 24 | return self.convhead(inputs[0]) 25 | 26 | 27 | class SepConvHead(nn.Module): 28 | def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1, 29 | kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0, 30 | norm_layer=nn.BatchNorm2d): 31 | super(SepConvHead, self).__init__() 32 | 33 | sepconvhead = [] 34 | 35 | for i in range(num_layers): 36 | sepconvhead.append( 37 | SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels, 38 | out_channels=mid_channels, 39 | dw_kernel=kernel_size, dw_padding=padding, 40 | norm_layer=norm_layer, activation='relu') 41 | ) 42 | if dropout_ratio > 0 and dropout_indx == i: 43 | sepconvhead.append(nn.Dropout(dropout_ratio)) 44 | 45 | sepconvhead.append( 46 | nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0) 47 | ) 48 | 49 | self.layers = nn.Sequential(*sepconvhead) 50 | 51 | def forward(self, *inputs): 52 | x = inputs[0] 53 | 54 | return self.layers(x) 55 | 56 | 57 | class SeparableConv2d(nn.Module): 58 | def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1, 59 | activation=None, use_bias=False, norm_layer=None): 60 | super(SeparableConv2d, self).__init__() 61 | _activation = ops.select_activation_function(activation) 62 | self.body = nn.Sequential( 63 | nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride, 64 | padding=dw_padding, bias=use_bias, groups=in_channels), 65 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias), 66 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(), 67 | _activation() 68 | ) 69 | 70 | def forward(self, x): 71 | return self.body(x) 72 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/modeling/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s 3 | 4 | 5 | class ResNetBackbone(torch.nn.Module): 6 | def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs): 7 | super(ResNetBackbone, self).__init__() 8 | 9 | if backbone == 'resnet34': 10 | pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) 11 | elif backbone == 'resnet50': 12 | pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 13 | elif backbone == 'resnet101': 14 | pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 15 | elif backbone == 'resnet152': 16 | pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 17 | else: 18 | raise RuntimeError(f'unknown backbone: {backbone}') 19 | 20 | self.conv1 = pretrained.conv1 21 | self.bn1 = pretrained.bn1 22 | self.relu = pretrained.relu 23 | self.maxpool = pretrained.maxpool 24 | self.layer1 = pretrained.layer1 25 | self.layer2 = pretrained.layer2 26 | self.layer3 = pretrained.layer3 27 | self.layer4 = pretrained.layer4 28 | 29 | def forward(self, x): 30 | x = self.conv1(x) 31 | x = self.bn1(x) 32 | x = self.relu(x) 33 | x = self.maxpool(x) 34 | c1 = self.layer1(x) 35 | c2 = self.layer2(c1) 36 | c3 = self.layer3(c2) 37 | c4 = self.layer4(c3) 38 | 39 | return c1, c2, c3, c4 40 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | import numpy as np 4 | 5 | from . import initializer as initializer 6 | from ..utils.cython import get_dist_maps 7 | 8 | 9 | def select_activation_function(activation): 10 | if isinstance(activation, str): 11 | if activation.lower() == 'relu': 12 | return nn.ReLU 13 | elif activation.lower() == 'softplus': 14 | return nn.Softplus 15 | else: 16 | raise ValueError(f"Unknown activation type {activation}") 17 | elif isinstance(activation, nn.Module): 18 | return activation 19 | else: 20 | raise ValueError(f"Unknown activation type {activation}") 21 | 22 | 23 | class BilinearConvTranspose2d(nn.ConvTranspose2d): 24 | def __init__(self, in_channels, out_channels, scale, groups=1): 25 | kernel_size = 2 * scale - scale % 2 26 | self.scale = scale 27 | 28 | super().__init__( 29 | in_channels, out_channels, 30 | kernel_size=kernel_size, 31 | stride=scale, 32 | padding=1, 33 | groups=groups, 34 | bias=False) 35 | 36 | self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)) 37 | 38 | 39 | class DistMaps(nn.Module): 40 | def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False): 41 | super(DistMaps, self).__init__() 42 | self.spatial_scale = spatial_scale 43 | self.norm_radius = norm_radius 44 | self.cpu_mode = cpu_mode 45 | 46 | def get_coord_features(self, points, batchsize, rows, cols): 47 | if self.cpu_mode: 48 | coords = [] 49 | for i in range(batchsize): 50 | norm_delimeter = self.spatial_scale * self.norm_radius 51 | coords.append(get_dist_maps(points[i].cpu().float().numpy(), rows, cols, 52 | norm_delimeter)) 53 | coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() 54 | else: 55 | num_points = points.shape[1] // 2 56 | points = points.view(-1, 2) 57 | invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 58 | row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) 59 | col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) 60 | 61 | coord_rows, coord_cols = torch.meshgrid(row_array, col_array) 62 | coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) 63 | 64 | add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) 65 | coords.add_(-add_xy) 66 | coords.div_(self.norm_radius * self.spatial_scale) 67 | coords.mul_(coords) 68 | 69 | coords[:, 0] += coords[:, 1] 70 | coords = coords[:, :1] 71 | 72 | coords[invalid_points, :, :, :] = 1e6 73 | 74 | coords = coords.view(-1, num_points, 1, rows, cols) 75 | coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w 76 | coords = coords.view(-1, 2, rows, cols) 77 | 78 | coords.sqrt_().mul_(2).tanh_() 79 | 80 | return coords 81 | 82 | def forward(self, x, coords): 83 | return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) 84 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/syncbn/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tamaki Kojima 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/syncbn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/inference/interact/fbrs/model/syncbn/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/syncbn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/inference/interact/fbrs/model/syncbn/modules/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/syncbn/modules/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import batchnorm2d_sync 2 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/syncbn/modules/functional/_csrc.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | Extension module loader 5 | 6 | code referenced from : https://github.com/facebookresearch/maskrcnn-benchmark 7 | 8 | /*****************************************************************************/ 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import glob 15 | import os.path 16 | 17 | import torch 18 | 19 | try: 20 | from torch.utils.cpp_extension import load 21 | from torch.utils.cpp_extension import CUDA_HOME 22 | except ImportError: 23 | raise ImportError( 24 | "The cpp layer extensions requires PyTorch 0.4 or higher") 25 | 26 | 27 | def _load_C_extensions(): 28 | this_dir = os.path.dirname(os.path.abspath(__file__)) 29 | this_dir = os.path.join(this_dir, "csrc") 30 | 31 | main_file = glob.glob(os.path.join(this_dir, "*.cpp")) 32 | sources_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp")) 33 | sources_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu")) 34 | 35 | sources = main_file + sources_cpu 36 | 37 | extra_cflags = [] 38 | extra_cuda_cflags = [] 39 | if torch.cuda.is_available() and CUDA_HOME is not None: 40 | sources.extend(sources_cuda) 41 | extra_cflags = ["-O3", "-DWITH_CUDA"] 42 | extra_cuda_cflags = ["--expt-extended-lambda"] 43 | sources = [os.path.join(this_dir, s) for s in sources] 44 | extra_include_paths = [this_dir] 45 | return load( 46 | name="ext_lib", 47 | sources=sources, 48 | extra_cflags=extra_cflags, 49 | extra_include_paths=extra_include_paths, 50 | extra_cuda_cflags=extra_cuda_cflags, 51 | ) 52 | 53 | 54 | _backend = _load_C_extensions() 55 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/bn.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | SyncBN 4 | 5 | *****************************************************************************/ 6 | #pragma once 7 | 8 | #ifdef WITH_CUDA 9 | #include "cuda/ext_lib.h" 10 | #endif 11 | 12 | /// SyncBN 13 | 14 | std::vector syncbn_sum_sqsum(const at::Tensor& x) { 15 | if (x.is_cuda()) { 16 | #ifdef WITH_CUDA 17 | return syncbn_sum_sqsum_cuda(x); 18 | #else 19 | AT_ERROR("Not compiled with GPU support"); 20 | #endif 21 | } else { 22 | AT_ERROR("CPU implementation not supported"); 23 | } 24 | } 25 | 26 | at::Tensor syncbn_forward(const at::Tensor& x, const at::Tensor& weight, 27 | const at::Tensor& bias, const at::Tensor& mean, 28 | const at::Tensor& var, bool affine, float eps) { 29 | if (x.is_cuda()) { 30 | #ifdef WITH_CUDA 31 | return syncbn_forward_cuda(x, weight, bias, mean, var, affine, eps); 32 | #else 33 | AT_ERROR("Not compiled with GPU support"); 34 | #endif 35 | } else { 36 | AT_ERROR("CPU implementation not supported"); 37 | } 38 | } 39 | 40 | std::vector syncbn_backward_xhat(const at::Tensor& dz, 41 | const at::Tensor& x, 42 | const at::Tensor& mean, 43 | const at::Tensor& var, float eps) { 44 | if (dz.is_cuda()) { 45 | #ifdef WITH_CUDA 46 | return syncbn_backward_xhat_cuda(dz, x, mean, var, eps); 47 | #else 48 | AT_ERROR("Not compiled with GPU support"); 49 | #endif 50 | } else { 51 | AT_ERROR("CPU implementation not supported"); 52 | } 53 | } 54 | 55 | std::vector syncbn_backward( 56 | const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, 57 | const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, 58 | const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, 59 | float eps) { 60 | if (dz.is_cuda()) { 61 | #ifdef WITH_CUDA 62 | return syncbn_backward_cuda(dz, x, weight, bias, mean, var, sum_dz, 63 | sum_dz_xhat, affine, eps); 64 | #else 65 | AT_ERROR("Not compiled with GPU support"); 66 | #endif 67 | } else { 68 | AT_ERROR("CPU implementation not supported"); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/common.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | CUDA utility funcs 4 | 5 | code referenced from : https://github.com/mapillary/inplace_abn 6 | 7 | *****************************************************************************/ 8 | #pragma once 9 | 10 | #include 11 | 12 | // Checks 13 | #ifndef AT_CHECK 14 | #define AT_CHECK AT_ASSERT 15 | #endif 16 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") 18 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 19 | 20 | /* 21 | * General settings 22 | */ 23 | const int WARP_SIZE = 32; 24 | const int MAX_BLOCK_SIZE = 512; 25 | 26 | template 27 | struct Pair { 28 | T v1, v2; 29 | __device__ Pair() {} 30 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 31 | __device__ Pair(T v) : v1(v), v2(v) {} 32 | __device__ Pair(int v) : v1(v), v2(v) {} 33 | __device__ Pair &operator+=(const Pair &a) { 34 | v1 += a.v1; 35 | v2 += a.v2; 36 | return *this; 37 | } 38 | }; 39 | 40 | /* 41 | * Utility functions 42 | */ 43 | template 44 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, 45 | int width = warpSize, 46 | unsigned int mask = 0xffffffff) { 47 | #if CUDART_VERSION >= 9000 48 | return __shfl_xor_sync(mask, value, laneMask, width); 49 | #else 50 | return __shfl_xor(value, laneMask, width); 51 | #endif 52 | } 53 | 54 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 55 | 56 | static int getNumThreads(int nElem) { 57 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 58 | for (int i = 0; i != 5; ++i) { 59 | if (nElem <= threadSizes[i]) { 60 | return threadSizes[i]; 61 | } 62 | } 63 | return MAX_BLOCK_SIZE; 64 | } 65 | 66 | template 67 | static __device__ __forceinline__ T warpSum(T val) { 68 | #if __CUDA_ARCH__ >= 300 69 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 70 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 71 | } 72 | #else 73 | __shared__ T values[MAX_BLOCK_SIZE]; 74 | values[threadIdx.x] = val; 75 | __threadfence_block(); 76 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 77 | for (int i = 1; i < WARP_SIZE; i++) { 78 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 79 | } 80 | #endif 81 | return val; 82 | } 83 | 84 | template 85 | static __device__ __forceinline__ Pair warpSum(Pair value) { 86 | value.v1 = warpSum(value.v1); 87 | value.v2 = warpSum(value.v2); 88 | return value; 89 | } 90 | 91 | template 92 | __device__ T reduce(Op op, int plane, int N, int C, int S) { 93 | T sum = (T)0; 94 | for (int batch = 0; batch < N; ++batch) { 95 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 96 | sum += op(batch, plane, x); 97 | } 98 | } 99 | 100 | // sum over NumThreads within a warp 101 | sum = warpSum(sum); 102 | 103 | // 'transpose', and reduce within warp again 104 | __shared__ T shared[32]; 105 | __syncthreads(); 106 | if (threadIdx.x % WARP_SIZE == 0) { 107 | shared[threadIdx.x / WARP_SIZE] = sum; 108 | } 109 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 110 | // zero out the other entries in shared 111 | shared[threadIdx.x] = (T)0; 112 | } 113 | __syncthreads(); 114 | if (threadIdx.x / WARP_SIZE == 0) { 115 | sum = warpSum(shared[threadIdx.x]); 116 | if (threadIdx.x == 0) { 117 | shared[0] = sum; 118 | } 119 | } 120 | __syncthreads(); 121 | 122 | // Everyone picks it up, should be broadcast into the whole gradInput 123 | return shared[0]; 124 | } -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/cuda/ext_lib.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | CUDA SyncBN code 4 | 5 | *****************************************************************************/ 6 | #pragma once 7 | #include 8 | #include 9 | 10 | /// Sync-BN 11 | std::vector syncbn_sum_sqsum_cuda(const at::Tensor& x); 12 | at::Tensor syncbn_forward_cuda(const at::Tensor& x, const at::Tensor& weight, 13 | const at::Tensor& bias, const at::Tensor& mean, 14 | const at::Tensor& var, bool affine, float eps); 15 | std::vector syncbn_backward_xhat_cuda(const at::Tensor& dz, 16 | const at::Tensor& x, 17 | const at::Tensor& mean, 18 | const at::Tensor& var, 19 | float eps); 20 | std::vector syncbn_backward_cuda( 21 | const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, 22 | const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, 23 | const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, 24 | float eps); 25 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/syncbn/modules/functional/csrc/ext_lib.cpp: -------------------------------------------------------------------------------- 1 | #include "bn.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("syncbn_sum_sqsum", &syncbn_sum_sqsum, "Sum and Sum^2 computation"); 5 | m.def("syncbn_forward", &syncbn_forward, "SyncBN forward computation"); 6 | m.def("syncbn_backward_xhat", &syncbn_backward_xhat, 7 | "First part of SyncBN backward computation"); 8 | m.def("syncbn_backward", &syncbn_backward, 9 | "Second part of SyncBN backward computation"); 10 | } -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/model/syncbn/modules/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import * 2 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/inference/interact/fbrs/utils/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/utils/cython/__init__.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | from .dist_maps import get_dist_maps -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport cython 3 | cimport numpy as np 4 | from libc.stdlib cimport malloc, free 5 | 6 | ctypedef struct qnode: 7 | int row 8 | int col 9 | int layer 10 | int orig_row 11 | int orig_col 12 | 13 | @cython.infer_types(True) 14 | @cython.boundscheck(False) 15 | @cython.wraparound(False) 16 | @cython.nonecheck(False) 17 | def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points, 18 | int height, int width, float norm_delimeter): 19 | cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \ 20 | np.full((2, height, width), 1e6, dtype=np.float32, order="C") 21 | 22 | cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0] 23 | cdef int i, j, x, y, dx, dy 24 | cdef qnode v 25 | cdef qnode *q = malloc((4 * height * width + 1) * sizeof(qnode)) 26 | cdef int qhead = 0, qtail = -1 27 | cdef float ndist 28 | 29 | for i in range(points.shape[0]): 30 | x, y = round(points[i, 0]), round(points[i, 1]) 31 | if x >= 0: 32 | qtail += 1 33 | q[qtail].row = x 34 | q[qtail].col = y 35 | q[qtail].orig_row = x 36 | q[qtail].orig_col = y 37 | if i >= points.shape[0] / 2: 38 | q[qtail].layer = 1 39 | else: 40 | q[qtail].layer = 0 41 | dist_maps[q[qtail].layer, x, y] = 0 42 | 43 | while qtail - qhead + 1 > 0: 44 | v = q[qhead] 45 | qhead += 1 46 | 47 | for k in range(4): 48 | x = v.row + dxy[2 * k] 49 | y = v.col + dxy[2 * k + 1] 50 | 51 | ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2 52 | if (x >= 0 and y >= 0 and x < height and y < width and 53 | dist_maps[v.layer, x, y] > ndist): 54 | qtail += 1 55 | q[qtail].orig_col = v.orig_col 56 | q[qtail].orig_row = v.orig_row 57 | q[qtail].layer = v.layer 58 | q[qtail].row = x 59 | q[qtail].col = y 60 | dist_maps[v.layer, x, y] = ndist 61 | 62 | free(q) 63 | return dist_maps 64 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/utils/cython/_get_dist_maps.pyxbld: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | def make_ext(modname, pyxfilename): 4 | from distutils.extension import Extension 5 | return Extension(modname, [pyxfilename], 6 | include_dirs=[numpy.get_include()], 7 | extra_compile_args=['-O3'], language='c++') 8 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/utils/cython/dist_maps.py: -------------------------------------------------------------------------------- 1 | import pyximport; pyximport.install(pyximport=True, language_level=3) 2 | # noinspection PyUnresolvedReferences 3 | from ._get_dist_maps import get_dist_maps -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs/utils/misc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def get_dims_with_exclusion(dim, exclude=None): 8 | dims = list(range(dim)) 9 | if exclude is not None: 10 | dims.remove(exclude) 11 | 12 | return dims 13 | 14 | 15 | def get_unique_labels(mask): 16 | return np.nonzero(np.bincount(mask.flatten() + 1))[0] - 1 17 | 18 | 19 | def get_bbox_from_mask(mask): 20 | rows = np.any(mask, axis=1) 21 | cols = np.any(mask, axis=0) 22 | rmin, rmax = np.where(rows)[0][[0, -1]] 23 | cmin, cmax = np.where(cols)[0][[0, -1]] 24 | 25 | return rmin, rmax, cmin, cmax 26 | 27 | 28 | def expand_bbox(bbox, expand_ratio, min_crop_size=None): 29 | rmin, rmax, cmin, cmax = bbox 30 | rcenter = 0.5 * (rmin + rmax) 31 | ccenter = 0.5 * (cmin + cmax) 32 | height = expand_ratio * (rmax - rmin + 1) 33 | width = expand_ratio * (cmax - cmin + 1) 34 | if min_crop_size is not None: 35 | height = max(height, min_crop_size) 36 | width = max(width, min_crop_size) 37 | 38 | rmin = int(round(rcenter - 0.5 * height)) 39 | rmax = int(round(rcenter + 0.5 * height)) 40 | cmin = int(round(ccenter - 0.5 * width)) 41 | cmax = int(round(ccenter + 0.5 * width)) 42 | 43 | return rmin, rmax, cmin, cmax 44 | 45 | 46 | def clamp_bbox(bbox, rmin, rmax, cmin, cmax): 47 | return (max(rmin, bbox[0]), min(rmax, bbox[1]), 48 | max(cmin, bbox[2]), min(cmax, bbox[3])) 49 | 50 | 51 | def get_bbox_iou(b1, b2): 52 | h_iou = get_segments_iou(b1[:2], b2[:2]) 53 | w_iou = get_segments_iou(b1[2:4], b2[2:4]) 54 | return h_iou * w_iou 55 | 56 | 57 | def get_segments_iou(s1, s2): 58 | a, b = s1 59 | c, d = s2 60 | intersection = max(0, min(b, d) - max(a, c) + 1) 61 | union = max(1e-6, max(b, d) - min(a, c) + 1) 62 | return intersection / union 63 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/fbrs_controller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .fbrs.controller import InteractiveController 3 | from .fbrs.inference import utils 4 | 5 | 6 | class FBRSController: 7 | def __init__(self, checkpoint_path, device='cuda:0', max_size=800): 8 | model = utils.load_is_model(checkpoint_path, device, cpu_dist_maps=True, norm_radius=260) 9 | 10 | # Predictor params 11 | zoomin_params = { 12 | 'skip_clicks': 1, 13 | 'target_size': 480, 14 | 'expansion_ratio': 1.4, 15 | } 16 | 17 | predictor_params = { 18 | 'brs_mode': 'f-BRS-B', 19 | 'prob_thresh': 0.5, 20 | 'zoom_in_params': zoomin_params, 21 | 'predictor_params': { 22 | 'net_clicks_limit': 8, 23 | 'max_size': 800, 24 | }, 25 | 'brs_opt_func_params': {'min_iou_diff': 1e-3}, 26 | 'lbfgs_params': {'maxfun': 20} 27 | } 28 | 29 | self.controller = InteractiveController(model, device, predictor_params) 30 | self.anchored = False 31 | self.device = device 32 | 33 | def unanchor(self): 34 | self.anchored = False 35 | 36 | def interact(self, image, x, y, is_positive): 37 | image = image.to(self.device, non_blocking=True) 38 | if not self.anchored: 39 | self.controller.set_image(image) 40 | self.controller.reset_predictor() 41 | self.anchored = True 42 | 43 | self.controller.add_click(x, y, is_positive) 44 | # return self.controller.result_mask 45 | # return self.controller.probs_history[-1][1] 46 | return (self.controller.probs_history[-1][1]>0.5).float() 47 | 48 | def undo(self): 49 | self.controller.undo_click() 50 | if len(self.controller.probs_history) == 0: 51 | return None 52 | else: 53 | return (self.controller.probs_history[-1][1]>0.5).float() -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/gui_utils.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtCore import Qt 2 | from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar) 3 | 4 | 5 | def create_parameter_box(min_val, max_val, text, step=1, callback=None): 6 | layout = QHBoxLayout() 7 | 8 | dial = QSpinBox() 9 | dial.setMaximumHeight(28) 10 | dial.setMaximumWidth(150) 11 | dial.setMinimum(min_val) 12 | dial.setMaximum(max_val) 13 | dial.setAlignment(Qt.AlignRight) 14 | dial.setSingleStep(step) 15 | dial.valueChanged.connect(callback) 16 | 17 | label = QLabel(text) 18 | label.setAlignment(Qt.AlignRight) 19 | 20 | layout.addWidget(label) 21 | layout.addWidget(dial) 22 | 23 | return dial, layout 24 | 25 | 26 | def create_gauge(text): 27 | layout = QHBoxLayout() 28 | 29 | gauge = QProgressBar() 30 | gauge.setMaximumHeight(28) 31 | gauge.setMaximumWidth(200) 32 | gauge.setAlignment(Qt.AlignCenter) 33 | 34 | label = QLabel(text) 35 | label.setAlignment(Qt.AlignRight) 36 | 37 | layout.addWidget(label) 38 | layout.addWidget(gauge) 39 | 40 | return gauge, layout 41 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/s2m/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/inference/interact/s2m/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/s2m/s2m_network.py: -------------------------------------------------------------------------------- 1 | # Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch 2 | 3 | from .utils import IntermediateLayerGetter 4 | from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3 5 | from . import s2m_resnet 6 | 7 | def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone): 8 | 9 | if output_stride==8: 10 | replace_stride_with_dilation=[False, True, True] 11 | aspp_dilate = [12, 24, 36] 12 | else: 13 | replace_stride_with_dilation=[False, False, True] 14 | aspp_dilate = [6, 12, 18] 15 | 16 | backbone = s2m_resnet.__dict__[backbone_name]( 17 | pretrained=pretrained_backbone, 18 | replace_stride_with_dilation=replace_stride_with_dilation) 19 | 20 | inplanes = 2048 21 | low_level_planes = 256 22 | 23 | if name=='deeplabv3plus': 24 | return_layers = {'layer4': 'out', 'layer1': 'low_level'} 25 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) 26 | elif name=='deeplabv3': 27 | return_layers = {'layer4': 'out'} 28 | classifier = DeepLabHead(inplanes , num_classes, aspp_dilate) 29 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 30 | 31 | model = DeepLabV3(backbone, classifier) 32 | return model 33 | 34 | def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone): 35 | 36 | if backbone.startswith('resnet'): 37 | model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 38 | else: 39 | raise NotImplementedError 40 | return model 41 | 42 | 43 | # Deeplab v3 44 | def deeplabv3_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False): 45 | """Constructs a DeepLabV3 model with a ResNet-50 backbone. 46 | 47 | Args: 48 | num_classes (int): number of classes. 49 | output_stride (int): output stride for deeplab. 50 | pretrained_backbone (bool): If True, use the pretrained backbone. 51 | """ 52 | return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 53 | 54 | 55 | # Deeplab v3+ 56 | def deeplabv3plus_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False): 57 | """Constructs a DeepLabV3 model with a ResNet-50 backbone. 58 | 59 | Args: 60 | num_classes (int): number of classes. 61 | output_stride (int): output stride for deeplab. 62 | pretrained_backbone (bool): If True, use the pretrained backbone. 63 | """ 64 | return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone) 65 | 66 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/s2m/utils.py: -------------------------------------------------------------------------------- 1 | # Credit: https://github.com/VainF/DeepLabV3Plus-Pytorch 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from collections import OrderedDict 8 | 9 | class _SimpleSegmentationModel(nn.Module): 10 | def __init__(self, backbone, classifier): 11 | super(_SimpleSegmentationModel, self).__init__() 12 | self.backbone = backbone 13 | self.classifier = classifier 14 | 15 | def forward(self, x): 16 | input_shape = x.shape[-2:] 17 | features = self.backbone(x) 18 | x = self.classifier(features) 19 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 20 | return x 21 | 22 | 23 | class IntermediateLayerGetter(nn.ModuleDict): 24 | """ 25 | Module wrapper that returns intermediate layers from a model 26 | 27 | It has a strong assumption that the modules have been registered 28 | into the model in the same order as they are used. 29 | This means that one should **not** reuse the same nn.Module 30 | twice in the forward if you want this to work. 31 | 32 | Additionally, it is only able to query submodules that are directly 33 | assigned to the model. So if `model` is passed, `model.feature1` can 34 | be returned, but not `model.feature1.layer2`. 35 | 36 | Arguments: 37 | model (nn.Module): model on which we will extract the features 38 | return_layers (Dict[name, new_name]): a dict containing the names 39 | of the modules for which the activations will be returned as 40 | the key of the dict, and the value of the dict is the name 41 | of the returned activation (which the user can specify). 42 | 43 | Examples:: 44 | 45 | >>> m = torchvision.models.resnet18(pretrained=True) 46 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2` 47 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, 48 | >>> {'layer1': 'feat1', 'layer3': 'feat2'}) 49 | >>> out = new_m(torch.rand(1, 3, 224, 224)) 50 | >>> print([(k, v.shape) for k, v in out.items()]) 51 | >>> [('feat1', torch.Size([1, 64, 56, 56])), 52 | >>> ('feat2', torch.Size([1, 256, 14, 14]))] 53 | """ 54 | def __init__(self, model, return_layers): 55 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 56 | raise ValueError("return_layers are not present in model") 57 | 58 | orig_return_layers = return_layers 59 | return_layers = {k: v for k, v in return_layers.items()} 60 | layers = OrderedDict() 61 | for name, module in model.named_children(): 62 | layers[name] = module 63 | if name in return_layers: 64 | del return_layers[name] 65 | if not return_layers: 66 | break 67 | 68 | super(IntermediateLayerGetter, self).__init__(layers) 69 | self.return_layers = orig_return_layers 70 | 71 | def forward(self, x): 72 | out = OrderedDict() 73 | for name, module in self.named_children(): 74 | x = module(x) 75 | if name in self.return_layers: 76 | out_name = self.return_layers[name] 77 | out[out_name] = x 78 | return out 79 | -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/s2m_controller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from ..interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M 4 | 5 | from util.tensor_util import pad_divide_by, unpad 6 | 7 | 8 | class S2MController: 9 | """ 10 | A controller for Scribble-to-Mask (for user interaction, not for DAVIS) 11 | Takes the image, previous mask, and scribbles to produce a new mask 12 | ignore_class is usually 255 13 | 0 is NOT the ignore class -- it is the label for the background 14 | """ 15 | def __init__(self, s2m_net:S2M, num_objects, ignore_class, device='cuda:0'): 16 | self.s2m_net = s2m_net 17 | self.num_objects = num_objects 18 | self.ignore_class = ignore_class 19 | self.device = device 20 | 21 | def interact(self, image, prev_mask, scr_mask): 22 | image = image.to(self.device, non_blocking=True) 23 | prev_mask = prev_mask.unsqueeze(0) 24 | 25 | h, w = image.shape[-2:] 26 | unaggre_mask = torch.zeros((self.num_objects, h, w), dtype=torch.float32, device=image.device) 27 | 28 | for ki in range(1, self.num_objects+1): 29 | p_srb = (scr_mask==ki).astype(np.uint8) 30 | n_srb = ((scr_mask!=ki) * (scr_mask!=self.ignore_class)).astype(np.uint8) 31 | 32 | Rs = torch.from_numpy(np.stack([p_srb, n_srb], 0)).unsqueeze(0).float().to(image.device) 33 | 34 | inputs = torch.cat([image, (prev_mask==ki).float().unsqueeze(0), Rs], 1) 35 | inputs, pads = pad_divide_by(inputs, 16) 36 | 37 | unaggre_mask[ki-1] = unpad(torch.sigmoid(self.s2m_net(inputs)), pads) 38 | 39 | return unaggre_mask -------------------------------------------------------------------------------- /third_party/XMem/inference/interact/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer: 4 | def __init__(self): 5 | self._acc_time = 0 6 | self._paused = True 7 | 8 | def start(self): 9 | if self._paused: 10 | self.last_time = time.time() 11 | self._paused = False 12 | return self 13 | 14 | def pause(self): 15 | self.count() 16 | self._paused = True 17 | return self 18 | 19 | def count(self): 20 | if self._paused: 21 | return self._acc_time 22 | t = time.time() 23 | self._acc_time += t - self.last_time 24 | self.last_time = t 25 | return self._acc_time 26 | 27 | def format(self): 28 | # count = int(self.count()*100) 29 | # return '%02d:%02d:%02d' % (count//6000, (count//100)%60, count%100) 30 | return '%03.2f' % self.count() 31 | 32 | def __str__(self): 33 | return self.format() -------------------------------------------------------------------------------- /third_party/XMem/interactive_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple user interface for XMem 3 | """ 4 | 5 | import os 6 | # fix for Windows 7 | if 'QT_QPA_PLATFORM_PLUGIN_PATH' not in os.environ: 8 | os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = '' 9 | 10 | import sys 11 | from argparse import ArgumentParser 12 | 13 | import torch 14 | 15 | from model.network import XMem 16 | from inference.interact.s2m_controller import S2MController 17 | from inference.interact.fbrs_controller import FBRSController 18 | from inference.interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M 19 | 20 | from PyQt5.QtWidgets import QApplication 21 | from inference.interact.gui import App 22 | from inference.interact.resource_manager import ResourceManager 23 | 24 | torch.set_grad_enabled(False) 25 | 26 | 27 | if __name__ == '__main__': 28 | 29 | # Arguments parsing 30 | parser = ArgumentParser() 31 | parser.add_argument('--model', default='./saves/XMem.pth') 32 | parser.add_argument('--s2m_model', default='saves/s2m.pth') 33 | parser.add_argument('--fbrs_model', default='saves/fbrs.pth') 34 | 35 | """ 36 | Priority 1: If a "images" folder exists in the workspace, we will read from that directory 37 | Priority 2: If --images is specified, we will copy/resize those images to the workspace 38 | Priority 3: If --video is specified, we will extract the frames to the workspace (in an "images" folder) and read from there 39 | 40 | In any case, if a "masks" folder exists in the workspace, we will use that to initialize the mask 41 | That way, you can continue annotation from an interrupted run as long as the same workspace is used. 42 | """ 43 | parser.add_argument('--images', help='Folders containing input images.', default=None) 44 | parser.add_argument('--video', help='Video file readable by OpenCV.', default=None) 45 | parser.add_argument('--workspace', help='directory for storing buffered images (if needed) and output masks', default=None) 46 | 47 | parser.add_argument('--buffer_size', help='Correlate with CPU memory consumption', type=int, default=100) 48 | 49 | parser.add_argument('--num_objects', type=int, default=1) 50 | 51 | # Long-memory options 52 | # Defaults. Some can be changed in the GUI. 53 | parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10) 54 | parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5) 55 | parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time', 56 | type=int, default=10000) 57 | parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128) 58 | 59 | parser.add_argument('--top_k', type=int, default=30) 60 | parser.add_argument('--mem_every', type=int, default=10) 61 | parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1) 62 | parser.add_argument('--no_amp', help='Turn off AMP', action='store_true') 63 | parser.add_argument('--size', default=480, type=int, 64 | help='Resize the shorter side to this size. -1 to use original resolution. ') 65 | args = parser.parse_args() 66 | 67 | config = vars(args) 68 | config['enable_long_term'] = True 69 | config['enable_long_term_count_usage'] = True 70 | 71 | with torch.cuda.amp.autocast(enabled=not args.no_amp): 72 | 73 | # Load our checkpoint 74 | network = XMem(config, args.model).cuda().eval() 75 | 76 | # Loads the S2M model 77 | if args.s2m_model is not None: 78 | s2m_saved = torch.load(args.s2m_model) 79 | s2m_model = S2M().cuda().eval() 80 | s2m_model.load_state_dict(s2m_saved) 81 | else: 82 | s2m_model = None 83 | 84 | s2m_controller = S2MController(s2m_model, args.num_objects, ignore_class=255) 85 | if args.fbrs_model is not None: 86 | fbrs_controller = FBRSController(args.fbrs_model) 87 | else: 88 | fbrs_controller = None 89 | 90 | # Manages most IO 91 | resource_manager = ResourceManager(config) 92 | 93 | app = QApplication(sys.argv) 94 | ex = App(network, resource_manager, s2m_controller, fbrs_controller, config) 95 | sys.exit(app.exec_()) 96 | -------------------------------------------------------------------------------- /third_party/XMem/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/model/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/model/aggregate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # Soft aggregation from STM 6 | def aggregate(prob, dim, return_logits=False): 7 | new_prob = torch.cat([ 8 | torch.prod(1-prob, dim=dim, keepdim=True), 9 | prob 10 | ], dim).clamp(1e-7, 1-1e-7) 11 | logits = torch.log((new_prob /(1-new_prob))) 12 | prob = F.softmax(logits, dim=dim) 13 | 14 | if return_logits: 15 | return logits, prob 16 | else: 17 | return prob -------------------------------------------------------------------------------- /third_party/XMem/model/cbam.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class BasicConv(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 9 | super(BasicConv, self).__init__() 10 | self.out_channels = out_planes 11 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 12 | 13 | def forward(self, x): 14 | x = self.conv(x) 15 | return x 16 | 17 | class Flatten(nn.Module): 18 | def forward(self, x): 19 | return x.view(x.size(0), -1) 20 | 21 | class ChannelGate(nn.Module): 22 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 23 | super(ChannelGate, self).__init__() 24 | self.gate_channels = gate_channels 25 | self.mlp = nn.Sequential( 26 | Flatten(), 27 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 28 | nn.ReLU(), 29 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 30 | ) 31 | self.pool_types = pool_types 32 | def forward(self, x): 33 | channel_att_sum = None 34 | for pool_type in self.pool_types: 35 | if pool_type=='avg': 36 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 37 | channel_att_raw = self.mlp( avg_pool ) 38 | elif pool_type=='max': 39 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 40 | channel_att_raw = self.mlp( max_pool ) 41 | 42 | if channel_att_sum is None: 43 | channel_att_sum = channel_att_raw 44 | else: 45 | channel_att_sum = channel_att_sum + channel_att_raw 46 | 47 | scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 48 | return x * scale 49 | 50 | class ChannelPool(nn.Module): 51 | def forward(self, x): 52 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 53 | 54 | class SpatialGate(nn.Module): 55 | def __init__(self): 56 | super(SpatialGate, self).__init__() 57 | kernel_size = 7 58 | self.compress = ChannelPool() 59 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) 60 | def forward(self, x): 61 | x_compress = self.compress(x) 62 | x_out = self.spatial(x_compress) 63 | scale = torch.sigmoid(x_out) # broadcasting 64 | return x * scale 65 | 66 | class CBAM(nn.Module): 67 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 68 | super(CBAM, self).__init__() 69 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 70 | self.no_spatial=no_spatial 71 | if not no_spatial: 72 | self.SpatialGate = SpatialGate() 73 | def forward(self, x): 74 | x_out = self.ChannelGate(x) 75 | if not self.no_spatial: 76 | x_out = self.SpatialGate(x_out) 77 | return x_out 78 | -------------------------------------------------------------------------------- /third_party/XMem/model/group_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Group-specific modules 3 | They handle features that also depends on the mask. 4 | Features are typically of shape 5 | batch_size * num_objects * num_channels * H * W 6 | 7 | All of them are permutation equivariant w.r.t. to the num_objects dimension 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | def interpolate_groups(g, ratio, mode, align_corners): 16 | batch_size, num_objects = g.shape[:2] 17 | g = F.interpolate(g.flatten(start_dim=0, end_dim=1), 18 | scale_factor=ratio, mode=mode, align_corners=align_corners) 19 | g = g.view(batch_size, num_objects, *g.shape[1:]) 20 | return g 21 | 22 | def upsample_groups(g, ratio=2, mode='bilinear', align_corners=False): 23 | return interpolate_groups(g, ratio, mode, align_corners) 24 | 25 | def downsample_groups(g, ratio=1/2, mode='area', align_corners=None): 26 | return interpolate_groups(g, ratio, mode, align_corners) 27 | 28 | 29 | class GConv2D(nn.Conv2d): 30 | def forward(self, g): 31 | batch_size, num_objects = g.shape[:2] 32 | g = super().forward(g.flatten(start_dim=0, end_dim=1)) 33 | return g.view(batch_size, num_objects, *g.shape[1:]) 34 | 35 | 36 | class GroupResBlock(nn.Module): 37 | def __init__(self, in_dim, out_dim): 38 | super().__init__() 39 | 40 | if in_dim == out_dim: 41 | self.downsample = None 42 | else: 43 | self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) 44 | 45 | self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) 46 | self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1) 47 | 48 | def forward(self, g): 49 | out_g = self.conv1(F.relu(g)) 50 | out_g = self.conv2(F.relu(out_g)) 51 | 52 | if self.downsample is not None: 53 | g = self.downsample(g) 54 | 55 | return out_g + g 56 | 57 | 58 | class MainToGroupDistributor(nn.Module): 59 | def __init__(self, x_transform=None, method='cat', reverse_order=False): 60 | super().__init__() 61 | 62 | self.x_transform = x_transform 63 | self.method = method 64 | self.reverse_order = reverse_order 65 | 66 | def forward(self, x, g): 67 | num_objects = g.shape[1] 68 | 69 | if self.x_transform is not None: 70 | x = self.x_transform(x) 71 | 72 | if self.method == 'cat': 73 | if self.reverse_order: 74 | g = torch.cat([g, x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1)], 2) 75 | else: 76 | g = torch.cat([x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1), g], 2) 77 | elif self.method == 'add': 78 | g = x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1) + g 79 | else: 80 | raise NotImplementedError 81 | 82 | return g 83 | -------------------------------------------------------------------------------- /third_party/XMem/model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from collections import defaultdict 6 | 7 | 8 | def dice_loss(input_mask, cls_gt): 9 | num_objects = input_mask.shape[1] 10 | losses = [] 11 | for i in range(num_objects): 12 | mask = input_mask[:,i].flatten(start_dim=1) 13 | # background not in mask, so we add one to cls_gt 14 | gt = (cls_gt==(i+1)).float().flatten(start_dim=1) 15 | numerator = 2 * (mask * gt).sum(-1) 16 | denominator = mask.sum(-1) + gt.sum(-1) 17 | loss = 1 - (numerator + 1) / (denominator + 1) 18 | losses.append(loss) 19 | return torch.cat(losses).mean() 20 | 21 | 22 | # https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch 23 | class BootstrappedCE(nn.Module): 24 | def __init__(self, start_warm, end_warm, top_p=0.15): 25 | super().__init__() 26 | 27 | self.start_warm = start_warm 28 | self.end_warm = end_warm 29 | self.top_p = top_p 30 | 31 | def forward(self, input, target, it): 32 | if it < self.start_warm: 33 | return F.cross_entropy(input, target), 1.0 34 | 35 | raw_loss = F.cross_entropy(input, target, reduction='none').view(-1) 36 | num_pixels = raw_loss.numel() 37 | 38 | if it > self.end_warm: 39 | this_p = self.top_p 40 | else: 41 | this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm)) 42 | loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False) 43 | return loss.mean(), this_p 44 | 45 | 46 | class LossComputer: 47 | def __init__(self, config): 48 | super().__init__() 49 | self.config = config 50 | self.bce = BootstrappedCE(config['start_warm'], config['end_warm']) 51 | 52 | def compute(self, data, num_objects, it): 53 | losses = defaultdict(int) 54 | 55 | b, t = data['rgb'].shape[:2] 56 | 57 | losses['total_loss'] = 0 58 | for ti in range(1, t): 59 | for bi in range(b): 60 | loss, p = self.bce(data[f'logits_{ti}'][bi:bi+1, :num_objects[bi]+1], data['cls_gt'][bi:bi+1,ti,0], it) 61 | losses['p'] += p / b / (t-1) 62 | losses[f'ce_loss_{ti}'] += loss / b 63 | 64 | losses['total_loss'] += losses['ce_loss_%d'%ti] 65 | losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti,0]) 66 | losses['total_loss'] += losses[f'dice_loss_{ti}'] 67 | 68 | return losses 69 | -------------------------------------------------------------------------------- /third_party/XMem/model/memory_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from typing import Optional 5 | 6 | 7 | def get_similarity(mk, ms, qk, qe): 8 | # used for training/inference and memory reading/memory potentiation 9 | # mk: B x CK x [N] - Memory keys 10 | # ms: B x 1 x [N] - Memory shrinkage 11 | # qk: B x CK x [HW/P] - Query keys 12 | # qe: B x CK x [HW/P] - Query selection 13 | # Dimensions in [] are flattened 14 | CK = mk.shape[1] 15 | mk = mk.flatten(start_dim=2) 16 | ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None 17 | qk = qk.flatten(start_dim=2) 18 | qe = qe.flatten(start_dim=2) if qe is not None else None 19 | 20 | if qe is not None: 21 | # See appendix for derivation 22 | # or you can just trust me ヽ(ー_ー )ノ 23 | mk = mk.transpose(1, 2) 24 | a_sq = (mk.pow(2) @ qe) 25 | two_ab = 2 * (mk @ (qk * qe)) 26 | b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) 27 | similarity = (-a_sq+two_ab-b_sq) 28 | else: 29 | # similar to STCN if we don't have the selection term 30 | a_sq = mk.pow(2).sum(1).unsqueeze(2) 31 | two_ab = 2 * (mk.transpose(1, 2) @ qk) 32 | similarity = (-a_sq+two_ab) 33 | 34 | if ms is not None: 35 | similarity = similarity * ms / math.sqrt(CK) # B*N*HW 36 | else: 37 | similarity = similarity / math.sqrt(CK) # B*N*HW 38 | 39 | return similarity 40 | 41 | def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, return_usage=False): 42 | # normalize similarity with top-k softmax 43 | # similarity: B x N x [HW/P] 44 | # use inplace with care 45 | if top_k is not None: 46 | values, indices = torch.topk(similarity, k=top_k, dim=1) 47 | 48 | x_exp = values.exp_() 49 | x_exp /= torch.sum(x_exp, dim=1, keepdim=True) 50 | if inplace: 51 | similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW 52 | affinity = similarity 53 | else: 54 | affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW 55 | else: 56 | maxes = torch.max(similarity, dim=1, keepdim=True)[0] 57 | x_exp = torch.exp(similarity - maxes) 58 | x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) 59 | affinity = x_exp / x_exp_sum 60 | indices = None 61 | 62 | if return_usage: 63 | return affinity, affinity.sum(dim=2) 64 | 65 | return affinity 66 | 67 | def get_affinity(mk, ms, qk, qe): 68 | # shorthand used in training with no top-k 69 | similarity = get_similarity(mk, ms, qk, qe) 70 | affinity = do_softmax(similarity) 71 | return affinity 72 | 73 | def readout(affinity, mv): 74 | B, CV, T, H, W = mv.shape 75 | 76 | mo = mv.view(B, CV, T*H*W) 77 | mem = torch.bmm(mo, affinity) 78 | mem = mem.view(B, CV, H, W) 79 | 80 | return mem 81 | -------------------------------------------------------------------------------- /third_party/XMem/requirements.txt: -------------------------------------------------------------------------------- 1 | progressbar2 2 | gdown 3 | gitpython 4 | git+https://github.com/cheind/py-thin-plate-spline 5 | hickle 6 | tensorboard 7 | numpy -------------------------------------------------------------------------------- /third_party/XMem/requirements_demo.txt: -------------------------------------------------------------------------------- 1 | PyQt5 2 | Cython 3 | scipy -------------------------------------------------------------------------------- /third_party/XMem/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/scripts/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/scripts/download_bl30k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | import tarfile 4 | 5 | 6 | LICENSE = """ 7 | This dataset is a derivative of ShapeNet. 8 | Please read and respect their licenses and terms before use. 9 | Textures and skybox image are obtained from Google image search with the "non-commercial reuse" flag. 10 | Do not use this dataset for commercial purposes. 11 | You should cite both ShapeNet and our paper if you use this dataset. 12 | """ 13 | 14 | print(LICENSE) 15 | print('Datasets will be downloaded and extracted to ../BL30K') 16 | print('The script will download and extract the segment one by one') 17 | print('You are going to need ~1TB of free disk space') 18 | reply = input('[y] to confirm, others to exit: ') 19 | if reply != 'y': 20 | exit() 21 | 22 | links = [ 23 | 'https://drive.google.com/uc?id=1z9V5zxLOJLNt1Uj7RFqaP2FZWKzyXvVc', 24 | 'https://drive.google.com/uc?id=11-IzgNwEAPxgagb67FSrBdzZR7OKAEdJ', 25 | 'https://drive.google.com/uc?id=1ZfIv6GTo-OGpXpoKen1fUvDQ0A_WoQ-Q', 26 | 'https://drive.google.com/uc?id=1G4eXgYS2kL7_Cc0x3N1g1x7Zl8D_aU_-', 27 | 'https://drive.google.com/uc?id=1Y8q0V_oBwJIY27W_6-8CD1dRqV2gNTdE', 28 | 'https://drive.google.com/uc?id=1nawBAazf_unMv46qGBHhWcQ4JXZ5883r', 29 | ] 30 | 31 | names = [ 32 | 'BL30K_a.tar', 33 | 'BL30K_b.tar', 34 | 'BL30K_c.tar', 35 | 'BL30K_d.tar', 36 | 'BL30K_e.tar', 37 | 'BL30K_f.tar', 38 | ] 39 | 40 | for i, link in enumerate(links): 41 | print('Downloading segment %d/%d ...' % (i, len(links))) 42 | gdown.download(link, output='../%s' % names[i], quiet=False) 43 | print('Extracting...') 44 | with tarfile.open('../%s' % names[i], 'r') as tar_file: 45 | tar_file.extractall('../%s' % names[i]) 46 | print('Cleaning up...') 47 | os.remove('../%s' % names[i]) 48 | 49 | 50 | print('Done.') -------------------------------------------------------------------------------- /third_party/XMem/scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth 2 | wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth -------------------------------------------------------------------------------- /third_party/XMem/scripts/download_models_demo.sh: -------------------------------------------------------------------------------- 1 | wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem.pth 2 | wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/fbrs.pth 3 | wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/s2m.pth -------------------------------------------------------------------------------- /third_party/XMem/scripts/expand_long_vid.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from os import path 4 | from shutil import copy2 5 | 6 | input_path = sys.argv[1] 7 | output_path = sys.argv[2] 8 | multiplier = int(sys.argv[3]) 9 | image_path = path.join(input_path, 'JPEGImages') 10 | gt_path = path.join(input_path, 'Annotations') 11 | 12 | videos = sorted(os.listdir(image_path)) 13 | 14 | for vid in videos: 15 | os.makedirs(path.join(output_path, 'JPEGImages', vid), exist_ok=True) 16 | os.makedirs(path.join(output_path, 'Annotations', vid), exist_ok=True) 17 | frames = sorted(os.listdir(path.join(image_path, vid))) 18 | 19 | num_frames = len(frames) 20 | counter = 0 21 | output_counter = 0 22 | direction = 1 23 | for _ in range(multiplier): 24 | for _ in range(num_frames): 25 | copy2(path.join(image_path, vid, frames[counter]), 26 | path.join(output_path, 'JPEGImages', vid, f'{output_counter:05d}.jpg')) 27 | 28 | mask_path = path.join(gt_path, vid, frames[counter].replace('.jpg', '.png')) 29 | if path.exists(mask_path): 30 | copy2(mask_path, 31 | path.join(output_path, 'Annotations', vid, f'{output_counter:05d}.png')) 32 | 33 | counter += direction 34 | output_counter += 1 35 | if counter == 0 or counter == len(frames) - 1: 36 | direction *= -1 37 | -------------------------------------------------------------------------------- /third_party/XMem/scripts/resize_youtube.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from os import path 4 | 5 | from PIL import Image 6 | import numpy as np 7 | from progressbar import progressbar 8 | from multiprocessing import Pool 9 | 10 | new_min_size = 480 11 | 12 | def resize_vid_jpeg(inputs): 13 | vid_name, folder_path, out_path = inputs 14 | 15 | vid_path = path.join(folder_path, vid_name) 16 | vid_out_path = path.join(out_path, 'JPEGImages', vid_name) 17 | os.makedirs(vid_out_path, exist_ok=True) 18 | 19 | for im_name in os.listdir(vid_path): 20 | hr_im = Image.open(path.join(vid_path, im_name)) 21 | w, h = hr_im.size 22 | 23 | ratio = new_min_size / min(w, h) 24 | 25 | lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.BICUBIC) 26 | lr_im.save(path.join(vid_out_path, im_name)) 27 | 28 | def resize_vid_anno(inputs): 29 | vid_name, folder_path, out_path = inputs 30 | 31 | vid_path = path.join(folder_path, vid_name) 32 | vid_out_path = path.join(out_path, 'Annotations', vid_name) 33 | os.makedirs(vid_out_path, exist_ok=True) 34 | 35 | for im_name in os.listdir(vid_path): 36 | hr_im = Image.open(path.join(vid_path, im_name)).convert('P') 37 | w, h = hr_im.size 38 | 39 | ratio = new_min_size / min(w, h) 40 | 41 | lr_im = hr_im.resize((int(w*ratio), int(h*ratio)), Image.NEAREST) 42 | lr_im.save(path.join(vid_out_path, im_name)) 43 | 44 | 45 | def resize_all(in_path, out_path): 46 | for folder in os.listdir(in_path): 47 | 48 | if folder not in ['JPEGImages', 'Annotations']: 49 | continue 50 | folder_path = path.join(in_path, folder) 51 | videos = os.listdir(folder_path) 52 | 53 | videos = [(v, folder_path, out_path) for v in videos] 54 | 55 | if folder == 'JPEGImages': 56 | print('Processing images') 57 | os.makedirs(path.join(out_path, 'JPEGImages'), exist_ok=True) 58 | 59 | pool = Pool(processes=8) 60 | for _ in progressbar(pool.imap_unordered(resize_vid_jpeg, videos), max_value=len(videos)): 61 | pass 62 | else: 63 | print('Processing annotations') 64 | os.makedirs(path.join(out_path, 'Annotations'), exist_ok=True) 65 | 66 | pool = Pool(processes=8) 67 | for _ in progressbar(pool.imap_unordered(resize_vid_anno, videos), max_value=len(videos)): 68 | pass 69 | 70 | 71 | if __name__ == '__main__': 72 | in_path = sys.argv[1] 73 | out_path = sys.argv[2] 74 | 75 | resize_all(in_path, out_path) 76 | 77 | print('Done.') -------------------------------------------------------------------------------- /third_party/XMem/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UT-Austin-RPL/GROOT/3d9eceae26c599c325f9457658ab41d2dd2a1319/third_party/XMem/util/__init__.py -------------------------------------------------------------------------------- /third_party/XMem/util/davis_subset.txt: -------------------------------------------------------------------------------- 1 | bear 2 | bmx-bumps 3 | boat 4 | boxing-fisheye 5 | breakdance-flare 6 | bus 7 | car-turn 8 | cat-girl 9 | classic-car 10 | color-run 11 | crossing 12 | dance-jump 13 | dancing 14 | disc-jockey 15 | dog-agility 16 | dog-gooses 17 | dogs-scale 18 | drift-turn 19 | drone 20 | elephant 21 | flamingo 22 | hike 23 | hockey 24 | horsejump-low 25 | kid-football 26 | kite-walk 27 | koala 28 | lady-running 29 | lindy-hop 30 | longboard 31 | lucia 32 | mallard-fly 33 | mallard-water 34 | miami-surf 35 | motocross-bumps 36 | motorbike 37 | night-race 38 | paragliding 39 | planes-water 40 | rallye 41 | rhino 42 | rollerblade 43 | schoolgirls 44 | scooter-board 45 | scooter-gray 46 | sheep 47 | skate-park 48 | snowboard 49 | soccerball 50 | stroller 51 | stunt 52 | surf 53 | swing 54 | tennis 55 | tractor-sand 56 | train 57 | tuk-tuk 58 | upside-down 59 | varanus-cage 60 | walking -------------------------------------------------------------------------------- /third_party/XMem/util/load_subset.py: -------------------------------------------------------------------------------- 1 | """ 2 | load_subset.py - Presents a subset of data 3 | DAVIS - only the training set 4 | YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all 5 | """ 6 | 7 | 8 | def load_sub_davis(path='util/davis_subset.txt'): 9 | with open(path, mode='r') as f: 10 | subset = set(f.read().splitlines()) 11 | return subset 12 | 13 | def load_sub_yv(path='util/yv_subset.txt'): 14 | with open(path, mode='r') as f: 15 | subset = set(f.read().splitlines()) 16 | return subset 17 | -------------------------------------------------------------------------------- /third_party/XMem/util/log_integrator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integrate numerical values for some iterations 3 | Typically used for loss computation / logging to tensorboard 4 | Call finalize and create a new Integrator when you want to display/log 5 | """ 6 | 7 | import torch 8 | 9 | 10 | class Integrator: 11 | def __init__(self, logger, distributed=True, local_rank=0, world_size=1): 12 | self.values = {} 13 | self.counts = {} 14 | self.hooks = [] # List is used here to maintain insertion order 15 | 16 | self.logger = logger 17 | 18 | self.distributed = distributed 19 | self.local_rank = local_rank 20 | self.world_size = world_size 21 | 22 | def add_tensor(self, key, tensor): 23 | if key not in self.values: 24 | self.counts[key] = 1 25 | if type(tensor) == float or type(tensor) == int: 26 | self.values[key] = tensor 27 | else: 28 | self.values[key] = tensor.mean().item() 29 | else: 30 | self.counts[key] += 1 31 | if type(tensor) == float or type(tensor) == int: 32 | self.values[key] += tensor 33 | else: 34 | self.values[key] += tensor.mean().item() 35 | 36 | def add_dict(self, tensor_dict): 37 | for k, v in tensor_dict.items(): 38 | self.add_tensor(k, v) 39 | 40 | def add_hook(self, hook): 41 | """ 42 | Adds a custom hook, i.e. compute new metrics using values in the dict 43 | The hook takes the dict as argument, and returns a (k, v) tuple 44 | e.g. for computing IoU 45 | """ 46 | if type(hook) == list: 47 | self.hooks.extend(hook) 48 | else: 49 | self.hooks.append(hook) 50 | 51 | def reset_except_hooks(self): 52 | self.values = {} 53 | self.counts = {} 54 | 55 | # Average and output the metrics 56 | def finalize(self, prefix, it, f=None): 57 | 58 | for hook in self.hooks: 59 | k, v = hook(self.values) 60 | self.add_tensor(k, v) 61 | 62 | for k, v in self.values.items(): 63 | 64 | if k[:4] == 'hide': 65 | continue 66 | 67 | avg = v / self.counts[k] 68 | 69 | if self.distributed: 70 | # Inplace operation 71 | avg = torch.tensor(avg).cuda() 72 | torch.distributed.reduce(avg, dst=0) 73 | 74 | if self.local_rank == 0: 75 | avg = (avg/self.world_size).cpu().item() 76 | self.logger.log_metrics(prefix, k, avg, it, f) 77 | else: 78 | # Simple does it 79 | self.logger.log_metrics(prefix, k, avg, it, f) 80 | 81 | -------------------------------------------------------------------------------- /third_party/XMem/util/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dumps things to tensorboard and console 3 | """ 4 | 5 | import os 6 | import warnings 7 | 8 | import torchvision.transforms as transforms 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | 12 | def tensor_to_numpy(image): 13 | image_np = (image.numpy() * 255).astype('uint8') 14 | return image_np 15 | 16 | def detach_to_cpu(x): 17 | return x.detach().cpu() 18 | 19 | def fix_width_trunc(x): 20 | return ('{:.9s}'.format('{:0.9f}'.format(x))) 21 | 22 | class TensorboardLogger: 23 | def __init__(self, short_id, id, git_info): 24 | self.short_id = short_id 25 | if self.short_id == 'NULL': 26 | self.short_id = 'DEBUG' 27 | 28 | if id is None: 29 | self.no_log = True 30 | warnings.warn('Logging has been disbaled.') 31 | else: 32 | self.no_log = False 33 | 34 | self.inv_im_trans = transforms.Normalize( 35 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 36 | std=[1/0.229, 1/0.224, 1/0.225]) 37 | 38 | self.inv_seg_trans = transforms.Normalize( 39 | mean=[-0.5/0.5], 40 | std=[1/0.5]) 41 | 42 | log_path = os.path.join('.', 'saves', '%s' % id) 43 | self.logger = SummaryWriter(log_path) 44 | 45 | self.log_string('git', git_info) 46 | 47 | def log_scalar(self, tag, x, step): 48 | if self.no_log: 49 | warnings.warn('Logging has been disabled.') 50 | return 51 | self.logger.add_scalar(tag, x, step) 52 | 53 | def log_metrics(self, l1_tag, l2_tag, val, step, f=None): 54 | tag = l1_tag + '/' + l2_tag 55 | text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val)) 56 | print(text) 57 | if f is not None: 58 | f.write(text + '\n') 59 | f.flush() 60 | self.log_scalar(tag, val, step) 61 | 62 | def log_im(self, tag, x, step): 63 | if self.no_log: 64 | warnings.warn('Logging has been disabled.') 65 | return 66 | x = detach_to_cpu(x) 67 | x = self.inv_im_trans(x) 68 | x = tensor_to_numpy(x) 69 | self.logger.add_image(tag, x, step) 70 | 71 | def log_cv2(self, tag, x, step): 72 | if self.no_log: 73 | warnings.warn('Logging has been disabled.') 74 | return 75 | x = x.transpose((2, 0, 1)) 76 | self.logger.add_image(tag, x, step) 77 | 78 | def log_seg(self, tag, x, step): 79 | if self.no_log: 80 | warnings.warn('Logging has been disabled.') 81 | return 82 | x = detach_to_cpu(x) 83 | x = self.inv_seg_trans(x) 84 | x = tensor_to_numpy(x) 85 | self.logger.add_image(tag, x, step) 86 | 87 | def log_gray(self, tag, x, step): 88 | if self.no_log: 89 | warnings.warn('Logging has been disabled.') 90 | return 91 | x = detach_to_cpu(x) 92 | x = tensor_to_numpy(x) 93 | self.logger.add_image(tag, x, step) 94 | 95 | def log_string(self, tag, x): 96 | print(tag, x) 97 | if self.no_log: 98 | warnings.warn('Logging has been disabled.') 99 | return 100 | self.logger.add_text(tag, x) 101 | -------------------------------------------------------------------------------- /third_party/XMem/util/palette.py: -------------------------------------------------------------------------------- 1 | davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0' 2 | 3 | youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f' 4 | -------------------------------------------------------------------------------- /third_party/XMem/util/tensor_util.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def compute_tensor_iu(seg, gt): 5 | intersection = (seg & gt).float().sum() 6 | union = (seg | gt).float().sum() 7 | 8 | return intersection, union 9 | 10 | def compute_tensor_iou(seg, gt): 11 | intersection, union = compute_tensor_iu(seg, gt) 12 | iou = (intersection + 1e-6) / (union + 1e-6) 13 | 14 | return iou 15 | 16 | # STM 17 | def pad_divide_by(in_img, d): 18 | h, w = in_img.shape[-2:] 19 | 20 | if h % d > 0: 21 | new_h = h + d - h % d 22 | else: 23 | new_h = h 24 | if w % d > 0: 25 | new_w = w + d - w % d 26 | else: 27 | new_w = w 28 | lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) 29 | lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) 30 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 31 | out = F.pad(in_img, pad_array) 32 | return out, pad_array 33 | 34 | def unpad(img, pad): 35 | if len(img.shape) == 4: 36 | if pad[2]+pad[3] > 0: 37 | img = img[:,:,pad[2]:-pad[3],:] 38 | if pad[0]+pad[1] > 0: 39 | img = img[:,:,:,pad[0]:-pad[1]] 40 | elif len(img.shape) == 3: 41 | if pad[2]+pad[3] > 0: 42 | img = img[:,pad[2]:-pad[3],:] 43 | if pad[0]+pad[1] > 0: 44 | img = img[:,:,pad[0]:-pad[1]] 45 | else: 46 | raise NotImplementedError 47 | return img --------------------------------------------------------------------------------