├── .gitignore
├── README.md
├── assets
└── network.png
├── configs
├── _base_
│ ├── catre_base.py
│ └── common_base.py
└── catre
│ └── NOCS_REAL
│ ├── aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e.py
│ └── aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e_initspd.py
├── core
├── __init__.py
├── base_data_loader.py
├── catre
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── cmra.py
│ │ ├── data_loader.py
│ │ ├── dataset_factory.py
│ │ └── nocs.py
│ ├── engine
│ │ ├── __init__.py
│ │ ├── batch_test.py
│ │ ├── batching.py
│ │ ├── catre_custom_evaluator.py
│ │ ├── catre_evaluator.py
│ │ ├── engine.py
│ │ ├── engine_utils.py
│ │ └── test_utils.py
│ ├── losses
│ │ ├── l2_loss.py
│ │ ├── pm_loss.py
│ │ └── rot_loss.py
│ ├── main_catre.py
│ ├── models
│ │ ├── CATRE_disR_shared.py
│ │ ├── heads
│ │ │ ├── conv_out_per_rot_head.py
│ │ │ └── fc_trans_size_head.py
│ │ ├── model_utils.py
│ │ ├── net_factory.py
│ │ ├── pointnets
│ │ │ └── pointnet.py
│ │ └── pose_scale_from_delta_init.py
│ ├── test_catre.sh
│ ├── tools
│ │ ├── camera25_prepare_spd_init_results.py
│ │ └── prepare_spd_init_results.py
│ └── train_catre.sh
└── utils
│ ├── __init__.py
│ ├── apex_trainer.py
│ ├── augment.py
│ ├── camera_geometry.py
│ ├── cat_data_utils.py
│ ├── data_utils.py
│ ├── dataset_utils.py
│ ├── default_args_setup.py
│ ├── depth_aug.py
│ ├── depth_image_smoothing.py
│ ├── edge_utils.py
│ ├── farthest_points_torch.py
│ ├── lie_algebra.py
│ ├── my_checkpoint.py
│ ├── my_comm.py
│ ├── my_distributed_sampler.py
│ ├── my_setup.py
│ ├── my_visualizer.py
│ ├── my_writer.py
│ ├── pose_aug.py
│ ├── pose_utils.py
│ ├── quaternion_lf.py
│ ├── rot_reps.py
│ ├── solver_utils.py
│ ├── ssd_color_transform.py
│ ├── timm_utils.py
│ ├── utils.py
│ └── zoom_utils.py
├── datasets
└── NOCS
│ ├── REAL
│ ├── real_test
│ └── real_train
│ ├── obj_models
│ ├── abs_scale.pkl
│ ├── cr_normed_mean_model_points_spd.pkl
│ ├── mug_handle.pkl
│ ├── mug_meta.pkl
│ ├── real_test_spd.pkl
│ └── real_train_spd.pkl
│ └── test_init_poses
│ ├── init_pose_dualposenet_nocs_real.json
│ ├── init_pose_nocs_real.json
│ └── init_pose_spd_nocs_real.json
├── docs
└── INSTALL.md
├── lib
├── __init__.py
├── pysixd
│ ├── RT_transform.py
│ ├── __init__.py
│ ├── colors.json
│ ├── comparative_report.py
│ ├── config.py
│ ├── dataset_params.py
│ ├── dataset_params_sixd.py
│ ├── eval_calc_errors.py
│ ├── eval_loc.py
│ ├── eval_loc_origin.py
│ ├── eval_plots.py
│ ├── eval_utils.py
│ ├── inout.py
│ ├── misc.py
│ ├── pose_error.py
│ ├── pose_error_more.py
│ ├── pose_matching.py
│ ├── pycoco_utils.py
│ ├── renderer.py
│ ├── renderer_cpp.py
│ ├── renderer_glumpy.py
│ ├── renderer_py.py
│ ├── renderer_pyrender.py
│ ├── renderer_vispy.py
│ ├── score.py
│ ├── se3.py
│ ├── test_set_bb8_sixd.yml
│ ├── transform.py
│ ├── uv_projection.py
│ ├── view_sampler.py
│ ├── visibility.py
│ └── visualization.py
├── structures
│ ├── __init__.py
│ ├── centers_2d.py
│ ├── keypoints_2d.py
│ ├── keypoints_3d.py
│ ├── my_list.py
│ ├── my_maps.py
│ ├── my_masks.py
│ ├── poses.py
│ ├── quats.py
│ ├── rots.py
│ └── translations.py
├── torch_utils
│ ├── __init__.py
│ ├── color
│ │ ├── __init__.py
│ │ ├── gray.py
│ │ ├── hls.py
│ │ ├── hsv.py
│ │ ├── lab.py
│ │ ├── luv.py
│ │ ├── rgb.py
│ │ ├── xyz.py
│ │ ├── ycbcr.py
│ │ └── yuv.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── acon.py
│ │ ├── activations_me.py
│ │ ├── conv_module.py
│ │ ├── coord_attention.py
│ │ ├── dropblock
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── dropblock.py
│ │ │ └── scheduler.py
│ │ ├── layer_utils.py
│ │ ├── mean_conv_deconv.py
│ │ └── std_conv_transpose.py
│ ├── misc.py
│ ├── solver
│ │ ├── AdaBelief.py
│ │ ├── __init__.py
│ │ ├── adamp.py
│ │ ├── badam.py
│ │ ├── grad_clip_d2.py
│ │ ├── lookahead.py
│ │ ├── lr_scheduler.py
│ │ ├── madgrad.py
│ │ ├── nadamw.py
│ │ ├── optimize.py
│ │ ├── over9000.py
│ │ ├── radam.py
│ │ ├── ralamb.py
│ │ ├── ranger.py
│ │ ├── ranger2020.py
│ │ ├── ranger21.py
│ │ ├── ranger_adabelief.py
│ │ ├── rmsprop_tf.py
│ │ ├── sgd_gc.py
│ │ └── sgdp.py
│ └── torch_utils.py
├── utils
│ ├── __init__.py
│ ├── bbox_utils.py
│ ├── config_utils.py
│ ├── fs.py
│ ├── is_binary_file.py
│ ├── logger.py
│ ├── mask_utils.py
│ ├── setup_logger.py
│ ├── setup_logger_loguru.py
│ ├── time_utils.py
│ └── utils.py
└── vis_utils
│ ├── __init__.py
│ ├── cmap_plt2cv.py
│ ├── colormap.py
│ ├── image.py
│ └── optflow.py
├── output
└── catre
│ └── NOCS_REAL
│ └── aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e
│ └── model_final_wo_optim-82cf930e.pth
├── preprocess
├── pose_data.py
└── shape_dataset.py
├── ref
├── __init__.py
├── cmra.py
└── nocs.py
├── requirements
└── requirements.txt
└── scripts
└── install_deps.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | *.so.*
3 | *.tar.gz
4 | *.egg-info*
5 |
6 | *.ttf
7 | *.jpg
8 |
9 |
10 | # compilation and distribution
11 | __pycache__
12 | _ext
13 | *.pyc
14 | *.so
15 | detectron2.egg-info/
16 | build/
17 | dist/
18 | .cache/
19 |
20 | # pytorch/python/numpy formats
21 | #*.pth
22 | #*.pkl
23 | *.npy
24 | *.engine
25 | *.onnx
26 | events.out.tfevents*
27 |
28 | # ipython/jupyter notebooks
29 | *.ipynb
30 | **/.ipynb_checkpoints/
31 |
32 | # Editor temporaries
33 | *.swn
34 | *.swo
35 | *.swp
36 | *~
37 |
38 | # Pycharm editor settings
39 | .idea
40 |
41 | # VSCode editor settings
42 | .vscode
43 |
44 | # project dirs
45 | /models
46 |
47 | #external/*
48 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CATRE
2 |
3 | This repo provides for the implementation of the ECCV'22 paper:
4 |
5 | **CATRE: Iterative Point Clouds Alignment for Category-level Object Pose Refinement**
6 | [[arXiv](https://arxiv.org/abs/2207.08082)][[Video](https://www.bilibili.com/video/BV1e8411e7Jt/?share_source=copy_web)]
7 |
8 | ## Overview
9 |
10 | 
11 |
12 |
13 | ## Dependencies
14 |
15 | See [INSTALL.md](./docs/INSTALL.md)
16 |
17 | ## Datasets
18 |
19 | Prepare datasets folder like this:
20 | ```bash
21 | datasets/
22 | ├── NOCS
23 | ├──REAL
24 | ├── real_test # download from http://download.cs.stanford.edu/orion/nocs/real_test.zip
25 | ├── real_train # download from http://download.cs.stanford.edu/orion/nocs/real_train.zip
26 | └── image_set # generate from pose_data.py
27 | ├──gts # download from http://download.cs.stanford.edu/orion/nocs/gts.zip
28 | └── real_test
29 | ├──test_init_poses # we provide
30 | └──object_models # we provide some necesarry files, complete files can be download from http://download.cs.stanford.edu/orion/nocs/obj_models.zip
31 | ```
32 |
33 | Run python scripts to prepare the datasets. (Modified from https://github.com/mentian/object-deformnet)
34 | ```bash
35 | # NOTE: this code will directly modify the data
36 | cd $ROOT/preprocess
37 | python pose_data.py
38 | ```
39 |
40 | ## Reproduce the results
41 |
42 | The trained model has been saved at `output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e/model_final_wo_optim-82cf930e.pth`. Run the following command to reproduce the results:
43 |
44 | ```
45 | ./core/catre/test_catre.sh configs/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e.py 1 output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e/model_final_wo_optim-82cf930e.pth
46 | ```
47 |
48 | ## NOTE
49 |
50 | **NOTE** that there is a small bug in the original evaluation [code](https://github.com/hughw19/NOCS_CVPR2019/blob/78a31c2026a954add1a2711286ff45ce1603b8ab/utils.py#L252) of NOCS w.r.t. IOU. We fixed this bug in our evaluation code and re-evaluated all the compared methods in the paper (we only revised the value of IOU and kept rotation/translation results the same, but indeed the accuracy of R/t will also change a little bit). See the revised [code](https://github.com/THU-DA-6D-Pose-Group/CATRE/blob/b649cbad6ed2121b22a37f7fe16ad923688d4995/core/catre/engine/test_utils.py#L158) for details. Also thanks [Peng et al.](https://github.com/swords123/SSC-6D/blob/bb0dcd5e5b789ea2a80c6c3fa16ccc2bf0a445d1/eval/utils.py#L114) for further confirming this bug.
51 |
52 | ## Training
53 |
54 | ```
55 | ./core/catre/train_catre.sh configs/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e.py (other args)
56 | ```
57 |
58 | ## Testing
59 | ```
60 | ./core/catre/test_catre.sh configs/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e.py (other args)
61 | ```
62 |
63 | ## Citation
64 | If you find this repo useful in your research, please consider citing:
65 | ```
66 | @InProceedings{liu_2022_catre,
67 | title = {{CATRE:} Iterative Point Clouds Alignment for Category-level Object Pose Refinement},
68 | author = {Liu, Xingyu and Wang, Gu and Li, Yi and Ji, Xiangyang},
69 | booktitle = {European Conference on Computer Vision (ECCV)},
70 | month = {October},
71 | year = {2022}
72 | }
73 | ```
74 |
--------------------------------------------------------------------------------
/assets/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/assets/network.png
--------------------------------------------------------------------------------
/configs/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e.py:
--------------------------------------------------------------------------------
1 | _base_ = ["../../_base_/catre_base.py"]
2 |
3 | # fix mug mean shape
4 | OUTPUT_DIR = "output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e"
5 | INPUT = dict(
6 | COLOR_AUG_PROB=0.0,
7 | DEPTH_SAMPLE_BALL_RATIO=0.6,
8 | # NOTE: used points!
9 | BBOX_TYPE_TEST="est", # from_pose | est | gt | gt_aug (TODO)
10 | INIT_POSE_TYPE_TRAIN=["gt_noise"], # gt_noise | random | canonical
11 | NOISE_ROT_STD_TRAIN=(10, 5, 2.5, 1.25), # randomly choose one
12 | NOISE_TRANS_STD_TRAIN=[
13 | (0.02, 0.02, 0.02),
14 | (0.01, 0.01, 0.01),
15 | (0.005, 0.005, 0.005),
16 | ],
17 | NOISE_SCALE_STD_TRAIN=[
18 | (0.01, 0.01, 0.01),
19 | (0.005, 0.005, 0.005),
20 | (0.002, 0.002, 0.002),
21 | ],
22 | INIT_POSE_TYPE_TEST="est", # gt_noise | est | canonical
23 | KPS_TYPE="mean_shape", # bbox_from_scale | mean_shape |fps (abla)
24 | WITH_DEPTH=True,
25 | AUG_DEPTH=True,
26 | WITH_PCL=True,
27 | WITH_IMG=False,
28 | BP_DEPTH=False,
29 | NUM_KPS=1024,
30 | NUM_PCL=1024,
31 | # augmentation when training
32 | BBOX3D_AUG_PROB=0.5,
33 | RT_AUG_PROB=0.5,
34 | # pose focalization
35 | ZERO_CENTER_INPUT=True,
36 | )
37 |
38 | DATALOADER = dict(
39 | NUM_WORKERS=24,
40 | )
41 |
42 | SOLVER = dict(
43 | IMS_PER_BATCH=16,
44 | TOTAL_EPOCHS=120,
45 | LR_SCHEDULER_NAME="flat_and_anneal",
46 | ANNEAL_METHOD="cosine", # "cosine"
47 | ANNEAL_POINT=0.72,
48 | # REL_STEPS=(0.3125, 0.625, 0.9375),
49 | OPTIMIZER_CFG=dict(_delete_=True, type="Ranger", lr=1e-4, weight_decay=0),
50 | WEIGHT_DECAY=0.0,
51 | WARMUP_FACTOR=0.001,
52 | WARMUP_ITERS=1000,
53 | )
54 |
55 | DATASETS = dict(
56 | TRAIN=("nocs_train_real",),
57 | TEST=("nocs_test_real",),
58 | INIT_POSE_FILES_TEST=("datasets/NOCS/test_init_poses/init_pose_spd_nocs_real.json",),
59 | )
60 |
61 | MODEL = dict(
62 | LOAD_POSES_TEST=True,
63 | PIXEL_MEAN=[0.0, 0.0, 0.0],
64 | PIXEL_STD=[255.0, 255.0, 255.0],
65 | REFINE_SCLAE=True,
66 | CATRE=dict(
67 | NAME="CATRE_disR_shared", # used module file name (define different model types)
68 | TASK="refine", # refine | init | init+refine
69 | NUM_CLASSES=6, # only valid for class aware
70 | N_ITER_TRAIN=4,
71 | N_ITER_TRAIN_WARM_EPOCH=4, # linearly increase the refine iter from 1 to N_ITER_TRAIN until this epoch
72 | N_ITER_TEST=4,
73 | PCLNET=dict(
74 | FREEZE=False,
75 | INIT_CFG=dict(
76 | type="point_net",
77 | num_points=1024,
78 | global_feat=False,
79 | feature_transform=True,
80 | out_dim=1024,
81 | ),
82 | ),
83 | ## disentangled pose head for delta R/T/s
84 | ROT_HEAD=dict(
85 | ROT_TYPE="ego_rot6d", # {ego|allo}_rot6d
86 | INIT_CFG=dict(
87 | type="ConvOutPerRotHead",
88 | in_dim=1088,
89 | num_layers=2,
90 | kernel_size=1,
91 | feat_dim=256,
92 | norm="GN", # BN | GN | none
93 | num_gn_groups=32,
94 | act="gelu", # relu | lrelu | silu (swish) | gelu | mish
95 | num_points=1024 + 1024,
96 | rot_dim=3, # ego_rot6d
97 | norm_input=False,
98 | ),
99 | SCLAE_TYPE="iter_add",
100 | ),
101 | TS_HEAD=dict(
102 | WITH_KPS_FEATURE=False,
103 | WITH_INIT_SCALE=True,
104 | INIT_CFG=dict(
105 | type="FC_TransSizeHead",
106 | in_dim=1088 + 3,
107 | num_layers=2,
108 | feat_dim=256,
109 | norm="GN", # BN | GN | none
110 | num_gn_groups=32,
111 | act="gelu", # relu | lrelu | silu (swish) | gelu | mish
112 | norm_input=False,
113 | ),
114 | ),
115 | LOSS_CFG=dict(
116 | # point matching loss ----------------
117 | PM_LOSS_SYM=True, # use symmetric PM loss
118 | PM_NORM_BY_EXTENT=False, # 1. / extent.max(1, keepdim=True)[0]
119 | # if False, the trans loss is in point matching loss
120 | PM_R_ONLY=True, # only do R loss in PM
121 | PM_WITH_SCALE=True,
122 | PM_LW=1.0,
123 | # rot loss --------------
124 | ROT_LOSS_TYPE="angular", # angular | L2
125 | ROT_LW=1.0,
126 | ROT_YAXIS_LOSS_TYPE="L1",
127 | # trans loss -----------
128 | TRANS_LOSS_TYPE="L1",
129 | TRANS_LOSS_DISENTANGLE=True,
130 | TRANS_LW=1.0,
131 | # scale loss ----------------------------------
132 | SCALE_LOSS_TYPE="L1",
133 | SCALE_LW=1.0,
134 | ),
135 | ),
136 | )
137 |
--------------------------------------------------------------------------------
/configs/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e_initspd.py:
--------------------------------------------------------------------------------
1 | _base_ = ["../../_base_/catre_base.py"]
2 |
3 | # fix mug mean shape
4 | OUTPUT_DIR = "output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e_initspd"
5 | INPUT = dict(
6 | COLOR_AUG_PROB=0.0,
7 | DEPTH_SAMPLE_BALL_RATIO=0.6,
8 | # NOTE: used points!
9 | BBOX_TYPE_TEST="est", # from_pose | est | gt | gt_aug (TODO)
10 | INIT_POSE_TYPE_TRAIN=["gt_noise"], # gt_noise | random | canonical
11 | NOISE_ROT_STD_TRAIN=(10, 5, 2.5, 1.25), # randomly choose one
12 | NOISE_TRANS_STD_TRAIN=[
13 | (0.02, 0.02, 0.02),
14 | (0.01, 0.01, 0.01),
15 | (0.005, 0.005, 0.005),
16 | ],
17 | NOISE_SCALE_STD_TRAIN=[
18 | (0.01, 0.01, 0.01),
19 | (0.005, 0.005, 0.005),
20 | (0.002, 0.002, 0.002),
21 | ],
22 | INIT_POSE_TYPE_TEST="est", # gt_noise | est | canonical
23 | KPS_TYPE="mean_shape", # bbox_from_scale | mean_shape |fps (abla)
24 | WITH_DEPTH=True,
25 | AUG_DEPTH=True,
26 | WITH_PCL=True,
27 | WITH_IMG=False,
28 | BP_DEPTH=False,
29 | NUM_KPS=1024,
30 | NUM_PCL=1024,
31 | # augmentation when training
32 | BBOX3D_AUG_PROB=0.5,
33 | RT_AUG_PROB=0.5,
34 | ZERO_CENTER_INPUT=True,
35 | )
36 |
37 | DATALOADER = dict(
38 | NUM_WORKERS=24,
39 | )
40 |
41 | SOLVER = dict(
42 | IMS_PER_BATCH=32,
43 | TOTAL_EPOCHS=120,
44 | LR_SCHEDULER_NAME="flat_and_anneal",
45 | ANNEAL_METHOD="cosine", # "cosine"
46 | ANNEAL_POINT=0.72,
47 | # REL_STEPS=(0.3125, 0.625, 0.9375),
48 | OPTIMIZER_CFG=dict(_delete_=True, type="Ranger", lr=1e-4, weight_decay=0),
49 | WEIGHT_DECAY=0.0,
50 | WARMUP_FACTOR=0.001,
51 | WARMUP_ITERS=1000,
52 | )
53 |
54 | DATASETS = dict(
55 | TRAIN=("nocs_train_real",),
56 | TEST=("nocs_test_real",),
57 | INIT_POSE_FILES_TEST=("datasets/NOCS/test_init_poses/init_pose_dualposenet_nocs_real.json",),
58 | )
59 |
60 | MODEL = dict(
61 | LOAD_POSES_TEST=True,
62 | PIXEL_MEAN=[0.0, 0.0, 0.0],
63 | PIXEL_STD=[255.0, 255.0, 255.0],
64 | REFINE_SCLAE=True,
65 | CATRE=dict(
66 | NAME="CATRE_disR_shared", # used module file name (define different model types)
67 | TASK="refine", # refine | init | init+refine
68 | NUM_CLASSES=6, # only valid for class aware
69 | N_ITER_TRAIN=4,
70 | N_ITER_TRAIN_WARM_EPOCH=4, # linearly increase the refine iter from 1 to N_ITER_TRAIN until this epoch
71 | N_ITER_TEST=4,
72 | PCLNET=dict(
73 | FREEZE=False,
74 | INIT_CFG=dict(
75 | type="point_net",
76 | num_points=1024,
77 | global_feat=False,
78 | feature_transform=True,
79 | out_dim=1024,
80 | ),
81 | ),
82 | ## pose head for delta R/T/s
83 | ROT_HEAD=dict(
84 | ROT_TYPE="ego_rot6d", # {ego|allo}_rot6d
85 | INIT_CFG=dict(
86 | type="ConvOutPerRotHead",
87 | in_dim=1088,
88 | num_layers=2,
89 | kernel_size=1,
90 | feat_dim=256,
91 | norm="GN", # BN | GN | none
92 | num_gn_groups=32,
93 | act="gelu", # relu | lrelu | silu (swish) | gelu | mish
94 | num_points=1024 + 1024,
95 | rot_dim=3, # ego_rot6d
96 | norm_input=False,
97 | ),
98 | SCLAE_TYPE="iter_add",
99 | ),
100 | TS_HEAD=dict(
101 | WITH_KPS_FEATURE=False,
102 | WITH_INIT_SCALE=True,
103 | INIT_CFG=dict(
104 | type="FC_TransSizeHead",
105 | in_dim=1088 + 3,
106 | num_layers=2,
107 | feat_dim=256,
108 | norm="GN", # BN | GN | none
109 | num_gn_groups=32,
110 | act="gelu", # relu | lrelu | silu (swish) | gelu | mish
111 | norm_input=False,
112 | ),
113 | ),
114 | LOSS_CFG=dict(
115 | # point matching loss ----------------
116 | PM_LOSS_SYM=True, # use symmetric PM loss
117 | PM_NORM_BY_EXTENT=False, # 1. / extent.max(1, keepdim=True)[0]
118 | # if False, the trans loss is in point matching loss
119 | PM_R_ONLY=True, # only do R loss in PM
120 | PM_WITH_SCALE=True,
121 | PM_LW=1.0,
122 | # rot loss --------------
123 | ROT_LOSS_TYPE="angular", # angular | L2
124 | ROT_LW=1.0,
125 | ROT_YAXIS_LOSS_TYPE="L1",
126 | # trans loss -----------
127 | TRANS_LOSS_TYPE="L1",
128 | TRANS_LOSS_DISENTANGLE=True,
129 | TRANS_LW=1.0,
130 | # scale loss ----------------------------------
131 | SCALE_LOSS_TYPE="L1",
132 | SCALE_LW=1.0,
133 | ),
134 | ),
135 | )
136 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/core/__init__.py
--------------------------------------------------------------------------------
/core/catre/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/core/catre/datasets/__init__.py
--------------------------------------------------------------------------------
/core/catre/datasets/dataset_factory.py:
--------------------------------------------------------------------------------
1 | """Register datasets in this file will be imported in project root to register
2 | the datasets."""
3 | import logging
4 | import os
5 | import os.path as osp
6 | import mmcv
7 | import detectron2.utils.comm as comm
8 | import ref
9 | from detectron2.data import DatasetCatalog, MetadataCatalog
10 | from core.catre.datasets import (
11 | nocs,
12 | cmra,
13 | )
14 |
15 | cur_dir = osp.dirname(osp.abspath(__file__))
16 | # from lib.utils.utils import iprint
17 | __all__ = ["register_dataset", "register_datasets", "register_datasets_in_cfg", "get_available_datasets"]
18 | _DSET_MOD_NAMES = [
19 | "nocs",
20 | "cmra",
21 | ]
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | def register_dataset(mod_name, dset_name, data_cfg=None):
27 | """
28 | mod_name: a module under core.datasets or other dataset source file imported here
29 | dset_name: dataset name
30 | data_cfg: dataset config
31 | """
32 | register_func = eval(mod_name)
33 | register_func.register_with_name_cfg(dset_name, data_cfg)
34 |
35 |
36 | def get_available_datasets(mod_name):
37 | return eval(mod_name).get_available_datasets()
38 |
39 |
40 | def register_datasets_in_cfg(cfg):
41 | for split in [
42 | "TRAIN",
43 | "TEST",
44 | "VAL",
45 | "TRAIN2",
46 | ]:
47 | for name in cfg.DATASETS.get(split, []):
48 | if name in DatasetCatalog.list():
49 | continue
50 | registered = False
51 | # try to find in pre-defined datasets
52 | # NOTE: it is better to let all datasets pre-refined
53 | for _mod_name in _DSET_MOD_NAMES:
54 | if name in get_available_datasets(_mod_name):
55 | register_dataset(_mod_name, name, data_cfg=None)
56 | registered = True
57 | break
58 | # not in pre-defined; not recommend
59 | if not registered:
60 | # try to get mod_name and data_cfg from cfg
61 | """load data_cfg and mod_name from file
62 | cfg.DATA_CFG[name] = 'path_to_cfg'
63 | """
64 | assert "DATA_CFG" in cfg and name in cfg.DATA_CFG, "no cfg.DATA_CFG.{}".format(name)
65 | assert osp.exists(cfg.DATA_CFG[name])
66 | data_cfg = mmcv.load(cfg.DATA_CFG[name])
67 | mod_name = data_cfg.pop("mod_name", None)
68 | assert mod_name in _DSET_MOD_NAMES, mod_name
69 | register_dataset(mod_name, name, data_cfg)
70 |
71 |
72 | def register_datasets(dataset_names):
73 | for name in dataset_names:
74 | if name in DatasetCatalog.list():
75 | continue
76 | registered = False
77 | # try to find in pre-defined datasets
78 | # NOTE: it is better to let all datasets pre-refined
79 | for _mod_name in _DSET_MOD_NAMES:
80 | if name in get_available_datasets(_mod_name):
81 | register_dataset(_mod_name, name, data_cfg=None)
82 | registered = True
83 | break
84 |
85 | # not in pre-defined; not recommend
86 | if not registered:
87 | raise ValueError(f"dataset {name} is not defined")
88 |
--------------------------------------------------------------------------------
/core/catre/engine/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/core/catre/engine/__init__.py
--------------------------------------------------------------------------------
/core/catre/engine/batch_test.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import random
3 | import torch
4 |
5 | from .engine_utils import get_normed_kps
6 | from lib.vis_utils.image import heatmap, grid_show
7 | from lib.pysixd.misc import transform_normed_pts_batch
8 |
9 |
10 | def batch_data_test(cfg, data, device="cuda", dtype=torch.float32):
11 | # batch test data and flatten
12 | tensor_kwargs = {"dtype": dtype, "device": device}
13 | to_float_args = {"dtype": dtype, "device": device, "non_blocking": True}
14 | to_long_args = {"dtype": torch.long, "device": device, "non_blocking": True}
15 |
16 | batch = {}
17 | num_imgs = len(data)
18 | # construct flattened instance data =============================
19 | batch["obj_cls"] = torch.cat([d["instances"].obj_classes for d in data], dim=0).to(**to_long_args)
20 | batch["obj_bbox"] = torch.cat([d["instances"].obj_boxes.tensor for d in data], dim=0).to(**to_float_args)
21 | # NOTE: initial pose or the output pose estimate
22 | batch["obj_pose_est"] = torch.cat([d["instances"].obj_poses.tensor for d in data], dim=0).to(**to_float_args)
23 | batch["obj_scale_est"] = torch.cat([d["instances"].obj_scales for d in data], dim=0).to(**to_float_args)
24 |
25 | batch["obj_mean_points"] = torch.cat([d["instances"].obj_mean_points for d in data], dim=0).to(**to_float_args)
26 | batch["obj_mean_scales"] = torch.cat([d["instances"].obj_mean_scales for d in data], dim=0).to(**to_float_args)
27 |
28 | if cfg.INPUT.KPS_TYPE.lower() == "fps":
29 | # NOTE: only an ablation setting!
30 | batch["obj_fps_points"] = torch.cat([d["instances"].obj_fps_points for d in data], dim=0).to(**to_float_args)
31 |
32 | num_insts_per_im = [len(d["instances"]) for d in data]
33 | n_obj = len(batch["obj_cls"])
34 | K_list = []
35 | sym_infos_list = []
36 | im_ids = []
37 | inst_ids = []
38 | for i_im in range(num_imgs):
39 | sym_infos_list.extend(data[i_im]["instances"].obj_sym_infos)
40 | for i_inst in range(num_insts_per_im[i_im]):
41 | im_ids.append(i_im)
42 | inst_ids.append(i_inst)
43 | K_list.append(data[i_im]["cam"].clone())
44 |
45 | batch["im_id"] = torch.tensor(im_ids, **tensor_kwargs)
46 | batch["inst_id"] = torch.tensor(inst_ids, **tensor_kwargs)
47 | batch["K"] = torch.stack(K_list, dim=0).to(**to_float_args)
48 | batch["sym_info"] = sym_infos_list
49 |
50 | input_cfg = cfg.INPUT
51 |
52 | batch["pcl"] = torch.cat([d["instances"].pcl for d in data], dim=0).to(**to_float_args)
53 |
54 | if input_cfg.WITH_IMG:
55 | batch["img"] = torch.stack([d["image"] for d in data]).to(**to_float_args)
56 |
57 | if input_cfg.WITH_DEPTH:
58 | batch["depth_obs"] = torch.stack([d["depth"] for d in data], dim=0).to(**to_float_args)
59 |
60 | return batch
61 |
62 |
63 | def batch_updater_test(cfg, batch, poses_est=None, scales_est=None, device="cuda", dtype=torch.float32):
64 | """
65 | iter=0: poses_est=None, obj_pose_est is from data loader
66 | if REFINE_SCLAE is False, keep init_scale unchanged from iter 0 ~ max_num
67 | """
68 | tensor_kwargs = {"dtype": dtype, "device": device}
69 | to_float_args = {"dtype": dtype, "device": device, "non_blocking": True}
70 |
71 | n_obj = batch["obj_cls"].shape[0]
72 | if poses_est is not None:
73 | batch["obj_pose_est"] = poses_est
74 |
75 | if scales_est is not None and cfg.MODEL.REFINE_SCLAE:
76 | batch["obj_scale_est"] = scales_est
77 |
78 | if "obj_kps" not in batch:
79 | get_normed_kps(cfg, batch, **to_float_args)
80 |
81 | r_est = batch["obj_pose_est"][:, :3, :3]
82 | t_est = batch["obj_pose_est"][:, :3, 3:4]
83 | s_est = batch["obj_scale_est"]
84 |
85 | tfd_kps = transform_normed_pts_batch(
86 | batch["obj_kps"],
87 | r_est,
88 | t=None if cfg.INPUT.ZERO_CENTER_INPUT else t_est,
89 | scale=s_est,
90 | )
91 |
92 | batch["tfd_kps"] = tfd_kps.permute(0, 2, 1) # [bs, 3, num_k]
93 |
94 | if cfg.INPUT.ZERO_CENTER_INPUT:
95 | batch["x"] = batch["pcl"].permute(0, 2, 1) - t_est.view(n_obj, 3, 1) # [bs, 3, num_k] - [bs, 3, 1]
96 | else:
97 | batch["x"] = batch["pcl"].permute(0, 2, 1)
98 |
99 | # done batch update test------------------------------------------
100 |
--------------------------------------------------------------------------------
/core/catre/losses/l2_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def l2_loss(pred, target, reduction="mean"):
6 | assert pred.size() == target.size() and target.numel() > 0
7 | assert pred.size()[0] == target.size()[0]
8 | batch_size = pred.size()[0]
9 | loss = torch.norm((pred - target).view(batch_size, -1), p=2, dim=1, keepdim=True)
10 | # loss = torch.sqrt(torch.sum(((pred - target)** 2).view(batch_size, -1), 1))
11 | # print(loss.shape)
12 | """
13 | _mse_loss = nn.MSELoss(reduction='none')
14 | loss_mse = _mse_loss(pred, target)
15 | print('l2 from mse loss: {}'.format(
16 | torch.sqrt(
17 | torch.sum(
18 | loss_mse.view(batch_size, -1),
19 | 1
20 | )
21 | ).mean()))
22 | """
23 | if reduction == "mean":
24 | loss = loss.mean()
25 | elif reduction == "sum":
26 | loss = loss.sum()
27 | return loss
28 |
29 |
30 | class L2Loss(nn.Module):
31 | def __init__(self, reduction="mean", loss_weight=1.0):
32 | super(L2Loss, self).__init__()
33 | self.reduction = reduction
34 | self.loss_weight = loss_weight
35 |
36 | def forward(self, pred, target):
37 | loss = self.loss_weight * l2_loss(pred, target, reduction=self.reduction)
38 | return loss
39 |
40 |
41 | if __name__ == "__main__":
42 |
43 | _l2_loss = L2Loss(reduction="mean")
44 | torch.manual_seed(2)
45 | pred = torch.randn(8, 3, 4)
46 | targets = torch.randn(8, 3, 4)
47 | # pred.requires_grad = True
48 | # targets.requires_grad = True
49 |
50 | loss_l2 = _l2_loss(pred, targets)
51 | batch_size = 8
52 | # print('mse loss: {}'.format(loss_mse))
53 | # print('sqrt(mse loss): {}'.format(torch.sqrt(loss_mse)))
54 |
55 | _mse_loss = nn.MSELoss(reduction="none")
56 | loss_mse = _mse_loss(pred, targets)
57 | print("l2 from mse loss: {}".format(torch.sqrt(torch.sum(loss_mse.view(batch_size, -1), 1)).mean()))
58 | print("l2 loss: {}".format(loss_l2))
59 | # print('squared l2 loss: {}'.format(loss_l2 ** 2))
60 |
--------------------------------------------------------------------------------
/core/catre/losses/rot_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def angular_distance(r1, r2, reduction="mean"):
5 | """https://math.stackexchange.com/questions/90081/quaternion-distance
6 | https.
7 |
8 | ://github.com/papagina/RotationContinuity/blob/master/sanity_test/code/tool
9 | s.py.
10 |
11 | 1 - ^2 <==> (1-cos(theta)) / 2
12 | """
13 | assert r1.shape == r2.shape
14 | if r1.shape[-1] == 4:
15 | return angular_distance_quat(r1, r2, reduction=reduction)
16 | if len(r1.shape) == 2 and r1.shape[-1] == 3: # bs * 3
17 | return angular_distance_vec(r1, r2, reduction=reduction)
18 | else:
19 | return angular_distance_rot(r1, r2, reduction=reduction)
20 |
21 |
22 | def angular_distance_quat(pred_q, gt_q, reduction="mean"):
23 | dist = 1 - torch.pow(torch.bmm(pred_q.view(-1, 1, 4), gt_q.view(-1, 4, 1)), 2)
24 | if reduction == "mean":
25 | return dist.mean()
26 | elif reduction == "sum":
27 | return dist.sum()
28 | else:
29 | return dist
30 |
31 |
32 | def angular_distance_vec(vec_1, vec_2, reduction="mean"):
33 | cos = torch.bmm(vec_1.unsqueeze(1), vec_2.unsqueeze(2)).squeeze() / (
34 | torch.norm(vec_1, dim=1) * torch.norm(vec_2, dim=1)
35 | ) # [-1, 1]
36 | dist = (1 - cos) / 2 # [0, 1]
37 | if reduction == "mean":
38 | return dist.mean()
39 | elif reduction == "sum":
40 | return dist.sum()
41 | else:
42 | return dist
43 |
44 |
45 | def angular_distance_rot(m1, m2, reduction="mean"):
46 | m = torch.bmm(m1, m2.transpose(1, 2)) # b*3*3
47 | m_trace = torch.einsum("bii->b", m) # batch trace
48 | cos = (m_trace - 1) / 2 # [-1, 1]
49 | # eps = 1e-6
50 | # cos = torch.clamp(cos, -1+eps, 1-eps) # avoid nan
51 | # theta = torch.acos(cos)
52 | dist = (1 - cos) / 2 # [0, 1]
53 | if reduction == "mean":
54 | return dist.mean()
55 | elif reduction == "sum":
56 | return dist.sum()
57 | else:
58 | return dist
59 |
60 |
61 | def rot_l2_loss(m1, m2):
62 | error = torch.pow(m1 - m2, 2).mean() # batch
63 | return error
64 |
65 |
66 | if __name__ == "__main__":
67 | import sys
68 | import os.path as osp
69 |
70 | cur_dir = osp.dirname(__file__)
71 | sys.path.insert(0, osp.join(cur_dir, "../../../"))
72 | from lib.pysixd.transform import random_quaternion
73 | from transforms3d.quaternions import quat2mat
74 |
75 | q1 = random_quaternion()
76 | q2 = random_quaternion()
77 | m1 = quat2mat(q1)
78 | m2 = quat2mat(q2)
79 | dtype = torch.float32
80 | device = "cpu"
81 | q1 = torch.tensor([q1, q1], dtype=dtype, device=device).view(-1, 4)
82 | q2 = torch.tensor([q2, q2], dtype=dtype, device=device).view(-1, 4)
83 | m1 = torch.tensor([m1, m1], dtype=dtype, device=device).view(-1, 3, 3)
84 | m2 = torch.tensor([m2, m2], dtype=dtype, device=device).view(-1, 3, 3)
85 | dist_q = angular_distance_quat(q1, q2)
86 | dist_r = angular_distance_rot(m1, m2)
87 | print("dist q: ", dist_q)
88 | print("dist r: ", dist_r)
89 | print(angular_distance(q1, q2))
90 | print(angular_distance(m1, m2))
91 |
--------------------------------------------------------------------------------
/core/catre/models/heads/conv_out_per_rot_head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.modules.batchnorm import _BatchNorm
4 | from mmcv.cnn import normal_init, constant_init
5 |
6 | from lib.torch_utils.layers.layer_utils import get_norm, get_nn_act_func
7 | from lib.torch_utils.layers.conv_module import ConvModule
8 |
9 |
10 | class ConvOutPerRotHead(nn.Module):
11 | def __init__(
12 | self,
13 | in_dim=1024,
14 | feat_dim=256,
15 | num_layers=2,
16 | rot_dim=3,
17 | norm="GN",
18 | num_gn_groups=32,
19 | act="gelu",
20 | num_classes=1,
21 | kernel_size=1,
22 | num_points=1,
23 | per_rot_sup=False,
24 | norm_input=False,
25 | dropout=False,
26 | point_bias=True,
27 | **args,
28 | ):
29 | super(ConvOutPerRotHead, self).__init__()
30 | self.per_rot_sup = per_rot_sup
31 | self.rot_head_x = RotHead(
32 | in_dim,
33 | feat_dim,
34 | num_layers,
35 | rot_dim,
36 | norm,
37 | num_gn_groups,
38 | act,
39 | num_classes,
40 | kernel_size,
41 | num_points,
42 | norm_input,
43 | dropout,
44 | point_bias,
45 | )
46 | self.rot_head_y = RotHead(
47 | in_dim,
48 | feat_dim,
49 | num_layers,
50 | rot_dim,
51 | norm,
52 | num_gn_groups,
53 | act,
54 | num_classes,
55 | kernel_size,
56 | num_points,
57 | norm_input,
58 | dropout,
59 | point_bias,
60 | )
61 |
62 | def forward(self, x):
63 | rx, feat_x = self.rot_head_x(x)
64 | ry, feat_y = self.rot_head_y(x)
65 | r_pred = torch.cat((rx, ry), dim=1)
66 | feat = torch.cat((feat_x, feat_y), dim=1)
67 |
68 | if self.per_rot_sup:
69 | return r_pred, feat # return bs * 6
70 | else:
71 | return r_pred
72 |
73 |
74 | class RotHead(nn.Module):
75 | def __init__(
76 | self,
77 | in_dim=1024,
78 | feat_dim=256,
79 | num_layers=2,
80 | rot_dim=4,
81 | norm="none",
82 | num_gn_groups=32,
83 | act="leaky_relu",
84 | num_classes=1,
85 | kernel_size=1,
86 | num_points=1,
87 | norm_input=False,
88 | dropout=False,
89 | point_bias=True,
90 | ):
91 | super().__init__()
92 | self.norm = get_norm(norm, feat_dim, num_gn_groups=num_gn_groups)
93 | self.act_func = act_func = get_nn_act_func(act)
94 | self.num_classes = num_classes
95 | self.rot_dim = rot_dim
96 |
97 | self.layers = nn.ModuleList()
98 |
99 | if norm_input:
100 | self.layers.append(nn.BatchNorm1d(in_dim))
101 | for _i in range(num_layers):
102 | _in_dim = in_dim if _i == 0 else feat_dim
103 | self.layers.append(nn.Conv1d(_in_dim, feat_dim, kernel_size))
104 | self.layers.append(get_norm(norm, feat_dim, num_gn_groups=num_gn_groups))
105 | self.layers.append(act_func)
106 | if dropout:
107 | self.layers.append(nn.Dropout(p=0.2))
108 |
109 | self.neck = nn.ModuleList()
110 | self.neck.append(nn.Conv1d(feat_dim, rot_dim * num_classes, 1))
111 |
112 | self.conv_p = nn.Conv1d(num_points, 1, 1, bias=point_bias)
113 |
114 | # init ------------------------------------
115 | self._init_weights()
116 |
117 | def _init_weights(self):
118 | for m in self.modules():
119 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d)):
120 | normal_init(m, std=0.001)
121 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
122 | constant_init(m, 1)
123 | elif isinstance(m, nn.Linear):
124 | normal_init(m, std=0.001)
125 |
126 | def forward(self, x):
127 | for _layer in self.layers:
128 | x = _layer(x)
129 |
130 | for _layer in self.neck:
131 | x = _layer(x)
132 |
133 | feat = x.clone()
134 | x = x.permute(0, 2, 1)
135 | x = self.conv_p(x)
136 |
137 | x = x.squeeze(1)
138 | x = x.contiguous()
139 |
140 | return x, feat
141 |
142 |
143 | if __name__ == "__main__":
144 | points = torch.rand(8, 1088, 1024 + 32) # bs x feature x num_p
145 | rot_head = ConvOutPerRotHead(in_dim=1088, num_points=1024 + 32)
146 | rot, feat = rot_head(points)
147 | print(rot.shape)
148 | print(feat.shape)
149 |
--------------------------------------------------------------------------------
/core/catre/models/heads/fc_trans_size_head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.modules.batchnorm import _BatchNorm
4 | from mmcv.cnn import normal_init, constant_init
5 | from lib.torch_utils.layers.layer_utils import get_norm, get_nn_act_func
6 | from lib.torch_utils.layers.conv_module import ConvModule
7 |
8 |
9 | class FC_TransSizeHead(nn.Module):
10 | def __init__(
11 | self,
12 | in_dim=1024,
13 | feat_dim=256,
14 | num_layers=2,
15 | rot_dim=4,
16 | norm="none",
17 | num_gn_groups=32,
18 | act="leaky_relu",
19 | num_classes=1,
20 | norm_input=False,
21 | dropout=False,
22 | ):
23 | """
24 | rot_dim: 4 for quaternion, 6 for rot6d
25 | num_classes: default 1 (either single class or class-agnostic)
26 | """
27 | super().__init__()
28 | self.norm = get_norm(norm, feat_dim, num_gn_groups=num_gn_groups)
29 | self.act_func = act_func = get_nn_act_func(act)
30 | self.num_classes = num_classes
31 | self.rot_dim = rot_dim
32 |
33 | self.linears = nn.ModuleList()
34 | if norm_input:
35 | self.linears.append(nn.BatchNorm1d(in_dim))
36 | for _i in range(num_layers):
37 | _in_dim = in_dim if _i == 0 else feat_dim
38 | self.linears.append(nn.Linear(_in_dim, feat_dim))
39 | self.linears.append(get_norm(norm, feat_dim, num_gn_groups=num_gn_groups))
40 | self.linears.append(act_func)
41 | if dropout:
42 | self.linears.append(nn.Dropout(p=0.5))
43 |
44 | self.fc_t = nn.Linear(feat_dim, 3 * num_classes)
45 | self.fc_s = nn.Linear(feat_dim, 3 * num_classes)
46 |
47 | # init ------------------------------------
48 | self._init_weights()
49 |
50 | def _init_weights(self):
51 | for m in self.modules():
52 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d)):
53 | normal_init(m, std=0.001)
54 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
55 | constant_init(m, 1)
56 | elif isinstance(m, nn.Linear):
57 | normal_init(m, std=0.001)
58 | normal_init(self.fc_t, std=0.01)
59 | normal_init(self.fc_s, std=0.01)
60 |
61 | def forward(self, x):
62 | """
63 | x: should be flattened
64 | """
65 | for _layer in self.linears:
66 | x = _layer(x)
67 |
68 | trans = self.fc_t(x)
69 | scale = self.fc_s(x)
70 | return trans, scale
71 |
--------------------------------------------------------------------------------
/core/catre/models/net_factory.py:
--------------------------------------------------------------------------------
1 | from .pointnets.pointnet import PointNetfeat
2 |
3 | from .heads.fc_trans_size_head import FC_TransSizeHead
4 | from .heads.conv_out_per_rot_head import ConvOutPerRotHead
5 |
6 | PCLNETS = {
7 | "point_net": PointNetfeat,
8 | }
9 |
10 | HEADS = {
11 | "FC_TransSizeHead": FC_TransSizeHead,
12 | "ConvOutPerRotHead": ConvOutPerRotHead,
13 | }
14 |
--------------------------------------------------------------------------------
/core/catre/models/pose_scale_from_delta_init.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from core.utils.utils import (
4 | allo_to_ego_mat_torch,
5 | )
6 |
7 |
8 | def pose_scale_from_delta_init(
9 | rot_deltas,
10 | trans_deltas,
11 | scale_deltas,
12 | rot_inits,
13 | trans_inits,
14 | scale_inits,
15 | Ks=None,
16 | K_aware=False,
17 | delta_T_space="3D",
18 | delta_T_weight=1.0,
19 | delta_z_style="cosypose",
20 | eps=1e-4,
21 | is_allo=False,
22 | scale_type="add_iter",
23 | ):
24 | """
25 | Args:
26 | rot_deltas: [b,3,3]
27 | trans_deltas: [b,3], vxvyvz, delta translations in image space
28 | rot_inits: [b,3,3]
29 | trans_inits: [b,3]
30 | Ks: if None, using constants 1
31 | otherwise use zoomed Ks
32 | K_aware: whether to use zoomed K
33 | delta_T_space: image | 3D
34 | delta_T_weight: deepim-pytorch uses 0.1, default 1.0
35 | delta_z_style: cosypose (_vz = ztgt / zsrc) | deepim (vz = log(zrsc/ztgt))
36 | eps:
37 | is_allo:
38 | Returns:
39 | rot_tgts, trans_tgts
40 | """
41 | bs = rot_deltas.shape[0]
42 | assert rot_deltas.shape == (bs, 3, 3)
43 | assert rot_inits.shape == (bs, 3, 3)
44 | assert trans_deltas.shape == (bs, 3)
45 | assert trans_inits.shape == (bs, 3)
46 |
47 | # trans============================================
48 | trans_deltas = trans_deltas * delta_T_weight
49 |
50 | if delta_T_space == "image":
51 | # Translation in image space
52 | zsrc = trans_inits[:, [2]] # [b,1]
53 | vz = trans_deltas[:, [2]] # [b,1]
54 | if delta_z_style == "cosypose":
55 | # NOTE: directly predict vz = 1/exp(_vz)
56 | # log(zsrc/ztgt) = _vz ==> ztgt = 1/exp(_vz) * zsrc
57 | ztgt = vz * zsrc # [b,1]
58 | else: # deepim
59 | # vz = log(zsrc/ztgt) ==> ztgt = zsrc / exp(vz)
60 | ztgt = torch.div(zsrc, torch.exp(vz)) # [b,1]
61 |
62 | if K_aware:
63 | assert Ks is not None and Ks.shape == (bs, 3, 3)
64 | vxvy = trans_deltas[:, :2] # [b,2]
65 | fxfy = Ks[:, [0, 1], [0, 1]] # [b,2]
66 | else: # deepim: treat fx, fy as 1
67 | vxvy = trans_deltas[:, :2] # [b,2]
68 | fxfy = torch.ones_like(vxvy)
69 |
70 | xy_src = trans_inits[:, :2] # [b,2]
71 | xy_tgt = ztgt * (vxvy / fxfy + xy_src / zsrc) # [b,2]
72 | trans_tgts = torch.cat([xy_tgt, ztgt], dim=-1) # [b,3]
73 | elif delta_T_space == "3D":
74 | trans_tgts = trans_inits + trans_deltas
75 | else:
76 | raise ValueError("Unknown delta_T_space: {}".format(delta_T_space))
77 |
78 | # scale =========================================
79 | if "add" in scale_type:
80 | scale_tgts = scale_inits + scale_deltas
81 | else:
82 | # NOTE: add exp to make scale_deltas zero-centered
83 | # scale_deltas =: log(s/mean_s)
84 | scale_tgts = scale_inits * torch.exp(scale_deltas)
85 |
86 | # rot ===========================================
87 | if is_allo:
88 | ego_rot_deltas = allo_to_ego_mat_torch(trans_tgts, rot_deltas, eps=eps)
89 | else:
90 | ego_rot_deltas = rot_deltas
91 |
92 | # Rotation in camera frame
93 | rot_tgts = ego_rot_deltas @ rot_inits
94 |
95 | return rot_tgts, trans_tgts, scale_tgts
96 |
--------------------------------------------------------------------------------
/core/catre/test_catre.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # test
3 | set -x
4 | this_dir=$(dirname "$0")
5 | # commonly used opts:
6 |
7 | # MODEL.WEIGHTS: resume or pretrained, or test checkpoint
8 | CFG=$1
9 | CUDA_VISIBLE_DEVICES=$2
10 | IFS=',' read -ra GPUS <<< "$CUDA_VISIBLE_DEVICES"
11 | # GPUS=($(echo "$CUDA_VISIBLE_DEVICES" | tr ',' '\n'))
12 | NGPU=${#GPUS[@]} # echo "${GPUS[0]}"
13 | echo "use gpu ids: $CUDA_VISIBLE_DEVICES num gpus: $NGPU"
14 | CKPT=$3
15 | if [ ! -f "$CKPT" ]; then
16 | echo "$CKPT does not exist."
17 | exit 1
18 | fi
19 | NCCL_DEBUG=INFO
20 | OMP_NUM_THREADS=1
21 | MKL_NUM_THREADS=1
22 | PYTHONPATH="$this_dir/../..":$PYTHONPATH \
23 | CUDA_VISIBLE_DEVICES=$2 python $this_dir/main_catre.py \
24 | --config-file $CFG --num-gpus $NGPU --eval-only \
25 | --opts MODEL.WEIGHTS=$CKPT \
26 | ${@:4}
27 |
--------------------------------------------------------------------------------
/core/catre/tools/camera25_prepare_spd_init_results.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import mmcv
3 | import os.path as osp
4 | import glob
5 | import sys
6 | from tqdm import tqdm
7 | import setproctitle
8 |
9 | cur_dir = osp.dirname(osp.abspath(__file__))
10 | PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../"))
11 | sys.path.insert(0, PROJ_ROOT)
12 | from lib.pysixd import inout, misc
13 | from lib.utils.mask_utils import binary_mask_to_rle
14 |
15 | setproctitle.setproctitle(osp.basename(__file__).split(".")[0])
16 |
17 | data_root = osp.join(PROJ_ROOT, "datasets/NOCS")
18 |
19 | # original spd results
20 | spd_pose_dir = osp.join(PROJ_ROOT, "datasets/NOCS/deformnet_eval/eval_camera") # in bbnc11
21 | spd_seg_dir = osp.join(PROJ_ROOT, "datasets/NOCS/deformnet_eval/mrcnn_results/val")
22 |
23 | # our format
24 | init_pose_dir = osp.join(data_root, "test_init_poses")
25 | mmcv.mkdir_or_exist(init_pose_dir)
26 | init_pose_path = osp.join(init_pose_dir, "init_pose_spd_nocs_cmra.json")
27 |
28 | if __name__ == "__main__":
29 | results = {}
30 |
31 | CACHED = False
32 | if CACHED:
33 | results = mmcv.load(init_pose_path)
34 | else:
35 | spd_pose_paths = glob.glob(osp.join(spd_pose_dir, "results*.pkl"))
36 |
37 | num_total = 0
38 | for idx, spd_pose_path in enumerate(tqdm(spd_pose_paths)):
39 | preds = mmcv.load(spd_pose_path)
40 | bboxes = preds["pred_bboxes"]
41 | scores = preds["pred_scores"]
42 | poses = preds["pred_RTs"][:, :3]
43 | pred_scales = preds["pred_scales"]
44 | class_ids = preds["pred_class_ids"]
45 | mug_handles = preds["gt_handle_visibility"]
46 |
47 | scene_id, im_id = spd_pose_path.split("/")[-1].split(".")[0].split("_")[-2:]
48 | scene_im_id = f"{scene_id}/{im_id}"
49 |
50 | seg_path = osp.join(spd_seg_dir, f"results_val_{scene_id}_{im_id}.pkl")
51 | assert osp.exists(seg_path), seg_path
52 |
53 | masks = mmcv.load(seg_path)["masks"].astype("int") # bool -> int
54 | assert masks.shape[2] == len(class_ids)
55 | results[scene_im_id] = []
56 | i = 0
57 | for class_id, pose, scale, score, bbox, mug_handle in zip(
58 | class_ids, poses, pred_scales, scores, bboxes, mug_handles
59 | ):
60 | # [sR -> R], normed_scale -> scale
61 | R = pose[:3, :3]
62 | if R.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]]:
63 | # print("ill pose")
64 | i += 1
65 | continue
66 | nocs_scale = pow(np.linalg.det(R), 1 / 3)
67 | abs_scale = scale * nocs_scale
68 | pose[:3, :3] = R / nocs_scale
69 | # mask2rle
70 | mask = masks[:, :, i]
71 | mask_rle = binary_mask_to_rle(mask)
72 | y1, x1, y2, x2 = bbox.tolist()
73 | bbox = [x1, y1, x2, y2]
74 | cur_res = {
75 | "obj_id": int(class_id),
76 | "pose_est": pose.tolist(),
77 | "scale_est": abs_scale.tolist(),
78 | "bbox_est": bbox,
79 | "score": float(score),
80 | "mug_handle": int(mug_handle),
81 | "segmentation": mask_rle,
82 | }
83 | results[scene_im_id].append(cur_res)
84 | i += 1
85 |
86 | print(init_pose_path)
87 | inout.save_json(init_pose_path, results, sort=False)
88 |
89 | VIS = False
90 | if VIS:
91 | from core.utils.data_utils import read_image_mmcv
92 | from lib.utils.mask_utils import cocosegm2mask
93 | from lib.vis_utils.image import grid_show, heatmap
94 | from core.catre.engine.test_utils import get_3d_bbox
95 | import ref
96 |
97 | for scene_im_id, r in results.items():
98 | img_path = f"datasets/NOCS/REAL/real_test/{scene_im_id}_color.png"
99 | img = read_image_mmcv(img_path, format="BGR")
100 | K = ref.nocs.real_intrinsics
101 | anno = r[0]
102 | imH, imW = img.shape[:2]
103 | mask = cocosegm2mask(anno["segmentation"], imH, imW)
104 | pose = np.array(anno["pose_est"]).reshape(3, 4)
105 | scale = np.array(anno["scale_est"])
106 | bbox = get_3d_bbox(scale).transpose()
107 | kpts_2d = misc.project_pts(bbox, K, pose[:, :3], pose[:, 3])
108 | img_vis_kpts2d = misc.draw_projected_box3d(img.copy(), kpts_2d)
109 | grid_show([img[:, :, ::-1], img_vis_kpts2d, mask], row=3, col=1)
110 |
--------------------------------------------------------------------------------
/core/catre/tools/prepare_spd_init_results.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import mmcv
3 | import os.path as osp
4 | import glob
5 | import sys
6 | from tqdm import tqdm
7 | import setproctitle
8 |
9 | cur_dir = osp.dirname(osp.abspath(__file__))
10 | PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../"))
11 | sys.path.insert(0, PROJ_ROOT)
12 | from lib.pysixd import inout, misc
13 | from lib.utils.mask_utils import binary_mask_to_rle
14 |
15 | setproctitle.setproctitle(osp.basename(__file__).split(".")[0])
16 |
17 | data_root = osp.join(PROJ_ROOT, "datasets/NOCS")
18 |
19 | # original spd results
20 | spd_pose_dir = osp.join(PROJ_ROOT, "datasets/NOCS/deformnet_eval/eval_real")
21 | spd_seg_dir = osp.join(PROJ_ROOT, "datasets/NOCS/deformnet_eval/mrcnn_results/real_test")
22 |
23 | # our format
24 | init_pose_dir = osp.join(data_root, "test_init_poses")
25 | mmcv.mkdir_or_exist(init_pose_dir)
26 | init_pose_path = osp.join(init_pose_dir, "init_pose_spd_nocs_real.json")
27 |
28 |
29 | if __name__ == "__main__":
30 | results = {}
31 |
32 | CACHED = False
33 | if CACHED:
34 | results = mmcv.load(init_pose_path)
35 | else:
36 | spd_pose_paths = glob.glob(osp.join(spd_pose_dir, "results*.pkl"))
37 |
38 | num_total = 0
39 | for idx, spd_pose_path in enumerate(tqdm(spd_pose_paths)):
40 | preds = mmcv.load(spd_pose_path)
41 | bboxes = preds["pred_bboxes"]
42 | scores = preds["pred_scores"]
43 | poses = preds["pred_RTs"][:, :3]
44 | pred_scales = preds["pred_scales"]
45 | class_ids = preds["pred_class_ids"]
46 | mug_handles = preds["gt_handle_visibility"]
47 |
48 | scene_id, im_id = spd_pose_path.split("/")[-1].split(".")[0].split("_")[-2:]
49 | scene_im_id = f"scene_{scene_id}/{im_id}"
50 |
51 | seg_path = osp.join(spd_seg_dir, f"results_test_scene_{scene_id}_{im_id}.pkl")
52 | assert osp.exists(seg_path), seg_path
53 |
54 | masks = mmcv.load(seg_path)["masks"].astype("int") # bool -> int
55 | assert masks.shape[2] == len(class_ids)
56 | results[scene_im_id] = []
57 | i = 0
58 | for class_id, pose, scale, score, bbox, mug_handle in zip(
59 | class_ids, poses, pred_scales, scores, bboxes, mug_handles
60 | ):
61 | # [sR -> R], normed_scale -> scale
62 | R = pose[:3, :3]
63 | nocs_scale = pow(np.linalg.det(R), 1 / 3)
64 | abs_scale = scale * nocs_scale
65 | pose[:3, :3] = R / nocs_scale
66 | # mask2rle
67 | mask = masks[:, :, i]
68 | mask_rle = binary_mask_to_rle(mask)
69 | y1, x1, y2, x2 = bbox.tolist()
70 | bbox = [x1, y1, x2, y2]
71 | cur_res = {
72 | "obj_id": int(class_id),
73 | "pose_est": pose.tolist(),
74 | "scale_est": abs_scale.tolist(),
75 | "bbox_est": bbox,
76 | "score": float(score),
77 | "mug_handle": int(mug_handle),
78 | "segmentation": mask_rle,
79 | }
80 | results[scene_im_id].append(cur_res)
81 | i += 1
82 |
83 | print(init_pose_path)
84 | inout.save_json(init_pose_path, results, sort=False)
85 |
86 | VIS = False
87 | if VIS:
88 | from core.utils.data_utils import read_image_mmcv
89 | from lib.utils.mask_utils import cocosegm2mask
90 | from lib.vis_utils.image import grid_show, heatmap
91 | from core.catre.engine.test_utils import get_3d_bbox
92 | import ref
93 |
94 | for scene_im_id, r in results.items():
95 | img_path = f"datasets/NOCS/REAL/real_test/{scene_im_id}_color.png"
96 | img = read_image_mmcv(img_path, format="BGR")
97 | K = ref.nocs.real_intrinsics
98 | anno = r[0]
99 | imH, imW = img.shape[:2]
100 | mask = cocosegm2mask(anno["segmentation"], imH, imW)
101 | pose = np.array(anno["pose_est"]).reshape(3, 4)
102 | scale = np.array(anno["scale_est"])
103 | bbox = get_3d_bbox(scale).transpose()
104 | kpts_2d = misc.project_pts(bbox, K, pose[:, :3], pose[:, 3])
105 | img_vis_kpts2d = misc.draw_projected_box3d(img.copy(), kpts_2d)
106 | grid_show([img[:, :, ::-1], img_vis_kpts2d, mask], row=3, col=1)
107 |
--------------------------------------------------------------------------------
/core/catre/train_catre.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -x
3 | this_dir=$(dirname "$0")
4 | # commonly used opts:
5 |
6 | # MODEL.WEIGHTS: resume or pretrained, or test checkpoint
7 | CFG=$1
8 | CUDA_VISIBLE_DEVICES=$2
9 | IFS=',' read -ra GPUS <<< "$CUDA_VISIBLE_DEVICES"
10 | # GPUS=($(echo "$CUDA_VISIBLE_DEVICES" | tr ',' '\n'))
11 | NGPU=${#GPUS[@]} # echo "${GPUS[0]}"
12 | echo "use gpu ids: $CUDA_VISIBLE_DEVICES num gpus: $NGPU"
13 | # CUDA_LAUNCH_BLOCKING=1
14 | NCCL_DEBUG=INFO
15 | OMP_NUM_THREADS=1
16 | MKL_NUM_THREADS=1
17 | PYTHONPATH="$this_dir/../..":$PYTHONPATH \
18 | CUDA_VISIBLE_DEVICES=$2 python $this_dir/main_catre.py \
19 | --config-file $CFG --num-gpus $NGPU ${@:3}
20 |
--------------------------------------------------------------------------------
/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/core/utils/__init__.py
--------------------------------------------------------------------------------
/core/utils/apex_trainer.py:
--------------------------------------------------------------------------------
1 | # just a reference implementation, no use
2 |
3 | import logging
4 | import time
5 | import torch
6 | from detectron2.engine import SimpleTrainer
7 | import core.utils.my_comm as comm
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 | try:
12 | import apex
13 | from apex import amp
14 | except:
15 | logger.exception("Please install apex from https://www.github.com/nvidia/apex to use this ApexTrainer.")
16 |
17 |
18 | class ApexTrainer(SimpleTrainer):
19 | """Like :class:`SimpleTrainer`, but uses NVIDIA's apex automatic mixed
20 | precision in the training loop."""
21 |
22 | def __init__(self, model, data_loader, optimizer, apex_opt_level="O1"):
23 | """
24 | Args:
25 | model, data_loader, optimizer: same as in :class:`SimpleTrainer`.
26 | grad_scaler: torch GradScaler to automatically scale gradients.
27 | """
28 | if comm.get_world_size() > 1:
29 | model, optimizer = amp.initialize(model, optimizer, opt_level=apex_opt_level)
30 | super().__init__(model, data_loader, optimizer)
31 |
32 | def run_step(self):
33 | """Implement the AMP training logic using apex."""
34 | assert self.model.training, "[ApexTrainer] model was changed to eval mode!"
35 | assert torch.cuda.is_available(), "[ApexTrainer] CUDA is required for AMP training!"
36 |
37 | start = time.perf_counter()
38 | data = next(self._data_loader_iter)
39 | data_time = time.perf_counter() - start
40 |
41 | loss_dict = self.model(data)
42 | if isinstance(loss_dict, torch.Tensor):
43 | losses = loss_dict
44 | loss_dict = {"total_loss": loss_dict}
45 | else:
46 | losses = sum(loss_dict.values())
47 |
48 | self.optimizer.zero_grad()
49 | with amp.scale_loss(losses, self.optimizer) as scaled_loss:
50 | scaled_loss.backwward()
51 |
52 | self._write_metrics(loss_dict, data_time)
53 |
54 | self.optimizer.step()
55 |
56 | def state_dict(self):
57 | ret = super().state_dict()
58 | # save amp state according to
59 | # https://nvidia.github.io/apex/amp.html#checkpointing
60 | ret["amp"] = amp.state_dict()
61 | return ret
62 |
63 | def load_state_dict(self, state_dict):
64 | super().load_state_dict(state_dict)
65 | if "amp" in state_dict:
66 | amp.load_state_dict(state_dict["amp"])
67 |
--------------------------------------------------------------------------------
/core/utils/camera_geometry.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import torch
4 |
5 |
6 | def get_K_crop_resize(K, crop_xy, resize_ratio):
7 | """
8 | Args:
9 | K: [b,3,3]
10 | crop_xy: [b, 2] left top of crop boxes
11 | resize_ratio: [b,2] or [b,1]
12 | """
13 | assert K.shape[1:] == (3, 3)
14 | assert crop_xy.shape[1:] == (2,)
15 | assert resize_ratio.shape[1:] == (2,) or resize_ratio.shape[1:] == (1,)
16 | bs = K.shape[0]
17 |
18 | new_K = K.clone()
19 | new_K[:, [0, 1], 2] = K[:, [0, 1], 2] - crop_xy # [b, 2]
20 | new_K[:, [0, 1]] = new_K[:, [0, 1]] * resize_ratio.view(bs, -1, 1)
21 | return new_K
22 |
23 |
24 | def project_points(points_3d, K, pose, z_min=None):
25 | """
26 | Args:
27 | points_3d: BxPx3
28 | K: Bx3x3
29 | pose: Bx3x4
30 | z_min: prevent zero devision, eg. 0.1
31 | Returns:
32 | projected 2d points: BxPx2
33 | """
34 | assert K.shape[-2:] == (3, 3)
35 | assert pose.shape[-2:] == (3, 4)
36 | batch_size = points_3d.shape[0]
37 | n_points = points_3d.shape[1]
38 | device = points_3d.device
39 | if points_3d.shape[-1] == 3:
40 | points_3d = torch.cat((points_3d, torch.ones(batch_size, n_points, 1).to(device)), dim=-1)
41 | P = K @ pose[:, :3]
42 | suv = (P.unsqueeze(1) @ points_3d.unsqueeze(-1)).squeeze(-1) # Bx1x3x4 @ BxPx4x1 -> BxPx3
43 | if z_min is not None:
44 | z = suv[..., -1]
45 | suv[..., -1] = torch.max(torch.ones_like(z) * z_min, z) # eg. z_min=0.1
46 | suv = suv / suv[..., [-1]]
47 | return suv[..., :2] # BxPx2
48 |
49 |
50 | def centers_2d_from_t(K, t, z_min=None):
51 | """can also get the centers via projecting the zero point (B,1,3)
52 | Args:
53 | K: Bx3x3
54 | t: Bx3
55 | z_min: to prevent zero division
56 | Returns:
57 | centers_2d: Bx2
58 | """
59 | assert K.ndim == 3 and K.shape[-2:] == (3, 3), K.shape
60 | bs = K.shape[0]
61 | proj = (K @ t.view(bs, 3, 1)).view(bs, 3)
62 | if z_min is not None:
63 | z = proj[..., -1]
64 | proj[..., -1] = torch.max(torch.ones_like(z) * z_min, z) # eg. z_min=0.1
65 | centers_2d = proj[:, :2] / proj[:, [-1]] # Nx2
66 | return centers_2d
67 |
68 |
69 | def centers_2d_from_pose(K, pose, z_min=None):
70 | """can also get the centers via projecting the zero point (B,1,3)
71 | Args:
72 | K: Bx3x3
73 | pose: Bx3x4 (only use the transltion)
74 | z_min: to prevent zero division
75 | Returns:
76 | centers_2d: Bx2
77 | """
78 | assert K.ndim == 3 and K.shape[-2:] == (3, 3), K.shape
79 | assert pose.ndim == 3 and pose.shape[-2:] == (3, 4), pose.shape
80 | bs = pose.shape[0]
81 | proj = (K @ pose[:, :3, [3]]).view(bs, 3) # Nx3x3 @ Nx3x1 -> Nx3x1 -> Nx3
82 | if z_min is not None:
83 | z = proj[..., -1]
84 | proj[..., -1] = torch.max(torch.ones_like(z) * z_min, z) # eg. z_min=0.1
85 | centers_2d = proj[:, :2] / proj[:, [-1]] # Nx2
86 | return centers_2d
87 |
88 |
89 | def boxes_from_points_2d(uv):
90 | """
91 | Args:
92 | uv: BxPx2 projected 2d points
93 | Returns:
94 | Bx4
95 | """
96 | assert uv.ndim == 3 and uv.shape[-1] == 2, uv.shape
97 | x1 = uv[..., 0].min(dim=1)[0] # (B,)
98 | y1 = uv[..., 1].min(dim=1)[0]
99 |
100 | x2 = uv[..., 0].max(dim=1)[0]
101 | y2 = uv[..., 1].max(dim=1)[0]
102 |
103 | return torch.stack([x1, y1, x2, y2], dim=1) # Bx4
104 |
105 |
106 | def bboxes_from_pose(points_3d, K, pose, z_min=None, imH=480, imW=640, clamp=False):
107 | points_2d = project_points(points_3d, K=K, pose=pose, z_min=z_min)
108 | boxes = boxes_from_points_2d(points_2d)
109 | if clamp:
110 | boxes[:, [0, 2]] = torch.clamp(boxes[:, [0, 2]], 0, imW - 1)
111 | boxes[:, [1, 3]] = torch.clamp(boxes[:, [1, 3]], 0, imH - 1)
112 | return boxes
113 |
114 |
115 | def adapt_image_by_K(
116 | image, *, K_old, K_new, interpolation=cv2.INTER_LINEAR, border_type=cv2.BORDER_REFLECT, height=480, width=640
117 | ):
118 | """adapt image from old K to new K."""
119 | H_old, W_old = image.shape[:2]
120 | K_old = K_old.copy()
121 | K_old[0, :] = K_old[0, :] / W_old * width
122 | K_old[1, :] = K_old[1, :] / H_old * height
123 |
124 | focal_scale_x = K_new[0, 0] / K_old[0, 0]
125 | focal_scale_y = K_new[1, 1] / K_old[1, 1]
126 | ox, oy = K_new[0, 2] - K_old[0, 2], K_new[1, 2] - K_old[1, 2]
127 |
128 | image = cv2.resize(
129 | image,
130 | (int(width * focal_scale_x), int(height * focal_scale_y)),
131 | interpolation=interpolation,
132 | )
133 |
134 | image = cv2.copyMakeBorder(image, 200, 200, 200, 200, borderType=border_type)
135 | # print(image.shape)
136 | y1 = int(round(image.shape[0] / 2 - oy - height / 2))
137 | y2 = int(round(image.shape[0] / 2 - oy + height / 2))
138 | x1 = int(round(image.shape[1] / 2 - ox - width / 2))
139 | x2 = int(round(image.shape[1] / 2 - ox + width / 2))
140 | # print(x1, y1, x2, y2)
141 | return image[y1:y2, x1:x2]
142 |
--------------------------------------------------------------------------------
/core/utils/depth_aug.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 |
5 | def add_noise_depth(depth, level=0.005, depth_valid_min=0):
6 | # from DeepIM-PyTorch and se3tracknet
7 | # in deepim: level=0.1, valid_min=0
8 | # in se3tracknet, level=5/1000, depth_valid_min = 100/1000 = 0.1
9 |
10 | if len(depth.shape) == 3:
11 | mask = depth[:, :, -1] > depth_valid_min
12 | row, col, ch = depth.shape
13 | noise_level = random.uniform(0, level)
14 | gauss = noise_level * np.random.randn(row, col)
15 | gauss = np.repeat(gauss[:, :, np.newaxis], ch, axis=2)
16 | else: # 2
17 | mask = depth > depth_valid_min
18 | row, col = depth.shape
19 | noise_level = random.uniform(0, level)
20 | gauss = noise_level * np.random.randn(row, col)
21 | noisy = depth.copy()
22 | noisy[mask] = depth[mask] + gauss[mask]
23 | return noisy
24 |
25 |
26 | if __name__ == "__main__":
27 | from lib.vis_utils.image import heatmap, grid_show
28 | import mmcv
29 | import cv2
30 | from skimage.restoration import denoise_bilateral
31 |
32 | # depth = mmcv.imread("datasets/BOP_DATASETS/ycbv/train_pbr/000000/depth/000000.png", "unchanged") / 10000.0
33 | # depth_aug = add_noise_depth(depth, level=0.005, depth_valid_min=0.1)
34 | # diff = depth_aug - depth
35 | # grid_show([
36 | # heatmap(depth, to_rgb=True), heatmap(depth_aug, to_rgb=True),
37 | # heatmap(diff, to_rgb=True)
38 | # ], ["depth", "depth_aug", "diff"], row=1, col=3)
39 |
40 | depth = (mmcv.imread("datasets/BOP_DATASETS/ycbv/test/000048/depth/000001.png", "unchanged") / 10000.0).astype(
41 | "float32"
42 | )
43 | # diameter, pix_sigma, space_sigma
44 | depth_aug = cv2.bilateralFilter(depth, 11, 0.1, 30)
45 | # depth_aug = denoise_bilateral(depth, sigma_color=0.05, sigma_spatial=15)
46 | diff = depth_aug - depth
47 | grid_show(
48 | [heatmap(depth, to_rgb=True), heatmap(depth_aug, to_rgb=True), heatmap(diff, to_rgb=True)],
49 | ["depth", "depth_aug", "diff"],
50 | row=1,
51 | col=3,
52 | )
53 |
--------------------------------------------------------------------------------
/core/utils/farthest_points_torch.py:
--------------------------------------------------------------------------------
1 | # https://github.com/NVlabs/latentfusion/blob/master/latentfusion/three/utils.py
2 | import torch
3 | from torch.nn import functional as F
4 |
5 |
6 | def farthest_points(
7 | data,
8 | n_clusters: int,
9 | dist_func=F.pairwise_distance,
10 | return_center_indexes=True,
11 | return_distances=False,
12 | verbose=False,
13 | init_center=True,
14 | ):
15 | """Performs farthest point sampling on data points.
16 |
17 | Args:
18 | data (torch.tensor): data points.
19 | n_clusters (int): number of clusters.
20 | dist_func (Callable): distance function that is used to compare two data points.
21 | return_center_indexes (bool): if True, returns the indexes of the center of clusters.
22 | return_distances (bool): if True, return distances of each point from centers.
23 | Returns clusters, [centers, distances]:
24 | clusters (torch.tensor): the cluster index for each element in data.
25 | centers (torch.tensor): the integer index of each center.
26 | distances (torch.tensor): closest distances of each point to any of the cluster centers.
27 | """
28 | if n_clusters >= data.shape[0]:
29 | if return_center_indexes:
30 | return (
31 | torch.arange(data.shape[0], dtype=torch.long),
32 | torch.arange(data.shape[0], dtype=torch.long),
33 | )
34 |
35 | return torch.arange(data.shape[0], dtype=torch.long)
36 |
37 | clusters = torch.full((data.shape[0],), fill_value=-1, dtype=torch.long)
38 | centers = torch.zeros(n_clusters, dtype=torch.long)
39 |
40 | if init_center:
41 | broadcasted_data = torch.mean(data, 0, keepdim=True).expand(data.shape[0], -1)
42 | distances = dist_func(broadcasted_data, data)
43 | else:
44 | distances = torch.full((data.shape[0],), fill_value=1e7, dtype=torch.float32)
45 |
46 | for i in range(n_clusters):
47 | center_idx = torch.argmax(distances)
48 | centers[i] = center_idx
49 |
50 | broadcasted_data = data[center_idx].unsqueeze(0).expand(data.shape[0], -1)
51 | new_distances = dist_func(broadcasted_data, data)
52 | distances = torch.min(distances, new_distances)
53 | clusters[distances == new_distances] = i
54 | if verbose:
55 | print("farthest points max distance : {}".format(torch.max(distances)))
56 |
57 | if return_center_indexes:
58 | if return_distances:
59 | return clusters, centers, distances
60 | return clusters, centers
61 |
62 | return clusters
63 |
64 |
65 | def get_fps_and_center_torch(points, num_fps: int, init_center=True, dist_func=F.pairwise_distance):
66 | center = torch.mean(points, 0, keepdim=True)
67 | _, fps_inds = farthest_points(
68 | points,
69 | n_clusters=num_fps,
70 | dist_func=dist_func,
71 | return_center_indexes=True,
72 | init_center=init_center,
73 | )
74 | fps_pts = points[fps_inds]
75 | return torch.cat([fps_pts, center], dim=0)
76 |
--------------------------------------------------------------------------------
/core/utils/my_setup.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os.path as osp
3 |
4 |
5 | def get_project_root():
6 | cur_dir = osp.dirname(osp.abspath(__file__))
7 | proj_root = osp.normpath(osp.join(cur_dir, "../../"))
8 | return proj_root
9 |
10 |
11 | PROJ_ROOT = get_project_root()
12 |
13 |
14 | def get_data_root():
15 | proj_root = get_project_root()
16 | return osp.join(proj_root, "datasets")
17 |
18 |
19 | DATA_ROOT = get_data_root()
20 |
21 |
22 | def setup_for_distributed(is_master):
23 | """This function disables printing when not in master process."""
24 | import builtins as __builtin__
25 |
26 | builtin_print = __builtin__.print
27 | if not is_master:
28 | logging.getLogger("core").setLevel("WARN")
29 | logging.getLogger("d2").setLevel("WARN")
30 | logging.getLogger("lib").setLevel("WARN")
31 | logging.getLogger("my").setLevel("WARN")
32 |
33 | def print(*args, **kwargs):
34 | force = kwargs.pop("force", False)
35 | if is_master or force:
36 | builtin_print(*args, **kwargs)
37 |
38 | __builtin__.print = print
39 |
--------------------------------------------------------------------------------
/core/utils/ssd_color_transform.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | # ref: https://github.com/facebookresearch/detectron2/blob/master/projects/PointRend/point_rend/color_augmentation.py
3 |
4 | import numpy as np
5 | import random
6 | import cv2
7 | from fvcore.transforms.transform import Transform
8 |
9 |
10 | class ColorAugSSDTransform(Transform):
11 | """
12 | A color related data augmentation used in Single Shot Multibox Detector (SSD).
13 | Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy,
14 | Scott Reed, Cheng-Yang Fu, Alexander C. Berg.
15 | SSD: Single Shot MultiBox Detector. ECCV 2016.
16 | Implementation based on:
17 | https://github.com/weiliu89/caffe/blob
18 | /4817bf8b4200b35ada8ed0dc378dceaf38c539e4
19 | /src/caffe/util/im_transforms.cpp
20 | https://github.com/chainer/chainercv/blob
21 | /7159616642e0be7c5b3ef380b848e16b7e99355b/chainercv
22 | /links/model/ssd/transforms.py
23 | """
24 |
25 | def __init__(
26 | self,
27 | img_format,
28 | brightness_delta=32,
29 | contrast_low=0.5,
30 | contrast_high=1.5,
31 | saturation_low=0.5,
32 | saturation_high=1.5,
33 | hue_delta=18,
34 | ):
35 | super().__init__()
36 | assert img_format in ["BGR", "RGB"]
37 | self.is_rgb = img_format == "RGB"
38 | del img_format
39 | self._set_attributes(locals())
40 |
41 | def apply_coords(self, coords):
42 | return coords
43 |
44 | def apply_segmentation(self, segmentation):
45 | return segmentation
46 |
47 | def apply_image(self, img, interp=None):
48 | if self.is_rgb:
49 | img = img[:, :, [2, 1, 0]]
50 | img = self.brightness(img)
51 | if random.randrange(2):
52 | img = self.contrast(img)
53 | img = self.saturation(img)
54 | img = self.hue(img)
55 | else:
56 | img = self.saturation(img)
57 | img = self.hue(img)
58 | img = self.contrast(img)
59 | if self.is_rgb:
60 | img = img[:, :, [2, 1, 0]]
61 | return img
62 |
63 | def convert(self, img, alpha=1, beta=0):
64 | img = img.astype(np.float32) * alpha + beta
65 | img = np.clip(img, 0, 255)
66 | return img.astype(np.uint8)
67 |
68 | def brightness(self, img):
69 | if random.randrange(2):
70 | return self.convert(
71 | img,
72 | beta=random.uniform(-self.brightness_delta, self.brightness_delta),
73 | )
74 | return img
75 |
76 | def contrast(self, img):
77 | if random.randrange(2):
78 | return self.convert(
79 | img,
80 | alpha=random.uniform(self.contrast_low, self.contrast_high),
81 | )
82 | return img
83 |
84 | def saturation(self, img):
85 | if random.randrange(2):
86 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
87 | img[:, :, 1] = self.convert(
88 | img[:, :, 1],
89 | alpha=random.uniform(self.saturation_low, self.saturation_high),
90 | )
91 | return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
92 | return img
93 |
94 | def hue(self, img):
95 | if random.randrange(2):
96 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
97 | img[:, :, 0] = (img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta)) % 180
98 | return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
99 | return img
100 |
--------------------------------------------------------------------------------
/core/utils/timm_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import timm
4 | import pathlib
5 |
6 | _logger = logging.getLogger(__name__)
7 |
8 |
9 | def my_create_timm_model(**init_args):
10 | # HACK: fix the bug for feature_only=True and checkpoint_path != ""
11 | # https://github.com/rwightman/pytorch-image-models/issues/488
12 | if init_args.get("checkpoint_path", "") != "" and init_args.get("features_only", True):
13 | init_args = copy.deepcopy(init_args)
14 | full_model_name = init_args["model_name"]
15 | modules = timm.models.list_modules()
16 | # find the mod which has the longest common name in model_name
17 | mod_len = 0
18 | for m in modules:
19 | if m in full_model_name:
20 | cur_mod_len = len(m)
21 | if cur_mod_len > mod_len:
22 | mod = m
23 | mod_len = cur_mod_len
24 | if mod_len >= 1:
25 | if hasattr(timm.models.__dict__[mod], "default_cfgs"):
26 | ckpt_path = init_args.pop("checkpoint_path")
27 | ckpt_url = pathlib.Path(ckpt_path).resolve().as_uri()
28 | _logger.warning(f"hacking model pretrained url to {ckpt_url}")
29 | timm.models.__dict__[mod].default_cfgs[full_model_name]["url"] = ckpt_url
30 | init_args["pretrained"] = True
31 | else:
32 | raise ValueError(f"model_name {full_model_name} has no module in timm")
33 |
34 | backbone = timm.create_model(**init_args)
35 | return backbone
36 |
--------------------------------------------------------------------------------
/core/utils/zoom_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from detectron2.layers.roi_align import ROIAlign
3 | from torchvision.ops import RoIPool
4 |
5 |
6 | def deepim_boxes(
7 | ren_boxes,
8 | ren_centers_2d,
9 | obs_boxes=None,
10 | lamb=1.4,
11 | imHW=(480, 640),
12 | outHW=(480, 640),
13 | clamp=False,
14 | ):
15 | """
16 | Args:
17 | ren_boxes: N x 4
18 | ren_centers_2d: Nx2, rendered object center is the crop center
19 | obs_boxes: N x 4, if None, only use the rendered boxes/centers to determine the crop region
20 | lamb: enlarge the scale of cropped region
21 | imH (int):
22 | imW (int):
23 | Returns:
24 | crop_boxes (Tensor): N x 4, either the common region from obs/ren or just obs
25 | resize_ratios (Tensor): Nx2, resize ratio of (w,h), actually the same in w,h because we keep the aspect ratio
26 | """
27 | ren_x1, ren_y1, ren_x2, ren_y2 = (ren_boxes[:, i] for i in range(4)) # (N,)
28 | ren_cx = ren_centers_2d[:, 0] # (N,)
29 | ren_cy = ren_centers_2d[:, 1] # (N,)
30 |
31 | outH, outW = outHW
32 | aspect_ratio = outW / outH # 4/3 or 1
33 |
34 | if obs_boxes is not None:
35 | obs_x1, obs_y1, obs_x2, obs_y2 = (obs_boxes[:, i] for i in range(4)) # (N,)
36 | xdists = torch.stack(
37 | [
38 | ren_cx - obs_x1,
39 | ren_cx - ren_x1,
40 | obs_x2 - ren_cx,
41 | ren_x2 - ren_cx,
42 | ],
43 | dim=1,
44 | ).abs()
45 | ydists = torch.stack(
46 | [
47 | ren_cy - obs_y1,
48 | ren_cy - ren_y1,
49 | obs_y2 - ren_cy,
50 | ren_y2 - ren_cy,
51 | ],
52 | dim=1,
53 | ).abs()
54 | else:
55 | xdists = torch.stack([ren_cx - ren_x1, ren_x2 - ren_cx], dim=1).abs()
56 | ydists = torch.stack([ren_cy - ren_y1, ren_y2 - ren_cy], dim=1).abs()
57 | xdist = xdists.max(dim=1)[0] # (N,)
58 | ydist = ydists.max(dim=1)[0]
59 |
60 | crop_h = torch.max(xdist / aspect_ratio, ydist).clamp(min=1) * 2 * lamb # (N,)
61 | crop_w = crop_h * aspect_ratio # (N,)
62 |
63 | x1, y1, x2, y2 = (
64 | ren_cx - crop_w / 2,
65 | ren_cy - crop_h / 2,
66 | ren_cx + crop_w / 2,
67 | ren_cy + crop_h / 2,
68 | )
69 | boxes = torch.stack([x1, y1, x2, y2], dim=1)
70 | assert not clamp
71 | if clamp:
72 | imH, imW = imHW
73 | boxes[:, [0, 2]] = torch.clamp(boxes[:, [0, 2]], 0, imW - 1)
74 | boxes[:, [1, 3]] = torch.clamp(boxes[:, [1, 3]], 0, imH - 1)
75 |
76 | resize_ratios = torch.stack([outW / crop_w, outH / crop_h], dim=1)
77 | return boxes, resize_ratios
78 |
79 |
80 | def batch_crop_resize(x, rois, out_H, out_W, aligned=True, interpolation="bilinear"):
81 | """
82 | Args:
83 | x: BCHW
84 | rois: Bx5, rois[:, 0] is the idx into x
85 | out_H (int):
86 | out_W (int):
87 | """
88 | output_size = (out_H, out_W)
89 | if interpolation == "bilinear":
90 | op = ROIAlign(output_size, 1.0, 0, aligned=aligned)
91 | elif interpolation == "nearest":
92 | op = RoIPool(output_size, 1.0) #
93 | else:
94 | raise ValueError(f"Wrong interpolation type: {interpolation}")
95 | return op(x, rois)
96 |
--------------------------------------------------------------------------------
/datasets/NOCS/REAL/real_test:
--------------------------------------------------------------------------------
1 | /data2/lxy/object_pose_benchmark/datasets/NOCS/REAL/real_test/
--------------------------------------------------------------------------------
/datasets/NOCS/REAL/real_train:
--------------------------------------------------------------------------------
1 | /data2/lxy/object_pose_benchmark/datasets/NOCS/REAL/real_train/
--------------------------------------------------------------------------------
/datasets/NOCS/obj_models/abs_scale.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/abs_scale.pkl
--------------------------------------------------------------------------------
/datasets/NOCS/obj_models/cr_normed_mean_model_points_spd.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/cr_normed_mean_model_points_spd.pkl
--------------------------------------------------------------------------------
/datasets/NOCS/obj_models/mug_handle.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/mug_handle.pkl
--------------------------------------------------------------------------------
/datasets/NOCS/obj_models/mug_meta.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/mug_meta.pkl
--------------------------------------------------------------------------------
/datasets/NOCS/obj_models/real_test_spd.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/real_test_spd.pkl
--------------------------------------------------------------------------------
/datasets/NOCS/obj_models/real_train_spd.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/datasets/NOCS/obj_models/real_train_spd.pkl
--------------------------------------------------------------------------------
/docs/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | * CUDA >= 10.1, Ubuntu >= 16.04
4 |
5 | * Python >= 3.6, PyTorch >= 1.9, torchvision
6 | ```
7 | ## create a new environment
8 | conda create -n py37 python=3.7.4
9 | conda activate py37 # maybe add this line to the end of ~/.bashrc
10 | conda install ipython
11 | ## install pytorch: https://pytorch.org/get-started/locally/ (check cuda version)
12 | pip install torchvision -U # will also install corresponding torch
13 | ```
14 |
15 | * `detectron2` from [source](https://github.com/facebookresearch/detectron2).
16 | ```
17 | git clone https://github.com/facebookresearch/detectron2.git
18 | cd detectron2
19 | pip install ninja
20 | pip install -e .
21 | ```
22 |
23 | * `sh scripts/install_deps.sh`
24 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/__init__.py
--------------------------------------------------------------------------------
/lib/pysixd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/pysixd/__init__.py
--------------------------------------------------------------------------------
/lib/pysixd/colors.json:
--------------------------------------------------------------------------------
1 | [
2 | [0.89, 0.28, 0.13],
3 | [0.45, 0.38, 0.92],
4 | [0.35, 0.73, 0.63],
5 | [0.62, 0.28, 0.91],
6 | [0.65, 0.71, 0.22],
7 | [0.8, 0.29, 0.89],
8 | [0.27, 0.55, 0.22],
9 | [0.37, 0.46, 0.84],
10 | [0.84, 0.63, 0.22],
11 | [0.68, 0.29, 0.71],
12 | [0.48, 0.75, 0.48],
13 | [0.88, 0.27, 0.75],
14 | [0.82, 0.45, 0.2],
15 | [0.86, 0.27, 0.27],
16 | [0.52, 0.49, 0.18],
17 | [0.33, 0.67, 0.25],
18 | [0.67, 0.42, 0.29],
19 | [0.67, 0.46, 0.86],
20 | [0.36, 0.72, 0.84],
21 | [0.85, 0.29, 0.4],
22 | [0.24, 0.53, 0.55],
23 | [0.85, 0.55, 0.8],
24 | [0.4, 0.51, 0.33],
25 | [0.56, 0.38, 0.63],
26 | [0.78, 0.66, 0.46],
27 | [0.33, 0.5, 0.72],
28 | [0.83, 0.31, 0.56],
29 | [0.56, 0.61, 0.85],
30 | [0.89, 0.58, 0.57],
31 | [0.67, 0.4, 0.49]
32 | ]
--------------------------------------------------------------------------------
/lib/pysixd/config.py:
--------------------------------------------------------------------------------
1 | # Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz)
2 | # Center for Machine Perception, Czech Technical University in Prague
3 |
4 | """Configuration of the BOP Toolkit."""
5 |
6 | import os
7 |
8 |
9 | ######## Basic ########
10 |
11 | # Folder with the BOP datasets.
12 | if "BOP_PATH" in os.environ:
13 | datasets_path = os.environ["BOP_PATH"]
14 | else:
15 | datasets_path = r"datasets/BOP_DATASETS/"
16 |
17 | # Folder with pose results to be evaluated.
18 | # results_path = r'/path/to/folder/with/results'
19 | results_path = r"output/bop_results"
20 |
21 | # Folder for the calculated pose errors and performance scores.
22 | # eval_path = r'/path/to/eval/folder'
23 | eval_path = r"output/bop_eval/"
24 |
25 | ######## Extended ########
26 |
27 | # Folder for outputs (e.g. visualizations).
28 | # output_path = r'/path/to/output/folder'
29 | output_path = r"output/bop_output/"
30 |
31 | # For offscreen C++ rendering: Path to the build folder of bop_renderer (github.com/thodan/bop_renderer).
32 | bop_renderer_path = r"bop_renderer/build"
33 |
34 | # Executable of the MeshLab server.
35 | # meshlab_server_path = r'/path/to/meshlabserver.exe'
36 | meshlab_server_path = r"/usr/bin/meshlabserver"
37 |
--------------------------------------------------------------------------------
/lib/pysixd/renderer.py:
--------------------------------------------------------------------------------
1 | # Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz)
2 | # Center for Machine Perception, Czech Technical University in Prague
3 |
4 | """Abstract class of a renderer and a factory function to create a renderer.
5 |
6 | The renderer produces an RGB/depth image of a 3D mesh model in a
7 | specified pose for given camera parameters and illumination settings.
8 | """
9 |
10 |
11 | class Renderer(object):
12 | """Abstract class of a renderer."""
13 |
14 | def __init__(self, width, height):
15 | """Constructor.
16 |
17 | :param width: Width of the rendered image.
18 | :param height: Height of the rendered image.
19 | """
20 | self.width = width
21 | self.height = height
22 |
23 | # 3D location of a point light (in the camera coordinates).
24 | self.light_cam_pos = (0, 0, 0)
25 |
26 | # Set light color and weights.
27 | self.light_color = (1.0, 1.0, 1.0) # Used only in C++ renderer.
28 | self.light_ambient_weight = 0.5
29 | self.light_diffuse_weight = 1.0 # Used only in C++ renderer.
30 | self.light_specular_weight = 0.0 # Used only in C++ renderer.
31 | self.light_specular_shininess = 0.0 # Used only in C++ renderer.
32 |
33 | def set_light_cam_pos(self, light_cam_pos):
34 | """Sets the 3D location of a point light.
35 |
36 | :param light_cam_pos: [X, Y, Z].
37 | """
38 | self.light_cam_pos = light_cam_pos
39 |
40 | def set_light_ambient_weight(self, light_ambient_weight):
41 | """Sets weight of the ambient light.
42 |
43 | :param light_ambient_weight: Scalar from 0 to 1.
44 | """
45 | self.light_ambient_weight = light_ambient_weight
46 |
47 | def add_object(self, obj_id, model_path, **kwargs):
48 | """Loads an object model.
49 |
50 | :param obj_id: Object identifier.
51 | :param model_path: Path to the object model file.
52 | """
53 | raise NotImplementedError
54 |
55 | def remove_object(self, obj_id):
56 | """Removes an object model.
57 |
58 | :param obj_id: Identifier of the object to remove.
59 | """
60 | raise NotImplementedError
61 |
62 | def render_object(self, obj_id, R, t, fx, fy, cx, cy):
63 | """Renders an object model in the specified pose.
64 |
65 | :param obj_id: Object identifier.
66 | :param R: 3x3 ndarray with a rotation matrix.
67 | :param t: 3x1 ndarray with a translation vector.
68 | :param fx: Focal length (X axis).
69 | :param fy: Focal length (Y axis).
70 | :param cx: The X coordinate of the principal point.
71 | :param cy: The Y coordinate of the principal point.
72 | :return: Returns a dictionary with rendered images.
73 | """
74 | raise NotImplementedError
75 |
76 |
77 | def create_renderer(
78 | width,
79 | height,
80 | renderer_type="cpp",
81 | mode="rgb+depth",
82 | shading="phong",
83 | bg_color=(0.0, 0.0, 0.0, 0.0),
84 | ):
85 | """A factory to create a renderer.
86 |
87 | Note: Parameters mode, shading and bg_color are currently supported only by
88 | the Python renderer (renderer_type='python').
89 |
90 | :param width: Width of the rendered image.
91 | :param height: Height of the rendered image.
92 | :param renderer_type: Type of renderer (options: 'cpp', 'python').
93 | :param mode: Rendering mode ('rgb+depth', 'rgb', 'depth').
94 | :param shading: Type of shading ('flat', 'phong').
95 | :param bg_color: Color of the background (R, G, B, A).
96 | :return: Instance of a renderer of the specified type.
97 | """
98 | if renderer_type == "python":
99 | from . import renderer_py
100 |
101 | return renderer_py.RendererPython(width, height, mode, shading, bg_color)
102 |
103 | elif renderer_type == "cpp":
104 | from . import renderer_cpp
105 |
106 | return renderer_cpp.RendererCpp(width, height)
107 |
108 | elif renderer_type == "pyrender":
109 | from . import renderer_pyrender
110 |
111 | return renderer_pyrender.Renderer(width, height, mode=mode, bg_color=bg_color)
112 |
113 | elif renderer_type == "vispy":
114 | from . import renderer_vispy
115 |
116 | return renderer_vispy.RendererVispy(width, height, mode, shading=shading, bg_color=bg_color)
117 |
118 | else:
119 | raise ValueError("Unknown renderer type.")
120 |
--------------------------------------------------------------------------------
/lib/pysixd/renderer_cpp.py:
--------------------------------------------------------------------------------
1 | # Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz)
2 | # Center for Machine Perception, Czech Technical University in Prague
3 |
4 | """An interface to the C++ based renderer (bop_renderer)."""
5 |
6 | import sys
7 | import numpy as np
8 |
9 | from lib.pysixd import config
10 | from lib.pysixd import renderer
11 |
12 | # C++ renderer (https://github.com/thodan/bop_renderer)
13 | sys.path.append(config.bop_renderer_path)
14 | import bop_renderer
15 |
16 |
17 | class RendererCpp(renderer.Renderer):
18 | """An interface to the C++ based renderer."""
19 |
20 | def __init__(self, width, height):
21 | """See base class."""
22 | super(RendererCpp, self).__init__(width, height)
23 | self.renderer = bop_renderer.Renderer()
24 | self.renderer.init(width, height)
25 | self._set_light()
26 |
27 | def _set_light(self):
28 | self.renderer.set_light(
29 | list(self.light_cam_pos),
30 | list(self.light_color),
31 | self.light_ambient_weight,
32 | self.light_diffuse_weight,
33 | self.light_specular_weight,
34 | self.light_specular_shininess,
35 | )
36 |
37 | def set_light_cam_pos(self, light_cam_pos):
38 | """See base class."""
39 | super(RendererCpp, self).set_light_cam_pos(light_cam_pos)
40 | self._set_light()
41 |
42 | def set_light_ambient_weight(self, light_ambient_weight):
43 | """See base class."""
44 | super(RendererCpp, self).set_light_ambient_weight(light_ambient_weight)
45 | self._set_light()
46 |
47 | def add_object(self, obj_id, model_path, **kwargs):
48 | """See base class.
49 |
50 | NEEDS TO BE CALLED RIGHT AFTER CREATING THE RENDERER (this is
51 | due to some memory issues in the C++ renderer which need to be
52 | fixed).
53 | """
54 | self.renderer.add_object(obj_id, model_path)
55 |
56 | def remove_object(self, obj_id):
57 | """See base class."""
58 | self.renderer.remove_object(obj_id)
59 |
60 | def render_object(self, obj_id, R, t, fx, fy, cx, cy):
61 | """See base class."""
62 | R_l = R.astype(np.float32).flatten().tolist()
63 | t_l = t.astype(np.float32).flatten().tolist()
64 | self.renderer.render_object(obj_id, R_l, t_l, fx, fy, cx, cy)
65 | rgb = self.renderer.get_color_image(obj_id)
66 | depth = self.renderer.get_depth_image(obj_id).astype(np.float32)
67 | return {"rgb": rgb, "depth": depth}
68 |
--------------------------------------------------------------------------------
/lib/pysixd/se3.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.matlib as npm
3 | from transforms3d.quaternions import mat2quat, quat2mat, quat2axangle
4 | import scipy.stats as sci_stats
5 |
6 |
7 | # RT is a 3x4 matrix
8 | def se3_inverse(RT):
9 | R = RT[0:3, 0:3]
10 | T = RT[0:3, 3].reshape((3, 1))
11 | RT_new = np.zeros((3, 4), dtype=np.float32)
12 | RT_new[0:3, 0:3] = R.transpose()
13 | RT_new[0:3, 3] = -1 * np.dot(R.transpose(), T).reshape(3)
14 | return RT_new
15 |
16 |
17 | def se3_mul(RT1, RT2):
18 | R1 = RT1[0:3, 0:3]
19 | T1 = RT1[0:3, 3].reshape((3, 1))
20 |
21 | R2 = RT2[0:3, 0:3]
22 | T2 = RT2[0:3, 3].reshape((3, 1))
23 |
24 | RT_new = np.zeros((3, 4), dtype=np.float32)
25 | RT_new[0:3, 0:3] = np.dot(R1, R2)
26 | T_new = np.dot(R1, T2) + T1
27 | RT_new[0:3, 3] = T_new.reshape(3)
28 | return RT_new
29 |
30 |
31 | def T_inv_transform(T_src, T_tgt):
32 | """
33 | :param T_src:
34 | :param T_tgt:
35 | :return: T_delta: delta in pixel
36 | """
37 | T_delta = np.zeros((3,), dtype=np.float32)
38 |
39 | T_delta[0] = T_tgt[0] / T_tgt[2] - T_src[0] / T_src[2]
40 | T_delta[1] = T_tgt[1] / T_tgt[2] - T_src[1] / T_src[2]
41 | T_delta[2] = np.log(T_src[2] / T_tgt[2])
42 |
43 | return T_delta
44 |
45 |
46 | def rotation_x(theta):
47 | t = theta * np.pi / 180.0
48 | R = np.zeros((3, 3), dtype=np.float32)
49 | R[0, 0] = 1
50 | R[1, 1] = np.cos(t)
51 | R[1, 2] = -(np.sin(t))
52 | R[2, 1] = np.sin(t)
53 | R[2, 2] = np.cos(t)
54 | return R
55 |
56 |
57 | def rotation_y(theta):
58 | t = theta * np.pi / 180.0
59 | R = np.zeros((3, 3), dtype=np.float32)
60 | R[0, 0] = np.cos(t)
61 | R[0, 2] = np.sin(t)
62 | R[1, 1] = 1
63 | R[2, 0] = -(np.sin(t))
64 | R[2, 2] = np.cos(t)
65 | return R
66 |
67 |
68 | def rotation_z(theta):
69 | t = theta * np.pi / 180.0
70 | R = np.zeros((3, 3), dtype=np.float32)
71 | R[0, 0] = np.cos(t)
72 | R[0, 1] = -(np.sin(t))
73 | R[1, 0] = np.sin(t)
74 | R[1, 1] = np.cos(t)
75 | R[2, 2] = 1
76 | return R
77 |
78 |
79 | def angular_distance(quat):
80 | vec, theta = quat2axangle(quat)
81 | return theta / np.pi * 180
82 |
83 |
84 | # Q is a Nx4 numpy matrix and contains the quaternions to average in the rows.
85 | # The quaternions are arranged as (w,x,y,z), with w being the scalar
86 | # The result will be the average quaternion of the input. Note that the signs
87 | # of the output quaternion can be reversed, since q and -q describe the same orientation
88 | def averageQuaternions(Q):
89 | # Number of quaternions to average
90 | M = Q.shape[0]
91 | A = npm.zeros(shape=(4, 4))
92 |
93 | for i in range(0, M):
94 | q = Q[i, :]
95 | # multiply q with its transposed version q' and add A
96 | A = np.outer(q, q) + A
97 |
98 | # scale
99 | A = (1.0 / M) * A
100 | # compute eigenvalues and -vectors
101 | eigenValues, eigenVectors = np.linalg.eig(A)
102 | # Sort by largest eigenvalue
103 | eigenVectors = eigenVectors[:, eigenValues.argsort()[::-1]]
104 | # return the real part of the largest eigenvector (has only real part)
105 | return np.real(eigenVectors[:, 0].A1)
106 |
--------------------------------------------------------------------------------
/lib/pysixd/visibility.py:
--------------------------------------------------------------------------------
1 | # Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz)
2 | # Center for Machine Perception, Czech Technical University in Prague
3 |
4 | """Estimation of the visible object surface from depth images."""
5 |
6 | import numpy as np
7 |
8 |
9 | def _estimate_visib_mask(d_test, d_model, delta, visib_mode="bop19"):
10 | """Estimates a mask of the visible object surface.
11 |
12 | :param d_test: Distance image of a scene in which the visibility is estimated.
13 | :param d_model: Rendered distance image of the object model.
14 | :param delta: Tolerance used in the visibility test.
15 | :param visib_mode: Visibility mode:
16 | 1) 'bop18' - Object is considered NOT VISIBLE at pixels with missing depth.
17 | 2) 'bop19' - Object is considered VISIBLE at pixels with missing depth. This
18 | allows to use the VSD pose error function also on shiny objects, which
19 | are typically not captured well by the depth sensors. A possible problem
20 | with this mode is that some invisible parts can be considered visible.
21 | However, the shadows of missing depth measurements, where this problem is
22 | expected to appear and which are often present at depth discontinuities,
23 | are typically relatively narrow and therefore this problem is less
24 | significant.
25 | :return: Visibility mask.
26 | """
27 | assert d_test.shape == d_model.shape
28 |
29 | if visib_mode == "bop18":
30 | mask_valid = np.logical_and(d_test > 0, d_model > 0)
31 | d_diff = d_model.astype(np.float32) - d_test.astype(np.float32)
32 | visib_mask = np.logical_and(d_diff <= delta, mask_valid)
33 |
34 | elif visib_mode == "bop19":
35 | d_diff = d_model.astype(np.float32) - d_test.astype(np.float32)
36 | visib_mask = np.logical_and(np.logical_or(d_diff <= delta, d_test == 0), d_model > 0)
37 |
38 | else:
39 | raise ValueError("Unknown visibility mode.")
40 |
41 | return visib_mask
42 |
43 |
44 | def estimate_visib_mask_gt(d_test, d_gt, delta, visib_mode="bop19"):
45 | """Estimates a mask of the visible object surface in the ground-truth pose.
46 |
47 | :param d_test: Distance image of a scene in which the visibility is estimated.
48 | :param d_gt: Rendered distance image of the object model in the GT pose.
49 | :param delta: Tolerance used in the visibility test.
50 | :param visib_mode: See _estimate_visib_mask.
51 | :return: Visibility mask.
52 | """
53 | visib_gt = _estimate_visib_mask(d_test, d_gt, delta, visib_mode)
54 | return visib_gt
55 |
56 |
57 | def estimate_visib_mask_est(d_test, d_est, visib_gt, delta, visib_mode="bop19"):
58 | """Estimates a mask of the visible object surface in the estimated pose.
59 |
60 | For an explanation of why the visibility mask is calculated differently for
61 | the estimated and the ground-truth pose, see equation (14) and related text in
62 | Hodan et al., On Evaluation of 6D Object Pose Estimation, ECCVW'16.
63 |
64 | :param d_test: Distance image of a scene in which the visibility is estimated.
65 | :param d_est: Rendered distance image of the object model in the est. pose.
66 | :param visib_gt: Visibility mask of the object model in the GT pose (from
67 | function estimate_visib_mask_gt).
68 | :param delta: Tolerance used in the visibility test.
69 | :param visib_mode: See _estimate_visib_mask.
70 | :return: Visibility mask.
71 | """
72 | visib_est = _estimate_visib_mask(d_test, d_est, delta, visib_mode)
73 | visib_est = np.logical_or(visib_est, np.logical_and(visib_gt, d_est > 0))
74 | return visib_est
75 |
--------------------------------------------------------------------------------
/lib/structures/__init__.py:
--------------------------------------------------------------------------------
1 | from .centers_2d import Center2Ds
2 | from .keypoints_2d import Keypoints2Ds
3 | from .keypoints_3d import Keypoints3Ds
4 | from .my_maps import MyMaps
5 | from .my_masks import MyBitMasks
6 | from .quats import Quats
7 | from .translations import Translations
8 | from .poses import Poses
9 | from .rots import Rots
10 | from .my_list import MyList
11 |
--------------------------------------------------------------------------------
/lib/structures/centers_2d.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Iterator, List, Tuple, Union
2 |
3 | import torch
4 | from torch import device
5 | from detectron2.utils.env import TORCH_VERSION
6 |
7 | if TORCH_VERSION < (1, 8):
8 | _maybe_jit_unused = torch.jit.unused
9 | else:
10 |
11 | def _maybe_jit_unused(x):
12 | return x
13 |
14 |
15 | class Center2Ds:
16 | """This structure stores a list of 2d centers (object/bbox centers) a Nx2
17 | torch.Tensor.
18 |
19 | Attributes:
20 | tensor: float matrix of Nx2.
21 | """
22 |
23 | def __init__(self, tensor: torch.Tensor):
24 | """
25 | Args:
26 | tensor (Tensor[float]):
27 | * a Nx2 matrix. Each row is (x, y).
28 | """
29 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
30 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
31 | if tensor.numel() == 0:
32 | # Use reshape, so we don't end up creating a new tensor that does not depend on
33 | # the inputs (and consequently confuses jit)
34 | tensor = torch.reshape(0, 2).to(dtype=torch.float32, device=device)
35 | assert tensor.ndim == 2 and (tensor.shape[-1] == 2), tensor.shape
36 |
37 | self.tensor = tensor
38 |
39 | def clone(self) -> "Center2Ds":
40 | """Clone the Center2Ds.
41 |
42 | Returns:
43 | Center2Ds
44 | """
45 | return Center2Ds(self.tensor.clone())
46 |
47 | @_maybe_jit_unused
48 | def to(self, device: torch.device = None, **kwargs) -> "Center2Ds":
49 | return Center2Ds(self.tensor.to(device=device, **kwargs))
50 |
51 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Center2Ds":
52 | """
53 | Returns:
54 | Center2Ds: Create a new :class:`Center2Ds` by indexing.
55 |
56 | The following usage are allowed:
57 | 1. `new_center2ds = center2ds[3]`: return a `Center2Ds` which contains only one 2d center.
58 | 2. `new_center2ds = center2ds[2:10]`: return a slice of center2ds.
59 | 3. `new_center2ds = center2ds[vector]`, where vector is a torch.BoolTensor
60 | with `length = len(center2ds)`. Nonzero elements in the vector will be selected.
61 |
62 | Note that the returned Center2Ds might share storage with this Center2Ds,
63 | subject to Pytorch's indexing semantics.
64 | """
65 | if isinstance(item, int):
66 | return Center2Ds(self.tensor[item].view(1, -1))
67 | b = self.tensor[item]
68 | assert b.ndim == 2, "Indexing on Center2Ds with {} failed to return a matrix!".format(item)
69 | return Center2Ds(b)
70 |
71 | def __len__(self) -> int:
72 | return self.tensor.shape[0]
73 |
74 | def __repr__(self) -> str:
75 | return "Center2Ds(" + str(self.tensor) + ")"
76 |
77 | @classmethod
78 | def cat(center2ds_list: List["Center2Ds"]) -> "Center2Ds":
79 | """Concatenates a list of Center2Ds into a single Center2Ds.
80 |
81 | Arguments:
82 | center2ds_list (list[Center2Ds])
83 |
84 | Returns:
85 | Center2Ds: the concatenated Center2Ds
86 | """
87 | if torch.jit.is_scripting():
88 | # https://github.com/pytorch/pytorch/issues/18627
89 | # 1. staticmethod can be used in torchscript, But we can not use
90 | # `type(center2ds).staticmethod` because torchscript only supports function
91 | # `type` with input type `torch.Tensor`.
92 | # 2. classmethod is not fully supported by torchscript. We explicitly assign
93 | # cls to Center2Ds as a workaround to get torchscript support.
94 | cls = Center2Ds
95 | assert isinstance(center2ds_list, (list, tuple))
96 | if len(center2ds_list) == 0:
97 | return cls(torch.empty(0))
98 | assert all(isinstance(center2ds, Center2Ds) for center2ds in center2ds_list)
99 |
100 | # use torch.cat (v.s. layers.cat) so the returned tensor never share storage with input
101 | cat_center2ds = cls(center2ds_list[0])(torch.cat([b.tensor for b in center2ds_list], dim=0))
102 | return cat_center2ds
103 |
104 | @property
105 | def device(self) -> device:
106 | return self.tensor.device
107 |
108 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
109 | # https://github.com/pytorch/pytorch/issues/18627
110 | @torch.jit.unused
111 | def __iter__(self) -> Iterator[torch.Tensor]:
112 | """Yield a 2d center as a Tensor of shape (2,) at a time."""
113 | yield from self.tensor
114 |
--------------------------------------------------------------------------------
/lib/structures/keypoints_2d.py:
--------------------------------------------------------------------------------
1 | # for example: store projected bbox3d+center3d => (N,9,2)
2 | from typing import Any, Iterator, List, Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | from torch import device
7 | from detectron2.utils.env import TORCH_VERSION
8 |
9 | if TORCH_VERSION < (1, 8):
10 | _maybe_jit_unused = torch.jit.unused
11 | else:
12 |
13 | def _maybe_jit_unused(x):
14 | return x
15 |
16 |
17 | class Keypoints2Ds:
18 | """Modified from class Keypoints.
19 |
20 | Stores 2d keypoint annotation data. GT Instances have a
21 | `gt_2d_keypoints` property containing the x,y location of each
22 | keypoint. This tensor has shape (N, K, 2) where N is the number of
23 | instances and K is the number of keypoints per instance.
24 | """
25 |
26 | def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
27 | """
28 | Arguments:
29 | keypoints: A Tensor, numpy array, or list of the x, y of each keypoint.
30 | The shape should be (N, K, 2) where N is the number of
31 | instances, and K is the number of keypoints per instance.
32 | """
33 | device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device("cpu")
34 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
35 | assert keypoints.ndim == 3 and keypoints.shape[2] == 2, keypoints.shape
36 | self.tensor = keypoints
37 |
38 | def __len__(self) -> int:
39 | return self.tensor.shape[0]
40 |
41 | def clone(self) -> "Keypoints2Ds":
42 | """Clone the Keypoints2Ds.
43 |
44 | Returns:
45 | Keypoints2Ds
46 | """
47 | return Keypoints2Ds(self.tensor.clone())
48 |
49 | def to(self, *args: Any, **kwargs: Any) -> "Keypoints2Ds":
50 | return type(self)(self.tensor.to(*args, **kwargs))
51 |
52 | def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor:
53 | # TODO: convert 2d keypoints to heatmap as proposed in Integoral Regression
54 | # copy from d2 if needed
55 | raise NotImplementedError
56 |
57 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints2Ds":
58 | """Create a new `Keypoints2Ds` by indexing on this `Keypoints2Ds`.
59 |
60 | The following usage are allowed:
61 |
62 | 1. `new_kpts = kpts[3]`: return a `Keypoints2Ds` which contains only one instance.
63 | 2. `new_kpts = kpts[2:10]`: return a slice of key points.
64 | 3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
65 | with `length = len(kpts)`. Nonzero elements in the vector will be selected.
66 |
67 | Note that the returned Keypoints2Ds might share storage with this Keypoints2Ds,
68 | subject to Pytorch's indexing semantics.
69 | """
70 | if isinstance(item, int):
71 | return Keypoints2Ds([self.tensor[item]])
72 | return Keypoints2Ds(self.tensor[item])
73 |
74 | def __repr__(self) -> str:
75 | s = self.__class__.__name__ + "("
76 | s += "num_instances={})".format(len(self.tensor))
77 | return s
78 |
79 | @classmethod
80 | def cat(cls, keypoints2ds_list: List["Keypoints2Ds"]) -> "Keypoints2Ds":
81 | """Concatenates a list of Keypoints2Ds into a single Keypoints2Ds.
82 |
83 | Arguments:
84 | keypoints2ds_list (list[Keypoints2Ds])
85 |
86 | Returns:
87 | Keypoints2Ds: the concatenated Keypoints2Ds
88 | """
89 | if torch.jit.is_scripting():
90 | # https://github.com/pytorch/pytorch/issues/18627
91 | # 1. staticmethod can be used in torchscript, But we can not use
92 | # `type(xxx).staticmethod` because torchscript only supports function
93 | # `type` with input type `torch.Tensor`.
94 | # 2. classmethod is not fully supported by torchscript. We explicitly assign
95 | # cls to ThisClassName as a workaround to get torchscript support.
96 | cls = Keypoints2Ds
97 | assert isinstance(keypoints2ds_list, (list, tuple))
98 | if len(keypoints2ds_list) == 0:
99 | return cls(torch.empty(0))
100 | assert all(isinstance(keypoints2ds, Keypoints2Ds) for keypoints2ds in keypoints2ds_list)
101 |
102 | # use torch.cat (v.s. layers.cat) so the returned tensor never share storage with input
103 | cat_keypoints2ds = type(keypoints2ds_list[0])(torch.cat([b.tensor for b in keypoints2ds_list], dim=0))
104 | return cat_keypoints2ds
105 |
106 | @property
107 | def device(self) -> device:
108 | return self.tensor.device
109 |
110 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
111 | # https://github.com/pytorch/pytorch/issues/18627
112 | @torch.jit.unused
113 | def __iter__(self) -> Iterator[torch.Tensor]:
114 | """Yield a 2d center as a Tensor of shape (2,) at a time."""
115 | yield from self.tensor
116 |
--------------------------------------------------------------------------------
/lib/structures/keypoints_3d.py:
--------------------------------------------------------------------------------
1 | # for example: store bbox3d+center3d => (N,9,3)
2 | from typing import Any, Iterator, List, Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | from torch import device
7 | from detectron2.utils.env import TORCH_VERSION
8 |
9 | if TORCH_VERSION < (1, 8):
10 | _maybe_jit_unused = torch.jit.unused
11 | else:
12 |
13 | def _maybe_jit_unused(x):
14 | return x
15 |
16 |
17 | class Keypoints3Ds:
18 | """Modified from class Keypoints.
19 |
20 | Stores 3d keypoint annotation data. GT Instances have a
21 | `gt_3d_keypoints` property containing the x,y,z location of each
22 | keypoint. This tensor has shape (N, K, 3) where N is the number of
23 | instances and K is the number of keypoints per instance.
24 | """
25 |
26 | def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
27 | """
28 | Arguments:
29 | keypoints: A Tensor, numpy array, or list of the x, y, z of each keypoint.
30 | The shape should be (N, K, 3) where N is the number of
31 | instances, and K is the number of keypoints per instance.
32 | """
33 | device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device("cpu")
34 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
35 | assert keypoints.ndim == 3 and keypoints.shape[2] == 3, keypoints.shape
36 | self.tensor = keypoints
37 |
38 | def __len__(self) -> int:
39 | return self.tensor.shape[0]
40 |
41 | def clone(self) -> "Keypoints3Ds":
42 | """Clone the Keypoints3Ds.
43 |
44 | Returns:
45 | Keypoints3Ds
46 | """
47 | return Keypoints3Ds(self.tensor.clone())
48 |
49 | def to(self, *args: Any, **kwargs: Any) -> "Keypoints3Ds":
50 | return type(self)(self.tensor.to(*args, **kwargs))
51 |
52 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints3Ds":
53 | """Create a new `Keypoints3Ds` by indexing on this `Keypoints3Ds`.
54 |
55 | The following usage are allowed:
56 |
57 | 1. `new_kpts = kpts[3]`: return a `Keypoints3Ds` which contains only one instance.
58 | 2. `new_kpts = kpts[2:10]`: return a slice of key points.
59 | 3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
60 | with `length = len(kpts)`. Nonzero elements in the vector will be selected.
61 |
62 | Note that the returned Keypoints3Ds might share storage with this Keypoints3Ds,
63 | subject to Pytorch's indexing semantics.
64 | """
65 | if isinstance(item, int):
66 | return Keypoints3Ds([self.tensor[item]])
67 | return Keypoints3Ds(self.tensor[item])
68 |
69 | def __repr__(self) -> str:
70 | s = self.__class__.__name__ + "("
71 | s += "num_instances={})".format(len(self.tensor))
72 | return s
73 |
74 | @classmethod
75 | def cat(cls, keypoints3ds_list: List["Keypoints3Ds"]) -> "Keypoints3Ds":
76 | """Concatenates a list of Keypoints3Ds into a single Keypoints3Ds.
77 |
78 | Arguments:
79 | keypoints3ds_list (list[Keypoints3Ds])
80 |
81 | Returns:
82 | Keypoints3Ds: the concatenated Keypoints3Ds
83 | """
84 | if torch.jit.is_scripting():
85 | # https://github.com/pytorch/pytorch/issues/18627
86 | # 1. staticmethod can be used in torchscript, But we can not use
87 | # `type(keypoits3ds).staticmethod` because torchscript only supports function
88 | # `type` with input type `torch.Tensor`.
89 | # 2. classmethod is not fully supported by torchscript. We explicitly assign
90 | # cls to Keypoints3Ds as a workaround to get torchscript support.
91 | cls = Keypoints3Ds
92 | assert isinstance(keypoints3ds_list, (list, tuple))
93 | if len(keypoints3ds_list) == 0:
94 | return cls(torch.empty(0))
95 | assert all(isinstance(keypoints3ds, Keypoints3Ds) for keypoints3ds in keypoints3ds_list)
96 |
97 | # use torch.cat (v.s. layers.cat) so the returned tensor never share storage with input
98 | cat_keypoints3ds = type(keypoints3ds_list[0])(torch.cat([b.tensor for b in keypoints3ds_list], dim=0))
99 | return cat_keypoints3ds
100 |
101 | @property
102 | def device(self) -> device:
103 | return self.tensor.device
104 |
105 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
106 | # https://github.com/pytorch/pytorch/issues/18627
107 | @torch.jit.unused
108 | def __iter__(self) -> Iterator[torch.Tensor]:
109 | """Yield a 2d center as a Tensor of shape (2,) at a time."""
110 | yield from self.tensor
111 |
--------------------------------------------------------------------------------
/lib/structures/my_list.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class MyList(list):
6 | def __getitem__(self, index):
7 | """support indexing using torch.Tensor."""
8 | if isinstance(index, torch.Tensor):
9 | if isinstance(index, torch.BoolTensor):
10 | return [self[i] for i, idx in enumerate(index) if idx]
11 | else:
12 | return [self[int(i)] for i in index]
13 | elif isinstance(index, (list, tuple)):
14 | if len(index) > 0 and isinstance(index[0], bool):
15 | return [self[i] for i, idx in enumerate(index) if idx]
16 | else:
17 | return [self[int(i)] for i in index]
18 | elif isinstance(index, np.ndarray):
19 | if index.dtype == np.bool:
20 | return [self[i] for i, idx in enumerate(index) if idx]
21 | else:
22 | return [self[int(i)] for i in index]
23 |
24 | return list.__getitem__(self, index)
25 |
26 |
27 | if __name__ == "__main__":
28 | a = [None, "a", 1, 2.3]
29 | a = MyList(a)
30 | print(a)
31 | print(type(a), isinstance(a, list))
32 | print("\ntorch bool index")
33 | index = torch.tensor([True, False, True, False])
34 | print(index)
35 | print(a[index])
36 |
37 | print("torch int index")
38 | index = torch.tensor([0, 2, 3])
39 | print(index)
40 | print(a[index])
41 |
42 | print("\nnumpy bool index")
43 | index = np.array([True, False, True, False])
44 | print(index)
45 | print(a[index])
46 |
47 | print("numpy int index")
48 | index = np.array([0, 2, 3])
49 | print(index)
50 | print(a[index])
51 |
52 | print("\nlist bool index")
53 | index = [True, False, True, False]
54 | print(index)
55 | print(a[index])
56 |
57 | print("list int index")
58 | index = [0, 2, 3]
59 | print(index)
60 | print(a[index])
61 |
62 | print("\ntuple bool index")
63 | index = (True, False, True, False)
64 | print(index)
65 | print(a[index])
66 |
67 | print("tuple int index")
68 | index = (0, 2, 3)
69 | print(index)
70 | print(a[index])
71 |
72 | # print(a[1:-1])
73 |
--------------------------------------------------------------------------------
/lib/structures/my_maps.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Any, Iterator, List, Union
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from detectron2.layers.roi_align import ROIAlign
8 | from torchvision.ops import RoIPool
9 |
10 |
11 | class MyMaps(object):
12 | """# NOTE: This class stores the maps (NOCS, coordinates map, pvnet vector
13 | maps, offset maps, heatmaps) for all objects in one image, support cpu_only
14 | option.
15 |
16 | Attributes:
17 | tensor: bool Tensor of N,C,H,W, representing N instances in the image.
18 | """
19 |
20 | def __init__(self, tensor: Union[torch.Tensor, np.ndarray], cpu_only: bool = True):
21 | """
22 | Args:
23 | tensor: float Tensor of N,C,H,W, representing N instances in the image.
24 | cpu_only: keep the maps on cpu even when to(device) is called
25 | """
26 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
27 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
28 | assert tensor.dim() == 4, tensor.size()
29 | self.image_size = tensor.shape[-2:]
30 | self.tensor = tensor
31 | self.cpu_only = cpu_only
32 |
33 | def to(self, device: str, **kwargs) -> "MyMaps":
34 | if not self.cpu_only:
35 | return MyMaps(self.tensor.to(device, **kwargs), cpu_only=False)
36 | else:
37 | return MyMaps(self.tensor.to("cpu", **kwargs), cpu_only=True)
38 |
39 | def to_device(self, device: str = "cuda", **kwargs) -> "MyMaps":
40 | # force to device
41 | return MyMaps(self.tensor.to(device, **kwargs), cpu_only=False)
42 |
43 | def crop_and_resize(
44 | self,
45 | boxes: torch.Tensor,
46 | map_size: int,
47 | interpolation: str = "bilinear",
48 | ) -> torch.Tensor:
49 | """# NOTE: if self.cpu_only, convert boxes to cpu
50 | Crop each map by the given box, and resize results to (map_size, map_size).
51 | This can be used to prepare training targets.
52 | Args:
53 | boxes (Tensor): Nx4 tensor storing the boxes for each map
54 | map_size (int): the size of the rasterized map.
55 | interpolation (str): bilinear | nearest
56 |
57 | Returns:
58 | Tensor:
59 | A bool tensor of shape (N, C, map_size, map_size), where
60 | N is the number of predicted boxes for this image.
61 | """
62 | assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
63 | if self.cpu_only:
64 | device = "cpu"
65 | else:
66 | device = self.tensor.device
67 |
68 | batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[:, None]
69 | rois = torch.cat([batch_inds, boxes.to(device)], dim=1) # Nx5
70 |
71 | maps = self.tensor.to(dtype=torch.float32)
72 | rois = rois.to(device=device)
73 | # on cpu, speed compared to cv2?
74 | if interpolation == "nearest":
75 | op = RoIPool((map_size, map_size), 1.0)
76 | elif interpolation == "bilinear":
77 | op = ROIAlign((map_size, map_size), 1.0, 0, aligned=True)
78 | else:
79 | raise ValueError(f"Unknown interpolation type: {interpolation}")
80 | output = op.forward(maps, rois)
81 | return output
82 |
83 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "MyMaps":
84 | """
85 | Returns:
86 | MyMaps: Create a new :class:`MyMaps` by indexing.
87 |
88 | The following usage are allowed:
89 |
90 | 1. `new_maps = maps[3]`: return a `MyMaps` which contains only one map.
91 | 2. `new_maps = maps[2:10]`: return a slice of maps.
92 | 3. `new_maps = maps[vector]`, where vector is a torch.BoolTensor
93 | with `length = len(maps)`. Nonzero elements in the vector will be selected.
94 |
95 | Note that the returned object might share storage with this object,
96 | subject to Pytorch's indexing semantics.
97 | """
98 | if isinstance(item, int):
99 | return MyMaps(self.tensor[item].view(1, -1))
100 | m = self.tensor[item]
101 | assert m.dim() == 4, "Indexing on MyMaps with {} returns a tensor with shape {}!".format(item, m.shape)
102 | return MyMaps(m)
103 |
104 | def __iter__(self) -> torch.Tensor:
105 | yield from self.tensor
106 |
107 | def __repr__(self) -> str:
108 | s = self.__class__.__name__ + "("
109 | s += "num_instances={})".format(len(self.tensor))
110 | return s
111 |
112 | def __len__(self) -> int:
113 | return self.tensor.shape[0]
114 |
115 | def nonempty(self) -> torch.Tensor:
116 | """Find maps that are non-empty.
117 |
118 | Returns:
119 | Tensor: a BoolTensor which represents
120 | whether each map is empty (False) or non-empty (True).
121 | """
122 | return self.tensor.flatten(1).any(dim=1)
123 |
--------------------------------------------------------------------------------
/lib/structures/poses.py:
--------------------------------------------------------------------------------
1 | # poses in [rot|trans] format (N, 3, 4).
2 | from typing import Any, Iterator, List, Tuple, Union
3 | import torch
4 | from torch import device
5 | from detectron2.utils.env import TORCH_VERSION
6 |
7 | if TORCH_VERSION < (1, 8):
8 | _maybe_jit_unused = torch.jit.unused
9 | else:
10 |
11 | def _maybe_jit_unused(x):
12 | return x
13 |
14 |
15 | class Poses:
16 | """This structure stores a list of 6d poses as a Nx3x4 ([rot|trans]
17 | torch.Tensor. It supports some common methods about poses, and also behaves
18 | like a Tensor (support indexing, `to(device)`, `.device`, and iteration
19 | over all 6d poses)
20 |
21 | Attributes:
22 | tensor (torch.Tensor): float matrix of Nx3x4.
23 | """
24 |
25 | def __init__(self, tensor: torch.Tensor):
26 | """
27 | Args:
28 | tensor (Tensor[float]):
29 | * a Nx3x4 matrix.
30 | """
31 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
32 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
33 | if tensor.numel() == 0:
34 | # Use reshape, so we don't end up creating a new tensor that does not depend on
35 | # the inputs (and consequently confuses jit)
36 | tensor = torch.reshape(0, 3, 4).to(dtype=torch.float32, device=device)
37 | assert tensor.ndim == 3 and (tensor.shape[1:] == (3, 4)), tensor.shape
38 |
39 | self.tensor = tensor
40 |
41 | def clone(self) -> "Poses":
42 | """Clone the Poses.
43 |
44 | Returns:
45 | Poses
46 | """
47 | return Poses(self.tensor.clone())
48 |
49 | @_maybe_jit_unused
50 | def to(self, device: torch.device = None, **kwargs) -> "Poses":
51 | return Poses(self.tensor.to(device=device, **kwargs))
52 |
53 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Poses":
54 | """
55 | Returns:
56 | Poses: Create a new :class:`Poses` by indexing.
57 |
58 | The following usage are allowed:
59 | 1. `new_poses = poses[3]`: return a `Poses` which contains only one pose.
60 | 2. `new_poses = poses[2:10]`: return a slice of poses.
61 | 3. `new_poses = poses[vector]`, where vector is a torch.BoolTensor
62 | with `length = len(poses)`. Nonzero elements in the vector will be selected.
63 |
64 | Note that the returned Poses might share storage with this Poses,
65 | subject to Pytorch's indexing semantics.
66 | """
67 | if isinstance(item, int):
68 | return Poses(self.tensor[item].view(1, 3, 4))
69 | b = self.tensor[item]
70 | assert b.ndim == 3, "Indexing on Poses with {} failed!".format(item)
71 | return Poses(b)
72 |
73 | def __len__(self) -> int:
74 | return self.tensor.shape[0]
75 |
76 | def __repr__(self) -> str:
77 | return "Poses(" + str(self.tensor) + ")"
78 |
79 | def get_centers_2d(self, K: torch.Tensor) -> torch.Tensor:
80 | """
81 | Args:
82 | K: camera intrinsic matrices, 1x3x3 or Nx3x3
83 | Returns:
84 | The 2d projected object centers in a Nx2 array of (x, y).
85 | """
86 | assert K.ndim == 3, K.shape
87 | bs = self.tensor.shape[0]
88 | proj = (K @ self.tensor[:, :3, [3]]).view(bs, 3) # Nx3
89 | centers_2d = proj[:, :2] / proj[:, 2:3] # Nx2
90 | return centers_2d
91 |
92 | @classmethod
93 | def cat(cls, poses_list: List["Poses"]) -> "Poses":
94 | """Concatenates a list of Poses into a single Poses.
95 |
96 | Arguments:
97 | poses_list (list[Poses])
98 |
99 | Returns:
100 | Poses: the concatenated Poses
101 | """
102 | if torch.jit.is_scripting():
103 | # https://github.com/pytorch/pytorch/issues/18627
104 | # 1. staticmethod can be used in torchscript, But we can not use
105 | # `type(poses).staticmethod` because torchscript only supports function
106 | # `type` with input type `torch.Tensor`.
107 | # 2. classmethod is not fully supported by torchscript. We explicitly assign
108 | # cls to Poses as a workaround to get torchscript support.
109 | cls = Poses
110 | assert isinstance(poses_list, (list, tuple))
111 | if len(poses_list) == 0:
112 | return cls(torch.empty(0))
113 | assert all(isinstance(pose, Poses) for pose in poses_list)
114 |
115 | # use torch.cat (v.s. layers.cat) so the returned poses never share storage with input
116 | cat_poses = cls(poses_list[0])(torch.cat([p.tensor for p in poses_list], dim=0))
117 | return cat_poses
118 |
119 | @property
120 | def device(self) -> device:
121 | return self.tensor.device
122 |
123 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
124 | # https://github.com/pytorch/pytorch/issues/18627
125 | @torch.jit.unused
126 | def __iter__(self) -> Iterator[torch.Tensor]:
127 | """Yield a 6d pose as a Tensor of shape (3,4) at a time."""
128 | yield from self.tensor
129 |
--------------------------------------------------------------------------------
/lib/structures/quats.py:
--------------------------------------------------------------------------------
1 | # quaternions
2 | from typing import Any, Iterator, List, Tuple, Union
3 |
4 | import torch
5 | from torch import device
6 | from detectron2.utils.env import TORCH_VERSION
7 |
8 | if TORCH_VERSION < (1, 8):
9 | _maybe_jit_unused = torch.jit.unused
10 | else:
11 |
12 | def _maybe_jit_unused(x):
13 | return x
14 |
15 |
16 | class Quats:
17 | """This structure stores a list of quaternions as a Nx4 torch.Tensor. It
18 | supports some common methods about quats, and also behaves like a Tensor
19 | (support indexing, `to(device)`, `.device`, and iteration over all quats)
20 |
21 | Attributes:
22 | tensor: float matrix of Nx4.
23 | """
24 |
25 | def __init__(self, tensor: torch.Tensor):
26 | """
27 | Args:
28 | tensor (Tensor[float]):
29 | * a Nx4 matrix. Each row is (qw, qx, qy, qz).
30 | """
31 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
32 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
33 | if tensor.numel() == 0:
34 | # Use reshape, so we don't end up creating a new tensor that does not depend on
35 | # the inputs (and consequently confuses jit)
36 | tensor = torch.reshape(0, 4).to(dtype=torch.float32, device=device)
37 | assert tensor.ndim == 2 and (tensor.shape[-1] == 4), tensor.shape
38 |
39 | self.tensor = tensor
40 |
41 | def clone(self) -> "Quats":
42 | """Clone the Quats.
43 |
44 | Returns:
45 | Quats
46 | """
47 | return Quats(self.tensor.clone())
48 |
49 | @_maybe_jit_unused
50 | def to(self, device: torch.device = None, **kwargs) -> "Quats":
51 | return Quats(self.tensor.to(device=device, **kwargs))
52 |
53 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Quats":
54 | """
55 | Returns:
56 | Quats: Create a new :class:`Quats` by indexing.
57 |
58 | The following usage are allowed:
59 | 1. `new_quats = quats[3]`: return a `Quats` which contains only one quat.
60 | 2. `new_quats = quats[2:10]`: return a slice of quats.
61 | 3. `new_quats = quats[vector]`, where vector is a torch.BoolTensor
62 | with `length = len(quats)`. Nonzero elements in the vector will be selected.
63 |
64 | Note that the returned Quats might share storage with this Quats,
65 | subject to Pytorch's indexing semantics.
66 | """
67 | if isinstance(item, int):
68 | return Quats(self.tensor[item].view(1, -1))
69 | b = self.tensor[item]
70 | assert b.ndim == 2, "Indexing on Quats with {} failed to return a matrix!".format(item)
71 | return Quats(b)
72 |
73 | def __len__(self) -> int:
74 | return self.tensor.shape[0]
75 |
76 | def __repr__(self) -> str:
77 | return "Quats(" + str(self.tensor) + ")"
78 |
79 | @classmethod
80 | def cat(cls, quats_list: List["Quats"]) -> "Quats":
81 | """Concatenates a list of Quats into a single Quats.
82 |
83 | Arguments:
84 | quats_list (list[Quats])
85 |
86 | Returns:
87 | Quats: the concatenated Quats
88 | """
89 | if torch.jit.is_scripting():
90 | # https://github.com/pytorch/pytorch/issues/18627
91 | # 1. staticmethod can be used in torchscript, But we can not use
92 | # `type(quats).staticmethod` because torchscript only supports function
93 | # `type` with input type `torch.Tensor`.
94 | # 2. classmethod is not fully supported by torchscript. We explicitly assign
95 | # cls to Quats as a workaround to get torchscript support.
96 | cls = Quats
97 | assert isinstance(quats_list, (list, tuple))
98 | if len(quats_list) == 0:
99 | return cls(torch.empty(0))
100 | assert all(isinstance(quats, Quats) for quats in quats_list)
101 |
102 | # use torch.cat (v.s. layers.cat) so the returned quats never share storage with input
103 | cat_quats = cls(quats_list[0])(torch.cat([q.tensor for q in quats_list], dim=0))
104 | return cat_quats
105 |
106 | @property
107 | def device(self) -> device:
108 | return self.tensor.device
109 |
110 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
111 | # https://github.com/pytorch/pytorch/issues/18627
112 | @torch.jit.unused
113 | def __iter__(self) -> Iterator[torch.Tensor]:
114 | """Yield a quat as a Tensor of shape (4,) at a time."""
115 | yield from self.tensor
116 |
--------------------------------------------------------------------------------
/lib/structures/rots.py:
--------------------------------------------------------------------------------
1 | # rotation matrices, format (N, 3, 3).
2 | from typing import Any, Iterator, List, Tuple, Union
3 | import torch
4 | from torch import device
5 | from detectron2.utils.env import TORCH_VERSION
6 |
7 | if TORCH_VERSION < (1, 8):
8 | _maybe_jit_unused = torch.jit.unused
9 | else:
10 |
11 | def _maybe_jit_unused(x):
12 | return x
13 |
14 |
15 | class Rots:
16 | """This structure stores a list of rotation matrices as a Nx3x3
17 | torch.Tensor. It supports some common methods about rots, and also behaves
18 | like a Tensor (support indexing, `to(device)`, `.device`, and iteration
19 | over all rots)
20 |
21 | Attributes:
22 | tensor (torch.Tensor): float matrix of Nx3x3.
23 | """
24 |
25 | def __init__(self, tensor: torch.Tensor):
26 | """
27 | Args:
28 | tensor (Tensor[float]):
29 | * a Nx3x3 matrix.
30 | """
31 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
32 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
33 | if tensor.numel() == 0:
34 | # Use reshape, so we don't end up creating a new tensor that does not depend on
35 | # the inputs (and consequently confuses jit)
36 | tensor = torch.reshape(0, 3, 3).to(dtype=torch.float32, device=device)
37 | assert tensor.ndim == 3 and (tensor.shape[1:] == (3, 3)), tensor.shape
38 |
39 | self.tensor = tensor
40 |
41 | def clone(self) -> "Rots":
42 | """Clone the Rots.
43 |
44 | Returns:
45 | Rots
46 | """
47 | return Rots(self.tensor.clone())
48 |
49 | @_maybe_jit_unused
50 | def to(self, device: torch.device = None, **kwargs) -> "Rots":
51 | return Rots(self.tensor.to(device=device, **kwargs))
52 |
53 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Rots":
54 | """
55 | Returns:
56 | Rots: Create a new :class:`Rots` by indexing.
57 |
58 | The following usage are allowed:
59 | 1. `new_rots = rots[3]`: return a `Rots` which contains only one pose.
60 | 2. `new_rots = rots[2:10]`: return a slice of rots.
61 | 3. `new_rots = rots[vector]`, where vector is a torch.BoolTensor
62 | with `length = len(rots)`. Nonzero elements in the vector will be selected.
63 |
64 | Note that the returned Rots might share storage with this Rots,
65 | subject to Pytorch's indexing semantics.
66 | """
67 | if isinstance(item, int):
68 | return Rots(self.tensor[item].view(1, 3, 3))
69 | b = self.tensor[item]
70 | assert b.ndim == 3, "Indexing on Rots with {} failed!".format(item)
71 | return Rots(b)
72 |
73 | def __len__(self) -> int:
74 | return self.tensor.shape[0]
75 |
76 | def __repr__(self) -> str:
77 | return "Rots(" + str(self.tensor) + ")"
78 |
79 | @classmethod
80 | def cat(cls, rots_list: List["Rots"]) -> "Rots":
81 | """Concatenates a list of Rots into a single Rots.
82 |
83 | Arguments:
84 | rots_list (list[Rots])
85 |
86 | Returns:
87 | Rots: the concatenated Rots
88 | """
89 | if torch.jit.is_scripting():
90 | # https://github.com/pytorch/pytorch/issues/18627
91 | # 1. staticmethod can be used in torchscript, But we can not use
92 | # `type(xxx).staticmethod` because torchscript only supports function
93 | # `type` with input type `torch.Tensor`.
94 | # 2. classmethod is not fully supported by torchscript. We explicitly assign
95 | # cls to ThisClassName as a workaround to get torchscript support.
96 | cls = Rots
97 | assert isinstance(rots_list, (list, tuple))
98 | if len(rots_list) == 0:
99 | return cls(torch.empty(0))
100 | assert all(isinstance(pose, Rots) for pose in rots_list)
101 |
102 | # use torch.cat (v.s. layers.cat) so the returned tensor never share storage with input
103 | cat_rots = cls(rots_list[0])(torch.cat([p.tensor for p in rots_list], dim=0))
104 | return cat_rots
105 |
106 | @property
107 | def device(self) -> device:
108 | return self.tensor.device
109 |
110 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
111 | # https://github.com/pytorch/pytorch/issues/18627
112 | @torch.jit.unused
113 | def __iter__(self) -> Iterator[torch.Tensor]:
114 | """Yield a rot as a Tensor of shape (3,3) at a time."""
115 | yield from self.tensor
116 |
--------------------------------------------------------------------------------
/lib/structures/translations.py:
--------------------------------------------------------------------------------
1 | # translations
2 | from typing import Any, Iterator, List, Tuple, Union
3 |
4 | import torch
5 | from torch import device
6 | from detectron2.utils.env import TORCH_VERSION
7 |
8 | if TORCH_VERSION < (1, 8):
9 | _maybe_jit_unused = torch.jit.unused
10 | else:
11 |
12 | def _maybe_jit_unused(x):
13 | return x
14 |
15 |
16 | class Translations:
17 | """This structure stores a list of translations a Nx3 torch.Tensor.
18 |
19 | Attributes:
20 | tensor: float matrix of Nx3.
21 | """
22 |
23 | def __init__(self, tensor: torch.Tensor):
24 | """
25 | Args:
26 | tensor (Tensor[float]):
27 | * a Nx3 matrix. Each row is (tx, ty, tz).
28 | """
29 | device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
30 | tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
31 | if tensor.numel() == 0:
32 | # Use reshape, so we don't end up creating a new tensor that does not depend on
33 | # the inputs (and consequently confuses jit)
34 | tensor = torch.reshape(0, 3).to(dtype=torch.float32, device=device)
35 | assert tensor.ndim == 2 and (tensor.shape[-1] == 3), tensor.shape
36 |
37 | self.tensor = tensor
38 |
39 | def clone(self) -> "Translations":
40 | """Clone the Translations.
41 |
42 | Returns:
43 | Translations
44 | """
45 | return Translations(self.tensor.clone())
46 |
47 | @_maybe_jit_unused
48 | def to(self, device: torch.device = None, **kwargs) -> "Translations":
49 | return Translations(self.tensor.to(device=device, **kwargs))
50 |
51 | def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Translations":
52 | """
53 | Returns:
54 | Translations: Create a new :class:`Translations` by indexing.
55 |
56 | The following usage are allowed:
57 | 1. `new_transes = transes[3]`: return a `Translations` which contains only one translation.
58 | 2. `new_transes = transes[2:10]`: return a slice of transes.
59 | 3. `new_transes = transes[vector]`, where vector is a torch.BoolTensor
60 | with `length = len(transes)`. Nonzero elements in the vector will be selected.
61 |
62 | Note that the returned Translations might share storage with this Translations,
63 | subject to Pytorch's indexing semantics.
64 | """
65 | if isinstance(item, int):
66 | return Translations(self.tensor[item].view(1, -1))
67 | b = self.tensor[item]
68 | assert b.ndim == 2, "Indexing on Translations with {} failed to return a matrix!".format(item)
69 | return Translations(b)
70 |
71 | def __len__(self) -> int:
72 | return self.tensor.shape[0]
73 |
74 | def __repr__(self) -> str:
75 | return "Translations(" + str(self.tensor) + ")"
76 |
77 | def get_centers_2d(self, K: torch.Tensor) -> torch.Tensor:
78 | """
79 | Args:
80 | K: camera intrinsic matrices, Nx3x3 or 1x3x3
81 | Returns:
82 | The 2d projected object centers in a Nx2 array of (x, y).
83 | """
84 | bs = self.tensor.shape[0]
85 | proj = (K @ self.tensor.view(bs, 3, 1)).view(bs, 3)
86 | centers_2d = proj[:, :2] / proj[:, 2:3] # Nx2
87 | return centers_2d
88 |
89 | @classmethod
90 | def cat(transes_list: List["Translations"]) -> "Translations":
91 | """Concatenates a list of Translations into a single Translations.
92 |
93 | Arguments:
94 | transes_list (list[Translations])
95 |
96 | Returns:
97 | Translations: the concatenated Translations
98 | """
99 | if torch.jit.is_scripting():
100 | # https://github.com/pytorch/pytorch/issues/18627
101 | # 1. staticmethod can be used in torchscript, But we can not use
102 | # `type(transes).staticmethod` because torchscript only supports function
103 | # `type` with input type `torch.Tensor`.
104 | # 2. classmethod is not fully supported by torchscript. We explicitly assign
105 | # cls to Translations as a workaround to get torchscript support.
106 | cls = Translations
107 | assert isinstance(transes_list, (list, tuple))
108 | if len(transes_list) == 0:
109 | return cls(torch.empty(0))
110 | assert all(isinstance(transes, Translations) for transes in transes_list)
111 |
112 | # use torch.cat (v.s. layers.cat) so the returned transes never share storage with input
113 | cat_transes = cls(transes_list[0])(torch.cat([t.tensor for t in transes_list], dim=0))
114 | return cat_transes
115 |
116 | @property
117 | def device(self) -> device:
118 | return self.tensor.device
119 |
120 | # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
121 | # https://github.com/pytorch/pytorch/issues/18627
122 | @torch.jit.unused
123 | def __iter__(self) -> Iterator[torch.Tensor]:
124 | """Yield a translation as a Tensor of shape (3,) at a time."""
125 | yield from self.tensor
126 |
--------------------------------------------------------------------------------
/lib/torch_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/torch_utils/__init__.py
--------------------------------------------------------------------------------
/lib/torch_utils/color/__init__.py:
--------------------------------------------------------------------------------
1 | from .gray import rgb_to_grayscale, RgbToGrayscale
2 | from .gray import bgr_to_grayscale, BgrToGrayscale
3 | from .rgb import BgrToRgb, bgr_to_rgb
4 | from .rgb import RgbToBgr, rgb_to_bgr
5 | from .rgb import RgbToRgba, rgb_to_rgba
6 | from .rgb import BgrToRgba, bgr_to_rgba
7 | from .rgb import RgbaToRgb, rgba_to_rgb
8 | from .rgb import RgbaToBgr, rgba_to_bgr
9 | from .hsv import RgbToHsv, rgb_to_hsv
10 | from .hsv import HsvToRgb, hsv_to_rgb
11 | from .hls import RgbToHls, rgb_to_hls
12 | from .hls import HlsToRgb, hls_to_rgb
13 | from .ycbcr import RgbToYcbcr, rgb_to_ycbcr
14 | from .ycbcr import YcbcrToRgb, ycbcr_to_rgb
15 | from .yuv import RgbToYuv, YuvToRgb, rgb_to_yuv, yuv_to_rgb
16 | from .xyz import RgbToXyz, XyzToRgb, rgb_to_xyz, xyz_to_rgb
17 | from .luv import RgbToLuv, LuvToRgb, rgb_to_luv, luv_to_rgb
18 | from .lab import RgbToLab, LabToRgb, rgb_to_lab, lab_to_rgb
19 |
20 |
21 | __all__ = [
22 | "rgb_to_grayscale",
23 | "bgr_to_grayscale",
24 | "bgr_to_rgb",
25 | "rgb_to_bgr",
26 | "rgb_to_rgba",
27 | "rgb_to_hsv",
28 | "hsv_to_rgb",
29 | "rgb_to_hls",
30 | "hls_to_rgb",
31 | "rgb_to_ycbcr",
32 | "ycbcr_to_rgb",
33 | "rgb_to_yuv",
34 | "yuv_to_rgb",
35 | "rgb_to_xyz",
36 | "xyz_to_rgb",
37 | "rgb_to_lab",
38 | "lab_to_rgb",
39 | "RgbToGrayscale",
40 | "BgrToGrayscale",
41 | "BgrToRgb",
42 | "RgbToBgr",
43 | "RgbToRgba",
44 | "RgbToHsv",
45 | "HsvToRgb",
46 | "RgbToHls",
47 | "HlsToRgb",
48 | "RgbToYcbcr",
49 | "YcbcrToRgb",
50 | "RgbToYuv",
51 | "YuvToRgb",
52 | "RgbToXyz",
53 | "XyzToRgb",
54 | "RgbToLuv",
55 | "LuvToRgb",
56 | "LabToRgb",
57 | "RgbToLab",
58 | ]
59 |
--------------------------------------------------------------------------------
/lib/torch_utils/color/gray.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .rgb import bgr_to_rgb
5 |
6 |
7 | def rgb_to_grayscale(image: torch.Tensor) -> torch.Tensor:
8 | r"""Convert a RGB image to grayscale version of image.
9 |
10 | The image data is assumed to be in the range of (0, 1).
11 |
12 | Args:
13 | image (torch.Tensor): RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`.
14 |
15 | Returns:
16 | torch.Tensor: grayscale version of the image with shape :math:`(*,1,H,W)`.
17 |
18 | Example:
19 | >>> input = torch.rand(2, 3, 4, 5)
20 | >>> gray = rgb_to_grayscale(input) # 2x1x4x5
21 | """
22 | if not isinstance(image, torch.Tensor):
23 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image)))
24 |
25 | if len(image.shape) < 3 or image.shape[-3] != 3:
26 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape))
27 |
28 | r: torch.Tensor = image[..., 0:1, :, :]
29 | g: torch.Tensor = image[..., 1:2, :, :]
30 | b: torch.Tensor = image[..., 2:3, :, :]
31 |
32 | gray: torch.Tensor = 0.299 * r + 0.587 * g + 0.114 * b
33 | return gray
34 |
35 |
36 | def bgr_to_grayscale(image: torch.Tensor) -> torch.Tensor:
37 | r"""Convert a BGR image to grayscale.
38 |
39 | The image data is assumed to be in the range of (0, 1). First flips to RGB, then converts.
40 |
41 | Args:
42 | image (torch.Tensor): BGR image to be converted to grayscale with shape :math:`(*,3,H,W)`.
43 |
44 | Returns:
45 | torch.Tensor: grayscale version of the image with shape :math:`(*,1,H,W)`.
46 |
47 | Example:
48 | >>> input = torch.rand(2, 3, 4, 5)
49 | >>> gray = bgr_to_grayscale(input) # 2x1x4x5
50 | """
51 | if not isinstance(image, torch.Tensor):
52 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image)))
53 |
54 | if len(image.shape) < 3 or image.shape[-3] != 3:
55 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape))
56 |
57 | image_rgb = bgr_to_rgb(image)
58 | gray: torch.Tensor = rgb_to_grayscale(image_rgb)
59 | return gray
60 |
61 |
62 | class RgbToGrayscale(nn.Module):
63 | r"""Module to convert a RGB image to grayscale version of image.
64 |
65 | The image data is assumed to be in the range of (0, 1).
66 |
67 | Shape:
68 | - image: :math:`(*, 3, H, W)`
69 | - output: :math:`(*, 1, H, W)`
70 |
71 | reference:
72 | https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html
73 |
74 | Example:
75 | >>> input = torch.rand(2, 3, 4, 5)
76 | >>> gray = RgbToGrayscale()
77 | >>> output = gray(input) # 2x1x4x5
78 | """
79 |
80 | def __init__(self) -> None:
81 | super(RgbToGrayscale, self).__init__()
82 |
83 | def forward(self, image: torch.Tensor) -> torch.Tensor: # type: ignore
84 | return rgb_to_grayscale(image)
85 |
86 |
87 | class BgrToGrayscale(nn.Module):
88 | r"""Module to convert a BGR image to grayscale version of image.
89 |
90 | The image data is assumed to be in the range of (0, 1). First flips to RGB, then converts.
91 |
92 | Shape:
93 | - image: :math:`(*, 3, H, W)`
94 | - output: :math:`(*, 1, H, W)`
95 |
96 | reference:
97 | https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html
98 |
99 | Example:
100 | >>> input = torch.rand(2, 3, 4, 5)
101 | >>> gray = BgrToGrayscale()
102 | >>> output = gray(input) # 2x1x4x5
103 | """
104 |
105 | def __init__(self) -> None:
106 | super(BgrToGrayscale, self).__init__()
107 |
108 | def forward(self, image: torch.Tensor) -> torch.Tensor: # type: ignore
109 | return bgr_to_grayscale(image)
110 |
--------------------------------------------------------------------------------
/lib/torch_utils/color/xyz.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def rgb_to_xyz(image: torch.Tensor) -> torch.Tensor:
6 | r"""Converts a RGB image to XYZ.
7 |
8 | Args:
9 | image (torch.Tensor): RGB Image to be converted to XYZ with shape :math:`(*, 3, H, W)`.
10 |
11 | Returns:
12 | torch.Tensor: XYZ version of the image with shape :math:`(*, 3, H, W)`.
13 |
14 | Example:
15 | >>> input = torch.rand(2, 3, 4, 5)
16 | >>> output = rgb_to_xyz(input) # 2x3x4x5
17 | """
18 | if not isinstance(image, torch.Tensor):
19 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image)))
20 |
21 | if len(image.shape) < 3 or image.shape[-3] != 3:
22 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape))
23 |
24 | r: torch.Tensor = image[..., 0, :, :]
25 | g: torch.Tensor = image[..., 1, :, :]
26 | b: torch.Tensor = image[..., 2, :, :]
27 |
28 | x: torch.Tensor = 0.412453 * r + 0.357580 * g + 0.180423 * b
29 | y: torch.Tensor = 0.212671 * r + 0.715160 * g + 0.072169 * b
30 | z: torch.Tensor = 0.019334 * r + 0.119193 * g + 0.950227 * b
31 |
32 | out: torch.Tensor = torch.stack([x, y, z], -3)
33 |
34 | return out
35 |
36 |
37 | def xyz_to_rgb(image: torch.Tensor) -> torch.Tensor:
38 | r"""Converts a XYZ image to RGB.
39 |
40 | Args:
41 | image (torch.Tensor): XYZ Image to be converted to RGB with shape :math:`(*, 3, H, W)`.
42 |
43 | Returns:
44 | torch.Tensor: RGB version of the image with shape :math:`(*, 3, H, W)`.
45 |
46 | Example:
47 | >>> input = torch.rand(2, 3, 4, 5)
48 | >>> output = xyz_to_rgb(input) # 2x3x4x5
49 | """
50 | if not isinstance(image, torch.Tensor):
51 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image)))
52 |
53 | if len(image.shape) < 3 or image.shape[-3] != 3:
54 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape))
55 |
56 | x: torch.Tensor = image[..., 0, :, :]
57 | y: torch.Tensor = image[..., 1, :, :]
58 | z: torch.Tensor = image[..., 2, :, :]
59 |
60 | r: torch.Tensor = 3.2404813432005266 * x + -1.5371515162713185 * y + -0.4985363261688878 * z
61 | g: torch.Tensor = -0.9692549499965682 * x + 1.8759900014898907 * y + 0.0415559265582928 * z
62 | b: torch.Tensor = 0.0556466391351772 * x + -0.2040413383665112 * y + 1.0573110696453443 * z
63 |
64 | out: torch.Tensor = torch.stack([r, g, b], dim=-3)
65 |
66 | return out
67 |
68 |
69 | class RgbToXyz(nn.Module):
70 | r"""Converts an image from RGB to XYZ.
71 |
72 | The image data is assumed to be in the range of (0, 1).
73 |
74 | Returns:
75 | torch.Tensor: XYZ version of the image.
76 |
77 | Shape:
78 | - image: :math:`(*, 3, H, W)`
79 | - output: :math:`(*, 3, H, W)`
80 |
81 | Examples:
82 | >>> input = torch.rand(2, 3, 4, 5)
83 | >>> xyz = RgbToXyz()
84 | >>> output = xyz(input) # 2x3x4x5
85 |
86 | Reference:
87 | [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html
88 | """
89 |
90 | def __init__(self) -> None:
91 | super(RgbToXyz, self).__init__()
92 |
93 | def forward(self, image: torch.Tensor) -> torch.Tensor:
94 | return rgb_to_xyz(image)
95 |
96 |
97 | class XyzToRgb(nn.Module):
98 | r"""Converts an image from XYZ to RGB.
99 |
100 | Returns:
101 | torch.Tensor: RGB version of the image.
102 |
103 | Shape:
104 | - image: :math:`(*, 3, H, W)`
105 | - output: :math:`(*, 3, H, W)`
106 |
107 | Examples:
108 | >>> input = torch.rand(2, 3, 4, 5)
109 | >>> rgb = XyzToRgb()
110 | >>> output = rgb(input) # 2x3x4x5
111 |
112 | Reference:
113 | [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html
114 | """
115 |
116 | def __init__(self) -> None:
117 | super(XyzToRgb, self).__init__()
118 |
119 | def forward(self, image: torch.Tensor) -> torch.Tensor:
120 | return xyz_to_rgb(image)
121 |
--------------------------------------------------------------------------------
/lib/torch_utils/color/ycbcr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def rgb_to_ycbcr(image: torch.Tensor) -> torch.Tensor:
6 | r"""Convert an RGB image to YCbCr.
7 |
8 | Args:
9 | image (torch.Tensor): RGB Image to be converted to YCbCr with shape :math:`(*, 3, H, W)`.
10 |
11 | Returns:
12 | torch.Tensor: YCbCr version of the image with shape :math:`(*, 3, H, W)`.
13 |
14 | Examples:
15 | >>> input = torch.rand(2, 3, 4, 5)
16 | >>> output = rgb_to_ycbcr(input) # 2x3x4x5
17 | """
18 | if not isinstance(image, torch.Tensor):
19 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image)))
20 |
21 | if len(image.shape) < 3 or image.shape[-3] != 3:
22 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape))
23 |
24 | r: torch.Tensor = image[..., 0, :, :]
25 | g: torch.Tensor = image[..., 1, :, :]
26 | b: torch.Tensor = image[..., 2, :, :]
27 |
28 | delta: float = 0.5
29 | y: torch.Tensor = 0.299 * r + 0.587 * g + 0.114 * b
30 | cb: torch.Tensor = (b - y) * 0.564 + delta
31 | cr: torch.Tensor = (r - y) * 0.713 + delta
32 | return torch.stack([y, cb, cr], -3)
33 |
34 |
35 | def ycbcr_to_rgb(image: torch.Tensor) -> torch.Tensor:
36 | r"""Convert an YCbCr image to RGB.
37 |
38 | The image data is assumed to be in the range of (0, 1).
39 |
40 | Args:
41 | image (torch.Tensor): YCbCr Image to be converted to RGB with shape :math:`(*, 3, H, W)`.
42 |
43 | Returns:
44 | torch.Tensor: RGB version of the image with shape :math:`(*, 3, H, W)`.
45 |
46 | Examples:
47 | >>> input = torch.rand(2, 3, 4, 5)
48 | >>> output = ycbcr_to_rgb(input) # 2x3x4x5
49 | """
50 | if not isinstance(image, torch.Tensor):
51 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image)))
52 |
53 | if len(image.shape) < 3 or image.shape[-3] != 3:
54 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape))
55 |
56 | y: torch.Tensor = image[..., 0, :, :]
57 | cb: torch.Tensor = image[..., 1, :, :]
58 | cr: torch.Tensor = image[..., 2, :, :]
59 |
60 | delta: float = 0.5
61 | cb_shifted: torch.Tensor = cb - delta
62 | cr_shifted: torch.Tensor = cr - delta
63 |
64 | r: torch.Tensor = y + 1.403 * cr_shifted
65 | g: torch.Tensor = y - 0.714 * cr_shifted - 0.344 * cb_shifted
66 | b: torch.Tensor = y + 1.773 * cb_shifted
67 | return torch.stack([r, g, b], -3)
68 |
69 |
70 | class RgbToYcbcr(nn.Module):
71 | r"""Convert an image from RGB to YCbCr.
72 |
73 | The image data is assumed to be in the range of (0, 1).
74 |
75 | Returns:
76 | torch.Tensor: YCbCr version of the image.
77 |
78 | Shape:
79 | - image: :math:`(*, 3, H, W)`
80 | - output: :math:`(*, 3, H, W)`
81 |
82 | Examples:
83 | >>> input = torch.rand(2, 3, 4, 5)
84 | >>> ycbcr = RgbToYcbcr()
85 | >>> output = ycbcr(input) # 2x3x4x5
86 | """
87 |
88 | def __init__(self) -> None:
89 | super(RgbToYcbcr, self).__init__()
90 |
91 | def forward(self, image: torch.Tensor) -> torch.Tensor:
92 | return rgb_to_ycbcr(image)
93 |
94 |
95 | class YcbcrToRgb(nn.Module):
96 | r"""Convert an image from YCbCr to Rgb.
97 |
98 | The image data is assumed to be in the range of (0, 1).
99 |
100 | Returns:
101 | torch.Tensor: RGB version of the image.
102 |
103 | Shape:
104 | - image: :math:`(*, 3, H, W)`
105 | - output: :math:`(*, 3, H, W)`
106 |
107 | Examples:
108 | >>> input = torch.rand(2, 3, 4, 5)
109 | >>> rgb = YcbcrToRgb()
110 | >>> output = rgb(input) # 2x3x4x5
111 | """
112 |
113 | def __init__(self) -> None:
114 | super(YcbcrToRgb, self).__init__()
115 |
116 | def forward(self, image: torch.Tensor) -> torch.Tensor:
117 | return ycbcr_to_rgb(image)
118 |
--------------------------------------------------------------------------------
/lib/torch_utils/color/yuv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def rgb_to_yuv(image: torch.Tensor) -> torch.Tensor:
6 | r"""Convert an RGB image to YUV.
7 |
8 | The image data is assumed to be in the range of (0, 1).
9 |
10 | Args:
11 | image (torch.Tensor): RGB Image to be converted to YUV with shape :math:`(*, 3, H, W)`.
12 |
13 | Returns:
14 | torch.Tensor: YUV version of the image with shape :math:`(*, 3, H, W)`.
15 |
16 | Example:
17 | >>> input = torch.rand(2, 3, 4, 5)
18 | >>> output = rgb_to_yuv(input) # 2x3x4x5
19 | """
20 | if not isinstance(image, torch.Tensor):
21 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image)))
22 |
23 | if len(image.shape) < 3 or image.shape[-3] != 3:
24 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape))
25 |
26 | r: torch.Tensor = image[..., 0, :, :]
27 | g: torch.Tensor = image[..., 1, :, :]
28 | b: torch.Tensor = image[..., 2, :, :]
29 |
30 | y: torch.Tensor = 0.299 * r + 0.587 * g + 0.114 * b
31 | u: torch.Tensor = -0.147 * r - 0.289 * g + 0.436 * b
32 | v: torch.Tensor = 0.615 * r - 0.515 * g - 0.100 * b
33 |
34 | out: torch.Tensor = torch.stack([y, u, v], -3)
35 |
36 | return out
37 |
38 |
39 | def yuv_to_rgb(image: torch.Tensor) -> torch.Tensor:
40 | r"""Convert an YUV image to RGB.
41 |
42 | The image data is assumed to be in the range of (0, 1).
43 |
44 | Args:
45 | image (torch.Tensor): YUV Image to be converted to RGB with shape :math:`(*, 3, H, W)`.
46 |
47 | Returns:
48 | torch.Tensor: RGB version of the image with shape :math:`(*, 3, H, W)`.
49 |
50 | Example:
51 | >>> input = torch.rand(2, 3, 4, 5)
52 | >>> output = yuv_to_rgb(input) # 2x3x4x5
53 | """
54 | if not isinstance(image, torch.Tensor):
55 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image)))
56 |
57 | if len(image.shape) < 3 or image.shape[-3] != 3:
58 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape))
59 |
60 | y: torch.Tensor = image[..., 0, :, :]
61 | u: torch.Tensor = image[..., 1, :, :]
62 | v: torch.Tensor = image[..., 2, :, :]
63 |
64 | r: torch.Tensor = y + 1.14 * v # coefficient for g is 0
65 | g: torch.Tensor = y + -0.396 * u - 0.581 * v
66 | b: torch.Tensor = y + 2.029 * u # coefficient for b is 0
67 |
68 | out: torch.Tensor = torch.stack([r, g, b], -3)
69 |
70 | return out
71 |
72 |
73 | class RgbToYuv(nn.Module):
74 | r"""Convert an image from RGB to YUV.
75 |
76 | The image data is assumed to be in the range of (0, 1).
77 |
78 | Returns:
79 | torch.Tensor: YUV version of the image.
80 |
81 | Shape:
82 | - image: :math:`(*, 3, H, W)`
83 | - output: :math:`(*, 3, H, W)`
84 |
85 | Examples:
86 | >>> input = torch.rand(2, 3, 4, 5)
87 | >>> yuv = RgbToYuv()
88 | >>> output = yuv(input) # 2x3x4x5
89 |
90 | Reference::
91 | [1] https://es.wikipedia.org/wiki/YUV#RGB_a_Y'UV
92 | """
93 |
94 | def __init__(self) -> None:
95 | super(RgbToYuv, self).__init__()
96 |
97 | def forward(self, input: torch.Tensor) -> torch.Tensor:
98 | return rgb_to_yuv(input)
99 |
100 |
101 | class YuvToRgb(nn.Module):
102 | r"""Convert an image from YUV to RGB.
103 |
104 | The image data is assumed to be in the range of (0, 1).
105 |
106 | Returns:
107 | torch.Tensor: RGB version of the image.
108 |
109 | Shape:
110 | - image: :math:`(*, 3, H, W)`
111 | - output: :math:`(*, 3, H, W)`
112 |
113 | Examples:
114 | >>> input = torch.rand(2, 3, 4, 5)
115 | >>> rgb = YuvToRgb()
116 | >>> output = rgb(input) # 2x3x4x5
117 | """
118 |
119 | def __init__(self) -> None:
120 | super(YuvToRgb, self).__init__()
121 |
122 | def forward(self, input: torch.Tensor) -> torch.Tensor:
123 | return yuv_to_rgb(input)
124 |
--------------------------------------------------------------------------------
/lib/torch_utils/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/torch_utils/layers/__init__.py
--------------------------------------------------------------------------------
/lib/torch_utils/layers/acon.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class AconC(nn.Module):
6 | r"""ACON activation (activate or not).
7 | # AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
8 | # according to "Activate or Not: Learning Customized Activation" .
9 | """
10 |
11 | def __init__(self, width):
12 | super().__init__()
13 | self.p1 = nn.Parameter(torch.randn(1, width, 1, 1))
14 | self.p2 = nn.Parameter(torch.randn(1, width, 1, 1))
15 | self.beta = nn.Parameter(torch.ones(1, width, 1, 1))
16 |
17 | def forward(self, x):
18 | return (self.p1 * x - self.p2 * x) * torch.sigmoid(self.beta * (self.p1 * x - self.p2 * x)) + self.p2 * x
19 |
20 |
21 | class MetaAconC(nn.Module):
22 | r"""ACON activation (activate or not).
23 | # MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network
24 | # according to "Activate or Not: Learning Customized Activation" .
25 | """
26 |
27 | def __init__(self, width, r=16):
28 | super().__init__()
29 | self.fc1 = nn.Conv2d(width, max(r, width // r), kernel_size=1, stride=1, bias=True)
30 | self.bn1 = nn.BatchNorm2d(max(r, width // r))
31 | self.fc2 = nn.Conv2d(max(r, width // r), width, kernel_size=1, stride=1, bias=True)
32 | self.bn2 = nn.BatchNorm2d(width)
33 |
34 | self.p1 = nn.Parameter(torch.randn(1, width, 1, 1))
35 | self.p2 = nn.Parameter(torch.randn(1, width, 1, 1))
36 |
37 | def forward(self, x):
38 | beta = torch.sigmoid(
39 | self.bn2(self.fc2(self.bn1(self.fc1(x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)))))
40 | )
41 | return (self.p1 * x - self.p2 * x) * torch.sigmoid(beta * (self.p1 * x - self.p2 * x)) + self.p2 * x
42 |
--------------------------------------------------------------------------------
/lib/torch_utils/layers/coord_attention.py:
--------------------------------------------------------------------------------
1 | """Coordinate Attention (CVPR 2021).
2 |
3 | Modified from https://github.com/Andrew-Qibin/CoordAttention/blob/main/coordatt.py.
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import math
9 | import torch.nn.functional as F
10 |
11 |
12 | class h_sigmoid(nn.Module):
13 | def __init__(self, inplace=True):
14 | super(h_sigmoid, self).__init__()
15 | self.relu = nn.ReLU6(inplace=inplace)
16 |
17 | def forward(self, x):
18 | return self.relu(x + 3) / 6
19 |
20 |
21 | class h_swish(nn.Module):
22 | def __init__(self, inplace=True):
23 | super(h_swish, self).__init__()
24 | self.sigmoid = h_sigmoid(inplace=inplace)
25 |
26 | def forward(self, x):
27 | return x * self.sigmoid(x)
28 |
29 |
30 | class CoordAtt(nn.Module):
31 | def __init__(self, inp, oup, reduction=32):
32 | super(CoordAtt, self).__init__()
33 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
34 | self.pool_w = nn.AdaptiveAvgPool2d((1, None))
35 |
36 | mip = max(8, inp // reduction)
37 |
38 | self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
39 | self.bn1 = nn.BatchNorm2d(mip)
40 | self.act = h_swish()
41 |
42 | self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
43 | self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
44 |
45 | def forward(self, x):
46 | identity = x
47 |
48 | n, c, h, w = x.size()
49 | x_h = self.pool_h(x)
50 | x_w = self.pool_w(x).permute(0, 1, 3, 2)
51 |
52 | y = torch.cat([x_h, x_w], dim=2)
53 | y = self.conv1(y)
54 | y = self.bn1(y)
55 | y = self.act(y)
56 |
57 | x_h, x_w = torch.split(y, [h, w], dim=2)
58 | x_w = x_w.permute(0, 1, 3, 2)
59 |
60 | a_h = self.conv_h(x_h).sigmoid()
61 | a_w = self.conv_w(x_w).sigmoid()
62 |
63 | out = identity * a_w * a_h
64 |
65 | return out
66 |
--------------------------------------------------------------------------------
/lib/torch_utils/layers/dropblock/README.md:
--------------------------------------------------------------------------------
1 | # DropBlock
2 |
3 | 
4 | [](https://pepy.tech/project/dropblock)
5 |
6 |
7 | Implementation of [DropBlock: A regularization method for convolutional networks](https://arxiv.org/pdf/1810.12890.pdf)
8 | in PyTorch.
9 |
10 | ## Abstract
11 |
12 | Deep neural networks often work well when they are over-parameterized
13 | and trained with a massive amount of noise and regularization, such as
14 | weight decay and dropout. Although dropout is widely used as a regularization
15 | technique for fully connected layers, it is often less effective for convolutional layers.
16 | This lack of success of dropout for convolutional layers is perhaps due to the fact
17 | that activation units in convolutional layers are spatially correlated so
18 | information can still flow through convolutional networks despite dropout.
19 | Thus a structured form of dropout is needed to regularize convolutional networks.
20 | In this paper, we introduce DropBlock, a form of structured dropout, where units in a
21 | contiguous region of a feature map are dropped together.
22 | We found that applying DropBlock in skip connections in addition to the
23 | convolution layers increases the accuracy. Also, gradually increasing number
24 | of dropped units during training leads to better accuracy and more robust to hyperparameter choices.
25 | Extensive experiments show that DropBlock works better than dropout in regularizing
26 | convolutional networks. On ImageNet classification, ResNet-50 architecture with
27 | DropBlock achieves 78.13% accuracy, which is more than 1.6% improvement on the baseline.
28 | On COCO detection, DropBlock improves Average Precision of RetinaNet from 36.8% to 38.4%.
29 |
30 |
31 | ## Installation
32 |
33 | Install directly from PyPI:
34 |
35 | pip install dropblock
36 |
37 | or the bleeding edge version from github:
38 |
39 | pip install git+https://github.com/miguelvr/dropblock.git#egg=dropblock
40 |
41 | **NOTE**: Implementation and tests were done in Python 3.6, if you have problems with other versions of python please open an issue.
42 |
43 | ## Usage
44 |
45 |
46 | For 2D inputs (DropBlock2D):
47 |
48 | ```python
49 | import torch
50 | from dropblock import DropBlock2D
51 |
52 | # (bsize, n_feats, height, width)
53 | x = torch.rand(100, 10, 16, 16)
54 |
55 | drop_block = DropBlock2D(block_size=3, drop_prob=0.3)
56 | regularized_x = drop_block(x)
57 | ```
58 |
59 | For 3D inputs (DropBlock3D):
60 |
61 | ```python
62 | import torch
63 | from dropblock import DropBlock3D
64 |
65 | # (bsize, n_feats, depth, height, width)
66 | x = torch.rand(100, 10, 16, 16, 16)
67 |
68 | drop_block = DropBlock3D(block_size=3, drop_prob=0.3)
69 | regularized_x = drop_block(x)
70 | ```
71 |
72 | Scheduled Dropblock:
73 |
74 | ```python
75 | import torch
76 | from dropblock import DropBlock2D, LinearScheduler
77 |
78 | # (bsize, n_feats, depth, height, width)
79 | loader = [torch.rand(20, 10, 16, 16) for _ in range(10)]
80 |
81 | drop_block = LinearScheduler(
82 | DropBlock2D(block_size=3, drop_prob=0.),
83 | start_value=0.,
84 | stop_value=0.25,
85 | nr_steps=5
86 | )
87 |
88 | probs = []
89 | for x in loader:
90 | drop_block.step()
91 | regularized_x = drop_block(x)
92 | probs.append(drop_block.dropblock.drop_prob)
93 |
94 | print(probs)
95 | ```
96 |
97 | The drop probabilities will be:
98 | ```
99 | >>> [0. , 0.0625, 0.125 , 0.1875, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]
100 | ```
101 |
102 | The user should include the `step()` call at the start of the batch loop,
103 | or at the the start of a model's `forward` call.
104 |
105 | Check [examples/resnet-cifar10.py](examples/resnet-cifar10.py) to
106 | see an implementation example.
107 |
108 | ## Implementation details
109 |
110 | We use `drop_prob` instead of `keep_prob` as a matter of preference,
111 | and to keep the argument consistent with pytorch's dropout.
112 | Regardless, everything else should work similarly to what is described in the paper.
113 |
114 | ## Benchmark
115 |
116 | Refer to [BENCHMARK.md](BENCHMARK.md)
117 |
118 | ## Reference
119 | [Ghiasi et al., 2018] DropBlock: A regularization method for convolutional networks
120 |
121 | ## TODO
122 | - [x] Scheduled DropBlock
123 | - [x] Get benchmark numbers
124 | - [x] Extend the concept for 3D images
125 |
--------------------------------------------------------------------------------
/lib/torch_utils/layers/dropblock/__init__.py:
--------------------------------------------------------------------------------
1 | from .dropblock import DropBlock2D, DropBlock3D
2 | from .scheduler import LinearScheduler
3 |
4 | __all__ = ["DropBlock2D", "DropBlock3D", "LinearScheduler"]
5 |
--------------------------------------------------------------------------------
/lib/torch_utils/layers/dropblock/dropblock.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | class DropBlock2D(nn.Module):
7 | r"""Randomly zeroes 2D spatial blocks of the input tensor.
8 |
9 | As described in the paper
10 | `DropBlock: A regularization method for convolutional networks`_ ,
11 | dropping whole blocks of feature map allows to remove semantic
12 | information as compared to regular dropout.
13 |
14 | Args:
15 | drop_prob (float): probability of an element to be dropped.
16 | block_size (int): size of the block to drop
17 |
18 | Shape:
19 | - Input: `(N, C, H, W)`
20 | - Output: `(N, C, H, W)`
21 |
22 | .. _DropBlock: A regularization method for convolutional networks:
23 | https://arxiv.org/abs/1810.12890
24 |
25 | """
26 |
27 | def __init__(self, drop_prob, block_size):
28 | super(DropBlock2D, self).__init__()
29 |
30 | self.drop_prob = drop_prob
31 | self.block_size = block_size
32 |
33 | def forward(self, x):
34 | # shape: (bsize, channels, height, width)
35 |
36 | assert x.dim() == 4, "Expected input with 4 dimensions (bsize, channels, height, width)"
37 |
38 | if not self.training or self.drop_prob == 0.0:
39 | return x
40 | else:
41 | # get gamma value
42 | gamma = self._compute_gamma(x)
43 |
44 | # sample mask
45 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float()
46 |
47 | # place mask on input device
48 | mask = mask.to(x.device)
49 |
50 | # compute block mask
51 | block_mask = self._compute_block_mask(mask)
52 |
53 | # apply block mask
54 | out = x * block_mask[:, None, :, :]
55 |
56 | # scale output
57 | out = out * block_mask.numel() / block_mask.sum()
58 |
59 | return out
60 |
61 | def _compute_block_mask(self, mask):
62 | block_mask = F.max_pool2d(
63 | input=mask[:, None, :, :],
64 | kernel_size=(self.block_size, self.block_size),
65 | stride=(1, 1),
66 | padding=self.block_size // 2,
67 | )
68 |
69 | if self.block_size % 2 == 0:
70 | block_mask = block_mask[:, :, :-1, :-1]
71 |
72 | block_mask = 1 - block_mask.squeeze(1)
73 |
74 | return block_mask
75 |
76 | def _compute_gamma(self, x):
77 | return self.drop_prob / (self.block_size**2)
78 |
79 |
80 | class DropBlock3D(DropBlock2D):
81 | r"""Randomly zeroes 3D spatial blocks of the input tensor.
82 |
83 | An extension to the concept described in the paper
84 | `DropBlock: A regularization method for convolutional networks`_ ,
85 | dropping whole blocks of feature map allows to remove semantic
86 | information as compared to regular dropout.
87 |
88 | Args:
89 | drop_prob (float): probability of an element to be dropped.
90 | block_size (int): size of the block to drop
91 |
92 | Shape:
93 | - Input: `(N, C, D, H, W)`
94 | - Output: `(N, C, D, H, W)`
95 |
96 | .. _DropBlock: A regularization method for convolutional networks:
97 | https://arxiv.org/abs/1810.12890
98 |
99 | """
100 |
101 | def __init__(self, drop_prob, block_size):
102 | super(DropBlock3D, self).__init__(drop_prob, block_size)
103 |
104 | def forward(self, x):
105 | # shape: (bsize, channels, depth, height, width)
106 |
107 | assert x.dim() == 5, "Expected input with 5 dimensions (bsize, channels, depth, height, width)"
108 |
109 | if not self.training or self.drop_prob == 0.0:
110 | return x
111 | else:
112 | # get gamma value
113 | gamma = self._compute_gamma(x)
114 |
115 | # sample mask
116 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float()
117 |
118 | # place mask on input device
119 | mask = mask.to(x.device)
120 |
121 | # compute block mask
122 | block_mask = self._compute_block_mask(mask)
123 |
124 | # apply block mask
125 | out = x * block_mask[:, None, :, :, :]
126 |
127 | # scale output
128 | out = out * block_mask.numel() / block_mask.sum()
129 |
130 | return out
131 |
132 | def _compute_block_mask(self, mask):
133 | block_mask = F.max_pool3d(
134 | input=mask[:, None, :, :, :],
135 | kernel_size=(self.block_size, self.block_size, self.block_size),
136 | stride=(1, 1, 1),
137 | padding=self.block_size // 2,
138 | )
139 |
140 | if self.block_size % 2 == 0:
141 | block_mask = block_mask[:, :, :-1, :-1, :-1]
142 |
143 | block_mask = 1 - block_mask.squeeze(1)
144 |
145 | return block_mask
146 |
147 | def _compute_gamma(self, x):
148 | return self.drop_prob / (self.block_size**3)
149 |
--------------------------------------------------------------------------------
/lib/torch_utils/layers/dropblock/scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch import nn
3 |
4 |
5 | class LinearScheduler(nn.Module):
6 | def __init__(self, dropblock, start_value, stop_value, nr_steps):
7 | super(LinearScheduler, self).__init__()
8 | self.dropblock = dropblock
9 | self.i = 0
10 | self.drop_values = np.linspace(start=start_value, stop=stop_value, num=int(nr_steps))
11 |
12 | def forward(self, x):
13 | return self.dropblock(x)
14 |
15 | def step(self):
16 | if self.i < len(self.drop_values):
17 | self.dropblock.drop_prob = self.drop_values[self.i]
18 |
19 | self.i += 1
20 |
--------------------------------------------------------------------------------
/lib/torch_utils/layers/mean_conv_deconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from typing import Optional, List, Tuple, Union
5 | from torch import Tensor
6 | from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
7 |
8 |
9 | class MeanConv2d(nn.Conv2d):
10 | """Conv2d with weight centralization.
11 |
12 | ref: Weight and Gradient Centralization in Deep Neural Networks. https://arxiv.org/pdf/2010.00866.pdf
13 | """
14 |
15 | def forward(self, x):
16 | w = self.weight # [c_out, c_in, k, k]
17 | w = w - torch.mean(w, dim=[1, 2, 3], keepdim=True)
18 | return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
19 |
20 |
21 | class MeanConvTranspose2d(nn.ConvTranspose2d):
22 | """ConvTranspose2d with Weight Centralization.
23 |
24 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
25 | https://arxiv.org/abs/1903.10520v2
26 | """
27 |
28 | def __init__(
29 | self,
30 | in_channels: int,
31 | out_channels: int,
32 | kernel_size: _size_2_t,
33 | stride: _size_2_t = 1,
34 | padding: _size_2_t = 0,
35 | output_padding: _size_2_t = 0,
36 | groups: int = 1,
37 | bias: bool = True,
38 | dilation: int = 1,
39 | padding_mode: str = "zeros",
40 | device=None,
41 | dtype=None,
42 | eps=1e-6,
43 | ):
44 | super().__init__(
45 | in_channels,
46 | out_channels,
47 | kernel_size,
48 | stride=stride,
49 | padding=padding,
50 | output_padding=output_padding,
51 | dilation=dilation,
52 | groups=groups,
53 | bias=bias,
54 | padding_mode=padding_mode,
55 | device=device,
56 | dtype=dtype,
57 | )
58 | self.eps = eps
59 |
60 | def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
61 | if self.padding_mode != "zeros":
62 | raise ValueError("Only `zeros` padding mode is supported for ConvTranspose2d")
63 |
64 | assert isinstance(self.padding, tuple)
65 | # One cannot replace List by Tuple or Sequence in "_output_padding" because
66 | # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
67 | output_padding = self._output_padding(
68 | input, output_size, self.stride, self.padding, self.kernel_size, self.dilation
69 | ) # type: ignore[arg-type]
70 |
71 | w = self.weight
72 | w = w - torch.mean(w, dim=[1, 2, 3], keepdim=True)
73 | return F.conv_transpose2d(
74 | input, w, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation
75 | )
76 |
--------------------------------------------------------------------------------
/lib/torch_utils/layers/std_conv_transpose.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch import nn
3 | from typing import Optional, List, Tuple, Union
4 | from torch import Tensor
5 | from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
6 |
7 |
8 | class StdConvTranspose2d(nn.ConvTranspose2d):
9 | """ConvTranspose2d with Weight Standardization.
10 |
11 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
12 | https://arxiv.org/abs/1903.10520v2
13 | """
14 |
15 | def __init__(
16 | self,
17 | in_channels: int,
18 | out_channels: int,
19 | kernel_size: _size_2_t,
20 | stride: _size_2_t = 1,
21 | padding: _size_2_t = 0,
22 | output_padding: _size_2_t = 0,
23 | groups: int = 1,
24 | bias: bool = True,
25 | dilation: int = 1,
26 | padding_mode: str = "zeros",
27 | device=None,
28 | dtype=None,
29 | eps=1e-6,
30 | ):
31 | super().__init__(
32 | in_channels,
33 | out_channels,
34 | kernel_size,
35 | stride=stride,
36 | padding=padding,
37 | output_padding=output_padding,
38 | dilation=dilation,
39 | groups=groups,
40 | bias=bias,
41 | padding_mode=padding_mode,
42 | device=device,
43 | dtype=dtype,
44 | )
45 | self.eps = eps
46 |
47 | def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
48 | if self.padding_mode != "zeros":
49 | raise ValueError("Only `zeros` padding mode is supported for ConvTranspose2d")
50 |
51 | assert isinstance(self.padding, tuple)
52 | # One cannot replace List by Tuple or Sequence in "_output_padding" because
53 | # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
54 | output_padding = self._output_padding(
55 | input, output_size, self.stride, self.padding, self.kernel_size, self.dilation
56 | ) # type: ignore[arg-type]
57 |
58 | weight = F.batch_norm(
59 | self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0.0, eps=self.eps
60 | ).reshape_as(self.weight)
61 | return F.conv_transpose2d(
62 | input, weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation
63 | )
64 |
--------------------------------------------------------------------------------
/lib/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 |
5 |
6 | # ----------------------------------------------------------------------------
7 | # Replace NaN/Inf with specified numerical values.
8 | # (from https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/torch_utils/misc.py)
9 |
10 | try:
11 | nan_to_num = torch.nan_to_num # 1.8.0a0
12 | except AttributeError:
13 |
14 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
15 | assert isinstance(input, torch.Tensor)
16 | if posinf is None:
17 | posinf = torch.finfo(input.dtype).max
18 | if neginf is None:
19 | neginf = torch.finfo(input.dtype).min
20 | assert nan == 0
21 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
22 |
23 |
24 | def set_nan_to_0(a, name=None, verbose=False):
25 | if torch.isnan(a).any():
26 | if verbose and name is not None:
27 | print("nan in {}".format(name))
28 | a[a != a] = 0
29 | return a
30 |
31 |
32 | # ----------------------------------------------------------------------------
33 | # Symbolic assert.
34 |
35 | try:
36 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
37 | except AttributeError:
38 | symbolic_assert = torch.Assert # 1.7.0
39 |
40 | # ----------------------------------------------------------------------------
41 |
42 |
43 | class suppress_tracer_warnings(warnings.catch_warnings):
44 | """Context manager to suppress known warnings in torch.jit.trace()."""
45 |
46 | def __enter__(self):
47 | super().__enter__()
48 | warnings.simplefilter("ignore", category=torch.jit.TracerWarning)
49 | return self
50 |
51 |
52 | # ----------------------------------------------------------------------------
53 |
54 |
55 | def assert_shape(tensor, ref_shape):
56 | """Assert that the shape of a tensor matches the given list of integers.
57 |
58 | None indicates that the size of a dimension is allowed to vary.
59 | Performs symbolic assertion when used in torch.jit.trace().
60 | """
61 | if tensor.ndim != len(ref_shape):
62 | raise AssertionError(f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}")
63 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
64 | if ref_size is None:
65 | pass
66 | elif isinstance(ref_size, torch.Tensor):
67 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
68 | symbolic_assert(
69 | torch.equal(torch.as_tensor(size), ref_size),
70 | f"Wrong size for dimension {idx}",
71 | )
72 | elif isinstance(size, torch.Tensor):
73 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
74 | symbolic_assert(
75 | torch.equal(size, torch.as_tensor(ref_size)),
76 | f"Wrong size for dimension {idx}: expected {ref_size}",
77 | )
78 | elif size != ref_size:
79 | raise AssertionError(f"Wrong size for dimension {idx}: got {size}, expected {ref_size}")
80 |
--------------------------------------------------------------------------------
/lib/torch_utils/solver/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/torch_utils/solver/__init__.py
--------------------------------------------------------------------------------
/lib/torch_utils/solver/adamp.py:
--------------------------------------------------------------------------------
1 | """AdamP Copyright (c) 2020-present NAVER Corp.
2 |
3 | MIT license
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.optim.optimizer import Optimizer, required
10 | import math
11 |
12 |
13 | class AdamP(Optimizer):
14 | def __init__(
15 | self,
16 | params,
17 | lr=1e-3,
18 | betas=(0.9, 0.999),
19 | eps=1e-8,
20 | weight_decay=0,
21 | delta=0.1,
22 | wd_ratio=0.1,
23 | nesterov=False,
24 | ):
25 | defaults = dict(
26 | lr=lr,
27 | betas=betas,
28 | eps=eps,
29 | weight_decay=weight_decay,
30 | delta=delta,
31 | wd_ratio=wd_ratio,
32 | nesterov=nesterov,
33 | )
34 | super(AdamP, self).__init__(params, defaults)
35 |
36 | def _channel_view(self, x):
37 | return x.view(x.size(0), -1)
38 |
39 | def _layer_view(self, x):
40 | return x.view(1, -1)
41 |
42 | def _cosine_similarity(self, x, y, eps, view_func):
43 | x = view_func(x)
44 | y = view_func(y)
45 |
46 | return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()
47 |
48 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
49 | wd = 1
50 | expand_size = [-1] + [1] * (len(p.shape) - 1)
51 | for view_func in [self._channel_view, self._layer_view]:
52 |
53 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
54 |
55 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
56 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
57 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
58 | wd = wd_ratio
59 |
60 | return perturb, wd
61 |
62 | return perturb, wd
63 |
64 | def step(self, closure=None):
65 | loss = None
66 | if closure is not None:
67 | loss = closure()
68 |
69 | for group in self.param_groups:
70 | for p in group["params"]:
71 | if p.grad is None:
72 | continue
73 |
74 | grad = p.grad.data
75 | beta1, beta2 = group["betas"]
76 | nesterov = group["nesterov"]
77 |
78 | state = self.state[p]
79 |
80 | # State initialization
81 | if len(state) == 0:
82 | state["step"] = 0
83 | state["exp_avg"] = torch.zeros_like(p.data)
84 | state["exp_avg_sq"] = torch.zeros_like(p.data)
85 |
86 | # Adam
87 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
88 |
89 | state["step"] += 1
90 | bias_correction1 = 1 - beta1 ** state["step"]
91 | bias_correction2 = 1 - beta2 ** state["step"]
92 |
93 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
94 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
95 |
96 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
97 | step_size = group["lr"] / bias_correction1
98 |
99 | if nesterov:
100 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
101 | else:
102 | perturb = exp_avg / denom
103 |
104 | # Projection
105 | wd_ratio = 1
106 | if len(p.shape) > 1:
107 | perturb, wd_ratio = self._projection(
108 | p,
109 | grad,
110 | perturb,
111 | group["delta"],
112 | group["wd_ratio"],
113 | group["eps"],
114 | )
115 |
116 | # Weight decay
117 | if group["weight_decay"] > 0:
118 | p.data.mul_(1 - group["lr"] * group["weight_decay"] * wd_ratio)
119 |
120 | # Step
121 | p.data.add_(perturb, alpha=-step_size)
122 |
123 | return loss
124 |
--------------------------------------------------------------------------------
/lib/torch_utils/solver/grad_clip_d2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # modified to support full_model gradient norm clip
3 | import copy
4 | import itertools
5 | import logging
6 | from enum import Enum
7 | from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
8 | from omegaconf.omegaconf import OmegaConf
9 |
10 | import torch
11 | from detectron2.config import CfgNode
12 | from fvcore.common.param_scheduler import CosineParamScheduler, MultiStepParamScheduler
13 | from lib.utils.config_utils import try_get_key
14 |
15 |
16 | _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
17 | _GradientClipper = Callable[[_GradientClipperInput], None]
18 |
19 |
20 | class GradientClipType(Enum):
21 | VALUE = "value"
22 | NORM = "norm"
23 | FULL_MODEL = "full_model"
24 |
25 |
26 | def _create_gradient_clipper(cfg) -> _GradientClipper:
27 | """Creates gradient clipping closure to clip by value or by norm, according
28 | to the provided config."""
29 | cfg = copy.deepcopy(cfg)
30 |
31 | _clip_value = try_get_key(cfg, "CLIP_VALUE", "clip_value", default=1.0)
32 | _norm_type = try_get_key(cfg, "NORM_TYPE", "norm_type", default=2.0)
33 |
34 | def clip_grad_norm(p: _GradientClipperInput):
35 | torch.nn.utils.clip_grad_norm_(p, _clip_value, _norm_type)
36 |
37 | def clip_grad_value(p: _GradientClipperInput):
38 | torch.nn.utils.clip_grad_value_(p, _clip_value)
39 |
40 | _GRADIENT_CLIP_TYPE_TO_CLIPPER = {
41 | GradientClipType.VALUE: clip_grad_value,
42 | GradientClipType.NORM: clip_grad_norm,
43 | GradientClipType.FULL_MODEL: clip_grad_norm,
44 | }
45 | _clip_type = try_get_key(cfg, "CLIP_TYPE", "clip_type", default="full_model")
46 | return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(_clip_type)]
47 |
48 |
49 | def _generate_optimizer_class_with_gradient_clipping(
50 | optimizer: Type[torch.optim.Optimizer],
51 | *,
52 | per_param_clipper: Optional[_GradientClipper] = None,
53 | global_clipper: Optional[_GradientClipper] = None,
54 | ) -> Type[torch.optim.Optimizer]:
55 | """Dynamically creates a new type that inherits the type of a given
56 | instance and overrides the `step` method to add gradient clipping."""
57 | assert (
58 | per_param_clipper is None or global_clipper is None
59 | ), "Not allowed to use both per-parameter clipping and global clipping"
60 |
61 | def optimizer_wgc_step(self, closure=None):
62 | if per_param_clipper is not None:
63 | for group in self.param_groups:
64 | for p in group["params"]:
65 | per_param_clipper(p)
66 | else:
67 | # global clipper for future use with detr
68 | # (https://github.com/facebookresearch/detr/pull/287)
69 | all_params = itertools.chain(*[g["params"] for g in self.param_groups])
70 | global_clipper(all_params)
71 | super(type(self), self).step(closure)
72 |
73 | OptimizerWithGradientClip = type(
74 | optimizer.__name__ + "WithGradientClip",
75 | (optimizer,),
76 | {"step": optimizer_wgc_step},
77 | )
78 | return OptimizerWithGradientClip
79 |
80 |
81 | def maybe_add_gradient_clipping(cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]) -> Type[torch.optim.Optimizer]:
82 | """If gradient clipping is enabled through config options, wraps the
83 | existing optimizer type to become a new dynamically created class
84 | OptimizerWithGradientClip that inherits the given optimizer and overrides
85 | the `step` method to include gradient clipping.
86 |
87 | Args:
88 | cfg: CfgNode, configuration options
89 | optimizer: type. A subclass of torch.optim.Optimizer
90 | Return:
91 | type: either the input `optimizer` (if gradient clipping is disabled), or
92 | a subclass of it with gradient clipping included in the `step` method.
93 | """
94 | clip_cfg = try_get_key(
95 | cfg, "SOLVER.CLIP_GRADIENTS", "train.grad_clip", default=OmegaConf.create(dict(enabled=False))
96 | )
97 | if not try_get_key(clip_cfg, "ENABLED", "enabled", default=False):
98 | return optimizer
99 | if isinstance(optimizer, torch.optim.Optimizer):
100 | optimizer_type = type(optimizer)
101 | else:
102 | assert issubclass(optimizer, torch.optim.Optimizer), optimizer
103 | optimizer_type = optimizer
104 |
105 | grad_clipper = _create_gradient_clipper(clip_cfg)
106 | _clip_type = try_get_key(clip_cfg, "CLIP_TYPE", "clip_type", default="full_model")
107 | if _clip_type != "full_model":
108 | OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
109 | optimizer_type, per_param_clipper=grad_clipper
110 | )
111 | else:
112 | OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
113 | optimizer_type, global_clipper=grad_clipper
114 | )
115 | if isinstance(optimizer, torch.optim.Optimizer):
116 | optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended
117 | return optimizer
118 | else:
119 | return OptimizerWithGradientClip
120 |
--------------------------------------------------------------------------------
/lib/torch_utils/solver/lookahead.py:
--------------------------------------------------------------------------------
1 | """Lookahead Optimizer Wrapper. Implementation modified from:
2 | https://github.com/alphadl/lookahead.pytorch.
3 |
4 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
5 | """
6 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py
7 | import torch
8 | from torch.optim.optimizer import Optimizer
9 | from collections import defaultdict
10 |
11 | # from lib.utils import logger
12 |
13 |
14 | class Lookahead(Optimizer):
15 | def __init__(self, base_optimizer, alpha=0.5, k=6):
16 | if not 0.0 <= alpha <= 1.0:
17 | raise ValueError(f"Invalid slow update rate: {alpha}")
18 | if not 1 <= k:
19 | raise ValueError(f"Invalid lookahead steps: {k}")
20 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
21 | self.base_optimizer = base_optimizer
22 | self.param_groups = self.base_optimizer.param_groups
23 | self.defaults = base_optimizer.defaults
24 | self.defaults.update(defaults)
25 | self.state = defaultdict(dict)
26 | # manually add our defaults to the param groups
27 | for name, default in defaults.items():
28 | for group in self.param_groups:
29 | group.setdefault(name, default)
30 |
31 | def update_slow(self, group):
32 | for fast_p in group["params"]:
33 | if fast_p.grad is None:
34 | continue
35 | param_state = self.state[fast_p]
36 | if "slow_buffer" not in param_state:
37 | param_state["slow_buffer"] = torch.empty_like(fast_p.data)
38 | param_state["slow_buffer"].copy_(fast_p.data)
39 | slow = param_state["slow_buffer"]
40 | slow.add_(group["lookahead_alpha"], fast_p.data - slow)
41 | fast_p.data.copy_(slow)
42 |
43 | def sync_lookahead(self):
44 | for group in self.param_groups:
45 | self.update_slow(group)
46 |
47 | def step(self, closure=None):
48 | # assert id(self.param_groups) == id(self.base_optimizer.param_groups)
49 | loss = self.base_optimizer.step(closure)
50 | for group in self.param_groups:
51 | group["lookahead_step"] += 1
52 | if group["lookahead_step"] % group["lookahead_k"] == 0:
53 | self.update_slow(group)
54 | return loss
55 |
56 | def state_dict(self):
57 | fast_state_dict = self.base_optimizer.state_dict()
58 | slow_state = {(id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items()}
59 | fast_state = fast_state_dict["state"]
60 | param_groups = fast_state_dict["param_groups"]
61 | return {
62 | "state": fast_state,
63 | "slow_state": slow_state,
64 | "param_groups": param_groups,
65 | }
66 |
67 | def load_state_dict(self, state_dict):
68 | fast_state_dict = {
69 | "state": state_dict["state"],
70 | "param_groups": state_dict["param_groups"],
71 | }
72 | self.base_optimizer.load_state_dict(fast_state_dict)
73 |
74 | # We want to restore the slow state, but share param_groups reference
75 | # with base_optimizer. This is a bit redundant but least code
76 | slow_state_new = False
77 | if "slow_state" not in state_dict:
78 | print("Loading state_dict from optimizer without Lookahead applied.")
79 | state_dict["slow_state"] = defaultdict(dict)
80 | slow_state_new = True
81 | slow_state_dict = {
82 | "state": state_dict["slow_state"],
83 | "param_groups": state_dict["param_groups"], # this is pointless but saves code
84 | }
85 | super(Lookahead, self).load_state_dict(slow_state_dict)
86 | self.param_groups = self.base_optimizer.param_groups # make both ref same container
87 | if slow_state_new:
88 | # reapply defaults to catch missing lookahead specific ones
89 | for name, default in self.defaults.items():
90 | for group in self.param_groups:
91 | group.setdefault(name, default)
92 |
--------------------------------------------------------------------------------
/lib/torch_utils/solver/over9000.py:
--------------------------------------------------------------------------------
1 | ####
2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000
3 | ####
4 |
5 | # import torch, math
6 | # from torch.optim.optimizer import Optimizer
7 | # import itertools as it
8 | from .lookahead import Lookahead
9 | from .ralamb import Ralamb
10 |
11 |
12 | # RangerLars = Over9000 = RAdam + LARS + LookAHead
13 |
14 | # Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py
15 | # RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20
16 |
17 |
18 | def Over9000(params, alpha=0.5, k=6, *args, **kwargs):
19 | ralamb = Ralamb(params, *args, **kwargs)
20 | return Lookahead(ralamb, alpha, k)
21 |
22 |
23 | RangerLars = Over9000
24 |
--------------------------------------------------------------------------------
/lib/torch_utils/solver/ralamb.py:
--------------------------------------------------------------------------------
1 | ####
2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000
3 | ####
4 |
5 | import torch, math
6 | from torch.optim.optimizer import Optimizer
7 |
8 | # RAdam + LARS
9 | class Ralamb(Optimizer):
10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
11 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
12 | self.buffer = [[None, None, None] for ind in range(10)]
13 | super(Ralamb, self).__init__(params, defaults)
14 |
15 | def __setstate__(self, state):
16 | super(Ralamb, self).__setstate__(state)
17 |
18 | def step(self, closure=None):
19 |
20 | loss = None
21 | if closure is not None:
22 | loss = closure()
23 |
24 | for group in self.param_groups:
25 |
26 | for p in group["params"]:
27 | if p.grad is None:
28 | continue
29 | grad = p.grad.data.float()
30 | if grad.is_sparse:
31 | raise RuntimeError("Ralamb does not support sparse gradients")
32 |
33 | p_data_fp32 = p.data.float()
34 |
35 | state = self.state[p]
36 |
37 | if len(state) == 0:
38 | state["step"] = 0
39 | state["exp_avg"] = torch.zeros_like(p_data_fp32)
40 | state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
41 | else:
42 | state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
43 | state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
44 |
45 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
46 | beta1, beta2 = group["betas"]
47 |
48 | # Decay the first and second moment running average coefficient
49 | # m_t
50 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
51 | # v_t
52 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
53 |
54 | state["step"] += 1
55 | buffered = self.buffer[int(state["step"] % 10)]
56 |
57 | if state["step"] == buffered[0]:
58 | N_sma, radam_step_size = buffered[1], buffered[2]
59 | else:
60 | buffered[0] = state["step"]
61 | beta2_t = beta2 ** state["step"]
62 | N_sma_max = 2 / (1 - beta2) - 1
63 | N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
64 | buffered[1] = N_sma
65 |
66 | # more conservative since it's an approximated value
67 | if N_sma >= 5:
68 | radam_step_size = math.sqrt(
69 | (1 - beta2_t)
70 | * (N_sma - 4)
71 | / (N_sma_max - 4)
72 | * (N_sma - 2)
73 | / N_sma
74 | * N_sma_max
75 | / (N_sma_max - 2)
76 | ) / (1 - beta1 ** state["step"])
77 | else:
78 | radam_step_size = 1.0 / (1 - beta1 ** state["step"])
79 | buffered[2] = radam_step_size
80 |
81 | if group["weight_decay"] != 0:
82 | p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
83 |
84 | # more conservative since it's an approximated value
85 | radam_step = p_data_fp32.clone()
86 | if N_sma >= 5:
87 | denom = exp_avg_sq.sqrt().add_(group["eps"])
88 | radam_step.addcdiv_(-radam_step_size * group["lr"], exp_avg, denom)
89 | else:
90 | radam_step.add_(-radam_step_size * group["lr"], exp_avg)
91 |
92 | radam_norm = radam_step.pow(2).sum().sqrt()
93 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
94 | if weight_norm == 0 or radam_norm == 0:
95 | trust_ratio = 1
96 | else:
97 | trust_ratio = weight_norm / radam_norm
98 |
99 | state["weight_norm"] = weight_norm
100 | state["adam_norm"] = radam_norm
101 | state["trust_ratio"] = trust_ratio
102 |
103 | if N_sma >= 5:
104 | p_data_fp32.addcdiv_(
105 | -radam_step_size * group["lr"] * trust_ratio,
106 | exp_avg,
107 | denom,
108 | )
109 | else:
110 | p_data_fp32.add_(-radam_step_size * group["lr"] * trust_ratio, exp_avg)
111 |
112 | p.data.copy_(p_data_fp32)
113 |
114 | return loss
115 |
--------------------------------------------------------------------------------
/lib/torch_utils/solver/sgdp.py:
--------------------------------------------------------------------------------
1 | """AdamP Copyright (c) 2020-present NAVER Corp.
2 |
3 | MIT license
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.optim.optimizer import Optimizer, required
10 | import math
11 |
12 |
13 | class SGDP(Optimizer):
14 | def __init__(
15 | self,
16 | params,
17 | lr=required,
18 | momentum=0,
19 | dampening=0,
20 | weight_decay=0,
21 | nesterov=False,
22 | eps=1e-8,
23 | delta=0.1,
24 | wd_ratio=0.1,
25 | ):
26 | defaults = dict(
27 | lr=lr,
28 | momentum=momentum,
29 | dampening=dampening,
30 | weight_decay=weight_decay,
31 | nesterov=nesterov,
32 | eps=eps,
33 | delta=delta,
34 | wd_ratio=wd_ratio,
35 | )
36 | super(SGDP, self).__init__(params, defaults)
37 |
38 | def _channel_view(self, x):
39 | return x.view(x.size(0), -1)
40 |
41 | def _layer_view(self, x):
42 | return x.view(1, -1)
43 |
44 | def _cosine_similarity(self, x, y, eps, view_func):
45 | x = view_func(x)
46 | y = view_func(y)
47 |
48 | return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()
49 |
50 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
51 | wd = 1
52 | expand_size = [-1] + [1] * (len(p.shape) - 1)
53 | for view_func in [self._channel_view, self._layer_view]:
54 |
55 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
56 |
57 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
58 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
59 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
60 | wd = wd_ratio
61 |
62 | return perturb, wd
63 |
64 | return perturb, wd
65 |
66 | def step(self, closure=None):
67 | loss = None
68 | if closure is not None:
69 | loss = closure()
70 |
71 | for group in self.param_groups:
72 | momentum = group["momentum"]
73 | dampening = group["dampening"]
74 | nesterov = group["nesterov"]
75 |
76 | for p in group["params"]:
77 | if p.grad is None:
78 | continue
79 | grad = p.grad.data
80 | state = self.state[p]
81 |
82 | # State initialization
83 | if len(state) == 0:
84 | state["momentum"] = torch.zeros_like(p.data)
85 |
86 | # SGD
87 | buf = state["momentum"]
88 | buf.mul_(momentum).add_(grad, alpha=1 - dampening)
89 | if nesterov:
90 | d_p = grad + momentum * buf
91 | else:
92 | d_p = buf
93 |
94 | # Projection
95 | wd_ratio = 1
96 | if len(p.shape) > 1:
97 | d_p, wd_ratio = self._projection(
98 | p,
99 | grad,
100 | d_p,
101 | group["delta"],
102 | group["wd_ratio"],
103 | group["eps"],
104 | )
105 |
106 | # Weight decay
107 | if group["weight_decay"] > 0:
108 | p.data.mul_(1 - group["lr"] * group["weight_decay"] * wd_ratio / (1 - momentum))
109 |
110 | # Step
111 | p.data.add_(d_p, alpha=-group["lr"])
112 |
113 | return loss
114 |
--------------------------------------------------------------------------------
/lib/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/utils/__init__.py
--------------------------------------------------------------------------------
/lib/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | from mmcv import Config
3 |
4 |
5 | def try_get_key(cfg, *keys, default=None):
6 | """# modified from detectron2 to also support mmcv Config.
7 |
8 | Try select keys from cfg until the first key that exists. Otherwise
9 | return default.
10 | """
11 | from detectron2.config import CfgNode
12 |
13 | if isinstance(cfg, CfgNode):
14 | cfg = OmegaConf.create(cfg.dump())
15 | elif isinstance(cfg, Config): # mmcv Config
16 | cfg = OmegaConf.create(cfg._cfg_dict.to_dict())
17 | elif isinstance(cfg, dict): # raw dict
18 | cfg = OmegaConf.create(cfg)
19 |
20 | for k in keys:
21 | none = object()
22 | p = OmegaConf.select(cfg, k, default=none)
23 | if p is not none:
24 | return p
25 | return default
26 |
--------------------------------------------------------------------------------
/lib/utils/fs.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: fs.py from tensorpack
3 | import os
4 | from six.moves import urllib
5 | import errno
6 | import tqdm
7 | from . import logger
8 | from .utils import execute_only_once
9 |
10 | __all__ = ["mkdir_p", "download", "recursive_walk", "get_dataset_path"]
11 |
12 |
13 | def mkdir_p(dirname):
14 | """Like "mkdir -p", make a dir recursively, but do nothing if the dir
15 | exists.
16 |
17 | Args:
18 | dirname(str):
19 | """
20 | assert dirname is not None
21 | if dirname == "" or os.path.isdir(dirname):
22 | return
23 | try:
24 | os.makedirs(dirname)
25 | except OSError as e:
26 | if e.errno != errno.EEXIST:
27 | raise e
28 |
29 |
30 | def download(url, dir, filename=None, expect_size=None):
31 | """Download URL to a directory.
32 |
33 | Will figure out the filename automatically from URL, if not given.
34 | """
35 | mkdir_p(dir)
36 | if filename is None:
37 | filename = url.split("/")[-1]
38 | fpath = os.path.join(dir, filename)
39 |
40 | if os.path.isfile(fpath):
41 | if expect_size is not None and os.stat(fpath).st_size == expect_size:
42 | logger.info("File {} exists! Skip download.".format(filename))
43 | return fpath
44 | else:
45 | logger.warning("File {} exists. Will overwrite with a new download!".format(filename))
46 |
47 | def hook(t):
48 | last_b = [0]
49 |
50 | def inner(b, bsize, tsize=None):
51 | if tsize is not None:
52 | t.total = tsize
53 | t.update((b - last_b[0]) * bsize)
54 | last_b[0] = b
55 |
56 | return inner
57 |
58 | try:
59 | with tqdm.tqdm(unit="B", unit_scale=True, miniters=1, desc=filename) as t:
60 | fpath, _ = urllib.request.urlretrieve(url, fpath, reporthook=hook(t))
61 | statinfo = os.stat(fpath)
62 | size = statinfo.st_size
63 | except IOError:
64 | logger.error("Failed to download {}".format(url))
65 | raise
66 | assert size > 0, "Downloaded an empty file from {}!".format(url)
67 |
68 | if expect_size is not None and size != expect_size:
69 | logger.error("File downloaded from {} does not match the expected size!".format(url))
70 | logger.error("You may have downloaded a broken file, or the upstream may have modified the file.")
71 |
72 | # TODO human-readable size
73 | logger.info("Succesfully downloaded " + filename + ". " + str(size) + " bytes.")
74 | return fpath
75 |
76 |
77 | def recursive_walk(rootdir):
78 | """
79 | Yields:
80 | str: All files in rootdir, recursively.
81 | """
82 | for r, dirs, files in os.walk(rootdir):
83 | for f in files:
84 | yield os.path.join(r, f)
85 |
86 |
87 | def get_dataset_path(*args):
88 | """Get the path to some dataset under ``$TENSORPACK_DATASET``.
89 |
90 | Args:
91 | args: strings to be joined to form path.
92 |
93 | Returns:
94 | str: path to the dataset.
95 | """
96 | d = os.environ.get("TENSORPACK_DATASET", None)
97 | if d is None:
98 | d = os.path.join(os.path.expanduser("~"), "tensorpack_data")
99 | if execute_only_once():
100 | logger.warning("Env var $TENSORPACK_DATASET not set, using {} for datasets.".format(d))
101 | if not os.path.isdir(d):
102 | mkdir_p(d)
103 | logger.info("Created the directory {}.".format(d))
104 | assert os.path.isdir(d), d
105 | return os.path.join(d, *args)
106 |
107 |
108 | if __name__ == "__main__":
109 | download("http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz", ".")
110 |
--------------------------------------------------------------------------------
/lib/utils/time_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import timeit
3 | import time
4 | from datetime import datetime, timedelta
5 | from . import logger
6 |
7 |
8 | def average_time_of_func(func, num_iter=10000, ms=False):
9 | """
10 | ms: if True, xxx ms/iter
11 | """
12 | duration = timeit.timeit(func, number=num_iter)
13 | avg_time = duration / num_iter
14 | if ms:
15 | avg_time *= 1000
16 | logger.info("{} {} ms/iter".format(func.__name__, avg_time))
17 | else:
18 | logger.info("{} {} s/iter".format(func.__name__, avg_time))
19 | return avg_time
20 |
21 |
22 | def my_timeit(func, number=100000):
23 | tic = time.perf_counter()
24 | for i in range(number):
25 | func()
26 | return time.perf_counter() - tic
27 |
28 |
29 | def get_time_str(fmt="%Y%m%d_%H%M%S", hours_offset=8):
30 | # get UTC+8 time by default
31 | # set hours_offset to 0 to get UTC time
32 | # use utc time to avoid the problem of mis-configured timezone on some machines
33 | return (datetime.utcnow() + timedelta(hours=hours_offset)).strftime(fmt)
34 |
35 |
36 | # def get_time_str(fmt='%Y%m%d_%H%M%S'):
37 | # # from mmcv.runner import get_time_str
38 | # return time.strftime(fmt, time.localtime()) # defined in mmcv
39 |
40 |
41 | def get_time_delta(sec):
42 | """Humanize timedelta given in seconds, modified from maskrcnn-
43 | benchmark."""
44 | if sec < 0:
45 | logger.warning("get_time_delta() obtains negative seconds!")
46 | return "{:.3g} seconds".format(sec)
47 | delta_time_str = str(timedelta(seconds=sec))
48 | return delta_time_str
49 |
50 |
51 | class Timer(object):
52 | # modified from maskrcnn-benchmark
53 | def __init__(self):
54 | self.reset()
55 |
56 | @property
57 | def average_time(self):
58 | return self.total_time / self.calls if self.calls > 0 else 0.0
59 |
60 | def tic(self):
61 | # using time.time instead of time.clock because time time.clock
62 | # does not normalize for multithreading
63 | self.start_time = time.perf_counter()
64 |
65 | def toc(self, average=True):
66 | self.add(time.perf_counter() - self.start_time)
67 | if average:
68 | return self.average_time
69 | else:
70 | return self.diff
71 |
72 | def add(self, time_diff):
73 | self.diff = time_diff
74 | self.total_time += self.diff
75 | self.calls += 1
76 |
77 | def reset(self):
78 | self.total_time = 0.0
79 | self.calls = 0
80 | self.start_time = 0.0
81 | self.diff = 0.0
82 |
83 | def avg_time_str(self):
84 | time_str = get_time_delta(self.average_time)
85 | return time_str
86 |
87 |
88 | def humanize_time_delta(sec):
89 | """Humanize timedelta given in seconds
90 | Args:
91 | sec (float): time difference in seconds. Must be positive.
92 | Returns:
93 | str - time difference as a readable string
94 | Example:
95 | .. code-block:: python
96 | print(humanize_time_delta(1)) # 1 second
97 | print(humanize_time_delta(60 + 1)) # 1 minute 1 second
98 | print(humanize_time_delta(87.6)) # 1 minute 27 seconds
99 | print(humanize_time_delta(0.01)) # 0.01 seconds
100 | print(humanize_time_delta(60 * 60 + 1)) # 1 hour 1 second
101 | print(humanize_time_delta(60 * 60 * 24 + 1)) # 1 day 1 second
102 | print(humanize_time_delta(60 * 60 * 24 + 60 * 2 + 60*60*9 + 3)) # 1 day 9 hours 2 minutes 3 seconds
103 | """
104 | if sec < 0:
105 | logger.warning("humanize_time_delta() obtains negative seconds!")
106 | return "{:.3g} seconds".format(sec)
107 | if sec == 0:
108 | return "0 second"
109 | _time = datetime(2000, 1, 1) + timedelta(seconds=int(sec))
110 | units = ["day", "hour", "minute", "second"]
111 | vals = [int(sec // 86400), _time.hour, _time.minute, _time.second]
112 | if sec < 60:
113 | vals[-1] = sec
114 |
115 | def _format(v, u):
116 | return "{:.3g} {}{}".format(v, u, "s" if v > 1 else "")
117 |
118 | ans = []
119 | for v, u in zip(vals, units):
120 | if v > 0:
121 | ans.append(_format(v, u))
122 | return " ".join(ans)
123 |
--------------------------------------------------------------------------------
/lib/vis_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/lib/vis_utils/__init__.py
--------------------------------------------------------------------------------
/lib/vis_utils/colormap.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | """An awesome colormap for really neat visualizations.
3 |
4 | Copied from Detectron, and removed gray colors.
5 | """
6 |
7 | import numpy as np
8 |
9 | __all__ = ["colormap", "random_color"]
10 |
11 | # yapf: disable
12 | # RGB:
13 | _COLORS = np.array(
14 | [
15 | 0.000, 0.447, 0.741,
16 | 0.850, 0.325, 0.098,
17 | 0.929, 0.694, 0.125,
18 | 0.494, 0.184, 0.556,
19 | 0.466, 0.674, 0.188,
20 | 0.301, 0.745, 0.933,
21 | 0.635, 0.078, 0.184,
22 | 0.300, 0.300, 0.300,
23 | 0.600, 0.600, 0.600,
24 | 1.000, 0.000, 0.000,
25 | 1.000, 0.500, 0.000,
26 | 0.749, 0.749, 0.000,
27 | 0.000, 1.000, 0.000,
28 | 0.000, 0.000, 1.000,
29 | 0.667, 0.000, 1.000,
30 | 0.333, 0.333, 0.000,
31 | 0.333, 0.667, 0.000,
32 | 0.333, 1.000, 0.000,
33 | 0.667, 0.333, 0.000,
34 | 0.667, 0.667, 0.000,
35 | 0.667, 1.000, 0.000,
36 | 1.000, 0.333, 0.000,
37 | 1.000, 0.667, 0.000,
38 | 1.000, 1.000, 0.000,
39 | 0.000, 0.333, 0.500,
40 | 0.000, 0.667, 0.500,
41 | 0.000, 1.000, 0.500,
42 | 0.333, 0.000, 0.500,
43 | 0.333, 0.333, 0.500,
44 | 0.333, 0.667, 0.500,
45 | 0.333, 1.000, 0.500,
46 | 0.667, 0.000, 0.500,
47 | 0.667, 0.333, 0.500,
48 | 0.667, 0.667, 0.500,
49 | 0.667, 1.000, 0.500,
50 | 1.000, 0.000, 0.500,
51 | 1.000, 0.333, 0.500,
52 | 1.000, 0.667, 0.500,
53 | 1.000, 1.000, 0.500,
54 | 0.000, 0.333, 1.000,
55 | 0.000, 0.667, 1.000,
56 | 0.000, 1.000, 1.000,
57 | 0.333, 0.000, 1.000,
58 | 0.333, 0.333, 1.000,
59 | 0.333, 0.667, 1.000,
60 | 0.333, 1.000, 1.000,
61 | 0.667, 0.000, 1.000,
62 | 0.667, 0.333, 1.000,
63 | 0.667, 0.667, 1.000,
64 | 0.667, 1.000, 1.000,
65 | 1.000, 0.000, 1.000,
66 | 1.000, 0.333, 1.000,
67 | 1.000, 0.667, 1.000,
68 | 0.333, 0.000, 0.000,
69 | 0.500, 0.000, 0.000,
70 | 0.667, 0.000, 0.000,
71 | 0.833, 0.000, 0.000,
72 | 1.000, 0.000, 0.000,
73 | 0.000, 0.167, 0.000,
74 | 0.000, 0.333, 0.000,
75 | 0.000, 0.500, 0.000,
76 | 0.000, 0.667, 0.000,
77 | 0.000, 0.833, 0.000,
78 | 0.000, 1.000, 0.000,
79 | 0.000, 0.000, 0.167,
80 | 0.000, 0.000, 0.333,
81 | 0.000, 0.000, 0.500,
82 | 0.000, 0.000, 0.667,
83 | 0.000, 0.000, 0.833,
84 | 0.000, 0.000, 1.000,
85 | 0.000, 0.000, 0.000,
86 | 0.143, 0.143, 0.143,
87 | 0.857, 0.857, 0.857,
88 | 1.000, 1.000, 1.000
89 | ]
90 | ).astype(np.float32).reshape(-1, 3)
91 | # yapf: enable
92 |
93 |
94 | def colormap(rgb=False, maximum=255):
95 | """
96 | Args:
97 | rgb (bool): whether to return RGB colors or BGR colors.
98 | maximum (int): either 255 or 1
99 |
100 | Returns:
101 | ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
102 | """
103 | assert maximum in [255, 1], maximum
104 | c = _COLORS * maximum
105 | if not rgb:
106 | c = c[:, ::-1]
107 | return c
108 |
109 |
110 | def random_color(rgb=False, maximum=255):
111 | """
112 | Args:
113 | rgb (bool): whether to return RGB colors or BGR colors.
114 | maximum (int): either 255 or 1
115 |
116 | Returns:
117 | ndarray: a vector of 3 numbers
118 | """
119 | idx = np.random.randint(0, len(_COLORS))
120 | ret = _COLORS[idx] * maximum
121 | if not rgb:
122 | ret = ret[::-1]
123 | return ret
124 |
125 |
126 | if __name__ == "__main__":
127 | import cv2
128 |
129 | size = 100
130 | H, W = 10, 10
131 | canvas = np.random.rand(H * size, W * size, 3).astype("float32")
132 | for h in range(H):
133 | for w in range(W):
134 | idx = h * W + w
135 | if idx >= len(_COLORS):
136 | break
137 | canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
138 | cv2.imshow("a", canvas)
139 | cv2.waitKey(0)
140 |
--------------------------------------------------------------------------------
/lib/vis_utils/optflow.py:
--------------------------------------------------------------------------------
1 | """visulization of optical flow (optflow)"""
2 | from __future__ import division
3 |
4 | import numpy as np
5 |
6 | from mmcv.image import rgb2bgr
7 | from mmcv.video import flowread
8 | from .image import imshow # cv2 imshow
9 | import matplotlib.pyplot as plt
10 |
11 |
12 | def flowshow(flow, win_name="", wait_time=0, vis_tool="matplotlib"):
13 | """Show optical flow.
14 |
15 | Args:
16 | flow (ndarray or str): The optical flow to be displayed.
17 | win_name (str): The window name.
18 | wait_time (int): Value of waitKey param.
19 | """
20 | flow = flowread(flow) # HxWx2
21 | flow_img = flow2rgb(flow)
22 | if vis_tool == "matplotlib":
23 | fig = plt.figure(frameon=False, figsize=(8, 6), dpi=100)
24 | tmp = fig.add_subplot(1, 1, 1)
25 | tmp.set_title("{}".format(win_name))
26 | plt.axis("off")
27 | plt.imshow(flow_img)
28 | plt.show()
29 | else:
30 | imshow(rgb2bgr(flow_img), win_name, wait_time)
31 |
32 |
33 | def flow2rgb(flow, color_wheel=None, unknown_thr=1e6):
34 | # NOTE: the same as flowlib.get_flow_show(flow, mode='Y')
35 | """Convert flow map to RGB image.
36 |
37 | Args:
38 | flow (ndarray): Array of optical flow.
39 | color_wheel (ndarray or None): Color wheel used to map flow field to
40 | RGB colorspace. Default color wheel will be used if not specified.
41 | unknown_thr (str): Values above this threshold will be marked as
42 | unknown and thus ignored.
43 | Returns:
44 | ndarray: RGB image that can be visualized.
45 | """
46 | assert flow.ndim == 3 and flow.shape[-1] == 2, flow.shape
47 | if color_wheel is None:
48 | color_wheel = make_color_wheel()
49 | assert color_wheel.ndim == 2 and color_wheel.shape[1] == 3
50 | num_bins = color_wheel.shape[0]
51 |
52 | dx = flow[:, :, 0].copy()
53 | dy = flow[:, :, 1].copy()
54 |
55 | ignore_inds = np.isnan(dx) | np.isnan(dy) | (np.abs(dx) > unknown_thr) | (np.abs(dy) > unknown_thr)
56 | dx[ignore_inds] = 0
57 | dy[ignore_inds] = 0
58 |
59 | rad = np.sqrt(dx**2 + dy**2)
60 | if np.any(rad > np.finfo(float).eps):
61 | max_rad = np.max(rad)
62 | dx /= max_rad
63 | dy /= max_rad
64 |
65 | [h, w] = dx.shape
66 |
67 | rad = np.sqrt(dx**2 + dy**2)
68 | angle = np.arctan2(-dy, -dx) / np.pi
69 |
70 | bin_real = (angle + 1) / 2 * (num_bins - 1)
71 | bin_left = np.floor(bin_real).astype(int)
72 | bin_right = (bin_left + 1) % num_bins
73 | w = (bin_real - bin_left.astype(np.float32))[..., None]
74 | flow_img = (1 - w) * color_wheel[bin_left, :] + w * color_wheel[bin_right, :]
75 | small_ind = rad <= 1
76 | flow_img[small_ind] = 1 - rad[small_ind, None] * (1 - flow_img[small_ind])
77 | flow_img[np.logical_not(small_ind)] *= 0.75
78 |
79 | flow_img[ignore_inds, :] = 0
80 |
81 | return flow_img
82 |
83 |
84 | def make_color_wheel(bins=None):
85 | """Build a color wheel.
86 |
87 | Args:
88 | bins(list or tuple, optional): Specify the number of bins for each
89 | color range, corresponding to six ranges: red -> yellow,
90 | yellow -> green, green -> cyan, cyan -> blue, blue -> magenta,
91 | magenta -> red. [15, 6, 4, 11, 13, 6] is used for default
92 | (see Middlebury).
93 | Returns:
94 | ndarray: Color wheel of shape (total_bins, 3).
95 | """
96 | if bins is None:
97 | bins = [15, 6, 4, 11, 13, 6]
98 | assert len(bins) == 6
99 |
100 | RY, YG, GC, CB, BM, MR = tuple(bins)
101 |
102 | ry = [1, np.arange(RY) / RY, 0]
103 | yg = [1 - np.arange(YG) / YG, 1, 0]
104 | gc = [0, 1, np.arange(GC) / GC]
105 | cb = [0, 1 - np.arange(CB) / CB, 1]
106 | bm = [np.arange(BM) / BM, 0, 1]
107 | mr = [1, 0, 1 - np.arange(MR) / MR]
108 |
109 | num_bins = RY + YG + GC + CB + BM + MR
110 |
111 | color_wheel = np.zeros((3, num_bins), dtype=np.float32)
112 |
113 | col = 0
114 | for i, color in enumerate([ry, yg, gc, cb, bm, mr]):
115 | for j in range(3):
116 | color_wheel[j, col : col + bins[i]] = color[j]
117 | col += bins[i]
118 |
119 | return color_wheel.T
120 |
--------------------------------------------------------------------------------
/output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e/model_final_wo_optim-82cf930e.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-DA-6D-Pose-Group/CATRE/89b1b375d38cf4cc2286912522628d2266d8c41b/output/catre/NOCS_REAL/aug05_kpsMS_r9d_catreDisR_shared_tspcl_convPerRot_scaleexp_120e/model_final_wo_optim-82cf930e.pth
--------------------------------------------------------------------------------
/preprocess/shape_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Originally writen by https://github.com/mentian/object-deformnet/
3 | """
4 |
5 | import h5py
6 | import numpy as np
7 | import torch.utils.data as data
8 |
9 |
10 | class ShapeDataset(data.Dataset):
11 | def __init__(self, h5_file, mode, n_points=2048, augment=False):
12 | assert mode == "train" or mode == "val", 'Mode must be "train" or "val".'
13 | self.mode = mode
14 | self.n_points = n_points
15 | self.augment = augment
16 | # load data from h5py file
17 | with h5py.File(h5_file, "r") as f:
18 | self.length = f[self.mode].attrs["len"]
19 | self.data = f[self.mode]["data"][:]
20 | self.label = f[self.mode]["label"][:]
21 | # augmentation parameters
22 | self.sigma = 0.01
23 | self.clip = 0.02
24 | self.shift_range = 0.02
25 |
26 | def __len__(self):
27 | return self.length
28 |
29 | def __getitem__(self, index):
30 | xyz = self.data[index]
31 | label = self.label[index] - 1 # data saved indexed from 1
32 | # randomly downsample
33 | np_data = xyz.shape[0]
34 | assert np_data >= self.n_points, "Not enough points in shape."
35 | idx = np.random.choice(np_data, self.n_points)
36 | xyz = xyz[idx, :]
37 | # data augmentation
38 | if self.augment:
39 | jitter = np.clip(self.sigma * np.random.randn(self.n_points, 3), -self.clip, self.clip)
40 | xyz[:, :3] += jitter
41 | shift = np.random.uniform(-self.shift_range, self.shift_range, (1, 3))
42 | xyz[:, :3] += shift
43 | return xyz, label
44 |
--------------------------------------------------------------------------------
/ref/__init__.py:
--------------------------------------------------------------------------------
1 | from . import nocs
2 |
--------------------------------------------------------------------------------
/ref/cmra.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """This file includes necessary params, info."""
3 | import os
4 | import mmcv
5 | import os.path as osp
6 |
7 | import numpy as np
8 |
9 | # ---------------------------------------------------------------- #
10 | # ROOT PATH INFO
11 | # ---------------------------------------------------------------- #
12 | cur_dir = osp.abspath(osp.dirname(__file__))
13 | root_dir = osp.normpath(osp.join(cur_dir, ".."))
14 | # directory storing experiment data (result, model checkpoints, etc).
15 | output_dir = osp.join(root_dir, "output")
16 |
17 | data_root = osp.join(root_dir, "datasets")
18 |
19 | # ---------------------------------------------------------------- #
20 | # NOCS DATASET
21 | # ---------------------------------------------------------------- #
22 | dataset_root = osp.join(data_root, "NOCS/")
23 | train_dir = osp.join(dataset_root, "CAMERA/")
24 |
25 | model_dir = osp.join(dataset_root, "obj_models")
26 | mean_model_path = osp.join(model_dir, "cr_normed_mean_model_points_spd.pkl")
27 | train_model_path = osp.join(model_dir, "camera_train.pkl")
28 | test_model_path = osp.join(model_dir, "camera_val.pkl")
29 |
30 | # object info
31 | objects = ["bottle", "bowl", "camera", "can", "laptop", "mug"]
32 |
33 | obj2id = {"bottle": 1, "bowl": 2, "camera": 3, "can": 4, "laptop": 5, "mug": 6}
34 |
35 | obj_num = len(objects)
36 |
37 | id2obj = {_id: _name for _name, _id in obj2id.items()}
38 |
39 | # id2obj_camera = {1: "02876657", 2: "02880940", 3: "02942699", 4: "02946921", 5: "03642806", 6: "03797390"}
40 | # objects_camera = list(id2obj_camera.values())
41 | # obj2id_camera = {_name: _id for _id, _name in id2obj_camera.items()}
42 |
43 | # Camera info
44 | width = 640
45 | height = 480
46 | center = (height / 2, width / 2)
47 |
48 | intrinsics = np.array([[577.5, 0, 319.5], [0, 577.5, 239.5], [0, 0, 1]], dtype=np.float32) # [fx, fy, cx, cy]
49 |
50 | mean_scale = {
51 | "bottle": 0.001 * np.array([81, 218.5, 80.25], dtype=np.float32),
52 | "bowl": 0.001 * np.array([168.75, 67.75, 168.75], dtype=np.float32),
53 | "camera": 0.001 * np.array([116.0, 121.75, 175.5], dtype=np.float32),
54 | "can": 0.001 * np.array([112.5, 188.25, 115.0], dtype=np.float32),
55 | "laptop": 0.001 * np.array([145.25, 111.25, 168.0], dtype=np.float32),
56 | "mug": 0.001 * np.array([167.5, 135.0, 124.25], dtype=np.float32),
57 | }
58 |
59 |
60 | def get_mean_bbox3d():
61 | mean_bboxes = {}
62 | for key, value in intrinsics.items():
63 | minx, maxx = -value[0] / 2, value[0] / 2
64 | miny, maxy = -value[1] / 2, value[1] / 2
65 | minz, maxz = -value[2] / 2, value[2] / 2
66 |
67 | mean_bboxes[key] = np.array(
68 | [
69 | [maxx, maxy, maxz],
70 | [minx, maxy, maxz],
71 | [minx, miny, maxz],
72 | [maxx, miny, maxz],
73 | [maxx, maxy, minz],
74 | [minx, maxy, minz],
75 | [minx, miny, minz],
76 | [maxx, miny, minz],
77 | ],
78 | dtype=np.float32,
79 | )
80 | return mean_bboxes
81 |
82 |
83 | def get_sym_info(obj_name, mug_handle=1):
84 | # Y axis points upwards, x axis pass through the handle, z axis otherwise
85 | # return sym axis
86 | if obj_name == "bottle":
87 | sym = np.array([0, 1, 0], dtype=np.int)
88 | elif obj_name == "bowl":
89 | sym = np.array([0, 1, 0], dtype=np.int)
90 | elif obj_name == "camera":
91 | sym = None
92 | elif obj_name == "can":
93 | sym = np.array([0, 1, 0], dtype=np.int)
94 | elif obj_name == "laptop":
95 | sym = None
96 | elif obj_name == "mug":
97 | if mug_handle == 1:
98 | sym = None
99 | else:
100 | sym = np.array([0, 1, 0], dtype=np.int)
101 | else:
102 | raise NotImplementedError(f"No such a object class {obj_name}")
103 | return sym
104 |
105 |
106 | def get_fps_points():
107 | """key is inst_name generated by
108 | core/catre/tools/nocs/nocs_fps_sample.py."""
109 | fps_points_path = osp.join(model_dir, "cmra_fps_points_spd.pkl")
110 | assert osp.exists(fps_points_path), fps_points_path
111 | fps_dict = mmcv.load(fps_points_path)
112 | return fps_dict
113 |
--------------------------------------------------------------------------------
/requirements/requirements.txt:
--------------------------------------------------------------------------------
1 | cython
2 | plyfile
3 |
4 | ujson # make json loading in cocoapi_nvidia faster
5 | pycocotools
6 |
7 | cffi
8 | ninja
9 | black
10 | docformatter
11 | setproctitle
12 | fastfunc
13 | meshplex
14 | OpenEXR
15 | vispy>=0.6.4
16 | yacs>=0.1.8
17 | tabulate
18 | pytest-runner
19 | pytest
20 | ipdb
21 | tqdm
22 | numba
23 | mmcv-full
24 | imagecorruptions
25 | pyassimp==4.1.3 # 4.1.4 will cause egl_renderer SegmentFault
26 | pypng
27 | imgaug>=0.4.0
28 | albumentations
29 | transforms3d
30 | # pyquaternion
31 | torchvision
32 | open3d
33 | fvcore
34 | tensorboardX
35 | einops
36 | pytorch3d
37 | # timm # pytorch-image-models
38 | # git+ssh://git@github.com/rwightman/pytorch-image-models.git # the latest timm
39 | glfw
40 | imageio
41 | imageio-ffmpeg
42 | PyOpenGL # >=3.1.5
43 | PyOpenGL_accelerate
44 | chardet
45 | h5py
46 |
47 | thop # https://github.com/Lyken17/pytorch-OpCounter
48 | loguru
49 |
50 | pytorch-lightning # tested for 1.6.0.dev0
51 | fairscale
52 | deepspeed
53 |
54 | # verified versions
55 | onnx==1.8.1
56 | onnxruntime==1.8.0
57 | onnx-simplifier==0.3.5
58 |
--------------------------------------------------------------------------------
/scripts/install_deps.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # some other dependencies
3 | set -x
4 | install=${1:-"all"}
5 |
6 | if test "$install" = "all"; then
7 | echo "Installing apt dependencies"
8 | sudo apt-get install -y libjpeg-dev zlib1g-dev
9 | sudo apt-get install -y libopenexr-dev
10 | sudo apt-get install -y openexr
11 | sudo apt-get install -y python3-dev
12 | sudo apt-get install -y libglfw3-dev libglfw3
13 | sudo apt-get install -y libglew-dev
14 | sudo apt-get install -y libassimp-dev
15 | sudo apt-get install -y libnuma-dev # for byteps
16 | sudo apt install -y clang
17 | ## for bop cpp renderer
18 | sudo apt install -y curl
19 | sudo apt install -y autoconf
20 | sudo apt-get install -y build-essential libtool
21 |
22 | ## for uncertainty pnp
23 | sudo apt-get install -y libeigen3-dev
24 | sudo apt-get install -y libgoogle-glog-dev
25 | sudo apt-get install -y libsuitesparse-dev
26 | sudo apt-get install -y libatlas-base-dev
27 |
28 | ## for nvdiffrast/egl
29 | sudo apt-get install -y --no-install-recommends \
30 | cmake curl pkg-config
31 | sudo apt-get install -y --no-install-recommends \
32 | libgles2 \
33 | libgl1-mesa-dev \
34 | libegl1-mesa-dev \
35 | libgles2-mesa-dev
36 | # (only available for Ubuntu >= 18.04)
37 | sudo apt-get install -y --no-install-recommends \
38 | libglvnd0 \
39 | libgl1 \
40 | libglx0 \
41 | libegl1 \
42 | libglvnd-dev
43 |
44 | sudo apt-get install -y libglew-dev
45 | # for GLEW, add this into ~/.bashrc
46 | # export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH
47 | fi
48 |
49 | pip install -r requirements/requirements.txt
50 |
51 | pip uninstall pillow
52 | CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
53 |
--------------------------------------------------------------------------------