├── .vscode └── settings.json ├── assets ├── object_names_test.txt ├── object_render.xml ├── gripper_render.xml ├── object_sampler.py ├── scan_object_process.py ├── icon_process.py ├── finger_sampler.py ├── finger_3d.py └── object_names.txt ├── generator ├── train_diffusion_2d.sh ├── train_diffusion_3d.sh ├── guided_sample_2d.sh ├── guided_sample_3d.sh ├── dataloader.py ├── train.py └── diffusion_utils.py ├── sim ├── run_sim_3d.sh ├── run_sim_2d.sh ├── render_mesh.py ├── sim_3d.py └── sim_2d.py ├── requirements.txt ├── dynamics ├── train_dynamics_2d.sh ├── train_dynamics_3d.sh ├── models │ ├── pointnet2.py │ └── pointnet2_utils.py ├── utils.py ├── profile_forward_3d.py ├── parser.py ├── dataloader.py ├── profile_forward_2d.py ├── trainer.py ├── main.py ├── metrics.py ├── sim_test_mj_3d.py └── sim_test_mj.py ├── README.md └── .gitignore /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.typeCheckingMode": "basic", 3 | "python.analysis.autoImportCompletions": true 4 | } -------------------------------------------------------------------------------- /assets/object_names_test.txt: -------------------------------------------------------------------------------- 1 | 3D_Dollhouse_Swing 2 | BABY_CAR 3 | Ecoforms_Plant_Container_B4_Har 4 | Threshold_Bamboo_Ceramic_Soap_Dish 5 | Squirt_Strain_Fruit_Basket 6 | Office_Depot_Canon_CLI_8CMY_Remanufactured_Ink_Cartridges_Color_Cyan_Magenta_Yellow_3_count -------------------------------------------------------------------------------- /generator/train_diffusion_2d.sh: -------------------------------------------------------------------------------- 1 | python generator/train.py --num_fingers=200000 --save_dir='' --learning_rate=1e-4 --lr_warmup_steps=0 --num_epochs=1000 --val_step=100 --num_workers=0 --num_train_timesteps=15 --num_inference_steps=5 --ema_power=0.85 --batch_size=2048 --ctrlpts_dim=14 -------------------------------------------------------------------------------- /generator/train_diffusion_3d.sh: -------------------------------------------------------------------------------- 1 | python generator/train.py --num_fingers=200000 --save_dir='' --fingers_3d --ctrlpts_dim=42 --ctrlpts_x_dim=7 --ctrlpts_z_dim=3 --learning_rate=1e-4 --lr_warmup_steps=0 --num_epochs=1000 --val_step=100 --num_workers=0 --num_train_timesteps=15 --num_inference_steps=5 --ema_power=0.85 --batch_size=1024 -------------------------------------------------------------------------------- /sim/run_sim_3d.sh: -------------------------------------------------------------------------------- 1 | model_root='' 2 | save_dir='' 3 | num_cpus=256 4 | 5 | for object_idx in {0..300}; do 6 | for ((i=0; i<2000; i+=512)) do 7 | python sim/sim_3d.py $model_root $i $object_idx 512 1 $save_dir $num_cpus 8 | done 9 | done -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.11.1 2 | geomdl==5.3.1 3 | imageio==2.31.4 4 | matplotlib==3.7.3 5 | mujoco==3.1.1 6 | numpy==1.24.4 7 | open3d==0.17.0 8 | opencv_python==4.8.1.78 9 | pytorch_lightning==2.1.0 10 | ray==2.7.1 11 | scipy==1.12.0 12 | torch==2.0.1 13 | tqdm==4.66.1 14 | transforms3d==0.4.1 15 | triangle==20230923 16 | trimesh==3.23.5 17 | wandb==0.15.12 18 | -------------------------------------------------------------------------------- /dynamics/train_dynamics_2d.sh: -------------------------------------------------------------------------------- 1 | python fitness/main.py --save_dir='' --wandb_id='' --ctrlpts_dim=14 --batch_size=128 --object_max_num_vertices=100 --data_dir='' --test_data_dir='' --object_dir='' --learning_rate=1e-4 --weight_decay=0 --num_epochs=100 --val_step=1 --save_ckpt_step=1000 --patience=100 --num_workers=8 --num_train_timesteps=15 --num_inference_steps=5 --num_timesteps_per_batch=1 -------------------------------------------------------------------------------- /sim/run_sim_2d.sh: -------------------------------------------------------------------------------- 1 | model_root='' 2 | save_dir='' 3 | num_cpus=128 # number of cpus to use for parallel simulation 4 | 5 | for object_idx in {0..1000}; do # number of objects 6 | for ((i=0; i<1000; i+=512)) do # number of manipulators 7 | python sim/sim_2d.py $model_root $i $object_idx 512 1 $save_dir $num_cpus 8 | done 9 | done -------------------------------------------------------------------------------- /dynamics/train_dynamics_3d.sh: -------------------------------------------------------------------------------- 1 | python fitness/main.py --save_dir='' --wandb_id='' --fingers_3d --batch_size=1 --use_sub_batch --sub_bs=2048 --object_max_num_vertices=512 --data_dir='' --test_data_dir='' --object_dir='' --ctrlpts_dim=42 --ctrlpts_x_dim=7 --ctrlpts_z_dim=3 --learning_rate=1e-4 --weight_decay=0 --num_epochs=100 --val_step=1 --save_ckpt_step=1000 --patience=100 --num_workers=16 --num_train_timesteps=15 --num_inference_steps=5 --num_timesteps_per_batch=1 -------------------------------------------------------------------------------- /generator/guided_sample_2d.sh: -------------------------------------------------------------------------------- 1 | python generator/train.py --mode='test' --checkpoint_path='ckpts/dynamics_2d.pt' \ 2 | --classifier_guidance --diffusion_checkpoint_path='ckpts/diffusion_2d.pt' --object_dir='/Icons-50.npy' --save_dir='' \ 3 | --ctrlpts_dim=14 --num_fingers=16 --grid_size=360 --num_pos=5 --object_max_num_vertices=100 \ 4 | --num_workers=0 --num_train_timesteps=15 --num_inference_steps=5 --ema_power=0.85 --batch_size=16 --num_cpus=32 --seed=0 5 | -------------------------------------------------------------------------------- /assets/object_render.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /generator/guided_sample_3d.sh: -------------------------------------------------------------------------------- 1 | python generator/train.py --mode='test' --checkpoint_path='ckpts/dynamics_3d.pt' \ 2 | --diffusion_checkpoint_path='ckpts/diffusion_3d.ckpt' --object_dir='' --save_dir='' \ 3 | --classifier_guidance --num_fingers=16 --grid_size=45 --num_pos=5 --fingers_3d --object_max_num_vertices=512 --ctrlpts_dim=42 --ctrlpts_x_dim=7 --ctrlpts_z_dim=3 \ 4 | --num_workers=0 --num_train_timesteps=15 --num_inference_steps=5 --ema_power=0.85 --batch_size=16 --sub_bs=512 --num_cpus=32 --seed=0 5 | -------------------------------------------------------------------------------- /assets/gripper_render.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /generator/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | class GripperDataset(Dataset): 6 | def __init__(self, gripper_pts, gripper_pts_max_x, gripper_pts_min_x, gripper_pts_max_y, gripper_pts_min_y): 7 | self.gripper_pts = gripper_pts 8 | self.gripper_pts_max_x = gripper_pts_max_x 9 | self.gripper_pts_min_x = gripper_pts_min_x 10 | self.gripper_pts_max_y = gripper_pts_max_y 11 | self.gripper_pts_min_y = gripper_pts_min_y 12 | 13 | def __len__(self): 14 | return len(self.gripper_pts) 15 | 16 | def __getitem__(self, idx): 17 | # IMPORTANT: normalize the input to [-1, 1] 18 | ctrlpts = self.gripper_pts[idx, :, 1].astype(np.float32) 19 | ctrlpts = (ctrlpts - self.gripper_pts_min_y) / (self.gripper_pts_max_y - self.gripper_pts_min_y) * 2.0 - 1.0 20 | return ctrlpts.reshape((-1, 1)) -------------------------------------------------------------------------------- /assets/object_sampler.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | 3 | def generate_object_xml(num_collision, object_idx, save_path): 4 | # Create the root element 5 | root = ET.Element("mujoco", model="object") 6 | 7 | # Create the 'asset' element 8 | asset = ET.SubElement(root, "asset") 9 | ET.SubElement(asset, "mesh", name="object", file="objects/%d/object.obj" % object_idx) 10 | 11 | for i in range(num_collision): 12 | ET.SubElement(asset, "mesh", name=f"object{i:03d}", file=f"objects/{object_idx}/object{i:03d}.obj") 13 | 14 | # Create the 'worldbody' element 15 | worldbody = ET.SubElement(root, "worldbody") 16 | body = ET.SubElement(worldbody, "body", name="object") 17 | 18 | # Add 'freejoint' and 'geom' elements to 'body' 19 | ET.SubElement(body, "freejoint", name="object_root") 20 | object_v = ET.SubElement(body, "geom", mesh="object", type="mesh") 21 | object_v.set("class", "visual") 22 | 23 | for i in range(num_collision): 24 | object_c = ET.SubElement(body, "geom", mesh=f"object{i:03d}", type="mesh") 25 | object_c.set("class", "collision") 26 | 27 | # Create an ElementTree object and write to file 28 | tree = ET.ElementTree(root) 29 | tree.write(save_path) -------------------------------------------------------------------------------- /dynamics/models/pointnet2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from os.path import join as pjoin 4 | BASEPATH = os.path.dirname(__file__) 5 | sys.path.insert(0, BASEPATH) 6 | sys.path.insert(0, pjoin(BASEPATH, '..')) 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from dynamics.models.pointnet2_utils import PointNetSetAbstraction 10 | 11 | class PointNet2(nn.Module): 12 | def __init__(self, num_output_ch, normal_channel=False): 13 | super(PointNet2, self).__init__() 14 | in_channel = 6 if normal_channel else 3 15 | self.normal_channel = normal_channel 16 | self.num_output_ch = num_output_ch 17 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 128], group_all=False) 18 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, num_output_ch], group_all=False) 19 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[num_output_ch], group_all=True) 20 | 21 | def forward(self, xyz): 22 | B, _, _ = xyz.shape 23 | if self.normal_channel: 24 | norm = xyz[:, 3:, :] 25 | xyz = xyz[:, :3, :] 26 | else: 27 | norm = None 28 | l1_xyz, l1_points = self.sa1(xyz, norm) 29 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 30 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 31 | x = l3_points.view(B, self.num_output_ch) 32 | return x, l3_points -------------------------------------------------------------------------------- /sim/render_mesh.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | from os.path import join as pjoin 5 | BASEPATH = os.path.dirname(__file__) 6 | sys.path.insert(0, BASEPATH) 7 | sys.path.insert(0, pjoin(BASEPATH, '..')) 8 | 9 | import numpy as np 10 | import mujoco 11 | import subprocess 12 | import subprocess 13 | from transforms3d import euler 14 | 15 | from assets.icon_process import extract_contours 16 | 17 | color_map = np.asarray([ 18 | [0, 0, 0], 19 | [255, 255, 255], 20 | ], dtype=np.uint8) 21 | color_maps = np.concatenate([color_map for _ in range(32)], axis=0) 22 | 23 | def render_mesh(gripper_root: str): 24 | subprocess.call(['cp', 'assets/gripper_render.xml', os.path.join(gripper_root, 'gripper_render.xml')]) 25 | model = mujoco.MjModel.from_xml_path(os.path.join(gripper_root, 'gripper_render.xml')) 26 | data = mujoco.MjData(model) 27 | renderer = mujoco.Renderer(model, 256, 256) 28 | camera = mujoco.MjvCamera() 29 | camera.lookat[:] = [0.0, 0.0, 0.0] 30 | camera.distance = 0.9 31 | camera.azimuth = 180 32 | camera.elevation = -30 33 | 34 | mujoco.mj_step(model, data) 35 | renderer.update_scene(data, camera) 36 | img = renderer.render() 37 | return img 38 | 39 | def render_object_mesh(object_root, z_rots): 40 | subprocess.call(['cp', 'assets/object_render.xml', os.path.join(object_root, 'object_render.xml')]) 41 | model = mujoco.MjModel.from_xml_path(os.path.join(object_root, 'object_render.xml')) 42 | data = mujoco.MjData(model) 43 | renderer = mujoco.Renderer(model, 128, 128) 44 | renderer.enable_segmentation_rendering() 45 | camera = mujoco.MjvCamera() 46 | camera.lookat[:] = [0.0, 0.0, 0.0] 47 | camera.distance = 0.8 48 | camera.azimuth = 135 49 | camera.elevation = -45 50 | obj_root_idx = [model.joint(jointid).name for jointid in range(model.njnt)].index("object_root") 51 | obj_jnt = model.joint(obj_root_idx) 52 | assert obj_jnt.type == 0 # freejoint 53 | contours = [] 54 | for z_rot in z_rots: 55 | data.qpos[obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 3] = [0, 0, 0,] 56 | data.qpos[ 57 | obj_jnt.qposadr[0] + 3 : obj_jnt.qposadr[0] + 7 58 | ] = euler.euler2quat(0, 0, z_rot) 59 | mujoco.mj_step(model, data) 60 | renderer.update_scene(data, camera) 61 | img = renderer.render()[..., 0] 62 | img = color_maps[img] 63 | contour = extract_contours(img, num_points=100, rescale=False) 64 | contours.append(contour) 65 | return contours -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamics-Guided Diffusion Model for Robot Manipulator Design 2 | 3 | ### [Paper](https://arxiv.org/abs/2402.15038) | [Website](https://dgdm-robot.github.io) | [Video](https://www.youtube.com/watch?v=0m5nTWgHULg) 4 | [Xiaomeng Xu](https://xxm19.github.io/), [Huy Ha](https://www.cs.columbia.edu/~huy/), [Shuran Song](https://shurans.github.io/) 5 | 6 | ### Dependencies 7 | Required packages can be installed by: 8 | ``` 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ## Data Preparation 13 | 14 | ### Download object dataset 15 | #### 2D objects 16 | Download 2D object icons from [Icons50 dataset](https://www.kaggle.com/datasets/danhendrycks/icons50). 17 | 18 | #### 3D objects 19 | Download 3D object meshes from [MuJoCo scanned object dataset](https://github.com/kevinzakka/mujoco_scanned_objects). 20 | 21 | ### Generate simulation data 22 | Replace ```OBJECT_DIR``` in ```sim/sim_2d.py``` and ```sim/sim_3d.py``` with the directory to object dataset. 23 | 24 | Install [v-hacd](https://github.com/kmammou/v-hacd). 25 | 26 | #### 2D 27 | ``` 28 | bash sim/run_sim_2d.sh 29 | ``` 30 | #### 3D 31 | ``` 32 | bash sim/run_sim_3d.sh 33 | ``` 34 | 35 | Note for data generation: Sometimes the sampled objects or manipulators may have weird shapes and thus lead to qhull error when doing convex decomposition. And ray is used to parallelize cpu-based data generation, which sometimes may lead to timeout issues. Therefore, it is expected if you see some error message and the data for some object-manipulator pairs is not generated, but it should be fine as long as you see most data is being generated. 36 | 37 | ## Training 38 | [Download pretrained model checkpoints](https://drive.google.com/drive/folders/1jjC6G5Qv_ZkJwTjk2mCBkSyXkZu_w5EB?usp=sharing) 39 | ### Train Dynamics Model 40 | #### 2D 41 | ``` 42 | bash dynamics/train_dynamics_2d.sh 43 | ``` 44 | #### 3D 45 | ``` 46 | bash dynamics/train_dynamics_3d.sh 47 | ``` 48 | 49 | ### Train Diffusion Model 50 | #### 2D 51 | ``` 52 | bash generator/train_diffusion_2d.sh 53 | ``` 54 | #### 3D 55 | ``` 56 | bash generator/train_diffusion_3d.sh 57 | ``` 58 | 59 | ## Inference 60 | ### Generate Task-Specific Manipulators 61 | #### 2D 62 | ``` 63 | bash generator/guided_sample_2d.sh 64 | ``` 65 | #### 3D 66 | ``` 67 | bash generator/guided_sample_3d.sh 68 | ``` 69 | 70 | ## Citation 71 | If you find DGDM useful for your work, please cite: 72 | ``` 73 | @misc{xu2024dynamicsguided, 74 | title={Dynamics-Guided Diffusion Model for Robot Manipulator Design}, 75 | author={Xiaomeng Xu and Huy Ha and Shuran Song}, 76 | year={2024}, 77 | eprint={2402.15038}, 78 | archivePrefix={arXiv}, 79 | primaryClass={cs.RO} 80 | } 81 | ``` 82 | 83 | 84 | ## Contact 85 | If you have any questions, please feel free to contact Xiaomeng Xu (xuxm@stanford.edu) -------------------------------------------------------------------------------- /dynamics/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | import open3d as o3d 5 | 6 | def continuous_signed_delta(theta1, theta2): 7 | delta = theta2 - theta1 8 | if delta > np.pi: 9 | delta = delta - 2*np.pi 10 | elif delta < -np.pi: 11 | delta = delta + 2*np.pi 12 | return delta 13 | 14 | def sample_pts_from_mesh(mesh_file, num_points=1024): 15 | mesh = o3d.io.read_triangle_mesh(mesh_file) 16 | pcd = mesh.sample_points_uniformly(number_of_points=num_points) 17 | pts = np.asarray(pcd.points).reshape((-1, 3)) 18 | return pts 19 | 20 | def visualize_finals(finals, save_path): 21 | plt.clf() 22 | f = plt.figure(figsize=(10, 6)) 23 | ax = f.add_subplot(111) 24 | ax.set(ylim=(0, 2*np.pi)) 25 | ax.scatter(np.arange(len(finals)), finals, s=2) 26 | plt.savefig(save_path) 27 | plt.close() 28 | 29 | def visualize_profile(profile, save_path, ori_range=[-1.0, 1.0]): 30 | plt.clf() 31 | signs = np.sign(profile) 32 | 33 | radii = np.array([1]) 34 | thetas = np.linspace(ori_range[0] * np.pi + np.pi, ori_range[1] * np.pi + np.pi, len(profile)) 35 | theta, r = np.meshgrid(thetas, radii) 36 | u = - 2 * np.pi / len(profile) * np.sin(theta) * signs 37 | v = 2 * np.pi / len(profile) * np.cos(theta) * signs 38 | 39 | f = plt.figure(figsize=(40, 40)) 40 | ax = f.add_subplot(polar=True) 41 | ax.quiver(theta, r, u, v, profile, scale=1, width=0.005, headwidth=4, headlength=2, headaxislength=2, cmap='bwr') 42 | 43 | plt.savefig(save_path) 44 | plt.close() 45 | 46 | def visualize_profile_xy_theta(input_ori, input_pos, profile_ori, profile_x, profile_y, save_dir): 47 | os.makedirs(save_dir, exist_ok=True) 48 | plt.clf() 49 | f = plt.figure(figsize=(60, 20)) 50 | ax = f.add_subplot(131, projection='3d') 51 | ax.set(xlim=(-3, 3), ylim=(-3, 3), zlim=(-1, 1)) 52 | color = np.asarray(['r' if ori == 1 else 'b' if ori == -1 else 'g' for ori in profile_ori]) 53 | x = (input_pos[:, 0]+2.0) * np.cos(input_ori) 54 | y = (input_pos[:, 0]+2.0) * np.sin(input_ori) 55 | z = input_pos[:, 1] 56 | ax.scatter(x, y, z, c=color, s=1) 57 | 58 | ax = f.add_subplot(132, projection='3d') 59 | ax.set(xlim=(-3, 3), ylim=(-3, 3), zlim=(-1, 1)) 60 | color = np.asarray(['r' if x == 1 else 'b' if x == -1 else 'g' for x in profile_x]) 61 | ax.scatter(x, y, z, c=color, s=1) 62 | 63 | ax = f.add_subplot(133, projection='3d') 64 | ax.set(xlim=(-3, 3), ylim=(-3, 3), zlim=(-1, 1)) 65 | color = np.asarray(['r' if y == 1 else 'b' if y == -1 else 'g' for y in profile_y]) 66 | ax.scatter(x, y, z, c=color, s=1) 67 | plt.savefig(os.path.join(save_dir, 'profile.png')) 68 | plt.close() 69 | 70 | def visualize_ctrlpts(ctrlpts, save_path): 71 | num_pt = ctrlpts.shape[0] // 2 72 | plt.clf() 73 | f = plt.figure() 74 | ax = f.add_subplot(211) 75 | ax.set(xlim=(-0.12, 0.12), ylim=(-0.045, 0.015)) 76 | ax.scatter(ctrlpts[:num_pt, 0], ctrlpts[:num_pt, 1]) 77 | ax = f.add_subplot(212) 78 | ax.set(xlim=(-0.12, 0.12), ylim=(-0.045, 0.015)) 79 | ax.scatter(ctrlpts[num_pt:, 0], ctrlpts[num_pt:, 1]) 80 | plt.savefig(save_path) -------------------------------------------------------------------------------- /dynamics/profile_forward_3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from os.path import join as pjoin 4 | BASEPATH = os.path.dirname(__file__) 5 | sys.path.insert(0, BASEPATH) 6 | sys.path.insert(0, pjoin(BASEPATH, '..')) 7 | import torch 8 | import torch.nn as nn 9 | 10 | from dynamics.profile_forward_2d import get_embedder, timestep_embedding 11 | from dynamics.models.pointnet2 import PointNet2 12 | 13 | class ProfileForward3DModel(nn.Module): 14 | def __init__(self, W=256, params_ch=1250, ori_ch=1, pos_ch=2, output_ch=3): 15 | super(ProfileForward3DModel, self).__init__() 16 | self.W = W 17 | self.ori_ch = ori_ch 18 | self.pos_ch = pos_ch 19 | self.output_ch = output_ch 20 | self.ori_embed, ori_embed_dim = get_embedder(ori_ch, 4, 0, scalar_factor=1) 21 | self.pos_embed, pos_embed_dim = get_embedder(pos_ch, 4, 0, scalar_factor=1) 22 | self.ori_ch = ori_embed_dim 23 | self.pos_ch = pos_embed_dim 24 | self.pose_embed_dim = ori_embed_dim + pos_embed_dim 25 | self.time_embed_dim = W 26 | self.time_encoder = nn.Sequential( 27 | nn.Linear(W // 2, self.time_embed_dim), 28 | nn.SiLU(), 29 | nn.Linear(self.time_embed_dim, self.time_embed_dim), 30 | ) 31 | self.object_encoder = PointNet2(W) 32 | self.object_encode_dim = W 33 | self.gripper_encoder = nn.Sequential( 34 | nn.Linear(params_ch, W), 35 | nn.ReLU(), 36 | nn.Linear(W, W), 37 | ) 38 | self.gripper_encode_dim = W 39 | self.linears = nn.Sequential( 40 | nn.Linear(self.gripper_encode_dim + self.pose_embed_dim + self.time_embed_dim + self.object_encode_dim, W * 2), 41 | nn.BatchNorm1d(W * 2), 42 | nn.ReLU(), 43 | nn.Linear(W * 2, W), 44 | nn.BatchNorm1d(W), 45 | nn.ReLU(), 46 | nn.Linear(W, W), 47 | nn.BatchNorm1d(W), 48 | nn.ReLU(), 49 | nn.Linear(W, W), 50 | nn.BatchNorm1d(W), 51 | nn.ReLU(), 52 | nn.Linear(W, W), 53 | nn.BatchNorm1d(W), 54 | nn.ReLU(), 55 | nn.Linear(W, W), 56 | nn.BatchNorm1d(W), 57 | nn.ReLU(), 58 | nn.Linear(W, W), 59 | nn.BatchNorm1d(W), 60 | nn.ReLU(), 61 | nn.Linear(W, W), 62 | nn.BatchNorm1d(W), 63 | nn.ReLU(), 64 | ) 65 | self.output = nn.Linear(W, output_ch) 66 | 67 | def forward(self, x_ctrl, x_ori, x_pos, timesteps=None, object_vertices=None): 68 | ''' 69 | input: 70 | ctrlpts [batch_size, 3, 1250] / [batch_size, 3, 42] 71 | ori [batch_size, 1] 72 | pos [batch_size, 2] 73 | timesteps [batch_size,] 74 | object_pts [batch_size, 1024, 3] 75 | output: 76 | profile [batch_size, 9] 77 | ''' 78 | x_ctrl = self.gripper_encoder(x_ctrl[:, 1, :]) 79 | x_ori = self.ori_embed(x_ori) 80 | x_pos = self.pos_embed(x_pos) 81 | x_pose = torch.cat([x_ori, x_pos], dim=1) 82 | x_object, _ = self.object_encoder(object_vertices) 83 | time_emb = timestep_embedding(timesteps, self.time_embed_dim) 84 | x = self.linears(torch.cat([x_object, x_ctrl, x_pose, time_emb], dim=1)) 85 | x = self.output(x) 86 | return x -------------------------------------------------------------------------------- /dynamics/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--batch_size', type=int, default=1024) 6 | parser.add_argument('--use_sub_batch', action='store_true', help='use sub batch to avoid OOM') 7 | parser.add_argument('--sub_bs', type=int, default=1024, help='sub batch size for training') 8 | parser.add_argument('--num_epochs', type=int, default=1000, help='number of epochs for training') 9 | parser.add_argument('--num_fingers', type=int, default=1000, help='number of fingers') 10 | parser.add_argument('--ctrlpts_dim', type=int, default=14) 11 | parser.add_argument('--ctrlpts_x_dim', type=int, default=7) 12 | parser.add_argument('--ctrlpts_z_dim', type=int, default=3) 13 | parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate for optimizer') 14 | parser.add_argument('--lr_warmup_steps', type=int, default=100, help='learning rate warmup steps for optimizer') 15 | parser.add_argument('--weight_decay', type=float, default=0, help='weight decay for optimizer') 16 | parser.add_argument('--patience', type=int, default=500, help='patience for early stopping when training dynamics model') 17 | parser.add_argument('--checkpoint_path', type=str, default=None, help='path to load dynamics model checkpoints') 18 | parser.add_argument('--save_dir', type=str, help='path to save model checkpoints') 19 | parser.add_argument('--wandb_id', type=str, default=None, help='wandb id') 20 | parser.add_argument('--data_dir', type=str, default='', help='path to data directory') 21 | parser.add_argument('--test_data_dir', type=str, default='', help='path to test data directory') 22 | parser.add_argument('--object_dir', type=str, default='', help='path to object directory') 23 | parser.add_argument('--num_workers', type=int, default=4, help='number of workers for dataloader') 24 | parser.add_argument('--mode', type=str, default='train', help='train or test') 25 | parser.add_argument('--grid_size', type=int, default=360, help='number of initial orientations sampled for each object') 26 | parser.add_argument('--num_pos', type=int, default=9, help='number of initial positions sampled for each object') 27 | parser.add_argument('--save_ckpt_step', type=int, default=10, help='step to save model checkpoints') 28 | parser.add_argument('--val_step', type=int, default=100, help='step to validate model') 29 | parser.add_argument('--num_train_timesteps', type=int, default=1000, help='number of training timesteps for diffusion model') 30 | parser.add_argument('--num_timesteps_per_batch', type=int, default=1, help='number of timesteps per batch') 31 | parser.add_argument('--num_inference_steps', type=int, default=100, help='number of inference steps for diffusion model') 32 | parser.add_argument('--ema_power', type=float, default=0.75, help='ema power') 33 | parser.add_argument('--object_max_num_vertices', type=int, default=10, help='max number of vertices for object encoder') 34 | parser.add_argument('--diffusion_checkpoint_path', type=str, default=None, help='path to load diffusion model checkpoints') 35 | parser.add_argument('--classifier_guidance', action='store_true', help='use classifier guidance') 36 | parser.add_argument('--num_cpus', type=int, default=4, help='number of cpus used in parallel for simulation') 37 | parser.add_argument('--fingers_3d', action='store_true', help='use 3d fingers') 38 | parser.add_argument('--render_video', action='store_true', help='render videos visualizing interactions of fingers and objects') 39 | parser.add_argument('--seed', type=int, default=0, help='random seed') 40 | args = parser.parse_args() 41 | return args -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | MUJOCO_LOG.TXT 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | *.mp4 158 | *.out 159 | *.png 160 | *.ckpt 161 | *.mtl 162 | *.obj 163 | *.stl 164 | *.zarr 165 | wandb/ 166 | histogram/ 167 | test/ 168 | train_gt/ 169 | train_predicted/ 170 | gripper_diffusion/ 171 | vis/ 172 | render_output/ 173 | -------------------------------------------------------------------------------- /assets/scan_object_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from tqdm import tqdm 5 | import open3d as o3d 6 | import xml.etree.ElementTree as ET 7 | 8 | def get_bbox(data_dir): 9 | max = [] 10 | min = [] 11 | for root, dirs, files in os.walk(data_dir): 12 | for dir in tqdm(dirs): 13 | mesh_file = os.path.join(root, dir, 'model.obj') 14 | mesh = o3d.io.read_triangle_mesh(mesh_file) 15 | bbox = mesh.get_axis_aligned_bounding_box() 16 | max.append(bbox.get_max_bound().reshape(-1)) 17 | min.append(bbox.get_min_bound().reshape(-1)) 18 | max = np.stack(max, axis=0) 19 | # plot histogram 20 | plt.clf() 21 | plt.hist(max[..., 0], bins=100) 22 | plt.savefig('max_x.png') 23 | plt.clf() 24 | plt.hist(max[..., 1], bins=100) 25 | plt.savefig('max_y.png') 26 | plt.clf() 27 | plt.hist(max[..., 2], bins=100) 28 | plt.savefig('max_z.png') 29 | min = np.stack(min, axis=0) 30 | plt.clf() 31 | plt.hist(min[..., 0], bins=100) 32 | plt.savefig('min_x.png') 33 | plt.clf() 34 | plt.hist(min[..., 1], bins=100) 35 | plt.savefig('min_y.png') 36 | plt.clf() 37 | plt.hist(min[..., 2], bins=100) 38 | plt.savefig('min_z.png') 39 | print('max: ', np.max(max, axis=0)) 40 | print('min: ', np.min(min, axis=0)) 41 | 42 | def filter_object(data_dir): 43 | object_names = [] 44 | for root, dirs, files in os.walk(data_dir): 45 | for dir in tqdm(dirs): 46 | mesh_file = os.path.join(root, dir, 'model.obj') 47 | mesh = o3d.io.read_triangle_mesh(mesh_file) 48 | bbox = mesh.get_axis_aligned_bounding_box() 49 | max = bbox.get_max_bound().reshape(-1) 50 | min = bbox.get_min_bound().reshape(-1) 51 | if max[0] < 0.1 and min[0] > -0.1 and max[1] < 0.1 and min[1] > -0.1 and max[2] < 0.12: 52 | object_names.append(dir) 53 | # save object names 54 | with open('assets/object_names.txt', 'w') as f: 55 | for name in object_names: 56 | f.write(name + '\n') 57 | 58 | def read_object_names(test=False): 59 | filename = 'assets/object_names_test.txt' if test else 'assets/object_names.txt' 60 | object_names = [] 61 | with open(filename, 'r') as f: 62 | for line in f.readlines(): 63 | object_names.append(line.strip()) 64 | return object_names 65 | 66 | def generate_object_3d_xml(num_collision, object_idx, save_path): 67 | # Create the root element 68 | root = ET.Element("mujoco", model="object") 69 | 70 | # Create the 'asset' element 71 | asset = ET.SubElement(root, "asset") 72 | ET.SubElement(asset, "mesh", name="object", file="objects/%d/model.obj" % object_idx) 73 | 74 | for i in range(num_collision): 75 | ET.SubElement(asset, "mesh", name=f"object{i:03d}", file=f"objects/{object_idx}/model_collision_{i}.obj") 76 | 77 | # Create the 'worldbody' element 78 | worldbody = ET.SubElement(root, "worldbody") 79 | body = ET.SubElement(worldbody, "body", name="object") 80 | 81 | # Add 'freejoint' and 'geom' elements to 'body' 82 | ET.SubElement(body, "freejoint", name="object_root") 83 | object_v = ET.SubElement(body, "geom", mesh="object", type="mesh") 84 | object_v.set("class", "visual") 85 | 86 | for i in range(num_collision): 87 | object_c = ET.SubElement(body, "geom", mesh=f"object{i:03d}", type="mesh") 88 | object_c.set("class", "collision") 89 | 90 | # Create an ElementTree object and write to file 91 | tree = ET.ElementTree(root) 92 | tree.write(save_path) 93 | 94 | 95 | if __name__ == '__main__': 96 | # get_bbox('/store/real/xuxm/mujoco_scanned_objects/models') 97 | filter_object('/store/real/xuxm/mujoco_scanned_objects/models') -------------------------------------------------------------------------------- /assets/icon_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import trimesh 5 | import triangle 6 | 7 | def resample_contour(contour, num_points): 8 | # Flatten the contour array 9 | contour = contour.reshape(-1, 2) 10 | 11 | # Calculate the distance between each pair of points 12 | distances = np.sqrt(np.sum(np.diff(contour, axis=0)**2, axis=1)) 13 | distances = np.insert(distances, 0, 0) # insert 0 at the start 14 | cumulative_distances = np.cumsum(distances) 15 | 16 | # Create an array of evenly spaced distances along the contour 17 | uniform_distances = np.linspace(0, cumulative_distances[-1], num_points) 18 | 19 | # Use linear interpolation to find the x, y coordinates at the uniform distances 20 | uniform_contour_x = np.interp(uniform_distances, cumulative_distances, contour[:, 0]) 21 | uniform_contour_y = np.interp(uniform_distances, cumulative_distances, contour[:, 1]) 22 | 23 | # Stack the coordinates together 24 | uniform_contour = np.vstack((uniform_contour_x, uniform_contour_y)).T 25 | uniform_contour = uniform_contour.reshape(-1, 1, 2).astype(np.int32) 26 | 27 | return uniform_contour 28 | 29 | def extract_contours(image, num_points=100, rescale=True): 30 | image = cv2.resize(image, (128, 128)) 31 | 32 | # Convert to grayscale 33 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 34 | 35 | # Apply thresholding 36 | _, thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV) 37 | 38 | # Find contours 39 | contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 40 | 41 | # Since we want to apply this to the largest contour, let's first identify it 42 | contour_lengths = [cv2.arcLength(contour, True) for contour in contours] 43 | max_contour = contours[np.argmax(contour_lengths)] 44 | 45 | # Resample the largest contour to have a fixed number of points 46 | resampled_contour = resample_contour(max_contour, num_points) 47 | # reshape from (num_points, 1, 2) to (num_points, 2) 48 | resampled_contour = resampled_contour.reshape(-1, 2) 49 | 50 | if rescale: 51 | # rescale the contour to be in [-0.05, 0.05] 52 | resampled_contour = resampled_contour / 128 * 0.1 - 0.05 53 | 54 | return resampled_contour 55 | 56 | def draw_contour(image): 57 | contour = extract_contours(image, 100, rescale=False) 58 | image_with_contour = cv2.resize(image, (128, 128)) 59 | cv2.drawContours(image_with_contour, [contour], -1, (0, 255, 0), 1) 60 | return contour, image_with_contour 61 | 62 | def generate_icon_mesh(img, height, num_points=100): 63 | contour = extract_contours(img, num_points) 64 | x = contour[..., 0] 65 | y = contour[..., 1] 66 | z = np.zeros_like(x) 67 | vertices_2d = np.stack([x, y, z], axis=-1) 68 | # Extrude 69 | vertices_3d = np.concatenate([ 70 | vertices_2d, 71 | vertices_2d + [0, 0, height] 72 | ]) 73 | 74 | # Generate indices for side faces 75 | indices = np.arange(0, num_points) 76 | side_faces_upper = np.stack([indices, np.roll(indices, -1) + num_points, np.roll(indices, -1)], axis=1) 77 | side_faces_lower = np.stack([indices, indices + num_points, np.roll(indices, -1) + num_points], axis=1) 78 | sides = np.concatenate([side_faces_upper, side_faces_lower]) 79 | 80 | # Triangulate top and bottom faces 81 | # keep the boundary-edges of the triangulation 82 | top_faces = triangle.triangulate({'vertices': contour, 'segments': np.stack([indices, np.roll(indices, -1)], axis=1)}, 'p')['triangles'] 83 | bottom_faces = top_faces + num_points 84 | top_faces[:, [1, 2]] = top_faces[:, [2, 1]] 85 | 86 | # Combine faces 87 | faces_3d = np.concatenate([sides, top_faces, bottom_faces]) 88 | 89 | # Create mesh 90 | mesh = trimesh.Trimesh(vertices=vertices_3d, faces=faces_3d) 91 | 92 | return mesh, contour 93 | 94 | def save_icon_mesh(img, height, num_points, save_dir): 95 | os.makedirs(save_dir, exist_ok=True) 96 | mesh, contour = generate_icon_mesh(img, height, num_points) 97 | mesh_path = os.path.join(save_dir, 'object.obj') 98 | mesh.export(mesh_path) 99 | return contour, mesh_path -------------------------------------------------------------------------------- /dynamics/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | from dynamics.utils import sample_pts_from_mesh 6 | 7 | class DynamicsDataset(Dataset): 8 | def __init__(self, dataset_dir, object_mesh_dir, fingers_3d, gripper_pts_max_x, gripper_pts_min_x, gripper_pts_max_y, gripper_pts_min_y, gripper_pts_max_z, gripper_pts_min_z, object_max_num_vertices=10, object_pts_max_x=0.05, object_pts_min_x=-0.05, object_pts_max_y=0.05, object_pts_min_y=-0.05, object_pts_max_z=0.05, object_pts_min_z=-0.05): 9 | self.fingers_3d = fingers_3d 10 | if fingers_3d: 11 | self.threshold = np.array([0.02, 0.001, 0.001]) 12 | self.std = np.array([0.0312, 0.0016, 0.0026]) 13 | else: 14 | self.threshold = np.array([0.03, 0.002, 0.003]) 15 | self.std = np.array([0.0565, 0.0026, 0.0047]) 16 | self.gripper_pts_max_x = gripper_pts_max_x 17 | self.gripper_pts_min_x = gripper_pts_min_x 18 | self.gripper_pts_max_y = gripper_pts_max_y 19 | self.gripper_pts_min_y = gripper_pts_min_y 20 | self.gripper_pts_max_z = gripper_pts_max_z 21 | self.gripper_pts_min_z = gripper_pts_min_z 22 | self.object_max_num_vertices = object_max_num_vertices 23 | self.object_pts_max_x = object_pts_max_x 24 | self.object_pts_min_x = object_pts_min_x 25 | self.object_pts_max_y = object_pts_max_y 26 | self.object_pts_min_y = object_pts_min_y 27 | self.object_pts_max_z = object_pts_max_z 28 | self.object_pts_min_z = object_pts_min_z 29 | self.data_files = [] 30 | for root, dirs, files in os.walk(dataset_dir): 31 | for file in files: 32 | if file.endswith('.npz'): 33 | self.data_files.append(os.path.join(root, file)) 34 | self.object_pts = {} # used for caching object points 35 | self.object_mesh_dir = object_mesh_dir 36 | 37 | def __len__(self): 38 | return len(self.data_files) 39 | 40 | def __getitem__(self, idx): 41 | data = np.load(self.data_files[idx], allow_pickle=True)['arr_0'].item() 42 | # normalize with std (already zero-mean) 43 | train_scores = np.stack([data['delta_theta']/self.std[0], data['delta_pos'][:, 0]/self.std[1], data['delta_pos'][:, 1]/self.std[2]], axis=1) 44 | train_scores = torch.from_numpy(train_scores).float() 45 | train_ctrlpts = data['ctrlpts'] 46 | train_ctrlpts[..., 0] = (train_ctrlpts[..., 0] - self.gripper_pts_min_x) / (self.gripper_pts_max_x - self.gripper_pts_min_x) * 2.0 - 1.0 47 | train_ctrlpts[..., 1] = (train_ctrlpts[..., 1] - self.gripper_pts_min_y) / (self.gripper_pts_max_y - self.gripper_pts_min_y) * 2.0 - 1.0 48 | if self.fingers_3d: 49 | train_ctrlpts[..., 2] = (train_ctrlpts[..., 2] - self.gripper_pts_min_z) / (self.gripper_pts_max_z - self.gripper_pts_min_z) * 2.0 - 1.0 50 | train_ctrlpts = torch.from_numpy(train_ctrlpts).float() 51 | train_input_ori = data['obj_theta'] / np.pi - 1.0 52 | train_input_pos = data['obj_pos'][..., :2] / 0.03 53 | train_input_ori = torch.from_numpy(train_input_ori).float() 54 | train_input_pos = torch.from_numpy(train_input_pos).float() 55 | if self.fingers_3d: 56 | object_name = data['object_name'] 57 | if object_name not in self.object_pts.keys(): 58 | mesh_file = os.path.join(self.object_mesh_dir, object_name, 'model.obj') 59 | object_vertices = sample_pts_from_mesh(mesh_file, self.object_max_num_vertices) 60 | object_vertices[..., 0] = (object_vertices[..., 0] - self.object_pts_min_x) / (self.object_pts_max_x - self.object_pts_min_x) * 2.0 - 1.0 61 | object_vertices[..., 1] = (object_vertices[..., 1] - self.object_pts_min_y) / (self.object_pts_max_y - self.object_pts_min_y) * 2.0 - 1.0 62 | object_vertices[..., 2] = (object_vertices[..., 2] - self.object_pts_min_z) / (self.object_pts_max_z - self.object_pts_min_z) * 2.0 - 1.0 63 | self.object_pts[object_name] = object_vertices 64 | else: 65 | object_vertices = self.object_pts[object_name] 66 | object_vertices = torch.from_numpy(object_vertices).float() 67 | else: 68 | object_vertices = data['object_vertices'] 69 | object_vertices[..., 0] = (object_vertices[..., 0] - self.object_pts_min_x) / (self.object_pts_max_x - self.object_pts_min_x) * 2.0 - 1.0 70 | object_vertices[..., 1] = (object_vertices[..., 1] - self.object_pts_min_y) / (self.object_pts_max_y - self.object_pts_min_y) * 2.0 - 1.0 71 | object_vertices = torch.from_numpy(object_vertices).float() 72 | object_vertices = torch.cat([object_vertices, torch.zeros(self.object_max_num_vertices - object_vertices.shape[0], 2)], dim=0) 73 | return { 74 | 'ctrlpts': train_ctrlpts, 75 | 'scores': train_scores, 76 | 'input_ori': train_input_ori, 77 | 'input_pos': train_input_pos, 78 | 'object_vertices': object_vertices, 79 | } -------------------------------------------------------------------------------- /dynamics/profile_forward_2d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class Embedder: 6 | def __init__(self, **kwargs): 7 | self.kwargs = kwargs 8 | self.create_embedding_fn() 9 | 10 | def create_embedding_fn(self): 11 | """ 12 | Embeds x to (x, sin(2^k x), cos(2^k x), ...) 13 | """ 14 | embed_fns = [] 15 | d = self.kwargs['input_dims'] 16 | out_dim = 0 17 | if self.kwargs['include_input']: # original raw input "x" is also included in the output 18 | embed_fns.append(lambda x: x) 19 | out_dim += d 20 | 21 | max_freq = self.kwargs['max_freq_log2'] 22 | N_freqs = self.kwargs['num_freqs'] 23 | 24 | if self.kwargs['log_sampling']: 25 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) 26 | else: 27 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) 28 | 29 | for freq in freq_bands: 30 | for p_fn in self.kwargs['periodic_fns']: 31 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 32 | out_dim += d 33 | 34 | self.embed_fns = embed_fns 35 | self.out_dim = out_dim 36 | 37 | def embed(self, inputs): 38 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 39 | 40 | 41 | def get_embedder(input_dims, multires, i=0, scalar_factor=1): 42 | if i == -1: 43 | return nn.Identity(), 3 44 | 45 | embed_kwargs = { 46 | 'include_input': True, 47 | 'input_dims': input_dims, 48 | 'max_freq_log2': multires - 1, 49 | 'num_freqs': multires, 50 | 'log_sampling': True, 51 | 'periodic_fns': [torch.sin, torch.cos], 52 | } 53 | 54 | embedder_obj = Embedder(**embed_kwargs) 55 | embed = lambda x, eo=embedder_obj: eo.embed(x/scalar_factor) 56 | return embed, embedder_obj.out_dim 57 | 58 | def timestep_embedding(timesteps, dim, max_period=10000): 59 | """ 60 | Create sinusoidal timestep embeddings. 61 | 62 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 63 | These may be fractional. 64 | :param dim: the dimension of the output. 65 | :param max_period: controls the minimum frequency of the embeddings. 66 | :return: an [N x dim] Tensor of positional embeddings. 67 | """ 68 | half = dim // 2 69 | freqs = torch.exp( 70 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 71 | ).to(device=timesteps.device) 72 | args = timesteps[:, None].float() * freqs[None] 73 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 74 | if dim % 2: 75 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 76 | return embedding 77 | 78 | class ProfileForward2DModel(nn.Module): 79 | def __init__(self, W=256, params_ch=400, ori_ch=1, pos_ch=2, output_ch=3, object_ch=20): 80 | super(ProfileForward2DModel, self).__init__() 81 | self.W = W 82 | self.params_ch = params_ch 83 | self.ori_ch = ori_ch 84 | self.pos_ch = pos_ch 85 | self.output_ch = output_ch 86 | self.ori_embed, ori_embed_dim = get_embedder(ori_ch, 4, 0, scalar_factor=1) 87 | self.pos_embed, pos_embed_dim = get_embedder(pos_ch, 4, 0, scalar_factor=1) 88 | self.ori_ch = ori_embed_dim 89 | self.pos_ch = pos_embed_dim 90 | self.pose_embed_dim = ori_embed_dim + pos_embed_dim 91 | self.time_embed_dim = W 92 | self.time_encoder = nn.Sequential( 93 | nn.Linear(W // 2, self.time_embed_dim), 94 | nn.SiLU(), 95 | nn.Linear(self.time_embed_dim, self.time_embed_dim), 96 | ) 97 | self.object_encode_dim = W 98 | self.object_encoder = nn.Sequential( 99 | nn.Linear(object_ch, self.object_encode_dim), 100 | nn.ReLU(), 101 | nn.Linear(self.object_encode_dim, self.object_encode_dim), 102 | ) 103 | self.gripper_encoder = nn.Sequential( 104 | nn.Linear(params_ch, W), 105 | nn.ReLU(), 106 | nn.Linear(W, W), 107 | ) 108 | self.gripper_encode_dim = W 109 | self.linears = nn.Sequential( 110 | nn.Linear(self.gripper_encode_dim + self.pose_embed_dim + self.time_embed_dim + self.object_encode_dim, W), 111 | nn.BatchNorm1d(W), 112 | nn.ReLU(), 113 | nn.Linear(W, W), 114 | nn.BatchNorm1d(W), 115 | nn.ReLU(), 116 | nn.Linear(W, W), 117 | nn.BatchNorm1d(W), 118 | nn.ReLU(), 119 | nn.Linear(W, W), 120 | nn.BatchNorm1d(W), 121 | nn.ReLU(), 122 | nn.Linear(W, W), 123 | nn.BatchNorm1d(W), 124 | nn.ReLU(), 125 | nn.Linear(W, W), 126 | nn.BatchNorm1d(W), 127 | nn.ReLU(), 128 | nn.Linear(W, W), 129 | nn.BatchNorm1d(W), 130 | nn.ReLU(), 131 | nn.Linear(W, W), 132 | nn.BatchNorm1d(W), 133 | nn.ReLU(), 134 | ) 135 | self.output = nn.Linear(W, output_ch) 136 | 137 | def forward(self, x_ctrl, x_ori, x_pos, timesteps, object_vertices): 138 | ''' 139 | input: 140 | ctrlpts [batch_size, 400] 141 | ori [batch_size, 1] 142 | pos [batch_size, 2] 143 | timesteps [batch_size,] 144 | object_vertices [batch_size, 20] 145 | output: 146 | profile [batch_size, 9] 147 | ''' 148 | x_ctrl = self.gripper_encoder(x_ctrl) 149 | x_ori = self.ori_embed(x_ori) 150 | x_pos = self.pos_embed(x_pos) 151 | x_pose = torch.cat([x_ori, x_pos], dim=1) 152 | x_object = self.object_encoder(object_vertices) 153 | time_emb = self.time_encoder(timestep_embedding(timesteps, self.W // 2)) 154 | x = self.linears(torch.cat([x_object, x_ctrl, x_pose, time_emb], dim=1)) 155 | x = self.output(x) 156 | return x 157 | -------------------------------------------------------------------------------- /dynamics/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing 3 | import sys 4 | from os.path import join as pjoin 5 | BASEPATH = os.path.dirname(__file__) 6 | sys.path.insert(0, BASEPATH) 7 | sys.path.insert(0, pjoin(BASEPATH, '..')) 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from dynamics.profile_forward_3d import ProfileForward3DModel 13 | from dynamics.profile_forward_2d import ProfileForward2DModel 14 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 15 | 16 | class Trainer(object): 17 | def __init__(self, args): 18 | super(Trainer, self).__init__() 19 | self.use_sub_batch = args.use_sub_batch 20 | self.sub_batch_size = args.sub_bs 21 | self.grid_size = args.grid_size 22 | self.learning_rate = args.learning_rate 23 | self.weight_decay = args.weight_decay 24 | self.num_epochs = args.num_epochs 25 | self.ckpt_path = args.checkpoint_path 26 | self.fingers_3d = args.fingers_3d 27 | if self.fingers_3d: 28 | self.gripperpts_dim = args.ctrlpts_dim 29 | self.object_vertices_dim = args.object_max_num_vertices 30 | else: 31 | self.gripperpts_dim = args.ctrlpts_dim 32 | self.object_vertices_dim = 2*args.object_max_num_vertices 33 | self.loss_fn = nn.MSELoss() 34 | self.num_timesteps_per_batch = args.num_timesteps_per_batch 35 | self.num_inference_steps = args.num_inference_steps 36 | self.noise_scheduler = DDIMScheduler(num_train_timesteps=args.num_train_timesteps,beta_schedule='squaredcos_cap_v2', clip_sample=True, prediction_type='epsilon') # squared cosine beta schedule 37 | self.noise_scheduler.set_timesteps(self.num_inference_steps) 38 | 39 | def create_model(self): 40 | if self.fingers_3d: 41 | self.model = nn.DataParallel(ProfileForward3DModel(output_ch=3, params_ch=self.gripperpts_dim).cuda()) 42 | else: 43 | self.model = nn.DataParallel(ProfileForward2DModel(output_ch=3, params_ch=self.gripperpts_dim, object_ch=self.object_vertices_dim).cuda()) 44 | for param in self.model.parameters(): 45 | param.requires_grad = True 46 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, betas=(0.9, 0.95), weight_decay=self.weight_decay) 47 | self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.num_epochs, eta_min=1e-2*self.learning_rate) 48 | if self.ckpt_path is not None: 49 | print('loading checkpoint from', self.ckpt_path) 50 | self.model.load_state_dict(torch.load(self.ckpt_path)) 51 | print('done') 52 | 53 | def step(self, ctrl, score, input_ori=None, input_pos=None, object_vertices=None): 54 | self.model.train() 55 | if self.fingers_3d: 56 | input_ctrl_all = ctrl.repeat(self.num_timesteps_per_batch, 1, 1) # already normalized to [-1,1] 57 | object_vertices_all = object_vertices.repeat(self.num_timesteps_per_batch, 1, 1) 58 | else: 59 | input_ctrl_all = ctrl.repeat(self.num_timesteps_per_batch, 1) # already normalized to [-1,1] 60 | object_vertices_all = object_vertices.repeat(self.num_timesteps_per_batch, 1) 61 | input_ori_all = input_ori.repeat(self.num_timesteps_per_batch, 1) 62 | input_pos_all = input_pos.repeat(self.num_timesteps_per_batch, 1) 63 | score_all = score.repeat(self.num_timesteps_per_batch, 1) 64 | 65 | # sample noise to add 66 | if self.fingers_3d: 67 | noise = torch.cat([torch.zeros((input_ctrl_all.shape[0]*self.num_timesteps_per_batch, 1, self.gripperpts_dim)), torch.randn((input_ctrl_all.shape[0]*self.num_timesteps_per_batch, 1, self.gripperpts_dim)), torch.zeros((input_ctrl_all.shape[0]*self.num_timesteps_per_batch, 1, self.gripperpts_dim))], dim=1).cuda() 68 | else: 69 | noise = torch.randn((input_ctrl_all.shape[0]*self.num_timesteps_per_batch, self.gripperpts_dim),).cuda() 70 | timesteps = torch.randint( 71 | 0, 72 | self.noise_scheduler.config.num_train_timesteps, # type: ignore 73 | (input_ctrl_all.shape[0],), 74 | ).long().cuda() 75 | noisy_ctrl_all = self.noise_scheduler.add_noise( 76 | original_samples=typing.cast(torch.FloatTensor, input_ctrl_all), 77 | noise=typing.cast(torch.FloatTensor, noise), 78 | timesteps=typing.cast(torch.IntTensor, timesteps), 79 | ) 80 | timesteps = timesteps.float() / self.noise_scheduler.config.num_train_timesteps # rescale to [0,1] 81 | if self.use_sub_batch: 82 | all_loss = 0.0 83 | all_pred = [] 84 | for i in range(0, noisy_ctrl_all.shape[0], self.sub_batch_size): 85 | pred = self.model(noisy_ctrl_all[i:i+self.sub_batch_size], input_ori_all[i:i+self.sub_batch_size], input_pos_all[i:i+self.sub_batch_size], timesteps[i:i+self.sub_batch_size], 86 | object_vertices=object_vertices_all[i:i+self.sub_batch_size]) 87 | loss = self.loss_fn(pred, score_all[i:i+self.sub_batch_size]) 88 | all_loss += loss.item() 89 | all_pred.append(pred.detach()) 90 | self.optimizer.zero_grad() 91 | loss.backward() 92 | self.optimizer.step() 93 | all_loss /= (noisy_ctrl_all.shape[0] / self.sub_batch_size) 94 | all_pred = torch.cat(all_pred, dim=0) 95 | return all_loss, all_pred 96 | else: 97 | pred = self.model(noisy_ctrl_all, input_ori_all, input_pos_all, timesteps=timesteps, object_vertices=object_vertices_all) 98 | loss = self.loss_fn(pred, score) 99 | 100 | self.optimizer.zero_grad() 101 | loss.backward() 102 | self.optimizer.step() 103 | return loss.item(), pred.detach() 104 | 105 | def save_checkpoint(self, checkpoint_path): 106 | torch.save(self.model.state_dict(), checkpoint_path) 107 | 108 | def inference(self, ctrl, score, input_ori=None, input_pos=None, object_vertices=None): 109 | self.model.eval() 110 | with torch.no_grad(): 111 | if self.fingers_3d: 112 | input_ctrl_all = ctrl.repeat(self.num_timesteps_per_batch, 1, 1) 113 | object_vertices_all = object_vertices.repeat(self.num_timesteps_per_batch, 1, 1) 114 | else: 115 | input_ctrl_all = ctrl.repeat(self.num_timesteps_per_batch, 1) # already normalized to [-1,1] 116 | object_vertices_all = object_vertices.repeat(self.num_timesteps_per_batch, 1) 117 | input_ori_all = input_ori.repeat(self.num_timesteps_per_batch, 1) 118 | input_pos_all = input_pos.repeat(self.num_timesteps_per_batch, 1) 119 | score_all = score.repeat(self.num_timesteps_per_batch, 1) 120 | 121 | # sample noise to add 122 | if self.fingers_3d: 123 | noise = torch.cat([torch.zeros((input_ctrl_all.shape[0]*self.num_timesteps_per_batch, 1, self.gripperpts_dim)), torch.randn((input_ctrl_all.shape[0]*self.num_timesteps_per_batch, 1, self.gripperpts_dim)), torch.zeros((input_ctrl_all.shape[0]*self.num_timesteps_per_batch, 1, self.gripperpts_dim))], dim=1).cuda() 124 | else: 125 | noise = torch.randn((input_ctrl_all.shape[0]*self.num_timesteps_per_batch, self.gripperpts_dim),).cuda() 126 | timesteps = torch.randint( 127 | 0, 128 | self.noise_scheduler.config.num_train_timesteps, # type: ignore 129 | (input_ctrl_all.shape[0],), 130 | ).long().cuda() 131 | noisy_ctrl_all = self.noise_scheduler.add_noise( 132 | original_samples=typing.cast(torch.FloatTensor, input_ctrl_all), 133 | noise=typing.cast(torch.FloatTensor, noise), 134 | timesteps=typing.cast(torch.IntTensor, timesteps), 135 | ) 136 | timesteps = timesteps.float() / self.noise_scheduler.config.num_train_timesteps 137 | if self.use_sub_batch: 138 | all_loss = 0.0 139 | all_pred = [] 140 | for i in range(0, noisy_ctrl_all.shape[0], self.sub_batch_size): 141 | pred = self.model(noisy_ctrl_all[i:i+self.sub_batch_size], input_ori_all[i:i+self.sub_batch_size], input_pos_all[i:i+self.sub_batch_size], timesteps[i:i+self.sub_batch_size], object_vertices=object_vertices_all[i:i+self.sub_batch_size]) 142 | loss = self.loss_fn(pred, score_all[i:i+self.sub_batch_size]) 143 | all_loss += loss.item() 144 | all_pred.append(pred.detach()) 145 | all_loss /= (noisy_ctrl_all.shape[0] / self.sub_batch_size) 146 | all_pred = torch.cat(all_pred, dim=0) 147 | return all_pred, all_loss 148 | else: 149 | pred = self.model(noisy_ctrl_all, input_ori_all, input_pos_all, timesteps, object_vertices=object_vertices_all) 150 | loss = self.loss_fn(pred, score) 151 | return pred, loss.item() -------------------------------------------------------------------------------- /assets/finger_sampler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from scipy.interpolate import CubicSpline 4 | import trimesh 5 | import xml.etree.ElementTree as ET 6 | 7 | def generate_finger_shape(x, y, width, height, num_points=100): 8 | # Create spline (cubic curve, also degree=3 b-spline) 9 | cs = CubicSpline(x, y) 10 | x_new = np.linspace(x.min(), x.max(), num_points) 11 | y_new = cs(x_new) 12 | z = np.zeros_like(x_new) 13 | vertices_2d = np.stack([x_new, y_new, z], axis=-1) 14 | 15 | # Extrude 16 | vertices_3d = np.concatenate([ 17 | vertices_2d, 18 | vertices_2d + [0, width, 0], 19 | vertices_2d + [0, width, height], 20 | vertices_2d + [0, 0, height] 21 | ]) 22 | 23 | # Create faces 24 | bottom = [[i+num_points, i+num_points+1, i+1, i] for i in range(num_points-1)] 25 | top = [[i+2*num_points, i+3*num_points, i+3*num_points+1, i+2*num_points+1] for i in range(num_points-1)] 26 | left = [[i, i+1, i+3*num_points+1, i+3*num_points] for i in range(num_points-1)] 27 | right = [[i+2*num_points, i+2*num_points+1, i+num_points+1, i+num_points] for i in range(num_points-1)] 28 | front = [[3*num_points, 2*num_points, num_points, 0]] 29 | back = [[num_points-1, 2*num_points-1, 3*num_points-1, 4*num_points-1]] 30 | 31 | faces_3d = left + right + front + back + top + bottom 32 | 33 | # Create mesh 34 | mesh = trimesh.Trimesh(vertices=vertices_3d, faces=faces_3d) 35 | 36 | return mesh, x_new, y_new 37 | 38 | def generate_gripper(finger_x, finger_yl, finger_yr, num_points): 39 | cs_l = CubicSpline(finger_x, finger_yl) 40 | x_new = np.linspace(finger_x.min(), finger_x.max(), num_points) 41 | y_new_l = cs_l(x_new) 42 | cs_r = CubicSpline(finger_x, finger_yr) 43 | y_new_r = cs_r(x_new) 44 | ctrlptsl = np.stack([finger_x, finger_yl], axis=-1) 45 | ctrlptsr = np.stack([finger_x, finger_yr], axis=-1) 46 | ctrlpts = np.concatenate((ctrlptsl, ctrlptsr), axis=0) 47 | allptsl = np.stack([x_new, y_new_l], axis=-1) 48 | allptsr = np.stack([x_new, y_new_r], axis=-1) 49 | allpts = np.concatenate((allptsl, allptsr), axis=0) 50 | return ctrlpts, allpts 51 | 52 | def save_gripper(finger_x, finger_yl, finger_yr, width, height, num_points, save_gripper_dir): 53 | os.makedirs(save_gripper_dir, exist_ok=True) 54 | meshl, x_new_l, y_new_l = generate_finger_shape(finger_x, finger_yl, width, height, num_points) 55 | meshl.export(os.path.join(save_gripper_dir, 'fingerl.obj')) 56 | meshr, x_new_r, y_new_r = generate_finger_shape(finger_x, finger_yr, width, height, num_points) 57 | meshr.export(os.path.join(save_gripper_dir, 'fingerr.obj')) 58 | ctrlptsl = np.stack([finger_x, finger_yl], axis=-1) 59 | ctrlptsr = np.stack([finger_x, finger_yr], axis=-1) 60 | ctrlpts = np.concatenate((ctrlptsl, ctrlptsr), axis=0) 61 | allptsl = np.stack([x_new_l, y_new_l], axis=-1) 62 | allptsr = np.stack([x_new_r, y_new_r], axis=-1) 63 | allpts = np.concatenate((allptsl, allptsr), axis=0) 64 | return ctrlpts, allpts 65 | 66 | def create_mesh_elements(num_meshes, mesh_prefix, gripper_idx): 67 | """ Create mesh elements for a given prefix and number of meshes. """ 68 | return [ET.Element("mesh", name=f"{mesh_prefix}{i:03d}", file=f"grippers/{gripper_idx}/{mesh_prefix}{i:03d}.obj") 69 | for i in range(num_meshes)] 70 | 71 | def create_geom_elements(num_meshes, mesh_prefix): 72 | """ Create geom elements for a given prefix and number of meshes. """ 73 | return [ET.Element("geom", mesh=f"{mesh_prefix}{i:03d}", type="mesh", attrib={"class": "collision"}) 74 | for i in range(num_meshes)] 75 | 76 | def generate_xml_optimized(left_num_collision_meshes, right_num_collision_meshes, gripper_idx, save_path): 77 | root = ET.Element("mujoco", model="gripper_2d") 78 | asset = ET.SubElement(root, "asset") 79 | 80 | # Creating mesh elements for left and right 81 | left_meshes = create_mesh_elements(left_num_collision_meshes, "fingerl", gripper_idx) 82 | right_meshes = create_mesh_elements(right_num_collision_meshes, "fingerr", gripper_idx) 83 | asset.extend([ET.Element("mesh", name="fingerl", file=f"grippers/{gripper_idx}/fingerl.obj"), 84 | ET.Element("mesh", name="fingerr", file=f"grippers/{gripper_idx}/fingerr.obj")] + left_meshes + right_meshes) 85 | 86 | default = ET.SubElement(root, "default") 87 | ET.SubElement(default, "joint", type="slide", axis="0 1 0", damping="1") 88 | 89 | worldbody = ET.SubElement(root, "worldbody") 90 | fingers = ET.SubElement(worldbody, "body", name="fingers", pos="0 0 0") 91 | 92 | # Left jaw and its geometries 93 | left_jaw = ET.SubElement(fingers, "body", name="left_jaw", pos="0 -0.15 0") 94 | ET.SubElement(left_jaw, "joint", name="left_grip") 95 | fingerl = ET.SubElement(left_jaw, "geom", mesh="fingerl", type="mesh", attrib={"class": "visual"}) 96 | left_jaw.extend(create_geom_elements(left_num_collision_meshes, "fingerl")) 97 | 98 | # Right jaw and its geometries 99 | right_jaw = ET.SubElement(fingers, "body", name="right_jaw", pos="0 0.15 0") 100 | ET.SubElement(right_jaw, "joint", name="right_grip") 101 | fingerr = ET.SubElement(right_jaw, "geom", mesh="fingerr", type="mesh", attrib={"class": "visual"}) 102 | right_jaw.extend(create_geom_elements(right_num_collision_meshes, "fingerr")) 103 | 104 | actuator = ET.SubElement(root, "actuator") 105 | left_act = ET.SubElement(actuator, "position", name="left", joint="left_grip", ctrlrange="0 0.1", kp="10") 106 | right_act = ET.SubElement(actuator, "position", name="right", joint="right_grip", ctrlrange="-0.1 0", kp="10") 107 | 108 | tree = ET.ElementTree(root) 109 | tree.write(save_path) 110 | 111 | def generate_xml(left_num_collision_meshes, right_num_collision_meshes, gripper_idx, save_path): 112 | root = ET.Element("mujoco", model="gripper_2d") 113 | asset = ET.SubElement(root, "asset") 114 | # Creating mesh elements for left and right 115 | left_meshes = create_mesh_elements(left_num_collision_meshes, "fingerl", gripper_idx) 116 | right_meshes = create_mesh_elements(right_num_collision_meshes, "fingerr", gripper_idx) 117 | asset.extend([ET.Element("mesh", name="fingerl", file=f"grippers/{gripper_idx}/fingerl.obj"), 118 | ET.Element("mesh", name="fingerr", file=f"grippers/{gripper_idx}/fingerr.obj")] + left_meshes + right_meshes) 119 | 120 | default = ET.SubElement(root, "default") 121 | ET.SubElement(default, "joint", type="slide", axis="0 1 0", damping="1") 122 | 123 | worldbody = ET.SubElement(root, "worldbody") 124 | fingers = ET.SubElement(worldbody, "body", name="fingers", pos="0 0 0") 125 | 126 | left_jaw = ET.SubElement(fingers, "body", name="left_jaw", pos="0 -0.15 0") 127 | ET.SubElement(left_jaw, "joint", name="left_grip") 128 | fingerl = ET.SubElement(left_jaw, "geom", mesh="fingerl", type="mesh") 129 | fingerl.set("class", "visual") 130 | 131 | for i in range(left_num_collision_meshes): 132 | fingerl_c = ET.SubElement(left_jaw, "geom", mesh=f"fingerl{i:03d}", type="mesh") 133 | fingerl_c.set("class", "collision") 134 | 135 | right_jaw = ET.SubElement(fingers, "body", name="right_jaw", pos="0 0.15 0") 136 | ET.SubElement(right_jaw, "joint", name="right_grip") 137 | fingerr = ET.SubElement(right_jaw, "geom", mesh="fingerr", type="mesh") 138 | fingerr.set("class", "visual") 139 | for i in range(right_num_collision_meshes): 140 | fingerr_c = ET.SubElement(right_jaw, "geom", mesh=f"fingerr{i:03d}", type="mesh") 141 | fingerr_c.set("class", "collision") 142 | 143 | actuator = ET.SubElement(root, "actuator") 144 | left_act = ET.SubElement(actuator, "position", name="left", joint="left_grip") 145 | left_act.set("ctrlrange", "0 0.1") 146 | left_act.set("kp", "10") 147 | right_act = ET.SubElement(actuator, "position", name="right", joint="right_grip") 148 | right_act.set("ctrlrange", "-0.1 0") 149 | right_act.set("kp", "10") 150 | 151 | tree = ET.ElementTree(root) 152 | tree.write(save_path) 153 | 154 | def generate_scene_xml(object_idx, gripper_idx, save_path): 155 | root = ET.Element("mujoco", model="scene") 156 | 157 | defaults = ET.SubElement(root, "default") 158 | 159 | # Add collision default 160 | collision_default = ET.SubElement(defaults, "default", {"class": "collision"}) 161 | ET.SubElement(collision_default, "geom", group="3", condim="4", friction="1.0 0.005 0.0001") 162 | 163 | # Add visual default 164 | visual_default = ET.SubElement(defaults, "default", {"class": "visual"}) 165 | ET.SubElement(visual_default, "geom", group="2", contype="0", conaffinity="0") 166 | 167 | # Include external XML files 168 | ET.SubElement(root, "include", file="object_%d.xml" % object_idx) 169 | ET.SubElement(root, "include", file="gripper_%d.xml" % gripper_idx) 170 | 171 | # Create worldbody and its child elements 172 | worldbody = ET.SubElement(root, "worldbody") 173 | body = ET.SubElement(worldbody, "body", name="plane", pos="0 0 -0.01") 174 | ET.SubElement(body, "geom", type="plane", size="1 1 0.1", rgba="1.0 1.0 1.0 1") 175 | 176 | tree = ET.ElementTree(root) 177 | tree.write(save_path) -------------------------------------------------------------------------------- /assets/finger_3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os 3 | import numpy as np 4 | from geomdl import BSpline 5 | from geomdl import utilities 6 | from geomdl import exchange 7 | import trimesh 8 | import time 9 | import xml.etree.ElementTree as ET 10 | 11 | TMP_DIR = './tmp' 12 | 13 | def generate_3d_finger_shape(control_points, degree_u=3, degree_v=2, sample_size=100): 14 | """Generate 3D finger shape from 3D control points. 15 | 16 | Args: 17 | control_points (np.array): Control points of the finger shape. 18 | degree_u (int, optional): Degree of the Bezier surface in u-direction. Defaults to 3. 19 | degree_v (int, optional): Degree of the Bezier surface in v-direction. Defaults to 2. 20 | sample_size (int, optional): Number of samples. Defaults to 100. 21 | """ 22 | surf = BSpline.Surface() 23 | surf.degree_u = degree_u 24 | surf.degree_v = degree_v 25 | surf.set_ctrlpts(control_points, 7, 3) 26 | surf.knotvector_u = utilities.generate_knot_vector(surf.degree_u, surf.ctrlpts_size_u) 27 | surf.knotvector_v = utilities.generate_knot_vector(surf.degree_v, surf.ctrlpts_size_v) 28 | surf.sample_size = sample_size 29 | surf.evaluate() 30 | os.makedirs(TMP_DIR, exist_ok=True) 31 | tmp_file = os.path.join(TMP_DIR, '%f.obj' % time.time()) 32 | exchange.export_obj(surf, tmp_file) 33 | mesh = trimesh.load(tmp_file) 34 | vertices = mesh.vertices 35 | faces = mesh.faces 36 | return vertices, faces 37 | 38 | def generate_3d_finger_mesh(control_points, degree_u=3, degree_v=2, sample_size=25, width=0.12): 39 | surf_vertices, surf_faces = generate_3d_finger_shape(control_points, degree_u, degree_v, sample_size) 40 | num_surf_vertices = surf_vertices.shape[0] 41 | all_vertices = np.concatenate([ 42 | surf_vertices, 43 | surf_vertices + [0, width, 0] 44 | ]) 45 | surf_contour_indices = np.concatenate([np.arange(sample_size-1), np.arange(sample_size-1, sample_size**2-sample_size, sample_size), np.arange(sample_size**2-1, sample_size**2-sample_size, -1), np.arange(sample_size**2-sample_size, 0, -sample_size)]) 46 | side_faces_upper = np.stack([surf_contour_indices, np.roll(surf_contour_indices, -1), np.roll(surf_contour_indices, -1)+num_surf_vertices], axis=-1) 47 | side_faces_lower = np.stack([surf_contour_indices, np.roll(surf_contour_indices, -1)+num_surf_vertices, surf_contour_indices+num_surf_vertices], axis=-1) 48 | bottom_faces = surf_faces + num_surf_vertices 49 | bottom_faces[:, [1, 2]] = bottom_faces[:, [2, 1]] 50 | all_faces = np.concatenate([ 51 | surf_faces, 52 | bottom_faces, 53 | side_faces_upper, 54 | side_faces_lower, 55 | ]) 56 | mesh = trimesh.Trimesh(vertices=all_vertices, faces=all_faces) 57 | return mesh, surf_vertices 58 | 59 | def generate_3d_finger_vertices(control_points, degree_u=3, degree_v=2, sample_size=25): 60 | surf = BSpline.Surface() 61 | surf.degree_u = degree_u 62 | surf.degree_v = degree_v 63 | surf.set_ctrlpts(control_points, 7, 3) 64 | surf.knotvector_u = utilities.generate_knot_vector(surf.degree_u, surf.ctrlpts_size_u) 65 | surf.knotvector_v = utilities.generate_knot_vector(surf.degree_v, surf.ctrlpts_size_v) 66 | surf.sample_size = sample_size 67 | return np.array(surf.evalpts).reshape(-1, 3) 68 | 69 | def save_3d_gripper(yl, yr, width=0.12, sample_size=25, save_gripper_dir=''): 70 | x = np.linspace(-0.12, 0.12, 7) 71 | z = np.linspace(0, 0.12, 3) 72 | x_n, z_n = np.meshgrid(x, z) 73 | ctrlpts_l = np.stack([x_n.T.reshape(-1), yl, z_n.T.reshape(-1)], axis=-1) 74 | ctrlpts_r = np.stack([x_n.T.reshape(-1), yr, z_n.T.reshape(-1)], axis=-1) 75 | mesh_l, vertices_l = generate_3d_finger_mesh(ctrlpts_l.tolist(), width=width, sample_size=sample_size) 76 | mesh_r, vertices_r = generate_3d_finger_mesh(ctrlpts_r.tolist(), width=width, sample_size=sample_size) 77 | os.makedirs(save_gripper_dir, exist_ok=True) 78 | mesh_l.export(os.path.join(save_gripper_dir, 'fingerl.obj')) 79 | mesh_r.export(os.path.join(save_gripper_dir, 'fingerr.obj')) 80 | return np.concatenate((ctrlpts_l, ctrlpts_r), axis=0), np.concatenate((vertices_l, vertices_r), axis=0) 81 | 82 | def generate_3d_ctrlpts(yl, yr): 83 | x = np.linspace(-0.12, 0.12, 7) 84 | z = np.linspace(0, 0.12, 3) 85 | x_n, z_n = np.meshgrid(x, z) 86 | ctrlpts_l = np.stack([x_n.T.reshape(-1), yl, z_n.T.reshape(-1)], axis=-1) 87 | ctrlpts_r = np.stack([x_n.T.reshape(-1), yr, z_n.T.reshape(-1)], axis=-1) 88 | return np.concatenate((ctrlpts_l, ctrlpts_r), axis=0) 89 | 90 | def generate_3d_gripper(yl, yr, sample_size=25): 91 | x = np.linspace(-0.12, 0.12, 7) 92 | z = np.linspace(0, 0.12, 3) 93 | x_n, z_n = np.meshgrid(x, z) 94 | ctrlpts_l = np.stack([x_n.T.reshape(-1), yl, z_n.T.reshape(-1)], axis=-1) 95 | ctrlpts_r = np.stack([x_n.T.reshape(-1), yr, z_n.T.reshape(-1)], axis=-1) 96 | vertices_l, _ = generate_3d_finger_shape(ctrlpts_l.tolist(), sample_size=sample_size) 97 | vertices_r, _ = generate_3d_finger_shape(ctrlpts_r.tolist(), sample_size=sample_size) 98 | return np.concatenate((ctrlpts_l, ctrlpts_r), axis=0), np.concatenate((vertices_l, vertices_r), axis=0) 99 | 100 | def create_mesh_elements(num_meshes, mesh_prefix, gripper_idx): 101 | """ Create mesh elements for a given prefix and number of meshes. """ 102 | return [ET.Element("mesh", name=f"{mesh_prefix}{i:03d}", file=f"grippers/{gripper_idx}/{mesh_prefix}{i:03d}.obj") 103 | for i in range(num_meshes)] 104 | 105 | def create_geom_elements(num_meshes, mesh_prefix): 106 | """ Create geom elements for a given prefix and number of meshes. """ 107 | return [ET.Element("geom", mesh=f"{mesh_prefix}{i:03d}", type="mesh", attrib={"class": "collision"}) 108 | for i in range(num_meshes)] 109 | 110 | def generate_gripper_3d_xml(left_num_collision_meshes, right_num_collision_meshes, gripper_idx, save_path): 111 | root = ET.Element("mujoco", model="gripper_3d") 112 | asset = ET.SubElement(root, "asset") 113 | # Creating mesh elements for left and right 114 | left_meshes = create_mesh_elements(left_num_collision_meshes, "fingerl", gripper_idx) 115 | right_meshes = create_mesh_elements(right_num_collision_meshes, "fingerr", gripper_idx) 116 | asset.extend([ET.Element("mesh", name="fingerl", file=f"grippers/{gripper_idx}/fingerl.obj"), 117 | ET.Element("mesh", name="fingerr", file=f"grippers/{gripper_idx}/fingerr.obj")] + left_meshes + right_meshes) 118 | 119 | default = ET.SubElement(root, "default") 120 | 121 | ET.SubElement(default, "joint", type="slide", axis="0 1 0", damping="1") 122 | 123 | worldbody = ET.SubElement(root, "worldbody") 124 | fingers = ET.SubElement(worldbody, "body", name="fingers", pos="0 0 0") 125 | 126 | left_jaw = ET.SubElement(fingers, "body", name="left_jaw", pos="0 -0.23 0") 127 | ET.SubElement(left_jaw, "joint", name="left_grip") 128 | fingerl = ET.SubElement(left_jaw, "geom", mesh="fingerl", type="mesh", rgba="0.9333 0.7804 0.3490 1") 129 | fingerl.set("class", "visual") 130 | 131 | for i in range(left_num_collision_meshes): 132 | fingerl_c = ET.SubElement(left_jaw, "geom", mesh=f"fingerl{i:03d}", type="mesh") 133 | fingerl_c.set("class", "collision") 134 | 135 | right_jaw = ET.SubElement(fingers, "body", name="right_jaw", pos="0 0.23 0") 136 | ET.SubElement(right_jaw, "joint", name="right_grip") 137 | fingerr = ET.SubElement(right_jaw, "geom", mesh="fingerr", type="mesh", rgba="0.6941 0.7647 0.5059 1") 138 | fingerr.set("class", "visual") 139 | for i in range(right_num_collision_meshes): 140 | fingerr_c = ET.SubElement(right_jaw, "geom", mesh=f"fingerr{i:03d}", type="mesh") 141 | fingerr_c.set("class", "collision") 142 | 143 | actuator = ET.SubElement(root, "actuator") 144 | left_act = ET.SubElement(actuator, "position", name="left", joint="left_grip") 145 | left_act.set("ctrlrange", "0 0.1") 146 | left_act.set("kp", "10") 147 | right_act = ET.SubElement(actuator, "position", name="right", joint="right_grip") 148 | right_act.set("ctrlrange", "-0.1 0") 149 | right_act.set("kp", "10") 150 | 151 | tree = ET.ElementTree(root) 152 | tree.write(save_path) 153 | 154 | def generate_scene_3d_xml(object_idx, gripper_idx, save_path): 155 | root = ET.Element("mujoco", model="scene") 156 | 157 | defaults = ET.SubElement(root, "default") 158 | 159 | # Add collision default 160 | collision_default = ET.SubElement(defaults, "default", {"class": "collision"}) 161 | ET.SubElement(collision_default, "geom", group="3", condim="4", friction="1.0 0.005 0.0001") 162 | 163 | # Add visual default 164 | visual_default = ET.SubElement(defaults, "default", {"class": "visual"}) 165 | ET.SubElement(visual_default, "geom", group="2", contype="0", conaffinity="0") 166 | 167 | # Include external XML files 168 | ET.SubElement(root, "include", file="object_%d.xml" % object_idx) 169 | ET.SubElement(root, "include", file="gripper_%d.xml" % gripper_idx) 170 | 171 | # Create worldbody and its child elements 172 | worldbody = ET.SubElement(root, "worldbody") 173 | body = ET.SubElement(worldbody, "body", name="plane", pos="0 0 -0.01") 174 | ET.SubElement(body, "geom", type="plane", size="1 1 0.1", rgba="1.0 1.0 1.0 1") 175 | 176 | tree = ET.ElementTree(root) 177 | tree.write(save_path) -------------------------------------------------------------------------------- /generator/train.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Trainer as LightningTrainer 2 | from pytorch_lightning.loggers import WandbLogger 3 | from pytorch_lightning.callbacks import ( 4 | LearningRateMonitor, 5 | ModelCheckpoint, 6 | RichProgressBar, 7 | ) 8 | import numpy as np 9 | import os 10 | import sys 11 | from os.path import join as pjoin 12 | 13 | BASEPATH = os.path.dirname(__file__) 14 | sys.path.insert(0, BASEPATH) 15 | sys.path.insert(0, pjoin(BASEPATH, '..')) 16 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 17 | 18 | from generator.diffusion import Diffusion 19 | from generator.diffusion_utils import ConditionalUnet1D 20 | from generator.dataloader import GripperDataset 21 | from dynamics.parser import parse 22 | from assets.finger_3d import generate_3d_ctrlpts 23 | from assets.icon_process import extract_contours 24 | from assets.scan_object_process import read_object_names 25 | from dynamics.utils import sample_pts_from_mesh 26 | from dynamics.profile_forward_3d import ProfileForward3DModel 27 | from dynamics.profile_forward_2d import ProfileForward2DModel 28 | 29 | import torch 30 | import torch.nn as nn 31 | from torch.utils.data import DataLoader 32 | 33 | torch.multiprocessing.set_sharing_strategy("file_system") 34 | 35 | rank_idx = os.environ.get("NODE_RANK", 0) 36 | OBJECT_IDS = [10000, 2009, 2114, 2082, 1041, 2048, 1045, 1019] 37 | 38 | def train(args): 39 | total_num = args.num_fingers 40 | train_ids = list(range(int(total_num*0.9))) 41 | val_ids = list(range(int(total_num*0.9), total_num)) 42 | gripper_pts = [] 43 | for idx in range(total_num): 44 | rs = np.random.RandomState(idx) 45 | if args.fingers_3d: 46 | yl = rs.uniform(-0.1, 0, size=(21)) 47 | yr = rs.uniform(-0.1, 0, size=(21)) 48 | ctrlpts = generate_3d_ctrlpts(yl, yr) 49 | gripper_pts.append(ctrlpts) 50 | else: 51 | x = np.linspace(-0.12, 0.12, 7) 52 | yl = rs.uniform(-0.045, 0.015, size=(7)) 53 | yr = rs.uniform(-0.045, 0.015, size=(7)) 54 | ctrlptsl = np.stack([x, yl], axis=-1) 55 | ctrlptsr = np.stack([x, yr], axis=-1) 56 | ctrlpts = np.concatenate((ctrlptsl, ctrlptsr), axis=0) 57 | gripper_pts.append(ctrlpts) 58 | gripper_pts = np.stack(gripper_pts, axis=0) 59 | gripper_pts_max_x = 0.12 60 | gripper_pts_min_x = -0.12 61 | if args.fingers_3d: 62 | gripper_pts_max_y = 0 63 | gripper_pts_min_y = -0.1 64 | else: 65 | gripper_pts_max_y = 0.015 66 | gripper_pts_min_y = -0.045 67 | if args.mode == 'test': 68 | test_dataset = GripperDataset(gripper_pts, gripper_pts_max_x, gripper_pts_min_x, gripper_pts_max_y, gripper_pts_min_y) 69 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=True) 70 | else: 71 | train_dataset = GripperDataset(gripper_pts[train_ids, ...], gripper_pts_max_x, gripper_pts_min_x, gripper_pts_max_y, gripper_pts_min_y) 72 | val_dataset = GripperDataset(gripper_pts[val_ids, ...], gripper_pts_max_x, gripper_pts_min_x, gripper_pts_max_y, gripper_pts_min_y) 73 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=False) 74 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) 75 | 76 | input_spline_dim = 1 77 | num_spline_points = args.ctrlpts_dim # for 2d, ctrlpts_dim=14, for 3d, ctrlpts_dim=42 78 | pts_x_dim = args.ctrlpts_x_dim 79 | pts_z_dim = args.ctrlpts_z_dim 80 | unet = ConditionalUnet1D(input_dim=input_spline_dim, global_cond_dim=0, down_dims=[128, 256], diffusion_step_embed_dim=32) 81 | mode = 'point_3d' if args.fingers_3d else 'point' 82 | input_dim = input_spline_dim 83 | scheduler = DDIMScheduler(num_train_timesteps=args.num_train_timesteps, beta_schedule='squaredcos_cap_v2', clip_sample=True, prediction_type='epsilon') # squared cosine beta schedule 84 | if args.classifier_guidance: 85 | if args.fingers_3d: 86 | classifier_model = nn.DataParallel(ProfileForward3DModel(output_ch=3, params_ch=num_spline_points).cuda()) 87 | else: 88 | classifier_model = nn.DataParallel(ProfileForward2DModel(output_ch=3, params_ch=num_spline_points, object_ch=2*args.object_max_num_vertices).cuda()) 89 | print('loading classifier checkpoint from', args.checkpoint_path) 90 | classifier_model.load_state_dict(torch.load(args.checkpoint_path)) 91 | for param in classifier_model.parameters(): 92 | param.requires_grad = False 93 | if args.fingers_3d: 94 | object_pts_max_x = 0.1 95 | object_pts_min_x = -0.1 96 | object_pts_max_y = 0.1 97 | object_pts_min_y = -0.1 98 | object_pts_max_z = 0.12 99 | object_pts_min_z = 0.0 100 | object_ids = read_object_names(test=True) 101 | object_vertices = [] 102 | for object_name in object_ids: 103 | mesh_file = os.path.join(args.object_dir, object_name, 'model.obj') 104 | pts = sample_pts_from_mesh(mesh_file, args.object_max_num_vertices) 105 | object_vertices.append(torch.from_numpy(pts).float()) 106 | object_vertices = torch.stack(object_vertices, dim=0) 107 | object_vertices[..., 0] = (object_vertices[..., 0] - object_pts_min_x) / (object_pts_max_x - object_pts_min_x) * 2.0 - 1.0 108 | object_vertices[..., 1] = (object_vertices[..., 1] - object_pts_min_y) / (object_pts_max_y - object_pts_min_y) * 2.0 - 1.0 109 | object_vertices[..., 2] = (object_vertices[..., 2] - object_pts_min_z) / (object_pts_max_z - object_pts_min_z) * 2.0 - 1.0 110 | else: 111 | object_pts_max_x = 0.05 112 | object_pts_min_x = -0.05 113 | object_pts_max_y = 0.05 114 | object_pts_min_y = -0.05 115 | object_vertices = [] 116 | object_image = np.load(args.object_dir, allow_pickle=True).item()['image'] 117 | object_ids = OBJECT_IDS 118 | for object_idx in object_ids: 119 | contour = extract_contours(object_image[object_idx].transpose((1, 2, 0))) 120 | contour = torch.from_numpy(contour).float() 121 | object_vertices.append(contour) 122 | object_vertices = torch.stack(object_vertices, dim=0) 123 | object_vertices[..., 0] = (object_vertices[..., 0] - object_pts_min_x) / (object_pts_max_x - object_pts_min_x) * 2.0 - 1.0 124 | object_vertices[..., 1] = (object_vertices[..., 1] - object_pts_min_y) / (object_pts_max_y - object_pts_min_y) * 2.0 - 1.0 125 | else: 126 | classifier_model = None 127 | object_vertices = None 128 | object_ids = None 129 | diffusion_model = Diffusion(noise_pred_net=unet, noise_scheduler=scheduler, num_inference_steps=args.num_inference_steps, mode=mode, input_dim=input_dim, num_points=num_spline_points, learning_rate=args.learning_rate, lr_warmup_steps=args.lr_warmup_steps, ema_power=args.ema_power, class_cond=args.classifier_guidance, classifier_model=classifier_model, grid_size=args.grid_size, num_pos=args.num_pos, object_vertices=object_vertices, object_ids=object_ids, num_cpus=args.num_cpus, pts_x_dim=pts_x_dim, pts_z_dim=pts_z_dim, sub_batch_size=args.sub_bs, render_video=args.render_video, seed=args.seed) 130 | 131 | os.makedirs(args.save_dir, exist_ok=True) 132 | project_name = 'classifier_guidance_fixed' if args.classifier_guidance else 'gripper_diffusion' 133 | wandb_logger = WandbLogger(project=project_name, log_model='all', save_dir=args.save_dir, name=mode) 134 | callbacks = [] 135 | if rank_idx == 0: 136 | callbacks.extend( 137 | ( 138 | ModelCheckpoint( 139 | dirpath=f"{args.save_dir}/checkpoints/", # type: ignore 140 | filename="{epoch:04d}", 141 | every_n_epochs=1, 142 | save_last=True, 143 | save_top_k=10, 144 | monitor="epoch", 145 | mode="max", 146 | save_weights_only=False, 147 | ), 148 | RichProgressBar(leave=True), 149 | LearningRateMonitor(logging_interval="step"), 150 | ) 151 | ) 152 | trainer = LightningTrainer(accelerator='gpu', devices=-1, check_val_every_n_epoch=args.val_step, log_every_n_steps=1, max_epochs=args.num_epochs, logger=wandb_logger, default_root_dir=args.save_dir, callbacks=callbacks, inference_mode=False) 153 | # shortcut for inference only 154 | if args.mode == 'test': 155 | trainer.validate(diffusion_model, test_loader, ckpt_path=args.diffusion_checkpoint_path) 156 | return 157 | 158 | if args.diffusion_checkpoint_path is not None: 159 | print('loading diffusion checkpoint from', args.diffusion_checkpoint_path) 160 | trainer.fit(diffusion_model, train_loader, val_loader, ckpt_path=args.diffusion_checkpoint_path) 161 | else: 162 | trainer.fit(diffusion_model, train_loader, val_loader) 163 | 164 | if __name__ == "__main__": 165 | args = parse() 166 | train(args) -------------------------------------------------------------------------------- /assets/object_names.txt: -------------------------------------------------------------------------------- 1 | Perricone_MD_AcylGlutathione_Eye_Lid_Serum 2 | Top_Paw_Dog_Bow_Bone_Ceramic_13_fl_oz_total 3 | Schleich_Lion_Action_Figure 4 | SHAPE_MATCHING_NxacpAY9jDt 5 | Cole_Hardware_Saucer_Electric 6 | Office_Depot_Dell_Series_1_Remanufactured_Ink_Cartridge_Black 7 | Perricone_MD_The_Cold_Plasma_Face_Eyes_Duo 8 | Cole_Hardware_Antislip_Surfacing_Material_White 9 | Office_Depot_Canon_CLI_221BK_Ink_Cartridge_Black_2946B001 10 | Shurtape_Tape_Purple_CP28 11 | Magnifying_Glassassrt 12 | BIA_Porcelain_Ramekin_With_Glazed_Rim_35_45_oz_cup 13 | BUNNY_RACER 14 | Phillips_Colon_Health_Probiotic_Capsule 15 | Room_Essentials_Bowl_Turquiose 16 | Granimals_20_Wooden_ABC_Blocks_Wagon_85VdSftGsLi 17 | Cole_Hardware_Butter_Dish_Square_Red 18 | Pokémon_Yellow_Special_Pikachu_Edition_Nintendo_Game_Boy_Color 19 | Epson_Ink_Cartridge_126_Yellow 20 | JBL_Charge_Speaker_portable_wireless_wired_Green 21 | Swiss_Miss_Hot_Cocoa_KCups_Milk_Chocolate_12_count 22 | Beta_Glucan 23 | Granimals_20_Wooden_ABC_Blocks_Wagon 24 | 5_HTP 25 | Ecoforms_Planter_Pot_GP12AAvocado 26 | CoQ10_wSSVoxVppVD 27 | Grreatv_Choice_Dog_Bowl_Gray_Bones_Plastic_20_fl_oz_total 28 | Ecoforms_Plant_Container_QP_Harvest 29 | Office_Depot_HP_932XL_Ink_Cartridge_Black_CN053A 30 | Nestle_Candy_19_oz_Butterfinger_Singles_116567 31 | Ecoforms_Plant_Container_FB6_Tur 32 | Shaxon_100_Molded_Category_6_RJ45RJ45_Shielded_Patch_Cord_White 33 | Schleich_Hereford_Bull 34 | Office_Depot_HP_920XL_920_High_Yield_Black_and_Standard_CMY_Color_Ink_Cartridges 35 | Android_Figure_Chrome 36 | BUNNY_RATTLE 37 | OXO_Soft_Works_Can_Opener_SnapLock 38 | Perricone_MD_Hypoallergenic_Firming_Eye_Cream_05_oz 39 | Circo_Fish_Toothbrush_Holder_14995988 40 | Twinlab_Nitric_Fuel 41 | JarroSil_Activated_Silicon_5exdZHIeLAp 42 | Perricoen_MD_No_Concealer_Concealer 43 | FemDophilus 44 | CHICKEN_RACER 45 | Canon_Pixma_Ink_Cartridge_251_M 46 | Perricone_MD_AcylGlutathione_Deep_Crease_Serum 47 | Threshold_Porcelain_Coffee_Mug_All_Over_Bead_White 48 | Spectrum_Wall_Mount 49 | Ecoforms_Plant_Saucer_SQ8COR 50 | Schleich_Bald_Eagle 51 | Sea_to_Summit_Xl_Bowl 52 | Weisshai_Great_White_Shark 53 | Dixie_10_ounce_Bowls_35_ct 54 | Super_Mario_3D_World_Wii_U_Game 55 | Office_Depot_Dell_Series_9_Ink_Cartridge_Black_MK992 56 | PETS_ACCESSORIES 57 | Razer_Taipan_Black_Ambidextrous_Gaming_Mouse 58 | SNAIL_MEASURING_TAPE 59 | Epson_UltraChrome_T0543_Ink_Cartridge_Magenta_1pack 60 | Ecoforms_Plant_Container_QP_Turquoise 61 | TWISTED_PUZZLE 62 | Nestle_Raisinets_Milk_Chocolate_35_oz_992_g 63 | Office_Depot_Dell_Series_5_Remanufactured_Ink_Cartridge_Black 64 | Canon_Pixma_Ink_Cartridge_8_Red 65 | DIM_CDG 66 | Lutein 67 | Central_Garden_Flower_Pot_Goo_425 68 | Sapota_Threshold_4_Ceramic_Round_Planter_Red 69 | Canon_Pixma_Ink_Cartridge_8 70 | Cole_Hardware_Saucer_Glazed_6 71 | Beetle_Adventure_Racing_Nintendo_64 72 | Crayola_Crayons_Washable_24_crayons 73 | Nickelodeon_Teenage_Mutant_Ninja_Turtles_Michelangelo 74 | Nintendo_2DS_Crimson_Red 75 | Office_Depot_Canon_CLI_8Y_Ink_Cartridge_Yellow_0623B002 76 | PhosphOmega 77 | Theanine 78 | Perricone_MD_Cold_Plasma 79 | Ecoforms_Plant_Pot_GP9_SAND 80 | DINNING_ROOM_FURNITURE_SET_1 81 | Guardians_of_the_Galaxy_Galactic_Battlers_Rocket_Raccoon_Figure 82 | Mastic_Gum 83 | Perricone_MD_The_Crease_Cure_Duo 84 | Perricone_MD_Skin_Clear_Supplements 85 | Prostate_Optimizer 86 | GEOMETRIC_PEG_BOARD 87 | Jarrow_Glucosamine_Chondroitin_Combination_120_Caps 88 | Ecoforms_Plant_Pot_GP9AAvocado 89 | Perricone_MD_No_Lipstick_Lipstick 90 | Philips_EcoVantage_43_W_Light_Bulbs_Natural_Light_2_pack 91 | Krill_Oil 92 | Cole_Hardware_Bowl_Scirocco_YellowBlue 93 | Top_Paw_Dog_Bowl_Blue_Paw_Bone_Ceramic_25_fl_oz_total 94 | Threshold_Bead_Cereal_Bowl_White 95 | Android_Figure_Panda 96 | MK7 97 | BAGEL_WITH_CHEESE 98 | Object 99 | BlackBlack_Nintendo_3DSXL 100 | Bifidus_Balance_FOS 101 | Lactoferrin 102 | Schleich_S_Bayala_Unicorn_70432 103 | Luigis_Mansion_Dark_Moon_Nintendo_3DS_Game 104 | Room_Essentials_Kitchen_Towels_16_x_26_2_count 105 | CARSII 106 | Cole_Hardware_Mini_Honey_Dipper 107 | 3D_Dollhouse_Happy_Brother 108 | QAbsorb_CoQ10_53iUqjWjW3O 109 | Ecoforms_Plant_Plate_S11Turquoise 110 | NattoMax 111 | Beyonc_Life_is_But_a_Dream_DVD 112 | Hyaluronic_Acid 113 | Ecoforms_Plant_Saucer_S17MOCHA 114 | Shurtape_Gaffers_Tape_Silver_2_x_60_yd 115 | LACING_SHEEP 116 | Crayola_Crayons_24_count 117 | Ecoforms_Quadra_Saucer_SQ1_Avocado 118 | Lalaloopsy_Peanut_Big_Top_Tricycle 119 | LTyrosine 120 | Perricone_MD_Blue_Plasma_Orbital 121 | FIRE_ENGINE 122 | Canon_Pixma_Ink_Cartridge_8_Green 123 | Kingston_DT4000MR_G2_Management_Ready_USB_64GB 124 | Down_To_Earth_Orchid_Pot_Ceramic_Red 125 | CITY_TAXI_POLICE_CAR 126 | Perricone_MD_Face_Finishing_Moisturizer_4_oz 127 | Nickelodeon_Teenage_Mutant_Ninja_Turtles_Leonardo 128 | Twinlab_Premium_Creatine_Fuel_Powder 129 | Schleich_African_Black_Rhino 130 | LADYBUG_BEAD 131 | Threshold_Ramekin_White_Porcelain 132 | Office_Depot_Canon_PGI5BK_Remanufactured_Ink_Cartridge_Black 133 | Pokmon_X_Nintendo_3DS_Game 134 | Paper_Mario_Sticker_Star_Nintendo_3DS_Game 135 | Mario_Luigi_Dream_Team_Nintendo_3DS_Game 136 | Pennington_Electric_Pot_Cabana_4 137 | Germanium_GE132 138 | YumYum_D3_Liquid 139 | Android_Lego 140 | FIRE_TRUCK 141 | Office_Depot_Canon_PG_240XL_Ink_Cartridge_Black_5206B001 142 | 3D_Dollhouse_TablePurple 143 | BIRD_RATTLE 144 | Utana_5_Porcelain_Ramekin_Large 145 | Office_Depot_Dell_Series_9_Color_Ink_Ink_Cartridge_MK991_MK993 146 | JarroSil_Activated_Silicon 147 | COAST_GUARD_BOAT 148 | SANDWICH_MEAL 149 | Phillips_Caplets_Size_24 150 | Thomas_Friends_Woodan_Railway_Henry 151 | Jawbone_UP24_Wireless_Activity_Tracker_Pink_Coral_L 152 | Office_Depot_Canon_CL211XL_Remanufactured_Ink_Cartridge_TriColor 153 | Room_Essentials_Mug_White_Yellow 154 | 3M_Vinyl_Tape_Green_1_x_36_yd 155 | Perricone_MD_No_Bronzer_Bronzer 156 | Perricone_MD_Face_Finishing_Moisturizer 157 | Great_Dinos_Triceratops_Toy 158 | Pet_Dophilus_powder 159 | Perricone_MD_Firming_Neck_Therapy_Treatment 160 | Now_Designs_Bowl_Akita_Black 161 | My_First_Wiggle_Crocodile 162 | Ubisoft_RockSmith_Real_Tone_Cable_Xbox_360 163 | Snack_Catcher_Snack_Dispenser 164 | HP_1800_Tablet_8GB_7 165 | Shurtape_30_Day_Removal_UV_Delct_15 166 | RedBlack_Nintendo_3DSXL 167 | Big_O_Sponges_Assorted_Cellulose_12_pack 168 | ROAD_CONSTRUCTION_SET 169 | Star_Wars_Rogue_Squadron_Nintendo_64 170 | HELICOPTER 171 | BlueBlack_Nintendo_3DSXL 172 | FRACTION_FUN_n4h4qte23QR 173 | Kotex_U_Tween_Pads_16_pads 174 | Thomas_Friends_Wooden_Railway_Porter_5JzRhMm3a9o 175 | Android_Figure_Orange 176 | Squirtin_Barnyard_Friends_4pk 177 | Office_Depot_Canon_PGI22_Remanufactured_Ink_Cartridge_Black 178 | Animal_Crossing_New_Leaf_Nintendo_3DS_Game 179 | CoQ10 180 | Blackcurrant_Lutein 181 | Ecoforms_Plant_Saucer_S14NATURAL 182 | Epson_Ink_Cartridge_Black_200 183 | Thomas_Friends_Wooden_Railway_Talking_Thomas_z7yi7UFHJRj 184 | New_Super_Mario_BrosWii_Wii_Game 185 | Cole_Hardware_Mug_Classic_Blue 186 | Razer_Abyssus_Ambidextrous_Gaming_Mouse 187 | Super_Mario_Kart_Super_Nintendo_Entertainment_System 188 | Canon_Pixma_Ink_Cartridge_Cyan_251 189 | GoPro_HERO3_Composite_Cable 190 | WHALE_WHISTLE_6PCS_SET 191 | Office_Depot_Canon_CL_41_Remanufactured_Ink_Cartridge_TriColor 192 | Mario_Party_9_Wii_Game 193 | Phillips_Stool_Softener_Liquid_Gels_30_liquid_gels 194 | 3M_Antislip_Surfacing_Light_Duty_White 195 | Office_Depot_HP_71_Remanufactured_Ink_Cartridge_Black 196 | Perricone_MD_Photo_Plasma 197 | Office_Depot_Dell_Series_11_Remanufactured_Ink_Cartridge_Black 198 | Office_Depot_Canon_PGI35_Remanufactured_Ink_Cartridge_Black 199 | Epson_T5803_Ink_Cartridge_Magenta_1pack 200 | Office_Depot_HP_564XL_Ink_Cartridge_Black_CN684WN 201 | QAbsorb_CoQ10 202 | Hey_You_Pikachu_Nintendo_64 203 | Office_Depot_Dell_Series_1_Remanufactured_Ink_Cartridge_TriColor 204 | Kong_Puppy_Teething_Rubber_Small_Pink 205 | Perricone_MD_OVM 206 | Office_Depot_HP_61Tricolor_Ink_Cartridge 207 | Home_Fashions_Washcloth_Olive_Green 208 | Jarrow_Formulas_Glucosamine_Hci_Mega_1000_100_ct 209 | SCHOOL_BUS 210 | Office_Depot_HP_950XL_Ink_Cartridge_Black_CN045AN 211 | Blue_Jasmine_Includes_Digital_Copy_UltraViolet_DVD 212 | Perricone_MD_Neuropeptide_Firming_Moisturizer 213 | Pokmon_Y_Nintendo_3DS_Game 214 | Cole_Hardware_Antislip_Surfacing_White_2_x_60 215 | KITCHEN_FURNITURE_SET_1 216 | Polar_Herring_Fillets_Smoked_Peppered_705_oz_total 217 | Toysmith_Windem_Up_Flippin_Animals_Dog 218 | AllergenFree_JarroDophilus 219 | Down_To_Earth_Ceramic_Orchid_Pot_Asst_Blue 220 | Epson_UltraChrome_T0548_Ink_Cartridge_Matte_Black_1pack 221 | Ecoforms_Plant_Container_S14Turquoise 222 | Kotobuki_Saucer_Dragon_Fly 223 | Office_Depot_HP_74XL75_Remanufactured_Ink_Cartridges_BlackTriColor_2_count 224 | Ecoforms_Plant_Saucer_S14MOCHA 225 | Canon_Pixma_Chromalife_100_Magenta_8 226 | QHPomegranate 227 | PINEAPPLE_MARACA_6_PCSSET 228 | Office_Depot_Canon_PG21XL_Remanufactured_Ink_Cartridge_Black 229 | Ecoforms_Plant_Saucer_SQ1HARVEST 230 | Ultra_JarroDophilus 231 | Nikon_1_AW1_w11275mm_Lens_Silver 232 | TURBOPROP_AIRPLANE_WITH_PILOT 233 | RESCUE_CREW 234 | Ecoforms_Cup_B4_SAN 235 | Tag_Dishtowel_Waffle_Gray_Checks_18_x_26 236 | Home_Fashions_Washcloth_Linen 237 | Perricone_MD_Chia_Serum 238 | Dino_4 239 | Razer_Taipan_White_Ambidextrous_Gaming_Mouse 240 | Nintendo_Yoshi_Action_Figure 241 | Borage_GLA240Gamma_Tocopherol 242 | 45oz_RAMEKIN_ASST_DEEP_COLORS 243 | MINI_ROLLER 244 | Pokmon_Conquest_Nintendo_DS_Game 245 | Nintendo_Mario_Action_Figure 246 | PEEKABOO_ROLLER 247 | MOVING_MOUSE_PW_6PCSSET 248 | Granimals_20_Wooden_ABC_Blocks_Wagon_g2TinmUGGHI 249 | GEOMETRIC_SORTING_BOARD 250 | Perricone_MD_Vitamin_C_Ester_Serum 251 | CoQ10_BjTLbuRVt1t 252 | -------------------------------------------------------------------------------- /sim/sim_3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from os.path import join as pjoin 4 | BASEPATH = os.path.dirname(__file__) 5 | sys.path.insert(0, BASEPATH) 6 | sys.path.insert(0, pjoin(BASEPATH, '..')) 7 | import glob 8 | from typing import Optional 9 | import shutil 10 | 11 | import mujoco 12 | from transforms3d import euler, quaternions 13 | import numpy as np 14 | import subprocess 15 | from mujoco import viewer 16 | import ray 17 | import subprocess 18 | 19 | from assets.finger_3d import generate_3d_gripper, save_3d_gripper, generate_gripper_3d_xml, generate_scene_3d_xml 20 | from dynamics.utils import continuous_signed_delta 21 | from assets.scan_object_process import read_object_names, generate_object_3d_xml 22 | 23 | OBJECT_DIR = '/mujoco_scanned_objects/models' 24 | 25 | def compute_collision(mesh_path, num_retries: int = 2): 26 | """ 27 | Computes the convex decomposition of a mesh using v-hacd. 28 | Convention: the input mesh is assumed to be in the same folder as the output mesh, 29 | with only the name change from `xyz.obj` to `xyz_collision.obj`. 30 | 31 | V-HACD help: 32 | ``` 33 | -h : Maximum number of output convex hulls. Default is 32 34 | -r : Total number of voxels to use. Default is 100,000 35 | -e : Volume error allowed as a percentage. Default is 1%. Valid range is 0.001 to 10 36 | -d : Maximum recursion depth. Default value is 10. 37 | -s : Whether or not to shrinkwrap output to source mesh. Default is true. 38 | -f : Fill mode. Default is 'flood', also 'surface' and 'raycast' are valid. 39 | -v : Maximum number of vertices in the output convex hull. Default value is 64 40 | -a : Whether or not to run asynchronously. Default is 'true' 41 | -l : Minimum size of a voxel edge. Default value is 2 voxels. 42 | -p : If false, splits hulls in the middle. If true, tries to find optimal split plane location. False by default. 43 | -o : Export the convex hulls as a series of wavefront OBJ files, STL files, or a single USDA. 44 | -g : If set to false, no logging will be displayed. 45 | ``` 46 | """ 47 | COMMAND = [ 48 | "TestVHACD", 49 | mesh_path, 50 | "-r", 51 | "100000", 52 | "-o", 53 | "obj", 54 | "-g", 55 | "false", 56 | "-h", 57 | "32", 58 | "-v", 59 | "32", 60 | ] 61 | output: Optional[subprocess.CompletedProcess] = None 62 | assert num_retries > 1 63 | for _ in range(num_retries): 64 | try: 65 | output = subprocess.run(COMMAND, check=True) 66 | except subprocess.CalledProcessError as e: 67 | print("V-HACD failed to run on %s, retrying..." % mesh_path) 68 | continue 69 | if output is None or output.returncode != 0: 70 | raise RuntimeError("V-HACD failed to run on %s" % mesh_path) 71 | 72 | def prepare_gripper(gripper_idx: int, model_root: str): 73 | rs = np.random.RandomState(gripper_idx) 74 | yl = rs.uniform(-0.1, 0, size=(21)) 75 | yr = rs.uniform(-0.1, 0, size=(21)) 76 | save_gripper_dir = os.path.join(model_root, 'grippers', str(gripper_idx)) 77 | if not os.path.exists(save_gripper_dir): 78 | ctrlpts, allpts = save_3d_gripper( 79 | yl, 80 | yr, 81 | width=0.1, 82 | sample_size=25, 83 | save_gripper_dir=save_gripper_dir, 84 | ) 85 | meshl_path = os.path.join(save_gripper_dir, "fingerl.obj") 86 | compute_collision(meshl_path) 87 | meshr_path = os.path.join(save_gripper_dir, "fingerr.obj") 88 | compute_collision(meshr_path) 89 | generate_gripper_3d_xml(len(glob.glob(os.path.join(save_gripper_dir, "fingerl0*.obj"))), len(glob.glob(os.path.join(save_gripper_dir, "fingerr0*.obj"))), gripper_idx, os.path.join(model_root, 'gripper_%d.xml' % gripper_idx)) 90 | 91 | else: 92 | ctrlpts, allpts = generate_3d_gripper( 93 | yl, 94 | yr, 95 | sample_size=25, 96 | ) 97 | return ctrlpts, allpts 98 | 99 | def prepare_object(object_name: str, object_idx: int, model_root: str): 100 | object_model_dir = os.path.join(OBJECT_DIR, object_name) 101 | object_model_new = os.path.join(model_root, 'object_%d.xml' % object_idx) 102 | if not os.path.exists(object_model_new): 103 | shutil.copytree(object_model_dir, os.path.join(model_root, 'objects', str(object_idx)), dirs_exist_ok = True) 104 | generate_object_3d_xml(len(glob.glob(os.path.join(model_root, 'objects', str(object_idx), "model_collision_*.obj"))), object_idx, object_model_new) 105 | return os.path.join(model_root, 'objects', str(object_idx)) 106 | 107 | # @profile 108 | @ray.remote(num_cpus=2) 109 | def main(model_root, gripper_idx: int=0, object_name: str='BUNNY_RACER', object_idx: int=0, save_dir: str="sim", gui: bool = False): 110 | ctrlpts, allpts = prepare_gripper(gripper_idx, model_root) 111 | prepare_object(object_name, object_idx, model_root) 112 | scene_path = os.path.join(model_root, 'scene_%d_%d.xml' % (object_idx, gripper_idx)) 113 | generate_scene_3d_xml(object_idx, gripper_idx, scene_path) 114 | model = mujoco.MjModel.from_xml_path(scene_path) 115 | data = mujoco.MjData(model) 116 | reset_qpos = data.qpos.copy() 117 | reset_qvel = data.qvel.copy() 118 | reset_force = data.qfrc_applied.copy() 119 | handle = viewer.launch_passive(model, data) if gui else None 120 | 121 | obj_root_idx = [model.joint(jointid).name for jointid in range(model.njnt)].index( 122 | "object_root" 123 | ) 124 | obj_jnt = model.joint(obj_root_idx) 125 | assert obj_jnt.type == 0 # freejoint 126 | 127 | z_rots = np.arange(0.0, 2 * np.pi, 2 * np.pi / 360) 128 | x_locs = -0.03+0.06*np.arange(5)/4 129 | y_locs = -0.03+0.06*np.arange(5)/4 130 | init_poses = np.zeros((len(z_rots), len(x_locs), len(y_locs), 7)) 131 | final_poses = np.zeros((len(z_rots), len(x_locs), len(y_locs), 7)) 132 | for i, x_loc in enumerate(x_locs): 133 | for j, y_loc in enumerate(y_locs): 134 | for k, z_rot in enumerate(z_rots): 135 | data.qpos[:] = reset_qpos[:] 136 | data.qvel[:] = reset_qvel[:] 137 | data.qfrc_applied[:] = reset_force 138 | data.qpos[obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 3] = [ 139 | x_loc, 140 | y_loc, 141 | 0, 142 | ] 143 | data.qpos[ 144 | obj_jnt.qposadr[0] + 3 : obj_jnt.qposadr[0] + 7 145 | ] = euler.euler2quat(0, 0, z_rot) 146 | init_poses[k, i, j, :] = data.qpos[ 147 | obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 7 148 | ] 149 | data.ctrl[0] = 0.5 150 | data.ctrl[1] = -0.5 151 | for t in range(800): 152 | if handle is not None and t % 10 == 0: 153 | handle.sync() 154 | input(f"Press Enter to continue..., {t}") 155 | mujoco.mj_step(model, data) 156 | final_poses[k, i, j, :] = data.qpos[ 157 | obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 7 158 | ] 159 | if not np.isclose(data.qpos[obj_jnt.qposadr[0] + 4], 0.0, atol=1e-2) or not np.isclose(data.qpos[obj_jnt.qposadr[0] + 5], 0.0, atol=1e-2): 160 | print("give up: object not upright") 161 | return 162 | save_data = { 163 | "ctrlpts": ctrlpts, 164 | "allpts": allpts, 165 | "object_name": object_name, 166 | "obj_pos": init_poses[..., :3].reshape((-1, 3)), 167 | "obj_theta": np.asarray([quaternions.quat2axangle(quat)[-1] for quat in init_poses[..., 3:].reshape((-1, 4))], dtype=np.float32), 168 | "delta_theta": np.asarray([continuous_signed_delta(quaternions.quat2axangle(last_quat)[-1], quaternions.quat2axangle(quat)[-1]) for last_quat, quat in zip(init_poses[..., 3:].reshape((-1, 4)), final_poses[..., 3:].reshape((-1, 4)))], dtype=np.float32), 169 | "delta_pos": (final_poses[..., :3] - init_poses[..., :3]).reshape((-1, 3)), 170 | } 171 | os.makedirs(save_dir, exist_ok=True) 172 | np.savez_compressed(os.path.join(save_dir, "%d_%d.npz" % (object_idx, gripper_idx)), save_data) 173 | 174 | if __name__ == "__main__": 175 | model_root = sys.argv[1] 176 | gripper_idx = int(sys.argv[2]) 177 | object_idx = int(sys.argv[3]) 178 | num_gripper_parallel = int(sys.argv[4]) 179 | num_object_parallel = int(sys.argv[5]) 180 | save_dir = sys.argv[6] 181 | num_cpus = int(sys.argv[7]) 182 | object_names = read_object_names() 183 | 184 | ray.init(num_cpus=num_cpus, log_to_driver=False) 185 | ray_tasks = [main.remote(model_root=model_root, gripper_idx=g_idx, object_name=object_names[object_idx], object_idx=o_idx, save_dir=save_dir, gui=False) for g_idx in range(gripper_idx, gripper_idx+num_gripper_parallel) for o_idx in range(object_idx, object_idx+num_object_parallel)] 186 | while len(ray_tasks) > 0: 187 | ready, ray_tasks = ray.wait(ray_tasks, num_returns=1) 188 | try: 189 | ray.get(ready[0], timeout=1) 190 | except Exception as e: 191 | print(e) 192 | continue 193 | -------------------------------------------------------------------------------- /sim/sim_2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from os.path import join as pjoin 4 | BASEPATH = os.path.dirname(__file__) 5 | sys.path.insert(0, BASEPATH) 6 | sys.path.insert(0, pjoin(BASEPATH, '..')) 7 | import glob 8 | from typing import Optional 9 | 10 | import mujoco 11 | from transforms3d import euler, quaternions 12 | import numpy as np 13 | import subprocess 14 | from mujoco import viewer 15 | import ray 16 | import subprocess 17 | import time 18 | 19 | from assets.finger_sampler import generate_gripper, save_gripper, generate_xml, generate_scene_xml 20 | from assets.object_sampler import generate_object_xml 21 | from assets.icon_process import save_icon_mesh, extract_contours 22 | from dynamics.utils import continuous_signed_delta 23 | 24 | OBJECT_DIR = '/Icons-50.npy' 25 | 26 | def compute_collision(mesh_path, num_retries: int = 2): 27 | """ 28 | Computes the convex decomposition of a mesh using v-hacd. 29 | Convention: the input mesh is assumed to be in the same folder as the output mesh, 30 | with only the name change from `xyz.obj` to `xyz_collision.obj`. 31 | 32 | V-HACD help: 33 | ``` 34 | -h : Maximum number of output convex hulls. Default is 32 35 | -r : Total number of voxels to use. Default is 100,000 36 | -e : Volume error allowed as a percentage. Default is 1%. Valid range is 0.001 to 10 37 | -d : Maximum recursion depth. Default value is 10. 38 | -s : Whether or not to shrinkwrap output to source mesh. Default is true. 39 | -f : Fill mode. Default is 'flood', also 'surface' and 'raycast' are valid. 40 | -v : Maximum number of vertices in the output convex hull. Default value is 64 41 | -a : Whether or not to run asynchronously. Default is 'true' 42 | -l : Minimum size of a voxel edge. Default value is 2 voxels. 43 | -p : If false, splits hulls in the middle. If true, tries to find optimal split plane location. False by default. 44 | -o : Export the convex hulls as a series of wavefront OBJ files, STL files, or a single USDA. 45 | -g : If set to false, no logging will be displayed. 46 | ``` 47 | """ 48 | COMMAND = [ 49 | "TestVHACD", 50 | mesh_path, 51 | "-r", 52 | "100000", 53 | "-o", 54 | "obj", 55 | "-g", 56 | "false", 57 | "-h", 58 | "16", 59 | "-v", 60 | "32", 61 | ] 62 | output: Optional[subprocess.CompletedProcess] = None 63 | assert num_retries > 1 64 | for _ in range(num_retries): 65 | try: 66 | output = subprocess.run(COMMAND, check=True) 67 | except subprocess.CalledProcessError as e: 68 | print("V-HACD failed to run on %s, retrying..." % mesh_path) 69 | continue 70 | if output is None or output.returncode != 0: 71 | raise RuntimeError("V-HACD failed to run on %s" % mesh_path) 72 | 73 | def prepare_gripper(gripper_idx: int, model_root: str): 74 | rs = np.random.RandomState(gripper_idx) 75 | x = np.linspace(-0.12, 0.12, 7) 76 | yl = rs.uniform(-0.045, 0.015, size=(7)) 77 | yr = rs.uniform(-0.045, 0.015, size=(7)) 78 | save_gripper_dir = os.path.join(model_root, 'grippers', str(gripper_idx)) 79 | if not os.path.exists(save_gripper_dir): 80 | ctrlpts, allpts = save_gripper( 81 | x, 82 | yl, 83 | yr, 84 | width=0.03, 85 | height=0.02, 86 | num_points=200, 87 | save_gripper_dir=save_gripper_dir, 88 | ) 89 | meshl_path = os.path.join(save_gripper_dir, "fingerl.obj") 90 | compute_collision(meshl_path) 91 | meshr_path = os.path.join(save_gripper_dir, "fingerr.obj") 92 | compute_collision(meshr_path) 93 | generate_xml(len(glob.glob(os.path.join(save_gripper_dir, "fingerl0*.obj"))), len(glob.glob(os.path.join(save_gripper_dir, "fingerr0*.obj"))), gripper_idx, os.path.join(model_root, 'gripper_%d.xml' % gripper_idx)) 94 | else: 95 | ctrlpts, allpts = generate_gripper( 96 | x, 97 | yl, 98 | yr, 99 | num_points=200, 100 | ) 101 | return ctrlpts, allpts 102 | 103 | def prepare_icon_object(object_idx, image, model_root): 104 | save_object_dir = os.path.join(model_root, 'objects', str(object_idx)) 105 | if not os.path.exists(save_object_dir): 106 | contour, mesh_path = save_icon_mesh(image, 0.02, 100, save_object_dir) 107 | compute_collision(mesh_path) 108 | generate_object_xml(len(glob.glob(os.path.join(save_object_dir, "object0*.obj"))), object_idx, os.path.join(model_root, 'object_%d.xml' % object_idx)) 109 | else: 110 | contour = extract_contours(image) 111 | return contour 112 | 113 | @ray.remote(num_cpus=2) 114 | def main(model_root, object_image, gripper_idx: int=0, object_idx: int=0, save_dir: str="sim", gui: bool = False): 115 | ctrlpts, allpts = prepare_gripper(gripper_idx, model_root) 116 | object_vertices = prepare_icon_object(object_idx, object_image, model_root) 117 | scene_path = os.path.join(model_root, 'scene_%d_%d.xml' % (object_idx, gripper_idx)) 118 | generate_scene_xml(object_idx, gripper_idx, scene_path) 119 | 120 | timeout = 1 121 | start_time = time.time() 122 | while not (os.path.exists(os.path.join(model_root, 'object_%d.xml' % object_idx)) and os.path.getsize(os.path.join(model_root, 'object_%d.xml' % object_idx))>0 and os.path.exists(os.path.join(model_root, 'gripper_%d.xml' % gripper_idx)) and os.path.getsize(os.path.join(model_root, 'gripper_%d.xml' % gripper_idx))>0): 123 | if time.time() - start_time > timeout: 124 | raise RuntimeError("Timeout waiting for object_%d.xml and gripper_%d.xml" % (object_idx, gripper_idx)) 125 | time.sleep(0.1) 126 | model = mujoco.MjModel.from_xml_path(scene_path) 127 | data = mujoco.MjData(model) 128 | reset_qpos = data.qpos.copy() 129 | reset_qvel = data.qvel.copy() 130 | reset_force = data.qfrc_applied.copy() 131 | handle = viewer.launch_passive(model, data) if gui else None 132 | 133 | obj_root_idx = [model.joint(jointid).name for jointid in range(model.njnt)].index( 134 | "object_root" 135 | ) 136 | obj_jnt = model.joint(obj_root_idx) 137 | assert obj_jnt.type == 0 # freejoint 138 | 139 | z_rots = np.arange(0.0, 2 * np.pi, 2 * np.pi / 360) 140 | x_locs = -0.03+0.06*np.arange(5)/4 141 | y_locs = -0.03+0.06*np.arange(5)/4 142 | init_poses = np.zeros((len(z_rots), len(x_locs), len(y_locs), 7)) 143 | final_poses = np.zeros((len(z_rots), len(x_locs), len(y_locs), 7)) 144 | for i, x_loc in enumerate(x_locs): 145 | for j, y_loc in enumerate(y_locs): 146 | for k, z_rot in enumerate(z_rots): 147 | data.qpos[:] = reset_qpos[:] 148 | data.qvel[:] = reset_qvel[:] 149 | data.qfrc_applied[:] = reset_force 150 | data.qpos[obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 3] = [ 151 | x_loc, 152 | y_loc, 153 | 0, 154 | ] 155 | data.qpos[ 156 | obj_jnt.qposadr[0] + 3 : obj_jnt.qposadr[0] + 7 157 | ] = euler.euler2quat(0, 0, z_rot) 158 | init_poses[k, i, j, :] = data.qpos[ 159 | obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 7 160 | ] 161 | data.ctrl[0] = 0.2 162 | data.ctrl[1] = -0.2 163 | # step for 1 second 164 | for t in range(200): 165 | if handle is not None and t % 10 == 0: 166 | handle.sync() 167 | input(f"Press Enter to continue..., {t}") 168 | mujoco.mj_step(model, data) 169 | final_poses[k, i, j, :] = data.qpos[ 170 | obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 7 171 | ] 172 | save_data = { 173 | "ctrlpts": ctrlpts, 174 | "allpts": allpts, 175 | "object_vertices": object_vertices, 176 | "obj_pos": init_poses[..., :3].reshape((-1, 3)), 177 | "obj_theta": np.asarray([quaternions.quat2axangle(quat)[-1] for quat in init_poses[..., 3:].reshape((-1, 4))], dtype=np.float32), 178 | "delta_theta": np.asarray([continuous_signed_delta(quaternions.quat2axangle(last_quat)[-1], quaternions.quat2axangle(quat)[-1]) for last_quat, quat in zip(init_poses[..., 3:].reshape((-1, 4)), final_poses[..., 3:].reshape((-1, 4)))], dtype=np.float32), 179 | "delta_pos": (final_poses[..., :3] - init_poses[..., :3]).reshape((-1, 3)), 180 | } 181 | os.makedirs(save_dir, exist_ok=True) 182 | np.savez_compressed(os.path.join(save_dir, "%d_%d.npz" % (object_idx, gripper_idx)), save_data) 183 | 184 | if __name__ == "__main__": 185 | model_root = sys.argv[1] 186 | gripper_idx = int(sys.argv[2]) 187 | object_idx = int(sys.argv[3]) 188 | num_gripper_parallel = int(sys.argv[4]) 189 | num_object_parallel = int(sys.argv[5]) 190 | save_dir = sys.argv[6] 191 | num_cpus = int(sys.argv[7]) 192 | object_image = np.load(OBJECT_DIR, allow_pickle=True).item()['image'][object_idx].transpose((1, 2, 0)) 193 | 194 | ray.init(num_cpus=num_cpus, log_to_driver=False) 195 | ray_tasks = [main.remote(model_root=model_root, object_image=object_image, gripper_idx=g_idx, object_idx=o_idx, save_dir=save_dir, gui=False) for g_idx in range(gripper_idx, gripper_idx+num_gripper_parallel) for o_idx in range(object_idx, object_idx+num_object_parallel)] 196 | while len(ray_tasks) > 0: 197 | ready, ray_tasks = ray.wait(ray_tasks, num_returns=1) 198 | try: 199 | ray.get(ready[0], timeout=1) 200 | except Exception as e: 201 | print(e) 202 | continue -------------------------------------------------------------------------------- /generator/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Taken from https://diffusion-policy.cs.columbia.edu/ 2 | import logging 3 | import math 4 | from typing import List, Optional, Tuple, Union 5 | import typing 6 | 7 | import torch 8 | from torch import nn 9 | 10 | # @markdown ### **Network** 11 | # @markdown 12 | # @markdown Defines a 1D UNet architecture `ConditionalUnet1D` 13 | # @markdown as the noies prediction network 14 | # @markdown 15 | # @markdown Components 16 | # @markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k 17 | # @markdown - `Downsample1d` Strided convolution to reduce temporal resolution 18 | # @markdown - `Upsample1d` Transposed convolution to increase temporal resolution 19 | # @markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish 20 | # @markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \ 21 | # @markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection. 22 | # @markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning. 23 | 24 | 25 | class SinusoidalPosEmb(nn.Module): 26 | def __init__(self, dim): 27 | super().__init__() 28 | self.dim = dim 29 | 30 | def forward(self, x): 31 | device = x.device 32 | half_dim = self.dim // 2 33 | emb = math.log(10000) / (half_dim - 1) 34 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 35 | emb = x[:, None] * emb[None, :] 36 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 37 | return emb 38 | 39 | class Downsample1d(nn.Module): 40 | def __init__(self, dim): 41 | super().__init__() 42 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 43 | 44 | def forward(self, x): 45 | return self.conv(x) 46 | 47 | 48 | class Upsample1d(nn.Module): 49 | def __init__(self, dim): 50 | super().__init__() 51 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 52 | 53 | def forward(self, x): 54 | return self.conv(x) 55 | 56 | 57 | class Conv1dBlock(nn.Module): 58 | """ 59 | Conv1d --> GroupNorm --> Mish 60 | """ 61 | 62 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 63 | super().__init__() 64 | 65 | self.block = nn.Sequential( 66 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 67 | nn.GroupNorm(n_groups, out_channels), 68 | nn.Mish(), 69 | ) 70 | 71 | def forward(self, x): 72 | return self.block(x) 73 | 74 | 75 | class ConditionalResidualBlock1D(nn.Module): 76 | def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8): 77 | super().__init__() 78 | 79 | self.blocks = nn.ModuleList( 80 | [ 81 | Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), 82 | Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), 83 | ] 84 | ) 85 | 86 | # FiLM modulation https://arxiv.org/abs/1709.07871 87 | # predicts per-channel scale and bias 88 | cond_channels = out_channels * 2 89 | self.out_channels = out_channels 90 | self.cond_encoder = nn.Sequential( 91 | nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1)) 92 | ) 93 | 94 | # make sure dimensions compatible 95 | self.residual_conv = ( 96 | nn.Conv1d(in_channels, out_channels, 1) 97 | if in_channels != out_channels 98 | else nn.Identity() 99 | ) 100 | 101 | def forward(self, x, cond): 102 | """ 103 | TODO: no horizon 104 | x : [ batch_size x in_channels x horizon ] 105 | cond : [ batch_size x cond_dim] 106 | 107 | returns: 108 | out : [ batch_size x out_channels x horizon ] 109 | """ 110 | out = self.blocks[0](x) 111 | embed = self.cond_encoder(cond) 112 | 113 | embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) 114 | scale = embed[:, 0, ...] 115 | bias = embed[:, 1, ...] 116 | out = scale * out + bias 117 | 118 | out = self.blocks[1](out) 119 | out = out + self.residual_conv(x) 120 | return out 121 | 122 | 123 | class ConditionalUnet1D(nn.Module): 124 | def __init__( 125 | self, 126 | input_dim: int, 127 | global_cond_dim: int, 128 | down_dims: List[int], 129 | diffusion_step_embed_dim: int, 130 | kernel_size: int = 5, 131 | n_groups: int = 8, 132 | ): 133 | """ 134 | input_dim: Dim of actions. 135 | global_cond_dim: Dim of global conditioning applied with FiLM 136 | in addition to diffusion step embedding. This is usually obs_horizon * obs_dim 137 | diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k 138 | down_dims: Channel size for each UNet level. 139 | The length of this array determines number of levels. 140 | kernel_size: Conv kernel size 141 | n_groups: Number of groups for GroupNorm 142 | """ 143 | 144 | super().__init__() 145 | all_dims = [input_dim] + list(down_dims) 146 | start_dim = down_dims[0] 147 | 148 | dsed = diffusion_step_embed_dim 149 | diffusion_step_encoder = nn.Sequential( 150 | SinusoidalPosEmb(dsed), 151 | nn.Linear(dsed, dsed * 4), 152 | nn.Mish(), 153 | nn.Linear(dsed * 4, dsed), 154 | ) 155 | cond_dim = dsed + global_cond_dim 156 | 157 | in_out = list(zip(all_dims[:-1], all_dims[1:])) 158 | mid_dim = all_dims[-1] 159 | self.mid_modules = nn.ModuleList( 160 | [ 161 | ConditionalResidualBlock1D( 162 | mid_dim, 163 | mid_dim, 164 | cond_dim=cond_dim, 165 | kernel_size=kernel_size, 166 | n_groups=n_groups, 167 | ), 168 | ConditionalResidualBlock1D( 169 | mid_dim, 170 | mid_dim, 171 | cond_dim=cond_dim, 172 | kernel_size=kernel_size, 173 | n_groups=n_groups, 174 | ), 175 | ] 176 | ) 177 | 178 | down_modules = nn.ModuleList([]) 179 | for ind, (dim_in, dim_out) in enumerate(in_out): 180 | is_last = ind >= (len(in_out) - 1) 181 | down_modules.append( 182 | nn.ModuleList( 183 | [ 184 | ConditionalResidualBlock1D( 185 | dim_in, 186 | dim_out, 187 | cond_dim=cond_dim, 188 | kernel_size=kernel_size, 189 | n_groups=n_groups, 190 | ), 191 | ConditionalResidualBlock1D( 192 | dim_out, 193 | dim_out, 194 | cond_dim=cond_dim, 195 | kernel_size=kernel_size, 196 | n_groups=n_groups, 197 | ), 198 | Downsample1d(dim_out) if not is_last else nn.Identity(), 199 | ] 200 | ) 201 | ) 202 | 203 | up_modules = nn.ModuleList([]) 204 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 205 | is_last = ind >= (len(in_out) - 1) 206 | up_modules.append( 207 | nn.ModuleList( 208 | [ 209 | ConditionalResidualBlock1D( 210 | dim_out * 2, 211 | dim_in, 212 | cond_dim=cond_dim, 213 | kernel_size=kernel_size, 214 | n_groups=n_groups, 215 | ), 216 | ConditionalResidualBlock1D( 217 | dim_in, 218 | dim_in, 219 | cond_dim=cond_dim, 220 | kernel_size=kernel_size, 221 | n_groups=n_groups, 222 | ), 223 | Upsample1d(dim_in) if not is_last else nn.Identity(), 224 | ] 225 | ) 226 | ) 227 | 228 | final_conv = nn.Sequential( 229 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), 230 | nn.Conv1d(start_dim, input_dim, 1), 231 | ) 232 | 233 | self.diffusion_step_encoder = diffusion_step_encoder 234 | self.up_modules = up_modules 235 | self.down_modules = down_modules 236 | self.final_conv = final_conv 237 | 238 | def forward( 239 | self, 240 | sample: torch.Tensor, 241 | timestep: torch.Tensor, 242 | global_cond: Optional[torch.Tensor] = None, 243 | ): 244 | """ 245 | x: (B, num_points, input_dim) 246 | timestep: (B,) or int, diffusion step 247 | global_cond: (B, global_cond_dim) 248 | output: (B, num_points, input_dim) 249 | """ 250 | sample = sample.moveaxis(-1, -2) 251 | # 1. time 252 | timesteps = timestep 253 | if timesteps.shape == 0: 254 | timesteps = timesteps[None].to(sample.device) 255 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 256 | timesteps = timesteps.expand(sample.shape[0]) 257 | 258 | global_feature = self.diffusion_step_encoder(timesteps) 259 | 260 | if global_cond is not None: 261 | global_feature = torch.cat([global_feature, global_cond], axis=-1) 262 | 263 | x = sample 264 | h = [] 265 | for resnet, resnet2, downsample in self.down_modules: 266 | x = resnet(x, global_feature) 267 | x = resnet2(x, global_feature) 268 | h.append(x) 269 | x = downsample(x) 270 | 271 | for mid_module in self.mid_modules: 272 | x = mid_module(x, global_feature) 273 | 274 | for resnet, resnet2, upsample in self.up_modules: 275 | x = torch.cat((x, h.pop()), dim=1) 276 | x = resnet(x, global_feature) 277 | x = resnet2(x, global_feature) 278 | x = upsample(x) 279 | 280 | x = self.final_conv(x) 281 | 282 | # (B, input_dim, num_points) 283 | x = x.moveaxis(-1, -2) 284 | # (B, num_points, input_dim) 285 | return x -------------------------------------------------------------------------------- /dynamics/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from os.path import join as pjoin 4 | BASEPATH = os.path.dirname(__file__) 5 | sys.path.insert(0, BASEPATH) 6 | sys.path.insert(0, pjoin(BASEPATH, '..')) 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | import wandb 12 | 13 | from dynamics.dataloader import DynamicsDataset 14 | from dynamics.trainer import Trainer 15 | from dynamics.parser import parse 16 | 17 | def validate(args, val_loader, trainer, threshold_std=[0.641, 0.625, 0.3846]): 18 | print('validation:') 19 | average_val_loss = 0 20 | average_val_accuracy = 0 21 | average_val_accuracy_x = 0 22 | average_val_accuracy_y = 0 23 | with torch.no_grad(): 24 | for batch in val_loader: 25 | score = batch['scores'] # [batch_size, num_ori*num_pos, 3] 26 | input_ori = batch['input_ori'].reshape((-1, 1)).cuda() 27 | input_pos = batch['input_pos'].reshape((-1, 2)).cuda() 28 | object_vertices = None 29 | if args.fingers_3d: 30 | ctrlpts = torch.cat([batch['ctrlpts'] for _ in range(score.size(1))], 0).moveaxis(-1, -2).cuda() 31 | object_vertices = torch.cat([batch['object_vertices'] for _ in range(score.size(1))], 0).moveaxis(-1, -2).cuda() 32 | else: 33 | ctrlpts = torch.cat([batch['ctrlpts'][..., 1] for _ in range(score.size(1))], 1).reshape((input_ori.shape[0], -1)).cuda() 34 | object_vertices = torch.cat([batch['object_vertices'] for _ in range(score.size(1))], 1).reshape((input_ori.shape[0], -1)).cuda() 35 | score = score.reshape((-1, 3)).cuda() 36 | pred, loss = trainer.inference(None, ctrlpts, score, input_ori, input_pos, object_vertices) 37 | accuracy = torch.mean(torch.Tensor([2 if score_ori > threshold_std[0] else 0 if score_ori < -threshold_std[0] else 1 for score_ori in score[..., 0]]) == torch.Tensor([2 if pred_ori > threshold_std[0] else 0 if pred_ori < -threshold_std[0] else 1 for pred_ori in pred[..., 0]]), dtype=torch.float32) 38 | accuracy_x = torch.mean(torch.Tensor([2 if score_x > threshold_std[1] else 0 if score_x < -threshold_std[1] else 1 for score_x in score[..., 1]]) == torch.Tensor([2 if pred_x > threshold_std[1] else 0 if pred_x < -threshold_std[1] else 1 for pred_x in pred[..., 1]]), dtype=torch.float32) 39 | accuracy_y = torch.mean(torch.Tensor([2 if score_y > threshold_std[2] else 0 if score_y < -threshold_std[2] else 1 for score_y in score[..., 2]]) == torch.Tensor([2 if pred_y > threshold_std[2] else 0 if pred_y < -threshold_std[2] else 1 for pred_y in pred[..., 2]]), dtype=torch.float32) 40 | average_val_accuracy_x += accuracy_x 41 | average_val_accuracy_y += accuracy_y 42 | average_val_loss += loss 43 | average_val_accuracy += accuracy 44 | average_val_loss /= len(val_loader) 45 | print('average val loss:', average_val_loss) 46 | average_val_accuracy /= len(val_loader) 47 | print('average val accuracy:', average_val_accuracy) 48 | average_val_accuracy_x /= len(val_loader) 49 | print('average val accuracy x:', average_val_accuracy_x) 50 | average_val_accuracy_y /= len(val_loader) 51 | print('average val accuracy y:', average_val_accuracy_y) 52 | return average_val_loss, average_val_accuracy, average_val_accuracy_x, average_val_accuracy_y 53 | 54 | def train(args): 55 | wandb.init( 56 | project='dynamics model', 57 | config=args, 58 | dir=args.save_dir, 59 | name=args.wandb_id, 60 | ) 61 | gripper_pts_max_x = 0.12 62 | gripper_pts_min_x = -0.12 63 | if args.fingers_3d: 64 | gripper_pts_max_y = 0 65 | gripper_pts_min_y = -0.1 66 | object_pts_max_x = 0.1 67 | object_pts_min_x = -0.1 68 | object_pts_max_y = 0.1 69 | object_pts_min_y = -0.1 70 | else: 71 | gripper_pts_max_y = 0.015 72 | gripper_pts_min_y = -0.045 73 | object_pts_max_x = 0.05 74 | object_pts_min_x = -0.05 75 | object_pts_max_y = 0.05 76 | object_pts_min_y = -0.05 77 | gripper_pts_max_z = 0.12 78 | gripper_pts_min_z = 0.0 79 | object_pts_max_z = 0.12 80 | object_pts_min_z = 0.0 81 | train_dataset = DynamicsDataset( 82 | dataset_dir=args.data_dir, 83 | object_mesh_dir=args.object_mesh_dir, 84 | fingers_3d=args.fingers_3d, 85 | gripper_pts_max_x=gripper_pts_max_x, 86 | gripper_pts_min_x=gripper_pts_min_x, 87 | gripper_pts_max_y=gripper_pts_max_y, 88 | gripper_pts_min_y=gripper_pts_min_y, 89 | gripper_pts_max_z=gripper_pts_max_z, 90 | gripper_pts_min_z=gripper_pts_min_z, 91 | object_max_num_vertices=args.object_max_num_vertices, 92 | object_pts_max_x=object_pts_max_x, 93 | object_pts_min_x=object_pts_min_x, 94 | object_pts_max_y=object_pts_max_y, 95 | object_pts_min_y=object_pts_min_y, 96 | object_pts_max_z=object_pts_max_z, 97 | object_pts_min_z=object_pts_min_z) 98 | threshold_std = train_dataset.threshold / train_dataset.std 99 | val_dataset = DynamicsDataset( 100 | dataset_dir=args.test_data_dir, 101 | object_mesh_dir=args.object_mesh_dir, 102 | fingers_3d=args.fingers_3d, 103 | gripper_pts_max_x=gripper_pts_max_x, 104 | gripper_pts_min_x=gripper_pts_min_x, 105 | gripper_pts_max_y=gripper_pts_max_y, 106 | gripper_pts_min_y=gripper_pts_min_y, 107 | gripper_pts_max_z=gripper_pts_max_z, 108 | gripper_pts_min_z=gripper_pts_min_z, 109 | object_max_num_vertices=args.object_max_num_vertices, 110 | object_pts_max_x=object_pts_max_x, 111 | object_pts_min_x=object_pts_min_x, 112 | object_pts_max_y=object_pts_max_y, 113 | object_pts_min_y=object_pts_min_y, 114 | object_pts_max_z=object_pts_max_z, 115 | object_pts_min_z=object_pts_min_z) 116 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=False) 117 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) 118 | trainer = Trainer(args) 119 | trainer.create_model() 120 | 121 | # short-cut for testing 122 | if args.mode == 'validate': 123 | if args.checkpoint_path is None: 124 | raise ValueError('checkpoint path is not specified') 125 | validate(args, val_loader, trainer, threshold_std=threshold_std) 126 | return 127 | 128 | # train model 129 | if args.mode == 'train': 130 | best_val_loss = float('inf') 131 | last_best_epoch = 0 132 | for epoch in tqdm(range(args.num_epochs)): 133 | average_loss = 0 134 | average_accuracy = 0 135 | average_accuracy_x = 0 136 | average_accuracy_y = 0 137 | for idx_batch, batch in enumerate(tqdm(train_loader)): 138 | score = batch['scores'] # [batch_size, num_ori*num_pos, 3] 139 | input_ori = batch['input_ori'].reshape((-1, 1)).cuda() 140 | input_pos = batch['input_pos'].reshape((-1, 2)).cuda() 141 | object_vertices = None 142 | if args.fingers_3d: 143 | ctrlpts = torch.cat([batch['ctrlpts'] for _ in range(score.size(1))], 0).moveaxis(-1, -2).cuda() 144 | object_vertices = torch.cat([batch['object_vertices'] for _ in range(score.size(1))], 0).moveaxis(-1, -2).cuda() 145 | else: 146 | ctrlpts = torch.cat([batch['ctrlpts'][..., 1] for _ in range(score.size(1))], 1).reshape((input_ori.shape[0], -1)).cuda() 147 | object_vertices = torch.cat([batch['object_vertices'] for _ in range(score.size(1))], 1).reshape((input_ori.shape[0], -1)).cuda() 148 | score = score.reshape((-1, 3)).cuda() 149 | loss, pred = trainer.step(ctrlpts, score, input_ori, input_pos, object_vertices) 150 | 151 | accuracy = torch.mean(torch.Tensor([2 if score_ori > threshold_std[0] else 0 if score_ori < -threshold_std[0] else 1 for score_ori in score[..., 0]]) == torch.Tensor([2 if pred_ori > threshold_std[0] else 0 if pred_ori < -threshold_std[0] else 1 for pred_ori in pred[..., 0]]), dtype=torch.float32) 152 | accuracy_x = torch.mean(torch.Tensor([2 if score_x > threshold_std[1] else 0 if score_x < -threshold_std[1] else 1 for score_x in score[..., 1]]) == torch.Tensor([2 if pred_x > threshold_std[1] else 0 if pred_x < -threshold_std[1] else 1 for pred_x in pred[..., 1]]), dtype=torch.float32) 153 | accuracy_y = torch.mean(torch.Tensor([2 if score_y > threshold_std[2] else 0 if score_y < -threshold_std[2] else 1 for score_y in score[..., 2]]) == torch.Tensor([2 if pred_y > threshold_std[2] else 0 if pred_y < -threshold_std[2] else 1 for pred_y in pred[..., 2]]), dtype=torch.float32) 154 | average_accuracy_x += accuracy_x 155 | average_accuracy_y += accuracy_y 156 | average_loss += loss 157 | average_accuracy += accuracy 158 | wandb.log({ 159 | 'train/lr': trainer.optimizer.param_groups[0]['lr'], 160 | 'train/batch loss': loss, 161 | 'train/batch accuracy ori': accuracy, 162 | 'train/batch accuracy x': accuracy_x, 163 | 'train/batch accuracy y': accuracy_y, 164 | }) 165 | if idx_batch % args.save_ckpt_step == 0: 166 | os.makedirs(args.save_dir, exist_ok=True) 167 | trainer.save_checkpoint(os.path.join(args.save_dir, '%d_%d.pt' % (epoch, idx_batch))) 168 | trainer.lr_scheduler.step() 169 | average_loss /= len(train_loader) 170 | print('epoch:', epoch, 'loss:', average_loss) 171 | average_accuracy /= len(train_loader) 172 | print('epoch:', epoch, 'accuracy:', average_accuracy) 173 | average_accuracy_x /= len(train_loader) 174 | print('epoch:', epoch, 'accuracy x:', average_accuracy_x) 175 | average_accuracy_y /= len(train_loader) 176 | print('epoch:', epoch, 'accuracy y:', average_accuracy_y) 177 | wandb.log({ 178 | 'train/average loss': average_loss, 179 | 'train/average accuracy ori': average_accuracy, 180 | 'train/average accuracy x': average_accuracy_x, 181 | 'train/average accuracy y': average_accuracy_y, 182 | }) 183 | if epoch % args.val_step == 0: 184 | val_loss, val_accuracy, val_accuracy_x, val_accuracy_y = validate(args, val_loader, trainer, threshold_std=threshold_std) 185 | wandb.log({ 186 | 'val/average loss': val_loss, 187 | 'val/average accuracy ori': val_accuracy, 188 | 'val/average accuracy x': val_accuracy_x, 189 | 'val/average accuracy y': val_accuracy_y, 190 | }) 191 | if val_loss < best_val_loss: 192 | best_val_loss = val_loss 193 | os.makedirs(args.save_dir, exist_ok=True) 194 | trainer.save_checkpoint(os.path.join(args.save_dir, 'best.pt')) 195 | last_best_epoch = epoch 196 | else: 197 | if epoch - last_best_epoch >= args.patience: 198 | print('early stopping...') 199 | break 200 | wandb.finish() 201 | 202 | if __name__ == '__main__': 203 | args = parse() 204 | os.makedirs(args.save_dir, exist_ok=True) 205 | train(args) 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /dynamics/models/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @article{Pytorch_Pointnet_Pointnet2, 3 | Author = {Xu Yan}, 4 | Title = {Pointnet/Pointnet++ Pytorch}, 5 | Journal = {https://github.com/yanx27/Pointnet_Pointnet2_pytorch}, 6 | Year = {2019} 7 | } 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from time import time 13 | import numpy as np 14 | 15 | def timeit(tag, t): 16 | print("{}: {}s".format(tag, time() - t)) 17 | return time() 18 | 19 | def pc_normalize(pc): 20 | l = pc.shape[0] 21 | centroid = np.mean(pc, axis=0) 22 | pc = pc - centroid 23 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 24 | pc = pc / m 25 | return pc 26 | 27 | def square_distance(src, dst): 28 | """ 29 | Calculate Euclid distance between each two points. 30 | 31 | src^T * dst = xn * xm + yn * ym + zn * zm; 32 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 33 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 34 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 35 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 36 | 37 | Input: 38 | src: source points, [B, N, C] 39 | dst: target points, [B, M, C] 40 | Output: 41 | dist: per-point square distance, [B, N, M] 42 | """ 43 | B, N, _ = src.shape 44 | _, M, _ = dst.shape 45 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 46 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 47 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 48 | return dist 49 | 50 | 51 | def index_points(points, idx): 52 | """ 53 | 54 | Input: 55 | points: input points data, [B, N, C] 56 | idx: sample index data, [B, S] 57 | Return: 58 | new_points:, indexed points data, [B, S, C] 59 | """ 60 | device = points.device 61 | B = points.shape[0] 62 | view_shape = list(idx.shape) 63 | view_shape[1:] = [1] * (len(view_shape) - 1) 64 | repeat_shape = list(idx.shape) 65 | repeat_shape[0] = 1 66 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 67 | new_points = points[batch_indices, idx, :] 68 | return new_points 69 | 70 | 71 | def farthest_point_sample(xyz, npoint): 72 | """ 73 | Input: 74 | xyz: pointcloud data, [B, N, 3] 75 | npoint: number of samples 76 | Return: 77 | centroids: sampled pointcloud index, [B, npoint] 78 | """ 79 | device = xyz.device 80 | B, N, C = xyz.shape 81 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 82 | distance = torch.ones(B, N).to(device) * 1e10 83 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 84 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 85 | for i in range(npoint): 86 | centroids[:, i] = farthest 87 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 88 | dist = torch.sum((xyz - centroid) ** 2, -1) 89 | mask = dist < distance 90 | distance[mask] = dist[mask] 91 | farthest = torch.max(distance, -1)[1] 92 | return centroids 93 | 94 | 95 | def query_ball_point(radius, nsample, xyz, new_xyz): 96 | """ 97 | Input: 98 | radius: local region radius 99 | nsample: max sample number in local region 100 | xyz: all points, [B, N, 3] 101 | new_xyz: query points, [B, S, 3] 102 | Return: 103 | group_idx: grouped points index, [B, S, nsample] 104 | """ 105 | device = xyz.device 106 | B, N, C = xyz.shape 107 | _, S, _ = new_xyz.shape 108 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 109 | sqrdists = square_distance(new_xyz, xyz) 110 | group_idx[sqrdists > radius ** 2] = N 111 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 112 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 113 | mask = group_idx == N 114 | group_idx[mask] = group_first[mask] 115 | return group_idx 116 | 117 | 118 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 119 | """ 120 | Input: 121 | npoint: 122 | radius: 123 | nsample: 124 | xyz: input points position data, [B, N, 3] 125 | points: input points data, [B, N, D] 126 | Return: 127 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 128 | new_points: sampled points data, [B, npoint, nsample, 3+D] 129 | """ 130 | B, N, C = xyz.shape 131 | S = npoint 132 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 133 | new_xyz = index_points(xyz, fps_idx) 134 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 135 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 136 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 137 | 138 | if points is not None: 139 | grouped_points = index_points(points, idx) 140 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 141 | else: 142 | new_points = grouped_xyz_norm 143 | if returnfps: 144 | return new_xyz, new_points, grouped_xyz, fps_idx 145 | else: 146 | return new_xyz, new_points 147 | 148 | 149 | def sample_and_group_all(xyz, points): 150 | """ 151 | Input: 152 | xyz: input points position data, [B, N, 3] 153 | points: input points data, [B, N, D] 154 | Return: 155 | new_xyz: sampled points position data, [B, 1, 3] 156 | new_points: sampled points data, [B, 1, N, 3+D] 157 | """ 158 | device = xyz.device 159 | B, N, C = xyz.shape 160 | new_xyz = torch.zeros(B, 1, C).to(device) 161 | grouped_xyz = xyz.view(B, 1, N, C) 162 | if points is not None: 163 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 164 | else: 165 | new_points = grouped_xyz 166 | return new_xyz, new_points 167 | 168 | 169 | class PointNetSetAbstraction(nn.Module): 170 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 171 | super(PointNetSetAbstraction, self).__init__() 172 | self.npoint = npoint 173 | self.radius = radius 174 | self.nsample = nsample 175 | self.mlp_convs = nn.ModuleList() 176 | self.mlp_bns = nn.ModuleList() 177 | last_channel = in_channel 178 | for out_channel in mlp: 179 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 180 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 181 | last_channel = out_channel 182 | self.group_all = group_all 183 | 184 | def forward(self, xyz, points): 185 | """ 186 | Input: 187 | xyz: input points position data, [B, C, N] 188 | points: input points data, [B, D, N] 189 | Return: 190 | new_xyz: sampled points position data, [B, C, S] 191 | new_points_concat: sample points feature data, [B, D', S] 192 | """ 193 | xyz = xyz.permute(0, 2, 1) 194 | if points is not None: 195 | points = points.permute(0, 2, 1) 196 | 197 | if self.group_all: 198 | new_xyz, new_points = sample_and_group_all(xyz, points) 199 | else: 200 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 201 | # new_xyz: sampled points position data, [B, npoint, C] 202 | # new_points: sampled points data, [B, npoint, nsample, C+D] 203 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 204 | for i, conv in enumerate(self.mlp_convs): 205 | bn = self.mlp_bns[i] 206 | new_points = F.relu(bn(conv(new_points))) 207 | 208 | new_points = torch.max(new_points, 2)[0] 209 | new_xyz = new_xyz.permute(0, 2, 1) 210 | return new_xyz, new_points 211 | 212 | 213 | class PointNetSetAbstractionMsg(nn.Module): 214 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 215 | super(PointNetSetAbstractionMsg, self).__init__() 216 | self.npoint = npoint 217 | self.radius_list = radius_list 218 | self.nsample_list = nsample_list 219 | self.conv_blocks = nn.ModuleList() 220 | self.bn_blocks = nn.ModuleList() 221 | for i in range(len(mlp_list)): 222 | convs = nn.ModuleList() 223 | bns = nn.ModuleList() 224 | last_channel = in_channel + 3 225 | for out_channel in mlp_list[i]: 226 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 227 | bns.append(nn.BatchNorm2d(out_channel)) 228 | last_channel = out_channel 229 | self.conv_blocks.append(convs) 230 | self.bn_blocks.append(bns) 231 | 232 | def forward(self, xyz, points): 233 | """ 234 | Input: 235 | xyz: input points position data, [B, C, N] 236 | points: input points data, [B, D, N] 237 | Return: 238 | new_xyz: sampled points position data, [B, C, S] 239 | new_points_concat: sample points feature data, [B, D', S] 240 | """ 241 | xyz = xyz.permute(0, 2, 1) 242 | if points is not None: 243 | points = points.permute(0, 2, 1) 244 | 245 | B, N, C = xyz.shape 246 | S = self.npoint 247 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 248 | new_points_list = [] 249 | for i, radius in enumerate(self.radius_list): 250 | K = self.nsample_list[i] 251 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 252 | grouped_xyz = index_points(xyz, group_idx) 253 | grouped_xyz -= new_xyz.view(B, S, 1, C) 254 | if points is not None: 255 | grouped_points = index_points(points, group_idx) 256 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 257 | else: 258 | grouped_points = grouped_xyz 259 | 260 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 261 | for j in range(len(self.conv_blocks[i])): 262 | conv = self.conv_blocks[i][j] 263 | bn = self.bn_blocks[i][j] 264 | grouped_points = F.relu(bn(conv(grouped_points))) 265 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 266 | new_points_list.append(new_points) 267 | 268 | new_xyz = new_xyz.permute(0, 2, 1) 269 | new_points_concat = torch.cat(new_points_list, dim=1) 270 | return new_xyz, new_points_concat 271 | 272 | 273 | class PointNetFeaturePropagation(nn.Module): 274 | def __init__(self, in_channel, mlp): 275 | super(PointNetFeaturePropagation, self).__init__() 276 | self.mlp_convs = nn.ModuleList() 277 | self.mlp_bns = nn.ModuleList() 278 | last_channel = in_channel 279 | for out_channel in mlp: 280 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 281 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 282 | last_channel = out_channel 283 | 284 | def forward(self, xyz1, xyz2, points1, points2): 285 | """ 286 | Input: 287 | xyz1: input points position data, [B, C, N] 288 | xyz2: sampled input points position data, [B, C, S] 289 | points1: input points data, [B, D, N] 290 | points2: input points data, [B, D, S] 291 | Return: 292 | new_points: upsampled points data, [B, D', N] 293 | """ 294 | xyz1 = xyz1.permute(0, 2, 1) 295 | xyz2 = xyz2.permute(0, 2, 1) 296 | 297 | points2 = points2.permute(0, 2, 1) 298 | B, N, C = xyz1.shape 299 | _, S, _ = xyz2.shape 300 | 301 | if S == 1: 302 | interpolated_points = points2.repeat(1, N, 1) 303 | else: 304 | dists = square_distance(xyz1, xyz2) 305 | dists, idx = dists.sort(dim=-1) 306 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 307 | 308 | dist_recip = 1.0 / (dists + 1e-8) 309 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 310 | weight = dist_recip / norm 311 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 312 | 313 | if points1 is not None: 314 | points1 = points1.permute(0, 2, 1) 315 | new_points = torch.cat([points1, interpolated_points], dim=-1) 316 | else: 317 | new_points = interpolated_points 318 | 319 | new_points = new_points.permute(0, 2, 1) 320 | for i, conv in enumerate(self.mlp_convs): 321 | bn = self.mlp_bns[i] 322 | new_points = F.relu(bn(conv(new_points))) 323 | return new_points 324 | -------------------------------------------------------------------------------- /dynamics/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def convergence_mode(profile): 5 | ''' 6 | return the lengths of sequences in a tensor that are consecutive 1s followed by consecutive 0s, e.g. 1,0/1,1,1,0,0,0/1,1,0,0,0,0. also wrap around the tensor to handle wrapped sequences. 7 | return the indices of the convergence points, i.e. the indices of the last 1s in the sequences. 8 | ''' 9 | profile = torch.where(profile > 0, 1.0, 0.0) 10 | if torch.all(profile == 0): 11 | return torch.tensor([len(profile)], device=profile.device), torch.tensor([0], device=profile.device) 12 | elif torch.all(profile == 1): 13 | return torch.tensor([len(profile)], device=profile.device), torch.tensor([len(profile)-1], device=profile.device) 14 | profile = torch.cat((profile, profile), dim=0) 15 | diff = torch.diff(profile) 16 | convergence_points = torch.where(diff < 0)[0] 17 | convergence_points = convergence_points[convergence_points < (len(profile) // 2)] 18 | sequence_start = torch.where(diff > 0)[0] 19 | sequence_lengths = torch.diff(torch.cat((torch.tensor([0], device=sequence_start.device), sequence_start[sequence_start > convergence_points[0]], torch.tensor([len(profile)], device=sequence_start.device)))) 20 | sequence_lengths = sequence_lengths[:len(convergence_points)] 21 | return sequence_lengths, convergence_points 22 | 23 | def convergence_mode_three_class(profile): 24 | profile_binary_ids = torch.where(profile != 1)[0] 25 | if len(profile_binary_ids) == 0: 26 | return torch.tensor([0], device=profile.device), torch.tensor([0], device=profile.device) 27 | profile_binary = profile[profile != 1] 28 | sequence_lengths, convergence_points = convergence_mode(profile_binary) 29 | convergence_points = profile_binary_ids[convergence_points] 30 | return sequence_lengths, convergence_points 31 | 32 | def slicer(a, lower, upper): 33 | if lower < 0: 34 | return torch.cat((a[lower:], a[:upper])) 35 | elif upper > len(a): 36 | return torch.cat((a[lower:], a[:upper-len(a)])) 37 | else: 38 | return a[lower: upper] 39 | 40 | def convergence_range_from_finals(finals, threshold=0.1): 41 | ''' 42 | given the final orientations, return the convergence range (consecutive range where the finals are close to each other in a threshold) 43 | finals: [num_ori] 44 | ''' 45 | convergence_range = [] 46 | start = 0 47 | end = 0 48 | min_consecutive_range = 1 49 | min_range_final = finals[0] 50 | max_range_final = finals[0] 51 | for i in range(1, len(finals)): 52 | min_range_final = min(min_range_final, finals[i]) 53 | max_range_final = max(max_range_final, finals[i]) 54 | if (max_range_final - min_range_final) <= threshold: 55 | end = i 56 | else: 57 | if end - start >= min_consecutive_range: 58 | convergence_range.append((start, end)) 59 | start = i 60 | end = i 61 | min_range_final = finals[i] 62 | max_range_final = finals[i] 63 | if end - start >= min_consecutive_range: 64 | convergence_range.append((start, end)) 65 | return convergence_range 66 | 67 | def metric2objective(metric, objective): 68 | if objective == 'rotate': 69 | return { 70 | 'success_rate': np.mean((metric['profile'] == 0) | (metric['profile'] == 2), dtype=np.float32), 71 | 'num_zero_classes': np.sum(metric['profile']==1, dtype=np.int16), 72 | 'delta_theta_abs': np.mean(np.abs(metric['delta_theta'])), 73 | 'final_delta_theta_abs': np.mean(np.abs(metric['final_delta_theta'])), 74 | } 75 | elif objective == 'rotate_clockwise': 76 | return{ 77 | 'success_rate': np.mean(metric['profile'] == 0, dtype=np.float32), 78 | 'num_clockwise_classes': np.sum(metric['profile']==0, dtype=np.int16), 79 | 'delta_theta': np.mean(metric['delta_theta']), 80 | 'final_delta_theta': np.mean(metric['final_delta_theta']), 81 | } 82 | elif objective == 'rotate_counterclockwise': 83 | return{ 84 | 'success_rate': np.mean(metric['profile'] == 2, dtype=np.float32), 85 | 'num_counterclockwise_classes': np.sum(metric['profile']==2, dtype=np.int16), 86 | 'delta_theta': np.mean(metric['delta_theta']), 87 | 'final_delta_theta': np.mean(metric['final_delta_theta']), 88 | } 89 | elif objective == 'shift_up': # shift up: x negative 90 | return { 91 | 'success_rate': np.mean(metric['profile_x'] == 0, dtype=np.float32), 92 | 'num_up_classes': np.sum(metric['profile_x']==0, dtype=np.int16), 93 | 'delta_pos_x': np.mean(metric['delta_pos'][:, 0]), 94 | 'final_pos_x': np.mean(metric['final_pos'][:, 0]), 95 | } 96 | elif objective == 'shift_down': # shift down: x positive 97 | return { 98 | 'success_rate': np.mean(metric['profile_x'] == 2, dtype=np.float32), 99 | 'num_down_classes': np.sum(metric['profile_x']==2, dtype=np.int16), 100 | 'delta_pos_x': np.mean(metric['delta_pos'][:, 0]), 101 | 'final_pos_x': np.mean(metric['final_pos'][:, 0]), 102 | } 103 | elif objective == 'shift_left': # shift left: y negative 104 | return { 105 | 'success_rate': np.mean(metric['profile_y'] == 0, dtype=np.float32), 106 | 'num_left_classes': np.sum(metric['profile_y']==0, dtype=np.int16), 107 | 'delta_pos_y': np.mean(metric['delta_pos'][:, 1]), 108 | 'final_pos_y': np.mean(metric['final_pos'][:, 1]), 109 | } 110 | elif objective == 'shift_right':# shift right: y positive 111 | return { 112 | 'success_rate': np.mean(metric['profile_y'] == 2, dtype=np.float32), 113 | 'num_right_classes': np.sum(metric['profile_y']==2, dtype=np.int16), 114 | 'delta_pos_y': np.mean(metric['delta_pos'][:, 1]), 115 | 'final_pos_y': np.mean(metric['final_pos'][:, 1]), 116 | } 117 | elif objective == 'convergence': 118 | convergence_range_3deg = convergence_range_from_finals(metric['final_theta'], threshold=3) 119 | max_convergence_range_3deg = np.max([end - start for start, end in convergence_range_3deg]) if len(convergence_range_3deg) > 0 else 0 120 | convergence_range_5deg = convergence_range_from_finals(metric['final_theta'], threshold=5) 121 | max_convergence_range_5deg = np.max([end - start for start, end in convergence_range_5deg]) if len(convergence_range_5deg) > 0 else 0 122 | convergence_range_10deg = convergence_range_from_finals(metric['final_theta'], threshold=10) 123 | max_convergence_range_10deg = np.max([end - start for start, end in convergence_range_10deg]) if len(convergence_range_10deg) > 0 else 0 124 | return { 125 | 'max_convergence_range_3deg': max_convergence_range_3deg, 126 | 'max_convergence_range_5deg': max_convergence_range_5deg, 127 | 'max_convergence_range_10deg': max_convergence_range_10deg, 128 | } 129 | elif objective == 'clockwise_up': 130 | num_clockwise_classes = np.sum(metric['profile']==0, dtype=np.int16) 131 | num_up_classes = np.sum(metric['profile_x']==0, dtype=np.int16) 132 | return { 133 | 'success_rate': np.mean((metric['profile'] == 0) & (metric['profile_x'] == 0), dtype=np.float32), 134 | 'num_clockwise_up_classes': num_clockwise_classes + num_up_classes, 135 | 'num_clockwise_classes': num_clockwise_classes, 136 | 'delta_theta': np.mean(metric['delta_theta']), 137 | 'final_delta_theta': np.mean(metric['final_delta_theta']), 138 | 'num_up_classes': num_up_classes, 139 | 'delta_pos_x': np.mean(metric['delta_pos'][:, 0]), 140 | 'final_pos_x': np.mean(metric['final_pos'][:, 0]), 141 | } 142 | elif objective == 'clockwise_down': 143 | num_clockwise_classes = np.sum(metric['profile']==0, dtype=np.int16) 144 | num_down_classes = np.sum(metric['profile_x']==2, dtype=np.int16) 145 | return { 146 | 'success_rate': np.mean((metric['profile'] == 0) & (metric['profile_x'] == 2), dtype=np.float32), 147 | 'num_clockwise_down_classes': num_clockwise_classes + num_down_classes, 148 | 'num_clockwise_classes': num_clockwise_classes, 149 | 'delta_theta': np.mean(metric['delta_theta']), 150 | 'final_delta_theta': np.mean(metric['final_delta_theta']), 151 | 'num_down_classes': num_down_classes, 152 | 'delta_pos_x': np.mean(metric['delta_pos'][:, 0]), 153 | 'final_pos_x': np.mean(metric['final_pos'][:, 0]), 154 | } 155 | elif objective == 'clockwise_right': 156 | num_clockwise_classes = np.sum(metric['profile']==0, dtype=np.int16) 157 | num_right_classes = np.sum(metric['profile_y']==2, dtype=np.int16) 158 | return { 159 | 'success_rate': np.mean((metric['profile'] == 0) & (metric['profile_y'] == 2), dtype=np.float32), 160 | 'num_clockwise_right_classes': num_clockwise_classes + num_right_classes, 161 | 'num_clockwise_classes': num_clockwise_classes, 162 | 'delta_theta': np.mean(metric['delta_theta']), 163 | 'final_delta_theta': np.mean(metric['final_delta_theta']), 164 | 'num_right_classes': num_right_classes, 165 | 'delta_pos_y': np.mean(metric['delta_pos'][:, 1]), 166 | 'final_pos_y': np.mean(metric['final_pos'][:, 1]), 167 | } 168 | elif objective == 'clockwise_left': 169 | num_clockwise_classes = np.sum(metric['profile']==0, dtype=np.int16) 170 | num_left_classes = np.sum(metric['profile_y']==0, dtype=np.int16) 171 | return { 172 | 'success_rate': np.mean((metric['profile'] == 0) & (metric['profile_y'] == 0), dtype=np.float32), 173 | 'num_clockwise_left_classes': num_clockwise_classes + num_left_classes, 174 | 'num_clockwise_classes': num_clockwise_classes, 175 | 'delta_theta': np.mean(metric['delta_theta']), 176 | 'final_delta_theta': np.mean(metric['final_delta_theta']), 177 | 'num_left_classes': num_left_classes, 178 | 'delta_pos_y': np.mean(metric['delta_pos'][:, 1]), 179 | 'final_pos_y': np.mean(metric['final_pos'][:, 1]), 180 | } 181 | elif objective == 'counterclockwise_up': 182 | num_counterclockwise_classes = np.sum(metric['profile']==2, dtype=np.int16) 183 | num_up_classes = np.sum(metric['profile_x']==0, dtype=np.int16) 184 | return { 185 | 'success_rate': np.mean((metric['profile'] == 2) & (metric['profile_x'] == 0), dtype=np.float32), 186 | 'num_counterclockwise_up_classes': num_counterclockwise_classes + num_up_classes, 187 | 'num_counterclockwise_classes': num_counterclockwise_classes, 188 | 'delta_theta': np.mean(metric['delta_theta']), 189 | 'final_delta_theta': np.mean(metric['final_delta_theta']), 190 | 'num_up_classes': num_up_classes, 191 | 'delta_pos_x': np.mean(metric['delta_pos'][:, 0]), 192 | 'final_pos_x': np.mean(metric['final_pos'][:, 0]), 193 | } 194 | elif objective == 'counterclockwise_down': 195 | num_counterclockwise_classes = np.sum(metric['profile']==2, dtype=np.int16) 196 | num_down_classes = np.sum(metric['profile_x']==2, dtype=np.int16) 197 | return { 198 | 'success_rate': np.mean((metric['profile'] == 2) & (metric['profile_x'] == 2), dtype=np.float32), 199 | 'num_counterclockwise_down_classes': num_counterclockwise_classes + num_down_classes, 200 | 'num_counterclockwise_classes': num_counterclockwise_classes, 201 | 'delta_theta': np.mean(metric['delta_theta']), 202 | 'final_delta_theta': np.mean(metric['final_delta_theta']), 203 | 'num_down_classes': num_down_classes, 204 | 'delta_pos_x': np.mean(metric['delta_pos'][:, 0]), 205 | 'final_pos_x': np.mean(metric['final_pos'][:, 0]), 206 | } 207 | elif objective == 'counterclockwise_right': 208 | num_counterclockwise_classes = np.sum(metric['profile']==2, dtype=np.int16) 209 | num_right_classes = np.sum(metric['profile_y']==2, dtype=np.int16) 210 | return { 211 | 'success_rate': np.mean((metric['profile'] == 2) & (metric['profile_y'] == 2), dtype=np.float32), 212 | 'num_counterclockwise_right_classes': num_counterclockwise_classes + num_right_classes, 213 | 'num_counterclockwise_classes': num_counterclockwise_classes, 214 | 'delta_theta': np.mean(metric['delta_theta']), 215 | 'final_delta_theta': np.mean(metric['final_delta_theta']), 216 | 'num_right_classes': num_right_classes, 217 | 'delta_pos_y': np.mean(metric['delta_pos'][:, 1]), 218 | 'final_pos_y': np.mean(metric['final_pos'][:, 1]), 219 | } 220 | elif objective == 'counterclockwise_left': 221 | num_counterclockwise_classes = np.sum(metric['profile']==2, dtype=np.int16) 222 | num_left_classes = np.sum(metric['profile_y']==0, dtype=np.int16) 223 | return { 224 | 'success_rate': np.mean((metric['profile'] == 2) & (metric['profile_y'] == 0), dtype=np.float32), 225 | 'num_counterclockwise_left_classes': num_counterclockwise_classes + num_left_classes, 226 | 'num_counterclockwise_classes': num_counterclockwise_classes, 227 | 'delta_theta': np.mean(metric['delta_theta']), 228 | 'final_delta_theta': np.mean(metric['final_delta_theta']), 229 | 'num_left_classes': num_left_classes, 230 | 'delta_pos_y': np.mean(metric['delta_pos'][:, 1]), 231 | 'final_pos_y': np.mean(metric['final_pos'][:, 1]), 232 | } 233 | else: 234 | raise NotImplementedError -------------------------------------------------------------------------------- /dynamics/sim_test_mj_3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | from os.path import join as pjoin 5 | BASEPATH = os.path.dirname(__file__) 6 | sys.path.insert(0, BASEPATH) 7 | sys.path.insert(0, pjoin(BASEPATH, '..')) 8 | from typing import Optional, final 9 | 10 | import mujoco 11 | from transforms3d import euler, quaternions 12 | import numpy as np 13 | import subprocess 14 | from mujoco import viewer 15 | import ray 16 | import subprocess 17 | import time 18 | import imageio 19 | import cv2 20 | 21 | from dynamics.utils import continuous_signed_delta, visualize_profile, visualize_finals 22 | from sim.sim_3d import prepare_object 23 | from assets.finger_3d import save_3d_gripper, generate_gripper_3d_xml, generate_scene_3d_xml 24 | from sim.render_mesh import render_mesh, render_object_mesh 25 | 26 | threshold = np.array([0.02, 0.001, 0.001]) 27 | 28 | def compute_collision(mesh_path, num_retries: int = 2): 29 | """ 30 | Computes the convex decomposition of a mesh using v-hacd. 31 | Convention: the input mesh is assumed to be in the same folder as the output mesh, 32 | with only the name change from `xyz.obj` to `xyz_collision.obj`. 33 | 34 | V-HACD help: 35 | ``` 36 | -h : Maximum number of output convex hulls. Default is 32 37 | -r : Total number of voxels to use. Default is 100,000 38 | -e : Volume error allowed as a percentage. Default is 1%. Valid range is 0.001 to 10 39 | -d : Maximum recursion depth. Default value is 10. 40 | -s : Whether or not to shrinkwrap output to source mesh. Default is true. 41 | -f : Fill mode. Default is 'flood', also 'surface' and 'raycast' are valid. 42 | -v : Maximum number of vertices in the output convex hull. Default value is 64 43 | -a : Whether or not to run asynchronously. Default is 'true' 44 | -l : Minimum size of a voxel edge. Default value is 2 voxels. 45 | -p : If false, splits hulls in the middle. If true, tries to find optimal split plane location. False by default. 46 | -o : Export the convex hulls as a series of wavefront OBJ files, STL files, or a single USDA. 47 | -g : If set to false, no logging will be displayed. 48 | ``` 49 | """ 50 | COMMAND = [ 51 | "TestVHACD", 52 | mesh_path, 53 | "-r", 54 | "100000", 55 | "-o", 56 | "obj", 57 | "-g", 58 | "false", 59 | "-h", 60 | "32", 61 | "-v", 62 | "32", 63 | ] 64 | output: Optional[subprocess.CompletedProcess] = None 65 | assert num_retries > 1 66 | for _ in range(num_retries): 67 | try: 68 | output = subprocess.run(COMMAND, check=True) 69 | except subprocess.CalledProcessError as e: 70 | print("V-HACD failed to run on %s, retrying..." % mesh_path) 71 | continue 72 | if output is None or output.returncode != 0: 73 | raise RuntimeError("V-HACD failed to run on %s" % mesh_path) 74 | 75 | def prepare_gripper(gripper_idx: int, ctrlpts, model_root: str): 76 | save_gripper_dir = os.path.join(model_root, 'grippers', str(gripper_idx)) 77 | if os.path.exists(save_gripper_dir): 78 | return save_gripper_dir 79 | else: 80 | save_3d_gripper( 81 | ctrlpts[:len(ctrlpts) // 2], 82 | ctrlpts[len(ctrlpts) // 2:], 83 | width=0.1, 84 | sample_size=25, 85 | save_gripper_dir=save_gripper_dir, 86 | ) 87 | meshl_path = os.path.join(save_gripper_dir, "fingerl.obj") 88 | compute_collision(meshl_path) 89 | meshr_path = os.path.join(save_gripper_dir, "fingerr.obj") 90 | compute_collision(meshr_path) 91 | generate_gripper_3d_xml(len(glob.glob(os.path.join(save_gripper_dir, "fingerl0*.obj"))), len(glob.glob(os.path.join(save_gripper_dir, "fingerr0*.obj"))), gripper_idx, os.path.join(model_root, 'gripper_%d.xml' % gripper_idx)) 92 | return save_gripper_dir 93 | 94 | @ray.remote(num_cpus=2) 95 | def sim_test(ctrlpts, object_name: str, gripper_idx: int=0, object_idx: int=0, object_order_idx: int=0, model_root: str="assets", save_dir: str="sim", gui: bool = False, render: bool = True, num_rot: int = 360, ori_range: list = [-1.0, 1.0], render_last: bool = False): 96 | save_gripper_dir = prepare_gripper(gripper_idx, ctrlpts, model_root) 97 | while not (os.path.exists(os.path.join(model_root, 'gripper_%d.xml' % gripper_idx)) and os.path.getsize(os.path.join(model_root, 'gripper_%d.xml' % gripper_idx))>0): 98 | time.sleep(0.1) 99 | gripper_img = render_mesh(save_gripper_dir) 100 | gripper_img_path = os.path.join(save_dir, '%d_%d_gripper.png' % (object_idx, gripper_idx)) 101 | cv2.imwrite(gripper_img_path, gripper_img) 102 | 103 | save_object_dir = prepare_object(object_name, object_idx, model_root) 104 | while not (os.path.exists(os.path.join(model_root, 'object_%d.xml' % object_idx)) and os.path.getsize(os.path.join(model_root, 'object_%d.xml' % object_idx))>0): 105 | time.sleep(0.1) 106 | contours = render_object_mesh(save_object_dir, np.linspace(ori_range[0], ori_range[1], num_rot//36) * np.pi + np.pi) 107 | 108 | scene_path = os.path.join(model_root, 'scene_%d_%d.xml' % (object_idx, gripper_idx)) 109 | generate_scene_3d_xml(object_idx, gripper_idx, scene_path) 110 | 111 | model = mujoco.MjModel.from_xml_path(scene_path) 112 | data = mujoco.MjData(model) 113 | reset_qpos = data.qpos.copy() 114 | reset_qvel = data.qvel.copy() 115 | reset_force = data.qfrc_applied.copy() 116 | handle = viewer.launch_passive(model, data) if gui else None 117 | obj_root_idx = [model.joint(jointid).name for jointid in range(model.njnt)].index("object_root") 118 | obj_jnt = model.joint(obj_root_idx) 119 | assert obj_jnt.type == 0 # freejoint 120 | 121 | left_grip_idx = [model.joint(jointid).name for jointid in range(model.njnt)].index("left_grip") 122 | left_grip_jnt = model.joint(left_grip_idx) 123 | right_grip_idx = [model.joint(jointid).name for jointid in range(model.njnt)].index("right_grip") 124 | right_grip_jnt = model.joint(right_grip_idx) 125 | 126 | if render or render_last: 127 | renderer = mujoco.Renderer(model, 128, 128) 128 | # renderer.enable_segmentation_rendering() 129 | camera = mujoco.MjvCamera() 130 | camera.lookat[:] = [0.0, 0.0, 0.0] 131 | camera.distance = 0.8 132 | camera.azimuth = 135 133 | camera.elevation = -45 134 | 135 | z_rots = np.linspace(ori_range[0], ori_range[1], num_rot) * np.pi + np.pi 136 | init_poses = np.zeros((len(z_rots), 7)) 137 | final_poses = np.zeros((len(z_rots), 7)) 138 | imgs = np.zeros((len(z_rots) // 36, 800, 128, 128, 3), dtype=np.int8) 139 | # segs = np.zeros((len(z_rots) // 36, 100, 128, 128), dtype=np.int16) 140 | final_final_poses = np.zeros((len(z_rots), 7)) 141 | for k, z_rot in enumerate(z_rots): 142 | data.qpos[:] = reset_qpos[:].copy() 143 | data.qvel[:] = reset_qvel[:].copy() 144 | data.qfrc_applied[:] = reset_force 145 | data.qpos[obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 3] = [0, 0, 0,] 146 | data.qpos[ 147 | obj_jnt.qposadr[0] + 3 : obj_jnt.qposadr[0] + 7 148 | ] = euler.euler2quat(0, 0, z_rot) 149 | init_poses[k, :] = data.qpos[ 150 | obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 7 151 | ] 152 | data.ctrl[0] = 0.5 153 | data.ctrl[1] = -0.5 154 | for t in range(32000): 155 | if handle is not None and t % 10 == 0: 156 | handle.sync() 157 | input(f"Press Enter to continue..., {t}") 158 | if t % 800 == 0 and t > 0: 159 | # reset the positions velocities forces of the gripper 160 | data.qpos[left_grip_jnt.qposadr[0]] = reset_qpos[left_grip_jnt.qposadr[0]] 161 | data.qpos[right_grip_jnt.qposadr[0]] = reset_qpos[right_grip_jnt.qposadr[0]] 162 | data.qvel[:] = reset_qvel[:] 163 | data.qfrc_applied[:] = reset_force[:] 164 | mujoco.mj_step(model, data) 165 | if (render and k % 36 == 0 and t % 40 == 0) or (render_last and k % 36 ==0 and t == 7999): 166 | renderer.update_scene(data, camera) 167 | img = renderer.render() 168 | # seg = renderer.render()[..., 0] 169 | # segs[k // 36, t // 40, ...] = seg 170 | # img = color_maps[seg] 171 | imgs[k // 36, t // 40, ...] = img 172 | if t == 800: 173 | final_poses[k, :] = data.qpos[ 174 | obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 7 175 | ] 176 | final_final_poses[k, :] = data.qpos[ 177 | obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 7 178 | ] 179 | 180 | save_data = { 181 | "obj_pos": init_poses[..., :3].reshape((-1, 3)), 182 | "obj_theta": np.asarray([quaternions.quat2axangle(quat)[-1] for quat in init_poses[..., 3:].reshape((-1, 4))], dtype=np.float32), 183 | "delta_theta": np.asarray([continuous_signed_delta(quaternions.quat2axangle(last_quat)[-1], quaternions.quat2axangle(quat)[-1]) for last_quat, quat in zip(init_poses[..., 3:].reshape((-1, 4)), final_poses[..., 3:].reshape((-1, 4)))], dtype=np.float32), # shape: (num_rot,) 184 | "delta_pos": (final_poses[..., :3] - init_poses[..., :3]).reshape((-1, 3)), 185 | } 186 | os.makedirs(os.path.join(save_dir, '%d_%d' % (object_idx, gripper_idx)), exist_ok=True) 187 | np.savez_compressed(os.path.join(save_dir, "%d_%d.npz" % (object_idx, gripper_idx)), save_data) 188 | # visualize and save profile 189 | profile = np.asarray([1 if delta_theta > threshold[0] else -1 if delta_theta < -threshold[0] else 0 for delta_theta in save_data['delta_theta']]) 190 | profile_x = np.asarray([1 if delta_pos[0] > threshold[1] else -1 if delta_pos[0] < -threshold[1] else 0 for delta_pos in save_data['delta_pos']]) 191 | profile_y = np.asarray([1 if delta_pos[1] > threshold[2] else -1 if delta_pos[1] < -threshold[2] else 0 for delta_pos in save_data['delta_pos']]) 192 | visualize_profile(profile, os.path.join(save_dir, '%d_%d_profile.png' % (object_idx, gripper_idx)), ori_range=ori_range) 193 | visualize_profile(profile_x, os.path.join(save_dir, '%d_%d_profile_x.png' % (object_idx, gripper_idx)), ori_range=ori_range) 194 | visualize_profile(profile_y, os.path.join(save_dir, '%d_%d_profile_y.png' % (object_idx, gripper_idx)), ori_range=ori_range) 195 | final_thetas = np.asarray([quaternions.quat2axangle(quat)[-1] for quat in final_final_poses[:, 3:].reshape((-1, 4))], dtype=np.float32) 196 | final_delta_thetas = np.asarray([continuous_signed_delta(init_theta, final_theta) for final_theta, init_theta in zip(final_thetas, save_data['obj_theta'])], dtype=np.float32) 197 | visualize_finals(final_thetas, os.path.join(save_dir, '%d_%d_final.png' % (object_idx, gripper_idx))) 198 | metrics = { 199 | 'delta_theta': save_data['delta_theta']*180/np.pi, 200 | 'delta_pos': save_data['delta_pos']*100, 201 | 'profile': profile + 1, 202 | 'profile_x': profile_x + 1, 203 | 'profile_y': profile_y + 1, 204 | 'final_theta': final_thetas*180/np.pi, 205 | 'final_delta_theta': final_delta_thetas*180/np.pi, 206 | 'final_pos': final_final_poses[:, :3]*100, 207 | } 208 | 209 | if render: 210 | videos = [] 211 | for video_idx, video in enumerate(imgs): 212 | with imageio.get_writer(os.path.join(save_dir, '%d_%d' % (object_idx, gripper_idx), '%d.mp4' % video_idx), fps=20) as writer: 213 | for frame_idx, frame in enumerate(video): 214 | cv2.drawContours(frame, [contours[video_idx]], -1, (38, 80, 115), 1) 215 | writer.append_data(frame.astype(np.uint8)) 216 | videos.append(os.path.join(save_dir, '%d_%d' % (object_idx, gripper_idx), '%d.mp4' % video_idx)) 217 | return gripper_img_path, metrics, os.path.join(save_dir, '%d_%d_profile.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_x.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_y.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_final.png' % (object_idx, gripper_idx)), videos, gripper_idx, object_order_idx, save_gripper_dir 218 | elif render_last: 219 | last_imgs = [] 220 | for video_idx, video in enumerate(imgs): 221 | img_last = video[-1].copy() 222 | cv2.drawContours(img_last, [contours[video_idx]], -1, (38, 80, 115), 1) 223 | cv2.imwrite(os.path.join(save_dir, '%d_%d' % (object_idx, gripper_idx), '%d.png' % video_idx), img_last) 224 | last_imgs.append(os.path.join(save_dir, '%d_%d' % (object_idx, gripper_idx), '%d.png' % video_idx)) 225 | return gripper_img_path, metrics, os.path.join(save_dir, '%d_%d_profile.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_x.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_y.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_final.png' % (object_idx, gripper_idx)), last_imgs, gripper_idx, object_order_idx, save_gripper_dir 226 | else: 227 | return gripper_img_path, metrics, os.path.join(save_dir, '%d_%d_profile.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_x.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_y.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_final.png' % (object_idx, gripper_idx)), gripper_idx, object_order_idx, save_gripper_dir 228 | 229 | def sim_test_batch_3d(ctrlpts_y, object_names, save_dir, num_cpus=32, num_rot=360, ori_range=[-1.0, 1.0], render=True, render_last=False): 230 | model_root = os.path.join(save_dir, 'sim_model') 231 | num_gripper = ctrlpts_y.shape[0] 232 | ray.init(num_cpus=num_cpus, log_to_driver=False) 233 | ray_tasks = [] 234 | for i, object_name in enumerate(object_names): 235 | for idx, p_y in enumerate(ctrlpts_y): 236 | p_y = p_y.reshape(-1) 237 | p_y = p_y * 0.05 - 0.05 # scale p_y from [-1, 1] to [-0.1, 0] 238 | ray_tasks.append(sim_test.remote(ctrlpts=p_y, object_name=object_name, gripper_idx=idx, object_idx=i, object_order_idx=i, model_root=model_root, save_dir=save_dir, gui=False, render=render, num_rot=num_rot, ori_range=ori_range, render_last=render_last)) 239 | gripper_imgs, metrics, profiles, profiles_x, profiles_y, finals, videos, save_gripper_dirs = {}, {}, {}, {}, {}, {}, {}, {} 240 | while len(ray_tasks) > 0: 241 | ready, ray_tasks = ray.wait(ray_tasks, num_returns=1) 242 | try: 243 | if render or render_last: 244 | gripper_img_path, metric, profile, profile_x, profile_y, final, video, gripper_idx, object_idx, save_gripper_dir = ray.get(ready[0]) 245 | gripper_imgs[object_idx * num_gripper + gripper_idx] = gripper_img_path 246 | metrics[object_idx * num_gripper + gripper_idx] = metric 247 | profiles[object_idx * num_gripper + gripper_idx] = profile 248 | profiles_x[object_idx * num_gripper + gripper_idx] = profile_x 249 | profiles_y[object_idx * num_gripper + gripper_idx] = profile_y 250 | finals[object_idx * num_gripper + gripper_idx] = final 251 | videos[object_idx * num_gripper + gripper_idx] = video 252 | save_gripper_dirs[object_idx * num_gripper + gripper_idx] = save_gripper_dir 253 | else: 254 | gripper_img_path, metric, profile, profile_x, profile_y, final, gripper_idx, object_idx, save_gripper_dir = ray.get(ready[0]) 255 | gripper_imgs[object_idx * num_gripper + gripper_idx] = gripper_img_path 256 | metrics[object_idx * num_gripper + gripper_idx] = metric 257 | profiles[object_idx * num_gripper + gripper_idx] = profile 258 | profiles_x[object_idx * num_gripper + gripper_idx] = profile_x 259 | profiles_y[object_idx * num_gripper + gripper_idx] = profile_y 260 | finals[object_idx * num_gripper + gripper_idx] = final 261 | save_gripper_dirs[object_idx * num_gripper + gripper_idx] = save_gripper_dir 262 | except Exception as e: 263 | print(e) 264 | continue 265 | ray.shutdown() 266 | gripper_imgs = list(map(lambda x: x[1], sorted(gripper_imgs.items(), key=lambda x: x[0]))) 267 | metrics = list(map(lambda x: x[1], sorted(metrics.items(), key=lambda x: x[0]))) 268 | profiles = list(map(lambda x: x[1], sorted(profiles.items(), key=lambda x: x[0]))) 269 | profiles_x = list(map(lambda x: x[1], sorted(profiles_x.items(), key=lambda x: x[0]))) 270 | profiles_y = list(map(lambda x: x[1], sorted(profiles_y.items(), key=lambda x: x[0]))) 271 | finals = list(map(lambda x: x[1], sorted(finals.items(), key=lambda x: x[0]))) 272 | save_gripper_dirs = list(map(lambda x: x[1], sorted(save_gripper_dirs.items(), key=lambda x: x[0]))) 273 | if render or render_last: 274 | videos = list(map(lambda x: x[1], sorted(videos.items(), key=lambda x: x[0]))) 275 | return gripper_imgs, metrics, profiles, profiles_x, profiles_y, finals, videos, save_gripper_dirs 276 | else: 277 | return gripper_imgs, metrics, profiles, profiles_x, profiles_y, finals, [], save_gripper_dirs -------------------------------------------------------------------------------- /dynamics/sim_test_mj.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | from os.path import join as pjoin 5 | BASEPATH = os.path.dirname(__file__) 6 | sys.path.insert(0, BASEPATH) 7 | sys.path.insert(0, pjoin(BASEPATH, '..')) 8 | from typing import Optional 9 | 10 | import mujoco 11 | from transforms3d import euler, quaternions 12 | import numpy as np 13 | import subprocess 14 | from mujoco import viewer 15 | import ray 16 | import subprocess 17 | import time 18 | import imageio 19 | import cv2 20 | 21 | from assets.finger_sampler import generate_xml, generate_scene_xml, save_gripper 22 | from assets.icon_process import extract_contours 23 | from dynamics.utils import continuous_signed_delta 24 | from dynamics.utils import visualize_profile, visualize_finals, visualize_ctrlpts 25 | from sim.sim_2d import OBJECT_DIR, prepare_icon_object 26 | 27 | threshold = np.array([0.03, 0.002, 0.003]) 28 | # map segments shape [128, 128] to colors [128, 128, 3] 29 | color_map = np.asarray([ 30 | [155, 184, 205], 31 | [238, 199, 89], 32 | [177, 195, 129], 33 | # [255, 255, 255], 34 | [255, 247, 212], 35 | ], dtype=np.uint8) 36 | color_maps = np.concatenate([color_map for _ in range(32)], axis=0) 37 | 38 | def compute_collision(mesh_path, num_retries: int = 2): 39 | """ 40 | Computes the convex decomposition of a mesh using v-hacd. 41 | Convention: the input mesh is assumed to be in the same folder as the output mesh, 42 | with only the name change from `xyz.obj` to `xyz_collision.obj`. 43 | 44 | V-HACD help: 45 | ``` 46 | -h : Maximum number of output convex hulls. Default is 32 47 | -r : Total number of voxels to use. Default is 100,000 48 | -e : Volume error allowed as a percentage. Default is 1%. Valid range is 0.001 to 10 49 | -d : Maximum recursion depth. Default value is 10. 50 | -s : Whether or not to shrinkwrap output to source mesh. Default is true. 51 | -f : Fill mode. Default is 'flood', also 'surface' and 'raycast' are valid. 52 | -v : Maximum number of vertices in the output convex hull. Default value is 64 53 | -a : Whether or not to run asynchronously. Default is 'true' 54 | -l : Minimum size of a voxel edge. Default value is 2 voxels. 55 | -p : If false, splits hulls in the middle. If true, tries to find optimal split plane location. False by default. 56 | -o : Export the convex hulls as a series of wavefront OBJ files, STL files, or a single USDA. 57 | -g : If set to false, no logging will be displayed. 58 | ``` 59 | """ 60 | COMMAND = [ 61 | "TestVHACD", 62 | mesh_path, 63 | "-r", 64 | "100000", 65 | "-o", 66 | "obj", 67 | "-g", 68 | "false", 69 | "-h", 70 | "16", 71 | "-v", 72 | "32", 73 | ] 74 | output: Optional[subprocess.CompletedProcess] = None 75 | assert num_retries > 1 76 | for _ in range(num_retries): 77 | try: 78 | output = subprocess.run(COMMAND, check=True) 79 | except subprocess.CalledProcessError as e: 80 | print("V-HACD failed to run on %s, retrying..." % mesh_path) 81 | continue 82 | if output is None or output.returncode != 0: 83 | raise RuntimeError("V-HACD failed to run on %s" % mesh_path) 84 | 85 | def prepare_finger(idx: int, ctrlpts, model_root: str): 86 | save_gripper_dir = os.path.join(model_root, 'grippers', str(idx)) 87 | if os.path.exists(save_gripper_dir): 88 | return save_gripper_dir 89 | else: 90 | save_gripper( 91 | ctrlpts[:ctrlpts.shape[0]//2, 0], 92 | ctrlpts[:ctrlpts.shape[0]//2, 1], 93 | ctrlpts[ctrlpts.shape[0]//2:, 1], 94 | width=0.03, 95 | height=0.02, 96 | num_points=200, 97 | save_gripper_dir=save_gripper_dir, 98 | ) 99 | meshl_path = os.path.join(save_gripper_dir, "fingerl.obj") 100 | compute_collision(meshl_path) 101 | meshr_path = os.path.join(save_gripper_dir, "fingerr.obj") 102 | compute_collision(meshr_path) 103 | generate_xml(len(glob.glob(os.path.join(save_gripper_dir, "fingerl0*.obj"))), len(glob.glob(os.path.join(save_gripper_dir, "fingerr0*.obj"))), idx, os.path.join(model_root, 'gripper_%d.xml' % idx)) 104 | return save_gripper_dir 105 | 106 | # @profile 107 | @ray.remote(num_cpus=2) 108 | def sim_test(ctrlpts, object_image, gripper_idx: int=0, object_idx: int=0, object_order_idx: int=0, model_root: str="assets", save_dir: str="sim", gui: bool = False, render: bool = True, num_rot: int = 360, ori_range: list = [-1.0, 1.0], render_last: bool = True): 109 | save_gripper_dir = prepare_finger(gripper_idx, ctrlpts, model_root) 110 | prepare_icon_object(object_idx, object_image, model_root) 111 | 112 | scene_path = os.path.join(model_root, 'scene_%d_%d.xml' % (object_idx, gripper_idx)) 113 | generate_scene_xml(object_idx, gripper_idx, scene_path) 114 | 115 | while not (os.path.exists(os.path.join(model_root, 'object_%d.xml' % object_idx)) and os.path.getsize(os.path.join(model_root, 'object_%d.xml' % object_idx))>0 and os.path.exists(os.path.join(model_root, 'gripper_%d.xml' % gripper_idx)) and os.path.getsize(os.path.join(model_root, 'gripper_%d.xml' % gripper_idx))>0): 116 | time.sleep(0.1) 117 | 118 | model = mujoco.MjModel.from_xml_path(scene_path) 119 | data = mujoco.MjData(model) 120 | reset_qpos = data.qpos.copy() 121 | reset_qvel = data.qvel.copy() 122 | reset_force = data.qfrc_applied.copy() 123 | handle = viewer.launch_passive(model, data) if gui else None 124 | obj_root_idx = [model.joint(jointid).name for jointid in range(model.njnt)].index("object_root") 125 | obj_jnt = model.joint(obj_root_idx) 126 | assert obj_jnt.type == 0 # freejoint 127 | 128 | left_grip_idx = [model.joint(jointid).name for jointid in range(model.njnt)].index("left_grip") 129 | left_grip_jnt = model.joint(left_grip_idx) 130 | right_grip_idx = [model.joint(jointid).name for jointid in range(model.njnt)].index("right_grip") 131 | right_grip_jnt = model.joint(right_grip_idx) 132 | 133 | if render or render_last: 134 | renderer = mujoco.Renderer(model, 128, 128) 135 | renderer.enable_segmentation_rendering() 136 | camera = mujoco.MjvCamera() 137 | camera.lookat[:] = [0.0, 0.0, 0.0] 138 | camera.distance = 0.45 139 | camera.azimuth = 180 140 | camera.elevation = -90 141 | 142 | z_rots = np.linspace(ori_range[0], ori_range[1], num_rot) * np.pi + np.pi 143 | init_poses = np.zeros((len(z_rots), 7)) 144 | final_poses = np.zeros((len(z_rots), 7)) 145 | # imgs = np.zeros((len(z_rots) // 36, 200, 128, 128, 3)) 146 | segs = np.zeros((max(1, len(z_rots) // 36), 400, 128, 128), dtype=np.int16) 147 | final_final_poses = np.zeros((len(z_rots), 7)) 148 | for k, z_rot in enumerate(z_rots): 149 | data.qpos[:] = reset_qpos[:] 150 | data.qvel[:] = reset_qvel[:] 151 | data.qfrc_applied[:] = reset_force 152 | data.qpos[obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 3] = [0, 0, 0,] 153 | data.qpos[ 154 | obj_jnt.qposadr[0] + 3 : obj_jnt.qposadr[0] + 7 155 | ] = euler.euler2quat(0, 0, z_rot) 156 | init_poses[k, :] = data.qpos[ 157 | obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 7 158 | ] 159 | data.ctrl[0] = 0.2 160 | data.ctrl[1] = -0.2 161 | for t in range(8000): 162 | if handle is not None and t % 10 == 0: 163 | handle.sync() 164 | input(f"Press Enter to continue..., {t}") 165 | if t % 200 == 0 and t > 0: 166 | # reset the positions velocities forces of the gripper 167 | data.qpos[left_grip_jnt.qposadr[0]] = reset_qpos[left_grip_jnt.qposadr[0]] 168 | data.qpos[right_grip_jnt.qposadr[0]] = reset_qpos[right_grip_jnt.qposadr[0]] 169 | data.qvel[:] = reset_qvel[:] 170 | data.qfrc_applied[:] = reset_force[:] 171 | mujoco.mj_step(model, data) 172 | if (render and k % 36 == 0 and t % 20 == 0) or (render_last and k % 36 == 0 and (t == 1999 or t == 0)): 173 | renderer.update_scene(data, camera) 174 | # img = renderer.render() 175 | seg = renderer.render()[..., 0] 176 | segs[k // 36, t // 20, ...] = seg 177 | # img = color_maps[seg] 178 | # imgs[k // 36, t // 20, ...] = img 179 | if t == 200: 180 | final_poses[k, :] = data.qpos[ 181 | obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 7 182 | ] 183 | final_final_poses[k, :] = data.qpos[ 184 | obj_jnt.qposadr[0] : obj_jnt.qposadr[0] + 7 185 | ] 186 | 187 | save_data = { 188 | "ctrlpts": ctrlpts, 189 | "obj_pos": init_poses[..., :3].reshape((-1, 3)), 190 | "obj_theta": np.asarray([quaternions.quat2axangle(quat)[-1] for quat in init_poses[..., 3:].reshape((-1, 4))], dtype=np.float32), 191 | "delta_theta": np.asarray([continuous_signed_delta(quaternions.quat2axangle(last_quat)[-1], quaternions.quat2axangle(quat)[-1]) for last_quat, quat in zip(init_poses[..., 3:].reshape((-1, 4)), final_poses[..., 3:].reshape((-1, 4)))], dtype=np.float32), # shape: (num_rot,) 192 | "delta_pos": (final_poses[..., :3] - init_poses[..., :3]).reshape((-1, 3)), 193 | } 194 | os.makedirs(os.path.join(save_dir, '%d_%d' % (object_idx, gripper_idx)), exist_ok=True) 195 | np.savez_compressed(os.path.join(save_dir, "%d_%d.npz" % (object_idx, gripper_idx)), save_data) 196 | # visualize and save profile 197 | visualize_ctrlpts(ctrlpts, os.path.join(save_dir, '%d_%d_ctrlpts.png' % (object_idx, gripper_idx))) 198 | profile = np.asarray([1 if delta_theta > threshold[0] else -1 if delta_theta < -threshold[0] else 0 for delta_theta in save_data['delta_theta']]) 199 | profile_x = np.asarray([1 if delta_pos[0] > threshold[1] else -1 if delta_pos[0] < -threshold[1] else 0 for delta_pos in save_data['delta_pos']]) 200 | profile_y = np.asarray([1 if delta_pos[1] > threshold[2] else -1 if delta_pos[1] < -threshold[2] else 0 for delta_pos in save_data['delta_pos']]) 201 | visualize_profile(profile, os.path.join(save_dir, '%d_%d_profile.png' % (object_idx, gripper_idx)), ori_range=ori_range) 202 | visualize_profile(profile_x, os.path.join(save_dir, '%d_%d_profile_x.png' % (object_idx, gripper_idx)), ori_range=ori_range) 203 | visualize_profile(profile_y, os.path.join(save_dir, '%d_%d_profile_y.png' % (object_idx, gripper_idx)), ori_range=ori_range) 204 | final_thetas = np.asarray([quaternions.quat2axangle(quat)[-1] for quat in final_final_poses[:, 3:].reshape((-1, 4))], dtype=np.float32) 205 | final_delta_thetas = np.asarray([continuous_signed_delta(init_theta, final_theta) for final_theta, init_theta in zip(final_thetas, save_data['obj_theta'])], dtype=np.float32) 206 | visualize_finals(final_thetas, os.path.join(save_dir, '%d_%d_final.png' % (object_idx, gripper_idx))) 207 | # columns = ['video', 'obj_theta', 'delta_pos', 'delta_theta', 'final_theta'] 208 | # table = wandb.Table(columns=columns) 209 | metrics = { 210 | 'delta_theta': save_data['delta_theta']*180/np.pi, 211 | 'delta_pos': save_data['delta_pos']*100, 212 | 'profile': profile + 1, 213 | 'profile_x': profile_x + 1, 214 | 'profile_y': profile_y + 1, 215 | 'final_theta': final_thetas*180/np.pi, 216 | 'final_delta_theta': final_delta_thetas*180/np.pi, 217 | 'final_pos': final_final_poses[:, :3]*100, 218 | } 219 | if render: 220 | videos = [] 221 | for video_idx, video in enumerate(segs): 222 | with imageio.get_writer(os.path.join(save_dir, '%d_%d' % (object_idx, gripper_idx), '%d.mp4' % video_idx), fps=20) as writer: 223 | init_contour = None 224 | for frame_idx, frame in enumerate(video): 225 | img = color_maps[frame] 226 | if frame_idx == 0: 227 | img_cp = img.copy() 228 | img_cp[frame % 4 != 0, :] = 255 229 | init_contour = extract_contours(img_cp, num_points=100, rescale=False) 230 | assert init_contour.shape == (100, 2) 231 | cv2.drawContours(img, [init_contour], -1, (38, 80, 115), 1) 232 | writer.append_data(img.astype(np.uint8)) 233 | videos.append(os.path.join(save_dir, '%d_%d' % (object_idx, gripper_idx), '%d.mp4' % video_idx)) 234 | return os.path.join(save_dir, '%d_%d_ctrlpts.png' % (object_idx, gripper_idx)), metrics, os.path.join(save_dir, '%d_%d_profile.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_x.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_y.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_final.png' % (object_idx, gripper_idx)), videos, gripper_idx, object_order_idx, save_gripper_dir 235 | elif render_last: 236 | last_imgs = [] 237 | for seg_idx, seg in enumerate(segs): 238 | img_cp = color_maps[seg[0]].copy() 239 | img_cp[seg[0] % 4 != 0, :] = 255 240 | init_contour = extract_contours(img_cp, num_points=100, rescale=False) 241 | img = color_maps[seg[-1]] 242 | cv2.drawContours(img, [init_contour], -1, (38, 80, 115), 1) 243 | cv2.imwrite(os.path.join(save_dir, '%d_%d' % (object_idx, gripper_idx), '%d.png' % seg_idx), img) 244 | last_imgs.append(os.path.join(save_dir, '%d_%d' % (object_idx, gripper_idx), '%d.png' % seg_idx)) 245 | return os.path.join(save_dir, '%d_%d_ctrlpts.png' % (object_idx, gripper_idx)), metrics, os.path.join(save_dir, '%d_%d_profile.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_x.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_y.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_final.png' % (object_idx, gripper_idx)), last_imgs, gripper_idx, object_order_idx, save_gripper_dir 246 | else: 247 | return os.path.join(save_dir, '%d_%d_ctrlpts.png' % (object_idx, gripper_idx)), metrics, os.path.join(save_dir, '%d_%d_profile.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_x.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_profile_y.png' % (object_idx, gripper_idx)), os.path.join(save_dir, '%d_%d_final.png' % (object_idx, gripper_idx)), gripper_idx, object_order_idx, save_gripper_dir 248 | 249 | def sim_test_batch(pts_y, object_ids, save_dir, num_cpus=32, num_rot=360, ori_range=[-1.0, 1.0], render=True, render_last=False): 250 | model_root = os.path.join(save_dir, 'sim_model') 251 | num_gripper = pts_y.shape[0] 252 | object_images = np.load(OBJECT_DIR, allow_pickle=True).item()['image'][object_ids] 253 | ray.init(num_cpus=num_cpus, log_to_driver=False) 254 | ray_tasks = [] 255 | for i, obj_idx in enumerate(object_ids): 256 | for idx, p_y in enumerate(pts_y): 257 | p_x = np.linspace(-0.12, 0.12, p_y.shape[0] // 2) 258 | p_x = np.concatenate([p_x, p_x], axis=0) 259 | p_x = np.expand_dims(p_x, axis=-1) 260 | # scale p_y from [-1,1] to [-0.045,0.015] 261 | p_y = p_y * 0.03 - 0.015 262 | pts = np.concatenate([p_x, p_y], axis=-1) 263 | ray_tasks.append(sim_test.remote(ctrlpts=pts, object_image=object_images[i].transpose((1, 2, 0)), gripper_idx=idx, object_idx=obj_idx, object_order_idx=i, model_root=model_root, save_dir=save_dir, gui=False, render=render, num_rot=num_rot, ori_range=ori_range, render_last=render_last)) 264 | imgs, metrics, profiles, profiles_x, profiles_y, finals, videos, save_gripper_dirs = {}, {}, {}, {}, {}, {}, {}, {} 265 | while len(ray_tasks) > 0: 266 | ready, ray_tasks = ray.wait(ray_tasks, num_returns=1) 267 | try: 268 | if render or render_last: 269 | img, metric, profile, profile_x, profile_y, final, video, gripper_idx, object_idx, save_gripper_dir = ray.get(ready[0]) 270 | videos[object_idx * num_gripper + gripper_idx] = video 271 | else: 272 | img, metric, profile, profile_x, profile_y, final, gripper_idx, object_idx, save_gripper_dir = ray.get(ready[0]) 273 | imgs[object_idx * num_gripper + gripper_idx] = img 274 | metrics[object_idx * num_gripper + gripper_idx] = metric 275 | profiles[object_idx * num_gripper + gripper_idx] = profile 276 | profiles_x[object_idx * num_gripper + gripper_idx] = profile_x 277 | profiles_y[object_idx * num_gripper + gripper_idx] = profile_y 278 | finals[object_idx * num_gripper + gripper_idx] = final 279 | save_gripper_dirs[object_idx * num_gripper + gripper_idx] = save_gripper_dir 280 | except Exception as e: 281 | print(e) 282 | continue 283 | ray.shutdown() 284 | imgs = list(map(lambda x: x[1], sorted(imgs.items(), key=lambda x: x[0]))) 285 | metrics = list(map(lambda x: x[1], sorted(metrics.items(), key=lambda x: x[0]))) 286 | profiles = list(map(lambda x: x[1], sorted(profiles.items(), key=lambda x: x[0]))) 287 | profiles_x = list(map(lambda x: x[1], sorted(profiles_x.items(), key=lambda x: x[0]))) 288 | profiles_y = list(map(lambda x: x[1], sorted(profiles_y.items(), key=lambda x: x[0]))) 289 | finals = list(map(lambda x: x[1], sorted(finals.items(), key=lambda x: x[0]))) 290 | save_gripper_dirs = list(map(lambda x: x[1], sorted(save_gripper_dirs.items(), key=lambda x: x[0]))) 291 | if render or render_last: 292 | videos = list(map(lambda x: x[1], sorted(videos.items(), key=lambda x: x[0]))) 293 | return imgs, metrics, profiles, profiles_x, profiles_y, finals, videos, save_gripper_dirs 294 | else: 295 | return imgs, metrics, profiles, profiles_x, profiles_y, finals, [], save_gripper_dirs --------------------------------------------------------------------------------