├── .gitignore ├── README.md ├── assets └── network.png ├── configs ├── _base_ │ ├── catre_base.py │ └── common_base.py └── catre │ └── NOCS_REAL │ ├── aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e.py │ └── aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e_initspd.py ├── core ├── __init__.py ├── base_data_loader.py ├── catre │ ├── datasets │ │ ├── __init__.py │ │ ├── cmra.py │ │ ├── data_loader.py │ │ ├── dataset_factory.py │ │ └── nocs.py │ ├── engine │ │ ├── __init__.py │ │ ├── batch_test.py │ │ ├── batching.py │ │ ├── catre_custom_evaluator.py │ │ ├── catre_evaluator.py │ │ ├── engine.py │ │ ├── engine_utils.py │ │ └── test_utils.py │ ├── losses │ │ ├── l2_loss.py │ │ ├── pm_loss.py │ │ └── rot_loss.py │ ├── main_catre.py │ ├── models │ │ ├── CATRE_disR_shared.py │ │ ├── heads │ │ │ ├── conv_out_per_rot_head.py │ │ │ └── fc_trans_size_head.py │ │ ├── model_utils.py │ │ ├── net_factory.py │ │ ├── pointnets │ │ │ └── pointnet.py │ │ └── pose_scale_from_delta_init.py │ ├── test_catre.sh │ ├── tools │ │ ├── camera25_prepare_spd_init_results.py │ │ └── prepare_spd_init_results.py │ └── train_catre.sh └── utils │ ├── __init__.py │ ├── apex_trainer.py │ ├── augment.py │ ├── camera_geometry.py │ ├── cat_data_utils.py │ ├── data_utils.py │ ├── dataset_utils.py │ ├── default_args_setup.py │ ├── depth_aug.py │ ├── depth_image_smoothing.py │ ├── edge_utils.py │ ├── farthest_points_torch.py │ ├── lie_algebra.py │ ├── my_checkpoint.py │ ├── my_comm.py │ ├── my_distributed_sampler.py │ ├── my_setup.py │ ├── my_visualizer.py │ ├── my_writer.py │ ├── pose_aug.py │ ├── pose_utils.py │ ├── quaternion_lf.py │ ├── rot_reps.py │ ├── solver_utils.py │ ├── ssd_color_transform.py │ ├── timm_utils.py │ ├── utils.py │ └── zoom_utils.py ├── datasets └── NOCS │ ├── REAL │ ├── real_test │ └── real_train │ ├── obj_models │ ├── abs_scale.pkl │ ├── cr_normed_mean_model_points_spd.pkl │ ├── mug_handle.pkl │ ├── mug_meta.pkl │ ├── real_test_spd.pkl │ └── real_train_spd.pkl │ └── test_init_poses │ ├── init_pose_dualposenet_nocs_real.json │ ├── init_pose_nocs_real.json │ └── init_pose_spd_nocs_real.json ├── docs └── INSTALL.md ├── lib ├── __init__.py ├── pysixd │ ├── RT_transform.py │ ├── __init__.py │ ├── colors.json │ ├── comparative_report.py │ ├── config.py │ ├── dataset_params.py │ ├── dataset_params_sixd.py │ ├── eval_calc_errors.py │ ├── eval_loc.py │ ├── eval_loc_origin.py │ ├── eval_plots.py │ ├── eval_utils.py │ ├── inout.py │ ├── misc.py │ ├── pose_error.py │ ├── pose_error_more.py │ ├── pose_matching.py │ ├── pycoco_utils.py │ ├── renderer.py │ ├── renderer_cpp.py │ ├── renderer_glumpy.py │ ├── renderer_py.py │ ├── renderer_pyrender.py │ ├── renderer_vispy.py │ ├── score.py │ ├── se3.py │ ├── test_set_bb8_sixd.yml │ ├── transform.py │ ├── uv_projection.py │ ├── view_sampler.py │ ├── visibility.py │ └── visualization.py ├── structures │ ├── __init__.py │ ├── centers_2d.py │ ├── keypoints_2d.py │ ├── keypoints_3d.py │ ├── my_list.py │ ├── my_maps.py │ ├── my_masks.py │ ├── poses.py │ ├── quats.py │ ├── rots.py │ └── translations.py ├── torch_utils │ ├── __init__.py │ ├── color │ │ ├── __init__.py │ │ ├── gray.py │ │ ├── hls.py │ │ ├── hsv.py │ │ ├── lab.py │ │ ├── luv.py │ │ ├── rgb.py │ │ ├── xyz.py │ │ ├── ycbcr.py │ │ └── yuv.py │ ├── layers │ │ ├── __init__.py │ │ ├── acon.py │ │ ├── activations_me.py │ │ ├── conv_module.py │ │ ├── coord_attention.py │ │ ├── dropblock │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── dropblock.py │ │ │ └── scheduler.py │ │ ├── layer_utils.py │ │ ├── mean_conv_deconv.py │ │ └── std_conv_transpose.py │ ├── misc.py │ ├── solver │ │ ├── AdaBelief.py │ │ ├── __init__.py │ │ ├── adamp.py │ │ ├── badam.py │ │ ├── grad_clip_d2.py │ │ ├── lookahead.py │ │ ├── lr_scheduler.py │ │ ├── madgrad.py │ │ ├── nadamw.py │ │ ├── optimize.py │ │ ├── over9000.py │ │ ├── radam.py │ │ ├── ralamb.py │ │ ├── ranger.py │ │ ├── ranger2020.py │ │ ├── ranger21.py │ │ ├── ranger_adabelief.py │ │ ├── rmsprop_tf.py │ │ ├── sgd_gc.py │ │ └── sgdp.py │ └── torch_utils.py ├── utils │ ├── __init__.py │ ├── bbox_utils.py │ ├── config_utils.py │ ├── fs.py │ ├── is_binary_file.py │ ├── logger.py │ ├── mask_utils.py │ ├── setup_logger.py │ ├── setup_logger_loguru.py │ ├── time_utils.py │ └── utils.py └── vis_utils │ ├── __init__.py │ ├── cmap_plt2cv.py │ ├── colormap.py │ ├── image.py │ └── optflow.py ├── output └── catre │ └── NOCS_REAL │ └── aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e │ └── model_final_wo_optim-82cf930e.pth ├── preprocess ├── pose_data.py └── shape_dataset.py ├── ref ├── __init__.py ├── cmra.py └── nocs.py ├── requirements └── requirements.txt └── scripts └── install_deps.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.so.* 3 | *.tar.gz 4 | *.egg-info* 5 | 6 | *.ttf 7 | *.jpg 8 | 9 | 10 | # compilation and distribution 11 | __pycache__ 12 | _ext 13 | *.pyc 14 | *.so 15 | detectron2.egg-info/ 16 | build/ 17 | dist/ 18 | .cache/ 19 | 20 | # pytorch/python/numpy formats 21 | #*.pth 22 | #*.pkl 23 | *.npy 24 | *.engine 25 | *.onnx 26 | events.out.tfevents* 27 | 28 | # ipython/jupyter notebooks 29 | *.ipynb 30 | **/.ipynb_checkpoints/ 31 | 32 | # Editor temporaries 33 | *.swn 34 | *.swo 35 | *.swp 36 | *~ 37 | 38 | # Pycharm editor settings 39 | .idea 40 | 41 | # VSCode editor settings 42 | .vscode 43 | 44 | # project dirs 45 | /models 46 | 47 | #external/* 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CATRE 2 | 3 | This repo provides for the implementation of the ECCV'22 paper: 4 | 5 | **CATRE: Iterative Point Clouds Alignment for Category-level Object Pose Refinement**
6 | [[arXiv](https://arxiv.org/abs/2207.08082)][[Video](https://www.bilibili.com/video/BV1e8411e7Jt/?share_source=copy_web)] 7 | 8 | ## Overview 9 | 10 | ![](assets/network.png) 11 | 12 | 13 | ## Dependencies 14 | 15 | See [INSTALL.md](./docs/INSTALL.md) 16 | 17 | ## Datasets 18 | 19 | Prepare datasets folder like this: 20 | ```bash 21 | datasets/ 22 | ├── NOCS 23 | ├──REAL 24 | ├── real_test # download from http://download.cs.stanford.edu/orion/nocs/real_test.zip 25 | ├── real_train # download from http://download.cs.stanford.edu/orion/nocs/real_train.zip 26 | └── image_set # generate from pose_data.py 27 | ├──gts # download from http://download.cs.stanford.edu/orion/nocs/gts.zip 28 | └── real_test 29 | ├──test_init_poses # we provide 30 | └──object_models # we provide some necesarry files, complete files can be download from http://download.cs.stanford.edu/orion/nocs/obj_models.zip 31 | ``` 32 | 33 | Run python scripts to prepare the datasets. (Modified from https://github.com/mentian/object-deformnet) 34 | ```bash 35 | # NOTE: this code will directly modify the data 36 | cd $ROOT/preprocess 37 | python pose_data.py 38 | ``` 39 | 40 | ## Reproduce the results 41 | 42 | The trained model has been saved at `output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e/model_final_wo_optim-82cf930e.pth`. Run the following command to reproduce the results: 43 | 44 | ``` 45 | ./core/catre/test_catre.sh configs/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e.py 1 output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e/model_final_wo_optim-82cf930e.pth 46 | ``` 47 | 48 | ## NOTE 49 | 50 | **NOTE** that there is a small bug in the original evaluation [code](https://github.com/hughw19/NOCS_CVPR2019/blob/78a31c2026a954add1a2711286ff45ce1603b8ab/utils.py#L252) of NOCS w.r.t. IOU. We fixed this bug in our evaluation code and re-evaluated all the compared methods in the paper (we only revised the value of IOU and kept rotation/translation results the same, but indeed the accuracy of R/t will also change a little bit). See the revised [code](https://github.com/THU-DA-6D-Pose-Group/CATRE/blob/b649cbad6ed2121b22a37f7fe16ad923688d4995/core/catre/engine/test_utils.py#L158) for details. Also thanks [Peng et al.](https://github.com/swords123/SSC-6D/blob/bb0dcd5e5b789ea2a80c6c3fa16ccc2bf0a445d1/eval/utils.py#L114) for further confirming this bug. 51 | 52 | ## Training 53 | 54 | ``` 55 | ./core/catre/train_catre.sh configs/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e.py (other args) 56 | ``` 57 | 58 | ## Testing 59 | ``` 60 | ./core/catre/test_catre.sh configs/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e.py (other args) 61 | ``` 62 | 63 | ## Citation 64 | If you find this repo useful in your research, please consider citing: 65 | ``` 66 | @InProceedings{liu_2022_catre, 67 | title = {{CATRE:} Iterative Point Clouds Alignment for Category-level Object Pose Refinement}, 68 | author = {Liu, Xingyu and Wang, Gu and Li, Yi and Ji, Xiangyang}, 69 | booktitle = {European Conference on Computer Vision (ECCV)}, 70 | month = {October}, 71 | year = {2022} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /assets/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/assets/network.png -------------------------------------------------------------------------------- /configs/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e.py: -------------------------------------------------------------------------------- 1 | _base_ = ["../../_base_/catre_base.py"] 2 | 3 | # fix mug mean shape 4 | OUTPUT_DIR = "output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e" 5 | INPUT = dict( 6 | COLOR_AUG_PROB=0.0, 7 | DEPTH_SAMPLE_BALL_RATIO=0.6, 8 | # NOTE: used points! 9 | BBOX_TYPE_TEST="est", # from_pose | est | gt | gt_aug (TODO) 10 | INIT_POSE_TYPE_TRAIN=["gt_noise"], # gt_noise | random | canonical 11 | NOISE_ROT_STD_TRAIN=(10, 5, 2.5, 1.25), # randomly choose one 12 | NOISE_TRANS_STD_TRAIN=[ 13 | (0.02, 0.02, 0.02), 14 | (0.01, 0.01, 0.01), 15 | (0.005, 0.005, 0.005), 16 | ], 17 | NOISE_SCALE_STD_TRAIN=[ 18 | (0.01, 0.01, 0.01), 19 | (0.005, 0.005, 0.005), 20 | (0.002, 0.002, 0.002), 21 | ], 22 | INIT_POSE_TYPE_TEST="est", # gt_noise | est | canonical 23 | KPS_TYPE="mean_shape", # bbox_from_scale | mean_shape |fps (abla) 24 | WITH_DEPTH=True, 25 | AUG_DEPTH=True, 26 | WITH_PCL=True, 27 | WITH_IMG=False, 28 | BP_DEPTH=False, 29 | NUM_KPS=1024, 30 | NUM_PCL=1024, 31 | # augmentation when training 32 | BBOX3D_AUG_PROB=0.5, 33 | RT_AUG_PROB=0.5, 34 | # pose focalization 35 | ZERO_CENTER_INPUT=True, 36 | ) 37 | 38 | DATALOADER = dict( 39 | NUM_WORKERS=24, 40 | ) 41 | 42 | SOLVER = dict( 43 | IMS_PER_BATCH=16, 44 | TOTAL_EPOCHS=120, 45 | LR_SCHEDULER_NAME="flat_and_anneal", 46 | ANNEAL_METHOD="cosine", # "cosine" 47 | ANNEAL_POINT=0.72, 48 | # REL_STEPS=(0.3125, 0.625, 0.9375), 49 | OPTIMIZER_CFG=dict(_delete_=True, type="Ranger", lr=1e-4, weight_decay=0), 50 | WEIGHT_DECAY=0.0, 51 | WARMUP_FACTOR=0.001, 52 | WARMUP_ITERS=1000, 53 | ) 54 | 55 | DATASETS = dict( 56 | TRAIN=("nocs_train_real",), 57 | TEST=("nocs_test_real",), 58 | INIT_POSE_FILES_TEST=("datasets/NOCS/test_init_poses/init_pose_spd_nocs_real.json",), 59 | ) 60 | 61 | MODEL = dict( 62 | LOAD_POSES_TEST=True, 63 | PIXEL_MEAN=[0.0, 0.0, 0.0], 64 | PIXEL_STD=[255.0, 255.0, 255.0], 65 | REFINE_SCLAE=True, 66 | CATRE=dict( 67 | NAME="CATRE_disR_shared", # used module file name (define different model types) 68 | TASK="refine", # refine | init | init+refine 69 | NUM_CLASSES=6, # only valid for class aware 70 | N_ITER_TRAIN=4, 71 | N_ITER_TRAIN_WARM_EPOCH=4, # linearly increase the refine iter from 1 to N_ITER_TRAIN until this epoch 72 | N_ITER_TEST=4, 73 | PCLNET=dict( 74 | FREEZE=False, 75 | INIT_CFG=dict( 76 | type="point_net", 77 | num_points=1024, 78 | global_feat=False, 79 | feature_transform=True, 80 | out_dim=1024, 81 | ), 82 | ), 83 | ## disentangled pose head for delta R/T/s 84 | ROT_HEAD=dict( 85 | ROT_TYPE="ego_rot6d", # {ego|allo}_rot6d 86 | INIT_CFG=dict( 87 | type="ConvOutPerRotHead", 88 | in_dim=1088, 89 | num_layers=2, 90 | kernel_size=1, 91 | feat_dim=256, 92 | norm="GN", # BN | GN | none 93 | num_gn_groups=32, 94 | act="gelu", # relu | lrelu | silu (swish) | gelu | mish 95 | num_points=1024 + 1024, 96 | rot_dim=3, # ego_rot6d 97 | norm_input=False, 98 | ), 99 | SCLAE_TYPE="iter_add", 100 | ), 101 | TS_HEAD=dict( 102 | WITH_KPS_FEATURE=False, 103 | WITH_INIT_SCALE=True, 104 | INIT_CFG=dict( 105 | type="FC_TransSizeHead", 106 | in_dim=1088 + 3, 107 | num_layers=2, 108 | feat_dim=256, 109 | norm="GN", # BN | GN | none 110 | num_gn_groups=32, 111 | act="gelu", # relu | lrelu | silu (swish) | gelu | mish 112 | norm_input=False, 113 | ), 114 | ), 115 | LOSS_CFG=dict( 116 | # point matching loss ---------------- 117 | PM_LOSS_SYM=True, # use symmetric PM loss 118 | PM_NORM_BY_EXTENT=False, # 1. / extent.max(1, keepdim=True)[0] 119 | # if False, the trans loss is in point matching loss 120 | PM_R_ONLY=True, # only do R loss in PM 121 | PM_WITH_SCALE=True, 122 | PM_LW=1.0, 123 | # rot loss -------------- 124 | ROT_LOSS_TYPE="angular", # angular | L2 125 | ROT_LW=1.0, 126 | ROT_YAXIS_LOSS_TYPE="L1", 127 | # trans loss ----------- 128 | TRANS_LOSS_TYPE="L1", 129 | TRANS_LOSS_DISENTANGLE=True, 130 | TRANS_LW=1.0, 131 | # scale loss ---------------------------------- 132 | SCALE_LOSS_TYPE="L1", 133 | SCALE_LW=1.0, 134 | ), 135 | ), 136 | ) 137 | -------------------------------------------------------------------------------- /configs/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e_initspd.py: -------------------------------------------------------------------------------- 1 | _base_ = ["../../_base_/catre_base.py"] 2 | 3 | # fix mug mean shape 4 | OUTPUT_DIR = "output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e_initspd" 5 | INPUT = dict( 6 | COLOR_AUG_PROB=0.0, 7 | DEPTH_SAMPLE_BALL_RATIO=0.6, 8 | # NOTE: used points! 9 | BBOX_TYPE_TEST="est", # from_pose | est | gt | gt_aug (TODO) 10 | INIT_POSE_TYPE_TRAIN=["gt_noise"], # gt_noise | random | canonical 11 | NOISE_ROT_STD_TRAIN=(10, 5, 2.5, 1.25), # randomly choose one 12 | NOISE_TRANS_STD_TRAIN=[ 13 | (0.02, 0.02, 0.02), 14 | (0.01, 0.01, 0.01), 15 | (0.005, 0.005, 0.005), 16 | ], 17 | NOISE_SCALE_STD_TRAIN=[ 18 | (0.01, 0.01, 0.01), 19 | (0.005, 0.005, 0.005), 20 | (0.002, 0.002, 0.002), 21 | ], 22 | INIT_POSE_TYPE_TEST="est", # gt_noise | est | canonical 23 | KPS_TYPE="mean_shape", # bbox_from_scale | mean_shape |fps (abla) 24 | WITH_DEPTH=True, 25 | AUG_DEPTH=True, 26 | WITH_PCL=True, 27 | WITH_IMG=False, 28 | BP_DEPTH=False, 29 | NUM_KPS=1024, 30 | NUM_PCL=1024, 31 | # augmentation when training 32 | BBOX3D_AUG_PROB=0.5, 33 | RT_AUG_PROB=0.5, 34 | ZERO_CENTER_INPUT=True, 35 | ) 36 | 37 | DATALOADER = dict( 38 | NUM_WORKERS=24, 39 | ) 40 | 41 | SOLVER = dict( 42 | IMS_PER_BATCH=32, 43 | TOTAL_EPOCHS=120, 44 | LR_SCHEDULER_NAME="flat_and_anneal", 45 | ANNEAL_METHOD="cosine", # "cosine" 46 | ANNEAL_POINT=0.72, 47 | # REL_STEPS=(0.3125, 0.625, 0.9375), 48 | OPTIMIZER_CFG=dict(_delete_=True, type="Ranger", lr=1e-4, weight_decay=0), 49 | WEIGHT_DECAY=0.0, 50 | WARMUP_FACTOR=0.001, 51 | WARMUP_ITERS=1000, 52 | ) 53 | 54 | DATASETS = dict( 55 | TRAIN=("nocs_train_real",), 56 | TEST=("nocs_test_real",), 57 | INIT_POSE_FILES_TEST=("datasets/NOCS/test_init_poses/init_pose_dualposenet_nocs_real.json",), 58 | ) 59 | 60 | MODEL = dict( 61 | LOAD_POSES_TEST=True, 62 | PIXEL_MEAN=[0.0, 0.0, 0.0], 63 | PIXEL_STD=[255.0, 255.0, 255.0], 64 | REFINE_SCLAE=True, 65 | CATRE=dict( 66 | NAME="CATRE_disR_shared", # used module file name (define different model types) 67 | TASK="refine", # refine | init | init+refine 68 | NUM_CLASSES=6, # only valid for class aware 69 | N_ITER_TRAIN=4, 70 | N_ITER_TRAIN_WARM_EPOCH=4, # linearly increase the refine iter from 1 to N_ITER_TRAIN until this epoch 71 | N_ITER_TEST=4, 72 | PCLNET=dict( 73 | FREEZE=False, 74 | INIT_CFG=dict( 75 | type="point_net", 76 | num_points=1024, 77 | global_feat=False, 78 | feature_transform=True, 79 | out_dim=1024, 80 | ), 81 | ), 82 | ## pose head for delta R/T/s 83 | ROT_HEAD=dict( 84 | ROT_TYPE="ego_rot6d", # {ego|allo}_rot6d 85 | INIT_CFG=dict( 86 | type="ConvOutPerRotHead", 87 | in_dim=1088, 88 | num_layers=2, 89 | kernel_size=1, 90 | feat_dim=256, 91 | norm="GN", # BN | GN | none 92 | num_gn_groups=32, 93 | act="gelu", # relu | lrelu | silu (swish) | gelu | mish 94 | num_points=1024 + 1024, 95 | rot_dim=3, # ego_rot6d 96 | norm_input=False, 97 | ), 98 | SCLAE_TYPE="iter_add", 99 | ), 100 | TS_HEAD=dict( 101 | WITH_KPS_FEATURE=False, 102 | WITH_INIT_SCALE=True, 103 | INIT_CFG=dict( 104 | type="FC_TransSizeHead", 105 | in_dim=1088 + 3, 106 | num_layers=2, 107 | feat_dim=256, 108 | norm="GN", # BN | GN | none 109 | num_gn_groups=32, 110 | act="gelu", # relu | lrelu | silu (swish) | gelu | mish 111 | norm_input=False, 112 | ), 113 | ), 114 | LOSS_CFG=dict( 115 | # point matching loss ---------------- 116 | PM_LOSS_SYM=True, # use symmetric PM loss 117 | PM_NORM_BY_EXTENT=False, # 1. / extent.max(1, keepdim=True)[0] 118 | # if False, the trans loss is in point matching loss 119 | PM_R_ONLY=True, # only do R loss in PM 120 | PM_WITH_SCALE=True, 121 | PM_LW=1.0, 122 | # rot loss -------------- 123 | ROT_LOSS_TYPE="angular", # angular | L2 124 | ROT_LW=1.0, 125 | ROT_YAXIS_LOSS_TYPE="L1", 126 | # trans loss ----------- 127 | TRANS_LOSS_TYPE="L1", 128 | TRANS_LOSS_DISENTANGLE=True, 129 | TRANS_LW=1.0, 130 | # scale loss ---------------------------------- 131 | SCALE_LOSS_TYPE="L1", 132 | SCALE_LW=1.0, 133 | ), 134 | ), 135 | ) 136 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/core/__init__.py -------------------------------------------------------------------------------- /core/catre/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/core/catre/datasets/__init__.py -------------------------------------------------------------------------------- /core/catre/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | """Register datasets in this file will be imported in project root to register 2 | the datasets.""" 3 | import logging 4 | import os 5 | import os.path as osp 6 | import mmcv 7 | import detectron2.utils.comm as comm 8 | import ref 9 | from detectron2.data import DatasetCatalog, MetadataCatalog 10 | from core.catre.datasets import ( 11 | nocs, 12 | cmra, 13 | ) 14 | 15 | cur_dir = osp.dirname(osp.abspath(__file__)) 16 | # from lib.utils.utils import iprint 17 | __all__ = ["register_dataset", "register_datasets", "register_datasets_in_cfg", "get_available_datasets"] 18 | _DSET_MOD_NAMES = [ 19 | "nocs", 20 | "cmra", 21 | ] 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def register_dataset(mod_name, dset_name, data_cfg=None): 27 | """ 28 | mod_name: a module under core.datasets or other dataset source file imported here 29 | dset_name: dataset name 30 | data_cfg: dataset config 31 | """ 32 | register_func = eval(mod_name) 33 | register_func.register_with_name_cfg(dset_name, data_cfg) 34 | 35 | 36 | def get_available_datasets(mod_name): 37 | return eval(mod_name).get_available_datasets() 38 | 39 | 40 | def register_datasets_in_cfg(cfg): 41 | for split in [ 42 | "TRAIN", 43 | "TEST", 44 | "VAL", 45 | "TRAIN2", 46 | ]: 47 | for name in cfg.DATASETS.get(split, []): 48 | if name in DatasetCatalog.list(): 49 | continue 50 | registered = False 51 | # try to find in pre-defined datasets 52 | # NOTE: it is better to let all datasets pre-refined 53 | for _mod_name in _DSET_MOD_NAMES: 54 | if name in get_available_datasets(_mod_name): 55 | register_dataset(_mod_name, name, data_cfg=None) 56 | registered = True 57 | break 58 | # not in pre-defined; not recommend 59 | if not registered: 60 | # try to get mod_name and data_cfg from cfg 61 | """load data_cfg and mod_name from file 62 | cfg.DATA_CFG[name] = 'path_to_cfg' 63 | """ 64 | assert "DATA_CFG" in cfg and name in cfg.DATA_CFG, "no cfg.DATA_CFG.{}".format(name) 65 | assert osp.exists(cfg.DATA_CFG[name]) 66 | data_cfg = mmcv.load(cfg.DATA_CFG[name]) 67 | mod_name = data_cfg.pop("mod_name", None) 68 | assert mod_name in _DSET_MOD_NAMES, mod_name 69 | register_dataset(mod_name, name, data_cfg) 70 | 71 | 72 | def register_datasets(dataset_names): 73 | for name in dataset_names: 74 | if name in DatasetCatalog.list(): 75 | continue 76 | registered = False 77 | # try to find in pre-defined datasets 78 | # NOTE: it is better to let all datasets pre-refined 79 | for _mod_name in _DSET_MOD_NAMES: 80 | if name in get_available_datasets(_mod_name): 81 | register_dataset(_mod_name, name, data_cfg=None) 82 | registered = True 83 | break 84 | 85 | # not in pre-defined; not recommend 86 | if not registered: 87 | raise ValueError(f"dataset {name} is not defined") 88 | -------------------------------------------------------------------------------- /core/catre/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/core/catre/engine/__init__.py -------------------------------------------------------------------------------- /core/catre/engine/batch_test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | import torch 4 | 5 | from .engine_utils import get_normed_kps 6 | from lib.vis_utils.image import heatmap, grid_show 7 | from lib.pysixd.misc import transform_normed_pts_batch 8 | 9 | 10 | def batch_data_test(cfg, data, device="cuda", dtype=torch.float32): 11 | # batch test data and flatten 12 | tensor_kwargs = {"dtype": dtype, "device": device} 13 | to_float_args = {"dtype": dtype, "device": device, "non_blocking": True} 14 | to_long_args = {"dtype": torch.long, "device": device, "non_blocking": True} 15 | 16 | batch = {} 17 | num_imgs = len(data) 18 | # construct flattened instance data ============================= 19 | batch["obj_cls"] = torch.cat([d["instances"].obj_classes for d in data], dim=0).to(**to_long_args) 20 | batch["obj_bbox"] = torch.cat([d["instances"].obj_boxes.tensor for d in data], dim=0).to(**to_float_args) 21 | # NOTE: initial pose or the output pose estimate 22 | batch["obj_pose_est"] = torch.cat([d["instances"].obj_poses.tensor for d in data], dim=0).to(**to_float_args) 23 | batch["obj_scale_est"] = torch.cat([d["instances"].obj_scales for d in data], dim=0).to(**to_float_args) 24 | 25 | batch["obj_mean_points"] = torch.cat([d["instances"].obj_mean_points for d in data], dim=0).to(**to_float_args) 26 | batch["obj_mean_scales"] = torch.cat([d["instances"].obj_mean_scales for d in data], dim=0).to(**to_float_args) 27 | 28 | if cfg.INPUT.KPS_TYPE.lower() == "fps": 29 | # NOTE: only an ablation setting! 30 | batch["obj_fps_points"] = torch.cat([d["instances"].obj_fps_points for d in data], dim=0).to(**to_float_args) 31 | 32 | num_insts_per_im = [len(d["instances"]) for d in data] 33 | n_obj = len(batch["obj_cls"]) 34 | K_list = [] 35 | sym_infos_list = [] 36 | im_ids = [] 37 | inst_ids = [] 38 | for i_im in range(num_imgs): 39 | sym_infos_list.extend(data[i_im]["instances"].obj_sym_infos) 40 | for i_inst in range(num_insts_per_im[i_im]): 41 | im_ids.append(i_im) 42 | inst_ids.append(i_inst) 43 | K_list.append(data[i_im]["cam"].clone()) 44 | 45 | batch["im_id"] = torch.tensor(im_ids, **tensor_kwargs) 46 | batch["inst_id"] = torch.tensor(inst_ids, **tensor_kwargs) 47 | batch["K"] = torch.stack(K_list, dim=0).to(**to_float_args) 48 | batch["sym_info"] = sym_infos_list 49 | 50 | input_cfg = cfg.INPUT 51 | 52 | batch["pcl"] = torch.cat([d["instances"].pcl for d in data], dim=0).to(**to_float_args) 53 | 54 | if input_cfg.WITH_IMG: 55 | batch["img"] = torch.stack([d["image"] for d in data]).to(**to_float_args) 56 | 57 | if input_cfg.WITH_DEPTH: 58 | batch["depth_obs"] = torch.stack([d["depth"] for d in data], dim=0).to(**to_float_args) 59 | 60 | return batch 61 | 62 | 63 | def batch_updater_test(cfg, batch, poses_est=None, scales_est=None, device="cuda", dtype=torch.float32): 64 | """ 65 | iter=0: poses_est=None, obj_pose_est is from data loader 66 | if REFINE_SCLAE is False, keep init_scale unchanged from iter 0 ~ max_num 67 | """ 68 | tensor_kwargs = {"dtype": dtype, "device": device} 69 | to_float_args = {"dtype": dtype, "device": device, "non_blocking": True} 70 | 71 | n_obj = batch["obj_cls"].shape[0] 72 | if poses_est is not None: 73 | batch["obj_pose_est"] = poses_est 74 | 75 | if scales_est is not None and cfg.MODEL.REFINE_SCLAE: 76 | batch["obj_scale_est"] = scales_est 77 | 78 | if "obj_kps" not in batch: 79 | get_normed_kps(cfg, batch, **to_float_args) 80 | 81 | r_est = batch["obj_pose_est"][:, :3, :3] 82 | t_est = batch["obj_pose_est"][:, :3, 3:4] 83 | s_est = batch["obj_scale_est"] 84 | 85 | tfd_kps = transform_normed_pts_batch( 86 | batch["obj_kps"], 87 | r_est, 88 | t=None if cfg.INPUT.ZERO_CENTER_INPUT else t_est, 89 | scale=s_est, 90 | ) 91 | 92 | batch["tfd_kps"] = tfd_kps.permute(0, 2, 1) # [bs, 3, num_k] 93 | 94 | if cfg.INPUT.ZERO_CENTER_INPUT: 95 | batch["x"] = batch["pcl"].permute(0, 2, 1) - t_est.view(n_obj, 3, 1) # [bs, 3, num_k] - [bs, 3, 1] 96 | else: 97 | batch["x"] = batch["pcl"].permute(0, 2, 1) 98 | 99 | # done batch update test------------------------------------------ 100 | -------------------------------------------------------------------------------- /core/catre/losses/l2_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def l2_loss(pred, target, reduction="mean"): 6 | assert pred.size() == target.size() and target.numel() > 0 7 | assert pred.size()[0] == target.size()[0] 8 | batch_size = pred.size()[0] 9 | loss = torch.norm((pred - target).view(batch_size, -1), p=2, dim=1, keepdim=True) 10 | # loss = torch.sqrt(torch.sum(((pred - target)** 2).view(batch_size, -1), 1)) 11 | # print(loss.shape) 12 | """ 13 | _mse_loss = nn.MSELoss(reduction='none') 14 | loss_mse = _mse_loss(pred, target) 15 | print('l2 from mse loss: {}'.format( 16 | torch.sqrt( 17 | torch.sum( 18 | loss_mse.view(batch_size, -1), 19 | 1 20 | ) 21 | ).mean())) 22 | """ 23 | if reduction == "mean": 24 | loss = loss.mean() 25 | elif reduction == "sum": 26 | loss = loss.sum() 27 | return loss 28 | 29 | 30 | class L2Loss(nn.Module): 31 | def __init__(self, reduction="mean", loss_weight=1.0): 32 | super(L2Loss, self).__init__() 33 | self.reduction = reduction 34 | self.loss_weight = loss_weight 35 | 36 | def forward(self, pred, target): 37 | loss = self.loss_weight * l2_loss(pred, target, reduction=self.reduction) 38 | return loss 39 | 40 | 41 | if __name__ == "__main__": 42 | 43 | _l2_loss = L2Loss(reduction="mean") 44 | torch.manual_seed(2) 45 | pred = torch.randn(8, 3, 4) 46 | targets = torch.randn(8, 3, 4) 47 | # pred.requires_grad = True 48 | # targets.requires_grad = True 49 | 50 | loss_l2 = _l2_loss(pred, targets) 51 | batch_size = 8 52 | # print('mse loss: {}'.format(loss_mse)) 53 | # print('sqrt(mse loss): {}'.format(torch.sqrt(loss_mse))) 54 | 55 | _mse_loss = nn.MSELoss(reduction="none") 56 | loss_mse = _mse_loss(pred, targets) 57 | print("l2 from mse loss: {}".format(torch.sqrt(torch.sum(loss_mse.view(batch_size, -1), 1)).mean())) 58 | print("l2 loss: {}".format(loss_l2)) 59 | # print('squared l2 loss: {}'.format(loss_l2 ** 2)) 60 | -------------------------------------------------------------------------------- /core/catre/losses/rot_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def angular_distance(r1, r2, reduction="mean"): 5 | """https://math.stackexchange.com/questions/90081/quaternion-distance 6 | https. 7 | 8 | ://github.com/papagina/RotationContinuity/blob/master/sanity_test/code/tool 9 | s.py. 10 | 11 | 1 - ^2 <==> (1-cos(theta)) / 2 12 | """ 13 | assert r1.shape == r2.shape 14 | if r1.shape[-1] == 4: 15 | return angular_distance_quat(r1, r2, reduction=reduction) 16 | if len(r1.shape) == 2 and r1.shape[-1] == 3: # bs * 3 17 | return angular_distance_vec(r1, r2, reduction=reduction) 18 | else: 19 | return angular_distance_rot(r1, r2, reduction=reduction) 20 | 21 | 22 | def angular_distance_quat(pred_q, gt_q, reduction="mean"): 23 | dist = 1 - torch.pow(torch.bmm(pred_q.view(-1, 1, 4), gt_q.view(-1, 4, 1)), 2) 24 | if reduction == "mean": 25 | return dist.mean() 26 | elif reduction == "sum": 27 | return dist.sum() 28 | else: 29 | return dist 30 | 31 | 32 | def angular_distance_vec(vec_1, vec_2, reduction="mean"): 33 | cos = torch.bmm(vec_1.unsqueeze(1), vec_2.unsqueeze(2)).squeeze() / ( 34 | torch.norm(vec_1, dim=1) * torch.norm(vec_2, dim=1) 35 | ) # [-1, 1] 36 | dist = (1 - cos) / 2 # [0, 1] 37 | if reduction == "mean": 38 | return dist.mean() 39 | elif reduction == "sum": 40 | return dist.sum() 41 | else: 42 | return dist 43 | 44 | 45 | def angular_distance_rot(m1, m2, reduction="mean"): 46 | m = torch.bmm(m1, m2.transpose(1, 2)) # b*3*3 47 | m_trace = torch.einsum("bii->b", m) # batch trace 48 | cos = (m_trace - 1) / 2 # [-1, 1] 49 | # eps = 1e-6 50 | # cos = torch.clamp(cos, -1+eps, 1-eps) # avoid nan 51 | # theta = torch.acos(cos) 52 | dist = (1 - cos) / 2 # [0, 1] 53 | if reduction == "mean": 54 | return dist.mean() 55 | elif reduction == "sum": 56 | return dist.sum() 57 | else: 58 | return dist 59 | 60 | 61 | def rot_l2_loss(m1, m2): 62 | error = torch.pow(m1 - m2, 2).mean() # batch 63 | return error 64 | 65 | 66 | if __name__ == "__main__": 67 | import sys 68 | import os.path as osp 69 | 70 | cur_dir = osp.dirname(__file__) 71 | sys.path.insert(0, osp.join(cur_dir, "../../../")) 72 | from lib.pysixd.transform import random_quaternion 73 | from transforms3d.quaternions import quat2mat 74 | 75 | q1 = random_quaternion() 76 | q2 = random_quaternion() 77 | m1 = quat2mat(q1) 78 | m2 = quat2mat(q2) 79 | dtype = torch.float32 80 | device = "cpu" 81 | q1 = torch.tensor([q1, q1], dtype=dtype, device=device).view(-1, 4) 82 | q2 = torch.tensor([q2, q2], dtype=dtype, device=device).view(-1, 4) 83 | m1 = torch.tensor([m1, m1], dtype=dtype, device=device).view(-1, 3, 3) 84 | m2 = torch.tensor([m2, m2], dtype=dtype, device=device).view(-1, 3, 3) 85 | dist_q = angular_distance_quat(q1, q2) 86 | dist_r = angular_distance_rot(m1, m2) 87 | print("dist q: ", dist_q) 88 | print("dist r: ", dist_r) 89 | print(angular_distance(q1, q2)) 90 | print(angular_distance(m1, m2)) 91 | -------------------------------------------------------------------------------- /core/catre/models/heads/conv_out_per_rot_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.batchnorm import _BatchNorm 4 | from mmcv.cnn import normal_init, constant_init 5 | 6 | from lib.torch_utils.layers.layer_utils import get_norm, get_nn_act_func 7 | from lib.torch_utils.layers.conv_module import ConvModule 8 | 9 | 10 | class ConvOutPerRotHead(nn.Module): 11 | def __init__( 12 | self, 13 | in_dim=1024, 14 | feat_dim=256, 15 | num_layers=2, 16 | rot_dim=3, 17 | norm="GN", 18 | num_gn_groups=32, 19 | act="gelu", 20 | num_classes=1, 21 | kernel_size=1, 22 | num_points=1, 23 | per_rot_sup=False, 24 | norm_input=False, 25 | dropout=False, 26 | point_bias=True, 27 | **args, 28 | ): 29 | super(ConvOutPerRotHead, self).__init__() 30 | self.per_rot_sup = per_rot_sup 31 | self.rot_head_x = RotHead( 32 | in_dim, 33 | feat_dim, 34 | num_layers, 35 | rot_dim, 36 | norm, 37 | num_gn_groups, 38 | act, 39 | num_classes, 40 | kernel_size, 41 | num_points, 42 | norm_input, 43 | dropout, 44 | point_bias, 45 | ) 46 | self.rot_head_y = RotHead( 47 | in_dim, 48 | feat_dim, 49 | num_layers, 50 | rot_dim, 51 | norm, 52 | num_gn_groups, 53 | act, 54 | num_classes, 55 | kernel_size, 56 | num_points, 57 | norm_input, 58 | dropout, 59 | point_bias, 60 | ) 61 | 62 | def forward(self, x): 63 | rx, feat_x = self.rot_head_x(x) 64 | ry, feat_y = self.rot_head_y(x) 65 | r_pred = torch.cat((rx, ry), dim=1) 66 | feat = torch.cat((feat_x, feat_y), dim=1) 67 | 68 | if self.per_rot_sup: 69 | return r_pred, feat # return bs * 6 70 | else: 71 | return r_pred 72 | 73 | 74 | class RotHead(nn.Module): 75 | def __init__( 76 | self, 77 | in_dim=1024, 78 | feat_dim=256, 79 | num_layers=2, 80 | rot_dim=4, 81 | norm="none", 82 | num_gn_groups=32, 83 | act="leaky_relu", 84 | num_classes=1, 85 | kernel_size=1, 86 | num_points=1, 87 | norm_input=False, 88 | dropout=False, 89 | point_bias=True, 90 | ): 91 | super().__init__() 92 | self.norm = get_norm(norm, feat_dim, num_gn_groups=num_gn_groups) 93 | self.act_func = act_func = get_nn_act_func(act) 94 | self.num_classes = num_classes 95 | self.rot_dim = rot_dim 96 | 97 | self.layers = nn.ModuleList() 98 | 99 | if norm_input: 100 | self.layers.append(nn.BatchNorm1d(in_dim)) 101 | for _i in range(num_layers): 102 | _in_dim = in_dim if _i == 0 else feat_dim 103 | self.layers.append(nn.Conv1d(_in_dim, feat_dim, kernel_size)) 104 | self.layers.append(get_norm(norm, feat_dim, num_gn_groups=num_gn_groups)) 105 | self.layers.append(act_func) 106 | if dropout: 107 | self.layers.append(nn.Dropout(p=0.2)) 108 | 109 | self.neck = nn.ModuleList() 110 | self.neck.append(nn.Conv1d(feat_dim, rot_dim * num_classes, 1)) 111 | 112 | self.conv_p = nn.Conv1d(num_points, 1, 1, bias=point_bias) 113 | 114 | # init ------------------------------------ 115 | self._init_weights() 116 | 117 | def _init_weights(self): 118 | for m in self.modules(): 119 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d)): 120 | normal_init(m, std=0.001) 121 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)): 122 | constant_init(m, 1) 123 | elif isinstance(m, nn.Linear): 124 | normal_init(m, std=0.001) 125 | 126 | def forward(self, x): 127 | for _layer in self.layers: 128 | x = _layer(x) 129 | 130 | for _layer in self.neck: 131 | x = _layer(x) 132 | 133 | feat = x.clone() 134 | x = x.permute(0, 2, 1) 135 | x = self.conv_p(x) 136 | 137 | x = x.squeeze(1) 138 | x = x.contiguous() 139 | 140 | return x, feat 141 | 142 | 143 | if __name__ == "__main__": 144 | points = torch.rand(8, 1088, 1024 + 32) # bs x feature x num_p 145 | rot_head = ConvOutPerRotHead(in_dim=1088, num_points=1024 + 32) 146 | rot, feat = rot_head(points) 147 | print(rot.shape) 148 | print(feat.shape) 149 | -------------------------------------------------------------------------------- /core/catre/models/heads/fc_trans_size_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.batchnorm import _BatchNorm 4 | from mmcv.cnn import normal_init, constant_init 5 | from lib.torch_utils.layers.layer_utils import get_norm, get_nn_act_func 6 | from lib.torch_utils.layers.conv_module import ConvModule 7 | 8 | 9 | class FC_TransSizeHead(nn.Module): 10 | def __init__( 11 | self, 12 | in_dim=1024, 13 | feat_dim=256, 14 | num_layers=2, 15 | rot_dim=4, 16 | norm="none", 17 | num_gn_groups=32, 18 | act="leaky_relu", 19 | num_classes=1, 20 | norm_input=False, 21 | dropout=False, 22 | ): 23 | """ 24 | rot_dim: 4 for quaternion, 6 for rot6d 25 | num_classes: default 1 (either single class or class-agnostic) 26 | """ 27 | super().__init__() 28 | self.norm = get_norm(norm, feat_dim, num_gn_groups=num_gn_groups) 29 | self.act_func = act_func = get_nn_act_func(act) 30 | self.num_classes = num_classes 31 | self.rot_dim = rot_dim 32 | 33 | self.linears = nn.ModuleList() 34 | if norm_input: 35 | self.linears.append(nn.BatchNorm1d(in_dim)) 36 | for _i in range(num_layers): 37 | _in_dim = in_dim if _i == 0 else feat_dim 38 | self.linears.append(nn.Linear(_in_dim, feat_dim)) 39 | self.linears.append(get_norm(norm, feat_dim, num_gn_groups=num_gn_groups)) 40 | self.linears.append(act_func) 41 | if dropout: 42 | self.linears.append(nn.Dropout(p=0.5)) 43 | 44 | self.fc_t = nn.Linear(feat_dim, 3 * num_classes) 45 | self.fc_s = nn.Linear(feat_dim, 3 * num_classes) 46 | 47 | # init ------------------------------------ 48 | self._init_weights() 49 | 50 | def _init_weights(self): 51 | for m in self.modules(): 52 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d)): 53 | normal_init(m, std=0.001) 54 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)): 55 | constant_init(m, 1) 56 | elif isinstance(m, nn.Linear): 57 | normal_init(m, std=0.001) 58 | normal_init(self.fc_t, std=0.01) 59 | normal_init(self.fc_s, std=0.01) 60 | 61 | def forward(self, x): 62 | """ 63 | x: should be flattened 64 | """ 65 | for _layer in self.linears: 66 | x = _layer(x) 67 | 68 | trans = self.fc_t(x) 69 | scale = self.fc_s(x) 70 | return trans, scale 71 | -------------------------------------------------------------------------------- /core/catre/models/net_factory.py: -------------------------------------------------------------------------------- 1 | from .pointnets.pointnet import PointNetfeat 2 | 3 | from .heads.fc_trans_size_head import FC_TransSizeHead 4 | from .heads.conv_out_per_rot_head import ConvOutPerRotHead 5 | 6 | PCLNETS = { 7 | "point_net": PointNetfeat, 8 | } 9 | 10 | HEADS = { 11 | "FC_TransSizeHead": FC_TransSizeHead, 12 | "ConvOutPerRotHead": ConvOutPerRotHead, 13 | } 14 | -------------------------------------------------------------------------------- /core/catre/models/pose_scale_from_delta_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from core.utils.utils import ( 4 | allo_to_ego_mat_torch, 5 | ) 6 | 7 | 8 | def pose_scale_from_delta_init( 9 | rot_deltas, 10 | trans_deltas, 11 | scale_deltas, 12 | rot_inits, 13 | trans_inits, 14 | scale_inits, 15 | Ks=None, 16 | K_aware=False, 17 | delta_T_space="3D", 18 | delta_T_weight=1.0, 19 | delta_z_style="cosypose", 20 | eps=1e-4, 21 | is_allo=False, 22 | scale_type="add_iter", 23 | ): 24 | """ 25 | Args: 26 | rot_deltas: [b,3,3] 27 | trans_deltas: [b,3], vxvyvz, delta translations in image space 28 | rot_inits: [b,3,3] 29 | trans_inits: [b,3] 30 | Ks: if None, using constants 1 31 | otherwise use zoomed Ks 32 | K_aware: whether to use zoomed K 33 | delta_T_space: image | 3D 34 | delta_T_weight: deepim-pytorch uses 0.1, default 1.0 35 | delta_z_style: cosypose (_vz = ztgt / zsrc) | deepim (vz = log(zrsc/ztgt)) 36 | eps: 37 | is_allo: 38 | Returns: 39 | rot_tgts, trans_tgts 40 | """ 41 | bs = rot_deltas.shape[0] 42 | assert rot_deltas.shape == (bs, 3, 3) 43 | assert rot_inits.shape == (bs, 3, 3) 44 | assert trans_deltas.shape == (bs, 3) 45 | assert trans_inits.shape == (bs, 3) 46 | 47 | # trans============================================ 48 | trans_deltas = trans_deltas * delta_T_weight 49 | 50 | if delta_T_space == "image": 51 | # Translation in image space 52 | zsrc = trans_inits[:, [2]] # [b,1] 53 | vz = trans_deltas[:, [2]] # [b,1] 54 | if delta_z_style == "cosypose": 55 | # NOTE: directly predict vz = 1/exp(_vz) 56 | # log(zsrc/ztgt) = _vz ==> ztgt = 1/exp(_vz) * zsrc 57 | ztgt = vz * zsrc # [b,1] 58 | else: # deepim 59 | # vz = log(zsrc/ztgt) ==> ztgt = zsrc / exp(vz) 60 | ztgt = torch.div(zsrc, torch.exp(vz)) # [b,1] 61 | 62 | if K_aware: 63 | assert Ks is not None and Ks.shape == (bs, 3, 3) 64 | vxvy = trans_deltas[:, :2] # [b,2] 65 | fxfy = Ks[:, [0, 1], [0, 1]] # [b,2] 66 | else: # deepim: treat fx, fy as 1 67 | vxvy = trans_deltas[:, :2] # [b,2] 68 | fxfy = torch.ones_like(vxvy) 69 | 70 | xy_src = trans_inits[:, :2] # [b,2] 71 | xy_tgt = ztgt * (vxvy / fxfy + xy_src / zsrc) # [b,2] 72 | trans_tgts = torch.cat([xy_tgt, ztgt], dim=-1) # [b,3] 73 | elif delta_T_space == "3D": 74 | trans_tgts = trans_inits + trans_deltas 75 | else: 76 | raise ValueError("Unknown delta_T_space: {}".format(delta_T_space)) 77 | 78 | # scale ========================================= 79 | if "add" in scale_type: 80 | scale_tgts = scale_inits + scale_deltas 81 | else: 82 | # NOTE: add exp to make scale_deltas zero-centered 83 | # scale_deltas =: log(s/mean_s) 84 | scale_tgts = scale_inits * torch.exp(scale_deltas) 85 | 86 | # rot =========================================== 87 | if is_allo: 88 | ego_rot_deltas = allo_to_ego_mat_torch(trans_tgts, rot_deltas, eps=eps) 89 | else: 90 | ego_rot_deltas = rot_deltas 91 | 92 | # Rotation in camera frame 93 | rot_tgts = ego_rot_deltas @ rot_inits 94 | 95 | return rot_tgts, trans_tgts, scale_tgts 96 | -------------------------------------------------------------------------------- /core/catre/test_catre.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # test 3 | set -x 4 | this_dir=$(dirname "$0") 5 | # commonly used opts: 6 | 7 | # MODEL.WEIGHTS: resume or pretrained, or test checkpoint 8 | CFG=$1 9 | CUDA_VISIBLE_DEVICES=$2 10 | IFS=',' read -ra GPUS <<< "$CUDA_VISIBLE_DEVICES" 11 | # GPUS=($(echo "$CUDA_VISIBLE_DEVICES" | tr ',' '\n')) 12 | NGPU=${#GPUS[@]} # echo "${GPUS[0]}" 13 | echo "use gpu ids: $CUDA_VISIBLE_DEVICES num gpus: $NGPU" 14 | CKPT=$3 15 | if [ ! -f "$CKPT" ]; then 16 | echo "$CKPT does not exist." 17 | exit 1 18 | fi 19 | NCCL_DEBUG=INFO 20 | OMP_NUM_THREADS=1 21 | MKL_NUM_THREADS=1 22 | PYTHONPATH="$this_dir/../..":$PYTHONPATH \ 23 | CUDA_VISIBLE_DEVICES=$2 python $this_dir/main_catre.py \ 24 | --config-file $CFG --num-gpus $NGPU --eval-only \ 25 | --opts MODEL.WEIGHTS=$CKPT \ 26 | ${@:4} 27 | -------------------------------------------------------------------------------- /core/catre/tools/camera25_prepare_spd_init_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mmcv 3 | import os.path as osp 4 | import glob 5 | import sys 6 | from tqdm import tqdm 7 | import setproctitle 8 | 9 | cur_dir = osp.dirname(osp.abspath(__file__)) 10 | PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../")) 11 | sys.path.insert(0, PROJ_ROOT) 12 | from lib.pysixd import inout, misc 13 | from lib.utils.mask_utils import binary_mask_to_rle 14 | 15 | setproctitle.setproctitle(osp.basename(__file__).split(".")[0]) 16 | 17 | data_root = osp.join(PROJ_ROOT, "datasets/NOCS") 18 | 19 | # original spd results 20 | spd_pose_dir = osp.join(PROJ_ROOT, "datasets/NOCS/deformnet_eval/eval_camera") # in bbnc11 21 | spd_seg_dir = osp.join(PROJ_ROOT, "datasets/NOCS/deformnet_eval/mrcnn_results/val") 22 | 23 | # our format 24 | init_pose_dir = osp.join(data_root, "test_init_poses") 25 | mmcv.mkdir_or_exist(init_pose_dir) 26 | init_pose_path = osp.join(init_pose_dir, "init_pose_spd_nocs_cmra.json") 27 | 28 | if __name__ == "__main__": 29 | results = {} 30 | 31 | CACHED = False 32 | if CACHED: 33 | results = mmcv.load(init_pose_path) 34 | else: 35 | spd_pose_paths = glob.glob(osp.join(spd_pose_dir, "results*.pkl")) 36 | 37 | num_total = 0 38 | for idx, spd_pose_path in enumerate(tqdm(spd_pose_paths)): 39 | preds = mmcv.load(spd_pose_path) 40 | bboxes = preds["pred_bboxes"] 41 | scores = preds["pred_scores"] 42 | poses = preds["pred_RTs"][:, :3] 43 | pred_scales = preds["pred_scales"] 44 | class_ids = preds["pred_class_ids"] 45 | mug_handles = preds["gt_handle_visibility"] 46 | 47 | scene_id, im_id = spd_pose_path.split("/")[-1].split(".")[0].split("_")[-2:] 48 | scene_im_id = f"{scene_id}/{im_id}" 49 | 50 | seg_path = osp.join(spd_seg_dir, f"results_val_{scene_id}_{im_id}.pkl") 51 | assert osp.exists(seg_path), seg_path 52 | 53 | masks = mmcv.load(seg_path)["masks"].astype("int") # bool -> int 54 | assert masks.shape[2] == len(class_ids) 55 | results[scene_im_id] = [] 56 | i = 0 57 | for class_id, pose, scale, score, bbox, mug_handle in zip( 58 | class_ids, poses, pred_scales, scores, bboxes, mug_handles 59 | ): 60 | # [sR -> R], normed_scale -> scale 61 | R = pose[:3, :3] 62 | if R.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]]: 63 | # print("ill pose") 64 | i += 1 65 | continue 66 | nocs_scale = pow(np.linalg.det(R), 1 / 3) 67 | abs_scale = scale * nocs_scale 68 | pose[:3, :3] = R / nocs_scale 69 | # mask2rle 70 | mask = masks[:, :, i] 71 | mask_rle = binary_mask_to_rle(mask) 72 | y1, x1, y2, x2 = bbox.tolist() 73 | bbox = [x1, y1, x2, y2] 74 | cur_res = { 75 | "obj_id": int(class_id), 76 | "pose_est": pose.tolist(), 77 | "scale_est": abs_scale.tolist(), 78 | "bbox_est": bbox, 79 | "score": float(score), 80 | "mug_handle": int(mug_handle), 81 | "segmentation": mask_rle, 82 | } 83 | results[scene_im_id].append(cur_res) 84 | i += 1 85 | 86 | print(init_pose_path) 87 | inout.save_json(init_pose_path, results, sort=False) 88 | 89 | VIS = False 90 | if VIS: 91 | from core.utils.data_utils import read_image_mmcv 92 | from lib.utils.mask_utils import cocosegm2mask 93 | from lib.vis_utils.image import grid_show, heatmap 94 | from core.catre.engine.test_utils import get_3d_bbox 95 | import ref 96 | 97 | for scene_im_id, r in results.items(): 98 | img_path = f"datasets/NOCS/REAL/real_test/{scene_im_id}_color.png" 99 | img = read_image_mmcv(img_path, format="BGR") 100 | K = ref.nocs.real_intrinsics 101 | anno = r[0] 102 | imH, imW = img.shape[:2] 103 | mask = cocosegm2mask(anno["segmentation"], imH, imW) 104 | pose = np.array(anno["pose_est"]).reshape(3, 4) 105 | scale = np.array(anno["scale_est"]) 106 | bbox = get_3d_bbox(scale).transpose() 107 | kpts_2d = misc.project_pts(bbox, K, pose[:, :3], pose[:, 3]) 108 | img_vis_kpts2d = misc.draw_projected_box3d(img.copy(), kpts_2d) 109 | grid_show([img[:, :, ::-1], img_vis_kpts2d, mask], row=3, col=1) 110 | -------------------------------------------------------------------------------- /core/catre/tools/prepare_spd_init_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mmcv 3 | import os.path as osp 4 | import glob 5 | import sys 6 | from tqdm import tqdm 7 | import setproctitle 8 | 9 | cur_dir = osp.dirname(osp.abspath(__file__)) 10 | PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../")) 11 | sys.path.insert(0, PROJ_ROOT) 12 | from lib.pysixd import inout, misc 13 | from lib.utils.mask_utils import binary_mask_to_rle 14 | 15 | setproctitle.setproctitle(osp.basename(__file__).split(".")[0]) 16 | 17 | data_root = osp.join(PROJ_ROOT, "datasets/NOCS") 18 | 19 | # original spd results 20 | spd_pose_dir = osp.join(PROJ_ROOT, "datasets/NOCS/deformnet_eval/eval_real") 21 | spd_seg_dir = osp.join(PROJ_ROOT, "datasets/NOCS/deformnet_eval/mrcnn_results/real_test") 22 | 23 | # our format 24 | init_pose_dir = osp.join(data_root, "test_init_poses") 25 | mmcv.mkdir_or_exist(init_pose_dir) 26 | init_pose_path = osp.join(init_pose_dir, "init_pose_spd_nocs_real.json") 27 | 28 | 29 | if __name__ == "__main__": 30 | results = {} 31 | 32 | CACHED = False 33 | if CACHED: 34 | results = mmcv.load(init_pose_path) 35 | else: 36 | spd_pose_paths = glob.glob(osp.join(spd_pose_dir, "results*.pkl")) 37 | 38 | num_total = 0 39 | for idx, spd_pose_path in enumerate(tqdm(spd_pose_paths)): 40 | preds = mmcv.load(spd_pose_path) 41 | bboxes = preds["pred_bboxes"] 42 | scores = preds["pred_scores"] 43 | poses = preds["pred_RTs"][:, :3] 44 | pred_scales = preds["pred_scales"] 45 | class_ids = preds["pred_class_ids"] 46 | mug_handles = preds["gt_handle_visibility"] 47 | 48 | scene_id, im_id = spd_pose_path.split("/")[-1].split(".")[0].split("_")[-2:] 49 | scene_im_id = f"scene_{scene_id}/{im_id}" 50 | 51 | seg_path = osp.join(spd_seg_dir, f"results_test_scene_{scene_id}_{im_id}.pkl") 52 | assert osp.exists(seg_path), seg_path 53 | 54 | masks = mmcv.load(seg_path)["masks"].astype("int") # bool -> int 55 | assert masks.shape[2] == len(class_ids) 56 | results[scene_im_id] = [] 57 | i = 0 58 | for class_id, pose, scale, score, bbox, mug_handle in zip( 59 | class_ids, poses, pred_scales, scores, bboxes, mug_handles 60 | ): 61 | # [sR -> R], normed_scale -> scale 62 | R = pose[:3, :3] 63 | nocs_scale = pow(np.linalg.det(R), 1 / 3) 64 | abs_scale = scale * nocs_scale 65 | pose[:3, :3] = R / nocs_scale 66 | # mask2rle 67 | mask = masks[:, :, i] 68 | mask_rle = binary_mask_to_rle(mask) 69 | y1, x1, y2, x2 = bbox.tolist() 70 | bbox = [x1, y1, x2, y2] 71 | cur_res = { 72 | "obj_id": int(class_id), 73 | "pose_est": pose.tolist(), 74 | "scale_est": abs_scale.tolist(), 75 | "bbox_est": bbox, 76 | "score": float(score), 77 | "mug_handle": int(mug_handle), 78 | "segmentation": mask_rle, 79 | } 80 | results[scene_im_id].append(cur_res) 81 | i += 1 82 | 83 | print(init_pose_path) 84 | inout.save_json(init_pose_path, results, sort=False) 85 | 86 | VIS = False 87 | if VIS: 88 | from core.utils.data_utils import read_image_mmcv 89 | from lib.utils.mask_utils import cocosegm2mask 90 | from lib.vis_utils.image import grid_show, heatmap 91 | from core.catre.engine.test_utils import get_3d_bbox 92 | import ref 93 | 94 | for scene_im_id, r in results.items(): 95 | img_path = f"datasets/NOCS/REAL/real_test/{scene_im_id}_color.png" 96 | img = read_image_mmcv(img_path, format="BGR") 97 | K = ref.nocs.real_intrinsics 98 | anno = r[0] 99 | imH, imW = img.shape[:2] 100 | mask = cocosegm2mask(anno["segmentation"], imH, imW) 101 | pose = np.array(anno["pose_est"]).reshape(3, 4) 102 | scale = np.array(anno["scale_est"]) 103 | bbox = get_3d_bbox(scale).transpose() 104 | kpts_2d = misc.project_pts(bbox, K, pose[:, :3], pose[:, 3]) 105 | img_vis_kpts2d = misc.draw_projected_box3d(img.copy(), kpts_2d) 106 | grid_show([img[:, :, ::-1], img_vis_kpts2d, mask], row=3, col=1) 107 | -------------------------------------------------------------------------------- /core/catre/train_catre.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | this_dir=$(dirname "$0") 4 | # commonly used opts: 5 | 6 | # MODEL.WEIGHTS: resume or pretrained, or test checkpoint 7 | CFG=$1 8 | CUDA_VISIBLE_DEVICES=$2 9 | IFS=',' read -ra GPUS <<< "$CUDA_VISIBLE_DEVICES" 10 | # GPUS=($(echo "$CUDA_VISIBLE_DEVICES" | tr ',' '\n')) 11 | NGPU=${#GPUS[@]} # echo "${GPUS[0]}" 12 | echo "use gpu ids: $CUDA_VISIBLE_DEVICES num gpus: $NGPU" 13 | # CUDA_LAUNCH_BLOCKING=1 14 | NCCL_DEBUG=INFO 15 | OMP_NUM_THREADS=1 16 | MKL_NUM_THREADS=1 17 | PYTHONPATH="$this_dir/../..":$PYTHONPATH \ 18 | CUDA_VISIBLE_DEVICES=$2 python $this_dir/main_catre.py \ 19 | --config-file $CFG --num-gpus $NGPU ${@:3} 20 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/core/utils/__init__.py -------------------------------------------------------------------------------- /core/utils/apex_trainer.py: -------------------------------------------------------------------------------- 1 | # just a reference implementation, no use 2 | 3 | import logging 4 | import time 5 | import torch 6 | from detectron2.engine import SimpleTrainer 7 | import core.utils.my_comm as comm 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | try: 12 | import apex 13 | from apex import amp 14 | except: 15 | logger.exception("Please install apex from https://www.github.com/nvidia/apex to use this ApexTrainer.") 16 | 17 | 18 | class ApexTrainer(SimpleTrainer): 19 | """Like :class:`SimpleTrainer`, but uses NVIDIA's apex automatic mixed 20 | precision in the training loop.""" 21 | 22 | def __init__(self, model, data_loader, optimizer, apex_opt_level="O1"): 23 | """ 24 | Args: 25 | model, data_loader, optimizer: same as in :class:`SimpleTrainer`. 26 | grad_scaler: torch GradScaler to automatically scale gradients. 27 | """ 28 | if comm.get_world_size() > 1: 29 | model, optimizer = amp.initialize(model, optimizer, opt_level=apex_opt_level) 30 | super().__init__(model, data_loader, optimizer) 31 | 32 | def run_step(self): 33 | """Implement the AMP training logic using apex.""" 34 | assert self.model.training, "[ApexTrainer] model was changed to eval mode!" 35 | assert torch.cuda.is_available(), "[ApexTrainer] CUDA is required for AMP training!" 36 | 37 | start = time.perf_counter() 38 | data = next(self._data_loader_iter) 39 | data_time = time.perf_counter() - start 40 | 41 | loss_dict = self.model(data) 42 | if isinstance(loss_dict, torch.Tensor): 43 | losses = loss_dict 44 | loss_dict = {"total_loss": loss_dict} 45 | else: 46 | losses = sum(loss_dict.values()) 47 | 48 | self.optimizer.zero_grad() 49 | with amp.scale_loss(losses, self.optimizer) as scaled_loss: 50 | scaled_loss.backwward() 51 | 52 | self._write_metrics(loss_dict, data_time) 53 | 54 | self.optimizer.step() 55 | 56 | def state_dict(self): 57 | ret = super().state_dict() 58 | # save amp state according to 59 | # https://nvidia.github.io/apex/amp.html#checkpointing 60 | ret["amp"] = amp.state_dict() 61 | return ret 62 | 63 | def load_state_dict(self, state_dict): 64 | super().load_state_dict(state_dict) 65 | if "amp" in state_dict: 66 | amp.load_state_dict(state_dict["amp"]) 67 | -------------------------------------------------------------------------------- /core/utils/camera_geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | 5 | 6 | def get_K_crop_resize(K, crop_xy, resize_ratio): 7 | """ 8 | Args: 9 | K: [b,3,3] 10 | crop_xy: [b, 2] left top of crop boxes 11 | resize_ratio: [b,2] or [b,1] 12 | """ 13 | assert K.shape[1:] == (3, 3) 14 | assert crop_xy.shape[1:] == (2,) 15 | assert resize_ratio.shape[1:] == (2,) or resize_ratio.shape[1:] == (1,) 16 | bs = K.shape[0] 17 | 18 | new_K = K.clone() 19 | new_K[:, [0, 1], 2] = K[:, [0, 1], 2] - crop_xy # [b, 2] 20 | new_K[:, [0, 1]] = new_K[:, [0, 1]] * resize_ratio.view(bs, -1, 1) 21 | return new_K 22 | 23 | 24 | def project_points(points_3d, K, pose, z_min=None): 25 | """ 26 | Args: 27 | points_3d: BxPx3 28 | K: Bx3x3 29 | pose: Bx3x4 30 | z_min: prevent zero devision, eg. 0.1 31 | Returns: 32 | projected 2d points: BxPx2 33 | """ 34 | assert K.shape[-2:] == (3, 3) 35 | assert pose.shape[-2:] == (3, 4) 36 | batch_size = points_3d.shape[0] 37 | n_points = points_3d.shape[1] 38 | device = points_3d.device 39 | if points_3d.shape[-1] == 3: 40 | points_3d = torch.cat((points_3d, torch.ones(batch_size, n_points, 1).to(device)), dim=-1) 41 | P = K @ pose[:, :3] 42 | suv = (P.unsqueeze(1) @ points_3d.unsqueeze(-1)).squeeze(-1) # Bx1x3x4 @ BxPx4x1 -> BxPx3 43 | if z_min is not None: 44 | z = suv[..., -1] 45 | suv[..., -1] = torch.max(torch.ones_like(z) * z_min, z) # eg. z_min=0.1 46 | suv = suv / suv[..., [-1]] 47 | return suv[..., :2] # BxPx2 48 | 49 | 50 | def centers_2d_from_t(K, t, z_min=None): 51 | """can also get the centers via projecting the zero point (B,1,3) 52 | Args: 53 | K: Bx3x3 54 | t: Bx3 55 | z_min: to prevent zero division 56 | Returns: 57 | centers_2d: Bx2 58 | """ 59 | assert K.ndim == 3 and K.shape[-2:] == (3, 3), K.shape 60 | bs = K.shape[0] 61 | proj = (K @ t.view(bs, 3, 1)).view(bs, 3) 62 | if z_min is not None: 63 | z = proj[..., -1] 64 | proj[..., -1] = torch.max(torch.ones_like(z) * z_min, z) # eg. z_min=0.1 65 | centers_2d = proj[:, :2] / proj[:, [-1]] # Nx2 66 | return centers_2d 67 | 68 | 69 | def centers_2d_from_pose(K, pose, z_min=None): 70 | """can also get the centers via projecting the zero point (B,1,3) 71 | Args: 72 | K: Bx3x3 73 | pose: Bx3x4 (only use the transltion) 74 | z_min: to prevent zero division 75 | Returns: 76 | centers_2d: Bx2 77 | """ 78 | assert K.ndim == 3 and K.shape[-2:] == (3, 3), K.shape 79 | assert pose.ndim == 3 and pose.shape[-2:] == (3, 4), pose.shape 80 | bs = pose.shape[0] 81 | proj = (K @ pose[:, :3, [3]]).view(bs, 3) # Nx3x3 @ Nx3x1 -> Nx3x1 -> Nx3 82 | if z_min is not None: 83 | z = proj[..., -1] 84 | proj[..., -1] = torch.max(torch.ones_like(z) * z_min, z) # eg. z_min=0.1 85 | centers_2d = proj[:, :2] / proj[:, [-1]] # Nx2 86 | return centers_2d 87 | 88 | 89 | def boxes_from_points_2d(uv): 90 | """ 91 | Args: 92 | uv: BxPx2 projected 2d points 93 | Returns: 94 | Bx4 95 | """ 96 | assert uv.ndim == 3 and uv.shape[-1] == 2, uv.shape 97 | x1 = uv[..., 0].min(dim=1)[0] # (B,) 98 | y1 = uv[..., 1].min(dim=1)[0] 99 | 100 | x2 = uv[..., 0].max(dim=1)[0] 101 | y2 = uv[..., 1].max(dim=1)[0] 102 | 103 | return torch.stack([x1, y1, x2, y2], dim=1) # Bx4 104 | 105 | 106 | def bboxes_from_pose(points_3d, K, pose, z_min=None, imH=480, imW=640, clamp=False): 107 | points_2d = project_points(points_3d, K=K, pose=pose, z_min=z_min) 108 | boxes = boxes_from_points_2d(points_2d) 109 | if clamp: 110 | boxes[:, [0, 2]] = torch.clamp(boxes[:, [0, 2]], 0, imW - 1) 111 | boxes[:, [1, 3]] = torch.clamp(boxes[:, [1, 3]], 0, imH - 1) 112 | return boxes 113 | 114 | 115 | def adapt_image_by_K( 116 | image, *, K_old, K_new, interpolation=cv2.INTER_LINEAR, border_type=cv2.BORDER_REFLECT, height=480, width=640 117 | ): 118 | """adapt image from old K to new K.""" 119 | H_old, W_old = image.shape[:2] 120 | K_old = K_old.copy() 121 | K_old[0, :] = K_old[0, :] / W_old * width 122 | K_old[1, :] = K_old[1, :] / H_old * height 123 | 124 | focal_scale_x = K_new[0, 0] / K_old[0, 0] 125 | focal_scale_y = K_new[1, 1] / K_old[1, 1] 126 | ox, oy = K_new[0, 2] - K_old[0, 2], K_new[1, 2] - K_old[1, 2] 127 | 128 | image = cv2.resize( 129 | image, 130 | (int(width * focal_scale_x), int(height * focal_scale_y)), 131 | interpolation=interpolation, 132 | ) 133 | 134 | image = cv2.copyMakeBorder(image, 200, 200, 200, 200, borderType=border_type) 135 | # print(image.shape) 136 | y1 = int(round(image.shape[0] / 2 - oy - height / 2)) 137 | y2 = int(round(image.shape[0] / 2 - oy + height / 2)) 138 | x1 = int(round(image.shape[1] / 2 - ox - width / 2)) 139 | x2 = int(round(image.shape[1] / 2 - ox + width / 2)) 140 | # print(x1, y1, x2, y2) 141 | return image[y1:y2, x1:x2] 142 | -------------------------------------------------------------------------------- /core/utils/depth_aug.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | 5 | def add_noise_depth(depth, level=0.005, depth_valid_min=0): 6 | # from DeepIM-PyTorch and se3tracknet 7 | # in deepim: level=0.1, valid_min=0 8 | # in se3tracknet, level=5/1000, depth_valid_min = 100/1000 = 0.1 9 | 10 | if len(depth.shape) == 3: 11 | mask = depth[:, :, -1] > depth_valid_min 12 | row, col, ch = depth.shape 13 | noise_level = random.uniform(0, level) 14 | gauss = noise_level * np.random.randn(row, col) 15 | gauss = np.repeat(gauss[:, :, np.newaxis], ch, axis=2) 16 | else: # 2 17 | mask = depth > depth_valid_min 18 | row, col = depth.shape 19 | noise_level = random.uniform(0, level) 20 | gauss = noise_level * np.random.randn(row, col) 21 | noisy = depth.copy() 22 | noisy[mask] = depth[mask] + gauss[mask] 23 | return noisy 24 | 25 | 26 | if __name__ == "__main__": 27 | from lib.vis_utils.image import heatmap, grid_show 28 | import mmcv 29 | import cv2 30 | from skimage.restoration import denoise_bilateral 31 | 32 | # depth = mmcv.imread("datasets/BOP_DATASETS/ycbv/train_pbr/000000/depth/000000.png", "unchanged") / 10000.0 33 | # depth_aug = add_noise_depth(depth, level=0.005, depth_valid_min=0.1) 34 | # diff = depth_aug - depth 35 | # grid_show([ 36 | # heatmap(depth, to_rgb=True), heatmap(depth_aug, to_rgb=True), 37 | # heatmap(diff, to_rgb=True) 38 | # ], ["depth", "depth_aug", "diff"], row=1, col=3) 39 | 40 | depth = (mmcv.imread("datasets/BOP_DATASETS/ycbv/test/000048/depth/000001.png", "unchanged") / 10000.0).astype( 41 | "float32" 42 | ) 43 | # diameter, pix_sigma, space_sigma 44 | depth_aug = cv2.bilateralFilter(depth, 11, 0.1, 30) 45 | # depth_aug = denoise_bilateral(depth, sigma_color=0.05, sigma_spatial=15) 46 | diff = depth_aug - depth 47 | grid_show( 48 | [heatmap(depth, to_rgb=True), heatmap(depth_aug, to_rgb=True), heatmap(diff, to_rgb=True)], 49 | ["depth", "depth_aug", "diff"], 50 | row=1, 51 | col=3, 52 | ) 53 | -------------------------------------------------------------------------------- /core/utils/farthest_points_torch.py: -------------------------------------------------------------------------------- 1 | # https://github.com/NVlabs/latentfusion/blob/master/latentfusion/three/utils.py 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def farthest_points( 7 | data, 8 | n_clusters: int, 9 | dist_func=F.pairwise_distance, 10 | return_center_indexes=True, 11 | return_distances=False, 12 | verbose=False, 13 | init_center=True, 14 | ): 15 | """Performs farthest point sampling on data points. 16 | 17 | Args: 18 | data (torch.tensor): data points. 19 | n_clusters (int): number of clusters. 20 | dist_func (Callable): distance function that is used to compare two data points. 21 | return_center_indexes (bool): if True, returns the indexes of the center of clusters. 22 | return_distances (bool): if True, return distances of each point from centers. 23 | Returns clusters, [centers, distances]: 24 | clusters (torch.tensor): the cluster index for each element in data. 25 | centers (torch.tensor): the integer index of each center. 26 | distances (torch.tensor): closest distances of each point to any of the cluster centers. 27 | """ 28 | if n_clusters >= data.shape[0]: 29 | if return_center_indexes: 30 | return ( 31 | torch.arange(data.shape[0], dtype=torch.long), 32 | torch.arange(data.shape[0], dtype=torch.long), 33 | ) 34 | 35 | return torch.arange(data.shape[0], dtype=torch.long) 36 | 37 | clusters = torch.full((data.shape[0],), fill_value=-1, dtype=torch.long) 38 | centers = torch.zeros(n_clusters, dtype=torch.long) 39 | 40 | if init_center: 41 | broadcasted_data = torch.mean(data, 0, keepdim=True).expand(data.shape[0], -1) 42 | distances = dist_func(broadcasted_data, data) 43 | else: 44 | distances = torch.full((data.shape[0],), fill_value=1e7, dtype=torch.float32) 45 | 46 | for i in range(n_clusters): 47 | center_idx = torch.argmax(distances) 48 | centers[i] = center_idx 49 | 50 | broadcasted_data = data[center_idx].unsqueeze(0).expand(data.shape[0], -1) 51 | new_distances = dist_func(broadcasted_data, data) 52 | distances = torch.min(distances, new_distances) 53 | clusters[distances == new_distances] = i 54 | if verbose: 55 | print("farthest points max distance : {}".format(torch.max(distances))) 56 | 57 | if return_center_indexes: 58 | if return_distances: 59 | return clusters, centers, distances 60 | return clusters, centers 61 | 62 | return clusters 63 | 64 | 65 | def get_fps_and_center_torch(points, num_fps: int, init_center=True, dist_func=F.pairwise_distance): 66 | center = torch.mean(points, 0, keepdim=True) 67 | _, fps_inds = farthest_points( 68 | points, 69 | n_clusters=num_fps, 70 | dist_func=dist_func, 71 | return_center_indexes=True, 72 | init_center=init_center, 73 | ) 74 | fps_pts = points[fps_inds] 75 | return torch.cat([fps_pts, center], dim=0) 76 | -------------------------------------------------------------------------------- /core/utils/my_setup.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path as osp 3 | 4 | 5 | def get_project_root(): 6 | cur_dir = osp.dirname(osp.abspath(__file__)) 7 | proj_root = osp.normpath(osp.join(cur_dir, "../../")) 8 | return proj_root 9 | 10 | 11 | PROJ_ROOT = get_project_root() 12 | 13 | 14 | def get_data_root(): 15 | proj_root = get_project_root() 16 | return osp.join(proj_root, "datasets") 17 | 18 | 19 | DATA_ROOT = get_data_root() 20 | 21 | 22 | def setup_for_distributed(is_master): 23 | """This function disables printing when not in master process.""" 24 | import builtins as __builtin__ 25 | 26 | builtin_print = __builtin__.print 27 | if not is_master: 28 | logging.getLogger("core").setLevel("WARN") 29 | logging.getLogger("d2").setLevel("WARN") 30 | logging.getLogger("lib").setLevel("WARN") 31 | logging.getLogger("my").setLevel("WARN") 32 | 33 | def print(*args, **kwargs): 34 | force = kwargs.pop("force", False) 35 | if is_master or force: 36 | builtin_print(*args, **kwargs) 37 | 38 | __builtin__.print = print 39 | -------------------------------------------------------------------------------- /core/utils/ssd_color_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # ref: https://github.com/facebookresearch/detectron2/blob/master/projects/PointRend/point_rend/color_augmentation.py 3 | 4 | import numpy as np 5 | import random 6 | import cv2 7 | from fvcore.transforms.transform import Transform 8 | 9 | 10 | class ColorAugSSDTransform(Transform): 11 | """ 12 | A color related data augmentation used in Single Shot Multibox Detector (SSD). 13 | Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, 14 | Scott Reed, Cheng-Yang Fu, Alexander C. Berg. 15 | SSD: Single Shot MultiBox Detector. ECCV 2016. 16 | Implementation based on: 17 | https://github.com/weiliu89/caffe/blob 18 | /4817bf8b4200b35ada8ed0dc378dceaf38c539e4 19 | /src/caffe/util/im_transforms.cpp 20 | https://github.com/chainer/chainercv/blob 21 | /7159616642e0be7c5b3ef380b848e16b7e99355b/chainercv 22 | /links/model/ssd/transforms.py 23 | """ 24 | 25 | def __init__( 26 | self, 27 | img_format, 28 | brightness_delta=32, 29 | contrast_low=0.5, 30 | contrast_high=1.5, 31 | saturation_low=0.5, 32 | saturation_high=1.5, 33 | hue_delta=18, 34 | ): 35 | super().__init__() 36 | assert img_format in ["BGR", "RGB"] 37 | self.is_rgb = img_format == "RGB" 38 | del img_format 39 | self._set_attributes(locals()) 40 | 41 | def apply_coords(self, coords): 42 | return coords 43 | 44 | def apply_segmentation(self, segmentation): 45 | return segmentation 46 | 47 | def apply_image(self, img, interp=None): 48 | if self.is_rgb: 49 | img = img[:, :, [2, 1, 0]] 50 | img = self.brightness(img) 51 | if random.randrange(2): 52 | img = self.contrast(img) 53 | img = self.saturation(img) 54 | img = self.hue(img) 55 | else: 56 | img = self.saturation(img) 57 | img = self.hue(img) 58 | img = self.contrast(img) 59 | if self.is_rgb: 60 | img = img[:, :, [2, 1, 0]] 61 | return img 62 | 63 | def convert(self, img, alpha=1, beta=0): 64 | img = img.astype(np.float32) * alpha + beta 65 | img = np.clip(img, 0, 255) 66 | return img.astype(np.uint8) 67 | 68 | def brightness(self, img): 69 | if random.randrange(2): 70 | return self.convert( 71 | img, 72 | beta=random.uniform(-self.brightness_delta, self.brightness_delta), 73 | ) 74 | return img 75 | 76 | def contrast(self, img): 77 | if random.randrange(2): 78 | return self.convert( 79 | img, 80 | alpha=random.uniform(self.contrast_low, self.contrast_high), 81 | ) 82 | return img 83 | 84 | def saturation(self, img): 85 | if random.randrange(2): 86 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 87 | img[:, :, 1] = self.convert( 88 | img[:, :, 1], 89 | alpha=random.uniform(self.saturation_low, self.saturation_high), 90 | ) 91 | return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) 92 | return img 93 | 94 | def hue(self, img): 95 | if random.randrange(2): 96 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 97 | img[:, :, 0] = (img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta)) % 180 98 | return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) 99 | return img 100 | -------------------------------------------------------------------------------- /core/utils/timm_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import timm 4 | import pathlib 5 | 6 | _logger = logging.getLogger(__name__) 7 | 8 | 9 | def my_create_timm_model(**init_args): 10 | # HACK: fix the bug for feature_only=True and checkpoint_path != "" 11 | # https://github.com/rwightman/pytorch-image-models/issues/488 12 | if init_args.get("checkpoint_path", "") != "" and init_args.get("features_only", True): 13 | init_args = copy.deepcopy(init_args) 14 | full_model_name = init_args["model_name"] 15 | modules = timm.models.list_modules() 16 | # find the mod which has the longest common name in model_name 17 | mod_len = 0 18 | for m in modules: 19 | if m in full_model_name: 20 | cur_mod_len = len(m) 21 | if cur_mod_len > mod_len: 22 | mod = m 23 | mod_len = cur_mod_len 24 | if mod_len >= 1: 25 | if hasattr(timm.models.__dict__[mod], "default_cfgs"): 26 | ckpt_path = init_args.pop("checkpoint_path") 27 | ckpt_url = pathlib.Path(ckpt_path).resolve().as_uri() 28 | _logger.warning(f"hacking model pretrained url to {ckpt_url}") 29 | timm.models.__dict__[mod].default_cfgs[full_model_name]["url"] = ckpt_url 30 | init_args["pretrained"] = True 31 | else: 32 | raise ValueError(f"model_name {full_model_name} has no module in timm") 33 | 34 | backbone = timm.create_model(**init_args) 35 | return backbone 36 | -------------------------------------------------------------------------------- /core/utils/zoom_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from detectron2.layers.roi_align import ROIAlign 3 | from torchvision.ops import RoIPool 4 | 5 | 6 | def deepim_boxes( 7 | ren_boxes, 8 | ren_centers_2d, 9 | obs_boxes=None, 10 | lamb=1.4, 11 | imHW=(480, 640), 12 | outHW=(480, 640), 13 | clamp=False, 14 | ): 15 | """ 16 | Args: 17 | ren_boxes: N x 4 18 | ren_centers_2d: Nx2, rendered object center is the crop center 19 | obs_boxes: N x 4, if None, only use the rendered boxes/centers to determine the crop region 20 | lamb: enlarge the scale of cropped region 21 | imH (int): 22 | imW (int): 23 | Returns: 24 | crop_boxes (Tensor): N x 4, either the common region from obs/ren or just obs 25 | resize_ratios (Tensor): Nx2, resize ratio of (w,h), actually the same in w,h because we keep the aspect ratio 26 | """ 27 | ren_x1, ren_y1, ren_x2, ren_y2 = (ren_boxes[:, i] for i in range(4)) # (N,) 28 | ren_cx = ren_centers_2d[:, 0] # (N,) 29 | ren_cy = ren_centers_2d[:, 1] # (N,) 30 | 31 | outH, outW = outHW 32 | aspect_ratio = outW / outH # 4/3 or 1 33 | 34 | if obs_boxes is not None: 35 | obs_x1, obs_y1, obs_x2, obs_y2 = (obs_boxes[:, i] for i in range(4)) # (N,) 36 | xdists = torch.stack( 37 | [ 38 | ren_cx - obs_x1, 39 | ren_cx - ren_x1, 40 | obs_x2 - ren_cx, 41 | ren_x2 - ren_cx, 42 | ], 43 | dim=1, 44 | ).abs() 45 | ydists = torch.stack( 46 | [ 47 | ren_cy - obs_y1, 48 | ren_cy - ren_y1, 49 | obs_y2 - ren_cy, 50 | ren_y2 - ren_cy, 51 | ], 52 | dim=1, 53 | ).abs() 54 | else: 55 | xdists = torch.stack([ren_cx - ren_x1, ren_x2 - ren_cx], dim=1).abs() 56 | ydists = torch.stack([ren_cy - ren_y1, ren_y2 - ren_cy], dim=1).abs() 57 | xdist = xdists.max(dim=1)[0] # (N,) 58 | ydist = ydists.max(dim=1)[0] 59 | 60 | crop_h = torch.max(xdist / aspect_ratio, ydist).clamp(min=1) * 2 * lamb # (N,) 61 | crop_w = crop_h * aspect_ratio # (N,) 62 | 63 | x1, y1, x2, y2 = ( 64 | ren_cx - crop_w / 2, 65 | ren_cy - crop_h / 2, 66 | ren_cx + crop_w / 2, 67 | ren_cy + crop_h / 2, 68 | ) 69 | boxes = torch.stack([x1, y1, x2, y2], dim=1) 70 | assert not clamp 71 | if clamp: 72 | imH, imW = imHW 73 | boxes[:, [0, 2]] = torch.clamp(boxes[:, [0, 2]], 0, imW - 1) 74 | boxes[:, [1, 3]] = torch.clamp(boxes[:, [1, 3]], 0, imH - 1) 75 | 76 | resize_ratios = torch.stack([outW / crop_w, outH / crop_h], dim=1) 77 | return boxes, resize_ratios 78 | 79 | 80 | def batch_crop_resize(x, rois, out_H, out_W, aligned=True, interpolation="bilinear"): 81 | """ 82 | Args: 83 | x: BCHW 84 | rois: Bx5, rois[:, 0] is the idx into x 85 | out_H (int): 86 | out_W (int): 87 | """ 88 | output_size = (out_H, out_W) 89 | if interpolation == "bilinear": 90 | op = ROIAlign(output_size, 1.0, 0, aligned=aligned) 91 | elif interpolation == "nearest": 92 | op = RoIPool(output_size, 1.0) # 93 | else: 94 | raise ValueError(f"Wrong interpolation type: {interpolation}") 95 | return op(x, rois) 96 | -------------------------------------------------------------------------------- /datasets/NOCS/REAL/real_test: -------------------------------------------------------------------------------- 1 | /data2/lxy/object_pose_benchmark/datasets/NOCS/REAL/real_test/ -------------------------------------------------------------------------------- /datasets/NOCS/REAL/real_train: -------------------------------------------------------------------------------- 1 | /data2/lxy/object_pose_benchmark/datasets/NOCS/REAL/real_train/ -------------------------------------------------------------------------------- /datasets/NOCS/obj_models/abs_scale.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/abs_scale.pkl -------------------------------------------------------------------------------- /datasets/NOCS/obj_models/cr_normed_mean_model_points_spd.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/cr_normed_mean_model_points_spd.pkl -------------------------------------------------------------------------------- /datasets/NOCS/obj_models/mug_handle.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/mug_handle.pkl -------------------------------------------------------------------------------- /datasets/NOCS/obj_models/mug_meta.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/mug_meta.pkl -------------------------------------------------------------------------------- /datasets/NOCS/obj_models/real_test_spd.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/real_test_spd.pkl -------------------------------------------------------------------------------- /datasets/NOCS/obj_models/real_train_spd.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/real_train_spd.pkl -------------------------------------------------------------------------------- /docs/INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | * CUDA >= 10.1, Ubuntu >= 16.04 4 | 5 | * Python >= 3.6, PyTorch >= 1.9, torchvision 6 | ``` 7 | ## create a new environment 8 | conda create -n py37 python=3.7.4 9 | conda activate py37 # maybe add this line to the end of ~/.bashrc 10 | conda install ipython 11 | ## install pytorch: https://pytorch.org/get-started/locally/ (check cuda version) 12 | pip install torchvision -U # will also install corresponding torch 13 | ``` 14 | 15 | * `detectron2` from [source](https://github.com/facebookresearch/detectron2). 16 | ``` 17 | git clone https://github.com/facebookresearch/detectron2.git 18 | cd detectron2 19 | pip install ninja 20 | pip install -e . 21 | ``` 22 | 23 | * `sh scripts/install_deps.sh` 24 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/__init__.py -------------------------------------------------------------------------------- /lib/pysixd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/pysixd/__init__.py -------------------------------------------------------------------------------- /lib/pysixd/colors.json: -------------------------------------------------------------------------------- 1 | [ 2 | [0.89, 0.28, 0.13], 3 | [0.45, 0.38, 0.92], 4 | [0.35, 0.73, 0.63], 5 | [0.62, 0.28, 0.91], 6 | [0.65, 0.71, 0.22], 7 | [0.8, 0.29, 0.89], 8 | [0.27, 0.55, 0.22], 9 | [0.37, 0.46, 0.84], 10 | [0.84, 0.63, 0.22], 11 | [0.68, 0.29, 0.71], 12 | [0.48, 0.75, 0.48], 13 | [0.88, 0.27, 0.75], 14 | [0.82, 0.45, 0.2], 15 | [0.86, 0.27, 0.27], 16 | [0.52, 0.49, 0.18], 17 | [0.33, 0.67, 0.25], 18 | [0.67, 0.42, 0.29], 19 | [0.67, 0.46, 0.86], 20 | [0.36, 0.72, 0.84], 21 | [0.85, 0.29, 0.4], 22 | [0.24, 0.53, 0.55], 23 | [0.85, 0.55, 0.8], 24 | [0.4, 0.51, 0.33], 25 | [0.56, 0.38, 0.63], 26 | [0.78, 0.66, 0.46], 27 | [0.33, 0.5, 0.72], 28 | [0.83, 0.31, 0.56], 29 | [0.56, 0.61, 0.85], 30 | [0.89, 0.58, 0.57], 31 | [0.67, 0.4, 0.49] 32 | ] -------------------------------------------------------------------------------- /lib/pysixd/config.py: -------------------------------------------------------------------------------- 1 | # Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz) 2 | # Center for Machine Perception, Czech Technical University in Prague 3 | 4 | """Configuration of the BOP Toolkit.""" 5 | 6 | import os 7 | 8 | 9 | ######## Basic ######## 10 | 11 | # Folder with the BOP datasets. 12 | if "BOP_PATH" in os.environ: 13 | datasets_path = os.environ["BOP_PATH"] 14 | else: 15 | datasets_path = r"datasets/BOP_DATASETS/" 16 | 17 | # Folder with pose results to be evaluated. 18 | # results_path = r'/path/to/folder/with/results' 19 | results_path = r"output/bop_results" 20 | 21 | # Folder for the calculated pose errors and performance scores. 22 | # eval_path = r'/path/to/eval/folder' 23 | eval_path = r"output/bop_eval/" 24 | 25 | ######## Extended ######## 26 | 27 | # Folder for outputs (e.g. visualizations). 28 | # output_path = r'/path/to/output/folder' 29 | output_path = r"output/bop_output/" 30 | 31 | # For offscreen C++ rendering: Path to the build folder of bop_renderer (github.com/thodan/bop_renderer). 32 | bop_renderer_path = r"bop_renderer/build" 33 | 34 | # Executable of the MeshLab server. 35 | # meshlab_server_path = r'/path/to/meshlabserver.exe' 36 | meshlab_server_path = r"/usr/bin/meshlabserver" 37 | -------------------------------------------------------------------------------- /lib/pysixd/renderer.py: -------------------------------------------------------------------------------- 1 | # Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz) 2 | # Center for Machine Perception, Czech Technical University in Prague 3 | 4 | """Abstract class of a renderer and a factory function to create a renderer. 5 | 6 | The renderer produces an RGB/depth image of a 3D mesh model in a 7 | specified pose for given camera parameters and illumination settings. 8 | """ 9 | 10 | 11 | class Renderer(object): 12 | """Abstract class of a renderer.""" 13 | 14 | def __init__(self, width, height): 15 | """Constructor. 16 | 17 | :param width: Width of the rendered image. 18 | :param height: Height of the rendered image. 19 | """ 20 | self.width = width 21 | self.height = height 22 | 23 | # 3D location of a point light (in the camera coordinates). 24 | self.light_cam_pos = (0, 0, 0) 25 | 26 | # Set light color and weights. 27 | self.light_color = (1.0, 1.0, 1.0) # Used only in C++ renderer. 28 | self.light_ambient_weight = 0.5 29 | self.light_diffuse_weight = 1.0 # Used only in C++ renderer. 30 | self.light_specular_weight = 0.0 # Used only in C++ renderer. 31 | self.light_specular_shininess = 0.0 # Used only in C++ renderer. 32 | 33 | def set_light_cam_pos(self, light_cam_pos): 34 | """Sets the 3D location of a point light. 35 | 36 | :param light_cam_pos: [X, Y, Z]. 37 | """ 38 | self.light_cam_pos = light_cam_pos 39 | 40 | def set_light_ambient_weight(self, light_ambient_weight): 41 | """Sets weight of the ambient light. 42 | 43 | :param light_ambient_weight: Scalar from 0 to 1. 44 | """ 45 | self.light_ambient_weight = light_ambient_weight 46 | 47 | def add_object(self, obj_id, model_path, **kwargs): 48 | """Loads an object model. 49 | 50 | :param obj_id: Object identifier. 51 | :param model_path: Path to the object model file. 52 | """ 53 | raise NotImplementedError 54 | 55 | def remove_object(self, obj_id): 56 | """Removes an object model. 57 | 58 | :param obj_id: Identifier of the object to remove. 59 | """ 60 | raise NotImplementedError 61 | 62 | def render_object(self, obj_id, R, t, fx, fy, cx, cy): 63 | """Renders an object model in the specified pose. 64 | 65 | :param obj_id: Object identifier. 66 | :param R: 3x3 ndarray with a rotation matrix. 67 | :param t: 3x1 ndarray with a translation vector. 68 | :param fx: Focal length (X axis). 69 | :param fy: Focal length (Y axis). 70 | :param cx: The X coordinate of the principal point. 71 | :param cy: The Y coordinate of the principal point. 72 | :return: Returns a dictionary with rendered images. 73 | """ 74 | raise NotImplementedError 75 | 76 | 77 | def create_renderer( 78 | width, 79 | height, 80 | renderer_type="cpp", 81 | mode="rgb+depth", 82 | shading="phong", 83 | bg_color=(0.0, 0.0, 0.0, 0.0), 84 | ): 85 | """A factory to create a renderer. 86 | 87 | Note: Parameters mode, shading and bg_color are currently supported only by 88 | the Python renderer (renderer_type='python'). 89 | 90 | :param width: Width of the rendered image. 91 | :param height: Height of the rendered image. 92 | :param renderer_type: Type of renderer (options: 'cpp', 'python'). 93 | :param mode: Rendering mode ('rgb+depth', 'rgb', 'depth'). 94 | :param shading: Type of shading ('flat', 'phong'). 95 | :param bg_color: Color of the background (R, G, B, A). 96 | :return: Instance of a renderer of the specified type. 97 | """ 98 | if renderer_type == "python": 99 | from . import renderer_py 100 | 101 | return renderer_py.RendererPython(width, height, mode, shading, bg_color) 102 | 103 | elif renderer_type == "cpp": 104 | from . import renderer_cpp 105 | 106 | return renderer_cpp.RendererCpp(width, height) 107 | 108 | elif renderer_type == "pyrender": 109 | from . import renderer_pyrender 110 | 111 | return renderer_pyrender.Renderer(width, height, mode=mode, bg_color=bg_color) 112 | 113 | elif renderer_type == "vispy": 114 | from . import renderer_vispy 115 | 116 | return renderer_vispy.RendererVispy(width, height, mode, shading=shading, bg_color=bg_color) 117 | 118 | else: 119 | raise ValueError("Unknown renderer type.") 120 | -------------------------------------------------------------------------------- /lib/pysixd/renderer_cpp.py: -------------------------------------------------------------------------------- 1 | # Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz) 2 | # Center for Machine Perception, Czech Technical University in Prague 3 | 4 | """An interface to the C++ based renderer (bop_renderer).""" 5 | 6 | import sys 7 | import numpy as np 8 | 9 | from lib.pysixd import config 10 | from lib.pysixd import renderer 11 | 12 | # C++ renderer (https://github.com/thodan/bop_renderer) 13 | sys.path.append(config.bop_renderer_path) 14 | import bop_renderer 15 | 16 | 17 | class RendererCpp(renderer.Renderer): 18 | """An interface to the C++ based renderer.""" 19 | 20 | def __init__(self, width, height): 21 | """See base class.""" 22 | super(RendererCpp, self).__init__(width, height) 23 | self.renderer = bop_renderer.Renderer() 24 | self.renderer.init(width, height) 25 | self._set_light() 26 | 27 | def _set_light(self): 28 | self.renderer.set_light( 29 | list(self.light_cam_pos), 30 | list(self.light_color), 31 | self.light_ambient_weight, 32 | self.light_diffuse_weight, 33 | self.light_specular_weight, 34 | self.light_specular_shininess, 35 | ) 36 | 37 | def set_light_cam_pos(self, light_cam_pos): 38 | """See base class.""" 39 | super(RendererCpp, self).set_light_cam_pos(light_cam_pos) 40 | self._set_light() 41 | 42 | def set_light_ambient_weight(self, light_ambient_weight): 43 | """See base class.""" 44 | super(RendererCpp, self).set_light_ambient_weight(light_ambient_weight) 45 | self._set_light() 46 | 47 | def add_object(self, obj_id, model_path, **kwargs): 48 | """See base class. 49 | 50 | NEEDS TO BE CALLED RIGHT AFTER CREATING THE RENDERER (this is 51 | due to some memory issues in the C++ renderer which need to be 52 | fixed). 53 | """ 54 | self.renderer.add_object(obj_id, model_path) 55 | 56 | def remove_object(self, obj_id): 57 | """See base class.""" 58 | self.renderer.remove_object(obj_id) 59 | 60 | def render_object(self, obj_id, R, t, fx, fy, cx, cy): 61 | """See base class.""" 62 | R_l = R.astype(np.float32).flatten().tolist() 63 | t_l = t.astype(np.float32).flatten().tolist() 64 | self.renderer.render_object(obj_id, R_l, t_l, fx, fy, cx, cy) 65 | rgb = self.renderer.get_color_image(obj_id) 66 | depth = self.renderer.get_depth_image(obj_id).astype(np.float32) 67 | return {"rgb": rgb, "depth": depth} 68 | -------------------------------------------------------------------------------- /lib/pysixd/se3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.matlib as npm 3 | from transforms3d.quaternions import mat2quat, quat2mat, quat2axangle 4 | import scipy.stats as sci_stats 5 | 6 | 7 | # RT is a 3x4 matrix 8 | def se3_inverse(RT): 9 | R = RT[0:3, 0:3] 10 | T = RT[0:3, 3].reshape((3, 1)) 11 | RT_new = np.zeros((3, 4), dtype=np.float32) 12 | RT_new[0:3, 0:3] = R.transpose() 13 | RT_new[0:3, 3] = -1 * np.dot(R.transpose(), T).reshape(3) 14 | return RT_new 15 | 16 | 17 | def se3_mul(RT1, RT2): 18 | R1 = RT1[0:3, 0:3] 19 | T1 = RT1[0:3, 3].reshape((3, 1)) 20 | 21 | R2 = RT2[0:3, 0:3] 22 | T2 = RT2[0:3, 3].reshape((3, 1)) 23 | 24 | RT_new = np.zeros((3, 4), dtype=np.float32) 25 | RT_new[0:3, 0:3] = np.dot(R1, R2) 26 | T_new = np.dot(R1, T2) + T1 27 | RT_new[0:3, 3] = T_new.reshape(3) 28 | return RT_new 29 | 30 | 31 | def T_inv_transform(T_src, T_tgt): 32 | """ 33 | :param T_src: 34 | :param T_tgt: 35 | :return: T_delta: delta in pixel 36 | """ 37 | T_delta = np.zeros((3,), dtype=np.float32) 38 | 39 | T_delta[0] = T_tgt[0] / T_tgt[2] - T_src[0] / T_src[2] 40 | T_delta[1] = T_tgt[1] / T_tgt[2] - T_src[1] / T_src[2] 41 | T_delta[2] = np.log(T_src[2] / T_tgt[2]) 42 | 43 | return T_delta 44 | 45 | 46 | def rotation_x(theta): 47 | t = theta * np.pi / 180.0 48 | R = np.zeros((3, 3), dtype=np.float32) 49 | R[0, 0] = 1 50 | R[1, 1] = np.cos(t) 51 | R[1, 2] = -(np.sin(t)) 52 | R[2, 1] = np.sin(t) 53 | R[2, 2] = np.cos(t) 54 | return R 55 | 56 | 57 | def rotation_y(theta): 58 | t = theta * np.pi / 180.0 59 | R = np.zeros((3, 3), dtype=np.float32) 60 | R[0, 0] = np.cos(t) 61 | R[0, 2] = np.sin(t) 62 | R[1, 1] = 1 63 | R[2, 0] = -(np.sin(t)) 64 | R[2, 2] = np.cos(t) 65 | return R 66 | 67 | 68 | def rotation_z(theta): 69 | t = theta * np.pi / 180.0 70 | R = np.zeros((3, 3), dtype=np.float32) 71 | R[0, 0] = np.cos(t) 72 | R[0, 1] = -(np.sin(t)) 73 | R[1, 0] = np.sin(t) 74 | R[1, 1] = np.cos(t) 75 | R[2, 2] = 1 76 | return R 77 | 78 | 79 | def angular_distance(quat): 80 | vec, theta = quat2axangle(quat) 81 | return theta / np.pi * 180 82 | 83 | 84 | # Q is a Nx4 numpy matrix and contains the quaternions to average in the rows. 85 | # The quaternions are arranged as (w,x,y,z), with w being the scalar 86 | # The result will be the average quaternion of the input. Note that the signs 87 | # of the output quaternion can be reversed, since q and -q describe the same orientation 88 | def averageQuaternions(Q): 89 | # Number of quaternions to average 90 | M = Q.shape[0] 91 | A = npm.zeros(shape=(4, 4)) 92 | 93 | for i in range(0, M): 94 | q = Q[i, :] 95 | # multiply q with its transposed version q' and add A 96 | A = np.outer(q, q) + A 97 | 98 | # scale 99 | A = (1.0 / M) * A 100 | # compute eigenvalues and -vectors 101 | eigenValues, eigenVectors = np.linalg.eig(A) 102 | # Sort by largest eigenvalue 103 | eigenVectors = eigenVectors[:, eigenValues.argsort()[::-1]] 104 | # return the real part of the largest eigenvector (has only real part) 105 | return np.real(eigenVectors[:, 0].A1) 106 | -------------------------------------------------------------------------------- /lib/pysixd/visibility.py: -------------------------------------------------------------------------------- 1 | # Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz) 2 | # Center for Machine Perception, Czech Technical University in Prague 3 | 4 | """Estimation of the visible object surface from depth images.""" 5 | 6 | import numpy as np 7 | 8 | 9 | def _estimate_visib_mask(d_test, d_model, delta, visib_mode="bop19"): 10 | """Estimates a mask of the visible object surface. 11 | 12 | :param d_test: Distance image of a scene in which the visibility is estimated. 13 | :param d_model: Rendered distance image of the object model. 14 | :param delta: Tolerance used in the visibility test. 15 | :param visib_mode: Visibility mode: 16 | 1) 'bop18' - Object is considered NOT VISIBLE at pixels with missing depth. 17 | 2) 'bop19' - Object is considered VISIBLE at pixels with missing depth. This 18 | allows to use the VSD pose error function also on shiny objects, which 19 | are typically not captured well by the depth sensors. A possible problem 20 | with this mode is that some invisible parts can be considered visible. 21 | However, the shadows of missing depth measurements, where this problem is 22 | expected to appear and which are often present at depth discontinuities, 23 | are typically relatively narrow and therefore this problem is less 24 | significant. 25 | :return: Visibility mask. 26 | """ 27 | assert d_test.shape == d_model.shape 28 | 29 | if visib_mode == "bop18": 30 | mask_valid = np.logical_and(d_test > 0, d_model > 0) 31 | d_diff = d_model.astype(np.float32) - d_test.astype(np.float32) 32 | visib_mask = np.logical_and(d_diff <= delta, mask_valid) 33 | 34 | elif visib_mode == "bop19": 35 | d_diff = d_model.astype(np.float32) - d_test.astype(np.float32) 36 | visib_mask = np.logical_and(np.logical_or(d_diff <= delta, d_test == 0), d_model > 0) 37 | 38 | else: 39 | raise ValueError("Unknown visibility mode.") 40 | 41 | return visib_mask 42 | 43 | 44 | def estimate_visib_mask_gt(d_test, d_gt, delta, visib_mode="bop19"): 45 | """Estimates a mask of the visible object surface in the ground-truth pose. 46 | 47 | :param d_test: Distance image of a scene in which the visibility is estimated. 48 | :param d_gt: Rendered distance image of the object model in the GT pose. 49 | :param delta: Tolerance used in the visibility test. 50 | :param visib_mode: See _estimate_visib_mask. 51 | :return: Visibility mask. 52 | """ 53 | visib_gt = _estimate_visib_mask(d_test, d_gt, delta, visib_mode) 54 | return visib_gt 55 | 56 | 57 | def estimate_visib_mask_est(d_test, d_est, visib_gt, delta, visib_mode="bop19"): 58 | """Estimates a mask of the visible object surface in the estimated pose. 59 | 60 | For an explanation of why the visibility mask is calculated differently for 61 | the estimated and the ground-truth pose, see equation (14) and related text in 62 | Hodan et al., On Evaluation of 6D Object Pose Estimation, ECCVW'16. 63 | 64 | :param d_test: Distance image of a scene in which the visibility is estimated. 65 | :param d_est: Rendered distance image of the object model in the est. pose. 66 | :param visib_gt: Visibility mask of the object model in the GT pose (from 67 | function estimate_visib_mask_gt). 68 | :param delta: Tolerance used in the visibility test. 69 | :param visib_mode: See _estimate_visib_mask. 70 | :return: Visibility mask. 71 | """ 72 | visib_est = _estimate_visib_mask(d_test, d_est, delta, visib_mode) 73 | visib_est = np.logical_or(visib_est, np.logical_and(visib_gt, d_est > 0)) 74 | return visib_est 75 | -------------------------------------------------------------------------------- /lib/structures/__init__.py: -------------------------------------------------------------------------------- 1 | from .centers_2d import Center2Ds 2 | from .keypoints_2d import Keypoints2Ds 3 | from .keypoints_3d import Keypoints3Ds 4 | from .my_maps import MyMaps 5 | from .my_masks import MyBitMasks 6 | from .quats import Quats 7 | from .translations import Translations 8 | from .poses import Poses 9 | from .rots import Rots 10 | from .my_list import MyList 11 | -------------------------------------------------------------------------------- /lib/structures/centers_2d.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterator, List, Tuple, Union 2 | 3 | import torch 4 | from torch import device 5 | from detectron2.utils.env import TORCH_VERSION 6 | 7 | if TORCH_VERSION < (1, 8): 8 | _maybe_jit_unused = torch.jit.unused 9 | else: 10 | 11 | def _maybe_jit_unused(x): 12 | return x 13 | 14 | 15 | class Center2Ds: 16 | """This structure stores a list of 2d centers (object/bbox centers) a Nx2 17 | torch.Tensor. 18 | 19 | Attributes: 20 | tensor: float matrix of Nx2. 21 | """ 22 | 23 | def __init__(self, tensor: torch.Tensor): 24 | """ 25 | Args: 26 | tensor (Tensor[float]): 27 | * a Nx2 matrix. Each row is (x, y). 28 | """ 29 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") 30 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) 31 | if tensor.numel() == 0: 32 | # Use reshape, so we don't end up creating a new tensor that does not depend on 33 | # the inputs (and consequently confuses jit) 34 | tensor = torch.reshape(0, 2).to(dtype=torch.float32, device=device) 35 | assert tensor.ndim == 2 and (tensor.shape[-1] == 2), tensor.shape 36 | 37 | self.tensor = tensor 38 | 39 | def clone(self) -> "Center2Ds": 40 | """Clone the Center2Ds. 41 | 42 | Returns: 43 | Center2Ds 44 | """ 45 | return Center2Ds(self.tensor.clone()) 46 | 47 | @_maybe_jit_unused 48 | def to(self, device: torch.device = None, **kwargs) -> "Center2Ds": 49 | return Center2Ds(self.tensor.to(device=device, **kwargs)) 50 | 51 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Center2Ds": 52 | """ 53 | Returns: 54 | Center2Ds: Create a new :class:`Center2Ds` by indexing. 55 | 56 | The following usage are allowed: 57 | 1. `new_center2ds = center2ds[3]`: return a `Center2Ds` which contains only one 2d center. 58 | 2. `new_center2ds = center2ds[2:10]`: return a slice of center2ds. 59 | 3. `new_center2ds = center2ds[vector]`, where vector is a torch.BoolTensor 60 | with `length = len(center2ds)`. Nonzero elements in the vector will be selected. 61 | 62 | Note that the returned Center2Ds might share storage with this Center2Ds, 63 | subject to Pytorch's indexing semantics. 64 | """ 65 | if isinstance(item, int): 66 | return Center2Ds(self.tensor[item].view(1, -1)) 67 | b = self.tensor[item] 68 | assert b.ndim == 2, "Indexing on Center2Ds with {} failed to return a matrix!".format(item) 69 | return Center2Ds(b) 70 | 71 | def __len__(self) -> int: 72 | return self.tensor.shape[0] 73 | 74 | def __repr__(self) -> str: 75 | return "Center2Ds(" + str(self.tensor) + ")" 76 | 77 | @classmethod 78 | def cat(center2ds_list: List["Center2Ds"]) -> "Center2Ds": 79 | """Concatenates a list of Center2Ds into a single Center2Ds. 80 | 81 | Arguments: 82 | center2ds_list (list[Center2Ds]) 83 | 84 | Returns: 85 | Center2Ds: the concatenated Center2Ds 86 | """ 87 | if torch.jit.is_scripting(): 88 | # https://github.com/pytorch/pytorch/issues/18627 89 | # 1. staticmethod can be used in torchscript, But we can not use 90 | # `type(center2ds).staticmethod` because torchscript only supports function 91 | # `type` with input type `torch.Tensor`. 92 | # 2. classmethod is not fully supported by torchscript. We explicitly assign 93 | # cls to Center2Ds as a workaround to get torchscript support. 94 | cls = Center2Ds 95 | assert isinstance(center2ds_list, (list, tuple)) 96 | if len(center2ds_list) == 0: 97 | return cls(torch.empty(0)) 98 | assert all(isinstance(center2ds, Center2Ds) for center2ds in center2ds_list) 99 | 100 | # use torch.cat (v.s. layers.cat) so the returned tensor never share storage with input 101 | cat_center2ds = cls(center2ds_list[0])(torch.cat([b.tensor for b in center2ds_list], dim=0)) 102 | return cat_center2ds 103 | 104 | @property 105 | def device(self) -> device: 106 | return self.tensor.device 107 | 108 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript 109 | # https://github.com/pytorch/pytorch/issues/18627 110 | @torch.jit.unused 111 | def __iter__(self) -> Iterator[torch.Tensor]: 112 | """Yield a 2d center as a Tensor of shape (2,) at a time.""" 113 | yield from self.tensor 114 | -------------------------------------------------------------------------------- /lib/structures/keypoints_2d.py: -------------------------------------------------------------------------------- 1 | # for example: store projected bbox3d+center3d => (N,9,2) 2 | from typing import Any, Iterator, List, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from torch import device 7 | from detectron2.utils.env import TORCH_VERSION 8 | 9 | if TORCH_VERSION < (1, 8): 10 | _maybe_jit_unused = torch.jit.unused 11 | else: 12 | 13 | def _maybe_jit_unused(x): 14 | return x 15 | 16 | 17 | class Keypoints2Ds: 18 | """Modified from class Keypoints. 19 | 20 | Stores 2d keypoint annotation data. GT Instances have a 21 | `gt_2d_keypoints` property containing the x,y location of each 22 | keypoint. This tensor has shape (N, K, 2) where N is the number of 23 | instances and K is the number of keypoints per instance. 24 | """ 25 | 26 | def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]): 27 | """ 28 | Arguments: 29 | keypoints: A Tensor, numpy array, or list of the x, y of each keypoint. 30 | The shape should be (N, K, 2) where N is the number of 31 | instances, and K is the number of keypoints per instance. 32 | """ 33 | device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device("cpu") 34 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device) 35 | assert keypoints.ndim == 3 and keypoints.shape[2] == 2, keypoints.shape 36 | self.tensor = keypoints 37 | 38 | def __len__(self) -> int: 39 | return self.tensor.shape[0] 40 | 41 | def clone(self) -> "Keypoints2Ds": 42 | """Clone the Keypoints2Ds. 43 | 44 | Returns: 45 | Keypoints2Ds 46 | """ 47 | return Keypoints2Ds(self.tensor.clone()) 48 | 49 | def to(self, *args: Any, **kwargs: Any) -> "Keypoints2Ds": 50 | return type(self)(self.tensor.to(*args, **kwargs)) 51 | 52 | def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor: 53 | # TODO: convert 2d keypoints to heatmap as proposed in Integoral Regression 54 | # copy from d2 if needed 55 | raise NotImplementedError 56 | 57 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints2Ds": 58 | """Create a new `Keypoints2Ds` by indexing on this `Keypoints2Ds`. 59 | 60 | The following usage are allowed: 61 | 62 | 1. `new_kpts = kpts[3]`: return a `Keypoints2Ds` which contains only one instance. 63 | 2. `new_kpts = kpts[2:10]`: return a slice of key points. 64 | 3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor 65 | with `length = len(kpts)`. Nonzero elements in the vector will be selected. 66 | 67 | Note that the returned Keypoints2Ds might share storage with this Keypoints2Ds, 68 | subject to Pytorch's indexing semantics. 69 | """ 70 | if isinstance(item, int): 71 | return Keypoints2Ds([self.tensor[item]]) 72 | return Keypoints2Ds(self.tensor[item]) 73 | 74 | def __repr__(self) -> str: 75 | s = self.__class__.__name__ + "(" 76 | s += "num_instances={})".format(len(self.tensor)) 77 | return s 78 | 79 | @classmethod 80 | def cat(cls, keypoints2ds_list: List["Keypoints2Ds"]) -> "Keypoints2Ds": 81 | """Concatenates a list of Keypoints2Ds into a single Keypoints2Ds. 82 | 83 | Arguments: 84 | keypoints2ds_list (list[Keypoints2Ds]) 85 | 86 | Returns: 87 | Keypoints2Ds: the concatenated Keypoints2Ds 88 | """ 89 | if torch.jit.is_scripting(): 90 | # https://github.com/pytorch/pytorch/issues/18627 91 | # 1. staticmethod can be used in torchscript, But we can not use 92 | # `type(xxx).staticmethod` because torchscript only supports function 93 | # `type` with input type `torch.Tensor`. 94 | # 2. classmethod is not fully supported by torchscript. We explicitly assign 95 | # cls to ThisClassName as a workaround to get torchscript support. 96 | cls = Keypoints2Ds 97 | assert isinstance(keypoints2ds_list, (list, tuple)) 98 | if len(keypoints2ds_list) == 0: 99 | return cls(torch.empty(0)) 100 | assert all(isinstance(keypoints2ds, Keypoints2Ds) for keypoints2ds in keypoints2ds_list) 101 | 102 | # use torch.cat (v.s. layers.cat) so the returned tensor never share storage with input 103 | cat_keypoints2ds = type(keypoints2ds_list[0])(torch.cat([b.tensor for b in keypoints2ds_list], dim=0)) 104 | return cat_keypoints2ds 105 | 106 | @property 107 | def device(self) -> device: 108 | return self.tensor.device 109 | 110 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript 111 | # https://github.com/pytorch/pytorch/issues/18627 112 | @torch.jit.unused 113 | def __iter__(self) -> Iterator[torch.Tensor]: 114 | """Yield a 2d center as a Tensor of shape (2,) at a time.""" 115 | yield from self.tensor 116 | -------------------------------------------------------------------------------- /lib/structures/keypoints_3d.py: -------------------------------------------------------------------------------- 1 | # for example: store bbox3d+center3d => (N,9,3) 2 | from typing import Any, Iterator, List, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from torch import device 7 | from detectron2.utils.env import TORCH_VERSION 8 | 9 | if TORCH_VERSION < (1, 8): 10 | _maybe_jit_unused = torch.jit.unused 11 | else: 12 | 13 | def _maybe_jit_unused(x): 14 | return x 15 | 16 | 17 | class Keypoints3Ds: 18 | """Modified from class Keypoints. 19 | 20 | Stores 3d keypoint annotation data. GT Instances have a 21 | `gt_3d_keypoints` property containing the x,y,z location of each 22 | keypoint. This tensor has shape (N, K, 3) where N is the number of 23 | instances and K is the number of keypoints per instance. 24 | """ 25 | 26 | def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]): 27 | """ 28 | Arguments: 29 | keypoints: A Tensor, numpy array, or list of the x, y, z of each keypoint. 30 | The shape should be (N, K, 3) where N is the number of 31 | instances, and K is the number of keypoints per instance. 32 | """ 33 | device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device("cpu") 34 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device) 35 | assert keypoints.ndim == 3 and keypoints.shape[2] == 3, keypoints.shape 36 | self.tensor = keypoints 37 | 38 | def __len__(self) -> int: 39 | return self.tensor.shape[0] 40 | 41 | def clone(self) -> "Keypoints3Ds": 42 | """Clone the Keypoints3Ds. 43 | 44 | Returns: 45 | Keypoints3Ds 46 | """ 47 | return Keypoints3Ds(self.tensor.clone()) 48 | 49 | def to(self, *args: Any, **kwargs: Any) -> "Keypoints3Ds": 50 | return type(self)(self.tensor.to(*args, **kwargs)) 51 | 52 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints3Ds": 53 | """Create a new `Keypoints3Ds` by indexing on this `Keypoints3Ds`. 54 | 55 | The following usage are allowed: 56 | 57 | 1. `new_kpts = kpts[3]`: return a `Keypoints3Ds` which contains only one instance. 58 | 2. `new_kpts = kpts[2:10]`: return a slice of key points. 59 | 3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor 60 | with `length = len(kpts)`. Nonzero elements in the vector will be selected. 61 | 62 | Note that the returned Keypoints3Ds might share storage with this Keypoints3Ds, 63 | subject to Pytorch's indexing semantics. 64 | """ 65 | if isinstance(item, int): 66 | return Keypoints3Ds([self.tensor[item]]) 67 | return Keypoints3Ds(self.tensor[item]) 68 | 69 | def __repr__(self) -> str: 70 | s = self.__class__.__name__ + "(" 71 | s += "num_instances={})".format(len(self.tensor)) 72 | return s 73 | 74 | @classmethod 75 | def cat(cls, keypoints3ds_list: List["Keypoints3Ds"]) -> "Keypoints3Ds": 76 | """Concatenates a list of Keypoints3Ds into a single Keypoints3Ds. 77 | 78 | Arguments: 79 | keypoints3ds_list (list[Keypoints3Ds]) 80 | 81 | Returns: 82 | Keypoints3Ds: the concatenated Keypoints3Ds 83 | """ 84 | if torch.jit.is_scripting(): 85 | # https://github.com/pytorch/pytorch/issues/18627 86 | # 1. staticmethod can be used in torchscript, But we can not use 87 | # `type(keypoits3ds).staticmethod` because torchscript only supports function 88 | # `type` with input type `torch.Tensor`. 89 | # 2. classmethod is not fully supported by torchscript. We explicitly assign 90 | # cls to Keypoints3Ds as a workaround to get torchscript support. 91 | cls = Keypoints3Ds 92 | assert isinstance(keypoints3ds_list, (list, tuple)) 93 | if len(keypoints3ds_list) == 0: 94 | return cls(torch.empty(0)) 95 | assert all(isinstance(keypoints3ds, Keypoints3Ds) for keypoints3ds in keypoints3ds_list) 96 | 97 | # use torch.cat (v.s. layers.cat) so the returned tensor never share storage with input 98 | cat_keypoints3ds = type(keypoints3ds_list[0])(torch.cat([b.tensor for b in keypoints3ds_list], dim=0)) 99 | return cat_keypoints3ds 100 | 101 | @property 102 | def device(self) -> device: 103 | return self.tensor.device 104 | 105 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript 106 | # https://github.com/pytorch/pytorch/issues/18627 107 | @torch.jit.unused 108 | def __iter__(self) -> Iterator[torch.Tensor]: 109 | """Yield a 2d center as a Tensor of shape (2,) at a time.""" 110 | yield from self.tensor 111 | -------------------------------------------------------------------------------- /lib/structures/my_list.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class MyList(list): 6 | def __getitem__(self, index): 7 | """support indexing using torch.Tensor.""" 8 | if isinstance(index, torch.Tensor): 9 | if isinstance(index, torch.BoolTensor): 10 | return [self[i] for i, idx in enumerate(index) if idx] 11 | else: 12 | return [self[int(i)] for i in index] 13 | elif isinstance(index, (list, tuple)): 14 | if len(index) > 0 and isinstance(index[0], bool): 15 | return [self[i] for i, idx in enumerate(index) if idx] 16 | else: 17 | return [self[int(i)] for i in index] 18 | elif isinstance(index, np.ndarray): 19 | if index.dtype == np.bool: 20 | return [self[i] for i, idx in enumerate(index) if idx] 21 | else: 22 | return [self[int(i)] for i in index] 23 | 24 | return list.__getitem__(self, index) 25 | 26 | 27 | if __name__ == "__main__": 28 | a = [None, "a", 1, 2.3] 29 | a = MyList(a) 30 | print(a) 31 | print(type(a), isinstance(a, list)) 32 | print("\ntorch bool index") 33 | index = torch.tensor([True, False, True, False]) 34 | print(index) 35 | print(a[index]) 36 | 37 | print("torch int index") 38 | index = torch.tensor([0, 2, 3]) 39 | print(index) 40 | print(a[index]) 41 | 42 | print("\nnumpy bool index") 43 | index = np.array([True, False, True, False]) 44 | print(index) 45 | print(a[index]) 46 | 47 | print("numpy int index") 48 | index = np.array([0, 2, 3]) 49 | print(index) 50 | print(a[index]) 51 | 52 | print("\nlist bool index") 53 | index = [True, False, True, False] 54 | print(index) 55 | print(a[index]) 56 | 57 | print("list int index") 58 | index = [0, 2, 3] 59 | print(index) 60 | print(a[index]) 61 | 62 | print("\ntuple bool index") 63 | index = (True, False, True, False) 64 | print(index) 65 | print(a[index]) 66 | 67 | print("tuple int index") 68 | index = (0, 2, 3) 69 | print(index) 70 | print(a[index]) 71 | 72 | # print(a[1:-1]) 73 | -------------------------------------------------------------------------------- /lib/structures/my_maps.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any, Iterator, List, Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from detectron2.layers.roi_align import ROIAlign 8 | from torchvision.ops import RoIPool 9 | 10 | 11 | class MyMaps(object): 12 | """# NOTE: This class stores the maps (NOCS, coordinates map, pvnet vector 13 | maps, offset maps, heatmaps) for all objects in one image, support cpu_only 14 | option. 15 | 16 | Attributes: 17 | tensor: bool Tensor of N,C,H,W, representing N instances in the image. 18 | """ 19 | 20 | def __init__(self, tensor: Union[torch.Tensor, np.ndarray], cpu_only: bool = True): 21 | """ 22 | Args: 23 | tensor: float Tensor of N,C,H,W, representing N instances in the image. 24 | cpu_only: keep the maps on cpu even when to(device) is called 25 | """ 26 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") 27 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) 28 | assert tensor.dim() == 4, tensor.size() 29 | self.image_size = tensor.shape[-2:] 30 | self.tensor = tensor 31 | self.cpu_only = cpu_only 32 | 33 | def to(self, device: str, **kwargs) -> "MyMaps": 34 | if not self.cpu_only: 35 | return MyMaps(self.tensor.to(device, **kwargs), cpu_only=False) 36 | else: 37 | return MyMaps(self.tensor.to("cpu", **kwargs), cpu_only=True) 38 | 39 | def to_device(self, device: str = "cuda", **kwargs) -> "MyMaps": 40 | # force to device 41 | return MyMaps(self.tensor.to(device, **kwargs), cpu_only=False) 42 | 43 | def crop_and_resize( 44 | self, 45 | boxes: torch.Tensor, 46 | map_size: int, 47 | interpolation: str = "bilinear", 48 | ) -> torch.Tensor: 49 | """# NOTE: if self.cpu_only, convert boxes to cpu 50 | Crop each map by the given box, and resize results to (map_size, map_size). 51 | This can be used to prepare training targets. 52 | Args: 53 | boxes (Tensor): Nx4 tensor storing the boxes for each map 54 | map_size (int): the size of the rasterized map. 55 | interpolation (str): bilinear | nearest 56 | 57 | Returns: 58 | Tensor: 59 | A bool tensor of shape (N, C, map_size, map_size), where 60 | N is the number of predicted boxes for this image. 61 | """ 62 | assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self)) 63 | if self.cpu_only: 64 | device = "cpu" 65 | else: 66 | device = self.tensor.device 67 | 68 | batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[:, None] 69 | rois = torch.cat([batch_inds, boxes.to(device)], dim=1) # Nx5 70 | 71 | maps = self.tensor.to(dtype=torch.float32) 72 | rois = rois.to(device=device) 73 | # on cpu, speed compared to cv2? 74 | if interpolation == "nearest": 75 | op = RoIPool((map_size, map_size), 1.0) 76 | elif interpolation == "bilinear": 77 | op = ROIAlign((map_size, map_size), 1.0, 0, aligned=True) 78 | else: 79 | raise ValueError(f"Unknown interpolation type: {interpolation}") 80 | output = op.forward(maps, rois) 81 | return output 82 | 83 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "MyMaps": 84 | """ 85 | Returns: 86 | MyMaps: Create a new :class:`MyMaps` by indexing. 87 | 88 | The following usage are allowed: 89 | 90 | 1. `new_maps = maps[3]`: return a `MyMaps` which contains only one map. 91 | 2. `new_maps = maps[2:10]`: return a slice of maps. 92 | 3. `new_maps = maps[vector]`, where vector is a torch.BoolTensor 93 | with `length = len(maps)`. Nonzero elements in the vector will be selected. 94 | 95 | Note that the returned object might share storage with this object, 96 | subject to Pytorch's indexing semantics. 97 | """ 98 | if isinstance(item, int): 99 | return MyMaps(self.tensor[item].view(1, -1)) 100 | m = self.tensor[item] 101 | assert m.dim() == 4, "Indexing on MyMaps with {} returns a tensor with shape {}!".format(item, m.shape) 102 | return MyMaps(m) 103 | 104 | def __iter__(self) -> torch.Tensor: 105 | yield from self.tensor 106 | 107 | def __repr__(self) -> str: 108 | s = self.__class__.__name__ + "(" 109 | s += "num_instances={})".format(len(self.tensor)) 110 | return s 111 | 112 | def __len__(self) -> int: 113 | return self.tensor.shape[0] 114 | 115 | def nonempty(self) -> torch.Tensor: 116 | """Find maps that are non-empty. 117 | 118 | Returns: 119 | Tensor: a BoolTensor which represents 120 | whether each map is empty (False) or non-empty (True). 121 | """ 122 | return self.tensor.flatten(1).any(dim=1) 123 | -------------------------------------------------------------------------------- /lib/structures/poses.py: -------------------------------------------------------------------------------- 1 | # poses in [rot|trans] format (N, 3, 4). 2 | from typing import Any, Iterator, List, Tuple, Union 3 | import torch 4 | from torch import device 5 | from detectron2.utils.env import TORCH_VERSION 6 | 7 | if TORCH_VERSION < (1, 8): 8 | _maybe_jit_unused = torch.jit.unused 9 | else: 10 | 11 | def _maybe_jit_unused(x): 12 | return x 13 | 14 | 15 | class Poses: 16 | """This structure stores a list of 6d poses as a Nx3x4 ([rot|trans] 17 | torch.Tensor. It supports some common methods about poses, and also behaves 18 | like a Tensor (support indexing, `to(device)`, `.device`, and iteration 19 | over all 6d poses) 20 | 21 | Attributes: 22 | tensor (torch.Tensor): float matrix of Nx3x4. 23 | """ 24 | 25 | def __init__(self, tensor: torch.Tensor): 26 | """ 27 | Args: 28 | tensor (Tensor[float]): 29 | * a Nx3x4 matrix. 30 | """ 31 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") 32 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) 33 | if tensor.numel() == 0: 34 | # Use reshape, so we don't end up creating a new tensor that does not depend on 35 | # the inputs (and consequently confuses jit) 36 | tensor = torch.reshape(0, 3, 4).to(dtype=torch.float32, device=device) 37 | assert tensor.ndim == 3 and (tensor.shape[1:] == (3, 4)), tensor.shape 38 | 39 | self.tensor = tensor 40 | 41 | def clone(self) -> "Poses": 42 | """Clone the Poses. 43 | 44 | Returns: 45 | Poses 46 | """ 47 | return Poses(self.tensor.clone()) 48 | 49 | @_maybe_jit_unused 50 | def to(self, device: torch.device = None, **kwargs) -> "Poses": 51 | return Poses(self.tensor.to(device=device, **kwargs)) 52 | 53 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Poses": 54 | """ 55 | Returns: 56 | Poses: Create a new :class:`Poses` by indexing. 57 | 58 | The following usage are allowed: 59 | 1. `new_poses = poses[3]`: return a `Poses` which contains only one pose. 60 | 2. `new_poses = poses[2:10]`: return a slice of poses. 61 | 3. `new_poses = poses[vector]`, where vector is a torch.BoolTensor 62 | with `length = len(poses)`. Nonzero elements in the vector will be selected. 63 | 64 | Note that the returned Poses might share storage with this Poses, 65 | subject to Pytorch's indexing semantics. 66 | """ 67 | if isinstance(item, int): 68 | return Poses(self.tensor[item].view(1, 3, 4)) 69 | b = self.tensor[item] 70 | assert b.ndim == 3, "Indexing on Poses with {} failed!".format(item) 71 | return Poses(b) 72 | 73 | def __len__(self) -> int: 74 | return self.tensor.shape[0] 75 | 76 | def __repr__(self) -> str: 77 | return "Poses(" + str(self.tensor) + ")" 78 | 79 | def get_centers_2d(self, K: torch.Tensor) -> torch.Tensor: 80 | """ 81 | Args: 82 | K: camera intrinsic matrices, 1x3x3 or Nx3x3 83 | Returns: 84 | The 2d projected object centers in a Nx2 array of (x, y). 85 | """ 86 | assert K.ndim == 3, K.shape 87 | bs = self.tensor.shape[0] 88 | proj = (K @ self.tensor[:, :3, [3]]).view(bs, 3) # Nx3 89 | centers_2d = proj[:, :2] / proj[:, 2:3] # Nx2 90 | return centers_2d 91 | 92 | @classmethod 93 | def cat(cls, poses_list: List["Poses"]) -> "Poses": 94 | """Concatenates a list of Poses into a single Poses. 95 | 96 | Arguments: 97 | poses_list (list[Poses]) 98 | 99 | Returns: 100 | Poses: the concatenated Poses 101 | """ 102 | if torch.jit.is_scripting(): 103 | # https://github.com/pytorch/pytorch/issues/18627 104 | # 1. staticmethod can be used in torchscript, But we can not use 105 | # `type(poses).staticmethod` because torchscript only supports function 106 | # `type` with input type `torch.Tensor`. 107 | # 2. classmethod is not fully supported by torchscript. We explicitly assign 108 | # cls to Poses as a workaround to get torchscript support. 109 | cls = Poses 110 | assert isinstance(poses_list, (list, tuple)) 111 | if len(poses_list) == 0: 112 | return cls(torch.empty(0)) 113 | assert all(isinstance(pose, Poses) for pose in poses_list) 114 | 115 | # use torch.cat (v.s. layers.cat) so the returned poses never share storage with input 116 | cat_poses = cls(poses_list[0])(torch.cat([p.tensor for p in poses_list], dim=0)) 117 | return cat_poses 118 | 119 | @property 120 | def device(self) -> device: 121 | return self.tensor.device 122 | 123 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript 124 | # https://github.com/pytorch/pytorch/issues/18627 125 | @torch.jit.unused 126 | def __iter__(self) -> Iterator[torch.Tensor]: 127 | """Yield a 6d pose as a Tensor of shape (3,4) at a time.""" 128 | yield from self.tensor 129 | -------------------------------------------------------------------------------- /lib/structures/quats.py: -------------------------------------------------------------------------------- 1 | # quaternions 2 | from typing import Any, Iterator, List, Tuple, Union 3 | 4 | import torch 5 | from torch import device 6 | from detectron2.utils.env import TORCH_VERSION 7 | 8 | if TORCH_VERSION < (1, 8): 9 | _maybe_jit_unused = torch.jit.unused 10 | else: 11 | 12 | def _maybe_jit_unused(x): 13 | return x 14 | 15 | 16 | class Quats: 17 | """This structure stores a list of quaternions as a Nx4 torch.Tensor. It 18 | supports some common methods about quats, and also behaves like a Tensor 19 | (support indexing, `to(device)`, `.device`, and iteration over all quats) 20 | 21 | Attributes: 22 | tensor: float matrix of Nx4. 23 | """ 24 | 25 | def __init__(self, tensor: torch.Tensor): 26 | """ 27 | Args: 28 | tensor (Tensor[float]): 29 | * a Nx4 matrix. Each row is (qw, qx, qy, qz). 30 | """ 31 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") 32 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) 33 | if tensor.numel() == 0: 34 | # Use reshape, so we don't end up creating a new tensor that does not depend on 35 | # the inputs (and consequently confuses jit) 36 | tensor = torch.reshape(0, 4).to(dtype=torch.float32, device=device) 37 | assert tensor.ndim == 2 and (tensor.shape[-1] == 4), tensor.shape 38 | 39 | self.tensor = tensor 40 | 41 | def clone(self) -> "Quats": 42 | """Clone the Quats. 43 | 44 | Returns: 45 | Quats 46 | """ 47 | return Quats(self.tensor.clone()) 48 | 49 | @_maybe_jit_unused 50 | def to(self, device: torch.device = None, **kwargs) -> "Quats": 51 | return Quats(self.tensor.to(device=device, **kwargs)) 52 | 53 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Quats": 54 | """ 55 | Returns: 56 | Quats: Create a new :class:`Quats` by indexing. 57 | 58 | The following usage are allowed: 59 | 1. `new_quats = quats[3]`: return a `Quats` which contains only one quat. 60 | 2. `new_quats = quats[2:10]`: return a slice of quats. 61 | 3. `new_quats = quats[vector]`, where vector is a torch.BoolTensor 62 | with `length = len(quats)`. Nonzero elements in the vector will be selected. 63 | 64 | Note that the returned Quats might share storage with this Quats, 65 | subject to Pytorch's indexing semantics. 66 | """ 67 | if isinstance(item, int): 68 | return Quats(self.tensor[item].view(1, -1)) 69 | b = self.tensor[item] 70 | assert b.ndim == 2, "Indexing on Quats with {} failed to return a matrix!".format(item) 71 | return Quats(b) 72 | 73 | def __len__(self) -> int: 74 | return self.tensor.shape[0] 75 | 76 | def __repr__(self) -> str: 77 | return "Quats(" + str(self.tensor) + ")" 78 | 79 | @classmethod 80 | def cat(cls, quats_list: List["Quats"]) -> "Quats": 81 | """Concatenates a list of Quats into a single Quats. 82 | 83 | Arguments: 84 | quats_list (list[Quats]) 85 | 86 | Returns: 87 | Quats: the concatenated Quats 88 | """ 89 | if torch.jit.is_scripting(): 90 | # https://github.com/pytorch/pytorch/issues/18627 91 | # 1. staticmethod can be used in torchscript, But we can not use 92 | # `type(quats).staticmethod` because torchscript only supports function 93 | # `type` with input type `torch.Tensor`. 94 | # 2. classmethod is not fully supported by torchscript. We explicitly assign 95 | # cls to Quats as a workaround to get torchscript support. 96 | cls = Quats 97 | assert isinstance(quats_list, (list, tuple)) 98 | if len(quats_list) == 0: 99 | return cls(torch.empty(0)) 100 | assert all(isinstance(quats, Quats) for quats in quats_list) 101 | 102 | # use torch.cat (v.s. layers.cat) so the returned quats never share storage with input 103 | cat_quats = cls(quats_list[0])(torch.cat([q.tensor for q in quats_list], dim=0)) 104 | return cat_quats 105 | 106 | @property 107 | def device(self) -> device: 108 | return self.tensor.device 109 | 110 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript 111 | # https://github.com/pytorch/pytorch/issues/18627 112 | @torch.jit.unused 113 | def __iter__(self) -> Iterator[torch.Tensor]: 114 | """Yield a quat as a Tensor of shape (4,) at a time.""" 115 | yield from self.tensor 116 | -------------------------------------------------------------------------------- /lib/structures/rots.py: -------------------------------------------------------------------------------- 1 | # rotation matrices, format (N, 3, 3). 2 | from typing import Any, Iterator, List, Tuple, Union 3 | import torch 4 | from torch import device 5 | from detectron2.utils.env import TORCH_VERSION 6 | 7 | if TORCH_VERSION < (1, 8): 8 | _maybe_jit_unused = torch.jit.unused 9 | else: 10 | 11 | def _maybe_jit_unused(x): 12 | return x 13 | 14 | 15 | class Rots: 16 | """This structure stores a list of rotation matrices as a Nx3x3 17 | torch.Tensor. It supports some common methods about rots, and also behaves 18 | like a Tensor (support indexing, `to(device)`, `.device`, and iteration 19 | over all rots) 20 | 21 | Attributes: 22 | tensor (torch.Tensor): float matrix of Nx3x3. 23 | """ 24 | 25 | def __init__(self, tensor: torch.Tensor): 26 | """ 27 | Args: 28 | tensor (Tensor[float]): 29 | * a Nx3x3 matrix. 30 | """ 31 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") 32 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) 33 | if tensor.numel() == 0: 34 | # Use reshape, so we don't end up creating a new tensor that does not depend on 35 | # the inputs (and consequently confuses jit) 36 | tensor = torch.reshape(0, 3, 3).to(dtype=torch.float32, device=device) 37 | assert tensor.ndim == 3 and (tensor.shape[1:] == (3, 3)), tensor.shape 38 | 39 | self.tensor = tensor 40 | 41 | def clone(self) -> "Rots": 42 | """Clone the Rots. 43 | 44 | Returns: 45 | Rots 46 | """ 47 | return Rots(self.tensor.clone()) 48 | 49 | @_maybe_jit_unused 50 | def to(self, device: torch.device = None, **kwargs) -> "Rots": 51 | return Rots(self.tensor.to(device=device, **kwargs)) 52 | 53 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Rots": 54 | """ 55 | Returns: 56 | Rots: Create a new :class:`Rots` by indexing. 57 | 58 | The following usage are allowed: 59 | 1. `new_rots = rots[3]`: return a `Rots` which contains only one pose. 60 | 2. `new_rots = rots[2:10]`: return a slice of rots. 61 | 3. `new_rots = rots[vector]`, where vector is a torch.BoolTensor 62 | with `length = len(rots)`. Nonzero elements in the vector will be selected. 63 | 64 | Note that the returned Rots might share storage with this Rots, 65 | subject to Pytorch's indexing semantics. 66 | """ 67 | if isinstance(item, int): 68 | return Rots(self.tensor[item].view(1, 3, 3)) 69 | b = self.tensor[item] 70 | assert b.ndim == 3, "Indexing on Rots with {} failed!".format(item) 71 | return Rots(b) 72 | 73 | def __len__(self) -> int: 74 | return self.tensor.shape[0] 75 | 76 | def __repr__(self) -> str: 77 | return "Rots(" + str(self.tensor) + ")" 78 | 79 | @classmethod 80 | def cat(cls, rots_list: List["Rots"]) -> "Rots": 81 | """Concatenates a list of Rots into a single Rots. 82 | 83 | Arguments: 84 | rots_list (list[Rots]) 85 | 86 | Returns: 87 | Rots: the concatenated Rots 88 | """ 89 | if torch.jit.is_scripting(): 90 | # https://github.com/pytorch/pytorch/issues/18627 91 | # 1. staticmethod can be used in torchscript, But we can not use 92 | # `type(xxx).staticmethod` because torchscript only supports function 93 | # `type` with input type `torch.Tensor`. 94 | # 2. classmethod is not fully supported by torchscript. We explicitly assign 95 | # cls to ThisClassName as a workaround to get torchscript support. 96 | cls = Rots 97 | assert isinstance(rots_list, (list, tuple)) 98 | if len(rots_list) == 0: 99 | return cls(torch.empty(0)) 100 | assert all(isinstance(pose, Rots) for pose in rots_list) 101 | 102 | # use torch.cat (v.s. layers.cat) so the returned tensor never share storage with input 103 | cat_rots = cls(rots_list[0])(torch.cat([p.tensor for p in rots_list], dim=0)) 104 | return cat_rots 105 | 106 | @property 107 | def device(self) -> device: 108 | return self.tensor.device 109 | 110 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript 111 | # https://github.com/pytorch/pytorch/issues/18627 112 | @torch.jit.unused 113 | def __iter__(self) -> Iterator[torch.Tensor]: 114 | """Yield a rot as a Tensor of shape (3,3) at a time.""" 115 | yield from self.tensor 116 | -------------------------------------------------------------------------------- /lib/structures/translations.py: -------------------------------------------------------------------------------- 1 | # translations 2 | from typing import Any, Iterator, List, Tuple, Union 3 | 4 | import torch 5 | from torch import device 6 | from detectron2.utils.env import TORCH_VERSION 7 | 8 | if TORCH_VERSION < (1, 8): 9 | _maybe_jit_unused = torch.jit.unused 10 | else: 11 | 12 | def _maybe_jit_unused(x): 13 | return x 14 | 15 | 16 | class Translations: 17 | """This structure stores a list of translations a Nx3 torch.Tensor. 18 | 19 | Attributes: 20 | tensor: float matrix of Nx3. 21 | """ 22 | 23 | def __init__(self, tensor: torch.Tensor): 24 | """ 25 | Args: 26 | tensor (Tensor[float]): 27 | * a Nx3 matrix. Each row is (tx, ty, tz). 28 | """ 29 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") 30 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) 31 | if tensor.numel() == 0: 32 | # Use reshape, so we don't end up creating a new tensor that does not depend on 33 | # the inputs (and consequently confuses jit) 34 | tensor = torch.reshape(0, 3).to(dtype=torch.float32, device=device) 35 | assert tensor.ndim == 2 and (tensor.shape[-1] == 3), tensor.shape 36 | 37 | self.tensor = tensor 38 | 39 | def clone(self) -> "Translations": 40 | """Clone the Translations. 41 | 42 | Returns: 43 | Translations 44 | """ 45 | return Translations(self.tensor.clone()) 46 | 47 | @_maybe_jit_unused 48 | def to(self, device: torch.device = None, **kwargs) -> "Translations": 49 | return Translations(self.tensor.to(device=device, **kwargs)) 50 | 51 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Translations": 52 | """ 53 | Returns: 54 | Translations: Create a new :class:`Translations` by indexing. 55 | 56 | The following usage are allowed: 57 | 1. `new_transes = transes[3]`: return a `Translations` which contains only one translation. 58 | 2. `new_transes = transes[2:10]`: return a slice of transes. 59 | 3. `new_transes = transes[vector]`, where vector is a torch.BoolTensor 60 | with `length = len(transes)`. Nonzero elements in the vector will be selected. 61 | 62 | Note that the returned Translations might share storage with this Translations, 63 | subject to Pytorch's indexing semantics. 64 | """ 65 | if isinstance(item, int): 66 | return Translations(self.tensor[item].view(1, -1)) 67 | b = self.tensor[item] 68 | assert b.ndim == 2, "Indexing on Translations with {} failed to return a matrix!".format(item) 69 | return Translations(b) 70 | 71 | def __len__(self) -> int: 72 | return self.tensor.shape[0] 73 | 74 | def __repr__(self) -> str: 75 | return "Translations(" + str(self.tensor) + ")" 76 | 77 | def get_centers_2d(self, K: torch.Tensor) -> torch.Tensor: 78 | """ 79 | Args: 80 | K: camera intrinsic matrices, Nx3x3 or 1x3x3 81 | Returns: 82 | The 2d projected object centers in a Nx2 array of (x, y). 83 | """ 84 | bs = self.tensor.shape[0] 85 | proj = (K @ self.tensor.view(bs, 3, 1)).view(bs, 3) 86 | centers_2d = proj[:, :2] / proj[:, 2:3] # Nx2 87 | return centers_2d 88 | 89 | @classmethod 90 | def cat(transes_list: List["Translations"]) -> "Translations": 91 | """Concatenates a list of Translations into a single Translations. 92 | 93 | Arguments: 94 | transes_list (list[Translations]) 95 | 96 | Returns: 97 | Translations: the concatenated Translations 98 | """ 99 | if torch.jit.is_scripting(): 100 | # https://github.com/pytorch/pytorch/issues/18627 101 | # 1. staticmethod can be used in torchscript, But we can not use 102 | # `type(transes).staticmethod` because torchscript only supports function 103 | # `type` with input type `torch.Tensor`. 104 | # 2. classmethod is not fully supported by torchscript. We explicitly assign 105 | # cls to Translations as a workaround to get torchscript support. 106 | cls = Translations 107 | assert isinstance(transes_list, (list, tuple)) 108 | if len(transes_list) == 0: 109 | return cls(torch.empty(0)) 110 | assert all(isinstance(transes, Translations) for transes in transes_list) 111 | 112 | # use torch.cat (v.s. layers.cat) so the returned transes never share storage with input 113 | cat_transes = cls(transes_list[0])(torch.cat([t.tensor for t in transes_list], dim=0)) 114 | return cat_transes 115 | 116 | @property 117 | def device(self) -> device: 118 | return self.tensor.device 119 | 120 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript 121 | # https://github.com/pytorch/pytorch/issues/18627 122 | @torch.jit.unused 123 | def __iter__(self) -> Iterator[torch.Tensor]: 124 | """Yield a translation as a Tensor of shape (3,) at a time.""" 125 | yield from self.tensor 126 | -------------------------------------------------------------------------------- /lib/torch_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/torch_utils/__init__.py -------------------------------------------------------------------------------- /lib/torch_utils/color/__init__.py: -------------------------------------------------------------------------------- 1 | from .gray import rgb_to_grayscale, RgbToGrayscale 2 | from .gray import bgr_to_grayscale, BgrToGrayscale 3 | from .rgb import BgrToRgb, bgr_to_rgb 4 | from .rgb import RgbToBgr, rgb_to_bgr 5 | from .rgb import RgbToRgba, rgb_to_rgba 6 | from .rgb import BgrToRgba, bgr_to_rgba 7 | from .rgb import RgbaToRgb, rgba_to_rgb 8 | from .rgb import RgbaToBgr, rgba_to_bgr 9 | from .hsv import RgbToHsv, rgb_to_hsv 10 | from .hsv import HsvToRgb, hsv_to_rgb 11 | from .hls import RgbToHls, rgb_to_hls 12 | from .hls import HlsToRgb, hls_to_rgb 13 | from .ycbcr import RgbToYcbcr, rgb_to_ycbcr 14 | from .ycbcr import YcbcrToRgb, ycbcr_to_rgb 15 | from .yuv import RgbToYuv, YuvToRgb, rgb_to_yuv, yuv_to_rgb 16 | from .xyz import RgbToXyz, XyzToRgb, rgb_to_xyz, xyz_to_rgb 17 | from .luv import RgbToLuv, LuvToRgb, rgb_to_luv, luv_to_rgb 18 | from .lab import RgbToLab, LabToRgb, rgb_to_lab, lab_to_rgb 19 | 20 | 21 | __all__ = [ 22 | "rgb_to_grayscale", 23 | "bgr_to_grayscale", 24 | "bgr_to_rgb", 25 | "rgb_to_bgr", 26 | "rgb_to_rgba", 27 | "rgb_to_hsv", 28 | "hsv_to_rgb", 29 | "rgb_to_hls", 30 | "hls_to_rgb", 31 | "rgb_to_ycbcr", 32 | "ycbcr_to_rgb", 33 | "rgb_to_yuv", 34 | "yuv_to_rgb", 35 | "rgb_to_xyz", 36 | "xyz_to_rgb", 37 | "rgb_to_lab", 38 | "lab_to_rgb", 39 | "RgbToGrayscale", 40 | "BgrToGrayscale", 41 | "BgrToRgb", 42 | "RgbToBgr", 43 | "RgbToRgba", 44 | "RgbToHsv", 45 | "HsvToRgb", 46 | "RgbToHls", 47 | "HlsToRgb", 48 | "RgbToYcbcr", 49 | "YcbcrToRgb", 50 | "RgbToYuv", 51 | "YuvToRgb", 52 | "RgbToXyz", 53 | "XyzToRgb", 54 | "RgbToLuv", 55 | "LuvToRgb", 56 | "LabToRgb", 57 | "RgbToLab", 58 | ] 59 | -------------------------------------------------------------------------------- /lib/torch_utils/color/gray.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .rgb import bgr_to_rgb 5 | 6 | 7 | def rgb_to_grayscale(image: torch.Tensor) -> torch.Tensor: 8 | r"""Convert a RGB image to grayscale version of image. 9 | 10 | The image data is assumed to be in the range of (0, 1). 11 | 12 | Args: 13 | image (torch.Tensor): RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`. 14 | 15 | Returns: 16 | torch.Tensor: grayscale version of the image with shape :math:`(*,1,H,W)`. 17 | 18 | Example: 19 | >>> input = torch.rand(2, 3, 4, 5) 20 | >>> gray = rgb_to_grayscale(input) # 2x1x4x5 21 | """ 22 | if not isinstance(image, torch.Tensor): 23 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image))) 24 | 25 | if len(image.shape) < 3 or image.shape[-3] != 3: 26 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape)) 27 | 28 | r: torch.Tensor = image[..., 0:1, :, :] 29 | g: torch.Tensor = image[..., 1:2, :, :] 30 | b: torch.Tensor = image[..., 2:3, :, :] 31 | 32 | gray: torch.Tensor = 0.299 * r + 0.587 * g + 0.114 * b 33 | return gray 34 | 35 | 36 | def bgr_to_grayscale(image: torch.Tensor) -> torch.Tensor: 37 | r"""Convert a BGR image to grayscale. 38 | 39 | The image data is assumed to be in the range of (0, 1). First flips to RGB, then converts. 40 | 41 | Args: 42 | image (torch.Tensor): BGR image to be converted to grayscale with shape :math:`(*,3,H,W)`. 43 | 44 | Returns: 45 | torch.Tensor: grayscale version of the image with shape :math:`(*,1,H,W)`. 46 | 47 | Example: 48 | >>> input = torch.rand(2, 3, 4, 5) 49 | >>> gray = bgr_to_grayscale(input) # 2x1x4x5 50 | """ 51 | if not isinstance(image, torch.Tensor): 52 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image))) 53 | 54 | if len(image.shape) < 3 or image.shape[-3] != 3: 55 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape)) 56 | 57 | image_rgb = bgr_to_rgb(image) 58 | gray: torch.Tensor = rgb_to_grayscale(image_rgb) 59 | return gray 60 | 61 | 62 | class RgbToGrayscale(nn.Module): 63 | r"""Module to convert a RGB image to grayscale version of image. 64 | 65 | The image data is assumed to be in the range of (0, 1). 66 | 67 | Shape: 68 | - image: :math:`(*, 3, H, W)` 69 | - output: :math:`(*, 1, H, W)` 70 | 71 | reference: 72 | https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html 73 | 74 | Example: 75 | >>> input = torch.rand(2, 3, 4, 5) 76 | >>> gray = RgbToGrayscale() 77 | >>> output = gray(input) # 2x1x4x5 78 | """ 79 | 80 | def __init__(self) -> None: 81 | super(RgbToGrayscale, self).__init__() 82 | 83 | def forward(self, image: torch.Tensor) -> torch.Tensor: # type: ignore 84 | return rgb_to_grayscale(image) 85 | 86 | 87 | class BgrToGrayscale(nn.Module): 88 | r"""Module to convert a BGR image to grayscale version of image. 89 | 90 | The image data is assumed to be in the range of (0, 1). First flips to RGB, then converts. 91 | 92 | Shape: 93 | - image: :math:`(*, 3, H, W)` 94 | - output: :math:`(*, 1, H, W)` 95 | 96 | reference: 97 | https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html 98 | 99 | Example: 100 | >>> input = torch.rand(2, 3, 4, 5) 101 | >>> gray = BgrToGrayscale() 102 | >>> output = gray(input) # 2x1x4x5 103 | """ 104 | 105 | def __init__(self) -> None: 106 | super(BgrToGrayscale, self).__init__() 107 | 108 | def forward(self, image: torch.Tensor) -> torch.Tensor: # type: ignore 109 | return bgr_to_grayscale(image) 110 | -------------------------------------------------------------------------------- /lib/torch_utils/color/xyz.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def rgb_to_xyz(image: torch.Tensor) -> torch.Tensor: 6 | r"""Converts a RGB image to XYZ. 7 | 8 | Args: 9 | image (torch.Tensor): RGB Image to be converted to XYZ with shape :math:`(*, 3, H, W)`. 10 | 11 | Returns: 12 | torch.Tensor: XYZ version of the image with shape :math:`(*, 3, H, W)`. 13 | 14 | Example: 15 | >>> input = torch.rand(2, 3, 4, 5) 16 | >>> output = rgb_to_xyz(input) # 2x3x4x5 17 | """ 18 | if not isinstance(image, torch.Tensor): 19 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image))) 20 | 21 | if len(image.shape) < 3 or image.shape[-3] != 3: 22 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape)) 23 | 24 | r: torch.Tensor = image[..., 0, :, :] 25 | g: torch.Tensor = image[..., 1, :, :] 26 | b: torch.Tensor = image[..., 2, :, :] 27 | 28 | x: torch.Tensor = 0.412453 * r + 0.357580 * g + 0.180423 * b 29 | y: torch.Tensor = 0.212671 * r + 0.715160 * g + 0.072169 * b 30 | z: torch.Tensor = 0.019334 * r + 0.119193 * g + 0.950227 * b 31 | 32 | out: torch.Tensor = torch.stack([x, y, z], -3) 33 | 34 | return out 35 | 36 | 37 | def xyz_to_rgb(image: torch.Tensor) -> torch.Tensor: 38 | r"""Converts a XYZ image to RGB. 39 | 40 | Args: 41 | image (torch.Tensor): XYZ Image to be converted to RGB with shape :math:`(*, 3, H, W)`. 42 | 43 | Returns: 44 | torch.Tensor: RGB version of the image with shape :math:`(*, 3, H, W)`. 45 | 46 | Example: 47 | >>> input = torch.rand(2, 3, 4, 5) 48 | >>> output = xyz_to_rgb(input) # 2x3x4x5 49 | """ 50 | if not isinstance(image, torch.Tensor): 51 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image))) 52 | 53 | if len(image.shape) < 3 or image.shape[-3] != 3: 54 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape)) 55 | 56 | x: torch.Tensor = image[..., 0, :, :] 57 | y: torch.Tensor = image[..., 1, :, :] 58 | z: torch.Tensor = image[..., 2, :, :] 59 | 60 | r: torch.Tensor = 3.2404813432005266 * x + -1.5371515162713185 * y + -0.4985363261688878 * z 61 | g: torch.Tensor = -0.9692549499965682 * x + 1.8759900014898907 * y + 0.0415559265582928 * z 62 | b: torch.Tensor = 0.0556466391351772 * x + -0.2040413383665112 * y + 1.0573110696453443 * z 63 | 64 | out: torch.Tensor = torch.stack([r, g, b], dim=-3) 65 | 66 | return out 67 | 68 | 69 | class RgbToXyz(nn.Module): 70 | r"""Converts an image from RGB to XYZ. 71 | 72 | The image data is assumed to be in the range of (0, 1). 73 | 74 | Returns: 75 | torch.Tensor: XYZ version of the image. 76 | 77 | Shape: 78 | - image: :math:`(*, 3, H, W)` 79 | - output: :math:`(*, 3, H, W)` 80 | 81 | Examples: 82 | >>> input = torch.rand(2, 3, 4, 5) 83 | >>> xyz = RgbToXyz() 84 | >>> output = xyz(input) # 2x3x4x5 85 | 86 | Reference: 87 | [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html 88 | """ 89 | 90 | def __init__(self) -> None: 91 | super(RgbToXyz, self).__init__() 92 | 93 | def forward(self, image: torch.Tensor) -> torch.Tensor: 94 | return rgb_to_xyz(image) 95 | 96 | 97 | class XyzToRgb(nn.Module): 98 | r"""Converts an image from XYZ to RGB. 99 | 100 | Returns: 101 | torch.Tensor: RGB version of the image. 102 | 103 | Shape: 104 | - image: :math:`(*, 3, H, W)` 105 | - output: :math:`(*, 3, H, W)` 106 | 107 | Examples: 108 | >>> input = torch.rand(2, 3, 4, 5) 109 | >>> rgb = XyzToRgb() 110 | >>> output = rgb(input) # 2x3x4x5 111 | 112 | Reference: 113 | [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html 114 | """ 115 | 116 | def __init__(self) -> None: 117 | super(XyzToRgb, self).__init__() 118 | 119 | def forward(self, image: torch.Tensor) -> torch.Tensor: 120 | return xyz_to_rgb(image) 121 | -------------------------------------------------------------------------------- /lib/torch_utils/color/ycbcr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def rgb_to_ycbcr(image: torch.Tensor) -> torch.Tensor: 6 | r"""Convert an RGB image to YCbCr. 7 | 8 | Args: 9 | image (torch.Tensor): RGB Image to be converted to YCbCr with shape :math:`(*, 3, H, W)`. 10 | 11 | Returns: 12 | torch.Tensor: YCbCr version of the image with shape :math:`(*, 3, H, W)`. 13 | 14 | Examples: 15 | >>> input = torch.rand(2, 3, 4, 5) 16 | >>> output = rgb_to_ycbcr(input) # 2x3x4x5 17 | """ 18 | if not isinstance(image, torch.Tensor): 19 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image))) 20 | 21 | if len(image.shape) < 3 or image.shape[-3] != 3: 22 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape)) 23 | 24 | r: torch.Tensor = image[..., 0, :, :] 25 | g: torch.Tensor = image[..., 1, :, :] 26 | b: torch.Tensor = image[..., 2, :, :] 27 | 28 | delta: float = 0.5 29 | y: torch.Tensor = 0.299 * r + 0.587 * g + 0.114 * b 30 | cb: torch.Tensor = (b - y) * 0.564 + delta 31 | cr: torch.Tensor = (r - y) * 0.713 + delta 32 | return torch.stack([y, cb, cr], -3) 33 | 34 | 35 | def ycbcr_to_rgb(image: torch.Tensor) -> torch.Tensor: 36 | r"""Convert an YCbCr image to RGB. 37 | 38 | The image data is assumed to be in the range of (0, 1). 39 | 40 | Args: 41 | image (torch.Tensor): YCbCr Image to be converted to RGB with shape :math:`(*, 3, H, W)`. 42 | 43 | Returns: 44 | torch.Tensor: RGB version of the image with shape :math:`(*, 3, H, W)`. 45 | 46 | Examples: 47 | >>> input = torch.rand(2, 3, 4, 5) 48 | >>> output = ycbcr_to_rgb(input) # 2x3x4x5 49 | """ 50 | if not isinstance(image, torch.Tensor): 51 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image))) 52 | 53 | if len(image.shape) < 3 or image.shape[-3] != 3: 54 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape)) 55 | 56 | y: torch.Tensor = image[..., 0, :, :] 57 | cb: torch.Tensor = image[..., 1, :, :] 58 | cr: torch.Tensor = image[..., 2, :, :] 59 | 60 | delta: float = 0.5 61 | cb_shifted: torch.Tensor = cb - delta 62 | cr_shifted: torch.Tensor = cr - delta 63 | 64 | r: torch.Tensor = y + 1.403 * cr_shifted 65 | g: torch.Tensor = y - 0.714 * cr_shifted - 0.344 * cb_shifted 66 | b: torch.Tensor = y + 1.773 * cb_shifted 67 | return torch.stack([r, g, b], -3) 68 | 69 | 70 | class RgbToYcbcr(nn.Module): 71 | r"""Convert an image from RGB to YCbCr. 72 | 73 | The image data is assumed to be in the range of (0, 1). 74 | 75 | Returns: 76 | torch.Tensor: YCbCr version of the image. 77 | 78 | Shape: 79 | - image: :math:`(*, 3, H, W)` 80 | - output: :math:`(*, 3, H, W)` 81 | 82 | Examples: 83 | >>> input = torch.rand(2, 3, 4, 5) 84 | >>> ycbcr = RgbToYcbcr() 85 | >>> output = ycbcr(input) # 2x3x4x5 86 | """ 87 | 88 | def __init__(self) -> None: 89 | super(RgbToYcbcr, self).__init__() 90 | 91 | def forward(self, image: torch.Tensor) -> torch.Tensor: 92 | return rgb_to_ycbcr(image) 93 | 94 | 95 | class YcbcrToRgb(nn.Module): 96 | r"""Convert an image from YCbCr to Rgb. 97 | 98 | The image data is assumed to be in the range of (0, 1). 99 | 100 | Returns: 101 | torch.Tensor: RGB version of the image. 102 | 103 | Shape: 104 | - image: :math:`(*, 3, H, W)` 105 | - output: :math:`(*, 3, H, W)` 106 | 107 | Examples: 108 | >>> input = torch.rand(2, 3, 4, 5) 109 | >>> rgb = YcbcrToRgb() 110 | >>> output = rgb(input) # 2x3x4x5 111 | """ 112 | 113 | def __init__(self) -> None: 114 | super(YcbcrToRgb, self).__init__() 115 | 116 | def forward(self, image: torch.Tensor) -> torch.Tensor: 117 | return ycbcr_to_rgb(image) 118 | -------------------------------------------------------------------------------- /lib/torch_utils/color/yuv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def rgb_to_yuv(image: torch.Tensor) -> torch.Tensor: 6 | r"""Convert an RGB image to YUV. 7 | 8 | The image data is assumed to be in the range of (0, 1). 9 | 10 | Args: 11 | image (torch.Tensor): RGB Image to be converted to YUV with shape :math:`(*, 3, H, W)`. 12 | 13 | Returns: 14 | torch.Tensor: YUV version of the image with shape :math:`(*, 3, H, W)`. 15 | 16 | Example: 17 | >>> input = torch.rand(2, 3, 4, 5) 18 | >>> output = rgb_to_yuv(input) # 2x3x4x5 19 | """ 20 | if not isinstance(image, torch.Tensor): 21 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image))) 22 | 23 | if len(image.shape) < 3 or image.shape[-3] != 3: 24 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape)) 25 | 26 | r: torch.Tensor = image[..., 0, :, :] 27 | g: torch.Tensor = image[..., 1, :, :] 28 | b: torch.Tensor = image[..., 2, :, :] 29 | 30 | y: torch.Tensor = 0.299 * r + 0.587 * g + 0.114 * b 31 | u: torch.Tensor = -0.147 * r - 0.289 * g + 0.436 * b 32 | v: torch.Tensor = 0.615 * r - 0.515 * g - 0.100 * b 33 | 34 | out: torch.Tensor = torch.stack([y, u, v], -3) 35 | 36 | return out 37 | 38 | 39 | def yuv_to_rgb(image: torch.Tensor) -> torch.Tensor: 40 | r"""Convert an YUV image to RGB. 41 | 42 | The image data is assumed to be in the range of (0, 1). 43 | 44 | Args: 45 | image (torch.Tensor): YUV Image to be converted to RGB with shape :math:`(*, 3, H, W)`. 46 | 47 | Returns: 48 | torch.Tensor: RGB version of the image with shape :math:`(*, 3, H, W)`. 49 | 50 | Example: 51 | >>> input = torch.rand(2, 3, 4, 5) 52 | >>> output = yuv_to_rgb(input) # 2x3x4x5 53 | """ 54 | if not isinstance(image, torch.Tensor): 55 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image))) 56 | 57 | if len(image.shape) < 3 or image.shape[-3] != 3: 58 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape)) 59 | 60 | y: torch.Tensor = image[..., 0, :, :] 61 | u: torch.Tensor = image[..., 1, :, :] 62 | v: torch.Tensor = image[..., 2, :, :] 63 | 64 | r: torch.Tensor = y + 1.14 * v # coefficient for g is 0 65 | g: torch.Tensor = y + -0.396 * u - 0.581 * v 66 | b: torch.Tensor = y + 2.029 * u # coefficient for b is 0 67 | 68 | out: torch.Tensor = torch.stack([r, g, b], -3) 69 | 70 | return out 71 | 72 | 73 | class RgbToYuv(nn.Module): 74 | r"""Convert an image from RGB to YUV. 75 | 76 | The image data is assumed to be in the range of (0, 1). 77 | 78 | Returns: 79 | torch.Tensor: YUV version of the image. 80 | 81 | Shape: 82 | - image: :math:`(*, 3, H, W)` 83 | - output: :math:`(*, 3, H, W)` 84 | 85 | Examples: 86 | >>> input = torch.rand(2, 3, 4, 5) 87 | >>> yuv = RgbToYuv() 88 | >>> output = yuv(input) # 2x3x4x5 89 | 90 | Reference:: 91 | [1] https://es.wikipedia.org/wiki/YUV#RGB_a_Y'UV 92 | """ 93 | 94 | def __init__(self) -> None: 95 | super(RgbToYuv, self).__init__() 96 | 97 | def forward(self, input: torch.Tensor) -> torch.Tensor: 98 | return rgb_to_yuv(input) 99 | 100 | 101 | class YuvToRgb(nn.Module): 102 | r"""Convert an image from YUV to RGB. 103 | 104 | The image data is assumed to be in the range of (0, 1). 105 | 106 | Returns: 107 | torch.Tensor: RGB version of the image. 108 | 109 | Shape: 110 | - image: :math:`(*, 3, H, W)` 111 | - output: :math:`(*, 3, H, W)` 112 | 113 | Examples: 114 | >>> input = torch.rand(2, 3, 4, 5) 115 | >>> rgb = YuvToRgb() 116 | >>> output = rgb(input) # 2x3x4x5 117 | """ 118 | 119 | def __init__(self) -> None: 120 | super(YuvToRgb, self).__init__() 121 | 122 | def forward(self, input: torch.Tensor) -> torch.Tensor: 123 | return yuv_to_rgb(input) 124 | -------------------------------------------------------------------------------- /lib/torch_utils/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/torch_utils/layers/__init__.py -------------------------------------------------------------------------------- /lib/torch_utils/layers/acon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class AconC(nn.Module): 6 | r"""ACON activation (activate or not). 7 | # AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter 8 | # according to "Activate or Not: Learning Customized Activation" . 9 | """ 10 | 11 | def __init__(self, width): 12 | super().__init__() 13 | self.p1 = nn.Parameter(torch.randn(1, width, 1, 1)) 14 | self.p2 = nn.Parameter(torch.randn(1, width, 1, 1)) 15 | self.beta = nn.Parameter(torch.ones(1, width, 1, 1)) 16 | 17 | def forward(self, x): 18 | return (self.p1 * x - self.p2 * x) * torch.sigmoid(self.beta * (self.p1 * x - self.p2 * x)) + self.p2 * x 19 | 20 | 21 | class MetaAconC(nn.Module): 22 | r"""ACON activation (activate or not). 23 | # MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network 24 | # according to "Activate or Not: Learning Customized Activation" . 25 | """ 26 | 27 | def __init__(self, width, r=16): 28 | super().__init__() 29 | self.fc1 = nn.Conv2d(width, max(r, width // r), kernel_size=1, stride=1, bias=True) 30 | self.bn1 = nn.BatchNorm2d(max(r, width // r)) 31 | self.fc2 = nn.Conv2d(max(r, width // r), width, kernel_size=1, stride=1, bias=True) 32 | self.bn2 = nn.BatchNorm2d(width) 33 | 34 | self.p1 = nn.Parameter(torch.randn(1, width, 1, 1)) 35 | self.p2 = nn.Parameter(torch.randn(1, width, 1, 1)) 36 | 37 | def forward(self, x): 38 | beta = torch.sigmoid( 39 | self.bn2(self.fc2(self.bn1(self.fc1(x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True))))) 40 | ) 41 | return (self.p1 * x - self.p2 * x) * torch.sigmoid(beta * (self.p1 * x - self.p2 * x)) + self.p2 * x 42 | -------------------------------------------------------------------------------- /lib/torch_utils/layers/coord_attention.py: -------------------------------------------------------------------------------- 1 | """Coordinate Attention (CVPR 2021). 2 | 3 | Modified from https://github.com/Andrew-Qibin/CoordAttention/blob/main/coordatt.py. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | import torch.nn.functional as F 10 | 11 | 12 | class h_sigmoid(nn.Module): 13 | def __init__(self, inplace=True): 14 | super(h_sigmoid, self).__init__() 15 | self.relu = nn.ReLU6(inplace=inplace) 16 | 17 | def forward(self, x): 18 | return self.relu(x + 3) / 6 19 | 20 | 21 | class h_swish(nn.Module): 22 | def __init__(self, inplace=True): 23 | super(h_swish, self).__init__() 24 | self.sigmoid = h_sigmoid(inplace=inplace) 25 | 26 | def forward(self, x): 27 | return x * self.sigmoid(x) 28 | 29 | 30 | class CoordAtt(nn.Module): 31 | def __init__(self, inp, oup, reduction=32): 32 | super(CoordAtt, self).__init__() 33 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) 34 | self.pool_w = nn.AdaptiveAvgPool2d((1, None)) 35 | 36 | mip = max(8, inp // reduction) 37 | 38 | self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) 39 | self.bn1 = nn.BatchNorm2d(mip) 40 | self.act = h_swish() 41 | 42 | self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 43 | self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 44 | 45 | def forward(self, x): 46 | identity = x 47 | 48 | n, c, h, w = x.size() 49 | x_h = self.pool_h(x) 50 | x_w = self.pool_w(x).permute(0, 1, 3, 2) 51 | 52 | y = torch.cat([x_h, x_w], dim=2) 53 | y = self.conv1(y) 54 | y = self.bn1(y) 55 | y = self.act(y) 56 | 57 | x_h, x_w = torch.split(y, [h, w], dim=2) 58 | x_w = x_w.permute(0, 1, 3, 2) 59 | 60 | a_h = self.conv_h(x_h).sigmoid() 61 | a_w = self.conv_w(x_w).sigmoid() 62 | 63 | out = identity * a_w * a_h 64 | 65 | return out 66 | -------------------------------------------------------------------------------- /lib/torch_utils/layers/dropblock/README.md: -------------------------------------------------------------------------------- 1 | # DropBlock 2 | 3 | ![build](https://travis-ci.org/miguelvr/dropblock.png?branch=master) 4 | [![Downloads](https://pepy.tech/badge/dropblock)](https://pepy.tech/project/dropblock) 5 | 6 | 7 | Implementation of [DropBlock: A regularization method for convolutional networks](https://arxiv.org/pdf/1810.12890.pdf) 8 | in PyTorch. 9 | 10 | ## Abstract 11 | 12 | Deep neural networks often work well when they are over-parameterized 13 | and trained with a massive amount of noise and regularization, such as 14 | weight decay and dropout. Although dropout is widely used as a regularization 15 | technique for fully connected layers, it is often less effective for convolutional layers. 16 | This lack of success of dropout for convolutional layers is perhaps due to the fact 17 | that activation units in convolutional layers are spatially correlated so 18 | information can still flow through convolutional networks despite dropout. 19 | Thus a structured form of dropout is needed to regularize convolutional networks. 20 | In this paper, we introduce DropBlock, a form of structured dropout, where units in a 21 | contiguous region of a feature map are dropped together. 22 | We found that applying DropBlock in skip connections in addition to the 23 | convolution layers increases the accuracy. Also, gradually increasing number 24 | of dropped units during training leads to better accuracy and more robust to hyperparameter choices. 25 | Extensive experiments show that DropBlock works better than dropout in regularizing 26 | convolutional networks. On ImageNet classification, ResNet-50 architecture with 27 | DropBlock achieves 78.13% accuracy, which is more than 1.6% improvement on the baseline. 28 | On COCO detection, DropBlock improves Average Precision of RetinaNet from 36.8% to 38.4%. 29 | 30 | 31 | ## Installation 32 | 33 | Install directly from PyPI: 34 | 35 | pip install dropblock 36 | 37 | or the bleeding edge version from github: 38 | 39 | pip install git+https://github.com/miguelvr/dropblock.git#egg=dropblock 40 | 41 | **NOTE**: Implementation and tests were done in Python 3.6, if you have problems with other versions of python please open an issue. 42 | 43 | ## Usage 44 | 45 | 46 | For 2D inputs (DropBlock2D): 47 | 48 | ```python 49 | import torch 50 | from dropblock import DropBlock2D 51 | 52 | # (bsize, n_feats, height, width) 53 | x = torch.rand(100, 10, 16, 16) 54 | 55 | drop_block = DropBlock2D(block_size=3, drop_prob=0.3) 56 | regularized_x = drop_block(x) 57 | ``` 58 | 59 | For 3D inputs (DropBlock3D): 60 | 61 | ```python 62 | import torch 63 | from dropblock import DropBlock3D 64 | 65 | # (bsize, n_feats, depth, height, width) 66 | x = torch.rand(100, 10, 16, 16, 16) 67 | 68 | drop_block = DropBlock3D(block_size=3, drop_prob=0.3) 69 | regularized_x = drop_block(x) 70 | ``` 71 | 72 | Scheduled Dropblock: 73 | 74 | ```python 75 | import torch 76 | from dropblock import DropBlock2D, LinearScheduler 77 | 78 | # (bsize, n_feats, depth, height, width) 79 | loader = [torch.rand(20, 10, 16, 16) for _ in range(10)] 80 | 81 | drop_block = LinearScheduler( 82 | DropBlock2D(block_size=3, drop_prob=0.), 83 | start_value=0., 84 | stop_value=0.25, 85 | nr_steps=5 86 | ) 87 | 88 | probs = [] 89 | for x in loader: 90 | drop_block.step() 91 | regularized_x = drop_block(x) 92 | probs.append(drop_block.dropblock.drop_prob) 93 | 94 | print(probs) 95 | ``` 96 | 97 | The drop probabilities will be: 98 | ``` 99 | >>> [0. , 0.0625, 0.125 , 0.1875, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25] 100 | ``` 101 | 102 | The user should include the `step()` call at the start of the batch loop, 103 | or at the the start of a model's `forward` call. 104 | 105 | Check [examples/resnet-cifar10.py](examples/resnet-cifar10.py) to 106 | see an implementation example. 107 | 108 | ## Implementation details 109 | 110 | We use `drop_prob` instead of `keep_prob` as a matter of preference, 111 | and to keep the argument consistent with pytorch's dropout. 112 | Regardless, everything else should work similarly to what is described in the paper. 113 | 114 | ## Benchmark 115 | 116 | Refer to [BENCHMARK.md](BENCHMARK.md) 117 | 118 | ## Reference 119 | [Ghiasi et al., 2018] DropBlock: A regularization method for convolutional networks 120 | 121 | ## TODO 122 | - [x] Scheduled DropBlock 123 | - [x] Get benchmark numbers 124 | - [x] Extend the concept for 3D images 125 | -------------------------------------------------------------------------------- /lib/torch_utils/layers/dropblock/__init__.py: -------------------------------------------------------------------------------- 1 | from .dropblock import DropBlock2D, DropBlock3D 2 | from .scheduler import LinearScheduler 3 | 4 | __all__ = ["DropBlock2D", "DropBlock3D", "LinearScheduler"] 5 | -------------------------------------------------------------------------------- /lib/torch_utils/layers/dropblock/dropblock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class DropBlock2D(nn.Module): 7 | r"""Randomly zeroes 2D spatial blocks of the input tensor. 8 | 9 | As described in the paper 10 | `DropBlock: A regularization method for convolutional networks`_ , 11 | dropping whole blocks of feature map allows to remove semantic 12 | information as compared to regular dropout. 13 | 14 | Args: 15 | drop_prob (float): probability of an element to be dropped. 16 | block_size (int): size of the block to drop 17 | 18 | Shape: 19 | - Input: `(N, C, H, W)` 20 | - Output: `(N, C, H, W)` 21 | 22 | .. _DropBlock: A regularization method for convolutional networks: 23 | https://arxiv.org/abs/1810.12890 24 | 25 | """ 26 | 27 | def __init__(self, drop_prob, block_size): 28 | super(DropBlock2D, self).__init__() 29 | 30 | self.drop_prob = drop_prob 31 | self.block_size = block_size 32 | 33 | def forward(self, x): 34 | # shape: (bsize, channels, height, width) 35 | 36 | assert x.dim() == 4, "Expected input with 4 dimensions (bsize, channels, height, width)" 37 | 38 | if not self.training or self.drop_prob == 0.0: 39 | return x 40 | else: 41 | # get gamma value 42 | gamma = self._compute_gamma(x) 43 | 44 | # sample mask 45 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() 46 | 47 | # place mask on input device 48 | mask = mask.to(x.device) 49 | 50 | # compute block mask 51 | block_mask = self._compute_block_mask(mask) 52 | 53 | # apply block mask 54 | out = x * block_mask[:, None, :, :] 55 | 56 | # scale output 57 | out = out * block_mask.numel() / block_mask.sum() 58 | 59 | return out 60 | 61 | def _compute_block_mask(self, mask): 62 | block_mask = F.max_pool2d( 63 | input=mask[:, None, :, :], 64 | kernel_size=(self.block_size, self.block_size), 65 | stride=(1, 1), 66 | padding=self.block_size // 2, 67 | ) 68 | 69 | if self.block_size % 2 == 0: 70 | block_mask = block_mask[:, :, :-1, :-1] 71 | 72 | block_mask = 1 - block_mask.squeeze(1) 73 | 74 | return block_mask 75 | 76 | def _compute_gamma(self, x): 77 | return self.drop_prob / (self.block_size**2) 78 | 79 | 80 | class DropBlock3D(DropBlock2D): 81 | r"""Randomly zeroes 3D spatial blocks of the input tensor. 82 | 83 | An extension to the concept described in the paper 84 | `DropBlock: A regularization method for convolutional networks`_ , 85 | dropping whole blocks of feature map allows to remove semantic 86 | information as compared to regular dropout. 87 | 88 | Args: 89 | drop_prob (float): probability of an element to be dropped. 90 | block_size (int): size of the block to drop 91 | 92 | Shape: 93 | - Input: `(N, C, D, H, W)` 94 | - Output: `(N, C, D, H, W)` 95 | 96 | .. _DropBlock: A regularization method for convolutional networks: 97 | https://arxiv.org/abs/1810.12890 98 | 99 | """ 100 | 101 | def __init__(self, drop_prob, block_size): 102 | super(DropBlock3D, self).__init__(drop_prob, block_size) 103 | 104 | def forward(self, x): 105 | # shape: (bsize, channels, depth, height, width) 106 | 107 | assert x.dim() == 5, "Expected input with 5 dimensions (bsize, channels, depth, height, width)" 108 | 109 | if not self.training or self.drop_prob == 0.0: 110 | return x 111 | else: 112 | # get gamma value 113 | gamma = self._compute_gamma(x) 114 | 115 | # sample mask 116 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() 117 | 118 | # place mask on input device 119 | mask = mask.to(x.device) 120 | 121 | # compute block mask 122 | block_mask = self._compute_block_mask(mask) 123 | 124 | # apply block mask 125 | out = x * block_mask[:, None, :, :, :] 126 | 127 | # scale output 128 | out = out * block_mask.numel() / block_mask.sum() 129 | 130 | return out 131 | 132 | def _compute_block_mask(self, mask): 133 | block_mask = F.max_pool3d( 134 | input=mask[:, None, :, :, :], 135 | kernel_size=(self.block_size, self.block_size, self.block_size), 136 | stride=(1, 1, 1), 137 | padding=self.block_size // 2, 138 | ) 139 | 140 | if self.block_size % 2 == 0: 141 | block_mask = block_mask[:, :, :-1, :-1, :-1] 142 | 143 | block_mask = 1 - block_mask.squeeze(1) 144 | 145 | return block_mask 146 | 147 | def _compute_gamma(self, x): 148 | return self.drop_prob / (self.block_size**3) 149 | -------------------------------------------------------------------------------- /lib/torch_utils/layers/dropblock/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | 4 | 5 | class LinearScheduler(nn.Module): 6 | def __init__(self, dropblock, start_value, stop_value, nr_steps): 7 | super(LinearScheduler, self).__init__() 8 | self.dropblock = dropblock 9 | self.i = 0 10 | self.drop_values = np.linspace(start=start_value, stop=stop_value, num=int(nr_steps)) 11 | 12 | def forward(self, x): 13 | return self.dropblock(x) 14 | 15 | def step(self): 16 | if self.i < len(self.drop_values): 17 | self.dropblock.drop_prob = self.drop_values[self.i] 18 | 19 | self.i += 1 20 | -------------------------------------------------------------------------------- /lib/torch_utils/layers/mean_conv_deconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Optional, List, Tuple, Union 5 | from torch import Tensor 6 | from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t 7 | 8 | 9 | class MeanConv2d(nn.Conv2d): 10 | """Conv2d with weight centralization. 11 | 12 | ref: Weight and Gradient Centralization in Deep Neural Networks. https://arxiv.org/pdf/2010.00866.pdf 13 | """ 14 | 15 | def forward(self, x): 16 | w = self.weight # [c_out, c_in, k, k] 17 | w = w - torch.mean(w, dim=[1, 2, 3], keepdim=True) 18 | return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) 19 | 20 | 21 | class MeanConvTranspose2d(nn.ConvTranspose2d): 22 | """ConvTranspose2d with Weight Centralization. 23 | 24 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - 25 | https://arxiv.org/abs/1903.10520v2 26 | """ 27 | 28 | def __init__( 29 | self, 30 | in_channels: int, 31 | out_channels: int, 32 | kernel_size: _size_2_t, 33 | stride: _size_2_t = 1, 34 | padding: _size_2_t = 0, 35 | output_padding: _size_2_t = 0, 36 | groups: int = 1, 37 | bias: bool = True, 38 | dilation: int = 1, 39 | padding_mode: str = "zeros", 40 | device=None, 41 | dtype=None, 42 | eps=1e-6, 43 | ): 44 | super().__init__( 45 | in_channels, 46 | out_channels, 47 | kernel_size, 48 | stride=stride, 49 | padding=padding, 50 | output_padding=output_padding, 51 | dilation=dilation, 52 | groups=groups, 53 | bias=bias, 54 | padding_mode=padding_mode, 55 | device=device, 56 | dtype=dtype, 57 | ) 58 | self.eps = eps 59 | 60 | def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: 61 | if self.padding_mode != "zeros": 62 | raise ValueError("Only `zeros` padding mode is supported for ConvTranspose2d") 63 | 64 | assert isinstance(self.padding, tuple) 65 | # One cannot replace List by Tuple or Sequence in "_output_padding" because 66 | # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. 67 | output_padding = self._output_padding( 68 | input, output_size, self.stride, self.padding, self.kernel_size, self.dilation 69 | ) # type: ignore[arg-type] 70 | 71 | w = self.weight 72 | w = w - torch.mean(w, dim=[1, 2, 3], keepdim=True) 73 | return F.conv_transpose2d( 74 | input, w, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation 75 | ) 76 | -------------------------------------------------------------------------------- /lib/torch_utils/layers/std_conv_transpose.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | from typing import Optional, List, Tuple, Union 4 | from torch import Tensor 5 | from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t 6 | 7 | 8 | class StdConvTranspose2d(nn.ConvTranspose2d): 9 | """ConvTranspose2d with Weight Standardization. 10 | 11 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - 12 | https://arxiv.org/abs/1903.10520v2 13 | """ 14 | 15 | def __init__( 16 | self, 17 | in_channels: int, 18 | out_channels: int, 19 | kernel_size: _size_2_t, 20 | stride: _size_2_t = 1, 21 | padding: _size_2_t = 0, 22 | output_padding: _size_2_t = 0, 23 | groups: int = 1, 24 | bias: bool = True, 25 | dilation: int = 1, 26 | padding_mode: str = "zeros", 27 | device=None, 28 | dtype=None, 29 | eps=1e-6, 30 | ): 31 | super().__init__( 32 | in_channels, 33 | out_channels, 34 | kernel_size, 35 | stride=stride, 36 | padding=padding, 37 | output_padding=output_padding, 38 | dilation=dilation, 39 | groups=groups, 40 | bias=bias, 41 | padding_mode=padding_mode, 42 | device=device, 43 | dtype=dtype, 44 | ) 45 | self.eps = eps 46 | 47 | def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: 48 | if self.padding_mode != "zeros": 49 | raise ValueError("Only `zeros` padding mode is supported for ConvTranspose2d") 50 | 51 | assert isinstance(self.padding, tuple) 52 | # One cannot replace List by Tuple or Sequence in "_output_padding" because 53 | # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. 54 | output_padding = self._output_padding( 55 | input, output_size, self.stride, self.padding, self.kernel_size, self.dilation 56 | ) # type: ignore[arg-type] 57 | 58 | weight = F.batch_norm( 59 | self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0.0, eps=self.eps 60 | ).reshape_as(self.weight) 61 | return F.conv_transpose2d( 62 | input, weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation 63 | ) 64 | -------------------------------------------------------------------------------- /lib/torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | 5 | 6 | # ---------------------------------------------------------------------------- 7 | # Replace NaN/Inf with specified numerical values. 8 | # (from https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/torch_utils/misc.py) 9 | 10 | try: 11 | nan_to_num = torch.nan_to_num # 1.8.0a0 12 | except AttributeError: 13 | 14 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 15 | assert isinstance(input, torch.Tensor) 16 | if posinf is None: 17 | posinf = torch.finfo(input.dtype).max 18 | if neginf is None: 19 | neginf = torch.finfo(input.dtype).min 20 | assert nan == 0 21 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 22 | 23 | 24 | def set_nan_to_0(a, name=None, verbose=False): 25 | if torch.isnan(a).any(): 26 | if verbose and name is not None: 27 | print("nan in {}".format(name)) 28 | a[a != a] = 0 29 | return a 30 | 31 | 32 | # ---------------------------------------------------------------------------- 33 | # Symbolic assert. 34 | 35 | try: 36 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 37 | except AttributeError: 38 | symbolic_assert = torch.Assert # 1.7.0 39 | 40 | # ---------------------------------------------------------------------------- 41 | 42 | 43 | class suppress_tracer_warnings(warnings.catch_warnings): 44 | """Context manager to suppress known warnings in torch.jit.trace().""" 45 | 46 | def __enter__(self): 47 | super().__enter__() 48 | warnings.simplefilter("ignore", category=torch.jit.TracerWarning) 49 | return self 50 | 51 | 52 | # ---------------------------------------------------------------------------- 53 | 54 | 55 | def assert_shape(tensor, ref_shape): 56 | """Assert that the shape of a tensor matches the given list of integers. 57 | 58 | None indicates that the size of a dimension is allowed to vary. 59 | Performs symbolic assertion when used in torch.jit.trace(). 60 | """ 61 | if tensor.ndim != len(ref_shape): 62 | raise AssertionError(f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}") 63 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 64 | if ref_size is None: 65 | pass 66 | elif isinstance(ref_size, torch.Tensor): 67 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 68 | symbolic_assert( 69 | torch.equal(torch.as_tensor(size), ref_size), 70 | f"Wrong size for dimension {idx}", 71 | ) 72 | elif isinstance(size, torch.Tensor): 73 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 74 | symbolic_assert( 75 | torch.equal(size, torch.as_tensor(ref_size)), 76 | f"Wrong size for dimension {idx}: expected {ref_size}", 77 | ) 78 | elif size != ref_size: 79 | raise AssertionError(f"Wrong size for dimension {idx}: got {size}, expected {ref_size}") 80 | -------------------------------------------------------------------------------- /lib/torch_utils/solver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/torch_utils/solver/__init__.py -------------------------------------------------------------------------------- /lib/torch_utils/solver/adamp.py: -------------------------------------------------------------------------------- 1 | """AdamP Copyright (c) 2020-present NAVER Corp. 2 | 3 | MIT license 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim.optimizer import Optimizer, required 10 | import math 11 | 12 | 13 | class AdamP(Optimizer): 14 | def __init__( 15 | self, 16 | params, 17 | lr=1e-3, 18 | betas=(0.9, 0.999), 19 | eps=1e-8, 20 | weight_decay=0, 21 | delta=0.1, 22 | wd_ratio=0.1, 23 | nesterov=False, 24 | ): 25 | defaults = dict( 26 | lr=lr, 27 | betas=betas, 28 | eps=eps, 29 | weight_decay=weight_decay, 30 | delta=delta, 31 | wd_ratio=wd_ratio, 32 | nesterov=nesterov, 33 | ) 34 | super(AdamP, self).__init__(params, defaults) 35 | 36 | def _channel_view(self, x): 37 | return x.view(x.size(0), -1) 38 | 39 | def _layer_view(self, x): 40 | return x.view(1, -1) 41 | 42 | def _cosine_similarity(self, x, y, eps, view_func): 43 | x = view_func(x) 44 | y = view_func(y) 45 | 46 | return F.cosine_similarity(x, y, dim=1, eps=eps).abs_() 47 | 48 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 49 | wd = 1 50 | expand_size = [-1] + [1] * (len(p.shape) - 1) 51 | for view_func in [self._channel_view, self._layer_view]: 52 | 53 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 54 | 55 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 56 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 57 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 58 | wd = wd_ratio 59 | 60 | return perturb, wd 61 | 62 | return perturb, wd 63 | 64 | def step(self, closure=None): 65 | loss = None 66 | if closure is not None: 67 | loss = closure() 68 | 69 | for group in self.param_groups: 70 | for p in group["params"]: 71 | if p.grad is None: 72 | continue 73 | 74 | grad = p.grad.data 75 | beta1, beta2 = group["betas"] 76 | nesterov = group["nesterov"] 77 | 78 | state = self.state[p] 79 | 80 | # State initialization 81 | if len(state) == 0: 82 | state["step"] = 0 83 | state["exp_avg"] = torch.zeros_like(p.data) 84 | state["exp_avg_sq"] = torch.zeros_like(p.data) 85 | 86 | # Adam 87 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 88 | 89 | state["step"] += 1 90 | bias_correction1 = 1 - beta1 ** state["step"] 91 | bias_correction2 = 1 - beta2 ** state["step"] 92 | 93 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 94 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 95 | 96 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"]) 97 | step_size = group["lr"] / bias_correction1 98 | 99 | if nesterov: 100 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 101 | else: 102 | perturb = exp_avg / denom 103 | 104 | # Projection 105 | wd_ratio = 1 106 | if len(p.shape) > 1: 107 | perturb, wd_ratio = self._projection( 108 | p, 109 | grad, 110 | perturb, 111 | group["delta"], 112 | group["wd_ratio"], 113 | group["eps"], 114 | ) 115 | 116 | # Weight decay 117 | if group["weight_decay"] > 0: 118 | p.data.mul_(1 - group["lr"] * group["weight_decay"] * wd_ratio) 119 | 120 | # Step 121 | p.data.add_(perturb, alpha=-step_size) 122 | 123 | return loss 124 | -------------------------------------------------------------------------------- /lib/torch_utils/solver/grad_clip_d2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # modified to support full_model gradient norm clip 3 | import copy 4 | import itertools 5 | import logging 6 | from enum import Enum 7 | from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union 8 | from omegaconf.omegaconf import OmegaConf 9 | 10 | import torch 11 | from detectron2.config import CfgNode 12 | from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler 13 | from lib.utils.config_utils import try_get_key 14 | 15 | 16 | _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] 17 | _GradientClipper = Callable[[_GradientClipperInput], None] 18 | 19 | 20 | class GradientClipType(Enum): 21 | VALUE = "value" 22 | NORM = "norm" 23 | FULL_MODEL = "full_model" 24 | 25 | 26 | def _create_gradient_clipper(cfg) -> _GradientClipper: 27 | """Creates gradient clipping closure to clip by value or by norm, according 28 | to the provided config.""" 29 | cfg = copy.deepcopy(cfg) 30 | 31 | _clip_value = try_get_key(cfg, "CLIP_VALUE", "clip_value", default=1.0) 32 | _norm_type = try_get_key(cfg, "NORM_TYPE", "norm_type", default=2.0) 33 | 34 | def clip_grad_norm(p: _GradientClipperInput): 35 | torch.nn.utils.clip_grad_norm_(p, _clip_value, _norm_type) 36 | 37 | def clip_grad_value(p: _GradientClipperInput): 38 | torch.nn.utils.clip_grad_value_(p, _clip_value) 39 | 40 | _GRADIENT_CLIP_TYPE_TO_CLIPPER = { 41 | GradientClipType.VALUE: clip_grad_value, 42 | GradientClipType.NORM: clip_grad_norm, 43 | GradientClipType.FULL_MODEL: clip_grad_norm, 44 | } 45 | _clip_type = try_get_key(cfg, "CLIP_TYPE", "clip_type", default="full_model") 46 | return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(_clip_type)] 47 | 48 | 49 | def _generate_optimizer_class_with_gradient_clipping( 50 | optimizer: Type[torch.optim.Optimizer], 51 | *, 52 | per_param_clipper: Optional[_GradientClipper] = None, 53 | global_clipper: Optional[_GradientClipper] = None, 54 | ) -> Type[torch.optim.Optimizer]: 55 | """Dynamically creates a new type that inherits the type of a given 56 | instance and overrides the `step` method to add gradient clipping.""" 57 | assert ( 58 | per_param_clipper is None or global_clipper is None 59 | ), "Not allowed to use both per-parameter clipping and global clipping" 60 | 61 | def optimizer_wgc_step(self, closure=None): 62 | if per_param_clipper is not None: 63 | for group in self.param_groups: 64 | for p in group["params"]: 65 | per_param_clipper(p) 66 | else: 67 | # global clipper for future use with detr 68 | # (https://github.com/facebookresearch/detr/pull/287) 69 | all_params = itertools.chain(*[g["params"] for g in self.param_groups]) 70 | global_clipper(all_params) 71 | super(type(self), self).step(closure) 72 | 73 | OptimizerWithGradientClip = type( 74 | optimizer.__name__ + "WithGradientClip", 75 | (optimizer,), 76 | {"step": optimizer_wgc_step}, 77 | ) 78 | return OptimizerWithGradientClip 79 | 80 | 81 | def maybe_add_gradient_clipping(cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]) -> Type[torch.optim.Optimizer]: 82 | """If gradient clipping is enabled through config options, wraps the 83 | existing optimizer type to become a new dynamically created class 84 | OptimizerWithGradientClip that inherits the given optimizer and overrides 85 | the `step` method to include gradient clipping. 86 | 87 | Args: 88 | cfg: CfgNode, configuration options 89 | optimizer: type. A subclass of torch.optim.Optimizer 90 | Return: 91 | type: either the input `optimizer` (if gradient clipping is disabled), or 92 | a subclass of it with gradient clipping included in the `step` method. 93 | """ 94 | clip_cfg = try_get_key( 95 | cfg, "SOLVER.CLIP_GRADIENTS", "train.grad_clip", default=OmegaConf.create(dict(enabled=False)) 96 | ) 97 | if not try_get_key(clip_cfg, "ENABLED", "enabled", default=False): 98 | return optimizer 99 | if isinstance(optimizer, torch.optim.Optimizer): 100 | optimizer_type = type(optimizer) 101 | else: 102 | assert issubclass(optimizer, torch.optim.Optimizer), optimizer 103 | optimizer_type = optimizer 104 | 105 | grad_clipper = _create_gradient_clipper(clip_cfg) 106 | _clip_type = try_get_key(clip_cfg, "CLIP_TYPE", "clip_type", default="full_model") 107 | if _clip_type != "full_model": 108 | OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( 109 | optimizer_type, per_param_clipper=grad_clipper 110 | ) 111 | else: 112 | OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( 113 | optimizer_type, global_clipper=grad_clipper 114 | ) 115 | if isinstance(optimizer, torch.optim.Optimizer): 116 | optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended 117 | return optimizer 118 | else: 119 | return OptimizerWithGradientClip 120 | -------------------------------------------------------------------------------- /lib/torch_utils/solver/lookahead.py: -------------------------------------------------------------------------------- 1 | """Lookahead Optimizer Wrapper. Implementation modified from: 2 | https://github.com/alphadl/lookahead.pytorch. 3 | 4 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 5 | """ 6 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from collections import defaultdict 10 | 11 | # from lib.utils import logger 12 | 13 | 14 | class Lookahead(Optimizer): 15 | def __init__(self, base_optimizer, alpha=0.5, k=6): 16 | if not 0.0 <= alpha <= 1.0: 17 | raise ValueError(f"Invalid slow update rate: {alpha}") 18 | if not 1 <= k: 19 | raise ValueError(f"Invalid lookahead steps: {k}") 20 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 21 | self.base_optimizer = base_optimizer 22 | self.param_groups = self.base_optimizer.param_groups 23 | self.defaults = base_optimizer.defaults 24 | self.defaults.update(defaults) 25 | self.state = defaultdict(dict) 26 | # manually add our defaults to the param groups 27 | for name, default in defaults.items(): 28 | for group in self.param_groups: 29 | group.setdefault(name, default) 30 | 31 | def update_slow(self, group): 32 | for fast_p in group["params"]: 33 | if fast_p.grad is None: 34 | continue 35 | param_state = self.state[fast_p] 36 | if "slow_buffer" not in param_state: 37 | param_state["slow_buffer"] = torch.empty_like(fast_p.data) 38 | param_state["slow_buffer"].copy_(fast_p.data) 39 | slow = param_state["slow_buffer"] 40 | slow.add_(group["lookahead_alpha"], fast_p.data - slow) 41 | fast_p.data.copy_(slow) 42 | 43 | def sync_lookahead(self): 44 | for group in self.param_groups: 45 | self.update_slow(group) 46 | 47 | def step(self, closure=None): 48 | # assert id(self.param_groups) == id(self.base_optimizer.param_groups) 49 | loss = self.base_optimizer.step(closure) 50 | for group in self.param_groups: 51 | group["lookahead_step"] += 1 52 | if group["lookahead_step"] % group["lookahead_k"] == 0: 53 | self.update_slow(group) 54 | return loss 55 | 56 | def state_dict(self): 57 | fast_state_dict = self.base_optimizer.state_dict() 58 | slow_state = {(id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items()} 59 | fast_state = fast_state_dict["state"] 60 | param_groups = fast_state_dict["param_groups"] 61 | return { 62 | "state": fast_state, 63 | "slow_state": slow_state, 64 | "param_groups": param_groups, 65 | } 66 | 67 | def load_state_dict(self, state_dict): 68 | fast_state_dict = { 69 | "state": state_dict["state"], 70 | "param_groups": state_dict["param_groups"], 71 | } 72 | self.base_optimizer.load_state_dict(fast_state_dict) 73 | 74 | # We want to restore the slow state, but share param_groups reference 75 | # with base_optimizer. This is a bit redundant but least code 76 | slow_state_new = False 77 | if "slow_state" not in state_dict: 78 | print("Loading state_dict from optimizer without Lookahead applied.") 79 | state_dict["slow_state"] = defaultdict(dict) 80 | slow_state_new = True 81 | slow_state_dict = { 82 | "state": state_dict["slow_state"], 83 | "param_groups": state_dict["param_groups"], # this is pointless but saves code 84 | } 85 | super(Lookahead, self).load_state_dict(slow_state_dict) 86 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 87 | if slow_state_new: 88 | # reapply defaults to catch missing lookahead specific ones 89 | for name, default in self.defaults.items(): 90 | for group in self.param_groups: 91 | group.setdefault(name, default) 92 | -------------------------------------------------------------------------------- /lib/torch_utils/solver/over9000.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000 3 | #### 4 | 5 | # import torch, math 6 | # from torch.optim.optimizer import Optimizer 7 | # import itertools as it 8 | from .lookahead import Lookahead 9 | from .ralamb import Ralamb 10 | 11 | 12 | # RangerLars = Over9000 = RAdam + LARS + LookAHead 13 | 14 | # Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py 15 | # RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 16 | 17 | 18 | def Over9000(params, alpha=0.5, k=6, *args, **kwargs): 19 | ralamb = Ralamb(params, *args, **kwargs) 20 | return Lookahead(ralamb, alpha, k) 21 | 22 | 23 | RangerLars = Over9000 24 | -------------------------------------------------------------------------------- /lib/torch_utils/solver/ralamb.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000 3 | #### 4 | 5 | import torch, math 6 | from torch.optim.optimizer import Optimizer 7 | 8 | # RAdam + LARS 9 | class Ralamb(Optimizer): 10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 11 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 12 | self.buffer = [[None, None, None] for ind in range(10)] 13 | super(Ralamb, self).__init__(params, defaults) 14 | 15 | def __setstate__(self, state): 16 | super(Ralamb, self).__setstate__(state) 17 | 18 | def step(self, closure=None): 19 | 20 | loss = None 21 | if closure is not None: 22 | loss = closure() 23 | 24 | for group in self.param_groups: 25 | 26 | for p in group["params"]: 27 | if p.grad is None: 28 | continue 29 | grad = p.grad.data.float() 30 | if grad.is_sparse: 31 | raise RuntimeError("Ralamb does not support sparse gradients") 32 | 33 | p_data_fp32 = p.data.float() 34 | 35 | state = self.state[p] 36 | 37 | if len(state) == 0: 38 | state["step"] = 0 39 | state["exp_avg"] = torch.zeros_like(p_data_fp32) 40 | state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) 41 | else: 42 | state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) 43 | state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) 44 | 45 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 46 | beta1, beta2 = group["betas"] 47 | 48 | # Decay the first and second moment running average coefficient 49 | # m_t 50 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 51 | # v_t 52 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 53 | 54 | state["step"] += 1 55 | buffered = self.buffer[int(state["step"] % 10)] 56 | 57 | if state["step"] == buffered[0]: 58 | N_sma, radam_step_size = buffered[1], buffered[2] 59 | else: 60 | buffered[0] = state["step"] 61 | beta2_t = beta2 ** state["step"] 62 | N_sma_max = 2 / (1 - beta2) - 1 63 | N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) 64 | buffered[1] = N_sma 65 | 66 | # more conservative since it's an approximated value 67 | if N_sma >= 5: 68 | radam_step_size = math.sqrt( 69 | (1 - beta2_t) 70 | * (N_sma - 4) 71 | / (N_sma_max - 4) 72 | * (N_sma - 2) 73 | / N_sma 74 | * N_sma_max 75 | / (N_sma_max - 2) 76 | ) / (1 - beta1 ** state["step"]) 77 | else: 78 | radam_step_size = 1.0 / (1 - beta1 ** state["step"]) 79 | buffered[2] = radam_step_size 80 | 81 | if group["weight_decay"] != 0: 82 | p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) 83 | 84 | # more conservative since it's an approximated value 85 | radam_step = p_data_fp32.clone() 86 | if N_sma >= 5: 87 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 88 | radam_step.addcdiv_(-radam_step_size * group["lr"], exp_avg, denom) 89 | else: 90 | radam_step.add_(-radam_step_size * group["lr"], exp_avg) 91 | 92 | radam_norm = radam_step.pow(2).sum().sqrt() 93 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 94 | if weight_norm == 0 or radam_norm == 0: 95 | trust_ratio = 1 96 | else: 97 | trust_ratio = weight_norm / radam_norm 98 | 99 | state["weight_norm"] = weight_norm 100 | state["adam_norm"] = radam_norm 101 | state["trust_ratio"] = trust_ratio 102 | 103 | if N_sma >= 5: 104 | p_data_fp32.addcdiv_( 105 | -radam_step_size * group["lr"] * trust_ratio, 106 | exp_avg, 107 | denom, 108 | ) 109 | else: 110 | p_data_fp32.add_(-radam_step_size * group["lr"] * trust_ratio, exp_avg) 111 | 112 | p.data.copy_(p_data_fp32) 113 | 114 | return loss 115 | -------------------------------------------------------------------------------- /lib/torch_utils/solver/sgdp.py: -------------------------------------------------------------------------------- 1 | """AdamP Copyright (c) 2020-present NAVER Corp. 2 | 3 | MIT license 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim.optimizer import Optimizer, required 10 | import math 11 | 12 | 13 | class SGDP(Optimizer): 14 | def __init__( 15 | self, 16 | params, 17 | lr=required, 18 | momentum=0, 19 | dampening=0, 20 | weight_decay=0, 21 | nesterov=False, 22 | eps=1e-8, 23 | delta=0.1, 24 | wd_ratio=0.1, 25 | ): 26 | defaults = dict( 27 | lr=lr, 28 | momentum=momentum, 29 | dampening=dampening, 30 | weight_decay=weight_decay, 31 | nesterov=nesterov, 32 | eps=eps, 33 | delta=delta, 34 | wd_ratio=wd_ratio, 35 | ) 36 | super(SGDP, self).__init__(params, defaults) 37 | 38 | def _channel_view(self, x): 39 | return x.view(x.size(0), -1) 40 | 41 | def _layer_view(self, x): 42 | return x.view(1, -1) 43 | 44 | def _cosine_similarity(self, x, y, eps, view_func): 45 | x = view_func(x) 46 | y = view_func(y) 47 | 48 | return F.cosine_similarity(x, y, dim=1, eps=eps).abs_() 49 | 50 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 51 | wd = 1 52 | expand_size = [-1] + [1] * (len(p.shape) - 1) 53 | for view_func in [self._channel_view, self._layer_view]: 54 | 55 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 56 | 57 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 58 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 59 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 60 | wd = wd_ratio 61 | 62 | return perturb, wd 63 | 64 | return perturb, wd 65 | 66 | def step(self, closure=None): 67 | loss = None 68 | if closure is not None: 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | momentum = group["momentum"] 73 | dampening = group["dampening"] 74 | nesterov = group["nesterov"] 75 | 76 | for p in group["params"]: 77 | if p.grad is None: 78 | continue 79 | grad = p.grad.data 80 | state = self.state[p] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state["momentum"] = torch.zeros_like(p.data) 85 | 86 | # SGD 87 | buf = state["momentum"] 88 | buf.mul_(momentum).add_(grad, alpha=1 - dampening) 89 | if nesterov: 90 | d_p = grad + momentum * buf 91 | else: 92 | d_p = buf 93 | 94 | # Projection 95 | wd_ratio = 1 96 | if len(p.shape) > 1: 97 | d_p, wd_ratio = self._projection( 98 | p, 99 | grad, 100 | d_p, 101 | group["delta"], 102 | group["wd_ratio"], 103 | group["eps"], 104 | ) 105 | 106 | # Weight decay 107 | if group["weight_decay"] > 0: 108 | p.data.mul_(1 - group["lr"] * group["weight_decay"] * wd_ratio / (1 - momentum)) 109 | 110 | # Step 111 | p.data.add_(d_p, alpha=-group["lr"]) 112 | 113 | return loss 114 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/utils/__init__.py -------------------------------------------------------------------------------- /lib/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from mmcv import Config 3 | 4 | 5 | def try_get_key(cfg, *keys, default=None): 6 | """# modified from detectron2 to also support mmcv Config. 7 | 8 | Try select keys from cfg until the first key that exists. Otherwise 9 | return default. 10 | """ 11 | from detectron2.config import CfgNode 12 | 13 | if isinstance(cfg, CfgNode): 14 | cfg = OmegaConf.create(cfg.dump()) 15 | elif isinstance(cfg, Config): # mmcv Config 16 | cfg = OmegaConf.create(cfg._cfg_dict.to_dict()) 17 | elif isinstance(cfg, dict): # raw dict 18 | cfg = OmegaConf.create(cfg) 19 | 20 | for k in keys: 21 | none = object() 22 | p = OmegaConf.select(cfg, k, default=none) 23 | if p is not none: 24 | return p 25 | return default 26 | -------------------------------------------------------------------------------- /lib/utils/fs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: fs.py from tensorpack 3 | import os 4 | from six.moves import urllib 5 | import errno 6 | import tqdm 7 | from . import logger 8 | from .utils import execute_only_once 9 | 10 | __all__ = ["mkdir_p", "download", "recursive_walk", "get_dataset_path"] 11 | 12 | 13 | def mkdir_p(dirname): 14 | """Like "mkdir -p", make a dir recursively, but do nothing if the dir 15 | exists. 16 | 17 | Args: 18 | dirname(str): 19 | """ 20 | assert dirname is not None 21 | if dirname == "" or os.path.isdir(dirname): 22 | return 23 | try: 24 | os.makedirs(dirname) 25 | except OSError as e: 26 | if e.errno != errno.EEXIST: 27 | raise e 28 | 29 | 30 | def download(url, dir, filename=None, expect_size=None): 31 | """Download URL to a directory. 32 | 33 | Will figure out the filename automatically from URL, if not given. 34 | """ 35 | mkdir_p(dir) 36 | if filename is None: 37 | filename = url.split("/")[-1] 38 | fpath = os.path.join(dir, filename) 39 | 40 | if os.path.isfile(fpath): 41 | if expect_size is not None and os.stat(fpath).st_size == expect_size: 42 | logger.info("File {} exists! Skip download.".format(filename)) 43 | return fpath 44 | else: 45 | logger.warning("File {} exists. Will overwrite with a new download!".format(filename)) 46 | 47 | def hook(t): 48 | last_b = [0] 49 | 50 | def inner(b, bsize, tsize=None): 51 | if tsize is not None: 52 | t.total = tsize 53 | t.update((b - last_b[0]) * bsize) 54 | last_b[0] = b 55 | 56 | return inner 57 | 58 | try: 59 | with tqdm.tqdm(unit="B", unit_scale=True, miniters=1, desc=filename) as t: 60 | fpath, _ = urllib.request.urlretrieve(url, fpath, reporthook=hook(t)) 61 | statinfo = os.stat(fpath) 62 | size = statinfo.st_size 63 | except IOError: 64 | logger.error("Failed to download {}".format(url)) 65 | raise 66 | assert size > 0, "Downloaded an empty file from {}!".format(url) 67 | 68 | if expect_size is not None and size != expect_size: 69 | logger.error("File downloaded from {} does not match the expected size!".format(url)) 70 | logger.error("You may have downloaded a broken file, or the upstream may have modified the file.") 71 | 72 | # TODO human-readable size 73 | logger.info("Succesfully downloaded " + filename + ". " + str(size) + " bytes.") 74 | return fpath 75 | 76 | 77 | def recursive_walk(rootdir): 78 | """ 79 | Yields: 80 | str: All files in rootdir, recursively. 81 | """ 82 | for r, dirs, files in os.walk(rootdir): 83 | for f in files: 84 | yield os.path.join(r, f) 85 | 86 | 87 | def get_dataset_path(*args): 88 | """Get the path to some dataset under ``$TENSORPACK_DATASET``. 89 | 90 | Args: 91 | args: strings to be joined to form path. 92 | 93 | Returns: 94 | str: path to the dataset. 95 | """ 96 | d = os.environ.get("TENSORPACK_DATASET", None) 97 | if d is None: 98 | d = os.path.join(os.path.expanduser("~"), "tensorpack_data") 99 | if execute_only_once(): 100 | logger.warning("Env var $TENSORPACK_DATASET not set, using {} for datasets.".format(d)) 101 | if not os.path.isdir(d): 102 | mkdir_p(d) 103 | logger.info("Created the directory {}.".format(d)) 104 | assert os.path.isdir(d), d 105 | return os.path.join(d, *args) 106 | 107 | 108 | if __name__ == "__main__": 109 | download("http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz", ".") 110 | -------------------------------------------------------------------------------- /lib/utils/time_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import timeit 3 | import time 4 | from datetime import datetime, timedelta 5 | from . import logger 6 | 7 | 8 | def average_time_of_func(func, num_iter=10000, ms=False): 9 | """ 10 | ms: if True, xxx ms/iter 11 | """ 12 | duration = timeit.timeit(func, number=num_iter) 13 | avg_time = duration / num_iter 14 | if ms: 15 | avg_time *= 1000 16 | logger.info("{} {} ms/iter".format(func.__name__, avg_time)) 17 | else: 18 | logger.info("{} {} s/iter".format(func.__name__, avg_time)) 19 | return avg_time 20 | 21 | 22 | def my_timeit(func, number=100000): 23 | tic = time.perf_counter() 24 | for i in range(number): 25 | func() 26 | return time.perf_counter() - tic 27 | 28 | 29 | def get_time_str(fmt="%Y%m%d_%H%M%S", hours_offset=8): 30 | # get UTC+8 time by default 31 | # set hours_offset to 0 to get UTC time 32 | # use utc time to avoid the problem of mis-configured timezone on some machines 33 | return (datetime.utcnow() + timedelta(hours=hours_offset)).strftime(fmt) 34 | 35 | 36 | # def get_time_str(fmt='%Y%m%d_%H%M%S'): 37 | # # from mmcv.runner import get_time_str 38 | # return time.strftime(fmt, time.localtime()) # defined in mmcv 39 | 40 | 41 | def get_time_delta(sec): 42 | """Humanize timedelta given in seconds, modified from maskrcnn- 43 | benchmark.""" 44 | if sec < 0: 45 | logger.warning("get_time_delta() obtains negative seconds!") 46 | return "{:.3g} seconds".format(sec) 47 | delta_time_str = str(timedelta(seconds=sec)) 48 | return delta_time_str 49 | 50 | 51 | class Timer(object): 52 | # modified from maskrcnn-benchmark 53 | def __init__(self): 54 | self.reset() 55 | 56 | @property 57 | def average_time(self): 58 | return self.total_time / self.calls if self.calls > 0 else 0.0 59 | 60 | def tic(self): 61 | # using time.time instead of time.clock because time time.clock 62 | # does not normalize for multithreading 63 | self.start_time = time.perf_counter() 64 | 65 | def toc(self, average=True): 66 | self.add(time.perf_counter() - self.start_time) 67 | if average: 68 | return self.average_time 69 | else: 70 | return self.diff 71 | 72 | def add(self, time_diff): 73 | self.diff = time_diff 74 | self.total_time += self.diff 75 | self.calls += 1 76 | 77 | def reset(self): 78 | self.total_time = 0.0 79 | self.calls = 0 80 | self.start_time = 0.0 81 | self.diff = 0.0 82 | 83 | def avg_time_str(self): 84 | time_str = get_time_delta(self.average_time) 85 | return time_str 86 | 87 | 88 | def humanize_time_delta(sec): 89 | """Humanize timedelta given in seconds 90 | Args: 91 | sec (float): time difference in seconds. Must be positive. 92 | Returns: 93 | str - time difference as a readable string 94 | Example: 95 | .. code-block:: python 96 | print(humanize_time_delta(1)) # 1 second 97 | print(humanize_time_delta(60 + 1)) # 1 minute 1 second 98 | print(humanize_time_delta(87.6)) # 1 minute 27 seconds 99 | print(humanize_time_delta(0.01)) # 0.01 seconds 100 | print(humanize_time_delta(60 * 60 + 1)) # 1 hour 1 second 101 | print(humanize_time_delta(60 * 60 * 24 + 1)) # 1 day 1 second 102 | print(humanize_time_delta(60 * 60 * 24 + 60 * 2 + 60*60*9 + 3)) # 1 day 9 hours 2 minutes 3 seconds 103 | """ 104 | if sec < 0: 105 | logger.warning("humanize_time_delta() obtains negative seconds!") 106 | return "{:.3g} seconds".format(sec) 107 | if sec == 0: 108 | return "0 second" 109 | _time = datetime(2000, 1, 1) + timedelta(seconds=int(sec)) 110 | units = ["day", "hour", "minute", "second"] 111 | vals = [int(sec // 86400), _time.hour, _time.minute, _time.second] 112 | if sec < 60: 113 | vals[-1] = sec 114 | 115 | def _format(v, u): 116 | return "{:.3g} {}{}".format(v, u, "s" if v > 1 else "") 117 | 118 | ans = [] 119 | for v, u in zip(vals, units): 120 | if v > 0: 121 | ans.append(_format(v, u)) 122 | return " ".join(ans) 123 | -------------------------------------------------------------------------------- /lib/vis_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/vis_utils/__init__.py -------------------------------------------------------------------------------- /lib/vis_utils/colormap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | """An awesome colormap for really neat visualizations. 3 | 4 | Copied from Detectron, and removed gray colors. 5 | """ 6 | 7 | import numpy as np 8 | 9 | __all__ = ["colormap", "random_color"] 10 | 11 | # yapf: disable 12 | # RGB: 13 | _COLORS = np.array( 14 | [ 15 | 0.000, 0.447, 0.741, 16 | 0.850, 0.325, 0.098, 17 | 0.929, 0.694, 0.125, 18 | 0.494, 0.184, 0.556, 19 | 0.466, 0.674, 0.188, 20 | 0.301, 0.745, 0.933, 21 | 0.635, 0.078, 0.184, 22 | 0.300, 0.300, 0.300, 23 | 0.600, 0.600, 0.600, 24 | 1.000, 0.000, 0.000, 25 | 1.000, 0.500, 0.000, 26 | 0.749, 0.749, 0.000, 27 | 0.000, 1.000, 0.000, 28 | 0.000, 0.000, 1.000, 29 | 0.667, 0.000, 1.000, 30 | 0.333, 0.333, 0.000, 31 | 0.333, 0.667, 0.000, 32 | 0.333, 1.000, 0.000, 33 | 0.667, 0.333, 0.000, 34 | 0.667, 0.667, 0.000, 35 | 0.667, 1.000, 0.000, 36 | 1.000, 0.333, 0.000, 37 | 1.000, 0.667, 0.000, 38 | 1.000, 1.000, 0.000, 39 | 0.000, 0.333, 0.500, 40 | 0.000, 0.667, 0.500, 41 | 0.000, 1.000, 0.500, 42 | 0.333, 0.000, 0.500, 43 | 0.333, 0.333, 0.500, 44 | 0.333, 0.667, 0.500, 45 | 0.333, 1.000, 0.500, 46 | 0.667, 0.000, 0.500, 47 | 0.667, 0.333, 0.500, 48 | 0.667, 0.667, 0.500, 49 | 0.667, 1.000, 0.500, 50 | 1.000, 0.000, 0.500, 51 | 1.000, 0.333, 0.500, 52 | 1.000, 0.667, 0.500, 53 | 1.000, 1.000, 0.500, 54 | 0.000, 0.333, 1.000, 55 | 0.000, 0.667, 1.000, 56 | 0.000, 1.000, 1.000, 57 | 0.333, 0.000, 1.000, 58 | 0.333, 0.333, 1.000, 59 | 0.333, 0.667, 1.000, 60 | 0.333, 1.000, 1.000, 61 | 0.667, 0.000, 1.000, 62 | 0.667, 0.333, 1.000, 63 | 0.667, 0.667, 1.000, 64 | 0.667, 1.000, 1.000, 65 | 1.000, 0.000, 1.000, 66 | 1.000, 0.333, 1.000, 67 | 1.000, 0.667, 1.000, 68 | 0.333, 0.000, 0.000, 69 | 0.500, 0.000, 0.000, 70 | 0.667, 0.000, 0.000, 71 | 0.833, 0.000, 0.000, 72 | 1.000, 0.000, 0.000, 73 | 0.000, 0.167, 0.000, 74 | 0.000, 0.333, 0.000, 75 | 0.000, 0.500, 0.000, 76 | 0.000, 0.667, 0.000, 77 | 0.000, 0.833, 0.000, 78 | 0.000, 1.000, 0.000, 79 | 0.000, 0.000, 0.167, 80 | 0.000, 0.000, 0.333, 81 | 0.000, 0.000, 0.500, 82 | 0.000, 0.000, 0.667, 83 | 0.000, 0.000, 0.833, 84 | 0.000, 0.000, 1.000, 85 | 0.000, 0.000, 0.000, 86 | 0.143, 0.143, 0.143, 87 | 0.857, 0.857, 0.857, 88 | 1.000, 1.000, 1.000 89 | ] 90 | ).astype(np.float32).reshape(-1, 3) 91 | # yapf: enable 92 | 93 | 94 | def colormap(rgb=False, maximum=255): 95 | """ 96 | Args: 97 | rgb (bool): whether to return RGB colors or BGR colors. 98 | maximum (int): either 255 or 1 99 | 100 | Returns: 101 | ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1] 102 | """ 103 | assert maximum in [255, 1], maximum 104 | c = _COLORS * maximum 105 | if not rgb: 106 | c = c[:, ::-1] 107 | return c 108 | 109 | 110 | def random_color(rgb=False, maximum=255): 111 | """ 112 | Args: 113 | rgb (bool): whether to return RGB colors or BGR colors. 114 | maximum (int): either 255 or 1 115 | 116 | Returns: 117 | ndarray: a vector of 3 numbers 118 | """ 119 | idx = np.random.randint(0, len(_COLORS)) 120 | ret = _COLORS[idx] * maximum 121 | if not rgb: 122 | ret = ret[::-1] 123 | return ret 124 | 125 | 126 | if __name__ == "__main__": 127 | import cv2 128 | 129 | size = 100 130 | H, W = 10, 10 131 | canvas = np.random.rand(H * size, W * size, 3).astype("float32") 132 | for h in range(H): 133 | for w in range(W): 134 | idx = h * W + w 135 | if idx >= len(_COLORS): 136 | break 137 | canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx] 138 | cv2.imshow("a", canvas) 139 | cv2.waitKey(0) 140 | -------------------------------------------------------------------------------- /lib/vis_utils/optflow.py: -------------------------------------------------------------------------------- 1 | """visulization of optical flow (optflow)""" 2 | from __future__ import division 3 | 4 | import numpy as np 5 | 6 | from mmcv.image import rgb2bgr 7 | from mmcv.video import flowread 8 | from .image import imshow # cv2 imshow 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def flowshow(flow, win_name="", wait_time=0, vis_tool="matplotlib"): 13 | """Show optical flow. 14 | 15 | Args: 16 | flow (ndarray or str): The optical flow to be displayed. 17 | win_name (str): The window name. 18 | wait_time (int): Value of waitKey param. 19 | """ 20 | flow = flowread(flow) # HxWx2 21 | flow_img = flow2rgb(flow) 22 | if vis_tool == "matplotlib": 23 | fig = plt.figure(frameon=False, figsize=(8, 6), dpi=100) 24 | tmp = fig.add_subplot(1, 1, 1) 25 | tmp.set_title("{}".format(win_name)) 26 | plt.axis("off") 27 | plt.imshow(flow_img) 28 | plt.show() 29 | else: 30 | imshow(rgb2bgr(flow_img), win_name, wait_time) 31 | 32 | 33 | def flow2rgb(flow, color_wheel=None, unknown_thr=1e6): 34 | # NOTE: the same as flowlib.get_flow_show(flow, mode='Y') 35 | """Convert flow map to RGB image. 36 | 37 | Args: 38 | flow (ndarray): Array of optical flow. 39 | color_wheel (ndarray or None): Color wheel used to map flow field to 40 | RGB colorspace. Default color wheel will be used if not specified. 41 | unknown_thr (str): Values above this threshold will be marked as 42 | unknown and thus ignored. 43 | Returns: 44 | ndarray: RGB image that can be visualized. 45 | """ 46 | assert flow.ndim == 3 and flow.shape[-1] == 2, flow.shape 47 | if color_wheel is None: 48 | color_wheel = make_color_wheel() 49 | assert color_wheel.ndim == 2 and color_wheel.shape[1] == 3 50 | num_bins = color_wheel.shape[0] 51 | 52 | dx = flow[:, :, 0].copy() 53 | dy = flow[:, :, 1].copy() 54 | 55 | ignore_inds = np.isnan(dx) | np.isnan(dy) | (np.abs(dx) > unknown_thr) | (np.abs(dy) > unknown_thr) 56 | dx[ignore_inds] = 0 57 | dy[ignore_inds] = 0 58 | 59 | rad = np.sqrt(dx**2 + dy**2) 60 | if np.any(rad > np.finfo(float).eps): 61 | max_rad = np.max(rad) 62 | dx /= max_rad 63 | dy /= max_rad 64 | 65 | [h, w] = dx.shape 66 | 67 | rad = np.sqrt(dx**2 + dy**2) 68 | angle = np.arctan2(-dy, -dx) / np.pi 69 | 70 | bin_real = (angle + 1) / 2 * (num_bins - 1) 71 | bin_left = np.floor(bin_real).astype(int) 72 | bin_right = (bin_left + 1) % num_bins 73 | w = (bin_real - bin_left.astype(np.float32))[..., None] 74 | flow_img = (1 - w) * color_wheel[bin_left, :] + w * color_wheel[bin_right, :] 75 | small_ind = rad <= 1 76 | flow_img[small_ind] = 1 - rad[small_ind, None] * (1 - flow_img[small_ind]) 77 | flow_img[np.logical_not(small_ind)] *= 0.75 78 | 79 | flow_img[ignore_inds, :] = 0 80 | 81 | return flow_img 82 | 83 | 84 | def make_color_wheel(bins=None): 85 | """Build a color wheel. 86 | 87 | Args: 88 | bins(list or tuple, optional): Specify the number of bins for each 89 | color range, corresponding to six ranges: red -> yellow, 90 | yellow -> green, green -> cyan, cyan -> blue, blue -> magenta, 91 | magenta -> red. [15, 6, 4, 11, 13, 6] is used for default 92 | (see Middlebury). 93 | Returns: 94 | ndarray: Color wheel of shape (total_bins, 3). 95 | """ 96 | if bins is None: 97 | bins = [15, 6, 4, 11, 13, 6] 98 | assert len(bins) == 6 99 | 100 | RY, YG, GC, CB, BM, MR = tuple(bins) 101 | 102 | ry = [1, np.arange(RY) / RY, 0] 103 | yg = [1 - np.arange(YG) / YG, 1, 0] 104 | gc = [0, 1, np.arange(GC) / GC] 105 | cb = [0, 1 - np.arange(CB) / CB, 1] 106 | bm = [np.arange(BM) / BM, 0, 1] 107 | mr = [1, 0, 1 - np.arange(MR) / MR] 108 | 109 | num_bins = RY + YG + GC + CB + BM + MR 110 | 111 | color_wheel = np.zeros((3, num_bins), dtype=np.float32) 112 | 113 | col = 0 114 | for i, color in enumerate([ry, yg, gc, cb, bm, mr]): 115 | for j in range(3): 116 | color_wheel[j, col : col + bins[i]] = color[j] 117 | col += bins[i] 118 | 119 | return color_wheel.T 120 | -------------------------------------------------------------------------------- /output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e/model_final_wo_optim-82cf930e.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e/model_final_wo_optim-82cf930e.pth -------------------------------------------------------------------------------- /preprocess/shape_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally writen by https://github.com/mentian/object-deformnet/ 3 | """ 4 | 5 | import h5py 6 | import numpy as np 7 | import torch.utils.data as data 8 | 9 | 10 | class ShapeDataset(data.Dataset): 11 | def __init__(self, h5_file, mode, n_points=2048, augment=False): 12 | assert mode == "train" or mode == "val", 'Mode must be "train" or "val".' 13 | self.mode = mode 14 | self.n_points = n_points 15 | self.augment = augment 16 | # load data from h5py file 17 | with h5py.File(h5_file, "r") as f: 18 | self.length = f[self.mode].attrs["len"] 19 | self.data = f[self.mode]["data"][:] 20 | self.label = f[self.mode]["label"][:] 21 | # augmentation parameters 22 | self.sigma = 0.01 23 | self.clip = 0.02 24 | self.shift_range = 0.02 25 | 26 | def __len__(self): 27 | return self.length 28 | 29 | def __getitem__(self, index): 30 | xyz = self.data[index] 31 | label = self.label[index] - 1 # data saved indexed from 1 32 | # randomly downsample 33 | np_data = xyz.shape[0] 34 | assert np_data >= self.n_points, "Not enough points in shape." 35 | idx = np.random.choice(np_data, self.n_points) 36 | xyz = xyz[idx, :] 37 | # data augmentation 38 | if self.augment: 39 | jitter = np.clip(self.sigma * np.random.randn(self.n_points, 3), -self.clip, self.clip) 40 | xyz[:, :3] += jitter 41 | shift = np.random.uniform(-self.shift_range, self.shift_range, (1, 3)) 42 | xyz[:, :3] += shift 43 | return xyz, label 44 | -------------------------------------------------------------------------------- /ref/__init__.py: -------------------------------------------------------------------------------- 1 | from . import nocs 2 | -------------------------------------------------------------------------------- /ref/cmra.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """This file includes necessary params, info.""" 3 | import os 4 | import mmcv 5 | import os.path as osp 6 | 7 | import numpy as np 8 | 9 | # ---------------------------------------------------------------- # 10 | # ROOT PATH INFO 11 | # ---------------------------------------------------------------- # 12 | cur_dir = osp.abspath(osp.dirname(__file__)) 13 | root_dir = osp.normpath(osp.join(cur_dir, "..")) 14 | # directory storing experiment data (result, model checkpoints, etc). 15 | output_dir = osp.join(root_dir, "output") 16 | 17 | data_root = osp.join(root_dir, "datasets") 18 | 19 | # ---------------------------------------------------------------- # 20 | # NOCS DATASET 21 | # ---------------------------------------------------------------- # 22 | dataset_root = osp.join(data_root, "NOCS/") 23 | train_dir = osp.join(dataset_root, "CAMERA/") 24 | 25 | model_dir = osp.join(dataset_root, "obj_models") 26 | mean_model_path = osp.join(model_dir, "cr_normed_mean_model_points_spd.pkl") 27 | train_model_path = osp.join(model_dir, "camera_train.pkl") 28 | test_model_path = osp.join(model_dir, "camera_val.pkl") 29 | 30 | # object info 31 | objects = ["bottle", "bowl", "camera", "can", "laptop", "mug"] 32 | 33 | obj2id = {"bottle": 1, "bowl": 2, "camera": 3, "can": 4, "laptop": 5, "mug": 6} 34 | 35 | obj_num = len(objects) 36 | 37 | id2obj = {_id: _name for _name, _id in obj2id.items()} 38 | 39 | # id2obj_camera = {1: "02876657", 2: "02880940", 3: "02942699", 4: "02946921", 5: "03642806", 6: "03797390"} 40 | # objects_camera = list(id2obj_camera.values()) 41 | # obj2id_camera = {_name: _id for _id, _name in id2obj_camera.items()} 42 | 43 | # Camera info 44 | width = 640 45 | height = 480 46 | center = (height / 2, width / 2) 47 | 48 | intrinsics = np.array([[577.5, 0, 319.5], [0, 577.5, 239.5], [0, 0, 1]], dtype=np.float32) # [fx, fy, cx, cy] 49 | 50 | mean_scale = { 51 | "bottle": 0.001 * np.array([81, 218.5, 80.25], dtype=np.float32), 52 | "bowl": 0.001 * np.array([168.75, 67.75, 168.75], dtype=np.float32), 53 | "camera": 0.001 * np.array([116.0, 121.75, 175.5], dtype=np.float32), 54 | "can": 0.001 * np.array([112.5, 188.25, 115.0], dtype=np.float32), 55 | "laptop": 0.001 * np.array([145.25, 111.25, 168.0], dtype=np.float32), 56 | "mug": 0.001 * np.array([167.5, 135.0, 124.25], dtype=np.float32), 57 | } 58 | 59 | 60 | def get_mean_bbox3d(): 61 | mean_bboxes = {} 62 | for key, value in intrinsics.items(): 63 | minx, maxx = -value[0] / 2, value[0] / 2 64 | miny, maxy = -value[1] / 2, value[1] / 2 65 | minz, maxz = -value[2] / 2, value[2] / 2 66 | 67 | mean_bboxes[key] = np.array( 68 | [ 69 | [maxx, maxy, maxz], 70 | [minx, maxy, maxz], 71 | [minx, miny, maxz], 72 | [maxx, miny, maxz], 73 | [maxx, maxy, minz], 74 | [minx, maxy, minz], 75 | [minx, miny, minz], 76 | [maxx, miny, minz], 77 | ], 78 | dtype=np.float32, 79 | ) 80 | return mean_bboxes 81 | 82 | 83 | def get_sym_info(obj_name, mug_handle=1): 84 | # Y axis points upwards, x axis pass through the handle, z axis otherwise 85 | # return sym axis 86 | if obj_name == "bottle": 87 | sym = np.array([0, 1, 0], dtype=np.int) 88 | elif obj_name == "bowl": 89 | sym = np.array([0, 1, 0], dtype=np.int) 90 | elif obj_name == "camera": 91 | sym = None 92 | elif obj_name == "can": 93 | sym = np.array([0, 1, 0], dtype=np.int) 94 | elif obj_name == "laptop": 95 | sym = None 96 | elif obj_name == "mug": 97 | if mug_handle == 1: 98 | sym = None 99 | else: 100 | sym = np.array([0, 1, 0], dtype=np.int) 101 | else: 102 | raise NotImplementedError(f"No such a object class {obj_name}") 103 | return sym 104 | 105 | 106 | def get_fps_points(): 107 | """key is inst_name generated by 108 | core/catre/tools/nocs/nocs_fps_sample.py.""" 109 | fps_points_path = osp.join(model_dir, "cmra_fps_points_spd.pkl") 110 | assert osp.exists(fps_points_path), fps_points_path 111 | fps_dict = mmcv.load(fps_points_path) 112 | return fps_dict 113 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | plyfile 3 | 4 | ujson # make json loading in cocoapi_nvidia faster 5 | pycocotools 6 | 7 | cffi 8 | ninja 9 | black 10 | docformatter 11 | setproctitle 12 | fastfunc 13 | meshplex 14 | OpenEXR 15 | vispy>=0.6.4 16 | yacs>=0.1.8 17 | tabulate 18 | pytest-runner 19 | pytest 20 | ipdb 21 | tqdm 22 | numba 23 | mmcv-full 24 | imagecorruptions 25 | pyassimp==4.1.3 # 4.1.4 will cause egl_renderer SegmentFault 26 | pypng 27 | imgaug>=0.4.0 28 | albumentations 29 | transforms3d 30 | # pyquaternion 31 | torchvision 32 | open3d 33 | fvcore 34 | tensorboardX 35 | einops 36 | pytorch3d 37 | # timm # pytorch-image-models 38 | # git+ssh://git@github.com/rwightman/pytorch-image-models.git # the latest timm 39 | glfw 40 | imageio 41 | imageio-ffmpeg 42 | PyOpenGL # >=3.1.5 43 | PyOpenGL_accelerate 44 | chardet 45 | h5py 46 | 47 | thop # https://github.com/Lyken17/pytorch-OpCounter 48 | loguru 49 | 50 | pytorch-lightning # tested for 1.6.0.dev0 51 | fairscale 52 | deepspeed 53 | 54 | # verified versions 55 | onnx==1.8.1 56 | onnxruntime==1.8.0 57 | onnx-simplifier==0.3.5 58 | -------------------------------------------------------------------------------- /scripts/install_deps.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # some other dependencies 3 | set -x 4 | install=${1:-"all"} 5 | 6 | if test "$install" = "all"; then 7 | echo "Installing apt dependencies" 8 | sudo apt-get install -y libjpeg-dev zlib1g-dev 9 | sudo apt-get install -y libopenexr-dev 10 | sudo apt-get install -y openexr 11 | sudo apt-get install -y python3-dev 12 | sudo apt-get install -y libglfw3-dev libglfw3 13 | sudo apt-get install -y libglew-dev 14 | sudo apt-get install -y libassimp-dev 15 | sudo apt-get install -y libnuma-dev # for byteps 16 | sudo apt install -y clang 17 | ## for bop cpp renderer 18 | sudo apt install -y curl 19 | sudo apt install -y autoconf 20 | sudo apt-get install -y build-essential libtool 21 | 22 | ## for uncertainty pnp 23 | sudo apt-get install -y libeigen3-dev 24 | sudo apt-get install -y libgoogle-glog-dev 25 | sudo apt-get install -y libsuitesparse-dev 26 | sudo apt-get install -y libatlas-base-dev 27 | 28 | ## for nvdiffrast/egl 29 | sudo apt-get install -y --no-install-recommends \ 30 | cmake curl pkg-config 31 | sudo apt-get install -y --no-install-recommends \ 32 | libgles2 \ 33 | libgl1-mesa-dev \ 34 | libegl1-mesa-dev \ 35 | libgles2-mesa-dev 36 | # (only available for Ubuntu >= 18.04) 37 | sudo apt-get install -y --no-install-recommends \ 38 | libglvnd0 \ 39 | libgl1 \ 40 | libglx0 \ 41 | libegl1 \ 42 | libglvnd-dev 43 | 44 | sudo apt-get install -y libglew-dev 45 | # for GLEW, add this into ~/.bashrc 46 | # export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH 47 | fi 48 | 49 | pip install -r requirements/requirements.txt 50 | 51 | pip uninstall pillow 52 | CC="cc -mavx2" pip install -U --force-reinstall pillow-simd 53 | --------------------------------------------------------------------------------