├── CatePoseEstimation ├── README.md ├── align.py ├── config.py ├── configs │ └── swin_tiny_patch4_window7_224_lite.yaml ├── data │ └── DOWNLOAD.sh ├── datasets │ ├── __pycache__ │ │ └── datasets.cpython-37.pyc │ ├── datasets.py │ └── transform.py ├── inference.py ├── networks │ ├── CrossAttention.py │ ├── SwinDRNet.py │ ├── SwinTransformer.py │ ├── UPerNet.py │ └── __pycache__ │ │ ├── CrossAttention.cpython-37.pyc │ │ ├── SwinDRNet.cpython-37.pyc │ │ ├── SwinTransformer.cpython-37.pyc │ │ └── UPerNet.cpython-37.pyc ├── pretrained_model │ └── DOWNLOAD.sh ├── requirments.txt ├── train.py ├── trainer.py └── utils │ ├── __pycache__ │ ├── aligning.cpython-37.pyc │ ├── api_utils.cpython-37.pyc │ ├── create_image_grid.cpython-37.pyc │ ├── loss_functions.cpython-37.pyc │ ├── metrics_depth_restoration.cpython-37.pyc │ ├── metrics_nocs.cpython-37.pyc │ └── metrics_sem_seg.cpython-37.pyc │ ├── aligning.py │ ├── api_utils.py │ ├── create_image_grid.py │ ├── loss_functions.py │ ├── metrics_depth_restoration.py │ ├── metrics_nocs.py │ └── metrics_sem_seg.py ├── DepthSensorSimulator ├── README.md ├── modify_material.py ├── renderer.py ├── requirement.txt ├── run_renderer.sh ├── stereo_matching.py ├── teaser_dreds.png └── texture │ └── texture_0.jpg ├── README.md ├── SwinDRNet ├── README.md ├── config.py ├── configs │ └── swin_tiny_patch4_window7_224_lite.yaml ├── data │ └── DOWNLOAD.sh ├── datasets │ ├── .DS_Store │ ├── datasets.py │ └── transform.py ├── images │ ├── .DS_Store │ └── SwinDRNet.png ├── inference.py ├── models │ └── DOWNLOAD.sh ├── networks │ ├── CrossAttention.py │ ├── SwinDRNet.py │ ├── SwinTransformer.py │ └── UPerNet.py ├── pretrained_model │ └── DOWNLOAD.sh ├── requirments.txt ├── train.py ├── trainer.py └── utils │ ├── api_utils.py │ ├── create_image_grid.py │ ├── loss_functions.py │ ├── metrics_depth_restoration.py │ └── metrics_sem_seg.py └── images └── teaser.png /CatePoseEstimation/README.md: -------------------------------------------------------------------------------- 1 | # SwinDRNet for Category-level Pose Estimation 2 | PyTorch code and weights of SwinDRNet baseline for category-level pose estimation. 3 | ## System Dependencies 4 | ```bash 5 | $ sudo apt-get install libhdf5-10 libhdf5-serial-dev libhdf5-dev libhdf5-cpp-11 6 | $ sudo apt install libopenexr-dev zlib1g-dev openexr 7 | ``` 8 | ## Setup 9 | - ### Install pip dependencies 10 | We have tested on Ubuntu 20.04 with an NVIDIA GeForce RTX 2080 and NVIDIA GeForce RTX 3090 with Python 3.7. The code may work on other systems.Install the dependencies using pip: 11 | ```bash 12 | $ pip install -r requirments.txt 13 | ``` 14 | - ### Download dataset and models 15 | 16 | 1. Download the pre-trained SwinDRNet model and dataset. In the scripts below, be sure to comment out files you do not want, as they are very large. Alternatively, you can download files [manually](https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/) 17 | 18 | ```bash 19 | # Download DREDS and STD Dataset 20 | $ cd data 21 | $ bash DOWNLOAD.sh 22 | $ cd .. 23 | 24 | # Download the pretrained model 25 | $ cd pretrained_model 26 | $ bash DOWNLOAD.sh 27 | $ cd .. 28 | 29 | ``` 30 | 2. Model: We provide our pretrained pose estimation model [here](https://drive.google.com/file/d/1MqUIUhJYLljnoj66mjiyq36VxZHUYQ77/view?usp=sharing). Please download to /results/ckpt/ . 31 | 3. Extract the downloaded dataset and merge the train split of DREDS-CatKnown following the file structure. 32 | ``` 33 | data 34 | ├── DREDS 35 | │   ├── DREDS-CatKnown 36 | │ │ ├── train 37 | │ │ │ ├── 00001 38 | │ │ │ └── ... 39 | │ │ ├── val 40 | │ │ │ ├── 01162 41 | │ │ │ └── ... 42 | │ │ └── test 43 | │ │ ├── 00000 44 | │ │ └── ... 45 | │   └── DREDS-CatNovel 46 | │ ├── 00029 47 | │ └── ... 48 | ├── STD 49 | │  ├── STD-CatKnown 50 | │ │ ├── test_0 51 | │ │ └── ... 52 | │   └── STD-CatNovel 53 | │ ├── test_novel_0-1 54 | │ └── ... 55 | └── cad_model 56 |  ├──syn_train 57 | │ ├──00000000 58 | │ └──... 59 |  ├──syn_test 60 | │ ├──00000000 61 | │ └──... 62 |  ├──real_cat_known 63 | │ ├──aeroplane 64 | │ └──... 65 |  └──real_cat_novel 66 | ├──0_trans_teapot 67 | └──... 68 | ``` 69 | 70 | 71 | ## Training 72 | - Start training by: 73 | ```bash 74 | # An example command for training 75 | $ python train.py --train_data_path PATH_DRED_CatKnown_TrainSplit --val_data_path PATH_DRED_CatKnown_ValSplit --val_obj_path PATH_DRED_CatKnown_CADMOEL 76 | ``` 77 | 78 | ## Testing 79 | - Start testing by: 80 | ```bash 81 | # An example command for testing 82 | $ python inference.py --val_data_type TYPE_OF_DATA --train_data_path PATH_DRED_CatKnown_TrainSplit --val_data_path PATH_DRED_CatKnown_TestSplit --val_obj_path PATH_DRED_CatKnown_CADMOEL --val_depth_path PATH_VAL_DEPTH 83 | ``` 84 | 85 | ## Note 86 | We has fixed a bug of calculating IoUs in the [original NOCS code](https://github.com/hughw19/NOCS_CVPR2019/blob/78a31c2026a954add1a2711286ff45ce1603b8ab/utils.py#L252), as shown in [here](https://github.com/PKU-EPIC/DREDS/blob/310a4921d8cf5a565d6c2547d98070db0e59c999/CatePoseEstimation/utils/metrics_nocs.py#L234), and re-evaluated the related results. The modification does not affect our conclusions in the paper. Thanks [Liu et al.](https://github.com/THU-DA-6D-Pose-Group/CATRE#note) for their confirmation. -------------------------------------------------------------------------------- /CatePoseEstimation/align.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" 5 | from PIL import Image 6 | from utils.metrics_nocs import align, prepare_data_posefitting, draw_detections 7 | from datasets.datasets import exr_loader,load_meta 8 | syn_depth_path = '/data/sensor/data/real_data/test_0/0000_gt_depth.exr' 9 | nocs_path = '/data/sensor/data/real_data/test_0/0000_coord.png' 10 | mask_path = '/data/sensor/data/real_data/test_0/0000_mask.png' 11 | meta_path = '/data/sensor/data/real_data/test_0/0000_meta.txt' 12 | obj_dir = '/data/sensor/data/cad_model/real_cat_known' 13 | synset_names = ['other', # 0 14 | 'bottle', # 1 15 | 'bowl', # 2 16 | 'camera', # 3 17 | 'can', # 4 18 | 'car', # 5 19 | 'mug', # 6 20 | 'aeroplane', # 7 21 | 'BG', # 8 22 | ] 23 | intrinsics = np.zeros((3,3)) 24 | 25 | img_h = 720 26 | img_w = 1280 27 | fx = 918.295227050781 28 | fy = 917.5439453125 29 | 30 | cx = img_w * 0.5 - 0.5 31 | cy = img_h * 0.5 - 0.5 32 | camera_params = { 33 | 'fx': fx, 34 | 'fy': fy, 35 | 'cx': cx, 36 | 'cy': cy, 37 | 'yres': img_h, 38 | 'xres': img_w, 39 | } 40 | hw = (640 / img_w, 360 / img_h) 41 | camera_params['fx'] *= hw[0] 42 | camera_params['fy'] *= hw[1] 43 | camera_params['cx'] *= hw[0] 44 | camera_params['cy'] *= hw[1] 45 | camera_params['xres'] *= hw[0] 46 | camera_params['yres'] *= hw[1] 47 | intrinsics[0,0] = camera_params['fx'] 48 | intrinsics[0,2] = camera_params['cx'] 49 | intrinsics[1,1] = camera_params['fy'] 50 | intrinsics[1,2] = camera_params['cy'] 51 | intrinsics[2,2] = 1.0 52 | 53 | 54 | _syn_depth = cv2.imread(syn_depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 55 | if len(_syn_depth.shape) == 3: 56 | _syn_depth = _syn_depth[:, :, 0] 57 | coords = Image.open(nocs_path).convert('RGB') 58 | coords = np.array(coords) / 255. 59 | 60 | if mask_path.split('.')[-1] == 'exr': 61 | _mask = exr_loader(mask_path, ndim=1) 62 | else: 63 | _mask = Image.open(mask_path) 64 | _mask = np.array(_mask) 65 | if mask_path.split('.')[-1] == 'exr': 66 | _mask = np.array(_mask * 255, dtype=np.int32) 67 | _meta = load_meta(meta_path) 68 | 69 | _mask_sem = np.full(_mask.shape, 8) #, 0) 70 | _scale = np.ones((10,3)) #(class_num+1,3) 71 | for i in range(len(_meta)): 72 | _mask_sem[_mask == _meta[i]["index"]] = _meta[i]["label"] #1 73 | if _meta[i]["instance_folder"] !=" " : 74 | 75 | bbox_file = os.path.join(obj_dir,_meta[i]["instance_folder"] ,_meta[i]["name"],'bbox.txt') 76 | bbox = np.loadtxt(bbox_file) 77 | _scale[_meta[i]["label"]] = bbox[0, :] - bbox[1, :] 78 | else : 79 | _scale[_meta[i]["label"]] = np.ones((3)) 80 | _scale[_meta[i]["label"]] /= np.linalg.norm(_scale[_meta[i]["label"]]) 81 | 82 | 83 | gt_class_ids , gt_scores , gt_masks ,gt_coords ,\ 84 | gt_boxes = prepare_data_posefitting(_mask_sem,coords) 85 | 86 | result = {} 87 | result['gt_RTs'], scales, error_message, _ = align(gt_class_ids, 88 | gt_masks, 89 | gt_coords, 90 | _syn_depth, 91 | intrinsics, 92 | synset_names) 93 | 94 | if 1: 95 | output_path = 'tmp' 96 | draw_rgb = False 97 | save_dir =os.path.join(output_path ,'save_{}'.format(i)) 98 | if not os.path.exists(save_dir) : 99 | os.mkdir(save_dir) 100 | data = 'camera' 101 | result['gt_handle_visibility'] = np.ones_like(gt_class_ids) 102 | rgb_path = '/data/sensor/data/real_data/test_0/0000_color.png' 103 | _rgb = Image.open(rgb_path).convert('RGB') 104 | _rgb = _rgb.resize((640,360)) 105 | _rgb = np.array(_rgb) 106 | 107 | draw_detections(_rgb, save_dir, data, 1, intrinsics, synset_names, draw_rgb, 108 | gt_boxes, gt_class_ids, gt_masks, gt_coords, result['gt_RTs'], scales, np.ones(gt_boxes.shape[0]), 109 | gt_boxes, gt_class_ids, gt_masks, gt_coords, result['gt_RTs'], np.ones(gt_boxes.shape[0]), scales) -------------------------------------------------------------------------------- /CatePoseEstimation/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Checkpoint to resume, could be overwritten by command line argument 50 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' 51 | _C.MODEL.RESUME = '' 52 | # Number of classes, overwritten in data preparation 53 | _C.MODEL.NUM_CLASSES = 1000 54 | # Dropout rate 55 | _C.MODEL.DROP_RATE = 0.0 56 | # Drop path rate 57 | _C.MODEL.DROP_PATH_RATE = 0.1 58 | # Label Smoothing 59 | _C.MODEL.LABEL_SMOOTHING = 0.1 60 | 61 | # Swin Transformer parameters 62 | _C.MODEL.SWIN = CN() 63 | _C.MODEL.SWIN.PATCH_SIZE = 4 64 | _C.MODEL.SWIN.IN_CHANS = 3 65 | _C.MODEL.SWIN.EMBED_DIM = 96 66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 67 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] 68 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 69 | _C.MODEL.SWIN.WINDOW_SIZE = 7 70 | _C.MODEL.SWIN.MLP_RATIO = 4. 71 | _C.MODEL.SWIN.QKV_BIAS = True 72 | _C.MODEL.SWIN.QK_SCALE = None 73 | _C.MODEL.SWIN.APE = False 74 | _C.MODEL.SWIN.PATCH_NORM = True 75 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" 76 | 77 | # ----------------------------------------------------------------------------- 78 | # Training settings 79 | # ----------------------------------------------------------------------------- 80 | _C.TRAIN = CN() 81 | _C.TRAIN.START_EPOCH = 0 82 | _C.TRAIN.EPOCHS = 300 83 | _C.TRAIN.WARMUP_EPOCHS = 20 84 | _C.TRAIN.WEIGHT_DECAY = 0.05 85 | _C.TRAIN.BASE_LR = 5e-4 86 | _C.TRAIN.WARMUP_LR = 5e-7 87 | _C.TRAIN.MIN_LR = 5e-6 88 | # Clip gradient norm 89 | _C.TRAIN.CLIP_GRAD = 5.0 90 | # Auto resume from latest checkpoint 91 | _C.TRAIN.AUTO_RESUME = True 92 | # Gradient accumulation steps 93 | # could be overwritten by command line argument 94 | _C.TRAIN.ACCUMULATION_STEPS = 0 95 | # Whether to use gradient checkpointing to save memory 96 | # could be overwritten by command line argument 97 | _C.TRAIN.USE_CHECKPOINT = False 98 | 99 | # LR scheduler 100 | _C.TRAIN.LR_SCHEDULER = CN() 101 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 102 | # Epoch interval to decay LR, used in StepLRScheduler 103 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 104 | # LR decay rate, used in StepLRScheduler 105 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 106 | 107 | # Optimizer 108 | _C.TRAIN.OPTIMIZER = CN() 109 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 110 | # Optimizer Epsilon 111 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 112 | # Optimizer Betas 113 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 114 | # SGD momentum 115 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 116 | 117 | # ----------------------------------------------------------------------------- 118 | # Augmentation settings 119 | # ----------------------------------------------------------------------------- 120 | _C.AUG = CN() 121 | # Color jitter factor 122 | _C.AUG.COLOR_JITTER = 0.4 123 | # Use AutoAugment policy. "v0" or "original" 124 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 125 | # Random erase prob 126 | _C.AUG.REPROB = 0.25 127 | # Random erase mode 128 | _C.AUG.REMODE = 'pixel' 129 | # Random erase count 130 | _C.AUG.RECOUNT = 1 131 | # Mixup alpha, mixup enabled if > 0 132 | _C.AUG.MIXUP = 0.8 133 | # Cutmix alpha, cutmix enabled if > 0 134 | _C.AUG.CUTMIX = 1.0 135 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 136 | _C.AUG.CUTMIX_MINMAX = None 137 | # Probability of performing mixup or cutmix when either/both is enabled 138 | _C.AUG.MIXUP_PROB = 1.0 139 | # Probability of switching to cutmix when both mixup and cutmix enabled 140 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 141 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 142 | _C.AUG.MIXUP_MODE = 'batch' 143 | 144 | # ----------------------------------------------------------------------------- 145 | # Testing settings 146 | # ----------------------------------------------------------------------------- 147 | _C.TEST = CN() 148 | # Whether to use center crop when testing 149 | _C.TEST.CROP = True 150 | 151 | # ----------------------------------------------------------------------------- 152 | # Misc 153 | # ----------------------------------------------------------------------------- 154 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 155 | # overwritten by command line argument 156 | _C.AMP_OPT_LEVEL = '' 157 | # Path to output folder, overwritten by command line argument 158 | _C.OUTPUT = '' 159 | # Tag of experiment, overwritten by command line argument 160 | _C.TAG = 'default' 161 | # Frequency to save checkpoint 162 | _C.SAVE_FREQ = 1 163 | # Frequency to logging info 164 | _C.PRINT_FREQ = 10 165 | # Fixed random seed 166 | _C.SEED = 0 167 | # Perform evaluation only, overwritten by command line argument 168 | _C.EVAL_MODE = False 169 | # Test throughput only, overwritten by command line argument 170 | _C.THROUGHPUT_MODE = False 171 | # local rank for DistributedDataParallel, given by command line argument 172 | _C.LOCAL_RANK = 0 173 | 174 | 175 | def _update_config_from_file(config, cfg_file): 176 | config.defrost() 177 | with open(cfg_file, 'r') as f: 178 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 179 | 180 | for cfg in yaml_cfg.setdefault('BASE', ['']): 181 | if cfg: 182 | _update_config_from_file( 183 | config, os.path.join(os.path.dirname(cfg_file), cfg) 184 | ) 185 | print('=> merge config from {}'.format(cfg_file)) 186 | config.merge_from_file(cfg_file) 187 | config.freeze() 188 | 189 | 190 | def update_config(config, args): 191 | _update_config_from_file(config, args.cfg) 192 | 193 | config.defrost() 194 | if args.opts: 195 | config.merge_from_list(args.opts) 196 | 197 | # merge from specific arguments 198 | if args.batch_size: 199 | config.DATA.BATCH_SIZE = args.batch_size 200 | if args.zip: 201 | config.DATA.ZIP_MODE = True 202 | if args.cache_mode: 203 | config.DATA.CACHE_MODE = args.cache_mode 204 | if args.resume: 205 | config.MODEL.RESUME = args.resume 206 | if args.accumulation_steps: 207 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 208 | if args.use_checkpoint: 209 | config.TRAIN.USE_CHECKPOINT = True 210 | if args.amp_opt_level: 211 | config.AMP_OPT_LEVEL = args.amp_opt_level 212 | if args.tag: 213 | config.TAG = args.tag 214 | if args.eval: 215 | config.EVAL_MODE = True 216 | if args.throughput: 217 | config.THROUGHPUT_MODE = True 218 | 219 | config.freeze() 220 | 221 | 222 | def get_config(args): 223 | """Get a yacs CfgNode object with default values.""" 224 | # Return a clone so that the defaults will not be altered 225 | # This is for the "local variable" use pattern 226 | config = _C.clone() 227 | update_config(config, args) 228 | 229 | return config 230 | -------------------------------------------------------------------------------- /CatePoseEstimation/configs/swin_tiny_patch4_window7_224_lite.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | PRETRAIN_CKPT: "pretrain_model/swin_tiny_patch4_window7_224.pth" 6 | SWIN: 7 | FINAL_UPSAMPLE: "expand_first" 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | DECODER_DEPTHS: [ 2, 2, 2, 1] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /CatePoseEstimation/data/DOWNLOAD.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script makes it easy to download DREDS/STD dataset. 4 | # Files are at https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/ 5 | 6 | # Comment out any data you do not want. 7 | 8 | echo 'Warning: Files are *very* large. Be sure to comment out any files you do not want.' 9 | 10 | 11 | #----- DREDS Dataset ----------------------------------- 12 | mkdir DREDS 13 | cd DREDS 14 | mkdir DREDS-CatKnown 15 | cd DREDS-CatKnown 16 | 17 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/test/test.gz # DREDS-CatKnown-Test (73.4G) 18 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/val/val.tar.gz # DREDS-CatKnown-Val (10.5G) 19 | 20 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/train/train_part0.tar.gz # DREDS-CatKnown-Train-Part0 (74.2G) 21 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/train/train_part1.tar.gz # DREDS-CatKnown-Train-Part1 (73.5G) 22 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/train/train_part2.tar.gz # DREDS-CatKnown-Train-Part2 (73.7G) 23 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/train/train_part3.tar.gz # DREDS-CatKnown-Train-Part3 (73.4G) 24 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/train/train_part4.tar.gz # DREDS-CatKnown-Train-Part4 (73.4G) 25 | 26 | cd .. 27 | 28 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatNovel/DREDS-CatNovel.tar.gz # DREDS-CatNovel (45.4G) 29 | 30 | cd .. 31 | 32 | 33 | #----- STD Dataset ----------------------------------- 34 | mkdir STD 35 | cd STD 36 | 37 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/STD-CatKnown/STD-CatKnown.tar.gz # STD-CatKnown () 38 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/STD-CatNovel/STD-CatNovel.tar.gz # STD-CatNovel () 39 | 40 | cd .. -------------------------------------------------------------------------------- /CatePoseEstimation/datasets/__pycache__/datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/datasets/__pycache__/datasets.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from networks.SwinDRNet import SwinDRNet 9 | from trainer import SwinDRNetTrainer 10 | from config import get_config 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--mask_transparent', action='store_true', default=True, help='material mask') 15 | parser.add_argument('--mask_specular', action='store_true', default=True, help='material mask') 16 | parser.add_argument('--mask_diffuse', action='store_true', default=True, help='material mask') 17 | 18 | parser.add_argument('--train_data_path', type=str, 19 | default='/data/DREDS/DREDS-CatKnown/train', help='root dir for training dataset') 20 | parser.add_argument('--train_obj_path', type=str, 21 | default='/data/cad_model/syn_train', help='root dir for obj') 22 | 23 | parser.add_argument('--val_data_path', type=str, 24 | default='/data/DREDS/DREDS-CatKnown/test', help='root dir for data') 25 | parser.add_argument('--val_data_type', type=str, 26 | default='sim', help='type of val dataset') 27 | parser.add_argument('--val_obj_path', type=str, 28 | default='/data/cad_model/syn_test', help='root dir for obj') 29 | parser.add_argument('--val_depth_path', type=str, 30 | default='/data/DREDS/DREDS-CatKnown/test', help='root dir for depth') 31 | 32 | 33 | 34 | 35 | 36 | parser.add_argument('--output_dir', type=str, 37 | default='results/inference', help='output dir') 38 | parser.add_argument('--decode_mode', type=str, 39 | default='multi_head', help='Select encode mode') 40 | parser.add_argument('--checkpoint_save_path', type=str, 41 | default='results/inference', help='Choose a path to save checkpoints') 42 | 43 | 44 | 45 | parser.add_argument('--val_interation_interval', type=int, 46 | default=5000, help='The iteration interval to perform validation') 47 | 48 | parser.add_argument('--percentageDataForTraining', type=float, 49 | default=1.0, help='The percentage of full training data for training') 50 | parser.add_argument('--percentageDataForVal', type=float, 51 | default=1.0, help='The percentage of full training data for training') 52 | 53 | parser.add_argument('--num_classes', type=int, 54 | default=9, help='output channel of network') 55 | parser.add_argument('--max_epochs', type=int, default=20, 56 | help='maximum epoch number to train') 57 | parser.add_argument('--batch_size', type=int, default=8, 58 | help='batch_size per gpu') 59 | parser.add_argument('--n_gpu', type=int, default=1, help='total gpu') 60 | parser.add_argument('--deterministic', type=int, default=1, 61 | help='whether use deterministic training') 62 | parser.add_argument('--base_lr', type=float, default=0.0001, 63 | help='segmentation network learning rate') 64 | parser.add_argument('--img_size', type=int, 65 | default=224, help='input patch size of network input') 66 | parser.add_argument('--seed', type=int, 67 | default=1234, help='random seed') 68 | 69 | parser.add_argument('--cfg', type=str, default="configs/swin_tiny_patch4_window7_224_lite.yaml", metavar="FILE", help='path to config file', ) 70 | parser.add_argument( 71 | "--opts", 72 | help="Modify config options by adding 'KEY VALUE' pairs. ", 73 | default=None, 74 | nargs='+', 75 | ) 76 | parser.add_argument('--zip', action='store_true', default=True, help='use zipped dataset instead of folder dataset') 77 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 78 | help='no: no cache, ' 79 | 'full: cache all data, ' 80 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 81 | parser.add_argument('--resume',type=str, default='./output-1/epoch_149.pth', help='resume from checkpoint') 82 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 83 | parser.add_argument('--use-checkpoint', action='store_true', 84 | help="whether to use gradient checkpointing to save memory") 85 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 86 | help='mixed precision opt level, if O0, no amp is used') 87 | parser.add_argument('--tag', help='tag of experiment') 88 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 89 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 90 | 91 | 92 | args = parser.parse_args() 93 | config = get_config(args) 94 | 95 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 96 | device_list = [0] 97 | model_path = "results/ckpt/checkpoint-iter-00150000.pth" 98 | 99 | 100 | if __name__ == "__main__": 101 | if not args.deterministic: 102 | cudnn.benchmark = True 103 | cudnn.deterministic = False 104 | else: 105 | cudnn.benchmark = False 106 | cudnn.deterministic = True 107 | 108 | random.seed(args.seed) 109 | np.random.seed(args.seed) 110 | torch.manual_seed(args.seed) 111 | torch.cuda.manual_seed(args.seed) 112 | 113 | if not os.path.exists(args.output_dir): 114 | os.makedirs(args.output_dir) 115 | 116 | net = SwinDRNet(config, img_size=args.img_size, num_classes=args.num_classes).cuda() 117 | trainer = SwinDRNetTrainer 118 | _trainer = trainer(args, net, device_list, model_path) 119 | 120 | _trainer.inference() 121 | 122 | -------------------------------------------------------------------------------- /CatePoseEstimation/networks/SwinDRNet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | import warnings 10 | import torch.nn.functional as F 11 | 12 | from os.path import join as pjoin 13 | 14 | import torch 15 | import torch.nn as nn 16 | import numpy as np 17 | from torch.autograd import gradcheck, Variable 18 | 19 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 20 | from torch.nn.modules.utils import _pair 21 | from scipy import ndimage 22 | from .SwinTransformer import SwinTransformerSys 23 | 24 | from .UPerNet import UPerHead, FCNHead 25 | from .CrossAttention import CrossAttention 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class SwinDRNet(nn.Module): 31 | """ SwinDRNet. 32 | A PyTorch impl of SwinDRNet, a depth restoration network proposed in: 33 | `Domain Randomization-Enhanced Depth Simulation and Restoration for 34 | Perceiving and Grasping Specular and Transparent Objects' (ECCV2022) 35 | """ 36 | 37 | def __init__(self, config, img_size=224, num_classes=3): 38 | super(SwinDRNet, self).__init__() 39 | self.num_classes = num_classes 40 | self.config = config 41 | self.img_size = img_size 42 | 43 | self.backbone_rgb_branch = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, 44 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 45 | in_chans=config.MODEL.SWIN.IN_CHANS, 46 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 47 | depths=config.MODEL.SWIN.DEPTHS, 48 | num_heads=config.MODEL.SWIN.NUM_HEADS, 49 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 50 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 51 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 52 | qk_scale=config.MODEL.SWIN.QK_SCALE, 53 | drop_rate=config.MODEL.DROP_RATE, 54 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 55 | ape=config.MODEL.SWIN.APE, 56 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 57 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 58 | self.backbone_xyz_branch = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, 59 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 60 | in_chans=config.MODEL.SWIN.IN_CHANS, 61 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 62 | depths=config.MODEL.SWIN.DEPTHS, 63 | num_heads=config.MODEL.SWIN.NUM_HEADS, 64 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 65 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 66 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 67 | qk_scale=config.MODEL.SWIN.QK_SCALE, 68 | drop_rate=config.MODEL.DROP_RATE, 69 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 70 | ape=config.MODEL.SWIN.APE, 71 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 72 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 73 | 74 | # self.decode_head_sem_seg = UPerHead(num_classes=self.num_classes, img_size = self.img_size) 75 | # self.decode_head_coord = UPerHead(num_classes=3, img_size = self.img_size) 76 | 77 | self.decode_head_sem_seg = UPerHead(num_classes=self.num_classes, img_size = self.img_size) 78 | self.decode_head_coord = UPerHead(num_classes=3, img_size = self.img_size) 79 | self.decode_head_depth_restoration = UPerHead(num_classes=1, img_size = self.img_size) 80 | self.decode_head_confidence = UPerHead(num_classes=2, img_size = self.img_size) 81 | 82 | self.cross_attention_0 = CrossAttention(in_channel=96, depth=config.MODEL.SWIN.DEPTHS[0], num_heads=config.MODEL.SWIN.NUM_HEADS[0]) 83 | self.cross_attention_1 = CrossAttention(in_channel=192, depth=config.MODEL.SWIN.DEPTHS[1], num_heads=config.MODEL.SWIN.NUM_HEADS[1]) 84 | self.cross_attention_2 = CrossAttention(in_channel=384, depth=config.MODEL.SWIN.DEPTHS[2], num_heads=config.MODEL.SWIN.NUM_HEADS[2]) 85 | self.cross_attention_3 = CrossAttention(in_channel=768, depth=config.MODEL.SWIN.DEPTHS[3], num_heads=config.MODEL.SWIN.NUM_HEADS[3]) 86 | 87 | self.softmax = nn.Softmax(dim=1) 88 | 89 | 90 | def forward(self, rgb, xyz): 91 | """Forward function.""" 92 | 93 | rgb = rgb.repeat(1,3,1,1) if rgb.size()[1] == 1 else rgb # B, C, H, W 94 | xyz = xyz.repeat(1,3,1,1) if xyz.size()[1] == 1 else xyz # B, C, H, W 95 | 96 | # depth = torch.unsqueeze(xyz[:, 2, :, :], 1) 97 | # depth = depth.repeat(1, 3, 1, 1) 98 | 99 | input_org_shape = rgb.shape[2:] 100 | rgb_feature = self.backbone_rgb_branch(rgb) 101 | xyz_feature = self.backbone_xyz_branch(xyz) 102 | 103 | shortcut = torch.unsqueeze(xyz[:, 2, :, :], 1) 104 | 105 | # fusion 106 | 107 | x = [] 108 | out = self.cross_attention_0(tuple([rgb_feature[0], xyz_feature[0]])) # [B, 96, 56, 56] 109 | x.append(out) 110 | out = self.cross_attention_1(tuple([rgb_feature[1], xyz_feature[1]])) # [B, 192, 28, 28] 111 | x.append(out) 112 | out = self.cross_attention_2(tuple([rgb_feature[2], xyz_feature[2]])) # [B, 384, 14, 14] 113 | x.append(out) 114 | out = self.cross_attention_3(tuple([rgb_feature[3], xyz_feature[3]])) # [B, 768, 7, 7] 115 | x.append(out) 116 | # pred_sem_seg = self.decode_head_sem_seg(x, input_org_shape) 117 | # pred_coord = self.decode_head_coord(x, input_org_shape) 118 | pred_sem_seg = self.decode_head_sem_seg(x, input_org_shape) 119 | pred_coord = self.decode_head_coord(x, input_org_shape) 120 | 121 | return pred_sem_seg, pred_coord 122 | 123 | 124 | def init_weights(self, pretrained=None): 125 | """Initialize the weights in backbone and heads. 126 | Args: 127 | pretrained (str, optional): Path to pre-trained weights. 128 | Defaults to None. 129 | """ 130 | self.backbone_rgb_branch.init_weights(pretrained=pretrained) 131 | self.backbone_xyz_branch.init_weights(pretrained=pretrained) 132 | self.decode_head_confidence.init_weights() 133 | self.decode_head_depth_restoration.init_weights() 134 | self.cross_attention_0.init_weights() 135 | self.cross_attention_1.init_weights() 136 | self.cross_attention_2.init_weights() 137 | self.cross_attention_3.init_weights() 138 | # self.decode_head_sem_seg.init_weights() 139 | # self.decode_head_coord.init_weights() 140 | -------------------------------------------------------------------------------- /CatePoseEstimation/networks/UPerNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import warnings 5 | 6 | 7 | def resize(input, 8 | size=None, 9 | scale_factor=None, 10 | mode='nearest', 11 | align_corners=None, 12 | warning=True): 13 | if warning: 14 | if size is not None and align_corners: 15 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 16 | output_h, output_w = tuple(int(x) for x in size) 17 | if output_h > input_h or output_w > output_h: 18 | if ((output_h > 1 and output_w > 1 and input_h > 1 19 | and input_w > 1) and (output_h - 1) % (input_h - 1) 20 | and (output_w - 1) % (input_w - 1)): 21 | warnings.warn( 22 | f'When align_corners={align_corners}, ' 23 | 'the output would more aligned if ' 24 | f'input size {(input_h, input_w)} is `x+1` and ' 25 | f'out size {(output_h, output_w)} is `nx+1`') 26 | if isinstance(size, torch.Size): 27 | size = tuple(int(x) for x in size) 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | 30 | 31 | class PPM(nn.ModuleList): 32 | """Pooling Pyramid Module used in PSPNet. 33 | 34 | Args: 35 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 36 | Module. 37 | in_channels (int): Input channels. 38 | channels (int): Channels after modules, before conv_seg. 39 | conv_cfg (dict|None): Config of conv layers. 40 | norm_cfg (dict|None): Config of norm layers. 41 | act_cfg (dict): Config of activation layers. 42 | align_corners (bool): align_corners argument of F.interpolate. 43 | """ 44 | 45 | def __init__(self, pool_scales, in_channels, channels, align_corners): 46 | super(PPM, self).__init__() 47 | self.pool_scales = pool_scales 48 | self.align_corners = align_corners 49 | self.in_channels = in_channels 50 | self.channels = channels 51 | for pool_scale in pool_scales: 52 | self.append( 53 | nn.Sequential( 54 | nn.AdaptiveAvgPool2d(pool_scale), 55 | ConvModule( 56 | self.in_channels, 57 | self.channels, 58 | 1))) 59 | 60 | def forward(self, x): 61 | """Forward function.""" 62 | ppm_outs = [] 63 | for ppm in self: 64 | ppm_out = ppm(x) 65 | upsampled_ppm_out = resize( 66 | ppm_out, 67 | size=x.size()[2:], 68 | mode='bilinear', 69 | align_corners=self.align_corners) 70 | ppm_outs.append(upsampled_ppm_out) 71 | return ppm_outs 72 | 73 | 74 | class ConvModule(nn.Module): 75 | def __init__(self, 76 | in_channels, 77 | out_channels, 78 | kernel_size, 79 | stride=1, 80 | padding=0, 81 | dilation=1, 82 | groups=1, 83 | bias=False, 84 | inplace=True, 85 | with_spectral_norm=False, 86 | padding_mode='zeros', 87 | order=('conv', 'norm', 'act')): 88 | super().__init__() 89 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) 90 | self.bn = nn.BatchNorm2d(out_channels) 91 | self.activate = nn.ReLU(inplace) 92 | self.with_spectral_norm = with_spectral_norm 93 | official_padding_mode = ['zeros', 'circular'] 94 | self.with_explicit_padding = padding_mode not in official_padding_mode 95 | self.order = order 96 | assert isinstance(self.order, tuple) and len(self.order) == 3 97 | assert set(order) == set(['conv', 'norm', 'act']) 98 | 99 | def forward(self, x, activate=True, norm=True): 100 | for layer in self.order: 101 | if layer == 'conv': 102 | #if self.with_explicit_padding: 103 | # x = self.padding_layer(x) 104 | x = self.conv(x) 105 | elif layer == 'norm' and norm: 106 | x = self.bn(x) 107 | elif layer == 'act' and activate: 108 | x = self.activate(x) 109 | return x 110 | 111 | 112 | class UPerHead(nn.Module): 113 | """Unified Perceptual Parsing for Scene Understanding. 114 | 115 | This head is the implementation of `UPerNet 116 | `_. 117 | 118 | Args: 119 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 120 | Module applied on the last feature. Default: (1, 2, 3, 6). 121 | """ 122 | 123 | def __init__(self, pool_scales=(1, 2, 3, 6), num_classes=150, in_channels=[96, 192, 384, 768], 124 | channels=512, dropout_ratio=0.1, align_corners=False, in_index=[0, 1, 2, 3], img_size = 512): 125 | super(UPerHead, self).__init__()#(input_transform='multiple_select', **kwargs) 126 | self.in_channels = in_channels 127 | self.num_classes = num_classes 128 | self.channels = channels 129 | self.dropout_ratio = dropout_ratio 130 | self.align_corners = align_corners 131 | self.input_transform = 'multiple_select' 132 | self.in_index = in_index 133 | self.img_size = img_size 134 | # PSP Module 135 | self.psp_modules = PPM( 136 | pool_scales, 137 | self.in_channels[-1], 138 | self.channels, 139 | align_corners=self.align_corners) 140 | self.bottleneck = ConvModule( 141 | self.in_channels[-1] + len(pool_scales) * self.channels, # in_channels 142 | self.channels, # out_channels 143 | 3, # kernel_size 144 | padding=1 # kernel_size 145 | ) 146 | # FPN Module 147 | self.lateral_convs = nn.ModuleList() 148 | self.fpn_convs = nn.ModuleList() 149 | for in_channels in self.in_channels[:-1]: # skip the top layer 150 | l_conv = ConvModule( 151 | in_channels, 152 | self.channels, 153 | 1, 154 | inplace=False) 155 | fpn_conv = ConvModule( 156 | self.channels, 157 | self.channels, 158 | 3, 159 | padding=1, 160 | inplace=False) 161 | self.lateral_convs.append(l_conv) 162 | self.fpn_convs.append(fpn_conv) 163 | 164 | self.fpn_bottleneck = ConvModule( 165 | len(self.in_channels) * self.channels, 166 | self.channels, 167 | 3, 168 | padding=1) 169 | 170 | #self.fpn_bottleneck = nn.Sequential( 171 | 172 | self.conv_seg = nn.Conv2d(self.channels, self.num_classes, kernel_size=1) 173 | self.dropout = nn.Dropout2d(self.dropout_ratio) 174 | 175 | def psp_forward(self, inputs): 176 | """Forward function of PSP module.""" 177 | x = inputs[-1] 178 | psp_outs = [x] 179 | psp_outs.extend(self.psp_modules(x)) 180 | psp_outs = torch.cat(psp_outs, dim=1) 181 | output = self.bottleneck(psp_outs) 182 | return output 183 | 184 | def _transform_inputs(self, inputs): 185 | """Transform inputs for decoder. 186 | 187 | Args: 188 | inputs (list[Tensor]): List of multi-level img features. 189 | 190 | Returns: 191 | Tensor: The transformed inputs 192 | """ 193 | if self.input_transform == 'resize_concat': 194 | inputs = [inputs[i] for i in self.in_index] 195 | upsampled_inputs = [ 196 | resize( 197 | input=x, 198 | size=inputs[0].shape[2:], 199 | mode='bilinear', 200 | align_corners=self.align_corners) for x in inputs 201 | ] 202 | inputs = torch.cat(upsampled_inputs, dim=1) 203 | elif self.input_transform == 'multiple_select': 204 | inputs = [inputs[i] for i in self.in_index] 205 | else: 206 | inputs = inputs[self.in_index] 207 | 208 | return inputs 209 | 210 | def init_weights(self): 211 | """Initialize weights of classification layer.""" 212 | normal_init(self.conv_seg, mean=0, std=0.01) 213 | 214 | def cls_seg(self, feat): 215 | """Classify each pixel.""" 216 | if self.dropout is not None: 217 | feat = self.dropout(feat) 218 | output = self.conv_seg(feat) 219 | return output 220 | 221 | def forward(self, inputs, input_org_shape): 222 | """Forward function.""" 223 | 224 | inputs = self._transform_inputs(inputs) 225 | 226 | # build laterals 227 | laterals = [ 228 | lateral_conv(inputs[i]) 229 | for i, lateral_conv in enumerate(self.lateral_convs) 230 | ] 231 | 232 | laterals.append(self.psp_forward(inputs)) 233 | 234 | # build top-down path 235 | used_backbone_levels = len(laterals) 236 | for i in range(used_backbone_levels - 1, 0, -1): 237 | prev_shape = laterals[i - 1].shape[2:] 238 | laterals[i - 1] += resize( 239 | laterals[i], 240 | size=prev_shape, 241 | mode='bilinear', 242 | align_corners=self.align_corners) 243 | 244 | # build outputs 245 | fpn_outs = [ 246 | self.fpn_convs[i](laterals[i]) 247 | for i in range(used_backbone_levels - 1) 248 | ] 249 | # append psp feature 250 | fpn_outs.append(laterals[-1]) 251 | 252 | ######################原始方案########################### 253 | # 直接在最后得到logits后做resize 254 | for i in range(used_backbone_levels - 1, 0, -1): 255 | fpn_outs[i] = resize( 256 | fpn_outs[i], 257 | size=fpn_outs[0].shape[2:], 258 | mode='bilinear', 259 | align_corners=self.align_corners) 260 | fpn_outs = torch.cat(fpn_outs, dim=1) 261 | output = self.fpn_bottleneck(fpn_outs) 262 | 263 | output = self.cls_seg(output) # [bs, 150, 128, 128] 264 | output = resize( 265 | input=output, 266 | size=(input_org_shape[0], input_org_shape[1]), 267 | mode='bilinear', 268 | align_corners=self.align_corners) 269 | ######################原始方案########################### 270 | 271 | ######################方案二############################# 272 | # 在fpn后seg_cls前resize 273 | # for i in range(used_backbone_levels - 1, 0, -1): 274 | # fpn_outs[i] = resize( 275 | # fpn_outs[i], 276 | # size=fpn_outs[0].shape[2:], 277 | # mode='bilinear', 278 | # align_corners=self.align_corners) 279 | # fpn_outs = torch.cat(fpn_outs, dim=1) 280 | # output = self.fpn_bottleneck(fpn_outs) 281 | 282 | # output = resize( 283 | # input=output, 284 | # size=(input_org_shape[0], input_org_shape[1]), 285 | # mode='bilinear', 286 | # align_corners=self.align_corners) 287 | 288 | # output = self.cls_seg(output) # [bs, 150, 128, 128] 289 | ######################################################### 290 | 291 | ######################方案三############################# 292 | # 在fpn前resize 293 | # for i in range(used_backbone_levels - 1, -1, -1): 294 | # fpn_outs[i] = resize( 295 | # fpn_outs[i], 296 | # # size=fpn_outs[0].shape[2:], 297 | # size=(input_org_shape[0], input_org_shape[1]), 298 | # mode='bilinear', 299 | # align_corners=self.align_corners) 300 | # fpn_outs = torch.cat(fpn_outs, dim=1) 301 | # output = self.fpn_bottleneck(fpn_outs) 302 | # output = self.cls_seg(output) # [bs, 150, 128, 128] 303 | ######################################################### 304 | return output 305 | 306 | 307 | class FCNHead(nn.Module): 308 | """Fully Convolution Networks for Semantic Segmentation. 309 | 310 | This head is implemented of `FCNNet `_. 311 | 312 | Args: 313 | num_convs (int): Number of convs in the head. Default: 2. 314 | kernel_size (int): The kernel size for convs in the head. Default: 3. 315 | concat_input (bool): Whether concat the input and output of convs 316 | before classification layer. 317 | """ 318 | 319 | def __init__(self, 320 | in_channels, 321 | in_index, 322 | channels, 323 | num_classes, 324 | align_corners=False, 325 | dropout_ratio=0.5, 326 | 327 | num_convs=2, 328 | kernel_size=3, 329 | concat_input=True, 330 | img_size = 512): 331 | self.in_channels = in_channels 332 | self.in_index = in_index 333 | self.channels = channels 334 | self.num_classes = num_classes 335 | self.align_corners = align_corners 336 | self.dropout_ratio = dropout_ratio 337 | self.img_size = img_size 338 | 339 | assert num_convs >= 0 340 | self.num_convs = num_convs 341 | self.concat_input = concat_input 342 | self.kernel_size = kernel_size 343 | self.input_transform = None 344 | super(FCNHead, self).__init__() 345 | if num_convs == 0: 346 | assert self.in_channels == self.channels 347 | 348 | convs = [] 349 | convs.append( 350 | ConvModule( 351 | self.in_channels, 352 | self.channels, 353 | kernel_size=kernel_size, 354 | padding=kernel_size // 2)) 355 | for i in range(num_convs - 1): 356 | convs.append( 357 | ConvModule( 358 | self.channels, 359 | self.channels, 360 | kernel_size=kernel_size, 361 | padding=kernel_size // 2)) 362 | if num_convs == 0: 363 | self.convs = nn.Identity() 364 | else: 365 | self.convs = nn.Sequential(*convs) 366 | if self.concat_input: 367 | self.conv_cat = ConvModule( 368 | self.in_channels + self.channels, 369 | self.channels, 370 | kernel_size=kernel_size, 371 | padding=kernel_size // 2) 372 | 373 | self.conv_seg = nn.Conv2d(self.channels, self.num_classes, kernel_size=1) 374 | self.dropout = nn.Dropout2d(self.dropout_ratio) 375 | 376 | def _transform_inputs(self, inputs): 377 | """Transform inputs for decoder. 378 | 379 | Args: 380 | inputs (list[Tensor]): List of multi-level img features. 381 | 382 | Returns: 383 | Tensor: The transformed inputs 384 | """ 385 | if self.input_transform == 'resize_concat': 386 | inputs = [inputs[i] for i in self.in_index] 387 | upsampled_inputs = [ 388 | resize( 389 | input=x, 390 | size=inputs[0].shape[2:], 391 | mode='bilinear', 392 | align_corners=self.align_corners) for x in inputs 393 | ] 394 | inputs = torch.cat(upsampled_inputs, dim=1) 395 | elif self.input_transform == 'multiple_select': 396 | inputs = [inputs[i] for i in self.in_index] 397 | else: 398 | inputs = inputs[self.in_index] 399 | 400 | return inputs 401 | 402 | def cls_seg(self, feat): 403 | """Classify each pixel.""" 404 | if self.dropout is not None: 405 | feat = self.dropout(feat) 406 | output = self.conv_seg(feat) 407 | return output 408 | 409 | # def normal_init(self, module, mean=0, std=1, bias=0): 410 | # if hasattr(module, 'weight') and module.weight is not None: 411 | # nn.init.normal_(module.weight, mean, std) 412 | # if hasattr(module, 'bias') and module.bias is not None: 413 | # nn.init.constant_(module.bias, bias) 414 | 415 | def init_weights(self): 416 | """Initialize weights of classification layer.""" 417 | normal_init(self.conv_seg, mean=0, std=0.01) 418 | 419 | def forward(self, inputs, input_org_shape): 420 | """Forward function.""" 421 | x = self._transform_inputs(inputs) 422 | output = self.convs(x) 423 | if self.concat_input: 424 | output = self.conv_cat(torch.cat([x, output], dim=1)) 425 | output = self.cls_seg(output) # [bs, 150, 32, 32] 426 | 427 | # resize回原图尺寸 428 | output = resize( 429 | input=output, 430 | size=(input_org_shape[0], input_org_shape[1]), 431 | mode='bilinear', 432 | align_corners=self.align_corners) 433 | 434 | return output 435 | 436 | 437 | def normal_init(module, mean=0, std=1, bias=0): 438 | if hasattr(module, 'weight') and module.weight is not None: 439 | nn.init.normal_(module.weight, mean, std) 440 | #nn.init.constant_(module.weight, 0) 441 | if hasattr(module, 'bias') and module.bias is not None: 442 | nn.init.constant_(module.bias, bias) -------------------------------------------------------------------------------- /CatePoseEstimation/networks/__pycache__/CrossAttention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/networks/__pycache__/CrossAttention.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/networks/__pycache__/SwinDRNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/networks/__pycache__/SwinDRNet.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/networks/__pycache__/SwinTransformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/networks/__pycache__/SwinTransformer.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/networks/__pycache__/UPerNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/networks/__pycache__/UPerNet.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/pretrained_model/DOWNLOAD.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script makes it easy to download the pretrained model. 4 | # Files are at https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/checkpoint/SwinDRNet/pretrain_model/ 5 | 6 | 7 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/checkpoint/SwinDRNet/pretrain_model/swin_tiny_patch4_window7_224.pth -------------------------------------------------------------------------------- /CatePoseEstimation/requirments.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | cachetools==5.0.0 3 | certifi==2021.10.8 4 | charset-normalizer==2.0.10 5 | cycler==0.11.0 6 | einops==0.4.0 7 | fonttools==4.28.5 8 | google-auth==2.4.0 9 | google-auth-oauthlib==0.4.6 10 | grpcio==1.43.0 11 | idna==3.3 12 | imageio==2.14.0 13 | imgaug==0.4.0 14 | importlib-metadata==4.10.1 15 | kiwisolver==1.3.2 16 | Markdown==3.3.6 17 | matplotlib==3.5.1 18 | networkx==2.6.3 19 | numpy==1.21.5 20 | oauthlib==3.1.1 21 | opencv-python==4.5.5.62 22 | OpenEXR==1.3.2 23 | packaging==21.3 24 | Pillow==9.0.0 25 | protobuf==3.19.3 26 | pyasn1==0.4.8 27 | pyasn1-modules==0.2.8 28 | pyparsing==3.0.7 29 | python-dateutil==2.8.2 30 | PyWavelets==1.2.0 31 | PyYAML==6.0 32 | requests==2.27.1 33 | requests-oauthlib==1.3.0 34 | rsa==4.8 35 | scikit-image==0.19.1 36 | scipy==1.7.3 37 | Shapely==1.8.0 38 | six==1.16.0 39 | tensorboard==2.8.0 40 | tensorboard-data-server==0.6.1 41 | tensorboard-plugin-wit==1.8.1 42 | termcolor==1.1.0 43 | tifffile==2021.11.2 44 | timm==0.5.4 45 | torch==1.7.1 46 | torchaudio==0.7.2 47 | torchvision==0.8.2 48 | tqdm==4.62.3 49 | typing_extensions==4.0.1 50 | urllib3==1.26.8 51 | Werkzeug==2.0.2 52 | yacs==0.1.8 53 | zipp==3.7.0 -------------------------------------------------------------------------------- /CatePoseEstimation/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from networks.SwinDRNet import SwinDRNet 9 | from trainer import SwinDRNetTrainer 10 | from config import get_config 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--mask_transparent', action='store_true', default=True, help='material mask') 15 | parser.add_argument('--mask_specular', action='store_true', default=True, help='material mask') 16 | parser.add_argument('--mask_diffuse', action='store_true', default=True, help='material mask') 17 | 18 | parser.add_argument('--train_data_path', type=str, 19 | default='/data/DREDS/DREDS-CatKnown/train', help='root dir for training dataset') 20 | parser.add_argument('--val_data_path', type=str, 21 | default='/data/DREDS/DREDS-CatKnown/val', help='root dir for validation dataset') 22 | parser.add_argument('--val_data_type', type=str, 23 | default='sim', help='type of val dataset (real/sim)') 24 | parser.add_argument('--val_depth_path', type=str, 25 | default=None, help='root dir for val depth') 26 | parser.add_argument('--train_obj_path', type=str, 27 | default='/data/cad_model/syn_train', help='root dir for train obj') 28 | parser.add_argument('--val_obj_path', type=str, 29 | default='/data/cad_model/syn_train', help='root dir for val obj') 30 | parser.add_argument('--output_dir', type=str, 31 | default='results', help='output dir') 32 | parser.add_argument('--decode_mode', type=str, 33 | default='multi_head', help='Select encode mode') 34 | parser.add_argument('--checkpoint_save_path', type=str, 35 | default='results/ckpt', help='Choose a path to save checkpoints') 36 | 37 | parser.add_argument('--val_interation_interval', type=int, 38 | default=5000, help='The iteration interval to perform validation') 39 | 40 | parser.add_argument('--percentageDataForTraining', type=float, 41 | default=1.0, help='The percentage of full training data for training') 42 | parser.add_argument('--percentageDataForVal', type=float, 43 | default=1.0, help='The percentage of full training data for training') 44 | 45 | parser.add_argument('--num_classes', type=int, 46 | default=9, help='output channel of network') 47 | parser.add_argument('--max_epochs', type=int, default=20, 48 | help='maximum epoch number to train') 49 | parser.add_argument('--batch_size', type=int, default=8, 50 | help='batch_size per gpu') 51 | parser.add_argument('--n_gpu', type=int, default=1, help='total gpu') 52 | parser.add_argument('--deterministic', type=int, default=1, 53 | help='whether use deterministic training') 54 | parser.add_argument('--base_lr', type=float, default=0.0001, 55 | help='segmentation network learning rate') 56 | parser.add_argument('--img_size', type=int, 57 | default=224, help='input patch size of network input') 58 | parser.add_argument('--seed', type=int, 59 | default=1234, help='random seed') 60 | 61 | parser.add_argument('--cfg', type=str, default="configs/swin_tiny_patch4_window7_224_lite.yaml", metavar="FILE", help='path to config file', ) 62 | parser.add_argument( 63 | "--opts", 64 | help="Modify config options by adding 'KEY VALUE' pairs. ", 65 | default=None, 66 | nargs='+', 67 | ) 68 | parser.add_argument('--zip', action='store_true', default=True, help='use zipped dataset instead of folder dataset') 69 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 70 | help='no: no cache, ' 71 | 'full: cache all data, ' 72 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 73 | parser.add_argument('--resume',type=str, default='./output-1/epoch_149.pth', help='resume from checkpoint') 74 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 75 | parser.add_argument('--use-checkpoint', action='store_true', 76 | help="whether to use gradient checkpointing to save memory") 77 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 78 | help='mixed precision opt level, if O0, no amp is used') 79 | parser.add_argument('--tag', help='tag of experiment') 80 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 81 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 82 | 83 | 84 | args = parser.parse_args() 85 | config = get_config(args) 86 | 87 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 88 | device_list = [0] 89 | 90 | if __name__ == "__main__": 91 | if not args.deterministic: 92 | cudnn.benchmark = True 93 | cudnn.deterministic = False 94 | else: 95 | cudnn.benchmark = False 96 | cudnn.deterministic = True 97 | 98 | random.seed(args.seed) 99 | np.random.seed(args.seed) 100 | torch.manual_seed(args.seed) 101 | torch.cuda.manual_seed(args.seed) 102 | 103 | if not os.path.exists(args.output_dir): 104 | os.makedirs(args.output_dir) 105 | 106 | net = SwinDRNet(config, img_size=args.img_size, num_classes=args.num_classes).cuda() 107 | continue_ckpt_path = None 108 | 109 | if continue_ckpt_path is None: 110 | # net.load_from(config) 111 | pretrained_path = config.MODEL.PRETRAIN_CKPT 112 | net.init_weights(pretrained_path) 113 | 114 | trainer = SwinDRNetTrainer 115 | _trainer = trainer(args, net, device_list, continue_ckpt_path) 116 | _trainer.train() -------------------------------------------------------------------------------- /CatePoseEstimation/utils/__pycache__/aligning.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/utils/__pycache__/aligning.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/utils/__pycache__/api_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/utils/__pycache__/api_utils.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/utils/__pycache__/create_image_grid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/utils/__pycache__/create_image_grid.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/utils/__pycache__/loss_functions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/utils/__pycache__/loss_functions.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/utils/__pycache__/metrics_depth_restoration.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/utils/__pycache__/metrics_depth_restoration.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/utils/__pycache__/metrics_nocs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/utils/__pycache__/metrics_nocs.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/utils/__pycache__/metrics_sem_seg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/CatePoseEstimation/utils/__pycache__/metrics_sem_seg.cpython-37.pyc -------------------------------------------------------------------------------- /CatePoseEstimation/utils/aligning.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Normalized Object Coordinate Space for Category-Level 6D Object Pose and Size Estimation 3 | RANSAC for Similarity Transformation Estimation 4 | 5 | Written by Srinath Sridhar 6 | ''' 7 | 8 | import numpy as np 9 | import cv2 10 | import itertools 11 | 12 | def estimateSimilarityTransform(source: np.array, target: np.array, verbose=False): 13 | SourceHom = np.transpose(np.hstack([source, np.ones([source.shape[0], 1])])) 14 | TargetHom = np.transpose(np.hstack([target, np.ones([source.shape[0], 1])])) 15 | 16 | # Auto-parameter selection based on source-target heuristics 17 | TargetNorm = np.mean(np.linalg.norm(target, axis=1)) 18 | SourceNorm = np.mean(np.linalg.norm(source, axis=1)) 19 | RatioTS = (TargetNorm / SourceNorm) 20 | RatioST = (SourceNorm / TargetNorm) 21 | PassT = RatioST if(RatioST>RatioTS) else RatioTS 22 | StopT = PassT / 100 23 | nIter = 100 24 | if verbose: 25 | print('Pass threshold: ', PassT) 26 | print('Stop threshold: ', StopT) 27 | print('Number of iterations: ', nIter) 28 | 29 | SourceInliersHom, TargetInliersHom, BestInlierRatio = getRANSACInliers(SourceHom, TargetHom, MaxIterations=nIter, PassThreshold=PassT, StopThreshold=StopT) 30 | 31 | if(BestInlierRatio < 0.1): 32 | print('[ WARN ] - Something is wrong. Small BestInlierRatio: ', BestInlierRatio) 33 | return None, None, None, None 34 | 35 | Scales, Rotation, Translation, OutTransform = estimateSimilarityUmeyama(SourceInliersHom, TargetInliersHom) 36 | 37 | if verbose: 38 | print('BestInlierRatio:', BestInlierRatio) 39 | print('Rotation:\n', Rotation) 40 | print('Translation:\n', Translation) 41 | print('Scales:', Scales) 42 | 43 | return Scales, Rotation, Translation, OutTransform 44 | 45 | def estimateRestrictedAffineTransform(source: np.array, target: np.array, verbose=False): 46 | SourceHom = np.transpose(np.hstack([source, np.ones([source.shape[0], 1])])) 47 | TargetHom = np.transpose(np.hstack([target, np.ones([source.shape[0], 1])])) 48 | 49 | RetVal, AffineTrans, Inliers = cv2.estimateAffine3D(source, target) 50 | # We assume no shear in the affine matrix and decompose into rotation, non-uniform scales, and translation 51 | Translation = AffineTrans[:3, 3] 52 | NUScaleRotMat = AffineTrans[:3, :3] 53 | # NUScaleRotMat should be the matrix SR, where S is a diagonal scale matrix and R is the rotation matrix (equivalently RS) 54 | # Let us do the SVD of NUScaleRotMat to obtain R1*S*R2 and then R = R1 * R2 55 | R1, ScalesSorted, R2 = np.linalg.svd(NUScaleRotMat, full_matrices=True) 56 | 57 | if verbose: 58 | print('-----------------------------------------------------------------------') 59 | # Now, the scales are sort in ascending order which is painful because we don't know the x, y, z scales 60 | # Let's figure that out by evaluating all 6 possible permutations of the scales 61 | ScalePermutations = list(itertools.permutations(ScalesSorted)) 62 | MinResidual = 1e8 63 | Scales = ScalePermutations[0] 64 | OutTransform = np.identity(4) 65 | Rotation = np.identity(3) 66 | for ScaleCand in ScalePermutations: 67 | CurrScale = np.asarray(ScaleCand) 68 | CurrTransform = np.identity(4) 69 | CurrRotation = (np.diag(1 / CurrScale) @ NUScaleRotMat).transpose() 70 | CurrTransform[:3, :3] = np.diag(CurrScale) @ CurrRotation 71 | CurrTransform[:3, 3] = Translation 72 | # Residual = evaluateModel(CurrTransform, SourceHom, TargetHom) 73 | Residual = evaluateModelNonHom(source, target, CurrScale,CurrRotation, Translation) 74 | if verbose: 75 | # print('CurrTransform:\n', CurrTransform) 76 | print('CurrScale:', CurrScale) 77 | print('Residual:', Residual) 78 | print('AltRes:', evaluateModelNoThresh(CurrTransform, SourceHom, TargetHom)) 79 | if Residual < MinResidual: 80 | MinResidual = Residual 81 | Scales = CurrScale 82 | Rotation = CurrRotation 83 | OutTransform = CurrTransform 84 | 85 | if verbose: 86 | print('Best Scale:', Scales) 87 | 88 | if verbose: 89 | print('Affine Scales:', Scales) 90 | print('Affine Translation:', Translation) 91 | print('Affine Rotation:\n', Rotation) 92 | print('-----------------------------------------------------------------------') 93 | 94 | return Scales, Rotation, Translation, OutTransform 95 | 96 | def getRANSACInliers(SourceHom, TargetHom, MaxIterations=100, PassThreshold=200, StopThreshold=1): 97 | BestResidual = 1e10 98 | BestInlierRatio = 0 99 | BestInlierIdx = np.arange(SourceHom.shape[1]) 100 | for i in range(0, MaxIterations): 101 | # Pick 5 random (but corresponding) points from source and target 102 | RandIdx = np.random.randint(SourceHom.shape[1], size=5) 103 | _, _, _, OutTransform = estimateSimilarityUmeyama(SourceHom[:, RandIdx], TargetHom[:, RandIdx]) 104 | Residual, InlierRatio, InlierIdx = evaluateModel(OutTransform, SourceHom, TargetHom, PassThreshold) 105 | if Residual < BestResidual: 106 | BestResidual = Residual 107 | BestInlierRatio = InlierRatio 108 | BestInlierIdx = InlierIdx 109 | if BestResidual < StopThreshold: 110 | break 111 | 112 | # print('Iteration: ', i) 113 | # print('Residual: ', Residual) 114 | # print('Inlier ratio: ', InlierRatio) 115 | 116 | return SourceHom[:, BestInlierIdx], TargetHom[:, BestInlierIdx], BestInlierRatio 117 | 118 | def evaluateModel(OutTransform, SourceHom, TargetHom, PassThreshold): 119 | Diff = TargetHom - np.matmul(OutTransform, SourceHom) 120 | ResidualVec = np.linalg.norm(Diff[:3, :], axis=0) 121 | Residual = np.linalg.norm(ResidualVec) 122 | InlierIdx = np.where(ResidualVec < PassThreshold) 123 | nInliers = np.count_nonzero(InlierIdx) 124 | InlierRatio = nInliers / SourceHom.shape[1] 125 | return Residual, InlierRatio, InlierIdx[0] 126 | 127 | def evaluateModelNoThresh(OutTransform, SourceHom, TargetHom): 128 | Diff = TargetHom - np.matmul(OutTransform, SourceHom) 129 | ResidualVec = np.linalg.norm(Diff[:3, :], axis=0) 130 | Residual = np.linalg.norm(ResidualVec) 131 | return Residual 132 | 133 | def evaluateModelNonHom(source, target, Scales, Rotation, Translation): 134 | RepTrans = np.tile(Translation, (source.shape[0], 1)) 135 | TransSource = (np.diag(Scales) @ Rotation @ source.transpose() + RepTrans.transpose()).transpose() 136 | Diff = target - TransSource 137 | ResidualVec = np.linalg.norm(Diff, axis=0) 138 | Residual = np.linalg.norm(ResidualVec) 139 | return Residual 140 | 141 | def testNonUniformScale(SourceHom, TargetHom): 142 | OutTransform = np.matmul(TargetHom, np.linalg.pinv(SourceHom)) 143 | ScaledRotation = OutTransform[:3, :3] 144 | Translation = OutTransform[:3, 3] 145 | Sx = np.linalg.norm(ScaledRotation[0, :]) 146 | Sy = np.linalg.norm(ScaledRotation[1, :]) 147 | Sz = np.linalg.norm(ScaledRotation[2, :]) 148 | Rotation = np.vstack([ScaledRotation[0, :] / Sx, ScaledRotation[1, :] / Sy, ScaledRotation[2, :] / Sz]) 149 | print('Rotation matrix norm:', np.linalg.norm(Rotation)) 150 | Scales = np.array([Sx, Sy, Sz]) 151 | 152 | # # Check 153 | # Diff = TargetHom - np.matmul(OutTransform, SourceHom) 154 | # Residual = np.linalg.norm(Diff[:3, :], axis=0) 155 | return Scales, Rotation, Translation, OutTransform 156 | 157 | def estimateSimilarityUmeyama(SourceHom, TargetHom): 158 | # Copy of original paper is at: http://web.stanford.edu/class/cs273/refs/umeyama.pdf 159 | SourceCentroid = np.mean(SourceHom[:3, :], axis=1) 160 | TargetCentroid = np.mean(TargetHom[:3, :], axis=1) 161 | nPoints = SourceHom.shape[1] 162 | 163 | CenteredSource = SourceHom[:3, :] - np.tile(SourceCentroid, (nPoints, 1)).transpose() 164 | CenteredTarget = TargetHom[:3, :] - np.tile(TargetCentroid, (nPoints, 1)).transpose() 165 | 166 | CovMatrix = np.matmul(CenteredTarget, np.transpose(CenteredSource)) / nPoints 167 | 168 | if np.isnan(CovMatrix).any(): 169 | print('nPoints:', nPoints) 170 | print(SourceHom.shape) 171 | print(TargetHom.shape) 172 | raise RuntimeError('There are NANs in the input.') 173 | 174 | U, D, Vh = np.linalg.svd(CovMatrix, full_matrices=True) 175 | d = (np.linalg.det(U) * np.linalg.det(Vh)) < 0.0 176 | if d: 177 | D[-1] = -D[-1] 178 | U[:, -1] = -U[:, -1] 179 | 180 | Rotation = np.matmul(U, Vh).T # Transpose is the one that works 181 | 182 | varP = np.var(SourceHom[:3, :], axis=1).sum() 183 | ScaleFact = 1/varP * np.sum(D) # scale factor 184 | Scales = np.array([ScaleFact, ScaleFact, ScaleFact]) 185 | ScaleMatrix = np.diag(Scales) 186 | 187 | Translation = TargetHom[:3, :].mean(axis=1) - SourceHom[:3, :].mean(axis=1).dot(ScaleFact*Rotation) 188 | 189 | OutTransform = np.identity(4) 190 | OutTransform[:3, :3] = ScaleMatrix @ Rotation 191 | OutTransform[:3, 3] = Translation 192 | 193 | # # Check 194 | # Diff = TargetHom - np.matmul(OutTransform, SourceHom) 195 | # Residual = np.linalg.norm(Diff[:3, :], axis=0) 196 | return Scales, Rotation, Translation, OutTransform 197 | -------------------------------------------------------------------------------- /CatePoseEstimation/utils/api_utils.py: -------------------------------------------------------------------------------- 1 | '''Misc functions like functions for reading and saving EXR images using OpenEXR, saving pointclouds, etc. 2 | ''' 3 | import struct 4 | 5 | import numpy as np 6 | import cv2 7 | import Imath 8 | import OpenEXR 9 | from PIL import Image 10 | import torch 11 | import torch.nn.functional as F 12 | # from torchvision.utils import make_grid 13 | 14 | 15 | def exr_loader(EXR_PATH, ndim=3): 16 | """Loads a .exr file as a numpy array 17 | 18 | Args: 19 | EXR_PATH: path to the exr file 20 | ndim: number of channels that should be in returned array. Valid values are 1 and 3. 21 | if ndim=1, only the 'R' channel is taken from exr file 22 | if ndim=3, the 'R', 'G' and 'B' channels are taken from exr file. 23 | The exr file must have 3 channels in this case. 24 | Returns: 25 | numpy.ndarray (dtype=np.float32): If ndim=1, shape is (height x width) 26 | If ndim=3, shape is (3 x height x width) 27 | 28 | """ 29 | 30 | exr_file = OpenEXR.InputFile(EXR_PATH) 31 | cm_dw = exr_file.header()['dataWindow'] 32 | size = (cm_dw.max.x - cm_dw.min.x + 1, cm_dw.max.y - cm_dw.min.y + 1) 33 | 34 | pt = Imath.PixelType(Imath.PixelType.FLOAT) 35 | 36 | if ndim == 3: 37 | # read channels indivudally 38 | allchannels = [] 39 | for c in ['R', 'G', 'B']: 40 | # transform data to numpy 41 | channel = np.frombuffer(exr_file.channel(c, pt), dtype=np.float32) 42 | channel.shape = (size[1], size[0]) 43 | allchannels.append(channel) 44 | 45 | # create array and transpose dimensions to match tensor style 46 | exr_arr = np.array(allchannels).transpose((0, 1, 2)) 47 | return exr_arr 48 | 49 | if ndim == 1: 50 | # transform data to numpy 51 | channel = np.frombuffer(exr_file.channel('R', pt), dtype=np.float32) 52 | channel.shape = (size[1], size[0]) # Numpy arrays are (row, col) 53 | exr_arr = np.array(channel) 54 | return exr_arr 55 | 56 | 57 | def exr_saver(EXR_PATH, ndarr, ndim=3): 58 | '''Saves a numpy array as an EXR file with HALF precision (float16) 59 | Args: 60 | EXR_PATH (str): The path to which file will be saved 61 | ndarr (ndarray): A numpy array containing img data 62 | ndim (int): The num of dimensions in the saved exr image, either 3 or 1. 63 | If ndim = 3, ndarr should be of shape (height, width) or (3 x height x width), 64 | If ndim = 1, ndarr should be of shape (height, width) 65 | Returns: 66 | None 67 | ''' 68 | if ndim == 3: 69 | # Check params 70 | if len(ndarr.shape) == 2: 71 | # If a depth image of shape (height x width) is passed, convert into shape (3 x height x width) 72 | ndarr = np.stack((ndarr, ndarr, ndarr), axis=0) 73 | 74 | if ndarr.shape[0] != 3 or len(ndarr.shape) != 3: 75 | raise ValueError( 76 | 'The shape of the tensor should be (3 x height x width) for ndim = 3. Given shape is {}'.format( 77 | ndarr.shape)) 78 | 79 | # Convert each channel to strings 80 | Rs = ndarr[0, :, :].astype(np.float16).tostring() 81 | Gs = ndarr[1, :, :].astype(np.float16).tostring() 82 | Bs = ndarr[2, :, :].astype(np.float16).tostring() 83 | 84 | # Write the three color channels to the output file 85 | HEADER = OpenEXR.Header(ndarr.shape[2], ndarr.shape[1]) 86 | half_chan = Imath.Channel(Imath.PixelType(Imath.PixelType.HALF)) 87 | HEADER['channels'] = dict([(c, half_chan) for c in "RGB"]) 88 | 89 | out = OpenEXR.OutputFile(EXR_PATH, HEADER) 90 | out.writePixels({'R': Rs, 'G': Gs, 'B': Bs}) 91 | out.close() 92 | elif ndim == 1: 93 | # Check params 94 | if len(ndarr.shape) != 2: 95 | raise ValueError(('The shape of the tensor should be (height x width) for ndim = 1. ' + 96 | 'Given shape is {}'.format(ndarr.shape))) 97 | 98 | # Convert each channel to strings 99 | Rs = ndarr[:, :].astype(np.float16).tostring() 100 | 101 | # Write the color channel to the output file 102 | HEADER = OpenEXR.Header(ndarr.shape[1], ndarr.shape[0]) 103 | half_chan = Imath.Channel(Imath.PixelType(Imath.PixelType.HALF)) 104 | HEADER['channels'] = dict([(c, half_chan) for c in "R"]) 105 | 106 | out = OpenEXR.OutputFile(EXR_PATH, HEADER) 107 | out.writePixels({'R': Rs}) 108 | out.close() 109 | 110 | 111 | def save_uint16_png(path, image): 112 | '''save weight file - scaled png representation of outlines estimation 113 | 114 | Args: 115 | path (str): path to save the file 116 | image (numpy.ndarray): 16-bit single channel image to be saved. 117 | Shape=(H, W), dtype=np.uint16 118 | ''' 119 | assert image.dtype == np.uint16, ("data type of the array should be np.uint16." + "Got {}".format(image.dtype)) 120 | assert len(image.shape) == 2, ("Shape of input image should be (H, W)" + "Got {}".format(len(image.shape))) 121 | 122 | array_buffer = image.tobytes() 123 | img = Image.new("I", image.T.shape) 124 | img.frombytes(array_buffer, 'raw', 'I;16') 125 | img.save(path) 126 | 127 | 128 | def _normalize_depth_img(depth_img, dtype=np.uint8, min_depth=0.0, max_depth=1.0): 129 | '''Converts a floating point depth image to uint8 or uint16 image. 130 | The depth image is first scaled to (0.0, max_depth) and then scaled and converted to given datatype. 131 | 132 | Args: 133 | depth_img (numpy.float32): Depth image, value is depth in meters 134 | dtype (numpy.dtype, optional): Defaults to np.uint16. Output data type. Must be np.uint8 or np.uint16 135 | max_depth (float, optional): The max depth to be considered in the input depth image. The min depth is 136 | considered to be 0.0. 137 | Raises: 138 | ValueError: If wrong dtype is given 139 | 140 | Returns: 141 | numpy.ndarray: Depth image scaled to given dtype 142 | ''' 143 | 144 | if dtype != np.uint16 and dtype != np.uint8: 145 | raise ValueError('Unsupported dtype {}. Must be one of ("np.uint8", "np.uint16")'.format(dtype)) 146 | 147 | # Clip depth image to given range 148 | depth_img = np.ma.masked_array(depth_img, mask=(depth_img == 0.0)) 149 | depth_img = np.ma.clip(depth_img, min_depth, max_depth) 150 | 151 | # Get min/max value of given datatype 152 | type_info = np.iinfo(dtype) 153 | min_val = type_info.min 154 | max_val = type_info.max 155 | 156 | # Scale the depth image to given datatype range 157 | depth_img = ((depth_img - min_depth) / (max_depth - min_depth)) * max_val 158 | depth_img = depth_img.astype(dtype) 159 | 160 | depth_img = np.ma.filled(depth_img, fill_value=0) # Convert back to normal numpy array from masked numpy array 161 | 162 | return depth_img 163 | 164 | 165 | def depth2rgb(depth_img, min_depth=0.0, max_depth=1.5, color_mode=cv2.COLORMAP_JET, reverse_scale=False, 166 | dynamic_scaling=False): 167 | '''Generates RGB representation of a depth image. 168 | To do so, the depth image has to be normalized by specifying a min and max depth to be considered. 169 | 170 | Holes in the depth image (0.0) appear black in color. 171 | 172 | Args: 173 | depth_img (numpy.ndarray): Depth image, values in meters. Shape=(H, W), dtype=np.float32 174 | min_depth (float): Min depth to be considered 175 | max_depth (float): Max depth to be considered 176 | color_mode (int): Integer or cv2 object representing Which coloring scheme to use. 177 | Please consult https://docs.opencv.org/master/d3/d50/group__imgproc__colormap.html 178 | 179 | Each mode is mapped to an int. Eg: cv2.COLORMAP_AUTUMN = 0. 180 | This mapping changes from version to version. 181 | reverse_scale (bool): Whether to make the largest values the smallest to reverse the color mapping 182 | dynamic_scaling (bool): If true, the depth image will be colored according to the min/max depth value within the 183 | image, rather that the passed arguments. 184 | Returns: 185 | numpy.ndarray: RGB representation of depth image. Shape=(H,W,3) 186 | ''' 187 | # Map depth image to Color Map 188 | if dynamic_scaling: 189 | depth_img_scaled = _normalize_depth_img(depth_img, dtype=np.uint8, 190 | min_depth=max(depth_img[depth_img > 0].min(), min_depth), # Add a small epsilon so that min depth does not show up as black (invalid pixels) 191 | max_depth=min(depth_img.max(), max_depth)) 192 | else: 193 | depth_img_scaled = _normalize_depth_img(depth_img, dtype=np.uint8, min_depth=min_depth, max_depth=max_depth) 194 | 195 | if reverse_scale is True: 196 | depth_img_scaled = np.ma.masked_array(depth_img_scaled, mask=(depth_img_scaled == 0.0)) 197 | depth_img_scaled = 255 - depth_img_scaled 198 | depth_img_scaled = np.ma.filled(depth_img_scaled, fill_value=0) 199 | 200 | depth_img_mapped = cv2.applyColorMap(depth_img_scaled, color_mode) 201 | depth_img_mapped = cv2.cvtColor(depth_img_mapped, cv2.COLOR_BGR2RGB) 202 | 203 | # Make holes in input depth black: 204 | depth_img_mapped[depth_img_scaled == 0, :] = 0 205 | 206 | return depth_img_mapped 207 | 208 | 209 | def scale_depth(depth_image): 210 | '''Convert depth in meters (float32) to a scaled uint16 format as required by depth2depth module. 211 | 212 | Args: 213 | depth_image (numpy.ndarray, float32): Depth Image 214 | 215 | Returns: 216 | numpy.ndarray: scaled depth image. dtype=np.uint16 217 | ''' 218 | 219 | assert depth_image.dtype == np.float32, "data type of the array should be float32. Got {}".format(depth_image.dtype) 220 | SCALING_FACTOR = 4000 221 | OUTPUT_DTYPE = np.uint16 222 | 223 | # Prevent Overflow of data by clipping depth values 224 | type_info = np.iinfo(OUTPUT_DTYPE) 225 | max_val = type_info.max 226 | depth_image = np.clip(depth_image, 0, np.floor(max_val / SCALING_FACTOR)) 227 | 228 | return (depth_image * SCALING_FACTOR).astype(OUTPUT_DTYPE) 229 | 230 | 231 | def unscale_depth(depth_image): 232 | '''Unscale the depth image from uint16 to denote the depth in meters (float32) 233 | 234 | Args: 235 | depth_image (numpy.ndarray, uint16): Depth Image 236 | 237 | Returns: 238 | numpy.ndarray: unscaled depth image. dtype=np.float32 239 | ''' 240 | 241 | assert depth_image.dtype == np.uint16, "data type of the array should be uint16. Got {}".format(depth_image.dtype) 242 | SCALING_FACTOR = 4000 243 | 244 | return depth_image.astype(np.float32) / SCALING_FACTOR 245 | 246 | 247 | def normal_to_rgb(normals_to_convert, output_dtype='float'): 248 | '''Converts a surface normals array into an RGB image. 249 | Surface normals are represented in a range of (-1,1), 250 | This is converted to a range of (0,255) for a numpy image, or a range of (0,1) to represent PIL Image. 251 | 252 | The surface normals' axes are mapped as (x,y,z) -> (R,G,B). 253 | 254 | Args: 255 | normals_to_convert (numpy.ndarray): Surface normals, dtype float32, range [-1, 1] 256 | output_dtype (str): format of output, possibel values = ['float', 'uint8'] 257 | if 'float', range of output (0,1) 258 | if 'uint8', range of output (0,255) 259 | ''' 260 | camera_normal_rgb = (normals_to_convert + 1) / 2 261 | if output_dtype == 'uint8': 262 | camera_normal_rgb *= 255 263 | camera_normal_rgb = camera_normal_rgb.astype(np.uint8) 264 | elif output_dtype == 'float': 265 | pass 266 | else: 267 | raise NotImplementedError('Possible values for "output_dtype" are only float and uint8. received value {}'.format(output_dtype)) 268 | 269 | return camera_normal_rgb 270 | 271 | 272 | def _get_point_cloud(color_image, depth_image, fx, fy, cx, cy): 273 | """Creates point cloud from rgb images and depth image 274 | 275 | Args: 276 | color image (numpy.ndarray): Shape=[H, W, C], dtype=np.uint8 277 | depth image (numpy.ndarray): Shape=[H, W], dtype=np.float32. Each pixel contains depth in meters. 278 | fx (int): The focal len along x-axis in pixels of camera used to capture image. 279 | fy (int): The focal len along y-axis in pixels of camera used to capture image. 280 | cx (int): The center of the image (along x-axis, pixels) as per camera used to capture image. 281 | cy (int): The center of the image (along y-axis, pixels) as per camera used to capture image. 282 | Returns: 283 | numpy.ndarray: camera_points - The XYZ location of each pixel. Shape: (num of pixels, 3) 284 | numpy.ndarray: color_points - The RGB color of each pixel. Shape: (num of pixels, 3) 285 | """ 286 | # camera instrinsic parameters 287 | # camera_intrinsics = [[fx 0 cx], 288 | # [0 fy cy], 289 | # [0 0 1]] 290 | camera_intrinsics = np.asarray([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 291 | 292 | image_height = depth_image.shape[0] 293 | image_width = depth_image.shape[1] 294 | pixel_x, pixel_y = np.meshgrid(np.linspace(0, image_width - 1, image_width), 295 | np.linspace(0, image_height - 1, image_height)) 296 | camera_points_x = np.multiply(pixel_x - camera_intrinsics[0, 2], (depth_image / camera_intrinsics[0, 0])) 297 | camera_points_y = np.multiply(pixel_y - camera_intrinsics[1, 2], (depth_image / camera_intrinsics[1, 1])) 298 | camera_points_z = depth_image 299 | camera_points = np.array([camera_points_x, camera_points_y, camera_points_z]).transpose(1, 2, 0).reshape(-1, 3) 300 | 301 | color_points = color_image.reshape(-1, 3) 302 | 303 | return camera_points, color_points 304 | 305 | 306 | def write_point_cloud(filename, color_image, depth_image, fx, fy, cx, cy): 307 | """Creates and Writes a .ply point cloud file using RGB and Depth images. 308 | 309 | Args: 310 | filename (str): The path to the file which should be written. It should end with extension '.ply' 311 | color image (numpy.ndarray): Shape=[H, W, C], dtype=np.uint8 312 | depth image (numpy.ndarray): Shape=[H, W], dtype=np.float32. Each pixel contains depth in meters. 313 | fx (int): The focal len along x-axis in pixels of camera used to capture image. 314 | fy (int): The focal len along y-axis in pixels of camera used to capture image. 315 | cx (int): The center of the image (along x-axis, pixels) as per camera used to capture image. 316 | cy (int): The center of the image (along y-axis, pixels) as per camera used to capture image. 317 | """ 318 | xyz_points, rgb_points = _get_point_cloud(color_image, depth_image, fx, fy, cx, cy) 319 | 320 | # Write header of .ply file 321 | with open(filename, 'wb') as fid: 322 | fid.write(bytes('ply\n', 'utf-8')) 323 | fid.write(bytes('format binary_little_endian 1.0\n', 'utf-8')) 324 | fid.write(bytes('element vertex %d\n' % xyz_points.shape[0], 'utf-8')) 325 | fid.write(bytes('property float x\n', 'utf-8')) 326 | fid.write(bytes('property float y\n', 'utf-8')) 327 | fid.write(bytes('property float z\n', 'utf-8')) 328 | fid.write(bytes('property uchar red\n', 'utf-8')) 329 | fid.write(bytes('property uchar green\n', 'utf-8')) 330 | fid.write(bytes('property uchar blue\n', 'utf-8')) 331 | fid.write(bytes('end_header\n', 'utf-8')) 332 | 333 | # Write 3D points to .ply file 334 | for i in range(xyz_points.shape[0]): 335 | fid.write( 336 | bytearray( 337 | struct.pack("fffccc", xyz_points[i, 0], xyz_points[i, 1], xyz_points[i, 2], 338 | rgb_points[i, 0].tostring(), rgb_points[i, 1].tostring(), rgb_points[i, 2].tostring()))) 339 | 340 | 341 | def imdenormalize(img, mean, std, to_bgr=False, to_rgb=False): 342 | assert img.dtype != np.uint8 343 | mean = mean.reshape(1, -1).astype(np.float64) 344 | std = std.reshape(1, -1).astype(np.float64) 345 | img = cv2.multiply(img, std) # make a copy 346 | cv2.add(img, mean, img) # inplace 347 | #if to_bgr: 348 | # cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace 349 | if to_rgb: 350 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace 351 | return 255-img 352 | 353 | 354 | def depth_to_xyz(depthImage, f, scale_h=1., scale_w=1.): 355 | # input depth image[B, 1, H, W] 356 | # output xyz image[B, 3, H, W] 357 | 358 | fx = f * scale_w 359 | fy = f * scale_h 360 | B, C, H, W = depthImage.shape 361 | device = depthImage.device 362 | du = W//2 - 0.5 363 | dv = H//2 - 0.5 364 | 365 | xyz = torch.zeros([B, H, W, 3], device=device) 366 | imageIndexX = torch.arange(0, W, 1, device=device) - du 367 | imageIndexY = torch.arange(0, H, 1, device=device) - dv 368 | depthImage = depthImage.squeeze() 369 | if B == 1: 370 | depthImage = depthImage.unsqueeze(0) 371 | 372 | xyz[:, :, :, 0] = depthImage/fx * imageIndexX 373 | xyz[:, :, :, 1] = (depthImage.transpose(1, 2)/fy * imageIndexY.T).transpose(1, 2) 374 | xyz[:, :, :, 2] = depthImage 375 | xyz = xyz.permute(0, 3, 1, 2).to(device) 376 | return xyz 377 | 378 | 379 | def gradient(x): 380 | # idea from tf.image.image_gradients(image) 381 | # https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/image_ops_impl.py#L3441-L3512 382 | # x: (b,c,h,w), float32 or float64 383 | # dx, dy: (b,c,h,w) 384 | 385 | # gradient step=1 386 | left = x 387 | right = F.pad(x, [0, 1, 0, 0])[:, :, :, 1:] 388 | top = x 389 | bottom = F.pad(x, [0, 0, 0, 1])[:, :, 1:, :] 390 | 391 | # dx, dy = torch.abs(right - left), torch.abs(bottom - top) 392 | dx, dy = right - left, bottom - top 393 | # dx will always have zeros in the last column, right-left 394 | # dy will always have zeros in the last row, bottom-top 395 | dx[:, :, :, -1] = 0 396 | dy[:, :, -1, :] = 0 397 | 398 | return dx, dy 399 | 400 | 401 | def get_surface_normal(x, f, scale_h, scale_w): 402 | xyz = depth_to_xyz(x, f, scale_h, scale_w) 403 | dx,dy = gradient(xyz) 404 | surface_normal = torch.cross(dx, dy, dim=1) 405 | surface_normal = surface_normal / (torch.norm(surface_normal,dim=1,keepdim=True)+1e-8) 406 | return surface_normal, dx, dy 407 | 408 | 409 | # def create_grid_image(inputs, outputs, labels, max_num_images_to_save=3): 410 | # '''Make a grid of images for display purposes 411 | # Size of grid is (3, N, 3), where each coloum belongs to input, output, label resp 412 | 413 | # Args: 414 | # inputs (Tensor): Batch Tensor of shape (B x C x H x W) 415 | # outputs (Tensor): Batch Tensor of shape (B x C x H x W) 416 | # labels (Tensor): Batch Tensor of shape (B x C x H x W) 417 | # max_num_images_to_save (int, optional): Defaults to 3. Out of the given tensors, chooses a 418 | # max number of imaged to put in grid 419 | 420 | # Returns: 421 | # numpy.ndarray: A numpy array with of input images arranged in a grid 422 | # ''' 423 | 424 | # img_tensor = inputs[:max_num_images_to_save] 425 | 426 | # output_tensor = outputs[:max_num_images_to_save] 427 | # output_tensor_rgb = normal_to_rgb(output_tensor) 428 | 429 | # label_tensor = labels[:max_num_images_to_save] 430 | # label_tensor_rgb = normal_to_rgb(label_tensor) 431 | 432 | # images = torch.cat((img_tensor, output_tensor_rgb, label_tensor_rgb), dim=3) 433 | # grid_image = make_grid(images, 1, normalize=True, scale_each=True) 434 | 435 | # return grid_image 436 | 437 | class label2color(object): 438 | def __init__(self,class_num): 439 | self.class_num = class_num 440 | 441 | self.colors = self.create_pascal_label_colormap(self.class_num) 442 | 443 | def to_color_img(self,imgs): 444 | # img:bs,3,height,width 445 | color_imgs = [] 446 | for i in range(imgs.shape[0]): 447 | score_i = imgs[i,...] 448 | score_i = score_i.cpu().numpy() 449 | score_i = np.transpose(score_i,(1,2,0)) 450 | # np.save('pre.npy',score_i) 451 | score_i = np.argmax(score_i,axis=2) 452 | color_imgs.append(self.colors[score_i]) 453 | return color_imgs 454 | 455 | def single_img_color(self,img): 456 | score_i = img 457 | score_i = score_i.numpy() 458 | score_i = np.transpose(score_i,(1,2,0)) 459 | # np.save('pre.npy',score_i) 460 | # score_i = np.argmax(score_i,axis=2) 461 | return self.colors[score_i] 462 | 463 | def bit_get(self,val, idx): 464 | """Gets the bit value. 465 | Args: 466 | val: Input value, int or numpy int array. 467 | idx: Which bit of the input val. 468 | Returns: 469 | The "idx"-th bit of input val. 470 | """ 471 | return (val >> idx) & 1 472 | 473 | def create_pascal_label_colormap(self,class_num): 474 | """Creates a label colormap used in PASCAL VOC segmentation benchmark. 475 | Returns: 476 | A colormap for visualizing segmentation results. 477 | """ 478 | colormap = np.zeros((class_num, 3), dtype=int) 479 | ind = np.arange(class_num, dtype=int) 480 | 481 | for shift in reversed(range(8)): 482 | for channel in range(3): 483 | colormap[:, channel] |= self.bit_get(ind, channel) << shift 484 | ind >>= 3 485 | 486 | return colormap 487 | -------------------------------------------------------------------------------- /CatePoseEstimation/utils/create_image_grid.py: -------------------------------------------------------------------------------- 1 | '''Functions for reading and saving EXR images using OpenEXR. 2 | ''' 3 | 4 | import sys 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from torchvision.utils import make_grid 10 | import torch.nn as nn 11 | from utils.api_utils import depth2rgb, label2color 12 | import imageio 13 | import os 14 | 15 | sys.path.append('../..') 16 | 17 | def seg_mask_to_rgb(seg_mask, num_classes): 18 | l2c = label2color(num_classes + 1) 19 | seg_mask_color = np.zeros((seg_mask.shape[0], 3, seg_mask.shape[2], seg_mask.shape[3])) 20 | for i in range(seg_mask.shape[0]): 21 | color = l2c.single_img_color(seg_mask[i])#.squeeze(2).transpose(2,0,1).unsqueeze(0) 22 | color = np.squeeze(color,axis=2) 23 | color = color.transpose((2,0,1)) 24 | color = color[np.newaxis,:,:,:] 25 | seg_mask_color[i] = color 26 | seg_mask_color = torch.from_numpy(seg_mask_color) 27 | return seg_mask_color 28 | 29 | def xyz_to_rgb(xyz_map): 30 | xyz_rgb = torch.ones_like(xyz_map) 31 | for i in range(xyz_rgb.shape[0]): 32 | xyz_rgb[i] = torch.div((xyz_map[i] - xyz_map[i].min()), 33 | (xyz_map[i].max() - xyz_map[i].min()).item()) 34 | return xyz_rgb 35 | 36 | def normal_to_rgb(normals_to_convert): 37 | '''Converts a surface normals array into an RGB image. 38 | Surface normals are represented in a range of (-1,1), 39 | This is converted to a range of (0,255) to be written 40 | into an image. 41 | The surface normals are normally in camera co-ords, 42 | with positive z axis coming out of the page. And the axes are 43 | mapped as (x,y,z) -> (R,G,B). 44 | 45 | Args: 46 | normals_to_convert (numpy.ndarray): Surface normals, dtype float32, range [-1, 1] 47 | ''' 48 | camera_normal_rgb = (normals_to_convert + 1) / 2 49 | return camera_normal_rgb 50 | 51 | def create_grid_image(inputs, outputs, labels, rgb, normal_pred, normal_labels, confidence_1, confidence_2, sem_masks=None, pred_sem_seg=None, coords=None, pred_coords=None, max_num_images_to_save=3): 52 | '''Make a grid of images for display purposes 53 | Size of grid is (3, N, 3), where each coloum belongs to input, output, label resp 54 | 55 | Args: 56 | inputs (Tensor): Batch Tensor of shape (B x C x H x W) 57 | outputs (Tensor): Batch Tensor of shape (B x C x H x W) 58 | labels (Tensor): Batch Tensor of shape (B x C x H x W) 59 | max_num_images_to_save (int, optional): Defaults to 3. Out of the given tensors, chooses a 60 | max number of imaged to put in grid 61 | 62 | Returns: 63 | numpy.ndarray: A numpy array with of input images arranged in a grid 64 | ''' 65 | 66 | rgb_tensor = rgb[:max_num_images_to_save] 67 | normal_pred_tensor = normal_pred[:max_num_images_to_save] 68 | normal_pred_tensor = normal_to_rgb(normal_pred_tensor) 69 | normal_labels_tensor = normal_labels[:max_num_images_to_save] 70 | normal_labels_tensor = normal_to_rgb(normal_labels_tensor) 71 | 72 | if not coords is None: 73 | coords_tensor = coords[:max_num_images_to_save] 74 | pred_coords_tensor = pred_coords[:max_num_images_to_save] 75 | 76 | pred_sem_seg_tensor = pred_sem_seg[:max_num_images_to_save] 77 | sem_masks_tensor = sem_masks[:max_num_images_to_save] 78 | pred_sem_seg_tensor = seg_mask_to_rgb(pred_sem_seg_tensor, 8) 79 | sem_masks_tensor = seg_mask_to_rgb(sem_masks_tensor, 8) 80 | 81 | img_tensor = inputs[:max_num_images_to_save] 82 | output_tensor = outputs[:max_num_images_to_save] 83 | label_tensor = labels[:max_num_images_to_save] 84 | confidence_1_tensor = confidence_1[:max_num_images_to_save] 85 | confidence_2_tensor = confidence_2[:max_num_images_to_save] 86 | 87 | 88 | output_tensor[output_tensor < 0] = 0 89 | output_tensor[output_tensor > 4] = 0 90 | 91 | label_tensor[label_tensor < 0] = 0 92 | label_tensor[label_tensor > 4] = 4 93 | 94 | img_tensor = xyz_to_rgb(img_tensor) 95 | 96 | 97 | output_tensor = output_tensor.repeat(1, 3, 1, 1) 98 | label_tensor = label_tensor.repeat(1, 3, 1, 1) 99 | confidence_1_tensor = confidence_1_tensor.repeat(1, 3, 1, 1) 100 | confidence_2_tensor = confidence_2_tensor.repeat(1, 3, 1, 1) 101 | 102 | if coords is None: 103 | images = torch.cat((img_tensor, confidence_1_tensor, confidence_2_tensor, output_tensor, \ 104 | label_tensor, rgb_tensor, normal_pred_tensor, normal_labels_tensor), dim=3) 105 | else: 106 | images = torch.cat((img_tensor, confidence_1_tensor, confidence_2_tensor, output_tensor, \ 107 | label_tensor, rgb_tensor, normal_pred_tensor, normal_labels_tensor, pred_coords_tensor, \ 108 | coords_tensor, pred_sem_seg_tensor, sem_masks_tensor), dim=3) 109 | 110 | # grid_image = make_grid(images, 1, normalize=True, scale_each=True) 111 | grid_image = make_grid(images, 1, normalize=False, scale_each=False) 112 | 113 | return grid_image 114 | -------------------------------------------------------------------------------- /CatePoseEstimation/utils/metrics_depth_restoration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | import torch.nn.functional as F 5 | 6 | 7 | def get_metrics_depth_restoration_train(gt, pred, width, height, seg_mask=None): 8 | 9 | gt = gt.detach().permute(0, 2, 3, 1).cpu().numpy().astype("float32") 10 | pred = pred.detach().permute(0, 2, 3, 1).cpu().numpy().astype("float32") 11 | if not seg_mask is None: 12 | seg_mask = seg_mask.detach().cpu().permute(0, 2, 3, 1).numpy() 13 | 14 | gt_depth = gt 15 | pred_depth = pred 16 | gt_depth[np.isnan(gt_depth)] = 0 17 | gt_depth[np.isinf(gt_depth)] = 0 18 | mask_valid_region = (gt_depth > 0) 19 | 20 | if not seg_mask is None: 21 | seg_mask = seg_mask.astype(np.uint8) 22 | mask_valid_region = np.logical_and(mask_valid_region, seg_mask) 23 | 24 | gt = torch.from_numpy(gt_depth).float().cuda() 25 | pred = torch.from_numpy(pred_depth).float().cuda() 26 | mask = torch.from_numpy(mask_valid_region).bool().cuda() 27 | gt = gt[mask] 28 | pred = pred[mask] 29 | 30 | thresh = torch.max(gt / pred, pred / gt) 31 | 32 | a1 = (thresh < 1.05).float().mean() 33 | a2 = (thresh < 1.10).float().mean() 34 | a3 = (thresh < 1.25).float().mean() 35 | 36 | rmse = ((gt - pred)**2).mean().sqrt() 37 | abs_rel = ((gt - pred).abs() / gt).mean() 38 | mae = (gt - pred).abs().mean() 39 | 40 | return a1, a2, a3, rmse, abs_rel, mae 41 | 42 | 43 | def get_metrics_depth_restoration_inference(gt, pred, width, height, seg_mask=None): 44 | B = gt.shape[0] 45 | gt = F.interpolate(gt, size=[width, height], mode="nearest") 46 | pred = F.interpolate(pred, size=[width, height], mode="nearest") 47 | 48 | gt = gt.detach().permute(0, 2, 3, 1).cpu().numpy().astype("float32") 49 | pred = pred.detach().permute(0, 2, 3, 1).cpu().numpy().astype("float32") 50 | if not seg_mask is None: 51 | seg_mask = seg_mask.float() 52 | seg_mask = F.interpolate(seg_mask, size=[width, height], mode="nearest") 53 | seg_mask = seg_mask.detach().cpu().permute(0, 2, 3, 1).numpy() 54 | 55 | gt_depth = gt 56 | pred_depth = pred 57 | gt_depth[np.isnan(gt_depth)] = 0 58 | gt_depth[np.isinf(gt_depth)] = 0 59 | mask_valid_region = (gt_depth > 0) 60 | 61 | if not seg_mask is None: 62 | seg_mask = seg_mask.astype(np.uint8) 63 | mask_valid_region = np.logical_and(mask_valid_region, seg_mask) 64 | 65 | gt = torch.from_numpy(gt_depth).float().cuda() 66 | pred = torch.from_numpy(pred_depth).float().cuda() 67 | mask = torch.from_numpy(mask_valid_region).bool().cuda() 68 | 69 | a1 = 0.0 70 | a2 = 0.0 71 | a3 = 0.0 72 | rmse = 0.0 73 | abs_rel = 0.0 74 | mae = 0.0 75 | 76 | num_valid = 0 77 | 78 | for i in range(B): 79 | gt_i = gt[i][mask[i]] 80 | pred_i = pred[i][mask[i]] 81 | # print(len(gt_i)) 82 | 83 | if len(gt_i) > 0: 84 | num_valid += 1 85 | thresh = torch.max(gt_i / pred_i, pred_i / gt_i) 86 | 87 | a1_i = (thresh < 1.05).float().mean() 88 | a2_i = (thresh < 1.10).float().mean() 89 | a3_i = (thresh < 1.25).float().mean() 90 | 91 | rmse_i = ((gt_i - pred_i)**2).mean().sqrt() 92 | abs_rel_i = ((gt_i - pred_i).abs() / gt_i).mean() 93 | mae_i = (gt_i - pred_i).abs().mean() 94 | a1 += a1_i 95 | a2 += a2_i 96 | a3 += a3_i 97 | rmse += rmse_i 98 | abs_rel += abs_rel_i 99 | mae += mae_i 100 | # print(a1.item(), a2.item(), a3.item(), rmse.item(), abs_rel.item(), mae.item()) 101 | 102 | return a1, a2, a3, rmse, abs_rel, mae, num_valid -------------------------------------------------------------------------------- /CatePoseEstimation/utils/metrics_sem_seg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def intersect_and_union(pred_label, 4 | label, 5 | num_classes, 6 | ignore_index, 7 | label_map=dict(), 8 | reduce_zero_label=False): 9 | """Calculate intersection and Union. 10 | 11 | Args: 12 | pred_label (ndarray): Prediction segmentation map. 13 | label (ndarray): Ground truth segmentation map. 14 | num_classes (int): Number of categories. 15 | ignore_index (int): Index that will be ignored in evaluation. 16 | label_map (dict): Mapping old labels to new labels. The parameter will 17 | work only when label is str. Default: dict(). 18 | reduce_zero_label (bool): Wether ignore zero label. The parameter will 19 | work only when label is str. Default: False. 20 | 21 | Returns: 22 | ndarray: The intersection of prediction and ground truth histogram 23 | on all classes. 24 | ndarray: The union of prediction and ground truth histogram on all 25 | classes. 26 | ndarray: The prediction histogram on all classes. 27 | ndarray: The ground truth histogram on all classes. 28 | """ 29 | # modify if custom classes 30 | if label_map is not None: 31 | for old_id, new_id in label_map.items(): 32 | label[label == old_id] = new_id 33 | if reduce_zero_label: 34 | # avoid using underflow conversion 35 | label[label == 0] = 255 36 | label = label - 1 37 | label[label == 254] = 255 38 | 39 | mask = (label != ignore_index) 40 | pred_label = pred_label[mask] 41 | label = label[mask] 42 | 43 | intersect = pred_label[pred_label == label] 44 | area_intersect, _ = np.histogram( 45 | intersect, bins=np.arange(num_classes + 1)) 46 | area_pred_label, _ = np.histogram( 47 | pred_label, bins=np.arange(num_classes + 1)) 48 | area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1)) 49 | area_union = area_pred_label + area_label - area_intersect 50 | 51 | return area_intersect, area_union, area_pred_label, area_label 52 | 53 | 54 | def total_intersect_and_union(results, 55 | gt_seg_maps, 56 | num_classes, 57 | ignore_index, 58 | label_map=dict(), 59 | reduce_zero_label=False): 60 | """Calculate Total Intersection and Union. 61 | 62 | Args: 63 | results (list[ndarray]): List of prediction segmentation maps. 64 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 65 | num_classes (int): Number of categories. 66 | ignore_index (int): Index that will be ignored in evaluation. 67 | label_map (dict): Mapping old labels to new labels. Default: dict(). 68 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 69 | 70 | Returns: 71 | ndarray: The intersection of prediction and ground truth histogram 72 | on all classes. 73 | ndarray: The union of prediction and ground truth histogram on all 74 | classes. 75 | ndarray: The prediction histogram on all classes. 76 | ndarray: The ground truth histogram on all classes. 77 | """ 78 | 79 | num_imgs = len(results) 80 | assert len(gt_seg_maps) == num_imgs 81 | total_area_intersect = np.zeros((num_classes, ), dtype=np.float) 82 | total_area_union = np.zeros((num_classes, ), dtype=np.float) 83 | total_area_pred_label = np.zeros((num_classes, ), dtype=np.float) 84 | total_area_label = np.zeros((num_classes, ), dtype=np.float) 85 | for i in range(num_imgs): 86 | area_intersect, area_union, area_pred_label, area_label = \ 87 | intersect_and_union(results[i], gt_seg_maps[i], num_classes, 88 | ignore_index, label_map, reduce_zero_label) 89 | total_area_intersect += area_intersect 90 | total_area_union += area_union 91 | total_area_pred_label += area_pred_label 92 | total_area_label += area_label 93 | return total_area_intersect, total_area_union, \ 94 | total_area_pred_label, total_area_label 95 | 96 | 97 | def mean_iou(results, 98 | gt_seg_maps, 99 | num_classes, 100 | ignore_index, 101 | nan_to_num=None, 102 | label_map=dict(), 103 | reduce_zero_label=False): 104 | """Calculate Mean Intersection and Union (mIoU) 105 | 106 | Args: 107 | results (list[ndarray]): List of prediction segmentation maps. 108 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 109 | num_classes (int): Number of categories. 110 | ignore_index (int): Index that will be ignored in evaluation. 111 | nan_to_num (int, optional): If specified, NaN values will be replaced 112 | by the numbers defined by the user. Default: None. 113 | label_map (dict): Mapping old labels to new labels. Default: dict(). 114 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 115 | 116 | Returns: 117 | float: Overall accuracy on all images. 118 | ndarray: Per category accuracy, shape (num_classes, ). 119 | ndarray: Per category IoU, shape (num_classes, ). 120 | """ 121 | 122 | all_acc, acc, iou = eval_metrics( 123 | results=results, 124 | gt_seg_maps=gt_seg_maps, 125 | num_classes=num_classes, 126 | ignore_index=ignore_index, 127 | metrics=['mIoU'], 128 | nan_to_num=nan_to_num, 129 | label_map=label_map, 130 | reduce_zero_label=reduce_zero_label) 131 | return all_acc, acc, iou 132 | 133 | 134 | def mean_dice(results, 135 | gt_seg_maps, 136 | num_classes, 137 | ignore_index, 138 | nan_to_num=None, 139 | label_map=dict(), 140 | reduce_zero_label=False): 141 | """Calculate Mean Dice (mDice) 142 | 143 | Args: 144 | results (list[ndarray]): List of prediction segmentation maps. 145 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 146 | num_classes (int): Number of categories. 147 | ignore_index (int): Index that will be ignored in evaluation. 148 | nan_to_num (int, optional): If specified, NaN values will be replaced 149 | by the numbers defined by the user. Default: None. 150 | label_map (dict): Mapping old labels to new labels. Default: dict(). 151 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 152 | 153 | Returns: 154 | float: Overall accuracy on all images. 155 | ndarray: Per category accuracy, shape (num_classes, ). 156 | ndarray: Per category dice, shape (num_classes, ). 157 | """ 158 | 159 | all_acc, acc, dice = eval_metrics( 160 | results=results, 161 | gt_seg_maps=gt_seg_maps, 162 | num_classes=num_classes, 163 | ignore_index=ignore_index, 164 | metrics=['mDice'], 165 | nan_to_num=nan_to_num, 166 | label_map=label_map, 167 | reduce_zero_label=reduce_zero_label) 168 | return all_acc, acc, dice 169 | 170 | 171 | def get_metrics_sem_seg(results, 172 | gt_seg_maps, 173 | num_classes, 174 | ignore_index, 175 | metrics=['mIoU'], 176 | nan_to_num=None, 177 | label_map=dict(), 178 | reduce_zero_label=False): 179 | """Calculate evaluation metrics 180 | Args: 181 | results (list[ndarray]): List of prediction segmentation maps. 182 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 183 | num_classes (int): Number of categories. 184 | ignore_index (int): Index that will be ignored in evaluation. 185 | metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. 186 | nan_to_num (int, optional): If specified, NaN values will be replaced 187 | by the numbers defined by the user. Default: None. 188 | label_map (dict): Mapping old labels to new labels. Default: dict(). 189 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 190 | Returns: 191 | float: Overall accuracy on all images. 192 | ndarray: Per category accuracy, shape (num_classes, ). 193 | ndarray: Per category evalution metrics, shape (num_classes, ). 194 | """ 195 | 196 | if isinstance(metrics, str): 197 | metrics = [metrics] 198 | allowed_metrics = ['mIoU', 'mDice'] 199 | if not set(metrics).issubset(set(allowed_metrics)): 200 | raise KeyError('metrics {} is not supported'.format(metrics)) 201 | total_area_intersect, total_area_union, total_area_pred_label, \ 202 | total_area_label = total_intersect_and_union(results, gt_seg_maps, 203 | num_classes, ignore_index, 204 | label_map, 205 | reduce_zero_label) 206 | all_acc = total_area_intersect.sum() / total_area_label.sum() 207 | acc = total_area_intersect / total_area_label 208 | ret_metrics = [all_acc, acc] 209 | for metric in metrics: 210 | if metric == 'mIoU': 211 | iou = total_area_intersect / total_area_union 212 | ret_metrics.append(iou) 213 | elif metric == 'mDice': 214 | dice = 2 * total_area_intersect / ( 215 | total_area_pred_label + total_area_label) 216 | ret_metrics.append(dice) 217 | if nan_to_num is not None: 218 | ret_metrics = [ 219 | np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics 220 | ] 221 | return ret_metrics -------------------------------------------------------------------------------- /DepthSensorSimulator/README.md: -------------------------------------------------------------------------------- 1 | # Domain Randomization-Enhanced Depth Sensor Simulator 2 | ![teaser](teaser_dreds.png) 3 | This is the official implementation of **Domain Randomization-Enhanced Depth Sensor Simulation (DREDS)** pipeline. We simulate an active stereo depth system (**RealSense D415**) using physics-based rendering and generate a large-scale synthetic data that contains **photorealistic RGB images along with their simulated depths carrying realistic sensor noises**. To facilitate generalization, we further adopt domain randomization techniques that randomize the object textures, object materials (from specular, transparent, to diffuse), object layout, floor textures, illuminations along camera poses. Experiments demonstrate that the data generated from this pipeline bridges the sim-to-real domain gap, and significantly improves the generalization ability on depth restoration and the downstream tasks. 4 | 5 | ## Installation 6 | ### IR rendering (Blender-python) 7 | - Download [Blender 2.93.3 (Linux x64)](https://download.blender.org/release/Blender2.93/blender-2.93.3-linux-x64.tar.xz) compressed file and uncompress. 8 | - Download the [environment map asset](https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/simulator/envmap_lib.tar.gz) and the [CAD model](https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/cad_model/). Please follow the [file structure](https://github.com/PKU-EPIC/DREDS/tree/main/DepthSensorSimulator#file-structure). 9 | - Download the [blend file](https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/simulator/material_lib_v2.blend). 10 | - Install the Python packages (Numpy, etc.) into the Blender built-in Python environment. 11 | ``` 12 | cd /home/YourUserName/blender-2.93.3-linux-x64/2.93/python/bin 13 | ./python3.9 -m ensurepip 14 | ./python3.9 -m pip install --upgrade pip --user 15 | ./python3.9 -m pip install numpy --user 16 | ``` 17 | 18 | ### Stereo matching (PyTorch) 19 | - Create a new Anaconda virtual environment 20 | ``` 21 | conda create --name DepthSensorSimulator python=3.9 22 | conda activate DepthSensorSimulator 23 | ``` 24 | - Install the packages: 25 | ``` 26 | conda install pytorch==1.10.0 -c pytorch -c conda-forge 27 | pip install -r requirement.txt 28 | ``` 29 | 30 | ## Usage 31 | - Run the shell script to start data generation, including RGB image, left and right IR images, GT depth, NOCS map, mask, and surface normal. Please set the correct Blender software package root path, working root path, GPU id, etc in run_renderer.sh and renderer.py. 32 | ``` 33 | bash run_renderer.sh 34 | ``` 35 | The rendered results would be saved in ./rendered_output. 36 | - Run the stereo matching code for left and right IR images to get the simulated depth map. 37 | ``` 38 | python stereo_matching.py 39 | ``` 40 | The simulated depths would also be saved in ./rendered_output. See [File Structure](https://github.com/PKU-EPIC/DREDS/tree/main/DepthSensorSimulator#file-structure) 41 | 42 | ## File Structure 43 | ``` 44 | Depth Sensor Simulator 45 | ├── envmap_lib 46 | │   ├── abandoned_factory_canteen_01_1k.hdr 47 | │   └── ... 48 | ├── texture 49 | │   └── texture_0.jpg 50 | ├── cad_model 51 | │   ├── 02691156 52 | │   │   ├── 1a32f10b20170883663e90eaf6b4ca52 53 | │   │   └── ... 54 | │   └── ... 55 | ├── rendered_output # rendered results 56 | │   ├── 00000 # scene id, including 30 frames 57 | │   │   ├── 0000_color.png 58 | │   │   ├── 0000_coord.png # NOCS map (category-level pose annotation) 59 | │   │   ├── 0000_depth_120.exr # GT depth 60 | │   │   ├── 0000_ir_l.png # left IR images 61 | │   │   ├── 0000_ir_r.png # right IR images 62 | │   │   ├── 0000_mask.exr 63 | │   │   ├── 0000_meta.txt # [instance id, category label, category name, instance name, object scale, material label] 64 | │   │   ├── 0000_normal.exr 65 | │   │   ├── 0000_simDepthImage.exr # simulated depth calculated from the stereo IR images 66 | │   │   └── ... 67 | │   ├── 00001 68 | │   └── ... 69 | ├── material_lib_v2.blend 70 | ├── run_renderer.sh 71 | ├── renderer.py 72 | ├── modify_material.py 73 | └── stereo_matching.py 74 | ``` 75 | 76 | ## Todo 77 | - Detailed user instruction -------------------------------------------------------------------------------- /DepthSensorSimulator/requirement.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | matplotlib 3 | kornia 4 | opencv-python -------------------------------------------------------------------------------- /DepthSensorSimulator/run_renderer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # set working root and number of scene 4 | cd /data/sensor/renderer/DepthSensorSimulator 5 | 6 | # run renderer.py 7 | # scene id: 0~2999 8 | mycount=0; 9 | while (( $mycount < 3000 )); do 10 | /home/qiyudai/blender-2.93.3-linux-x64/blender material_lib_v2.blend --background --python renderer.py -- $mycount; 11 | ((mycount=$mycount+1)); 12 | done; -------------------------------------------------------------------------------- /DepthSensorSimulator/teaser_dreds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/DepthSensorSimulator/teaser_dreds.png -------------------------------------------------------------------------------- /DepthSensorSimulator/texture/texture_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/DepthSensorSimulator/texture/texture_0.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Domain Randomization-Enhanced Depth Simulation and Restoration for Perceiving and Grasping Specular and Transparent Objects (ECCV 2022) 2 | 3 | This is the official repository of [**Domain Randomization-Enhanced Depth Simulation and Restoration for Perceiving and Grasping Specular and Transparent Objects**](http://arxiv.org/abs/2208.03792). 4 | 5 | For more information, please visit our [**project page**](https://pku-epic.github.io/DREDS/). 6 | 7 | ## Introduction 8 | ![teaser](images/teaser.png) 9 | 10 | This paper investigates the problem of specular and transparent object depth simulation and restoration. We propose a system composed of a RGBD fusion network **SwinDRNet** for depth restoration, along with a **synthetic data generation pipeline, Domain Randomization-Enhanced Depth Simulation**, to generate the large-scale **synthetic RGBD dataset, DREDS**, that contains 130k photorealistic RGB images and simulated depths with realistic sensor noise. We also curate a **real-world dataset, STD**, that captures 30 cluttered scenes composed of 50 objects with various materials from specular, transparent, to diffuse. Experiments demonstrate that training on our simulated data, SwinDRNet can directly generalize to real RGBD images and significantly boosts the performance of perception and interaction tasks (e.g. **category-level pose estimation, object grasping**) 11 | 12 | ## Overview 13 | This repository provides: 14 | - [Dataset: DREDS (simulated), STD (real)](https://github.com/PKU-EPIC/DREDS#dataset) 15 | - Blender-python code and asset of [Domain randomization-enhanced depth sensor simulator](https://github.com/PKU-EPIC/DREDS/blob/main/DepthSensorSimulator) 16 | - PyTorch code and weights of [Depth restoration network SwinDRNet](https://github.com/PKU-EPIC/DREDS/blob/main/SwinDRNet) 17 | - PyTorch code and weights of [SwinDRNet baseline for category-level pose estimation](https://github.com/PKU-EPIC/DREDS/blob/main/CatePoseEstimation) 18 | 19 | ## Dataset 20 | 21 | ### DREDS dataset (simulated) 22 | - [**DREDS-CatKnown**](https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/): 100,200 training and 19,380 testing RGBD images made of 1,801 objects spanning 7 categories from ShapeNetCore, with randomized specular, transparent, and diffuse materials. 23 | - [**DREDS-CatNovel**](https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatNovel/): 11,520 images of 60 category-novel objects, which is transformed from GraspNet-1Billion that contains CAD models and annotates poses, by changing their object materials to specular or transparent, to verify the ability to generalize to new object categories. 24 | 25 | ### STD dataset (real) 26 | - [**STD-CatKnown**](https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/STD-CatKnown/): 27000 RGBD images of 42 category-level objects spanning 7 categories, captured from 25 different scenes with various backgrounds and illumination. 27 | - [**STD-CatNovel**](https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/STD-CatNovel/): 11000 data of 8 category-novel objects from 5 scenes. 28 | 29 | ### CAD models 30 | We provide the [**CAD models**](https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/cad_model/) of our DREDS and STD dataset, including: 1,801 of DREDS-CatKnown (syn_train & syn_test), 42 of STD-CatKnown (real_cat_known), and 8 of STD-CatNovel (real_cat_known). 31 | 32 | NOTE: The data is only for non-commercial use. 33 | 34 | ## Citation 35 | If you find our work useful in your research, please consider citing: 36 | 37 | ``` 38 | @inproceedings{dai2022dreds, 39 | title={Domain Randomization-Enhanced Depth Simulation and Restoration for Perceiving and Grasping Specular and Transparent Objects}, 40 | author={Dai, Qiyu and Zhang, Jiyao and Li, Qiwei and Wu, Tianhao and Dong, Hao and Liu, Ziyuan and Tan, Ping and Wang, He}, 41 | booktitle={European Conference on Computer Vision (ECCV)}, 42 | year={2022} 43 | } 44 | ``` 45 | 46 | ## License 47 | 48 | This work and the dataset are licensed under [CC BY-NC 4.0][cc-by-nc]. 49 | 50 | [![CC BY-NC 4.0][cc-by-nc-image]][cc-by-nc] 51 | 52 | [cc-by-nc]: https://creativecommons.org/licenses/by-nc/4.0/ 53 | [cc-by-nc-image]: https://licensebuttons.net/l/by-nc/4.0/88x31.png 54 | 55 | ## Contact 56 | If you have any questions, please open a github issue or contact us: 57 | 58 | Qiyu Dai: qiyudai@pku.edu.cn, Jiyao Zhang: zhangjiyao@stu.xjtu.edu.cn, Qiwei Li: lqw@pku.edu.cn, He Wang: hewang@pku.edu.cn -------------------------------------------------------------------------------- /SwinDRNet/README.md: -------------------------------------------------------------------------------- 1 | SwinDRNet 2 | ---------- 3 | 4 | This is the official implementation of SwinDRNet, a depth restoration network proposed in _["Domain Randomization-Enhanced Depth Simulation and Restoration for Perceiving and Grasping Specular and Transparent Objects"](https://arxiv.org/abs/2208.03792)_. SwinDRNet takes inputs of a colored RGB image along with its aligned depth image and outputs a refined depth that restores the error area of the depth image and completes the invalid area caused by specular and transparent objects. The refined depth can be directly used for some downstream tasks (e.g., category-level object 6D pose estimation and robotic grasping). For more details, please see our paper and video. 5 | 6 | ![SwinDRNet](./images/SwinDRNet.png) 7 | 8 | ## System Dependencies 9 | ```bash 10 | $ sudo apt-get install libhdf5-10 libhdf5-serial-dev libhdf5-dev libhdf5-cpp-11 11 | $ sudo apt install libopenexr-dev zlib1g-dev openexr 12 | ``` 13 | ## Setup 14 | - ### Install pip dependencies 15 | We have tested on Ubuntu 20.04 with an NVIDIA GeForce RTX 2080 and NVIDIA GeForce RTX 3090 with Python 3.7. The code may work on other systems.Install the dependencies using pip: 16 | ```bash 17 | $ pip install -r requirments.txt 18 | ``` 19 | - ### Download dataset and models 20 | 21 | 1. Download the pre-trained model, our model and dataset. In the scripts below, be sure to comment out files you do not want, as they are very large. Alternatively, you can download files [manually](https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/) 22 | 23 | ```bash 24 | # Download DREDS and STD Dataset 25 | $ cd data 26 | $ bash DOWNLOAD.sh 27 | $ cd .. 28 | 29 | # Download the pretrained model 30 | $ cd pretrained_model 31 | $ bash DOWNLOAD.sh 32 | $ cd .. 33 | 34 | # Download our model 35 | $ cd models 36 | $ bash DOWNLOAD.sh 37 | $ cd .. 38 | ``` 39 | 2. Extract the downloaded dataset and merge the train split of DREDS-CatKnown following the file structure. 40 | ``` 41 | data 42 | ├── DREDS 43 | │   ├── DREDS-CatKnown 44 | │ │ ├── train 45 | │ │ │ ├── 00001 46 | │ │ │ └── ... 47 | │ │ ├── val 48 | │ │ │ ├── 01162 49 | │ │ │ └── ... 50 | │ │ └── test 51 | │ │ ├── 00000 52 | │ │ └── ... 53 | │   └── DREDS-CatNovel 54 | │ ├── 00029 55 | │ └── ... 56 | └── STD 57 |    ├── STD-CatKnown 58 | │ ├── test_0 59 | │ └── ... 60 |    └── STD-CatNovel 61 | ├── test_novel_0-1 62 | └── ... 63 | ``` 64 | 65 | ## Training 66 | Below is an example for training a SwinDRNet model using the train split of DREDS-CatKnown dataset: 67 | ```bash 68 | # An example command for training 69 | $ python train.py --train_data_path PATH_DRED_CatKnown_TrainSplit --val_data_path PATH_DRED_CatKnown_ValSplit --val_data_type sim/real 70 | ``` 71 | ## Testing 72 | Below is an example for testing the trained SwinDRNet model: 73 | ```bash 74 | # An example command for testing 75 | $ python inference.py --train_data_path PATH_DRED_CatKnown_TrainSplit --val_data_path PATH_DRED_CatKnown_TestSplit --val_data_type sim/real 76 | ``` 77 | -------------------------------------------------------------------------------- /SwinDRNet/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | # Base config files 14 | _C.BASE = [''] 15 | 16 | # ----------------------------------------------------------------------------- 17 | # Data settings 18 | # ----------------------------------------------------------------------------- 19 | _C.DATA = CN() 20 | # Batch size for a single GPU, could be overwritten by command line argument 21 | _C.DATA.BATCH_SIZE = 128 22 | # Path to dataset, could be overwritten by command line argument 23 | _C.DATA.DATA_PATH = '' 24 | # Dataset name 25 | _C.DATA.DATASET = 'imagenet' 26 | # Input image size 27 | _C.DATA.IMG_SIZE = 224 28 | # Interpolation to resize image (random, bilinear, bicubic) 29 | _C.DATA.INTERPOLATION = 'bicubic' 30 | # Use zipped dataset instead of folder dataset 31 | # could be overwritten by command line argument 32 | _C.DATA.ZIP_MODE = False 33 | # Cache Data in Memory, could be overwritten by command line argument 34 | _C.DATA.CACHE_MODE = 'part' 35 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 36 | _C.DATA.PIN_MEMORY = True 37 | # Number of data loading threads 38 | _C.DATA.NUM_WORKERS = 8 39 | 40 | # ----------------------------------------------------------------------------- 41 | # Model settings 42 | # ----------------------------------------------------------------------------- 43 | _C.MODEL = CN() 44 | # Model type 45 | _C.MODEL.TYPE = 'swin' 46 | # Model name 47 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 48 | # Checkpoint to resume, could be overwritten by command line argument 49 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' 50 | _C.MODEL.RESUME = '' 51 | # Number of classes, overwritten in data preparation 52 | _C.MODEL.NUM_CLASSES = 1000 53 | # Dropout rate 54 | _C.MODEL.DROP_RATE = 0.0 55 | # Drop path rate 56 | _C.MODEL.DROP_PATH_RATE = 0.1 57 | # Label Smoothing 58 | _C.MODEL.LABEL_SMOOTHING = 0.1 59 | 60 | # Swin Transformer parameters 61 | _C.MODEL.SWIN = CN() 62 | _C.MODEL.SWIN.PATCH_SIZE = 4 63 | _C.MODEL.SWIN.IN_CHANS = 3 64 | _C.MODEL.SWIN.EMBED_DIM = 96 65 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 66 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] 67 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 68 | _C.MODEL.SWIN.WINDOW_SIZE = 7 69 | _C.MODEL.SWIN.MLP_RATIO = 4. 70 | _C.MODEL.SWIN.QKV_BIAS = True 71 | _C.MODEL.SWIN.QK_SCALE = None 72 | _C.MODEL.SWIN.APE = False 73 | _C.MODEL.SWIN.PATCH_NORM = True 74 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" 75 | 76 | # ----------------------------------------------------------------------------- 77 | # Training settings 78 | # ----------------------------------------------------------------------------- 79 | _C.TRAIN = CN() 80 | _C.TRAIN.START_EPOCH = 0 81 | _C.TRAIN.EPOCHS = 300 82 | _C.TRAIN.WARMUP_EPOCHS = 20 83 | _C.TRAIN.WEIGHT_DECAY = 0.05 84 | _C.TRAIN.BASE_LR = 5e-4 85 | _C.TRAIN.WARMUP_LR = 5e-7 86 | _C.TRAIN.MIN_LR = 5e-6 87 | # Clip gradient norm 88 | _C.TRAIN.CLIP_GRAD = 5.0 89 | # Auto resume from latest checkpoint 90 | _C.TRAIN.AUTO_RESUME = True 91 | # Gradient accumulation steps 92 | # could be overwritten by command line argument 93 | _C.TRAIN.ACCUMULATION_STEPS = 0 94 | # Whether to use gradient checkpointing to save memory 95 | # could be overwritten by command line argument 96 | _C.TRAIN.USE_CHECKPOINT = False 97 | 98 | # LR scheduler 99 | _C.TRAIN.LR_SCHEDULER = CN() 100 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 101 | # Epoch interval to decay LR, used in StepLRScheduler 102 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 103 | # LR decay rate, used in StepLRScheduler 104 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 105 | 106 | # Optimizer 107 | _C.TRAIN.OPTIMIZER = CN() 108 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 109 | # Optimizer Epsilon 110 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 111 | # Optimizer Betas 112 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 113 | # SGD momentum 114 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 115 | 116 | # ----------------------------------------------------------------------------- 117 | # Augmentation settings 118 | # ----------------------------------------------------------------------------- 119 | _C.AUG = CN() 120 | # Color jitter factor 121 | _C.AUG.COLOR_JITTER = 0.4 122 | # Use AutoAugment policy. "v0" or "original" 123 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 124 | # Random erase prob 125 | _C.AUG.REPROB = 0.25 126 | # Random erase mode 127 | _C.AUG.REMODE = 'pixel' 128 | # Random erase count 129 | _C.AUG.RECOUNT = 1 130 | # Mixup alpha, mixup enabled if > 0 131 | _C.AUG.MIXUP = 0.8 132 | # Cutmix alpha, cutmix enabled if > 0 133 | _C.AUG.CUTMIX = 1.0 134 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 135 | _C.AUG.CUTMIX_MINMAX = None 136 | # Probability of performing mixup or cutmix when either/both is enabled 137 | _C.AUG.MIXUP_PROB = 1.0 138 | # Probability of switching to cutmix when both mixup and cutmix enabled 139 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 140 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 141 | _C.AUG.MIXUP_MODE = 'batch' 142 | 143 | # ----------------------------------------------------------------------------- 144 | # Testing settings 145 | # ----------------------------------------------------------------------------- 146 | _C.TEST = CN() 147 | # Whether to use center crop when testing 148 | _C.TEST.CROP = True 149 | 150 | # ----------------------------------------------------------------------------- 151 | # Misc 152 | # ----------------------------------------------------------------------------- 153 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 154 | # overwritten by command line argument 155 | _C.AMP_OPT_LEVEL = '' 156 | # Path to output folder, overwritten by command line argument 157 | _C.OUTPUT = '' 158 | # Tag of experiment, overwritten by command line argument 159 | _C.TAG = 'default' 160 | # Frequency to save checkpoint 161 | _C.SAVE_FREQ = 1 162 | # Frequency to logging info 163 | _C.PRINT_FREQ = 10 164 | # Fixed random seed 165 | _C.SEED = 0 166 | # Perform evaluation only, overwritten by command line argument 167 | _C.EVAL_MODE = False 168 | # Test throughput only, overwritten by command line argument 169 | _C.THROUGHPUT_MODE = False 170 | # local rank for DistributedDataParallel, given by command line argument 171 | _C.LOCAL_RANK = 0 172 | 173 | 174 | def _update_config_from_file(config, cfg_file): 175 | config.defrost() 176 | with open(cfg_file, 'r') as f: 177 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 178 | 179 | for cfg in yaml_cfg.setdefault('BASE', ['']): 180 | if cfg: 181 | _update_config_from_file( 182 | config, os.path.join(os.path.dirname(cfg_file), cfg) 183 | ) 184 | print('=> merge config from {}'.format(cfg_file)) 185 | config.merge_from_file(cfg_file) 186 | config.freeze() 187 | 188 | 189 | def update_config(config, args): 190 | _update_config_from_file(config, args.cfg) 191 | 192 | config.defrost() 193 | if args.opts: 194 | config.merge_from_list(args.opts) 195 | 196 | # merge from specific arguments 197 | if args.batch_size: 198 | config.DATA.BATCH_SIZE = args.batch_size 199 | if args.zip: 200 | config.DATA.ZIP_MODE = True 201 | if args.cache_mode: 202 | config.DATA.CACHE_MODE = args.cache_mode 203 | if args.resume: 204 | config.MODEL.RESUME = args.resume 205 | if args.accumulation_steps: 206 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 207 | if args.use_checkpoint: 208 | config.TRAIN.USE_CHECKPOINT = True 209 | if args.amp_opt_level: 210 | config.AMP_OPT_LEVEL = args.amp_opt_level 211 | if args.tag: 212 | config.TAG = args.tag 213 | if args.eval: 214 | config.EVAL_MODE = True 215 | if args.throughput: 216 | config.THROUGHPUT_MODE = True 217 | 218 | config.freeze() 219 | 220 | 221 | def get_config(args): 222 | """Get a yacs CfgNode object with default values.""" 223 | # Return a clone so that the defaults will not be altered 224 | # This is for the "local variable" use pattern 225 | config = _C.clone() 226 | update_config(config, args) 227 | 228 | return config 229 | -------------------------------------------------------------------------------- /SwinDRNet/configs/swin_tiny_patch4_window7_224_lite.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | PRETRAIN_CKPT: "pretrained_model/swin_tiny_patch4_window7_224.pth" 6 | SWIN: 7 | FINAL_UPSAMPLE: "expand_first" 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | DECODER_DEPTHS: [ 2, 2, 2, 1] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 7 13 | -------------------------------------------------------------------------------- /SwinDRNet/data/DOWNLOAD.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script makes it easy to download DREDS/STD dataset. 4 | # Files are at https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/ 5 | 6 | # Comment out any data you do not want. 7 | 8 | echo 'Warning: Files are *very* large. Be sure to comment out any files you do not want.' 9 | 10 | 11 | #----- DREDS Dataset ----------------------------------- 12 | mkdir DREDS 13 | cd DREDS 14 | mkdir DREDS-CatKnown 15 | cd DREDS-CatKnown 16 | 17 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/test/test.gz # DREDS-CatKnown-Test (73.4G) 18 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/val/val.tar.gz # DREDS-CatKnown-Val (10.5G) 19 | 20 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/train/train_part0.tar.gz # DREDS-CatKnown-Train-Part0 (74.2G) 21 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/train/train_part1.tar.gz # DREDS-CatKnown-Train-Part1 (73.5G) 22 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/train/train_part2.tar.gz # DREDS-CatKnown-Train-Part2 (73.7G) 23 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/train/train_part3.tar.gz # DREDS-CatKnown-Train-Part3 (73.4G) 24 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatKnown/train/train_part4.tar.gz # DREDS-CatKnown-Train-Part4 (73.4G) 25 | 26 | cd .. 27 | 28 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/DREDS-CatNovel/DREDS-CatNovel.tar.gz # DREDS-CatNovel (45.4G) 29 | 30 | cd .. 31 | 32 | 33 | #----- STD Dataset ----------------------------------- 34 | mkdir STD 35 | cd STD 36 | 37 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/STD-CatKnown/STD-CatKnown.tar.gz # STD-CatKnown () 38 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/data/STD-CatNovel/STD-CatNovel.tar.gz # STD-CatNovel () 39 | 40 | cd .. -------------------------------------------------------------------------------- /SwinDRNet/datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/SwinDRNet/datasets/.DS_Store -------------------------------------------------------------------------------- /SwinDRNet/images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/SwinDRNet/images/.DS_Store -------------------------------------------------------------------------------- /SwinDRNet/images/SwinDRNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/SwinDRNet/images/SwinDRNet.png -------------------------------------------------------------------------------- /SwinDRNet/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from networks.SwinDRNet import SwinDRNet 9 | from trainer import SwinDRNetTrainer 10 | from config import get_config 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--mask_transparent', action='store_true', default=True, help='material mask') 15 | parser.add_argument('--mask_specular', action='store_true', default=True, help='material mask') 16 | parser.add_argument('--mask_diffuse', action='store_true', default=True, help='material mask') 17 | 18 | parser.add_argument('--train_data_path', type=str, 19 | default='/data/DREDS/DREDS-CatKnown/train', help='root dir for training dataset') 20 | 21 | parser.add_argument('--val_data_path', type=str, 22 | default='/data/DREDS/DREDS-CatKnown/test', help='root dir for validation dataset') 23 | parser.add_argument('--val_data_type', type=str, 24 | default='sim', help='type of val dataset (real/sim)') 25 | 26 | # parser.add_argument('--val_data_path', type=str, 27 | # default='/data/DREDS/DREDS-CatNovel', help='root dir for data') 28 | # parser.add_argument('--val_data_type', type=str, 29 | # default='sim', help='type of val dataset') 30 | 31 | # parser.add_argument('--val_data_path', type=str, 32 | # default='/data/STD/STD-CatKnown', help='root dir for data') 33 | # parser.add_argument('--val_data_type', type=str, 34 | # default='real', help='type of val dataset') 35 | 36 | # parser.add_argument('--val_data_path', type=str, 37 | # default='/data/STD/STD-CatNovel', help='root dir for data') 38 | # parser.add_argument('--val_data_type', type=str, 39 | # default='real', help='type of val dataset') 40 | 41 | 42 | parser.add_argument('--output_dir', type=str, 43 | default='results/DREDS_CatKnown', help='output dir') 44 | parser.add_argument('--checkpoint_save_path', type=str, 45 | default='models/DREDS', help='Choose a path to save checkpoints') 46 | 47 | 48 | parser.add_argument('--decode_mode', type=str, 49 | default='multi_head', help='Select encode mode') 50 | parser.add_argument('--val_interation_interval', type=int, 51 | default=5000, help='The iteration interval to perform validation') 52 | 53 | parser.add_argument('--percentageDataForTraining', type=float, 54 | default=1.0, help='The percentage of full training data for training') 55 | parser.add_argument('--percentageDataForVal', type=float, 56 | default=1.0, help='The percentage of full training data for training') 57 | 58 | parser.add_argument('--num_classes', type=int, 59 | default=9, help='output channel of network') 60 | parser.add_argument('--max_epochs', type=int, default=20, 61 | help='maximum epoch number to train') 62 | parser.add_argument('--batch_size', type=int, default=64, 63 | help='batch_size per gpu') 64 | parser.add_argument('--n_gpu', type=int, default=1, help='total gpu') 65 | parser.add_argument('--deterministic', type=int, default=1, 66 | help='whether use deterministic training') 67 | parser.add_argument('--base_lr', type=float, default=0.0001, 68 | help='segmentation network learning rate') 69 | parser.add_argument('--img_size', type=int, 70 | default=224, help='input patch size of network input') 71 | parser.add_argument('--seed', type=int, 72 | default=1234, help='random seed') 73 | 74 | parser.add_argument('--cfg', type=str, default="configs/swin_tiny_patch4_window7_224_lite.yaml", metavar="FILE", help='path to config file', ) 75 | parser.add_argument( 76 | "--opts", 77 | help="Modify config options by adding 'KEY VALUE' pairs. ", 78 | default=None, 79 | nargs='+', 80 | ) 81 | parser.add_argument('--zip', action='store_true', default=True, help='use zipped dataset instead of folder dataset') 82 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 83 | help='no: no cache, ' 84 | 'full: cache all data, ' 85 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 86 | parser.add_argument('--resume',type=str, default='./output-1/epoch_149.pth', help='resume from checkpoint') 87 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 88 | parser.add_argument('--use-checkpoint', action='store_true', 89 | help="whether to use gradient checkpointing to save memory") 90 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 91 | help='mixed precision opt level, if O0, no amp is used') 92 | parser.add_argument('--tag', help='tag of experiment') 93 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 94 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 95 | 96 | 97 | args = parser.parse_args() 98 | config = get_config(args) 99 | 100 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 101 | device_list = [0] 102 | model_path = "trained_model/model.pth" 103 | 104 | 105 | if __name__ == "__main__": 106 | if not args.deterministic: 107 | cudnn.benchmark = True 108 | cudnn.deterministic = False 109 | else: 110 | cudnn.benchmark = False 111 | cudnn.deterministic = True 112 | 113 | random.seed(args.seed) 114 | np.random.seed(args.seed) 115 | torch.manual_seed(args.seed) 116 | torch.cuda.manual_seed(args.seed) 117 | 118 | if not os.path.exists(args.output_dir): 119 | os.makedirs(args.output_dir) 120 | 121 | net = SwinDRNet(config, img_size=args.img_size, num_classes=args.num_classes).cuda() 122 | trainer = SwinDRNetTrainer 123 | _trainer = trainer(args, net, device_list, model_path) 124 | 125 | metrics_instance, metrics_background = _trainer.inference() 126 | with open(os.path.join(args.output_dir, 'result.txt'), 'w') as f: 127 | f.write('instance metrics: ') 128 | f.write('\n') 129 | for key in metrics_instance: 130 | f.write(key) 131 | f.write(': ') 132 | f.write(str(metrics_instance[key].item())) 133 | f.write('\n') 134 | 135 | 136 | f.write('\n') 137 | f.write('\n') 138 | 139 | f.write('background metrics: ') 140 | f.write('\n') 141 | for key in metrics_background: 142 | f.write(key) 143 | f.write(': ') 144 | f.write(str(metrics_background[key].item())) 145 | f.write('\n') 146 | f.close() 147 | print('Done!') 148 | print('save path: ', os.path.join(args.output_dir, 'result.txt')) 149 | -------------------------------------------------------------------------------- /SwinDRNet/models/DOWNLOAD.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script makes it easy to download the trained model. 4 | # Files are at https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/checkpoint/SwinDRNet/models/ 5 | 6 | 7 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/checkpoint/SwinDRNet/models/model.pth -------------------------------------------------------------------------------- /SwinDRNet/networks/SwinDRNet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | import warnings 10 | import torch.nn.functional as F 11 | from os.path import join as pjoin 12 | 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from torch.autograd import gradcheck, Variable 17 | 18 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 19 | from torch.nn.modules.utils import _pair 20 | from scipy import ndimage 21 | from .SwinTransformer import SwinTransformerSys 22 | 23 | from .UPerNet import UPerHead, FCNHead 24 | from .CrossAttention import CrossAttention 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class SwinDRNet(nn.Module): 30 | """ SwinDRNet. 31 | A PyTorch impl of SwinDRNet, a depth restoration network proposed in: 32 | `Domain Randomization-Enhanced Depth Simulation and Restoration for 33 | Perceiving and Grasping Specular and Transparent Objects' (ECCV2022) 34 | """ 35 | 36 | def __init__(self, config, img_size=224, num_classes=3): 37 | super(SwinDRNet, self).__init__() 38 | self.num_classes = num_classes 39 | self.config = config 40 | self.img_size = img_size 41 | 42 | self.backbone_rgb_branch = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, 43 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 44 | in_chans=config.MODEL.SWIN.IN_CHANS, 45 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 46 | depths=config.MODEL.SWIN.DEPTHS, 47 | num_heads=config.MODEL.SWIN.NUM_HEADS, 48 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 49 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 50 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 51 | qk_scale=config.MODEL.SWIN.QK_SCALE, 52 | drop_rate=config.MODEL.DROP_RATE, 53 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 54 | ape=config.MODEL.SWIN.APE, 55 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 56 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 57 | self.backbone_xyz_branch = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, 58 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 59 | in_chans=config.MODEL.SWIN.IN_CHANS, 60 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 61 | depths=config.MODEL.SWIN.DEPTHS, 62 | num_heads=config.MODEL.SWIN.NUM_HEADS, 63 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 64 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 65 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 66 | qk_scale=config.MODEL.SWIN.QK_SCALE, 67 | drop_rate=config.MODEL.DROP_RATE, 68 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 69 | ape=config.MODEL.SWIN.APE, 70 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 71 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 72 | 73 | # self.decode_head_sem_seg = UPerHead(num_classes=self.num_classes, img_size = self.img_size) 74 | # self.decode_head_coord = UPerHead(num_classes=3, img_size = self.img_size) 75 | 76 | self.decode_head_depth_restoration = UPerHead(num_classes=1, in_channels=[288, 576, 1152, 2304], img_size = self.img_size) 77 | self.decode_head_confidence = UPerHead(num_classes=2, in_channels=[288, 576, 1152, 2304], img_size = self.img_size) 78 | 79 | self.cross_attention_0 = CrossAttention(in_channel=96, depth=1, num_heads=1) 80 | self.cross_attention_1 = CrossAttention(in_channel=192, depth=1, num_heads=1) 81 | self.cross_attention_2 = CrossAttention(in_channel=384, depth=1, num_heads=1) 82 | self.cross_attention_3 = CrossAttention(in_channel=768, depth=1, num_heads=1) 83 | 84 | self.softmax = nn.Softmax(dim=1) 85 | 86 | 87 | def forward(self, rgb, depth): 88 | """Forward function.""" 89 | 90 | rgb = rgb.repeat(1,3,1,1) if rgb.size()[1] == 1 else rgb # B, C, H, W 91 | depth = depth.repeat(1,3,1,1) if depth.size()[1] == 1 else xyz # B, C, H, W 92 | 93 | # depth = torch.unsqueeze(xyz[:, 2, :, :], 1) 94 | # depth = depth.repeat(1, 3, 1, 1) 95 | 96 | input_org_shape = rgb.shape[2:] 97 | rgb_feature = self.backbone_rgb_branch(rgb) 98 | depth_feature = self.backbone_xyz_branch(depth) 99 | 100 | shortcut = torch.unsqueeze(depth[:, 2, :, :], 1) 101 | 102 | # fusion 103 | x = [] 104 | out = self.cross_attention_0(tuple([rgb_feature[0], depth_feature[0]])) # [B, 96, 56, 56] 105 | x.append(torch.cat((out, rgb_feature[0], depth_feature[0]), 1)) 106 | out = self.cross_attention_1(tuple([rgb_feature[1], depth_feature[1]])) # [B, 192, 28, 28] 107 | x.append(torch.cat((out, rgb_feature[1], depth_feature[1]), 1)) 108 | out = self.cross_attention_2(tuple([rgb_feature[2], depth_feature[2]])) # [B, 384, 14, 14] 109 | x.append(torch.cat((out, rgb_feature[2], depth_feature[2]), 1)) 110 | out = self.cross_attention_3(tuple([rgb_feature[3], depth_feature[3]])) # [B, 768, 7, 7] 111 | x.append(torch.cat((out, rgb_feature[3], depth_feature[3]), 1)) 112 | 113 | # pred_sem_seg = self.decode_head_sem_seg(x, input_org_shape) 114 | # pred_coord = self.decode_head_coord(x, input_org_shape) 115 | pred_depth_initial = self.decode_head_depth_restoration(x, input_org_shape) 116 | confidence = self.softmax(self.decode_head_confidence(x, input_org_shape)) 117 | 118 | confidence_depth = confidence[:, 0, :, :].unsqueeze(1) 119 | confidence_initial = confidence[:, 1, :, :].unsqueeze(1) 120 | 121 | pred_depth = confidence_depth * shortcut + confidence_initial * pred_depth_initial 122 | 123 | return pred_depth, pred_depth_initial, confidence_depth, confidence_initial# , pred_sem_seg, pred_coord 124 | 125 | 126 | def init_weights(self, pretrained=None): 127 | """Initialize the weights in backbone and heads. 128 | Args: 129 | pretrained (str, optional): Path to pre-trained weights. 130 | Defaults to None. 131 | """ 132 | self.backbone_rgb_branch.init_weights(pretrained=pretrained) 133 | self.backbone_xyz_branch.init_weights(pretrained=pretrained) 134 | self.decode_head_confidence.init_weights() 135 | self.decode_head_depth_restoration.init_weights() 136 | self.cross_attention_0.init_weights() 137 | self.cross_attention_1.init_weights() 138 | self.cross_attention_2.init_weights() 139 | self.cross_attention_3.init_weights() 140 | # self.decode_head_sem_seg.init_weights() 141 | # self.decode_head_coord.init_weights() 142 | -------------------------------------------------------------------------------- /SwinDRNet/networks/UPerNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import warnings 5 | 6 | def resize(input, 7 | size=None, 8 | scale_factor=None, 9 | mode='nearest', 10 | align_corners=None, 11 | warning=True): 12 | if warning: 13 | if size is not None and align_corners: 14 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 15 | output_h, output_w = tuple(int(x) for x in size) 16 | if output_h > input_h or output_w > output_h: 17 | if ((output_h > 1 and output_w > 1 and input_h > 1 18 | and input_w > 1) and (output_h - 1) % (input_h - 1) 19 | and (output_w - 1) % (input_w - 1)): 20 | warnings.warn( 21 | f'When align_corners={align_corners}, ' 22 | 'the output would more aligned if ' 23 | f'input size {(input_h, input_w)} is `x+1` and ' 24 | f'out size {(output_h, output_w)} is `nx+1`') 25 | if isinstance(size, torch.Size): 26 | size = tuple(int(x) for x in size) 27 | return F.interpolate(input, size, scale_factor, mode, align_corners) 28 | 29 | 30 | class PPM(nn.ModuleList): 31 | """Pooling Pyramid Module used in PSPNet. 32 | 33 | Args: 34 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 35 | Module. 36 | in_channels (int): Input channels. 37 | channels (int): Channels after modules, before conv_seg. 38 | conv_cfg (dict|None): Config of conv layers. 39 | norm_cfg (dict|None): Config of norm layers. 40 | act_cfg (dict): Config of activation layers. 41 | align_corners (bool): align_corners argument of F.interpolate. 42 | """ 43 | 44 | def __init__(self, pool_scales, in_channels, channels, align_corners): 45 | super(PPM, self).__init__() 46 | self.pool_scales = pool_scales 47 | self.align_corners = align_corners 48 | self.in_channels = in_channels 49 | self.channels = channels 50 | for pool_scale in pool_scales: 51 | self.append( 52 | nn.Sequential( 53 | nn.AdaptiveAvgPool2d(pool_scale), 54 | ConvModule( 55 | self.in_channels, 56 | self.channels, 57 | 1))) 58 | 59 | def forward(self, x): 60 | """Forward function.""" 61 | ppm_outs = [] 62 | for ppm in self: 63 | ppm_out = ppm(x) 64 | upsampled_ppm_out = resize( 65 | ppm_out, 66 | size=x.size()[2:], 67 | mode='bilinear', 68 | align_corners=self.align_corners) 69 | ppm_outs.append(upsampled_ppm_out) 70 | return ppm_outs 71 | 72 | 73 | class ConvModule(nn.Module): 74 | def __init__(self, 75 | in_channels, 76 | out_channels, 77 | kernel_size, 78 | stride=1, 79 | padding=0, 80 | dilation=1, 81 | groups=1, 82 | bias=False, 83 | inplace=True, 84 | with_spectral_norm=False, 85 | padding_mode='zeros', 86 | order=('conv', 'norm', 'act')): 87 | super().__init__() 88 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) 89 | self.bn = nn.BatchNorm2d(out_channels) 90 | self.activate = nn.ReLU(inplace) 91 | self.with_spectral_norm = with_spectral_norm 92 | official_padding_mode = ['zeros', 'circular'] 93 | self.with_explicit_padding = padding_mode not in official_padding_mode 94 | self.order = order 95 | assert isinstance(self.order, tuple) and len(self.order) == 3 96 | assert set(order) == set(['conv', 'norm', 'act']) 97 | 98 | def forward(self, x, activate=True, norm=True): 99 | for layer in self.order: 100 | if layer == 'conv': 101 | #if self.with_explicit_padding: 102 | # x = self.padding_layer(x) 103 | x = self.conv(x) 104 | elif layer == 'norm' and norm: 105 | x = self.bn(x) 106 | elif layer == 'act' and activate: 107 | x = self.activate(x) 108 | return x 109 | 110 | 111 | class UPerHead(nn.Module): 112 | """Unified Perceptual Parsing for Scene Understanding. 113 | 114 | This head is the implementation of `UPerNet 115 | `_. 116 | 117 | Args: 118 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 119 | Module applied on the last feature. Default: (1, 2, 3, 6). 120 | """ 121 | 122 | def __init__(self, pool_scales=(1, 2, 3, 6), num_classes=150, in_channels=[96, 192, 384, 768], 123 | channels=512, dropout_ratio=0.1, align_corners=False, in_index=[0, 1, 2, 3], img_size = 512): 124 | super(UPerHead, self).__init__()#(input_transform='multiple_select', **kwargs) 125 | self.in_channels = in_channels 126 | self.num_classes = num_classes 127 | self.channels = channels 128 | self.dropout_ratio = dropout_ratio 129 | self.align_corners = align_corners 130 | self.input_transform = 'multiple_select' 131 | self.in_index = in_index 132 | self.img_size = img_size 133 | # PSP Module 134 | self.psp_modules = PPM( 135 | pool_scales, 136 | self.in_channels[-1], 137 | self.channels, 138 | align_corners=self.align_corners) 139 | self.bottleneck = ConvModule( 140 | self.in_channels[-1] + len(pool_scales) * self.channels, # in_channels 141 | self.channels, # out_channels 142 | 3, # kernel_size 143 | padding=1 # kernel_size 144 | ) 145 | # FPN Module 146 | self.lateral_convs = nn.ModuleList() 147 | self.fpn_convs = nn.ModuleList() 148 | for in_channels in self.in_channels[:-1]: # skip the top layer 149 | l_conv = ConvModule( 150 | in_channels, 151 | self.channels, 152 | 1, 153 | inplace=False) 154 | fpn_conv = ConvModule( 155 | self.channels, 156 | self.channels, 157 | 3, 158 | padding=1, 159 | inplace=False) 160 | self.lateral_convs.append(l_conv) 161 | self.fpn_convs.append(fpn_conv) 162 | 163 | self.fpn_bottleneck = ConvModule( 164 | len(self.in_channels) * self.channels, 165 | self.channels, 166 | 3, 167 | padding=1) 168 | 169 | #self.fpn_bottleneck = nn.Sequential( 170 | 171 | self.conv_seg = nn.Conv2d(self.channels, self.num_classes, kernel_size=1) 172 | self.dropout = nn.Dropout2d(self.dropout_ratio) 173 | 174 | def psp_forward(self, inputs): 175 | """Forward function of PSP module.""" 176 | x = inputs[-1] 177 | psp_outs = [x] 178 | psp_outs.extend(self.psp_modules(x)) 179 | psp_outs = torch.cat(psp_outs, dim=1) 180 | output = self.bottleneck(psp_outs) 181 | return output 182 | 183 | def _transform_inputs(self, inputs): 184 | """Transform inputs for decoder. 185 | 186 | Args: 187 | inputs (list[Tensor]): List of multi-level img features. 188 | 189 | Returns: 190 | Tensor: The transformed inputs 191 | """ 192 | if self.input_transform == 'resize_concat': 193 | inputs = [inputs[i] for i in self.in_index] 194 | upsampled_inputs = [ 195 | resize( 196 | input=x, 197 | size=inputs[0].shape[2:], 198 | mode='bilinear', 199 | align_corners=self.align_corners) for x in inputs 200 | ] 201 | inputs = torch.cat(upsampled_inputs, dim=1) 202 | elif self.input_transform == 'multiple_select': 203 | inputs = [inputs[i] for i in self.in_index] 204 | else: 205 | inputs = inputs[self.in_index] 206 | 207 | return inputs 208 | 209 | def init_weights(self): 210 | """Initialize weights of classification layer.""" 211 | normal_init(self.conv_seg, mean=0, std=0.01) 212 | 213 | def cls_seg(self, feat): 214 | """Classify each pixel.""" 215 | if self.dropout is not None: 216 | feat = self.dropout(feat) 217 | output = self.conv_seg(feat) 218 | return output 219 | 220 | def forward(self, inputs, input_org_shape): 221 | """Forward function.""" 222 | 223 | inputs = self._transform_inputs(inputs) 224 | 225 | # build laterals 226 | laterals = [ 227 | lateral_conv(inputs[i]) 228 | for i, lateral_conv in enumerate(self.lateral_convs) 229 | ] 230 | 231 | laterals.append(self.psp_forward(inputs)) 232 | 233 | # build top-down path 234 | used_backbone_levels = len(laterals) 235 | for i in range(used_backbone_levels - 1, 0, -1): 236 | prev_shape = laterals[i - 1].shape[2:] 237 | laterals[i - 1] += resize( 238 | laterals[i], 239 | size=prev_shape, 240 | mode='bilinear', 241 | align_corners=self.align_corners) 242 | 243 | # build outputs 244 | fpn_outs = [ 245 | self.fpn_convs[i](laterals[i]) 246 | for i in range(used_backbone_levels - 1) 247 | ] 248 | # append psp feature 249 | fpn_outs.append(laterals[-1]) 250 | 251 | ######################原始方案########################### 252 | # 直接在最后得到logits后做resize 253 | for i in range(used_backbone_levels - 1, 0, -1): 254 | fpn_outs[i] = resize( 255 | fpn_outs[i], 256 | size=fpn_outs[0].shape[2:], 257 | mode='bilinear', 258 | align_corners=self.align_corners) 259 | fpn_outs = torch.cat(fpn_outs, dim=1) 260 | output = self.fpn_bottleneck(fpn_outs) 261 | 262 | output = self.cls_seg(output) # [bs, 150, 128, 128] 263 | output = resize( 264 | input=output, 265 | size=(input_org_shape[0], input_org_shape[1]), 266 | mode='bilinear', 267 | align_corners=self.align_corners) 268 | ######################原始方案########################### 269 | 270 | ######################方案二############################# 271 | # 在fpn后seg_cls前resize 272 | # for i in range(used_backbone_levels - 1, 0, -1): 273 | # fpn_outs[i] = resize( 274 | # fpn_outs[i], 275 | # size=fpn_outs[0].shape[2:], 276 | # mode='bilinear', 277 | # align_corners=self.align_corners) 278 | # fpn_outs = torch.cat(fpn_outs, dim=1) 279 | # output = self.fpn_bottleneck(fpn_outs) 280 | 281 | # output = resize( 282 | # input=output, 283 | # size=(input_org_shape[0], input_org_shape[1]), 284 | # mode='bilinear', 285 | # align_corners=self.align_corners) 286 | 287 | # output = self.cls_seg(output) # [bs, 150, 128, 128] 288 | ######################################################### 289 | 290 | ######################方案三############################# 291 | # 在fpn前resize 292 | # for i in range(used_backbone_levels - 1, -1, -1): 293 | # fpn_outs[i] = resize( 294 | # fpn_outs[i], 295 | # # size=fpn_outs[0].shape[2:], 296 | # size=(input_org_shape[0], input_org_shape[1]), 297 | # mode='bilinear', 298 | # align_corners=self.align_corners) 299 | # fpn_outs = torch.cat(fpn_outs, dim=1) 300 | # output = self.fpn_bottleneck(fpn_outs) 301 | # output = self.cls_seg(output) # [bs, 150, 128, 128] 302 | ######################################################### 303 | return output 304 | 305 | 306 | class FCNHead(nn.Module): 307 | """Fully Convolution Networks for Semantic Segmentation. 308 | 309 | This head is implemented of `FCNNet `_. 310 | 311 | Args: 312 | num_convs (int): Number of convs in the head. Default: 2. 313 | kernel_size (int): The kernel size for convs in the head. Default: 3. 314 | concat_input (bool): Whether concat the input and output of convs 315 | before classification layer. 316 | """ 317 | 318 | def __init__(self, 319 | in_channels, 320 | in_index, 321 | channels, 322 | num_classes, 323 | align_corners=False, 324 | dropout_ratio=0.5, 325 | 326 | num_convs=2, 327 | kernel_size=3, 328 | concat_input=True, 329 | img_size = 512): 330 | self.in_channels = in_channels 331 | self.in_index = in_index 332 | self.channels = channels 333 | self.num_classes = num_classes 334 | self.align_corners = align_corners 335 | self.dropout_ratio = dropout_ratio 336 | self.img_size = img_size 337 | 338 | assert num_convs >= 0 339 | self.num_convs = num_convs 340 | self.concat_input = concat_input 341 | self.kernel_size = kernel_size 342 | self.input_transform = None 343 | super(FCNHead, self).__init__() 344 | if num_convs == 0: 345 | assert self.in_channels == self.channels 346 | 347 | convs = [] 348 | convs.append( 349 | ConvModule( 350 | self.in_channels, 351 | self.channels, 352 | kernel_size=kernel_size, 353 | padding=kernel_size // 2)) 354 | for i in range(num_convs - 1): 355 | convs.append( 356 | ConvModule( 357 | self.channels, 358 | self.channels, 359 | kernel_size=kernel_size, 360 | padding=kernel_size // 2)) 361 | if num_convs == 0: 362 | self.convs = nn.Identity() 363 | else: 364 | self.convs = nn.Sequential(*convs) 365 | if self.concat_input: 366 | self.conv_cat = ConvModule( 367 | self.in_channels + self.channels, 368 | self.channels, 369 | kernel_size=kernel_size, 370 | padding=kernel_size // 2) 371 | 372 | self.conv_seg = nn.Conv2d(self.channels, self.num_classes, kernel_size=1) 373 | self.dropout = nn.Dropout2d(self.dropout_ratio) 374 | 375 | def _transform_inputs(self, inputs): 376 | """Transform inputs for decoder. 377 | 378 | Args: 379 | inputs (list[Tensor]): List of multi-level img features. 380 | 381 | Returns: 382 | Tensor: The transformed inputs 383 | """ 384 | if self.input_transform == 'resize_concat': 385 | inputs = [inputs[i] for i in self.in_index] 386 | upsampled_inputs = [ 387 | resize( 388 | input=x, 389 | size=inputs[0].shape[2:], 390 | mode='bilinear', 391 | align_corners=self.align_corners) for x in inputs 392 | ] 393 | inputs = torch.cat(upsampled_inputs, dim=1) 394 | elif self.input_transform == 'multiple_select': 395 | inputs = [inputs[i] for i in self.in_index] 396 | else: 397 | inputs = inputs[self.in_index] 398 | 399 | return inputs 400 | 401 | def cls_seg(self, feat): 402 | """Classify each pixel.""" 403 | if self.dropout is not None: 404 | feat = self.dropout(feat) 405 | output = self.conv_seg(feat) 406 | return output 407 | 408 | # def normal_init(self, module, mean=0, std=1, bias=0): 409 | # if hasattr(module, 'weight') and module.weight is not None: 410 | # nn.init.normal_(module.weight, mean, std) 411 | # if hasattr(module, 'bias') and module.bias is not None: 412 | # nn.init.constant_(module.bias, bias) 413 | 414 | def init_weights(self): 415 | """Initialize weights of classification layer.""" 416 | normal_init(self.conv_seg, mean=0, std=0.01) 417 | 418 | def forward(self, inputs, input_org_shape): 419 | """Forward function.""" 420 | x = self._transform_inputs(inputs) 421 | output = self.convs(x) 422 | if self.concat_input: 423 | output = self.conv_cat(torch.cat([x, output], dim=1)) 424 | output = self.cls_seg(output) # [bs, 150, 32, 32] 425 | 426 | # resize回原图尺寸 427 | output = resize( 428 | input=output, 429 | size=(input_org_shape[0], input_org_shape[1]), 430 | mode='bilinear', 431 | align_corners=self.align_corners) 432 | 433 | return output 434 | 435 | 436 | def normal_init(module, mean=0, std=1, bias=0): 437 | if hasattr(module, 'weight') and module.weight is not None: 438 | nn.init.normal_(module.weight, mean, std) 439 | #nn.init.constant_(module.weight, 0) 440 | if hasattr(module, 'bias') and module.bias is not None: 441 | nn.init.constant_(module.bias, bias) -------------------------------------------------------------------------------- /SwinDRNet/pretrained_model/DOWNLOAD.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script makes it easy to download the pretrained model. 4 | # Files are at https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/checkpoint/SwinDRNet/pretrain_model/ 5 | 6 | 7 | wget https://mirrors.pku.edu.cn/dl-release/DREDS_ECCV2022/checkpoint/SwinDRNet/pretrain_model/swin_tiny_patch4_window7_224.pth -------------------------------------------------------------------------------- /SwinDRNet/requirments.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | cachetools==5.0.0 3 | certifi==2021.10.8 4 | charset-normalizer==2.0.10 5 | cycler==0.11.0 6 | einops==0.4.0 7 | fonttools==4.28.5 8 | google-auth==2.4.0 9 | google-auth-oauthlib==0.4.6 10 | grpcio==1.43.0 11 | idna==3.3 12 | imageio==2.14.0 13 | imgaug==0.4.0 14 | importlib-metadata==4.10.1 15 | kiwisolver==1.3.2 16 | Markdown==3.3.6 17 | matplotlib==3.5.1 18 | networkx==2.6.3 19 | numpy==1.21.5 20 | oauthlib==3.1.1 21 | opencv-python==4.5.5.62 22 | OpenEXR==1.3.2 23 | packaging==21.3 24 | Pillow==9.0.0 25 | protobuf==3.19.3 26 | pyasn1==0.4.8 27 | pyasn1-modules==0.2.8 28 | pyparsing==3.0.7 29 | python-dateutil==2.8.2 30 | PyWavelets==1.2.0 31 | PyYAML==6.0 32 | requests==2.27.1 33 | requests-oauthlib==1.3.0 34 | rsa==4.8 35 | scikit-image==0.19.1 36 | scipy==1.7.3 37 | Shapely==1.8.0 38 | six==1.16.0 39 | tensorboard==2.8.0 40 | tensorboard-data-server==0.6.1 41 | tensorboard-plugin-wit==1.8.1 42 | termcolor==1.1.0 43 | tifffile==2021.11.2 44 | timm==0.5.4 45 | torch==1.7.1 46 | torchaudio==0.7.2 47 | torchvision==0.8.2 48 | tqdm==4.62.3 49 | typing_extensions==4.0.1 50 | urllib3==1.26.8 51 | Werkzeug==2.0.2 52 | yacs==0.1.8 53 | zipp==3.7.0 -------------------------------------------------------------------------------- /SwinDRNet/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from networks.SwinDRNet import SwinDRNet 9 | from trainer import SwinDRNetTrainer 10 | from config import get_config 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--mask_transparent', action='store_true', default=True, help='material mask') 14 | parser.add_argument('--mask_specular', action='store_true', default=True, help='material mask') 15 | parser.add_argument('--mask_diffuse', action='store_true', default=True, help='material mask') 16 | 17 | parser.add_argument('--train_data_path', type=str, 18 | default='/data/DREDS/DREDS-CatKnown/train', help='root dir for training dataset') 19 | parser.add_argument('--val_data_path', type=str, 20 | default='/data/DREDS/DREDS-CatKnown/val', help='root dir for validation dataset') 21 | parser.add_argument('--val_data_type', type=str, 22 | default='sim', help='type of val dataset (real/sim)') 23 | parser.add_argument('--output_dir', type=str, 24 | default='results', help='output dir') 25 | 26 | parser.add_argument('--decode_mode', type=str, 27 | default='multi_head', help='Select encode mode') 28 | parser.add_argument('--checkpoint_save_path', type=str, 29 | default='models', help='Choose a path to save checkpoints') 30 | 31 | parser.add_argument('--val_interation_interval', type=int, 32 | default=5000, help='The iteration interval to perform validation') 33 | 34 | parser.add_argument('--percentageDataForTraining', type=float, 35 | default=1.0, help='The percentage of full training data for training') 36 | parser.add_argument('--percentageDataForVal', type=float, 37 | default=1.0, help='The percentage of full training data for training') 38 | 39 | parser.add_argument('--num_classes', type=int, 40 | default=9, help='output channel of network') 41 | parser.add_argument('--max_epochs', type=int, default=20, 42 | help='maximum epoch number to train') 43 | parser.add_argument('--batch_size', type=int, default=8, 44 | help='batch_size per gpu') 45 | parser.add_argument('--n_gpu', type=int, default=1, help='total gpu') 46 | parser.add_argument('--deterministic', type=int, default=1, 47 | help='whether use deterministic training') 48 | parser.add_argument('--base_lr', type=float, default=0.0001, 49 | help='segmentation network learning rate') 50 | parser.add_argument('--img_size', type=int, 51 | default=224, help='input patch size of network input') 52 | parser.add_argument('--seed', type=int, 53 | default=1234, help='random seed') 54 | 55 | parser.add_argument('--cfg', type=str, default="configs/swin_tiny_patch4_window7_224_lite.yaml", metavar="FILE", help='path to config file', ) 56 | parser.add_argument( 57 | "--opts", 58 | help="Modify config options by adding 'KEY VALUE' pairs. ", 59 | default=None, 60 | nargs='+', 61 | ) 62 | parser.add_argument('--zip', action='store_true', default=True, help='use zipped dataset instead of folder dataset') 63 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 64 | help='no: no cache, ' 65 | 'full: cache all data, ' 66 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 67 | parser.add_argument('--resume',type=str, default='./output-1/epoch_149.pth', help='resume from checkpoint') 68 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 69 | parser.add_argument('--use-checkpoint', action='store_true', 70 | help="whether to use gradient checkpointing to save memory") 71 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 72 | help='mixed precision opt level, if O0, no amp is used') 73 | parser.add_argument('--tag', help='tag of experiment') 74 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 75 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 76 | 77 | 78 | args = parser.parse_args() 79 | config = get_config(args) 80 | 81 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 82 | device_list = [0] 83 | 84 | if __name__ == "__main__": 85 | if not args.deterministic: 86 | cudnn.benchmark = True 87 | cudnn.deterministic = False 88 | else: 89 | cudnn.benchmark = False 90 | cudnn.deterministic = True 91 | 92 | random.seed(args.seed) 93 | np.random.seed(args.seed) 94 | torch.manual_seed(args.seed) 95 | torch.cuda.manual_seed(args.seed) 96 | 97 | if not os.path.exists(args.output_dir): 98 | os.makedirs(args.output_dir) 99 | 100 | net = SwinDRNet(config, img_size=args.img_size, num_classes=args.num_classes).cuda() 101 | continue_ckpt_path = None 102 | 103 | if continue_ckpt_path is None: 104 | # net.load_from(config) 105 | pretrained_path = config.MODEL.PRETRAIN_CKPT 106 | net.init_weights(pretrained_path) 107 | 108 | trainer = SwinDRNetTrainer 109 | _trainer = trainer(args, net, device_list, continue_ckpt_path) 110 | _trainer.train() -------------------------------------------------------------------------------- /SwinDRNet/utils/api_utils.py: -------------------------------------------------------------------------------- 1 | '''Misc functions like functions for reading and saving EXR images using OpenEXR, saving pointclouds, etc. 2 | ''' 3 | import struct 4 | import numpy as np 5 | import cv2 6 | import Imath 7 | import OpenEXR 8 | from PIL import Image 9 | import torch 10 | import torch.nn.functional as F 11 | # from torchvision.utils import make_grid 12 | 13 | 14 | def exr_loader(EXR_PATH, ndim=3): 15 | """Loads a .exr file as a numpy array 16 | 17 | Args: 18 | EXR_PATH: path to the exr file 19 | ndim: number of channels that should be in returned array. Valid values are 1 and 3. 20 | if ndim=1, only the 'R' channel is taken from exr file 21 | if ndim=3, the 'R', 'G' and 'B' channels are taken from exr file. 22 | The exr file must have 3 channels in this case. 23 | Returns: 24 | numpy.ndarray (dtype=np.float32): If ndim=1, shape is (height x width) 25 | If ndim=3, shape is (3 x height x width) 26 | 27 | """ 28 | 29 | exr_file = OpenEXR.InputFile(EXR_PATH) 30 | cm_dw = exr_file.header()['dataWindow'] 31 | size = (cm_dw.max.x - cm_dw.min.x + 1, cm_dw.max.y - cm_dw.min.y + 1) 32 | 33 | pt = Imath.PixelType(Imath.PixelType.FLOAT) 34 | 35 | if ndim == 3: 36 | # read channels indivudally 37 | allchannels = [] 38 | for c in ['R', 'G', 'B']: 39 | # transform data to numpy 40 | channel = np.frombuffer(exr_file.channel(c, pt), dtype=np.float32) 41 | channel.shape = (size[1], size[0]) 42 | allchannels.append(channel) 43 | 44 | # create array and transpose dimensions to match tensor style 45 | exr_arr = np.array(allchannels).transpose((0, 1, 2)) 46 | return exr_arr 47 | 48 | if ndim == 1: 49 | # transform data to numpy 50 | channel = np.frombuffer(exr_file.channel('R', pt), dtype=np.float32) 51 | channel.shape = (size[1], size[0]) # Numpy arrays are (row, col) 52 | exr_arr = np.array(channel) 53 | return exr_arr 54 | 55 | 56 | def exr_saver(EXR_PATH, ndarr, ndim=3): 57 | '''Saves a numpy array as an EXR file with HALF precision (float16) 58 | Args: 59 | EXR_PATH (str): The path to which file will be saved 60 | ndarr (ndarray): A numpy array containing img data 61 | ndim (int): The num of dimensions in the saved exr image, either 3 or 1. 62 | If ndim = 3, ndarr should be of shape (height, width) or (3 x height x width), 63 | If ndim = 1, ndarr should be of shape (height, width) 64 | Returns: 65 | None 66 | ''' 67 | if ndim == 3: 68 | # Check params 69 | if len(ndarr.shape) == 2: 70 | # If a depth image of shape (height x width) is passed, convert into shape (3 x height x width) 71 | ndarr = np.stack((ndarr, ndarr, ndarr), axis=0) 72 | 73 | if ndarr.shape[0] != 3 or len(ndarr.shape) != 3: 74 | raise ValueError( 75 | 'The shape of the tensor should be (3 x height x width) for ndim = 3. Given shape is {}'.format( 76 | ndarr.shape)) 77 | 78 | # Convert each channel to strings 79 | Rs = ndarr[0, :, :].astype(np.float16).tostring() 80 | Gs = ndarr[1, :, :].astype(np.float16).tostring() 81 | Bs = ndarr[2, :, :].astype(np.float16).tostring() 82 | 83 | # Write the three color channels to the output file 84 | HEADER = OpenEXR.Header(ndarr.shape[2], ndarr.shape[1]) 85 | half_chan = Imath.Channel(Imath.PixelType(Imath.PixelType.HALF)) 86 | HEADER['channels'] = dict([(c, half_chan) for c in "RGB"]) 87 | 88 | out = OpenEXR.OutputFile(EXR_PATH, HEADER) 89 | out.writePixels({'R': Rs, 'G': Gs, 'B': Bs}) 90 | out.close() 91 | elif ndim == 1: 92 | # Check params 93 | if len(ndarr.shape) != 2: 94 | raise ValueError(('The shape of the tensor should be (height x width) for ndim = 1. ' + 95 | 'Given shape is {}'.format(ndarr.shape))) 96 | 97 | # Convert each channel to strings 98 | Rs = ndarr[:, :].astype(np.float16).tostring() 99 | 100 | # Write the color channel to the output file 101 | HEADER = OpenEXR.Header(ndarr.shape[1], ndarr.shape[0]) 102 | half_chan = Imath.Channel(Imath.PixelType(Imath.PixelType.HALF)) 103 | HEADER['channels'] = dict([(c, half_chan) for c in "R"]) 104 | 105 | out = OpenEXR.OutputFile(EXR_PATH, HEADER) 106 | out.writePixels({'R': Rs}) 107 | out.close() 108 | 109 | 110 | def save_uint16_png(path, image): 111 | '''save weight file - scaled png representation of outlines estimation 112 | 113 | Args: 114 | path (str): path to save the file 115 | image (numpy.ndarray): 16-bit single channel image to be saved. 116 | Shape=(H, W), dtype=np.uint16 117 | ''' 118 | assert image.dtype == np.uint16, ("data type of the array should be np.uint16." + "Got {}".format(image.dtype)) 119 | assert len(image.shape) == 2, ("Shape of input image should be (H, W)" + "Got {}".format(len(image.shape))) 120 | 121 | array_buffer = image.tobytes() 122 | img = Image.new("I", image.T.shape) 123 | img.frombytes(array_buffer, 'raw', 'I;16') 124 | img.save(path) 125 | 126 | 127 | def _normalize_depth_img(depth_img, dtype=np.uint8, min_depth=0.0, max_depth=1.0): 128 | '''Converts a floating point depth image to uint8 or uint16 image. 129 | The depth image is first scaled to (0.0, max_depth) and then scaled and converted to given datatype. 130 | 131 | Args: 132 | depth_img (numpy.float32): Depth image, value is depth in meters 133 | dtype (numpy.dtype, optional): Defaults to np.uint16. Output data type. Must be np.uint8 or np.uint16 134 | max_depth (float, optional): The max depth to be considered in the input depth image. The min depth is 135 | considered to be 0.0. 136 | Raises: 137 | ValueError: If wrong dtype is given 138 | 139 | Returns: 140 | numpy.ndarray: Depth image scaled to given dtype 141 | ''' 142 | 143 | if dtype != np.uint16 and dtype != np.uint8: 144 | raise ValueError('Unsupported dtype {}. Must be one of ("np.uint8", "np.uint16")'.format(dtype)) 145 | 146 | # Clip depth image to given range 147 | depth_img = np.ma.masked_array(depth_img, mask=(depth_img == 0.0)) 148 | depth_img = np.ma.clip(depth_img, min_depth, max_depth) 149 | 150 | # Get min/max value of given datatype 151 | type_info = np.iinfo(dtype) 152 | min_val = type_info.min 153 | max_val = type_info.max 154 | 155 | # Scale the depth image to given datatype range 156 | depth_img = ((depth_img - min_depth) / (max_depth - min_depth)) * max_val 157 | depth_img = depth_img.astype(dtype) 158 | 159 | depth_img = np.ma.filled(depth_img, fill_value=0) # Convert back to normal numpy array from masked numpy array 160 | 161 | return depth_img 162 | 163 | 164 | def depth2rgb(depth_img, min_depth=0.0, max_depth=1.5, color_mode=cv2.COLORMAP_JET, reverse_scale=False, 165 | dynamic_scaling=False): 166 | '''Generates RGB representation of a depth image. 167 | To do so, the depth image has to be normalized by specifying a min and max depth to be considered. 168 | 169 | Holes in the depth image (0.0) appear black in color. 170 | 171 | Args: 172 | depth_img (numpy.ndarray): Depth image, values in meters. Shape=(H, W), dtype=np.float32 173 | min_depth (float): Min depth to be considered 174 | max_depth (float): Max depth to be considered 175 | color_mode (int): Integer or cv2 object representing Which coloring scheme to use. 176 | Please consult https://docs.opencv.org/master/d3/d50/group__imgproc__colormap.html 177 | 178 | Each mode is mapped to an int. Eg: cv2.COLORMAP_AUTUMN = 0. 179 | This mapping changes from version to version. 180 | reverse_scale (bool): Whether to make the largest values the smallest to reverse the color mapping 181 | dynamic_scaling (bool): If true, the depth image will be colored according to the min/max depth value within the 182 | image, rather that the passed arguments. 183 | Returns: 184 | numpy.ndarray: RGB representation of depth image. Shape=(H,W,3) 185 | ''' 186 | # Map depth image to Color Map 187 | if dynamic_scaling: 188 | depth_img_scaled = _normalize_depth_img(depth_img, dtype=np.uint8, 189 | min_depth=max(depth_img[depth_img > 0].min(), min_depth), # Add a small epsilon so that min depth does not show up as black (invalid pixels) 190 | max_depth=min(depth_img.max(), max_depth)) 191 | else: 192 | depth_img_scaled = _normalize_depth_img(depth_img, dtype=np.uint8, min_depth=min_depth, max_depth=max_depth) 193 | 194 | if reverse_scale is True: 195 | depth_img_scaled = np.ma.masked_array(depth_img_scaled, mask=(depth_img_scaled == 0.0)) 196 | depth_img_scaled = 255 - depth_img_scaled 197 | depth_img_scaled = np.ma.filled(depth_img_scaled, fill_value=0) 198 | 199 | depth_img_mapped = cv2.applyColorMap(depth_img_scaled, color_mode) 200 | depth_img_mapped = cv2.cvtColor(depth_img_mapped, cv2.COLOR_BGR2RGB) 201 | 202 | # Make holes in input depth black: 203 | depth_img_mapped[depth_img_scaled == 0, :] = 0 204 | 205 | return depth_img_mapped 206 | 207 | 208 | def scale_depth(depth_image): 209 | '''Convert depth in meters (float32) to a scaled uint16 format as required by depth2depth module. 210 | 211 | Args: 212 | depth_image (numpy.ndarray, float32): Depth Image 213 | 214 | Returns: 215 | numpy.ndarray: scaled depth image. dtype=np.uint16 216 | ''' 217 | 218 | assert depth_image.dtype == np.float32, "data type of the array should be float32. Got {}".format(depth_image.dtype) 219 | SCALING_FACTOR = 4000 220 | OUTPUT_DTYPE = np.uint16 221 | 222 | # Prevent Overflow of data by clipping depth values 223 | type_info = np.iinfo(OUTPUT_DTYPE) 224 | max_val = type_info.max 225 | depth_image = np.clip(depth_image, 0, np.floor(max_val / SCALING_FACTOR)) 226 | 227 | return (depth_image * SCALING_FACTOR).astype(OUTPUT_DTYPE) 228 | 229 | 230 | def unscale_depth(depth_image): 231 | '''Unscale the depth image from uint16 to denote the depth in meters (float32) 232 | 233 | Args: 234 | depth_image (numpy.ndarray, uint16): Depth Image 235 | 236 | Returns: 237 | numpy.ndarray: unscaled depth image. dtype=np.float32 238 | ''' 239 | 240 | assert depth_image.dtype == np.uint16, "data type of the array should be uint16. Got {}".format(depth_image.dtype) 241 | SCALING_FACTOR = 4000 242 | 243 | return depth_image.astype(np.float32) / SCALING_FACTOR 244 | 245 | 246 | def normal_to_rgb(normals_to_convert, output_dtype='float'): 247 | '''Converts a surface normals array into an RGB image. 248 | Surface normals are represented in a range of (-1,1), 249 | This is converted to a range of (0,255) for a numpy image, or a range of (0,1) to represent PIL Image. 250 | 251 | The surface normals' axes are mapped as (x,y,z) -> (R,G,B). 252 | 253 | Args: 254 | normals_to_convert (numpy.ndarray): Surface normals, dtype float32, range [-1, 1] 255 | output_dtype (str): format of output, possibel values = ['float', 'uint8'] 256 | if 'float', range of output (0,1) 257 | if 'uint8', range of output (0,255) 258 | ''' 259 | camera_normal_rgb = (normals_to_convert + 1) / 2 260 | if output_dtype == 'uint8': 261 | camera_normal_rgb *= 255 262 | camera_normal_rgb = camera_normal_rgb.astype(np.uint8) 263 | elif output_dtype == 'float': 264 | pass 265 | else: 266 | raise NotImplementedError('Possible values for "output_dtype" are only float and uint8. received value {}'.format(output_dtype)) 267 | 268 | return camera_normal_rgb 269 | 270 | 271 | def _get_point_cloud(color_image, depth_image, fx, fy, cx, cy): 272 | """Creates point cloud from rgb images and depth image 273 | 274 | Args: 275 | color image (numpy.ndarray): Shape=[H, W, C], dtype=np.uint8 276 | depth image (numpy.ndarray): Shape=[H, W], dtype=np.float32. Each pixel contains depth in meters. 277 | fx (int): The focal len along x-axis in pixels of camera used to capture image. 278 | fy (int): The focal len along y-axis in pixels of camera used to capture image. 279 | cx (int): The center of the image (along x-axis, pixels) as per camera used to capture image. 280 | cy (int): The center of the image (along y-axis, pixels) as per camera used to capture image. 281 | Returns: 282 | numpy.ndarray: camera_points - The XYZ location of each pixel. Shape: (num of pixels, 3) 283 | numpy.ndarray: color_points - The RGB color of each pixel. Shape: (num of pixels, 3) 284 | """ 285 | # camera instrinsic parameters 286 | # camera_intrinsics = [[fx 0 cx], 287 | # [0 fy cy], 288 | # [0 0 1]] 289 | camera_intrinsics = np.asarray([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 290 | 291 | image_height = depth_image.shape[0] 292 | image_width = depth_image.shape[1] 293 | pixel_x, pixel_y = np.meshgrid(np.linspace(0, image_width - 1, image_width), 294 | np.linspace(0, image_height - 1, image_height)) 295 | camera_points_x = np.multiply(pixel_x - camera_intrinsics[0, 2], (depth_image / camera_intrinsics[0, 0])) 296 | camera_points_y = np.multiply(pixel_y - camera_intrinsics[1, 2], (depth_image / camera_intrinsics[1, 1])) 297 | camera_points_z = depth_image 298 | camera_points = np.array([camera_points_x, camera_points_y, camera_points_z]).transpose(1, 2, 0).reshape(-1, 3) 299 | 300 | color_points = color_image.reshape(-1, 3) 301 | 302 | return camera_points, color_points 303 | 304 | 305 | def write_point_cloud(filename, color_image, depth_image, fx, fy, cx, cy): 306 | """Creates and Writes a .ply point cloud file using RGB and Depth images. 307 | 308 | Args: 309 | filename (str): The path to the file which should be written. It should end with extension '.ply' 310 | color image (numpy.ndarray): Shape=[H, W, C], dtype=np.uint8 311 | depth image (numpy.ndarray): Shape=[H, W], dtype=np.float32. Each pixel contains depth in meters. 312 | fx (int): The focal len along x-axis in pixels of camera used to capture image. 313 | fy (int): The focal len along y-axis in pixels of camera used to capture image. 314 | cx (int): The center of the image (along x-axis, pixels) as per camera used to capture image. 315 | cy (int): The center of the image (along y-axis, pixels) as per camera used to capture image. 316 | """ 317 | xyz_points, rgb_points = _get_point_cloud(color_image, depth_image, fx, fy, cx, cy) 318 | 319 | # Write header of .ply file 320 | with open(filename, 'wb') as fid: 321 | fid.write(bytes('ply\n', 'utf-8')) 322 | fid.write(bytes('format binary_little_endian 1.0\n', 'utf-8')) 323 | fid.write(bytes('element vertex %d\n' % xyz_points.shape[0], 'utf-8')) 324 | fid.write(bytes('property float x\n', 'utf-8')) 325 | fid.write(bytes('property float y\n', 'utf-8')) 326 | fid.write(bytes('property float z\n', 'utf-8')) 327 | fid.write(bytes('property uchar red\n', 'utf-8')) 328 | fid.write(bytes('property uchar green\n', 'utf-8')) 329 | fid.write(bytes('property uchar blue\n', 'utf-8')) 330 | fid.write(bytes('end_header\n', 'utf-8')) 331 | 332 | # Write 3D points to .ply file 333 | for i in range(xyz_points.shape[0]): 334 | fid.write( 335 | bytearray( 336 | struct.pack("fffccc", xyz_points[i, 0], xyz_points[i, 1], xyz_points[i, 2], 337 | rgb_points[i, 0].tostring(), rgb_points[i, 1].tostring(), rgb_points[i, 2].tostring()))) 338 | 339 | 340 | def imdenormalize(img, mean, std, to_bgr=False, to_rgb=False): 341 | assert img.dtype != np.uint8 342 | mean = mean.reshape(1, -1).astype(np.float64) 343 | std = std.reshape(1, -1).astype(np.float64) 344 | img = cv2.multiply(img, std) # make a copy 345 | cv2.add(img, mean, img) # inplace 346 | #if to_bgr: 347 | # cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace 348 | if to_rgb: 349 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace 350 | return 255-img 351 | 352 | 353 | def depth_to_xyz(depthImage, f, scale_h=1., scale_w=1.): 354 | # input depth image[B, 1, H, W] 355 | # output xyz image[B, 3, H, W] 356 | 357 | fx = f * scale_w 358 | fy = f * scale_h 359 | B, C, H, W = depthImage.shape 360 | device = depthImage.device 361 | du = W//2 - 0.5 362 | dv = H//2 - 0.5 363 | 364 | xyz = torch.zeros([B, H, W, 3], device=device) 365 | imageIndexX = torch.arange(0, W, 1, device=device) - du 366 | imageIndexY = torch.arange(0, H, 1, device=device) - dv 367 | depthImage = depthImage.squeeze() 368 | if B == 1: 369 | depthImage = depthImage.unsqueeze(0) 370 | 371 | xyz[:, :, :, 0] = depthImage/fx * imageIndexX 372 | xyz[:, :, :, 1] = (depthImage.transpose(1, 2)/fy * imageIndexY.T).transpose(1, 2) 373 | xyz[:, :, :, 2] = depthImage 374 | xyz = xyz.permute(0, 3, 1, 2).to(device) 375 | return xyz 376 | 377 | 378 | def gradient(x): 379 | # idea from tf.image.image_gradients(image) 380 | # https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/image_ops_impl.py#L3441-L3512 381 | # x: (b,c,h,w), float32 or float64 382 | # dx, dy: (b,c,h,w) 383 | 384 | # gradient step=1 385 | left = x 386 | right = F.pad(x, [0, 1, 0, 0])[:, :, :, 1:] 387 | top = x 388 | bottom = F.pad(x, [0, 0, 0, 1])[:, :, 1:, :] 389 | 390 | # dx, dy = torch.abs(right - left), torch.abs(bottom - top) 391 | dx, dy = right - left, bottom - top 392 | # dx will always have zeros in the last column, right-left 393 | # dy will always have zeros in the last row, bottom-top 394 | dx[:, :, :, -1] = 0 395 | dy[:, :, -1, :] = 0 396 | 397 | return dx, dy 398 | 399 | 400 | def get_surface_normal(x, f, scale_h, scale_w): 401 | xyz = depth_to_xyz(x, f, scale_h, scale_w) 402 | dx,dy = gradient(xyz) 403 | surface_normal = torch.cross(dx, dy, dim=1) 404 | surface_normal = surface_normal / (torch.norm(surface_normal,dim=1,keepdim=True)+1e-8) 405 | return surface_normal, dx, dy 406 | 407 | 408 | # def create_grid_image(inputs, outputs, labels, max_num_images_to_save=3): 409 | # '''Make a grid of images for display purposes 410 | # Size of grid is (3, N, 3), where each coloum belongs to input, output, label resp 411 | 412 | # Args: 413 | # inputs (Tensor): Batch Tensor of shape (B x C x H x W) 414 | # outputs (Tensor): Batch Tensor of shape (B x C x H x W) 415 | # labels (Tensor): Batch Tensor of shape (B x C x H x W) 416 | # max_num_images_to_save (int, optional): Defaults to 3. Out of the given tensors, chooses a 417 | # max number of imaged to put in grid 418 | 419 | # Returns: 420 | # numpy.ndarray: A numpy array with of input images arranged in a grid 421 | # ''' 422 | 423 | # img_tensor = inputs[:max_num_images_to_save] 424 | 425 | # output_tensor = outputs[:max_num_images_to_save] 426 | # output_tensor_rgb = normal_to_rgb(output_tensor) 427 | 428 | # label_tensor = labels[:max_num_images_to_save] 429 | # label_tensor_rgb = normal_to_rgb(label_tensor) 430 | 431 | # images = torch.cat((img_tensor, output_tensor_rgb, label_tensor_rgb), dim=3) 432 | # grid_image = make_grid(images, 1, normalize=True, scale_each=True) 433 | 434 | # return grid_image 435 | 436 | class label2color(object): 437 | def __init__(self,class_num): 438 | self.class_num = class_num 439 | 440 | self.colors = self.create_pascal_label_colormap(self.class_num) 441 | 442 | def to_color_img(self,imgs): 443 | # img:bs,3,height,width 444 | color_imgs = [] 445 | for i in range(imgs.shape[0]): 446 | score_i = imgs[i,...] 447 | score_i = score_i.cpu().numpy() 448 | score_i = np.transpose(score_i,(1,2,0)) 449 | # np.save('pre.npy',score_i) 450 | score_i = np.argmax(score_i,axis=2) 451 | color_imgs.append(self.colors[score_i]) 452 | return color_imgs 453 | 454 | def single_img_color(self,img): 455 | score_i = img 456 | score_i = score_i.numpy() 457 | score_i = np.transpose(score_i,(1,2,0)) 458 | # np.save('pre.npy',score_i) 459 | # score_i = np.argmax(score_i,axis=2) 460 | return self.colors[score_i] 461 | 462 | def bit_get(self,val, idx): 463 | """Gets the bit value. 464 | Args: 465 | val: Input value, int or numpy int array. 466 | idx: Which bit of the input val. 467 | Returns: 468 | The "idx"-th bit of input val. 469 | """ 470 | return (val >> idx) & 1 471 | 472 | def create_pascal_label_colormap(self,class_num): 473 | """Creates a label colormap used in PASCAL VOC segmentation benchmark. 474 | Returns: 475 | A colormap for visualizing segmentation results. 476 | """ 477 | colormap = np.zeros((class_num, 3), dtype=int) 478 | ind = np.arange(class_num, dtype=int) 479 | 480 | for shift in reversed(range(8)): 481 | for channel in range(3): 482 | colormap[:, channel] |= self.bit_get(ind, channel) << shift 483 | ind >>= 3 484 | 485 | return colormap 486 | -------------------------------------------------------------------------------- /SwinDRNet/utils/create_image_grid.py: -------------------------------------------------------------------------------- 1 | '''Functions for reading and saving EXR images using OpenEXR. 2 | ''' 3 | 4 | import sys 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from torchvision.utils import make_grid 9 | import torch.nn as nn 10 | from utils.api_utils import depth2rgb, label2color 11 | import imageio 12 | import os 13 | 14 | sys.path.append('../..') 15 | 16 | def seg_mask_to_rgb(seg_mask, num_classes): 17 | l2c = label2color(num_classes + 1) 18 | seg_mask_color = np.zeros((seg_mask.shape[0], 3, seg_mask.shape[2], seg_mask.shape[3])) 19 | for i in range(seg_mask.shape[0]): 20 | color = l2c.single_img_color(seg_mask[i])#.squeeze(2).transpose(2,0,1).unsqueeze(0) 21 | color = np.squeeze(color,axis=2) 22 | color = color.transpose((2,0,1)) 23 | color = color[np.newaxis,:,:,:] 24 | seg_mask_color[i] = color 25 | seg_mask_color = torch.from_numpy(seg_mask_color) 26 | return seg_mask_color 27 | 28 | def xyz_to_rgb(xyz_map): 29 | xyz_rgb = torch.ones_like(xyz_map) 30 | for i in range(xyz_rgb.shape[0]): 31 | xyz_rgb[i] = torch.div((xyz_map[i] - xyz_map[i].min()), 32 | (xyz_map[i].max() - xyz_map[i].min()).item()) 33 | return xyz_rgb 34 | 35 | def normal_to_rgb(normals_to_convert): 36 | '''Converts a surface normals array into an RGB image. 37 | Surface normals are represented in a range of (-1,1), 38 | This is converted to a range of (0,255) to be written 39 | into an image. 40 | The surface normals are normally in camera co-ords, 41 | with positive z axis coming out of the page. And the axes are 42 | mapped as (x,y,z) -> (R,G,B). 43 | 44 | Args: 45 | normals_to_convert (numpy.ndarray): Surface normals, dtype float32, range [-1, 1] 46 | ''' 47 | camera_normal_rgb = (normals_to_convert + 1) / 2 48 | return camera_normal_rgb 49 | 50 | def create_grid_image(inputs, outputs, labels, rgb, normal_pred, normal_labels, confidence_1, confidence_2, sem_masks=None, pred_sem_seg=None, coords=None, pred_coords=None, max_num_images_to_save=3): 51 | '''Make a grid of images for display purposes 52 | Size of grid is (3, N, 3), where each coloum belongs to input, output, label resp 53 | 54 | Args: 55 | inputs (Tensor): Batch Tensor of shape (B x C x H x W) 56 | outputs (Tensor): Batch Tensor of shape (B x C x H x W) 57 | labels (Tensor): Batch Tensor of shape (B x C x H x W) 58 | max_num_images_to_save (int, optional): Defaults to 3. Out of the given tensors, chooses a 59 | max number of imaged to put in grid 60 | 61 | Returns: 62 | numpy.ndarray: A numpy array with of input images arranged in a grid 63 | ''' 64 | 65 | rgb_tensor = rgb[:max_num_images_to_save] 66 | normal_pred_tensor = normal_pred[:max_num_images_to_save] 67 | normal_pred_tensor = normal_to_rgb(normal_pred_tensor) 68 | normal_labels_tensor = normal_labels[:max_num_images_to_save] 69 | normal_labels_tensor = normal_to_rgb(normal_labels_tensor) 70 | 71 | if not coords is None: 72 | coords_tensor = coords[:max_num_images_to_save] 73 | pred_coords_tensor = pred_coords[:max_num_images_to_save] 74 | 75 | pred_sem_seg_tensor = pred_sem_seg[:max_num_images_to_save] 76 | sem_masks_tensor = sem_masks[:max_num_images_to_save] 77 | pred_sem_seg_tensor = seg_mask_to_rgb(pred_sem_seg_tensor, 8) 78 | sem_masks_tensor = seg_mask_to_rgb(sem_masks_tensor, 8) 79 | 80 | img_tensor = inputs[:max_num_images_to_save] 81 | output_tensor = outputs[:max_num_images_to_save] 82 | label_tensor = labels[:max_num_images_to_save] 83 | confidence_1_tensor = confidence_1[:max_num_images_to_save] 84 | confidence_2_tensor = confidence_2[:max_num_images_to_save] 85 | 86 | 87 | output_tensor[output_tensor < 0] = 0 88 | output_tensor[output_tensor > 4] = 0 89 | 90 | label_tensor[label_tensor < 0] = 0 91 | label_tensor[label_tensor > 4] = 4 92 | 93 | img_tensor = xyz_to_rgb(img_tensor) 94 | 95 | 96 | output_tensor = output_tensor.repeat(1, 3, 1, 1) 97 | label_tensor = label_tensor.repeat(1, 3, 1, 1) 98 | confidence_1_tensor = confidence_1_tensor.repeat(1, 3, 1, 1) 99 | confidence_2_tensor = confidence_2_tensor.repeat(1, 3, 1, 1) 100 | 101 | if coords is None: 102 | images = torch.cat((img_tensor, confidence_1_tensor, confidence_2_tensor, output_tensor, \ 103 | label_tensor, rgb_tensor, normal_pred_tensor, normal_labels_tensor), dim=3) 104 | else: 105 | images = torch.cat((img_tensor, confidence_1_tensor, confidence_2_tensor, output_tensor, \ 106 | label_tensor, rgb_tensor, normal_pred_tensor, normal_labels_tensor, pred_coords_tensor, \ 107 | coords_tensor, pred_sem_seg_tensor, sem_masks_tensor), dim=3) 108 | 109 | # grid_image = make_grid(images, 1, normalize=True, scale_each=True) 110 | grid_image = make_grid(images, 1, normalize=False, scale_each=False) 111 | 112 | return grid_image 113 | -------------------------------------------------------------------------------- /SwinDRNet/utils/metrics_depth_restoration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | import torch.nn.functional as F 5 | 6 | def get_metrics_depth_restoration_train(gt, pred, width, height, seg_mask=None): 7 | 8 | gt = gt.detach().permute(0, 2, 3, 1).cpu().numpy().astype("float32") 9 | pred = pred.detach().permute(0, 2, 3, 1).cpu().numpy().astype("float32") 10 | if not seg_mask is None: 11 | seg_mask = seg_mask.detach().cpu().permute(0, 2, 3, 1).numpy() 12 | 13 | gt_depth = gt 14 | pred_depth = pred 15 | gt_depth[np.isnan(gt_depth)] = 0 16 | gt_depth[np.isinf(gt_depth)] = 0 17 | mask_valid_region = (gt_depth > 0) 18 | 19 | if not seg_mask is None: 20 | seg_mask = seg_mask.astype(np.uint8) 21 | mask_valid_region = np.logical_and(mask_valid_region, seg_mask) 22 | 23 | gt = torch.from_numpy(gt_depth).float().cuda() 24 | pred = torch.from_numpy(pred_depth).float().cuda() 25 | mask = torch.from_numpy(mask_valid_region).bool().cuda() 26 | gt = gt[mask] 27 | pred = pred[mask] 28 | 29 | thresh = torch.max(gt / pred, pred / gt) 30 | 31 | a1 = (thresh < 1.05).float().mean() 32 | a2 = (thresh < 1.10).float().mean() 33 | a3 = (thresh < 1.25).float().mean() 34 | 35 | rmse = ((gt - pred)**2).mean().sqrt() 36 | abs_rel = ((gt - pred).abs() / gt).mean() 37 | mae = (gt - pred).abs().mean() 38 | 39 | return a1, a2, a3, rmse, abs_rel, mae 40 | 41 | 42 | def get_metrics_depth_restoration_inference(gt, pred, width, height, seg_mask=None): 43 | B = gt.shape[0] 44 | gt = F.interpolate(gt, size=[width, height], mode="nearest") 45 | pred = F.interpolate(pred, size=[width, height], mode="nearest") 46 | 47 | gt = gt.detach().permute(0, 2, 3, 1).cpu().numpy().astype("float32") 48 | pred = pred.detach().permute(0, 2, 3, 1).cpu().numpy().astype("float32") 49 | if not seg_mask is None: 50 | seg_mask = seg_mask.float() 51 | seg_mask = F.interpolate(seg_mask, size=[width, height], mode="nearest") 52 | seg_mask = seg_mask.detach().cpu().permute(0, 2, 3, 1).numpy() 53 | 54 | gt_depth = gt 55 | pred_depth = pred 56 | gt_depth[np.isnan(gt_depth)] = 0 57 | gt_depth[np.isinf(gt_depth)] = 0 58 | mask_valid_region = (gt_depth > 0) 59 | 60 | if not seg_mask is None: 61 | seg_mask = seg_mask.astype(np.uint8) 62 | mask_valid_region = np.logical_and(mask_valid_region, seg_mask) 63 | 64 | gt = torch.from_numpy(gt_depth).float().cuda() 65 | pred = torch.from_numpy(pred_depth).float().cuda() 66 | mask = torch.from_numpy(mask_valid_region).bool().cuda() 67 | 68 | a1 = 0.0 69 | a2 = 0.0 70 | a3 = 0.0 71 | rmse = 0.0 72 | abs_rel = 0.0 73 | mae = 0.0 74 | 75 | num_valid = 0 76 | 77 | for i in range(B): 78 | gt_i = gt[i][mask[i]] 79 | pred_i = pred[i][mask[i]] 80 | # print(len(gt_i)) 81 | 82 | if len(gt_i) > 0: 83 | num_valid += 1 84 | thresh = torch.max(gt_i / pred_i, pred_i / gt_i) 85 | 86 | a1_i = (thresh < 1.05).float().mean() 87 | a2_i = (thresh < 1.10).float().mean() 88 | a3_i = (thresh < 1.25).float().mean() 89 | 90 | rmse_i = ((gt_i - pred_i)**2).mean().sqrt() 91 | abs_rel_i = ((gt_i - pred_i).abs() / gt_i).mean() 92 | mae_i = (gt_i - pred_i).abs().mean() 93 | a1 += a1_i 94 | a2 += a2_i 95 | a3 += a3_i 96 | rmse += rmse_i 97 | abs_rel += abs_rel_i 98 | mae += mae_i 99 | # print(a1.item(), a2.item(), a3.item(), rmse.item(), abs_rel.item(), mae.item()) 100 | 101 | return a1, a2, a3, rmse, abs_rel, mae, num_valid -------------------------------------------------------------------------------- /SwinDRNet/utils/metrics_sem_seg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def intersect_and_union(pred_label, 5 | label, 6 | num_classes, 7 | ignore_index, 8 | label_map=dict(), 9 | reduce_zero_label=False): 10 | """Calculate intersection and Union. 11 | 12 | Args: 13 | pred_label (ndarray): Prediction segmentation map. 14 | label (ndarray): Ground truth segmentation map. 15 | num_classes (int): Number of categories. 16 | ignore_index (int): Index that will be ignored in evaluation. 17 | label_map (dict): Mapping old labels to new labels. The parameter will 18 | work only when label is str. Default: dict(). 19 | reduce_zero_label (bool): Wether ignore zero label. The parameter will 20 | work only when label is str. Default: False. 21 | 22 | Returns: 23 | ndarray: The intersection of prediction and ground truth histogram 24 | on all classes. 25 | ndarray: The union of prediction and ground truth histogram on all 26 | classes. 27 | ndarray: The prediction histogram on all classes. 28 | ndarray: The ground truth histogram on all classes. 29 | """ 30 | # modify if custom classes 31 | if label_map is not None: 32 | for old_id, new_id in label_map.items(): 33 | label[label == old_id] = new_id 34 | if reduce_zero_label: 35 | # avoid using underflow conversion 36 | label[label == 0] = 255 37 | label = label - 1 38 | label[label == 254] = 255 39 | 40 | mask = (label != ignore_index) 41 | pred_label = pred_label[mask] 42 | label = label[mask] 43 | 44 | intersect = pred_label[pred_label == label] 45 | area_intersect, _ = np.histogram( 46 | intersect, bins=np.arange(num_classes + 1)) 47 | area_pred_label, _ = np.histogram( 48 | pred_label, bins=np.arange(num_classes + 1)) 49 | area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1)) 50 | area_union = area_pred_label + area_label - area_intersect 51 | 52 | return area_intersect, area_union, area_pred_label, area_label 53 | 54 | 55 | def total_intersect_and_union(results, 56 | gt_seg_maps, 57 | num_classes, 58 | ignore_index, 59 | label_map=dict(), 60 | reduce_zero_label=False): 61 | """Calculate Total Intersection and Union. 62 | 63 | Args: 64 | results (list[ndarray]): List of prediction segmentation maps. 65 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 66 | num_classes (int): Number of categories. 67 | ignore_index (int): Index that will be ignored in evaluation. 68 | label_map (dict): Mapping old labels to new labels. Default: dict(). 69 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 70 | 71 | Returns: 72 | ndarray: The intersection of prediction and ground truth histogram 73 | on all classes. 74 | ndarray: The union of prediction and ground truth histogram on all 75 | classes. 76 | ndarray: The prediction histogram on all classes. 77 | ndarray: The ground truth histogram on all classes. 78 | """ 79 | 80 | num_imgs = len(results) 81 | assert len(gt_seg_maps) == num_imgs 82 | total_area_intersect = np.zeros((num_classes, ), dtype=np.float) 83 | total_area_union = np.zeros((num_classes, ), dtype=np.float) 84 | total_area_pred_label = np.zeros((num_classes, ), dtype=np.float) 85 | total_area_label = np.zeros((num_classes, ), dtype=np.float) 86 | for i in range(num_imgs): 87 | area_intersect, area_union, area_pred_label, area_label = \ 88 | intersect_and_union(results[i], gt_seg_maps[i], num_classes, 89 | ignore_index, label_map, reduce_zero_label) 90 | total_area_intersect += area_intersect 91 | total_area_union += area_union 92 | total_area_pred_label += area_pred_label 93 | total_area_label += area_label 94 | return total_area_intersect, total_area_union, \ 95 | total_area_pred_label, total_area_label 96 | 97 | 98 | def mean_iou(results, 99 | gt_seg_maps, 100 | num_classes, 101 | ignore_index, 102 | nan_to_num=None, 103 | label_map=dict(), 104 | reduce_zero_label=False): 105 | """Calculate Mean Intersection and Union (mIoU) 106 | 107 | Args: 108 | results (list[ndarray]): List of prediction segmentation maps. 109 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 110 | num_classes (int): Number of categories. 111 | ignore_index (int): Index that will be ignored in evaluation. 112 | nan_to_num (int, optional): If specified, NaN values will be replaced 113 | by the numbers defined by the user. Default: None. 114 | label_map (dict): Mapping old labels to new labels. Default: dict(). 115 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 116 | 117 | Returns: 118 | float: Overall accuracy on all images. 119 | ndarray: Per category accuracy, shape (num_classes, ). 120 | ndarray: Per category IoU, shape (num_classes, ). 121 | """ 122 | 123 | all_acc, acc, iou = eval_metrics( 124 | results=results, 125 | gt_seg_maps=gt_seg_maps, 126 | num_classes=num_classes, 127 | ignore_index=ignore_index, 128 | metrics=['mIoU'], 129 | nan_to_num=nan_to_num, 130 | label_map=label_map, 131 | reduce_zero_label=reduce_zero_label) 132 | return all_acc, acc, iou 133 | 134 | 135 | def mean_dice(results, 136 | gt_seg_maps, 137 | num_classes, 138 | ignore_index, 139 | nan_to_num=None, 140 | label_map=dict(), 141 | reduce_zero_label=False): 142 | """Calculate Mean Dice (mDice) 143 | 144 | Args: 145 | results (list[ndarray]): List of prediction segmentation maps. 146 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 147 | num_classes (int): Number of categories. 148 | ignore_index (int): Index that will be ignored in evaluation. 149 | nan_to_num (int, optional): If specified, NaN values will be replaced 150 | by the numbers defined by the user. Default: None. 151 | label_map (dict): Mapping old labels to new labels. Default: dict(). 152 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 153 | 154 | Returns: 155 | float: Overall accuracy on all images. 156 | ndarray: Per category accuracy, shape (num_classes, ). 157 | ndarray: Per category dice, shape (num_classes, ). 158 | """ 159 | 160 | all_acc, acc, dice = eval_metrics( 161 | results=results, 162 | gt_seg_maps=gt_seg_maps, 163 | num_classes=num_classes, 164 | ignore_index=ignore_index, 165 | metrics=['mDice'], 166 | nan_to_num=nan_to_num, 167 | label_map=label_map, 168 | reduce_zero_label=reduce_zero_label) 169 | return all_acc, acc, dice 170 | 171 | 172 | def get_metrics_sem_seg(results, 173 | gt_seg_maps, 174 | num_classes, 175 | ignore_index, 176 | metrics=['mIoU'], 177 | nan_to_num=None, 178 | label_map=dict(), 179 | reduce_zero_label=False): 180 | """Calculate evaluation metrics 181 | Args: 182 | results (list[ndarray]): List of prediction segmentation maps. 183 | gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. 184 | num_classes (int): Number of categories. 185 | ignore_index (int): Index that will be ignored in evaluation. 186 | metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. 187 | nan_to_num (int, optional): If specified, NaN values will be replaced 188 | by the numbers defined by the user. Default: None. 189 | label_map (dict): Mapping old labels to new labels. Default: dict(). 190 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 191 | Returns: 192 | float: Overall accuracy on all images. 193 | ndarray: Per category accuracy, shape (num_classes, ). 194 | ndarray: Per category evalution metrics, shape (num_classes, ). 195 | """ 196 | 197 | if isinstance(metrics, str): 198 | metrics = [metrics] 199 | allowed_metrics = ['mIoU', 'mDice'] 200 | if not set(metrics).issubset(set(allowed_metrics)): 201 | raise KeyError('metrics {} is not supported'.format(metrics)) 202 | total_area_intersect, total_area_union, total_area_pred_label, \ 203 | total_area_label = total_intersect_and_union(results, gt_seg_maps, 204 | num_classes, ignore_index, 205 | label_map, 206 | reduce_zero_label) 207 | all_acc = total_area_intersect.sum() / total_area_label.sum() 208 | acc = total_area_intersect / total_area_label 209 | ret_metrics = [all_acc, acc] 210 | for metric in metrics: 211 | if metric == 'mIoU': 212 | iou = total_area_intersect / total_area_union 213 | ret_metrics.append(iou) 214 | elif metric == 'mDice': 215 | dice = 2 * total_area_intersect / ( 216 | total_area_pred_label + total_area_label) 217 | ret_metrics.append(dice) 218 | if nan_to_num is not None: 219 | ret_metrics = [ 220 | np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics 221 | ] 222 | return ret_metrics -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/DREDS/1b0ac3091dbf0af91a48ce28d000fb4e2fe31350/images/teaser.png --------------------------------------------------------------------------------