├── SSR ├── datasets │ ├── __init__.py │ ├── replica │ │ ├── __init__.py │ │ └── replica_datasets.py │ ├── scannet │ │ ├── __init__.py │ │ ├── scannet_reader.py │ │ ├── scannet_utils.py │ │ └── scannet_datasets.py │ └── replica_nyu │ │ ├── __init__.py │ │ └── replica_nyu_cnn_datasets.py ├── geometry │ ├── __init__.py │ └── occupancy.py ├── utils │ ├── __init__.py │ ├── ndc_derivation.pdf │ └── image_utils.py ├── visualisation │ ├── __init__.py │ ├── tensorboard_vis.py │ └── open3d_utils.py ├── training │ ├── __init__.py │ └── training_utils.py ├── __init__.py ├── models │ ├── __init__.py │ ├── model_utils.py │ ├── semantic_nerf.py │ └── rays.py ├── data_generation │ ├── replica_render_config_vMAP.yaml │ ├── README.md │ ├── extract_inst_obj.py │ ├── transformation.py │ ├── settings.py │ └── habitat_renderer.py ├── configs │ ├── SSR_room0_config.yaml │ └── SSR_ScanNet_config.yaml └── extract_colour_mesh.py ├── imgs ├── teaser.png └── sem_mesh_room0.png ├── requirements.txt ├── .gitignore ├── README.md ├── LICENSE └── train_SSR_main.py /SSR/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SSR/datasets/replica/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SSR/datasets/scannet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SSR/datasets/replica_nyu/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SSR/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from . import occupancy -------------------------------------------------------------------------------- /SSR/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import image_utils -------------------------------------------------------------------------------- /SSR/visualisation/__init__.py: -------------------------------------------------------------------------------- 1 | from . import open3d_utils 2 | -------------------------------------------------------------------------------- /SSR/training/__init__.py: -------------------------------------------------------------------------------- 1 | from . import trainer 2 | from . import training_utils -------------------------------------------------------------------------------- /imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Harry-Zhi/semantic_nerf/HEAD/imgs/teaser.png -------------------------------------------------------------------------------- /imgs/sem_mesh_room0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Harry-Zhi/semantic_nerf/HEAD/imgs/sem_mesh_room0.png -------------------------------------------------------------------------------- /SSR/utils/ndc_derivation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Harry-Zhi/semantic_nerf/HEAD/SSR/utils/ndc_derivation.pdf -------------------------------------------------------------------------------- /SSR/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configs 2 | from . import datasets 3 | from . import geometry 4 | from . import models 5 | from . import training 6 | from . import utils 7 | from . import visualisation 8 | __version__ = "0.0.1" -------------------------------------------------------------------------------- /SSR/models/__init__.py: -------------------------------------------------------------------------------- 1 | # from . import iMAP_model_utils 2 | # from . import iMAP_nerf 3 | # from . import semantic_nerf 4 | # from . import model_utils 5 | # from . import rays 6 | 7 | # we do not pre-load iMAP_model_utils and model_utils here since they contain functions with the same name and will cause conflicts -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision==0.7.0 3 | tensorboard==2.4.1 4 | imageio==2.9.0 5 | imageio-ffmpeg==0.4.2 6 | matplotlib==3.3.2 7 | scikit-image==0.17.2 8 | scikit-learn==0.23.2 9 | tqdm==4.54.1 10 | tensorboard==2.4.1 11 | pyyaml==5.3.1 12 | trimesh==3.9.9 13 | imgviz==1.2.2 14 | open3d==0.12.0 15 | opencv-python==4.4.0.44 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints 2 | **/__pycache__ 3 | *.mp4 4 | *.npy 5 | *.npz 6 | *.dae 7 | data/* 8 | logs/* 9 | 10 | .idea/ 11 | .anaconda3/ 12 | SSR/data/ 13 | SSR/results/ 14 | # Compiled python modules. 15 | *.pyc 16 | 17 | # Setuptools distribution folder. 18 | /dist/ 19 | 20 | # vim 21 | **/*.swp 22 | 23 | # vscode 24 | .vscode/ 25 | ../.vscode/ 26 | 27 | # Python egg metadata, regenerated from source files by setuptools. 28 | /*.egg-info 29 | 30 | *.json 31 | 32 | SSR/configs/SSR_room0_config_test.yaml 33 | -------------------------------------------------------------------------------- /SSR/visualisation/tensorboard_vis.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | import os 3 | import yaml 4 | 5 | class TFVisualizer(object): 6 | def __init__(self, log_dir, vis_interval, config): 7 | self.tb_writer = SummaryWriter(log_dir=os.path.join(log_dir)) 8 | self.vis_interval = vis_interval 9 | self.config = config 10 | 11 | # dump args to tensorboard 12 | args_str = '{}'.format(yaml.dump(config, sort_keys=False, indent=4)) 13 | self.tb_writer.add_text('Exp_args', args_str, 0) 14 | 15 | def vis_scalars(self, i_iter, losses, names): 16 | for i, loss in enumerate(losses): 17 | self.tb_writer.add_scalar(names[i], loss, i_iter) 18 | 19 | 20 | def vis_histogram(self, i_iter, value, names): 21 | self.tb_writer.add_histogram(tag=names, values=value, global_step=i_iter) 22 | -------------------------------------------------------------------------------- /SSR/data_generation/replica_render_config_vMAP.yaml: -------------------------------------------------------------------------------- 1 | # Agent settings 2 | default_agent: 0 3 | gpu_id: 0 4 | width: 1200 #1280 5 | height: 680 #960 6 | sensor_height: 0 7 | 8 | color_sensor: true # RGB sensor 9 | semantic_sensor: true # Semantic sensor 10 | depth_sensor: true # Depth sensor 11 | enable_semantics: true 12 | 13 | # room_0 14 | scene_file: "/home/xin/data/vmap/room_0/habitat/mesh_semantic.ply" 15 | instance2class_mapping: "/home/xin/data/vmap/room_0/habitat/info_semantic.json" 16 | save_path: "/home/xin/data/vmap/room_0/vmap/00/" 17 | pose_file: "/home/xin/data/vmap/room_0/vmap/00/traj_w_c.txt" 18 | ## HDR texture 19 | ## issue https://github.com/facebookresearch/Replica-Dataset/issues/41#issuecomment-566251467 20 | #scene_file: "/home/xin/data/vmap/room_0/mesh.ply" 21 | #instance2class_mapping: "/home/xin/data/vmap/room_0/habitat/info_semantic.json" 22 | #save_path: "/home/xin/data/vmap/room_0/vmap/00/" 23 | #pose_file: "/home/xin/data/vmap/room_0/vmap/00/traj_w_c.txt" -------------------------------------------------------------------------------- /SSR/data_generation/README.md: -------------------------------------------------------------------------------- 1 | ## Replica Data Generation 2 | 3 | ### Download Replica Dataset 4 | Download 3D models and info files from [Replica](https://github.com/facebookresearch/Replica-Dataset) 5 | 6 | ### 3D Object Mesh Extraction 7 | Change the input path in `./data_generation/extract_inst_obj.py` and run 8 | ```angular2html 9 | python ./data_generation/extract_inst_obj.py 10 | ``` 11 | 12 | ### Camera Trajectory Generation 13 | Please refer to [Semantic-NeRF](https://github.com/Harry-Zhi/semantic_nerf/issues/25#issuecomment-1340595427) for more details. The random trajectory generation only works for single room scene. For multiple rooms scene, collision checking is needed. Welcome contributions. 14 | 15 | ### Rendering 2D Images 16 | Given camera trajectory t_wc (change pose_file in configs), we use [Habitat-Sim](https://github.com/facebookresearch/habitat-sim) to render RGB, Depth, Semantic and Instance images. 17 | 18 | #### Install Habitat-Sim 0.2.1 19 | We recommend to use conda to install habitat-sim 0.2.1. 20 | ```angular2html 21 | conda create -n habitat python=3.8.12 cmake=3.14.0 22 | conda activate habitat 23 | conda install habitat-sim=0.2.1 withbullet -c conda-forge -c aihabitat 24 | conda install numba=0.54.1 25 | ``` 26 | 27 | #### Run rendering with configs 28 | ```angular2html 29 | python ./data_generation/habitat_renderer.py --config ./data_generation/replica_render_config_vMAP.yaml 30 | ``` 31 | Note that to get HDR img, use mesh.ply not semantic_mesh.ply. Change path in configs. And copy rgb folder to replace previous high exposure rgb. 32 | ```angular2html 33 | python ./data_generation/habitat_renderer.py --config ./data_generation/replica_render_config_vMAP.yaml 34 | ``` -------------------------------------------------------------------------------- /SSR/data_generation/extract_inst_obj.py: -------------------------------------------------------------------------------- 1 | # reference https://github.com/facebookresearch/Replica-Dataset/issues/17#issuecomment-538757418 2 | 3 | from plyfile import * 4 | import numpy as np 5 | import trimesh 6 | 7 | 8 | # path_in = 'path/to/mesh_semantic.ply' 9 | path_in = '/home/xin/data/vmap/room_0_debug/habitat/mesh_semantic.ply' 10 | 11 | print("Reading input...") 12 | mesh = trimesh.load(path_in) 13 | # mesh.show() 14 | file_in = PlyData.read(path_in) 15 | vertices_in = file_in.elements[0] 16 | faces_in = file_in.elements[1] 17 | 18 | print("Filtering data...") 19 | objects = {} 20 | sub_mesh_indices = {} 21 | for i, f in enumerate(faces_in): 22 | object_id = f[1] 23 | if not object_id in objects: 24 | objects[object_id] = [] 25 | sub_mesh_indices[object_id] = [] 26 | objects[object_id].append((f[0],)) 27 | sub_mesh_indices[object_id].append(i) 28 | sub_mesh_indices[object_id].append(i+faces_in.data.shape[0]) 29 | 30 | 31 | print("Writing data...") 32 | for object_id, faces in objects.items(): 33 | path_out = path_in + f"_{object_id}.ply" 34 | # print("sub_mesh_indices[object_id] ", sub_mesh_indices[object_id]) 35 | obj_mesh = mesh.submesh([sub_mesh_indices[object_id]], append=True) 36 | in_n = len(sub_mesh_indices[object_id]) 37 | out_n = obj_mesh.faces.shape[0] 38 | # print("obj id ", object_id) 39 | # print("in_n ", in_n) 40 | # print("out_n ", out_n) 41 | # print("faces ", len(faces)) 42 | # assert in_n == out_n 43 | obj_mesh.export(path_out) 44 | # faces_out = PlyElement.describe(np.array(faces, dtype=[('vertex_indices', 'O')]), 'face') 45 | # print("faces out ", len(PlyData([vertices_in, faces_out]).elements[1].data)) 46 | # PlyData([vertices_in, faces_out]).write(path_out+"_cmp.ply") 47 | 48 | -------------------------------------------------------------------------------- /SSR/data_generation/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import quaternion 3 | import trimesh 4 | 5 | def habitat_world_transformations(): 6 | import habitat_sim 7 | # Transforms between the habitat frame H (y-up) and the world frame W (z-up). 8 | T_wh = np.identity(4) 9 | 10 | # https://stackoverflow.com/questions/1171849/finding-quaternion-representing-the-rotation-from-one-vector-to-another 11 | T_wh[0:3, 0:3] = quaternion.as_rotation_matrix(habitat_sim.utils.common.quat_from_two_vectors( 12 | habitat_sim.geo.GRAVITY, np.array([0.0, 0.0, -1.0]))) 13 | 14 | T_hw = np.linalg.inv(T_wh) 15 | 16 | return T_wh, T_hw 17 | 18 | def opencv_to_opengl_camera(transform=None): 19 | if transform is None: 20 | transform = np.eye(4) 21 | return transform @ trimesh.transformations.rotation_matrix( 22 | np.deg2rad(180), [1, 0, 0] 23 | ) 24 | 25 | def opengl_to_opencv_camera(transform=None): 26 | if transform is None: 27 | transform = np.eye(4) 28 | return transform @ trimesh.transformations.rotation_matrix( 29 | np.deg2rad(-180), [1, 0, 0] 30 | ) 31 | 32 | def Twc_to_Thc(T_wc): # opencv-camera to world transformation ---> habitat-caemra to habitat world transformation 33 | T_wh, T_hw = habitat_world_transformations() 34 | T_hc = T_hw @ T_wc @ opengl_to_opencv_camera() 35 | return T_hc 36 | 37 | 38 | def Thc_to_Twc(T_hc): # habitat-caemra to habitat world transformation ---> opencv-camera to world transformation 39 | T_wh, T_hw = habitat_world_transformations() 40 | T_wc = T_wh @ T_hc @ opencv_to_opengl_camera() 41 | return T_wc 42 | 43 | 44 | def combine_pose(t: np.array, q: quaternion.quaternion) -> np.array: 45 | T = np.identity(4) 46 | T[0:3, 3] = t 47 | T[0:3, 0:3] = quaternion.as_rotation_matrix(q) 48 | return T -------------------------------------------------------------------------------- /SSR/configs/SSR_room0_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | experiment: 3 | scene_file: "PATHtoREPLICA/Replica/mesh/room_0/habitat/" # room_0,room_01, etc. 4 | save_dir: "PATHtoLOGS" # where to store ckpts and rendering 5 | dataset_dir: "PATHtoRENDERED_REPLICA_DATA" 6 | convention: "opencv" 7 | width: 320 8 | height: 240 9 | gpu: "0" 10 | 11 | enable_semantic: True 12 | enable_depth: True 13 | endpoint_feat: False 14 | 15 | model: 16 | netdepth: 8 17 | netwidth: 256 18 | netdepth_fine: 8 19 | netwidth_fine: 256 20 | chunk: 1024*128 # number of rays processed in parallel, decrease if running out of memory 21 | netchunk: 1024*128 # number of pts sent through network in parallel, decrease if running out of memory 22 | 23 | render: 24 | N_rays: 32*32*1 # average number of rays sampled from each sample within a batch 25 | N_samples: 64 # Number of different times to sample along each ray. 26 | N_importance: 128 # Number of additional fine samples per ray 27 | perturb: 1 28 | use_viewdirs: true 29 | i_embed: 0 # 'set 0 for default positional encoding, -1 for none' 30 | multires: 10 # log2 of max freq for positional encoding (3D location)' 31 | multires_views: 4 # 'log2 of max freq for positional encoding (2D direction)' 32 | raw_noise_std: 1 # 'std dev of noise added to regularize sigma_a output, 1e0 recommended') 33 | test_viz_factor: 1 # down scaling factor when rendering test and training images 34 | no_batching: True # True-sample random pixels from random images; False-sample from all random pixels from all images 35 | depth_range: [0.1, 10.0] 36 | white_bkgd: false # set to render synthetic data on a white bkgd (always use for dvoxels) 37 | 38 | train: 39 | lrate: 5e-4 40 | lrate_decay: 250e3 41 | N_iters: 200000 42 | wgt_sem: 4e-2 43 | 44 | 45 | 46 | logging: # logging/saving options 47 | step_log_print: 1 # 'frequency of console print' 48 | step_log_tfb: 500 49 | step_save_ckpt: 20000 50 | step_val: 5000 # frequency of rendering on unseen data 51 | step_vis_train: 5000 -------------------------------------------------------------------------------- /SSR/datasets/scannet/scannet_reader.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/ScanNet/ScanNet/blob/master/SensReader/python/reader.py 2 | # github: https://github.com/ScanNet/ScanNet/tree/master/SensReader/python 3 | # python 2.7 is recommended. 4 | 5 | 6 | import argparse 7 | import os, sys 8 | 9 | import os 10 | import numpy as np 11 | import argparse 12 | import random 13 | from tqdm import tqdm 14 | 15 | from SensorData import SensorData 16 | 17 | def parse_raw_data(output_path, data_filename): 18 | if not os.path.exists(output_path): 19 | os.makedirs(output_path) 20 | # load the data 21 | sys.stdout.write('loading %s...' % data_filename) 22 | sd = SensorData(data_filename) 23 | sys.stdout.write('loaded!\n') 24 | if opt.export_depth_images: 25 | sd.export_depth_images(os.path.join(output_path, 'depth')) 26 | if opt.export_color_images: 27 | sd.export_color_images(os.path.join(output_path, 'color')) 28 | if opt.export_poses: 29 | sd.export_poses(os.path.join(output_path, 'pose')) 30 | if opt.export_intrinsics: 31 | sd.export_intrinsics(os.path.join(output_path, 'intrinsic')) 32 | 33 | 34 | # params 35 | parser = argparse.ArgumentParser() 36 | # data paths 37 | parser.add_argument('--export_depth_images', dest='export_depth_images', action='store_true') 38 | parser.add_argument('--export_color_images', dest='export_color_images', action='store_true') 39 | parser.add_argument('--export_poses', dest='export_poses', action='store_true') 40 | parser.add_argument('--export_intrinsics', dest='export_intrinsics', action='store_true') 41 | parser.set_defaults(export_depth_images=True, export_color_images=True, export_poses=True, export_intrinsics=True) 42 | 43 | 44 | opt = parser.parse_args() 45 | print(opt) 46 | 47 | 48 | data_dir = "PATH_TO_SCANNET/ScanNet/scans_val/" # path to list of scannet scenes 49 | val_seqs = os.listdir(data_dir) 50 | with open("PATH_TO_SCANNET/ScanNet/tasks/scannetv2_val.txt") as f: 51 | val_seq_ids = f.readlines() 52 | val_seq_ids = [s.strip() for s in val_seq_ids] 53 | 54 | for i in tqdm(range(len(val_seqs))): 55 | val_id = val_seqs[i] 56 | val_seq_dir = os.path.join(data_dir, val_id, "renders") 57 | raw_data_filename = os.path.join(data_dir, val_id, val_id+".sens") 58 | parse_raw_data(val_seq_dir, raw_data_filename) 59 | 60 | if __name__ == '__main__': 61 | main() -------------------------------------------------------------------------------- /SSR/configs/SSR_ScanNet_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | experiment: 3 | save_dir: "PATHtoLOGS" # where to store ckpts and rendering 4 | dataset_dir: "PATHtoScanNet_Scene_Folder" # e.g., "xxx/ScanNet/scans/scene0010_00" 5 | 6 | # All parsed scannet images per scenes are arranged into a unified folder called "renders" using scannet_reader.py, 7 | # where the subfolders "pose", "color", "depth" contains the corresponding data. 8 | # e.g., "xxx/ScanNet/scans/scene0010_00/renders/color/00001.jpg" 9 | 10 | sample_step: 100 # this is the sampling interval of the whole ScanNet sequence to determine the overall amount of training/testing images. 11 | convention: "opencv" 12 | width: 320 13 | height: 240 14 | gpu: "0" 15 | 16 | enable_semantic: True 17 | enable_depth: True 18 | endpoint_feat: False 19 | 20 | model: 21 | netdepth: 8 22 | netwidth: 256 23 | netdepth_fine: 8 24 | netwidth_fine: 256 25 | chunk: 1024*128 # number of rays processed in parallel, decrease if running out of memory 26 | netchunk: 1024*128 # number of pts sent through network in parallel, decrease if running out of memory 27 | 28 | render: 29 | N_rays: 32*32*1 # average number of rays sampled from each sample within a batch 30 | N_samples: 64 # Number of different times to sample along each ray. 31 | N_importance: 128 # Number of additional fine samples per ray 32 | perturb: 1 33 | use_viewdirs: true 34 | i_embed: 0 # 'set 0 for default positional encoding, -1 for none' 35 | multires: 10 # log2 of max freq for positional encoding (3D location)' 36 | multires_views: 4 # 'log2 of max freq for positional encoding (2D direction)' 37 | raw_noise_std: 1 # 'std dev of noise added to regularize sigma_a output, 1e0 recommended') 38 | test_viz_factor: 1 # down scaling factor when rendering test and training images 39 | no_batching: True # True-sample random pixels from random images; False-sample from all random pixels from all images 40 | depth_range: [0.1, 10.0] 41 | white_bkgd: false # set to render synthetic data on a white bkgd (always use for dvoxels) 42 | 43 | train: 44 | lrate: 5e-4 45 | lrate_decay: 250e3 46 | N_iters: 200000 47 | wgt_sem: 4e-2 48 | 49 | 50 | 51 | logging: # logging/saving options 52 | step_log_print: 1 # 'frequency of console print' 53 | step_log_tfb: 500 54 | step_save_ckpt: 20000 55 | step_val: 5000 # frequency of rendering on unseen data 56 | step_vis_train: 5000 -------------------------------------------------------------------------------- /SSR/datasets/scannet/scannet_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | import csv 5 | 6 | def load_scannet_label_mapping(path): 7 | """ Returns a dict mapping scannet category label strings to scannet Ids 8 | 9 | scene****_**.aggregation.json contains the category labels as strings 10 | so this maps the strings to the integer scannet Id 11 | 12 | Args: 13 | path: Path to the original scannet data. 14 | This is used to get scannetv2-labels.combined.tsv 15 | 16 | Returns: 17 | mapping: A dict from strings to ints 18 | example: 19 | {'wall': 1, 20 | 'chair: 2, 21 | 'books': 22} 22 | 23 | """ 24 | 25 | mapping = {} 26 | with open(os.path.join(path, 'scannetv2-labels.combined.tsv')) as tsvfile: 27 | tsvreader = csv.reader(tsvfile, delimiter='\t') 28 | for i, line in enumerate(tsvreader): 29 | if i==0: 30 | continue 31 | scannet_id, name = int(line[0]), line[1] 32 | mapping[name] = scannet_id 33 | 34 | return mapping 35 | 36 | 37 | def load_scannet_nyu40_mapping(path): 38 | """ Returns a dict mapping scannet Ids to NYU40 Ids 39 | 40 | Args: 41 | path: Path to the original scannet data. 42 | This is used to get scannetv2-labels.combined.tsv 43 | 44 | Returns: 45 | mapping: A dict from ints to ints 46 | example: 47 | {1: 1, 48 | 2: 5, 49 | 22: 23} 50 | 51 | """ 52 | 53 | mapping = {} 54 | with open(os.path.join(path, 'scannetv2-labels.combined.tsv')) as tsvfile: 55 | tsvreader = csv.reader(tsvfile, delimiter='\t') 56 | for i, line in enumerate(tsvreader): 57 | if i==0: 58 | continue 59 | scannet_id, nyu40id = int(line[0]), int(line[4]) 60 | mapping[scannet_id] = nyu40id 61 | return mapping 62 | 63 | 64 | def load_scannet_nyu13_mapping(path): 65 | """ Returns a dict mapping scannet Ids to NYU40 Ids 66 | 67 | Args: 68 | path: Path to the original scannet data. 69 | This is used to get scannetv2-labels.combined.tsv 70 | 71 | Returns: 72 | mapping: A dict from ints to ints 73 | example: 74 | {1: 1, 75 | 2: 5, 76 | 22: 23} 77 | 78 | """ 79 | 80 | mapping = {} 81 | with open(os.path.join(path, 'scannetv2-labels.combined.tsv')) as tsvfile: 82 | tsvreader = csv.reader(tsvfile, delimiter='\t') 83 | for i, line in enumerate(tsvreader): 84 | if i==0: 85 | continue 86 | scannet_id, nyu40id = int(line[0]), int(line[5]) 87 | mapping[scannet_id] = nyu40id 88 | return mapping -------------------------------------------------------------------------------- /SSR/geometry/occupancy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def grid_within_bound(occ_range, extents, transform, grid_dim): 6 | range_dist = occ_range[1] - occ_range[0] 7 | bounds_tranform_np = transform 8 | 9 | bounds_tranform = torch.from_numpy(bounds_tranform_np).float() 10 | scene_scale_np = extents / (range_dist * 0.9) 11 | scene_scale = torch.from_numpy(scene_scale_np).float() 12 | 13 | # todo: only make grid once, then only transform! 14 | grid_pc = make_3D_grid( 15 | occ_range, 16 | grid_dim, 17 | transform=bounds_tranform, 18 | scale=scene_scale, 19 | ) 20 | grid_pc = grid_pc.view(-1, 1, 3) 21 | 22 | return grid_pc, scene_scale 23 | 24 | def make_3D_grid(occ_range, dim, transform=None, scale=None): 25 | t = torch.linspace(occ_range[0], occ_range[1], steps=dim) 26 | grid = torch.meshgrid(t, t, t) 27 | grid_3d_norm = torch.cat( 28 | (grid[0][..., None], 29 | grid[1][..., None], 30 | grid[2][..., None]), dim=3 31 | ) 32 | 33 | if scale is not None: 34 | grid_3d = grid_3d_norm * scale 35 | if transform is not None: 36 | R1 = transform[None, None, None, 0, :3] 37 | R2 = transform[None, None, None, 1, :3] 38 | R3 = transform[None, None, None, 2, :3] 39 | 40 | grid1 = (R1 * grid_3d).sum(-1, keepdim=True) 41 | grid2 = (R2 * grid_3d).sum(-1, keepdim=True) 42 | grid3 = (R3 * grid_3d).sum(-1, keepdim=True) 43 | grid_3d = torch.cat([grid1, grid2, grid3], dim=-1) 44 | 45 | trans = transform[None, None, None, :3, 3] 46 | grid_3d = grid_3d + trans 47 | 48 | return grid_3d 49 | 50 | def make_3D_grid_np(occ_range, dim, device, transform=None, scale=None): 51 | t = torch.linspace(occ_range[0], occ_range[1], steps=dim, device=device) 52 | t = np.linspace(occ_range[0], occ_range[1], num=dim) 53 | grid = np.meshgrid(t, t, t) # list of 3 elements of shape [dim, dim, dim] 54 | 55 | grid_3d_norm = np.concatenate( 56 | (grid[0][..., None], 57 | grid[1][..., None], 58 | grid[2][..., None]), axis=3 59 | ) # shape of [dim, dim, dim, 3] 60 | 61 | if scale is not None: 62 | grid_3d = grid_3d_norm * scale 63 | if transform is not None: 64 | R1 = transform[None, None, None, 0, :3] 65 | R2 = transform[None, None, None, 1, :3] 66 | R3 = transform[None, None, None, 2, :3] 67 | 68 | grid1 = (R1 * grid_3d).sum(-1, keepdim=True) 69 | grid2 = (R2 * grid_3d).sum(-1, keepdim=True) 70 | grid3 = (R3 * grid_3d).sum(-1, keepdim=True) 71 | grid_3d = np.concatenate([grid1, grid2, grid3], dim=-1) 72 | 73 | trans = transform[None, None, None, :3, 3] 74 | grid_3d = grid_3d + trans 75 | 76 | return grid_3d 77 | 78 | 79 | 80 | def chunk_alphas(pc, chunk_size, fc_occ_map, n_embed_funcs, B_layer,): 81 | n_pts = pc.shape[0] 82 | n_chunks = int(np.ceil(n_pts / chunk_size)) 83 | alphas = [] 84 | for n in range(n_chunks): 85 | start = n * chunk_size 86 | end = start + chunk_size 87 | chunk = pc[start:end, :] 88 | points_embedding = embedding.positional_encoding( 89 | chunk, B_layer, num_encoding_functions=n_embed_funcs 90 | ) 91 | alpha = fc_occ_map(points_embedding, full=True).squeeze(dim=-1) 92 | alphas.append(alpha) 93 | alphas = torch.cat(alphas, dim=-1) 94 | 95 | return alphas 96 | -------------------------------------------------------------------------------- /SSR/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from SSR.training.training_utils import batchify 4 | 5 | 6 | def run_network_compund(inputs, fn, embed_fn, netchunk=1024 * 64): 7 | """Prepares inputs and applies network 'fn'. 8 | 9 | Input: [N_rays, N_samples, 3] 10 | """ 11 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 12 | compund_fn = lambda x: fn(embed_fn(x)) 13 | 14 | outputs_flat = batchify(compund_fn, netchunk)(inputs_flat) 15 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 16 | return outputs 17 | 18 | 19 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024 * 64): 20 | """Prepares inputs and applies network 'fn'. 21 | 22 | Input: [N_rays, N_samples, 3] 23 | """ 24 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 25 | embedded = embed_fn(inputs_flat) 26 | 27 | if viewdirs is not None: 28 | input_dirs = viewdirs[:, None].expand(inputs.shape) 29 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 30 | embedded_dirs = embeddirs_fn(input_dirs_flat) 31 | embedded = torch.cat([embedded, embedded_dirs], -1) 32 | 33 | outputs_flat = batchify(fn, netchunk)(embedded) 34 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 35 | return outputs 36 | 37 | 38 | 39 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, enable_semantic=True, 40 | num_sem_class=0, endpoint_feat=False): 41 | """Transforms model's predictions to semantically meaningful values. 42 | Args: 43 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 44 | z_vals: [num_rays, num_samples along ray]. Integration time. 45 | rays_d: [num_rays, 3]. Direction of each ray. 46 | raw_noise_std: random perturbations added to ray samples 47 | 48 | Returns: 49 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 50 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 51 | acc_map: [num_rays]. Sum of weights along each ray. 52 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 53 | depth_map: [num_rays]. Estimated distance to object. 54 | """ 55 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) 56 | 57 | dists = z_vals[..., 1:] - z_vals[..., :-1] # # (N_rays, N_samples_-1) 58 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[..., :1].shape).cuda()], -1) # [N_rays, N_samples] 59 | 60 | # Multiply each distance by the norm of its corresponding direction ray 61 | # to convert to real world distance (accounts for non-unit directions). 62 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 63 | 64 | rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] 65 | 66 | if raw_noise_std > 0.: 67 | noise = torch.randn(raw[..., 3].shape) * raw_noise_std 68 | noise = noise.cuda() 69 | else: 70 | noise = 0. 71 | 72 | alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] 73 | 74 | 75 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 76 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)).cuda(), 1.-alpha + 1e-10], -1), -1)[:, :-1] 77 | # [1, 1-a1, 1-a2, ...] 78 | # [N_rays, N_samples+1] sliced by [:, :-1] to [N_rays, N_samples] 79 | 80 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] 81 | # [N_rays, 3], the accumulated opacity along the rays, equals "1 - (1-a1)(1-a2)...(1-an)" mathematically 82 | 83 | if enable_semantic: 84 | assert num_sem_class>0 85 | # https://discuss.pytorch.org/t/multi-class-cross-entropy-loss-and-softmax-in-pytorch/24920/2 86 | sem_logits = raw[..., 4:4+num_sem_class] # [N_rays, N_samples, num_class] 87 | sem_map = torch.sum(weights[..., None] * sem_logits, -2) # [N_rays, num_class] 88 | else: 89 | sem_map = torch.tensor(0) 90 | 91 | 92 | if endpoint_feat: 93 | feat = raw[..., -128:] # [N_rays, N_samples, feat_dim] take the last 128 dim from predictions 94 | feat_map = torch.sum(weights[..., None] * feat, -2) # [N_rays, feat_dim] 95 | else: 96 | feat_map = torch.tensor(0) 97 | 98 | depth_map = torch.sum(weights * z_vals, -1) # (N_rays,) 99 | disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) 100 | acc_map = torch.sum(weights, -1) 101 | 102 | if white_bkgd: 103 | rgb_map = rgb_map + (1.-acc_map[..., None]) 104 | if enable_semantic: 105 | sem_map = sem_map + (1.-acc_map[..., None]) 106 | 107 | return rgb_map, disp_map, acc_map, weights, depth_map, sem_map, feat_map 108 | 109 | -------------------------------------------------------------------------------- /SSR/training/training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import confusion_matrix 4 | 5 | def batchify_rays(render_fn, rays_flat, chunk=1024 * 32): 6 | """Render rays in smaller minibatches to avoid OOM. 7 | """ 8 | all_ret = {} 9 | for i in range(0, rays_flat.shape[0], chunk): 10 | ret = render_fn(rays_flat[i:i + chunk]) 11 | for k in ret: 12 | if k not in all_ret: 13 | all_ret[k] = [] 14 | all_ret[k].append(ret[k]) 15 | 16 | all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret} 17 | return all_ret 18 | 19 | 20 | def batchify(fn, chunk): 21 | """Constructs a version of 'fn' that applies to smaller batches. 22 | """ 23 | if chunk is None: 24 | return fn 25 | 26 | def ret(inputs): 27 | return torch.cat([fn(inputs[i:i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 28 | 29 | return ret 30 | 31 | 32 | def lr_poly_decay(base_lr, iter, max_iter, power): 33 | """ Polynomial learning rate decay 34 | Polynomial Decay provides a smoother decay using a polynomial function and reaches a learning rate of 0 35 | after max_update iterations. 36 | https://kiranscaria.github.io/general/2019/08/16/learning-rate-schedules.html 37 | 38 | max_iter: number of iterations to perform before the learning rate is taken to . 39 | power: the degree of the polynomial function. Smaller values of power produce slower decay and 40 | large values of learning rate for longer periods. 41 | """ 42 | return base_lr * ((1 - float(iter) / max_iter) ** (power)) 43 | 44 | 45 | def lr_exp_decay(base_lr, exp_base_lr, current_step, decay_steps): 46 | """ lr = lr0 * decay_base^(−kt) 47 | """ 48 | new_lrate = base_lr * (exp_base_lr ** (current_step / decay_steps)) 49 | return new_lrate 50 | 51 | 52 | def nanmean(data, **args): 53 | # This makes it ignore the first 'background' class 54 | return np.ma.masked_array(data, np.isnan(data)).mean(**args) 55 | # In np.ma.masked_array(data, np.isnan(data), elements of data == np.nan is invalid and will be ingorned during computation of np.mean() 56 | 57 | 58 | def calculate_segmentation_metrics(true_labels, predicted_labels, number_classes, ignore_label): 59 | if (true_labels == ignore_label).all(): 60 | return [0]*4 61 | 62 | true_labels = true_labels.flatten() 63 | predicted_labels = predicted_labels.flatten() 64 | valid_pix_ids = true_labels!=ignore_label 65 | predicted_labels = predicted_labels[valid_pix_ids] 66 | true_labels = true_labels[valid_pix_ids] 67 | 68 | conf_mat = confusion_matrix(true_labels, predicted_labels, labels=list(range(number_classes))) 69 | norm_conf_mat = np.transpose( 70 | np.transpose(conf_mat) / conf_mat.astype(np.float).sum(axis=1)) 71 | 72 | missing_class_mask = np.isnan(norm_conf_mat.sum(1)) # missing class will have NaN at corresponding class 73 | exsiting_class_mask = ~ missing_class_mask 74 | 75 | class_average_accuracy = nanmean(np.diagonal(norm_conf_mat)) 76 | total_accuracy = (np.sum(np.diagonal(conf_mat)) / np.sum(conf_mat)) 77 | ious = np.zeros(number_classes) 78 | for class_id in range(number_classes): 79 | ious[class_id] = (conf_mat[class_id, class_id] / ( 80 | np.sum(conf_mat[class_id, :]) + np.sum(conf_mat[:, class_id]) - 81 | conf_mat[class_id, class_id])) 82 | miou = nanmean(ious) 83 | miou_valid_class = np.mean(ious[exsiting_class_mask]) 84 | return miou, miou_valid_class, total_accuracy, class_average_accuracy, ious 85 | 86 | 87 | def calculate_depth_metrics(depth_trgt, depth_pred): 88 | """ Computes 2d metrics between two depth maps 89 | 90 | Args: 91 | depth_pred: mxn np.array containing prediction 92 | depth_trgt: mxn np.array containing ground truth 93 | Returns: 94 | Dict of metrics 95 | """ 96 | mask1 = depth_pred>0 # ignore values where prediction is 0 (% complete) 97 | mask = (depth_trgt<10) * (depth_trgt>0) * mask1 98 | 99 | depth_pred = depth_pred[mask] 100 | depth_trgt = depth_trgt[mask] 101 | abs_diff = np.abs(depth_pred-depth_trgt) 102 | abs_rel = abs_diff/depth_trgt 103 | sq_diff = abs_diff**2 104 | sq_rel = sq_diff/depth_trgt 105 | sq_log_diff = (np.log(depth_pred)-np.log(depth_trgt))**2 106 | thresh = np.maximum((depth_trgt / depth_pred), (depth_pred / depth_trgt)) 107 | r1 = (thresh < 1.25).astype('float') 108 | r2 = (thresh < 1.25**2).astype('float') 109 | r3 = (thresh < 1.25**3).astype('float') 110 | 111 | metrics = {} 112 | metrics['AbsRel'] = np.mean(abs_rel) 113 | metrics['AbsDiff'] = np.mean(abs_diff) 114 | metrics['SqRel'] = np.mean(sq_rel) 115 | metrics['RMSE'] = np.sqrt(np.mean(sq_diff)) 116 | metrics['LogRMSE'] = np.sqrt(np.mean(sq_log_diff)) 117 | metrics['r1'] = np.mean(r1) 118 | metrics['r2'] = np.mean(r2) 119 | metrics['r3'] = np.mean(r3) 120 | metrics['complete'] = np.mean(mask1.astype('float')) 121 | 122 | return metrics -------------------------------------------------------------------------------- /SSR/models/semantic_nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | # Misc 8 | img2mse = lambda x, y: torch.mean((x - y) ** 2) 9 | mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 10 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 11 | 12 | 13 | # Positional encoding (section 5.1) 14 | class Embedder: 15 | def __init__(self, **kwargs): 16 | self.kwargs = kwargs 17 | self.create_embedding_fn() 18 | 19 | def create_embedding_fn(self): 20 | """ 21 | Embeds x to (x, sin(2^k x), cos(2^k x), ...) 22 | """ 23 | embed_fns = [] 24 | d = self.kwargs['input_dims'] 25 | out_dim = 0 26 | if self.kwargs['include_input']: # original raw input "x" is also included in the output 27 | embed_fns.append(lambda x: x) 28 | out_dim += d 29 | 30 | max_freq = self.kwargs['max_freq_log2'] 31 | N_freqs = self.kwargs['num_freqs'] 32 | 33 | if self.kwargs['log_sampling']: 34 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) 35 | else: 36 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) 37 | 38 | for freq in freq_bands: 39 | for p_fn in self.kwargs['periodic_fns']: 40 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 41 | out_dim += d 42 | 43 | self.embed_fns = embed_fns 44 | self.out_dim = out_dim 45 | 46 | def embed(self, inputs): 47 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 48 | 49 | 50 | def get_embedder(multires, i=0, scalar_factor=1): 51 | if i == -1: 52 | return nn.Identity(), 3 53 | 54 | embed_kwargs = { 55 | 'include_input': True, 56 | 'input_dims': 3, 57 | 'max_freq_log2': multires - 1, 58 | 'num_freqs': multires, 59 | 'log_sampling': True, 60 | 'periodic_fns': [torch.sin, torch.cos], 61 | } 62 | 63 | embedder_obj = Embedder(**embed_kwargs) 64 | embed = lambda x, eo=embedder_obj: eo.embed(x/scalar_factor) 65 | return embed, embedder_obj.out_dim 66 | 67 | 68 | def fc_block(in_f, out_f): 69 | return torch.nn.Sequential( 70 | torch.nn.Linear(in_f, out_f), 71 | torch.nn.ReLU(out_f) 72 | ) 73 | 74 | class Semantic_NeRF(nn.Module): 75 | """ 76 | Compared to the NeRF class wich also predicts semantic logits from MLPs, here we make the semantic label only a function of 3D position 77 | instead of both positon and viewing directions. 78 | """ 79 | def __init__(self, enable_semantic, num_semantic_classes, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False, 80 | ): 81 | super(Semantic_NeRF, self).__init__() 82 | """ 83 | D: number of layers for density (sigma) encoder 84 | W: number of hidden units in each layer 85 | input_ch: number of input channels for xyz (3+3*10*2=63 by default) 86 | in_channels_dir: number of input channels for direction (3+3*4*2=27 by default) 87 | skips: layer index to add skip connection in the Dth layer 88 | """ 89 | self.D = D 90 | self.W = W 91 | self.input_ch = input_ch 92 | self.input_ch_views = input_ch_views 93 | self.skips = skips 94 | self.use_viewdirs = use_viewdirs 95 | self.enable_semantic = enable_semantic 96 | 97 | # build the encoder 98 | self.pts_linears = nn.ModuleList( 99 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in 100 | range(D - 1)]) 101 | 102 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 103 | 104 | # Another layer is used to 105 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)]) 106 | if use_viewdirs: 107 | self.feature_linear = nn.Linear(W, W) 108 | self.alpha_linear = nn.Linear(W, 1) 109 | if enable_semantic: 110 | self.semantic_linear = nn.Sequential(fc_block(W, W // 2), nn.Linear(W // 2, num_semantic_classes)) 111 | self.rgb_linear = nn.Linear(W // 2, 3) 112 | else: 113 | self.output_linear = nn.Linear(W, output_ch) 114 | 115 | def forward(self, x, show_endpoint=False): 116 | """ 117 | Encodes input (xyz+dir) to rgb+sigma+semantics raw output 118 | Inputs: 119 | x: (B, self.in_channels_xyz(+self.in_channels_dir)) 120 | the embedded vector of 3D xyz position and viewing direction 121 | """ 122 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 123 | h = input_pts 124 | for i, l in enumerate(self.pts_linears): 125 | h = self.pts_linears[i](h) 126 | h = F.relu(h) 127 | if i in self.skips: 128 | h = torch.cat([input_pts, h], -1) 129 | 130 | if self.use_viewdirs: 131 | # if using view-dirs, output occupancy alpha as well as features for concatenation 132 | alpha = self.alpha_linear(h) 133 | if self.enable_semantic: 134 | sem_logits = self.semantic_linear(h) 135 | feature = self.feature_linear(h) 136 | 137 | h = torch.cat([feature, input_views], -1) 138 | 139 | for i, l in enumerate(self.views_linears): 140 | h = self.views_linears[i](h) 141 | h = F.relu(h) 142 | 143 | if show_endpoint: 144 | endpoint_feat = h 145 | rgb = self.rgb_linear(h) 146 | 147 | if self.enable_semantic: 148 | outputs = torch.cat([rgb, alpha, sem_logits], -1) 149 | else: 150 | outputs = torch.cat([rgb, alpha], -1) 151 | else: 152 | outputs = self.output_linear(h) 153 | 154 | if show_endpoint is False: 155 | return outputs 156 | else: 157 | return torch.cat([outputs, endpoint_feat], -1) 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic-NeRF: Semantic Neural Radiance Fields 2 | 3 | ### [Project Page](https://shuaifengzhi.com/Semantic-NeRF/) | [Video](https://youtu.be/FpShWO7LVbM) | [Paper](https://arxiv.org/abs/2103.15875) | [Data(DropBox)](https://www.dropbox.com/sh/9yu1elddll00sdl/AAC-rSJdLX0C6HhKXGKMOIija?dl=0)| [Data(BaiduYun)[code:nerf]](https://pan.baidu.com/s/1UmABiPQKm_S5Elq_ffXzPA) 4 | 5 | 6 | [In-Place Scene Labelling and Understanding with Implicit Scene Representation](https://shuaifengzhi.com/Semantic-NeRF/) 7 | [Shuaifeng Zhi](https://shuaifengzhi.com/), 8 | [Tristan Laidlow](https://wp.doc.ic.ac.uk/twl15/), 9 | [Stefan Leutenegger](https://wp.doc.ic.ac.uk/sleutene/), 10 | [Andrew J. Davison](https://www.doc.ic.ac.uk/~ajd/), 11 |
12 | Dyson Robotics Laboratory at Imperial College \ 13 | Published in ICCV 2021 (Oral Presentation) 14 | 15 | 16 | 17 | We build upon neural radiance fields to create a scene-specific implicit 3D semantic representation, Semantic-NeRF. 18 | 19 | 20 | ## Latest Updates. 21 | - **Release of Replica Data Generation Codes.** We have provided data generation scripts for Replica sequences at [SSR/data_generation](https://github.com/Harry-Zhi/semantic_nerf/tree/main/SSR/data_generation) folder. Thanks Xin of [vMAP](https://github.com/kxhit/vMAP) for cleaning up. 22 | - **Instance Label Maps Available.** We have also provided corresponding instance label maps of pre-rendered Replica sequences in [dropbox](https://www.dropbox.com/home/Public_Hosting/Semantic_NeRF(ICCV2021)/Replica_Dataset) as a zip file *Replica_Instance_Segmentation.zip*. 23 | 24 | ## Getting Started 25 | 26 | For flawless reproduction of our results, the Ubuntu OS 20.04 is recommended. The models have been tested using Python 3.7, Pytorch 1.6.0, CUDA10.1. Higher versions should also perform similarly. 27 | 28 | ### Dependencies 29 | Main python dependencies are listed below: 30 | - Python >=3.7 31 | - torch>=1.6.0 (integrate *searchsorted* API, otherwise need to use the third party implementation [SearchSorted](https://github.com/aliutkus/torchsearchsorted) ) 32 | - cudatoolkit>=10.1 33 | 34 | Following packages are used for 3D mesh reconstruction: 35 | - trimesh==3.9.9 36 | - open3d==0.12.0 37 | 38 | With Anaconda, you can simply create a virtual environment and install dependencies with CONDA by: 39 | - `conda create -n semantic_nerf python=3.7` 40 | - `conda activate semantic_nerf` 41 | - `pip install -r requirements.txt` 42 | 43 | ## Datasets 44 | We mainly use [Replica](https://github.com/facebookresearch/Replica-Dataset) and [ScanNet](http://www.scan-net.org/) datasets for experiments, where we train a new Semantic-NeRF model on each 3D scene. Other similar indoor datasets with colour images, semantic labels and poses can also be used. 45 | 46 | ### We also provide [pre-rendered Replica data](https://www.dropbox.com/sh/9yu1elddll00sdl/AAC-rSJdLX0C6HhKXGKMOIija?dl=0) that can be directly used by Semantic-NeRF. 47 | 48 | 49 | ## Running code 50 | After cloning the codes, we can start to run Semantic-NeRF in the root directory of the repository. 51 | 52 | #### Semantic-NeRF training 53 | For standard Semantic-NeRF training with full dense semantic supervision. You can simply run following command with a chosen config file specifying data directory and hyper-params. 54 | ``` 55 | python3 train_SSR_main.py --config_file /SSR/configs/SSR_room0_config.yaml 56 | ``` 57 | 58 | Different working modes and set-ups can be chosen via commands: 59 | #### Semantic View Synthesis with Sparse Labels: 60 | ``` 61 | python3 train_SSR_main.py --sparse_views --sparse_ratio 0.6 62 | ``` 63 | Sparse ratio here is the portion of **dropped** frames in the training sequence. 64 | 65 | #### Pixel-wise Denoising Task: 66 | ``` 67 | python3 train_SSR_main.py --pixel_denoising --pixel_noise_ratio 0.5 68 | ``` 69 | 70 | We could also use a sparse set of frames along with denoising task: 71 | ``` 72 | python3 train_SSR_main.py --pixel_denoising --pixel_noise_ratio 0.5 --sparse_views --sparse_ratio 0.6 73 | ``` 74 | 75 | #### Region-wise Denoising task (For Replica Room2): 76 | ``` 77 | python3 train_SSR_main.py --region_denoising --region_noise_ratio 0.3 78 | ``` 79 | The argument **uniform_flip** corresponds to the two modes of "Even/Sort"in region-wise denoising task. 80 | 81 | #### Super-Resolution Task: 82 | For super-resolution with **dense** labels, please run 83 | ``` 84 | python3 train_SSR_main.py --super_resolution --sr_factor 8 --dense_sr 85 | ``` 86 | 87 | For super-resolution with **sparse** labels, please run 88 | ``` 89 | python3 train_SSR_main.py --super_resolution --sr_factor 8 90 | ``` 91 | 92 | #### Label Propagation Task: 93 | For label propagation task with single-click seed regions, please run 94 | ``` 95 | python3 train_SSR_main.py --label_propagation --partial_perc 0 96 | ``` 97 | 98 | In order to improve reproducibility, for denoising and label-propagation tasks, we can also include `--visualise_save` and `--load_saved` to save/load randomly generated labels. 99 | 100 | 101 | #### 3D Reconstruction of Replica Scenes 102 | We also provide codes for extracting 3D semantic mesh from a trained Seamntic-NeRF model. 103 | 104 | ``` 105 | python3 SSR/extract_colour_mesh.py --sem --mesh_dir PATH_TO_MESH --mesh_dir PATH_TO_MESH --training_data_dir PATH_TO_TRAINING_DATA --save_dir PATH_TO_SAVE_DIR 106 | ``` 107 | 108 | 109 | 110 | ### For more demos and qualitative results, please check our [project page](https://shuaifengzhi.com/Semantic-NeRF/) and [video](https://youtu.be/FpShWO7LVbM). 111 | 112 | 113 | ## Acknowledgements 114 | Thanks [nerf](https://github.com/bmild/nerf), [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch) and [nerf_pl](https://github.com/kwea123/nerf_pl) for providing nice and inspiring implementations of NeRF. Thank [Atlas](https://github.com/magicleap/Atlas) for scripts in processing ScanNet dataset. 115 | 116 | ## Citation 117 | If you found this code/work to be useful in your own research, please consider citing the following: 118 | ``` 119 | @inproceedings{Zhi:etal:ICCV2021, 120 | title={In-Place Scene Labelling and Understanding with Implicit Scene Representation}, 121 | author={Shuaifeng Zhi and Tristan Laidlow and Stefan Leutenegger and Andrew J. Davison}, 122 | booktitle=ICCV, 123 | year={2021} 124 | } 125 | ``` 126 | 127 | ## Contact 128 | If you have any questions, please contact s.zhi17@imperial.ac.uk or zhishuaifeng@outlook.com. 129 | 130 | -------------------------------------------------------------------------------- /SSR/visualisation/open3d_utils.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | 4 | 5 | def draw_segment(t1, t2, color=(1., 1., 0.)): 6 | points = [t1, t2] 7 | 8 | lines = [[0, 1]] 9 | 10 | colors = [color for i in range(len(lines))] 11 | line_set = o3d.geometry.LineSet( 12 | points=o3d.utility.Vector3dVector(points), 13 | lines=o3d.utility.Vector2iVector(lines), 14 | ) 15 | line_set.colors = o3d.utility.Vector3dVector(colors) 16 | 17 | return line_set # line-segment 18 | 19 | 20 | def draw_trajectory(scene, transform_wc, color=(1., 1., 0.), name="trajectory"): 21 | for i in range(trajectory.shape[0] - 1): 22 | t1 = transform_wc[i, :3, 3] 23 | t2 = transform_wc[i+1, :3, 3] 24 | segment = draw_segment(t1, t2, color) 25 | scene.scene.add_geometry("{}_{}".format(name, i), segment, material) 26 | scene.force_redraw() 27 | 28 | def draw_camera_frustrums(scene, material, intrinsics, transform_wc, scale=1.0, color=(1, 0, 0), name="camera"): 29 | for i in range(len(transform_wc)): 30 | camera_frustum = gen_camera_frustrum(intrinsics, transform_wc[i]) 31 | scene.scene.add_geometry("{}_{}".format(name, i), camera_frustum, material) 32 | scene.force_redraw() 33 | 34 | 35 | def gen_camera_frustrum(intrinsics, transform_wc, scale=1.0, color=(1, 0, 0)): 36 | """ 37 | intrinsics: camera intrinsic matrix 38 | scale: the depth of the frustum front plane 39 | color: frustum line colours 40 | """ 41 | print("Draw camera frustum using o3d.geometry.LineSet.") 42 | w = intrinsics['cx'] * 2 43 | h = intrinsics['cy'] * 2 44 | xl = scale * -intrinsics['cx'] / intrinsics['fx'] # 3D coordinate of minimum x 45 | xh = scale * (w - intrinsics['cx']) / intrinsics['fx'] # 3D coordinate of maximum x 46 | yl = scale * -intrinsics['cy'] / intrinsics['fy'] # 3D coordinate of minimum y 47 | yh = scale * (h - intrinsics['cy']) / intrinsics['fy'] # 3D coordinate of maximum y 48 | verts = [ 49 | 0, 0, 0, # 0 - camera center 50 | xl, yl, scale, # 1 - upper left 51 | xh, yl, scale, # 2 - upper right 52 | xh, yh, scale, # 3 - bottom right 53 | xl, yh, scale, # 4 - bottom leff 54 | ] 55 | 56 | lines = [ 57 | [0, 1], 58 | [0, 2], 59 | [0, 3], 60 | [0, 4], 61 | [1, 2], 62 | [1, 4], 63 | [3, 2], 64 | [3, 4], 65 | ] 66 | 67 | colors = [color for i in range(len(lines))] 68 | line_set = o3d.geometry.LineSet( 69 | points=o3d.utility.Vector3dVector(points), 70 | lines=o3d.utility.Vector2iVector(lines), 71 | ) 72 | line_set.colors = o3d.utility.Vector3dVector(colors) 73 | 74 | line_set = line_set.transform(transform_wc) 75 | return line_set # camera frustun 76 | 77 | 78 | 79 | def integrate_rgbd_tsdf(tsdf_volume, rgb, dep, depth_trunc, T_wc, intrinsic): 80 | for i in range(0, len(T_wc)): 81 | print("Integrate {:d}-th image into the volume.".format(i)) 82 | color = o3d.geometry.Image(rgb[i]) 83 | depth = o3d.geometry.Image(dep[i]) 84 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 85 | color, 86 | depth, 87 | depth_trunc=depth_trunc, 88 | depth_scale=1, 89 | convert_rgb_to_intensity=False, 90 | ) 91 | 92 | T_cw = np.linalg.inv(T_wc[i]) 93 | 94 | tsdf_volume.integrate( 95 | image=rgbd, 96 | intrinsic=intrinsic, 97 | extrinsic=T_cw, 98 | ) 99 | return tsdf_volume 100 | 101 | def tsdf2mesh(tsdf): 102 | mesh = tsdf.extract_triangle_mesh() 103 | mesh.compute_vertex_normals() 104 | return mesh 105 | 106 | 107 | 108 | def integrate_dep_pcd(dep, T_wc, intrinsic): 109 | # http://www.open3d.org/docs/latest/tutorial/Advanced/multiway_registration.html#Make-a-combined-point-cloud 110 | 111 | pcd_list = [] 112 | pcd_combined = o3d.geometry.PointCloud() 113 | for i in range(0, len(T_wc)): 114 | depth = o3d.geometry.Image(dep[i]) 115 | pcd = o3d.geometry.PointCloud.create_from_depth_image( 116 | depth_map, 117 | intrinsic, 118 | depth_scale=1, 119 | stride=1, 120 | project_valid_depth_only=True) 121 | pcd.transform(T_WC[i]) 122 | pcd_combined+= pcd 123 | 124 | # pcd_combined = pcd_combined.voxel_down_sample(voxel_size=0.02) 125 | print("Merge point clouds from multiple views.") 126 | # Flip it, otherwise the pointcloud will be upside down 127 | pcd_combined.transform([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 128 | o3d.visualization.draw_geometries([pcd_combined], 129 | zoom=0.3412, 130 | front=[0.4257, -0.2125, -0.8795], 131 | lookat=[2.6172, 2.0475, 1.532], 132 | up=[-0.0694, -0.9768, 0.2024]) 133 | return pcd_combined 134 | 135 | 136 | def draw_pc(batch_size, 137 | pcs_cam, 138 | T_WC_batch_np, 139 | im_batch=None, 140 | scene=None): 141 | 142 | pcs_w = [] 143 | for batch_i in range(batch_size): 144 | T_WC = T_WC_batch_np[batch_i] 145 | pc_cam = pcs_cam[batch_i] 146 | 147 | col = None 148 | if im_batch is not None: 149 | img = im_batch[batch_i] 150 | col = img.reshape(-1, 3) 151 | 152 | pc_tri = trimesh.PointCloud(vertices=pc_cam, colors=col) 153 | pc_tri.apply_transform(T_WC) 154 | pcs_w.append(pc_tri.vertices) 155 | 156 | if scene is not None: 157 | scene.add_geometry(pc_tri) 158 | 159 | pcs_w = np.concatenate(pcs_w, axis=0) 160 | return pcs_w 161 | 162 | 163 | 164 | def trimesh_to_open3d(src): 165 | dst = o3d.geometry.TriangleMesh() 166 | dst.vertices = o3d.utility.Vector3dVector(src.vertices) 167 | dst.triangles = o3d.utility.Vector3iVector(src.faces) 168 | vertex_colors = src.visual.vertex_colors[:, :3].astype(np.float) / 255.0 169 | dst.vertex_colors = o3d.utility.Vector3dVector(vertex_colors) 170 | dst.compute_vertex_normals() 171 | 172 | return dst 173 | 174 | 175 | def clean_mesh(o3d_mesh, keep_single_cluster=False, min_num_cluster=200): 176 | import copy 177 | 178 | o3d_mesh_clean = copy.deepcopy(o3d_mesh) 179 | # http://www.open3d.org/docs/release/tutorial/geometry/mesh.html?highlight=cluster_connected_triangles 180 | triangle_clusters, cluster_n_triangles, cluster_area = o3d_mesh_clean.cluster_connected_triangles() 181 | 182 | triangle_clusters = np.asarray(triangle_clusters) 183 | cluster_n_triangles = np.asarray(cluster_n_triangles) 184 | cluster_area = np.asarray(cluster_area) 185 | 186 | if keep_single_cluster: 187 | # keep the largest cluster.! 188 | largest_cluster_idx = np.argmax(cluster_n_triangles) 189 | triangles_to_remove = triangle_clusters != largest_cluster_idx 190 | o3d_mesh_clean.remove_triangles_by_mask(triangles_to_remove) 191 | o3d_mesh_clean.remove_unreferenced_vertices() 192 | print("Show mesh with largest cluster kept") 193 | else: 194 | # remove small clusters 195 | triangles_to_remove = cluster_n_triangles[triangle_clusters] < min_num_cluster 196 | o3d_mesh_clean.remove_triangles_by_mask(triangles_to_remove) 197 | o3d_mesh_clean.remove_unreferenced_vertices() 198 | print("Show mesh with small clusters removed") 199 | 200 | 201 | return o3d_mesh_clean -------------------------------------------------------------------------------- /SSR/datasets/replica_nyu/replica_nyu_cnn_datasets.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | import cv2 7 | 8 | class Replica_CNN_NYU(Dataset): 9 | def __init__(self, data_dir, train_ids, test_ids, nyu_mode, img_h=None, img_w=None, load_softmax=False): 10 | 11 | assert nyu_mode == "nyu13" or nyu_mode == "nyu34" or nyu_mode == "gt_nyu13" 12 | 13 | traj_file = os.path.join(data_dir, "traj_w_c.txt") 14 | self.rgb_dir = os.path.join(data_dir, "rgb") 15 | self.depth_dir = os.path.join(data_dir, "depth") # depth is in mm uint 16 | # self.cnn_semantic_class_dir = os.path.join(data_dir, "CNN_semantic_class_{}".format(nyu_mode)) 17 | if nyu_mode == "nyu13": 18 | self.cnn_semantic_class_dir = os.path.join(data_dir, "CNN_semantic_class_nyu13") 19 | self.gt_semantic_class_dir = os.path.join(data_dir, "semantic_class_nyu13_remap") 20 | elif nyu_mode=="nyu34": 21 | self.cnn_semantic_class_dir = os.path.join(data_dir, "CNN_semantic_class_nyu34") 22 | self.gt_semantic_class_dir = os.path.join(data_dir, "semantic_class_nyu40_remap_nyu34") 23 | elif nyu_mode == "gt_nyu13": 24 | self.cnn_semantic_class_dir = os.path.join(data_dir, "semantic_class_nyu13_remap") 25 | self.gt_semantic_class_dir = os.path.join(data_dir, "semantic_class_nyu13_remap") 26 | 27 | # self.cnn_softmax_dir = os.path.join(data_dir, "semantic_prob_CNN") 28 | 29 | 30 | self.nyu_mode = nyu_mode 31 | self.load_softmax = load_softmax 32 | 33 | self.train_ids = train_ids 34 | self.train_num = len(train_ids) 35 | self.test_ids = test_ids 36 | self.test_num = len(test_ids) 37 | 38 | self.img_h = img_h 39 | self.img_w = img_w 40 | 41 | self.Ts_full = np.loadtxt(traj_file, delimiter=" ").reshape(-1, 4, 4) 42 | 43 | self.rgb_list = sorted(glob.glob(self.rgb_dir + '/rgb*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 44 | self.depth_list = sorted(glob.glob(self.depth_dir + '/depth*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 45 | self.cnn_semantic_list = sorted(glob.glob(self.cnn_semantic_class_dir + '/semantic_class_*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 46 | self.gt_semantic_list = sorted(glob.glob(self.gt_semantic_class_dir + '/semantic_class_*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 47 | 48 | if load_softmax: 49 | self.cnn_softmax_list = sorted(glob.glob(self.cnn_softmax_dir + '/softmax_prob_*.npy'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 50 | 51 | 52 | 53 | self.train_samples = {'image': [], 'depth': [], 54 | 'cnn_semantic': [], 55 | 'gt_semantic': [], 56 | 'cnn_softmax': [], 57 | 'cnn_entropy':[], 58 | 'T_wc': []} 59 | 60 | self.test_samples = {'image': [], 'depth': [], 61 | 'cnn_semantic': [], 62 | 'gt_semantic': [], 63 | 'cnn_softmax': [], 64 | 'cnn_entropy':[], 65 | 'T_wc': []} 66 | # training samples 67 | for idx in train_ids: 68 | image = cv2.imread(self.rgb_list[idx])[:,:,::-1] / 255.0 # change from BGR uinit 8 to RGB float 69 | depth = cv2.imread(self.depth_list[idx], cv2.IMREAD_UNCHANGED) / 1000.0 # uint16 mm depth, then turn depth from mm to meter 70 | cnn_semantic = cv2.imread(self.cnn_semantic_list[idx], cv2.IMREAD_UNCHANGED) 71 | gt_semantic = cv2.imread(self.gt_semantic_list[idx], cv2.IMREAD_UNCHANGED) 72 | 73 | 74 | if (self.img_h is not None and self.img_h != image.shape[0]) or \ 75 | (self.img_w is not None and self.img_w != image.shape[1]): 76 | image = cv2.resize(image, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 77 | depth = cv2.resize(depth, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 78 | cnn_semantic = cv2.resize(cnn_semantic, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 79 | gt_semantic = cv2.resize(gt_semantic, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 80 | T_wc = self.Ts_full[idx] 81 | 82 | self.train_samples["image"].append(image) 83 | self.train_samples["depth"].append(depth) 84 | self.train_samples["cnn_semantic"].append(cnn_semantic) 85 | self.train_samples["gt_semantic"].append(gt_semantic) 86 | self.train_samples["T_wc"].append(T_wc) 87 | 88 | 89 | # test samples 90 | for idx in test_ids: 91 | image = cv2.imread(self.rgb_list[idx])[:,:,::-1] / 255.0 # change from BGR uinit 8 to RGB float 92 | depth = cv2.imread(self.depth_list[idx], cv2.IMREAD_UNCHANGED) / 1000.0 # uint16 mm depth, then turn depth from mm to meter 93 | cnn_semantic = cv2.imread(self.cnn_semantic_list[idx], cv2.IMREAD_UNCHANGED) 94 | gt_semantic = cv2.imread(self.gt_semantic_list[idx], cv2.IMREAD_UNCHANGED) 95 | 96 | 97 | if (self.img_h is not None and self.img_h != image.shape[0]) or \ 98 | (self.img_w is not None and self.img_w != image.shape[1]): 99 | image = cv2.resize(image, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 100 | depth = cv2.resize(depth, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 101 | cnn_semantic = cv2.resize(cnn_semantic, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 102 | gt_semantic = cv2.resize(gt_semantic, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 103 | T_wc = self.Ts_full[idx] 104 | 105 | self.test_samples["image"].append(image) 106 | self.test_samples["depth"].append(depth) 107 | self.test_samples["cnn_semantic"].append(cnn_semantic) 108 | self.test_samples["gt_semantic"].append(gt_semantic) 109 | self.test_samples["T_wc"].append(T_wc) 110 | 111 | 112 | 113 | if load_softmax is True: 114 | softmax_2_entropy_np = lambda x, axis: np.sum(-np.log2(x+1e-12)*x, axis=axis, keepdims=False) # H,W 115 | # training samples 116 | cnt = 0 117 | for idx in train_ids: 118 | cnn_softmax = np.clip(np.load(self.cnn_softmax_list[idx]), a_min=0, a_max=1.0) 119 | if (self.img_h is not None and self.img_h != cnn_softmax.shape[0]) or \ 120 | (self.img_w is not None and self.img_w != cnn_softmax.shape[1]): 121 | cnn_softmax = cv2.resize(cnn_softmax, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 122 | # opencv resize support resize 512 channel at maximum 123 | 124 | valid_mask = self.train_samples["gt_semantic"][cnt]>0 125 | entropy = softmax_2_entropy_np(cnn_softmax, -1)*valid_mask 126 | cnn_softmax = cnn_softmax*valid_mask[:,:,None] 127 | self.train_samples["cnn_softmax"].append(cnn_softmax) 128 | self.train_samples["cnn_entropy"].append(entropy) 129 | cnt += 1 130 | assert cnt==len(train_ids) 131 | 132 | # test samples 133 | cnt = 0 134 | for idx in test_ids: 135 | cnn_softmax = np.load(self.cnn_softmax_list[idx]) 136 | assert cnn_softmax.shape[-1]==34 137 | if (self.img_h is not None and self.img_h != cnn_softmax.shape[0]) or \ 138 | (self.img_w is not None and self.img_w != cnn_softmax.shape[1]): 139 | cnn_softmax = cv2.resize(cnn_softmax, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 140 | # we do not need softmax for testing, can also save memory 141 | valid_mask = self.test_samples["gt_semantic"][cnt]>0 142 | entropy = softmax_2_entropy_np(cnn_softmax, -1)*valid_mask 143 | self.test_samples["cnn_entropy"].append(entropy) 144 | cnt += 1 145 | assert cnt==len(test_ids) 146 | 147 | 148 | for key in self.test_samples.keys(): # transform list of np array to array with batch dimension 149 | self.train_samples[key] = np.asarray(self.train_samples[key]) 150 | self.test_samples[key] = np.asarray(self.test_samples[key]) 151 | 152 | if nyu_mode == "nyu13" or nyu_mode == "gt_nyu13": 153 | self.semantic_classes = np.arange(14) # 0-void, 1-13 valid classes 154 | self.num_semantic_class = 14 # 13 valid class + 1 void class 155 | from SSR.utils import image_utils 156 | self.colour_map_np = image_utils.nyu13_colour_code 157 | elif nyu_mode=="nyu34": 158 | self.semantic_classes = np.arange(35) # 0-void, 1-34 valid classes 159 | self.num_semantic_class = 35 # 34 valid class + 1 void class 160 | self.colour_map_np = image_utils.nyu34_colour_code 161 | 162 | self.mask_ids = np.ones(self.train_num) # init self.mask_ids as full ones 163 | # 1 means the correspinding label map is used for semantic loss during training, while 0 means no semantic loss 164 | self.train_samples["cnn_semantic_clean"] = self.train_samples["cnn_semantic"].copy() 165 | 166 | print() 167 | print("Training Sample Summary:") 168 | for key in self.train_samples.keys(): 169 | print("{} has shape of {}, type {}.".format(key, self.train_samples[key].shape, self.train_samples[key].dtype)) 170 | print() 171 | print("Testing Sample Summary:") 172 | for key in self.test_samples.keys(): 173 | print("{} has shape of {}, type {}.".format(key, self.test_samples[key].shape, self.test_samples[key].dtype)) -------------------------------------------------------------------------------- /SSR/models/rays.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | # Ray helpers 5 | def get_rays(H, W, focal, c2w): 6 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' 7 | i = i.t() 8 | j = j.t() 9 | dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1) 10 | # Rotate ray directions from camera frame to the world frame 11 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 12 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 13 | rays_o = c2w[:3,-1].expand(rays_d.shape) 14 | return rays_o, rays_d 15 | 16 | 17 | def get_rays_np(H, W, focal, c2w): 18 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 19 | # Rotate ray directions from camera frame to the world frame 20 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 21 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 22 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 23 | return rays_o, rays_d 24 | 25 | 26 | # Ray helpers 27 | def get_rays_camera(B, H, W, fx, fy, cx, cy, depth_type, convention="opencv"): 28 | 29 | assert depth_type is "z" or depth_type is "euclidean" 30 | i, j = torch.meshgrid(torch.arange(W), torch.arange(H)) # pytorch's meshgrid has indexing='ij', we transpose to "xy" moode 31 | 32 | i = i.t().float() 33 | j = j.t().float() 34 | 35 | size = [B, H, W] 36 | 37 | i_batch = torch.empty(size) 38 | j_batch = torch.empty(size) 39 | i_batch[:, :, :] = i[None, :, :] 40 | j_batch[:, :, :] = j[None, :, :] 41 | 42 | if convention == "opencv": 43 | x = (i_batch - cx) / fx 44 | y = (j_batch - cy) / fy 45 | z = torch.ones(size) 46 | elif convention == "opengl": 47 | x = (i_batch - cx) / fx 48 | y = -(j_batch - cy) / fy 49 | z = -torch.ones(size) 50 | else: 51 | assert False 52 | 53 | dirs = torch.stack((x, y, z), dim=3) # shape of [B, H, W, 3] 54 | 55 | if depth_type == 'euclidean': 56 | norm = torch.norm(dirs, dim=3, keepdim=True) 57 | dirs = dirs * (1. / norm) 58 | 59 | return dirs 60 | 61 | 62 | def get_rays_world(T_WC, dirs_C): 63 | R_WC = T_WC[:, :3, :3] # Bx3x3 64 | dirs_W = torch.matmul(R_WC[:, None, ...], dirs_C[..., None]).squeeze(-1) 65 | origins = T_WC[:, :3, -1] # Bx3 66 | origins = torch.broadcast_tensors(origins[:, None, :], dirs_W)[0] 67 | return origins, dirs_W 68 | 69 | 70 | def get_rays_camera_np(B, H, W, fx, fy, cx, cy, depth_type, convention="opencv"): 71 | assert depth_type is "z" or depth_type is "euclidean" 72 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), 73 | np.arange(H, dtype=np.float32), indexing='xy') # pytorch's meshgrid has default indexing='ij' 74 | 75 | size = [B, H, W] 76 | 77 | i_batch = np.empty(size, dtype=np.float32) 78 | j_batch = np.empty(size, dtype=np.float32) 79 | i_batch[:, :, :] = i[None, :, :] 80 | j_batch[:, :, :] = j[None, :, :] 81 | 82 | if convention == "opencv": 83 | x = (i_batch - cx) / fx 84 | y = (j_batch - cy) / fy 85 | z = np.ones(size, dtype=np.float32) 86 | elif convention == "opengl": 87 | x = (i_batch - cx) / fx 88 | y = -(j_batch - cy) / fy 89 | z = -np.ones(size, dtype=np.float32) 90 | else: 91 | assert False 92 | 93 | dirs = np.stack((x, y, z), axis=3) # shape of [B, H, W, 3] 94 | 95 | if depth_type == 'euclidean': 96 | norm = np.norm(dirs, axis=3, keepdim=True) 97 | dirs = dirs * (1. / norm) 98 | 99 | return dirs 100 | 101 | 102 | def get_rays_world_np(T_WC, dirs_C): 103 | R_WC = T_WC[:, :3, :3] # Bx3x3 104 | dirs_W = (R_WC * dirs_C[..., np.newaxis, :]).sum(axis=-1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 105 | # sum([B,3,3] * [B, H, W, 1, 3], axis=-1) --> [B, H, W, 3] 106 | origins = T_WC[:, :3, -1] # Bx3 107 | 108 | return origins, dirs_W 109 | 110 | 111 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 112 | 113 | # Shift ray origins to near plane 114 | # solves for the t value such that o + t * d = -near 115 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 116 | rays_o = rays_o + t[..., None] * rays_d 117 | 118 | # Projection 119 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 120 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 121 | o2 = 1. + 2. * near / rays_o[..., 2] 122 | 123 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 124 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 125 | d2 = -2. * near / rays_o[..., 2] 126 | 127 | rays_o = torch.stack([o0, o1, o2], -1) 128 | rays_d = torch.stack([d0, d1, d2], -1) 129 | 130 | return rays_o, rays_d 131 | 132 | 133 | def stratified_bins(min_depth, 134 | max_depth, 135 | n_bins, 136 | n_rays, 137 | device): 138 | 139 | bin_limits = torch.linspace( 140 | min_depth, 141 | max_depth, 142 | n_bins + 1, 143 | device=device, 144 | ) 145 | lower_limits = bin_limits[:-1] 146 | bin_length = (max_depth - min_depth) / (n_bins) 147 | increments = torch.rand(n_rays, n_bins, device=device) * bin_length 148 | z_vals = lower_limits[None, :] + increments 149 | 150 | return z_vals 151 | 152 | 153 | def sampling_index(n_rays, batch_size, h, w): 154 | 155 | index_b = np.random.choice(np.arange(batch_size)).reshape((1, 1)) # sample one image from the full trainiing set 156 | index_hw = torch.randint(0, h * w, (1, n_rays)) 157 | 158 | return index_b, index_hw 159 | 160 | 161 | # Hierarchical sampling using inverse CDF transformations 162 | def sample_pdf(bins, weights, N_samples, det=False): 163 | """ Sample @N_importance samples from @bins with distribution defined by @weights. 164 | 165 | Inputs: 166 | bins: N_rays x (N_samples_coarse - 1) 167 | weights: N_rays x (N_samples_coarse - 2) 168 | N_samples: N_samples_fine 169 | det: deterministic or not 170 | """ 171 | # Get pdf 172 | weights = weights + 1e-5 # prevent nans, prevent division by zero (don't do inplace op!) 173 | pdf = weights / torch.sum(weights, -1, keepdim=True) 174 | cdf = torch.cumsum(pdf, -1) # N_rays x (N_samples - 2) 175 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # N_rays x (N_samples_coarse - 1) 176 | # padded to 0~1 inclusive, (N_rays, N_samples-1) 177 | 178 | # Take uniform samples 179 | if det: # generate deterministic samples 180 | u = torch.linspace(0., 1., steps=N_samples, device=bins.device) 181 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 182 | else: 183 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=bins.device) 184 | # (N_rays, N_samples_fine) 185 | 186 | # Invert CDF 187 | u = u.contiguous() 188 | inds = torch.searchsorted(cdf.detach(), u, right=True) # N_rays x N_samples_fine 189 | below = torch.max(torch.zeros_like(inds-1), inds-1) 190 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) 191 | inds_g = torch.stack([below, above], -1) # (N_rays, N_samples_fine, 2) 192 | 193 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] # (N_rays, N_samples_fine, N_samples_coarse - 1) 194 | 195 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) # N_rays, N_samples_fine, 2 196 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) # N_rays, N_samples_fine, 2 197 | 198 | denom = (cdf_g[..., 1]-cdf_g[..., 0]) # # N_rays, N_samples_fine 199 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 200 | # denom equals 0 means a bin has weight 0, in which case it will not be sampled 201 | # anyway, therefore any value for it is fine (set to 1 here) 202 | 203 | t = (u-cdf_g[..., 0])/denom 204 | samples = bins_g[..., 0] + t * (bins_g[...,1]-bins_g[...,0]) 205 | 206 | return samples 207 | 208 | 209 | def create_rays(num_rays, Ts_c2w, height, width, fx, fy, cx, cy, near, far, c2w_staticcam=None, depth_type="z", 210 | use_viewdirs=True, convention="opencv"): 211 | """ 212 | convention: 213 | "opencv" or "opengl". It defines the coordinates convention of rays from cameras. 214 | OpenCv defines x,y,z as right, down, forward while OpenGl defines x,y,z as right, up, backward (camera looking towards forward direction still, -z!) 215 | Note: Use either convention is fine, but the corresponding pose should follow the same convention. 216 | 217 | """ 218 | print('prepare rays') 219 | 220 | rays_cam = get_rays_camera(num_rays, height, width, fx, fy, cx, cy, depth_type=depth_type, convention=convention) # [N, H, W, 3] 221 | 222 | dirs_C = rays_cam.view(num_rays, -1, 3) # [N, HW, 3] 223 | rays_o, rays_d = get_rays_world(Ts_c2w, dirs_C) # origins: [B, HW, 3], dirs_W: [B, HW, 3] 224 | 225 | if use_viewdirs: 226 | # provide ray directions as input 227 | viewdirs = rays_d 228 | if c2w_staticcam is not None: 229 | # c2w_staticcam: If not None, use this transformation matrix for camera, 230 | # while using other c2w argument for viewing directions. 231 | # special case to visualize effect of viewdirs 232 | rays_o, rays_d = get_rays_world(c2w_staticcam, dirs_C) # origins: [B, HW, 3], dirs_W: [B, HW, 3] 233 | 234 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True).float() 235 | 236 | near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1]) 237 | rays = torch.cat([rays_o, rays_d, near, far], -1) 238 | 239 | if use_viewdirs: 240 | rays = torch.cat([rays, viewdirs], -1) 241 | return rays 242 | 243 | -------------------------------------------------------------------------------- /SSR/data_generation/settings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import habitat_sim 6 | import habitat_sim.agent 7 | 8 | default_sim_settings = { 9 | # settings shared by example.py and benchmark.py 10 | "max_frames": 1000, 11 | "width": 640, 12 | "height": 480, 13 | "default_agent": 0, 14 | "sensor_height": 1.5, 15 | "hfov": 90, 16 | "color_sensor": True, # RGB sensor (default: ON) 17 | "semantic_sensor": False, # semantic sensor (default: OFF) 18 | "depth_sensor": False, # depth sensor (default: OFF) 19 | "ortho_rgba_sensor": False, # Orthographic RGB sensor (default: OFF) 20 | "ortho_depth_sensor": False, # Orthographic depth sensor (default: OFF) 21 | "ortho_semantic_sensor": False, # Orthographic semantic sensor (default: OFF) 22 | "fisheye_rgba_sensor": False, 23 | "fisheye_depth_sensor": False, 24 | "fisheye_semantic_sensor": False, 25 | "equirect_rgba_sensor": False, 26 | "equirect_depth_sensor": False, 27 | "equirect_semantic_sensor": False, 28 | "seed": 1, 29 | "silent": False, # do not print log info (default: OFF) 30 | # settings exclusive to example.py 31 | "save_png": False, # save the pngs to disk (default: OFF) 32 | "print_semantic_scene": False, 33 | "print_semantic_mask_stats": False, 34 | "compute_shortest_path": False, 35 | "compute_action_shortest_path": False, 36 | "scene": "data/scene_datasets/habitat-test-scenes/skokloster-castle.glb", 37 | "test_scene_data_url": "http://dl.fbaipublicfiles.com/habitat/habitat-test-scenes.zip", 38 | "goal_position": [5.047, 0.199, 11.145], 39 | "enable_physics": False, 40 | "enable_gfx_replay_save": False, 41 | "physics_config_file": "./data/default.physics_config.json", 42 | "num_objects": 10, 43 | "test_object_index": 0, 44 | "frustum_culling": True, 45 | } 46 | 47 | # build SimulatorConfiguration 48 | def make_cfg(settings): 49 | sim_cfg = habitat_sim.SimulatorConfiguration() 50 | if "frustum_culling" in settings: 51 | sim_cfg.frustum_culling = settings["frustum_culling"] 52 | else: 53 | sim_cfg.frustum_culling = False 54 | if "enable_physics" in settings: 55 | sim_cfg.enable_physics = settings["enable_physics"] 56 | if "physics_config_file" in settings: 57 | sim_cfg.physics_config_file = settings["physics_config_file"] 58 | # if not settings["silent"]: 59 | # print("sim_cfg.physics_config_file = " + sim_cfg.physics_config_file) 60 | if "scene_light_setup" in settings: 61 | sim_cfg.scene_light_setup = settings["scene_light_setup"] 62 | sim_cfg.gpu_device_id = 0 63 | if not hasattr(sim_cfg, "scene_id"): 64 | raise RuntimeError( 65 | "Error: Please upgrade habitat-sim. SimulatorConfig API version mismatch" 66 | ) 67 | sim_cfg.scene_id = settings["scene_file"] 68 | 69 | # define default sensor parameters (see src/esp/Sensor/Sensor.h) 70 | sensor_specs = [] 71 | 72 | def create_camera_spec(**kw_args): 73 | camera_sensor_spec = habitat_sim.CameraSensorSpec() 74 | camera_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR 75 | camera_sensor_spec.resolution = [settings["height"], settings["width"]] 76 | camera_sensor_spec.position = [0, settings["sensor_height"], 0] 77 | for k in kw_args: 78 | setattr(camera_sensor_spec, k, kw_args[k]) 79 | return camera_sensor_spec 80 | 81 | if settings["color_sensor"]: 82 | color_sensor_spec = create_camera_spec( 83 | uuid="color_sensor", 84 | # hfov=settings["hfov"], 85 | sensor_type=habitat_sim.SensorType.COLOR, 86 | sensor_subtype=habitat_sim.SensorSubType.PINHOLE, 87 | ) 88 | sensor_specs.append(color_sensor_spec) 89 | 90 | if settings["depth_sensor"]: 91 | depth_sensor_spec = create_camera_spec( 92 | uuid="depth_sensor", 93 | # hfov=settings["hfov"], 94 | sensor_type=habitat_sim.SensorType.DEPTH, 95 | channels=1, 96 | sensor_subtype=habitat_sim.SensorSubType.PINHOLE, 97 | ) 98 | sensor_specs.append(depth_sensor_spec) 99 | 100 | if settings["semantic_sensor"]: 101 | semantic_sensor_spec = create_camera_spec( 102 | uuid="semantic_sensor", 103 | # hfov=settings["hfov"], 104 | sensor_type=habitat_sim.SensorType.SEMANTIC, 105 | channels=1, 106 | sensor_subtype=habitat_sim.SensorSubType.PINHOLE, 107 | ) 108 | sensor_specs.append(semantic_sensor_spec) 109 | 110 | # if settings["ortho_rgba_sensor"]: 111 | # ortho_rgba_sensor_spec = create_camera_spec( 112 | # uuid="ortho_rgba_sensor", 113 | # sensor_type=habitat_sim.SensorType.COLOR, 114 | # sensor_subtype=habitat_sim.SensorSubType.ORTHOGRAPHIC, 115 | # ) 116 | # sensor_specs.append(ortho_rgba_sensor_spec) 117 | # 118 | # if settings["ortho_depth_sensor"]: 119 | # ortho_depth_sensor_spec = create_camera_spec( 120 | # uuid="ortho_depth_sensor", 121 | # sensor_type=habitat_sim.SensorType.DEPTH, 122 | # channels=1, 123 | # sensor_subtype=habitat_sim.SensorSubType.ORTHOGRAPHIC, 124 | # ) 125 | # sensor_specs.append(ortho_depth_sensor_spec) 126 | # 127 | # if settings["ortho_semantic_sensor"]: 128 | # ortho_semantic_sensor_spec = create_camera_spec( 129 | # uuid="ortho_semantic_sensor", 130 | # sensor_type=habitat_sim.SensorType.SEMANTIC, 131 | # channels=1, 132 | # sensor_subtype=habitat_sim.SensorSubType.ORTHOGRAPHIC, 133 | # ) 134 | # sensor_specs.append(ortho_semantic_sensor_spec) 135 | 136 | # TODO Figure out how to implement copying of specs 137 | def create_fisheye_spec(**kw_args): 138 | fisheye_sensor_spec = habitat_sim.FisheyeSensorDoubleSphereSpec() 139 | fisheye_sensor_spec.uuid = "fisheye_sensor" 140 | fisheye_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR 141 | fisheye_sensor_spec.sensor_model_type = ( 142 | habitat_sim.FisheyeSensorModelType.DOUBLE_SPHERE 143 | ) 144 | 145 | # The default value (alpha, xi) is set to match the lens "GoPro" found in Table 3 of this paper: 146 | # Vladyslav Usenko, Nikolaus Demmel and Daniel Cremers: The Double Sphere 147 | # Camera Model, The International Conference on 3D Vision (3DV), 2018 148 | # You can find the intrinsic parameters for the other lenses in the same table as well. 149 | fisheye_sensor_spec.xi = -0.27 150 | fisheye_sensor_spec.alpha = 0.57 151 | fisheye_sensor_spec.focal_length = [364.84, 364.86] 152 | 153 | fisheye_sensor_spec.resolution = [settings["height"], settings["width"]] 154 | # The default principal_point_offset is the middle of the image 155 | fisheye_sensor_spec.principal_point_offset = None 156 | # default: fisheye_sensor_spec.principal_point_offset = [i/2 for i in fisheye_sensor_spec.resolution] 157 | fisheye_sensor_spec.position = [0, settings["sensor_height"], 0] 158 | for k in kw_args: 159 | setattr(fisheye_sensor_spec, k, kw_args[k]) 160 | return fisheye_sensor_spec 161 | 162 | # if settings["fisheye_rgba_sensor"]: 163 | # fisheye_rgba_sensor_spec = create_fisheye_spec(uuid="fisheye_rgba_sensor") 164 | # sensor_specs.append(fisheye_rgba_sensor_spec) 165 | # if settings["fisheye_depth_sensor"]: 166 | # fisheye_depth_sensor_spec = create_fisheye_spec( 167 | # uuid="fisheye_depth_sensor", 168 | # sensor_type=habitat_sim.SensorType.DEPTH, 169 | # channels=1, 170 | # ) 171 | # sensor_specs.append(fisheye_depth_sensor_spec) 172 | # if settings["fisheye_semantic_sensor"]: 173 | # fisheye_semantic_sensor_spec = create_fisheye_spec( 174 | # uuid="fisheye_semantic_sensor", 175 | # sensor_type=habitat_sim.SensorType.SEMANTIC, 176 | # channels=1, 177 | # ) 178 | # sensor_specs.append(fisheye_semantic_sensor_spec) 179 | 180 | def create_equirect_spec(**kw_args): 181 | equirect_sensor_spec = habitat_sim.EquirectangularSensorSpec() 182 | equirect_sensor_spec.uuid = "equirect_rgba_sensor" 183 | equirect_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR 184 | equirect_sensor_spec.resolution = [settings["height"], settings["width"]] 185 | equirect_sensor_spec.position = [0, settings["sensor_height"], 0] 186 | for k in kw_args: 187 | setattr(equirect_sensor_spec, k, kw_args[k]) 188 | return equirect_sensor_spec 189 | 190 | # if settings["equirect_rgba_sensor"]: 191 | # equirect_rgba_sensor_spec = create_equirect_spec(uuid="equirect_rgba_sensor") 192 | # sensor_specs.append(equirect_rgba_sensor_spec) 193 | # 194 | # if settings["equirect_depth_sensor"]: 195 | # equirect_depth_sensor_spec = create_equirect_spec( 196 | # uuid="equirect_depth_sensor", 197 | # sensor_type=habitat_sim.SensorType.DEPTH, 198 | # channels=1, 199 | # ) 200 | # sensor_specs.append(equirect_depth_sensor_spec) 201 | # 202 | # if settings["equirect_semantic_sensor"]: 203 | # equirect_semantic_sensor_spec = create_equirect_spec( 204 | # uuid="equirect_semantic_sensor", 205 | # sensor_type=habitat_sim.SensorType.SEMANTIC, 206 | # channels=1, 207 | # ) 208 | # sensor_specs.append(equirect_semantic_sensor_spec) 209 | 210 | # create agent specifications 211 | agent_cfg = habitat_sim.agent.AgentConfiguration() 212 | agent_cfg.sensor_specifications = sensor_specs 213 | agent_cfg.action_space = { 214 | "move_forward": habitat_sim.agent.ActionSpec( 215 | "move_forward", habitat_sim.agent.ActuationSpec(amount=0.25) 216 | ), 217 | "turn_left": habitat_sim.agent.ActionSpec( 218 | "turn_left", habitat_sim.agent.ActuationSpec(amount=10.0) 219 | ), 220 | "turn_right": habitat_sim.agent.ActionSpec( 221 | "turn_right", habitat_sim.agent.ActuationSpec(amount=10.0) 222 | ), 223 | } 224 | 225 | # override action space to no-op to test physics 226 | if sim_cfg.enable_physics: 227 | agent_cfg.action_space = { 228 | "move_forward": habitat_sim.agent.ActionSpec( 229 | "move_forward", habitat_sim.agent.ActuationSpec(amount=0.0) 230 | ) 231 | } 232 | 233 | return habitat_sim.Configuration(sim_cfg, [agent_cfg]) 234 | -------------------------------------------------------------------------------- /SSR/data_generation/habitat_renderer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os, sys, argparse 3 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | import cv2 6 | import logging 7 | import habitat_sim as hs 8 | import numpy as np 9 | import quaternion 10 | import yaml 11 | import json 12 | from typing import Any, Dict, List, Tuple, Union 13 | from imgviz import label_colormap 14 | from PIL import Image 15 | import matplotlib.pyplot as plt 16 | import transformation 17 | import imgviz 18 | from datetime import datetime 19 | import time 20 | from settings import make_cfg 21 | 22 | # Custom type definitions 23 | Config = Dict[str, Any] 24 | Observation = hs.sensor.Observation 25 | Sim = hs.Simulator 26 | 27 | def init_habitat(config) : 28 | """Initialize the Habitat simulator with sensors and scene file""" 29 | _cfg = make_cfg(config) 30 | sim = Sim(_cfg) 31 | sim_cfg = hs.SimulatorConfiguration() 32 | sim_cfg.gpu_device_id = 0 33 | # Note: all sensors must have the same resolution 34 | camera_resolution = [config["height"], config["width"]] 35 | sensors = { 36 | "color_sensor": { 37 | "sensor_type": hs.SensorType.COLOR, 38 | "resolution": camera_resolution, 39 | "position": [0.0, config["sensor_height"], 0.0], 40 | }, 41 | "depth_sensor": { 42 | "sensor_type": hs.SensorType.DEPTH, 43 | "resolution": camera_resolution, 44 | "position": [0.0, config["sensor_height"], 0.0], 45 | }, 46 | "semantic_sensor": { 47 | "sensor_type": hs.SensorType.SEMANTIC, 48 | "resolution": camera_resolution, 49 | "position": [0.0, config["sensor_height"], 0.0], 50 | }, 51 | } 52 | 53 | sensor_specs = [] 54 | for sensor_uuid, sensor_params in sensors.items(): 55 | if config[sensor_uuid]: 56 | sensor_spec = hs.SensorSpec() 57 | sensor_spec.uuid = sensor_uuid 58 | sensor_spec.sensor_type = sensor_params["sensor_type"] 59 | sensor_spec.resolution = sensor_params["resolution"] 60 | sensor_spec.position = sensor_params["position"] 61 | 62 | sensor_specs.append(sensor_spec) 63 | 64 | # Here you can specify the amount of displacement in a forward action and the turn angle 65 | agent_cfg = hs.agent.AgentConfiguration() 66 | agent_cfg.sensor_specifications = sensor_specs 67 | agent_cfg.action_space = { 68 | "move_forward": hs.agent.ActionSpec( 69 | "move_forward", hs.agent.ActuationSpec(amount=0.25) 70 | ), 71 | "turn_left": hs.agent.ActionSpec( 72 | "turn_left", hs.agent.ActuationSpec(amount=30.0) 73 | ), 74 | "turn_right": hs.agent.ActionSpec( 75 | "turn_right", hs.agent.ActuationSpec(amount=30.0) 76 | ), 77 | } 78 | 79 | hs_cfg = hs.Configuration(sim_cfg, [agent_cfg]) 80 | # sim = Sim(hs_cfg) 81 | 82 | if config["enable_semantics"]: # extract instance to class mapping function 83 | assert os.path.exists(config["instance2class_mapping"]) 84 | with open(config["instance2class_mapping"], "r") as f: 85 | annotations = json.load(f) 86 | instance_id_to_semantic_label_id = np.array(annotations["id_to_label"]) 87 | num_classes = len(annotations["classes"]) 88 | label_colour_map = label_colormap() 89 | config["instance2semantic"] = instance_id_to_semantic_label_id 90 | config["classes"] = annotations["classes"] 91 | config["objects"] = annotations["objects"] 92 | 93 | config["num_classes"] = num_classes 94 | config["label_colour_map"] = label_colormap() 95 | config["instance_colour_map"] = label_colormap(500) 96 | 97 | 98 | # add camera intrinsic 99 | # hfov = float(agent_cfg.sensor_specifications[0].parameters['hfov']) * np.pi / 180. 100 | # https://aihabitat.org/docs/habitat-api/view-transform-warp.html 101 | # config['K'] = K 102 | # config['K'] = np.array([[fx, 0.0, 0.0], [0.0, fx, 0.0], [0.0, 0.0, 1.0]], 103 | # dtype=np.float64) 104 | 105 | # hfov = float(agent_cfg.sensor_specifications[0].parameters['hfov']) 106 | # fx = 1.0 / np.tan(hfov / 2.0) 107 | # config['K'] = np.array([[fx, 0.0, 0.0], [0.0, fx, 0.0], [0.0, 0.0, 1.0]], 108 | # dtype=np.float64) 109 | 110 | # Get the intrinsic camera parameters 111 | 112 | 113 | logging.info('Habitat simulator initialized') 114 | 115 | return sim, hs_cfg, config 116 | 117 | def save_renders(save_path, observation, enable_semantic, suffix=""): 118 | save_path_rgb = os.path.join(save_path, "rgb") 119 | save_path_depth = os.path.join(save_path, "depth") 120 | save_path_sem_class = os.path.join(save_path, "semantic_class") 121 | save_path_sem_instance = os.path.join(save_path, "semantic_instance") 122 | 123 | if not os.path.exists(save_path_rgb): 124 | os.makedirs(save_path_rgb) 125 | if not os.path.exists(save_path_depth): 126 | os.makedirs(save_path_depth) 127 | if not os.path.exists(save_path_sem_class): 128 | os.makedirs(save_path_sem_class) 129 | if not os.path.exists(save_path_sem_instance): 130 | os.makedirs(save_path_sem_instance) 131 | 132 | cv2.imwrite(os.path.join(save_path_rgb, "rgb{}.png".format(suffix)), observation["color_sensor"][:,:,::-1]) # change from RGB to BGR for opencv write 133 | cv2.imwrite(os.path.join(save_path_depth, "depth{}.png".format(suffix)), observation["depth_sensor_mm"]) 134 | 135 | if enable_semantic: 136 | cv2.imwrite(os.path.join(save_path_sem_class, "semantic_class{}.png".format(suffix)), observation["semantic_class"]) 137 | cv2.imwrite(os.path.join(save_path_sem_class, "vis_sem_class{}.png".format(suffix)), observation["vis_sem_class"][:,:,::-1]) 138 | 139 | cv2.imwrite(os.path.join(save_path_sem_instance, "semantic_instance{}.png".format(suffix)), observation["semantic_instance"]) 140 | cv2.imwrite(os.path.join(save_path_sem_instance, "vis_sem_instance{}.png".format(suffix)), observation["vis_sem_instance"][:,:,::-1]) 141 | 142 | 143 | def render(sim, config): 144 | """Return the sensor observations and ground truth pose""" 145 | observation = sim.get_sensor_observations() 146 | 147 | # process rgb imagem change from RGBA to RGB 148 | observation['color_sensor'] = observation['color_sensor'][..., 0:3] 149 | rgb_img = observation['color_sensor'] 150 | 151 | # process depth 152 | depth_mm = (observation['depth_sensor'].copy()*1000).astype(np.uint16) # change meters to mm 153 | observation['depth_sensor_mm'] = depth_mm 154 | 155 | # process semantics 156 | if config['enable_semantics']: 157 | 158 | # Assuming the scene has no more than 65534 objects 159 | observation['semantic_instance'] = np.clip(observation['semantic_sensor'].astype(np.uint16), 0, 65535) 160 | # observation['semantic_instance'][observation['semantic_instance']==12]=0 # mask out certain instance 161 | # Convert instance IDs to class IDs 162 | 163 | 164 | # observation['semantic_classes'] = np.zeros(observation['semantic'].shape, dtype=np.uint8) 165 | # TODO make this conversion more efficient 166 | semantic_class = config["instance2semantic"][observation['semantic_instance']] 167 | semantic_class[semantic_class < 0] = 0 168 | 169 | vis_sem_class = config["label_colour_map"][semantic_class] 170 | vis_sem_instance = config["instance_colour_map"][observation['semantic_instance']] # may cause error when having more than 255 instances in the scene 171 | 172 | observation['semantic_class'] = semantic_class.astype(np.uint8) 173 | observation["vis_sem_class"] = vis_sem_class.astype(np.uint8) 174 | observation["vis_sem_instance"] = vis_sem_instance.astype(np.uint8) 175 | 176 | # del observation["semantic_sensor"] 177 | 178 | # Get the camera ground truth pose (T_HC) in the habitat frame from the 179 | # position and orientation 180 | t_HC = sim.get_agent(0).get_state().position 181 | q_HC = sim.get_agent(0).get_state().rotation 182 | T_HC = transformation.combine_pose(t_HC, q_HC) 183 | 184 | observation['T_HC'] = T_HC 185 | observation['T_WC'] = transformation.Thc_to_Twc(T_HC) 186 | 187 | return observation 188 | 189 | def set_agent_position(sim, pose): 190 | # Move the agent 191 | R = pose[:3, :3] 192 | orientation_quat = quaternion.from_rotation_matrix(R) 193 | t = pose[:3, 3] 194 | position = t 195 | 196 | orientation = [orientation_quat.x, orientation_quat.y, orientation_quat.z, orientation_quat.w] 197 | agent = sim.get_agent(0) 198 | agent_state = hs.agent.AgentState(position, orientation) 199 | # agent.set_state(agent_state, reset_sensors=False) 200 | agent.set_state(agent_state) 201 | 202 | def main(): 203 | parser = argparse.ArgumentParser(description='Render Colour, Depth, Semantic, Instance labeling from Habitat-Simultation.') 204 | parser.add_argument('--config_file', type=str, 205 | default="./data_generation/replica_render_config_vMAP.yaml", 206 | help='the path to custom config file.') 207 | args = parser.parse_args() 208 | 209 | """Initialize the config dict and Habitat simulator""" 210 | # Read YAML file 211 | with open(args.config_file, 'r') as f: 212 | config = yaml.safe_load(f) 213 | 214 | config["save_path"] = os.path.join(config["save_path"]) 215 | if not os.path.exists(config["save_path"]): 216 | os.makedirs(config["save_path"]) 217 | 218 | T_wc = np.loadtxt(config["pose_file"]).reshape(-1, 4, 4) 219 | Ts_cam2world = T_wc 220 | 221 | print("-----Initialise and Set Habitat-Sim-----") 222 | sim, hs_cfg, config = init_habitat(config) 223 | # Set agent state 224 | sim.initialize_agent(config["default_agent"]) 225 | 226 | """Set agent state""" 227 | print("-----Render Images from Habitat-Sim-----") 228 | with open(os.path.join(config["save_path"], 'render_config.yaml'), 'w') as outfile: 229 | yaml.dump(config, outfile, default_flow_style=False) 230 | start_time = time.time() 231 | total_render_num = Ts_cam2world.shape[0] 232 | for i in range(total_render_num): 233 | if i % 100 == 0 : 234 | print("Rendering Process: {}/{}".format(i, total_render_num)) 235 | set_agent_position(sim, transformation.Twc_to_Thc(Ts_cam2world[i])) 236 | 237 | # replica mode 238 | observation = render(sim, config) 239 | save_renders(config["save_path"], observation, config["enable_semantics"], suffix="_{}".format(i)) 240 | 241 | end_time = time.time() 242 | print("-----Finish Habitat Rendering, Showing Trajectories.-----") 243 | print("Average rendering time per image is {} seconds.".format((end_time-start_time)/Ts_cam2world.shape[0])) 244 | 245 | if __name__ == "__main__": 246 | main() 247 | 248 | 249 | -------------------------------------------------------------------------------- /SSR/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import imgviz 5 | from imgviz import label_colormap 6 | from imgviz import draw as draw_module 7 | import matplotlib.pyplot as plt 8 | 9 | def numpy2cv(image): 10 | """ 11 | 12 | :param image: a floating numpy images of shape [H,W,3] within range [0, 1] 13 | :return: 14 | """ 15 | 16 | image_cv = np.copy(image) 17 | image_cv = np.astype(np.clip(image_cv, 0, 1)*255, np.uint8)[:, :, ::-1] # uint8 BGR opencv format 18 | return image_cv 19 | 20 | 21 | 22 | 23 | def plot_semantic_legend( 24 | label, 25 | label_name, 26 | colormap=None, 27 | font_size=30, 28 | font_path=None, 29 | save_path=None, 30 | img_name=None): 31 | 32 | 33 | """Plot Colour Legend for Semantic Classes 34 | 35 | Parameters 36 | ---------- 37 | label: numpy.ndarray, (N,), int 38 | One-dimensional array containing the unique labels of exsiting semantic classes 39 | label_names: list of string 40 | Label id to label name. 41 | font_size: int 42 | Font size (default: 30). 43 | colormap: numpy.ndarray, (M, 3), numpy.uint8 44 | Label id to color. 45 | By default, :func:`~imgviz.label_colormap` is used. 46 | font_path: str 47 | Font path. 48 | 49 | Returns 50 | ------- 51 | res: numpy.ndarray, (H, W, 3), numpy.uint8 52 | Legend image of visualising semantic labels. 53 | 54 | """ 55 | 56 | label = np.unique(label) 57 | if colormap is None: 58 | colormap = label_colormap() 59 | 60 | text_sizes = np.array( 61 | [ 62 | draw_module.text_size( 63 | label_name[l], font_size, font_path=font_path 64 | ) 65 | for l in label 66 | ] 67 | ) 68 | 69 | text_height, text_width = text_sizes.max(axis=0) 70 | legend_height = text_height * len(label) + 5 71 | legend_width = text_width + 20 + (text_height - 10) 72 | 73 | 74 | legend = np.zeros((legend_height+50, legend_width+50, 3), dtype=np.uint8) 75 | aabb1 = np.array([25, 25], dtype=float) 76 | aabb2 = aabb1 + (legend_height, legend_width) 77 | 78 | legend = draw_module.rectangle( 79 | legend, aabb1, aabb2, fill=(255, 255, 255) 80 | ) # fill the legend area by white colour 81 | 82 | y1, x1 = aabb1.round().astype(int) 83 | y2, x2 = aabb2.round().astype(int) 84 | 85 | for i, l in enumerate(label): 86 | box_aabb1 = aabb1 + (i * text_height + 5, 5) 87 | box_aabb2 = box_aabb1 + (text_height - 10, text_height - 10) 88 | legend = draw_module.rectangle( 89 | legend, aabb1=box_aabb1, aabb2=box_aabb2, fill=colormap[l] 90 | ) 91 | legend = draw_module.text( 92 | legend, 93 | yx=aabb1 + (i * text_height, 10 + (text_height - 10)), 94 | text=label_name[l], 95 | size=font_size, 96 | font_path=font_path, 97 | ) 98 | 99 | 100 | plt.figure(1) 101 | plt.title("Semantic Legend!") 102 | plt.imshow(legend) 103 | plt.axis("off") 104 | 105 | img_arr = imgviz.io.pyplot_to_numpy() 106 | plt.close() 107 | if save_path is not None: 108 | import cv2 109 | if img_name is not None: 110 | sav_dir = os.path.join(save_path, img_name) 111 | else: 112 | sav_dir = os.path.join(save_path, "semantic_class_Legend.png") 113 | # plt.savefig(sav_dir, bbox_inches='tight', pad_inches=0) 114 | cv2.imwrite(sav_dir, img_arr[:,:,::-1]) 115 | return img_arr 116 | 117 | 118 | 119 | 120 | def image_vis( 121 | pred_data_dict, 122 | gt_data_dict, 123 | # enable_sem = True 124 | ): 125 | to8b_np = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 126 | batch_size = pred_data_dict["vis_deps"].shape[0] 127 | 128 | gt_dep_row = np.concatenate(np.split(gt_data_dict["vis_deps"], batch_size, 0), axis=-2)[0] 129 | gt_raw_dep_row = np.concatenate(np.split(gt_data_dict["deps"], batch_size, 0), axis=-1)[0] 130 | 131 | gt_sem_row = np.concatenate(np.split(gt_data_dict["vis_sems"], batch_size, 0), axis=-2)[0] 132 | gt_sem_clean_row = np.concatenate(np.split(gt_data_dict["vis_sems_clean"], batch_size, 0), axis=-2)[0] 133 | gt_rgb_row = np.concatenate(np.split(gt_data_dict["rgbs"], batch_size, 0), axis=-2)[0] 134 | 135 | pred_dep_row = np.concatenate(np.split(pred_data_dict["vis_deps"], batch_size, 0), axis=-2)[0] 136 | pred_raw_dep_row = np.concatenate(np.split(pred_data_dict["deps"], batch_size, 0), axis=-1)[0] 137 | 138 | pred_sem_row = np.concatenate(np.split(pred_data_dict["vis_sems"], batch_size, 0), axis=-2)[0] 139 | pred_entropy_row = np.concatenate(np.split(pred_data_dict["vis_sem_uncers"], batch_size, 0), axis=-2)[0] 140 | pred_rgb_row = np.concatenate(np.split(pred_data_dict["rgbs"], batch_size, 0), axis=-2)[0] 141 | 142 | rgb_diff = np.abs(gt_rgb_row - pred_rgb_row) 143 | 144 | dep_diff = np.abs(gt_raw_dep_row - pred_raw_dep_row) 145 | dep_diff[gt_raw_dep_row== 0] = 0 146 | dep_diff_vis = imgviz.depth2rgb(dep_diff) 147 | 148 | views = [to8b_np(gt_rgb_row), to8b_np(pred_rgb_row), to8b_np(rgb_diff), 149 | gt_dep_row, pred_dep_row, dep_diff_vis, 150 | gt_sem_clean_row, gt_sem_row, pred_sem_row, pred_entropy_row] 151 | 152 | viz = np.vstack(views) 153 | return viz 154 | 155 | 156 | 157 | 158 | nyu13_colour_code = (np.array([[0, 0, 0], 159 | [0, 0, 1], # BED 160 | [0.9137,0.3490,0.1882], #BOOKS 161 | [0, 0.8549, 0], #CEILING 162 | [0.5843,0,0.9412], #CHAIR 163 | [0.8706,0.9451,0.0941], #FLOOR 164 | [1.0000,0.8078,0.8078], #FURNITURE 165 | [0,0.8784,0.8980], #OBJECTS 166 | [0.4157,0.5333,0.8000], #PAINTING 167 | [0.4588,0.1137,0.1608], #SOFA 168 | [0.9412,0.1373,0.9216], #TABLE 169 | [0,0.6549,0.6118], #TV 170 | [0.9765,0.5451,0], #WALL 171 | [0.8824,0.8980,0.7608]])*255).astype(np.uint8) 172 | 173 | 174 | # color palette for nyu34 labels 175 | nyu34_colour_code = np.array([ 176 | (0, 0, 0), 177 | 178 | (174, 199, 232), # wall 179 | (152, 223, 138), # floor 180 | (31, 119, 180), # cabinet 181 | (255, 187, 120), # bed 182 | (188, 189, 34), # chair 183 | 184 | (140, 86, 75), # sofa 185 | (255, 152, 150), # table 186 | (214, 39, 40), # door 187 | (197, 176, 213), # window 188 | # (148, 103, 189), # bookshelf 189 | 190 | (196, 156, 148), # picture 191 | (23, 190, 207), # counter 192 | (178, 76, 76), # blinds 193 | (247, 182, 210), # desk 194 | (66, 188, 102), # shelves 195 | 196 | (219, 219, 141), # curtain 197 | # (140, 57, 197), # dresser 198 | (202, 185, 52), # pillow 199 | # (51, 176, 203), # mirror 200 | (200, 54, 131), # floor 201 | 202 | (92, 193, 61), # clothes 203 | (78, 71, 183), # ceiling 204 | (172, 114, 82), # books 205 | (255, 127, 14), # refrigerator 206 | (91, 163, 138), # tv 207 | 208 | (153, 98, 156), # paper 209 | (140, 153, 101), # towel 210 | # (158, 218, 229), # shower curtain 211 | (100, 125, 154), # box 212 | # (178, 127, 135), # white board 213 | 214 | # (120, 185, 128), # person 215 | (146, 111, 194), # night stand 216 | (44, 160, 44), # toilet 217 | (112, 128, 144), # sink 218 | (96, 207, 209), # lamp 219 | 220 | (227, 119, 194), # bathtub 221 | (213, 92, 176), # bag 222 | (94, 106, 211), # other struct 223 | (82, 84, 163), # otherfurn 224 | (100, 85, 144) # other prop 225 | ]).astype(np.uint8) 226 | 227 | 228 | 229 | # color palette for nyu40 labels 230 | nyu40_colour_code = np.array([ 231 | (0, 0, 0), 232 | 233 | (174, 199, 232), # wall 234 | (152, 223, 138), # floor 235 | (31, 119, 180), # cabinet 236 | (255, 187, 120), # bed 237 | (188, 189, 34), # chair 238 | 239 | (140, 86, 75), # sofa 240 | (255, 152, 150), # table 241 | (214, 39, 40), # door 242 | (197, 176, 213), # window 243 | (148, 103, 189), # bookshelf 244 | 245 | (196, 156, 148), # picture 246 | (23, 190, 207), # counter 247 | (178, 76, 76), # blinds 248 | (247, 182, 210), # desk 249 | (66, 188, 102), # shelves 250 | 251 | (219, 219, 141), # curtain 252 | (140, 57, 197), # dresser 253 | (202, 185, 52), # pillow 254 | (51, 176, 203), # mirror 255 | (200, 54, 131), # floor 256 | 257 | (92, 193, 61), # clothes 258 | (78, 71, 183), # ceiling 259 | (172, 114, 82), # books 260 | (255, 127, 14), # refrigerator 261 | (91, 163, 138), # tv 262 | 263 | (153, 98, 156), # paper 264 | (140, 153, 101), # towel 265 | (158, 218, 229), # shower curtain 266 | (100, 125, 154), # box 267 | (178, 127, 135), # white board 268 | 269 | (120, 185, 128), # person 270 | (146, 111, 194), # night stand 271 | (44, 160, 44), # toilet 272 | (112, 128, 144), # sink 273 | (96, 207, 209), # lamp 274 | 275 | (227, 119, 194), # bathtub 276 | (213, 92, 176), # bag 277 | (94, 106, 211), # other struct 278 | (82, 84, 163), # otherfurn 279 | (100, 85, 144) # other prop 280 | ]).astype(np.uint8) 281 | 282 | 283 | if __name__ == "__main__": 284 | # nyu40_class_name_string = ["void", 285 | # "wall", "floor", "cabinet", "bed", "chair", 286 | # "sofa", "table", "door", "window", "book", 287 | # "picture", "counter", "blinds", "desk", "shelves", 288 | # "curtain", "dresser", "pillow", "mirror", "floor", 289 | # "clothes", "ceiling", "books", "fridge", "tv", 290 | # "paper", "towel", "shower curtain", "box", "white board", 291 | # "person", "night stand", "toilet", "sink", "lamp", 292 | # "bath tub", "bag", "other struct", "other furntr", "other prop"] # NYUv2-40-class 293 | 294 | # legend_img_arr = plot_semantic_legend(np.arange(41), nyu40_class_name_string, 295 | # colormap=nyu40_colour_code, 296 | # save_path="/home/shuaifeng/Documents/PhD_Research/SemanticSceneRepresentations/SSR", 297 | # img_name="nyu40_legned.png") 298 | 299 | 300 | nyu13_class_name_string = ["void", 301 | "bed", "books", "ceiling", "chair", "floor", 302 | "furniture", "objects", "painting/picture", "sofa", "table", 303 | "TV", "wall", "window"] # NYUv2-13-class 304 | 305 | legend_img_arr = plot_semantic_legend(np.arange(14), nyu13_class_name_string, 306 | colormap=nyu13_colour_code, 307 | save_path="/home/shuaifeng/Documents/PhD_Research/SemanticSceneRepresentations/SSR", 308 | img_name="nyu13_legned.png") -------------------------------------------------------------------------------- /SSR/extract_colour_mesh.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import argparse 7 | from SSR.datasets.replica import replica_datasets 8 | from SSR.datasets.scannet import scannet_datasets 9 | from SSR.datasets.replica_nyu import replica_nyu_cnn_datasets 10 | from SSR.datasets.scannet import scannet_datasets 11 | import open3d as o3d 12 | 13 | from SSR.training import trainer 14 | from SSR.models.model_utils import run_network 15 | from SSR.geometry.occupancy import grid_within_bound 16 | from SSR.visualisation import open3d_utils 17 | import numpy as np 18 | import yaml 19 | import json 20 | 21 | import skimage.measure as ski_measure 22 | import time 23 | from imgviz import label_colormap 24 | import trimesh 25 | 26 | 27 | @torch.no_grad() 28 | def render_fn(trainer, rays, chunk): 29 | """Do batched inference on rays using chunk.""" 30 | B = rays.shape[0] 31 | results = defaultdict(list) 32 | for i in range(0, B, chunk): 33 | rendered_ray_chunks = \ 34 | trainer.render_rays(rays[i:i+chunk]) 35 | 36 | for k, v in rendered_ray_chunks.items(): 37 | results[k] += [v.cpu()] 38 | 39 | for k, v in results.items(): 40 | results[k] = torch.cat(v, 0) 41 | return results 42 | 43 | 44 | def train(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--config_file', type=str, default="/home/shuaifeng/Documents/PhD_Research/CodeRelease/SemanticSceneRepresentations/SSR/configs/SSR_room0_config_test.yaml", help='config file name.') 47 | 48 | parser.add_argument('--mesh_dir', type=str, required=True, help='Path to scene file, e.g., ROOT_PATH/Replica/mesh/room_0/') 49 | parser.add_argument('--training_data_dir', type=str, required=True, help='Path to rendered data.') 50 | parser.add_argument('--save_dir', type=str, required=True, help='Path to the directory saving training logs and ckpts.') 51 | 52 | parser.add_argument('--use_vertex_normal', action="store_true", help='use vertex normals to compute color') 53 | parser.add_argument('--near_t', type=float, default=2.0, help='the near bound factor to start the ray') 54 | parser.add_argument('--sem', action="store_true") 55 | parser.add_argument('--grid_dim', type=int, default=256) 56 | parser.add_argument('--gpu', type=str, default="", help='GPU IDs.') 57 | 58 | 59 | 60 | args = parser.parse_args() 61 | 62 | config_file_path = args.config_file 63 | 64 | # Read YAML file 65 | with open(config_file_path, 'r') as f: 66 | config = yaml.safe_load(f) 67 | if len(args.gpu)>0: 68 | config["experiment"]["gpu"] = args.gpu 69 | print("Experiment GPU is {}.".format(config["experiment"]["gpu"])) 70 | trainer.select_gpus(config["experiment"]["gpu"]) 71 | 72 | 73 | to8b_np = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 74 | logits_2_label = lambda x: torch.argmax(torch.nn.functional.softmax(x, dim=-1),dim=-1) 75 | 76 | # Cast intrinsics to right types 77 | ssr_trainer = trainer.SSRTrainer(config) 78 | 79 | 80 | near_t = args.near_t 81 | mesh_dir = args.mesh_dir 82 | training_data_dir = args.training_data_dir 83 | save_dir = args.save_dir 84 | mesh_recon_save_dir = os.path.join(save_dir, "mesh_reconstruction") 85 | os.makedirs(mesh_recon_save_dir, exist_ok=True) 86 | 87 | 88 | info_mesh_file = os.path.join(mesh_dir, "habitat", "info_semantic.json") 89 | with open(info_mesh_file, "r") as f: 90 | annotations = json.load(f) 91 | 92 | instance_id_to_semantic_label_id = np.array(annotations["id_to_label"]) 93 | instance_id_to_semantic_label_id[instance_id_to_semantic_label_id<=0] = 0 94 | semantic_classes = np.unique(instance_id_to_semantic_label_id) 95 | num_classes = len(semantic_classes) # including void class--0 96 | label_colour_map = label_colormap()[semantic_classes] 97 | valid_colour_map = label_colour_map[1:] 98 | 99 | total_num = 900 100 | step = 5 101 | ids = list(range(total_num)) 102 | train_ids = list(range(0, total_num, step)) 103 | test_ids = [x+2 for x in train_ids] 104 | 105 | replica_data_loader = replica_datasets.ReplicaDatasetCache(data_dir=training_data_dir, 106 | train_ids=train_ids, test_ids=test_ids, 107 | img_h=config["experiment"]["height"], 108 | img_w=config["experiment"]["width"]) 109 | 110 | ssr_trainer.set_params_replica() 111 | ssr_trainer.prepare_data_replica(replica_data_loader) 112 | 113 | ########################## 114 | 115 | # Create nerf model, init optimizer 116 | ssr_trainer.create_ssr() 117 | # Create rays in world coordinates 118 | ssr_trainer.init_rays() 119 | 120 | # load_ckpt into NeRF 121 | ckpt_path = os.path.join(save_dir, "checkpoints", "200000.ckpt") 122 | print('Reloading from', ckpt_path) 123 | ckpt = torch.load(ckpt_path) 124 | 125 | start = ckpt['global_step'] 126 | ssr_trainer.ssr_net_coarse.load_state_dict(ckpt['network_coarse_state_dict']) 127 | ssr_trainer.ssr_net_fine.load_state_dict(ckpt['network_fine_state_dict']) 128 | ssr_trainer.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) 129 | ssr_trainer.training = False # enable testing mode before rendering results, need to set back during training! 130 | ssr_trainer.ssr_net_coarse.eval() 131 | ssr_trainer.ssr_net_fine.eval() 132 | 133 | 134 | level = 0.45 # level = 0 135 | threshold = 0.2 136 | draw_cameras = True 137 | grid_dim = args.grid_dim 138 | 139 | train_Ts_np = replica_data_loader.train_samples["T_wc"] 140 | mesh_file = os.path.join(mesh_dir,"mesh.ply") 141 | assert os.path.exists(mesh_file) 142 | 143 | trimesh_scene = trimesh.load(mesh_file, process=False) 144 | 145 | to_origin_transform, extents = trimesh.bounds.oriented_bounds(trimesh_scene) 146 | T_extent_to_scene = np.linalg.inv(to_origin_transform) 147 | scene_transform = T_extent_to_scene 148 | scene_extents = extents 149 | grid_query_pts, scene_scale = grid_within_bound([-1.0, 1.0], scene_extents, scene_transform, grid_dim=grid_dim) 150 | 151 | grid_query_pts = grid_query_pts.cuda().reshape(-1,1,3) # Num_rays, 1, 3-xyz 152 | viewdirs = torch.zeros_like(grid_query_pts).reshape(-1, 3) 153 | st = time.time() 154 | print("Initialise Trimesh Scenes") 155 | 156 | with torch.no_grad(): 157 | chunk = 1024 158 | run_MLP_fn = lambda pts: run_network(inputs=pts, viewdirs=torch.zeros_like(pts).squeeze(1), 159 | fn=ssr_trainer.ssr_net_fine, embed_fn=ssr_trainer.embed_fn, 160 | embeddirs_fn=ssr_trainer.embeddirs_fn, netchunk=int(2048*128)) 161 | 162 | raw = torch.cat([run_MLP_fn(grid_query_pts[i: i+chunk]).cpu() for i in range(0, grid_query_pts.shape[0], chunk)], dim=0) 163 | rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] 164 | alpha = raw[..., 3] # [N] 165 | sem_logits = raw[..., 4:] # [N_rays, N_samples, num_class] 166 | label_fine = logits_2_label(sem_logits).view(-1).cpu().numpy() 167 | vis_label_colour = label_colour_map[label_fine+1] 168 | 169 | print("Finish Computing Semantics!") 170 | print() 171 | 172 | def occupancy_activation(alpha, distances): 173 | occ = 1.0 - torch.exp(-F.relu(alpha) * distances) 174 | # notice we apply RELU to raw sigma before computing alpha 175 | return occ 176 | 177 | # voxel_size = (ssr_trainer.far - ssr_trainer.near) / grid_dim # or self.N_importance 178 | voxel_size = (ssr_trainer.far - ssr_trainer.near) / ssr_trainer.N_importance # or self.N_importance 179 | occ = occupancy_activation(alpha, voxel_size) 180 | print("Compute Occupancy Grids") 181 | occ = occ.reshape(grid_dim, grid_dim, grid_dim) 182 | occupancy_grid = occ.detach().cpu().numpy() 183 | 184 | print('fraction occupied:', (occupancy_grid > threshold).mean()) 185 | print('Max Occ: {}, Min Occ: {}, Mean Occ: {}'.format(occupancy_grid.max(), occupancy_grid.min(), occupancy_grid.mean())) 186 | vertices, faces, vertex_normals, _ = ski_measure.marching_cubes(occupancy_grid, level=level, gradient_direction='ascent') 187 | print() 188 | 189 | dim = occupancy_grid.shape[0] 190 | vertices = vertices / (dim - 1) 191 | mesh = trimesh.Trimesh(vertices=vertices, vertex_normals=vertex_normals, faces=faces) 192 | 193 | # Transform to [-1, 1] range 194 | mesh_canonical = mesh.copy() 195 | mesh_canonical.apply_translation([-0.5, -0.5, -0.5]) 196 | mesh_canonical.apply_scale(2) 197 | 198 | scene_scale = scene_extents/2.0 199 | # Transform to scene coordinates 200 | mesh_canonical.apply_scale(scene_scale) 201 | mesh_canonical.apply_transform(scene_transform) 202 | # mesh.show() 203 | exported = trimesh.exchange.export.export_mesh(mesh_canonical, os.path.join(mesh_recon_save_dir, 'mesh_canonical.ply')) 204 | print("Saving Marching Cubes mesh to mesh_canonical.ply !") 205 | exported = trimesh.exchange.export.export_mesh(mesh_canonical, os.path.join(mesh_recon_save_dir, 'mesh.ply')) 206 | print("Saving Marching Cubes mesh to mesh.ply !") 207 | 208 | 209 | o3d_mesh = open3d_utils.trimesh_to_open3d(mesh) 210 | o3d_mesh_canonical = open3d_utils.trimesh_to_open3d(mesh_canonical) 211 | 212 | print('Removing noise ...') 213 | print(f'Original Mesh has {len(o3d_mesh_canonical.vertices)/1e6:.2f} M vertices and {len(o3d_mesh_canonical.triangles)/1e6:.2f} M faces.') 214 | o3d_mesh_canonical_clean = open3d_utils.clean_mesh(o3d_mesh_canonical, keep_single_cluster=False, min_num_cluster=400) 215 | 216 | vertices_ = np.array(o3d_mesh_canonical_clean.vertices).reshape([-1, 3]).astype(np.float32) 217 | triangles = np.asarray(o3d_mesh_canonical_clean.triangles) # (n, 3) int 218 | N_vertices = vertices_.shape[0] 219 | print(f'Denoised Mesh has {len(o3d_mesh_canonical_clean.vertices)/1e6:.2f} M vertices and {len(o3d_mesh_canonical_clean.triangles)/1e6:.2f} M faces.') 220 | 221 | print("###########################################") 222 | print() 223 | print("Using Normals for colour predictions!") 224 | print() 225 | print("###########################################") 226 | 227 | ## use normal vector method as suggested by the author, see https://github.com/bmild/nerf/issues/44 228 | mesh_recon_save_dir = os.path.join(mesh_recon_save_dir,"use_vertex_normal") 229 | os.makedirs(mesh_recon_save_dir, exist_ok=True) 230 | 231 | selected_mesh = o3d_mesh_canonical_clean 232 | rays_d = - torch.FloatTensor(np.asarray(selected_mesh.vertex_normals)) # use negative normal directions as ray marching directions 233 | near = 0.1 * torch.ones_like(rays_d[:, :1]) 234 | far = 10.0 * torch.ones_like(rays_d[:, :1]) 235 | rays_o = torch.FloatTensor(vertices_) - rays_d * near * args.near_t 236 | viewdirs = rays_d 237 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True).float() 238 | rays = torch.cat([rays_o, rays_d, near, far, viewdirs], -1) 239 | 240 | # provide ray directions as input 241 | rays = rays.cuda() 242 | with torch.no_grad(): 243 | chunk=4096 244 | # chunk=80*1024 245 | results = render_fn(ssr_trainer, rays, chunk) 246 | 247 | # combine the output and write to file 248 | if args.sem: 249 | labels = logits_2_label(results["sem_logits_fine"]).numpy() 250 | vis_labels = valid_colour_map[labels] 251 | v_colors = vis_labels 252 | else: 253 | rgbs = results["rgb_fine"].numpy() 254 | rgbs = to8b_np(rgbs) 255 | v_colors = rgbs 256 | 257 | v_colors = v_colors.astype(np.uint8) 258 | 259 | 260 | o3d_mesh_canonical_clean.vertex_colors = o3d.utility.Vector3dVector(v_colors/255.0) 261 | 262 | if args.sem: 263 | o3d.io.write_triangle_mesh(os.path.join(mesh_recon_save_dir, 'semantic_mesh_canonical_dim{}neart_{}.ply'.format(grid_dim, near_t)), o3d_mesh_canonical_clean) 264 | print("Saving Marching Cubes mesh to semantic_mesh_canonical_dim{}neart_{}.ply".format(grid_dim, near_t)) 265 | else: 266 | o3d.io.write_triangle_mesh(os.path.join(mesh_recon_save_dir, 'colour_mesh_canonical_dim{}neart_{}.ply'.format(grid_dim, near_t)), o3d_mesh_canonical_clean) 267 | print("Saving Marching Cubes mesh to colour_mesh_canonical_dim{}neart_{}.ply".format(grid_dim, near_t)) 268 | 269 | print('Done!') 270 | 271 | 272 | if __name__=='__main__': 273 | train() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Semantic-NeRF SOFTWARE 2 | 3 | LICENCE AGREEMENT 4 | 5 | WE (Imperial College of Science, Technology and Medicine, (“Imperial College 6 | London”)) ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY 7 | ON THE CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE FOLLOWING 8 | AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE DOWNLOADING THE SOFTWARE. 9 | BY EXERCISING THE OPTION TO DOWNLOAD THE SOFTWARE YOU AGREE TO BE BOUND BY THE 10 | TERMS OF THE AGREEMENT. 11 | 12 | SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS) 13 | 14 | 1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully 15 | paid-up, royalty free, non-transferable, non-sub- licensable licence (the 16 | “Licence”) to use the elastic fusion source code, including any modification, 17 | part or derivative (the “Software”). 18 | 19 | Ownership and Licence. Your rights to use and download the Software onto your 20 | computer, and all other copies that You are authorised to make, are specified 21 | in this Agreement. However, we (or our licensors) retain all rights, including 22 | but not limited to all copyright and other intellectual property rights 23 | anywhere in the world, in the Software not expressly granted to You in this 24 | Agreement. 25 | 26 | 2. Permitted use of the Licence: 27 | 28 | (a) You may download and install the Software onto one computer or server for 29 | use in accordance with Clause 2(b) of this Agreement provided that You ensure 30 | that the Software is not accessible by other users unless they have themselves 31 | accepted the terms of this licence agreement. 32 | 33 | (b) You may use the Software solely for non-commercial, internal or academic 34 | research purposes and only in accordance with the terms of this Agreement. You 35 | may not use the Software for commercial purposes, including but not limited to 36 | (1) integration of all or part of the source code or the Software into a 37 | product for sale or licence by or on behalf of You to third parties or (2) use 38 | of the Software or any derivative of it for research to develop software 39 | products for sale or licence to a third party or (3) use of the Software or any 40 | derivative of it for research to develop non-software products for sale or 41 | licence to a third party, or (4) use of the Software to provide any service to 42 | an external organisation for which payment is received. 43 | 44 | Should You wish to use the Software for commercial purposes, You shall 45 | email researchcontracts.engineering@imperial.ac.uk . 46 | 47 | (c) Right to Copy. You may copy the Software for back-up and archival purposes, 48 | provided that each copy is kept in your possession and provided You reproduce 49 | our copyright notice (set out in Schedule 1) on each copy. 50 | 51 | (d) Transfer and sub-licensing. You may not rent, lend, or lease the Software 52 | and You may not transmit, transfer or sub-license this licence to use the 53 | Software or any of your rights or obligations under this Agreement to another 54 | party. 55 | 56 | (e) Identity of Licensee. The licence granted herein is personal to You. You 57 | shall not permit any third party to access, modify or otherwise use the 58 | Software nor shall You access modify or otherwise use the Software on behalf of 59 | any third party. If You wish to obtain a licence for mutiple users or a site 60 | licence for the Software please contact us 61 | at researchcontracts.engineering@imperial.ac.uk . 62 | 63 | (f) Publications and presentations. You may make public, results or data 64 | obtained from, dependent on or arising from research carried out using the 65 | Software, provided that any such presentation or publication identifies the 66 | Software as the source of the results or the data, including the Copyright 67 | Notice given in each element of the Software, and stating that the Software has 68 | been made available for use by You under licence from Imperial College London 69 | and You provide a copy of any such publication to Imperial College London. 70 | 71 | 3. Prohibited Uses. You may not, without written permission from us 72 | at researchcontracts.engineering@imperial.ac.uk : 73 | 74 | (a) Use, copy, modify, merge, or transfer copies of the Software or any 75 | documentation provided by us which relates to the Software except as provided 76 | in this Agreement; 77 | 78 | (b) Use any back-up or archival copies of the Software (or allow anyone else to 79 | use such copies) for any purpose other than to replace the original copy in the 80 | event it is destroyed or becomes defective; or 81 | 82 | (c) Disassemble, decompile or "unlock", reverse translate, or in any manner 83 | decode the Software for any reason. 84 | 85 | 4. Warranty Disclaimer 86 | 87 | (a) Disclaimer. The Software has been developed for research purposes only. You 88 | acknowledge that we are providing the Software to You under this licence 89 | agreement free of charge and on condition that the disclaimer set out below 90 | shall apply. We do not represent or warrant that the Software as to: (i) the 91 | quality, accuracy or reliability of the Software; (ii) the suitability of the 92 | Software for any particular use or for use under any specific conditions; and 93 | (iii) whether use of the Software will infringe third-party rights. 94 | 95 | You acknowledge that You have reviewed and evaluated the Software to determine 96 | that it meets your needs and that You assume all responsibility and liability 97 | for determining the suitability of the Software as fit for your particular 98 | purposes and requirements. Subject to Clause 4(b), we exclude and expressly 99 | disclaim all express and implied representations, warranties, conditions and 100 | terms not stated herein (including the implied conditions or warranties of 101 | satisfactory quality, merchantable quality, merchantability and fitness for 102 | purpose). 103 | 104 | (b) Savings. Some jurisdictions may imply warranties, conditions or terms or 105 | impose obligations upon us which cannot, in whole or in part, be excluded, 106 | restricted or modified or otherwise do not allow the exclusion of implied 107 | warranties, conditions or terms, in which case the above warranty disclaimer 108 | and exclusion will only apply to You to the extent permitted in the relevant 109 | jurisdiction and does not in any event exclude any implied warranties, 110 | conditions or terms which may not under applicable law be excluded. 111 | 112 | (c) Imperial College London disclaims all responsibility for the use which is 113 | made of the Software and any liability for the outcomes arising from using the 114 | Software. 115 | 116 | 5. Limitation of Liability 117 | 118 | (a) You acknowledge that we are providing the Software to You under this 119 | licence agreement free of charge and on condition that the limitation of 120 | liability set out below shall apply. Accordingly, subject to Clause 5(b), we 121 | exclude all liability whether in contract, tort, negligence or otherwise, in 122 | respect of the Software and/or any related documentation provided to You by us 123 | including, but not limited to, liability for loss or corruption of data, loss 124 | of contracts, loss of income, loss of profits, loss of cover and any 125 | consequential or indirect loss or damage of any kind arising out of or in 126 | connection with this licence agreement, however caused. This exclusion shall 127 | apply even if we have been advised of the possibility of such loss or damage. 128 | 129 | (b) You agree to indemnify Imperial College London and hold it harmless from 130 | and against any and all claims, damages and liabilities asserted by third 131 | parties (including claims for negligence) which arise directly or indirectly 132 | from the use of the Software or any derivative of it or the sale of any 133 | products based on the Software. You undertake to make no liability claim 134 | against any employee, student, agent or appointee of Imperial College London, 135 | in connection with this Licence or the Software. 136 | 137 | (c) Nothing in this Agreement shall have the effect of excluding or limiting 138 | our statutory liability. 139 | 140 | (d) Some jurisdictions do not allow these limitations or exclusions either 141 | wholly or in part, and, to that extent, they may not apply to you. Nothing in 142 | this licence agreement will affect your statutory rights or other relevant 143 | statutory provisions which cannot be excluded, restricted or modified, and its 144 | terms and conditions must be read and construed subject to any such statutory 145 | rights and/or provisions. 146 | 147 | 6. Confidentiality. You agree not to disclose any confidential information 148 | provided to You by us pursuant to this Agreement to any third party without our 149 | prior written consent. The obligations in this Clause 6 shall survive the 150 | termination of this Agreement for any reason. 151 | 152 | 7. Termination. 153 | 154 | (a) We may terminate this licence agreement and your right to use the Software 155 | at any time with immediate effect upon written notice to You. 156 | 157 | (b) This licence agreement and your right to use the Software automatically 158 | terminate if You: 159 | 160 | (i) fail to comply with any provisions of this Agreement; or 161 | 162 | (ii) destroy the copies of the Software in your possession, or voluntarily 163 | return the Software to us. 164 | 165 | (c) Upon termination You will destroy all copies of the Software. 166 | 167 | (d) Otherwise, the restrictions on your rights to use the Software will expire 168 | 10 (ten) years after first use of the Software under this licence agreement. 169 | 170 | 8. Miscellaneous Provisions. 171 | 172 | (a) This Agreement will be governed by and construed in accordance with the 173 | substantive laws of England and Wales whose courts shall have exclusive 174 | jurisdiction over all disputes which may arise between us. 175 | 176 | (b) This is the entire agreement between us relating to the Software, and 177 | supersedes any prior purchase order, communications, advertising or 178 | representations concerning the Software. 179 | 180 | (c) No change or modification of this Agreement will be valid unless it is in 181 | writing, and is signed by us. 182 | 183 | (d) The unenforceability or invalidity of any part of this Agreement will not 184 | affect the enforceability or validity of the remaining parts. 185 | 186 | BSD Elements of the Software 187 | 188 | For BSD elements of the Software, the following terms shall apply: 189 | Copyright as indicated in the header of the individual element of the Software. 190 | All rights reserved. 191 | 192 | Redistribution and use in source and binary forms, with or without 193 | modification, are permitted provided that the following conditions are met: 194 | 195 | 1. Redistributions of source code must retain the above copyright notice, this 196 | list of conditions and the following disclaimer. 197 | 198 | 2. Redistributions in binary form must reproduce the above copyright notice, 199 | this list of conditions and the following disclaimer in the documentation 200 | and/or other materials provided with the distribution. 201 | 202 | 3. Neither the name of the copyright holder nor the names of its contributors 203 | may be used to endorse or promote products derived from this software without 204 | specific prior written permission. 205 | 206 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 207 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 208 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 209 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 210 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 211 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 212 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 213 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 214 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 215 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 216 | 217 | SCHEDULE 1 218 | 219 | The Software 220 | 221 | Semantic-NeRF is a scene-specific 3D semantic representation built upon Neural Radiance Fields (NeRF), jointly encoding semantics with appearance and geometry. It can be efficiently learned with a small amount of in-place supervision and reach complete and achieves accurate 2D semantic labels in room-scale 222 | scenes. It is based on the techniques described in the following publication: 223 | 224 | • Shuaifeng Zhi, Tristan Laidlow, Stefan Leutenegger, Andrew J. Davison. In-Place Scene Labelling and Understanding with Implicit Scene Representation. International Conference on Computer Vision (ICCV), 2021 225 | _________________________ 226 | 227 | Acknowledgments 228 | 229 | If you use the software, you should reference the following paper in any 230 | publication: 231 | 232 | • Shuaifeng Zhi, Tristan Laidlow, Stefan Leutenegger, Andrew J. Davison. In-Place Scene Labelling and Understanding with Implicit Scene Representation. International Conference on Computer Vision (ICCV), 2021 -------------------------------------------------------------------------------- /train_SSR_main.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import argparse 4 | 5 | from SSR.datasets.replica import replica_datasets 6 | from SSR.datasets.scannet import scannet_datasets 7 | from SSR.datasets.replica_nyu import replica_nyu_cnn_datasets 8 | from SSR.datasets.scannet import scannet_datasets 9 | 10 | from SSR.training import trainer 11 | 12 | from tqdm import trange 13 | import time 14 | 15 | def train(): 16 | parser = argparse.ArgumentParser() 17 | # parser.add_argument('--config_file', type=str, default="/home/shuaifeng/Documents/PhD_Research/CodeRelease/SemanticSceneRepresentations/SSR/configs/SSR_room2_config_release.yaml", 18 | # help='config file name.') 19 | parser.add_argument('--config_file', type=str, default="/home/shuaifeng/Documents/PhD_Research/CodeRelease/SemanticSceneRepresentations/SSR/configs/SSR_room0_config_test.yaml", 20 | help='config file name.') 21 | parser.add_argument('--dataset_type', type=str, default="replica", choices= ["replica", "replica_nyu_cnn", "scannet"], 22 | help='the dataset to be used,') 23 | 24 | ### working mode and specific options 25 | 26 | # sparse-views 27 | parser.add_argument("--sparse_views", action='store_true', 28 | help='Use labels from a sparse set of frames') 29 | parser.add_argument("--sparse_ratio", type=float, default=0, 30 | help='The portion of dropped labelling frames during training, which can be used along with all working modes.') 31 | parser.add_argument("--label_map_ids", nargs='*', type=int, default=[], 32 | help='In sparse view mode, use selected frame ids from sequences as supervision.') 33 | parser.add_argument("--random_sample", action='store_true', help='Whether to randomly/evenly sample frames from the sequence.') 34 | 35 | # denoising---pixel-wsie 36 | parser.add_argument("--pixel_denoising", action='store_true', 37 | help='Whether to work in pixel-denoising tasks.') 38 | parser.add_argument("--pixel_noise_ratio", type=float, default=0, 39 | help='In sparse view mode, if pixel_noise_ratio > 0, the percentage of pixels to be perturbed in each sampled frame for pixel-wise denoising task..') 40 | 41 | # denoising---region-wsie 42 | parser.add_argument("--region_denoising", action='store_true', 43 | help='Whether to work in region-denoising tasks by flipping class labels of chair instances in Replica Room_2') 44 | parser.add_argument("--region_noise_ratio", type=float, default=0, 45 | help='In region-wise denoising task, region_noise_ratio is the percentage of chair instances to be perturbed in each sampled frame for region-wise denoising task.') 46 | parser.add_argument("--uniform_flip", action='store_true', 47 | help='In region-wise denoising task, whether to change chair labels uniformly or not, i.e., by ascending area ratios. This corresponds to two set-ups mentioned in the paper.') 48 | parser.add_argument("--instance_id", nargs='*', type=int, default=[3, 6, 7, 9, 11, 12, 13, 48], 49 | help='In region-wise denoising task, the chair instance ids in Replica Room_2 to be randomly perturbed. The ids of all 8 chairs are [3, 6, 7, 9, 11, 12, 13, 48]') 50 | 51 | # super-resolution 52 | parser.add_argument("--super_resolution", action='store_true', 53 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 54 | parser.add_argument('--dense_sr', action='store_true', help='Whether to use dense or sparse labels for SR instead of dense labels.') 55 | parser.add_argument('--sr_factor', type=int, default=8, help='Scaling factor of super-resolution.') 56 | 57 | # label propagation 58 | parser.add_argument("--label_propagation", action='store_true', 59 | help='Label propagation using partial seed regions.') 60 | parser.add_argument("--partial_perc", type=float, default=0, 61 | help='0: single-click propagation; 1: using 1-percent sub-regions for label propagation, 5: using 5-percent sub-regions for label propagation') 62 | 63 | # misc. 64 | parser.add_argument('--visualise_save', action='store_true', help='whether to save the noisy labels into harddrive for later usage') 65 | parser.add_argument('--load_saved', action='store_true', help='use trained noisy labels for training to ensure consistency betwwen experiments') 66 | parser.add_argument('--gpu', type=str, default="", help='GPU IDs.') 67 | 68 | args = parser.parse_args() 69 | # Read YAML file 70 | with open(args.config_file, 'r') as f: 71 | config = yaml.safe_load(f) 72 | if len(args.gpu)>0: 73 | config["experiment"]["gpu"] = args.gpu 74 | print("Experiment GPU is {}.".format(config["experiment"]["gpu"])) 75 | trainer.select_gpus(config["experiment"]["gpu"]) 76 | config["experiment"].update(vars(args)) 77 | # Cast intrinsics to right types 78 | ssr_trainer = trainer.SSRTrainer(config) 79 | 80 | if args.dataset_type == "replica": 81 | print("----- Replica Dataset -----") 82 | 83 | total_num = 900 84 | step = 5 85 | train_ids = list(range(0, total_num, step)) 86 | test_ids = [x+step//2 for x in train_ids] 87 | #add ids to config for later saving. 88 | config["experiment"]["train_ids"] = train_ids 89 | config["experiment"]["test_ids"] = test_ids 90 | 91 | # Todo: like nerf, creating sprial/test poses. Make training and test poses/ids interleaved 92 | replica_data_loader = replica_datasets.ReplicaDatasetCache(data_dir=config["experiment"]["dataset_dir"], 93 | train_ids=train_ids, test_ids=test_ids, 94 | img_h=config["experiment"]["height"], 95 | img_w=config["experiment"]["width"]) 96 | 97 | 98 | print("--------------------") 99 | if args.super_resolution: 100 | print("Super Resolution Mode! Dense Label Flag is {}, SR Factor is {}".format(args.dense_sr,args.sr_factor)) 101 | replica_data_loader.super_resolve_label(down_scale_factor=args.sr_factor, dense_supervision=args.dense_sr) 102 | elif args.label_propagation: 103 | print("Label Propagation Mode! Partial labelling percentage is: {} ".format(args.partial_perc)) 104 | replica_data_loader.simulate_user_click_partial(perc=args.partial_perc, load_saved=args.load_saved, visualise_save=args.visualise_save) 105 | if args.sparse_views: # add view-point sampling to partial sampling 106 | print("Sparse Viewing Labels Mode under ***Patial Labelling***! Sparse Ratio is ", args.sparse_ratio) 107 | replica_data_loader.sample_label_maps(sparse_ratio=args.sparse_ratio, random_sample=args.random_sample, load_saved=args.load_saved) 108 | elif args.pixel_denoising: 109 | print("Pixel-Denoising Mode! Noise Ratio is ", args.pixel_noise_ratio) 110 | replica_data_loader.add_pixel_wise_noise_label(sparse_views=args.sparse_views, 111 | sparse_ratio=args.sparse_ratio, 112 | random_sample=args.random_sample, 113 | noise_ratio=args.pixel_noise_ratio, 114 | visualise_save=args.visualise_save, 115 | load_saved=args.load_saved) 116 | elif args.region_denoising: 117 | print("Chair Label Flipping for Region-wise Denoising, Flip ratio is {}, Uniform Sampling is {}".format( args.region_noise_ratio, args.uniform_flip)) 118 | replica_data_loader.add_instance_wise_noise_label(sparse_views=args.sparse_views, sparse_ratio=args.sparse_ratio, random_sample=args.random_sample, 119 | flip_ratio=args.region_noise_ratio, uniform_flip=args.uniform_flip, instance_id= args.instance_id, 120 | load_saved=args.load_saved, visualise_save=args.visualise_save,) 121 | 122 | elif args.sparse_views: 123 | if len(args.label_map_ids)>0: 124 | print("Use label maps only for selected frames, ", args.label_map_ids) 125 | replica_data_loader.sample_specific_labels(args.label_map_ids, train_ids) 126 | else: 127 | print("Sparse Labels Mode! Sparsity Ratio is ", args.sparse_ratio) 128 | replica_data_loader.sample_label_maps(sparse_ratio=args.sparse_ratio, random_sample=args.random_sample, load_saved=args.load_saved) 129 | 130 | else: 131 | print("Standard setup with full dense supervision.") 132 | ssr_trainer.set_params_replica() 133 | ssr_trainer.prepare_data_replica(replica_data_loader) 134 | 135 | elif args.dataset_type == "replica_nyu_cnn": 136 | print("----- Replica Dataset with NYUv2-13 CNN Predictions -----") 137 | 138 | print("Replica_nyu_cnn mode using labels from trained CNNs: {}".format(config["experiment"]["nyu_mode"])) 139 | 140 | total_num = 900 141 | step = 5 142 | 143 | train_ids = list(range(0, total_num, step)) 144 | test_ids = [x+step//2 for x in train_ids] 145 | 146 | #add ids to config for later saving. 147 | config["experiment"]["train_ids"] = train_ids 148 | config["experiment"]["test_ids"] = test_ids 149 | 150 | replica_nyu_cnn_data_loader = replica_nyu_cnn_datasets.Replica_CNN_NYU(data_dir=config["experiment"]["dataset_dir"], 151 | train_ids=train_ids, test_ids=test_ids, 152 | img_h=config["experiment"]["height"], 153 | img_w=config["experiment"]["width"], 154 | nyu_mode = config["experiment"]["nyu_mode"], 155 | load_softmax=config["experiment"]["load_softmax"]) 156 | 157 | ssr_trainer.set_params_replica() # we still call params of replica here since the image sources are from Replica still 158 | ssr_trainer.prepare_data_replica_nyu_cnn(replica_nyu_cnn_data_loader) 159 | 160 | elif args.dataset_type == "scannet": 161 | print("----- ScanNet Dataset with NYUv2-40 Conventions-----") 162 | 163 | print("processing ScanNet scene: ", os.path.basename(config["experiment"]["dataset_dir"])) 164 | # Todo: like nerf, creating sprial/test poses. Make training and test poses/ids interleaved 165 | scannet_data_loader = scannet_datasets.ScanNet_Dataset( scene_dir=config["experiment"]["dataset_dir"], 166 | img_h=config["experiment"]["height"], 167 | img_w=config["experiment"]["width"], 168 | sample_step=config["experiment"]["sample_step"], 169 | save_dir=config["experiment"]["dataset_dir"]) 170 | 171 | 172 | print("--------------------") 173 | if args.super_resolution: 174 | print("Super Resolution Mode! Dense Label Flag is {}, SR Factor is {}".format(args.dense_sr,args.sr_factor)) 175 | scannet_data_loader.super_resolve_label(down_scale_factor=args.sr_factor, dense_supervision=args.dense_sr) 176 | 177 | elif args.label_propagation: 178 | print("Partial Segmentation Mode! Partial percentage is: {} ", args.partial_perc) 179 | scannet_data_loader.simulate_user_click_partial(perc=args.partial_perc, load_saved=args.load_saved, visualise_save=args.visualise_save) 180 | 181 | elif args.pixel_denoising: 182 | print("Pixel-Denoising Mode! Noise Ratio is ", args.pixel_noise_ratio) 183 | scannet_data_loader.add_pixel_wise_noise_label(sparse_views=args.sparse_views, 184 | sparse_ratio=args.sparse_ratio, 185 | random_sample=args.random_sample, 186 | noise_ratio=args.pixel_noise_ratio, 187 | visualise_save=args.visualise_save, 188 | load_saved=args.load_saved) 189 | elif args.sparse_views: 190 | print("Sparse Viewing Labels Mode! Sparse Ratio is ", args.sparse_ratio) 191 | scannet_data_loader.sample_label_maps(sparse_ratio=args.sparse_ratio, random_sample=args.random_sample, load_saved=args.load_saved) 192 | 193 | ssr_trainer.set_params_scannet(scannet_data_loader) 194 | ssr_trainer.prepare_data_scannet(scannet_data_loader) 195 | 196 | 197 | # Create nerf model, init optimizer 198 | ssr_trainer.create_ssr() 199 | # Create rays in world coordinates 200 | ssr_trainer.init_rays() 201 | 202 | start = 0 203 | 204 | N_iters = int(float(config["train"]["N_iters"])) + 1 205 | global_step = start 206 | ########################## 207 | print('Begin') 208 | ##### Training loop ##### 209 | for i in trange(start, N_iters): 210 | 211 | time0 = time.time() 212 | ssr_trainer.step(global_step) 213 | 214 | dt = time.time()-time0 215 | print() 216 | print("Time per step is :", dt) 217 | global_step += 1 218 | 219 | 220 | print('done') 221 | 222 | 223 | if __name__=='__main__': 224 | train() -------------------------------------------------------------------------------- /SSR/datasets/scannet/scannet_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from skimage.io import imread 5 | import cv2 6 | import imageio 7 | 8 | from SSR.datasets.scannet.scannet_utils import load_scannet_nyu40_mapping, load_scannet_nyu13_mapping 9 | from SSR.utils import image_utils 10 | class ScanNet_Dataset(object): 11 | def __init__(self, scene_dir, img_h=None, img_w=None, sample_step=1, save_dir=None, mode="nyu40"): 12 | # we only use rgb+poses from Scannet 13 | self.img_h = img_h 14 | self.img_w = img_w 15 | 16 | self.scene_dir = scene_dir # scene_dir is the root directory of each sequence, i.e., xxx/ScanNet/scans/scene0088_00" 17 | # scene_dir = "/home/shuaifeng/Documents/Datasets/ScanNet/scans/scene0088_00" 18 | scene_name = os.path.basename(scene_dir) 19 | data_dir = os.path.dirname(scene_dir) 20 | 21 | instance_filt_dir = os.path.join(scene_dir, scene_name+'_2d-instance-filt') 22 | label_filt_dir = os.path.join(scene_dir, scene_name+'_2d-label-filt') 23 | self.semantic_class_dir = label_filt_dir 24 | 25 | # (0 corresponds to unannotated or no depth). 26 | if mode=="nyu40": 27 | label_mapping_nyu = load_scannet_nyu40_mapping(scene_dir) 28 | colour_map_np = image_utils.nyu40_colour_code 29 | assert colour_map_np.shape[0] == 41 30 | elif mode=="nyu13": 31 | label_mapping_nyu = load_scannet_nyu13_mapping(scene_dir) 32 | colour_map_np = image_utils.nyu13_colour_code 33 | assert colour_map_np.shape[0] == 14 34 | else: 35 | assert False 36 | 37 | # get camera intrinsics 38 | # we use color camera intrinsics and resize depth to match 39 | with open(os.path.join(scene_dir, "{}.txt".format(scene_name))) as info_f: 40 | info = [line.rstrip().split(' = ') for line in info_f] 41 | info = {key:value for key, value in info} 42 | intrinsics = [ 43 | [float(info['fx_color']), 0, float(info['mx_color'])], 44 | [0, float(info['fy_color']), float(info['my_color'])], 45 | [0, 0, 1]] 46 | 47 | original_colour_h = int(info["colorHeight"]) 48 | original_colour_w = int(info["colorWidth"]) 49 | original_depth_h = int(info["depthHeight"]) 50 | original_depth_w = int(info["depthWidth"]) 51 | assert original_colour_h==968 and original_colour_w==1296 and original_depth_h==480 and original_depth_w==640 52 | 53 | # load 2D colour frames and poses 54 | 55 | frame_ids = os.listdir(os.path.join(scene_dir, "renders", 'color')) 56 | frame_ids = [int(os.path.splitext(frame)[0]) for frame in frame_ids] 57 | frame_ids = sorted(frame_ids) 58 | 59 | frames_file_list = [] 60 | for i, frame_id in enumerate(frame_ids): 61 | if i%25==0: 62 | print('preparing %s frame %d/%d'%(scene_name, i, len(frame_ids))) 63 | 64 | pose = np.loadtxt(os.path.join(scene_dir, "renders", 'pose', '%d.txt' % frame_id)) 65 | 66 | # skip frames with no valid pose 67 | if not np.all(np.isfinite(pose)): 68 | continue 69 | 70 | frame = {'file_name_image': 71 | os.path.join(scene_dir, "renders", 'color', '%d.jpg'%frame_id), 72 | 'file_name_depth': 73 | os.path.join(scene_dir, "renders", 'depth', '%d.png'%frame_id), 74 | 'file_name_instance': 75 | os.path.join(instance_filt_dir, 'instance-filt', '%d.png'%frame_id), 76 | 'file_name_label': 77 | os.path.join(label_filt_dir, 'label-filt', '%d.png'%frame_id), 78 | 'intrinsics': intrinsics, 79 | 'pose': pose, 80 | } 81 | 82 | frames_file_list.append(frame) 83 | 84 | step = sample_step 85 | valid_data_num = len(frames_file_list) 86 | self.valid_data_num = valid_data_num 87 | total_ids = range(valid_data_num) 88 | train_ids = total_ids[::step] 89 | test_ids = [x+ (step//2) for x in train_ids] 90 | if test_ids[-1]>valid_data_num-1: 91 | test_ids.pop(-1) 92 | self.train_ids = train_ids 93 | self.train_num = len(train_ids) 94 | self.test_ids = test_ids 95 | self.test_num = len(test_ids) 96 | 97 | self.train_samples = {'image': [], 'depth': [], 98 | 'semantic_raw': [], # raw scannet label id 99 | 'semantic': [], # nyu40 id 100 | 'T_wc': [], 101 | 'instance': []} 102 | 103 | 104 | self.test_samples = {'image': [], 'depth': [], 105 | 'semantic_raw': [], 106 | 'semantic': [], 107 | 'T_wc': [], 108 | 'instance': []} 109 | 110 | # training samples 111 | for idx in train_ids: 112 | image = cv2.imread(frames_file_list[idx]["file_name_image"])[:,:,::-1] # change from BGR uinit 8 to RGB float 113 | image = cv2.copyMakeBorder(src=image, top=2, bottom=2, left=0, right=0, borderType=cv2.BORDER_CONSTANT, value=[0,0,0]) # pad 4 pixels to height so that images have aspect ratio of 4:3 114 | assert image.shape[0]/image.shape[1]==3/4 and image.shape[1]==original_colour_w and image.shape[0] == 972 115 | image = image/255.0 116 | 117 | depth = cv2.imread(frames_file_list[idx]["file_name_depth"], cv2.IMREAD_UNCHANGED) / 1000.0 # uint16 mm depth, then turn depth from mm to meter 118 | 119 | semantic = cv2.imread(frames_file_list[idx]["file_name_label"], cv2.IMREAD_UNCHANGED) 120 | semantic = cv2.copyMakeBorder(src=semantic, top=2, bottom=2, left=0, right=0, borderType=cv2.BORDER_CONSTANT, value=0) 121 | 122 | instance = cv2.imread(frames_file_list[idx]["file_name_instance"], cv2.IMREAD_UNCHANGED) 123 | instance = cv2.copyMakeBorder(src=instance, top=2, bottom=2, left=0, right=0, borderType=cv2.BORDER_CONSTANT, value=0) 124 | 125 | T_wc = frames_file_list[idx]["pose"].reshape((4, 4)) 126 | 127 | if (self.img_h is not None and self.img_h != image.shape[0]) or \ 128 | (self.img_w is not None and self.img_w != image.shape[1]): 129 | image = cv2.resize(image, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 130 | depth = cv2.resize(depth, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 131 | semantic = cv2.resize(semantic, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 132 | instance = cv2.resize(instance, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 133 | 134 | self.train_samples["image"].append(image) 135 | self.train_samples["depth"].append(depth) 136 | self.train_samples["semantic_raw"].append(semantic) 137 | self.train_samples["instance"].append(instance) 138 | self.train_samples["T_wc"].append(T_wc) 139 | 140 | 141 | # test samples 142 | for idx in test_ids: 143 | image = cv2.imread(frames_file_list[idx]["file_name_image"])[:,:,::-1] # change from BGR uinit 8 to RGB float 144 | image = cv2.copyMakeBorder(src=image, top=2, bottom=2, left=0, right=0, borderType=cv2.BORDER_CONSTANT, value=[0,0,0]) # pad 4 pixels to height so that images have aspect ratio of 4:3 145 | assert image.shape[0]/image.shape[1]==3/4 and image.shape[1]==original_colour_w and image.shape[0] == 972 146 | image = image/255.0 147 | 148 | depth = cv2.imread(frames_file_list[idx]["file_name_depth"], cv2.IMREAD_UNCHANGED) / 1000.0 # uint16 mm depth, then turn depth from mm to meter 149 | 150 | semantic = cv2.imread(frames_file_list[idx]["file_name_label"], cv2.IMREAD_UNCHANGED) 151 | semantic = cv2.copyMakeBorder(src=semantic, top=2, bottom=2, left=0, right=0, borderType=cv2.BORDER_CONSTANT, value=0) 152 | 153 | instance = cv2.imread(frames_file_list[idx]["file_name_instance"], cv2.IMREAD_UNCHANGED) 154 | instance = cv2.copyMakeBorder(src=instance, top=2, bottom=2, left=0, right=0, borderType=cv2.BORDER_CONSTANT, value=0) 155 | 156 | T_wc = frames_file_list[idx]["pose"].reshape((4, 4)) 157 | 158 | if (self.img_h is not None and self.img_h != image.shape[0]) or \ 159 | (self.img_w is not None and self.img_w != image.shape[1]): 160 | image = cv2.resize(image, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 161 | depth = cv2.resize(depth, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 162 | semantic = cv2.resize(semantic, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 163 | instance = cv2.resize(instance, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 164 | 165 | 166 | self.test_samples["image"].append(image) 167 | self.test_samples["depth"].append(depth) 168 | self.test_samples["semantic_raw"].append(semantic) 169 | self.test_samples["instance"].append(instance) 170 | self.test_samples["T_wc"].append(T_wc) 171 | 172 | 173 | scale_y = image.shape[0]/(original_colour_h+4) 174 | scale_x = image.shape[1]/original_colour_w 175 | assert scale_x == scale_y # this requires the desired shape to also has a aspect ratio of 4:3 176 | 177 | # we modify the camera intrinsics considering the padding and scaling 178 | self.intrinsics = np.asarray(intrinsics) 179 | self.intrinsics[1,2] += 2 # we add c_y by 2 since we pad the height by 4 pixels 180 | self.intrinsics[0, 0] = self.intrinsics[0, 0]*scale_x # fx 181 | self.intrinsics[1, 1] = self.intrinsics[1, 1]*scale_x # fy 182 | 183 | self.intrinsics[0, 2] = self.intrinsics[0, 2]*scale_x # cx 184 | self.intrinsics[1, 2] = self.intrinsics[1, 2]*scale_x # cy 185 | 186 | 187 | for key in self.test_samples.keys(): # transform list of np array to array with batch dimension 188 | self.train_samples[key] = np.asarray(self.train_samples[key]) 189 | self.test_samples[key] = np.asarray(self.test_samples[key]) 190 | 191 | # map scannet classes to nyu definition 192 | train_semantic = self.train_samples["semantic_raw"] 193 | test_semantic = self.test_samples["semantic_raw"] 194 | 195 | train_semantic_nyu = train_semantic.copy() 196 | test_semantic_nyu = test_semantic.copy() 197 | 198 | for scan_id, nyu_id in label_mapping_nyu.items(): 199 | train_semantic_nyu[train_semantic==scan_id] = nyu_id 200 | test_semantic_nyu[test_semantic==scan_id] = nyu_id 201 | 202 | self.train_samples["semantic"] = train_semantic_nyu 203 | self.test_samples["semantic"] = test_semantic_nyu 204 | 205 | 206 | self.semantic_classes = np.unique( 207 | np.concatenate( 208 | (np.unique(self.train_samples["semantic"]), 209 | np.unique(self.test_samples["semantic"]))) 210 | ).astype(np.uint8) 211 | # each scene may not contain all 40-classes 212 | 213 | self.num_semantic_class = self.semantic_classes.shape[0] # number of semantic classes 214 | 215 | colour_map_np_remap = colour_map_np.copy()[self.semantic_classes] # take corresponding colour map 216 | self.colour_map_np = colour_map_np 217 | self.colour_map_np_remap = colour_map_np_remap 218 | self.mask_ids = np.ones(self.train_num) # init self.mask_ids as full ones 219 | # 1 means the correspinding label map is used for semantic loss during training, while 0 means no semantic loss 220 | 221 | # save colourised ground truth label to img folder 222 | if save_dir is not None: 223 | # save colourised ground truth label to img folder 224 | vis_label_save_dir = os.path.join(save_dir, "vis-sampled-label-filt") 225 | os.makedirs(vis_label_save_dir, exist_ok=True) 226 | vis_train_label = colour_map_np[self.train_samples["semantic"]] 227 | vis_test_label = colour_map_np[self.test_samples["semantic"]] 228 | for i in range(self.train_num): 229 | label = vis_train_label[i].astype(np.uint8) 230 | cv2.imwrite(os.path.join(vis_label_save_dir, "train_vis_sem_{}.png".format(i)),label[...,::-1]) 231 | 232 | for i in range(self.test_num): 233 | label = vis_test_label[i].astype(np.uint8) 234 | cv2.imwrite(os.path.join(vis_label_save_dir, "test_vis_sem_{}.png".format(i)),label[...,::-1]) 235 | 236 | 237 | # remap existing semantic class labels to continuous label ranging from 0 to num_class-1 238 | self.train_samples["semantic_clean"] = self.train_samples["semantic"].copy() 239 | self.train_samples["semantic_remap"] = self.train_samples["semantic"].copy() 240 | self.train_samples["semantic_remap_clean"] = self.train_samples["semantic_clean"].copy() 241 | 242 | self.test_samples["semantic_remap"] = self.test_samples["semantic"].copy() 243 | 244 | for i in range(self.num_semantic_class): 245 | self.train_samples["semantic_remap"][self.train_samples["semantic"]== self.semantic_classes[i]] = i 246 | self.train_samples["semantic_remap_clean"][self.train_samples["semantic_clean"]== self.semantic_classes[i]] = i 247 | self.test_samples["semantic_remap"][self.test_samples["semantic"]== self.semantic_classes[i]] = i 248 | 249 | 250 | self.train_samples["semantic_remap"] = self.train_samples["semantic_remap"].astype(np.uint8) 251 | self.train_samples["semantic_remap_clean"] = self.train_samples["semantic_remap_clean"].astype(np.uint8) 252 | self.test_samples["semantic_remap"] = self.test_samples["semantic_remap"].astype(np.uint8) 253 | 254 | print() 255 | print("Training Sample Summary:") 256 | for key in self.train_samples.keys(): 257 | print("{} has shape of {}, type {}.".format(key, self.train_samples[key].shape, self.train_samples[key].dtype)) 258 | print() 259 | print("Testing Sample Summary:") 260 | for key in self.test_samples.keys(): 261 | print("{} has shape of {}, type {}.".format(key, self.test_samples[key].shape, self.test_samples[key].dtype)) 262 | 263 | 264 | def sample_label_maps(self, sparse_ratio=0.5, random_sample=False, load_saved=False): 265 | if load_saved is False: 266 | K = int(self.train_num*sparse_ratio) # number of skipped training frames, mask=0 267 | N = self.train_num-K # number of used training frames, mask=1 268 | assert np.sum(self.mask_ids) == self.train_num # sanity check that all masks are avaible before sampling 269 | 270 | if K==0: # incase sparse_ratio==0: 271 | return 272 | 273 | if random_sample: 274 | self.mask_ids[:K] = 0 275 | np.random.shuffle(self.mask_ids) 276 | else: # sample evenly 277 | if sparse_ratio<=0.5: # skip less/equal than half frames 278 | assert K <= self.train_num/2 279 | q, r = divmod(self.train_num, K) 280 | indices = [q*i + min(i, r) for i in range(K)] 281 | self.mask_ids[indices] = 0 282 | 283 | else: # skip more than half frames 284 | assert K > self.train_num/2 285 | self.mask_ids = np.zeros_like(self.mask_ids) # disable all images and evenly enable N images in total 286 | q, r = divmod(self.train_num, N) 287 | indices = [q*i + min(i, r) for i in range(N)] 288 | self.mask_ids[indices] = 1 289 | print("{} of {} semantic labels are sampled (sparse ratio: {}).".format(sum(self.mask_ids), len(self.mask_ids), sparse_ratio)) 290 | noisy_sem_dir = os.path.join(self.scene_dir, "renders", "noisy_pixel_sems_sr{}".format(sparse_ratio)) 291 | if not os.path.exists(noisy_sem_dir): 292 | os.makedirs(noisy_sem_dir) 293 | with open(os.path.join(noisy_sem_dir, "mask_ids.npy"), 'wb') as f: 294 | np.save(f, self.mask_ids) 295 | elif load_saved is True: 296 | noisy_sem_dir = os.path.join(self.scene_dir, "renders", "noisy_pixel_sems_sr{}".format(sparse_ratio)) 297 | self.mask_ids = np.load(os.path.join(noisy_sem_dir, "mask_ids.npy")) 298 | 299 | 300 | def add_pixel_wise_noise_label(self, 301 | sparse_views=False, sparse_ratio=0.5, random_sample=False, 302 | noise_ratio=0.3, visualise_save=False, load_saved=False): 303 | if not load_saved: 304 | if sparse_views: 305 | self.sample_label_maps(sparse_ratio=sparse_ratio, random_sample=random_sample) 306 | num_pixel = self.img_h * self.img_w 307 | num_pixel_noisy = int(num_pixel*noise_ratio) 308 | train_sem = self.train_samples["semantic_remap"] 309 | 310 | for i in range(len(self.mask_ids)): 311 | if self.mask_ids[i] == 1: # add label noise to unmasked/available labels 312 | noisy_index_1d = np.random.permutation(num_pixel)[:num_pixel_noisy] 313 | faltten_sem = train_sem[i].flatten() 314 | faltten_sem[noisy_index_1d] = np.random.choice(self.num_semantic_class, num_pixel_noisy) 315 | # we replace the label of randomly selected num_pixel_noisy pixels to random labels from [1, self.num_semantic_class], 0 class is the none class 316 | train_sem[i] = faltten_sem.reshape(self.img_h, self.img_w) 317 | 318 | print("{} of {} semantic labels are added noise {} percent area ratio.".format(sum(self.mask_ids), len(self.mask_ids), noise_ratio)) 319 | 320 | if visualise_save: 321 | noisy_sem_dir = os.path.join(self.scene_dir, "renders", "noisy_pixel_sems_sr{}_nr{}".format(sparse_ratio, noise_ratio)) 322 | if not os.path.exists(noisy_sem_dir): 323 | os.makedirs(noisy_sem_dir) 324 | with open(os.path.join(noisy_sem_dir, "mask_ids.npy"), 'wb') as f: 325 | np.save(f, self.mask_ids) 326 | 327 | 328 | vis_noisy_semantic_list = [] 329 | vis_semantic_clean_list = [] 330 | 331 | colour_map_np = self.colour_map_np_remap 332 | 333 | semantic_remap = self.train_samples["semantic_remap"] # [H, W, 3] 334 | semantic_remap_clean = self.train_samples["semantic_remap_clean"] # [H, W, 3] 335 | 336 | for i in range(len(self.mask_ids)): 337 | if self.mask_ids[i] == 1: # add label noise to unmasked/available labels 338 | vis_noisy_semantic = colour_map_np[semantic_remap[i]] # [H, W, 3] 339 | vis_semantic_clean = colour_map_np[semantic_remap_clean[i]] # [H, W, 3] 340 | 341 | imageio.imwrite(os.path.join(noisy_sem_dir, "semantic_class_{}.png".format(i)), semantic_remap[i]) 342 | imageio.imwrite(os.path.join(noisy_sem_dir, "vis_sem_class_{}.png".format(i)), vis_noisy_semantic) 343 | 344 | vis_noisy_semantic_list.append(vis_noisy_semantic) 345 | vis_semantic_clean_list.append(vis_semantic_clean) 346 | else: 347 | # for mask_ids of 0, we skip these frames during training and do not add noise 348 | vis_noisy_semantic = colour_map_np[semantic_remap[i]] # [H, W, 3] 349 | vis_semantic_clean = colour_map_np[semantic_remap_clean[i]] # [H, W, 3] 350 | assert np.all(vis_noisy_semantic==vis_semantic_clean) 351 | 352 | imageio.imwrite(os.path.join(noisy_sem_dir, "semantic_class_{}.png".format(i)), semantic_remap[i]) 353 | imageio.imwrite(os.path.join(noisy_sem_dir, "vis_sem_class_{}.png".format(i)), vis_noisy_semantic) 354 | 355 | vis_noisy_semantic_list.append(vis_noisy_semantic) 356 | vis_semantic_clean_list.append(vis_semantic_clean) 357 | 358 | imageio.mimwrite(os.path.join(noisy_sem_dir, 'noisy_sem_ratio_{}.mp4'.format(noise_ratio)), 359 | np.stack(vis_noisy_semantic_list, 0), fps=30, quality=8) 360 | 361 | imageio.mimwrite(os.path.join(noisy_sem_dir, 'clean_sem.mp4'), 362 | np.stack(vis_semantic_clean_list, 0), fps=30, quality=8) 363 | else: 364 | print("Load saved noisy labels.") 365 | noisy_sem_dir = os.path.join(self.scene_dir, "renders", "noisy_pixel_sems_sr{}_nr{}".format(sparse_ratio, noise_ratio)) 366 | self.mask_ids = np.load(os.path.join(noisy_sem_dir, "mask_ids.npy")) 367 | semantic_img_list = [] 368 | semantic_path_list = sorted(glob.glob(noisy_sem_dir + '/semantic_class_*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 369 | assert len(semantic_path_list)>0 370 | for idx in range(len(self.mask_ids)): 371 | semantic = imread(semantic_path_list[idx]) 372 | semantic_img_list.append(semantic) 373 | self.train_samples["semantic_remap"] = np.asarray(semantic_img_list) 374 | 375 | 376 | def super_resolve_label(self, down_scale_factor=8, dense_supervision=True): 377 | if down_scale_factor==1: 378 | return 379 | if dense_supervision: # train down-scale and up-scale again 380 | scaled_low_res_train_label = [] 381 | for i in range(self.train_num): 382 | low_res_label = cv2.resize(self.train_samples["semantic_remap"][i], 383 | (self.img_w//down_scale_factor, self.img_h//down_scale_factor), 384 | interpolation=cv2.INTER_NEAREST) 385 | 386 | scaled_low_res_label = cv2.resize(low_res_label, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 387 | scaled_low_res_train_label.append(scaled_low_res_label) 388 | 389 | scaled_low_res_train_label = np.asarray(scaled_low_res_train_label) 390 | 391 | self.train_samples["semantic_remap"] = scaled_low_res_train_label 392 | 393 | else: # we only penalise strictly on valid pixel positions 394 | valid_low_res_pixel_mask = np.zeros((self.img_h, self.img_w)) 395 | valid_low_res_pixel_mask[::down_scale_factor, ::down_scale_factor]=1 396 | self.train_samples["semantic_remap"] = (self.train_samples["semantic_remap"]*valid_low_res_pixel_mask[None,...]).astype(np.uint8) 397 | # we mask all the decimated pixel label to void class==0 398 | 399 | 400 | 401 | def simulate_user_click_partial(self, perc=0, load_saved=False, visualise_save=True): 402 | assert perc<=100 and perc >= 0 403 | assert self.train_num == self.train_samples["semantic_remap"].shape[0] 404 | single_click=True if perc==0 else False # single_click: whether to use single click only from each class 405 | perc = perc/100.0 # make perc 406 | if not load_saved: 407 | 408 | if single_click: 409 | click_semantic_map = [] 410 | for i in range(self.train_num): 411 | if (i+1)%5==10: 412 | print("Generating partial label of ratio {} for frame {}/{}.".format(perc, i, self.train_num)) 413 | im = self.train_samples["semantic_remap"][i] 414 | void_class = [0] 415 | label_class = np.unique(im).tolist() 416 | valid_class = [i for i in label_class if i not in void_class] 417 | im_ = np.zeros_like(im) 418 | for l in valid_class: 419 | label_idx = np.transpose(np.nonzero(im == l)) 420 | sample_ind = np.random.choice(label_idx.shape[0], 1, replace=False) 421 | label_idx_ = label_idx[sample_ind] 422 | im_[label_idx_[:, 0], label_idx_[:, 1]] = l 423 | click_semantic_map.append(im_) 424 | click_semantic_map = np.asarray(click_semantic_map).astype(np.uint8) 425 | self.train_samples["semantic_remap"] = click_semantic_map 426 | 427 | print('Partial Label images with centroid sampling (extreme) has completed.') 428 | 429 | elif perc>0 and not single_click: 430 | click_semantic_map = [] 431 | for i in range(self.train_num): 432 | if (i+1)%5==10: 433 | print("Generating partial label of ratio {} for frame {}/{}.".format(perc, i, self.train_num)) 434 | im = self.train_samples["semantic_remap"][i] 435 | void_class = [0] 436 | label_class = np.unique(im).tolist() # find the unique class-ids in the current training label 437 | valid_class = [c for c in label_class if c not in void_class] 438 | 439 | im_ = np.zeros_like(im) 440 | for l in valid_class: 441 | label_mask = np.zeros_like(im) 442 | label_mask_ = im == l # binary mask of pixels equal to class-l 443 | label_idx = np.transpose(np.nonzero(label_mask_)) # Nx2 444 | sample_ind = np.random.choice(label_idx.shape[0], 1, replace=False) # shape [1,] 445 | label_idx_ = label_idx[sample_ind] # shape [1, 2] 446 | target_num = int(perc * label_mask_.sum()) # find the target and total number of pixels belong to class-l in the current image 447 | label_mask[label_idx_[0, 0], label_idx_[0, 1]] = 1 # full-zero mask with only selected pixel to be 1 448 | label_mask_true = label_mask 449 | # label_mask_true initially has only 1 True pixel, we continuously grow mask until reach expected percentage 450 | 451 | while label_mask_true.sum() < target_num: 452 | num_before_grow = label_mask_true.sum() 453 | label_mask = cv2.dilate(label_mask, kernel=np.ones([5, 5])) 454 | label_mask_true = label_mask * label_mask_ 455 | num_after_grow = label_mask_true.sum() 456 | # print("Before growth: {}, After growth: {}".format(num_before_grow, num_after_grow)) 457 | if num_after_grow==num_before_grow: 458 | print("Initialise Another Seed for Growing!") 459 | # the region does not grow means the very local has been filled, 460 | # and we need to initiate another seed to keep growing 461 | uncovered_region_mask = label_mask_ - label_mask_true # pixel equal to 1 means un-sampled regions belong to current class 462 | label_idx = np.transpose(np.nonzero(uncovered_region_mask)) # Nx2 463 | sample_ind = np.random.choice(label_idx.shape[0], 1, replace=False) # shape [1,] 464 | label_idx_ = label_idx[sample_ind] # shape [1, 2] 465 | label_mask[label_idx_[0, 0], label_idx_[0, 1]] = 1 466 | 467 | im_[label_mask_true.astype(bool)] = l 468 | click_semantic_map.append(im_) 469 | 470 | click_semantic_map = np.asarray(click_semantic_map).astype(np.uint8) 471 | self.train_samples["semantic_remap"] = click_semantic_map 472 | print('Partial Label images with centroid sampling has completed.') 473 | else: 474 | assert False 475 | 476 | if visualise_save: 477 | partial_sem_dir = os.path.join(self.semantic_class_dir, "partial_perc_{}".format(perc)) 478 | if not os.path.exists(partial_sem_dir): 479 | os.makedirs(partial_sem_dir) 480 | colour_map_np = self.colour_map_np_remap 481 | vis_partial_sem = [] 482 | for i in range(self.train_num): 483 | vis_partial_semantic = colour_map_np[self.train_samples["semantic_remap"][i]] # [H, W, 3] 484 | imageio.imwrite(os.path.join(partial_sem_dir, "semantic_class_{}.png".format(i)), self.train_samples["semantic_remap"][i]) 485 | imageio.imwrite(os.path.join(partial_sem_dir, "vis_sem_class_{}.png".format(i)), vis_partial_semantic) 486 | vis_partial_sem.append(vis_partial_semantic) 487 | 488 | imageio.mimwrite(os.path.join(partial_sem_dir, 'partial_sem.mp4'), self.train_samples["semantic_remap"], fps=30, quality=8) 489 | imageio.mimwrite(os.path.join(partial_sem_dir, 'vis_partial_sem.mp4'), np.stack(vis_partial_sem, 0), fps=30, quality=8) 490 | 491 | else: # load saved single-click/partial semantics 492 | saved_partial_sem_dir = os.path.join(self.semantic_class_dir, "partial_perc_{}".format(perc)) 493 | semantic_img_list = [] 494 | semantic_path_list = sorted(glob.glob(saved_partial_sem_dir + '/semantic_class_*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 495 | assert len(semantic_path_list)>0 496 | for idx in range(self.train_num): 497 | semantic = imread(semantic_path_list[idx]) 498 | semantic_img_list.append(semantic) 499 | self.train_samples["semantic_remap"] = np.asarray(semantic_img_list).astype(np.uint8) 500 | -------------------------------------------------------------------------------- /SSR/datasets/replica/replica_datasets.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import glob 3 | import numpy as np 4 | from skimage.io import imread 5 | from torch.utils.data import Dataset 6 | import cv2 7 | import imageio 8 | from imgviz import label_colormap 9 | 10 | class ReplicaDatasetCache(Dataset): 11 | def __init__(self, data_dir, train_ids, test_ids, img_h=None, img_w=None): 12 | 13 | traj_file = os.path.join(data_dir, "traj_w_c.txt") 14 | self.rgb_dir = os.path.join(data_dir, "rgb") 15 | self.depth_dir = os.path.join(data_dir, "depth") # depth is in mm uint 16 | self.semantic_class_dir = os.path.join(data_dir, "semantic_class") 17 | self.semantic_instance_dir = os.path.join(data_dir, "semantic_instance") 18 | if not os.path.exists(self.semantic_instance_dir): 19 | self.semantic_instance_dir = None 20 | 21 | 22 | self.train_ids = train_ids 23 | self.train_num = len(train_ids) 24 | self.test_ids = test_ids 25 | self.test_num = len(test_ids) 26 | 27 | self.img_h = img_h 28 | self.img_w = img_w 29 | 30 | self.Ts_full = np.loadtxt(traj_file, delimiter=" ").reshape(-1, 4, 4) 31 | 32 | self.rgb_list = sorted(glob.glob(self.rgb_dir + '/rgb*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 33 | self.depth_list = sorted(glob.glob(self.depth_dir + '/depth*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 34 | self.semantic_list = sorted(glob.glob(self.semantic_class_dir + '/semantic_class_*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 35 | if self.semantic_instance_dir is not None: 36 | self.instance_list = sorted(glob.glob(self.semantic_instance_dir + '/semantic_instance_*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 37 | 38 | self.train_samples = {'image': [], 'depth': [], 39 | 'semantic': [], 'T_wc': [], 40 | 'instance': []} 41 | 42 | self.test_samples = {'image': [], 'depth': [], 43 | 'semantic': [], 'T_wc': [], 44 | 'instance': []} 45 | 46 | # training samples 47 | for idx in train_ids: 48 | image = cv2.imread(self.rgb_list[idx])[:,:,::-1] / 255.0 # change from BGR uinit 8 to RGB float 49 | depth = cv2.imread(self.depth_list[idx], cv2.IMREAD_UNCHANGED) / 1000.0 # uint16 mm depth, then turn depth from mm to meter 50 | semantic = cv2.imread(self.semantic_list[idx], cv2.IMREAD_UNCHANGED) 51 | if self.semantic_instance_dir is not None: 52 | instance = cv2.imread(self.instance_list[idx], cv2.IMREAD_UNCHANGED) # uint16 53 | 54 | if (self.img_h is not None and self.img_h != image.shape[0]) or \ 55 | (self.img_w is not None and self.img_w != image.shape[1]): 56 | image = cv2.resize(image, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 57 | depth = cv2.resize(depth, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 58 | semantic = cv2.resize(semantic, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 59 | if self.semantic_instance_dir is not None: 60 | instance = cv2.resize(instance, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 61 | 62 | T_wc = self.Ts_full[idx] 63 | 64 | self.train_samples["image"].append(image) 65 | self.train_samples["depth"].append(depth) 66 | self.train_samples["semantic"].append(semantic) 67 | if self.semantic_instance_dir is not None: 68 | self.train_samples["instance"].append(instance) 69 | self.train_samples["T_wc"].append(T_wc) 70 | 71 | 72 | # test samples 73 | for idx in test_ids: 74 | image = cv2.imread(self.rgb_list[idx])[:,:,::-1] / 255.0 # change from BGR uinit 8 to RGB float 75 | depth = cv2.imread(self.depth_list[idx], cv2.IMREAD_UNCHANGED) / 1000.0 # uint16 mm depth, then turn depth from mm to meter 76 | semantic = cv2.imread(self.semantic_list[idx], cv2.IMREAD_UNCHANGED) 77 | if self.semantic_instance_dir is not None: 78 | instance = cv2.imread(self.instance_list[idx], cv2.IMREAD_UNCHANGED) # uint16 79 | 80 | if (self.img_h is not None and self.img_h != image.shape[0]) or \ 81 | (self.img_w is not None and self.img_w != image.shape[1]): 82 | image = cv2.resize(image, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 83 | depth = cv2.resize(depth, (self.img_w, self.img_h), interpolation=cv2.INTER_LINEAR) 84 | semantic = cv2.resize(semantic, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 85 | if self.semantic_instance_dir is not None: 86 | instance = cv2.resize(instance, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 87 | T_wc = self.Ts_full[idx] 88 | 89 | self.test_samples["image"].append(image) 90 | self.test_samples["depth"].append(depth) 91 | self.test_samples["semantic"].append(semantic) 92 | if self.semantic_instance_dir is not None: 93 | self.test_samples["instance"].append(instance) 94 | self.test_samples["T_wc"].append(T_wc) 95 | 96 | for key in self.test_samples.keys(): # transform list of np array to array with batch dimension 97 | self.train_samples[key] = np.asarray(self.train_samples[key]) 98 | self.test_samples[key] = np.asarray(self.test_samples[key]) 99 | 100 | self.semantic_classes = np.unique( 101 | np.concatenate( 102 | (np.unique(self.train_samples["semantic"]), 103 | np.unique(self.test_samples["semantic"])))).astype(np.uint8) 104 | self.num_semantic_class = self.semantic_classes.shape[0] # number of semantic classes, including the void class of 0 105 | 106 | self.colour_map_np = label_colormap()[self.semantic_classes] 107 | self.mask_ids = np.ones(self.train_num) # init self.mask_ids as full ones 108 | # 1 means the correspinding label map is used for semantic loss during training, while 0 means no semantic loss is applied on this frame 109 | 110 | # remap existing semantic class labels to continuous label ranging from 0 to num_class-1 111 | self.train_samples["semantic_clean"] = self.train_samples["semantic"].copy() 112 | self.train_samples["semantic_remap"] = self.train_samples["semantic"].copy() 113 | self.train_samples["semantic_remap_clean"] = self.train_samples["semantic_clean"].copy() 114 | 115 | self.test_samples["semantic_remap"] = self.test_samples["semantic"].copy() 116 | 117 | for i in range(self.num_semantic_class): 118 | self.train_samples["semantic_remap"][self.train_samples["semantic"]== self.semantic_classes[i]] = i 119 | self.train_samples["semantic_remap_clean"][self.train_samples["semantic_clean"]== self.semantic_classes[i]] = i 120 | self.test_samples["semantic_remap"][self.test_samples["semantic"]== self.semantic_classes[i]] = i 121 | 122 | 123 | print() 124 | print("Training Sample Summary:") 125 | for key in self.train_samples.keys(): 126 | print("{} has shape of {}, type {}.".format(key, self.train_samples[key].shape, self.train_samples[key].dtype)) 127 | print() 128 | print("Testing Sample Summary:") 129 | for key in self.test_samples.keys(): 130 | print("{} has shape of {}, type {}.".format(key, self.test_samples[key].shape, self.test_samples[key].dtype)) 131 | 132 | 133 | def sample_label_maps(self, sparse_ratio=0.5, K=0, random_sample=False, load_saved=False): 134 | """ 135 | sparse_ratio means the ratio of removed training images, e.g., 0.3 means 30% of semantic labels are removed 136 | Input: 137 | sparse_ratio: the percentage of semantic label frames to be *removed* 138 | K: the number of frames to be removed, if this is speficied we override the results computed from sparse_ratio 139 | random_sample: whether to random sample frames or interleavely/evenly sample, True--random sample; False--interleavely sample 140 | load_saved: use pre-computed mask_ids from previous experiments 141 | """ 142 | if load_saved is False: 143 | if K==0: 144 | K = int(self.train_num*sparse_ratio) # number of skipped training frames, mask=0 145 | 146 | N = self.train_num-K # number of used training frames, mask=1 147 | assert np.sum(self.mask_ids) == self.train_num # sanity check that all masks are avaible before sampling 148 | 149 | if K==0: # in case sparse_ratio==0: 150 | return 151 | 152 | if random_sample: 153 | self.mask_ids[:K] = 0 154 | np.random.shuffle(self.mask_ids) 155 | else: # sample interleave 156 | if sparse_ratio<=0.5: # skip less/equal than half frames 157 | assert K <= self.train_num/2 158 | q, r = divmod(self.train_num, K) 159 | indices = [q*i + min(i, r) for i in range(K)] 160 | self.mask_ids[indices] = 0 161 | 162 | else: # skip more than half frames 163 | assert K > self.train_num/2 164 | self.mask_ids = np.zeros_like(self.mask_ids) # disable all images and evenly enable N images in total 165 | q, r = divmod(self.train_num, N) 166 | indices = [q*i + min(i, r) for i in range(N)] 167 | self.mask_ids[indices] = 1 168 | 169 | print("{} of {} semantic labels are sampled (sparse ratio: {}).".format(sum(self.mask_ids), len(self.mask_ids), sparse_ratio)) 170 | noisy_sem_dir = os.path.join(self.semantic_class_dir, "noisy_pixel_sems_sr{}".format(sparse_ratio)) 171 | if not os.path.exists(noisy_sem_dir): 172 | os.makedirs(noisy_sem_dir) 173 | with open(os.path.join(noisy_sem_dir, "mask_ids.npy"), 'wb') as f: 174 | np.save(f, self.mask_ids) 175 | elif load_saved is True: 176 | noisy_sem_dir = os.path.join(self.semantic_class_dir, "noisy_pixel_sems_sr{}".format(sparse_ratio)) 177 | self.mask_ids = np.load(os.path.join(noisy_sem_dir, "mask_ids.npy")) 178 | 179 | 180 | 181 | def sample_specific_labels(self, frame_ids, train_ids): 182 | """ 183 | Only use dense label maps for specific/selected frames. 184 | """ 185 | assert np.sum(self.mask_ids) == self.train_num # sanity check that all masks are avaible before sampling 186 | 187 | self.mask_ids = np.zeros_like(self.mask_ids) 188 | 189 | if len(frame_ids)==1 and frame_ids[0] is None: 190 | # we do not add any semantic supervision 191 | return 192 | 193 | relative_ids = [train_ids.index(x) for x in frame_ids] 194 | 195 | self.mask_ids[relative_ids] = 1 196 | 197 | 198 | def add_pixel_wise_noise_label(self, 199 | sparse_views=False, sparse_ratio=0.0, random_sample=False, 200 | noise_ratio=0.0, visualise_save=False, load_saved=False): 201 | """ 202 | sparse_views: whether we sample a subset of dense semantic labels for training 203 | sparse_ratio: the ratio of frames to be removed/skipped if sampling a subset of labels 204 | random_sample: whether to random sample frames or interleavely/evenly sample, True--random sample; False--interleavely sample 205 | noise_ratio: the ratio of num pixels per-frame to be randomly perturbed 206 | visualise_save: whether to save the noisy labels into harddrive for later usage 207 | load_saved: use trained noisy labels for training to ensure consistency betwwen experiments 208 | """ 209 | 210 | if not load_saved: 211 | if sparse_views: 212 | self.sample_label_maps(sparse_ratio=sparse_ratio, random_sample=random_sample) 213 | num_pixel = self.img_h * self.img_w 214 | num_pixel_noisy = int(num_pixel*noise_ratio) 215 | train_sem = self.train_samples["semantic_remap"] 216 | 217 | for i in range(len(self.mask_ids)): 218 | if self.mask_ids[i] == 1: # add label noise to unmasked/available labels 219 | noisy_index_1d = np.random.permutation(num_pixel)[:num_pixel_noisy] 220 | faltten_sem = train_sem[i].flatten() 221 | 222 | faltten_sem[noisy_index_1d] = np.random.choice(self.num_semantic_class, num_pixel_noisy) 223 | # we replace the label of randomly selected num_pixel_noisy pixels to random labels from [1, self.num_semantic_class], 0 class is the none class 224 | train_sem[i] = faltten_sem.reshape(self.img_h, self.img_w) 225 | 226 | print("{} of {} semantic labels are added noise {} percent area ratio.".format(sum(self.mask_ids), len(self.mask_ids), noise_ratio)) 227 | 228 | if visualise_save: 229 | noisy_sem_dir = os.path.join(self.semantic_class_dir, "noisy_pixel_sems_sr{}_nr{}".format(sparse_ratio, noise_ratio)) 230 | if not os.path.exists(noisy_sem_dir): 231 | os.makedirs(noisy_sem_dir) 232 | with open(os.path.join(noisy_sem_dir, "mask_ids.npy"), 'wb') as f: 233 | np.save(f, self.mask_ids) 234 | 235 | vis_noisy_semantic_list = [] 236 | vis_semantic_clean_list = [] 237 | 238 | colour_map_np = self.colour_map_np 239 | # 101 classes in total from Replica, select the existing class from total colour map 240 | 241 | semantic_remap = self.train_samples["semantic_remap"] # [H, W, 3] 242 | semantic_remap_clean = self.train_samples["semantic_remap_clean"] # [H, W, 3] 243 | 244 | # save semantic labels 245 | for i in range(len(self.mask_ids)): 246 | if self.mask_ids[i] == 1: 247 | vis_noisy_semantic = colour_map_np[semantic_remap[i]] # [H, W, 3] 248 | vis_semantic_clean = colour_map_np[semantic_remap_clean[i]] # [H, W, 3] 249 | 250 | imageio.imwrite(os.path.join(noisy_sem_dir, "semantic_class_{}.png".format(i)), semantic_remap[i]) 251 | imageio.imwrite(os.path.join(noisy_sem_dir, "vis_sem_class_{}.png".format(i)), vis_noisy_semantic) 252 | 253 | vis_noisy_semantic_list.append(vis_noisy_semantic) 254 | vis_semantic_clean_list.append(vis_semantic_clean) 255 | else: 256 | # for mask_ids of 0, we skip these frames during training and do not add noise 257 | vis_noisy_semantic = colour_map_np[semantic_remap[i]] # [H, W, 3] 258 | vis_semantic_clean = colour_map_np[semantic_remap_clean[i]] # [H, W, 3] 259 | assert np.all(vis_noisy_semantic==vis_semantic_clean) # apply this check to skipped frames 260 | 261 | imageio.imwrite(os.path.join(noisy_sem_dir, "semantic_class_{}.png".format(i)), semantic_remap[i]) 262 | imageio.imwrite(os.path.join(noisy_sem_dir, "vis_sem_class_{}.png".format(i)), vis_noisy_semantic) 263 | 264 | vis_noisy_semantic_list.append(vis_noisy_semantic) 265 | vis_semantic_clean_list.append(vis_semantic_clean) 266 | 267 | imageio.mimwrite(os.path.join(noisy_sem_dir, 'noisy_sem_ratio_{}.mp4'.format(noise_ratio)), 268 | np.stack(vis_noisy_semantic_list, 0), fps=30, quality=8) 269 | 270 | imageio.mimwrite(os.path.join(noisy_sem_dir, 'clean_sem.mp4'), 271 | np.stack(vis_semantic_clean_list, 0), fps=30, quality=8) 272 | else: 273 | print("Load saved noisy labels.") 274 | noisy_sem_dir = os.path.join(self.semantic_class_dir, "noisy_pixel_sems_sr{}_nr{}".format(sparse_ratio, noise_ratio)) 275 | assert os.path.exists(noisy_sem_dir) 276 | self.mask_ids = np.load(os.path.join(noisy_sem_dir, "mask_ids.npy")) 277 | semantic_img_list = [] 278 | semantic_path_list = sorted(glob.glob(noisy_sem_dir + '/semantic_class_*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 279 | assert len(semantic_path_list)>0 280 | for idx in range(len(self.mask_ids)): 281 | semantic = imread(semantic_path_list[idx]) 282 | semantic_img_list.append(semantic) 283 | self.train_samples["semantic_remap"] = np.asarray(semantic_img_list) 284 | 285 | 286 | def add_instance_wise_noise_label(self, sparse_views=False, sparse_ratio=0.0, random_sample=False, 287 | flip_ratio=0.0, uniform_flip=False, 288 | instance_id=[3, 6, 7, 9, 11, 12, 13, 48], 289 | load_saved=False, 290 | visualise_save=False): 291 | 292 | """ In this function, we try to test if semantic-NERF can correct the wrong instance label after fusion (training). 293 | For selected instances, we randomly pick a portion of frames and change their class labels. 294 | Input: 295 | sparse_views: if we use a subset of sampled original training set or not. 296 | sparse_ratio: the ratio of frames to be dropped. 297 | random_sample: whether to random sample frames or interleavely/evenly sample, True--random sample; False--interleavely sample 298 | flip_ratio: for all the frames containing certain instances, the ratio of changing labels 299 | uniform_flip: True: after sorting the candidate frames by instance area ratio, 300 | we uniform sample frames to flip certain instances' semantic class; 301 | False: we take the frames with least instance area ratio to change color. 302 | instance_id: the instance id of all 8 chairs in Replica Room_2, used for adding region-wise noise 303 | load_saved: whether to load the saved self.mask_ids or not 304 | visualise_save: If true, save processed partial labels into local harddrive/folders for futher usage. 305 | 306 | 307 | """ 308 | num_pixel = self.img_w * self.img_h 309 | 310 | if not load_saved: 311 | if sparse_views: 312 | self.sample_label_maps(sparse_ratio=sparse_ratio, random_sample=random_sample, load_saved=load_saved) 313 | assert self.semantic_instance_dir is not None 314 | # instance_id = [3, 6, 7, 9, 11,12, 13, 48] 315 | # instance_maps_dict = dict.fromkeys(instance_id, []) # using this one will make all the keys share the same value due to list [] is mutable 316 | instance_maps_dict = dict.fromkeys(instance_id) 317 | for k in instance_maps_dict.keys(): 318 | instance_maps_dict[k] = list() 319 | 320 | 321 | # find which training images contrain the instance we want to flip labels 322 | for img_idx in range(self.train_num): 323 | instance_label_map = self.train_samples["instance"][img_idx] 324 | for ins_idx in instance_id: 325 | instance_ratio = np.sum(instance_label_map==ins_idx)/num_pixel 326 | if instance_ratio > 0 and self.mask_ids[img_idx]==1: # larger than 1% image area and the image is also sampled into training set 327 | instance_maps_dict[ins_idx].append([img_idx, instance_ratio]) 328 | 329 | num_frame_per_instance_id = np.asarray([len(x) for x in instance_maps_dict.values()]) 330 | num_flip_frame_per_instance_id = (num_frame_per_instance_id*flip_ratio).astype(np.int) 331 | 332 | for k, v in instance_maps_dict.items(): 333 | instance_maps_dict[k] = sorted(instance_maps_dict[k], key=lambda x: x[1]) # sorted, default is ascending order 334 | if not uniform_flip: 335 | # we flip the labels with minimum area ratio, 336 | # the intuition is that the observation is partial and is likely to be wrong. 337 | for i in range(len(instance_id)): # loop over instance id 338 | selected_frame_id = [x[0] for x in instance_maps_dict[instance_id[i]][:num_flip_frame_per_instance_id[i]]] 339 | for m in selected_frame_id: # loop over image ids having the selected instance 340 | self.train_samples["semantic_remap"][m][self.train_samples["instance"][m]==instance_id[i]] = np.random.choice(self.num_semantic_class, 1) 341 | else: 342 | if flip_ratio<=0.5: # flip less/equal than half frames 343 | for i in range(len(instance_id)): # loop over instance id 344 | K = num_flip_frame_per_instance_id[i] 345 | q, r = divmod(num_frame_per_instance_id[i], K) 346 | indices_to_flip = [q*i + min(i, r) for i in range(K)] 347 | valid_frame_id_list = [x[0] for x in instance_maps_dict[instance_id[i]]] 348 | selected_frame_id = [valid_frame_id_list[flip_id] for flip_id in indices_to_flip] 349 | for m in selected_frame_id: # loop over image ids having the selected instance 350 | self.train_samples["semantic_remap"][m][self.train_samples["instance"][m]==instance_id[i]] = np.random.choice(self.num_semantic_class, 1) 351 | 352 | else: # flip more than half frames 353 | for i in range(len(instance_id)): # loop over instance id 354 | K = num_flip_frame_per_instance_id[i] 355 | N = num_frame_per_instance_id[i] - K 356 | q, r = divmod(num_frame_per_instance_id[i], N) 357 | indices_NOT_flip = [q*i + min(i, r) for i in range(N)] 358 | indices_to_flip = [x for x in range(num_frame_per_instance_id[i]) if x not in indices_NOT_flip] 359 | valid_frame_id_list = [x[0] for x in instance_maps_dict[instance_id[i]]] 360 | selected_frame_id = [valid_frame_id_list[flip_id] for flip_id in indices_to_flip] 361 | for m in selected_frame_id: # loop over image ids having the selected instance 362 | self.train_samples["semantic_remap"][m][self.train_samples["instance"][m]==instance_id[i]] = np.random.choice(self.num_semantic_class, 1) 363 | 364 | colour_map_np = self.colour_map_np 365 | vis_flip_semantic = [colour_map_np[sem] for sem in self.train_samples["semantic_remap"]] 366 | vis_gt_semantic = [colour_map_np[sem] for sem in self.train_samples["semantic_remap_clean"]] 367 | 368 | if visualise_save: 369 | flip_sem_dir = os.path.join(self.semantic_class_dir, "flipped_chair_nr_{}".format(flip_ratio)) 370 | if not os.path.exists(flip_sem_dir): 371 | os.makedirs(flip_sem_dir) 372 | 373 | with open(os.path.join(flip_sem_dir, "mask_ids.npy"), 'wb') as f: 374 | np.save(f, self.mask_ids) 375 | 376 | for i in range(len(vis_flip_semantic)): 377 | imageio.imwrite(os.path.join(flip_sem_dir, "semantic_class_{}.png".format(i)), self.train_samples["semantic_remap"][i]) 378 | imageio.imwrite(os.path.join(flip_sem_dir, "vis_sem_class_{}.png".format(i)), vis_flip_semantic[i]) 379 | imageio.imwrite(os.path.join(flip_sem_dir, "vis_gt_{}.png".format(i)), vis_gt_semantic[i]) 380 | else: 381 | print("Load saved noisy labels.") 382 | flip_sem_dir = os.path.join(self.semantic_class_dir, "flipped_chair_nr_{}".format(flip_ratio)) 383 | assert os.path.exists(flip_sem_dir) 384 | self.mask_ids = np.load(os.path.join(flip_sem_dir, "mask_ids.npy")) 385 | semantic_img_list = [] 386 | semantic_path_list = sorted(glob.glob(flip_sem_dir + '/semantic_class_*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 387 | assert len(semantic_path_list)>0 388 | for idx in range(len(self.mask_ids)): 389 | semantic = imread(semantic_path_list[idx]) 390 | semantic_img_list.append(semantic) 391 | self.train_samples["semantic_remap"] = np.asarray(semantic_img_list) 392 | 393 | def super_resolve_label(self, down_scale_factor=8, dense_supervision=True): 394 | """ In super-resolution mode, to create training supervisions, we downscale the ground truth label by certain scaling factor to 395 | throw away information. We then upscale the image back to original size. 396 | 397 | Two setups for upscaling: 398 | (1) Sparse label: we set the interpolated label pixels to void label==0, so we only have losses on grid of every 8 pixels 399 | (2) Dense label: we penalise also on interpolated pixel values 400 | 401 | down_scale_factor: the scaling factor for down-sampling and up-sampling 402 | dense_supervision: dense label mode or not. 403 | """ 404 | if down_scale_factor==1: 405 | return 406 | if dense_supervision: # for dense labelling, we down-scale and up-scale label maps again 407 | scaled_low_res_train_label = [] 408 | for i in range(self.train_num): 409 | low_res_label = cv2.resize(self.train_samples["semantic_remap"][i], 410 | (self.img_w//down_scale_factor, self.img_h//down_scale_factor), 411 | interpolation=cv2.INTER_NEAREST) 412 | 413 | scaled_low_res_label = cv2.resize(low_res_label, (self.img_w, self.img_h), interpolation=cv2.INTER_NEAREST) 414 | scaled_low_res_train_label.append(scaled_low_res_label) 415 | 416 | scaled_low_res_train_label = np.asarray(scaled_low_res_train_label) 417 | 418 | self.train_samples["semantic_remap"] = scaled_low_res_train_label 419 | 420 | else: # for sparse labelling, we only penalise strictly on valid pixel positions 421 | valid_low_res_pixel_mask = np.zeros((self.img_h, self.img_w)) 422 | valid_low_res_pixel_mask[::down_scale_factor, ::down_scale_factor]=1 423 | self.train_samples["semantic_remap"] = (self.train_samples["semantic_remap"]*valid_low_res_pixel_mask[None,...]).astype(np.uint8) 424 | # we mask all the decimated pixel label to void class==0 425 | 426 | def simulate_user_click_partial(self, perc=0, load_saved=False, visualise_save=True): 427 | """ 428 | Generate partial label maps for label propagation task. 429 | perc: the percentage of pixels per class per image to be preserved to simulate partial user clicks 430 | 0: single-clicks 431 | 1: 1% user clicks 432 | 5: 5% user clicks 433 | 434 | load_saved: If true, load saved partial clicks to guarantee reproductability. False, create new partial laels 435 | visualise_save: If true, save processed partial labels into local harddrive/folders for futher usage like visualisation. 436 | """ 437 | assert perc<=100 and perc >= 0 438 | assert self.train_num == self.train_samples["semantic_remap"].shape[0] 439 | single_click=True if perc==0 else False # single_click: whether to use single click only from each class 440 | perc = perc/100.0 # make perc value into percentage 441 | if not load_saved: 442 | 443 | if single_click: 444 | click_semantic_map = [] 445 | for i in range(self.train_num): 446 | if (i+1)%10==0: 447 | print("Generating partial label of ratio {} for frame {}/{}.".format(perc, i, self.train_num)) 448 | im = self.train_samples["semantic_remap"][i] 449 | void_class = [0] 450 | label_class = np.unique(im).tolist() 451 | valid_class = [i for i in label_class if i not in void_class] 452 | im_ = np.zeros_like(im) 453 | for l in valid_class: 454 | label_idx = np.transpose(np.nonzero(im == l)) 455 | sample_ind = np.random.choice(label_idx.shape[0], 1, replace=False) 456 | label_idx_ = label_idx[sample_ind] 457 | im_[label_idx_[:, 0], label_idx_[:, 1]] = l 458 | click_semantic_map.append(im_) 459 | click_semantic_map = np.asarray(click_semantic_map).astype(np.uint8) 460 | self.train_samples["semantic_remap"] = click_semantic_map 461 | 462 | print('Partial Label images with centroid sampling (extreme) has completed.') 463 | 464 | elif perc>0 and not single_click: 465 | click_semantic_map = [] 466 | for i in range(self.train_num): 467 | if (i+1)%10==0: 468 | print("Generating partial label of ratio {} for frame {}/{}.".format(perc, i, self.train_num)) 469 | im = self.train_samples["semantic_remap"][i] 470 | void_class = [0] 471 | label_class = np.unique(im).tolist() # find the unique class-ids in the current training label 472 | valid_class = [c for c in label_class if c not in void_class] 473 | 474 | im_ = np.zeros_like(im) 475 | for l in valid_class: 476 | label_mask = np.zeros_like(im) 477 | label_mask_ = im == l # binary mask of pixels equal to class-l 478 | label_idx = np.transpose(np.nonzero(label_mask_)) # Nx2 479 | sample_ind = np.random.choice(label_idx.shape[0], 1, replace=False) # shape [1,] 480 | label_idx_ = label_idx[sample_ind] # shape [1, 2] 481 | target_num = int(perc * label_mask_.sum()) # find the target and total number of pixels belong to class-l in the current image 482 | label_mask[label_idx_[0, 0], label_idx_[0, 1]] = 1 # full-zero mask with only selected pixel to be 1 483 | label_mask_true = label_mask 484 | # label_mask_true initially has only 1 True pixel, we continuously grow mask until reach expected percentage 485 | 486 | while label_mask_true.sum() < target_num: 487 | num_before_grow = label_mask_true.sum() 488 | label_mask = cv2.dilate(label_mask, kernel=np.ones([5, 5])) 489 | label_mask_true = label_mask * label_mask_ 490 | num_after_grow = label_mask_true.sum() 491 | if num_after_grow==num_before_grow: 492 | print("Initialise Another Seed for Growing!") 493 | # The current region stop growing which means the very local area has been filled, 494 | # so we need to initiate another seed to keep it growing 495 | uncovered_region_mask = label_mask_ - label_mask_true # pixels which are equal to 1 are un-sampled regions and belong to current class 496 | label_idx = np.transpose(np.nonzero(uncovered_region_mask)) # Nx2 497 | sample_ind = np.random.choice(label_idx.shape[0], 1, replace=False) # shape [1,] 498 | label_idx_ = label_idx[sample_ind] # shape [1, 2] 499 | label_mask[label_idx_[0, 0], label_idx_[0, 1]] = 1 500 | 501 | im_[label_mask_true.astype(bool)] = l 502 | click_semantic_map.append(im_) 503 | 504 | click_semantic_map = np.asarray(click_semantic_map).astype(np.uint8) 505 | self.train_samples["semantic_remap"] = click_semantic_map 506 | print('Partial Label images with centroid sampling has completed.') 507 | else: 508 | assert False 509 | 510 | if visualise_save: 511 | partial_sem_dir = os.path.join(self.semantic_class_dir, "partial_perc_{}".format(perc)) 512 | if not os.path.exists(partial_sem_dir): 513 | os.makedirs(partial_sem_dir) 514 | colour_map_np = self.colour_map_np 515 | vis_partial_sem = [] 516 | for i in range(self.train_num): 517 | vis_partial_semantic = colour_map_np[self.train_samples["semantic_remap"][i]] # [H, W, 3] 518 | imageio.imwrite(os.path.join(partial_sem_dir, "semantic_class_{}.png".format(i)), self.train_samples["semantic_remap"][i]) 519 | imageio.imwrite(os.path.join(partial_sem_dir, "vis_sem_class_{}.png".format(i)), vis_partial_semantic) 520 | vis_partial_sem.append(vis_partial_semantic) 521 | 522 | imageio.mimwrite(os.path.join(partial_sem_dir, 'partial_sem.mp4'), self.train_samples["semantic_remap"], fps=30, quality=8) 523 | imageio.mimwrite(os.path.join(partial_sem_dir, 'vis_partial_sem.mp4'), np.stack(vis_partial_sem, 0), fps=30, quality=8) 524 | 525 | else: # load saved single-click/partial semantics 526 | saved_partial_sem_dir = os.path.join(self.semantic_class_dir, "partial_perc_{}".format(perc)) 527 | semantic_img_list = [] 528 | semantic_path_list = sorted(glob.glob(saved_partial_sem_dir + '/semantic_class_*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4])) 529 | assert len(semantic_path_list)>0 530 | for idx in range(self.train_num): 531 | semantic = imread(semantic_path_list[idx]) 532 | semantic_img_list.append(semantic) 533 | self.train_samples["semantic_remap"] = np.asarray(semantic_img_list).astype(np.uint8) --------------------------------------------------------------------------------