├── .gitignore ├── assets ├── dyn_0000.pt ├── teaser.png ├── real_data.pt └── open_manipulator │ ├── meshes │ ├── chain_link1.stl │ ├── chain_link2.stl │ ├── chain_link3.stl │ ├── chain_link4.stl │ ├── chain_link5.stl │ ├── chain_link_grip_l.stl │ └── chain_link_grip_r.stl │ └── open_manipulator_joint2_only_v2.urdf ├── requirements.txt ├── configs └── hard_ball.yaml ├── eval.py ├── train.py ├── README.md ├── render_usd.py └── diffsim.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | experiments/ 3 | outputs/ -------------------------------------------------------------------------------- /assets/dyn_0000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediosZ/WarpDiffRobot/HEAD/assets/dyn_0000.pt -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediosZ/WarpDiffRobot/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /assets/real_data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediosZ/WarpDiffRobot/HEAD/assets/real_data.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | warp-lang 2 | torch 3 | hydra-core 4 | omegaconf 5 | tqdm 6 | urdfpy 7 | networkx 8 | matplotlib 9 | usd-core -------------------------------------------------------------------------------- /assets/open_manipulator/meshes/chain_link1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediosZ/WarpDiffRobot/HEAD/assets/open_manipulator/meshes/chain_link1.stl -------------------------------------------------------------------------------- /assets/open_manipulator/meshes/chain_link2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediosZ/WarpDiffRobot/HEAD/assets/open_manipulator/meshes/chain_link2.stl -------------------------------------------------------------------------------- /assets/open_manipulator/meshes/chain_link3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediosZ/WarpDiffRobot/HEAD/assets/open_manipulator/meshes/chain_link3.stl -------------------------------------------------------------------------------- /assets/open_manipulator/meshes/chain_link4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediosZ/WarpDiffRobot/HEAD/assets/open_manipulator/meshes/chain_link4.stl -------------------------------------------------------------------------------- /assets/open_manipulator/meshes/chain_link5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediosZ/WarpDiffRobot/HEAD/assets/open_manipulator/meshes/chain_link5.stl -------------------------------------------------------------------------------- /assets/open_manipulator/meshes/chain_link_grip_l.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediosZ/WarpDiffRobot/HEAD/assets/open_manipulator/meshes/chain_link_grip_l.stl -------------------------------------------------------------------------------- /assets/open_manipulator/meshes/chain_link_grip_r.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MediosZ/WarpDiffRobot/HEAD/assets/open_manipulator/meshes/chain_link_grip_r.stl -------------------------------------------------------------------------------- /configs/hard_ball.yaml: -------------------------------------------------------------------------------- 1 | trajectory: 'robotis_2_hard_ball' 2 | ckpt: null 3 | ckpt_idx: -2 4 | urdf: 'open_manipulator/open_manipulator_joint2_only_v2' 5 | output_obj: False 6 | training: 7 | train_iters: 15 8 | train_rate: 0.01 9 | mass_diff_config: 'cup_diff_none' 10 | load_file_path_overwrite: './assets/real_data.pt' 11 | compare_indices_predictions: [-1, null, 1] 12 | compare_indices_targets: [35, 36, 1] 13 | loss_type: null 14 | sim: 15 | mass_diff: null 16 | initialization_filename: './assets/dyn_0000.pt' 17 | gravity: -9.81 18 | frame_dt: 0.0166666666 19 | density: -1 20 | armature: 0.01 21 | episode_duration: 0.6 22 | sim_substeps: 5 23 | collapse_fixed_joints: False 24 | ignore_inertial_definitions: False 25 | parse_visuals_as_colliders: True 26 | modify_object_type: 'hard_ball' 27 | update_mass_matrix_every: -1 28 | is_colliding: False 29 | requires_grad: True 30 | render: 31 | every_n_frame: 1 32 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | import warp as wp 4 | import time 5 | 6 | from diffsim import Sim, generate_traj 7 | 8 | @hydra.main(config_path="./configs", config_name="config", version_base=None) 9 | def main(cfg: DictConfig) -> None: 10 | # Access parameters from the configuration file 11 | traj_list = generate_traj(cfg.trajectory) 12 | mode = "dataset" if cfg.ckpt is None else "test" 13 | 14 | experiment = Sim( 15 | cfg, 16 | traj_list, 17 | device=wp.get_preferred_device(), 18 | verbose=True, 19 | mass_diff=None, 20 | mode=mode, 21 | ) 22 | 23 | total_start_time = time.time() 24 | experiment.forward() 25 | total_elapsed_time = time.time() - total_start_time 26 | print(f"Total time: {total_elapsed_time:.4f} seconds") 27 | 28 | if mode == "test": 29 | experiment.load_state() 30 | experiment.compute_loss() 31 | experiment.save_testing() 32 | experiment.render() 33 | experiment.save_state() 34 | 35 | if experiment.renderer: 36 | experiment.renderer.save() 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | from tqdm.autonotebook import tqdm, trange 4 | import warp as wp 5 | import time 6 | 7 | from diffsim import generate_traj, generate_mass_diff, Sim 8 | 9 | @hydra.main(config_path="./configs", config_name="config") 10 | def main(cfg: DictConfig) -> None: 11 | # Access parameters from the configuration file 12 | traj_list = generate_traj(cfg.trajectory) 13 | mode = 'train' 14 | mass_diff = generate_mass_diff(cfg.training.mass_diff_config) 15 | 16 | experiment = Sim(cfg, traj_list, device=wp.get_preferred_device(), verbose=True, mass_diff=mass_diff, mode=mode) 17 | experiment.load_state() 18 | 19 | # Start timing before the loop 20 | total_start_time = time.time() 21 | for epoch in trange(cfg.training.train_iters): 22 | experiment.step() 23 | # experiment.render() 24 | tqdm.write('[{}]'.format(experiment.msg)) 25 | 26 | # Calculate the total elapsed time after the loop ends 27 | total_elapsed_time = time.time() - total_start_time 28 | print(f'Total time for {cfg.training.train_iters} iterations: {total_elapsed_time:.4f} seconds') 29 | 30 | experiment.save_state() 31 | if experiment.renderer: 32 | experiment.renderer.save() 33 | experiment.plot_loss() 34 | experiment.save_training() 35 | 36 | if __name__ == "__main__": 37 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Differentiable robot-object interaction for learning object properties 2 | 3 | ![teaser figure](./assets/teaser.png) 4 | 5 | This repository is the official implementation of the paper: 6 | 7 | > **[Learning Object Properties Using Robot Proprioception via Differentiable Robot-Object Interaction]([https://arxiv.org/abs/2206.02607](https://arxiv.org/abs/2410.03920))** 8 | > [Peter Yichen Chen](https://peterchencyc.com), [Chao Liu](https://chaoliu.tech), [Pingchuan Ma](https://pingchuan.ma), [John Eastman](http://redsweatshirt.github.io), [Daniela Rus](https://danielarus.csail.mit.edu), [Dylan Randle](https://dylanrandle.github.io), [Yuri Ivanov](https://www.linkedin.com/in/yivanov/), [Wojciech Matusik](https://cdfg.mit.edu/wojciech)\ 9 | > MIT CSAIL, Amazon Robotics, University of British Columbia\ 10 | > **International Conference on Robotics and Automation (ICRA), 2025** 11 | 12 | A big shoutout to the [Nvidia Warp team](https://nvidia.github.io/warp/)! Warp integrates effortlessly with Torch, streamlining the use of differentiable simulation for Torch-based optimization workflows. 13 | 14 | ## Installation 15 | 16 | Install the required packages first: 17 | 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | For visualization, install these optional packages: 23 | 24 | - bpy 25 | - blendertoolbox 26 | 27 | and these softwares: 28 | 29 | - Blender 30 | - ffmpeg 31 | 32 | ## Usage 33 | 34 | To calibrate object properties, use the following command: 35 | 36 | ```bash 37 | python train.py --config-name hard_ball 38 | ``` 39 | 40 | To evaluate the calibrated object property, use the following command:`` 41 | 42 | ```bash 43 | python eval.py --config-name hard_ball ckpt=experiments/log/robotis_2_hard_ball/open_manipulator/open_manipulator_joint2_only_v2/train/training_stats.pt ckpt_idx=8 44 | ``` 45 | 46 | To visualize the robot, use the following command: 47 | 48 | ```bash 49 | python render_usd.py --usd-path experiments/log/robotis_2_hard_ball/open_manipulator/open_manipulator_joint2_only_v2/test/test_ckpt_idx_0008.usd 50 | ``` 51 | 52 | ## If this helps you, please consider citing the paper below. 53 | ``` 54 | @misc{chen2025learningobjectpropertiesusing, 55 | title={Learning Object Properties Using Robot Proprioception via Differentiable Robot-Object Interaction}, 56 | author={Peter Yichen Chen and Chao Liu and Pingchuan Ma and John Eastman and Daniela Rus and Dylan Randle and Yuri Ivanov and Wojciech Matusik}, 57 | year={2025}, 58 | eprint={2410.03920}, 59 | archivePrefix={arXiv}, 60 | primaryClass={cs.RO}, 61 | url={https://arxiv.org/abs/2410.03920}, 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /render_usd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import bpy 4 | import blendertoolbox as bt 5 | import pathlib 6 | 7 | parser = argparse.ArgumentParser(description="Render a USD file with Blender") 8 | parser.add_argument("--usd-path", required=True, help="Path to the USD file", nargs='*') 9 | args = parser.parse_args() 10 | 11 | imgRes_x, imgRes_y = 1920, 1080 12 | numSamples = 16 13 | exposure = 1.5 14 | use_GPU = True 15 | bt.blenderInit(imgRes_x, imgRes_y, numSamples, exposure, use_GPU) 16 | cwd = pathlib.Path(__file__).parent.absolute() 17 | 18 | for path_to_usd in args.usd_path: 19 | bpy.ops.wm.usd_import(filepath=(cwd / path_to_usd).as_posix()) 20 | 21 | bt.invisibleGround(location = (0,0,-.5), shadowBrightness=0.9) 22 | 23 | color_dict = { 24 | "blue": [152, 199, 255, 255], 25 | "green": [165, 221, 144, 255], 26 | "red": [255, 154, 156, 255], 27 | "orange": [243, 163, 124, 255], 28 | "brown": [216, 176, 107, 255], 29 | } 30 | 31 | RGBA = [x / 255.0 for x in color_dict['blue']] 32 | meshColor = bt.colorObj(RGBA, 0.5, 1.0, 1.0, 0.0, 2.0) 33 | 34 | for obj in bpy.context.scene.objects: 35 | if obj.type == 'MESH': 36 | if 'mesh' in obj.name: 37 | mesh_obj = obj 38 | bt.setMat_plastic(mesh_obj, meshColor) 39 | elif obj.name == 'Plane': 40 | plane_obj = obj 41 | plane_obj.hide_render = True 42 | # hide the outer sphere 43 | elif obj.name == "shape_8": 44 | sphere_obj = obj 45 | sphere_obj.hide_render = True 46 | elif obj.type == 'LIGHT': 47 | obj.select_set(True) 48 | bpy.context.view_layer.objects.active = obj 49 | bpy.ops.object.delete() 50 | 51 | 52 | # ----- Camera attributes ----- 53 | camLocation = (1, -15, 2) 54 | rotation_euler = (90, 0, 0) 55 | cam = bt.setCamera_from_UI(camLocation, rotation_euler, focalLength = 45) 56 | 57 | # ----- Light attributes 58 | lightAngle = (6, -30, -155) 59 | strength = 2 60 | shadowSoftness = 0.3 61 | sun = bt.setLight_sun(lightAngle, strength, shadowSoftness) 62 | bt.setLight_ambient(color=(0.1,0.1,0.1,1)) 63 | bt.shadowThreshold(alphaThreshold = 0.05, interpolationMode = 'CARDINAL') 64 | 65 | outputFolder = cwd/ 'output' / 'usd_anim' 66 | outputFolder.mkdir(parents=True, exist_ok=True) 67 | outputPath = outputFolder / 'frame' 68 | bpy.ops.wm.save_mainfile(filepath=(outputFolder/ 'usd_anim.blend').as_posix()) 69 | 70 | duration = bpy.context.scene.frame_end 71 | bpy.context.scene.render.image_settings.file_format = 'PNG' 72 | bt.renderAnimation(outputPath.as_posix(), cam, duration) 73 | 74 | os.system("ffmpeg -r 48 -i 'output/usd_anim/frame%*.png' -c:v libx264 -r 48 -pix_fmt yuv420p output/usd_anim.mp4") 75 | -------------------------------------------------------------------------------- /assets/open_manipulator/open_manipulator_joint2_only_v2.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 271 | 272 | 273 | 274 | 275 | 283 | 284 | 285 | -------------------------------------------------------------------------------- /diffsim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import matplotlib.pyplot as plt 5 | import pathlib 6 | import numpy as np 7 | import warp as wp 8 | import warp.sim 9 | import warp.sim.render 10 | wp.init() 11 | 12 | @wp.kernel 13 | def damp_particle_velocity( 14 | particle_qd: wp.vec3f, 15 | ): 16 | tid = wp.tid() 17 | particle_qd[tid] = 0.98 * particle_qd[tid] 18 | 19 | 20 | class ForwardKinematics(torch.autograd.Function): 21 | @staticmethod 22 | def forward( 23 | ctx, 24 | body_mass, 25 | model, 26 | states, 27 | sim_dt, 28 | sim_steps, 29 | controls, 30 | update_mass_matrix_every, 31 | is_colliding, 32 | particle_damping=False, 33 | ): 34 | 35 | ctx.tape = wp.Tape() 36 | ctx.model = model 37 | # NOTE: update mass (torch -> warp) 38 | ctx.model.body_mass = wp.from_torch(body_mass) 39 | ctx.states = states 40 | ctx.sim_dt = sim_dt 41 | ctx.sim_steps = sim_steps 42 | ctx.controls = controls 43 | 44 | with ctx.tape: 45 | ctx.integrator = wp.sim.FeatherstoneIntegrator( 46 | ctx.model, update_mass_matrix_every=update_mass_matrix_every 47 | ) 48 | 49 | for i in range(ctx.sim_steps): 50 | ctx.states[i].clear_forces() 51 | if is_colliding: 52 | wp.sim.collide(ctx.model, ctx.states[i]) 53 | ctx.integrator.simulate( 54 | ctx.model, 55 | ctx.states[i], 56 | ctx.states[i + 1], 57 | ctx.sim_dt, 58 | ctx.controls[i], 59 | ) 60 | 61 | if particle_damping: 62 | wp.launch( 63 | kernel=damp_particle_velocity, 64 | dim=len(ctx.states[i + 1].particle_qd), 65 | inputs=[ctx.states[i + 1].particle_qd], 66 | device=ctx.model.device, 67 | ) 68 | 69 | # NOTE: collect computed joint positions 70 | joint_q_list = [] 71 | for i in range(ctx.sim_steps): 72 | joint_q_list.append(wp.to_torch(ctx.states[i].joint_q)) 73 | return tuple(joint_q_list) 74 | 75 | @staticmethod 76 | def backward(ctx, *adj_joint_q_list): 77 | for i in range(ctx.sim_steps): 78 | ctx.states[i].joint_q.grad = wp.from_torch(adj_joint_q_list[i]) 79 | 80 | ctx.tape.backward() 81 | 82 | # return adjoint w.r.t. inputs 83 | return ( 84 | wp.to_torch(ctx.tape.gradients[ctx.model.body_mass]), 85 | None, 86 | None, 87 | None, 88 | None, 89 | None, 90 | None, 91 | None, 92 | None, 93 | ) 94 | 95 | class Sim: 96 | def __init__( 97 | self, cfg, traj_list, device=None, verbose=False, mass_diff=None, mode=None 98 | ): 99 | self.cfg = cfg 100 | self.traj_name = cfg.trajectory 101 | self.train_rate = cfg.training.train_rate 102 | self.save_dir = pathlib.Path(__file__).parent / "experiments" / "log" / self.traj_name / cfg.urdf / mode 103 | self.save_dir.mkdir(parents=True, exist_ok=True) 104 | self.save_file_path = self.save_dir / "sim_data.pt" 105 | if mode == "train" or mode == "test": 106 | load_dir = pathlib.Path(__file__).parent / "experiments" / "log" / self.traj_name / cfg.urdf / "dataset" 107 | self.load_file_path = load_dir / "sim_data.pt" 108 | if mode == "train": 109 | self.save_stats_path = self.save_dir / "training_stats.pt" 110 | stage = self.save_dir / "train.usd" 111 | elif mode == "test": 112 | self.save_stats_path = self.save_dir / "testing_stats.pt" 113 | stage = self.save_dir / f"test_ckpt_idx_{cfg.ckpt_idx:04d}.usd" 114 | else: 115 | self.load_file_path = self.save_file_path 116 | stage = self.save_dir / "eval.usd" 117 | stage = stage.as_posix() 118 | if cfg.training.load_file_path_overwrite: 119 | self.load_file_path = cfg.training.load_file_path_overwrite 120 | 121 | self.mode = mode 122 | self.verbose = verbose 123 | self.losses = [] 124 | self.masses = [] 125 | 126 | articulation_builder = wp.sim.ModelBuilder(gravity=cfg.sim.gravity) 127 | 128 | if hasattr(cfg, "urdf_overwrite"): 129 | cfg.urdf = cfg.urdf_overwrite 130 | 131 | urdf_path = pathlib.Path(__file__).parent / "assets" / f"{cfg.urdf}.urdf" 132 | 133 | # Import robots unless cfg.sim says not to 134 | if not hasattr(cfg.sim, "import_robot") or cfg.sim.import_robot: 135 | wp.sim.parse_urdf( 136 | urdf_path, 137 | articulation_builder, 138 | xform=wp.transform( 139 | (0.0, 0.0, 0.0), 140 | wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), -math.pi * 0.5), 141 | ), 142 | floating=False, 143 | density=cfg.sim.density, 144 | armature=cfg.sim.armature, 145 | stiffness=0.0, 146 | damping=0.0, 147 | limit_ke=1.0e4, 148 | limit_kd=1.0e1, 149 | enable_self_collisions=False, 150 | parse_visuals_as_colliders=cfg.sim.parse_visuals_as_colliders, 151 | collapse_fixed_joints=cfg.sim.collapse_fixed_joints, 152 | ignore_inertial_definitions=cfg.sim.ignore_inertial_definitions, 153 | ) 154 | 155 | builder = wp.sim.ModelBuilder() 156 | 157 | self.sim_time = 0.0 158 | self.frame_dt = cfg.sim.frame_dt 159 | 160 | episode_duration = cfg.sim.episode_duration # seconds 161 | self.episode_frames = int(episode_duration / self.frame_dt) 162 | 163 | self.sim_substeps = cfg.sim.sim_substeps 164 | self.sim_dt = self.frame_dt / self.sim_substeps 165 | 166 | self.num_envs = traj_list["num_env"] 167 | num_per_dim = int(math.sqrt(self.num_envs)) 168 | 169 | self.control_func = traj_list["control"] 170 | 171 | articulation_builder = modify_builder_with_object( 172 | articulation_builder, cfg.sim.modify_object_type, cfg 173 | ) 174 | 175 | for id in range(self.num_envs): 176 | i = int(id / num_per_dim) 177 | j = id % num_per_dim 178 | articulation_builder.joint_q = traj_list["q"][id] 179 | builder.add_builder( 180 | articulation_builder, 181 | xform=wp.transform( 182 | np.array(((i) * 2.0, (j) * 2.0, 0.0)), wp.quat_identity() 183 | ), 184 | ) 185 | 186 | # finalize model 187 | # use `requires_grad=True` to create a model for differentiable simulation 188 | if cfg.sim.initialization_filename: 189 | builder = modify_builder_with_joint_data( 190 | cfg.sim.initialization_filename, builder 191 | ) 192 | 193 | self.model = builder.finalize(device, requires_grad=cfg.sim.requires_grad) 194 | self.model.ground = False 195 | 196 | self.torch_device = wp.device_to_torch(self.model.device) 197 | self.renderer = wp.sim.render.SimRenderer( 198 | path=stage, model=self.model, scaling=15.0 199 | ) 200 | 201 | self.render_time = 0.0 202 | self.joint_q_list = None 203 | 204 | # optimization variable 205 | self.body_mass = wp.to_torch(self.model.body_mass, requires_grad=False).clone() 206 | self.body_mass_single = self.body_mass[0 : len(articulation_builder.body_mass)] 207 | 208 | if cfg.sim.mass_diff: 209 | mass_diff = cfg.sim.mass_diff 210 | 211 | # noise 212 | if mass_diff is not None: 213 | mass_diff = torch.tensor(mass_diff, device=self.torch_device) 214 | self.one_indices = [ 215 | index for index, value in enumerate(mass_diff) if value == 1 216 | ] 217 | self.non_one_indices = [ 218 | index for index, value in enumerate(mass_diff) if value != 1 219 | ] 220 | print("ground truth: ", self.body_mass_single[self.non_one_indices]) 221 | self.body_mass_single *= mass_diff 222 | print("after noise : ", self.body_mass_single[self.non_one_indices]) 223 | else: 224 | self.one_indices = [] 225 | self.non_one_indices = [ 226 | index for index, _ in enumerate(self.body_mass_single) 227 | ] 228 | 229 | # override variable with values read from ckpt 230 | if cfg.ckpt: 231 | ckpt_path = pathlib.Path(__file__).parent / cfg.ckpt 232 | ckpt = torch.load(ckpt_path, weights_only=False) 233 | masses = ckpt["masses"] 234 | self.body_mass_single = ( 235 | masses[cfg.ckpt_idx].detach().clone() 236 | ) # second to the last epoch 237 | print("masses: ", self.body_mass_single) 238 | 239 | self.body_mass_single.requires_grad_() 240 | 241 | self.optimizer = torch.optim.Adam([self.body_mass_single], lr=self.train_rate) 242 | self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 243 | optimizer=self.optimizer, T_max=cfg.training.train_iters 244 | ) 245 | self.criterion = torch.nn.MSELoss() 246 | 247 | self.compare_indices_predictions = slice( 248 | cfg.training.compare_indices_predictions[0], 249 | cfg.training.compare_indices_predictions[1], 250 | cfg.training.compare_indices_predictions[2], 251 | ) 252 | 253 | self.compare_indices_targets = slice( 254 | cfg.training.compare_indices_targets[0], 255 | cfg.training.compare_indices_targets[1], 256 | cfg.training.compare_indices_targets[2], 257 | ) 258 | 259 | self.sim_steps = self.episode_frames * self.sim_substeps 260 | self.update_mass_matrix_every = ( 261 | self.sim_steps 262 | if cfg.sim.update_mass_matrix_every == -1 263 | else cfg.sim.update_mass_matrix_every 264 | ) 265 | self.is_colliding = cfg.sim.is_colliding 266 | self.particle_damping = ( 267 | True 268 | if hasattr(cfg.sim, "particle_damping") and cfg.sim.particle_damping 269 | else False 270 | ) 271 | 272 | def forward(self): 273 | # update all the states 274 | self.body_mass_all = self.body_mass_single.repeat(self.num_envs) 275 | with torch.no_grad(): 276 | self.model.body_inv_mass = wp.from_torch(1.0 / self.body_mass_all) 277 | 278 | # allocate sim states for trajectory 279 | self.states = [] 280 | for i in range(self.sim_steps + 1): 281 | self.states.append(self.model.state()) 282 | 283 | self.controls = [] 284 | for i in range(self.sim_steps): 285 | self.controls.append(self.model.control()) 286 | self.controls[i].joint_act = self.control_func(i) 287 | 288 | self.joint_q_list = ForwardKinematics.apply( 289 | self.body_mass_all, 290 | self.model, 291 | self.states, 292 | self.sim_dt, 293 | self.sim_steps, 294 | self.controls, 295 | self.update_mass_matrix_every, 296 | self.is_colliding, 297 | self.particle_damping, 298 | ) 299 | 300 | def compute_loss(self): 301 | predictions = torch.cat(self.joint_q_list) 302 | targets = torch.cat(self.load_data["joint_q_list"]).to(predictions.device) 303 | self.loss = self.criterion( 304 | predictions[self.compare_indices_predictions], 305 | targets[self.compare_indices_targets], 306 | ) 307 | 308 | def step(self): 309 | def closure(): 310 | self.forward() 311 | self.compute_loss() 312 | self.loss.backward() 313 | 314 | self.body_mass_single.grad[self.one_indices] = 0 # fixed object 315 | 316 | self.msg = "loss: {loss}, loss grad: {loss_grad}, masses: {masses}".format( 317 | loss=self.loss.item(), 318 | loss_grad=self.body_mass_single.grad[self.non_one_indices], 319 | masses=self.body_mass_single[self.non_one_indices], 320 | ) 321 | 322 | # Append the info to the list 323 | self.losses.append(self.loss.item()) 324 | self.masses.append(self.body_mass_single.clone()) 325 | 326 | return self.loss.item() # Return loss value 327 | 328 | # Perform optimization step 329 | loss = self.optimizer.step(closure) 330 | self.optimizer.zero_grad() 331 | self.lr_scheduler.step() 332 | 333 | def render(self): 334 | if self.renderer is None: 335 | return 336 | frame_count = 0 337 | print("render begin.") 338 | for i in range(0, self.sim_steps): 339 | if i % (self.sim_substeps * self.cfg.render.every_n_frame) == 0: 340 | self.renderer.begin_frame(self.render_time) 341 | 342 | self.renderer.render(self.states[i]) 343 | 344 | self.renderer.end_frame() 345 | self.render_time += self.frame_dt 346 | 347 | cfg = dict(zip(self.model.joint_name, self.states[i].joint_q.numpy())) 348 | print(self.states[i].joint_q) 349 | if self.cfg.output_obj: 350 | self.output_obj.output(cfg, frame_count) 351 | frame_count += 1 352 | 353 | frame_idx = 0 354 | frame_time = 0.0 355 | for i in range(0, self.sim_steps): 356 | if i % (self.sim_substeps) == 0: 357 | frame_idx += 1 358 | frame_time += self.frame_dt 359 | name = f"sim_{frame_idx:04d}.pt" 360 | save_data = { 361 | "time": frame_time, 362 | "joint_q": wp.to_torch(self.states[i].joint_q), 363 | } 364 | save_dir = self.save_dir / "sim_data" 365 | save_dir.mkdir(parents=True, exist_ok=True) 366 | torch.save(save_data, save_dir / name) 367 | print("render finish.") 368 | 369 | def save_state(self, save=True): 370 | with torch.no_grad(): 371 | self.save_data = { 372 | "joint_q_list": self.joint_q_list, 373 | } 374 | if save: 375 | torch.save(self.save_data, self.save_file_path) 376 | 377 | def save_training(self): 378 | # Find the index of the smallest loss 379 | min_loss_index = self.losses.index(min(self.losses)) 380 | # Get the smallest loss and its corresponding mass 381 | smallest_loss = self.losses[min_loss_index] 382 | corresponding_mass = self.masses[min_loss_index] 383 | 384 | print( 385 | "min_loss_index, smallest_loss, corresponding_mass: ", 386 | min_loss_index, 387 | smallest_loss, 388 | corresponding_mass, 389 | ) 390 | 391 | with torch.no_grad(): 392 | self.save_stats = { 393 | "losses": self.losses, 394 | "masses": self.masses, 395 | } 396 | torch.save(self.save_stats, self.save_stats_path) 397 | print("saved training: ", self.save_stats_path) 398 | 399 | def save_testing(self): 400 | with torch.no_grad(): 401 | self.save_stats = { 402 | "loss": self.loss, 403 | } 404 | print("[loss: {loss}]".format(loss=self.loss.item())) 405 | print("saved testing: ", self.save_stats_path) 406 | 407 | def load_state(self): 408 | file = pathlib.Path(self.load_file_path) 409 | if file.is_absolute(): 410 | file = file.as_posix() 411 | else: 412 | file = pathlib.Path(__file__).parent / file 413 | self.load_data = torch.load(file, weights_only=False) 414 | 415 | def plot_loss(self): 416 | plt.figure(figsize=(6, 4)) 417 | plt.plot( 418 | self.losses, 419 | linestyle="--", 420 | label="Difference between simulation and observation", 421 | ) 422 | plt.xlabel("Iteration") 423 | plt.ylabel("Loss") 424 | plt.legend() 425 | plt.savefig( 426 | self.save_dir / f"loss_vs_iter_{self.traj_name}.png", 427 | dpi=300, 428 | bbox_inches="tight", 429 | ) 430 | 431 | def generate_traj(traj_config): 432 | if traj_config == 'robotis_2_hard_ball': 433 | traj_list = {} 434 | num_env = 1 435 | q_list = [[0.0]] 436 | act_list_flat = [-0.49] 437 | def control_func(step): 438 | act = wp.array(act_list_flat, dtype=float, requires_grad=True) 439 | return act 440 | traj_list['num_env'] = num_env 441 | traj_list['q'] = q_list 442 | traj_list['control'] = control_func 443 | 444 | return traj_list 445 | else: 446 | raise ValueError(f"Invalid traj_config: {traj_config}") 447 | 448 | def generate_mass_diff(mass_diff_config): 449 | if mass_diff_config == 'cup_diff_none': 450 | mass_diff = [1,1,1,1,1,1,1,1,1.0001] 451 | return mass_diff 452 | else: 453 | raise ValueError(f'Invalid mass_diff_config: {mass_diff_config}.') 454 | 455 | def modify_builder_with_joint_data(file, builder): 456 | file = pathlib.Path(file) 457 | if file.is_absolute(): 458 | file = file.as_posix() 459 | else: 460 | file = pathlib.Path(__file__).parent / file 461 | # Load data from the file 462 | data = torch.load(file, weights_only=False) 463 | joint_q = data["joint_q"] 464 | 465 | # Apply offset to joint_q[2] 466 | offset = -np.radians(90.0) 467 | joint_q[2] += offset 468 | 469 | # Limit joint_q to the first four joints 470 | joint_q = joint_q[:4] 471 | 472 | # Modify builder's joint_X_p based on joint names 473 | for idx, transform in enumerate(builder.joint_X_p): 474 | if builder.joint_name[idx] == "joint1": 475 | joint_axis = wp.vec3(0.0, 0.0, 1.0) 476 | angle = joint_q[0] 477 | elif builder.joint_name[idx] == "joint2": 478 | joint_axis = wp.vec3(0.0, 1.0, 0.0) 479 | angle = joint_q[1] 480 | elif builder.joint_name[idx] == "joint3": 481 | joint_axis = wp.vec3(0.0, 1.0, 0.0) 482 | angle = joint_q[2] 483 | elif builder.joint_name[idx] == "joint4": 484 | joint_axis = wp.vec3(0.0, 1.0, 0.0) 485 | angle = joint_q[3] 486 | else: 487 | joint_axis = None 488 | angle = None 489 | 490 | # If joint_axis is valid, update the transform in the builder 491 | if joint_axis is not None: 492 | rot = wp.quat_from_axis_angle(joint_axis, float(angle)) 493 | builder.joint_X_p[idx] = wp.transform( 494 | wp.transform_get_translation(transform), rot 495 | ) 496 | 497 | return builder 498 | 499 | def modify_builder_with_object(builder, modify_object_type, cfg): 500 | if modify_object_type == "hard_ball": 501 | b = builder.add_body() 502 | 503 | builder.add_shape_sphere( 504 | body=b, radius=0.0225, density=10, has_shape_collision=False 505 | ) 506 | 507 | builder.add_joint_fixed( 508 | parent=b - 1, 509 | child=b, 510 | ) 511 | 512 | return builder 513 | elif modify_object_type == None: 514 | return builder 515 | else: 516 | raise ValueError(f"Invalid modify_object_type: {modify_object_type}.") 517 | --------------------------------------------------------------------------------