├── src ├── gd │ ├── __init__.py │ ├── experiments │ │ ├── __init__.py │ │ └── clutter_removal.py │ ├── utils │ │ ├── __init__.py │ │ ├── panda_control.py │ │ ├── transform.py │ │ ├── ros_utils.py │ │ └── btsim.py │ ├── grasp.py │ ├── baselines.py │ ├── dataset.py │ ├── networks.py │ ├── io.py │ ├── detection.py │ ├── perception.py │ └── vis.py └── nr │ ├── dataset │ ├── name2dataset.py │ └── database.py │ ├── run_training.py │ ├── utils │ ├── dataset_utils.py │ ├── field_utils.py │ ├── view_select.py │ ├── grasp_utils.py │ └── imgs_info.py │ ├── network │ ├── vis_encoder.py │ ├── init_net.py │ ├── neus.py │ ├── mvsnet │ │ ├── modules.py │ │ └── mvsnet.py │ ├── metrics.py │ ├── dist_decoder.py │ ├── aggregate_net.py │ ├── render_ops.py │ ├── loss.py │ └── ops.py │ ├── configs │ └── nrvgn_sdf.yaml │ ├── train │ ├── train_valid.py │ ├── lr_common_manager.py │ ├── train_tools.py │ └── trainer.py │ ├── asset.py │ └── main.py ├── images └── teaser.png ├── train.sh ├── data_generator └── run_pile_rand.sh ├── requirements.txt ├── run_simgrasp.sh ├── .gitignore ├── scripts ├── stat_expresult.py └── sim_grasp.py └── README.md /src/gd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/gd/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-EPIC/GraspNeRF/HEAD/images/teaser.png -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | cd src/nr 2 | CUDA_VISIBLE_DEVICES=$1 python run_training.py --cfg configs/nrvgn_sdf.yaml 3 | cd - -------------------------------------------------------------------------------- /src/nr/dataset/name2dataset.py: -------------------------------------------------------------------------------- 1 | from dataset.train_dataset import GeneralRendererDataset, FinetuningRendererDataset 2 | 3 | name2dataset={ 4 | 'gen': GeneralRendererDataset, 5 | 'ft': FinetuningRendererDataset, 6 | } -------------------------------------------------------------------------------- /data_generator/run_pile_rand.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /data/InterNeRF/renderer/renderer_giga_GPU6-0_rand_M 4 | 5 | # 830*6 6 | mycount=0; 7 | while (( $mycount < 100 )); do 8 | /home/xxx/blender-2.93.3-linux-x64/blender material_lib_v2.blend --background -noaudio --python render_pile_STD_rand.py -- $mycount; 9 | ((mycount=$mycount+1)); 10 | done; -------------------------------------------------------------------------------- /src/nr/run_training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from train.trainer import Trainer 4 | from utils.base_utils import load_cfg 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--cfg', type=str, default='configs/train/gen/neuray_gen_depth_train.yaml') 8 | flags = parser.parse_args() 9 | 10 | trainer = Trainer(load_cfg(flags.cfg)) 11 | trainer.run() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | tensorflow 3 | easydict 4 | inplace-abn 5 | plyfile 6 | numpy 7 | scikit-image 8 | pyyaml 9 | h5py 10 | opencv-python 11 | tqdm 12 | matplotlib 13 | scipy 14 | lpips 15 | transforms3d 16 | kornia 17 | sklearn 18 | catkin_pkg 19 | black 20 | jupyterlab 21 | pandas 22 | mpi4py 23 | open3d 24 | pybullet==2.7.9 25 | pytorch-ignite 26 | tensorboard -------------------------------------------------------------------------------- /src/gd/utils/__init__.py: -------------------------------------------------------------------------------- 1 | def workspace_lines(size): 2 | return [ 3 | [0.0, 0.0, 0.0], 4 | [size, 0.0, 0.0], 5 | [size, 0.0, 0.0], 6 | [size, size, 0.0], 7 | [size, size, 0.0], 8 | [0.0, size, 0.0], 9 | [0.0, size, 0.0], 10 | [0.0, 0.0, 0.0], 11 | [0.0, 0.0, size], 12 | [size, 0.0, size], 13 | [size, 0.0, size], 14 | [size, size, size], 15 | [size, size, size], 16 | [0.0, size, size], 17 | [0.0, size, size], 18 | [0.0, 0.0, size], 19 | [0.0, 0.0, 0.0], 20 | [0.0, 0.0, size], 21 | [size, 0.0, 0.0], 22 | [size, 0.0, size], 23 | [size, size, 0.0], 24 | [size, size, size], 25 | [0.0, size, 0.0], 26 | [0.0, size, size], 27 | ] 28 | 29 | -------------------------------------------------------------------------------- /src/nr/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import random 4 | import torch 5 | 6 | def dummy_collate_fn(data_list): 7 | return data_list[0] 8 | 9 | def simple_collate_fn(data_list): 10 | ks=data_list[0].keys() 11 | outputs={k:[] for k in ks} 12 | for k in ks: 13 | for data in data_list: 14 | outputs[k].append(data[k]) 15 | outputs[k]=torch.stack(outputs[k],0) 16 | return outputs 17 | 18 | def set_seed(index,is_train): 19 | if is_train: 20 | np.random.seed((index+int(time.time()))%(2**16)) 21 | random.seed((index+int(time.time()))%(2**16)+1) 22 | torch.random.manual_seed((index+int(time.time()))%(2**16)+1) 23 | else: 24 | np.random.seed(index % (2 ** 16)) 25 | random.seed(index % (2 ** 16) + 1) 26 | torch.random.manual_seed(index % (2 ** 16) + 1) -------------------------------------------------------------------------------- /src/nr/network/vis_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from network.ops import conv3x3, ResidualBlock, conv1x1 5 | 6 | class DefaultVisEncoder(nn.Module): 7 | default_cfg={} 8 | def __init__(self, cfg): 9 | super().__init__() 10 | self.cfg={**self.default_cfg,**cfg} 11 | norm_layer = lambda dim: nn.InstanceNorm2d(dim,track_running_stats=False,affine=True) 12 | self.out_conv=nn.Sequential( 13 | conv3x3(64, 32), 14 | ResidualBlock(32, 32, norm_layer=norm_layer), 15 | ResidualBlock(32, 32, norm_layer=norm_layer), 16 | conv1x1(32, 32), 17 | ) 18 | 19 | def forward(self, ray_feats, imgs_feats): 20 | feats = self.out_conv(torch.cat([imgs_feats, ray_feats],1)) 21 | return feats 22 | 23 | name2vis_encoder={ 24 | 'default': DefaultVisEncoder, 25 | } -------------------------------------------------------------------------------- /src/gd/grasp.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | 4 | class Label(enum.IntEnum): 5 | FAILURE = 0 # grasp execution failed due to collision or slippage 6 | SUCCESS = 1 # object was successfully removed 7 | 8 | 9 | class Grasp(object): 10 | """Grasp parameterized as pose of a 2-finger robot hand. 11 | 12 | TODO(mbreyer): clarify definition of grasp frame 13 | """ 14 | 15 | def __init__(self, pose, width): 16 | self.pose = pose 17 | self.width = width 18 | 19 | 20 | def to_voxel_coordinates(grasp, voxel_size): 21 | pose = grasp.pose 22 | pose.translation /= voxel_size 23 | width = grasp.width / voxel_size 24 | return Grasp(pose, width) 25 | 26 | 27 | def from_voxel_coordinates(grasp, voxel_size): 28 | pose = grasp.pose 29 | pose.translation *= voxel_size 30 | width = grasp.width * voxel_size 31 | return Grasp(pose, width) 32 | -------------------------------------------------------------------------------- /run_simgrasp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPUID=0 4 | BLENDER_BIN=blender 5 | 6 | RENDERER_ASSET_DIR=./data/assets 7 | BLENDER_PROJ_PATH=./data/assets/material_lib_graspnet-v2.blend 8 | SIM_LOG_DIR="./log/`date '+%Y%m%d-%H%M%S'`" 9 | 10 | scene="pile" 11 | object_set="pile_subdiv" 12 | material_type="specular_and_transparent" 13 | render_frame_list="2,6,10,14,18,22" 14 | check_seen_scene=0 15 | expname=0 16 | 17 | NUM_TRIALS=200 18 | METHOD='graspnerf' 19 | 20 | mycount=0 21 | while (( $mycount < $NUM_TRIALS )); do 22 | $BLENDER_BIN $BLENDER_PROJ_PATH --background --python scripts/sim_grasp.py \ 23 | -- $mycount $GPUID $expname $scene $object_set $check_seen_scene $material_type \ 24 | $RENDERER_ASSET_DIR $SIM_LOG_DIR 0 $render_frame_list $METHOD 25 | 26 | python ./scripts/stat_expresult.py -- $SIM_LOG_DIR $expname 27 | ((mycount=$mycount+1)); 28 | done; 29 | 30 | python ./scripts/stat_expresult.py -- $SIM_LOG_DIR $expname -------------------------------------------------------------------------------- /src/nr/network/init_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from network.ops import interpolate_feats, masked_mean_var, ResEncoder, ResUNetLight, conv3x3, ResidualBlock, conv1x1 7 | 8 | class CostVolumeInitNet(nn.Module): 9 | default_cfg={ 10 | 'cost_volume_sn': 64, 11 | } 12 | def __init__(self,cfg): 13 | super().__init__() 14 | self.cfg={**self.default_cfg,**cfg} 15 | 16 | imagenet_mean = torch.from_numpy(np.asarray([0.485, 0.456, 0.406], np.float32)).cuda()[None, :, None, None] 17 | imagenet_std = torch.from_numpy(np.asarray([0.229, 0.224, 0.225], np.float32)).cuda()[None, :, None, None] 18 | self.register_buffer('imagenet_mean', imagenet_mean) 19 | self.register_buffer('imagenet_std', imagenet_std) 20 | 21 | self.res_net = ResUNetLight(out_dim=32) 22 | norm_layer = lambda dim: nn.InstanceNorm2d(dim, track_running_stats=False, affine=True) 23 | 24 | 25 | in_dim = 32 26 | 27 | self.out_conv = nn.Sequential( 28 | conv3x3(in_dim, 32), 29 | ResidualBlock(32, 32, norm_layer=norm_layer), 30 | conv1x1(32, 32), 31 | ) 32 | 33 | def forward(self, ref_imgs_info, src_imgs_info, is_train): 34 | ref_feats = self.res_net(ref_imgs_info['imgs']) 35 | return self.out_conv(torch.cat([ref_feats], 1)) 36 | 37 | name2init_net={ 38 | 'cost_volume': CostVolumeInitNet, 39 | } -------------------------------------------------------------------------------- /src/nr/utils/field_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def generate_grid_points_old(bound_min, bound_max, resolution): 5 | X = torch.linspace(bound_min[0], bound_max[0], resolution) 6 | Y = torch.linspace(bound_min[1], bound_max[1], resolution) 7 | Z = torch.linspace(bound_max[2], bound_min[2], resolution) # from top to down to be like with training rays 8 | XYZ = torch.stack(torch.meshgrid(X, Y, Z), dim=-1) 9 | 10 | return XYZ 11 | 12 | RESOLUTION = 40 13 | VOLUME_SIZE = 0.3 14 | VOXEL_SIZE = VOLUME_SIZE / RESOLUTION 15 | HALF_VOXEL_SIZE = VOXEL_SIZE / 2 16 | 17 | def generate_grid_points(): 18 | points = [] 19 | for x in range(RESOLUTION): 20 | for y in range(RESOLUTION): 21 | for z in range(RESOLUTION): 22 | points.append([x * VOXEL_SIZE + HALF_VOXEL_SIZE, 23 | y * VOXEL_SIZE + HALF_VOXEL_SIZE, 24 | z * VOXEL_SIZE + HALF_VOXEL_SIZE]) 25 | return np.array(points).astype(np.float32) 26 | 27 | TSDF_SAMPLE_POINTS = generate_grid_points() 28 | 29 | if __name__ == "__main__": 30 | GT_POINTS = np.load('points.npy') 31 | TSDF_VOLUME_MASK = np.zeros((1, 40, 40, 40), dtype=np.bool8) 32 | idxs = [] 33 | for point in GT_POINTS: 34 | i, j, k = np.floor(point / VOXEL_SIZE).astype(int) 35 | TSDF_VOLUME_MASK[0, i, j, k] = True 36 | idxs.append(i * (RESOLUTION * RESOLUTION) + j * RESOLUTION + k) 37 | print(TSDF_SAMPLE_POINTS[idxs], GT_POINTS) 38 | assert np.allclose(TSDF_SAMPLE_POINTS[idxs], GT_POINTS) 39 | 40 | -------------------------------------------------------------------------------- /src/nr/utils/view_select.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dataset.database import BaseDatabase 4 | 5 | def compute_nearest_camera_indices(database, que_ids, ref_ids=None): 6 | if ref_ids is None: ref_ids = que_ids 7 | ref_poses = [database.get_pose(ref_id) for ref_id in ref_ids] 8 | ref_cam_pts = np.asarray([-pose[:, :3].T @ pose[:, 3] for pose in ref_poses]) 9 | que_poses = [database.get_pose(que_id) for que_id in que_ids] 10 | que_cam_pts = np.asarray([-pose[:, :3].T @ pose[:, 3] for pose in que_poses]) 11 | 12 | dists = np.linalg.norm(ref_cam_pts[None, :, :] - que_cam_pts[:, None, :], 2, 2) 13 | dists_idx = np.argsort(dists, 1) 14 | return dists_idx 15 | 16 | def select_working_views(ref_poses, que_poses, work_num, exclude_self=False): 17 | ref_cam_pts = np.asarray([-pose[:, :3].T @ pose[:, 3] for pose in ref_poses]) 18 | render_cam_pts = np.asarray([-pose[:, :3].T @ pose[:, 3] for pose in que_poses]) 19 | dists = np.linalg.norm(ref_cam_pts[None, :, :] - render_cam_pts[:, None, :], 2, 2) # qn,rfn 20 | ids = np.argsort(dists) 21 | if exclude_self: 22 | ids = ids[:, 1:work_num+1] 23 | else: 24 | ids = ids[:, :work_num] 25 | return ids 26 | 27 | def select_working_views_db(database: BaseDatabase, ref_ids, que_poses, work_num, exclude_self=False): 28 | ref_ids = database.get_img_ids() if ref_ids is None else ref_ids 29 | ref_poses = [database.get_pose(img_id) for img_id in ref_ids] 30 | 31 | ref_ids = np.asarray(ref_ids) 32 | ref_poses = np.asarray(ref_poses) 33 | indices = select_working_views(ref_poses, que_poses, work_num, exclude_self) 34 | return ref_ids[indices] # qn,wn -------------------------------------------------------------------------------- /src/nr/configs/nrvgn_sdf.yaml: -------------------------------------------------------------------------------- 1 | name: test 2 | group_name: "" 3 | # network 4 | fix_seed: true 5 | network: grasp_nerf 6 | init_net_type: cost_volume 7 | agg_net_type: neus 8 | use_hierarchical_sampling: true 9 | use_depth: true 10 | use_depth_loss: true 11 | depth_loss_weight: 1.0 12 | dist_decoder_cfg: 13 | use_vis: false 14 | fine_dist_decoder_cfg: 15 | use_vis: false 16 | ray_batch_num: 4096 #2048 17 | sample_volume: true 18 | render_rgb: true 19 | volume_type: [sdf] 20 | 21 | volume_resolution: 40 22 | depth_sample_num: 40 23 | fine_depth_sample_num: 40 24 | agg_net_cfg: 25 | sample_num: 40 26 | init_s: 0.3 27 | fix_s: 0 28 | fine_agg_net_cfg: 29 | sample_num: 40 30 | init_s: 0.3 31 | fix_s: 0 32 | vis_vol: false 33 | 34 | # loss 35 | loss: [render, depth, sdf, vgn] 36 | val_metric: [psnr_ssim, vis_img] 37 | key_metric_name: loss_vgn # depth_mae psnr_nr_fine 38 | key_metric_prefer: lower 39 | use_dr_loss: false 40 | use_dr_fine_loss: false 41 | use_nr_fine_loss: true 42 | render_depth: true 43 | depth_correct_ratio: 1.0 44 | depth_thresh: 0.8 45 | use_dr_prediction: false 46 | 47 | # lr 48 | total_step: 500000 49 | val_interval: 5000 50 | lr_type: exp_decay 51 | lr_cfg: 52 | lr_init: 1.0e-4 53 | decay_step: 100000 54 | decay_rate: 0.5 55 | nr_initial_training_steps: 0 56 | 57 | # dataset 58 | train_dataset_type: gen 59 | train_dataset_cfg: 60 | resolution_type: hr 61 | type2sample_weights: { vgn_syn: 100 } 62 | train_database_types: ['vgn_syn'] 63 | aug_pixel_center_sample: true 64 | aug_view_select_type: hard 65 | ref_pad_interval: 32 66 | use_src_imgs: true 67 | num_input_views: 6 68 | 69 | val_set_list: 70 | - 71 | name: vgn_syn 72 | type: gen 73 | val_scene_num: -1 # if the set, use val scene list in asset.py 74 | cfg: 75 | use_src_imgs: true 76 | num_input_views: 6 -------------------------------------------------------------------------------- /src/nr/train/train_valid.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from network.metrics import name2key_metrics 8 | from train.train_tools import to_cuda 9 | 10 | 11 | class ValidationEvaluator: 12 | def __init__(self,cfg): 13 | self.key_metric_name=cfg['key_metric_name'] 14 | self.key_metric=name2key_metrics[self.key_metric_name] 15 | 16 | def __call__(self, model, losses, eval_dataset, step, model_name, val_set_name=None): 17 | if val_set_name is not None: model_name=f'{model_name}-{val_set_name}' 18 | model.eval() 19 | eval_results={} 20 | begin=time.time() 21 | for data_i, data in tqdm(enumerate(eval_dataset)): 22 | data = to_cuda(data) 23 | data['eval']=True 24 | data['step']=step 25 | with torch.no_grad(): 26 | outputs=model(data) 27 | for loss in losses: 28 | loss_results=loss(outputs, data, step, data_index=data_i, model_name=model_name, is_train=False) 29 | for k,v in loss_results.items(): 30 | if type(v)==torch.Tensor: 31 | v=v.detach().cpu().numpy() 32 | 33 | if k in eval_results: 34 | eval_results[k].append(v) 35 | else: 36 | eval_results[k]=[v] 37 | 38 | for k,v in eval_results.items(): 39 | eval_results[k]=np.concatenate(v,axis=0) 40 | 41 | key_metric_val=self.key_metric(eval_results) 42 | if key_metric_val != 1e6: 43 | eval_results[self.key_metric_name + '_all'] = eval_results[self.key_metric_name] 44 | eval_results[self.key_metric_name]=key_metric_val 45 | print('eval cost {} s'.format(time.time()-begin)) 46 | return eval_results, key_metric_val 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | .venv-py2 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # vscode 108 | .vscode 109 | 110 | # data 111 | data 112 | 113 | __pycache__ 114 | 115 | *.log 116 | log/ 117 | 118 | output/ 119 | ckpt/ 120 | 121 | *.pth -------------------------------------------------------------------------------- /src/gd/baselines.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from gpd_ros.msg import GraspConfigList 4 | import numpy as np 5 | from sensor_msgs.msg import PointCloud2 6 | import rospy 7 | 8 | from gd.grasp import Grasp 9 | from gd.utils import ros_utils 10 | from gd.utils.transform import Rotation, Transform 11 | 12 | 13 | class GPD(object): 14 | def __init__(self): 15 | self.input_topic = "/cloud_stitched" 16 | self.output_topic = "/detect_grasps/clustered_grasps" 17 | self.cloud_pub = rospy.Publisher(self.input_topic, PointCloud2, queue_size=1) 18 | 19 | def __call__(self, state): 20 | points = np.asarray(state.pc.points) 21 | msg = ros_utils.to_cloud_msg(points, frame="task") 22 | self.cloud_pub.publish(msg) 23 | 24 | tic = time.time() 25 | result = rospy.wait_for_message(self.output_topic, GraspConfigList) 26 | toc = time.time() - tic 27 | 28 | grasps, scores = self.to_grasp_list(result) 29 | 30 | return grasps, scores, toc 31 | 32 | def to_grasp_list(self, grasp_configs): 33 | grasps, scores = [], [] 34 | for grasp_config in grasp_configs.grasps: 35 | # orientation 36 | x_axis = ros_utils.from_vector3_msg(grasp_config.axis) 37 | y_axis = -ros_utils.from_vector3_msg(grasp_config.binormal) 38 | z_axis = ros_utils.from_vector3_msg(grasp_config.approach) 39 | orientation = Rotation.from_matrix(np.vstack([x_axis, y_axis, z_axis]).T) 40 | # position 41 | position = ros_utils.from_point_msg(grasp_config.position) 42 | # width 43 | width = grasp_config.width.data 44 | # score 45 | score = grasp_config.score.data 46 | 47 | if score < 0.0: 48 | continue # negative score is larger than positive score (https://github.com/atenpas/gpd/issues/32#issuecomment-387846534) 49 | 50 | grasps.append(Grasp(Transform(orientation, position), width)) 51 | scores.append(score) 52 | 53 | return grasps, scores 54 | -------------------------------------------------------------------------------- /src/nr/train/lr_common_manager.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | class LearningRateManager(abc.ABC): 4 | @staticmethod 5 | def set_lr_for_all(optimizer, lr): 6 | for param_group in optimizer.param_groups: 7 | param_group['lr'] = lr 8 | 9 | def construct_optimizer(self, optimizer, network): 10 | # may specify different lr for different parts 11 | # use group to set learning rate 12 | paras = network.parameters() 13 | return optimizer(paras, lr=1e-3) 14 | 15 | @abc.abstractmethod 16 | def __call__(self, optimizer, step, *args, **kwargs): 17 | pass 18 | 19 | class ExpDecayLR(LearningRateManager): 20 | def __init__(self,cfg): 21 | self.lr_init=cfg['lr_init'] 22 | self.decay_step=cfg['decay_step'] 23 | self.decay_rate=cfg['decay_rate'] 24 | self.lr_min=1e-5 25 | 26 | def __call__(self, optimizer, step, *args, **kwargs): 27 | lr=max(self.lr_init*(self.decay_rate**(step//self.decay_step)),self.lr_min) 28 | self.set_lr_for_all(optimizer,lr) 29 | return lr 30 | 31 | class ExpDecayLRRayFeats(ExpDecayLR): 32 | def construct_optimizer(self, optimizer, network): 33 | paras = network.parameters() 34 | return optimizer([para for para in paras] + network.ray_feats, lr=1e-3) 35 | 36 | class WarmUpExpDecayLR(LearningRateManager): 37 | def __init__(self, cfg): 38 | self.lr_warm=cfg['lr_warm'] 39 | self.warm_step=cfg['warm_step'] 40 | self.lr_init=cfg['lr_init'] 41 | self.decay_step=cfg['decay_step'] 42 | self.decay_rate=cfg['decay_rate'] 43 | self.lr_min=1e-5 44 | 45 | def __call__(self, optimizer, step, *args, **kwargs): 46 | if step self.fix_s: 18 | self.variance.requires_grad = True 19 | return torch.ones([len(x), 1], device=x.device) * torch.exp(self.variance * 10.0) 20 | 21 | class Embedder: 22 | def __init__(self, **kwargs): 23 | self.kwargs = kwargs 24 | self.create_embedding_fn() 25 | 26 | def create_embedding_fn(self): 27 | embed_fns = [] 28 | d = self.kwargs['input_dims'] 29 | out_dim = 0 30 | if self.kwargs['include_input']: 31 | embed_fns.append(lambda x: x) 32 | out_dim += d 33 | 34 | max_freq = self.kwargs['max_freq_log2'] 35 | N_freqs = self.kwargs['num_freqs'] 36 | 37 | if self.kwargs['log_sampling']: 38 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 39 | else: 40 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 41 | 42 | for freq in freq_bands: 43 | for p_fn in self.kwargs['periodic_fns']: 44 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 45 | out_dim += d 46 | 47 | self.embed_fns = embed_fns 48 | self.out_dim = out_dim 49 | 50 | def embed(self, inputs): 51 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 52 | 53 | 54 | def get_embedder(multires, input_dims=3): 55 | embed_kwargs = { 56 | 'include_input': True, 57 | 'input_dims': input_dims, 58 | 'max_freq_log2': multires-1, 59 | 'num_freqs': multires, 60 | 'log_sampling': True, 61 | 'periodic_fns': [torch.sin, torch.cos], 62 | } 63 | 64 | embedder_obj = Embedder(**embed_kwargs) 65 | def embed(x, eo=embedder_obj): return eo.embed(x) 66 | return embed, embedder_obj.out_dim 67 | -------------------------------------------------------------------------------- /src/nr/asset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | DATA_ROOT_DIR = '../../data/traindata_example/' 5 | VGN_TRAIN_ROOT = DATA_ROOT_DIR + 'giga_hemisphere_train_demo' 6 | 7 | def add_scenes(root, type, filter_list=None): 8 | scene_names = [] 9 | splits = os.listdir(root) 10 | for split in splits: 11 | if filter_list is not None and split not in filter_list: continue 12 | scenes = os.listdir(os.path.join(root, split)) 13 | scene_names += [f'vgn_syn/train/{type}/{split}/{fn}/w_0.8' for fn in scenes] 14 | return scene_names 15 | if os.path.exists(VGN_TRAIN_ROOT): 16 | vgn_pile_train_scene_names = sorted(add_scenes(os.path.join(VGN_TRAIN_ROOT, 'pile_full'), 'pile'), key=lambda x: x.split('/')[4]) 17 | vgn_pack_train_scene_names = sorted(add_scenes(os.path.join(VGN_TRAIN_ROOT, 'packed_full'), 'packed'), key=lambda x: x.split('/')[4]) 18 | num_scenes_pile = len(vgn_pile_train_scene_names) 19 | num_scenes_pack = len(vgn_pack_train_scene_names) 20 | vgn_pack_train_scene_names = vgn_pack_train_scene_names[:num_scenes_pack] 21 | num_val_pile = 1 22 | num_val_pack = 1 23 | print(f"total: {num_scenes_pile + num_scenes_pack} pile: {num_scenes_pile} pack: {num_scenes_pack}") 24 | vgn_val_scene_names = vgn_pile_train_scene_names[-num_val_pile:] + vgn_pack_train_scene_names[-num_val_pack:] 25 | vgn_train_scene_names = vgn_pile_train_scene_names[:-num_val_pile] + vgn_pack_train_scene_names[:-num_val_pack] 26 | 27 | VGN_SDF_DIR = DATA_ROOT_DIR + "giga_hemisphere_train_demo/scenes_tsdf_dep-nor" 28 | 29 | VGN_TEST_ROOT = '' 30 | VGN_TEST_ROOT_PILE = os.path.join(VGN_TEST_ROOT,'pile') 31 | VGN_TEST_ROOT_PACK = os.path.join(VGN_TEST_ROOT,'packed') 32 | if os.path.exists(VGN_TEST_ROOT): 33 | fns = os.listdir(VGN_TEST_ROOT_PILE) 34 | vgn_pile_test_scene_names = [f'vgn_syn/test/pile//{fn}/w_0.8' for fn in fns] 35 | fns = os.listdir(VGN_TEST_ROOT_PACK) 36 | vgn_pack_test_scene_names = [f'vgn_syn/test/packed//{fn}/w_0.8' for fn in fns] 37 | 38 | vgn_test_scene_names = vgn_pile_test_scene_names + vgn_pack_test_scene_names 39 | 40 | CSV_ROOT = DATA_ROOT_DIR + 'GIGA_demo' 41 | import pandas as pd 42 | from pathlib import Path 43 | import time 44 | t0 = time.time() 45 | VGN_PACK_TRAIN_CSV = pd.read_csv(Path(CSV_ROOT + '/data_packed_train_processed_dex_noise/grasps.csv')) 46 | VGN_PILE_TRAIN_CSV = pd.read_csv(Path(CSV_ROOT + '/data_pile_train_processed_dex_noise/grasps.csv')) 47 | print(f"finished loading csv in {time.time() - t0} s") 48 | VGN_PACK_TEST_CSV = None 49 | VGN_PILE_TEST_CSV = None 50 | -------------------------------------------------------------------------------- /src/gd/networks.py: -------------------------------------------------------------------------------- 1 | from builtins import super 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from scipy import ndimage 7 | 8 | 9 | def get_network(name): 10 | models = { 11 | "conv": ConvNet(), 12 | } 13 | return models[name.lower()] 14 | 15 | 16 | def load_network(path, device): 17 | """Construct the neural network and load parameters from the specified file. 18 | 19 | Args: 20 | path: Path to the model parameters. The name must conform to `vgn_name_[_...]`. 21 | 22 | """ 23 | model_name = path.stem.split("_")[1] 24 | net = get_network(model_name).to(device) 25 | net.load_state_dict(torch.load(path, map_location=device)) 26 | return net 27 | 28 | 29 | def conv(in_channels, out_channels, kernel_size): 30 | return nn.Conv3d(in_channels, out_channels, kernel_size, padding=kernel_size // 2) 31 | 32 | 33 | def conv_stride(in_channels, out_channels, kernel_size): 34 | return nn.Conv3d( 35 | in_channels, out_channels, kernel_size, stride=2, padding=kernel_size // 2 36 | ) 37 | 38 | 39 | class ConvNet(nn.Module): 40 | def __init__(self): 41 | super().__init__() 42 | self.encoder = Encoder(1, [16, 32, 64], [5, 3, 3]) 43 | self.decoder = Decoder(64, [64, 32, 16], [3, 3, 5]) 44 | self.conv_qual = conv(16, 1, 5) 45 | self.conv_rot = conv(16, 4, 5) 46 | self.conv_width = conv(16, 1, 5) 47 | 48 | def forward(self, x): 49 | x = self.encoder(x) 50 | x = self.decoder(x) 51 | qual_out = torch.sigmoid(self.conv_qual(x)) 52 | rot_out = F.normalize(self.conv_rot(x), dim=1) 53 | width_out = self.conv_width(x) 54 | return qual_out, rot_out, width_out 55 | 56 | 57 | class Encoder(nn.Module): 58 | def __init__(self, in_channels, filters, kernels): 59 | super().__init__() 60 | self.conv1 = conv_stride(in_channels, filters[0], kernels[0]) 61 | self.conv2 = conv_stride(filters[0], filters[1], kernels[1]) 62 | self.conv3 = conv_stride(filters[1], filters[2], kernels[2]) 63 | 64 | def forward(self, x): 65 | x = self.conv1(x) 66 | x = F.relu(x) 67 | 68 | x = self.conv2(x) 69 | x = F.relu(x) 70 | 71 | x = self.conv3(x) 72 | x = F.relu(x) 73 | 74 | return x 75 | 76 | 77 | class Decoder(nn.Module): 78 | def __init__(self, in_channels, filters, kernels): 79 | super().__init__() 80 | self.conv1 = conv(in_channels, filters[0], kernels[0]) 81 | self.conv2 = conv(filters[0], filters[1], kernels[1]) 82 | self.conv3 = conv(filters[1], filters[2], kernels[2]) 83 | 84 | def forward(self, x): 85 | x = self.conv1(x) 86 | x = F.relu(x) 87 | 88 | x = F.interpolate(x, 10) 89 | x = self.conv2(x) 90 | x = F.relu(x) 91 | 92 | x = F.interpolate(x, 20) 93 | x = self.conv3(x) 94 | x = F.relu(x) 95 | 96 | x = F.interpolate(x, 40) 97 | return x 98 | 99 | 100 | def count_num_trainable_parameters(net): 101 | return sum(p.numel() for p in net.parameters() if p.requires_grad) 102 | -------------------------------------------------------------------------------- /src/nr/network/mvsnet/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from inplace_abn import InPlaceABN 5 | from kornia.utils import create_meshgrid 6 | 7 | class ConvBnReLU(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=InPlaceABN): 9 | super(ConvBnReLU, self).__init__() 10 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) 11 | self.bn = norm_act(out_channels) 12 | 13 | def forward(self, x): 14 | return self.bn(self.conv(x)) 15 | 16 | class ConvBnReLU3D(nn.Module): 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=InPlaceABN): 18 | super(ConvBnReLU3D, self).__init__() 19 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) 20 | self.bn = norm_act(out_channels) 21 | 22 | def forward(self, x): 23 | return self.bn(self.conv(x)) 24 | 25 | def homo_warp(src_feat, src_proj, ref_proj_inv, depth_values): 26 | # src_feat: (B, C, H, W) 27 | # src_proj: (B, 4, 4) 28 | # ref_proj_inv: (B, 4, 4) 29 | # depth_values: (B, D) 30 | # out: (B, C, D, H, W) 31 | B, C, H, W = src_feat.shape 32 | D = depth_values.shape[1] 33 | device = src_feat.device 34 | dtype = src_feat.dtype 35 | 36 | transform = src_proj @ ref_proj_inv 37 | R = transform[:, :3, :3] # (B, 3, 3) 38 | T = transform[:, :3, 3:] # (B, 3, 1) 39 | # create grid from the ref frame 40 | ref_grid = create_meshgrid(H, W, normalized_coordinates=False) # (1, H, W, 2) 41 | ref_grid = ref_grid.to(device).to(dtype) 42 | ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W) 43 | ref_grid = ref_grid.reshape(1, 2, H*W) # (1, 2, H*W) 44 | ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W) 45 | ref_grid = torch.cat((ref_grid, torch.ones_like(ref_grid[:,:1])), 1) # (B, 3, H*W) 46 | ref_grid_d = ref_grid.unsqueeze(2) * depth_values.view(B, 1, D, 1) # (B, 3, D, H*W) 47 | ref_grid_d = ref_grid_d.view(B, 3, D*H*W) 48 | src_grid_d = R @ ref_grid_d + T # (B, 3, D*H*W) 49 | del ref_grid_d, ref_grid, transform, R, T # release (GPU) memory 50 | div_val = src_grid_d[:, -1:] 51 | div_val[div_val<1e-4] = 1e-4 52 | src_grid = src_grid_d[:, :2] / div_val # divide by depth (B, 2, D*H*W) 53 | del src_grid_d, div_val 54 | src_grid[:, 0] = src_grid[:, 0]/((W - 1) / 2) - 1 # scale to -1~1 55 | src_grid[:, 1] = src_grid[:, 1]/((H - 1) / 2) - 1 # scale to -1~1 56 | src_grid = src_grid.permute(0, 2, 1) # (B, D*H*W, 2) 57 | src_grid = src_grid.view(B, D, H*W, 2) 58 | 59 | warped_src_feat = F.grid_sample(src_feat, src_grid, 60 | mode='bilinear', padding_mode='zeros', 61 | align_corners=True) # (B, C, D, H*W) 62 | warped_src_feat = warped_src_feat.view(B, C, D, H, W) 63 | 64 | return warped_src_feat 65 | 66 | def depth_regression(p, depth_values): 67 | # p: probability volume [B, D, H, W] 68 | # depth_values: discrete depth values [B, D] 69 | depth_values = depth_values.view(*depth_values.shape, 1, 1) 70 | depth = torch.sum(p * depth_values, 1) 71 | return depth -------------------------------------------------------------------------------- /src/gd/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from gd.grasp import Grasp 8 | from gd.perception import * 9 | from gd.utils.transform import Rotation, Transform 10 | 11 | 12 | def write_setup(root, size, intrinsic, max_opening_width, finger_depth): 13 | data = { 14 | "size": size, 15 | "intrinsic": intrinsic.to_dict(), 16 | "max_opening_width": max_opening_width, 17 | "finger_depth": finger_depth, 18 | } 19 | write_json(data, root / "setup.json") 20 | 21 | 22 | def read_setup(root): 23 | data = read_json(root / "setup.json") 24 | size = data["size"] 25 | intrinsic = CameraIntrinsic.from_dict(data["intrinsic"]) 26 | max_opening_width = data["max_opening_width"] 27 | finger_depth = data["finger_depth"] 28 | return size, intrinsic, max_opening_width, finger_depth 29 | 30 | 31 | def write_sensor_data(root, depth_imgs, extrinsics): 32 | scene_id = uuid.uuid4().hex 33 | path = root / "scenes" / (scene_id + ".npz") 34 | np.savez_compressed(path, depth_imgs=depth_imgs, extrinsics=extrinsics) 35 | return scene_id 36 | 37 | 38 | def read_sensor_data(root, scene_id): 39 | data = np.load(root / "scenes" / (scene_id + ".npz")) 40 | return data["depth_imgs"], data["extrinsics"] 41 | 42 | 43 | def write_grasp(root, scene_id, grasp, label): 44 | # TODO concurrent writes could be an issue 45 | csv_path = root / "grasps.csv" 46 | if not csv_path.exists(): 47 | create_csv( 48 | csv_path, 49 | ["scene_id", "qx", "qy", "qz", "qw", "x", "y", "z", "width", "label"], 50 | ) 51 | qx, qy, qz, qw = grasp.pose.rotation.as_quat() 52 | x, y, z = grasp.pose.translation 53 | width = grasp.width 54 | append_csv(csv_path, scene_id, qx, qy, qz, qw, x, y, z, width, label) 55 | 56 | 57 | def read_grasp(df, i): 58 | scene_id = df.loc[i, "scene_id"] 59 | orientation = Rotation.from_quat(df.loc[i, "qx":"qw"].to_numpy(np.double)) 60 | position = df.loc[i, "x":"z"].to_numpy(np.double) 61 | width = df.loc[i, "width"] 62 | label = df.loc[i, "label"] 63 | grasp = Grasp(Transform(orientation, position), width) 64 | return scene_id, grasp, label 65 | 66 | 67 | def read_df(root): 68 | return pd.read_csv(root / "grasps.csv") 69 | 70 | 71 | def write_df(df, root): 72 | df.to_csv(root / "grasps.csv", index=False) 73 | 74 | 75 | def write_voxel_grid(root, scene_id, voxel_grid): 76 | path = root / "scenes" / (scene_id + ".npz") 77 | np.savez_compressed(path, grid=voxel_grid) 78 | 79 | 80 | def read_voxel_grid(root, scene_id): 81 | path = root / "scenes" / (scene_id + ".npz") 82 | return np.load(path)["grid"] 83 | 84 | 85 | def read_json(path): 86 | with path.open("r") as f: 87 | data = json.load(f) 88 | return data 89 | 90 | 91 | def write_json(data, path): 92 | with path.open("w") as f: 93 | json.dump(data, f, indent=4) 94 | 95 | 96 | def create_csv(path, columns): 97 | with path.open("w") as f: 98 | f.write(",".join(columns)) 99 | f.write("\n") 100 | 101 | 102 | def append_csv(path, *args): 103 | row = ",".join([str(arg) for arg in args]) 104 | with path.open("a") as f: 105 | f.write(row) 106 | f.write("\n") 107 | -------------------------------------------------------------------------------- /src/gd/utils/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.spatial.transform 3 | 4 | 5 | class Rotation(scipy.spatial.transform.Rotation): 6 | @classmethod 7 | def identity(cls): 8 | return cls.from_quat([0.0, 0.0, 0.0, 1.0]) 9 | 10 | 11 | class Transform(object): 12 | """Rigid spatial transform between coordinate systems in 3D space. 13 | 14 | Attributes: 15 | rotation (scipy.spatial.transform.Rotation) 16 | translation (np.ndarray) 17 | """ 18 | 19 | def __init__(self, rotation, translation): 20 | assert isinstance(rotation, scipy.spatial.transform.Rotation) 21 | assert isinstance(translation, (np.ndarray, list)) 22 | 23 | self.rotation = rotation 24 | self.translation = np.asarray(translation, np.double) 25 | 26 | def as_matrix(self): 27 | """Represent as a 4x4 matrix.""" 28 | return np.vstack( 29 | (np.c_[self.rotation.as_matrix(), self.translation], [0.0, 0.0, 0.0, 1.0]) 30 | ) 31 | 32 | def to_dict(self): 33 | """Serialize Transform object into a dictionary.""" 34 | return { 35 | "rotation": self.rotation.as_quat().tolist(), 36 | "translation": self.translation.tolist(), 37 | } 38 | 39 | def to_list(self): 40 | return np.r_[self.rotation.as_quat(), self.translation] 41 | 42 | def __mul__(self, other): 43 | """Compose this transform with another.""" 44 | rotation = self.rotation * other.rotation 45 | translation = self.rotation.apply(other.translation) + self.translation 46 | return self.__class__(rotation, translation) 47 | 48 | def transform_point(self, point): 49 | return self.rotation.apply(point) + self.translation 50 | 51 | def transform_vector(self, vector): 52 | return self.rotation.apply(vector) 53 | 54 | def inverse(self): 55 | """Compute the inverse of this transform.""" 56 | rotation = self.rotation.inv() 57 | translation = -rotation.apply(self.translation) 58 | return self.__class__(rotation, translation) 59 | 60 | @classmethod 61 | def from_matrix(cls, m): 62 | """Initialize from a 4x4 matrix.""" 63 | rotation = Rotation.from_matrix(m[:3, :3]) 64 | translation = m[:3, 3] 65 | return cls(rotation, translation) 66 | 67 | @classmethod 68 | def from_dict(cls, dictionary): 69 | rotation = Rotation.from_quat(dictionary["rotation"]) 70 | translation = np.asarray(dictionary["translation"]) 71 | return cls(rotation, translation) 72 | 73 | @classmethod 74 | def from_list(cls, list): 75 | rotation = Rotation.from_quat(list[:4]) 76 | translation = list[4:] 77 | return cls(rotation, translation) 78 | 79 | @classmethod 80 | def identity(cls): 81 | """Initialize with the identity transformation.""" 82 | rotation = Rotation.from_quat([0.0, 0.0, 0.0, 1.0]) 83 | translation = np.array([0.0, 0.0, 0.0]) 84 | return cls(rotation, translation) 85 | 86 | @classmethod 87 | def look_at(cls, eye, center, up): 88 | """Initialize with a LookAt matrix. 89 | 90 | Returns: 91 | T_eye_ref, the transform from camera to the reference frame, w.r.t. 92 | which the input arguments were defined. 93 | """ 94 | eye = np.asarray(eye) 95 | center = np.asarray(center) 96 | 97 | forward = center - eye 98 | forward /= np.linalg.norm(forward) 99 | 100 | right = np.cross(forward, up) 101 | right /= np.linalg.norm(right) 102 | 103 | up = np.asarray(up) / np.linalg.norm(up) 104 | up = np.cross(right, forward) 105 | 106 | m = np.eye(4, 4) 107 | m[:3, 0] = right 108 | m[:3, 1] = -up 109 | m[:3, 2] = forward 110 | m[:3, 3] = eye 111 | 112 | return cls.from_matrix(m).inverse() 113 | -------------------------------------------------------------------------------- /src/gd/detection.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | from scipy import ndimage 5 | import torch 6 | 7 | 8 | from gd.grasp import * 9 | from gd.utils.transform import Transform, Rotation 10 | from gd.networks import load_network 11 | 12 | 13 | class VGN(object): 14 | def __init__(self, model_path, rviz=False): 15 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | self.net = load_network(model_path, self.device) 17 | self.rviz = rviz 18 | 19 | def __call__(self, state): 20 | tsdf_vol = state.tsdf.get_grid() 21 | voxel_size = state.tsdf.voxel_size 22 | 23 | tic = time.time() 24 | qual_vol, rot_vol, width_vol = predict(tsdf_vol, self.net, self.device) 25 | qual_vol, rot_vol, width_vol = process(tsdf_vol, qual_vol, rot_vol, width_vol) 26 | grasps, scores = select(qual_vol.copy(), rot_vol, width_vol) 27 | toc = time.time() - tic 28 | 29 | grasps, scores = np.asarray(grasps), np.asarray(scores) 30 | 31 | if len(grasps) > 0: 32 | p = np.random.permutation(len(grasps)) 33 | grasps = [from_voxel_coordinates(g, voxel_size) for g in grasps[p]] 34 | scores = scores[p] 35 | 36 | if self.rviz: 37 | from gd import vis 38 | vis.draw_quality(qual_vol, state.tsdf.voxel_size, threshold=0.01) 39 | 40 | return grasps, scores, toc 41 | 42 | 43 | def predict(tsdf_vol, net, device): 44 | assert tsdf_vol.shape == (1, 40, 40, 40) 45 | 46 | # move input to the GPU 47 | tsdf_vol = torch.from_numpy(tsdf_vol).unsqueeze(0).to(device) 48 | 49 | # forward pass 50 | with torch.no_grad(): 51 | qual_vol, rot_vol, width_vol = net(tsdf_vol) 52 | 53 | # move output back to the CPU 54 | qual_vol = qual_vol.cpu().squeeze().numpy() 55 | rot_vol = rot_vol.cpu().squeeze().numpy() 56 | width_vol = width_vol.cpu().squeeze().numpy() 57 | return qual_vol, rot_vol, width_vol 58 | 59 | 60 | def process( 61 | tsdf_vol, 62 | qual_vol, 63 | rot_vol, 64 | width_vol, 65 | gaussian_filter_sigma=1.0, 66 | min_width=1.33, 67 | max_width=9.33, 68 | ): 69 | tsdf_vol = tsdf_vol.squeeze() 70 | 71 | # smooth quality volume with a Gaussian 72 | qual_vol = ndimage.gaussian_filter( 73 | qual_vol, sigma=gaussian_filter_sigma, mode="nearest" 74 | ) 75 | 76 | # mask out voxels too far away from the surface 77 | outside_voxels = tsdf_vol > 0.5 78 | inside_voxels = np.logical_and(1e-3 < tsdf_vol, tsdf_vol < 0.5) 79 | valid_voxels = ndimage.morphology.binary_dilation( 80 | outside_voxels, iterations=2, mask=np.logical_not(inside_voxels) 81 | ) 82 | qual_vol[valid_voxels == False] = 0.0 83 | 84 | # reject voxels with predicted widths that are too small or too large 85 | qual_vol[np.logical_or(width_vol < min_width, width_vol > max_width)] = 0.0 86 | 87 | return qual_vol, rot_vol, width_vol 88 | 89 | 90 | def select(qual_vol, rot_vol, width_vol, threshold=0.90, max_filter_size=4): 91 | # threshold on grasp quality 92 | qual_vol[qual_vol < threshold] = 0.0 93 | 94 | # non maximum suppression 95 | max_vol = ndimage.maximum_filter(qual_vol, size=max_filter_size) 96 | qual_vol = np.where(qual_vol == max_vol, qual_vol, 0.0) 97 | mask = np.where(qual_vol, 1.0, 0.0) 98 | 99 | # construct grasps 100 | grasps, scores = [], [] 101 | for index in np.argwhere(mask): 102 | grasp, score = select_index(qual_vol, rot_vol, width_vol, index) 103 | grasps.append(grasp) 104 | scores.append(score) 105 | 106 | return grasps, scores 107 | 108 | 109 | def select_index(qual_vol, rot_vol, width_vol, index): 110 | i, j, k = index 111 | score = qual_vol[i, j, k] 112 | ori = Rotation.from_quat(rot_vol[:, i, j, k]) 113 | pos = np.array([i, j, k], dtype=np.float64) 114 | width = width_vol[i, j, k] 115 | return Grasp(Transform(ori, pos), width), score 116 | -------------------------------------------------------------------------------- /scripts/stat_expresult.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | import pandas as pd 4 | import os 5 | import numpy as np 6 | 7 | argv = sys.argv 8 | argv = argv[argv.index("--") + 1:] # get all args after "--" 9 | log_root_dir = str(argv[0]) 10 | expname = str(argv[1]) 11 | 12 | class Data(object): 13 | """Object for loading and analyzing experimental data.""" 14 | 15 | def __init__(self, logdir): 16 | self.logdir = logdir 17 | self.rounds = pd.read_csv(logdir / "rounds.csv") 18 | self.grasps = pd.read_csv(logdir / "grasps.csv") 19 | 20 | def num_rounds(self): 21 | return len(self.rounds.index) 22 | 23 | def num_grasps(self): 24 | return len(self.grasps.index) 25 | 26 | def success_rate(self): 27 | return self.grasps["label"].mean() * 100 28 | 29 | def percent_cleared(self): 30 | df = ( 31 | self.grasps[["round_id", "label"]] 32 | .groupby("round_id") 33 | .sum() 34 | .rename(columns={"label": "cleared_count"}) 35 | .merge(self.rounds, on="round_id") 36 | ) 37 | return df["cleared_count"].sum() / df["object_count"].sum() * 100 38 | 39 | def avg_planning_time(self): 40 | return self.grasps["planning_time"].mean() 41 | 42 | def read_grasp(self, i): 43 | scene_id, grasp, label = io.read_grasp(self.grasps, i) 44 | score = self.grasps.loc[i, "score"] 45 | scene_data = np.load(self.logdir / "scenes" / (scene_id + ".npz")) 46 | 47 | return scene_data["points"], grasp, score, label 48 | 49 | ############################## 50 | # Combine all trials 51 | ############################## 52 | root_path = os.path.join(log_root_dir, "exp_results", expname) 53 | 54 | round_dir_list = sorted(os.listdir(root_path)) 55 | 56 | if not os.path.exists(root_path + "_combine"): 57 | os.makedirs(root_path + "_combine") 58 | 59 | df = pd.DataFrame() 60 | for i in range(len(round_dir_list)): 61 | df_round = pd.read_csv(os.path.join(root_path, round_dir_list[i], "grasps.csv")) 62 | df_round["round_id"] = i 63 | df = pd.concat([df, df_round]) 64 | df = df.reset_index(drop=True) 65 | df.to_csv(os.path.join(root_path + "_combine", "grasps.csv"), index=False) 66 | 67 | 68 | df = pd.DataFrame() 69 | for i in range(len(round_dir_list)): 70 | df_round = pd.read_csv(os.path.join(root_path, round_dir_list[i], "rounds.csv")) 71 | df_round["round_id"] = i 72 | df = pd.concat([df, df_round]) 73 | df = df.reset_index(drop=True) 74 | df.to_csv(os.path.join(root_path + "_combine", "rounds.csv"), index=False) 75 | 76 | ############################## 77 | # Print Stat 78 | ############################## 79 | logdir = Path(os.path.join(log_root_dir, "exp_results", expname+"_combine")) 80 | data = Data(logdir) 81 | 82 | # First, we compute the following metrics for the experiment: 83 | # * **Success rate**: the ratio of successful grasp executions, 84 | # * **Percent cleared**: the percentage of objects removed during each round, 85 | try: 86 | print("Path: ",str(logdir)) 87 | print("Num grasps: ", data.num_grasps()) 88 | print("Success rate: ", data.success_rate()) 89 | print("Percent cleared: ", data.percent_cleared()) 90 | except: 91 | print("[W] Incomplete results, exit") 92 | exit() 93 | ############################## 94 | # Calc first-time grasping SR 95 | ############################## 96 | 97 | sum_label = 0 98 | firstgrasp_fail_expidx_list = [] 99 | for i in range(len(round_dir_list)): 100 | #print(i) 101 | df_round = pd.read_csv(os.path.join(root_path, round_dir_list[i], "grasps.csv")) 102 | df = df_round.iloc[0:1,:] 103 | 104 | label = df[["label"]].to_numpy(np.float32) 105 | if label.shape[0] == 0: 106 | firstgrasp_fail_expidx_list.append(i) 107 | continue 108 | sum_label += label[0,0] 109 | if label[0,0]==0: 110 | firstgrasp_fail_expidx_list.append(i) 111 | 112 | print("First grasp success rate: ", sum_label / len(round_dir_list)) 113 | print("First grasp fail:", len(firstgrasp_fail_expidx_list),"/",len(round_dir_list), ", exp id: ", firstgrasp_fail_expidx_list) 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraspNeRF: Multiview-based 6-DoF Grasp Detection for Transparent and Specular Objects Using Generalizable NeRF (ICRA 2023) 2 | 3 | This is the official repository of [**GraspNeRF: Multiview-based 6-DoF Grasp Detection for Transparent and Specular Objects Using Generalizable NeRF**](https://arxiv.org/abs/2210.06575). 4 | 5 | For more information, please visit our [**project page**](https://pku-epic.github.io/GraspNeRF/). 6 | 7 | ## Introduction 8 | 9 | 10 | In this work, we propose a multiview RGB-based 6-DoF grasp detection network, **GraspNeRF**, 11 | that leverages the **generalizable neural radiance field (NeRF)** to achieve material-agnostic object grasping in clutter. 12 | Compared to the existing NeRF-based 3-DoF grasp detection methods that rely on densely captured input images and time-consuming per-scene optimization, 13 | our system can perform zero-shot NeRF construction with **sparse RGB inputs** and reliably detect **6-DoF grasps**, both in **real-time**. 14 | The proposed framework jointly learns generalizable NeRF and grasp detection in an **end-to-end** manner, optimizing the scene representation construction for the grasping. 15 | For training data, we generate a large-scale photorealistic **domain-randomized synthetic dataset** of grasping in cluttered tabletop scenes that enables direct transfer to the real world. 16 | Experiments in synthetic and real-world environments demonstrate that our method significantly outperforms all the baselines in all the experiments. 17 | 18 | ## Overview 19 | This repository provides: 20 | - PyTorch code, and weights of GraspNeRF. 21 | - Grasp Simulator based on blender and pybullet. 22 | - Multiview 6-DoF Grasping Dataset Generator and Examples. 23 | 24 | ## Dependency 25 | 1. Please run 26 | ``` 27 | pip install -r requirements.txt 28 | ``` 29 | to install dependency. 30 | 31 | 2. (optional) Please install [blender 2.93.3--Ubuntu](https://www.blender.org/) if you need simulation. 32 | 33 | ## Data & Checkpoints 34 | 1. Please generate or download and uncompress the [example data](https://drive.google.com/file/d/1Ku-EotayUhfv5DtXAvFitGzzdMF84Ve2/view?usp=share_link) to `data/` for training, and [rendering assets](https://drive.google.com/file/d/1Udvi2QQ6AtYDLUWY0oH-PO2R6kZBxJLT/view?usp=share_link) to `data/assets` for simulation. 35 | Specifically, download [imagenet valset](https://image-net.org/data/ILSVRC/2010/ILSVRC2010_images_val.tar) to `data/assets/imagenet/images/val` which is used as random texture in simulation. 36 | 2. We provide pretrained weights for testing. Please download the [checkpoint](https://drive.google.com/file/d/1k-Cy4NO2isCBYc3az-34HEdcNxDptDgU/view?usp=share_link) to `src/nr/ckpt/test`. 37 | 38 | ## Testing 39 | Our grasp simulation pipeline is depend on blender and pybullet. Please verify the installation before running simulation. 40 | 41 | After the dependency and assets are ready, please run 42 | ``` 43 | bash run_simgrasp.sh 44 | ``` 45 | 46 | ## Training 47 | After the training data is ready, please run 48 | ``` 49 | bash train.sh GPU_ID 50 | ``` 51 | e.g. `bash train.sh 0`. 52 | 53 | ## Data Generator 54 | 1. Download the scene descriptor files from [GIGA](https://github.com/UT-Austin-RPL/GIGA#pre-generated-data) and [assets](https://drive.google.com/file/d/1-59zcQ8h5esT_ogjaDjtzQ6sG70WNWzU/view?usp=share_link). 55 | 2. For example, run 56 | ``` 57 | bash run_pile_rand.sh 58 | ``` 59 | in ./data_generator for pile data generation. 60 | 61 | ## Citation 62 | If you find our work useful in your research, please consider citing: 63 | 64 | ``` 65 | @article{Dai2023GraspNeRF, 66 | title={GraspNeRF: Multiview-based 6-DoF Grasp Detection for Transparent and Specular Objects Using Generalizable NeRF}, 67 | author={Qiyu Dai and Yan Zhu and Yiran Geng and Ciyu Ruan and Jiazhao Zhang and He Wang}, 68 | booktitle={IEEE International Conference on Robotics and Automation (ICRA)}, 69 | year={2023} 70 | ``` 71 | 72 | ## License 73 | 74 | This work and the dataset are licensed under [CC BY-NC 4.0][cc-by-nc]. 75 | 76 | [![CC BY-NC 4.0][cc-by-nc-image]][cc-by-nc] 77 | 78 | [cc-by-nc]: https://creativecommons.org/licenses/by-nc/4.0/ 79 | [cc-by-nc-image]: https://licensebuttons.net/l/by-nc/4.0/88x31.png 80 | 81 | ## Contact 82 | If you have any questions, please open a github issue or contact us: 83 | 84 | Qiyu Dai: qiyudai@pku.edu.cn, Yan Zhu: zhuyan_@stu.pku.edu.cn, He Wang: hewang@pku.edu.cn 85 | -------------------------------------------------------------------------------- /src/gd/perception.py: -------------------------------------------------------------------------------- 1 | from math import cos, sin 2 | import time 3 | 4 | import numpy as np 5 | import open3d as o3d 6 | 7 | from gd.utils.transform import Transform 8 | 9 | 10 | class CameraIntrinsic(object): 11 | """Intrinsic parameters of a pinhole camera model. 12 | 13 | Attributes: 14 | width (int): The width in pixels of the camera. 15 | height(int): The height in pixels of the camera. 16 | K: The intrinsic camera matrix. 17 | """ 18 | 19 | def __init__(self, width, height, fx, fy, cx, cy, channel=1): 20 | self.width = width 21 | self.height = height 22 | self.channel = channel 23 | self.K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]]) 24 | 25 | @property 26 | def fx(self): 27 | return self.K[0, 0] 28 | 29 | @property 30 | def fy(self): 31 | return self.K[1, 1] 32 | 33 | @property 34 | def cx(self): 35 | return self.K[0, 2] 36 | 37 | @property 38 | def cy(self): 39 | return self.K[1, 2] 40 | 41 | def to_dict(self): 42 | """Serialize intrinsic parameters to a dict object.""" 43 | data = { 44 | "width": self.width, 45 | "height": self.height, 46 | "channel": self.channel, 47 | "K": self.K.flatten().tolist(), 48 | } 49 | return data 50 | 51 | @classmethod 52 | def from_dict(cls, data): 53 | """Deserialize intrinisic parameters from a dict object.""" 54 | intrinsic = cls( 55 | width=data["width"], 56 | height=data["height"], 57 | channel=data["channel"], 58 | fx=data["K"][0], 59 | fy=data["K"][4], 60 | cx=data["K"][2], 61 | cy=data["K"][5], 62 | ) 63 | return intrinsic 64 | 65 | 66 | class TSDFVolume(object): 67 | """Integration of multiple depth images using a TSDF.""" 68 | 69 | def __init__(self, size, resolution): 70 | self.size = size 71 | self.resolution = resolution 72 | self.voxel_size = self.size / self.resolution 73 | self.sdf_trunc = 4 * self.voxel_size 74 | 75 | self._volume = o3d.pipelines.integration.UniformTSDFVolume( 76 | length=self.size, 77 | resolution=self.resolution, 78 | sdf_trunc=self.sdf_trunc, 79 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.NoColor, 80 | ) 81 | 82 | def integrate(self, depth_img, intrinsic, extrinsic): 83 | """ 84 | Args: 85 | depth_img: The depth image. 86 | intrinsic: The intrinsic parameters of a pinhole camera model. 87 | extrinsics: The transform from the TSDF to camera coordinates, T_eye_task. 88 | """ 89 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 90 | o3d.geometry.Image(np.empty_like(depth_img)), 91 | o3d.geometry.Image(depth_img), 92 | depth_scale=1.0, 93 | depth_trunc=2.0, 94 | convert_rgb_to_intensity=False, 95 | ) 96 | 97 | intrinsic = o3d.camera.PinholeCameraIntrinsic( 98 | width=intrinsic.width, 99 | height=intrinsic.height, 100 | fx=intrinsic.fx, 101 | fy=intrinsic.fy, 102 | cx=intrinsic.cx, 103 | cy=intrinsic.cy, 104 | ) 105 | 106 | 107 | self._volume.integrate(rgbd, intrinsic, extrinsic) 108 | 109 | def get_grid(self): 110 | cloud = self._volume.extract_voxel_point_cloud() 111 | points = np.asarray(cloud.points) 112 | distances = np.asarray(cloud.colors)[:, [0]] 113 | grid = np.zeros((1, 40, 40, 40), dtype=np.float32) 114 | for idx, point in enumerate(points): 115 | i, j, k = np.floor(point / self.voxel_size).astype(int) 116 | grid[0, i, j, k] = distances[idx] 117 | return grid 118 | 119 | def get_cloud(self): 120 | return self._volume.extract_point_cloud() 121 | 122 | 123 | def create_tsdf(size, resolution, depth_imgs, intrinsic, extrinsics): 124 | tsdf = TSDFVolume(size, resolution) 125 | for i in range(depth_imgs.shape[0]): 126 | extrinsic = Transform.from_list(extrinsics[i]) 127 | tsdf.integrate(depth_imgs[i], intrinsic, extrinsic) 128 | return tsdf 129 | 130 | 131 | def camera_on_sphere(origin, radius, theta, phi): 132 | eye = np.r_[ 133 | radius * sin(theta) * cos(phi), 134 | radius * sin(theta) * sin(phi), 135 | radius * cos(theta), 136 | ] 137 | target = np.array([0.0, 0.0, 0.0]) 138 | up = np.array([0.0, 0.0, 1.0]) # this breaks when looking straight down 139 | return Transform.look_at(eye, target, up) * origin.inverse() 140 | -------------------------------------------------------------------------------- /src/gd/utils/ros_utils.py: -------------------------------------------------------------------------------- 1 | import geometry_msgs.msg 2 | import numpy as np 3 | import rospy 4 | from sensor_msgs.msg import PointCloud2, PointField 5 | import std_msgs.msg 6 | 7 | 8 | from gd.utils.transform import Rotation, Transform 9 | 10 | 11 | def to_point_msg(position): 12 | """Convert numpy array to a Point message.""" 13 | msg = geometry_msgs.msg.Point() 14 | msg.x = position[0] 15 | msg.y = position[1] 16 | msg.z = position[2] 17 | return msg 18 | 19 | 20 | def from_point_msg(msg): 21 | """Convert a Point message to a numpy array.""" 22 | return np.r_[msg.x, msg.y, msg.z] 23 | 24 | 25 | def to_vector3_msg(vector3): 26 | """Convert numpy array to a Vector3 message.""" 27 | msg = geometry_msgs.msg.Vector3() 28 | msg.x = vector3[0] 29 | msg.y = vector3[1] 30 | msg.z = vector3[2] 31 | return msg 32 | 33 | 34 | def from_vector3_msg(msg): 35 | """Convert a Vector3 message to a numpy array.""" 36 | return np.r_[msg.x, msg.y, msg.z] 37 | 38 | 39 | def to_quat_msg(orientation): 40 | """Convert a `Rotation` object to a Quaternion message.""" 41 | quat = orientation.as_quat() 42 | msg = geometry_msgs.msg.Quaternion() 43 | msg.x = quat[0] 44 | msg.y = quat[1] 45 | msg.z = quat[2] 46 | msg.w = quat[3] 47 | return msg 48 | 49 | 50 | def from_quat_msg(msg): 51 | """Convert a Quaternion message to a Rotation object.""" 52 | return Rotation.from_quat([msg.x, msg.y, msg.z, msg.w]) 53 | 54 | 55 | def to_pose_msg(transform): 56 | """Convert a `Transform` object to a Pose message.""" 57 | msg = geometry_msgs.msg.Pose() 58 | msg.position = to_point_msg(transform.translation) 59 | msg.orientation = to_quat_msg(transform.rotation) 60 | return msg 61 | 62 | 63 | def to_transform_msg(transform): 64 | """Convert a `Transform` object to a Transform message.""" 65 | msg = geometry_msgs.msg.Transform() 66 | msg.translation = to_vector3_msg(transform.translation) 67 | msg.rotation = to_quat_msg(transform.rotation) 68 | return msg 69 | 70 | 71 | def from_transform_msg(msg): 72 | """Convert a Transform message to a Transform object.""" 73 | translation = from_vector3_msg(msg.translation) 74 | rotation = from_quat_msg(msg.rotation) 75 | return Transform(rotation, translation) 76 | 77 | 78 | def to_color_msg(color): 79 | """Convert a numpy array to a ColorRGBA message.""" 80 | msg = std_msgs.msg.ColorRGBA() 81 | msg.r = color[0] 82 | msg.g = color[1] 83 | msg.b = color[2] 84 | msg.a = color[3] if len(color) == 4 else 1.0 85 | return msg 86 | 87 | 88 | def to_cloud_msg(points, intensities=None, frame=None, stamp=None): 89 | """Convert list of unstructured points to a PointCloud2 message. 90 | 91 | Args: 92 | points: Point coordinates as array of shape (N,3). 93 | colors: Colors as array of shape (N,3). 94 | frame 95 | stamp 96 | """ 97 | msg = PointCloud2() 98 | msg.header.frame_id = frame 99 | msg.header.stamp = stamp or rospy.Time.now() 100 | 101 | msg.height = 1 102 | msg.width = points.shape[0] 103 | msg.is_bigendian = False 104 | msg.is_dense = False 105 | 106 | msg.fields = [ 107 | PointField("x", 0, PointField.FLOAT32, 1), 108 | PointField("y", 4, PointField.FLOAT32, 1), 109 | PointField("z", 8, PointField.FLOAT32, 1), 110 | ] 111 | msg.point_step = 12 112 | data = points 113 | 114 | if intensities is not None: 115 | msg.fields.append(PointField("intensity", 12, PointField.FLOAT32, 1)) 116 | msg.point_step += 4 117 | data = np.hstack([points, intensities]) 118 | 119 | msg.row_step = msg.point_step * points.shape[0] 120 | msg.data = data.astype(np.float32).tostring() 121 | 122 | return msg 123 | 124 | 125 | class TransformTree(object): 126 | def __init__(self): 127 | import tf2_ros 128 | self._buffer = tf2_ros.Buffer() 129 | self._listener = tf2_ros.TransformListener(self._buffer) 130 | self._broadcaster = tf2_ros.TransformBroadcaster() 131 | self._static_broadcaster = tf2_ros.StaticTransformBroadcaster() 132 | 133 | def lookup(self, target_frame, source_frame, time, timeout=rospy.Duration(0)): 134 | msg = self._buffer.lookup_transform(target_frame, source_frame, time, timeout) 135 | return from_transform_msg(msg.transform) 136 | 137 | def broadcast(self, transform, target_frame, source_frame): 138 | msg = geometry_msgs.msg.TransformStamped() 139 | msg.header.stamp = rospy.Time.now() 140 | msg.header.frame_id = target_frame 141 | msg.child_frame_id = source_frame 142 | msg.transform = to_transform_msg(transform) 143 | self._broadcaster.sendTransform(msg) 144 | 145 | def broadcast_static(self, transform, target_frame, source_frame): 146 | msg = geometry_msgs.msg.TransformStamped() 147 | msg.header.stamp = rospy.Time.now() 148 | msg.header.frame_id = target_frame 149 | msg.child_frame_id = source_frame 150 | msg.transform = to_transform_msg(transform) 151 | self._static_broadcaster.sendTransform(msg) 152 | -------------------------------------------------------------------------------- /scripts/sim_grasp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from pathlib import Path 5 | 6 | def main(args, round_idx, gpuid, render_frame_list): 7 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpuid) 8 | 9 | sys.path.append("src") 10 | from nr.main import GraspNeRFPlanner 11 | 12 | if args.method == "graspnerf": 13 | grasp_planner = GraspNeRFPlanner(args) 14 | else: 15 | print("No such method!") 16 | raise NotImplementedError 17 | 18 | from gd.experiments import clutter_removal 19 | clutter_removal.run( 20 | grasp_plan_fn=grasp_planner, 21 | logdir=args.logdir, 22 | description=args.description, 23 | scene=args.scene, 24 | object_set=args.object_set, 25 | num_objects=args.num_objects, 26 | num_rounds=args.num_rounds, 27 | seed=args.seed, 28 | sim_gui=args.sim_gui, 29 | rviz=args.rviz, 30 | round_idx = round_idx, 31 | renderer_root_dir = args.renderer_root_dir, 32 | gpuid = gpuid, 33 | args = args, 34 | render_frame_list = render_frame_list 35 | ) 36 | 37 | class ArgumentParserForBlender(argparse.ArgumentParser): 38 | """ 39 | This class is identical to its superclass, except for the parse_args 40 | method (see docstring). It resolves the ambiguity generated when calling 41 | Blender from the CLI with a python script, and both Blender and the script 42 | have arguments. E.g., the following call will make Blender crash because 43 | it will try to process the script's -a and -b flags: 44 | >>> blender --python my_script.py -a 1 -b 2 45 | 46 | To bypass this issue this class uses the fact that Blender will ignore all 47 | arguments given after a double-dash ('--'). The approach is that all 48 | arguments before '--' go to Blender, arguments after go to the script. 49 | The following calls work fine: 50 | >>> blender --python my_script.py -- -a 1 -b 2 51 | >>> blender --python my_script.py -- 52 | """ 53 | 54 | def _get_argv_after_doubledash(self): 55 | """ 56 | Given the sys.argv as a list of strings, this method returns the 57 | sublist right after the '--' element (if present, otherwise returns 58 | an empty list). 59 | """ 60 | try: 61 | idx = sys.argv.index("---") 62 | return sys.argv[idx+1:] # the list after '--' 63 | except ValueError as e: # '--' not in the list: 64 | return [] 65 | 66 | # overrides superclass 67 | def parse_args(self): 68 | """ 69 | This method is expected to behave identically as in the superclass, 70 | except that the sys.argv list will be pre-processed using 71 | _get_argv_after_doubledash before. See the docstring of the class for 72 | usage examples and details. 73 | """ 74 | return super().parse_args(args=self._get_argv_after_doubledash()) 75 | 76 | if __name__ == "__main__": 77 | argv = sys.argv 78 | argv = argv[argv.index("--") + 1:] # get all args after "--" 79 | round_idx = int(argv[0]) 80 | gpuid = int(argv[1]) 81 | expname = str(argv[2]) 82 | scene = str(argv[3]) 83 | object_set = str(argv[4]) 84 | check_seen_scene = bool(int(argv[5])) 85 | material_type = str(argv[6]) 86 | blender_asset_dir = str(argv[7]) 87 | log_root_dir = str(argv[8]) 88 | use_gt_tsdf = bool(int(argv[9])) 89 | render_frame_list=[int(frame_id) for frame_id in str(argv[10]).replace(' ','').split(",")] 90 | method = str(argv[11]) 91 | print("########## Simulation Start ##########") 92 | print("Round %d\nmethod: %s\nmaterial_type: %s\nviews: %s "%(round_idx, method, material_type, str(render_frame_list))) 93 | print("######################################") 94 | 95 | parser = ArgumentParserForBlender() ### argparse.ArgumentParser() 96 | parser.add_argument("---model", type=Path, default="") 97 | parser.add_argument("---logdir", type=Path, default=expname) 98 | parser.add_argument("---description", type=str, default="") 99 | parser.add_argument("---scene", type=str, choices=["pile", "packed", "single"], default=scene) 100 | parser.add_argument("---object-set", type=str, default=object_set) 101 | parser.add_argument("---num-objects", type=int, default=5) 102 | parser.add_argument("---num-rounds", type=int, default=200) 103 | parser.add_argument("---seed", type=int, default=42) 104 | parser.add_argument("---sim-gui", type=bool, default=False) 105 | parser.add_argument("---rviz", action="store_true") 106 | 107 | ### 108 | parser.add_argument("---renderer_root_dir", type=str, default=blender_asset_dir) 109 | parser.add_argument("---log_root_dir", type=str, default=log_root_dir) 110 | parser.add_argument("---obj_texture_image_root_path", type=str, default=blender_asset_dir+"/imagenet") #TODO 111 | parser.add_argument("---cfg_fn", type=str, default="src/nr/configs/nrvgn_sdf.yaml") 112 | parser.add_argument('---database_name', type=str, default='vgn_syn/train/packed/packed_170-220/032cd891d9be4a16be5ea4be9f7eca2b/w_0.8', help='//') 113 | 114 | parser.add_argument("---gen_scene_descriptor", type=bool, default=False) 115 | parser.add_argument("---load_scene_descriptor", type=bool, default=True) 116 | parser.add_argument("---material_type", type=str, default=material_type) 117 | parser.add_argument("---method", type=str, default=method) 118 | 119 | # pybullet camera parameter 120 | parser.add_argument("---camera_focal", type=float, default=446.31) #TODO 121 | 122 | ### 123 | args = parser.parse_args() 124 | main(args, round_idx, gpuid, render_frame_list) -------------------------------------------------------------------------------- /src/nr/network/metrics.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from skimage.io import imsave 5 | 6 | from network.loss import Loss 7 | from utils.base_utils import color_map_backward, make_dir 8 | from skimage.metrics import structural_similarity 9 | import numpy as np 10 | 11 | from utils.draw_utils import concat_images_list 12 | 13 | 14 | def compute_psnr(img_gt, img_pr, use_vis_scores=False, vis_scores=None, vis_scores_thresh=1.5): 15 | if use_vis_scores: 16 | mask = vis_scores >= vis_scores_thresh 17 | mask = mask.flatten() 18 | img_gt = img_gt.reshape([-1, 3]).astype(np.float32)[mask] 19 | img_pr = img_pr.reshape([-1, 3]).astype(np.float32)[mask] 20 | mse = np.mean((img_gt - img_pr) ** 2, 0) 21 | 22 | img_gt = img_gt.reshape([-1, 3]).astype(np.float32) 23 | img_pr = img_pr.reshape([-1, 3]).astype(np.float32) 24 | mse = np.mean((img_gt - img_pr) ** 2, 0) 25 | mse = np.mean(mse) 26 | psnr = 10 * np.log10(255 * 255 / mse) 27 | return psnr 28 | 29 | def compute_mae(depth_pr, depth_gt): 30 | return np.mean(np.abs(depth_pr - depth_gt)) 31 | 32 | class PSNR_SSIM(Loss): 33 | default_cfg = { 34 | 'eval_margin_ratio': 1.0, 35 | } 36 | def __init__(self, cfg): 37 | super().__init__([]) 38 | self.cfg={**self.default_cfg,**cfg} 39 | 40 | def __call__(self, data_pr, data_gt, step, **kwargs): 41 | rgbs_gt = data_pr['pixel_colors_gt'] # 1,rn,3 42 | rgbs_pr = data_pr['pixel_colors_nr'] # 1,rn,3 43 | if 'que_imgs_info' in data_gt: 44 | h, w = data_gt['que_imgs_info']['imgs'].shape[2:] 45 | else: 46 | h, w = data_pr['que_imgs_info']['imgs'].shape[2:] 47 | rgbs_pr = rgbs_pr.reshape([h,w,3]).detach().cpu().numpy() 48 | rgbs_pr=color_map_backward(rgbs_pr) 49 | 50 | rgbs_gt = rgbs_gt.reshape([h,w,3]).detach().cpu().numpy() 51 | rgbs_gt = color_map_backward(rgbs_gt) 52 | 53 | h, w, _ = rgbs_gt.shape 54 | h_margin = int(h * (1 - self.cfg['eval_margin_ratio'])) // 2 55 | w_margin = int(w * (1 - self.cfg['eval_margin_ratio'])) // 2 56 | rgbs_gt = rgbs_gt[h_margin:h - h_margin, w_margin:w - w_margin] 57 | rgbs_pr = rgbs_pr[h_margin:h - h_margin, w_margin:w - w_margin] 58 | 59 | psnr = compute_psnr(rgbs_gt,rgbs_pr) 60 | outputs={ 61 | 'psnr_nr': torch.tensor([psnr],dtype=torch.float32), 62 | } 63 | 64 | def compute_psnr_prefix(suffix): 65 | if f'pixel_colors_{suffix}' in data_pr: 66 | rgbs_other = data_pr[f'pixel_colors_{suffix}'] # 1,rn,3 67 | # h, w = data_pr['shape'] 68 | rgbs_other = rgbs_other.reshape([h,w,3]).detach().cpu().numpy() 69 | rgbs_other=color_map_backward(rgbs_other) 70 | psnr = compute_psnr(rgbs_gt,rgbs_other) 71 | ssim = structural_similarity(rgbs_gt,rgbs_other,win_size=11,multichannel=True,data_range=255) 72 | outputs[f'psnr_{suffix}']=torch.tensor([psnr], dtype=torch.float32) 73 | 74 | # compute_psnr_prefix('nr') 75 | compute_psnr_prefix('dr') 76 | compute_psnr_prefix('nr_fine') 77 | compute_psnr_prefix('dr_fine') 78 | 79 | depth_pr = data_pr['render_depth'].reshape([h,w]).detach().cpu().numpy() 80 | depth_gt = data_gt['que_imgs_info']['true_depth'][0,0].cpu().numpy() 81 | 82 | 83 | outputs['depth_mae'] = torch.tensor([compute_mae(depth_pr, depth_gt)],dtype=torch.float32) # higher is better 84 | return outputs 85 | 86 | class VisualizeImage(Loss): 87 | def __init__(self, cfg): 88 | super().__init__([]) 89 | 90 | def __call__(self, data_pr, data_gt, step, **kwargs): 91 | if 'que_imgs_info' in data_gt: 92 | h, w = data_gt['que_imgs_info']['imgs'].shape[2:] 93 | else: 94 | h, w = data_pr['que_imgs_info']['imgs'].shape[2:] 95 | def get_img(key): 96 | rgbs = data_pr[key] # 1,rn,3 97 | rgbs = rgbs.reshape([h,w,3]).detach().cpu().numpy() 98 | rgbs = color_map_backward(rgbs) 99 | return rgbs 100 | 101 | outputs={} 102 | imgs=[get_img('pixel_colors_gt'), get_img('pixel_colors_nr')] 103 | if 'pixel_colors_dr' in data_pr: imgs.append(get_img('pixel_colors_dr')) 104 | if 'pixel_colors_nr_fine' in data_pr: imgs.append(get_img('pixel_colors_nr_fine')) 105 | if 'pixel_colors_dr_fine' in data_pr: imgs.append(get_img('pixel_colors_dr_fine')) 106 | 107 | data_index=kwargs['data_index'] 108 | model_name=kwargs['model_name'] 109 | Path(f'data/vis_val/{model_name}').mkdir(exist_ok=True, parents=True) 110 | if h<=64 and w<=64: 111 | imsave(f'data/vis_val/{model_name}/step-{step}-index-{data_index}.png',concat_images_list(*imgs)) 112 | else: 113 | imsave(f'data/vis_val/{model_name}/step-{step}-index-{data_index}.jpg', concat_images_list(*imgs)) 114 | return outputs 115 | 116 | name2metrics={ 117 | 'psnr_ssim': PSNR_SSIM, 118 | 'vis_img': VisualizeImage, 119 | } 120 | 121 | def psnr_nr(results): 122 | return np.mean(results['psnr_nr']) 123 | 124 | def psnr_nr_fine(results): 125 | return np.mean(results['psnr_nr_fine']) 126 | 127 | def depth_mae(results): 128 | return np.mean(results['depth_mae']) 129 | 130 | def sdf_mae(results): 131 | return np.mean(results['sdf_mae']) 132 | 133 | def loss_vgn(results): 134 | if 'loss_vgn' in results: 135 | return np.mean(results['loss_vgn']) 136 | else: 137 | return 1e6 138 | 139 | name2key_metrics={ 140 | 'psnr_nr': psnr_nr, 141 | 'psnr_nr_fine': psnr_nr_fine, 142 | 'depth_mae': depth_mae, 143 | 'loss_vgn': loss_vgn, 144 | 'sdf_mae': sdf_mae 145 | } -------------------------------------------------------------------------------- /src/nr/utils/grasp_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pathlib import Path 3 | import numpy as np 4 | from scipy import ndimage 5 | import sys 6 | sys.path.append("./src") 7 | import time 8 | 9 | from nr.utils.base_utils import color_map_forward 10 | from nr.utils.draw_utils import draw_cube, extract_surface_points_from_volume 11 | from gd.utils.transform import Transform, Rotation 12 | from skimage.io import imsave 13 | import cv2 14 | 15 | class Grasp(object): 16 | """Grasp parameterized as pose of a 2-finger robot hand. 17 | 18 | TODO(mbreyer): clarify definition of grasp frame 19 | """ 20 | 21 | def __init__(self, pose, width): 22 | self.pose = pose 23 | self.width = width 24 | 25 | 26 | def to_voxel_coordinates(grasp, voxel_size): 27 | pose = grasp.pose 28 | pose.translation /= voxel_size 29 | width = grasp.width / voxel_size 30 | return Grasp(pose, width) 31 | 32 | 33 | def from_voxel_coordinates(grasp, voxel_size): 34 | pose = grasp.pose 35 | pose.translation *= voxel_size 36 | width = grasp.width * voxel_size 37 | return Grasp(pose, width) 38 | 39 | 40 | def process( 41 | tsdf_vol, 42 | qual_vol, 43 | rot_vol, 44 | width_vol, 45 | gaussian_filter_sigma=1.0, 46 | min_width=0, 47 | max_width=12, 48 | ): 49 | tsdf_vol = tsdf_vol.squeeze() 50 | qual_vol = qual_vol.squeeze() 51 | rot_vol = rot_vol.squeeze() 52 | width_vol = width_vol.squeeze() 53 | # smooth quality volume with a Gaussian 54 | qual_vol = ndimage.gaussian_filter( 55 | qual_vol, sigma=gaussian_filter_sigma, mode="nearest" 56 | ) 57 | 58 | # mask out voxels too far away from the surface 59 | outside_voxels = tsdf_vol > 0.1 60 | inside_voxels = np.logical_and(-1 < tsdf_vol, tsdf_vol < -0.1) 61 | valid_voxels = ndimage.morphology.binary_dilation( 62 | outside_voxels, iterations=2, mask=np.logical_not(inside_voxels) 63 | ) 64 | qual_vol[valid_voxels == False] = 0.0 65 | # reject voxels with predicted widths that are too small or too large 66 | qual_vol[np.logical_or(width_vol < min_width, width_vol > max_width)] = 0.0 67 | 68 | return tsdf_vol, qual_vol, rot_vol, width_vol 69 | 70 | def select_index(qual_vol, rot_vol, width_vol, index): 71 | i, j, k = index 72 | score = qual_vol[i, j, k] 73 | ori = Rotation.from_quat(rot_vol[:, i, j, k]) 74 | pos = np.array([i, j, k], dtype=np.float64) 75 | width = width_vol[i, j, k] 76 | return Grasp(Transform(ori, pos), width), score 77 | 78 | def select(qual_vol, rot_vol, width_vol, threshold=0.90, max_filter_size=4): 79 | # threshold on grasp quality 80 | qual_vol[qual_vol < threshold] = 0.0 81 | 82 | # non maximum suppression 83 | max_vol = ndimage.maximum_filter(qual_vol, size=max_filter_size) 84 | qual_vol = np.where(qual_vol == max_vol, qual_vol, 0.0) 85 | mask = np.where(qual_vol, 1.0, 0.0) 86 | 87 | # construct grasps 88 | grasps, scores = [], [] 89 | for index in np.argwhere(mask): 90 | grasp, score = select_index(qual_vol, rot_vol, width_vol, index) 91 | grasps.append(grasp) 92 | scores.append(score) 93 | 94 | return grasps, scores 95 | 96 | 97 | def sim_grasp(database, alp_vol, qual_vol, rot_vol, width_vol, top_k=10): 98 | from utils.grasp_utils import select, process 99 | qual_vol, rot_vol, width_vol = process(alp_vol, qual_vol, rot_vol, width_vol) 100 | grasps, scores = select(qual_vol.copy(), rot_vol, width_vol) 101 | grasps, scores = np.asarray(grasps), np.asarray(scores) 102 | 103 | img = None 104 | if len(grasps) > 0: 105 | p = np.argsort(scores)[::-1][:top_k] 106 | grasps = [g for g in grasps[p]] 107 | scores = scores[p] 108 | pos = np.array([ g.pose.translation for g in grasps ]) 109 | rot = np.array([ g.pose.rotation.as_matrix() for g in grasps ]) 110 | width = np.array([ g.width for g in grasps ]) 111 | 112 | img = database.visualize_grasping(pos, rot, width) 113 | database.visualize_grasping_3d(pos, rot, width, scores) 114 | 115 | 116 | return grasps, scores, img 117 | 118 | 119 | def run_real(run_id, model, images: list, extrinsics: list, intrinsic, save_img=True): 120 | extrinsics = np.stack(extrinsics, 0) 121 | intrinsics = np.repeat(np.expand_dims(intrinsic, 0), extrinsics.shape[0], axis=0) 122 | depth_range = np.repeat(np.expand_dims(np.r_[0.2, 0.8], 0), extrinsics.shape[0], axis=0).astype(np.float32) 123 | bbox3d = [[-0.15, -0.15, 0.00], [0.15, 0.15, 0.3]] 124 | 125 | if save_img: 126 | save_path = f'data/grasp_capture/{run_id}' 127 | if not Path(save_path).exists(): 128 | Path(save_path).mkdir(parents=True) 129 | for i, img in enumerate(images): 130 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 131 | # img = draw_cube(img, extrinsics[i][:3,:3], extrinsics[i][:3,3], intrinsics[i], length=0.3, bias=bbox3d[0]) 132 | cv2.imwrite(f"{save_path}/{i}.png", img) 133 | 134 | images = color_map_forward(np.stack(images, 0)).transpose([0, 3, 1, 2]) 135 | 136 | t0 = time.time() 137 | tsdf_vol, qual_vol, rot_vol, width_vol = model(images, extrinsics, intrinsics, depth_range=depth_range, bbox3d=bbox3d, que_id=3) 138 | t = time.time() - t0 139 | 140 | tsdf_vol, qual_vol, rot_vol, width_vol = process(tsdf_vol, qual_vol, rot_vol, width_vol) 141 | grasps, scores = select(qual_vol.copy(), rot_vol, width_vol) 142 | grasps, scores = np.asarray(grasps), np.asarray(scores) 143 | 144 | if len(grasps) > 0: 145 | p = np.random.permutation(len(grasps)) 146 | grasps = [from_voxel_coordinates(g, 0.3 / 40) for g in grasps[p]] 147 | scores = scores[p] 148 | 149 | pc = extract_surface_points_from_volume(tsdf_vol, (-0.2, 0.2)) 150 | 151 | return grasps, scores, tsdf_vol, pc, t -------------------------------------------------------------------------------- /src/nr/train/train_tools.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import numpy as np 7 | from torch.utils.tensorboard import SummaryWriter 8 | import torch.nn as nn 9 | 10 | 11 | def load_model(model, optim, model_dir, epoch=-1): 12 | if not os.path.exists(model_dir): 13 | return 0 14 | 15 | pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir)] 16 | if len(pths) == 0: 17 | return 0 18 | if epoch == -1: 19 | pth = max(pths) 20 | else: 21 | pth = epoch 22 | 23 | pretrained_model = torch.load(os.path.join(model_dir, '{}.pth'.format(pth))) 24 | model.load_state_dict(pretrained_model['net']) 25 | optim.load_state_dict(pretrained_model['optim']) 26 | print('load {} epoch {}'.format(model_dir, pretrained_model['epoch'] + 1)) 27 | return pretrained_model['epoch'] + 1 28 | 29 | def adjust_learning_rate(optimizer, epoch, lr_decay_rate, lr_decay_epoch, min_lr=1e-5): 30 | if ((epoch + 1) % lr_decay_epoch) != 0: 31 | return 32 | 33 | for param_group in optimizer.param_groups: 34 | # print(param_group) 35 | lr_before = param_group['lr'] 36 | param_group['lr'] = param_group['lr'] * lr_decay_rate 37 | param_group['lr'] = max(param_group['lr'], min_lr) 38 | 39 | print('changing learning rate {:5f} to {:.5f}'.format(lr_before, max(param_group['lr'], min_lr))) 40 | 41 | def reset_learning_rate(optimizer, lr): 42 | for param_group in optimizer.param_groups: 43 | # print(param_group) 44 | # lr_before = param_group['lr'] 45 | param_group['lr'] = lr 46 | # print('changing learning rate {:5f} to {:.5f}'.format(lr_before,lr)) 47 | return lr 48 | 49 | def save_model(net, optim, epoch, model_dir): 50 | os.system('mkdir -p {}'.format(model_dir)) 51 | torch.save({ 52 | 'net': net.feats_state_dict(), 53 | 'optim': optim.feats_state_dict(), 54 | 'epoch': epoch 55 | }, os.path.join(model_dir, '{}.pth'.format(epoch))) 56 | 57 | class Recorder(object): 58 | def __init__(self, rec_dir, rec_fn): 59 | self.rec_dir = rec_dir 60 | self.rec_fn = rec_fn 61 | self.data = OrderedDict() 62 | self.writer = SummaryWriter(log_dir=rec_dir) 63 | 64 | def rec_loss(self, losses_batch, step, epoch, prefix='train', dump=False): 65 | for k, v in losses_batch.items(): 66 | name = '{}/{}'.format(prefix, k) 67 | if name in self.data: 68 | self.data[name].append(v) 69 | else: 70 | self.data[name] = [v] 71 | 72 | if dump: 73 | if prefix == 'train': 74 | msg = '{} epoch {} step {} '.format(prefix, epoch, step) 75 | else: 76 | msg = '{} epoch {} '.format(prefix, epoch) 77 | for k, v in self.data.items(): 78 | if not k.startswith(prefix): continue 79 | if len(v) > 0: 80 | msg += '{} {:.5f} '.format(k.split('/')[-1], np.mean(v)) 81 | self.writer.add_scalar(k, np.mean(v), step) 82 | self.data[k] = [] 83 | 84 | print(msg) 85 | with open(self.rec_fn, 'a') as f: 86 | f.write(msg + '\n') 87 | 88 | def rec_msg(self, msg): 89 | print(msg) 90 | with open(self.rec_fn, 'a') as f: 91 | f.write(msg + '\n') 92 | 93 | 94 | class Logger: 95 | def __init__(self, log_dir): 96 | self.log_dir=log_dir 97 | self.data = OrderedDict() 98 | self.writer = SummaryWriter(log_dir=log_dir + "/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 99 | 100 | def log(self,data, prefix='train',step=None,verbose=False): 101 | msg=f'{prefix} ' 102 | for k, v in data.items(): 103 | msg += f'{k} {v:.5f} ' 104 | self.writer.add_scalar(f'{prefix}/{k}',v,step) 105 | 106 | if verbose: 107 | print(msg) 108 | with open(os.path.join(self.log_dir,f'{prefix}.txt'), 'a') as f: 109 | f.write(msg + '\n') 110 | 111 | def print_shape(obj): 112 | if type(obj) == list or type(obj) == tuple: 113 | shapes = [item.shape for item in obj] 114 | print(shapes) 115 | else: 116 | print(obj.shape) 117 | 118 | def overwrite_configs(cfg_base: dict, cfg: dict): 119 | keysNotinBase = [] 120 | for key in cfg.keys(): 121 | if key in cfg_base.keys(): 122 | cfg_base[key] = cfg[key] 123 | else: 124 | keysNotinBase.append(key) 125 | cfg_base.update({key: cfg[key]}) 126 | if len(keysNotinBase) != 0: 127 | print('==== WARNING: These keys are not set in DEFAULT_BASE_CONFIG... ====') 128 | print(keysNotinBase) 129 | return cfg_base 130 | 131 | def to_cuda(data): 132 | if type(data)==list: 133 | results = [] 134 | for i, item in enumerate(data): 135 | results.append(to_cuda(item)) 136 | return results 137 | elif type(data)==dict: 138 | results={} 139 | for k,v in data.items(): 140 | results[k]=to_cuda(v) 141 | return results 142 | elif type(data).__name__ == "Tensor": 143 | return data.cuda() 144 | else: 145 | return data 146 | 147 | def dim_extend(data_list): 148 | results = [] 149 | for i, tensor in enumerate(data_list): 150 | results.append(tensor[None,...]) 151 | return results 152 | 153 | class MultiGPUWrapper(nn.Module): 154 | def __init__(self,network,losses): 155 | super().__init__() 156 | self.network=network 157 | self.losses=losses 158 | 159 | def forward(self, data_gt): 160 | results={} 161 | data_pr=self.network(data_gt) 162 | results.update(data_pr) 163 | for loss in self.losses: 164 | results.update(loss(data_pr,data_gt,data_gt['step'])) 165 | return results 166 | 167 | class DummyLoss: 168 | def __init__(self,losses): 169 | self.keys=[] 170 | for loss in losses: 171 | self.keys+=loss.keys 172 | 173 | def __call__(self, data_pr, data_gt, step): 174 | return {key: data_pr[key] for key in self.keys} 175 | -------------------------------------------------------------------------------- /src/nr/network/dist_decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from network.ops import AddBias 5 | 6 | def get_near_far_points(depth, interval, depth_range, is_ref, fixed_interval=False, fixed_interval_val=0.01): 7 | """ is_ref | not is_ref 8 | :param depth: [...,dn] rfn,qn,rn,dn or qn,rn,dn 9 | :param interval: [...,dn] 1,qn,rn,dn or qn,rn,dn 10 | :param depth_range: rfn,2 or qn,2 11 | :param is_ref: 12 | :param fixed_interval: 13 | :param fixed_interval_val: 14 | :return: near far [rfn,qn,rn,dn] or [qn,rn,dn] 15 | """ 16 | if is_ref: 17 | ref_near = depth_range[:, 0] 18 | ref_far = depth_range[:, 1] 19 | ref_near = -1 / ref_near[:, None, None, None] 20 | ref_far = -1 / ref_far[:, None, None, None] 21 | depth = torch.clamp(depth, min=1e-5) 22 | depth = -1 / depth 23 | depth = (depth - ref_near) / (ref_far - ref_near) 24 | else: 25 | que_near = depth_range[:, 0] # qn 26 | que_far = depth_range[:, 1] # qn 27 | que_near = -1 / que_near[:, None, None] 28 | que_far = -1 / que_far[:, None, None] 29 | depth = torch.clamp(depth, min=1e-5) 30 | depth = -1 / depth 31 | depth = (depth - que_near) / (que_far - que_near) 32 | 33 | if not fixed_interval: 34 | if is_ref: 35 | interval_half = interval / 2 36 | interval_ext = torch.cat([interval_half[..., 0:1], interval_half], -1) 37 | near = depth - interval_ext[..., :-1] 38 | far = depth + interval_ext[..., 1:] 39 | else: 40 | interval_half = interval / 2 41 | first = depth[..., 0] - interval_half[..., 0] 42 | last = depth[..., -1] + interval_half[..., -1] 43 | depth_ext = (depth[..., :-1] + depth[..., 1:]) / 2 44 | depth_ext = torch.cat([first[..., None], depth_ext, last[..., None]], -1) 45 | near = depth_ext[..., :-1] 46 | far = depth_ext[..., 1:] 47 | else: 48 | near = depth - fixed_interval_val/2 49 | far = depth + fixed_interval_val/2 50 | 51 | return near, far 52 | 53 | class MixtureLogisticsDistDecoder(nn.Module): 54 | default_cfg={ 55 | 'feats_dim': 32, 56 | 'bias_val': 0.05, 57 | "use_vis": True, 58 | } 59 | def __init__(self,cfg): 60 | super().__init__() 61 | self.cfg={**self.default_cfg,**cfg} 62 | ray_feats_dim = self.cfg["feats_dim"] 63 | run_dim = ray_feats_dim 64 | self.mean_decoder=nn.Sequential( 65 | nn.Linear(ray_feats_dim, run_dim), 66 | nn.ELU(), 67 | nn.Linear(run_dim, run_dim), 68 | nn.ELU(), 69 | nn.Linear(run_dim, 2), 70 | nn.Softplus() 71 | ) 72 | self.var_decoder=nn.Sequential( 73 | nn.Linear(ray_feats_dim, run_dim), 74 | nn.ELU(), 75 | nn.Linear(run_dim, run_dim), 76 | nn.ELU(), 77 | nn.Linear(run_dim, 2), 78 | nn.Softplus(), 79 | AddBias(self.cfg['bias_val']), 80 | ) 81 | self.aw_decoder=nn.Sequential( 82 | nn.Linear(ray_feats_dim, run_dim), 83 | nn.ELU(), 84 | nn.Linear(run_dim, run_dim), 85 | nn.ELU(), 86 | nn.Linear(run_dim, 1), 87 | nn.Sigmoid(), 88 | ) 89 | if self.cfg['use_vis']: 90 | self.vis_decoder=nn.Sequential( 91 | nn.Linear(ray_feats_dim, run_dim), 92 | nn.ELU(), 93 | nn.Linear(run_dim, run_dim), 94 | nn.ELU(), 95 | nn.Linear(run_dim, 1), 96 | nn.Sigmoid(), 97 | ) 98 | 99 | def forward(self, feats): 100 | prj_mean = self.mean_decoder(feats) 101 | prj_var = self.var_decoder(feats) 102 | prj_aw = self.aw_decoder(feats) 103 | if self.cfg['use_vis']: 104 | prj_vis = self.vis_decoder(feats) 105 | else: 106 | prj_vis = None 107 | return prj_mean, prj_var, prj_vis, prj_aw 108 | 109 | def compute_prob(self, depth, interval, mean, var, vis, aw, is_ref, depth_range): 110 | """ 111 | :param depth: [...,dn] rfn,qn,rn,dn or qn,rn,dn 112 | :param interval: [...,dn] 1,qn,rn,dn or qn,rn,dn 113 | :param mean: [...,1 or dn] rfn,qn,rn,dn,2 or qn,rn,1,2 114 | :param var: [...,1 or dn] rfn,qn,rn,dn,2 or qn,rn,1,2 115 | :param vis: [...,1 or dn] rfn,qn,rn,dn,1 or qn,rn,1,1 116 | :param aw: [...,1 or dn] rfn,qn,rn,dn,1 or qn,rn,1,1 117 | :param is_ref: 118 | :param depth_range: rfn,2 or qn,2 119 | :return: 120 | """ 121 | if interval.shape != (1,0): 122 | near, far = get_near_far_points(depth, interval, depth_range, is_ref) 123 | else: 124 | near, far = get_near_far_points(depth, interval, depth_range, is_ref, fixed_interval=True, fixed_interval_val=0.01) 125 | # near and far [rfn,qn,rn,dn] or [qn,rn,dn] 126 | mix = torch.cat([aw, 1 - aw],-1) # [...,2] 127 | near, far = near[...,None], far[...,None] 128 | 129 | d0 = (near - mean) * var # [...,2] 130 | d1 = (far - mean) * var # [...,2] 131 | cdf0 = (0.5 + 0.5 * torch.tanh(d0)) # t(z_i) 132 | cdf1 = (0.5 + 0.5 * torch.tanh(d1)) # t(z_{i+1}) 133 | if self.cfg['use_vis']: 134 | cdf0, cdf1 = cdf0 * vis, cdf1 * vis 135 | visibility = 1 - cdf0 136 | hit_prob = cdf1 - cdf0 137 | visibility = torch.sum(visibility*mix, -1) 138 | hit_prob = torch.sum(hit_prob*mix, -1) 139 | 140 | eps = 1e-5 141 | alpha_value = torch.log(hit_prob / (visibility - hit_prob + eps) + eps) 142 | return alpha_value, visibility, hit_prob 143 | 144 | def decode_alpha_value(self, alpha_value): 145 | alpha_value = torch.sigmoid(alpha_value) 146 | return alpha_value 147 | 148 | def predict_mean(self,prj_ray_feats): 149 | prj_mean = self.mean_decoder(prj_ray_feats) 150 | return prj_mean 151 | 152 | def predict_aw(self,prj_ray_feats): 153 | return self.aw_decoder(prj_ray_feats) 154 | 155 | 156 | name2dist_decoder={ 157 | 'mixture_logistics': MixtureLogisticsDistDecoder 158 | } -------------------------------------------------------------------------------- /src/gd/vis.py: -------------------------------------------------------------------------------- 1 | """Render volumes, point clouds, and grasp detections in rviz.""" 2 | 3 | import matplotlib.colors 4 | import numpy as np 5 | from sensor_msgs.msg import PointCloud2 6 | import rospy 7 | from rospy import Publisher 8 | from visualization_msgs.msg import Marker, MarkerArray 9 | 10 | from gd.utils import ros_utils, workspace_lines 11 | from gd.utils.transform import Transform, Rotation 12 | 13 | 14 | cmap = matplotlib.colors.LinearSegmentedColormap.from_list("RedGreen", ["r", "g"]) 15 | DELETE_MARKER_MSG = Marker(action=Marker.DELETEALL) 16 | DELETE_MARKER_ARRAY_MSG = MarkerArray(markers=[DELETE_MARKER_MSG]) 17 | 18 | 19 | def draw_workspace(size): 20 | scale = size * 0.005 21 | pose = Transform.identity() 22 | scale = [scale, 0.0, 0.0] 23 | color = [0.5, 0.5, 0.5] 24 | msg = _create_marker_msg(Marker.LINE_LIST, "task", pose, scale, color) 25 | msg.points = [ros_utils.to_point_msg(point) for point in workspace_lines(size)] 26 | pubs["workspace"].publish(msg) 27 | 28 | 29 | def draw_tsdf(vol, voxel_size, threshold=0.01): 30 | msg = _create_vol_msg(vol, voxel_size, threshold) 31 | pubs["tsdf"].publish(msg) 32 | 33 | 34 | def draw_points(points): 35 | msg = ros_utils.to_cloud_msg(points, frame="task") 36 | pubs["points"].publish(msg) 37 | 38 | 39 | def draw_quality(vol, voxel_size, threshold=0.01): 40 | msg = _create_vol_msg(vol, voxel_size, threshold) 41 | pubs["quality"].publish(msg) 42 | 43 | 44 | def draw_volume(vol, voxel_size, threshold=0.01): 45 | msg = _create_vol_msg(vol, voxel_size, threshold) 46 | pubs["debug"].publish(msg) 47 | 48 | 49 | def draw_grasp(grasp, score, finger_depth): 50 | radius = 0.1 * finger_depth 51 | w, d = grasp.width, finger_depth 52 | color = cmap(float(score)) 53 | 54 | markers = [] 55 | 56 | # left finger 57 | pose = grasp.pose * Transform(Rotation.identity(), [0.0, -w / 2, d / 2]) 58 | scale = [radius, radius, d] 59 | msg = _create_marker_msg(Marker.CYLINDER, "task", pose, scale, color) 60 | msg.id = 0 61 | markers.append(msg) 62 | 63 | # right finger 64 | pose = grasp.pose * Transform(Rotation.identity(), [0.0, w / 2, d / 2]) 65 | scale = [radius, radius, d] 66 | msg = _create_marker_msg(Marker.CYLINDER, "task", pose, scale, color) 67 | msg.id = 1 68 | markers.append(msg) 69 | 70 | # wrist 71 | pose = grasp.pose * Transform(Rotation.identity(), [0.0, 0.0, -d / 4]) 72 | scale = [radius, radius, d / 2] 73 | msg = _create_marker_msg(Marker.CYLINDER, "task", pose, scale, color) 74 | msg.id = 2 75 | markers.append(msg) 76 | 77 | # palm 78 | pose = grasp.pose * Transform( 79 | Rotation.from_rotvec(np.pi / 2 * np.r_[1.0, 0.0, 0.0]), [0.0, 0.0, 0.0] 80 | ) 81 | scale = [radius, radius, w] 82 | msg = _create_marker_msg(Marker.CYLINDER, "task", pose, scale, color) 83 | msg.id = 3 84 | markers.append(msg) 85 | 86 | pubs["grasp"].publish(MarkerArray(markers=markers)) 87 | 88 | 89 | def draw_grasps(grasps, scores, finger_depth): 90 | markers = [] 91 | for i, (grasp, score) in enumerate(zip(grasps, scores)): 92 | msg = _create_grasp_marker_msg(grasp, score, finger_depth) 93 | msg.id = i 94 | markers.append(msg) 95 | msg = MarkerArray(markers=markers) 96 | pubs["grasps"].publish(msg) 97 | 98 | 99 | def clear(): 100 | pubs["workspace"].publish(DELETE_MARKER_MSG) 101 | pubs["tsdf"].publish(ros_utils.to_cloud_msg(np.array([]), frame="task")) 102 | pubs["points"].publish(ros_utils.to_cloud_msg(np.array([]), frame="task")) 103 | clear_quality() 104 | pubs["grasp"].publish(DELETE_MARKER_ARRAY_MSG) 105 | clear_grasps() 106 | pubs["debug"].publish(ros_utils.to_cloud_msg(np.array([]), frame="task")) 107 | 108 | 109 | def clear_quality(): 110 | pubs["quality"].publish(ros_utils.to_cloud_msg(np.array([]), frame="task")) 111 | 112 | 113 | def clear_grasps(): 114 | pubs["grasps"].publish(DELETE_MARKER_ARRAY_MSG) 115 | 116 | 117 | def _create_publishers(): 118 | pubs = dict() 119 | pubs["workspace"] = Publisher("/workspace", Marker, queue_size=1, latch=True) 120 | pubs["tsdf"] = Publisher("/tsdf", PointCloud2, queue_size=1, latch=True) 121 | pubs["points"] = Publisher("/points", PointCloud2, queue_size=1, latch=True) 122 | pubs["quality"] = Publisher("/quality", PointCloud2, queue_size=1, latch=True) 123 | pubs["grasp"] = Publisher("/grasp", MarkerArray, queue_size=1, latch=True) 124 | pubs["grasps"] = Publisher("/grasps", MarkerArray, queue_size=1, latch=True) 125 | pubs["debug"] = Publisher("/debug", PointCloud2, queue_size=1, latch=True) 126 | return pubs 127 | 128 | 129 | def _create_marker_msg(marker_type, frame, pose, scale, color): 130 | msg = Marker() 131 | msg.header.frame_id = frame 132 | msg.header.stamp = rospy.Time() 133 | msg.type = marker_type 134 | msg.action = Marker.ADD 135 | msg.pose = ros_utils.to_pose_msg(pose) 136 | msg.scale = ros_utils.to_vector3_msg(scale) 137 | msg.color = ros_utils.to_color_msg(color) 138 | return msg 139 | 140 | 141 | def _create_vol_msg(vol, voxel_size, threshold): 142 | vol = vol.squeeze() 143 | if type(threshold) is tuple: 144 | idx_arr = np.logical_and(vol > threshold[0], vol < threshold[1]) 145 | else: 146 | idx_arr = vol > threshold 147 | 148 | points = np.argwhere(idx_arr) * voxel_size 149 | values = np.expand_dims(vol[idx_arr], 1) 150 | return ros_utils.to_cloud_msg(points, values, frame="task") 151 | 152 | 153 | def _create_grasp_marker_msg(grasp, score, finger_depth): 154 | radius = 0.1 * finger_depth 155 | w, d = grasp.width, finger_depth 156 | scale = [radius, 0.0, 0.0] 157 | color = list(cmap(float(score))) 158 | if score < 0.01: 159 | color = [0., 0., 1, 0.8] 160 | msg = _create_marker_msg(Marker.LINE_LIST, "task", grasp.pose, scale, color) 161 | msg.points = [ros_utils.to_point_msg(point) for point in _gripper_lines(w, d)] 162 | return msg 163 | 164 | 165 | def _gripper_lines(width, depth): 166 | return [ 167 | [0.0, 0.0, -depth / 2.0], 168 | [0.0, 0.0, 0.0], 169 | [0.0, -width / 2.0, 0.0], 170 | [0.0, -width / 2.0, depth], 171 | [0.0, width / 2.0, 0.0], 172 | [0.0, width / 2.0, depth], 173 | [0.0, -width / 2.0, 0.0], 174 | [0.0, width / 2.0, 0.0], 175 | ] 176 | 177 | 178 | pubs = _create_publishers() 179 | -------------------------------------------------------------------------------- /src/nr/network/aggregate_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from easydict import EasyDict 5 | import numpy as np 6 | 7 | from network.ibrnet import IBRNetWithNeuRay, IBRNetWithNeuRayNeus 8 | from network.neus import SingleVarianceNetwork 9 | 10 | 11 | def get_dir_diff(prj_dir,que_dir): 12 | rfn, qn, rn, dn, _ = prj_dir.shape 13 | dir_diff = prj_dir - que_dir.unsqueeze(0) # rfn,qn,rn,dn,3 14 | dir_dot = torch.sum(prj_dir * que_dir.unsqueeze(0), -1, keepdim=True) 15 | dir_diff = torch.cat([dir_diff, dir_dot], -1) # rfn,qn,rn,dn,4 16 | dir_diff = dir_diff.reshape(rfn, qn * rn, dn, -1).permute(1, 2, 0, 3) 17 | return dir_diff 18 | 19 | class BaseAggregationNet(nn.Module): 20 | default_cfg={ 21 | 'sample_num': 64, 22 | 'neuray_dim': 32, 23 | 'use_img_feats': False, 24 | } 25 | def __init__(self, cfg): 26 | super().__init__() 27 | self.cfg={**self.default_cfg, **cfg} 28 | dim = self.cfg['neuray_dim'] 29 | self.prob_embed = nn.Sequential( 30 | nn.Linear(2+32, dim), 31 | nn.ReLU(), 32 | nn.Linear(dim, dim), 33 | ) 34 | 35 | def _get_embedding(self, prj_dict, que_dir): 36 | """ 37 | :param prj_dict 38 | prj_ray_feats: rfn,qn,rn,dn,f 39 | prj_hit_prob: rfn,qn,rn,dn,1 40 | prj_vis: rfn,qn,rn,dn,1 41 | prj_alpha: rfn,qn,rn,dn,1 42 | prj_rgb: rfn,qn,rn,dn,3 43 | prj_dir: rfn,qn,rn,dn,3 44 | :param que_dir: qn,rn,dn,3 45 | :return: qn,rn,dn 46 | """ 47 | hit_prob_val = (prj_dict['hit_prob']-0.5)*2 48 | vis_val = (prj_dict['vis']-0.5)*2 49 | 50 | prj_hit_prob, prj_vis, prj_rgb, prj_dir, prj_ray_feats = \ 51 | hit_prob_val, vis_val, prj_dict['rgb'], prj_dict['dir'], prj_dict['ray_feats'] 52 | rfn,qn,rn,dn,_ = hit_prob_val.shape 53 | 54 | prob_embedding = self.prob_embed(torch.cat([prj_ray_feats, prj_hit_prob, prj_vis],-1)) 55 | 56 | if que_dir is not None: 57 | dir_diff = get_dir_diff(prj_dir, que_dir) 58 | else: 59 | _,qn,rn,dn,_ = prj_hit_prob.shape 60 | dir_diff = torch.zeros((rfn, qn * rn, dn, 4)).permute(1, 2, 0, 3).to(prj_hit_prob.device) 61 | 62 | valid_mask = prj_dict['mask'] 63 | valid_mask = valid_mask.float() # rfn,qn,rn,dn 64 | valid_mask = valid_mask.reshape(rfn, qn * rn, dn, -1).permute(1, 2, 0, 3) 65 | 66 | prj_img_feats = prj_dict['img_feats'] 67 | prj_img_feats = torch.cat([prj_rgb, prj_img_feats], -1) 68 | prj_img_feats = prj_img_feats.reshape(rfn, qn * rn, dn, -1).permute(1, 2, 0, 3) 69 | prob_embedding = prob_embedding.reshape(rfn, qn * rn, dn, -1).permute(1, 2, 0, 3) 70 | return prj_img_feats, prob_embedding, dir_diff, valid_mask 71 | 72 | class DefaultAggregationNet(BaseAggregationNet): 73 | def __init__(self,cfg): 74 | super().__init__(cfg) 75 | dim = self.cfg['neuray_dim'] 76 | self.agg_impl = IBRNetWithNeuRay(dim,n_samples=self.cfg['sample_num']) 77 | 78 | def forward(self, prj_dict, que_dir, que_pts=None, que_dists=None): 79 | qn,rn,dn,_ = que_dir.shape 80 | prj_img_feats, prob_embedding, dir_diff, valid_mask = self._get_embedding(prj_dict, que_dir) 81 | outs = self.agg_impl(prj_img_feats, prob_embedding, dir_diff, valid_mask) 82 | colors = outs[...,:3] # qn*rn,dn,3 83 | density = outs[...,3] # qn*rn,dn,0 84 | return density.reshape(qn,rn,dn), colors.reshape(qn,rn,dn,3) 85 | 86 | 87 | class NeusAggregationNet(BaseAggregationNet): 88 | neus_default_cfg = { 89 | 'cos_anneal_end_iter': 0, 90 | 'init_s': 0.3, 91 | 'fix_s': False 92 | } 93 | def __init__(self,cfg): 94 | cfg = {**self.neus_default_cfg, **cfg} 95 | super().__init__(cfg) 96 | dim = self.cfg['neuray_dim'] 97 | self.agg_impl = IBRNetWithNeuRayNeus(dim,n_samples=self.cfg['sample_num']) 98 | self.deviation_network = SingleVarianceNetwork(self.cfg['init_s'], self.cfg['fix_s']) 99 | self.step = 0 100 | self.cos_anneal_ratio = 1.0 101 | 102 | def _update_cos_anneal_ratio(self): 103 | self.cos_anneal_ratio = np.min([1.0, self.step / self.cfg['cos_anneal_end_iter']]) 104 | 105 | def _get_alpha_from_sdf(self, sdf, grad, que_dir, que_dists): 106 | qn,rn,dn,_ = que_dir.shape 107 | inv_s = self.deviation_network(torch.zeros([1, 3], device=sdf.device))[:, :1].clip(1e-6, 1e6) # Single parameter 108 | inv_s = inv_s.expand(qn*rn, dn) 109 | true_cos = (-que_dir * grad).sum(-1, keepdim=True) 110 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes 111 | # the cos value "not dead" at the beginning training iterations, for better convergence. 112 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio) + 113 | F.relu(-true_cos) * self.cos_anneal_ratio)[0].squeeze(-1) # always non-positive 114 | # Estimate signed distances at section points 115 | estimated_next_sdf = sdf + iter_cos * que_dists[0] * 0.5 116 | estimated_prev_sdf = sdf - iter_cos * que_dists[0] * 0.5 117 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) 118 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) 119 | p = prev_cdf - next_cdf 120 | c = prev_cdf 121 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(qn,rn,dn).clip(0.0, 1.0) 122 | 123 | return alpha 124 | 125 | def forward(self, prj_dict, que_dir, que_pts, que_dists, is_train): 126 | if self.cfg['cos_anneal_end_iter'] and is_train: 127 | self._update_cos_anneal_ratio() 128 | qn,rn,dn,_ = que_dir.shape 129 | prj_img_feats, prob_embedding, dir_diff, valid_mask = self._get_embedding(prj_dict, que_dir) 130 | outs, grad = self.agg_impl(prj_img_feats, prob_embedding, dir_diff, valid_mask, que_pts) 131 | colors = outs[...,:3] # qn*rn,dn,3 132 | sdf = outs[...,3] # qn*rn,dn,0 133 | if que_dists is None: 134 | return None, sdf.reshape(qn,rn,dn), colors.reshape(qn,rn,dn,3), None, None 135 | if is_train: 136 | self.step += 1 137 | self.deviation_network.set_step(self.step) 138 | alpha = self._get_alpha_from_sdf(sdf, grad, que_dir, que_dists) 139 | grad_error = torch.mean((torch.linalg.norm(grad.reshape(qn,rn,dn,3), ord=2, dim=-1) - 1.0) ** 2).reshape(1,1) 140 | return alpha.reshape(qn,rn,dn), sdf.reshape(qn,rn,dn), colors.reshape(qn,rn,dn,3), grad_error, self.deviation_network.variance.reshape(1,1) 141 | 142 | 143 | name2agg_net={ 144 | 'default': DefaultAggregationNet, 145 | 'neus': NeusAggregationNet 146 | } -------------------------------------------------------------------------------- /src/nr/utils/imgs_info.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from utils.base_utils import color_map_forward, pad_img_end 5 | 6 | def random_crop(ref_imgs_info, que_imgs_info, target_size): 7 | imgs = ref_imgs_info['imgs'] 8 | n, _, h, w = imgs.shape 9 | out_h, out_w = target_size[0], target_size[1] 10 | if out_w >= w or out_h >= h: 11 | return ref_imgs_info 12 | 13 | center_h = np.random.randint(low=out_h // 2 + 1, high=h - out_h // 2 - 1) 14 | center_w = np.random.randint(low=out_w // 2 + 1, high=w - out_w // 2 - 1) 15 | 16 | def crop(tensor): 17 | tensor = tensor[:, :, center_h - out_h // 2:center_h + out_h // 2, 18 | center_w - out_w // 2:center_w + out_w // 2] 19 | return tensor 20 | 21 | def crop_imgs_info(imgs_info): 22 | imgs_info['imgs'] = crop(imgs_info['imgs']) 23 | if 'depth' in imgs_info: imgs_info['depth'] = crop(imgs_info['depth']) 24 | if 'true_depth' in imgs_info: imgs_info['true_depth'] = crop(imgs_info['true_depth']) 25 | if 'masks' in imgs_info: imgs_info['masks'] = crop(imgs_info['masks']) 26 | 27 | Ks = imgs_info['Ks'] # n, 3, 3 28 | h_init = center_h - out_h // 2 29 | w_init = center_w - out_w // 2 30 | Ks[:,0,2]-=w_init 31 | Ks[:,1,2]-=h_init 32 | imgs_info['Ks']=Ks 33 | return imgs_info 34 | 35 | return crop_imgs_info(ref_imgs_info), crop_imgs_info(que_imgs_info) 36 | 37 | def random_flip(ref_imgs_info,que_imgs_info): 38 | def flip(tensor): 39 | tensor = np.flip(tensor.transpose([0, 2, 3, 1]), 2) # n,h,w,3 40 | tensor = np.ascontiguousarray(tensor.transpose([0, 3, 1, 2])) 41 | return tensor 42 | 43 | def flip_imgs_info(imgs_info): 44 | imgs_info['imgs'] = flip(imgs_info['imgs']) 45 | if 'depth' in imgs_info: imgs_info['depth'] = flip(imgs_info['depth']) 46 | if 'true_depth' in imgs_info: imgs_info['true_depth'] = flip(imgs_info['true_depth']) 47 | if 'masks' in imgs_info: imgs_info['masks'] = flip(imgs_info['masks']) 48 | 49 | Ks = imgs_info['Ks'] # n, 3, 3 50 | Ks[:, 0, :] *= -1 51 | w = imgs_info['imgs'].shape[-1] 52 | Ks[:, 0, 2] += w - 1 53 | imgs_info['Ks'] = Ks 54 | return imgs_info 55 | 56 | ref_imgs_info = flip_imgs_info(ref_imgs_info) 57 | que_imgs_info = flip_imgs_info(que_imgs_info) 58 | return ref_imgs_info, que_imgs_info 59 | 60 | def pad_imgs_info(ref_imgs_info,pad_interval): 61 | ref_imgs, ref_depths, ref_masks = ref_imgs_info['imgs'], ref_imgs_info['depth'], ref_imgs_info['masks'] 62 | ref_depth_gt = ref_imgs_info['true_depth'] if 'true_depth' in ref_imgs_info else None 63 | rfn, _, h, w = ref_imgs.shape 64 | ph = (pad_interval - (h % pad_interval)) % pad_interval 65 | pw = (pad_interval - (w % pad_interval)) % pad_interval 66 | if ph != 0 or pw != 0: 67 | ref_imgs = np.pad(ref_imgs, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect') 68 | ref_depths = np.pad(ref_depths, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect') 69 | ref_masks = np.pad(ref_masks, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect') 70 | if ref_depth_gt is not None: 71 | ref_depth_gt = np.pad(ref_depth_gt, ((0, 0), (0, 0), (0, ph), (0, pw)), 'reflect') 72 | ref_imgs_info['imgs'], ref_imgs_info['depth'], ref_imgs_info['masks'] = ref_imgs, ref_depths, ref_masks 73 | if ref_depth_gt is not None: 74 | ref_imgs_info['true_depth'] = ref_depth_gt 75 | return ref_imgs_info 76 | 77 | def build_imgs_info(database, ref_ids, pad_interval=-1, is_aligned=True, align_depth_range=False, has_mask=True, has_depth=True, replace_none_depth = False): 78 | if not is_aligned: 79 | assert has_depth 80 | rfn = len(ref_ids) 81 | ref_imgs, ref_masks, ref_depths, shapes = [], [], [], [] 82 | for ref_id in ref_ids: 83 | img = database.get_image(ref_id) 84 | shapes.append([img.shape[0], img.shape[1]]) 85 | ref_imgs.append(img) 86 | ref_masks.append(database.get_mask(ref_id)) 87 | ref_depths.append(database.get_depth(ref_id)) 88 | 89 | shapes = np.asarray(shapes) 90 | th, tw = np.max(shapes, 0) 91 | for rfi in range(rfn): 92 | ref_imgs[rfi] = pad_img_end(ref_imgs[rfi], th, tw, 'reflect') 93 | ref_masks[rfi] = pad_img_end(ref_masks[rfi][:, :, None], th, tw, 'constant', 0)[..., 0] 94 | ref_depths[rfi] = pad_img_end(ref_depths[rfi][:, :, None], th, tw, 'constant', 0)[..., 0] 95 | ref_imgs = color_map_forward(np.stack(ref_imgs, 0)).transpose([0, 3, 1, 2]) 96 | ref_masks = np.stack(ref_masks, 0)[:, None, :, :] 97 | ref_depths = np.stack(ref_depths, 0)[:, None, :, :] 98 | else: 99 | ref_imgs = color_map_forward(np.asarray([database.get_image(ref_id) for ref_id in ref_ids])).transpose([0, 3, 1, 2]) 100 | if has_mask: 101 | ref_masks = np.asarray([database.get_mask(ref_id) for ref_id in ref_ids], dtype=np.float32)[:, None, :, :] 102 | else: 103 | b, _, h, w = ref_imgs.shape 104 | ref_masks = np.ones([b, _, h, w], dtype=np.float32) 105 | if has_depth: 106 | ref_depths = [database.get_depth(ref_id) for ref_id in ref_ids] 107 | if replace_none_depth: 108 | b, _, h, w = ref_imgs.shape 109 | for i, depth in enumerate(ref_depths): 110 | if depth is None: ref_depths[i] = np.zeros([h, w], dtype=np.float32) 111 | ref_depths = np.asarray(ref_depths, dtype=np.float32)[:, None, :, :] 112 | else: ref_depths = None 113 | 114 | ref_poses = np.asarray([database.get_pose(ref_id) for ref_id in ref_ids], dtype=np.float32) 115 | ref_Ks = np.asarray([database.get_K(ref_id) for ref_id in ref_ids], dtype=np.float32) 116 | ref_depth_range = np.asarray([database.get_depth_range(ref_id) for ref_id in ref_ids], dtype=np.float32) 117 | if align_depth_range: 118 | ref_depth_range[:,0]=np.min(ref_depth_range[:,0]) 119 | ref_depth_range[:,1]=np.max(ref_depth_range[:,1]) 120 | ref_imgs_info = {'imgs': ref_imgs, 'poses': ref_poses, 'Ks': ref_Ks, 'depth_range': ref_depth_range, 'masks': ref_masks, 'bbox3d': database.get_bbox3d()} 121 | if has_depth: ref_imgs_info['depth'] = ref_depths 122 | if pad_interval!=-1: 123 | ref_imgs_info = pad_imgs_info(ref_imgs_info, pad_interval) 124 | return ref_imgs_info 125 | 126 | def build_render_imgs_info(que_pose,que_K,que_shape,que_depth_range): 127 | h, w = que_shape 128 | h, w = int(h), int(w) 129 | que_coords = np.stack(np.meshgrid(np.arange(w), np.arange(h)), -1) 130 | que_coords = que_coords.reshape([1, -1, 2]).astype(np.float32) 131 | return {'poses': que_pose.astype(np.float32)[None,:,:], # 1,3,4 132 | 'Ks': que_K.astype(np.float32)[None,:,:], # 1,3,3 133 | 'coords': que_coords, 134 | 'depth_range': np.asarray(que_depth_range, np.float32)[None, :], 135 | 'shape': (h,w)} 136 | 137 | def build_canonical_info(bbox, resolution, que_pose, que_K): 138 | x_min,x_max,y_min,y_max = bbox 139 | print('bbox', bbox) 140 | que_coords = np.stack(np.meshgrid(np.linspace(y_min, y_max, 2), np.linspace(x_min, x_max, 2)), -1) 141 | print('que_coords', que_coords) 142 | return {'poses': que_pose.astype(np.float32)[None,:,:], # 1,3,4 143 | 'Ks': que_K.astype(np.float32)[None,:,:], # 1,3,3 144 | 'coords': que_coords, 145 | 'depth_range': np.asarray([0.5, 0.8], np.float32)[None, :], 146 | 'shape': (resolution, resolution)} 147 | 148 | def imgs_info_to_torch(imgs_info): 149 | for k, v in imgs_info.items(): 150 | if isinstance(v,np.ndarray): 151 | imgs_info[k] = torch.from_numpy(v).float() 152 | return imgs_info 153 | 154 | def grasp_info_to_torch(info): 155 | torch_info = [] 156 | for item in info: 157 | torch_info.append(torch.from_numpy(item)) 158 | return torch_info 159 | def imgs_info_slice(imgs_info, indices): 160 | imgs_info_out={} 161 | imgs_info_out['bbox3d'] = imgs_info['bbox3d'] 162 | for k, v in imgs_info.items(): 163 | if k != 'bbox3d' and v is not None: 164 | imgs_info_out[k] = v[indices] 165 | return imgs_info_out 166 | -------------------------------------------------------------------------------- /src/gd/experiments/clutter_removal.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import uuid 3 | import os 4 | import numpy as np 5 | import pandas as pd 6 | import sys 7 | import shutil 8 | from pathlib import Path 9 | from gd import io 10 | from gd.grasp import * 11 | from gd.simulation import ClutterRemovalSim 12 | sys.path.append("./") 13 | from rd.render import blender_init_scene, blender_render, blender_update_sceneobj 14 | 15 | MAX_CONSECUTIVE_FAILURES = 2 16 | 17 | State = collections.namedtuple("State", ["tsdf", "pc"]) 18 | 19 | def copydirs(from_file, to_file): 20 | if not os.path.exists(to_file): 21 | os.makedirs(to_file) 22 | files = os.listdir(from_file) 23 | for f in files: 24 | if os.path.isdir(from_file + '/' + f): 25 | copydirs(from_file + '/' + f, to_file + '/' + f) 26 | else: 27 | shutil.copy(from_file + '/' + f, to_file + '/' + f) 28 | 29 | 30 | def run( 31 | grasp_plan_fn, 32 | logdir, 33 | description, 34 | scene, 35 | object_set, 36 | num_objects=5, 37 | n=6, 38 | N=None, 39 | num_rounds=40, 40 | seed=1, 41 | sim_gui=False, 42 | rviz=False, 43 | round_idx=0, 44 | renderer_root_dir="", 45 | gpuid=None, 46 | args=None, 47 | render_frame_list=[] 48 | ): 49 | """Run several rounds of simulated clutter removal experiments. 50 | 51 | Each round, m objects are randomly placed in a tray. Then, the grasping pipeline is 52 | run until (a) no objects remain, (b) the planner failed to find a grasp hypothesis, 53 | or (c) maximum number of consecutive failed grasp attempts. 54 | """ 55 | sim = ClutterRemovalSim(scene, object_set, gui=sim_gui, seed=seed, renderer_root_dir=renderer_root_dir, args=args) 56 | logger = Logger(args.log_root_dir, logdir, description, round_idx) 57 | 58 | # output modality 59 | output_modality_dict = {'RGB': 1, 60 | 'IR': 0, 61 | 'NOCS': 0, 62 | 'Mask': 0, 63 | 'Normal': 0} 64 | 65 | for n_round in range(round_idx, round_idx+1): 66 | urdfs_and_poses_dict = sim.reset(num_objects, round_idx) 67 | renderer, quaternion_list, translation_list, path_scene = blender_init_scene(renderer_root_dir, args.log_root_dir, args.obj_texture_image_root_path, scene, urdfs_and_poses_dict, round_idx, logdir, False, args.material_type, gpuid, output_modality_dict) 68 | 69 | render_finished = False 70 | render_fail_times = 0 71 | while not render_finished and render_fail_times < 3: 72 | try: 73 | blender_render(renderer, quaternion_list, translation_list, path_scene, render_frame_list, output_modality_dict, args.camera_focal, is_init=True) 74 | render_finished = True 75 | except: 76 | render_fail_times += 1 77 | if not render_finished: 78 | raise RuntimeError("Blender render failed for 3 times.") 79 | 80 | path_scene_backup = os.path.join(path_scene + "_backup", "%d_init"%n_round) 81 | if os.path.exists(path_scene_backup) == False: 82 | os.makedirs(path_scene_backup) 83 | copydirs(path_scene, path_scene_backup) 84 | 85 | round_id = logger.last_round_id() + 1 86 | logger.log_round(round_id, sim.num_objects) 87 | 88 | consecutive_failures = 1 89 | last_label = None 90 | 91 | n_grasp = 0 92 | while sim.num_objects > 0 and consecutive_failures < MAX_CONSECUTIVE_FAILURES: 93 | timings = {} 94 | 95 | timings["integration"] = 0 96 | 97 | gt_tsdf, gt_pc, _ = sim.acquire_tsdf(n=n, N=N) 98 | 99 | if args.method == "graspnerf": 100 | grasps, scores, timings["planning"] = grasp_plan_fn(render_frame_list, round_idx, n_grasp, gt_tsdf) 101 | else: 102 | raise NotImplementedError 103 | 104 | if len(grasps) == 0: 105 | print("no detections found, abort this round") 106 | break 107 | else: 108 | print(f"{len(grasps)} detections found.") 109 | 110 | # execute grasp 111 | grasp, score = grasps[0], scores[0] 112 | (label, _), remain_obj_inws_infos = sim.execute_grasp(grasp, allow_contact=True) 113 | 114 | # render the modified scene after grasping 115 | obj_name_list = [str(value[0]).split("/")[-1][:-5] for value in remain_obj_inws_infos] 116 | obj_quat_list = [value[2][[3, 0, 1, 2]] for value in remain_obj_inws_infos] 117 | obj_trans_list = [value[3] for value in remain_obj_inws_infos] 118 | obj_uid_list = [value[4] for value in remain_obj_inws_infos] 119 | 120 | # update blender scene 121 | blender_update_sceneobj(obj_name_list, obj_trans_list, obj_quat_list, obj_uid_list) 122 | 123 | # render updated scene 124 | render_finished = False 125 | render_fail_times = 0 126 | while not render_finished and render_fail_times < 3: 127 | try: 128 | blender_render(renderer, quaternion_list, translation_list, path_scene, render_frame_list, output_modality_dict, args.camera_focal) 129 | render_finished = True 130 | except: 131 | render_fail_times += 1 132 | if not render_finished: 133 | raise RuntimeError("Blender render failed for 3 times.") 134 | 135 | 136 | path_scene_backup = os.path.join(path_scene+"_backup", "%d_%d"%(n_round,n_grasp)) 137 | if os.path.exists(path_scene_backup)==False: 138 | os.makedirs(path_scene_backup) 139 | copydirs(path_scene, path_scene_backup) 140 | 141 | # log the grasp 142 | logger.log_grasp(round_id, timings, grasp, score, label) 143 | 144 | if last_label == Label.FAILURE and label == Label.FAILURE: 145 | consecutive_failures += 1 146 | else: 147 | consecutive_failures = 1 148 | last_label = label 149 | 150 | n_grasp += 1 151 | 152 | 153 | class Logger(object): 154 | def __init__(self, log_root_dir, expname, description, round_idx): 155 | self.logdir = Path(os.path.join(log_root_dir, "exp_results", expname , "%04d"%int(round_idx)))#description 156 | self.scenes_dir = self.logdir / "scenes" 157 | self.scenes_dir.mkdir(parents=True, exist_ok=True) 158 | 159 | self.rounds_csv_path = self.logdir / "rounds.csv" 160 | self.grasps_csv_path = self.logdir / "grasps.csv" 161 | self._create_csv_files_if_needed() 162 | 163 | def _create_csv_files_if_needed(self): 164 | if not self.rounds_csv_path.exists(): 165 | io.create_csv(self.rounds_csv_path, ["round_id", "object_count"]) 166 | 167 | if not self.grasps_csv_path.exists(): 168 | columns = [ 169 | "round_id", 170 | "scene_id", 171 | "qx", 172 | "qy", 173 | "qz", 174 | "qw", 175 | "x", 176 | "y", 177 | "z", 178 | "width", 179 | "score", 180 | "label", 181 | "integration_time", 182 | "planning_time", 183 | ] 184 | io.create_csv(self.grasps_csv_path, columns) 185 | 186 | def last_round_id(self): 187 | df = pd.read_csv(self.rounds_csv_path) 188 | return -1 if df.empty else df["round_id"].max() 189 | 190 | def log_round(self, round_id, object_count): 191 | io.append_csv(self.rounds_csv_path, round_id, object_count) 192 | 193 | def log_grasp(self, round_id, timings, grasp, score, label): 194 | # log scene 195 | scene_id = uuid.uuid4().hex 196 | 197 | # log grasp 198 | qx, qy, qz, qw = grasp.pose.rotation.as_quat() 199 | x, y, z = grasp.pose.translation 200 | width = grasp.width 201 | label = int(label) 202 | io.append_csv( 203 | self.grasps_csv_path, 204 | round_id, 205 | scene_id, 206 | qx, 207 | qy, 208 | qz, 209 | qw, 210 | x, 211 | y, 212 | z, 213 | width, 214 | score, 215 | label, 216 | timings["integration"], 217 | timings["planning"], 218 | ) 219 | 220 | 221 | class Data(object): 222 | """Object for loading and analyzing experimental data.""" 223 | 224 | def __init__(self, logdir): 225 | self.logdir = logdir 226 | self.rounds = pd.read_csv(logdir / "rounds.csv") 227 | self.grasps = pd.read_csv(logdir / "grasps.csv") 228 | 229 | def num_rounds(self): 230 | return len(self.rounds.index) 231 | 232 | def num_grasps(self): 233 | return len(self.grasps.index) 234 | 235 | def success_rate(self): 236 | return self.grasps["label"].mean() * 100 237 | 238 | def percent_cleared(self): 239 | df = ( 240 | self.grasps[["round_id", "label"]] 241 | .groupby("round_id") 242 | .sum() 243 | .rename(columns={"label": "cleared_count"}) 244 | .merge(self.rounds, on="round_id") 245 | ) 246 | return df["cleared_count"].sum() / df["object_count"].sum() * 100 247 | 248 | def avg_planning_time(self): 249 | return self.grasps["planning_time"].mean() 250 | 251 | def read_grasp(self, i): 252 | scene_id, grasp, label = io.read_grasp(self.grasps, i) 253 | score = self.grasps.loc[i, "score"] 254 | scene_data = np.load(self.logdir / "scenes" / (scene_id + ".npz")) 255 | 256 | return scene_data["points"], grasp, score, label -------------------------------------------------------------------------------- /src/nr/network/render_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from network.ops import interpolate_feats 3 | 4 | def coords2rays(coords, poses, Ks): 5 | """ 6 | :param coords: [rfn,rn,2] 7 | :param poses: [rfn,3,4] 8 | :param Ks: [rfn,3,3] 9 | :return: 10 | ref_rays: 11 | centers: [rfn,rn,3] 12 | directions: [rfn,rn,3] 13 | """ 14 | rot = poses[:, :, :3].unsqueeze(1).permute(0, 1, 3, 2) # rfn,1,3,3 15 | trans = -rot @ poses[:, :, 3:].unsqueeze(1) # rfn,1,3,1 16 | 17 | rfn, rn, _ = coords.shape 18 | centers = trans.repeat(1, rn, 1, 1).squeeze(-1) # rfn,rn,3 19 | coords = torch.cat([coords, torch.ones([rfn, rn, 1], dtype=torch.float32, device=coords.device)], 2) # rfn,rn,3 20 | Ks_inv = torch.inverse(Ks).unsqueeze(1) 21 | cam_xyz = Ks_inv @ coords.unsqueeze(3) 22 | cam_xyz = rot @ cam_xyz + trans 23 | directions = cam_xyz.squeeze(3) - centers 24 | # directions = directions / torch.clamp(torch.norm(directions, dim=2, keepdim=True), min=1e-4) 25 | return centers, directions 26 | 27 | def depth2points(que_imgs_info, que_depth): 28 | """ 29 | :param que_imgs_info: 30 | :param que_depth: qn,rn,dn 31 | :return: 32 | """ 33 | cneters, directions = coords2rays(que_imgs_info['coords'],que_imgs_info['poses'],que_imgs_info['Ks']) # centers, directions qn,rn,3 34 | qn, rn, _ = cneters.shape 35 | que_pts = cneters.unsqueeze(2) + directions.unsqueeze(2) * que_depth.unsqueeze(3) # qn,rn,dn,3 36 | qn, rn, dn, _ = que_pts.shape 37 | que_dir = -directions / torch.norm(directions, dim=2, keepdim=True) # qn,rn,3 38 | que_dir = que_dir.unsqueeze(2).repeat(1, 1, dn, 1) 39 | return que_pts, que_dir # qn,rn,dn,3 40 | 41 | def depth2dists(depth): 42 | device = depth.device 43 | dists = depth[...,1:]-depth[...,:-1] 44 | return torch.cat([dists, torch.full([*depth.shape[:-1], 1], 1e6, dtype=torch.float32, device=device)], -1) 45 | 46 | def depth2inv_dists(depth,depth_range): 47 | near, far = -1 / depth_range[:, 0], -1 / depth_range[:, 1] 48 | near, far = near[:, None, None], far[:, None, None] 49 | depth_inv = -1 / depth # qn,rn,dn 50 | depth_inv = (depth_inv - near) / (far - near) 51 | dists = depth2dists(depth_inv) # qn,rn,dn 52 | return dists 53 | 54 | def interpolate_feature_map(ray_feats, coords, mask, h, w, border_type='border'): 55 | """ 56 | :param ray_feats: rfn,f,h,w 57 | :param coords: rfn,pn,2 58 | :param mask: rfn,pn 59 | :param h: 60 | :param w: 61 | :param border_type: 62 | :return: 63 | """ 64 | fh, fw = ray_feats.shape[-2:] 65 | if fh == h and fw == w: 66 | cur_ray_feats = interpolate_feats(ray_feats, coords, h, w, border_type, True) # rfn,pn,f 67 | else: 68 | cur_ray_feats = interpolate_feats(ray_feats, coords, h, w, border_type, False) # rfn,pn,f 69 | cur_ray_feats = cur_ray_feats * mask.float().unsqueeze(-1) # rfn,pn,f 70 | return cur_ray_feats 71 | 72 | def alpha_values2hit_prob(alpha_values): 73 | """ 74 | :param alpha_values: qn,rn,dn 75 | :return: qn,rn,dn 76 | """ 77 | no_hit_density = torch.cat([torch.ones((*alpha_values.shape[:-1], 1)) 78 | .to(alpha_values.device), 1. - alpha_values + 1e-10], -1) # rn,k+1 79 | hit_prob = alpha_values * torch.cumprod(no_hit_density, -1)[..., :-1] # [n,k] 80 | return hit_prob 81 | 82 | def project_points_coords(pts, Rt, K): 83 | """ 84 | :param pts: [pn,3] 85 | :param Rt: [rfn,3,4] 86 | :param K: [rfn,3,3] 87 | :return: 88 | coords: [rfn,pn,2] 89 | invalid_mask: [rfn,pn] 90 | """ 91 | pn = pts.shape[0] 92 | hpts = torch.cat([pts,torch.ones([pn,1],device=pts.device,dtype=torch.float32)],1) 93 | srn = Rt.shape[0] 94 | KRt = K @ Rt # rfn,3,4 95 | last_row = torch.zeros([srn,1,4],device=pts.device,dtype=torch.float32) 96 | last_row[:,:,3] = 1.0 97 | H = torch.cat([KRt,last_row],1) # rfn,4,4 98 | pts_cam = H[:,None,:,:] @ hpts[None,:,:,None] 99 | pts_cam = pts_cam[:,:,:3,0] 100 | depth = pts_cam[:,:,2:] 101 | invalid_mask = torch.abs(depth)<1e-4 102 | depth[invalid_mask] = 1e-3 103 | pts_2d = pts_cam[:,:,:2]/depth 104 | return pts_2d, ~(invalid_mask[...,0]), depth 105 | 106 | def project_points_directions(poses,points): 107 | """ 108 | :param poses: rfn,3,4 109 | :param points: pn,3 110 | :return: rfn,pn,3 111 | """ 112 | cam_pts = -poses[:, :, :3].permute(0, 2, 1) @ poses[:, :, 3:] # rfn,3,1 113 | dir = points.unsqueeze(0) - cam_pts.permute(0, 2, 1) # [1,pn,3] - [rfn,1,3] -> rfn,pn,3 114 | dir = -dir / torch.clamp_min(torch.norm(dir, dim=2, keepdim=True), min=1e-5) # rfn,pn,3 115 | return dir 116 | 117 | def project_points_ref_views(ref_imgs_info, que_points): 118 | """ 119 | :param ref_imgs_info: 120 | :param que_points: pn,3 121 | :return: 122 | """ 123 | prj_pts, prj_valid_mask, prj_depth = project_points_coords( 124 | que_points, ref_imgs_info['poses'], ref_imgs_info['Ks']) # rfn,pn,2 125 | h,w=ref_imgs_info['imgs'].shape[-2:] 126 | prj_img_invalid_mask = (prj_pts[..., 0] < -0.5) | (prj_pts[..., 0] >= w - 0.5) | \ 127 | (prj_pts[..., 1] < -0.5) | (prj_pts[..., 1] >= h - 0.5) 128 | valid_mask = prj_valid_mask & (~prj_img_invalid_mask) 129 | prj_dir = project_points_directions(ref_imgs_info['poses'], que_points) # rfn,pn,3 130 | return prj_dir, prj_pts, prj_depth, valid_mask 131 | 132 | def project_points_dict(ref_imgs_info, que_pts): 133 | # project all points 134 | qn, rn, dn, _ = que_pts.shape 135 | prj_dir, prj_pts, prj_depth, prj_mask = project_points_ref_views(ref_imgs_info, que_pts.reshape([qn * rn * dn, 3])) 136 | rfn, _, h, w = ref_imgs_info['imgs'].shape 137 | prj_ray_feats = interpolate_feature_map(ref_imgs_info['ray_feats'], prj_pts, prj_mask, h, w) 138 | prj_rgb = interpolate_feature_map(ref_imgs_info['imgs'], prj_pts, prj_mask, h, w) 139 | prj_dict = {'dir':prj_dir, 'pts':prj_pts, 'depth':prj_depth, 'mask': prj_mask.float(), 'ray_feats':prj_ray_feats, 'rgb':prj_rgb} 140 | 141 | # post process 142 | for k, v in prj_dict.items(): 143 | prj_dict[k]=v.reshape(rfn,qn,rn,dn,-1) 144 | return prj_dict 145 | 146 | def sample_depth(depth_range, coords, sample_num, random_sample): 147 | """ 148 | :param depth_range: qn,2 149 | :param sample_num: 150 | :param random_sample: 151 | :return: 152 | """ 153 | qn, rn, _ = coords.shape 154 | device = coords.device 155 | near, far = depth_range[:,0], depth_range[:,1] # qn,2 156 | dn = sample_num 157 | assert(dn>2) 158 | interval = (1 / far - 1 / near) / (dn - 1) # qn 159 | val = torch.arange(1, dn - 1, dtype=torch.float32, device=near.device)[None, None, :] 160 | if random_sample: 161 | val = val + (torch.rand(qn, rn, dn-2, dtype=torch.float32, device=device) - 0.5) * 0.999 162 | else: 163 | val = val + torch.zeros(qn, rn, dn-2, dtype=torch.float32, device=device) 164 | ticks = interval[:, None, None] * val 165 | 166 | diff = (1 / far - 1 / near) 167 | ticks = torch.cat([torch.zeros(qn,rn,1,dtype=torch.float32,device=device),ticks,diff[:,None,None].repeat(1,rn,1)],-1) 168 | que_depth = 1 / (1 / near[:, None, None] + ticks) # qn, dn, 169 | que_dists = torch.cat([que_depth[...,1:],torch.full([*que_depth.shape[:-1],1],1e6,dtype=torch.float32,device=device)],-1) - que_depth 170 | return que_depth, que_dists # qn, rn, dn 171 | 172 | def sample_fine_depth(depth, hit_prob, depth_range, sample_num, random_sample, inv_mode=True): 173 | """ 174 | :param depth: qn,rn,dn 175 | :param hit_prob: qn,rn,dn 176 | :param depth_range: qn,2 177 | :param sample_num: 178 | :param random_sample: 179 | :param inv_mode: 180 | :return: qn,rn,dn 181 | """ 182 | if inv_mode: 183 | near, far = depth_range[0,0], depth_range[0,1] 184 | near, far = -1/near, -1/far 185 | depth_inv = -1 / depth # qn,rn,dn 186 | depth_inv = (depth_inv - near) / (far - near) 187 | depth = depth_inv 188 | 189 | depth_center = (depth[...,1:] + depth[...,:-1])/2 190 | depth_center = torch.cat([depth[...,0:1],depth_center,depth[...,-1:]],-1) # rfn,pn,dn+1 191 | fdn = sample_num 192 | # Get pdf 193 | hit_prob = hit_prob + 1e-5 # prevent nans 194 | pdf = hit_prob / torch.sum(hit_prob, -1, keepdim=True) # rfn,pn,dn-1 195 | cdf = torch.cumsum(pdf, -1) # rfn,pn,dn-1 196 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # rfn,pn,dn 197 | 198 | # Take uniform samples 199 | if not random_sample: 200 | interval = 1 / fdn 201 | u = 0.5*interval+torch.arange(fdn)*interval 202 | # u = torch.linspace(0., 1., steps=fdn) 203 | u = u.expand(list(cdf.shape[:-1]) + [fdn]) # rfn,pn,fdn 204 | else: 205 | u = torch.rand(list(cdf.shape[:-1]) + [fdn]) 206 | 207 | # Invert CDF 208 | device = pdf.device 209 | u = u.to(device).contiguous() # rfn,pn,fdn 210 | inds = torch.searchsorted(cdf, u, right=True) # rfn,pn,fdn 211 | below = torch.max(torch.zeros_like(inds-1), inds-1) # rfn,pn,fdn 212 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) # rfn,pn,fdn 213 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) # rfn,pn,fdn,2 214 | 215 | matched_shape = [*inds_g.shape[:-1], cdf.shape[-1]] 216 | cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), -1, inds_g) # rfn,pn,fdn,2 217 | bins_g = torch.gather(depth_center.unsqueeze(-2).expand(matched_shape), -1, inds_g) # rfn,pn,fdn,2 218 | 219 | denom = (cdf_g[...,1]-cdf_g[...,0]) # rfn,pn,fdn 220 | denom = torch.where(denom<1e-5, torch.ones_like(denom), denom) 221 | t = (u-cdf_g[...,0])/denom 222 | fine_depth = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) 223 | 224 | if inv_mode: 225 | near, far = depth_range[0,0], depth_range[0,1] 226 | near, far = -1/near, -1/far 227 | fine_depth = fine_depth * (far - near) + near 228 | fine_depth = -1/fine_depth 229 | return fine_depth 230 | -------------------------------------------------------------------------------- /src/nr/train/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | from torch.nn import DataParallel 6 | from torch.optim import Adam, SGD 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | from dataset.name2dataset import name2dataset 11 | from network.loss import name2loss 12 | from network.renderer import name2network 13 | from train.lr_common_manager import name2lr_manager 14 | from network.metrics import name2metrics 15 | from train.train_tools import to_cuda, Logger, reset_learning_rate, MultiGPUWrapper, DummyLoss 16 | from train.train_valid import ValidationEvaluator 17 | from utils.dataset_utils import simple_collate_fn, dummy_collate_fn 18 | from asset import vgn_val_scene_names 19 | 20 | class Trainer: 21 | default_cfg={ 22 | "optimizer_type": 'adam', 23 | "multi_gpus": False, 24 | "lr_type": "exp_decay", 25 | "lr_cfg":{ 26 | "lr_init": 1.0e-4, 27 | "decay_step": 100000, 28 | "decay_rate": 0.5, 29 | }, 30 | "total_step": 300000, 31 | "train_log_step": 20, 32 | "val_interval": 10000, 33 | "save_interval": 1000, 34 | "worker_num": 8, 35 | "fix_seed": False 36 | } 37 | def _init_dataset(self): 38 | self.train_set=name2dataset[self.cfg['train_dataset_type']](self.cfg['train_dataset_cfg'], True) 39 | self.train_set=DataLoader(self.train_set,1,True,num_workers=self.cfg['worker_num'],collate_fn=dummy_collate_fn) 40 | print(f'train set len {len(self.train_set)}') 41 | self.val_set_list, self.val_set_names = [], [] 42 | for val_set_cfg in self.cfg['val_set_list']: 43 | name, val_type, val_cfg = val_set_cfg['name'], val_set_cfg['type'], val_set_cfg['cfg'] 44 | if 'val_scene_num' in val_set_cfg: 45 | num = val_set_cfg['val_scene_num'] 46 | num = len(vgn_val_scene_names) if num == -1 else num 47 | names, val_types = [name] * num, [val_type] * num 48 | val_cfgs = [] 49 | for i in range(num): 50 | val_cfgs.append({**val_cfg, **{'val_database_name': vgn_val_scene_names[i]}}) 51 | else: 52 | names, val_types, val_cfgs = [name], [val_type], [val_cfg] 53 | for name, val_type, val_cfg in zip(names, val_types, val_cfgs): 54 | val_set = name2dataset[val_type](val_cfg, False) 55 | val_set = DataLoader(val_set,1,False,num_workers=self.cfg['worker_num'],collate_fn=dummy_collate_fn) 56 | self.val_set_list.append(val_set) 57 | self.val_set_names.append(name) 58 | print(f"val set num: {len(self.val_set_list)}") 59 | 60 | def _init_network(self): 61 | self.network=name2network[self.cfg['network']](self.cfg).cuda() 62 | 63 | # loss 64 | self.val_losses = [] 65 | for loss_name in self.cfg['loss']: 66 | self.val_losses.append(name2loss[loss_name](self.cfg)) 67 | self.val_metrics = [] 68 | 69 | # metrics 70 | for metric_name in self.cfg['val_metric']: 71 | if metric_name in name2metrics: 72 | self.val_metrics.append(name2metrics[metric_name](self.cfg)) 73 | else: 74 | self.val_metrics.append(name2loss[metric_name](self.cfg)) 75 | 76 | # we do not support multi gpu training for NeuRay 77 | if self.cfg['multi_gpus']: 78 | raise NotImplementedError 79 | # make multi gpu network 80 | # self.train_network=DataParallel(MultiGPUWrapper(self.network,self.val_losses)) 81 | # self.train_losses=[DummyLoss(self.val_losses)] 82 | else: 83 | self.train_network=self.network 84 | self.train_losses=self.val_losses 85 | 86 | if self.cfg['optimizer_type']=='adam': 87 | self.optimizer = Adam 88 | elif self.cfg['optimizer_type']=='sgd': 89 | self.optimizer = SGD 90 | else: 91 | raise NotImplementedError 92 | 93 | self.val_evaluator=ValidationEvaluator(self.cfg) 94 | self.lr_manager=name2lr_manager[self.cfg['lr_type']](self.cfg['lr_cfg']) 95 | self.optimizer=self.lr_manager.construct_optimizer(self.optimizer,self.network) 96 | 97 | def __init__(self,cfg): 98 | self.cfg={**self.default_cfg,**cfg} 99 | if self.cfg['fix_seed']: 100 | seed = 0 101 | torch.manual_seed(seed) 102 | np.random.seed(seed) 103 | random.seed(seed) 104 | torch.cuda.manual_seed_all(seed) 105 | os.environ['PYTHONHASHSEED'] = str(seed) 106 | print("fix seed") 107 | self.model_name=cfg['name'] 108 | self.model_dir=os.path.join('data/model',cfg['group_name'],cfg['name']) 109 | if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) 110 | self.pth_fn=os.path.join(self.model_dir,'model.pth') 111 | self.best_pth_fn=os.path.join(self.model_dir,'model_best.pth') 112 | assert self.cfg["key_metric_prefer"] in ['higher', 'lower'] 113 | self.better = lambda x, y: x > y if self.cfg["key_metric_prefer"] == 'higher' else x < y 114 | 115 | def run(self): 116 | self._init_dataset() 117 | self._init_network() 118 | self._init_logger() 119 | 120 | best_para,start_step=self._load_model() 121 | if self.cfg["key_metric_prefer"] == 'lower' and start_step == 0: 122 | best_para = 1e6 123 | train_iter=iter(self.train_set) 124 | 125 | pbar=tqdm(total=self.cfg['total_step'],bar_format='{r_bar}') 126 | pbar.update(start_step) 127 | for step in range(start_step,self.cfg['total_step']): 128 | try: 129 | train_data = next(train_iter) 130 | except StopIteration: 131 | self.train_set.dataset.reset() 132 | train_iter = iter(self.train_set) 133 | train_data = next(train_iter) 134 | if not self.cfg['multi_gpus']: 135 | train_data = to_cuda(train_data) 136 | train_data['step']=step 137 | 138 | self.train_network.train() 139 | self.network.train() 140 | lr = self.lr_manager(self.optimizer, step) 141 | 142 | self.optimizer.zero_grad() 143 | self.train_network.zero_grad() 144 | 145 | log_info={} 146 | outputs=self.train_network(train_data) 147 | for loss in self.train_losses: 148 | loss_results = loss(outputs,train_data,step) 149 | for k,v in loss_results.items(): 150 | log_info[k]=v 151 | 152 | loss=0 153 | for k,v in log_info.items(): 154 | if k.startswith('loss'): 155 | loss=loss+torch.mean(v) 156 | 157 | loss.backward() 158 | self.optimizer.step() 159 | if ((step+1) % self.cfg['train_log_step']) == 0: 160 | self._log_data(log_info,step+1,'train') 161 | 162 | if step==0 or (step+1)%self.cfg['val_interval']==0 or (step+1)==self.cfg['total_step']: 163 | torch.cuda.empty_cache() 164 | val_results={} 165 | val_para = 0 166 | for vi, val_set in enumerate(self.val_set_list): 167 | val_results_cur, val_para_cur = self.val_evaluator( 168 | self.network, self.val_losses + self.val_metrics, val_set, step, 169 | self.model_name, val_set_name=self.val_set_names[vi]) 170 | for k,v in val_results_cur.items(): 171 | key = f'{self.val_set_names[vi]}-{k}' 172 | if not key in val_results: 173 | val_results[key] = v 174 | else: 175 | val_results[key] += v 176 | val_para += val_para_cur 177 | 178 | # average all items 179 | for k,v in val_results.items(): 180 | val_results[k] /= len(self.val_set_list) 181 | val_para /= len(self.val_set_list) 182 | 183 | if step and self.better(val_para, best_para): # do not save the first step 184 | print(f'New best model {self.cfg["key_metric_name"]}: {val_para:.5f} previous {best_para:.5f}') 185 | best_para=val_para 186 | self._save_model(step+1,best_para,self.best_pth_fn) 187 | self._log_data(val_results,step+1,'val') 188 | del val_results, val_para, val_para_cur, val_results_cur 189 | 190 | if (step+1)%self.cfg['save_interval']==0: 191 | self._save_model(step+1,best_para) 192 | 193 | pbar.set_postfix(loss=float(loss.detach().cpu().numpy()),lr=lr) 194 | pbar.update(1) 195 | del loss, log_info 196 | 197 | pbar.close() 198 | 199 | def _load_model(self): 200 | best_para,start_step=0,0 201 | if os.path.exists(self.pth_fn): 202 | checkpoint=torch.load(self.pth_fn) 203 | best_para = checkpoint['best_para'] 204 | start_step = checkpoint['step'] 205 | self.network.load_state_dict(checkpoint['network_state_dict']) 206 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 207 | print(f'==> resuming from the latest {self.pth_fn} of step {start_step} with best metric {best_para}') 208 | 209 | return best_para, start_step 210 | 211 | def _save_model(self, step, best_para, save_fn=None): 212 | save_fn = self.pth_fn if save_fn is None else save_fn 213 | torch.save({ 214 | 'step':step, 215 | 'best_para':best_para, 216 | 'network_state_dict': self.network.state_dict(), 217 | 'optimizer_state_dict': self.optimizer.state_dict(), 218 | },save_fn) 219 | 220 | def _init_logger(self): 221 | self.logger = Logger(self.model_dir) 222 | 223 | def _log_data(self,results,step,prefix='train',verbose=False): 224 | log_results={} 225 | for k, v in results.items(): 226 | if isinstance(v,float) or np.isscalar(v): 227 | log_results[k] = v 228 | elif type(v)==np.ndarray: 229 | log_results[k]=np.mean(v) 230 | else: 231 | log_results[k]=np.mean(v.detach().cpu().numpy()) 232 | self.logger.log(log_results,prefix,step,verbose) 233 | -------------------------------------------------------------------------------- /src/nr/network/mvsnet/mvsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from network.mvsnet.modules import ConvBnReLU, ConvBnReLU3D, depth_regression, homo_warp 5 | from inplace_abn import InPlaceABN 6 | 7 | class FeatureNet(nn.Module): 8 | def __init__(self, norm_act=InPlaceABN): 9 | super(FeatureNet, self).__init__() 10 | self.inplanes = 32 11 | 12 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act) 13 | self.conv1 = ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act) 14 | 15 | self.conv2 = ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act) 16 | self.conv3 = ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act) 17 | self.conv4 = ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act) 18 | 19 | self.conv5 = ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act) 20 | self.conv6 = ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act) 21 | self.feature = nn.Conv2d(32, 32, 3, 1, 1) 22 | 23 | def forward(self, x): 24 | x = self.conv1(self.conv0(x)) 25 | x = self.conv4(self.conv3(self.conv2(x))) 26 | x = self.feature(self.conv6(self.conv5(x))) 27 | return x 28 | 29 | class CostRegNet(nn.Module): 30 | def __init__(self, norm_act=InPlaceABN): 31 | super(CostRegNet, self).__init__() 32 | self.conv0 = ConvBnReLU3D(32, 8, norm_act=norm_act) 33 | 34 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act) 35 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) 36 | 37 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act) 38 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) 39 | 40 | self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act) 41 | self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act) 42 | 43 | self.conv7 = nn.Sequential( 44 | nn.ConvTranspose3d(64, 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), 45 | norm_act(32)) 46 | 47 | self.conv9 = nn.Sequential( 48 | nn.ConvTranspose3d(32, 16, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), 49 | norm_act(16)) 50 | 51 | self.conv11 = nn.Sequential( 52 | nn.ConvTranspose3d(16, 8, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), 53 | norm_act(8)) 54 | 55 | self.prob = nn.Conv3d(8, 1, 3, stride=1, padding=1) 56 | 57 | def forward(self, x): 58 | conv0 = self.conv0(x) 59 | conv2 = self.conv2(self.conv1(conv0)) 60 | conv4 = self.conv4(self.conv3(conv2)) 61 | x = self.conv6(self.conv5(conv4)) 62 | x = conv4 + self.conv7(x) 63 | del conv4 64 | x = conv2 + self.conv9(x) 65 | del conv2 66 | x = conv0 + self.conv11(x) 67 | del conv0 68 | x = self.prob(x) 69 | return x 70 | 71 | class MVSNet(nn.Module): 72 | def __init__(self, norm_act=InPlaceABN): 73 | super(MVSNet, self).__init__() 74 | self.feature = FeatureNet(norm_act) 75 | self.cost_regularization = CostRegNet(norm_act) 76 | 77 | def forward(self, imgs, proj_mats, depth_values): 78 | # imgs: (B, V, 3, H, W) 79 | # proj_mats: (B, V, 4, 4) 80 | # depth_values: (B, D) 81 | B, V, _, H, W = imgs.shape 82 | D = depth_values.shape[1] 83 | 84 | # step 1. feature extraction 85 | # in: images; out: 32-channel feature maps 86 | imgs = imgs.reshape(B*V, 3, H, W) 87 | feats = self.feature(imgs) # (B*V, F, h, w) 88 | del imgs 89 | feats = feats.reshape(B, V, *feats.shape[1:]) # (B, V, F, h, w) 90 | ref_feats, src_feats = feats[:, 0], feats[:, 1:] 91 | ref_proj, src_projs = proj_mats[:, 0], proj_mats[:, 1:] 92 | src_feats = src_feats.permute(1, 0, 2, 3, 4) # (V-1, B, F, h, w) 93 | src_projs = src_projs.permute(1, 0, 2, 3) # (V-1, B, 4, 4) 94 | 95 | # step 2. differentiable homograph, build cost volume 96 | ref_volume = ref_feats.unsqueeze(2).repeat(1, 1, D, 1, 1) # (B, F, D, h, w) 97 | volume_sum = ref_volume 98 | volume_sq_sum = ref_volume ** 2 99 | del ref_volume 100 | 101 | ref_proj = torch.inverse(ref_proj) 102 | for src_feat, src_proj in zip(src_feats, src_projs): 103 | warped_volume = homo_warp(src_feat, src_proj, ref_proj, depth_values) 104 | volume_sum = volume_sum + warped_volume 105 | volume_sq_sum = volume_sq_sum + warped_volume ** 2 106 | del warped_volume 107 | # aggregate multiple feature volumes by variance 108 | volume_variance = volume_sq_sum.div_(V).sub_(volume_sum.div_(V).pow_(2)) 109 | del volume_sq_sum, volume_sum 110 | 111 | # step 3. cost volume regularization 112 | cost_reg = self.cost_regularization(volume_variance).squeeze(1) 113 | prob_volume = F.softmax(cost_reg, 1) # (B, D, h, w) 114 | depth = depth_regression(prob_volume, depth_values) 115 | 116 | with torch.no_grad(): 117 | # sum probability of 4 consecutive depth indices 118 | prob_volume_sum4 = 4 * F.avg_pool3d(F.pad(prob_volume.unsqueeze(1), 119 | pad=(0, 0, 0, 0, 1, 2)), 120 | (4, 1, 1), stride=1).squeeze(1) # (B, D, h, w) 121 | # find the (rounded) index that is the final prediction 122 | depth_index = depth_regression(prob_volume, 123 | torch.arange(D, 124 | device=prob_volume.device, 125 | dtype=prob_volume.dtype) 126 | ).long() # (B, h, w) 127 | # the confidence is the 4-sum probability at this index 128 | confidence = torch.gather(prob_volume_sum4, 1, 129 | depth_index.unsqueeze(1)).squeeze(1) # (B, h, w) 130 | 131 | return depth, confidence 132 | 133 | def construct_cost_volume(self, ref_imgs, ref_nn_idx, ref_prjs, depth_values, batch_num=2): 134 | # ref_imgs rfn,3,h,w 135 | # ref_nn_ids: rfn,nn 136 | # ref_prjs: rfn,4,4 note it is already scaled!!! 137 | # depth_values: rfn,dn 138 | # return: rfn,dn,h//4,w//4 139 | ref_feats = self.feature(ref_imgs) # rfn,f,h,w 140 | ref_prjs_inv = torch.inverse(ref_prjs) # rfn,4,4 141 | dn = depth_values.shape[1] 142 | 143 | rfn, n_num = ref_nn_idx.shape 144 | cost_reg_all = [] 145 | for rfi in range(0,rfn,batch_num): 146 | volume_sum, volume_sum_sq = ref_feats[rfi:rfi+batch_num].unsqueeze(2), ref_feats[rfi:rfi+batch_num].unsqueeze(2)**2 # 1,f,1,h,w 147 | volume_sum, volume_sum_sq = volume_sum.repeat(1, 1, dn, 1, 1), volume_sum_sq.repeat(1, 1, dn, 1, 1) 148 | for ni in range(n_num): 149 | warp_feats = homo_warp(ref_feats[ref_nn_idx[rfi:rfi+batch_num,ni]],ref_prjs[ref_nn_idx[rfi:rfi+batch_num,ni]], 150 | ref_prjs_inv[rfi:rfi+batch_num],depth_values[rfi:rfi+batch_num]) # 1,f,dn,h,w 151 | volume_sum += warp_feats 152 | volume_sum_sq += warp_feats**2 153 | volume_variance = volume_sum_sq.div_(n_num+1).sub_(volume_sum.div_(n_num+1).pow_(2)) # 1,f,dn,h,w 154 | del volume_sum_sq, volume_sum 155 | # 1,dn,h,w 156 | cost_reg_all.append(self.cost_regularization(volume_variance).squeeze(1)) 157 | cost_reg_all = torch.cat(cost_reg_all,0) 158 | return cost_reg_all 159 | 160 | def construct_cost_volume_with_src(self, ref_imgs, src_imgs, ref_nn_idx, ref_prjs, src_prjs, depth_values, batch_num=2): 161 | # ref_imgs rfn,3,h,w 162 | # src_imgs srn,3,h,w 163 | # ref_nn_ids: rfn,nn 164 | # ref_prjs: rfn,4,4 note it is already scaled!!! 165 | # src_prjs: src,4,4 note it is already scaled!!! 166 | # depth_values: rfn,dn 167 | # return: rfn,dn,h//4,w//4 168 | ref_feats = self.feature(ref_imgs) # rfn,f,h,w 169 | src_feats = self.feature(src_imgs) # src,f,h,w 170 | ref_prjs_inv = torch.inverse(ref_prjs) # rfn,4,4 171 | dn = depth_values.shape[1] 172 | 173 | rfn, n_num = ref_nn_idx.shape 174 | cost_reg_all = [] 175 | for rfi in range(0,rfn,batch_num): 176 | volume_sum, volume_sum_sq = ref_feats[rfi:rfi+batch_num].unsqueeze(2), ref_feats[rfi:rfi+batch_num].unsqueeze(2)**2 # 1,f,1,h,w 177 | volume_sum, volume_sum_sq = volume_sum.repeat(1, 1, dn, 1, 1), volume_sum_sq.repeat(1, 1, dn, 1, 1) 178 | for ni in range(n_num): 179 | warp_feats = homo_warp(src_feats[ref_nn_idx[rfi:rfi+batch_num,ni]],src_prjs[ref_nn_idx[rfi:rfi+batch_num,ni]], 180 | ref_prjs_inv[rfi:rfi+batch_num],depth_values[rfi:rfi+batch_num]) # 1,f,dn,h,w 181 | volume_sum += warp_feats 182 | volume_sum_sq += warp_feats**2 183 | volume_variance = volume_sum_sq.div_(n_num+1).sub_(volume_sum.div_(n_num+1).pow_(2)) # 1,f,dn,h,w 184 | del volume_sum_sq, volume_sum 185 | # 1,dn,h,w 186 | cost_reg_all.append(self.cost_regularization(volume_variance).squeeze(1)) 187 | cost_reg_all = torch.cat(cost_reg_all,0) 188 | return cost_reg_all 189 | 190 | 191 | def extract_model_state_dict(ckpt_path, prefixes_to_ignore=[]): 192 | checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) 193 | checkpoint_ = {} 194 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint 195 | for k, v in checkpoint['state_dict'].items(): 196 | if not k.startswith('model.'): 197 | continue 198 | k = k[6:] # remove 'model.' 199 | for prefix in prefixes_to_ignore: 200 | if k.startswith(prefix): 201 | print('ignore', k) 202 | break 203 | else: 204 | checkpoint_[k] = v 205 | else: # if it only has model weights 206 | for k, v in checkpoint.items(): 207 | for prefix in prefixes_to_ignore: 208 | if k.startswith(prefix): 209 | print('ignore', k) 210 | break 211 | else: 212 | checkpoint_[k] = v 213 | return checkpoint_ 214 | 215 | def load_ckpt(model, ckpt_path, prefixes_to_ignore=[]): 216 | model_dict = model.state_dict() 217 | checkpoint_ = extract_model_state_dict(ckpt_path, prefixes_to_ignore) 218 | model_dict.update(checkpoint_) 219 | model.load_state_dict(model_dict) -------------------------------------------------------------------------------- /src/nr/main.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import time 3 | 4 | sys.path.append("./src/nr") 5 | from pathlib import Path 6 | import numpy as np 7 | 8 | import torch 9 | from skimage.io import imsave, imread 10 | from network.renderer import name2network 11 | from utils.base_utils import load_cfg, to_cuda 12 | from utils.imgs_info import build_render_imgs_info, imgs_info_to_torch, grasp_info_to_torch 13 | from network.renderer import name2network 14 | from utils.base_utils import color_map_forward 15 | from network.loss import VGNLoss 16 | from tqdm import tqdm 17 | from scipy import ndimage 18 | import cv2 19 | from gd.utils.transform import Transform, Rotation 20 | from gd.grasp import * 21 | 22 | 23 | def process( 24 | tsdf_vol, 25 | qual_vol, 26 | rot_vol, 27 | width_vol, 28 | gaussian_filter_sigma=1.0, 29 | min_width=1.33, 30 | max_width=9.33, 31 | tsdf_thres_high = 0.5, 32 | tsdf_thres_low = 1e-3, 33 | n_grasp=0 34 | ): 35 | tsdf_vol = tsdf_vol.squeeze() 36 | qual_vol = qual_vol.squeeze() 37 | rot_vol = rot_vol.squeeze() 38 | width_vol = width_vol.squeeze() 39 | # smooth quality volume with a Gaussian 40 | qual_vol = ndimage.gaussian_filter( 41 | qual_vol, sigma=gaussian_filter_sigma, mode="nearest" 42 | ) 43 | 44 | # mask out voxels too far away from the surface 45 | outside_voxels = tsdf_vol > tsdf_thres_high 46 | inside_voxels = np.logical_and(tsdf_thres_low < tsdf_vol, tsdf_vol < tsdf_thres_high) 47 | valid_voxels = ndimage.morphology.binary_dilation( 48 | outside_voxels, iterations=2, mask=np.logical_not(inside_voxels) 49 | ) 50 | qual_vol[valid_voxels == False] = 0.0 51 | 52 | # reject voxels with predicted widths that are too small or too large 53 | qual_vol[np.logical_or(width_vol < min_width, width_vol > max_width)] = 0.0 54 | 55 | return qual_vol, rot_vol, width_vol 56 | 57 | 58 | def select(qual_vol, rot_vol, width_vol, threshold=0.90, max_filter_size=4): 59 | qual_vol[qual_vol < threshold] = 0.0 60 | 61 | # non maximum suppression 62 | max_vol = ndimage.maximum_filter(qual_vol, size=max_filter_size) 63 | 64 | qual_vol = np.where(qual_vol == max_vol, qual_vol, 0.0) 65 | mask = np.where(qual_vol, 1.0, 0.0) 66 | 67 | # construct grasps 68 | grasps, scores, indexs = [], [], [] 69 | for index in np.argwhere(mask): 70 | indexs.append(index) 71 | grasp, score = select_index(qual_vol, rot_vol, width_vol, index) 72 | grasps.append(grasp) 73 | scores.append(score) 74 | return grasps, scores, indexs 75 | 76 | 77 | def select_index(qual_vol, rot_vol, width_vol, index): 78 | i, j, k = index 79 | score = qual_vol[i, j, k] 80 | rot = rot_vol[:, i, j, k] 81 | ori = Rotation.from_quat(rot) 82 | pos = np.array([i, j, k], dtype=np.float64) 83 | width = width_vol[i, j, k] 84 | return Grasp(Transform(ori, pos), width), score 85 | 86 | 87 | class GraspNeRFPlanner(object): 88 | def set_params(self, args): 89 | self.args = args 90 | self.voxel_size = 0.3 / 40 91 | self.bbox3d = [[-0.15, -0.15, -0.0503],[0.15, 0.15, 0.2497]] 92 | self.tsdf_thres_high = 0 93 | self.tsdf_thres_low = -0.85 94 | 95 | self.renderer_root_dir = self.args.renderer_root_dir 96 | tp, split, scene_type, scene_split, scene_id, background_size = args.database_name.split('/') 97 | background, size = background_size.split('_') 98 | self.split = split 99 | self.tp = tp 100 | self.downSample = float(size) 101 | tp2wh = { 102 | 'vgn_syn': (640, 360) 103 | } 104 | src_wh = tp2wh[tp] 105 | self.img_wh = (np.array(src_wh) * self.downSample).astype(int) 106 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 107 | self.K = np.array([[892.62, 0.0, 639.5], 108 | [0.0, 892.62, 359.5], 109 | [0.0, 0.0, 1.0]]) 110 | self.K[:2] = self.K[:2] * self.downSample 111 | if self.tp == 'vgn_syn': 112 | self.K[:2] /= 2 113 | self.depth_thres = { 114 | 'vgn_syn': 0.8, 115 | } 116 | 117 | if args.object_set == "graspnet": 118 | dir_name = "pile_graspnet_test" 119 | else: 120 | if self.args.scene == "pile": 121 | dir_name = "pile_pile_test_200" 122 | elif self.args.scene == "packed": 123 | dir_name = "packed_packed_test_200" 124 | elif self.args.scene == "single": 125 | dir_name = "single_single_test_200" 126 | 127 | scene_root_dir = os.path.join(self.renderer_root_dir, "data/mesh_pose_list", dir_name) 128 | self.mesh_pose_list = [i for i in sorted(os.listdir(scene_root_dir))] 129 | self.depth_root_dir = "" 130 | self.depth_list = [] 131 | 132 | def __init__(self, args=None, cfg_fn=None, debug_dir=None) -> None: 133 | default_render_cfg = { 134 | 'min_wn': 3, # working view number 135 | 'ref_pad_interval': 16, # input image size should be multiple of 16 136 | 'use_src_imgs': False, # use source images to construct cost volume or not 137 | 'cost_volume_nn_num': 3, # number of source views used in cost volume 138 | 'use_depth': True, # use colmap depth in rendering or not, 139 | } 140 | # load render cfg 141 | if cfg_fn is None: 142 | self.set_params(args) 143 | cfg = load_cfg(args.cfg_fn) 144 | else: 145 | cfg = load_cfg(cfg_fn) 146 | 147 | print(f"[I] GraspNeRFPlanner: using ckpt: {cfg['name']}") 148 | render_cfg = cfg['train_dataset_cfg'] if 'train_dataset_cfg' in cfg else {} 149 | render_cfg = {**default_render_cfg, **render_cfg} 150 | cfg['render_rgb'] = False # only for training. Disable in grasping. 151 | # load model 152 | self.net = name2network[cfg['network']](cfg) 153 | ckpt_filename = 'model_best' 154 | ckpt = torch.load(Path('src/nr/ckpt') / cfg["group_name"] / cfg["name"] / f'{ckpt_filename}.pth') 155 | self.net.load_state_dict(ckpt['network_state_dict']) 156 | self.net.cuda() 157 | self.net.eval() 158 | self.step = ckpt["step"] 159 | self.output_dir = debug_dir 160 | if debug_dir is not None: 161 | if not Path(debug_dir).exists(): 162 | Path(debug_dir).mkdir(parents=True) 163 | self.loss = VGNLoss({}) 164 | self.num_input_views = render_cfg['num_input_views'] 165 | print(f"[I] GraspNeRFPlanner: load model at step {self.step} of best metric {ckpt['best_para']}") 166 | 167 | def get_image(self, img_id, round_idx): 168 | img_filename = os.path.join(self.args.log_root_dir, "rendered_results/" + str(self.args.logdir).split("/")[-1], "rgb/%04d.png"%img_id) 169 | img = imread(img_filename)[:,:,:3] 170 | img = cv2.resize(img, self.img_wh) 171 | return np.asarray(img, dtype=np.float32) 172 | 173 | def get_pose(self, img_id): 174 | poses_ori = np.load(Path(self.renderer_root_dir) / 'camera_pose.npy') 175 | poses = [np.linalg.inv(p @ self.blender2opencv)[:3,:] for p in poses_ori] 176 | return poses[img_id].astype(np.float32).copy() 177 | 178 | def get_K(self, img_id): 179 | return self.K.astype(np.float32).copy() 180 | 181 | def get_depth_range(self,img_id, round_idx, fixed=False): 182 | if fixed: 183 | return np.array([0.2,0.8]) 184 | depth = self.get_depth(img_id, round_idx) 185 | nf = [max(0, np.min(depth)), min(self.depth_thres[self.tp], np.max(depth))] 186 | return np.array(nf) 187 | 188 | def __call__(self, test_view_id, round_idx, n_grasp, gt_tsdf): 189 | # load data for test 190 | images = [self.get_image(i, round_idx) for i in test_view_id] 191 | images = color_map_forward(np.stack(images, 0)).transpose([0, 3, 1, 2]) 192 | extrinsics = np.stack([self.get_pose(i) for i in test_view_id], 0) 193 | intrinsics = np.stack([self.get_K(i) for i in test_view_id], 0) 194 | depth_range = np.asarray([self.get_depth_range(i, round_idx, fixed = True) for i in test_view_id], dtype=np.float32) 195 | 196 | tsdf_vol, qual_vol_ori, rot_vol_ori, width_vol_ori, toc = self.core(images, extrinsics, intrinsics, depth_range, self.bbox3d) 197 | 198 | qual_vol, rot_vol, width_vol = process(tsdf_vol, qual_vol_ori, rot_vol_ori, width_vol_ori, tsdf_thres_high=self.tsdf_thres_high, tsdf_thres_low=self.tsdf_thres_low, n_grasp=n_grasp) 199 | grasps, scores, indexs = select(qual_vol.copy(), rot_vol, width_vol) 200 | grasps, scores, indexs = np.asarray(grasps), np.asarray(scores), np.asarray(indexs) 201 | 202 | if len(grasps) > 0: 203 | np.random.seed(self.args.seed + round_idx + n_grasp) 204 | p = np.random.permutation(len(grasps)) 205 | grasps = [from_voxel_coordinates(g, self.voxel_size) for g in grasps[p]] 206 | scores = scores[p] 207 | indexs = indexs[p] 208 | 209 | return grasps, scores, toc 210 | 211 | def core(self, 212 | images: np.ndarray, 213 | extrinsics: np.ndarray, 214 | intrinsics: np.ndarray, 215 | depth_range=[0.2, 0.8], 216 | bbox3d=[[-0.15, -0.15, -0.05],[0.15, 0.15, 0.25]], gt_info=None, que_id=0): 217 | """ 218 | @args 219 | images: np array of shape (3, 3, h, w), image in RGB format 220 | extrinsics: np array of shape (3, 4, 4), the transformation matrix from world to camera 221 | intrinsics: np array of shape (3, 3, 3) 222 | @rets 223 | volume, label, rot, width: np array of shape (1, 1, res, res, res) 224 | """ 225 | _, _, h, w = images.shape 226 | assert h % 32 == 0 and w % 32 == 0 227 | extrinsics = extrinsics[:, :3, :] 228 | que_imgs_info = build_render_imgs_info(extrinsics[que_id], intrinsics[que_id], (h, w), depth_range[que_id]) 229 | src_imgs_info = {'imgs': images, 'poses': extrinsics.astype(np.float32), 'Ks': intrinsics.astype(np.float32), 'depth_range': depth_range.astype(np.float32), 230 | 'bbox3d': np.array(bbox3d)} 231 | 232 | ref_imgs_info = src_imgs_info.copy() 233 | num_views = images.shape[0] 234 | ref_imgs_info['nn_ids'] = np.arange(num_views).repeat(num_views, 0) 235 | data = {'step': self.step , 'eval': True} 236 | if not gt_info: 237 | data['full_vol'] = True 238 | else: 239 | data['grasp_info'] = to_cuda(grasp_info_to_torch(gt_info)) 240 | data['que_imgs_info'] = to_cuda(imgs_info_to_torch(que_imgs_info)) 241 | data['src_imgs_info'] = to_cuda(imgs_info_to_torch(src_imgs_info)) 242 | data['ref_imgs_info'] = to_cuda(imgs_info_to_torch(ref_imgs_info)) 243 | 244 | with torch.no_grad(): 245 | t0 = time.time() 246 | render_info = self.net(data) 247 | t = time.time() - t0 248 | 249 | if gt_info: 250 | return self.loss(render_info, data, self.step, False) 251 | 252 | label, rot, width = render_info['vgn_pred'] 253 | 254 | return render_info['volume'].cpu().numpy(), label.cpu().numpy(), rot.cpu().numpy(), width.cpu().numpy(), t -------------------------------------------------------------------------------- /src/gd/utils/btsim.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import pybullet 5 | from pybullet_utils import bullet_client 6 | 7 | 8 | from vgn.utils.transform import Rotation, Transform 9 | 10 | assert pybullet.isNumpyEnabled(), "Pybullet needs to be built with NumPy" 11 | 12 | 13 | class BtWorld(object): 14 | """Interface to a PyBullet physics server. 15 | 16 | Attributes: 17 | dt: Time step of the physics simulation. 18 | rtf: Real time factor. If negative, the simulation is run as fast as possible. 19 | sim_time: Virtual time elpased since the last simulation reset. 20 | """ 21 | 22 | def __init__(self, gui=True): 23 | connection_mode = pybullet.GUI if gui else pybullet.DIRECT 24 | self.p = bullet_client.BulletClient(connection_mode) 25 | 26 | self.gui = gui 27 | self.dt = 1.0 / 240.0 28 | self.solver_iterations = 150 29 | 30 | self.bodies_urdfs = {} ## 31 | 32 | self.reset() 33 | 34 | def set_gravity(self, gravity): 35 | self.p.setGravity(*gravity) 36 | 37 | def load_urdf(self, urdf_path, pose, scale=1.0): 38 | body = Body.from_urdf(self.p, urdf_path, pose, scale) 39 | self.bodies[body.uid] = body 40 | self.bodies_urdfs[body.uid] = [urdf_path, scale] ## 41 | 42 | return body 43 | 44 | def remove_body(self, body, isRemoveObjPerGrasp=False): 45 | self.p.removeBody(body.uid) 46 | del self.bodies[body.uid] 47 | if isRemoveObjPerGrasp: ## 48 | self.bodies_urdfs.pop(body.uid) 49 | 50 | def add_constraint(self, *argv, **kwargs): 51 | """See `Constraint` below.""" 52 | constraint = Constraint(self.p, *argv, **kwargs) 53 | return constraint 54 | 55 | def add_camera(self, intrinsic, near, far): 56 | camera = Camera(self.p, intrinsic, near, far) 57 | return camera 58 | 59 | def get_contacts(self, bodyA): 60 | points = self.p.getContactPoints(bodyA.uid) 61 | contacts = [] 62 | for point in points: 63 | ### 64 | urdf = self.bodies_urdfs[point[2]][0] 65 | if str(urdf).split("/")[-1] == "plane.urdf": 66 | continue 67 | ### 68 | contact = Contact( 69 | bodyA=self.bodies[point[1]], 70 | bodyB=self.bodies[point[2]], 71 | point=point[5], 72 | normal=point[7], 73 | depth=point[8], 74 | force=point[9], 75 | ) 76 | contacts.append(contact) 77 | return contacts 78 | 79 | def reset(self): 80 | self.p.resetSimulation() 81 | self.p.setPhysicsEngineParameter( 82 | fixedTimeStep=self.dt, numSolverIterations=self.solver_iterations 83 | ) 84 | self.bodies = {} 85 | self.sim_time = 0.0 86 | 87 | def step(self): 88 | self.p.stepSimulation() 89 | self.sim_time += self.dt 90 | if self.gui: 91 | time.sleep(self.dt) 92 | 93 | def save_state(self): 94 | return self.p.saveState() 95 | 96 | def restore_state(self, state_uid): 97 | self.p.restoreState(stateId=state_uid) 98 | 99 | def close(self): 100 | self.p.disconnect() 101 | 102 | 103 | class Body(object): 104 | """Interface to a multibody simulated in PyBullet. 105 | 106 | Attributes: 107 | uid: The unique id of the body within the physics server. 108 | name: The name of the body. 109 | joints: A dict mapping joint names to Joint objects. 110 | links: A dict mapping link names to Link objects. 111 | """ 112 | 113 | def __init__(self, physics_client, body_uid): 114 | self.p = physics_client 115 | self.uid = body_uid 116 | self.name = self.p.getBodyInfo(self.uid)[1].decode("utf-8") 117 | self.joints, self.links = {}, {} 118 | for i in range(self.p.getNumJoints(self.uid)): 119 | joint_info = self.p.getJointInfo(self.uid, i) 120 | joint_name = joint_info[1].decode("utf8") 121 | self.joints[joint_name] = Joint(self.p, self.uid, i) 122 | link_name = joint_info[12].decode("utf8") 123 | self.links[link_name] = Link(self.p, self.uid, i) 124 | 125 | @classmethod 126 | def from_urdf(cls, physics_client, urdf_path, pose, scale): 127 | body_uid = physics_client.loadURDF( 128 | str(urdf_path), 129 | pose.translation, 130 | pose.rotation.as_quat(), 131 | globalScaling=scale, 132 | ) 133 | return cls(physics_client, body_uid) 134 | 135 | def get_pose(self): 136 | pos, ori = self.p.getBasePositionAndOrientation(self.uid) 137 | return Transform(Rotation.from_quat(ori), np.asarray(pos)) 138 | 139 | def set_pose(self, pose): 140 | self.p.resetBasePositionAndOrientation( 141 | self.uid, pose.translation, pose.rotation.as_quat() 142 | ) 143 | 144 | def get_velocity(self): 145 | linear, angular = self.p.getBaseVelocity(self.uid) 146 | return linear, angular 147 | 148 | 149 | class Link(object): 150 | """Interface to a link simulated in Pybullet. 151 | 152 | Attributes: 153 | link_index: The index of the joint. 154 | """ 155 | 156 | def __init__(self, physics_client, body_uid, link_index): 157 | self.p = physics_client 158 | self.body_uid = body_uid 159 | self.link_index = link_index 160 | 161 | def get_pose(self): 162 | link_state = self.p.getLinkState(self.body_uid, self.link_index) 163 | pos, ori = link_state[0], link_state[1] 164 | return Transform(Rotation.from_quat(ori), pos) 165 | 166 | 167 | class Joint(object): 168 | """Interface to a joint simulated in PyBullet. 169 | 170 | Attributes: 171 | joint_index: The index of the joint. 172 | lower_limit: Lower position limit of the joint. 173 | upper_limit: Upper position limit of the joint. 174 | effort: The maximum joint effort. 175 | """ 176 | 177 | def __init__(self, physics_client, body_uid, joint_index): 178 | self.p = physics_client 179 | self.body_uid = body_uid 180 | self.joint_index = joint_index 181 | 182 | joint_info = self.p.getJointInfo(body_uid, joint_index) 183 | self.lower_limit = joint_info[8] 184 | self.upper_limit = joint_info[9] 185 | self.effort = joint_info[10] 186 | 187 | def get_position(self): 188 | joint_state = self.p.getJointState(self.body_uid, self.joint_index) 189 | return joint_state[0] 190 | 191 | def set_position(self, position, kinematics=False): 192 | if kinematics: 193 | self.p.resetJointState(self.body_uid, self.joint_index, position) 194 | self.p.setJointMotorControl2( 195 | self.body_uid, 196 | self.joint_index, 197 | pybullet.POSITION_CONTROL, 198 | targetPosition=position, 199 | force=self.effort, 200 | ) 201 | 202 | 203 | class Constraint(object): 204 | """Interface to a constraint in PyBullet. 205 | 206 | Attributes: 207 | uid: The unique id of the constraint within the physics server. 208 | """ 209 | 210 | def __init__( 211 | self, 212 | physics_client, 213 | parent, 214 | parent_link, 215 | child, 216 | child_link, 217 | joint_type, 218 | joint_axis, 219 | parent_frame, 220 | child_frame, 221 | ): 222 | """ 223 | Create a new constraint between links of bodies. 224 | 225 | Args: 226 | parent: 227 | parent_link: None for the base. 228 | child: None for a fixed frame in world coordinates. 229 | 230 | """ 231 | self.p = physics_client 232 | parent_body_uid = parent.uid 233 | parent_link_index = parent_link.link_index if parent_link else -1 234 | child_body_uid = child.uid if child else -1 235 | child_link_index = child_link.link_index if child_link else -1 236 | 237 | self.uid = self.p.createConstraint( 238 | parentBodyUniqueId=parent_body_uid, 239 | parentLinkIndex=parent_link_index, 240 | childBodyUniqueId=child_body_uid, 241 | childLinkIndex=child_link_index, 242 | jointType=joint_type, 243 | jointAxis=joint_axis, 244 | parentFramePosition=parent_frame.translation, 245 | parentFrameOrientation=parent_frame.rotation.as_quat(), 246 | childFramePosition=child_frame.translation, 247 | childFrameOrientation=child_frame.rotation.as_quat(), 248 | ) 249 | 250 | def change(self, **kwargs): 251 | self.p.changeConstraint(self.uid, **kwargs) 252 | 253 | 254 | class Contact(object): 255 | """Contact point between two multibodies. 256 | 257 | Attributes: 258 | point: Contact point. 259 | normal: Normal vector from ... to ... 260 | depth: Penetration depth 261 | force: Contact force acting on body ... 262 | """ 263 | 264 | def __init__(self, bodyA, bodyB, point, normal, depth, force): 265 | self.bodyA = bodyA 266 | self.bodyB = bodyB 267 | self.point = point 268 | self.normal = normal 269 | self.depth = depth 270 | self.force = force 271 | 272 | 273 | class Camera(object): 274 | """Virtual RGB-D camera based on the PyBullet camera interface. 275 | 276 | Attributes: 277 | intrinsic: The camera intrinsic parameters. 278 | """ 279 | 280 | def __init__(self, physics_client, intrinsic, near, far): 281 | self.intrinsic = intrinsic 282 | self.near = near 283 | self.far = far 284 | self.proj_matrix = _build_projection_matrix(intrinsic, near, far) 285 | self.p = physics_client 286 | 287 | def render(self, extrinsic): 288 | """Render synthetic RGB and depth images. 289 | 290 | Args: 291 | extrinsic: Extrinsic parameters, T_cam_ref. 292 | """ 293 | # Construct OpenGL compatible view and projection matrices. 294 | gl_view_matrix = extrinsic 295 | gl_view_matrix[2, :] *= -1 # flip the Z axis 296 | gl_view_matrix = gl_view_matrix.flatten(order="F") 297 | gl_proj_matrix = self.proj_matrix.flatten(order="F") 298 | 299 | result = self.p.getCameraImage( 300 | width=self.intrinsic.width, 301 | height=self.intrinsic.height, 302 | viewMatrix=gl_view_matrix, 303 | projectionMatrix=gl_proj_matrix, 304 | renderer=pybullet.ER_TINY_RENDERER, 305 | ) 306 | 307 | rgb, z_buffer = result[2][:, :, :3], result[3] 308 | depth = ( 309 | 1.0 * self.far * self.near / (self.far - (self.far - self.near) * z_buffer) 310 | ) 311 | return rgb, depth 312 | 313 | 314 | def _build_projection_matrix(intrinsic, near, far): 315 | perspective = np.array( 316 | [ 317 | [intrinsic.fx, 0.0, -intrinsic.cx, 0.0], 318 | [0.0, intrinsic.fy, -intrinsic.cy, 0.0], 319 | [0.0, 0.0, near + far, near * far], 320 | [0.0, 0.0, -1.0, 0.0], 321 | ] 322 | ) 323 | ortho = _gl_ortho(0.0, intrinsic.width, intrinsic.height, 0.0, near, far) 324 | return np.matmul(ortho, perspective) 325 | 326 | 327 | def _gl_ortho(left, right, bottom, top, near, far): 328 | ortho = np.diag( 329 | [2.0 / (right - left), 2.0 / (top - bottom), -2.0 / (far - near), 1.0] 330 | ) 331 | ortho[0, 3] = -(right + left) / (right - left) 332 | ortho[1, 3] = -(top + bottom) / (top - bottom) 333 | ortho[2, 3] = -(far + near) / (far - near) 334 | return ortho -------------------------------------------------------------------------------- /src/nr/network/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import pyquaternion as pyq 5 | import math 6 | from network.ops import interpolate_feats 7 | import torch.nn.functional as F 8 | import torchmetrics 9 | from utils.base_utils import calc_rot_error_from_qxyzw 10 | 11 | class Loss: 12 | def __init__(self, keys): 13 | """ 14 | keys are used in multi-gpu model, DummyLoss in train_tools.py 15 | :param keys: the output keys of the dict 16 | """ 17 | self.keys=keys 18 | 19 | def __call__(self, data_pr, data_gt, step, **kwargs): 20 | pass 21 | 22 | class ConsistencyLoss(Loss): 23 | default_cfg={ 24 | 'use_ray_mask': False, 25 | 'use_dr_loss': False, 26 | 'use_dr_fine_loss': False, 27 | 'use_nr_fine_loss': False, 28 | } 29 | def __init__(self, cfg): 30 | self.cfg={**self.default_cfg,**cfg} 31 | super().__init__([f'loss_prob','loss_prob_fine']) 32 | 33 | def __call__(self, data_pr, data_gt, step, **kwargs): 34 | if 'hit_prob_self' not in data_pr: return {} 35 | prob0 = data_pr['hit_prob_nr'].detach() # qn,rn,dn 36 | prob1 = data_pr['hit_prob_self'] # qn,rn,dn 37 | if self.cfg['use_ray_mask']: 38 | ray_mask = data_pr['ray_mask'].float() # 1,rn 39 | else: 40 | ray_mask = 1 41 | ce = - prob0 * torch.log(prob1 + 1e-5) - (1 - prob0) * torch.log(1 - prob1 + 1e-5) 42 | outputs={'loss_prob': torch.mean(torch.mean(ce,-1),1)} 43 | if 'hit_prob_nr_fine' in data_pr: 44 | prob0 = data_pr['hit_prob_nr_fine'].detach() # qn,rn,dn 45 | prob1 = data_pr['hit_prob_self_fine'] # qn,rn,dn 46 | ce = - prob0 * torch.log(prob1 + 1e-5) - (1 - prob0) * torch.log(1 - prob1 + 1e-5) 47 | outputs['loss_prob_fine']=torch.mean(torch.mean(ce,-1),1) 48 | return outputs 49 | 50 | class RenderLoss(Loss): 51 | default_cfg={ 52 | 'use_ray_mask': True, 53 | 'use_dr_loss': False, 54 | 'use_dr_fine_loss': False, 55 | 'use_nr_fine_loss': False, 56 | 'disable_at_eval': True, 57 | 'render_loss_weight': 0.01 58 | } 59 | def __init__(self, cfg): 60 | self.cfg={**self.default_cfg,**cfg} 61 | super().__init__([f'loss_rgb']) 62 | 63 | def __call__(self, data_pr, data_gt, step, is_train=True, **kwargs): 64 | if not is_train and self.cfg['disable_at_eval']: 65 | return {} 66 | rgb_gt = data_pr['pixel_colors_gt'] # 1,rn,3 67 | rgb_nr = data_pr['pixel_colors_nr'] # 1,rn,3 68 | def compute_loss(rgb_pr,rgb_gt): 69 | loss=torch.sum((rgb_pr-rgb_gt)**2,-1) # b,n 70 | if self.cfg['use_ray_mask']: 71 | ray_mask = data_pr['ray_mask'].float() # 1,rn 72 | loss = torch.sum(loss*ray_mask,1)/(torch.sum(ray_mask,1)+1e-3) 73 | else: 74 | loss = torch.mean(loss, 1) 75 | return loss * self.cfg['render_loss_weight'] 76 | 77 | results = {'loss_rgb_nr': compute_loss(rgb_nr, rgb_gt)} 78 | if self.cfg['use_dr_loss']: 79 | rgb_dr = data_pr['pixel_colors_dr'] # 1,rn,3 80 | results['loss_rgb_dr'] = compute_loss(rgb_dr, rgb_gt) 81 | if self.cfg['use_dr_fine_loss']: 82 | results['loss_rgb_dr_fine'] = compute_loss(data_pr['pixel_colors_dr_fine'], rgb_gt) 83 | if self.cfg['use_nr_fine_loss']: 84 | results['loss_rgb_nr_fine'] = compute_loss(data_pr['pixel_colors_nr_fine'], rgb_gt) 85 | return results 86 | 87 | class DepthLoss(Loss): 88 | default_cfg={ 89 | 'depth_correct_thresh': 0.02, 90 | 'depth_loss_type': 'l2', 91 | 'depth_loss_l1_beta': 0.05, 92 | 'depth_loss_weight': 1, 93 | 'disable_at_eval': True, 94 | } 95 | def __init__(self, cfg): 96 | super().__init__(['loss_depth']) 97 | self.cfg={**self.default_cfg,**cfg} 98 | if self.cfg['depth_loss_type']=='smooth_l1': 99 | self.loss_op=nn.SmoothL1Loss(reduction='none',beta=self.cfg['depth_loss_l1_beta']) 100 | 101 | def __call__(self, data_pr, data_gt, step, is_train=True, **kwargs): 102 | if not is_train and self.cfg['disable_at_eval']: 103 | return {} 104 | if 'true_depth' not in data_gt['ref_imgs_info']: 105 | print('no') 106 | return {'loss_depth': torch.zeros([1], dtype=torch.float32, device=data_pr['pixel_colors_nr'].device)} 107 | coords = data_pr['depth_coords'] # rfn,pn,2 108 | depth_pr = data_pr['depth_mean'] # rfn,pn 109 | depth_maps = data_gt['ref_imgs_info']['true_depth'] # rfn,1,h,w 110 | rfn, _, h, w = depth_maps.shape 111 | depth_gt = interpolate_feats( 112 | depth_maps,coords,h,w,padding_mode='border',align_corners=True)[...,0] # rfn,pn 113 | 114 | # transform to inverse depth coordinate 115 | depth_range = data_gt['ref_imgs_info']['depth_range'] # rfn,2 116 | near, far = -1/depth_range[:,0:1], -1/depth_range[:,1:2] # rfn,1 117 | def process(depth): 118 | depth = torch.clamp(depth, min=1e-5) 119 | depth = -1 / depth 120 | depth = (depth - near) / (far - near) 121 | depth = torch.clamp(depth, min=0, max=1.0) 122 | return depth 123 | depth_gt = process(depth_gt) 124 | 125 | # compute loss 126 | def compute_loss(depth_pr): 127 | if self.cfg['depth_loss_type']=='l2': 128 | loss = (depth_gt - depth_pr)**2 129 | elif self.cfg['depth_loss_type']=='smooth_l1': 130 | loss = self.loss_op(depth_gt, depth_pr) 131 | 132 | if data_gt['scene_name'].startswith('gso'): 133 | depth_maps_noise = data_gt['ref_imgs_info']['depth'] # rfn,1,h,w 134 | depth_aug = interpolate_feats(depth_maps_noise, coords, h, w, padding_mode='border', align_corners=True)[..., 0] # rfn,pn 135 | depth_aug = process(depth_aug) 136 | mask = (torch.abs(depth_aug-depth_gt) 0: 170 | valid_mask = data_gt['ref_imgs_info']['sdf_gt'] != -1.0 171 | outputs['loss_sdf'] = self.loss_fn(data_gt['ref_imgs_info']['sdf_gt'] * valid_mask, data_pr['volume'][0,0] * valid_mask)[None] * self.cfg['loss_sdf_weight'] 172 | if self.cfg['loss_eikonal_weight'] > 0: 173 | outputs['loss_eikonal'] = (data_pr['sdf_gradient_error']).mean()[None] * self.cfg['loss_eikonal_weight'] 174 | if self.cfg['record_s']: 175 | outputs['variance'] = data_pr['s'][None] 176 | if self.cfg['loss_s_weight'] > 0: 177 | outputs['loss_s'] = torch.norm(data_pr['s']).mean()[None] * self.cfg['loss_s_weight'] 178 | return outputs 179 | 180 | class VGNLoss(Loss): 181 | default_cfg={ 182 | 'loss_vgn_weight': 1e-2, 183 | } 184 | def __init__(self, cfg): 185 | super().__init__(['loss_vgn']) 186 | self.cfg={**self.default_cfg,**cfg} 187 | 188 | def _loss_fn(self, y_pred, y, is_train): 189 | label_pred, rotation_pred, width_pred = y_pred 190 | _, label, rotations, width = y 191 | loss_qual = self._qual_loss_fn(label_pred, label) 192 | acc = self._acc_fn(label_pred, label) 193 | loss_rot_raw = self._rot_loss_fn(rotation_pred, rotations) 194 | loss_rot = label * loss_rot_raw 195 | loss_width_raw = 0.01 * self._width_loss_fn(width_pred, width) 196 | loss_width = label * loss_width_raw 197 | loss = loss_qual + loss_rot + loss_width 198 | loss_item = {'loss_vgn': loss.mean()[None] * self.cfg['loss_vgn_weight'], 199 | 'vgn_total_loss':loss.mean()[None],'vgn_qual_loss': loss_qual.mean()[None], 200 | 'vgn_rot_loss': loss_rot.mean()[None], 'vgn_width_loss':loss_width.mean()[None], 201 | 'vgn_qual_acc': acc[None]} 202 | 203 | num = torch.count_nonzero(label) 204 | angle_torch = label * self._angle_error_fn(rotation_pred, rotations, 'torch') 205 | loss_item['vgn_rot_err'] = (angle_torch.sum() / num)[None] if num else torch.zeros((1,),device=label.device) 206 | return loss_item 207 | 208 | def _qual_loss_fn(self, pred, target): 209 | return F.binary_cross_entropy(pred, target, reduction="none") 210 | 211 | def _acc_fn(self, pred, target): 212 | return 100 * (torch.round(pred) == target).float().sum() / target.shape[0] 213 | 214 | def _pr_fn(self, pred, target): 215 | p, r = torchmetrics.functional.precision_recall(torch.round(pred).to(torch.int), target.to(torch.int), 'macro',num_classes=2) 216 | return p[None] * 100, r[None] * 100 217 | 218 | def _rot_loss_fn(self, pred, target): 219 | loss0 = self._quat_loss_fn(pred, target[:, 0]) 220 | loss1 = self._quat_loss_fn(pred, target[:, 1]) 221 | return torch.min(loss0, loss1) 222 | 223 | def _angle_error_fn(self, pred, target, method='torch'): 224 | if method == 'np': 225 | def _angle_error(q1, q2, ): 226 | q1 = pyq.Quaternion(q1[[3,0,1,2]]) 227 | q1 /= q1.norm 228 | q2 = pyq.Quaternion(q2[[3,0,1,2]]) 229 | q2 /= q2.norm 230 | qd = q1.conjugate * q2 231 | qdv = pyq.Quaternion(0, qd.x, qd.y, qd.z) 232 | err = 2 * math.atan2(qdv.norm, qd.w) / math.pi * 180 233 | return min(err, 360 - err) 234 | q1s = pred.detach().cpu().numpy() 235 | q2s = target.detach().cpu().numpy() 236 | err = [] 237 | for q1,q2 in zip(q1s, q2s): 238 | err.append(min(_angle_error(q1, q2[0]), _angle_error(q1, q2[1]))) 239 | return torch.tensor(err, device = pred.device) 240 | elif method == 'torch': 241 | return calc_rot_error_from_qxyzw(pred, target) 242 | else: 243 | raise NotImplementedError 244 | 245 | def _quat_loss_fn(self, pred, target): 246 | return 1.0 - torch.abs(torch.sum(pred * target, dim=1)) 247 | 248 | def _width_loss_fn(self, pred, target): 249 | return F.mse_loss(pred, target, reduction="none") 250 | 251 | def __call__(self, data_pr, data_gt, step, is_train=True, **kwargs): 252 | return self._loss_fn(data_pr['vgn_pred'], data_gt['grasp_info'], is_train) 253 | 254 | name2loss={ 255 | 'render': RenderLoss, 256 | 'depth': DepthLoss, 257 | 'consist': ConsistencyLoss, 258 | 'vgn': VGNLoss, 259 | 'sdf': SDFLoss 260 | } -------------------------------------------------------------------------------- /src/nr/dataset/database.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import glob 3 | import json 4 | import os 5 | import re 6 | from pathlib import Path 7 | import sys 8 | import open3d as o3d 9 | from utils.draw_utils import draw_gripper_o3d 10 | os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" 11 | import cv2 12 | import numpy as np 13 | from skimage.io import imread, imsave 14 | 15 | from asset import VGN_TRAIN_ROOT, VGN_TEST_ROOT, VGN_PILE_TRAIN_CSV, VGN_PACK_TRAIN_CSV, VGN_PILE_TEST_CSV,VGN_PACK_TEST_CSV, VGN_SDF_DIR 16 | 17 | from utils.draw_utils import draw_cube, draw_axis, draw_points, draw_gripper, draw_world_points 18 | sys.path.append("../") 19 | from gd.utils.transform import Rotation, Transform 20 | 21 | class BaseDatabase(abc.ABC): 22 | def __init__(self, database_name): 23 | self.database_name = database_name 24 | 25 | @abc.abstractmethod 26 | def get_image(self, img_id): 27 | pass 28 | 29 | @abc.abstractmethod 30 | def get_K(self, img_id): 31 | pass 32 | 33 | @abc.abstractmethod 34 | def get_pose(self, img_id): 35 | pass 36 | 37 | @abc.abstractmethod 38 | def get_img_ids(self,check_depth_exist=False): 39 | pass 40 | 41 | @abc.abstractmethod 42 | def get_bbox(self, img_id): 43 | pass 44 | 45 | @abc.abstractmethod 46 | def get_depth(self,img_id): 47 | pass 48 | 49 | @abc.abstractmethod 50 | def get_mask(self,img_id): 51 | pass 52 | 53 | @abc.abstractmethod 54 | def get_depth_range(self,img_id): 55 | pass 56 | 57 | class GraspSynDatabase(BaseDatabase): 58 | def __init__(self, database_name): 59 | super().__init__(database_name) 60 | self.debug_save_dir = Path(f'output/nrvgn/{database_name}') 61 | tp, split, scene_type, scene_split, scene_id, background_size = database_name.split('/') 62 | background, size = background_size.split('_') 63 | self.split = split 64 | self.scene_id = scene_id 65 | self.scene_type = scene_type 66 | self.tp = tp 67 | self.downSample = float(size) 68 | tp2wh = { 69 | 'vgn_syn': (640, 360) 70 | } 71 | src_wh = tp2wh[tp] 72 | self.img_wh = (np.array(src_wh) * self.downSample).astype(int) 73 | 74 | root_dir = {'test': { 75 | 'vgn_syn': VGN_TEST_ROOT, 76 | }, 77 | 'train': { 78 | 'vgn_syn': VGN_TRAIN_ROOT, 79 | }, 80 | } 81 | 82 | if tp == 'vgn_syn': 83 | self.root_dir = Path(root_dir[split][tp]) / (scene_type + "_full") / scene_split / scene_id 84 | else: 85 | raise NotImplementedError 86 | 87 | tp2len = {'grasp_syn': 256, 88 | 'vgn_syn':24} 89 | self.depth_img_ids = self.img_ids = list(range(tp2len[tp])) 90 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 91 | 92 | self.K = np.array([[ 93 | 892.62, 94 | 0.0, 95 | 639.5 96 | ], 97 | [ 98 | 0.0, 99 | 892.62, 100 | 359.5 101 | ], 102 | [ 103 | 0.0, 104 | 0.0, 105 | 1.0 106 | ]]) 107 | self.K[:2] = self.K[:2] * self.downSample 108 | if self.tp == 'vgn_syn': 109 | self.K[:2] /= 2 110 | self.poses_ori = np.load(self.root_dir / 'camera_pose.npy') 111 | self.poses = [np.linalg.inv(p @ self.blender2opencv)[:3,:] for p in self.poses_ori] 112 | 113 | 114 | self.depth_thres = { 115 | 'grasp_syn': 1.5, 116 | 'vgn_syn': 0.8, 117 | } 118 | self.fixed_depth_range = [0.2, 0.8] 119 | 120 | tp2bbox3d = {'grasp_syn': [[-0.35, -0.45, 0], 121 | [0.15, 0.05, 0.5]], 122 | 'vgn_syn': [[-0.15, -0.15, -0.05], 123 | [0.15, 0.15, 0.25]]} 124 | self.bbox3d = tp2bbox3d[tp] 125 | 126 | def get_split(self): 127 | return self.split 128 | 129 | def get_image(self, img_id): 130 | img_filename = os.path.join(self.root_dir, 131 | f'rgb/{img_id:04d}.png') 132 | img = imread(img_filename)[:,:,:3] 133 | img = cv2.resize(img, self.img_wh) 134 | #img[self.get_mask(img_id)] = 255 135 | return np.asarray(img, dtype=np.float32) 136 | 137 | def get_K(self, img_id): 138 | return self.K.astype(np.float32).copy() 139 | 140 | def get_pose(self, img_id): 141 | return self.poses[img_id].astype(np.float32).copy() 142 | 143 | def get_img_ids(self,check_depth_exist=False): 144 | if check_depth_exist: return self.depth_img_ids 145 | return self.img_ids 146 | 147 | def get_bbox3d(self, vis=False): 148 | if vis: 149 | img_id = 0 150 | img = self.get_image(img_id) 151 | cRb = self.poses[img_id][:3,:3] 152 | ctb = self.poses[img_id][:3,3] 153 | l = self.bbox3d[1][0] - self.bbox3d[0][0] 154 | img = draw_cube(img, cRb, ctb, self.K, length=l, bias=self.bbox3d[0]) 155 | if not self.debug_save_dir.exists(): 156 | self.debug_save_dir.mkdir(parents=True) 157 | imsave(str(self.debug_save_dir / 'bbox3d.jpg'), img) 158 | return self.bbox3d 159 | 160 | def get_bbox(self, img_id, vis=False): 161 | mask = self.get_mask(img_id,'obj') 162 | xs,ys=np.nonzero(mask) 163 | x_min,x_max=np.min(xs,0),np.max(xs,0) 164 | y_min,y_max=np.min(ys,0),np.max(ys,0) 165 | 166 | if vis: 167 | img = self.get_image(img_id) 168 | img = cv2.rectangle(img, (y_min, x_min), (y_max, x_max), (255,0,0), 2) 169 | 170 | imsave(str(self.debug_save_dir / 'box.jpg'), img) 171 | 172 | return [x_min,x_max,y_min,y_max] 173 | 174 | def _depth_existence(self,img_id): 175 | return True 176 | 177 | def get_depth(self, img_id): 178 | depth_filename = os.path.join(self.root_dir, 179 | f'depth/{img_id:04d}.exr') 180 | depth_h = cv2.imread(depth_filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[:,:,0] 181 | depth_h = cv2.resize(depth_h, self.img_wh, interpolation=cv2.INTER_NEAREST) 182 | 183 | return depth_h 184 | 185 | def get_mask(self, img_id, tp='desk'): 186 | if tp == 'desk': 187 | mask = self.get_depth(img_id) < self.depth_thres[self.tp] 188 | return (mask.astype(np.bool)) 189 | elif tp == 'obj': 190 | mask_filename = os.path.join(self.root_dir, 191 | f'mask/{img_id:04d}.exr') 192 | mask = cv2.imread(mask_filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[:,:,0] 193 | mask = cv2.resize(mask, self.img_wh, interpolation=cv2.INTER_NEAREST) 194 | cv2.imwrite('mask.jpg', mask * 256) 195 | 196 | return ~(mask.astype(np.bool)) 197 | else: 198 | return np.ones((self.img_wh[1], self.img_wh[0])) 199 | 200 | def get_depth_range(self,img_id, fixed=True): 201 | if fixed: 202 | return np.array(self.fixed_depth_range) 203 | depth = self.get_depth(img_id) 204 | nf = [max(0, np.min(depth)), min(self.depth_thres[self.tp], np.max(depth))] 205 | return np.array(nf) 206 | 207 | def get_sdf(self): 208 | sdf_volume = np.load( Path(VGN_SDF_DIR) / f'{self.scene_id}.npz')['grid'][0] 209 | return sdf_volume * 2 - 1 210 | 211 | class VGNSynDatabase(GraspSynDatabase): 212 | def __init__(self, database_name): 213 | super().__init__(database_name) 214 | split = self.get_split() 215 | 216 | if self.scene_type == 'packed': 217 | csv = VGN_PACK_TEST_CSV if split == 'test' else VGN_PACK_TRAIN_CSV 218 | elif self.scene_type == 'pile': 219 | csv = VGN_PILE_TEST_CSV if split == 'test' else VGN_PILE_TRAIN_CSV 220 | else: 221 | return 222 | 223 | self.df = csv 224 | self.df = self.df[self.df["scene_id"] == self.scene_id] 225 | assert len(self.df) > 0, f"empty grasping info {database_name}" 226 | 227 | def visualize_grasping(self, pos, rot, w, label=None, img_id=3,save_img=False): 228 | voxel_size = 0.3 / 40 229 | pts_w = pos * voxel_size 230 | width = w * voxel_size 231 | 232 | img = self.get_image(img_id) 233 | 234 | t = np.array([[-0.15, -0.15, -0.05]]).repeat(pts_w.shape[0], axis=0) 235 | pts_b = pts_w + t 236 | 237 | cRb = self.poses[img_id][:3,:3] 238 | ctb = self.poses[img_id][:3,3] # + np.array([-0.15, -0.15, -0.05]) 239 | 240 | for gid in range(pts_w.shape[0]): 241 | if label is not None and label[gid] == 0: 242 | continue 243 | btg = pts_b[gid] 244 | wRg = rot[gid] 245 | bRg = wRg 246 | bTg = np.eye(4) 247 | bTg[:3,:3] = bRg 248 | bTg[:3,3] = btg 249 | cTb = self.poses[img_id] 250 | cTg = cTb @ bTg 251 | img = draw_gripper(img, cTg[:3,:3], cTg[:3,3], self.K, width[gid], 2) 252 | img = draw_world_points(img, pts_b[gid], cRb, ctb, self.K) 253 | 254 | if save_img: 255 | save_dir = str(self.debug_save_dir / f'gripper_test-{img_id}.jpg') 256 | print("save to", save_dir) 257 | imsave(save_dir, img) 258 | return img 259 | 260 | def visualize_grasping_3d(self, pos, rot, w, label=None, voxel_size = 0.3 / 40): 261 | pts_w = pos * voxel_size 262 | width = w * voxel_size 263 | 264 | geometry = o3d.geometry.TriangleMesh() 265 | for gid in range(pts_w.shape[0]): 266 | if label is not None and label[gid] == 0: 267 | continue 268 | wRg = rot[gid] 269 | y_ccw_90 = np.array([[0, 0, -1], [0, 1,0], [1, 0, 0]]) 270 | _R = wRg @ y_ccw_90 271 | _t = pts_w[gid] 272 | 273 | geometry_gripper = draw_gripper_o3d(_R, _t, width[gid]) 274 | geometry += geometry_gripper 275 | 276 | o3d.io.write_triangle_mesh(str(self.debug_save_dir / f'gripper.ply'), geometry) 277 | 278 | def get_grasp_info(self): 279 | pos = self.df[["i","j","k"]].to_numpy(np.single) 280 | index = np.round(pos).astype(np.long) 281 | l = pos.shape[0] 282 | width = self.df[["width"]].to_numpy(np.single).reshape(l) 283 | label = self.df[["label"]].to_numpy(np.float32).reshape(l) 284 | rotations = np.empty((l, 2, 4), dtype=np.single) 285 | q = self.df[["qx","qy","qz","qw"]].to_numpy(np.single) 286 | ori = Rotation.from_quat(q) 287 | R = Rotation.from_rotvec(np.pi * np.r_[0.0, 0.0, 1.0]) 288 | rotations[:,0] = ori.as_quat() 289 | rotations[:,1] = (ori * R).as_quat() 290 | 291 | # for i in range(4): 292 | # self.visualize_grasping(pos, ori.as_matrix(), width, label, i) 293 | # exit() 294 | return (index, label, rotations, width) 295 | 296 | 297 | def parse_database_name(database_name:str)->BaseDatabase: 298 | name2database={ 299 | 'vgn_syn': VGNSynDatabase, 300 | } 301 | database_type = database_name.split('/')[0] 302 | if database_type in name2database: 303 | return name2database[database_type](database_name) 304 | else: 305 | raise NotImplementedError 306 | 307 | def get_database_split(database: BaseDatabase, split_type='val'): 308 | database_name = database.database_name 309 | if split_type.startswith('val'): 310 | splits = split_type.split('_') 311 | depth_valid = not(len(splits)>1 and splits[1]=='all') 312 | if database_name.startswith('vgn'): 313 | val_ids = database.get_img_ids()[2:24:8]# TODO 314 | train_ids = [img_id for img_id in database.get_img_ids(check_depth_exist=depth_valid) if img_id not in val_ids] 315 | else: 316 | raise NotImplementedError 317 | elif split_type.startswith('test'): 318 | splits = split_type.split('_') 319 | depth_valid = not(len(splits)>1 and splits[1]=='all') 320 | if database_name.startswith('vgn'): 321 | val_ids = database.get_img_ids()[2:24:8] + [0]# TODO 322 | train_ids = [img_id for img_id in database.get_img_ids(check_depth_exist=depth_valid) if img_id not in val_ids] 323 | else: 324 | raise NotImplementedError 325 | else: 326 | raise NotImplementedError 327 | return train_ids, val_ids -------------------------------------------------------------------------------- /src/nr/network/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=dilation, groups=groups, bias=False, dilation=dilation, padding_mode='reflect') 9 | 10 | def conv1x1(in_planes, out_planes, stride=1): 11 | """1x1 convolution""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False, padding_mode='reflect') 13 | 14 | def interpolate_feats(feats, points, h=None, w=None, padding_mode='zeros', align_corners=False, inter_mode='bilinear'): 15 | """ 16 | 17 | :param feats: b,f,h,w 18 | :param points: b,n,2 19 | :param h: float 20 | :param w: float 21 | :param padding_mode: 22 | :param align_corners: 23 | :param inter_mode: 24 | :return: 25 | """ 26 | b, _, ch, cw = feats.shape 27 | if h is None and w is None: 28 | h, w = ch, cw 29 | x_norm = points[:, :, 0] / (w - 1) * 2 - 1 30 | y_norm = points[:, :, 1] / (h - 1) * 2 - 1 31 | points_norm = torch.stack([x_norm, y_norm], -1).unsqueeze(1) # [srn,1,n,2] 32 | feats_inter = F.grid_sample(feats, points_norm, mode=inter_mode, padding_mode=padding_mode, align_corners=align_corners).squeeze(2) # srn,f,n 33 | feats_inter = feats_inter.permute(0,2,1) 34 | return feats_inter 35 | 36 | def masked_mean_var(feats,mask,dim=2): 37 | mask=mask.float() # b,1,n,1 38 | mask_sum = torch.clamp_min(torch.sum(mask,dim,keepdim=True),min=1e-4) # b,1,1,1 39 | feats_mean = torch.sum(feats*mask,dim,keepdim=True)/mask_sum # b,f,1,1 40 | feats_var = torch.sum((feats-feats_mean)**2*mask,dim,keepdim=True)/mask_sum # b,f,1,1 41 | return feats_mean, feats_var 42 | 43 | class ResidualBlock(nn.Module): 44 | def __init__(self, dim_in, dim_out, dim_inter=None, use_norm=True, norm_layer=nn.BatchNorm2d,bias=False): 45 | super().__init__() 46 | if dim_inter is None: 47 | dim_inter=dim_out 48 | 49 | if use_norm: 50 | self.conv=nn.Sequential( 51 | norm_layer(dim_in), 52 | nn.ReLU(True), 53 | nn.Conv2d(dim_in,dim_inter,3,1,1,bias=bias,padding_mode='reflect'), 54 | norm_layer(dim_inter), 55 | nn.ReLU(True), 56 | nn.Conv2d(dim_inter,dim_out,3,1,1,bias=bias,padding_mode='reflect'), 57 | ) 58 | else: 59 | self.conv=nn.Sequential( 60 | nn.ReLU(True), 61 | nn.Conv2d(dim_in,dim_inter,3,1,1), 62 | nn.ReLU(True), 63 | nn.Conv2d(dim_inter,dim_out,3,1,1), 64 | ) 65 | 66 | self.short_cut=None 67 | if dim_in!=dim_out: 68 | self.short_cut=nn.Conv2d(dim_in,dim_out,1,1) 69 | 70 | def forward(self, feats): 71 | feats_out=self.conv(feats) 72 | if self.short_cut is not None: 73 | feats_out=self.short_cut(feats)+feats_out 74 | else: 75 | feats_out=feats_out+feats 76 | return feats_out 77 | 78 | class AddBias(nn.Module): 79 | def __init__(self,val): 80 | super().__init__() 81 | self.val=val 82 | 83 | def forward(self,x): 84 | return x+self.val 85 | 86 | class BasicBlock(nn.Module): 87 | expansion = 1 88 | 89 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 90 | base_width=64, dilation=1, norm_layer=None): 91 | super(BasicBlock, self).__init__() 92 | if norm_layer is None: 93 | norm_layer = nn.BatchNorm2d 94 | if groups != 1 or base_width != 64: 95 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 96 | if dilation > 1: 97 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 98 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 99 | self.conv1 = conv3x3(inplanes, planes, stride) 100 | self.bn1 = norm_layer(planes, track_running_stats=False, affine=True) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.conv2 = conv3x3(planes, planes) 103 | self.bn2 = norm_layer(planes, track_running_stats=False, affine=True) 104 | self.downsample = downsample 105 | self.stride = stride 106 | 107 | def forward(self, x): 108 | identity = x 109 | 110 | out = self.conv1(x) 111 | out = self.bn1(out) 112 | out = self.relu(out) 113 | 114 | out = self.conv2(out) 115 | out = self.bn2(out) 116 | 117 | if self.downsample is not None: 118 | identity = self.downsample(x) 119 | 120 | out += identity 121 | out = self.relu(out) 122 | 123 | return out 124 | 125 | class conv(nn.Module): 126 | def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): 127 | super(conv, self).__init__() 128 | self.kernel_size = kernel_size 129 | self.conv = nn.Conv2d(num_in_layers, 130 | num_out_layers, 131 | kernel_size=kernel_size, 132 | stride=stride, 133 | padding=(self.kernel_size - 1) // 2, 134 | padding_mode='reflect') 135 | self.bn = nn.InstanceNorm2d(num_out_layers, track_running_stats=False, affine=True) 136 | 137 | def forward(self, x): 138 | return F.elu(self.bn(self.conv(x)), inplace=True) 139 | 140 | class upconv(nn.Module): 141 | def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): 142 | super(upconv, self).__init__() 143 | self.scale = scale 144 | self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1) 145 | 146 | def forward(self, x): 147 | x = nn.functional.interpolate(x, scale_factor=self.scale, align_corners=True, mode='bilinear') 148 | return self.conv(x) 149 | 150 | class ResUNetLight(nn.Module): 151 | def __init__(self, in_dim=3, layers=(2, 3, 6, 3), out_dim=32, inplanes=32): 152 | super(ResUNetLight, self).__init__() 153 | # layers = [2, 3, 6, 3] 154 | norm_layer = nn.InstanceNorm2d 155 | self._norm_layer = norm_layer 156 | self.dilation = 1 157 | block = BasicBlock 158 | replace_stride_with_dilation = [False, False, False] 159 | self.inplanes = inplanes 160 | self.groups = 1 # seems useless 161 | self.base_width = 64 # seems useless 162 | self.conv1 = nn.Conv2d(in_dim, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False, 163 | padding_mode='reflect') 164 | self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True) 165 | self.relu = nn.ReLU(inplace=True) 166 | self.layer1 = self._make_layer(block, 32, layers[0], stride=2) 167 | self.layer2 = self._make_layer(block, 64, layers[1], stride=2, 168 | dilate=replace_stride_with_dilation[0]) 169 | self.layer3 = self._make_layer(block, 128, layers[2], stride=2, 170 | dilate=replace_stride_with_dilation[1]) 171 | 172 | # decoder 173 | self.upconv3 = upconv(128, 64, 3, 2) 174 | self.iconv3 = conv(64 + 64, 64, 3, 1) 175 | self.upconv2 = upconv(64, 32, 3, 2) 176 | self.iconv2 = conv(32 + 32, 32, 3, 1) 177 | 178 | # fine-level conv 179 | self.out_conv = nn.Conv2d(32, out_dim, 1, 1) 180 | 181 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 182 | norm_layer = self._norm_layer 183 | downsample = None 184 | previous_dilation = self.dilation 185 | if dilate: 186 | self.dilation *= stride 187 | stride = 1 188 | if stride != 1 or self.inplanes != planes * block.expansion: 189 | downsample = nn.Sequential( 190 | conv1x1(self.inplanes, planes * block.expansion, stride), 191 | norm_layer(planes * block.expansion, track_running_stats=False, affine=True), 192 | ) 193 | 194 | layers = [] 195 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 196 | self.base_width, previous_dilation, norm_layer)) 197 | self.inplanes = planes * block.expansion 198 | for _ in range(1, blocks): 199 | layers.append(block(self.inplanes, planes, groups=self.groups, 200 | base_width=self.base_width, dilation=self.dilation, 201 | norm_layer=norm_layer)) 202 | 203 | return nn.Sequential(*layers) 204 | 205 | def skipconnect(self, x1, x2): 206 | diffY = x2.size()[2] - x1.size()[2] 207 | diffX = x2.size()[3] - x1.size()[3] 208 | 209 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, 210 | diffY // 2, diffY - diffY // 2)) 211 | x = torch.cat([x2, x1], dim=1) 212 | return x 213 | 214 | def forward(self, x): 215 | x = self.relu(self.bn1(self.conv1(x))) 216 | 217 | x1 = self.layer1(x) 218 | x2 = self.layer2(x1) 219 | x3 = self.layer3(x2) 220 | 221 | x = self.upconv3(x3) 222 | x = self.skipconnect(x2, x) 223 | x = self.iconv3(x) 224 | 225 | x = self.upconv2(x) 226 | x = self.skipconnect(x1, x) 227 | x = self.iconv2(x) 228 | 229 | x_out = self.out_conv(x) 230 | return x_out 231 | 232 | class ResEncoder(nn.Module): 233 | def __init__(self): 234 | super(ResEncoder, self).__init__() 235 | self.inplanes = 32 236 | filters = [32, 64, 128] 237 | layers = [2, 2, 2, 2] 238 | out_planes = 32 239 | 240 | norm_layer = nn.InstanceNorm2d 241 | self._norm_layer = norm_layer 242 | self.dilation = 1 243 | block = BasicBlock 244 | replace_stride_with_dilation = [False, False, False] 245 | self.groups = 1 246 | self.base_width = 64 247 | 248 | self.conv1 = nn.Conv2d(12, self.inplanes, kernel_size=8, stride=2, padding=2, 249 | bias=False, padding_mode='reflect') 250 | self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True) 251 | self.relu = nn.ReLU(inplace=True) 252 | self.layer1 = self._make_layer(block, filters[0], layers[0], stride=2) 253 | self.layer2 = self._make_layer(block, filters[1], layers[1], stride=2, 254 | dilate=replace_stride_with_dilation[0]) 255 | self.layer3 = self._make_layer(block, filters[2], layers[2], stride=2, 256 | dilate=replace_stride_with_dilation[1]) 257 | 258 | # decoder 259 | self.upconv3 = upconv(filters[2], filters[1], 3, 2) 260 | self.iconv3 = conv(filters[1]*2, filters[1], 3, 1) 261 | self.upconv2 = upconv(filters[1], filters[0], 3, 2) 262 | self.iconv2 = conv(filters[0]*2, out_planes, 3, 1) 263 | self.out_conv = nn.Conv2d(out_planes, out_planes, 1, 1) 264 | 265 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 266 | norm_layer = self._norm_layer 267 | downsample = None 268 | previous_dilation = self.dilation 269 | if dilate: 270 | self.dilation *= stride 271 | stride = 1 272 | if stride != 1 or self.inplanes != planes * block.expansion: 273 | downsample = nn.Sequential( 274 | conv1x1(self.inplanes, planes * block.expansion, stride), 275 | norm_layer(planes * block.expansion, track_running_stats=False, affine=True), 276 | ) 277 | 278 | layers = [] 279 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 280 | self.base_width, 1, norm_layer)) 281 | self.inplanes = planes * block.expansion 282 | for _ in range(1, blocks): 283 | layers.append(block(self.inplanes, planes, groups=self.groups, 284 | base_width=self.base_width, dilation=self.dilation, 285 | norm_layer=norm_layer)) 286 | 287 | return nn.Sequential(*layers) 288 | 289 | def skipconnect(self, x1, x2): 290 | diffY = x2.size()[2] - x1.size()[2] 291 | diffX = x2.size()[3] - x1.size()[3] 292 | 293 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, 294 | diffY // 2, diffY - diffY // 2)) 295 | 296 | # for padding issues, see 297 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 298 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 299 | 300 | x = torch.cat([x2, x1], dim=1) 301 | return x 302 | 303 | def forward(self, x): 304 | x = self.relu(self.bn1(self.conv1(x))) 305 | 306 | x1 = self.layer1(x) 307 | x2 = self.layer2(x1) 308 | x3 = self.layer3(x2) 309 | 310 | x = self.upconv3(x3) 311 | x = self.skipconnect(x2, x) 312 | x = self.iconv3(x) 313 | 314 | x = self.upconv2(x) 315 | x = self.skipconnect(x1, x) 316 | x = self.iconv2(x) 317 | 318 | x_out = self.out_conv(x) 319 | return x_out --------------------------------------------------------------------------------