├── data └── .gitkeep ├── network └── .gitkeep ├── notebooks ├── README.md ├── imitation_expert_demo.ipynb ├── generateHang.ipynb ├── generateObject.ipynb └── evaluation_grasp.ipynb ├── README.md ├── src ├── frame.py ├── evaluation.py ├── functional_object.py ├── data_gen_utils.py ├── dataset.py ├── utils.py ├── vector_object.py ├── simulation_utils.py └── training.py └── visualize_data.ipynb /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /network/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | ## This folder contains the notebook files for data generation, evaluation and sequential manipulation demos. 2 | 3 | In order to run these, our [robotics codebase](https://github.com/MarcToussaint/rai) needs to be installed, an installation guide to which will be provided in the future. 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Visual Constraints (DVC) 2 | 3 | This is a pytorch implimentation of the paper: "Deep Visual Constraints: Neural Implicit Models for Manipulation Planning from Visual Input". 4 | 5 | [[Project Page]](https://sites.google.com/view/deep-visual-constraints) [[Paper]](https://arxiv.org/abs/2112.04812) [[Video]](https://youtu.be/r__mIGTu6Jg) 6 | 7 | ## Requirements 8 | - Pytorch 9 | - Torchvision 10 | - [H5py](https://docs.h5py.org/en/stable/quick.html) 11 | - [Trimesh](https://trimsh.org/trimesh.html) 12 | - Scipy 13 | - Pyglet 14 | - Matplotlib 15 | - Scikit-image 16 | - Tqdm 17 | - Tensorboard 18 | 19 | 20 | ## Instruction 21 | 1. Download [the pretrained network](https://drive.google.com/drive/folders/1RcjmbazIrejv6QT8cSJ9V62KSbA2ip5k?usp=sharing) files into the folder './network' 22 | 2. Download [the dataset](https://drive.google.com/file/d/12Ycx9oJkd8lape1SuQ2k0w75yp0IQ1pF/view?usp=sharing) and extract it to the folder './data' 23 | 3. Run 'visualize_*.ipynb' to visualize data, learend SDFs (& mesh reconstruction), PCAs on learned features, or tasks (optimized grasp/hang poses) 24 | 4. Run 'train_PIFO.ipynb' to train the whole framework 25 | 26 | ## Citation 27 | ``` 28 | @article{ha2022dvc, 29 | title={Deep Visual Constraints: Neural Implicit Models for Manipulation Planning from Visual Input}, 30 | author={Ha, Jung-Su and Driess, Danny and Toussaint, Marc}, 31 | journal={IEEE Robotics and Automation Letters, 2022}, 32 | year={2022} 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /src/frame.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .utils import * 5 | from .functional_object import * 6 | from .feature import * 7 | 8 | from skimage import measure 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 11 | 12 | 13 | class Frame(nn.Module): 14 | def __init__(self, **kwargs): 15 | super(Frame, self).__init__() 16 | 17 | self.backbone = kwargs.get("backbone", None) 18 | self.sdf_head = kwargs.get("sdf_head", None) 19 | self.grasp_head = kwargs.get("grasp_head", None) 20 | self.hanging_head = kwargs.get("hanging_head", None) 21 | 22 | def build_backbone(self, **C): 23 | self.backbone = FunctionalObject(**C) 24 | 25 | def build_sdf_head(self, width): 26 | if width is None: 27 | self.sdf_head = nn.Identity() 28 | else: 29 | layer_list = [nn.Linear(self.backbone.out_dim, width[0]), nn.ReLU(inplace=True)] 30 | for i in range(len(width)-1): 31 | layer_list.extend([ 32 | nn.Linear(width[i], width[i+1]), nn.ReLU(inplace=True) 33 | ]) 34 | layer_list.append(nn.Linear(width[-1], 1)) 35 | self.sdf_head = nn.Sequential(*layer_list) 36 | 37 | 38 | def build_keypoint_head(self, name, width, key_points, sdf_object=False, train_pts=False): 39 | num_points = key_points.shape[0] 40 | in_dim = num_points if sdf_object else self.backbone.out_dim*num_points 41 | layer_list = [ 42 | nn.Linear(in_dim, width[0]), nn.ReLU(inplace=True) 43 | ] 44 | for i in range(len(width)-1): 45 | layer_list.extend([ 46 | nn.Linear(width[i], width[i+1]), nn.ReLU(inplace=True) 47 | ]) 48 | layer_list.append(nn.Linear(width[-1], 1)) 49 | 50 | setattr(self, name+'_head', nn.Sequential(*layer_list)) 51 | head = getattr(self, name+'_head') 52 | head.key_points = nn.parameter.Parameter( 53 | key_points.view(1,1,num_points,3), requires_grad=train_pts) 54 | head.name = name 55 | 56 | def extract_mesh(self, 57 | images=None, 58 | projection_matrices=None, 59 | center=[0,0,0], 60 | scale=.15, 61 | num_grid=50, 62 | sdf_scale=10., 63 | delta=0., 64 | draw=True, 65 | return_sdf=False): 66 | assert self.sdf_head is not None, "sdf_head is not defined!" 67 | 68 | 69 | if images is None: 70 | images = self.backbone.images 71 | else: 72 | self.backbone.encode(images, projection_matrices) 73 | 74 | device = images.device 75 | num_views = images.shape[1] 76 | 77 | F_sdf = SDF_Feature(self) 78 | 79 | dx = center[0]+scale*torch.linspace(-1, 1, num_grid, device=device) 80 | dy = center[1]+scale*torch.linspace(-1, 1, num_grid, device=device) 81 | dz = center[2]+scale*torch.linspace(-1, 1, num_grid, device=device) 82 | grid_x, grid_y, grid_z = torch.meshgrid(dx, dy, dz) 83 | grid_x, grid_y, grid_z = grid_x.flatten(), grid_y.flatten(), grid_z.flatten() 84 | pts = torch.stack([grid_x, grid_y, grid_z], dim=1).unsqueeze(0) # (1, num_grid**3, 3) 85 | 86 | 87 | L = pts.shape[1] 88 | N = num_grid**2 89 | mu = np.zeros((L,1)) 90 | for i in range(L//N): 91 | with torch.no_grad(): 92 | mu[i*N:(i+1)*N] = F_sdf(pts[:,i*N:(i+1)*N,:])[0].view(-1, 1).detach().cpu().numpy()/sdf_scale 93 | mu = mu.reshape((num_grid, num_grid, num_grid)) 94 | vertices, faces, normals, _ = measure.marching_cubes(mu, delta) 95 | vertices = np.array(center).reshape(1,3)-scale + vertices * 2*scale/(num_grid-1) 96 | if draw: 97 | mesh = Poly3DCollection(vertices[faces], 98 | facecolors='w', 99 | edgecolors='k', 100 | linewidths=1, 101 | alpha=0.5) 102 | 103 | fig = plt.figure() 104 | ax = plt.subplot(111, projection='3d') 105 | ax.set_xlim([center[0]-scale, center[0]+scale]) 106 | ax.set_ylim([center[1]-scale, center[1]+scale]) 107 | ax.set_zlim([center[2]-scale, center[2]+scale]) 108 | ax.set_xlabel('x') 109 | ax.set_ylabel('y') 110 | ax.set_zlabel('z') 111 | ax.grid() 112 | ax.add_collection3d(mesh) 113 | plt.tight_layout() 114 | 115 | render_images = images.cpu().squeeze(0) 116 | fig = plt.figure() 117 | for i in range(num_views): 118 | ax = plt.subplot(np.ceil(num_views/5),5,i+1) 119 | ax.imshow(render_images[i,...].permute(1,2,0)) 120 | plt.tight_layout() 121 | plt.show() 122 | 123 | if return_sdf: 124 | return vertices, faces, normals, mu.flatten() 125 | else: 126 | return vertices, faces, normals -------------------------------------------------------------------------------- /src/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | import h5py 4 | 5 | import time 6 | import sys 7 | sys.path.append('../PIFO/rai-fork/rai/ry') 8 | import libry as ry 9 | import numpy as np 10 | 11 | 12 | def inCollision(C, frames1, frames2): 13 | for f1 in frames1: 14 | for f2 in frames2: 15 | y = -C.evalFeature(ry.FS.pairCollision_negScalar, [f1, f2])[0] 16 | if y < 0: 17 | return True 18 | return False 19 | 20 | 21 | def signed_distance(C, frames1, frames2): 22 | dist = np.inf 23 | for f1 in frames1: 24 | for f2 in frames2: 25 | y = -C.evalFeature(ry.FS.pairCollision_negScalar, [f1, f2])[0] 26 | dist = min(dist, y) 27 | return dist 28 | 29 | 30 | ################################################################################ 31 | 32 | def isFeasible_grasp(C): 33 | C2 = ry.Config() 34 | C2.copy(C) 35 | S = C2.simulation(ry.SimulatorEngine.bullet, 0) 36 | S.setGravity(np.zeros(3)) 37 | S.closeGripper("gripper", objFrameName="mug") 38 | tau = 0.01 39 | for _ in range(200): 40 | S.step([], tau, ry.ControlMode.none) 41 | f = S.getGripperIsGrasping("gripper") 42 | S = 0 43 | return f 44 | 45 | def isFeasible_hang(C, mugFrames, tau=0): 46 | mug = C.frame('mug') 47 | mug_pos0 = np.array(mug.getPosition()) 48 | 49 | vectors = np.array([[0,0,1.], 50 | [0,0,-1.], 51 | [0,1.,0], 52 | [0,-1.,0]]) 53 | 54 | 55 | for v in vectors: 56 | mug_pos = mug_pos0.copy() 57 | cum_dist = 0. 58 | while True: 59 | time.sleep(tau) 60 | dist = signed_distance(C, mugFrames, ['hook']) 61 | if dist < 0.: 62 | mug.setPosition(mug_pos0) 63 | break 64 | 65 | cum_dist += max(abs(dist), 1e-4) 66 | mug_pos = mug_pos0 + cum_dist*v 67 | if cum_dist > 0.3: 68 | mug.setPosition(mug_pos0) 69 | return False 70 | mug.setPosition(mug_pos) 71 | 72 | return True 73 | 74 | ########################################################################################## 75 | 76 | def check_grasp_feasibility(poses, mesh_coll_filename, mass, com): 77 | """ 78 | Args: 79 | poses: (N, 7) poses 80 | mesh_coll_filename, mass, com 81 | Returns: 82 | (N) feasibility 83 | """ 84 | N = poses.shape[0] 85 | feasibility = np.zeros(N) 86 | 87 | C = ry.Config() 88 | C.addFile('gripperWorld.g') 89 | gripperFrames = ['gripper', 90 | 'L_finger', 'L_finger_1', 'L_finger_2', 'L_finger_3', 91 | 'R_finger', 'R_finger_1', 'R_finger_2', 'R_finger_3'] 92 | 93 | mugPos = np.array([0,0,1.]) 94 | C.addMeshFrame(mesh_coll_filename, 'mug', mass=mass, com=com).setPosition(mugPos) 95 | # mug's position has changed because of com 96 | mugFrames = [] 97 | for fname in C.getFrameNames(): 98 | if fname[:3] == 'mug': 99 | C.frame(fname).setContact(1) 100 | mugFrames.append(fname) 101 | 102 | gripper = C.frame('gripper') 103 | gripperCenter = C.frame('gripperCenter') 104 | 105 | m_gripperCenter = C.addFrame('m_gripperCenter').setShape(ry.ST.marker, [0.]) 106 | m_gripper = C.addFrame('m_gripper', 'm_gripperCenter').setShape(ry.ST.marker, [0.]) 107 | m_gripper.setRelativeAffineMatrix(np.linalg.inv(gripperCenter.getRelativeAffineMatrix())) 108 | 109 | S = C.simulation(ry.SimulatorEngine.bullet, 0) 110 | fInit = C.getFrameState() 111 | 112 | for n in range(N): 113 | pose = poses[n] 114 | 115 | m_gripperCenter.setPosition(pose[:3]+mugPos-com)\ 116 | .setQuaternion(pose[3:]) 117 | gripper.setAffineMatrix(m_gripper.getAffineMatrix()) 118 | feasibility[n] = (not inCollision(C, gripperFrames, mugFrames)) and isFeasible_grasp(C) 119 | C.setFrameState(fInit) 120 | 121 | return feasibility 122 | 123 | def check_hang_feasibility(poses, mesh_coll_filename, mass, com): 124 | """ 125 | Args: 126 | poses: (N, 7) poses 127 | mesh_coll_filename, mass, com 128 | Returns: 129 | (N) feasibility 130 | """ 131 | N = poses.shape[0] 132 | feasibility = np.zeros(N) 133 | tau = 0. 134 | 135 | for n in range(N): 136 | C = ry.Config() 137 | hook = C.addFrame('hook').setShape(ry.ST.capsule, [.15*2, .002]).setColor([.4, .7, .4]).setPosition([0,0,0.8]).setRotationRad(np.pi/2, 0, 1, 0) 138 | hook_len = hook.info()['size'][0] 139 | hook_radii = hook.info()['size'][1] 140 | T_hook = hook.getAffineMatrix() 141 | 142 | 143 | mug =C.addMeshFrame(mesh_coll_filename, 'mug', mass=mass, com=com).setPosition([0,0,0.5]) 144 | # mug's position has changed because of com 145 | mugFrames = [] 146 | for fname in C.getFrameNames(): 147 | if fname[:3] == 'mug': 148 | C.frame(fname).setContact(1) 149 | mugFrames.append(fname) 150 | 151 | 152 | pose = poses[n] 153 | T = mug.setPosition(pose[:3]-com).setQuaternion(pose[3:]).getAffineMatrix() 154 | T_mug = T_hook@np.linalg.inv(T) 155 | mug.setAffineMatrix(T_mug) 156 | feasibility[n] = (not inCollision(C, mugFrames, ['hook']))\ 157 | and isFeasible_hang(C, mugFrames, tau) 158 | 159 | # S = C.simulation(ry.SimulatorEngine.bullet, 0) 160 | # tau = 0.01 161 | # for _ in range(500): 162 | # # time.sleep(tau) 163 | # S.step([], tau, ry.ControlMode.none) 164 | # feasibility[n] = (mug.getPosition()[2] > 0.5) 165 | 166 | return feasibility -------------------------------------------------------------------------------- /src/functional_object.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | import numpy as np 5 | from .utils import * 6 | from .dataset import * 7 | from torchvision import transforms 8 | 9 | 10 | class ResBlock(nn.Module): 11 | def __init__(self, in_channel, out_channel): 12 | super().__init__() 13 | 14 | self.residual = nn.Sequential( 15 | nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False), 16 | nn.BatchNorm2d(out_channel), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1, bias=False), 19 | nn.BatchNorm2d(out_channel) 20 | ) 21 | 22 | if out_channel != in_channel: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d(in_channel, out_channel, 1, bias=False), 25 | nn.BatchNorm2d(out_channel) 26 | ) 27 | else: 28 | self.shortcut = nn.Identity() 29 | 30 | self.relu = nn.ReLU(inplace=True) 31 | 32 | 33 | def forward(self, x): 34 | return self.relu(self.residual(x)+self.shortcut(x)) 35 | 36 | 37 | class ResUNet(nn.Module): 38 | def __init__(self, out_dim=64): 39 | super().__init__() 40 | 41 | net = torchvision.models.resnet.resnet34(pretrained=True) 42 | self.pool = net.maxpool 43 | self.layer0 = nn.Sequential(net.conv1, net.bn1, net.relu) 44 | self.layer1 = net.layer1 45 | self.layer2 = net.layer2 46 | self.layer3 = net.layer3 47 | self.layer4 = net.layer4 48 | 49 | channels = [net.bn1.num_features] 50 | for l in [self.layer1, self.layer2, self.layer3, self.layer4]: 51 | channels.append(l[-1].bn2.num_features) 52 | # [64, 64, 128, 256, 512] 53 | 54 | depth = len(channels)-1 55 | self.up_convs = nn.ModuleList( 56 | [nn.Sequential(nn.ConvTranspose2d(channels[i+1], channels[i], 2, 2, bias=False), 57 | nn.BatchNorm2d(channels[i]), 58 | nn.ReLU(inplace=True)) for i in reversed(range(depth))] 59 | ) 60 | self.up_blocks = nn.ModuleList( 61 | [ResBlock(channels[i]*2, channels[i]) for i in reversed(range(depth))] 62 | ) 63 | self.out_layer = nn.Conv2d(channels[0], out_dim, 1) 64 | 65 | 66 | def forward(self, image): 67 | """ 68 | Args: 69 | image: (BxC_inxHxW) tensor of input image 70 | Returns: 71 | list of (BxC_outxHxW) tensors of output features 72 | """ 73 | down_features = [] 74 | x = self.layer0(image) 75 | down_features.append(x) 76 | x = self.layer1(self.pool(x)) 77 | down_features.append(x) 78 | x = self.layer2(x) 79 | down_features.append(x) 80 | x = self.layer3(x) 81 | down_features.append(x) 82 | x = self.layer4(x) 83 | 84 | for up_conv, up_block, down_feature in zip(self.up_convs, self.up_blocks, down_features[::-1]): 85 | x = up_conv(x) 86 | x = torch.cat([x, down_feature], dim=1) 87 | x = up_block(x) 88 | 89 | return self.out_layer(x) 90 | 91 | class FunctionalObject(nn.Module): 92 | def __init__(self, **C): 93 | super().__init__() 94 | self.C = {} 95 | self.C['FEAT_IMG'] = 64 96 | self.C['FEAT_UVZ'] = 32 97 | self.C['WIDTH_LIFTER'] = [256, 128] 98 | self.C['PIXEL_ALIGNED'] = True 99 | self.C.update(C) 100 | self.pixel_aligned = self.C['PIXEL_ALIGNED'] 101 | self.build_modules() 102 | 103 | def build_modules(self, pixel_aligned=True): 104 | if self.pixel_aligned: 105 | self.image_encoder = ResUNet(out_dim=self.C['FEAT_IMG']) 106 | else: 107 | self.image_encoder = torchvision.models.resnet.resnet34(pretrained=True) 108 | num_channels = self.image_encoder.layer4[-1].bn2.num_features 109 | self.image_encoder.fc = nn.Linear(num_channels, self.C['FEAT_IMG']) 110 | 111 | self.uvz_encoder = nn.Sequential( 112 | nn.Linear(3, self.C['FEAT_UVZ']), 113 | nn.ReLU(inplace=True) 114 | ) 115 | 116 | lifter_layers = [ 117 | nn.Linear(self.C['FEAT_IMG']+self.C['FEAT_UVZ'], self.C['WIDTH_LIFTER'][0]), 118 | nn.ReLU(inplace=True) 119 | ] 120 | for i in range(len(self.C['WIDTH_LIFTER']) - 1): 121 | lifter_layers.extend([ 122 | nn.Linear(self.C['WIDTH_LIFTER'][i], self.C['WIDTH_LIFTER'][i+1]), 123 | nn.ReLU(inplace=True) 124 | ]) 125 | self.feature_lifter = nn.Sequential(*lifter_layers) 126 | self.out_dim = self.C['WIDTH_LIFTER'][-1] 127 | 128 | mean=[0.485, 0.456, 0.406] 129 | std=[0.229, 0.224, 0.225] 130 | self.normalizer = transforms.Normalize(mean=mean, std=std) 131 | self.unnormalizer = UnNormalize(mean=mean, std=std) 132 | 133 | 134 | 135 | def forward(self, points, images, projection_matrices): 136 | """ 137 | Args: 138 | points: (B, N, 3) world coordinates of points 139 | images: (B, num_views, C, H, W) input images 140 | projections: (B, num_views, 4, 4) projection matrices for each image 141 | Returns: 142 | (B, num_view, N, Feat) features for each point 143 | """ 144 | 145 | self.encode(images, projection_matrices) 146 | self.features = self.query(points) # (B, num_view, N, Feat) 147 | return self.features 148 | 149 | def encode(self, images, projection_matrices): 150 | """ 151 | Args: 152 | images: (B, num_views, C, H, W) input images 153 | projection_matrices: (B, num_views, 4, 4) projection matrices for each image 154 | """ 155 | if images is None: 156 | return 157 | 158 | self.images = images.clone() 159 | images = self.normalizer(images) 160 | 161 | B, self.num_views, C, H, W = images.shape 162 | images = images.view(B*self.num_views, C, H, W) # (B * num_views, C, H, W) 163 | self.img_features = self.image_encoder(images) # (B * num_views, feat_img, H, W) 164 | 165 | self.projection_matrices = projection_matrices.view(B*self.num_views, 4, 4) 166 | # (B * num_views, 4, 4) 167 | 168 | def query(self, points): 169 | """ 170 | Query the network predictions for each point - should be called after filtering. 171 | Args: 172 | points: (B, N, 3) world space coordinates of points 173 | Returns: 174 | (B, num_view, N, Feat) features for each point 175 | """ 176 | 177 | B, N, _ = points.shape 178 | points = torch.repeat_interleave(points, repeats=self.num_views, dim=0) 179 | # (B * num_views, N, 3) 180 | uv, z = perspective(points, self.projection_matrices) 181 | # (B * num_views, N, 2), (B * num_views, N, 1) 182 | 183 | if self.pixel_aligned: 184 | img_feat = index(self.img_features, uv) # (B * num_views, N, Feat_img) 185 | else: 186 | img_feat = self.img_features.unsqueeze(1).repeat(1,N,1) # (B * num_views, N, Feat_img) 187 | 188 | 189 | uvz_feat = self.uvz_encoder(torch.cat([uv,z], dim=2)) 190 | feat_all = torch.cat([img_feat, uvz_feat], dim=2).view(B, self.num_views, N, -1) 191 | # (B, num_views, N, Feat_all) 192 | 193 | return self.feature_lifter(feat_all) # (B, num_view, N, Feat) -------------------------------------------------------------------------------- /visualize_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "7711b00c", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from src.utils import *\n", 11 | "import torch\n", 12 | "\n", 13 | "import h5py\n", 14 | "\n", 15 | "data_hdf5 = h5py.File('data/train_batch.hdf5', mode='r')\n", 16 | "filename = data_hdf5['object/mesh_filename'][:]\n", 17 | "grasp_pose = data_hdf5['grasp/pose'][:] # (len, N, 7)\n", 18 | "hang_pose = data_hdf5['hang/pose'][:] # (len, N, 7)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 4, 24 | "id": "b6d9d203", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "ename": "KeyboardInterrupt", 29 | "evalue": "", 30 | "output_type": "error", 31 | "traceback": [ 32 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 33 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 34 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# a = view_scene_hang(torch.Tensor(hang_pose[i,::10]), 'data/meshes_coll/'+filename[i].decode())\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mviewer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'gl'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;31m# pose = grasp_pose[:,::10]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;31m# a = view_scene_grasp_batch(pose, np.ones(pose.shape[:2]), [i.decode() for i in filename])\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 35 | "\u001b[0;32m~/miniconda3/envs/torchSource/lib/python3.9/site-packages/trimesh/scene/scene.py\u001b[0m in \u001b[0;36mshow\u001b[0;34m(self, viewer, **kwargs)\u001b[0m\n\u001b[1;32m 1095\u001b[0m \u001b[0;31m# if pyglet is not available\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1096\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSceneViewer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1097\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mSceneViewer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1098\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mviewer\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'notebook'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mscene_to_notebook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 36 | "\u001b[0;32m~/miniconda3/envs/torchSource/lib/python3.9/site-packages/trimesh/viewer/windowed.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, scene, smooth, flags, visible, resolution, start_loop, callback, callback_period, caption, fixed, offset_lines, line_settings, background, window_conf, profile, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m callback_period)\n\u001b[1;32m 198\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstart_loop\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mpyglet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_redraw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 37 | "\u001b[0;32m~/miniconda3/envs/torchSource/lib/python3.9/site-packages/pyglet/app/__init__.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m()\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 106\u001b[0m \"\"\"\n\u001b[0;32m--> 107\u001b[0;31m \u001b[0mevent_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 38 | "\u001b[0;32m~/miniconda3/envs/torchSource/lib/python3.9/site-packages/pyglet/app/base.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhas_exit\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mtimeout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 168\u001b[0;31m \u001b[0mplatform_event_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 169\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_running\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 39 | "\u001b[0;32m~/miniconda3/envs/torchSource/lib/python3.9/site-packages/pyglet/app/xlib.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0;31m# If no devices were ready, wait until one gets ready\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mpending_devices\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 113\u001b[0;31m \u001b[0mpending_devices\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mselect\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mselect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_select_devices\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 114\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mpending_devices\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 40 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "\n", 46 | "for i in range(grasp_pose.shape[0]):\n", 47 | " a = view_scene_grasp(torch.Tensor(grasp_pose[i,::10]), 'data/meshes_coll/'+filename[i].decode())\n", 48 | "# a = view_scene_hang(torch.Tensor(hang_pose[i,::10]), 'data/meshes_coll/'+filename[i].decode())\n", 49 | " a.show(viewer='gl')" 50 | ] 51 | } 52 | ], 53 | "metadata": { 54 | "kernelspec": { 55 | "display_name": "torchSource", 56 | "language": "python", 57 | "name": "torchsource" 58 | }, 59 | "language_info": { 60 | "codemirror_mode": { 61 | "name": "ipython", 62 | "version": 3 63 | }, 64 | "file_extension": ".py", 65 | "mimetype": "text/x-python", 66 | "name": "python", 67 | "nbconvert_exporter": "python", 68 | "pygments_lexer": "ipython3", 69 | "version": "3.9.7" 70 | } 71 | }, 72 | "nbformat": 4, 73 | "nbformat_minor": 5 74 | } 75 | -------------------------------------------------------------------------------- /notebooks/imitation_expert_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 21, 6 | "id": "13d330ae-bbb8-4c99-be45-337dcddaca10", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "from os import path\n", 12 | "import h5py\n", 13 | "import time\n", 14 | "import sys\n", 15 | "\n", 16 | "sys.path.append('../../../rai-fork/rai/ry')\n", 17 | "import libry as ry\n", 18 | "import numpy as np\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "\n", 21 | "from src.simulation_utils import *" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 22, 27 | "id": "17220cd8-952a-418a-ab7c-3dc4f5556ffd", 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "0.10730413538236316\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "load_dir = '../dataGeneration_vF/data/object'\n", 40 | "ind = 2\n", 41 | "\n", 42 | "filename = [fn for fn in os.listdir(load_dir) if fn.endswith('.hdf5')][ind]\n", 43 | " \n", 44 | "data_obj = h5py.File(path.join(load_dir, filename), mode='r')\n", 45 | "filename = data_obj['filename'][()].decode()\n", 46 | "mesh_coll_name = path.join('data/meshes_coll', filename)\n", 47 | "size = data_obj['size'][()]\n", 48 | "print(size)\n", 49 | "mass = data_obj['mass'][()]\n", 50 | "com = data_obj['com'][:]\n", 51 | "data_obj.close()\n", 52 | "\n", 53 | "C = ry.Config()\n", 54 | "C.addFile('world4.g')\n", 55 | "mug = C.addMeshFrame(mesh_coll_name, 'mug', mass=mass, com=com)\n", 56 | "mug.setPosition([0.,-0.2,1.0]).setQuaternion([1,0,0,1.])\n", 57 | "C.selectJoints([j for j in C.getJointNames() if j not in ['L_finger']])\n", 58 | "\n", 59 | "S = C.simulation(ry.SimulatorEngine.bullet, 1)\n", 60 | "for _ in range(500):\n", 61 | " S.step([], 0.01, ry.ControlMode.none)\n", 62 | "fInit = C.getFrameState()\n", 63 | "qInit = C.getJointState()" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 23, 69 | "id": "6b7e6a48-d8ba-4b29-93fc-92aa3eec9c92", 70 | "metadata": { 71 | "tags": [] 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "C.selectJoints([j for j in C.getJointNames() if j not in ['L_finger']])\n", 76 | "V = ry.ConfigurationViewer()\n", 77 | "V.setConfiguration(C)\n", 78 | "\n", 79 | "C.addFrame('graspedObj', 'gripperCenter').setRelativePosition([0,0,-.1])\n", 80 | "\n", 81 | "stepsPerPhase = 10\n", 82 | "komo = C.komo(3., stepsPerPhase, 5., 2, False)\n", 83 | "komo.verbose(3)\n", 84 | "# komo.animateOptimization(True)\n", 85 | "\n", 86 | "Sk = [[1., 3.], ry.SY.stable, [\"gripper\", \"mug\"]]\n", 87 | "komo.addSkeleton(Sk)\n", 88 | "\n", 89 | "komo.add_qControlObjective([], 2)\n", 90 | "komo.add_qControlObjective([], 1)\n", 91 | "\n", 92 | "colls = [\"gripper_coll\", \n", 93 | " \"L_finger_coll1\", \"L_finger_coll2\",\n", 94 | " \"R_finger_coll1\", \"R_finger_coll2\"]\n", 95 | "\n", 96 | "\n", 97 | "komo.addObjective([1.], ry.FS.positionRel, [\"gripperCenter\", \"mug\"], ry.OT.eq, [1e0], target=[0.1,0,0])\n", 98 | "komo.addObjective([1.], ry.FS.vectorZ, [\"gripperCenter\"], ry.OT.eq, [1e0], target=[0,1/np.sqrt(2),1/np.sqrt(2)])\n", 99 | "komo.addObjective([1.], ry.FS.scalarProductXZ, [\"gripperCenter\", \"world\"], ry.OT.eq, [1e0])\n", 100 | "komo.addObjective([.7, 1.], ry.FS.quaternion, [\"gripperCenter\"], ry.OT.eq, [1e0], order=1)\n", 101 | "komo.addObjective([.7, 1.], ry.FS.positionRel, [\"mug\", \"gripperCenter\"], ry.OT.eq, [1e1], target=[0,0,-.1], order=2)\n", 102 | "komo.addObjective([1.], ry.FS.qItself, C.getJointNames(), ry.OT.eq, [1e1], order=1)\n", 103 | "\n", 104 | "\n", 105 | "\n", 106 | "komo.addObjective([1., 2.], ry.FS.vectorZ, [\"mug\"], ry.OT.eq, [1e0], order=1)\n", 107 | "komo.addObjective([2., 3.], ry.FS.position, [\"gripperCenter\"], ry.OT.eq, [1e0], target=[.15, .3, 1.3])\n", 108 | "\n", 109 | "komo.addObjective([3.], ry.FS.vectorZ, [\"mug\"], ry.OT.eq, [1e0], target=[1/np.sqrt(2), 0, -1/np.sqrt(2)])\n", 110 | "komo.addObjective([3.], ry.FS.qItself, C.getJointNames(), ry.OT.eq, [1e1], order=1)\n", 111 | "\n", 112 | "komo.optimize(0.1)\n", 113 | "\n", 114 | "V=komo.view()" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 25, 120 | "id": "e3105eea-b2b1-45cc-88e4-65ca4596cb52", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "def stepTo(S, q, tau=0.1):\n", 125 | " N = int(tau/0.01)\n", 126 | " dq = q - S.get_q()\n", 127 | " for i in range(N):\n", 128 | " S.step(S.get_q()+dq/N, tau/N, ry.ControlMode.position)\n", 129 | " time.sleep(tau/N)\n", 130 | " \n", 131 | "RealWorld = ry.Config()\n", 132 | "RealWorld.addFile('world4.g')\n", 133 | "mug = RealWorld.addMeshFrame(mesh_coll_name, 'mug', mass=mass, com=com)\n", 134 | "RealWorld.setFrameState(fInit)\n", 135 | "for i in range(10):\n", 136 | " pos = mug.getPosition() + np.array([0,0,.05+.05*i])\n", 137 | " b = RealWorld.addFrame('ball_'+str(i)).setShape(ry.ST.sphere, [0.015]).setColor([.8,.6,.6])\n", 138 | " b.setPosition(pos).setMass(0.00001)#.addAttribute('friction',0.0)\n", 139 | "RealWorld.selectJoints([j for j in RealWorld.getJointNames() if j not in ['L_finger']])\n", 140 | "camera_name_list = ['camera_'+str(i) for i in range(4)]\n", 141 | "camera = RealWorld.cameraView()\n", 142 | "for camera_name in camera_name_list: \n", 143 | " camera.addSensorFromFrame(camera_name)\n", 144 | "\n", 145 | "t = 0\n", 146 | "save_dir = 'screenshots/imitation_ref/'\n", 147 | "\n", 148 | "S = RealWorld.simulation(ry.SimulatorEngine.bullet, 4)\n", 149 | "tau = 0.1\n", 150 | "for _ in range(500):\n", 151 | " S.step([], 0.01, ry.ControlMode.none)\n", 152 | " \n", 153 | "for t in range(stepsPerPhase):\n", 154 | " C.setFrameState(komo.getConfiguration(t))\n", 155 | " q = C.getJointState()\n", 156 | " stepTo(S, q)\n", 157 | " plt.imsave(save_dir+str(t).zfill(3)+'.png', S.getScreenshot()[::-1])\n", 158 | " t += 1\n", 159 | "\n", 160 | " \n", 161 | "S.closeGripper(\"gripper\", speed=3., objFrameName=\"mug\")\n", 162 | "while not (S.getGripperIsGrasping(\"gripper\") or S.getGripperIsClose(\"gripper\")):\n", 163 | " stepTo(S, S.get_q())\n", 164 | " \n", 165 | "plt.imsave(save_dir+str(t).zfill(3)+'.png', S.getScreenshot()[::-1])\n", 166 | "t += 1\n", 167 | "for t in range(stepsPerPhase,2*stepsPerPhase):\n", 168 | " C.setFrameState(komo.getConfiguration(t))\n", 169 | " q = C.getJointState()\n", 170 | " stepTo(S, q, 0.2)\n", 171 | " plt.imsave(save_dir+str(t).zfill(3)+'.png', S.getScreenshot()[::-1])\n", 172 | " t += 1\n", 173 | " \n", 174 | "out1 = get_all_images(RealWorld, \n", 175 | " camera, \n", 176 | " camera_name_list, \n", 177 | " ['mug'], \n", 178 | " r=0.15, )\n", 179 | " \n", 180 | " \n", 181 | "for t in range(2*stepsPerPhase, 3*stepsPerPhase):\n", 182 | " C.setFrameState(komo.getConfiguration(t))\n", 183 | " q = C.getJointState()\n", 184 | " stepTo(S, q, 0.2)\n", 185 | " plt.imsave(save_dir+str(t).zfill(3)+'.png', S.getScreenshot()[::-1])\n", 186 | " t += 1\n", 187 | "\n", 188 | "out2 = get_all_images(RealWorld, \n", 189 | " camera, \n", 190 | " camera_name_list, \n", 191 | " ['mug'], \n", 192 | " r=0.15, )\n", 193 | "\n", 194 | "for i in range(10):\n", 195 | " for _ in range(10): S.step([], 0.01, ry.ControlMode.none)\n", 196 | " plt.imsave(save_dir+str(t).zfill(3)+'.png', S.getScreenshot()[::-1])\n", 197 | " t += 1" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 16, 203 | "id": "3c69358b-d7ba-4de0-a053-44bb36c94f03", 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "with h5py.File('target1.hdf5', mode='w') as f:\n", 208 | " f.create_dataset('rgb', data=out1[2])\n", 209 | " f.create_dataset('projection', data=out1[3])\n", 210 | " f.create_dataset('obj_pos', data=out1[4])\n", 211 | " f.create_dataset('obj_r', data=out1[5])\n", 212 | " \n", 213 | "with h5py.File('target2.hdf5', mode='w') as f:\n", 214 | " f.create_dataset('rgb', data=out2[2])\n", 215 | " f.create_dataset('projection', data=out2[3])\n", 216 | " f.create_dataset('obj_pos', data=out2[4])\n", 217 | " f.create_dataset('obj_r', data=out2[5])" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 26, 223 | "id": "1610b1af-9f3a-4c5e-b0ff-b3081179a654", 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "data": { 228 | "text/plain": [ 229 | "0" 230 | ] 231 | }, 232 | "execution_count": 26, 233 | "metadata": {}, 234 | "output_type": "execute_result" 235 | } 236 | ], 237 | "source": [ 238 | "os.system('ffmpeg -r 5 -pattern_type glob -i \\'screenshots/imitation_ref/*.png\\' -c:v libx264 imitation_ref.mp4')" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "id": "a9019527-fe69-4ed9-a4fd-f6819a5efbbd", 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [] 248 | } 249 | ], 250 | "metadata": { 251 | "kernelspec": { 252 | "display_name": "torchSource", 253 | "language": "python", 254 | "name": "torchsource" 255 | }, 256 | "language_info": { 257 | "codemirror_mode": { 258 | "name": "ipython", 259 | "version": 3 260 | }, 261 | "file_extension": ".py", 262 | "mimetype": "text/x-python", 263 | "name": "python", 264 | "nbconvert_exporter": "python", 265 | "pygments_lexer": "ipython3", 266 | "version": "3.9.7" 267 | } 268 | }, 269 | "nbformat": 4, 270 | "nbformat_minor": 5 271 | } 272 | -------------------------------------------------------------------------------- /notebooks/generateHang.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "68d6fec3", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import trimesh\n", 11 | "\n", 12 | "import os\n", 13 | "from os import path\n", 14 | "import h5py\n", 15 | "\n", 16 | "from src.data_gen_utils import *\n", 17 | "\n", 18 | "import time\n", 19 | "import sys\n", 20 | "sys.path.append('../../../rai-fork/rai/ry')\n", 21 | "from scipy.spatial.transform import Rotation\n", 22 | "\n", 23 | "import libry as ry\n", 24 | "import multiprocessing as mp" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 6, 30 | "id": "89fa7202", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "def inCollision(C, mugFrames):\n", 35 | " for name in mugFrames:\n", 36 | " y = -C.evalFeature(ry.FS.pairCollision_negScalar, ['hook', name])[0]\n", 37 | " if y < 0:\n", 38 | " return True\n", 39 | " return False\n", 40 | "\n", 41 | "def distance(C, mugFrames):\n", 42 | " dist = np.inf\n", 43 | " for name in mugFrames:\n", 44 | " y = -C.evalFeature(ry.FS.pairCollision_negScalar, ['hook', name])[0]\n", 45 | " dist = min(dist, y)\n", 46 | " return dist\n", 47 | " \n", 48 | "\n", 49 | "def isFeasible(C, mugFrames, size):\n", 50 | " mug = C.frame('mug')\n", 51 | " mug_pos0 = mug.getPosition()\n", 52 | " \n", 53 | " ## up\n", 54 | " mug_pos = mug_pos0.copy()\n", 55 | " cum_dist = 0.\n", 56 | " while True:\n", 57 | " dist = distance(C, mugFrames)\n", 58 | " if dist < 0:\n", 59 | " mug.setPosition(mug_pos0)\n", 60 | " break\n", 61 | " mug_pos[2] += max(abs(dist), 1e-4)\n", 62 | " cum_dist += max(abs(dist), 1e-4)\n", 63 | " if cum_dist > 2*size:\n", 64 | " mug.setPosition(mug_pos0)\n", 65 | " return False\n", 66 | " mug.setPosition(mug_pos)\n", 67 | " \n", 68 | " ## down\n", 69 | " mug_pos = mug_pos0.copy()\n", 70 | " cum_dist = 0.\n", 71 | " while True:\n", 72 | " dist = distance(C, mugFrames)\n", 73 | " if dist < 0:\n", 74 | " mug.setPosition(mug_pos0)\n", 75 | " break\n", 76 | " mug_pos[2] -= max(abs(dist), 1e-4)\n", 77 | " cum_dist += max(abs(dist), 1e-4)\n", 78 | " if cum_dist > 2*size:\n", 79 | " mug.setPosition(mug_pos0)\n", 80 | " return False\n", 81 | " mug.setPosition(mug_pos)\n", 82 | " \n", 83 | " ## left\n", 84 | " mug_pos = mug_pos0.copy()\n", 85 | " cum_dist = 0.\n", 86 | " while True:\n", 87 | " dist = distance(C, mugFrames)\n", 88 | " if dist < 0:\n", 89 | " mug.setPosition(mug_pos0)\n", 90 | " break\n", 91 | " mug_pos[1] += max(abs(dist), 1e-4)\n", 92 | " cum_dist += max(abs(dist), 1e-4)\n", 93 | " if cum_dist > 2*size:\n", 94 | " mug.setPosition(mug_pos0)\n", 95 | " return False\n", 96 | " mug.setPosition(mug_pos)\n", 97 | " \n", 98 | " ## right\n", 99 | " mug_pos = mug_pos0.copy()\n", 100 | " cum_dist = 0.\n", 101 | " while True:\n", 102 | " dist = distance(C, mugFrames)\n", 103 | " if dist < 0:\n", 104 | " mug.setPosition(mug_pos0)\n", 105 | " break\n", 106 | " mug_pos[1] -= max(abs(dist), 1e-4)\n", 107 | " cum_dist += max(abs(dist), 1e-4)\n", 108 | " if cum_dist > 2*size:\n", 109 | " mug.setPosition(mug_pos0)\n", 110 | " return False\n", 111 | " mug.setPosition(mug_pos)\n", 112 | " \n", 113 | " return True\n", 114 | "\n", 115 | "def get_feasible_hang(N, mesh_filename, mesh_coll_filename, mass, com, size, view=False):\n", 116 | " feasible_poses = []\n", 117 | " object_trimesh = trimesh.load(mesh_filename)\n", 118 | " ray_intersector = trimesh.ray.ray_triangle.RayMeshIntersector(object_trimesh)\n", 119 | " \n", 120 | " C = ry.Config()\n", 121 | " \n", 122 | " hook = C.addFrame('hook').setShape(ry.ST.capsule, [.15, .002]).setColor([.4, .7, .4]).setPosition([0,0,0.8]).setRotationRad(np.pi/2, 0, 1, 0)\n", 123 | " hook_len = hook.info()['size'][0]\n", 124 | " hook_radii = hook.info()['size'][1]\n", 125 | " T_hook = hook.getAffineMatrix()\n", 126 | " \n", 127 | " C.addMeshFrame(mesh_coll_filename, 'mug', mass=mass, com=com).setPosition([0,0,0.5])\n", 128 | " mug = C.frame('mug') # mug's position can be changed because of com\n", 129 | " mugFrames = []\n", 130 | " for fname in C.getFrameNames():\n", 131 | " if fname[:3] == 'mug':\n", 132 | " C.frame(fname).setContact(1)\n", 133 | " mugFrames.append(fname)\n", 134 | " \n", 135 | " if view: \n", 136 | " V = ry.ConfigurationViewer()\n", 137 | " V.setConfiguration(C)\n", 138 | "\n", 139 | " pos_prev, vectorZ_prev = None, None\n", 140 | " while True:\n", 141 | " # sample a collision free point & ray\n", 142 | " while True:\n", 143 | " \n", 144 | " if pos_prev is None:\n", 145 | " pos = (np.random.rand(3)-0.5)*2*(size+hook_len/2)\n", 146 | " else:\n", 147 | " pos = pos_prev + np.random.randn(3)*0.01\n", 148 | " \n", 149 | " if trimesh.proximity.signed_distance(object_trimesh, pos.reshape(1,3))[0] < -hook_radii:\n", 150 | " if vectorZ_prev is None:\n", 151 | " vectorZ = np.random.randn(3)\n", 152 | " else:\n", 153 | " vectorZ = pos_prev + np.random.randn(3)*0.1\n", 154 | " \n", 155 | " vectorZ /= np.linalg.norm(vectorZ)\n", 156 | " intersect = ray_intersector.intersects_any(ray_origins = np.stack([pos, pos], axis=0),\n", 157 | " ray_directions = np.stack([vectorZ, -vectorZ], axis=0))\n", 158 | " if not intersect[0] and not intersect[1]:\n", 159 | " break\n", 160 | " \n", 161 | " vectorX = np.random.randn(3)\n", 162 | " vectorX -= np.dot(vectorX, vectorZ)*vectorZ\n", 163 | " vectorX /= np.linalg.norm(vectorX)\n", 164 | " rot = np.stack([vectorX, np.cross(vectorZ, vectorX), vectorZ], axis=1)\n", 165 | " T = np.eye(4)\n", 166 | " T[:3,:3] = rot\n", 167 | " T[:3,3] = pos-com\n", 168 | " \n", 169 | " \n", 170 | " T_mug = T_hook@np.linalg.inv(T)\n", 171 | " mug.setAffineMatrix(T_mug)\n", 172 | "# V.setConfiguration(C)\n", 173 | " if not inCollision(C, mugFrames) and isFeasible(C, mugFrames, size):\n", 174 | " quat = Rotation.from_matrix(rot).as_quat()[[3,0,1,2]]\n", 175 | " pose = np.hstack([pos, quat])\n", 176 | " feasible_poses.append(pose)\n", 177 | " pos_prev, vectorZ_prev = pos, vectorZ\n", 178 | " if view:\n", 179 | " V.setConfiguration(C)\n", 180 | "# input()\n", 181 | " \n", 182 | " if len(feasible_poses) == N:\n", 183 | " return np.array(feasible_poses)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "id": "ed418bf3", 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "547b3284875a6ce535deb3b0c692a2a.hdf5\n" 197 | ] 198 | } 199 | ], 200 | "source": [ 201 | "load_dir = 'data/object'\n", 202 | "mesh_dir = 'data/meshes'\n", 203 | "save_dir = 'data/hang'\n", 204 | "\n", 205 | "def start_process():\n", 206 | " print('Starting', mp.current_process().name)\n", 207 | "\n", 208 | " \n", 209 | "def saveHangData(filename):\n", 210 | " print(filename)\n", 211 | " \n", 212 | " data_obj = h5py.File(path.join(load_dir, filename), mode='r')\n", 213 | " name = data_obj['filename'][()].decode()\n", 214 | " size = data_obj['size'][()]\n", 215 | " mass = data_obj['mass'][()]\n", 216 | " com = data_obj['com'][:]\n", 217 | " data_obj.close()\n", 218 | " \n", 219 | " mesh_filename = path.join(mesh_dir, name) \n", 220 | " mesh_coll_filename = path.join(mesh_dir+'_coll', name) \n", 221 | " \n", 222 | " feasible_poses = get_feasible_hang(1000, \n", 223 | " mesh_filename, \n", 224 | " mesh_coll_filename, \n", 225 | " mass, \n", 226 | " com, \n", 227 | " size,\n", 228 | " view=False)\n", 229 | " data = h5py.File(path.join(save_dir, filename), mode='w')\n", 230 | " data.create_dataset(\"pose\", data=feasible_poses, dtype=np.float32)\n", 231 | " data.close()\n", 232 | " \n", 233 | " \n", 234 | "filename_list = [fn for fn in os.listdir(load_dir) if fn.endswith('.hdf5')]\n", 235 | "# saveHangData(filename_list[0])\n", 236 | "pool = mp.Pool(processes=31,\n", 237 | " initializer=start_process)\n", 238 | "outputs = pool.map(saveHangData, filename_list)\n", 239 | "\n", 240 | "pool.close() # no more tasks\n", 241 | "pool.join() # wrap up current tasks" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "id": "aef96046", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "id": "4b1777e8", 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "id": "9723a9d2", 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "id": "6ae4aa8a", 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [] 275 | } 276 | ], 277 | "metadata": { 278 | "kernelspec": { 279 | "display_name": "torchSource", 280 | "language": "python", 281 | "name": "torchsource" 282 | }, 283 | "language_info": { 284 | "codemirror_mode": { 285 | "name": "ipython", 286 | "version": 3 287 | }, 288 | "file_extension": ".py", 289 | "mimetype": "text/x-python", 290 | "name": "python", 291 | "nbconvert_exporter": "python", 292 | "pygments_lexer": "ipython3", 293 | "version": "3.9.7" 294 | } 295 | }, 296 | "nbformat": 4, 297 | "nbformat_minor": 5 298 | } 299 | -------------------------------------------------------------------------------- /src/data_gen_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import mesh_to_sdf 3 | 4 | import trimesh 5 | import pyrender 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from scipy.spatial.transform import Rotation 10 | 11 | import sys 12 | sys.path.append('../../../rai-fork/rai/ry') 13 | import libry as ry 14 | 15 | 16 | def angle(v1, v2): 17 | a = np.dot(v1/np.linalg.norm(v1), v2/np.linalg.norm(v2)).clip(-1, 1) 18 | angle = np.arccos(a)*180/np.pi 19 | return angle 20 | 21 | def get_rotation_matrix(angle, axis): 22 | matrix = np.identity(4) 23 | matrix[:3, :3] = Rotation.from_euler(axis, angle).as_matrix() 24 | return matrix 25 | 26 | def get_equidistant_camera_angles(total_num_cam): 27 | indices = np.arange(0, total_num_cam, dtype=float) + 0.5 28 | theta = np.arccos(1 - 2*indices/total_num_cam) 29 | phi = (np.pi * (1 + 5**0.5) * indices) % (2 * np.pi) 30 | return phi, theta 31 | 32 | def get_camera_transform(cam_pos, origin): 33 | rel_pos = cam_pos - origin 34 | theta = np.arctan2(np.linalg.norm(rel_pos[:2]), rel_pos[2]) 35 | phi = np.pi/2 + np.arctan2(rel_pos[1], rel_pos[0]) 36 | distance = np.linalg.norm(rel_pos) 37 | T = get_camera_transform_looking_at_origin(phi, theta, distance) 38 | T[:3,3] += origin 39 | 40 | return T 41 | 42 | def get_camera_transform_looking_at_origin(phi, theta, camera_distance): 43 | camera_transform = np.identity(4) 44 | camera_transform[2, 3] = camera_distance 45 | camera_transform = np.matmul(get_rotation_matrix(theta, axis='x'), camera_transform) 46 | camera_transform = np.matmul(get_rotation_matrix(phi, axis='z'), camera_transform) 47 | return camera_transform 48 | 49 | def get_camera_projection_matrix(cam_distance, r=1.): 50 | fov = 2*np.arcsin(r/cam_distance) 51 | cam_projection_matrix = np.zeros((4,4)) 52 | cam_projection_matrix[0,0] = 1/np.tan(fov/2) 53 | cam_projection_matrix[1,1] = -1/np.tan(fov/2) 54 | cam_projection_matrix[2,2] = -1/r 55 | cam_projection_matrix[2,3] = -cam_distance/r 56 | # cam_projection_matrix[2,2] = -cam_distance/r 57 | # cam_projection_matrix[2,3] = -(cam_distance**2-r**2)/r 58 | cam_projection_matrix[3,2] = -1. 59 | return cam_projection_matrix 60 | 61 | # allows for varying center 62 | # TODO: should revert it back!! 63 | def render_mesh_rai2(mesh_filename, 64 | num_cam = 10, 65 | mu_cam_distance = 1.5, 66 | sig_cam_distance = .3, 67 | obj_radius = .1, 68 | mu_obj_center = np.zeros(3), 69 | sig_obj_center = 0.01, 70 | res = 100, 71 | view = False): 72 | 73 | C = ry.Config() 74 | camera = C.cameraView() 75 | mug = C.addMeshFrame(mesh_filename, 'mug') 76 | quat0 = mug.getQuaternion() 77 | 78 | if view: plt.figure(figsize=(20,20)) 79 | phi, theta = get_equidistant_camera_angles(num_cam) 80 | cam_trans_inv_list = [] 81 | cam_projection_list = [] 82 | rgb_list = [] 83 | for i, (phi, theta) in enumerate(zip(phi,theta)): 84 | # print(i, np.rad2deg(phi), np.rad2deg(theta)) 85 | cam_distance = mu_cam_distance + np.random.randn(1)*sig_cam_distance 86 | camera_transform = get_camera_transform_looking_at_origin(phi, theta, cam_distance) 87 | 88 | obj_center = mu_obj_center + np.random.randn(1)*sig_obj_center 89 | camera_transform[:3,3] += obj_center 90 | P = get_camera_projection_matrix(cam_distance, obj_radius) 91 | 92 | cam_name = 'camera'+str(i) 93 | C.addFrame(cam_name, 'mug').setShape(ry.ST.marker, [0.001])\ 94 | .setRelativeAffineMatrix(camera_transform) 95 | 96 | random_pos = np.array([0,0,.9]) + np.random.randn(3)*np.array([.5,.5,.1]) 97 | random_quat = np.random.randn(4); random_quat /= np.linalg.norm(random_quat) 98 | mug.setQuaternion(random_quat).setPosition(random_pos) # randomizes lighting 99 | 100 | camera.updateConfig(C) 101 | camera.addSensor(cam_name, cam_name, res, res, 0.5*P[0,0]) 102 | rgb, depth = camera.computeImageAndDepth() 103 | mask = camera.extractMask('mug') 104 | mug.setQuaternion(quat0) 105 | rgb *= np.expand_dims(mask,axis=2) 106 | 107 | rgb_list.append(rgb) 108 | cam_trans_inv_list.append(np.linalg.inv(camera_transform)) 109 | cam_projection_list.append(P) 110 | 111 | if view: 112 | ax = plt.subplot(int(np.ceil(num_cam/10)),10,i+1) 113 | ax.imshow(rgb) 114 | 115 | if view: 116 | plt.show() 117 | V = ry.ConfigurationViewer() 118 | V.setConfiguration(C) 119 | input() 120 | 121 | return np.stack(rgb_list), np.stack(cam_trans_inv_list), np.stack(cam_projection_list) 122 | 123 | 124 | def render_mesh_rai(mesh_filename, 125 | num_cam = 10, 126 | cam_distance_center = 15, 127 | mu_cam_distance = 1.5, 128 | sig_cam_distance = .3, 129 | obj_radius = 1., 130 | obj_center = np.zeros(3), 131 | res = 100, 132 | view = False): 133 | 134 | C = ry.Config() 135 | camera = C.cameraView() 136 | mug = C.addMeshFrame(mesh_filename, 'mug') 137 | quat0 = mug.getQuaternion() 138 | 139 | if view: plt.figure(figsize=(20,20)) 140 | phi, theta = get_equidistant_camera_angles(num_cam) 141 | cam_trans_inv_list = [] 142 | cam_projection_list = [] 143 | rgb_list = [] 144 | for i, (phi, theta) in enumerate(zip(phi,theta)): 145 | cam_distance = mu_cam_distance + np.random.randn(1)*sig_cam_distance 146 | camera_transform = get_camera_transform_looking_at_origin(phi, theta, cam_distance) 147 | camera_transform[:3,3] += obj_center 148 | P = get_camera_projection_matrix(cam_distance, obj_radius) 149 | 150 | cam_name = 'camera'+str(i) 151 | C.addFrame(cam_name, 'mug').setShape(ry.ST.marker, [0.0])\ 152 | .setRelativeAffineMatrix(camera_transform) 153 | random_pos = np.array([0,0,1.9]) + np.random.randn(3)*np.array([.5,.5,.1]) 154 | random_quat = np.random.randn(4); random_quat /= np.linalg.norm(random_quat) 155 | mug.setQuaternion(random_quat).setPosition(random_pos) # randomizes lighting 156 | 157 | camera.updateConfig(C) 158 | camera.addSensor(cam_name, cam_name, res, res, 0.5*P[0,0]) 159 | rgb, depth = camera.computeImageAndDepth() 160 | mask = camera.extractMask('mug') 161 | mug.setQuaternion(quat0) 162 | rgb *= np.expand_dims(mask,axis=2) 163 | 164 | rgb_list.append(rgb) 165 | cam_trans_inv_list.append(np.linalg.inv(camera_transform)) 166 | cam_projection_list.append(P) 167 | 168 | if view: 169 | ax = plt.subplot(int(np.ceil(num_cam/10)),10,i+1) 170 | ax.imshow(rgb) 171 | 172 | if view: 173 | plt.show() 174 | # V = ry.ConfigurationViewer() 175 | # V.setConfiguration(C) 176 | # input() 177 | 178 | return np.stack(rgb_list), np.stack(cam_trans_inv_list), np.stack(cam_projection_list) 179 | 180 | 181 | 182 | def render_mesh(mesh_filename, 183 | num_cam = 10, 184 | cam_distance_center = 15, 185 | sig_cam_distance = 3., 186 | obj_radius = 1., 187 | obj_center = np.zeros(3), 188 | res = 100, 189 | view = False): 190 | 191 | obj_trimesh = trimesh.load(mesh_filename) 192 | obj_mesh = pyrender.Mesh.from_trimesh(obj_trimesh) 193 | cam = pyrender.PerspectiveCamera(yfov=2*np.arcsin(obj_radius/cam_distance_center), aspectRatio=1.0) 194 | point_l = pyrender.PointLight(color=np.ones(3), intensity=3*cam_distance_center**2) 195 | 196 | scene = pyrender.Scene(bg_color=np.zeros(3), ambient_light=0.1*np.ones(3)) 197 | obj_node = scene.add(obj_mesh) 198 | 199 | point_l_node = scene.add(point_l) 200 | cam_node = scene.add(cam) 201 | 202 | r = pyrender.OffscreenRenderer(viewport_width=res, viewport_height=res) 203 | 204 | if view: plt.figure(figsize=(20,20)) 205 | phi, theta = get_equidistant_camera_angles(num_cam) 206 | cam_trans_inv_list = [] 207 | cam_projection_list = [] 208 | rgb_list = [] 209 | for i, (phi, theta) in enumerate(zip(phi,theta)): 210 | # print(i, np.rad2deg(phi), np.rad2deg(theta)) 211 | cam_distance = (cam_distance_center + np.random.randn(1)*sig_cam_distance)*obj_radius 212 | camera_transform = get_camera_transform_looking_at_origin(phi, theta, cam_distance) 213 | camera_transform[:3,3] += obj_center 214 | 215 | cam_node.camera.yfov = 2*np.arcsin(obj_radius/cam_distance) 216 | cam_node.matrix = camera_transform 217 | 218 | point_l_node.light.intensity = 3*cam_distance**2 219 | point_l_node.matrix = camera_transform 220 | rgb, depth = r.render(scene) 221 | 222 | rgb_list.append(rgb) 223 | cam_trans_inv_list.append(np.linalg.inv(camera_transform)) 224 | cam_projection_list.append(get_camera_projection_matrix(cam_distance, obj_radius)) 225 | 226 | if view: 227 | ax = plt.subplot(int(np.ceil(num_cam/10)),10,i+1) 228 | ax.imshow(rgb) 229 | 230 | r.delete() 231 | if view: plt.show() 232 | 233 | return np.stack(rgb_list), np.stack(cam_trans_inv_list), np.stack(cam_projection_list) 234 | 235 | 236 | def compute_sdf(obj_trimesh, N=5000, sig=0.01, scale=1., center=np.zeros(3), view=False): 237 | 238 | surface_points = obj_trimesh.sample(N) 239 | # unit_samples = sample_uniform_points_in_unit_sphere(N//5)*scale + center 240 | global_samples = np.random.randn(N//2,3)*scale + center.reshape(1,3) 241 | 242 | points = np.concatenate([ 243 | surface_points + np.random.randn(N,3)*sig*scale, 244 | surface_points + np.random.randn(N,3)*sig*scale*10, 245 | global_samples], axis = 0) 246 | 247 | sdf = trimesh.proximity.signed_distance(obj_trimesh, points) 248 | 249 | if view: 250 | colors = np.zeros(points.shape) 251 | colors[sdf < 0, 2] = 1 252 | colors[sdf > 0, 0] = 1 253 | scene = pyrender.Scene() 254 | scene.add(pyrender.Mesh.from_points(points, colors=colors)) 255 | pyrender.viewer.Viewer(scene, use_raymond_lighting=True, point_size=10.) 256 | 257 | return points, sdf 258 | 259 | def compute_sdf2(obj_trimesh, N=5000, sig=0.01, scale=1., center=np.zeros(3), view=False): 260 | 261 | obj_trimesh.apply_scale(1/scale) 262 | point_cloud = mesh_to_sdf.get_surface_point_cloud(obj_trimesh) 263 | 264 | 265 | surface_points = point_cloud.get_random_surface_points(N) 266 | global_samples = np.random.randn(N//2,3) + center.reshape(1,3) 267 | 268 | points = np.concatenate([ 269 | surface_points + np.random.randn(N,3)*sig, 270 | surface_points + np.random.randn(N,3)*sig*10, 271 | global_samples 272 | ], axis = 0) 273 | 274 | sdf = point_cloud.get_sdf_in_batches(points, 275 | use_depth_buffer=True, 276 | sample_count=10000000) 277 | 278 | if view: 279 | colors = np.zeros(points.shape) 280 | colors[sdf < 0, 2] = 1 281 | colors[sdf > 0, 0] = 1 282 | scene = pyrender.Scene() 283 | scene.add(pyrender.Mesh.from_points(points, colors=colors)) 284 | pyrender.viewer.Viewer(scene, use_raymond_lighting=True, point_size=10.) 285 | 286 | return points*scale, sdf*scale -------------------------------------------------------------------------------- /notebooks/generateObject.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "6de3bc58", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "from os import path\n", 12 | "import h5py\n", 13 | "\n", 14 | "import pybullet as p\n", 15 | "\n", 16 | "import trimesh\n", 17 | "import numpy as np\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "import multiprocessing as mp" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "id": "ae7f8f5c", 26 | "metadata": { 27 | "tags": [] 28 | }, 29 | "outputs": [ 30 | { 31 | "ename": "NameError", 32 | "evalue": "name 'os' is not defined", 33 | "output_type": "error", 34 | "traceback": [ 35 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 36 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 37 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdata_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'data/old/object'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mfilename_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_process\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Starting'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcurrent_process\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 38 | "\u001b[0;31mNameError\u001b[0m: name 'os' is not defined" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "data_dir = 'data/old/object'\n", 44 | "filename_list = os.listdir(data_dir)\n", 45 | "def start_process():\n", 46 | " print('Starting', mp.current_process().name)\n", 47 | "\n", 48 | "def modifyObjectData(filename):\n", 49 | " print(filename)\n", 50 | " \n", 51 | " with h5py.File(path.join(data_dir, filename), mode='r') as data:\n", 52 | " mesh_filename = data['filename'][()].decode()\n", 53 | " old_size = data['size'][()]\n", 54 | " new_size = .1 + np.random.RandomState().rand()*.05 # 20 ~ 30 cm \n", 55 | " \n", 56 | " \n", 57 | " mug_trimesh = trimesh.load(path.join('data/old/meshes', mesh_filename))\n", 58 | " mug_trimesh.apply_scale(new_size/old_size)\n", 59 | " mug_trimesh.export(path.join('data/meshes', mesh_filename))\n", 60 | " \n", 61 | " \n", 62 | " p.vhacd(path.join('data/meshes', mesh_filename), \n", 63 | " path.join('data/meshes_coll', mesh_filename), \n", 64 | " \"log.txt\")\n", 65 | "\n", 66 | " mug_trimesh = trimesh.load(path.join('data/meshes_coll', mesh_filename))\n", 67 | " com = mug_trimesh.center_mass\n", 68 | " mass = mug_trimesh.mass\n", 69 | " \n", 70 | " with h5py.File(path.join('data/object', filename), mode='w') as data:\n", 71 | " data.create_dataset(\"filename\", data=mesh_filename)\n", 72 | " data.create_dataset(\"size\", data=new_size)\n", 73 | " data.create_dataset(\"com\", data=com)\n", 74 | " data.create_dataset(\"mass\", data=mass)\n", 75 | " \n", 76 | " return mass\n", 77 | "\n", 78 | "pool = mp.Pool(processes=7,\n", 79 | " initializer=start_process)\n", 80 | "mass_list = pool.map(modifyObjectData, filename_list)\n", 81 | "pool.close() # no more tasks\n", 82 | "pool.join() # wrap up current tasks\n", 83 | "\n", 84 | "plt.plot(mass_list)\n", 85 | "plt.show()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 11, 91 | "id": "499b3ec8", 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "data": { 96 | "text/plain": [ 97 | "['547b3284875a6ce535deb3b0c692a2a.hdf5',\n", 98 | " 'e9499e4a9f632725d6e865157050a80e.hdf5',\n", 99 | " 'e79d807e1093c6174e716404e6ec3a5f.hdf5',\n", 100 | " 'fad118b32085f3f2c2c72e575af174cd.hdf5',\n", 101 | " 'bed29baf625ce9145b68309557f3a78c.hdf5',\n", 102 | " '1305b9266d38eb4d9f818dd0aa1a251.hdf5',\n", 103 | " '9c930a8a3411f069e7f67f334aa9295c.hdf5',\n", 104 | " '6500ccc65e210b14d829190312080ea3.hdf5',\n", 105 | " '214dbcace712e49de195a69ef7c885a4.hdf5',\n", 106 | " '5c7c4cb503a757147dbda56eabff0c47.hdf5',\n", 107 | " '6e884701bfddd1f71e1138649f4c219.hdf5',\n", 108 | " '79e673336e836d1333becb3a9550cbb1.hdf5',\n", 109 | " '1d18255a04d22794e521eeb8bb14c5b3.hdf5',\n", 110 | " 'c6bc2c9770a59b5ddd195661813efe58.hdf5',\n", 111 | " '599e604a8265cc0a98765d8aa3638e70.hdf5',\n", 112 | " 'f1e439307b834015770a0ff1161fa15a.hdf5',\n", 113 | " 'ca198dc3f7dc0cacec6338171298c66b.hdf5',\n", 114 | " '2d10421716b16580e45ef4135c266a12.hdf5',\n", 115 | " 'b88bcf33f25c6cb15b4f129f868dedb.hdf5',\n", 116 | " '6faf1f04bde838e477f883dde7397db2.hdf5',\n", 117 | " '73b8b6456221f4ea20d3c05c08e26f.hdf5',\n", 118 | " '9961ccbafd59cb03fe36eb2ab5fa00e0.hdf5',\n", 119 | " '40f9a6cc6b2c3b3a78060a3a3a55e18f.hdf5',\n", 120 | " 'b18bf84bcd457ffbc2868ebdda32b945.hdf5',\n", 121 | " '62634df2ad8f19b87d1b7935311a2ed0.hdf5',\n", 122 | " '46955fddcc83a50f79b586547e543694.hdf5',\n", 123 | " '5c48d471200d2bf16e8a121e6886e18d.hdf5',\n", 124 | " '6aec84952a5ffcf33f60d03e1cb068dc.hdf5',\n", 125 | " '5fe74baba21bba7ca4eec1b19b3a18f8.hdf5',\n", 126 | " '8570d9a8d24cb0acbebd3c0c0c70fb03.hdf5',\n", 127 | " 'dfa8a3a0c8a552b62bc8a44b22fcb3b9.hdf5',\n", 128 | " '5d72df6bc7e93e6dd0cd466c08863ebd.hdf5',\n", 129 | " '91f90c3a50410c0dc27effd6fd6b7eb0.hdf5',\n", 130 | " 'f09e51579600cfbb88b651d2e4ea0846.hdf5',\n", 131 | " '15bd6225c209a8e3654b0ce7754570c8.hdf5',\n", 132 | " '7d282cc3cedd40c8b5c4f4801d3aada.hdf5',\n", 133 | " 'a8f7a0edd3edc3299e54b4084dc33544.hdf5',\n", 134 | " 'c82b9f1b98f044fc15cf6e5ad80f2da.hdf5',\n", 135 | " 'c34718bd10e378186c6c61abcbd83e5a.hdf5',\n", 136 | " '414772162ef70ec29109ad7f9c200d62.hdf5',\n", 137 | " 'd75af64aa166c24eacbe2257d0988c9c.hdf5',\n", 138 | " '17952a204c0a9f526c69dceb67157a66.hdf5',\n", 139 | " 'e9bd4ee553eb35c1d5ccc40b510e4bd.hdf5',\n", 140 | " '3143a4accdc23349cac584186c95ce9b.hdf5',\n", 141 | " 'c0c130c04edabc657c2b66248f91b3d8.hdf5',\n", 142 | " '6a9b31e1298ca1109c515ccf0f61e75f.hdf5',\n", 143 | " 'a0c78f254b037f88933dc172307a6bb9.hdf5',\n", 144 | " 'ba10400c108e5c3f54e1b6f41fdd78a.hdf5',\n", 145 | " 'd0a3fdd33c7e1eb040bc4e38b9ba163e.hdf5',\n", 146 | " '7a8ea24474846c5c2f23d8349a133d2b.hdf5',\n", 147 | " '8f6c86feaa74698d5c91ee20ade72edc.hdf5',\n", 148 | " '639a1f7d09d23ea37d70172a29ade99a.hdf5',\n", 149 | " '8b1dca1414ba88cb91986c63a4d7a99a.hdf5',\n", 150 | " '99eaa69cf6fe8811dec712af445786fe.hdf5',\n", 151 | " 'f99e19b8c4a729353deb88581ea8417a.hdf5',\n", 152 | " 'ff1a44e1c1785d618bca309f2c51966a.hdf5',\n", 153 | " '9d8c711750a73b06ad1d789f3b2120d0.hdf5',\n", 154 | " '586e67c53f181dc22adf8abaa25e0215.hdf5',\n", 155 | " '9af98540f45411467246665d3d3724c.hdf5',\n", 156 | " '4b7888feea81219ab5f4a9188bfa0ef6.hdf5',\n", 157 | " 'c2eacc521dd65bf7a1c742bb4ffef210.hdf5',\n", 158 | " 'bea77759a3e5f9037ae0031c221d81a4.hdf5',\n", 159 | " '57f73714cbc425e44ae022a8f6e258a7.hdf5',\n", 160 | " 'b7e705de46ebdcc14af54ba5738cb1c5.hdf5',\n", 161 | " 'c60f62684618cb52a4136492f17b9a59.hdf5',\n", 162 | " '2852b888abae54b0e3523e99fd841f4.hdf5',\n", 163 | " '128ecbc10df5b05d96eaf1340564a4de.hdf5',\n", 164 | " '387b695db51190d3be276203d0b1a33f.hdf5',\n", 165 | " 'edaf960fb6afdadc4cebc4b5998de5d0.hdf5',\n", 166 | " 'ec846432f3ebedf0a6f32a8797e3b9e9.hdf5',\n", 167 | " '9278005254c8db7e95f577622f465c85.hdf5',\n", 168 | " '28f1e7bc572a633cb9946438ed40eeb9.hdf5',\n", 169 | " '52273f4b17505e898ef19a48ac4fcfdf.hdf5',\n", 170 | " '345d3e7252156db8d44ee24d6b5498e1.hdf5',\n", 171 | " '10f6e09036350e92b3f21f1137c3c347.hdf5',\n", 172 | " '85a2511c375b5b32f72755048bac3f96.hdf5',\n", 173 | " '8012f52dd0a4d2f718a93a45bf780820.hdf5',\n", 174 | " 'b811555ccf5ef6c4948fa2daa427fe1f.hdf5',\n", 175 | " '3d1754b7cb46c0ce5c8081810641ef6.hdf5',\n", 176 | " '71995893d717598c9de7b195ccfa970.hdf5',\n", 177 | " '67b9abb424cf22a22d7082a28b056a5.hdf5',\n", 178 | " '92d6394732e6058d4bcbafcc905a9b98.hdf5',\n", 179 | " '4d9764afa3fbeb1b6c69dceb67157a66.hdf5',\n", 180 | " 'a6d9f9ae39728831808951ff5fb582ac.hdf5',\n", 181 | " 'b46e89995f4f9cc5161e440f04bd2a2.hdf5',\n", 182 | " '6dd59cc1130a426571215a0b56898e5e.hdf5',\n", 183 | " '187859d3c3a2fd23f54e1b6f41fdd78a.hdf5',\n", 184 | " 'dcec634f18e12427c2c72e575af174cd.hdf5',\n", 185 | " 'd38295b8d83e8cdec712af445786fe.hdf5',\n", 186 | " '1c9f9e25c654cbca3c71bf3f4dd78475.hdf5',\n", 187 | " '962883677a586bd84a60c1a189046dd1.hdf5',\n", 188 | " '7d6baadd51d0703455da767dfc5b748e.hdf5',\n", 189 | " '2037531c43448c3016329cbc378d2a2.hdf5',\n", 190 | " '4b8b10d03552e0891898dfa8eb8eefff.hdf5',\n", 191 | " '1a97f3c83016abca21d0de04f408950f.hdf5',\n", 192 | " '54f2d6a0b431839c99785666a0fe0255.hdf5',\n", 193 | " '896f1d494bac0ebcdec712af445786fe.hdf5',\n", 194 | " '34869e23f9fdee027528ae0782b54aae.hdf5',\n", 195 | " '44f9c4e1ea3532b8d7b20fded0142d7a.hdf5',\n", 196 | " '8b780e521c906eaf95a4f7ae0be930ac.hdf5',\n", 197 | " '403fb4eb4fc6235adf0c7dbe7f8f4c8e.hdf5',\n", 198 | " 'b4ae56d6638d5338de671f28c83d2dcb.hdf5',\n", 199 | " '6c379385bf0a23ffdec712af445786fe.hdf5',\n", 200 | " '62684ad0321b35189a3501eead248b52.hdf5',\n", 201 | " 'e94e46bc5833f2f5e57b873e4f3ef3a4.hdf5',\n", 202 | " 'daee5cf285b8d210eeb8d422649e5f2b.hdf5',\n", 203 | " 'a637500654ca8d16c97cfc3e8a6b1d16.hdf5',\n", 204 | " 'd7ba704184d424dfd56d9106430c3fe.hdf5',\n", 205 | " 'e984fd7e97c2be347eaeab1f0c9120b7.hdf5',\n", 206 | " 'b6f30c63c946c286cf6897d8875cfd5e.hdf5',\n", 207 | " '336122c3105440d193e42e2720468bf0.hdf5',\n", 208 | " '633379db14d4d2b287dd60af81c93a3c.hdf5',\n", 209 | " 'cf777e14ca2c7a19b4aad3cc5ce7ee8.hdf5',\n", 210 | " '83827973c79ca7631c9ec1e03e401f54.hdf5',\n", 211 | " '46ed9dad0440c043d33646b0990bb4a.hdf5',\n", 212 | " '48e260a614c0fd4434a8988fdcee4fde.hdf5',\n", 213 | " '34ae0b61b0d8aaf2d7b20fded0142d7a.hdf5',\n", 214 | " 'b9f9f5b48ab1153626829c11d9aba173.hdf5',\n", 215 | " '9737c77d3263062b8ca7a0a01bcd55b6.hdf5',\n", 216 | " 'f1c5b9bb744afd96d6e1954365b10b52.hdf5',\n", 217 | " '61c10dccfa8e508e2d66cbf6a91063.hdf5',\n", 218 | " '1eaf8db2dd2b710c7d5b1b70ae595e60.hdf5',\n", 219 | " '43e1cabc5dd2fa91fffc97a61124b1a9.hdf5',\n", 220 | " 'd46b98f63a017578ea456f4bbbc96af9.hdf5',\n", 221 | " '928a383f79698c3fb6d9bc28c8d8a2c4.hdf5',\n", 222 | " 'ea127b5b9ba0696967699ff4ba91a25.hdf5',\n", 223 | " '39361b14ba19303ee42cfae782879837.hdf5',\n", 224 | " 'f7d776fd68b126f23b67070c4a034f08.hdf5',\n", 225 | " 'e6dedae946ff5265a95fb60c110b25aa.hdf5',\n", 226 | " 'c51b79493419eccdc1584fff35347dc6.hdf5',\n", 227 | " '37f56901a07da69dac6b8e58caf61f95.hdf5']" 228 | ] 229 | }, 230 | "execution_count": 11, 231 | "metadata": {}, 232 | "output_type": "execute_result" 233 | } 234 | ], 235 | "source": [ 236 | "filename_list" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "id": "8878a39b", 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [] 246 | } 247 | ], 248 | "metadata": { 249 | "kernelspec": { 250 | "display_name": "Python 3", 251 | "language": "python", 252 | "name": "python3" 253 | }, 254 | "language_info": { 255 | "codemirror_mode": { 256 | "name": "ipython", 257 | "version": 3 258 | }, 259 | "file_extension": ".py", 260 | "mimetype": "text/x-python", 261 | "name": "python", 262 | "nbconvert_exporter": "python", 263 | "pygments_lexer": "ipython3", 264 | "version": "3.9.7" 265 | } 266 | }, 267 | "nbformat": 4, 268 | "nbformat_minor": 5 269 | } 270 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import h5py 4 | import os 5 | from os import path 6 | 7 | import numpy as np 8 | from torchvision import transforms 9 | 10 | from src.utils import * 11 | import matplotlib.pyplot as plt 12 | 13 | import scipy 14 | 15 | class UnNormalize: 16 | def __init__(self, mean, std): 17 | self.mean = mean 18 | self.std = std 19 | 20 | def __call__(self, tensor): 21 | """ 22 | Args: 23 | tensor (Tensor): Tensor image of size (C, H, W) or (B, C, H, W) to be normalized. 24 | Returns: 25 | Tensor: un-normalized image. 26 | """ 27 | dtype = tensor.dtype 28 | mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device) 29 | std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device) 30 | if (std == 0).any(): 31 | raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype)) 32 | if mean.ndim == 1: 33 | mean = mean.view(-1, 1, 1) 34 | if std.ndim == 1: 35 | std = std.view(-1, 1, 1) 36 | tensor.mul_(std).add_(mean) 37 | # The normalize code -> t.sub_(m).div_(s) 38 | return tensor 39 | 40 | 41 | class RandomEraser: 42 | """Erase part of images""" 43 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)): 44 | self.eraser = transforms.RandomErasing(p=p, 45 | scale=scale, 46 | ratio=ratio, 47 | value=0, 48 | inplace=True) 49 | def __call__(self, rgb): 50 | num_views = rgb.shape[0] 51 | for i in range(num_views): 52 | rgb[i] = self.eraser(rgb[i]) 53 | 54 | return rgb 55 | 56 | class PIFODataset(Dataset): 57 | def __init__(self, 58 | filename, 59 | num_views=2, 60 | num_points=300, 61 | num_grasps=100, 62 | num_hangs=100, 63 | grasp_draw_points=torch.eye(3), 64 | hang_draw_points=torch.eye(3), 65 | random_erase=True, 66 | on_gpu_memory=False): 67 | 68 | data_hdf5 = h5py.File(filename, mode='r') 69 | self.filename = data_hdf5['object/mesh_filename'][:] 70 | self.mass = data_hdf5['object/mass'][:] 71 | self.com = data_hdf5['object/com'][:] 72 | 73 | self.rgb = torch.from_numpy(data_hdf5['camera/rgb'][:]).permute(0,1,4,2,3).contiguous().to(torch.float32)/255. 74 | self.cam_extrinsic = torch.from_numpy(data_hdf5['camera/cam_extrinsic'][:]) 75 | self.cam_intrinsic = torch.from_numpy(data_hdf5['camera/cam_intrinsic'][:]) 76 | 77 | self.point = torch.from_numpy(data_hdf5['sdf/point'][:]) 78 | self.sdf = torch.from_numpy(data_hdf5['sdf/sdf'][:]) 79 | 80 | self.grasp_pose = torch.from_numpy(data_hdf5['grasp/pose'][:]) # (len, N, 7) 81 | self.hang_pose = torch.from_numpy(data_hdf5['hang/pose'][:]) # (len, N, 7) 82 | 83 | data_hdf5.close() 84 | 85 | self.grasp_draw_points = grasp_draw_points 86 | self.hang_draw_points = hang_draw_points 87 | 88 | self.device = "cpu" 89 | if on_gpu_memory: 90 | self.to_device("cuda") 91 | 92 | self.total_views = self.rgb.shape[1] 93 | self.total_points = self.point.shape[1] 94 | self.total_grasps = self.grasp_pose.shape[1] 95 | self.total_hangs = self.hang_pose.shape[1] 96 | 97 | self.num_views = num_views 98 | self.num_points = num_points 99 | self.num_grasps = num_grasps 100 | self.num_hangs = num_hangs 101 | 102 | self.random_erase = random_erase 103 | self.random_eraser = RandomEraser(p=0.5, 104 | scale=(0.02, 0.9), 105 | ratio=(0.3, 3.3)) 106 | 107 | 108 | def __len__(self): 109 | return self.filename.shape[0] 110 | 111 | def __getitem__(self, idx): 112 | cam_inds = torch.randperm(self.total_views)[:self.num_views] 113 | rgb = self.rgb[idx, cam_inds] 114 | if self.random_erase: rgb = self.random_eraser(rgb) 115 | 116 | point_inds = torch.randperm(self.total_points)[:self.num_points] 117 | grasp_inds = torch.randperm(self.total_grasps)[:self.num_grasps] 118 | hang_inds = torch.randperm(self.total_hangs)[:self.num_hangs] 119 | 120 | sample = { 121 | 'rgb': rgb, 122 | 'cam_extrinsic': self.cam_extrinsic[idx, cam_inds], 123 | 'cam_intrinsic': self.cam_intrinsic[idx, cam_inds], 124 | 125 | 'points': self.point[idx, point_inds], 126 | 'sdf': self.sdf[idx, point_inds], 127 | 128 | 'grasp_poses': self.grasp_pose[idx, grasp_inds], 129 | 'grasp_poses_all': self.grasp_pose[idx], 130 | 'hang_poses': self.hang_pose[idx, hang_inds], 131 | 'hang_poses_all': self.hang_pose[idx], 132 | 133 | 'filenames': self.filename[idx], 134 | 'masses': self.mass[idx], 135 | 'coms': self.com[idx], 136 | } 137 | return sample 138 | 139 | def show_data(self, idx, image_only=False): 140 | data = self.__getitem__(idx) 141 | 142 | imgs = data['rgb'].cpu().permute(0,2,3,1) 143 | fig = plt.figure(figsize=(10,10)) 144 | for i in range(self.num_views): 145 | ax = plt.subplot(2,self.num_views,i+1) 146 | ax.imshow(imgs[i]) 147 | ax.grid() 148 | 149 | if not image_only: 150 | points = data['points'].unsqueeze(0).repeat(self.num_views,1,1) 151 | projections = data['cam_intrinsic'].bmm(torch.inverse(data['cam_extrinsic'])) 152 | uvAll, z = perspective(points, projections) 153 | sd = data['sdf'].cpu() 154 | for i in range(self.num_views): 155 | ax = plt.subplot(2,self.num_views,self.num_views+i+1) 156 | uv = uvAll.cpu()[i,:,:] 157 | pc = ax.scatter(uv[:,0], -uv[:,1], c = sd, s=30.) 158 | # pc.set_clim([-0.1, 0.1]) 159 | ax.axis('square') 160 | ax.axis([-1,1,-1,1]) 161 | ax.grid() 162 | plt.colorbar(pc) 163 | 164 | fig = plt.figure() 165 | test_color = index(data['rgb'], uvAll).cpu() 166 | print(test_color.shape, uvAll.shape) 167 | for i in range(self.num_views): 168 | ax = plt.subplot(1,self.num_views,i+1) 169 | uv = uvAll.cpu()[i,:,:] 170 | ax.scatter(uv[:,0], -uv[:,1], c = test_color[i,:], s=30.) 171 | ax.axis('square') 172 | ax.axis([-1,1,-1,1]) 173 | ax.grid() 174 | 175 | grasp_poses = data['grasp_poses'] 176 | hang_poses = data['hang_poses'] 177 | 178 | for key_points, poses in zip([self.grasp_draw_points, self.hang_draw_points], 179 | [grasp_poses, hang_poses]): 180 | 181 | num_points = key_points.shape[0] 182 | poses_repeat = poses.unsqueeze(1).repeat(1, num_points, 1) # (N, num_points, 7) 183 | 184 | points = quaternion_apply(poses_repeat[..., 3:], key_points.unsqueeze(0)) # (N, num_points, 3) 185 | points += poses_repeat[..., :3] #(N, 4, 3) 186 | points = points.view(1,-1,3).repeat(self.num_views,1,1) 187 | 188 | uv, z = perspective(points, projections) 189 | uv = uv.view(self.num_views, -1, num_points, 2).cpu() 190 | 191 | fig = plt.figure(figsize=(10,5)) 192 | uv0 = uv.mean(dim=2) # (num_view, N, 2) 193 | for i in range(self.num_views): 194 | ax = plt.subplot(1,self.num_views,i+1) 195 | for j in range(num_points): 196 | tmp = torch.stack([uv0[i], uv[i,:,j]], dim=0) #(2, N, 2) 197 | ax.plot(tmp[...,0], -tmp[...,1]) 198 | ax.axis('square') 199 | ax.axis([-1,1,-1,1]) 200 | ax.grid() 201 | 202 | plt.show() 203 | 204 | def to_device(self, device): 205 | self.device = device 206 | 207 | self.rgb = self.rgb.to(device) 208 | self.cam_extrinsic = self.cam_extrinsic.to(device) 209 | self.cam_intrinsic = self.cam_intrinsic.to(device) 210 | 211 | self.point = self.point.to(device) 212 | self.sdf = self.sdf.to(device) 213 | 214 | self.grasp_pose = self.grasp_pose.to(device) 215 | self.hang_pose = self.hang_pose.to(device) 216 | 217 | self.grasp_draw_points = self.grasp_draw_points.to(device) 218 | self.hang_draw_points = self.hang_draw_points.to(device) 219 | 220 | 221 | class RandomImageWarper: 222 | def __init__(self, img_res=None, sig_center=0.01, obj_r=0.15, return_cam_params=False): 223 | self.img_res = img_res 224 | self.sig_center = sig_center 225 | self.obj_r = obj_r 226 | self.return_cam_params = return_cam_params 227 | 228 | def __call__(self, rgb, T1, K1): 229 | return batched_random_warping(rgb, T1, K1, 230 | self.img_res, 231 | self.sig_center, 232 | self.obj_r, 233 | self.return_cam_params) 234 | 235 | class PoseSampler: 236 | def __init__(self, scale): 237 | self.scale = scale 238 | 239 | def __call__(self, poses, poses_all): 240 | return get_pose_and_cost(poses, poses_all, self.scale) 241 | 242 | def batched_random_warping(rgb, T1, K1, img_res=None, sig_center=0.01, obj_r=0.1, return_cam_params=False): 243 | """ 244 | Warp images with random homography and compute corresponding transform matrix 245 | Args 246 | rgb: (B, num_views, 3, H_in, W_in) 247 | T1: (B, num_views, 4, 4) camera extrinsic 248 | K1: (B, num_views, 4, 4) camera intrinsic 249 | 250 | Return 251 | rgb_warped: (B, num_views, 3, H_out, W_out) 252 | projection: (B, num_views, 4, 4) camera projection matrix (= K1@T2_inv) 253 | """ 254 | 255 | B, num_views, _, H_in, W_in = rgb.shape 256 | if img_res is None: 257 | H_out, W_out = H_in, W_in 258 | else: 259 | H_out, W_out = img_res 260 | 261 | device = rgb.device 262 | 263 | new_origin = sig_center*torch.rand(B,1,3).to(device) 264 | rel_pos = T1[..., :3, 3] - new_origin # (B, num_views, 3) 265 | cam_distance = rel_pos.norm(dim=2) # (B, num_views) 266 | 267 | # compute a new intrinsic 268 | fov = 2*torch.asin(obj_r/cam_distance) # (B, num_views) 269 | K2 = torch.zeros_like(K1) # (B, num_views, 4, 4) 270 | K2[...,0,0] = 1/torch.tan(fov/2) 271 | K2[...,1,1] = -1/torch.tan(fov/2) 272 | # K2[...,2,2] = -cam_distance/obj_r 273 | # K2[...,2,3] = -(cam_distance**2-obj_r**2)/obj_r 274 | K2[...,2,2] = -1/obj_r 275 | K2[...,2,3] = -cam_distance/obj_r 276 | K2[...,3,2] = -1. 277 | 278 | # extrinsic 279 | theta = torch.atan2(rel_pos[...,:2].norm(dim=2), rel_pos[...,2]) # (B, num_views) 280 | phi = np.pi/2 + torch.atan2(rel_pos[...,1], rel_pos[...,0]) # (B, num_views) 281 | T2 = torch.eye(4).to(device).repeat(B,num_views,1,1) # (B, num_views, 4, 4) 282 | T2[...,2,3] = cam_distance 283 | T2 = batch_rotation_matrix(theta, axis='X').matmul(T2) 284 | T2 = batch_rotation_matrix(phi, axis='Z').matmul(T2) 285 | T2[...,:3,3] += new_origin 286 | 287 | # camera roll 288 | random_roll = torch.rand(B,num_views).to(device)*np.pi*2. 289 | rot_roll = axis_angle_to_matrix('Z', random_roll) 290 | T2[...,:3,:3] = T2[...,:3,:3].matmul(rot_roll) 291 | 292 | # Homography 293 | idx = np.ix_(np.arange(B), np.arange(num_views), [0,1,3], [0,1,2]) 294 | R_1_2 = T1[...,:3,:3].transpose(-1,-2).matmul(T2[...,:3,:3]) 295 | Hinv = K1[idx].matmul(R_1_2).matmul(torch.inverse(K2[idx])) # (B, num_views, 3, 3) 296 | 297 | # Warp 298 | x = torch.linspace(-1, 1, H_out).to(device) 299 | y = torch.linspace(-1, 1, W_out).to(device) 300 | grid_v, grid_u = torch.meshgrid(x, y) 301 | base_grid = torch.stack([grid_u, grid_v, torch.ones_like(grid_u)], dim=2) # (H, W, 3) 302 | grid = base_grid.view(1,1,H_out*W_out,3).matmul(Hinv.transpose(-1,-2)) # (B, num_views, H*W, 3) 303 | grid = grid[...,:2]/grid[...,2:3] 304 | 305 | rgb_warped = torch.nn.functional.grid_sample(rgb.view(B*num_views, 3, H_in, W_in), 306 | grid.view(B*num_views, H_out, W_out, 2), 307 | mode='bilinear', 308 | padding_mode='zeros', 309 | align_corners=True) # (B*num_views, 3, H, W) 310 | 311 | if return_cam_params: 312 | cam_pos = T1[..., :3, 3] 313 | return rgb_warped.view(B, num_views, 3, H_out, W_out), K2.matmul(torch.inverse(T2)), cam_pos, new_origin, random_roll.view(B,num_views,1) 314 | 315 | else: 316 | return rgb_warped.view(B, num_views, 3, H_out, W_out), K2.matmul(torch.inverse(T2)) 317 | 318 | def batch_rotation_matrix(angle, axis): 319 | B, num_views = angle.shape 320 | T = torch.eye(4).repeat(B,num_views,1,1).to(angle.device) 321 | T[...,:3,:3] = axis_angle_to_matrix(axis, angle) 322 | return T 323 | 324 | def compute_cost(perturbed_poses, poses_all, scale=None): 325 | perturbed_poses = perturbed_poses.unsqueeze(2) # (B, num_poses, 1, 7) 326 | feasible_poses = poses_all.unsqueeze(1) # (B, 1, total_poses, 7) 327 | 328 | pos_diff = perturbed_poses[...,:3] - feasible_poses[..., :3]# (B, num_poses, total_poses, 3) 329 | quat_diff = quaternion_multiply(quaternion_invert(feasible_poses[..., 3:]), perturbed_poses[...,3:]) 330 | rotvec_diff = quaternion_to_axis_angle(quat_diff)# (B, num_poses, total_poses, 3) 331 | total_diff = torch.cat([pos_diff, rotvec_diff], dim=3) # (B, num_poses, total_poses, 6) 332 | 333 | if scale is not None: 334 | total_diff *= scale.to(poses_all.device) 335 | 336 | costs = total_diff.norm(dim=3).min(dim=2)[0] # (B, num_poses) 337 | 338 | return costs 339 | 340 | 341 | def get_pose_and_cost(poses, poses_all, scale=None): 342 | """ 343 | Args: 344 | poses: (B, num_poses, 7) 345 | poses_all: (B, total_poses, 7) 346 | 347 | Return: 348 | poses (B, num_poses, 7) 349 | cost (B, num_poses) 350 | """ 351 | 352 | B, num_poses, _ = poses.shape 353 | device = poses.device 354 | 355 | noise_pos = .2*torch.randn(B, num_poses, 3, device=device) 356 | noise_quat = random_quaternions(B*num_poses, device=device).view(B, num_poses, 4) 357 | t = torch.rand(B, num_poses, 1, device=device) 358 | 359 | perturbed_poses = torch.zeros_like(poses) #(B, num_poses, 7) 360 | perturbed_poses[...,:3] = (1-t)*poses[..., :3] + t*noise_pos 361 | perturbed_poses[...,3:] = quaternion_slerp(poses[...,3:], noise_quat, t) 362 | 363 | costs = compute_cost(perturbed_poses, poses_all, scale) 364 | 365 | return perturbed_poses, costs 366 | 367 | 368 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Optional 4 | import trimesh 5 | from scipy.spatial.transform import Rotation 6 | 7 | 8 | 9 | 10 | def view_scene_hang(pose7d, filename): 11 | pose7d = pose7d.detach().cpu().view(-1, 7).numpy() 12 | 13 | mug_meshes = [] 14 | hook_meshes = [] 15 | 16 | sn = int(np.sqrt(pose7d.shape[0])) 17 | for n in range(pose7d.shape[0]): 18 | 19 | T = np.eye(4) 20 | T[:3,:3] = Rotation.from_quat(pose7d[n,[4,5,6,3]]).as_matrix() 21 | T[:3,3] = pose7d[n,:3] 22 | Tinv = np.linalg.inv(T) 23 | mug_mesh = trimesh.load(filename).apply_transform(Tinv) 24 | 25 | T2 = np.eye(4) 26 | T2[0,3] = .4*(n//sn - 0.5*(sn-1)) 27 | T2[1,3] = .4*(n%sn - 0.5*(sn-1)) 28 | mug_meshes.append(mug_mesh.apply_transform(T2)) 29 | hook_meshes.append(create_hook().apply_transform(T2)) 30 | 31 | axis_mesh = trimesh.creation.axis(origin_size=0.004, 32 | axis_radius=0.003, 33 | axis_length=0.03) 34 | return trimesh.Scene(mug_meshes + hook_meshes + [axis_mesh]) 35 | 36 | def create_hook(color=[100, 200, 100], cylinder_radius=0.002): 37 | pillar = trimesh.creation.cylinder( 38 | radius=cylinder_radius*2, 39 | segment=[ 40 | [0, 0.1, -0.15], 41 | [0, -0.2, -0.15], 42 | ], 43 | ) 44 | hook = trimesh.creation.cylinder(radius=cylinder_radius, 45 | height=.15*2) 46 | tmp = trimesh.util.concatenate([pillar, hook]) 47 | tmp.visual.vertex_colors = color 48 | return tmp 49 | 50 | 51 | def to_device(data, device): 52 | for key in data: 53 | if isinstance(data[key], torch.Tensor): 54 | data[key] = data[key].to(device) 55 | return data 56 | 57 | def view_scene_hang_batch(pose7d_batch, feasibility_batch, filename_list): 58 | """ 59 | Args: 60 | pose7d_batch: (B, N, 7) 61 | feasibility_batch: (B, N) 62 | filename_list: (B,) 63 | """ 64 | 65 | B, N, _ = pose7d_batch.shape 66 | mug_meshes = [] 67 | hook_meshes = [] 68 | 69 | sn = int(np.sqrt(len(filename_list))) 70 | for b, filename in enumerate(filename_list): 71 | 72 | mesh_filename = 'data/meshes_coll/'+filename 73 | T2 = np.eye(4) 74 | T2[0,3] = .4*(b//sn - 0.5*(sn-1)) 75 | T2[1,3] = .4*(b%sn - 0.5*(sn-1)) 76 | hook_meshes.append(create_hook().apply_transform(T2)) 77 | 78 | for n in range(N): 79 | pose7d = pose7d_batch[b,n] 80 | T = np.eye(4) 81 | T[:3,:3] = Rotation.from_quat(pose7d[[4,5,6,3]]).as_matrix() 82 | T[:3,3] = pose7d[:3] 83 | Tinv = np.linalg.inv(T) 84 | mug_mesh = trimesh.load(mesh_filename).apply_transform(Tinv) 85 | if not feasibility_batch[b,n]: 86 | mug_mesh.visual.vertex_colors = [200, 100, 100] 87 | 88 | mug_meshes.append(mug_mesh.apply_transform(T2)) 89 | 90 | 91 | axis_mesh = trimesh.creation.axis(origin_size=0.004, 92 | axis_radius=0.003, 93 | axis_length=0.03) 94 | return trimesh.Scene(mug_meshes + hook_meshes + [axis_mesh]) 95 | 96 | def view_scene_grasp_batch(pose7d_batch, feasibility_batch, filename_list, draw_coll=False): 97 | """ 98 | Args: 99 | pose7d_batch: (B, N, 7) 100 | feasibility_batch: (B, N) 101 | filename_list: (B,) 102 | """ 103 | 104 | B, N, _ = pose7d_batch.shape 105 | 106 | mug_meshes = [] 107 | gripper_meshes = [] 108 | 109 | sn = int(np.sqrt(len(filename_list))) 110 | for b, filename in enumerate(filename_list): 111 | mesh_filename = 'data/meshes_coll/'+filename 112 | T2 = np.eye(4) 113 | T2[0,3] = .4*(b//sn - 0.5*(sn-1)) 114 | T2[1,3] = .4*(b%sn - 0.5*(sn-1)) 115 | mug_meshes.append(trimesh.load(mesh_filename).apply_transform(T2)) 116 | 117 | for n in range(N): 118 | pose7d = pose7d_batch[b,n] 119 | T = np.eye(4) 120 | T[:3,:3] = Rotation.from_quat(pose7d[[4,5,6,3]]).as_matrix() 121 | T[:3,3] = pose7d[:3] 122 | 123 | if draw_coll: 124 | gripper_mesh = create_gripper_coll().apply_transform(T) 125 | else: 126 | gripper_mesh = create_gripper_marker().apply_transform(T) 127 | 128 | if not feasibility_batch[b,n]: 129 | gripper_mesh.visual.vertex_colors = [200, 100, 100] 130 | 131 | gripper_meshes.append(gripper_mesh.apply_transform(T2)) 132 | 133 | return trimesh.Scene(mug_meshes + gripper_meshes) 134 | 135 | def view_scene_grasp(pose7d, filename): 136 | 137 | obj_mesh = trimesh.load(filename) 138 | pose7d = pose7d.detach().cpu().view(-1, 7).numpy() 139 | grippers_mesh = [] 140 | for n in range(pose7d.shape[0]): 141 | T = np.eye(4) 142 | T[:3,:3] = Rotation.from_quat(pose7d[n,[4,5,6,3]]).as_matrix() 143 | T[:3,3] = pose7d[n,:3] 144 | grippers_mesh.append(create_gripper_marker().apply_transform(T)) 145 | 146 | return trimesh.Scene([obj_mesh] + grippers_mesh) 147 | 148 | def create_gripper_coll(color=[0, 0, 255, 120]): 149 | pose7d = np.array([0, 0, .064, -.5, .5, -.5, .5]) 150 | T = np.eye(4) 151 | T[:3,:3] = Rotation.from_quat(pose7d[[4,5,6,3]]).as_matrix() 152 | T[:3,3] = pose7d[:3] 153 | T1 = np.eye(4) 154 | T1[2,3] = -.075 155 | 156 | coll0 = trimesh.creation.capsule(radius=0.03, height=.15).apply_transform(T@T1) 157 | 158 | coll1 = trimesh.creation.icosphere(radius=0.02) 159 | coll1.vertices += [-0.058, 0, 0] 160 | 161 | coll2 = trimesh.creation.icosphere(radius=0.02) 162 | coll2.vertices += [0.058, 0, 0] 163 | 164 | coll3 = trimesh.creation.icosphere(radius=0.02) 165 | coll3.vertices += [-0.07, 0, 0.0256] 166 | 167 | coll4 = trimesh.creation.icosphere(radius=0.02) 168 | coll4.vertices += [0.07, 0, 0.0256] 169 | 170 | tmp = trimesh.util.concatenate([coll0, coll1, coll2, coll3, coll4]) 171 | tmp.visual.vertex_colors = color 172 | 173 | return tmp 174 | 175 | 176 | 177 | def create_gripper_marker(color=[0, 0, 255], tube_radius=0.001): 178 | """Create a 3D mesh visualizing a parallel yaw gripper. It consists of four cylinders. 179 | Args: 180 | color (list, optional): RGB values of marker. Defaults to [0, 0, 255]. 181 | tube_radius (float, optional): Radius of cylinders. Defaults to 0.001. 182 | sections (int, optional): Number of sections of each cylinder. Defaults to 6. 183 | Returns: 184 | trimesh.Trimesh: A mesh that represents a simple parallel yaw gripper. 185 | """ 186 | cfl = trimesh.creation.cylinder( 187 | radius=0.002, segment=[[0.05, 0, -0.02], [0.05, 0, 0.045]], 188 | ) 189 | cfr = trimesh.creation.cylinder( 190 | radius=0.002, segment=[[-0.05, 0, -0.02], [-0.05, 0, 0.045]], 191 | ) 192 | cb1 = trimesh.creation.cylinder( 193 | radius=0.002, segment=[[0, 0, 0.045], [0, 0, 0.090]] 194 | ) 195 | cb2 = trimesh.creation.cylinder( 196 | radius=0.002, segment=[[-0.05, 0, 0.045], [0.05, 0, 0.045]], 197 | ) 198 | 199 | tmp = trimesh.util.concatenate([cb1, cb2, cfr, cfl]) 200 | tmp.visual.vertex_colors = color 201 | 202 | return tmp 203 | 204 | def apply_delta(x, delta_x): 205 | """ 206 | apply delta_x 207 | Input 208 | x: (B, 7, 1) pose 209 | delta_x: (B, 6, 1) delta 210 | 211 | Return (B, 7, 1) 212 | """ 213 | y = torch.empty_like(x) 214 | y[:, :3] = x[:, :3] + delta_x[:, :3] 215 | 216 | delta_q = axis_angle_to_quaternion(delta_x[:, 3:, 0]) 217 | y[:, 3:] = quaternion_multiply(x[:, 3:, 0], delta_q).view(-1, 4, 1) 218 | return y 219 | 220 | ############################################################################ 221 | 222 | def quaternion_to_matrix(quaternions): 223 | """ 224 | Convert rotations given as quaternions to rotation matrices. 225 | 226 | Args: 227 | quaternions: quaternions with real part first, 228 | as tensor of shape (..., 4). 229 | 230 | Returns: 231 | Rotation matrices as tensor of shape (..., 3, 3). 232 | """ 233 | r, i, j, k = torch.unbind(quaternions, -1) 234 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 235 | 236 | o = torch.stack( 237 | ( 238 | 1 - two_s * (j * j + k * k), 239 | two_s * (i * j - k * r), 240 | two_s * (i * k + j * r), 241 | two_s * (i * j + k * r), 242 | 1 - two_s * (i * i + k * k), 243 | two_s * (j * k - i * r), 244 | two_s * (i * k - j * r), 245 | two_s * (j * k + i * r), 246 | 1 - two_s * (i * i + j * j), 247 | ), 248 | -1, 249 | ) 250 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 251 | 252 | 253 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: 254 | """ 255 | Convert rotations given as rotation matrices to quaternions. 256 | 257 | Args: 258 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 259 | 260 | Returns: 261 | quaternions with real part first, as tensor of shape (..., 4). 262 | """ 263 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 264 | raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") 265 | 266 | batch_dim = matrix.shape[:-2] 267 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( 268 | matrix.reshape(*batch_dim, 9), dim=-1 269 | ) 270 | 271 | q_abs = _sqrt_positive_part( 272 | torch.stack( 273 | [ 274 | 1.0 + m00 + m11 + m22, 275 | 1.0 + m00 - m11 - m22, 276 | 1.0 - m00 + m11 - m22, 277 | 1.0 - m00 - m11 + m22, 278 | ], 279 | dim=-1, 280 | ) 281 | ) 282 | 283 | # we produce the desired quaternion multiplied by each of r, i, j, k 284 | quat_by_rijk = torch.stack( 285 | [ 286 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), 287 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), 288 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), 289 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), 290 | ], 291 | dim=-2, 292 | ) 293 | 294 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, 295 | # the candidate won't be picked. 296 | # pyre-ignore [16]: `torch.Tensor` has no attribute `new_tensor`. 297 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1))) 298 | 299 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), 300 | # forall i; we pick the best-conditioned one (with the largest denominator) 301 | 302 | return quat_candidates[ 303 | F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16] 304 | ].reshape(*batch_dim, 4) 305 | 306 | def axis_angle_to_matrix(axis: str, angle): 307 | """ 308 | Return the rotation matrices for one of the rotations about an axis 309 | of which Euler angles describe, for each value of the angle given. 310 | 311 | Args: 312 | axis: Axis label "X" or "Y or "Z". 313 | angle: any shape tensor of Euler angles in radians 314 | 315 | Returns: 316 | Rotation matrices as tensor of shape (..., 3, 3). 317 | """ 318 | 319 | cos = torch.cos(angle) 320 | sin = torch.sin(angle) 321 | one = torch.ones_like(angle) 322 | zero = torch.zeros_like(angle) 323 | 324 | if axis == "X": 325 | R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) 326 | if axis == "Y": 327 | R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) 328 | if axis == "Z": 329 | R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) 330 | 331 | return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) 332 | 333 | 334 | def axis_angle_to_quaternion(axis_angle): 335 | """ 336 | Convert rotations given as axis/angle to quaternions. 337 | 338 | Args: 339 | axis_angle: Rotations given as a vector in axis angle form, 340 | as a tensor of shape (..., 3), where the magnitude is 341 | the angle turned anticlockwise in radians around the 342 | vector's direction. 343 | 344 | Returns: 345 | quaternions with real part first, as tensor of shape (..., 4). 346 | """ 347 | angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) 348 | half_angles = 0.5 * angles 349 | eps = 1e-6 350 | small_angles = angles.abs() < eps 351 | sin_half_angles_over_angles = torch.empty_like(angles) 352 | sin_half_angles_over_angles[~small_angles] = ( 353 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 354 | ) 355 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 356 | # so sin(x/2)/x is about 1/2 - (x*x)/48 357 | sin_half_angles_over_angles[small_angles] = ( 358 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 359 | ) 360 | quaternions = torch.cat( 361 | [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 362 | ) 363 | return quaternions 364 | 365 | def quaternion_to_axis_angle(quaternions): 366 | """ 367 | Convert rotations given as quaternions to axis/angle. 368 | 369 | Args: 370 | quaternions: quaternions with real part first, 371 | as tensor of shape (..., 4). 372 | 373 | Returns: 374 | Rotations given as a vector in axis angle form, as a tensor 375 | of shape (..., 3), where the magnitude is the angle 376 | turned anticlockwise in radians around the vector's 377 | direction. 378 | """ 379 | norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) 380 | half_angles = torch.atan2(norms, quaternions[..., :1]) 381 | angles = 2 * half_angles 382 | eps = 1e-6 383 | small_angles = angles.abs() < eps 384 | sin_half_angles_over_angles = torch.empty_like(angles) 385 | sin_half_angles_over_angles[~small_angles] = ( 386 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 387 | ) 388 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 389 | # so sin(x/2)/x is about 1/2 - (x*x)/48 390 | sin_half_angles_over_angles[small_angles] = ( 391 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 392 | ) 393 | return quaternions[..., 1:] / sin_half_angles_over_angles 394 | 395 | def standardize_quaternion(quaternions): 396 | """ 397 | Convert a unit quaternion to a standard form: one in which the real 398 | part is non negative. 399 | 400 | Args: 401 | quaternions: Quaternions with real part first, 402 | as tensor of shape (..., 4). 403 | 404 | Returns: 405 | Standardized quaternions as tensor of shape (..., 4). 406 | """ 407 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) 408 | 409 | def quaternion_raw_multiply(a, b): 410 | """ 411 | Multiply two quaternions. 412 | Usual torch rules for broadcasting apply. 413 | 414 | Args: 415 | a: Quaternions as tensor of shape (..., 4), real part first. 416 | b: Quaternions as tensor of shape (..., 4), real part first. 417 | 418 | Returns: 419 | The product of a and b, a tensor of quaternions shape (..., 4). 420 | """ 421 | aw, ax, ay, az = torch.unbind(a, -1) 422 | bw, bx, by, bz = torch.unbind(b, -1) 423 | ow = aw * bw - ax * bx - ay * by - az * bz 424 | ox = aw * bx + ax * bw + ay * bz - az * by 425 | oy = aw * by - ax * bz + ay * bw + az * bx 426 | oz = aw * bz + ax * by - ay * bx + az * bw 427 | return torch.stack((ow, ox, oy, oz), -1) 428 | 429 | def quaternion_multiply(a, b): 430 | """ 431 | Multiply two quaternions representing rotations, returning the quaternion 432 | representing their composition, i.e. the versor with nonnegative real part. 433 | Usual torch rules for broadcasting apply. 434 | 435 | Args: 436 | a: Quaternions as tensor of shape (..., 4), real part first. 437 | b: Quaternions as tensor of shape (..., 4), real part first. 438 | 439 | Returns: 440 | The product of a and b, a tensor of quaternions of shape (..., 4). 441 | """ 442 | ab = quaternion_raw_multiply(a, b) 443 | return standardize_quaternion(ab) 444 | 445 | def quaternion_invert(quaternion): 446 | """ 447 | Given a quaternion representing rotation, get the quaternion representing 448 | its inverse. 449 | 450 | Args: 451 | quaternion: Quaternions as tensor of shape (..., 4), with real part 452 | first, which must be versors (unit quaternions). 453 | 454 | Returns: 455 | The inverse, a tensor of quaternions of shape (..., 4). 456 | """ 457 | 458 | return quaternion * torch.tensor([1, -1, -1, -1], 459 | dtype=quaternion.dtype, 460 | device=quaternion.device) 461 | 462 | 463 | 464 | def quaternion_apply(quaternion, point): 465 | """ 466 | Apply the rotation given by a quaternion to a 3D point. 467 | Usual torch rules for broadcasting apply. 468 | 469 | Args: 470 | quaternion: Tensor of quaternions, real part first, of shape (..., 4). 471 | point: Tensor of 3D points of shape (..., 3). 472 | 473 | Returns: 474 | Tensor of rotated points of shape (..., 3). 475 | """ 476 | if point.size(-1) != 3: 477 | raise ValueError(f"Points are not in 3D, f{point.shape}.") 478 | real_parts = point.new_zeros(point.shape[:-1] + (1,)) 479 | point_as_quaternion = torch.cat((real_parts, point), -1) 480 | out = quaternion_raw_multiply( 481 | quaternion_raw_multiply(quaternion, point_as_quaternion), 482 | quaternion_invert(quaternion), 483 | ) 484 | return out[..., 1:] 485 | 486 | 487 | def _copysign(a, b): 488 | """ 489 | Return a tensor where each element has the absolute value taken from the, 490 | corresponding element of a, with sign taken from the corresponding 491 | element of b. This is like the standard copysign floating-point operation, 492 | but is not careful about negative 0 and NaN. 493 | 494 | Args: 495 | a: source tensor. 496 | b: tensor whose signs will be used, of the same shape as a. 497 | 498 | Returns: 499 | Tensor of the same shape as a with the signs of b. 500 | """ 501 | signs_differ = (a < 0) != (b < 0) 502 | return torch.where(signs_differ, -a, a) 503 | 504 | def random_quaternions( 505 | n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False 506 | ): 507 | """ 508 | Generate random quaternions representing rotations, 509 | i.e. versors with nonnegative real part. 510 | 511 | Args: 512 | n: Number of quaternions in a batch to return. 513 | dtype: Type to return. 514 | device: Desired device of returned tensor. Default: 515 | uses the current device for the default tensor type. 516 | requires_grad: Whether the resulting tensor should have the gradient 517 | flag set. 518 | 519 | Returns: 520 | Quaternions as tensor of shape (N, 4). 521 | """ 522 | o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) 523 | # o = o.sign()*torch.where(o.abs() > 1e-4, o.abs(), 1e-4*torch.ones_like(o)) 524 | s = (o * o).sum(1) 525 | o = o / _copysign(torch.sqrt(s+1e-10), o[:, 0])[:, None] 526 | return o 527 | 528 | def quaternion_slerp(q0, q1, t): 529 | """ 530 | Apply the rotation given by a quaternion to a 3D point. 531 | Usual torch rules for broadcasting apply. 532 | 533 | Args: 534 | q0, q1: Tensor of quaternions, real part first, of shape (..., 4). 535 | t: interpolation parameter between 0 and 1 (..., 1). 536 | 537 | Returns: 538 | Tensor of rotated points of shape (..., 4). 539 | """ 540 | 541 | q_tmp = quaternion_multiply(q1, quaternion_invert(q0)) 542 | rot = t*quaternion_to_axis_angle(q_tmp) 543 | q_interp = quaternion_multiply(axis_angle_to_quaternion(rot), q0) 544 | 545 | return q_interp 546 | 547 | 548 | def index(feat, uv): 549 | """ 550 | Extract image features at uv coordinates 551 | 552 | Args: 553 | feat: (B, Feat_img, H, W) image features 554 | uv: (B, N, 2) uv coordinates in the image plane, range [-1, 1] 555 | 556 | Returns: 557 | (B, N, Feat_img) image features 558 | """ 559 | samples = torch.nn.functional.grid_sample(feat, 560 | uv.unsqueeze(2), # (B, N, 1, 2) 561 | align_corners=True, 562 | padding_mode="border")# (B, Feat_img, N, 1) 563 | 564 | return samples.transpose(1,2).squeeze(-1) # (B, N, Feat_img) 565 | 566 | 567 | def perspective(points, projection_matrices): 568 | """ 569 | Compute the perspective projections of 3D points into the image plane by given projection matrix 570 | 571 | Args: 572 | points: (B, N, 3) Tensor of 3D points 573 | projection_matrices: (B, 4, 4) Tensor of projection matrix 574 | 575 | Returns: 576 | uv: (B, N, 2) uv coordinates in image space 577 | z: (B, N, 1) normalized depth 578 | """ 579 | 580 | tmp = torch.ones_like(points[...,0:1]) 581 | points = torch.cat([points, tmp], dim=2) # (B, N, 4) 582 | 583 | homo = points.bmm(projection_matrices.transpose(1,2)) # (B, N, 4) 584 | 585 | uv, z, w = torch.split(homo, [2,1,1], dim=2) 586 | w = w.clamp(min=1e-2) # clamp depth near the camera (1 cm) 587 | 588 | # return uv/w, z/w # (B, N, 2), (B, N, 1) 589 | return uv/w, z # (B, N, 2), (B, N, 1) -------------------------------------------------------------------------------- /src/vector_object.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .utils import * 5 | from .functional_object import * 6 | from .feature import * 7 | 8 | from skimage import measure 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 11 | 12 | class vectorObject(nn.Module): 13 | def __init__(self, **C): 14 | super().__init__() 15 | self.C = {} 16 | self.C['FEAT_IMG'] = 64 17 | self.C['FEAT_CAM'] = 32 18 | self.C['WIDTH_LIFTER'] = [256, 128] 19 | self.C.update(C) 20 | self.build_modules() 21 | 22 | def build_modules(self): 23 | self.image_encoder = torchvision.models.resnet.resnet34(pretrained=True) 24 | num_channels = self.image_encoder.layer4[-1].bn2.num_features 25 | self.image_encoder.fc = nn.Linear(num_channels, self.C['FEAT_IMG']) 26 | 27 | self.cam_encoder = nn.Sequential( 28 | nn.Linear(4, self.C['FEAT_CAM']), 29 | nn.ReLU(inplace=True) 30 | ) 31 | 32 | lifter_layers = [ 33 | nn.Linear(self.C['FEAT_IMG']+self.C['FEAT_CAM'], self.C['WIDTH_LIFTER'][0]), 34 | nn.ReLU(inplace=True) 35 | ] 36 | for i in range(len(self.C['WIDTH_LIFTER']) - 1): 37 | lifter_layers.extend([ 38 | nn.Linear(self.C['WIDTH_LIFTER'][i], self.C['WIDTH_LIFTER'][i+1]), 39 | nn.ReLU(inplace=True) 40 | ]) 41 | self.feature_lifter = nn.Sequential(*lifter_layers) 42 | self.out_dim = self.C['WIDTH_LIFTER'][-1] 43 | 44 | mean=[0.485, 0.456, 0.406] 45 | std=[0.229, 0.224, 0.225] 46 | self.normalizer = transforms.Normalize(mean=mean, std=std) 47 | self.unnormalizer = UnNormalize(mean=mean, std=std) 48 | 49 | def encode(self, images, cam_params): 50 | self.forward(images, cam_params) 51 | 52 | def forward(self, images, cam_params): 53 | """ 54 | Args: 55 | images: (B, num_views, C, H, W) input images 56 | cam_params: (B, num_views, 4) camera rel position and roll for each image 57 | Returns: 58 | (B, num_view, Feat) features for each point 59 | """ 60 | 61 | if images is not None: 62 | B, self.num_views, C, H, W = images.shape 63 | images = images.view(-1, C, H, W) # (B * num_views, C, H, W) 64 | images = self.normalizer(images) 65 | cam_params = cam_params.view(-1, 4) 66 | 67 | 68 | img_feat = self.image_encoder(images) # (B * num_views, feat_img) 69 | cam_feat = self.cam_encoder(cam_params) # (B * num_views, feat_cam) 70 | feat_all = torch.cat([img_feat, cam_feat], dim=1).view(B, self.num_views, -1) 71 | 72 | self.feature = self.feature_lifter(feat_all) # (B, num_view, Feat) 73 | 74 | return self.feature 75 | 76 | 77 | class Frame_vec(nn.Module): 78 | def __init__(self, **kwargs): 79 | super(Frame_vec, self).__init__() 80 | 81 | self.backbone = kwargs.get("backbone", None) 82 | self.sdf_head = kwargs.get("sdf_head", None) 83 | self.grasp_head = kwargs.get("grasp_head", None) 84 | self.placing_head = kwargs.get("placing_head", None) 85 | self.hanging_head = kwargs.get("hanging_head", None) 86 | 87 | def build_backbone(self, **C): 88 | self.backbone = vectorObject(**C) 89 | 90 | def build_sdf_head(self, width): 91 | layer_list = [nn.Linear(3+self.backbone.out_dim, width[0]), nn.ReLU(inplace=True)] 92 | for i in range(len(width)-1): 93 | layer_list.extend([ 94 | nn.Linear(width[i], width[i+1]), nn.ReLU(inplace=True) 95 | ]) 96 | layer_list.append(nn.Linear(width[-1], 1)) 97 | self.sdf_head = nn.Sequential(*layer_list) 98 | 99 | 100 | def build_pose_head(self, name, width): 101 | layer_list = [ 102 | nn.Linear(7+self.backbone.out_dim, width[0]), nn.ReLU(inplace=True) 103 | ] 104 | for i in range(len(width)-1): 105 | layer_list.extend([ 106 | nn.Linear(width[i], width[i+1]), nn.ReLU(inplace=True) 107 | ]) 108 | layer_list.append(nn.Linear(width[-1], 1)) 109 | 110 | setattr(self, name+'_head', nn.Sequential(*layer_list)) 111 | head = getattr(self, name+'_head') 112 | head.name = name 113 | 114 | def extract_mesh(self, 115 | images=None, 116 | cam_params=None, 117 | center=[0,0,0], 118 | scale=.15, 119 | num_grid=50, 120 | sdf_scale=10., 121 | delta=0., 122 | draw=True, 123 | return_sdf=False): 124 | assert self.sdf_head is not None, "sdf_head is not defined!" 125 | 126 | 127 | if images is None: 128 | images = self.backbone.images 129 | else: 130 | self.backbone.encode(images, cam_params) 131 | 132 | device = images.device 133 | num_views = images.shape[1] 134 | 135 | F_sdf = SDF_Feature_vec(self) 136 | 137 | dx = center[0]+scale*torch.linspace(-1, 1, num_grid, device=device) 138 | dy = center[1]+scale*torch.linspace(-1, 1, num_grid, device=device) 139 | dz = center[2]+scale*torch.linspace(-1, 1, num_grid, device=device) 140 | grid_x, grid_y, grid_z = torch.meshgrid(dx, dy, dz) 141 | grid_x, grid_y, grid_z = grid_x.flatten(), grid_y.flatten(), grid_z.flatten() 142 | pts = torch.stack([grid_x, grid_y, grid_z], dim=1).unsqueeze(0) # (1, num_grid**3, 3) 143 | 144 | 145 | L = pts.shape[1] 146 | N = num_grid**2 147 | mu = np.zeros((L,1)) 148 | for i in range(L//N): 149 | with torch.no_grad(): 150 | mu[i*N:(i+1)*N] = F_sdf(pts[:,i*N:(i+1)*N,:])[0].view(-1, 1).detach().cpu().numpy()/sdf_scale 151 | mu = mu.reshape((num_grid, num_grid, num_grid)) 152 | vertices, faces, normals, _ = measure.marching_cubes(mu, delta) 153 | vertices = np.array(center).reshape(1,3)-scale + vertices * 2*scale/(num_grid-1) 154 | if draw: 155 | mesh = Poly3DCollection(vertices[faces], 156 | facecolors='w', 157 | edgecolors='k', 158 | linewidths=1, 159 | alpha=0.5) 160 | 161 | fig = plt.figure() 162 | ax = plt.subplot(111, projection='3d') 163 | ax.set_xlim([center[0]-scale, center[0]+scale]) 164 | ax.set_ylim([center[1]-scale, center[1]+scale]) 165 | ax.set_zlim([center[2]-scale, center[2]+scale]) 166 | ax.set_xlabel('x') 167 | ax.set_ylabel('y') 168 | ax.set_zlabel('z') 169 | ax.grid() 170 | ax.add_collection3d(mesh) 171 | plt.tight_layout() 172 | 173 | render_images = images.cpu().squeeze(0) 174 | fig = plt.figure() 175 | for i in range(num_views): 176 | ax = plt.subplot(np.ceil(num_views/5),5,i+1) 177 | ax.imshow(render_images[i,...].permute(1,2,0)) 178 | plt.tight_layout() 179 | plt.show() 180 | 181 | if return_sdf: 182 | return vertices, faces, normals, mu.flatten() 183 | else: 184 | return vertices, faces, normals 185 | 186 | class SDF_Feature_vec(nn.Module): 187 | def __init__(self, frame): 188 | super().__init__() 189 | self.frame = frame 190 | self.backbone = frame.backbone 191 | self.head = frame.sdf_head 192 | self.name = 'sdf' 193 | 194 | def forward(self, points, images=None, cam_params=None, grad_method=None): 195 | """ 196 | Args: 197 | points: (B, N, 3) world coordinates of points 198 | images: (B, num_views, C, H, W) input images 199 | cam_params: (B, num_views, 8) camera pose & fov for each image 200 | grad_method: {"FD", "AD", None} how to compute grads - (forward) finite diff / auto diff 201 | Returns: 202 | sdf: (B, N) sdf predictions for each point 203 | (optional) grads: (B, N, 3) grads of sdf w.r.t. points 204 | """ 205 | assert self.head is not None, "head is not defined!" 206 | 207 | B, N = points.shape[:2] 208 | features = self.backbone(images, cam_params).mean(dim=1, keepdims=True).repeat(1,N,1) 209 | # (B, N, Feat) 210 | 211 | if grad_method is None: 212 | 213 | features = torch.cat([points, features], dim=2) # (B, N, 3+Feat) 214 | 215 | return self.head(features).view(B, N) #(B, N) 216 | 217 | elif grad_method == "AD": 218 | with torch.enable_grad(): 219 | points.requires_grad_() 220 | features = torch.cat([points, features], dim=2) # (B, N, 3+Feat) 221 | sdf = self.head(features).view(B, N) #(B, N) 222 | 223 | grads = torch.autograd.grad(outputs=sdf, 224 | inputs=points, 225 | grad_outputs=torch.ones_like(sdf))[0] 226 | return sdf, grads 227 | 228 | 229 | class Pose_Feature_vec(nn.Module): 230 | def __init__(self, frame, name): 231 | super().__init__() 232 | self.frame = frame 233 | self.backbone = frame.backbone 234 | self.head = getattr(frame, name+'_head') 235 | self.name = name 236 | 237 | collision_shapes = [] 238 | if name == 'grasp': 239 | collision_shapes.append({'shape': 'capsule', 240 | 'size': [0.15, 0.03], 241 | 'pos': torch.Tensor([0, 0, 0.064]), 242 | 'quat': torch.Tensor([-.5, .5, -.5, .5])}) 243 | fing_colls = torch.Tensor([[-0.058, 0, 0], 244 | [ 0.058, 0, 0], 245 | [-0.07, 0, 0.0256], 246 | [ 0.07, 0, 0.0256]]) 247 | for pos in fing_colls: 248 | collision_shapes.append({'shape': 'sphere', 'size': [0.02], 'pos': pos}) 249 | 250 | elif name == 'hang': 251 | collision_shapes.append({'shape': 'capsule', 252 | 'size': [.15*2, .002], 253 | 'pos': torch.zeros(3), 254 | 'quat': torch.Tensor([1,0,0,0])}) 255 | 256 | self.compute_collision_points(collision_shapes) 257 | 258 | def forward(self, poses, images=None, cam_params=None, grad_method=None): 259 | """ 260 | Args: 261 | poses: (B, N, 7) poses 262 | images: (B, num_views, C, H, W) input images 263 | cam_params: (B, num_views, 8) camera pose & fov for each image 264 | grad_method: {"FD", "AD", None} how to compute grads - (forward) finite diff, auto diff, none 265 | Returns: 266 | y: (B, N) task feature predictions for all pose 267 | (optional) grads: (B, N, 6) grads of task feature w.r.t. poses 268 | """ 269 | 270 | assert self.head is not None, "head is not defined!" 271 | 272 | B, N = poses.shape[:2] 273 | device = poses.device 274 | 275 | features = self.backbone(images, cam_params).mean(dim=1, keepdim=True).repeat(1,N,1) 276 | # (B, N, Feat) 277 | 278 | if grad_method is None: 279 | features = torch.cat([poses, features], dim=2) # (B, N, 3+Feat) 280 | return self.head(features).view(B, N) #(B, N) 281 | 282 | 283 | elif grad_method == "AD": 284 | with torch.enable_grad(): 285 | delta_x = torch.zeros(B, N, 6, device=device).requires_grad_() 286 | poses2 = torch.zeros_like(poses) 287 | poses2[...,:3] = poses[...,:3]+delta_x[..., :3] 288 | delta_q = torch.cat([ 289 | torch.ones(B,N,1, device=device), 0.5*delta_x[..., 3:] 290 | ], dim=2) 291 | 292 | poses2[...,3:] = quaternion_multiply(poses[..., 3:], delta_q) 293 | 294 | features = torch.cat([poses2, features], dim=2) # (B, N, 7+Feat) 295 | y = self.head(features).view(B, N) #(B, N) 296 | 297 | grads = torch.autograd.grad(outputs=y, 298 | inputs=delta_x, 299 | grad_outputs=torch.ones_like(y))[0] 300 | return y, grads 301 | 302 | 303 | def compute_collision_points(self, collision_shapes): 304 | if len(collision_shapes) == 0: return 305 | x = [] 306 | rad = [] 307 | for coll_shape in collision_shapes: 308 | if coll_shape['shape'] == 'sphere': 309 | x.append(coll_shape['pos']) 310 | rad.append(coll_shape['size'][-1]) 311 | elif coll_shape['shape'] == 'capsule': 312 | l = coll_shape['size'][0] 313 | r = coll_shape['size'][-1] 314 | N_capsule = int(0.5+l/r+1) 315 | tmp_pos = torch.zeros(N_capsule,3) 316 | tmp_pos[:,2] = torch.linspace(-l/2,l/2,N_capsule) 317 | half_delta = 0.5*l/(N_capsule-1) 318 | rad_ = np.sqrt(half_delta**2+r**2) 319 | for i in range(N_capsule): 320 | x_tmp = tmp_pos[i] 321 | x_tmp = quaternion_apply(coll_shape['quat'], x_tmp) 322 | x_tmp += coll_shape['pos'] 323 | x.append(x_tmp) 324 | rad.append(rad_) 325 | 326 | self.pts_coll = torch.stack(x, dim=0) 327 | self.rads_coll = torch.Tensor(rad) 328 | 329 | 330 | def eval_features(self, poses, w_coll=0., coll_margin=0., return_grad=False): 331 | """ 332 | poses: (B,N,7) 333 | """ 334 | 335 | B, N, _ = poses.shape 336 | device = poses.device 337 | 338 | 339 | if return_grad: 340 | y, grads = self.forward(poses, grad_method="AD") # (B,N), (B,N,6) 341 | phi, J_phi = y.view(B,N,1,1), grads.view(B,N,1,6) # (B,N,1,1), (B,N,1,6) 342 | else: 343 | phi = self.forward(poses).view(B,N,1,1) # (B,N,1,1) 344 | 345 | if w_coll==0.: 346 | phi_coll = torch.zeros_like(phi) 347 | if return_grad: 348 | J_phi_coll = torch.zeros_like(J_phi) 349 | else: 350 | 351 | with torch.enable_grad(): 352 | if return_grad: 353 | delta_x = torch.zeros(B, N, 6, device=device).requires_grad_() 354 | poses2 = torch.zeros_like(poses) 355 | poses2[...,:3] = poses[...,:3]+delta_x[..., :3] 356 | delta_q = torch.cat([ 357 | torch.ones(B,N,1, device=device), 0.5*delta_x[..., 3:] 358 | ], dim=2) 359 | poses2[...,3:] = quaternion_multiply(poses[..., 3:], delta_q) 360 | else: 361 | poses2 = poses 362 | 363 | K = self.pts_coll.shape[0] 364 | x = self.pts_coll.to(device).expand(B,N,K,3) 365 | rad = self.rads_coll.to(device).expand(B,N,K) 366 | 367 | 368 | poses2 = poses2.unsqueeze(2) # (B,N,1,7) 369 | x = quaternion_apply(poses2[...,3:], x) # (B,N,K,3) 370 | x += poses2[...,:3] # (B,N,K,3) 371 | 372 | 373 | 374 | 375 | y_coll = self.F_sdf(x.view(B, N*K, 3)).view(B,N,K)*0.1 # negative: inside 376 | y_coll -= rad # (B, N, K) 377 | y_coll -= coll_margin # (B, N, K) 378 | 379 | phi_coll = y_coll.clamp_(max=0.).view(B,N,K,1).sum(dim=2, keepdim=True) # (B,N,1,1) 380 | 381 | if return_grad: 382 | grads = torch.autograd.grad(outputs=phi_coll, 383 | inputs=delta_x, 384 | grad_outputs=torch.ones_like(phi_coll))[0] 385 | J_phi_coll = grads.view(B,N,1,6) 386 | 387 | 388 | 389 | phi = torch.cat([phi, w_coll*phi_coll], dim=2) # (B,N,2,1) 390 | if return_grad: 391 | J_phi = torch.cat([J_phi, w_coll*J_phi_coll], dim=2) # (B,N,2,6) 392 | return phi, J_phi 393 | else: 394 | return phi 395 | 396 | def optimize(self, 397 | poses, 398 | images, 399 | cam_params, 400 | w_coll=0., 401 | coll_margin=1e-3, 402 | max_iter=301, 403 | print_interval=1000, 404 | line_search=True, 405 | max_line_search=10000, 406 | max_step = 0.2, 407 | gamma=1e-4): 408 | """ 409 | Args: 410 | poses: ((B,) N, 7) poses 411 | images: ((B,) num_views, C, H, W) input images 412 | projections: ((B,) num_views, 4, 4) projection matrices for each image 413 | Returns: 414 | ((B,) N, 7) optimized poses 415 | ((B,) N) costs 416 | """ 417 | 418 | with torch.no_grad(): 419 | 420 | if w_coll > 0: 421 | self.F_sdf = SDF_Feature_vec(self.frame) 422 | 423 | batch = True 424 | if len(poses.shape) == 2: 425 | poses = poses.unsqueeze(0) 426 | images = images.unsqueeze(0) 427 | cam_params = cam_params.unsqueeze(0) 428 | batch = False 429 | 430 | self.backbone.encode(images, cam_params) 431 | 432 | B, N = poses.shape[:2] 433 | device = poses.device 434 | gammaI = gamma*torch.eye(6, device=device).view(1,1,6,6) 435 | num_tiny_steps = 0 436 | 437 | for i in range(max_iter): 438 | phi, J_phi = self.eval_features(poses, 439 | w_coll, 440 | coll_margin, 441 | return_grad=True) 442 | # (B,N,2,1), (B,N,2,6) 443 | 444 | f = phi.transpose(-1,-2).matmul(phi) 445 | g = 2*J_phi.transpose(-1,-2).matmul(phi) 446 | H = 2*J_phi.transpose(-1,-2).matmul(J_phi) + gammaI 447 | 448 | 449 | delta_x = -torch.linalg.solve(H, g).view(B,N,6) 450 | max_delta = delta_x.abs().max(dim=2,keepdims=True)[0].clamp(min=max_step) 451 | delta_x *= (max_step/max_delta) 452 | 453 | alpha = 1.*torch.ones(B,N,1).to(device) 454 | for _ in range(max_line_search): 455 | poses_tmp = poses.clone() 456 | delta_x_tmp = alpha*delta_x 457 | 458 | poses_tmp[...,:3] += delta_x_tmp[..., :3] 459 | delta_q_tmp = axis_angle_to_quaternion(delta_x_tmp[..., 3:]) 460 | poses_tmp[...,3:] = quaternion_multiply(poses_tmp[..., 3:], delta_q_tmp) 461 | 462 | phi_tmp = self.eval_features(poses_tmp, w_coll, coll_margin) # (B,N,d,1) 463 | f_tmp = phi_tmp.transpose(-1,-2).matmul(phi_tmp) 464 | 465 | masks = (f_tmp > f + 0.5*g.transpose(-1,-2).matmul(delta_x_tmp.view(B,N,6,1))).view(B,N,1) 466 | if masks.sum() == 0 or (not line_search): 467 | break 468 | else: 469 | alpha = ~masks*alpha + masks*alpha*0.5 470 | 471 | poses, delta_x = poses_tmp, delta_x_tmp, 472 | costs, colls = phi_tmp[:,:,0,0].abs(), phi_tmp[:,:,1,0].abs() 473 | 474 | max_diff = (delta_x).abs().max().item() 475 | if max_diff < 1e-4: 476 | num_tiny_steps += 1 477 | else: 478 | num_tiny_steps = 0 479 | 480 | if num_tiny_steps > 4: 481 | break 482 | 483 | if i % print_interval == 0: 484 | print('iter: {}, cost: {}, coll: {}'.format(i, costs.max().item(), colls.max().item(), max_diff)) 485 | 486 | 487 | if not batch: 488 | poses = poses.squeeze(0) 489 | costs = costs.squeeze(0) 490 | colls = colls.squeeze(0) 491 | 492 | return poses, costs.cpu().numpy(), colls.cpu().numpy() 493 | 494 | 495 | def check_feasibility(self, poses, filenames, masses, coms): 496 | """ 497 | Args: 498 | poses: ((B,) N, 7) poses 499 | Returns: 500 | ((B,) N) feasibility 501 | """ 502 | batch = True 503 | if len(poses.shape) == 2: 504 | poses = poses.unsqueeze(0) 505 | filenames = np.expand_dims(filenames, 0) 506 | masses = np.expand_dims(masses, 0) 507 | coms = np.expand_dims(coms, 0) 508 | batch = False 509 | 510 | B, N = poses.shape[0:2] 511 | feasibility = np.zeros((B, N)) 512 | 513 | for b in range(B): 514 | try: 515 | pose, mass, com = poses[b], masses[b], coms[b] 516 | mesh_coll_filename = 'data/meshes_coll/' + filenames[b] 517 | if self.name == "grasp": 518 | feasibility[b] = check_grasp_feasibility(pose.cpu().numpy(), 519 | mesh_coll_filename, mass, com) 520 | elif self.name == "hang": 521 | feasibility[b] = check_hang_feasibility(pose.cpu().numpy(), 522 | mesh_coll_filename, mass, com) 523 | except: 524 | print('feasibility check failed!') 525 | 526 | if not batch: 527 | feasibility = feasibility.squeeze(0) 528 | 529 | return feasibility -------------------------------------------------------------------------------- /src/simulation_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | import h5py 4 | import time 5 | import sys 6 | 7 | sys.path.append('../../../rai-fork/rai/ry') 8 | import libry as ry 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | from .data_gen_utils import * 13 | import scipy 14 | import torch 15 | from src.frame import Frame 16 | from src.feature import KeyPoint_Feature, JIT_Collision_Feature, JIT_Keypoint_Feature, JIT_ICP_Feature #, JIT_PoseEstimation_Feature 17 | from src.utils import * 18 | 19 | class Simulation: 20 | def __init__(self, worldFileName, mug_inds=[], mug_poses=[], verbose=1, addBalls=False, sort=False, addImp=False): 21 | 22 | self.RealWorld = ry.Config() 23 | self.RealWorld.addFile(worldFileName) 24 | self.RealWorld.selectJoints([j for j in self.RealWorld.getJointNames() if j[-6:] != 'finger']) 25 | self.sort = sort 26 | 27 | for ind, pose in zip(mug_inds, mug_poses): 28 | self.addMug(ind, pose) 29 | 30 | self.camera_name_list = ['camera_'+str(i) for i in range(4)] 31 | self.camera = self.RealWorld.cameraView() 32 | for camera_name in self.camera_name_list: 33 | self.camera.addSensorFromFrame(camera_name) 34 | 35 | self.qInit = self.RealWorld.getJointState() 36 | self.fInit = self.RealWorld.getFrameState() 37 | 38 | 39 | self.tau = 0.01 40 | self.verbose = verbose 41 | self.S = self.RealWorld.simulation(ry.SimulatorEngine.bullet, verbose) 42 | if addImp: self.S.addImp(ry.ImpType.noPenetrations, [], []) 43 | self.stepNone(3., False) 44 | if addBalls: 45 | MugPos = self.RealWorld.frame('mug'+str(mug_inds[0])).getPosition() 46 | for i in range(10): 47 | pos = MugPos + np.array([0,0,.02+.032*i]) 48 | b = self.RealWorld.addFrame('ball_'+str(i)).setShape(ry.ST.sphere, [0.015]) 49 | b.setPosition(pos).setMass(0.00001).setColor([.8,.6,.6]) 50 | self.S = self.RealWorld.simulation(ry.SimulatorEngine.bullet, verbose) 51 | if addImp: self.S.addImp(ry.ImpType.noPenetrations, [], []) 52 | self.stepNone(3., False) 53 | self.qInit = self.RealWorld.getJointState() 54 | self.fInit = self.RealWorld.getFrameState() 55 | 56 | def initialize(self): 57 | self.S.setState(self.fInit) 58 | self.S.step([], self.tau, ry.ControlMode.none) 59 | 60 | def addMug(self, ind, pose): 61 | load_dir = '../dataGeneration_vF/data/object' 62 | if self.sort: 63 | filename = sorted(os.listdir(load_dir))[ind] 64 | else: 65 | filename = os.listdir(load_dir)[ind] 66 | 67 | data_obj = h5py.File(path.join(load_dir, filename), mode='r') 68 | mesh_coll_name = path.join('data/meshes_coll', data_obj['filename'][()].decode()) 69 | size = data_obj['size'][()] 70 | mass = data_obj['mass'][()] 71 | com = data_obj['com'][:] 72 | data_obj.close() 73 | 74 | mug = self.RealWorld.addMeshFrame(mesh_coll_name, 75 | 'mug'+str(ind), 76 | mass=mass, 77 | com=com) 78 | mug.setPosition(pose[:3]).setQuaternion(pose[3:]) 79 | 80 | def closeGripper(self, ind, gripper_prefix=""): 81 | self.S.closeGripper(gripper_prefix+"gripper", speed=1., objFrameName="mug"+str(ind)) 82 | while True: 83 | self.stepNone() 84 | if self.S.getGripperIsGrasping(gripper_prefix+"gripper"): 85 | return True 86 | elif self.S.getGripperIsClose(gripper_prefix+"gripper"): 87 | return False 88 | 89 | def openGripper(self, gripper_prefix=""): 90 | self.S.openGripper(gripper_prefix+"gripper", speed=3.) 91 | 92 | def goingBack(self, tau=3.): 93 | self.stepPosition(self.qInit, tau) 94 | 95 | 96 | def showToCameras(self, tau=1.): 97 | gripperTarget = self.RealWorld.frame("gripperCenter").getPosition() 98 | gripperTarget += np.array([0,0,.2]) 99 | komo = self.RealWorld.komo_IK(False) 100 | komo.addObjective([], ry.FS.position, ["gripperCenter"], ry.OT.eq, target=gripperTarget) 101 | komo.addObjective([], ry.FS.scalarProductXZ, ["gripperCenter", "world"], ry.OT.eq) 102 | # komo.addObjective([], ry.FS.scalarProductZZ, ["gripperCenter", "world"], ry.OT.eq) 103 | komo.addObjective([], ry.FS.scalarProductXZ, ["world", "gripperCenter"], ry.OT.eq, target=[-1.]) 104 | komo.optimize() 105 | 106 | qTarget = komo.getJointState_t(0) 107 | self.stepPosition(qTarget, tau) 108 | 109 | def get_q(self): 110 | return self.S.get_q() 111 | 112 | def getMugPosition(self, mug_ind): 113 | return self.RealWorld.frame('mug'+str(mug_ind)).getPosition() 114 | 115 | def distanceToTable(self, mug_ind): 116 | dist = np.inf 117 | for f2 in [f for f in self.RealWorld.getFrameNames() if f[:3]=='mug']: 118 | y = -self.RealWorld.evalFeature(ry.FS.pairCollision_negScalar, ['table', f2])[0] 119 | dist = min(dist, y) 120 | return dist 121 | 122 | def isHung(self, mug_ind): 123 | pos = self.getMugPosition(mug_ind) 124 | dist = self.distanceToTable(mug_ind) 125 | 126 | return pos[2]>0.8 and dist>1e-2 127 | 128 | def executeTrajectory(self, traj, tau=0.1): 129 | for t in range(traj.shape[0]): 130 | self.stepPosition(traj[t], tau) 131 | 132 | def stepNone(self, tau=None, realTime=True): 133 | if tau is None: tau = self.tau 134 | for _ in range(int(tau/self.tau)): 135 | self.S.step([], self.tau, ry.ControlMode.none) 136 | if realTime and self.verbose>0: time.sleep(self.tau) 137 | 138 | def stepPosition(self, target, tau, realTime=True): 139 | delta_x = target-self.S.get_q() 140 | delta_x = np.where(delta_x < np.pi, delta_x, delta_x-2*np.pi) 141 | delta_x = np.where(delta_x > -np.pi, delta_x, delta_x+2*np.pi) 142 | N = (int(tau/self.tau)) 143 | for _ in range(N): 144 | q = self.S.get_q() + delta_x/N 145 | self.S.step(q, self.tau, ry.ControlMode.position) 146 | if realTime and self.verbose>0: time.sleep(self.tau) 147 | 148 | def takePicture(self, inds, draw=True): 149 | out = get_all_images(self.RealWorld, 150 | self.camera, 151 | self.camera_name_list, 152 | ['mug'+str(ind) for ind in inds], 153 | r=0.15, 154 | res=128) 155 | if draw and self.verbose>0: 156 | rgb_list = out[0] 157 | mask_list = out[1] 158 | rgb_focused_list = out[2] 159 | num_views = len(rgb_list) 160 | num_objs = len(rgb_focused_list) 161 | plt.figure(figsize=(15,int(2*(num_objs+1)))) 162 | for i in range(num_views): 163 | plt.subplot(num_objs+1, num_views, i+1) 164 | plt.imshow(rgb_list[i]) 165 | 166 | for j in range(num_objs): 167 | plt.subplot(num_objs+1, num_views, num_views*(j+1)+i+1) 168 | plt.imshow(mask_list[j][i]) 169 | 170 | plt.figure(figsize=(15,int(2*num_objs))) 171 | for i in range(num_views): 172 | for j in range(num_objs): 173 | plt.subplot(num_objs, num_views, num_views*j+i+1) 174 | plt.imshow(rgb_focused_list[j][i]) 175 | plt.show() 176 | 177 | return out 178 | 179 | class Configuration: 180 | def __init__(self, worldFileName, exp_name, mug_inds, view=True): 181 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 182 | 183 | state = torch.load('network/'+exp_name+'.pth.tar') 184 | config = state['config'] 185 | 186 | self.obj = Frame() 187 | self.obj.build_backbone(pretrained=True, **config) 188 | self.obj.build_sdf_head(config['SDF_HEAD_HIDDEN']) 189 | self.obj.build_keypoint_head('grasp', config['GRASP_HEAD_HIDDEN'], config['GRIPPER_POINTS']) 190 | self.obj.build_keypoint_head('hang', config['HANG_HEAD_HIDDEN'], config['HOOK_POINTS']) 191 | self.obj.load_state_dict(state['network']) 192 | self.obj.to(self.device).eval() 193 | 194 | self.F_grasp = KeyPoint_Feature(self.obj, 'grasp') 195 | self.F_hang = KeyPoint_Feature(self.obj, 'hang') 196 | 197 | self.graspInitDict = {} 198 | self.hangInitDict = {} 199 | 200 | self.sdf_scale = config['SDF_SCALE'] 201 | 202 | self.C = ry.Config() 203 | self.C.addFile(worldFileName) 204 | self.C.selectJoints([j for j in self.C.getJointNames() if j[-6:] != 'finger']) 205 | self.gripper_colls = ["gripper_coll", 206 | "L_finger_coll1", "L_finger_coll2", 207 | "R_finger_coll1", "R_finger_coll2"] 208 | 209 | for ind in mug_inds: 210 | self.C.addFrame('obj'+str(ind)) 211 | self.C.addFrame('mesh'+str(ind), 'obj'+str(ind)) 212 | 213 | self.view = view 214 | if view: 215 | self.V = ry.ConfigurationViewer() 216 | self.V.setConfiguration(self.C) 217 | 218 | def extractKeyFeatures(self, 219 | rgb_list, 220 | projection_list, 221 | obj_pos_list, 222 | obj_r_list): 223 | 224 | dx = torch.linspace(-1., 1., 10).to(self.device) 225 | grid_x, grid_y, grid_z = torch.meshgrid(dx, dx, dx) 226 | 227 | key_points_list, features_list = [], [] 228 | for rgb, projection, obj_pos, obj_r in zip(rgb_list, projection_list, obj_pos_list, obj_r_list): 229 | rgb_tensor = torch.Tensor(rgb).permute(0,3,1,2).to(self.device)/255. 230 | projection_tensor = torch.Tensor(projection).to(self.device) 231 | 232 | key_points = torch.stack([grid_x.flatten(), 233 | grid_y.flatten(), 234 | grid_z.flatten()], dim=1) 235 | key_points *= obj_r 236 | key_points += torch.Tensor(obj_pos).view(1,3).to(self.device) 237 | 238 | with torch.no_grad(): 239 | features = self.obj.backbone(key_points.unsqueeze(0), 240 | rgb_tensor.unsqueeze(0), 241 | projection_tensor.unsqueeze(0)).mean(dim=1) 242 | 243 | key_points_list.append(key_points.squeeze(0)) 244 | features_list.append(features.squeeze(0)) 245 | 246 | return key_points_list, features_list 247 | 248 | def setJointState(self, q): 249 | self.C.setJointState(q) 250 | 251 | def updatePIFO(self, 252 | inds, 253 | rgb_list, 254 | projection_list, 255 | obj_pos_list, 256 | obj_r_list, 257 | compute_init_guess=True, 258 | target_list=None): 259 | """ 260 | Args 261 | rgb_list, projection_list: (num_obj, num_cam) 262 | obj_pos_list, obj_r_list: (num_obj, ) 263 | target_list: list of dictionary (num_obj, ) {"target_name_list", "key_points_list", "features_list"} 264 | """ 265 | 266 | for i, (ind, rgb, projection, obj_pos, obj_r) in enumerate(zip(inds, rgb_list, projection_list, obj_pos_list, obj_r_list)): 267 | rgb_tensor = torch.Tensor(np.array(rgb)).permute(0,3,1,2).to(self.device)/255. 268 | projection_tensor = torch.Tensor(np.array(projection)).to(self.device) 269 | self.obj.backbone.encode(rgb_tensor.unsqueeze(0), projection_tensor.unsqueeze(0)) 270 | 271 | JIT_Collision_Feature(self.obj, self.sdf_scale).save("jit/sdfNet"+str(ind)+".pt") 272 | JIT_Keypoint_Feature(self.obj, "grasp").save("jit/graspNet"+str(ind)+".pt") 273 | JIT_Keypoint_Feature(self.obj, "hang").save("jit/hangNet"+str(ind)+".pt") 274 | 275 | if target_list is not None: 276 | dx = torch.linspace(-1., 1., 5).to(self.device) 277 | grid_x, grid_y, grid_z = torch.meshgrid(dx, dx, dx) 278 | 279 | scene_points = torch.stack([grid_x.flatten(), 280 | grid_y.flatten(), 281 | grid_z.flatten()], dim=1)*0.7 282 | scene_points *= obj_r 283 | scene_points += torch.Tensor(obj_pos).view(1,3).to(self.device) 284 | 285 | for j, target_name in enumerate(target_list[i]["target_name_list"]): 286 | JIT_ICP_Feature(self.obj, 287 | scene_points, 288 | target_list[i]["key_points_list"][j], 289 | target_list[i]["features_list"][j]).save("jit/ICPNet"+str(ind)+target_name+".pt") 290 | 291 | vertices, faces, normals = self.obj.extract_mesh(center=obj_pos, 292 | scale=obj_r, 293 | delta=0.0, 294 | draw=False) 295 | 296 | self.C.frame('obj'+str(ind)).setPosition(obj_pos) 297 | self.C.frame('mesh'+str(ind)).setMesh(vertices, faces).setRelativePosition(-obj_pos) 298 | 299 | 300 | if compute_init_guess: 301 | N_init = 10 302 | pos_init = obj_pos+np.array([0,0,.2])+np.random.randn(N_init,3)*.05 303 | quat_init = np.tile(np.array([1,0,0,0]), (N_init,1)) 304 | x_init = np.concatenate([pos_init, quat_init], axis=1) 305 | x_init = torch.Tensor(x_init).unsqueeze(0).to(self.device) 306 | self.graspInitDict[str(ind)] = self.getInitPose(self.F_grasp, x_init, 1e3, 1e-3) 307 | 308 | pos_init = obj_pos+np.random.randn(N_init,3)*.1 309 | quat_init = np.random.randn(N_init,4) 310 | quat_init /= np.linalg.norm(quat_init, axis=1, keepdims=True) 311 | x_init = np.concatenate([pos_init, quat_init], axis=1) 312 | x_init = torch.Tensor(x_init).unsqueeze(0).to(self.device) 313 | self.hangInitDict[str(ind)] = self.getInitPose(self.F_hang, x_init, 1e2, 1e-8) 314 | 315 | else: 316 | self.graspInitDict[str(ind)] = None 317 | self.hangInitDict[str(ind)] = None 318 | 319 | 320 | if self.view: 321 | self.V.recopyMeshes(self.C) 322 | self.V.setConfiguration(self.C) 323 | self.fInit = self.C.getFrameState() 324 | self.qInit = self.C.getJointState() 325 | 326 | def getInitPose(self, F, x_init, w_coll, coll_margin): 327 | x, cost, coll = F.optimize(x_init, max_iter=100) 328 | x, cost, coll = F.optimize(x, w_coll=w_coll, coll_margin=coll_margin, max_iter=100) 329 | best_ind = np.argmin(np.square(cost)+np.square(coll*w_coll), axis=1).flatten() 330 | 331 | return x[:, best_ind].squeeze().cpu().numpy() 332 | 333 | 334 | def solveKOMO(self, action_list, initSymbols=None, stepsPerPhase=10, verbose=3, animate=False): 335 | """ 336 | action: (grasp, gripper_prefix, mug_ind), (hang, hook_prefix, mug_ind), (pose, target_name, mug_ind) 337 | """ 338 | 339 | 340 | komo = self.C.komo(len(action_list), stepsPerPhase, 5., 2, False) 341 | komo.verbose(verbose) 342 | komo.animateOptimization(animate) 343 | Sk = [] 344 | if initSymbols is not None: 345 | for s in initSymbols: 346 | if s[0] == "grasp": 347 | Sk.extend([[0., 0.], ry.SY.stable, [s[1]+"gripper", "obj"+str(s[2])]]) 348 | 349 | for k, action in enumerate(action_list): 350 | if action[0] == "grasp": 351 | Sk.extend([[k+1., k+1.], ry.SY.stable, [action[1]+"gripper", "obj"+str(action[2])]]) 352 | elif action[0] == "hang": 353 | Sk.extend([[k+1., len(action_list)], ry.SY.stable, [action[1]+"hook", "obj"+str(action[2])]]) 354 | elif action[0] == "pose": 355 | Sk.extend([[k+1., len(action_list)], ry.SY.stable, ["gripper", "obj"+str(action[2])]]) 356 | 357 | if len(Sk)>0: komo.addSkeleton(Sk) 358 | 359 | komo.add_qControlObjective([], 2) 360 | komo.add_qControlObjective([], 1) 361 | # komo.add_qControlObjective([], 0, target=self.qInit) 362 | 363 | for k, action in enumerate(action_list): 364 | komo.addObjective([k+1], ry.FS.qItself, self.C.getJointNames(), ry.OT.eq, [1e1], order=1) 365 | 366 | mug_ind = action[-1] 367 | meshName = "mesh"+str(mug_ind) 368 | objName = "obj"+str(mug_ind) 369 | if action[0] == "grasp": 370 | komo.add_PFAccumulatedCollision([k+0.7, k+1.], 371 | [meshName]+[action[1]+c for c in self.gripper_colls], 372 | "jit/sdfNet"+str(mug_ind)+".pt", 373 | ry.OT.eq, [1e1]) 374 | komo.add_PFKeypointObjective([k+1], 375 | [meshName, action[1]+"gripperCenter"], 376 | "jit/graspNet"+str(mug_ind)+".pt", 377 | ry.OT.eq, [1e-1]) 378 | 379 | # komo.addObjective([k+0.7, k+1.], ry.FS.quaternion, 380 | # [action[1]+"gripperCenter"], ry.OT.eq, 381 | # [1e0], order=1) 382 | komo.addObjective([k+0.7, k+1.], ry.FS.positionRel, 383 | [objName, action[1]+"gripperCenter"], ry.OT.eq, 384 | [1e1], target=[0,0,-1/stepsPerPhase], order=2) 385 | 386 | poseInit = self.graspInitDict[str(mug_ind)] 387 | if poseInit is not None: 388 | komo.addObjective([k+1.], ry.FS.poseRel, 389 | [action[1]+"gripperCenter", meshName], ry.OT.eq, 390 | [1e-1], target=poseInit) 391 | # komo.addObjective([k+1.], ry.FS.positionRel, 392 | # [action[1]+"gripperCenter", meshName], ry.OT.eq, 393 | # [1e-1], target=poseInit[:3]) 394 | # komo.addObjective([k+1.], ry.FS.quaternionRel, 395 | # [action[1]+"gripperCenter", meshName], ry.OT.eq, 396 | # [1e-1], target=poseInit[3:]) 397 | 398 | if len(action[1])>0: 399 | for c1 in self.gripper_colls: 400 | for c2 in self.gripper_colls: 401 | komo.addObjective([k+0.7, k+1.], ry.FS.distance, ['R_'+c1, 'L_'+c2], ry.OT.ineq, [1e1]) 402 | 403 | 404 | elif action[0] == "hang": 405 | komo.add_PFAccumulatedCollision([k+0.7, k+1.], 406 | [meshName, action[1]+"hook_coll"], 407 | "jit/sdfNet"+str(mug_ind)+".pt", 408 | ry.OT.eq, [1e0], margin=0.0) 409 | komo.add_PFKeypointObjective([k+1.], 410 | [meshName, action[1]+"hook"], 411 | "jit/hangNet"+str(mug_ind)+".pt", 412 | ry.OT.eq, [1e0]) 413 | komo.addObjective([k+0.7, k+1.], ry.FS.positionRel, 414 | [objName, action[1]+"hook"], ry.OT.eq, 415 | [1e1], target=[0,0,1/stepsPerPhase], order=2) 416 | 417 | poseInit = self.hangInitDict[str(mug_ind)] 418 | if poseInit is not None: 419 | komo.addObjective([k+1.], ry.FS.poseRel, 420 | [action[1]+"hook", meshName], ry.OT.eq, 421 | [1e-1], target=poseInit) 422 | # komo.addObjective([k+1.], ry.FS.positionRel, 423 | # [action[1]+"hook", meshName], ry.OT.eq, 424 | # [1e-1], target=poseInit[:3]) 425 | # komo.addObjective([k+1.], ry.FS.quaternionRel, 426 | # [action[1]+"hook", meshName], ry.OT.eq, 427 | # [1e-1], target=poseInit[3:]) 428 | 429 | 430 | elif action[0] == "pose": 431 | komo.add_PFICPObjective([k+1], meshName, 432 | "jit/ICPNet"+str(mug_ind)+action[1]+".pt", 433 | ry.OT.eq, [1e-1]) 434 | 435 | 436 | komo.optimize(0.1) 437 | 438 | traj = np.zeros((len(action_list)*stepsPerPhase, self.C.getJointDimension())) 439 | for t in range(traj.shape[0]): 440 | self.C.setFrameState(komo.getConfiguration(t)) 441 | traj[t] = self.C.getJointState() 442 | self.C.setFrameState(self.fInit) 443 | return traj, komo 444 | 445 | 446 | def get_all_images(C, camera, camera_name_list, obj_name_list, r=0.15, res=128): 447 | rgb_list, mask_list, T_list, K_list = get_images(C, camera, camera_name_list, obj_name_list) 448 | rgb_focused_list, projection_list, obj_pos_list, obj_r_list = get_focused_images(rgb_list, mask_list, T_list, K_list, r, res) 449 | 450 | 451 | return rgb_list, mask_list, rgb_focused_list, projection_list, obj_pos_list, obj_r_list 452 | 453 | 454 | def get_images(C, camera, camera_name_list, obj_name_list): 455 | """ 456 | Take pictures from cameras 457 | Return 458 | rgb_list: (num_cam, ) 459 | mask_list, T_list, K_list: (num_obj, num_cam) 460 | """ 461 | rgb_list, T_list, K_list = [], [], [] 462 | mask_list = [[] for _ in obj_name_list] 463 | for camera_name in camera_name_list: 464 | camera.selectSensor(camera_name) 465 | camera.updateConfig(C) 466 | T, K = camera.getCameraMatrices() 467 | T_list.append(T) 468 | K_list.append(K) 469 | rgb_list.append(camera.computeImageAndDepth()[0]) 470 | for i, obj_name in enumerate(obj_name_list): 471 | mask_list[i].append(camera.extractMask(obj_name)) 472 | 473 | return rgb_list, mask_list, T_list, K_list 474 | 475 | 476 | def get_focused_images(rgb_list, mask_list, T_list, K_list, r=None, res=128): 477 | """ 478 | Multiview processing 479 | Args 480 | rgb_list: (num_cam) 481 | mask_list, T_list, K_list: (num_obj, num_cam) 482 | 483 | Return 484 | rgb_focused_list, projection_list: (num_obj, num_cam) 485 | obj_pos_list, obj_r_list: (num_obj, ) 486 | """ 487 | 488 | rgb_focused_list, projection_list, obj_pos_list, obj_r_list = [], [], [], [] 489 | for obj_mask_list in mask_list: 490 | obj_pos, obj_r = find_obj_pos_size(obj_mask_list, T_list, K_list) 491 | if r is None: r = obj_r 492 | 493 | obj_rgb_focused_list = [] 494 | obj_projection_list = [] 495 | for rgb, mask, T1, K1_rai in zip(rgb_list, obj_mask_list, T_list, K_list): 496 | H, T2, K2 = get_homography_matrix(T1, K1_rai, obj_pos, r) 497 | Hinv = np.linalg.inv(H) 498 | rgb_masked = rgb*np.expand_dims(mask,axis=2) 499 | rgb_focused = warp_with_homography(rgb_masked, Hinv, res) 500 | 501 | obj_rgb_focused_list.append(rgb_focused) 502 | obj_projection_list.append(K2@np.linalg.inv(T2)) 503 | 504 | rgb_focused_list.append(obj_rgb_focused_list) 505 | projection_list.append(obj_projection_list) 506 | obj_pos_list.append(obj_pos) 507 | obj_r_list.append(obj_r) 508 | 509 | return rgb_focused_list, projection_list, obj_pos_list, obj_r_list 510 | 511 | def get_homography_matrix(T1, K1_tmp, obj_pos, obj_r): 512 | f1x = K1_tmp[0,0] 513 | f1y = K1_tmp[1,1] 514 | K1 = np.diag([f1x, -f1y, -1]) 515 | 516 | cam_pos = T1[:3,3] 517 | T2 = get_camera_transform(cam_pos, obj_pos) 518 | cam_distance = np.linalg.norm(cam_pos - obj_pos) 519 | K2_full = get_camera_projection_matrix(cam_distance, obj_r) 520 | f2x = K2_full[0,0] 521 | f2y = K2_full[1,1] 522 | K2 = np.diag([f2x, f2y, -1]) 523 | 524 | R_2_1 = T2[:3,:3].T @ (T1[:3,:3]) 525 | H = K2 @ R_2_1 @ np.linalg.inv(K1) 526 | 527 | return H, T2, K2_full 528 | 529 | 530 | def mask_in_sphere(x, mask_list, T_list, K_list): 531 | dists = [] 532 | for mask, T1, K1_rai in zip(mask_list, T_list, K_list): 533 | # get uv coordinate of mask 534 | H, W = mask.shape 535 | i, j = np.where(mask) 536 | 537 | u = (2*j+1)/W - 1. 538 | v = (2*i+1)/H - 1. 539 | 540 | # get homography matrix 541 | obj_pos = x[:3] 542 | obj_r = x[3] 543 | H = get_homography_matrix(T1, K1_rai, obj_pos, obj_r)[0] 544 | 545 | # compute transformed uv coordinate 546 | uv1_ = np.stack([u,v,np.ones_like(u)], axis=0) 547 | uv2_ = H@uv1_ 548 | uv2 = uv2_[:2]/uv2_[2:3] 549 | 550 | uv_distance = 1.-np.linalg.norm(uv2, axis=0) 551 | 552 | dists.append(uv_distance) 553 | 554 | return np.hstack(dists) 555 | 556 | 557 | def find_obj_pos_size(mask_list, T_list, K_list): 558 | con1 = { 559 | 'type': 'ineq', 560 | 'fun': lambda x: mask_in_sphere(x, mask_list, T_list, K_list) 561 | } 562 | con2 = { 563 | 'type': 'ineq', 564 | 'fun': lambda x: x[3]-0.01 565 | } 566 | fun = lambda x: x[3]**2 567 | x0 = np.array([0,0,0.9,.1]) 568 | 569 | res = scipy.optimize.minimize(fun, x0, 570 | constraints=[con1, con2], 571 | # options={'disp': True} 572 | ) 573 | 574 | obj_pos = res['x'][:3] 575 | obj_r = res['x'][3] 576 | 577 | return obj_pos, obj_r 578 | 579 | def warp_with_homography(rgb, H, res): 580 | 581 | device = "cuda" if torch.cuda.is_available() else "cpu" 582 | 583 | rgb = torch.Tensor(rgb).permute(2,0,1).unsqueeze(0).to(device) 584 | H = torch.Tensor(H).to(device) 585 | 586 | x = torch.linspace(-1, 1, res).to(device) 587 | grid_v, grid_u = torch.meshgrid(x,x) 588 | base_grid = torch.stack([grid_u, grid_v, torch.ones_like(grid_u)], dim=2) # (res, res, 3) 589 | grid = base_grid.view(-1,3).mm(H.transpose(0,1)) # (res*res, 3) 590 | 591 | grid = grid[...,:2]/grid[...,2:3] 592 | grid = grid.view(1,res,res,2) # (1, res, res, 2) 593 | 594 | rgb_focused = torch.nn.functional.grid_sample(rgb, 595 | grid, 596 | mode='bilinear', 597 | padding_mode='zeros', 598 | align_corners=True) # (1, 3, res, res) 599 | 600 | return rgb_focused.squeeze(0).permute(1,2,0).cpu().numpy().astype(np.uint8) 601 | 602 | -------------------------------------------------------------------------------- /src/training.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch.utils.data import DataLoader 7 | 8 | from .dataset import * 9 | from .feature import * 10 | from .utils import * 11 | import h5py 12 | 13 | from .vector_object import * 14 | 15 | 16 | class Trainer: 17 | def __init__(self, obj, config): 18 | self.C = {} 19 | self.C['LEARNING_RATE'] = 1e-4 20 | self.C['NUM_EPOCHS'] = 1000000 21 | 22 | self.C['BATCH_SIZE'] = 39 23 | self.C['NUM_WORKERS'] = 0 24 | self.C['DATA_ON_GPU'] = True 25 | self.C['PIN_MEMORY'] = False 26 | 27 | self.C['PRINT_INTERVAL'] = 50 28 | self.C['TEST_INTERVAL'] = 50 29 | self.C['LOG_INTERVAL'] = 1 30 | self.C['SAVE_INTERVAL'] = 500 31 | 32 | self.C['IMG_RES'] = (128,128) 33 | self.C['NUM_VIEWS'] = 4 34 | 35 | self.C['WEIGHTED_L1'] = False 36 | self.C['DETACH_BACKBONE'] = False 37 | 38 | self.C['GRASP_LOSS_WEIGHT'] = 1. 39 | self.C['HANG_LOSS_WEIGHT'] = 1. 40 | 41 | self.C.update(config) 42 | 43 | self.device = torch.device("cuda" if self.C['DATA_ON_GPU'] else "cpu") 44 | 45 | 46 | self.trainset = PIFODataset(self.C['DATA_FILENAME'], 47 | num_views=self.C['NUM_VIEWS'], 48 | num_points=self.C['NUM_POINTS'], 49 | num_grasps=self.C['NUM_GRASPS'], 50 | num_hangs=self.C['NUM_HANGS'], 51 | grasp_draw_points=self.C['GRASP_DRAW_POINTS'], 52 | hang_draw_points=self.C['HANG_DRAW_POINTS'], 53 | on_gpu_memory=self.C['DATA_ON_GPU']) 54 | 55 | self.train_loader = DataLoader(self.trainset, 56 | batch_size=self.C['BATCH_SIZE'], 57 | shuffle=True, 58 | num_workers=self.C['NUM_WORKERS'], 59 | pin_memory=self.C['PIN_MEMORY']) 60 | 61 | self.testset = PIFODataset(self.C['TEST_DATA_FILENAME'], 62 | num_views=self.C['NUM_VIEWS'], 63 | num_points=self.C['NUM_POINTS'], 64 | num_grasps=self.C['NUM_GRASPS'], 65 | num_hangs=self.C['NUM_HANGS'], 66 | grasp_draw_points=self.C['GRASP_DRAW_POINTS'], 67 | hang_draw_points=self.C['HANG_DRAW_POINTS'], 68 | on_gpu_memory=self.C['DATA_ON_GPU']) 69 | 70 | self.test_loader = DataLoader(self.testset, 71 | # batch_size=self.C['BATCH_SIZE'], 72 | batch_size=len(self.testset), 73 | shuffle=True, 74 | num_workers=self.C['NUM_WORKERS'], 75 | pin_memory=self.C['PIN_MEMORY']) 76 | 77 | self.warper = RandomImageWarper(img_res=self.C['IMG_RES']) 78 | self.grasp_sampler = PoseSampler(scale=self.C['GRASP_COST_SCALE'].to(self.device)) 79 | self.hang_sampler = PoseSampler(scale=self.C['HANG_COST_SCALE'].to(self.device)) 80 | 81 | 82 | self.obj = obj 83 | self.F_sdf = SDF_Feature(obj) 84 | self.F_grasp = KeyPoint_Feature(obj, 'grasp', self.C['DETACH_BACKBONE']) 85 | self.F_hang = KeyPoint_Feature(obj, 'hang', self.C['DETACH_BACKBONE']) 86 | 87 | self.optimizer = torch.optim.Adam(obj.parameters(), lr=self.C['LEARNING_RATE']) 88 | 89 | if self.C['WEIGHTED_L1']: 90 | self.L1 = torch.nn.L1Loss(reduction='none') 91 | else: 92 | self.L1 = torch.nn.L1Loss() 93 | 94 | 95 | self.train_writer = SummaryWriter('runs/train/'+self.C['EXP_NAME']) 96 | self.test_writer = SummaryWriter('runs/test/'+self.C['EXP_NAME']) 97 | 98 | self.global_iter = 0 99 | 100 | def close(self): 101 | self.train_writer.close() 102 | self.test_writer.close() 103 | 104 | def weighted_loss(self, loss_func, preds, targets, h): 105 | tmp_loss = loss_func(preds, targets) 106 | far_samples = (targets.abs()>h).float() 107 | weight = 1.*far_samples + 10.*(1-far_samples) 108 | 109 | return (tmp_loss*weight).mean() 110 | 111 | def save_state(self, filename): 112 | state = { 113 | 'epoch': self.global_iter, 114 | 'config': self.C, 115 | 'network': self.obj.state_dict(), 116 | } 117 | torch.save(state, filename) 118 | 119 | def to_device(self, data): 120 | for key in data: 121 | if isinstance(data[key], torch.Tensor): 122 | data[key] = data[key].to(self.device) 123 | return data 124 | 125 | def forward_loss(self, data): 126 | data = self.to_device(data) 127 | rgb, projections = self.warper(data['rgb'], 128 | data['cam_extrinsic'], 129 | data['cam_intrinsic']) 130 | 131 | grasp_poses, grasp_costs = self.grasp_sampler(data['grasp_poses'], 132 | data['grasp_poses_all']) 133 | hang_poses, hang_costs = self.hang_sampler(data['hang_poses'], 134 | data['hang_poses_all']) 135 | 136 | self.obj.backbone.encode(rgb, projections) 137 | 138 | loss_dict = {'total_loss': 0} 139 | 140 | sdf_pred = self.F_sdf(data['points']) 141 | sdf_target = self.C['SDF_SCALE']*data['sdf'] 142 | if self.C['WEIGHTED_L1']: 143 | sdf_loss = self.weighted_loss(self.L1, 144 | sdf_pred, 145 | sdf_target, 146 | .01*self.C['SDF_SCALE']) 147 | else: 148 | sdf_loss = self.L1(sdf_pred, sdf_target) 149 | loss_dict['total_loss'] += sdf_loss 150 | loss_dict['sdf_loss'] = sdf_loss 151 | 152 | if self.C['GRASP_LOSS_WEIGHT'] > 0.: 153 | grasp_pred = self.F_grasp(grasp_poses).abs() 154 | if self.C['WEIGHTED_L1']: 155 | grasp_loss = self.weighted_loss(self.L1, grasp_pred, grasp_costs, 0.5) 156 | else: 157 | grasp_loss = self.L1(grasp_pred, grasp_costs) 158 | else: 159 | grasp_loss = torch.tensor(0., device=self.device) 160 | 161 | loss_dict['total_loss'] += self.C['GRASP_LOSS_WEIGHT']*grasp_loss 162 | loss_dict['grasp_loss'] = grasp_loss 163 | 164 | if self.C['HANG_LOSS_WEIGHT'] > 0.: 165 | hang_pred = self.F_hang(hang_poses).abs() 166 | if self.C['WEIGHTED_L1']: 167 | hang_loss = self.weighted_loss(self.L1, hang_pred, hang_costs, 0.5) 168 | else: 169 | hang_loss = self.L1(hang_pred, hang_costs) 170 | else: 171 | hang_loss = torch.tensor(0., device=self.device) 172 | 173 | loss_dict['total_loss'] += self.C['HANG_LOSS_WEIGHT']*hang_loss 174 | loss_dict['hang_loss'] = hang_loss 175 | 176 | return loss_dict 177 | 178 | def train(self, epoch): 179 | self.global_iter += 1 180 | self.obj.train() 181 | train_loss_dict = {'total_loss': 0., 'sdf_loss': 0., 'grasp_loss': 0., 'hang_loss': 0.} 182 | for data in self.train_loader: 183 | self.optimizer.zero_grad() 184 | loss_dict = self.forward_loss(data) 185 | loss_dict['total_loss'].backward() 186 | self.optimizer.step() 187 | 188 | w = data['sdf'].shape[0]/len(self.trainset) 189 | for l in train_loss_dict: 190 | train_loss_dict[l] += loss_dict[l].item()*w 191 | 192 | if epoch % self.C['LOG_INTERVAL'] == 0: 193 | for l in train_loss_dict: 194 | self.train_writer.add_scalar(l, train_loss_dict[l], self.global_iter) 195 | 196 | def test(self, epoch): 197 | self.obj.eval() 198 | test_loss_dict = {'total_loss': 0., 'sdf_loss': 0., 'grasp_loss': 0., 'hang_loss': 0.} 199 | with torch.no_grad(): 200 | for data in self.test_loader: 201 | loss_dict = self.forward_loss(data) 202 | w = data['sdf'].shape[0]/len(self.testset) 203 | for l in test_loss_dict: 204 | test_loss_dict[l] += loss_dict[l].item()*w 205 | 206 | for l in test_loss_dict: 207 | self.test_writer.add_scalar(l, test_loss_dict[l], self.global_iter) 208 | 209 | return test_loss_dict['total_loss'] 210 | 211 | 212 | # def get_feasibility(self, Feature, N, data): 213 | # B, device = data['rgb'].shape[0], data['rgb'].device 214 | # rgb, projections = self.warper(data['rgb'], 215 | # data['cam_extrinsic'], 216 | # data['cam_intrinsic']) 217 | # x = torch.cat([ 218 | # 0.2*torch.randn(B, N, 3, device=device), 219 | # random_quaternions(B*N, device=device).view(B,N,4) 220 | # ], dim=2) 221 | # x, cost, coll = Feature.optimize(x, 222 | # rgb, 223 | # projections, 224 | # print_interval=10, 225 | # max_iter=101, 226 | # gamma=1e-4) 227 | # best_ind = torch.Tensor(cost).to(device).argmin(dim=1).view(B,1,1).expand(-1,1,7) 228 | # feasibility = Feature.check_feasibility(torch.gather(x, dim=1, index=best_ind), 229 | # data['filenames'], 230 | # data['masses'].numpy(), 231 | # data['coms'].numpy()) # (B, 1) 232 | # return feasibility.sum() 233 | 234 | 235 | # def evaluate(self, N=20): 236 | # num_feasible_grasp = 0 237 | # num_feasible_hang = 0 238 | # self.train_loader.dataset.random_erase = False 239 | # self.test_loader.dataset.random_erase = False 240 | # for data in self.train_loader: 241 | # num_feasible_grasp += self.get_feasibility(self.F_grasp, N, data) 242 | # num_feasible_hang += self.get_feasibility(self.F_hang, N, data) 243 | 244 | # print(num_feasible_grasp, num_feasible_hang) 245 | # self.train_writer.add_scalar('eval_grasp', num_feasible_grasp/len(self.train_loader.dataset)*100., self.global_iter) 246 | # self.train_writer.add_scalar('eval_hang', num_feasible_hang/len(self.train_loader.dataset)*100., self.global_iter) 247 | 248 | 249 | # num_feasible_grasp = 0 250 | # num_feasible_hang = 0 251 | # for data in self.test_loader: 252 | # num_feasible_grasp += self.get_feasibility(self.F_grasp, N, data) 253 | # num_feasible_hang += self.get_feasibility(self.F_hang, N, data) 254 | 255 | # print(num_feasible_grasp, num_feasible_hang) 256 | # self.test_writer.add_scalar('eval_grasp', num_feasible_grasp/len(self.test_loader.dataset)*100., self.global_iter) 257 | # self.test_writer.add_scalar('eval_hang', num_feasible_hang/len(self.test_loader.dataset)*100., self.global_iter) 258 | 259 | 260 | # self.train_loader.dataset.random_erase = True 261 | # self.test_loader.dataset.random_erase = True 262 | 263 | def get_optim_results(self, Feature, N, data): 264 | B, device = data['rgb'].shape[0], data['rgb'].device 265 | rgb, projections = self.warper(data['rgb'], 266 | data['cam_extrinsic'], 267 | data['cam_intrinsic']) 268 | x = torch.cat([ 269 | 0.2*torch.randn(B, N, 3, device=device), 270 | random_quaternions(B*N, device=device).view(B,N,4) 271 | ], dim=2) 272 | x, cost, coll = Feature.optimize(x, 273 | rgb, 274 | projections, 275 | # print_interval=10, 276 | max_iter=101, 277 | gamma=1e-4) 278 | best_ind = torch.Tensor(cost).to(device).argmin(dim=1).view(B,1,1).expand(-1,1,7) 279 | best_x = torch.gather(x, dim=1, index=best_ind).view(-1,7).cpu().numpy() 280 | 281 | return best_x, data['filenames'], data['masses'].numpy(), data['coms'].numpy() 282 | 283 | def save_optims(self, N=20): 284 | self.train_loader.dataset.random_erase = False 285 | self.test_loader.dataset.random_erase = False 286 | 287 | x, filenames, masses, coms = [], [], [], [] 288 | x_h, filenames_h, masses_h, coms_h = [], [], [], [] 289 | for data in self.train_loader: 290 | x_, filenames_, masses_, coms_ = self.get_optim_results(self.F_grasp, N, data) 291 | x.append(x_) 292 | filenames.extend(filenames_) 293 | masses.append(masses_) 294 | coms.append(coms_) 295 | 296 | x_, filenames_, masses_, coms_ = self.get_optim_results(self.F_hang, N, data) 297 | x_h.append(x_) 298 | filenames_h.extend(filenames_) 299 | masses_h.append(masses_) 300 | coms_h.append(coms_) 301 | 302 | x = np.concatenate(x, axis=0) 303 | masses = np.concatenate(masses, axis=0) 304 | coms = np.concatenate(coms, axis=0) 305 | 306 | optim_data = h5py.File('evals/'+self.C['EXP_NAME']+'/train_grasp_'+str(self.global_iter)+'.hdf5', mode='w') 307 | dt = h5py.special_dtype(vlen=str) 308 | optim_data.create_dataset("mesh_filename", data=np.array(filenames, dtype=dt)) 309 | optim_data.create_dataset("best_x", data=x) 310 | optim_data.create_dataset("mass", data=masses) 311 | optim_data.create_dataset("com", data=coms) 312 | optim_data.close() 313 | 314 | 315 | x_h = np.concatenate(x_h, axis=0) 316 | masses_h = np.concatenate(masses_h, axis=0) 317 | coms_h = np.concatenate(coms_h, axis=0) 318 | 319 | optim_data = h5py.File('evals/'+self.C['EXP_NAME']+'/train_hang_'+str(self.global_iter)+'.hdf5', mode='w') 320 | dt = h5py.special_dtype(vlen=str) 321 | optim_data.create_dataset("mesh_filename", data=np.array(filenames_h, dtype=dt)) 322 | optim_data.create_dataset("best_x", data=x_h) 323 | optim_data.create_dataset("mass", data=masses_h) 324 | optim_data.create_dataset("com", data=coms_h) 325 | optim_data.close() 326 | 327 | x, filenames, masses, coms = [], [], [], [] 328 | x_h, filenames_h, masses_h, coms_h = [], [], [], [] 329 | for data in self.test_loader: 330 | x_, filenames_, masses_, coms_ = self.get_optim_results(self.F_grasp, N, data) 331 | x.append(x_) 332 | filenames.extend(filenames_) 333 | masses.append(masses_) 334 | coms.append(coms_) 335 | 336 | x_, filenames_, masses_, coms_ = self.get_optim_results(self.F_hang, N, data) 337 | x_h.append(x_) 338 | filenames_h.extend(filenames_) 339 | masses_h.append(masses_) 340 | coms_h.append(coms_) 341 | 342 | x = np.concatenate(x, axis=0) 343 | masses = np.concatenate(masses, axis=0) 344 | coms = np.concatenate(coms, axis=0) 345 | 346 | optim_data = h5py.File('evals/'+self.C['EXP_NAME']+'/test_grasp_'+str(self.global_iter)+'.hdf5', mode='w') 347 | dt = h5py.special_dtype(vlen=str) 348 | optim_data.create_dataset("mesh_filename", data=np.array(filenames, dtype=dt)) 349 | optim_data.create_dataset("best_x", data=x) 350 | optim_data.create_dataset("mass", data=masses) 351 | optim_data.create_dataset("com", data=coms) 352 | optim_data.close() 353 | 354 | 355 | x_h = np.concatenate(x_h, axis=0) 356 | masses_h = np.concatenate(masses_h, axis=0) 357 | coms_h = np.concatenate(coms_h, axis=0) 358 | 359 | optim_data = h5py.File('evals/'+self.C['EXP_NAME']+'/test_hang_'+str(self.global_iter)+'.hdf5', mode='w') 360 | dt = h5py.special_dtype(vlen=str) 361 | optim_data.create_dataset("mesh_filename", data=np.array(filenames_h, dtype=dt)) 362 | optim_data.create_dataset("best_x", data=x_h) 363 | optim_data.create_dataset("mass", data=masses_h) 364 | optim_data.create_dataset("com", data=coms_h) 365 | optim_data.close() 366 | 367 | self.train_loader.dataset.random_erase = True 368 | self.test_loader.dataset.random_erase = True 369 | 370 | 371 | 372 | 373 | def perturb(new_origin, cam_pos, points, grasp_poses=None, hang_poses=None): 374 | """ 375 | Args: 376 | new_origin: (B, 1, 3) 377 | cam_pos: (B, num_views, 3) 378 | points: (B, num_points, 3) 379 | grasp_poses: (B, num_poses, 7) 380 | hang_poses: (B, num_poses, 7) 381 | 382 | Return: 383 | """ 384 | 385 | new_quat_inv = random_quaternions(new_origin.shape[0], device = new_origin.device).unsqueeze(1) 386 | # (B, 1, 4) 387 | 388 | new_cam_pos = quaternion_apply(new_quat_inv, cam_pos-new_origin) 389 | 390 | new_points = quaternion_apply(new_quat_inv, points-new_origin) 391 | 392 | grasp_pos, grasp_quat = torch.split(grasp_poses, [3,4], dim=2) 393 | new_grasp_pos = quaternion_apply(new_quat_inv, grasp_pos-new_origin) 394 | new_grasp_quat = quaternion_multiply(new_quat_inv, grasp_quat) 395 | new_grasp_pose = torch.cat([new_grasp_pos, new_grasp_quat], dim=2) 396 | 397 | hang_pos, hang_quat = torch.split(hang_poses, [3,4], dim=2) 398 | new_hang_pos = quaternion_apply(new_quat_inv, hang_pos-new_origin) 399 | new_hang_quat = quaternion_multiply(new_quat_inv, hang_quat) 400 | new_hang_pose = torch.cat([new_hang_pos, new_hang_quat], dim=2) 401 | 402 | return new_cam_pos, new_points, new_grasp_pose, new_hang_pose 403 | 404 | 405 | class Trainer_vec(Trainer): 406 | def __init__(self, obj, config): 407 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 408 | self.C = {} 409 | self.C['LEARNING_RATE'] = 1e-4 410 | self.C['NUM_EPOCHS'] = 1000000 411 | 412 | self.C['BATCH_SIZE'] = 50 413 | self.C['NUM_WORKERS'] = 0 414 | self.C['DATA_ON_GPU'] = True 415 | self.C['PIN_MEMORY'] = False 416 | 417 | self.C['PRINT_INTERVAL'] = 50 418 | self.C['TEST_INTERVAL'] = 50 419 | self.C['LOG_INTERVAL'] = 1 420 | self.C['SAVE_INTERVAL'] = 500 421 | 422 | self.C['IMG_RES'] = (128,128) 423 | self.C['NUM_VIEWS'] = 4 424 | 425 | self.C['GRASP_LOSS_WEIGHT'] = 1. 426 | self.C['HANG_LOSS_WEIGHT'] = 1. 427 | 428 | self.C.update(config) 429 | 430 | 431 | 432 | self.trainset = PIFODataset(self.C['DATA_FILENAME'], 433 | num_views=self.C['NUM_VIEWS'], 434 | num_points=self.C['NUM_POINTS'], 435 | num_grasps=self.C['NUM_GRASPS'], 436 | num_hangs=self.C['NUM_HANGS'], 437 | grasp_draw_points=self.C['GRASP_DRAW_POINTS'], 438 | hang_draw_points=self.C['HANG_DRAW_POINTS'], 439 | on_gpu_memory=self.C['DATA_ON_GPU']) 440 | 441 | self.train_loader = DataLoader(self.trainset, 442 | batch_size=self.C['BATCH_SIZE'], 443 | shuffle=True, 444 | num_workers=self.C['NUM_WORKERS'], 445 | pin_memory=self.C['PIN_MEMORY']) 446 | 447 | self.testset = PIFODataset(self.C['TEST_DATA_FILENAME'], 448 | num_views=self.C['NUM_VIEWS'], 449 | num_points=self.C['NUM_POINTS'], 450 | num_grasps=self.C['NUM_GRASPS'], 451 | num_hangs=self.C['NUM_HANGS'], 452 | grasp_draw_points=self.C['GRASP_DRAW_POINTS'], 453 | hang_draw_points=self.C['HANG_DRAW_POINTS'], 454 | on_gpu_memory=self.C['DATA_ON_GPU']) 455 | 456 | self.test_loader = DataLoader(self.testset, 457 | batch_size=self.C['BATCH_SIZE'], 458 | shuffle=True, 459 | num_workers=self.C['NUM_WORKERS'], 460 | pin_memory=self.C['PIN_MEMORY']) 461 | 462 | self.warper = RandomImageWarper(img_res=self.C['IMG_RES'], return_cam_params=True) 463 | self.grasp_sampler = PoseSampler(scale=self.C['GRASP_COST_SCALE'].to(self.device)) 464 | self.hang_sampler = PoseSampler(scale=self.C['HANG_COST_SCALE'].to(self.device)) 465 | 466 | 467 | self.obj = obj 468 | self.F_sdf = SDF_Feature_vec(obj) 469 | self.F_grasp = Pose_Feature_vec(obj, 'grasp') 470 | self.F_hang = Pose_Feature_vec(obj, 'hang') 471 | 472 | self.optimizer = torch.optim.Adam(obj.parameters(), lr=self.C['LEARNING_RATE']) 473 | 474 | self.L1 = torch.nn.L1Loss(reduction='none') 475 | 476 | 477 | self.train_writer = SummaryWriter('runs/train/'+self.C['EXP_NAME']) 478 | self.test_writer = SummaryWriter('runs/test/'+self.C['EXP_NAME']) 479 | 480 | self.global_iter = 0 481 | 482 | def forward_loss(self, data): 483 | data = self.to_device(data) 484 | rgb, projections, cam_pos, new_origin, cam_roll = self.warper( 485 | data['rgb'], data['cam_extrinsic'], data['cam_intrinsic'] 486 | ) 487 | 488 | grasp_poses, grasp_costs = self.grasp_sampler(data['grasp_poses'], 489 | data['grasp_poses_all']) 490 | hang_poses, hang_costs = self.hang_sampler(data['hang_poses'], 491 | data['hang_poses_all']) 492 | 493 | cam_pos, points, grasp_poses, hang_poses = perturb(new_origin, cam_pos, data['points'], grasp_poses, hang_poses) 494 | self.obj.backbone.encode(rgb, torch.cat([cam_pos, cam_roll], dim=2)) 495 | 496 | loss_dict = {'total_loss': 0} 497 | 498 | sdf_pred = self.F_sdf(points) 499 | sdf_target = self.C['SDF_SCALE']*data['sdf'] 500 | sdf_loss = self.weighted_loss(self.L1, 501 | sdf_pred, 502 | sdf_target, 503 | .05*self.C['SDF_SCALE']) 504 | loss_dict['total_loss'] += sdf_loss 505 | loss_dict['sdf_loss'] = sdf_loss 506 | 507 | if self.C['GRASP_LOSS_WEIGHT'] > 0.: 508 | grasp_pred = self.F_grasp(grasp_poses).abs() 509 | grasp_loss = self.weighted_loss(self.L1, grasp_pred, grasp_costs, 1) 510 | else: 511 | grasp_loss = torch.tensor(0., device=self.device) 512 | 513 | loss_dict['total_loss'] += self.C['GRASP_LOSS_WEIGHT']*grasp_loss 514 | loss_dict['grasp_loss'] = grasp_loss 515 | 516 | if self.C['HANG_LOSS_WEIGHT'] > 0.: 517 | hang_pred = self.F_hang(hang_poses).abs() 518 | hang_loss = self.weighted_loss(self.L1, hang_pred, hang_costs, 1) 519 | else: 520 | hang_loss = torch.tensor(0., device=self.device) 521 | 522 | loss_dict['total_loss'] += self.C['HANG_LOSS_WEIGHT']*hang_loss 523 | loss_dict['hang_loss'] = hang_loss 524 | 525 | return loss_dict 526 | 527 | 528 | 529 | class Trainer_notShared(Trainer): 530 | def __init__(self, obj_list, config): 531 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 532 | self.C = {} 533 | self.C['LEARNING_RATE'] = 1e-4 534 | self.C['NUM_EPOCHS'] = 1000000 535 | 536 | self.C['BATCH_SIZE'] = 50 537 | self.C['NUM_WORKERS'] = 0 538 | self.C['DATA_ON_GPU'] = True 539 | self.C['PIN_MEMORY'] = False 540 | 541 | self.C['PRINT_INTERVAL'] = 50 542 | self.C['TEST_INTERVAL'] = 50 543 | self.C['LOG_INTERVAL'] = 1 544 | self.C['SAVE_INTERVAL'] = 500 545 | 546 | self.C['IMG_RES'] = (128,128) 547 | self.C['NUM_VIEWS'] = 4 548 | 549 | self.C['GRASP_LOSS_WEIGHT'] = 1. 550 | self.C['HANG_LOSS_WEIGHT'] = 1. 551 | 552 | self.C.update(config) 553 | 554 | 555 | 556 | self.trainset = PIFODataset(self.C['DATA_FILENAME'], 557 | num_views=self.C['NUM_VIEWS'], 558 | num_points=self.C['NUM_POINTS'], 559 | num_grasps=self.C['NUM_GRASPS'], 560 | num_hangs=self.C['NUM_HANGS'], 561 | grasp_draw_points=self.C['GRASP_DRAW_POINTS'], 562 | hang_draw_points=self.C['HANG_DRAW_POINTS'], 563 | on_gpu_memory=self.C['DATA_ON_GPU']) 564 | 565 | self.train_loader = DataLoader(self.trainset, 566 | batch_size=self.C['BATCH_SIZE'], 567 | shuffle=True, 568 | num_workers=self.C['NUM_WORKERS'], 569 | pin_memory=self.C['PIN_MEMORY']) 570 | 571 | self.testset = PIFODataset(self.C['TEST_DATA_FILENAME'], 572 | num_views=self.C['NUM_VIEWS'], 573 | num_points=self.C['NUM_POINTS'], 574 | num_grasps=self.C['NUM_GRASPS'], 575 | num_hangs=self.C['NUM_HANGS'], 576 | grasp_draw_points=self.C['GRASP_DRAW_POINTS'], 577 | hang_draw_points=self.C['HANG_DRAW_POINTS'], 578 | on_gpu_memory=self.C['DATA_ON_GPU']) 579 | 580 | self.test_loader = DataLoader(self.testset, 581 | batch_size=self.C['BATCH_SIZE'], 582 | shuffle=True, 583 | num_workers=self.C['NUM_WORKERS'], 584 | pin_memory=self.C['PIN_MEMORY']) 585 | 586 | self.warper = RandomImageWarper(img_res=self.C['IMG_RES']) 587 | self.grasp_sampler = PoseSampler(scale=self.C['GRASP_COST_SCALE'].to(self.device)) 588 | self.hang_sampler = PoseSampler(scale=self.C['HANG_COST_SCALE'].to(self.device)) 589 | 590 | self.obj_list = obj_list 591 | self.F_sdf = SDF_Feature(obj_list[0]) 592 | self.F_grasp = KeyPoint_Feature(obj_list[1], 'grasp') 593 | self.F_hang = KeyPoint_Feature(obj_list[2], 'hang') 594 | 595 | params = [] 596 | for obj in obj_list: 597 | params += list(obj.parameters()) 598 | self.optimizer = torch.optim.Adam(params, lr=self.C['LEARNING_RATE']) 599 | 600 | self.L1 = torch.nn.L1Loss(reduction='none') 601 | 602 | 603 | self.train_writer = SummaryWriter('runs/train/'+self.C['EXP_NAME']) 604 | self.test_writer = SummaryWriter('runs/test/'+self.C['EXP_NAME']) 605 | 606 | self.global_iter = 0 607 | 608 | def forward_loss(self, data): 609 | data = self.to_device(data) 610 | rgb, projections = self.warper(data['rgb'], 611 | data['cam_extrinsic'], 612 | data['cam_intrinsic']) 613 | 614 | grasp_poses, grasp_costs = self.grasp_sampler(data['grasp_poses'], 615 | data['grasp_poses_all']) 616 | hang_poses, hang_costs = self.hang_sampler(data['hang_poses'], 617 | data['hang_poses_all']) 618 | 619 | loss_dict = {'total_loss': 0} 620 | 621 | sdf_pred = self.F_sdf(data['points'], rgb, projections) 622 | sdf_target = self.C['SDF_SCALE']*data['sdf'] 623 | sdf_loss = self.weighted_loss(self.L1, 624 | sdf_pred, 625 | sdf_target, 626 | .05*self.C['SDF_SCALE']) 627 | loss_dict['total_loss'] += sdf_loss 628 | loss_dict['sdf_loss'] = sdf_loss 629 | 630 | if self.C['GRASP_LOSS_WEIGHT'] > 0.: 631 | grasp_pred = self.F_grasp(grasp_poses, rgb, projections).abs() 632 | grasp_loss = self.weighted_loss(self.L1, grasp_pred, grasp_costs, 1) 633 | else: 634 | grasp_loss = torch.tensor(0., device=self.device) 635 | 636 | loss_dict['total_loss'] += self.C['GRASP_LOSS_WEIGHT']*grasp_loss 637 | loss_dict['grasp_loss'] = grasp_loss 638 | 639 | if self.C['HANG_LOSS_WEIGHT'] > 0.: 640 | hang_pred = self.F_hang(hang_poses, rgb, projections).abs() 641 | hang_loss = self.weighted_loss(self.L1, hang_pred, hang_costs, 1) 642 | else: 643 | hang_loss = torch.tensor(0., device=self.device) 644 | 645 | loss_dict['total_loss'] += self.C['HANG_LOSS_WEIGHT']*hang_loss 646 | loss_dict['hang_loss'] = hang_loss 647 | 648 | return loss_dict 649 | 650 | def train(self, epoch): 651 | self.global_iter += 1 652 | for obj in self.obj_list: 653 | obj.train() 654 | train_loss_dict = {'total_loss': 0., 'sdf_loss': 0., 'grasp_loss': 0., 'hang_loss': 0.} 655 | for data in self.train_loader: 656 | self.optimizer.zero_grad() 657 | loss_dict = self.forward_loss(data) 658 | loss_dict['total_loss'].backward() 659 | self.optimizer.step() 660 | 661 | w = data['sdf'].shape[0]/len(self.trainset) 662 | for l in train_loss_dict: 663 | train_loss_dict[l] += loss_dict[l].item()*w 664 | 665 | if epoch % self.C['LOG_INTERVAL'] == 0: 666 | for l in train_loss_dict: 667 | self.train_writer.add_scalar(l, train_loss_dict[l], self.global_iter) 668 | 669 | def test(self, epoch): 670 | for obj in self.obj_list: 671 | obj.eval() 672 | test_loss_dict = {'total_loss': 0., 'sdf_loss': 0., 'grasp_loss': 0., 'hang_loss': 0.} 673 | with torch.no_grad(): 674 | for data in self.test_loader: 675 | loss_dict = self.forward_loss(data) 676 | w = data['sdf'].shape[0]/len(self.testset) 677 | for l in test_loss_dict: 678 | test_loss_dict[l] += loss_dict[l].item()*w 679 | 680 | for l in test_loss_dict: 681 | self.test_writer.add_scalar(l, test_loss_dict[l], self.global_iter) 682 | 683 | return test_loss_dict['total_loss'] 684 | 685 | 686 | def save_state(self, filename): 687 | state = { 688 | 'epoch': self.global_iter, 689 | 'config': self.C, 690 | 'network0': self.obj_list[0].state_dict(), 691 | 'network1': self.obj_list[1].state_dict(), 692 | 'network2': self.obj_list[2].state_dict(), 693 | } 694 | torch.save(state, filename) -------------------------------------------------------------------------------- /notebooks/evaluation_grasp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "74ce6375-87eb-49d4-ba23-5da6f828b7e1", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# %matplotlib widget\n", 11 | "import os\n", 12 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 13 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"4\"\n", 14 | "os.environ['PYOPENGL_PLATFORM'] = 'egl'\n", 15 | "\n", 16 | "from src.vector_object import *\n", 17 | "\n", 18 | "from src.frame import Frame\n", 19 | "\n", 20 | "from src.feature import SDF_Feature\n", 21 | "from src.dataset import *\n", 22 | "from src.utils import *\n", 23 | "from src.data_gen_utils import *\n", 24 | "\n", 25 | "from os import path\n", 26 | "from tqdm.notebook import tqdm\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "\n", 29 | "import trimesh" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "362d9a3f-a6fd-4573-bf38-fe0ca96d8bd4", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 40 | "exp_name = 'PIFO_best'\n", 41 | "state = torch.load('network/'+exp_name+'.pth.tar')\n", 42 | "C = state['config']\n", 43 | "\n", 44 | "trainset = PIFODataset(C['DATA_FILENAME'],\n", 45 | " num_views=C['NUM_VIEWS'],\n", 46 | " num_points=C['NUM_POINTS'],\n", 47 | " num_grasps=C['NUM_GRASPS'],\n", 48 | " num_hangs=C['NUM_HANGS'],\n", 49 | " grasp_draw_points=C['GRASP_DRAW_POINTS'],\n", 50 | " hang_draw_points=C['HANG_DRAW_POINTS'],\n", 51 | " random_erase=False,\n", 52 | " on_gpu_memory=True)\n", 53 | "\n", 54 | "testset = PIFODataset('data/test_batch.hdf5',\n", 55 | " num_views=C['NUM_VIEWS'],\n", 56 | " num_points=C['NUM_POINTS'],\n", 57 | " num_grasps=C['NUM_GRASPS'],\n", 58 | " num_hangs=C['NUM_HANGS'],\n", 59 | " grasp_draw_points=C['GRASP_DRAW_POINTS'],\n", 60 | " hang_draw_points=C['HANG_DRAW_POINTS'],\n", 61 | " random_erase=False,\n", 62 | " on_gpu_memory=True)\n", 63 | "\n", 64 | "\n", 65 | "warper = RandomImageWarper(img_res=C['IMG_RES'], \n", 66 | " sig_center=0, \n", 67 | " return_cam_params=True)\n", 68 | "\n", 69 | "# PIFO\n", 70 | "obj1 = Frame()\n", 71 | "obj1.build_backbone(pretrained=True, **C)\n", 72 | "obj1.build_sdf_head(C['SDF_HEAD_HIDDEN'])\n", 73 | "obj1.build_keypoint_head('grasp', C['GRASP_HEAD_HIDDEN'], C['GRIPPER_POINTS'])\n", 74 | "obj1.build_keypoint_head('hang', C['HANG_HEAD_HIDDEN'], C['HOOK_POINTS'])\n", 75 | "obj1.load_state_dict(state['network'])\n", 76 | "obj1.to(device).eval()\n", 77 | "F_grasp1 = KeyPoint_Feature(obj1, 'grasp')\n", 78 | "F_hang1 = KeyPoint_Feature(obj1, 'hang')\n", 79 | "\n", 80 | "# noPixel\n", 81 | "exp_name = 'noPixelAligned_best'\n", 82 | "state = torch.load('network/'+exp_name+'.pth.tar')\n", 83 | "C = state['config']\n", 84 | "obj2 = Frame()\n", 85 | "obj2.build_backbone(pretrained=True, **C)\n", 86 | "obj2.build_sdf_head(C['SDF_HEAD_HIDDEN'])\n", 87 | "obj2.build_keypoint_head('grasp', C['GRASP_HEAD_HIDDEN'], C['GRIPPER_POINTS'])\n", 88 | "obj2.build_keypoint_head('hang', C['HANG_HEAD_HIDDEN'], C['HOOK_POINTS'])\n", 89 | "obj2.load_state_dict(state['network'])\n", 90 | "obj2.to(device).eval()\n", 91 | "F_grasp2 = KeyPoint_Feature(obj2, 'grasp')\n", 92 | "F_hang2 = KeyPoint_Feature(obj2, 'hang')\n", 93 | "\n", 94 | "# vecObj\n", 95 | "exp_name = 'vectorObject_best'\n", 96 | "state = torch.load('network/'+exp_name+'.pth.tar')\n", 97 | "C = state['config']\n", 98 | "obj3 = Frame_vec()\n", 99 | "obj3.build_backbone(pretrained=True, **C)\n", 100 | "obj3.build_sdf_head(C['SDF_HEAD_HIDDEN'])\n", 101 | "obj3.build_pose_head('grasp', C['GRASP_HEAD_HIDDEN'])\n", 102 | "obj3.build_pose_head('hang', C['HANG_HEAD_HIDDEN'])\n", 103 | "obj3.load_state_dict(state['network'])\n", 104 | "obj3.to(device).eval()\n", 105 | "\n", 106 | "F_grasp3 = Pose_Feature_vec(obj3, 'grasp')\n", 107 | "F_hang3 = Pose_Feature_vec(obj3, 'hang')" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 3, 113 | "id": "609e45e2-3e80-4ac2-a7a0-fd372d824798", 114 | "metadata": { 115 | "collapsed": true, 116 | "jupyter": { 117 | "outputs_hidden": true 118 | }, 119 | "tags": [] 120 | }, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "======================== 2 ========================\n", 127 | "iter: 0, cost: 6.0169219970703125, coll: 0.0\n", 128 | "iter: 0, cost: 3.974701166152954, coll: 233.3343048095703\n", 129 | "iter: 0, cost: 5.02092981338501, coll: 0.0\n", 130 | "iter: 0, cost: 2.5152018070220947, coll: 188.2131805419922\n", 131 | "iter: 0, cost: 5.2272868156433105, coll: 0.0\n", 132 | "iter: 0, cost: 2.806736946105957, coll: 278.63421630859375\n", 133 | "iter: 0, cost: 5.037403106689453, coll: 0.0\n", 134 | "iter: 0, cost: 2.8562281131744385, coll: 199.72267150878906\n", 135 | "iter: 0, cost: 4.787069797515869, coll: 0.0\n", 136 | "iter: 0, cost: 2.8516359329223633, coll: 217.81399536132812\n", 137 | "iter: 0, cost: 4.920601844787598, coll: 0.0\n", 138 | "iter: 0, cost: 2.92130446434021, coll: 277.5904235839844\n", 139 | "iter: 0, cost: 5.023969650268555, coll: 0.0\n", 140 | "iter: 0, cost: 3.83841609954834, coll: 242.57069396972656\n", 141 | "iter: 0, cost: 4.968585014343262, coll: 0.0\n", 142 | "iter: 0, cost: 2.5577192306518555, coll: 204.83006286621094\n", 143 | "iter: 0, cost: 4.99035120010376, coll: 0.0\n", 144 | "iter: 0, cost: 2.4886393547058105, coll: 282.6738586425781\n", 145 | "iter: 0, cost: 5.9312920570373535, coll: 0.0\n", 146 | "iter: 0, cost: 4.017564296722412, coll: 255.24298095703125\n", 147 | "iter: 0, cost: 5.5413312911987305, coll: 0.0\n", 148 | "iter: 0, cost: 2.598666191101074, coll: 226.20899963378906\n", 149 | "iter: 0, cost: 5.262668132781982, coll: 0.0\n", 150 | "iter: 0, cost: 2.7295918464660645, coll: 283.5034484863281\n", 151 | "iter: 0, cost: 5.744115352630615, coll: 0.0\n", 152 | "iter: 0, cost: 2.9624688625335693, coll: 235.9788055419922\n", 153 | "iter: 0, cost: 5.158945560455322, coll: 0.0\n", 154 | "iter: 0, cost: 2.7271575927734375, coll: 237.20535278320312\n", 155 | "iter: 0, cost: 5.259897708892822, coll: 0.0\n", 156 | "iter: 0, cost: 2.6747183799743652, coll: 262.427001953125\n", 157 | "iter: 0, cost: 7.499991416931152, coll: 0.0\n", 158 | "iter: 0, cost: 5.040233135223389, coll: 207.8163299560547\n", 159 | "iter: 0, cost: 5.364550590515137, coll: 0.0\n", 160 | "iter: 0, cost: 2.6009504795074463, coll: 216.362548828125\n", 161 | "iter: 0, cost: 4.828066349029541, coll: 0.0\n", 162 | "iter: 0, cost: 2.8663976192474365, coll: 281.2763366699219\n", 163 | "iter: 0, cost: 5.058730125427246, coll: 0.0\n", 164 | "iter: 0, cost: 3.5493361949920654, coll: 219.26673889160156\n", 165 | "iter: 0, cost: 4.969138145446777, coll: 0.0\n", 166 | "iter: 0, cost: 2.518332004547119, coll: 288.5077819824219\n", 167 | "iter: 0, cost: 5.130303382873535, coll: 0.0\n", 168 | "iter: 0, cost: 3.028677225112915, coll: 277.05694580078125\n", 169 | "iter: 0, cost: 4.851702690124512, coll: 0.0\n", 170 | "iter: 0, cost: 4.089438438415527, coll: 220.92559814453125\n", 171 | "iter: 0, cost: 4.686545372009277, coll: 0.0\n", 172 | "iter: 0, cost: 2.848837375640869, coll: 199.71475219726562\n", 173 | "iter: 0, cost: 4.62904167175293, coll: 0.0\n", 174 | "iter: 0, cost: 2.7931790351867676, coll: 289.46856689453125\n", 175 | "iter: 0, cost: 6.296993732452393, coll: 0.0\n", 176 | "iter: 0, cost: 3.31953501701355, coll: 238.0959014892578\n", 177 | "iter: 0, cost: 4.726524353027344, coll: 0.0\n", 178 | "iter: 0, cost: 2.7483277320861816, coll: 192.95188903808594\n", 179 | "iter: 0, cost: 4.901299476623535, coll: 0.0\n", 180 | "iter: 0, cost: 2.9958345890045166, coll: 276.0810852050781\n", 181 | "iter: 0, cost: 5.695014953613281, coll: 0.0\n", 182 | "iter: 0, cost: 2.896209239959717, coll: 249.1409454345703\n", 183 | "iter: 0, cost: 5.641671180725098, coll: 0.0\n", 184 | "iter: 0, cost: 2.5541064739227295, coll: 212.2735595703125\n", 185 | "iter: 0, cost: 5.580175399780273, coll: 0.0\n", 186 | "iter: 0, cost: 3.1418075561523438, coll: 269.0629577636719\n", 187 | "0.6576923076923077 0.8294871794871795\n", 188 | "0.6756410256410257 0.808974358974359\n", 189 | "0.13205128205128205 0.007692307692307693\n", 190 | "iter: 0, cost: 6.921337604522705, coll: 0.0\n", 191 | "iter: 0, cost: 3.356431722640991, coll: 233.58775329589844\n", 192 | "iter: 0, cost: 4.804630279541016, coll: 0.0\n", 193 | "iter: 0, cost: 2.3112099170684814, coll: 179.9398956298828\n", 194 | "iter: 0, cost: 5.096407413482666, coll: 0.0\n", 195 | "iter: 0, cost: 2.9967024326324463, coll: 274.8147888183594\n", 196 | "iter: 0, cost: 5.5002665519714355, coll: 0.0\n", 197 | "iter: 0, cost: 3.70255708694458, coll: 180.61471557617188\n", 198 | "iter: 0, cost: 4.143067359924316, coll: 0.0\n", 199 | "iter: 0, cost: 2.1916821002960205, coll: 150.0166015625\n", 200 | "iter: 0, cost: 4.190585613250732, coll: 0.0\n", 201 | "iter: 0, cost: 2.8816428184509277, coll: 252.2906036376953\n", 202 | "iter: 0, cost: 5.223201751708984, coll: 0.0\n", 203 | "iter: 0, cost: 2.6801912784576416, coll: 235.3170166015625\n", 204 | "iter: 0, cost: 5.154111862182617, coll: 0.0\n", 205 | "iter: 0, cost: 2.9495022296905518, coll: 175.77511596679688\n", 206 | "iter: 0, cost: 4.972971439361572, coll: 0.0\n", 207 | "iter: 0, cost: 2.97935152053833, coll: 276.0395202636719\n", 208 | "iter: 0, cost: 5.827913284301758, coll: 0.0\n", 209 | "iter: 0, cost: 2.5509707927703857, coll: 215.46536254882812\n", 210 | "iter: 0, cost: 4.594428539276123, coll: 0.0\n", 211 | "iter: 0, cost: 2.458381414413452, coll: 159.38174438476562\n", 212 | "iter: 0, cost: 4.500207424163818, coll: 0.0\n", 213 | "iter: 0, cost: 2.29235577583313, coll: 265.7178649902344\n", 214 | "iter: 0, cost: 4.251039028167725, coll: 0.0\n", 215 | "iter: 0, cost: 2.542681932449341, coll: 208.62742614746094\n", 216 | "iter: 0, cost: 3.9566948413848877, coll: 0.0\n", 217 | "iter: 0, cost: 2.4060375690460205, coll: 210.0897979736328\n", 218 | "iter: 0, cost: 4.460279941558838, coll: 0.0\n", 219 | "iter: 0, cost: 3.1250696182250977, coll: 274.03729248046875\n", 220 | "iter: 0, cost: 5.453641891479492, coll: 0.0\n", 221 | "iter: 0, cost: 2.8635387420654297, coll: 223.63833618164062\n", 222 | "iter: 0, cost: 5.51327657699585, coll: 0.0\n", 223 | "iter: 0, cost: 2.4457695484161377, coll: 187.05482482910156\n", 224 | "iter: 0, cost: 5.475167751312256, coll: 0.0\n", 225 | "iter: 0, cost: 2.514143943786621, coll: 255.46145629882812\n", 226 | "iter: 0, cost: 4.591771602630615, coll: 0.0\n", 227 | "iter: 0, cost: 2.745338201522827, coll: 181.58197021484375\n", 228 | "iter: 0, cost: 4.475475788116455, coll: 0.0\n", 229 | "iter: 0, cost: 2.294517755508423, coll: 203.07382202148438\n", 230 | "iter: 0, cost: 4.307751178741455, coll: 0.0\n", 231 | "iter: 0, cost: 2.6489882469177246, coll: 238.17575073242188\n", 232 | "iter: 0, cost: 5.537105083465576, coll: 0.0\n", 233 | "iter: 0, cost: 2.5215260982513428, coll: 195.8470916748047\n", 234 | "iter: 0, cost: 4.805717468261719, coll: 0.0\n", 235 | "iter: 0, cost: 2.343567132949829, coll: 219.32403564453125\n", 236 | "iter: 0, cost: 4.526413440704346, coll: 0.0\n", 237 | "iter: 0, cost: 3.0172171592712402, coll: 257.03851318359375\n", 238 | "iter: 0, cost: 5.25303840637207, coll: 0.0\n", 239 | "iter: 0, cost: 2.9329965114593506, coll: 202.9410400390625\n", 240 | "iter: 0, cost: 5.10517692565918, coll: 0.0\n", 241 | "iter: 0, cost: 2.606665849685669, coll: 168.86875915527344\n", 242 | "iter: 0, cost: 5.093562602996826, coll: 0.0\n", 243 | "iter: 0, cost: 2.794419527053833, coll: 279.3061218261719\n", 244 | "iter: 0, cost: 5.089090347290039, coll: 0.0\n", 245 | "iter: 0, cost: 2.43367338180542, coll: 269.2640686035156\n", 246 | "iter: 0, cost: 4.268590927124023, coll: 0.0\n", 247 | "iter: 0, cost: 2.358184814453125, coll: 183.20774841308594\n", 248 | "iter: 0, cost: 4.708502769470215, coll: 0.0\n", 249 | "iter: 0, cost: 2.505507230758667, coll: 216.8807373046875\n", 250 | "0.5535714285714286 0.7714285714285715\n", 251 | "0.6392857142857142 0.7035714285714286\n", 252 | "0.12857142857142856 0.0035714285714285713\n", 253 | "======================== 8 ========================\n", 254 | "iter: 0, cost: 4.469699382781982, coll: 0.0\n", 255 | "iter: 0, cost: 2.6389689445495605, coll: 226.32936096191406\n", 256 | "iter: 0, cost: 4.4651665687561035, coll: 0.0\n", 257 | "iter: 0, cost: 2.7793893814086914, coll: 199.25123596191406\n", 258 | "iter: 0, cost: 4.613893985748291, coll: 0.0\n", 259 | "iter: 0, cost: 2.994060516357422, coll: 268.7557067871094\n", 260 | "iter: 0, cost: 5.410276412963867, coll: 0.0\n", 261 | "iter: 0, cost: 2.898651123046875, coll: 195.3013458251953\n", 262 | "iter: 0, cost: 5.5550336837768555, coll: 0.0\n", 263 | "iter: 0, cost: 2.341693878173828, coll: 276.87017822265625\n", 264 | "iter: 0, cost: 5.252536296844482, coll: 0.0\n", 265 | "iter: 0, cost: 3.2382848262786865, coll: 273.85693359375\n", 266 | "iter: 0, cost: 4.754010200500488, coll: 0.0\n", 267 | "iter: 0, cost: 2.671107769012451, coll: 224.0349884033203\n", 268 | "iter: 0, cost: 4.522400856018066, coll: 0.0\n", 269 | "iter: 0, cost: 2.474883556365967, coll: 207.3258819580078\n", 270 | "iter: 0, cost: 4.69058895111084, coll: 0.0\n", 271 | "iter: 0, cost: 2.769101858139038, coll: 274.862060546875\n", 272 | "iter: 0, cost: 5.0015339851379395, coll: 0.0\n", 273 | "iter: 0, cost: 2.716750144958496, coll: 244.9803009033203\n", 274 | "iter: 0, cost: 5.117935657501221, coll: 0.0\n", 275 | "iter: 0, cost: 2.4479916095733643, coll: 226.3140106201172\n", 276 | "iter: 0, cost: 5.224783420562744, coll: 0.0\n", 277 | "iter: 0, cost: 2.8813579082489014, coll: 280.4502258300781\n", 278 | "iter: 0, cost: 5.073554039001465, coll: 0.0\n", 279 | "iter: 0, cost: 2.6444737911224365, coll: 265.61859130859375\n", 280 | "iter: 0, cost: 5.023900508880615, coll: 0.0\n", 281 | "iter: 0, cost: 2.3993871212005615, coll: 202.03834533691406\n", 282 | "iter: 0, cost: 5.046306610107422, coll: 0.0\n", 283 | "iter: 0, cost: 3.093268632888794, coll: 278.5723876953125\n", 284 | "iter: 0, cost: 5.974239349365234, coll: 0.0\n", 285 | "iter: 0, cost: 2.688624382019043, coll: 232.18124389648438\n", 286 | "iter: 0, cost: 5.943511009216309, coll: 0.0\n", 287 | "iter: 0, cost: 2.3936665058135986, coll: 230.2108154296875\n", 288 | "iter: 0, cost: 6.12375545501709, coll: 0.0\n", 289 | "iter: 0, cost: 2.943547010421753, coll: 272.3543701171875\n", 290 | "iter: 0, cost: 5.526000022888184, coll: 0.0\n", 291 | "iter: 0, cost: 2.8847405910491943, coll: 198.6845245361328\n", 292 | "iter: 0, cost: 5.487680912017822, coll: 0.0\n", 293 | "iter: 0, cost: 3.2671730518341064, coll: 221.1943817138672\n", 294 | "iter: 0, cost: 5.271404266357422, coll: 0.0\n", 295 | "iter: 0, cost: 2.8672759532928467, coll: 277.38104248046875\n", 296 | "iter: 0, cost: 5.4088616371154785, coll: 0.0\n", 297 | "iter: 0, cost: 2.9759442806243896, coll: 225.27723693847656\n", 298 | "iter: 0, cost: 5.526289463043213, coll: 0.0\n", 299 | "iter: 0, cost: 2.7067198753356934, coll: 223.29933166503906\n", 300 | "iter: 0, cost: 4.9404072761535645, coll: 0.0\n", 301 | "iter: 0, cost: 2.747190237045288, coll: 277.987060546875\n", 302 | "iter: 0, cost: 5.298285484313965, coll: 0.0\n", 303 | "iter: 0, cost: 2.5897316932678223, coll: 194.94854736328125\n", 304 | "iter: 0, cost: 5.112144947052002, coll: 0.0\n", 305 | "iter: 0, cost: 2.663052558898926, coll: 186.8408660888672\n", 306 | "iter: 0, cost: 5.6352057456970215, coll: 0.0\n", 307 | "iter: 0, cost: 3.2570478916168213, coll: 268.6293029785156\n", 308 | "iter: 0, cost: 5.354095458984375, coll: 0.0\n", 309 | "iter: 0, cost: 2.5515127182006836, coll: 228.58885192871094\n", 310 | "iter: 0, cost: 5.62315559387207, coll: 0.0\n", 311 | "iter: 0, cost: 2.3387677669525146, coll: 200.41683959960938\n", 312 | "iter: 0, cost: 5.8819427490234375, coll: 0.0\n", 313 | "iter: 0, cost: 2.8612565994262695, coll: 280.8544616699219\n", 314 | "0.7192307692307692 0.8871794871794871\n", 315 | "0.7128205128205128 0.8397435897435898\n", 316 | "0.28974358974358977 0.005128205128205128\n", 317 | "iter: 0, cost: 4.655050277709961, coll: 0.0\n", 318 | "iter: 0, cost: 2.0962297916412354, coll: 222.926025390625\n", 319 | "iter: 0, cost: 4.838199138641357, coll: 0.0\n", 320 | "iter: 0, cost: 2.2492175102233887, coll: 199.38177490234375\n", 321 | "iter: 0, cost: 5.040234565734863, coll: 0.0\n", 322 | "iter: 0, cost: 3.2888858318328857, coll: 252.36007690429688\n", 323 | "iter: 0, cost: 4.568339824676514, coll: 0.0\n", 324 | "iter: 0, cost: 2.583643674850464, coll: 174.03721618652344\n", 325 | "iter: 0, cost: 4.62748384475708, coll: 0.0\n", 326 | "iter: 0, cost: 2.7131402492523193, coll: 179.92442321777344\n", 327 | "iter: 0, cost: 4.575658798217773, coll: 0.0\n", 328 | "iter: 0, cost: 2.6152474880218506, coll: 280.50738525390625\n", 329 | "iter: 0, cost: 4.725114345550537, coll: 0.0\n", 330 | "iter: 0, cost: 2.3749935626983643, coll: 201.59312438964844\n", 331 | "iter: 0, cost: 4.477592945098877, coll: 0.0\n", 332 | "iter: 0, cost: 2.39646577835083, coll: 175.77798461914062\n", 333 | "iter: 0, cost: 4.860243797302246, coll: 0.0\n", 334 | "iter: 0, cost: 2.6849052906036377, coll: 245.21368408203125\n", 335 | "iter: 0, cost: 5.667869567871094, coll: 0.0\n", 336 | "iter: 0, cost: 2.4528801441192627, coll: 198.05458068847656\n", 337 | "iter: 0, cost: 5.694857120513916, coll: 0.0\n", 338 | "iter: 0, cost: 2.3487985134124756, coll: 221.2257537841797\n", 339 | "iter: 0, cost: 5.859572887420654, coll: 0.0\n", 340 | "iter: 0, cost: 2.529609441757202, coll: 270.6771240234375\n", 341 | "iter: 0, cost: 4.623867988586426, coll: 0.0\n", 342 | "iter: 0, cost: 2.4987170696258545, coll: 157.00521850585938\n", 343 | "iter: 0, cost: 4.446608066558838, coll: 0.0\n", 344 | "iter: 0, cost: 2.5910463333129883, coll: 175.7313995361328\n", 345 | "iter: 0, cost: 4.6360273361206055, coll: 0.0\n", 346 | "iter: 0, cost: 3.147291660308838, coll: 281.3092041015625\n", 347 | "iter: 0, cost: 4.345896244049072, coll: 0.0\n", 348 | "iter: 0, cost: 2.2911813259124756, coll: 201.53880310058594\n", 349 | "iter: 0, cost: 4.436519145965576, coll: 0.0\n", 350 | "iter: 0, cost: 2.3071446418762207, coll: 176.85630798339844\n", 351 | "iter: 0, cost: 4.571264743804932, coll: 0.0\n", 352 | "iter: 0, cost: 2.9744346141815186, coll: 277.62158203125\n", 353 | "iter: 0, cost: 4.64131498336792, coll: 0.0\n", 354 | "iter: 0, cost: 2.49574613571167, coll: 183.47987365722656\n", 355 | "iter: 0, cost: 4.5885009765625, coll: 0.0\n", 356 | "iter: 0, cost: 2.5473387241363525, coll: 159.8391571044922\n", 357 | "iter: 0, cost: 4.545409679412842, coll: 0.0\n", 358 | "iter: 0, cost: 2.4146811962127686, coll: 257.1953430175781\n", 359 | "iter: 0, cost: 4.410649299621582, coll: 0.0\n", 360 | "iter: 0, cost: 2.6315581798553467, coll: 202.832275390625\n", 361 | "iter: 0, cost: 4.700764179229736, coll: 0.0\n", 362 | "iter: 0, cost: 2.6255152225494385, coll: 195.7467498779297\n", 363 | "iter: 0, cost: 4.630579471588135, coll: 0.0\n", 364 | "iter: 0, cost: 2.792097330093384, coll: 269.8475646972656\n", 365 | "iter: 0, cost: 5.199542999267578, coll: 0.0\n", 366 | "iter: 0, cost: 2.4331114292144775, coll: 211.89283752441406\n", 367 | "iter: 0, cost: 5.076155662536621, coll: 0.0\n", 368 | "iter: 0, cost: 2.6994900703430176, coll: 227.6631317138672\n", 369 | "iter: 0, cost: 5.2984747886657715, coll: 0.0\n", 370 | "iter: 0, cost: 3.532865047454834, coll: 257.8428955078125\n", 371 | "iter: 0, cost: 4.415054798126221, coll: 0.0\n", 372 | "iter: 0, cost: 2.51259183883667, coll: 222.4306640625\n", 373 | "iter: 0, cost: 4.339607238769531, coll: 0.0\n", 374 | "iter: 0, cost: 2.42262601852417, coll: 187.06309509277344\n", 375 | "iter: 0, cost: 4.375247478485107, coll: 0.0\n", 376 | "iter: 0, cost: 2.7221128940582275, coll: 275.1281433105469\n", 377 | "0.6928571428571428 0.85\n", 378 | "0.6714285714285714 0.7928571428571428\n", 379 | "0.2392857142857143 0.007142857142857143\n" 380 | ] 381 | } 382 | ], 383 | "source": [ 384 | "for num_views in [2, 8]:\n", 385 | " print('======================== '+str(num_views)+' ========================')\n", 386 | " for t, dataset in enumerate([trainset, testset]):\n", 387 | " dataset.num_views = num_views\n", 388 | " \n", 389 | " x_PIFO, x_noPixel, x_vecObj = [], [], []\n", 390 | " y_PIFO, y_noPixel, y_vecObj = [], [], []\n", 391 | " B, N, num_best = len(dataset), 10, 1\n", 392 | "\n", 393 | " for it in range(10):\n", 394 | " rgb_list, projections_list, cam_params_list, filename_list, mass_list, com_list = [], [], [], [], [], []\n", 395 | " for i in range(B):\n", 396 | " data = to_device(dataset[i], device)\n", 397 | " rgb, projections, cam_pos, new_origin, cam_roll = warper(data['rgb'].unsqueeze(0), \n", 398 | " data['cam_extrinsic'].unsqueeze(0), \n", 399 | " data['cam_intrinsic'].unsqueeze(0))\n", 400 | " rgb_list.append(rgb)\n", 401 | " projections_list.append(projections)\n", 402 | " cam_params_list.append(torch.cat([cam_pos, cam_roll], dim=2))\n", 403 | " filename_list.append(data['filenames'])\n", 404 | " mass_list.append(data['masses'])\n", 405 | " com_list.append(data['coms'])\n", 406 | "\n", 407 | " x_init = torch.cat([.2*torch.randn(B,N,3, device=device), \n", 408 | " random_quaternions(B*N, device=device).view(B,N,4)], dim=2)\n", 409 | "\n", 410 | " x, cost, coll = F_grasp1.optimize(x_init.clone(),\n", 411 | " torch.cat(rgb_list), \n", 412 | " torch.cat(projections_list))\n", 413 | "\n", 414 | " best_inds = torch.tensor(cost).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)\n", 415 | " best_poses = torch.gather(x, dim=1, index=best_inds)\n", 416 | "\n", 417 | " x_PIFO.append(best_poses)\n", 418 | "\n", 419 | " x, cost, coll = F_grasp1.optimize(x,\n", 420 | " torch.cat(rgb_list), \n", 421 | " torch.cat(projections_list),\n", 422 | " w_coll=1e3)\n", 423 | "\n", 424 | " best_inds = torch.tensor(np.square(cost)+np.square(coll*1e3)).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)\n", 425 | " best_poses = torch.gather(x, dim=1, index=best_inds)\n", 426 | "\n", 427 | " y_PIFO.append(best_poses)\n", 428 | "\n", 429 | "\n", 430 | " ### 2\n", 431 | " x, cost, coll = F_grasp2.optimize(x_init.clone(),\n", 432 | " torch.cat(rgb_list), \n", 433 | " torch.cat(projections_list))\n", 434 | "\n", 435 | " best_inds = torch.tensor(cost).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)\n", 436 | " best_poses = torch.gather(x, dim=1, index=best_inds)\n", 437 | "\n", 438 | " x_noPixel.append(best_poses)\n", 439 | "\n", 440 | " x, cost, coll = F_grasp2.optimize(x,\n", 441 | " torch.cat(rgb_list), \n", 442 | " torch.cat(projections_list),\n", 443 | " w_coll=1e3)\n", 444 | "\n", 445 | " best_inds = torch.tensor(np.square(cost)+np.square(coll*1e3)).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)\n", 446 | " best_poses = torch.gather(x, dim=1, index=best_inds)\n", 447 | "\n", 448 | " y_noPixel.append(best_poses)\n", 449 | "\n", 450 | "\n", 451 | " ### 3\n", 452 | " x, cost, coll = F_grasp3.optimize(x_init.clone(),\n", 453 | " torch.cat(rgb_list), \n", 454 | " torch.cat(cam_params_list))\n", 455 | "\n", 456 | " best_inds = torch.tensor(cost).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)\n", 457 | " best_poses = torch.gather(x, dim=1, index=best_inds)\n", 458 | "\n", 459 | " x_vecObj.append(best_poses)\n", 460 | "\n", 461 | " x, cost, coll = F_grasp3.optimize(x,\n", 462 | " torch.cat(rgb_list), \n", 463 | " torch.cat(cam_params_list),\n", 464 | " w_coll=1e3)\n", 465 | "\n", 466 | " best_inds = torch.tensor(np.square(cost)+np.square(coll*1e3)).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)\n", 467 | " best_poses = torch.gather(x, dim=1, index=best_inds)\n", 468 | "\n", 469 | " y_vecObj.append(best_poses)\n", 470 | "\n", 471 | " f1_PIFO = F_grasp1.check_feasibility(torch.cat(x_PIFO, dim=1),\n", 472 | " filename_list, \n", 473 | " mass_list,\n", 474 | " com_list)\n", 475 | " f2_PIFO = F_grasp1.check_feasibility(torch.cat(y_PIFO, dim=1),\n", 476 | " filename_list, \n", 477 | " mass_list,\n", 478 | " com_list)\n", 479 | "\n", 480 | " f1_noPixel = F_grasp2.check_feasibility(torch.cat(x_noPixel, dim=1),\n", 481 | " filename_list, \n", 482 | " mass_list,\n", 483 | " com_list)\n", 484 | " f2_noPixel = F_grasp2.check_feasibility(torch.cat(y_noPixel, dim=1),\n", 485 | " filename_list, \n", 486 | " mass_list,\n", 487 | " com_list)\n", 488 | "\n", 489 | " f1_vecObj = F_grasp3.check_feasibility(torch.cat(x_vecObj, dim=1),\n", 490 | " filename_list, \n", 491 | " mass_list,\n", 492 | " com_list)\n", 493 | " f2_vecObj = F_grasp3.check_feasibility(torch.cat(y_vecObj, dim=1),\n", 494 | " filename_list, \n", 495 | " mass_list,\n", 496 | " com_list)\n", 497 | "\n", 498 | " print(f1_PIFO.sum()/f1_PIFO.size, f2_PIFO.sum()/f2_PIFO.size)\n", 499 | " print(f1_noPixel.sum()/f1_noPixel.size, f2_noPixel.sum()/f2_noPixel.size)\n", 500 | " print(f1_vecObj.sum()/f1_vecObj.size, f2_vecObj.sum()/f2_vecObj.size)\n", 501 | "\n", 502 | " data_name = 'train' if t == 0 else 'test'\n", 503 | " with h5py.File('evals/grasp/'+data_name+'_'+str(num_views)+'.hdf5', mode='w') as f:\n", 504 | " f.create_dataset(\"x_PIFO\", data=torch.cat(x_PIFO, dim=1).cpu().numpy())\n", 505 | " f.create_dataset(\"y_PIFO\", data=torch.cat(y_PIFO, dim=1).cpu().numpy())\n", 506 | "\n", 507 | " f.create_dataset(\"x_noPixel\", data=torch.cat(x_noPixel, dim=1).cpu().numpy())\n", 508 | " f.create_dataset(\"y_noPixel\", data=torch.cat(y_noPixel, dim=1).cpu().numpy())\n", 509 | "\n", 510 | " f.create_dataset(\"x_vecObj\", data=torch.cat(x_vecObj, dim=1).cpu().numpy())\n", 511 | " f.create_dataset(\"y_vecObj\", data=torch.cat(y_vecObj, dim=1).cpu().numpy())\n", 512 | "\n", 513 | " f.create_dataset(\"f1_PIFO\", data=f1_PIFO)\n", 514 | " f.create_dataset(\"f2_PIFO\", data=f2_PIFO)\n", 515 | "\n", 516 | " f.create_dataset(\"f1_noPixel\", data=f1_noPixel)\n", 517 | " f.create_dataset(\"f2_noPixel\", data=f2_noPixel)\n", 518 | "\n", 519 | " f.create_dataset(\"f1_vecObj\", data=f1_vecObj)\n", 520 | " f.create_dataset(\"f2_vecObj\", data=f2_vecObj)" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 7, 526 | "id": "e95c23a0-9811-48d0-95f4-312c11a83ec3", 527 | "metadata": {}, 528 | "outputs": [ 529 | { 530 | "name": "stdout", 531 | "output_type": "stream", 532 | "text": [ 533 | "======================== train_2 ========================\n", 534 | "0.6576923076923077 0.8294871794871795\n", 535 | "0.6756410256410257 0.808974358974359\n", 536 | "0.13205128205128205 0.007692307692307693\n", 537 | "======================== test_2 ========================\n", 538 | "0.5535714285714286 0.7714285714285715\n", 539 | "0.6392857142857142 0.7035714285714286\n", 540 | "0.12857142857142856 0.0035714285714285713\n", 541 | "======================== train_4 ========================\n", 542 | "0.6897435897435897 0.8807692307692307\n", 543 | "0.6230769230769231 0.8269230769230769\n", 544 | "0.21153846153846154 0.005128205128205128\n", 545 | "======================== test_4 ========================\n", 546 | "0.6392857142857142 0.825\n", 547 | "0.6178571428571429 0.7571428571428571\n", 548 | "0.225 0.0035714285714285713\n", 549 | "======================== train_8 ========================\n", 550 | "0.7192307692307692 0.8871794871794871\n", 551 | "0.7128205128205128 0.8397435897435898\n", 552 | "0.28974358974358977 0.005128205128205128\n", 553 | "======================== test_8 ========================\n", 554 | "0.6928571428571428 0.85\n", 555 | "0.6714285714285714 0.7928571428571428\n", 556 | "0.2392857142857143 0.007142857142857143\n" 557 | ] 558 | } 559 | ], 560 | "source": [ 561 | "for num_views in [2, 4, 8]:\n", 562 | " for data_name in ['train', 'test']:\n", 563 | " with h5py.File('evals/grasp/'+data_name+'_'+str(num_views)+'.hdf5', mode='r') as f:\n", 564 | " f1_PIFO, f1_noPixel, f1_vecObj = f['f1_PIFO'][:], f['f1_noPixel'][:], f['f1_vecObj'][:] \n", 565 | " f2_PIFO, f2_noPixel, f2_vecObj = f['f2_PIFO'][:], f['f2_noPixel'][:], f['f2_vecObj'][:] \n", 566 | " print('======================== '+data_name+'_'+str(num_views)+' ========================')\n", 567 | " print(f1_PIFO.sum()/f1_PIFO.size, f2_PIFO.sum()/f2_PIFO.size)\n", 568 | " print(f1_noPixel.sum()/f1_noPixel.size, f2_noPixel.sum()/f2_noPixel.size)\n", 569 | " print(f1_vecObj.sum()/f1_vecObj.size, f2_vecObj.sum()/f2_vecObj.size)" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": null, 575 | "id": "55c98461-1eda-494e-bd71-215716f6a14d", 576 | "metadata": {}, 577 | "outputs": [], 578 | "source": [] 579 | } 580 | ], 581 | "metadata": { 582 | "kernelspec": { 583 | "display_name": "Python 3", 584 | "language": "python", 585 | "name": "python3" 586 | }, 587 | "language_info": { 588 | "codemirror_mode": { 589 | "name": "ipython", 590 | "version": 3 591 | }, 592 | "file_extension": ".py", 593 | "mimetype": "text/x-python", 594 | "name": "python", 595 | "nbconvert_exporter": "python", 596 | "pygments_lexer": "ipython3", 597 | "version": "3.8.10" 598 | } 599 | }, 600 | "nbformat": 4, 601 | "nbformat_minor": 5 602 | } 603 | --------------------------------------------------------------------------------