├── .gitignore ├── LICENSE ├── README.md ├── assets ├── PickCube-v1_eval_return_4090.png └── PickCube-v1_eval_success_once_4090.png ├── records ├── 10312024_cudagraphs │ ├── ppo.py │ └── script.sh └── baseline │ ├── ppo.py │ └── script.sh ├── rl_robotics_speedrun_colab.ipynb └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs/ 2 | .vscode/ 3 | wandb/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Stone Tao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RL Robotics Speedrun 2 | 3 | Speed-running solving robot manipulation tasks in [ManiSkill](https://github.com/haosulab/ManiSkill). Goal is to simply solve a list of tasks as fast as possible with RL + fixed dense rewards starting from scratch. 4 | 5 | 6 | Inspired by the [great speedrunning work done for LLMs by Jordan et. al](https://github.com/KellerJordan/modded-nanogpt) 7 | 8 | ## Getting Started 9 | 10 | You can run the speedrun/benchmark via your local machine or [google colab](https://colab.research.google.com/github/StoneT2000/rl-robotics-speedrun/blob/main/rl_robotics_speedrun_colab.ipynb). To set up locally, we recommend using conda/mamba to manage dependencies: 11 | 12 | ```bash 13 | conda create -n rl-robotics-speedrun python=3.11 14 | conda activate rl-robotics-speedrun 15 | git clone https://github.com/StoneT2000/rl-robotics-speedrun 16 | cd rl-robotics-speedrun 17 | pip install -e . 18 | ``` 19 | 20 | ## Benchmarking 21 | 22 | To run the benchmark `cd` into one of the folders and run the script.sh file. 23 | 24 | This by default logs results to tensorboard and wandb. You will need to setup a [wandb account](https://wandb.ai/). Wandb helps better display aggregated results. 25 | 26 | ```bash 27 | cd records/baseline && bash script.sh 28 | ``` 29 | 30 | The current standard is to run PPO initialized with random weights on 10 different seeds. By default this does not record evaluation videos. Remove `--no-capture_video` to record videos. Remove `--track` to not use wandb. Each of the 10 runs takes about 2 minutes to run on a 4090. 31 | 32 | Current best is `records/10312024_cudagraphs` which is standard PPO + GPU Simulation, very few steps per environment during rollouts, and cudagraphs enabled based on [leanrl](https://github.com/pytorch-labs/LeanRL/). Achieves >= 95% success rate after ~80s 33 | 34 | 35 | | Environment | Best Time | Wandb Results | 36 | |------------|-----------|---------------| 37 | | PickCube-v1 (state) | 80s | [Link](https://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun/workspace?nw=qgul0t4vstq) | 38 | | PickCube-v1 (visual) | ~ | ~ | 39 | | PushT-v1 (state) | ~ | ~ | 40 | | PushT-v1 (visual) | ~ | ~ | 41 | 42 | All results on 4090: [Wandb Link](https://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun/workspace?nw=qgul0t4vstq) 43 | 44 | All results on L4 (Google Colab GPU): [Wandb Link](https://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun/workspace?nw=i9kpqaqywjd) 45 | 46 | Figure of results on 4090 for PickCube-v1 across 10 seeds. Shaded area is the standard error: 47 | ![](./assets/PickCube-v1_eval_return_4090.png) 48 | ![](./assets/PickCube-v1_eval_success_once_4090.png) -------------------------------------------------------------------------------- /assets/PickCube-v1_eval_return_4090.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StoneT2000/rl-robotics-speedrun/5a9cf05ec0ecec2cf39f4469200e2446b3cd5a6f/assets/PickCube-v1_eval_return_4090.png -------------------------------------------------------------------------------- /assets/PickCube-v1_eval_success_once_4090.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StoneT2000/rl-robotics-speedrun/5a9cf05ec0ecec2cf39f4469200e2446b3cd5a6f/assets/PickCube-v1_eval_success_once_4090.png -------------------------------------------------------------------------------- /records/10312024_cudagraphs/ppo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from mani_skill.utils import gym_utils 4 | from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper 5 | from mani_skill.utils.wrappers.record import RecordEpisode 6 | from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv 7 | 8 | os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" 9 | 10 | import math 11 | import os 12 | import random 13 | import time 14 | from collections import defaultdict, deque 15 | from dataclasses import dataclass 16 | from typing import Optional, Tuple 17 | 18 | import gymnasium as gym 19 | import numpy as np 20 | import tensordict 21 | import torch 22 | import torch.nn as nn 23 | import torch.optim as optim 24 | import tqdm 25 | import tyro 26 | from torch.utils.tensorboard import SummaryWriter 27 | import wandb 28 | from tensordict import from_module 29 | from tensordict.nn import CudaGraphModule 30 | from torch.distributions.normal import Normal 31 | 32 | 33 | @dataclass 34 | class Args: 35 | exp_name: Optional[str] = None 36 | """the name of this experiment""" 37 | seed: int = 1 38 | """seed of the experiment""" 39 | torch_deterministic: bool = True 40 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 41 | cuda: bool = True 42 | """if toggled, cuda will be enabled by default""" 43 | track: bool = False 44 | """if toggled, this experiment will be tracked with Weights and Biases""" 45 | wandb_project_name: str = "ManiSkill" 46 | """the wandb's project name""" 47 | wandb_entity: Optional[str] = None 48 | """the entity (team) of wandb's project""" 49 | wandb_group: str = "PPO" 50 | """the group of the run for wandb""" 51 | capture_video: bool = True 52 | """whether to capture videos of the agent performances (check out `videos` folder)""" 53 | save_model: bool = False 54 | """whether to save model into the `runs/{run_name}` folder""" 55 | evaluate: bool = False 56 | """if toggled, only runs evaluation with the given model checkpoint and saves the evaluation trajectories""" 57 | checkpoint: Optional[str] = None 58 | """path to a pretrained checkpoint file to start evaluation/training from""" 59 | 60 | # Environment specific arguments 61 | env_id: str = "PickCube-v1" 62 | """the id of the environment""" 63 | env_vectorization: str = "gpu" 64 | """the type of environment vectorization to use""" 65 | num_envs: int = 512 66 | """the number of parallel environments""" 67 | num_eval_envs: int = 16 68 | """the number of parallel evaluation environments""" 69 | partial_reset: bool = True 70 | """whether to let parallel environments reset upon termination instead of truncation""" 71 | num_steps: int = 50 72 | """the number of steps to run in each environment per policy rollout""" 73 | num_eval_steps: int = 50 74 | """the number of steps to run in each evaluation environment during evaluation""" 75 | reconfiguration_freq: Optional[int] = None 76 | """how often to reconfigure the environment during training""" 77 | eval_reconfiguration_freq: Optional[int] = 1 78 | """for benchmarking purposes we want to reconfigure the eval environment each reset to ensure objects are randomized in some tasks""" 79 | eval_freq: int = 25 80 | """evaluation frequency in terms of iterations""" 81 | save_train_video_freq: Optional[int] = None 82 | """frequency to save training videos in terms of iterations""" 83 | 84 | # Algorithm specific arguments 85 | total_timesteps: int = 10000000 86 | """total timesteps of the experiments""" 87 | learning_rate: float = 3e-4 88 | """the learning rate of the optimizer""" 89 | anneal_lr: bool = False 90 | """Toggle learning rate annealing for policy and value networks""" 91 | gamma: float = 0.8 92 | """the discount factor gamma""" 93 | gae_lambda: float = 0.9 94 | """the lambda for the general advantage estimation""" 95 | num_minibatches: int = 32 96 | """the number of mini-batches""" 97 | update_epochs: int = 4 98 | """the K epochs to update the policy""" 99 | norm_adv: bool = True 100 | """Toggles advantages normalization""" 101 | clip_coef: float = 0.2 102 | """the surrogate clipping coefficient""" 103 | clip_vloss: bool = False 104 | """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" 105 | ent_coef: float = 0.0 106 | """coefficient of the entropy""" 107 | vf_coef: float = 0.5 108 | """coefficient of the value function""" 109 | max_grad_norm: float = 0.5 110 | """the maximum norm for the gradient clipping""" 111 | target_kl: float = 0.1 112 | """the target KL divergence threshold""" 113 | reward_scale: float = 1.0 114 | """Scale the reward by this factor""" 115 | finite_horizon_gae: bool = False 116 | 117 | # to be filled in runtime 118 | batch_size: int = 0 119 | """the batch size (computed in runtime)""" 120 | minibatch_size: int = 0 121 | """the mini-batch size (computed in runtime)""" 122 | num_iterations: int = 0 123 | """the number of iterations (computed in runtime)""" 124 | 125 | # Torch optimizations 126 | compile: bool = False 127 | """whether to use torch.compile.""" 128 | cudagraphs: bool = False 129 | """whether to use cudagraphs on top of compile.""" 130 | 131 | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): 132 | torch.nn.init.orthogonal_(layer.weight, std) 133 | torch.nn.init.constant_(layer.bias, bias_const) 134 | return layer 135 | 136 | 137 | class Agent(nn.Module): 138 | def __init__(self, n_obs, n_act, device=None): 139 | super().__init__() 140 | self.critic = nn.Sequential( 141 | layer_init(nn.Linear(n_obs, 256, device=device)), 142 | nn.Tanh(), 143 | layer_init(nn.Linear(256, 256, device=device)), 144 | nn.Tanh(), 145 | layer_init(nn.Linear(256, 256, device=device)), 146 | nn.Tanh(), 147 | layer_init(nn.Linear(256, 1, device=device)), 148 | ) 149 | self.actor_mean = nn.Sequential( 150 | layer_init(nn.Linear(n_obs, 256, device=device)), 151 | nn.Tanh(), 152 | layer_init(nn.Linear(256, 256, device=device)), 153 | nn.Tanh(), 154 | layer_init(nn.Linear(256, 256, device=device)), 155 | nn.Tanh(), 156 | layer_init(nn.Linear(256, n_act, device=device), std=0.01*np.sqrt(2)), 157 | ) 158 | self.actor_logstd = nn.Parameter(torch.zeros(1, n_act, device=device)) 159 | 160 | def get_value(self, x): 161 | return self.critic(x) 162 | 163 | def get_action_and_value(self, obs, action=None): 164 | action_mean = self.actor_mean(obs) 165 | action_logstd = self.actor_logstd.expand_as(action_mean) 166 | action_std = torch.exp(action_logstd) 167 | probs = Normal(action_mean, action_std) 168 | if action is None: 169 | action = action_mean + action_std * torch.randn_like(action_mean) 170 | return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(obs) 171 | 172 | class Logger: 173 | def __init__(self, log_wandb=False, tensorboard: SummaryWriter = None) -> None: 174 | self.writer = tensorboard 175 | self.log_wandb = log_wandb 176 | def add_scalar(self, tag, scalar_value, step): 177 | if self.log_wandb: 178 | wandb.log({tag: scalar_value}, step=step) 179 | self.writer.add_scalar(tag, scalar_value, step) 180 | def close(self): 181 | self.writer.close() 182 | 183 | def gae(next_obs, next_done, container, final_values): 184 | # bootstrap value if not done 185 | next_value = get_value(next_obs).reshape(-1) 186 | lastgaelam = 0 187 | nextnonterminals = (~container["dones"]).float().unbind(0) 188 | vals = container["vals"] 189 | vals_unbind = vals.unbind(0) 190 | rewards = container["rewards"].unbind(0) 191 | 192 | advantages = [] 193 | nextnonterminal = (~next_done).float() 194 | nextvalues = next_value 195 | for t in range(args.num_steps - 1, -1, -1): 196 | cur_val = vals_unbind[t] 197 | # real_next_values = nextvalues * nextnonterminal 198 | real_next_values = nextnonterminal * nextvalues + final_values[t] # t instead of t+1 199 | delta = rewards[t] + args.gamma * real_next_values - cur_val 200 | advantages.append(delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam) 201 | lastgaelam = advantages[-1] 202 | 203 | nextnonterminal = nextnonterminals[t] 204 | nextvalues = cur_val 205 | 206 | advantages = container["advantages"] = torch.stack(list(reversed(advantages))) 207 | container["returns"] = advantages + vals 208 | return container 209 | 210 | 211 | def rollout(obs, done): 212 | ts = [] 213 | final_values = torch.zeros((args.num_steps, args.num_envs), device=device) 214 | for step in range(args.num_steps): 215 | # ALGO LOGIC: action logic 216 | action, logprob, _, value = policy(obs=obs) 217 | 218 | # TRY NOT TO MODIFY: execute the game and log data. 219 | next_obs, reward, next_done, infos = step_func(action) 220 | 221 | if "final_info" in infos: 222 | final_info = infos["final_info"] 223 | done_mask = infos["_final_info"] 224 | for k, v in final_info["episode"].items(): 225 | logger.add_scalar(f"train/{k}", v[done_mask].float().mean(), global_step) 226 | with torch.no_grad(): 227 | final_values[step, torch.arange(args.num_envs, device=device)[done_mask]] = agent.get_value(infos["final_observation"][done_mask]).view(-1) 228 | 229 | ts.append( 230 | tensordict.TensorDict._new_unsafe( 231 | obs=obs, 232 | # cleanrl ppo examples associate the done with the previous obs (not the done resulting from action) 233 | dones=done, 234 | vals=value.flatten(), 235 | actions=action, 236 | logprobs=logprob, 237 | rewards=reward, 238 | batch_size=(args.num_envs,), 239 | ) 240 | ) 241 | # NOTE (stao): change here for gpu env 242 | obs = next_obs = next_obs 243 | done = next_done 244 | # NOTE (stao): need to do .to(device) i think? otherwise container.device is None, not sure if this affects anything 245 | container = torch.stack(ts, 0).to(device) 246 | return next_obs, done, container, final_values 247 | 248 | 249 | def update(obs, actions, logprobs, advantages, returns, vals): 250 | optimizer.zero_grad() 251 | _, newlogprob, entropy, newvalue = agent.get_action_and_value(obs, actions) 252 | logratio = newlogprob - logprobs 253 | ratio = logratio.exp() 254 | 255 | with torch.no_grad(): 256 | # calculate approx_kl http://joschu.net/blog/kl-approx.html 257 | old_approx_kl = (-logratio).mean() 258 | approx_kl = ((ratio - 1) - logratio).mean() 259 | clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean() 260 | 261 | if args.norm_adv: 262 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) 263 | 264 | # Policy loss 265 | pg_loss1 = -advantages * ratio 266 | pg_loss2 = -advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) 267 | pg_loss = torch.max(pg_loss1, pg_loss2).mean() 268 | 269 | # Value loss 270 | newvalue = newvalue.view(-1) 271 | if args.clip_vloss: 272 | v_loss_unclipped = (newvalue - returns) ** 2 273 | v_clipped = vals + torch.clamp( 274 | newvalue - vals, 275 | -args.clip_coef, 276 | args.clip_coef, 277 | ) 278 | v_loss_clipped = (v_clipped - returns) ** 2 279 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) 280 | v_loss = 0.5 * v_loss_max.mean() 281 | else: 282 | v_loss = 0.5 * ((newvalue - returns) ** 2).mean() 283 | 284 | entropy_loss = entropy.mean() 285 | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef 286 | 287 | loss.backward() 288 | gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) 289 | optimizer.step() 290 | 291 | return approx_kl, v_loss.detach(), pg_loss.detach(), entropy_loss.detach(), old_approx_kl, clipfrac, gn 292 | 293 | 294 | update = tensordict.nn.TensorDictModule( 295 | update, 296 | in_keys=["obs", "actions", "logprobs", "advantages", "returns", "vals"], 297 | out_keys=["approx_kl", "v_loss", "pg_loss", "entropy_loss", "old_approx_kl", "clipfrac", "gn"], 298 | ) 299 | 300 | if __name__ == "__main__": 301 | args = tyro.cli(Args) 302 | 303 | batch_size = int(args.num_envs * args.num_steps) 304 | args.minibatch_size = batch_size // args.num_minibatches 305 | args.batch_size = args.num_minibatches * args.minibatch_size 306 | args.num_iterations = args.total_timesteps // args.batch_size 307 | if args.exp_name is None: 308 | args.exp_name = os.path.basename(__file__)[: -len(".py")] 309 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" 310 | else: 311 | run_name = args.exp_name 312 | 313 | # TRY NOT TO MODIFY: seeding 314 | random.seed(args.seed) 315 | np.random.seed(args.seed) 316 | torch.manual_seed(args.seed) 317 | torch.backends.cudnn.deterministic = args.torch_deterministic 318 | 319 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 320 | 321 | ####### Environment setup ####### 322 | env_kwargs = dict(obs_mode="state", control_mode="pd_joint_delta_pos", render_mode="rgb_array", sim_backend="gpu") 323 | envs = gym.make(args.env_id, num_envs=args.num_envs if not args.evaluate else 1, reconfiguration_freq=args.reconfiguration_freq, **env_kwargs) 324 | eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, reconfiguration_freq=args.eval_reconfiguration_freq, human_render_camera_configs=dict(shader_pack="default"),**env_kwargs) 325 | if isinstance(envs.action_space, gym.spaces.Dict): 326 | envs = FlattenActionSpaceWrapper(envs) 327 | eval_envs = FlattenActionSpaceWrapper(eval_envs) 328 | if args.capture_video: 329 | eval_output_dir = f"runs/{run_name}/videos" 330 | if args.evaluate: 331 | eval_output_dir = f"{os.path.dirname(args.checkpoint)}/test_videos" 332 | print(f"Saving eval videos to {eval_output_dir}") 333 | if args.save_train_video_freq is not None: 334 | save_video_trigger = lambda x : (x // args.num_steps) % args.save_train_video_freq == 0 335 | envs = RecordEpisode(envs, output_dir=f"runs/{run_name}/train_videos", save_trajectory=False, save_video_trigger=save_video_trigger, max_steps_per_video=args.num_steps, video_fps=30) 336 | eval_envs = RecordEpisode(eval_envs, output_dir=eval_output_dir, save_trajectory=args.evaluate, trajectory_name="trajectory", max_steps_per_video=args.num_eval_steps, video_fps=30) 337 | envs = ManiSkillVectorEnv(envs, args.num_envs, ignore_terminations=not args.partial_reset, record_metrics=True) 338 | eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=True, record_metrics=True) 339 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 340 | 341 | max_episode_steps = gym_utils.find_max_episode_steps_value(envs._env) 342 | logger = None 343 | if not args.evaluate: 344 | print("Running training") 345 | if args.track: 346 | import wandb 347 | config = vars(args) 348 | config["env_cfg"] = dict(**env_kwargs, num_envs=args.num_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=args.partial_reset) 349 | config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=False) 350 | wandb.init( 351 | project=args.wandb_project_name, 352 | entity=args.wandb_entity, 353 | sync_tensorboard=False, 354 | config=config, 355 | name=run_name, 356 | save_code=True, 357 | group=args.wandb_group, 358 | tags=["ppo", "walltime_efficient", f"GPU:{torch.cuda.get_device_name()}"] 359 | ) 360 | writer = SummaryWriter(f"runs/{run_name}") 361 | writer.add_text( 362 | "hyperparameters", 363 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 364 | ) 365 | logger = Logger(log_wandb=args.track, tensorboard=writer) 366 | else: 367 | print("Running evaluation") 368 | n_act = math.prod(envs.single_action_space.shape) 369 | n_obs = math.prod(envs.single_observation_space.shape) 370 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 371 | 372 | # Register step as a special op not to graph break 373 | # @torch.library.custom_op("mylib::step", mutates_args=()) 374 | def step_func(action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 375 | # NOTE (stao): change here for gpu env 376 | next_obs, reward, terminations, truncations, info = envs.step(action) 377 | next_done = torch.logical_or(terminations, truncations) 378 | return next_obs, reward, next_done, info 379 | 380 | ####### Agent ####### 381 | agent = Agent(n_obs, n_act, device=device) 382 | if args.checkpoint: 383 | agent.load_state_dict(torch.load(args.checkpoint)) 384 | # Make a version of agent with detached params 385 | agent_inference = Agent(n_obs, n_act, device=device) 386 | agent_inference_p = from_module(agent).data 387 | agent_inference_p.to_module(agent_inference) 388 | 389 | ####### Optimizer ####### 390 | optimizer = optim.Adam( 391 | agent.parameters(), 392 | lr=torch.tensor(args.learning_rate, device=device), 393 | eps=1e-5, 394 | capturable=args.cudagraphs and not args.compile, 395 | ) 396 | 397 | ####### Executables ####### 398 | # Define networks: wrapping the policy in a TensorDictModule allows us to use CudaGraphModule 399 | policy = agent_inference.get_action_and_value 400 | get_value = agent_inference.get_value 401 | 402 | # Compile policy 403 | if args.compile: 404 | policy = torch.compile(policy) 405 | gae = torch.compile(gae, fullgraph=True) 406 | update = torch.compile(update) 407 | 408 | if args.cudagraphs: 409 | policy = CudaGraphModule(policy) 410 | gae = CudaGraphModule(gae) 411 | update = CudaGraphModule(update) 412 | 413 | global_step = 0 414 | start_time = time.time() 415 | container_local = None 416 | next_obs = envs.reset()[0] 417 | next_done = torch.zeros(args.num_envs, device=device, dtype=torch.bool) 418 | pbar = tqdm.tqdm(range(1, args.num_iterations + 1)) 419 | 420 | cumulative_times = defaultdict(float) 421 | 422 | for iteration in pbar: 423 | agent.eval() 424 | if iteration % args.eval_freq == 1: 425 | stime = time.perf_counter() 426 | eval_obs, _ = eval_envs.reset() 427 | eval_metrics = defaultdict(list) 428 | num_episodes = 0 429 | for _ in range(args.num_eval_steps): 430 | with torch.no_grad(): 431 | eval_obs, eval_rew, eval_terminations, eval_truncations, eval_infos = eval_envs.step(agent.actor_mean(eval_obs)) 432 | if "final_info" in eval_infos: 433 | mask = eval_infos["_final_info"] 434 | num_episodes += mask.sum() 435 | for k, v in eval_infos["final_info"]["episode"].items(): 436 | eval_metrics[k].append(v) 437 | eval_metrics_mean = {} 438 | for k, v in eval_metrics.items(): 439 | mean = torch.stack(v).float().mean() 440 | eval_metrics_mean[k] = mean 441 | if logger is not None: 442 | logger.add_scalar(f"eval/{k}", mean, global_step) 443 | pbar.set_description( 444 | f"success_once: {eval_metrics_mean['success_once']:.2f}, " 445 | f"return: {eval_metrics_mean['return']:.2f}" 446 | ) 447 | if logger is not None: 448 | eval_time = time.perf_counter() - stime 449 | cumulative_times["eval_time"] += eval_time 450 | logger.add_scalar("time/eval_time", eval_time, global_step) 451 | if args.evaluate: 452 | break 453 | if args.save_model and iteration % args.eval_freq == 1: 454 | model_path = f"runs/{run_name}/ckpt_{iteration}.pt" 455 | torch.save(agent.state_dict(), model_path) 456 | print(f"model saved to {model_path}") 457 | # Annealing the rate if instructed to do so. 458 | if args.anneal_lr: 459 | frac = 1.0 - (iteration - 1.0) / args.num_iterations 460 | lrnow = frac * args.learning_rate 461 | optimizer.param_groups[0]["lr"].copy_(lrnow) 462 | 463 | torch.compiler.cudagraph_mark_step_begin() 464 | rollout_time = time.perf_counter() 465 | next_obs, next_done, container, final_values = rollout(next_obs, next_done) 466 | rollout_time = time.perf_counter() - rollout_time 467 | cumulative_times["rollout_time"] += rollout_time 468 | global_step += container.numel() 469 | 470 | update_time = time.perf_counter() 471 | container = gae(next_obs, next_done, container, final_values) 472 | container_flat = container.view(-1) 473 | 474 | # Optimizing the policy and value network 475 | clipfracs = [] 476 | for epoch in range(args.update_epochs): 477 | b_inds = torch.randperm(container_flat.shape[0], device=device).split(args.minibatch_size) 478 | for b in b_inds: 479 | container_local = container_flat[b] 480 | 481 | out = update(container_local, tensordict_out=tensordict.TensorDict()) 482 | clipfracs.append(out["clipfrac"]) 483 | if args.target_kl is not None and out["approx_kl"] > args.target_kl: 484 | break 485 | else: 486 | continue 487 | break 488 | update_time = time.perf_counter() - update_time 489 | cumulative_times["update_time"] += update_time 490 | 491 | logger.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) 492 | logger.add_scalar("losses/value_loss", out["v_loss"].item(), global_step) 493 | logger.add_scalar("losses/policy_loss", out["pg_loss"].item(), global_step) 494 | logger.add_scalar("losses/entropy", out["entropy_loss"].item(), global_step) 495 | logger.add_scalar("losses/old_approx_kl", out["old_approx_kl"].item(), global_step) 496 | logger.add_scalar("losses/approx_kl", out["approx_kl"].item(), global_step) 497 | logger.add_scalar("losses/clipfrac", torch.stack(clipfracs).mean().cpu().item(), global_step) 498 | # logger.add_scalar("losses/explained_variance", explained_var, global_step) 499 | logger.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) 500 | logger.add_scalar("time/step", global_step, global_step) 501 | logger.add_scalar("time/update_time", update_time, global_step) 502 | logger.add_scalar("time/rollout_time", rollout_time, global_step) 503 | logger.add_scalar("time/rollout_fps", args.num_envs * args.num_steps / rollout_time, global_step) 504 | for k, v in cumulative_times.items(): 505 | logger.add_scalar(f"time/total_{k}", v, global_step) 506 | logger.add_scalar("time/total_rollout+update_time", cumulative_times["rollout_time"] + cumulative_times["update_time"], global_step) 507 | 508 | envs.close() 509 | eval_envs.close() 510 | -------------------------------------------------------------------------------- /records/10312024_cudagraphs/script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | seeds=(9351 4796 1788 2371 3462 4553 5644 6735 7826 8917) 3 | num_envs=4096 4 | update_epochs=12 5 | num_steps=4 6 | num_minibatches=32 7 | exp_name="10312024_cudagraphs" 8 | for seed in ${seeds[@]} 9 | do 10 | python ppo.py --env_id="PickCube-v1" --seed=${seed} --total_timesteps=4_000_000 --eval-freq=10 \ 11 | --num_envs=${num_envs} --update_epochs=${update_epochs} --num_minibatches=${num_minibatches} \ 12 | --num-steps=${num_steps} --cudagraphs \ 13 | --exp-name="ppo-PickCube-v1-state-${seed}-${exp_name}" \ 14 | --track --wandb_project_name="PPO-ManiSkill-GPU-SpeedRun" --wandb_group=${exp_name} --no-capture_video 15 | done -------------------------------------------------------------------------------- /records/baseline/ppo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from mani_skill.utils import gym_utils 4 | from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper 5 | from mani_skill.utils.wrappers.record import RecordEpisode 6 | from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv 7 | 8 | os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" 9 | 10 | import math 11 | import os 12 | import random 13 | import time 14 | from collections import defaultdict, deque 15 | from dataclasses import dataclass 16 | from typing import Optional, Tuple 17 | 18 | import gymnasium as gym 19 | import numpy as np 20 | import tensordict 21 | import torch 22 | import torch.nn as nn 23 | import torch.optim as optim 24 | import tqdm 25 | import tyro 26 | from torch.utils.tensorboard import SummaryWriter 27 | import wandb 28 | from tensordict import from_module 29 | from tensordict.nn import CudaGraphModule 30 | from torch.distributions.normal import Normal 31 | 32 | 33 | @dataclass 34 | class Args: 35 | exp_name: Optional[str] = None 36 | """the name of this experiment""" 37 | seed: int = 1 38 | """seed of the experiment""" 39 | torch_deterministic: bool = True 40 | """if toggled, `torch.backends.cudnn.deterministic=False`""" 41 | cuda: bool = True 42 | """if toggled, cuda will be enabled by default""" 43 | track: bool = False 44 | """if toggled, this experiment will be tracked with Weights and Biases""" 45 | wandb_project_name: str = "ManiSkill" 46 | """the wandb's project name""" 47 | wandb_entity: Optional[str] = None 48 | """the entity (team) of wandb's project""" 49 | wandb_group: str = "PPO" 50 | """the group of the run for wandb""" 51 | capture_video: bool = True 52 | """whether to capture videos of the agent performances (check out `videos` folder)""" 53 | save_model: bool = False 54 | """whether to save model into the `runs/{run_name}` folder""" 55 | evaluate: bool = False 56 | """if toggled, only runs evaluation with the given model checkpoint and saves the evaluation trajectories""" 57 | checkpoint: Optional[str] = None 58 | """path to a pretrained checkpoint file to start evaluation/training from""" 59 | 60 | # Environment specific arguments 61 | env_id: str = "PickCube-v1" 62 | """the id of the environment""" 63 | env_vectorization: str = "gpu" 64 | """the type of environment vectorization to use""" 65 | num_envs: int = 512 66 | """the number of parallel environments""" 67 | num_eval_envs: int = 16 68 | """the number of parallel evaluation environments""" 69 | partial_reset: bool = True 70 | """whether to let parallel environments reset upon termination instead of truncation""" 71 | num_steps: int = 50 72 | """the number of steps to run in each environment per policy rollout""" 73 | num_eval_steps: int = 50 74 | """the number of steps to run in each evaluation environment during evaluation""" 75 | reconfiguration_freq: Optional[int] = None 76 | """how often to reconfigure the environment during training""" 77 | eval_reconfiguration_freq: Optional[int] = 1 78 | """for benchmarking purposes we want to reconfigure the eval environment each reset to ensure objects are randomized in some tasks""" 79 | eval_freq: int = 25 80 | """evaluation frequency in terms of iterations""" 81 | save_train_video_freq: Optional[int] = None 82 | """frequency to save training videos in terms of iterations""" 83 | 84 | # Algorithm specific arguments 85 | total_timesteps: int = 10000000 86 | """total timesteps of the experiments""" 87 | learning_rate: float = 3e-4 88 | """the learning rate of the optimizer""" 89 | anneal_lr: bool = False 90 | """Toggle learning rate annealing for policy and value networks""" 91 | gamma: float = 0.8 92 | """the discount factor gamma""" 93 | gae_lambda: float = 0.9 94 | """the lambda for the general advantage estimation""" 95 | num_minibatches: int = 32 96 | """the number of mini-batches""" 97 | update_epochs: int = 4 98 | """the K epochs to update the policy""" 99 | norm_adv: bool = True 100 | """Toggles advantages normalization""" 101 | clip_coef: float = 0.2 102 | """the surrogate clipping coefficient""" 103 | clip_vloss: bool = False 104 | """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" 105 | ent_coef: float = 0.0 106 | """coefficient of the entropy""" 107 | vf_coef: float = 0.5 108 | """coefficient of the value function""" 109 | max_grad_norm: float = 0.5 110 | """the maximum norm for the gradient clipping""" 111 | target_kl: float = 0.1 112 | """the target KL divergence threshold""" 113 | reward_scale: float = 1.0 114 | """Scale the reward by this factor""" 115 | finite_horizon_gae: bool = False 116 | 117 | # to be filled in runtime 118 | batch_size: int = 0 119 | """the batch size (computed in runtime)""" 120 | minibatch_size: int = 0 121 | """the mini-batch size (computed in runtime)""" 122 | num_iterations: int = 0 123 | """the number of iterations (computed in runtime)""" 124 | 125 | # Torch optimizations 126 | compile: bool = False 127 | """whether to use torch.compile.""" 128 | cudagraphs: bool = False 129 | """whether to use cudagraphs on top of compile.""" 130 | 131 | def layer_init(layer, std=np.sqrt(2), bias_const=0.0): 132 | torch.nn.init.orthogonal_(layer.weight, std) 133 | torch.nn.init.constant_(layer.bias, bias_const) 134 | return layer 135 | 136 | 137 | class Agent(nn.Module): 138 | def __init__(self, n_obs, n_act, device=None): 139 | super().__init__() 140 | self.critic = nn.Sequential( 141 | layer_init(nn.Linear(n_obs, 256, device=device)), 142 | nn.Tanh(), 143 | layer_init(nn.Linear(256, 256, device=device)), 144 | nn.Tanh(), 145 | layer_init(nn.Linear(256, 256, device=device)), 146 | nn.Tanh(), 147 | layer_init(nn.Linear(256, 1, device=device)), 148 | ) 149 | self.actor_mean = nn.Sequential( 150 | layer_init(nn.Linear(n_obs, 256, device=device)), 151 | nn.Tanh(), 152 | layer_init(nn.Linear(256, 256, device=device)), 153 | nn.Tanh(), 154 | layer_init(nn.Linear(256, 256, device=device)), 155 | nn.Tanh(), 156 | layer_init(nn.Linear(256, n_act, device=device), std=0.01*np.sqrt(2)), 157 | ) 158 | self.actor_logstd = nn.Parameter(torch.zeros(1, n_act, device=device)) 159 | 160 | def get_value(self, x): 161 | return self.critic(x) 162 | 163 | def get_action_and_value(self, obs, action=None): 164 | action_mean = self.actor_mean(obs) 165 | action_logstd = self.actor_logstd.expand_as(action_mean) 166 | action_std = torch.exp(action_logstd) 167 | probs = Normal(action_mean, action_std) 168 | if action is None: 169 | action = action_mean + action_std * torch.randn_like(action_mean) 170 | return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(obs) 171 | 172 | class Logger: 173 | def __init__(self, log_wandb=False, tensorboard: SummaryWriter = None) -> None: 174 | self.writer = tensorboard 175 | self.log_wandb = log_wandb 176 | def add_scalar(self, tag, scalar_value, step): 177 | if self.log_wandb: 178 | wandb.log({tag: scalar_value}, step=step) 179 | self.writer.add_scalar(tag, scalar_value, step) 180 | def close(self): 181 | self.writer.close() 182 | 183 | def gae(next_obs, next_done, container, final_values): 184 | # bootstrap value if not done 185 | next_value = get_value(next_obs).reshape(-1) 186 | lastgaelam = 0 187 | nextnonterminals = (~container["dones"]).float().unbind(0) 188 | vals = container["vals"] 189 | vals_unbind = vals.unbind(0) 190 | rewards = container["rewards"].unbind(0) 191 | 192 | advantages = [] 193 | nextnonterminal = (~next_done).float() 194 | nextvalues = next_value 195 | for t in range(args.num_steps - 1, -1, -1): 196 | cur_val = vals_unbind[t] 197 | # real_next_values = nextvalues * nextnonterminal 198 | real_next_values = nextnonterminal * nextvalues + final_values[t] # t instead of t+1 199 | delta = rewards[t] + args.gamma * real_next_values - cur_val 200 | advantages.append(delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam) 201 | lastgaelam = advantages[-1] 202 | 203 | nextnonterminal = nextnonterminals[t] 204 | nextvalues = cur_val 205 | 206 | advantages = container["advantages"] = torch.stack(list(reversed(advantages))) 207 | container["returns"] = advantages + vals 208 | return container 209 | 210 | 211 | def rollout(obs, done): 212 | ts = [] 213 | final_values = torch.zeros((args.num_steps, args.num_envs), device=device) 214 | for step in range(args.num_steps): 215 | # ALGO LOGIC: action logic 216 | action, logprob, _, value = policy(obs=obs) 217 | 218 | # TRY NOT TO MODIFY: execute the game and log data. 219 | next_obs, reward, next_done, infos = step_func(action) 220 | 221 | if "final_info" in infos: 222 | final_info = infos["final_info"] 223 | done_mask = infos["_final_info"] 224 | for k, v in final_info["episode"].items(): 225 | logger.add_scalar(f"train/{k}", v[done_mask].float().mean(), global_step) 226 | with torch.no_grad(): 227 | final_values[step, torch.arange(args.num_envs, device=device)[done_mask]] = agent.get_value(infos["final_observation"][done_mask]).view(-1) 228 | 229 | ts.append( 230 | tensordict.TensorDict._new_unsafe( 231 | obs=obs, 232 | # cleanrl ppo examples associate the done with the previous obs (not the done resulting from action) 233 | dones=done, 234 | vals=value.flatten(), 235 | actions=action, 236 | logprobs=logprob, 237 | rewards=reward, 238 | batch_size=(args.num_envs,), 239 | ) 240 | ) 241 | # NOTE (stao): change here for gpu env 242 | obs = next_obs = next_obs 243 | done = next_done 244 | # NOTE (stao): need to do .to(device) i think? otherwise container.device is None, not sure if this affects anything 245 | container = torch.stack(ts, 0).to(device) 246 | return next_obs, done, container, final_values 247 | 248 | 249 | def update(obs, actions, logprobs, advantages, returns, vals): 250 | optimizer.zero_grad() 251 | _, newlogprob, entropy, newvalue = agent.get_action_and_value(obs, actions) 252 | logratio = newlogprob - logprobs 253 | ratio = logratio.exp() 254 | 255 | with torch.no_grad(): 256 | # calculate approx_kl http://joschu.net/blog/kl-approx.html 257 | old_approx_kl = (-logratio).mean() 258 | approx_kl = ((ratio - 1) - logratio).mean() 259 | clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean() 260 | 261 | if args.norm_adv: 262 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) 263 | 264 | # Policy loss 265 | pg_loss1 = -advantages * ratio 266 | pg_loss2 = -advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) 267 | pg_loss = torch.max(pg_loss1, pg_loss2).mean() 268 | 269 | # Value loss 270 | newvalue = newvalue.view(-1) 271 | if args.clip_vloss: 272 | v_loss_unclipped = (newvalue - returns) ** 2 273 | v_clipped = vals + torch.clamp( 274 | newvalue - vals, 275 | -args.clip_coef, 276 | args.clip_coef, 277 | ) 278 | v_loss_clipped = (v_clipped - returns) ** 2 279 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) 280 | v_loss = 0.5 * v_loss_max.mean() 281 | else: 282 | v_loss = 0.5 * ((newvalue - returns) ** 2).mean() 283 | 284 | entropy_loss = entropy.mean() 285 | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef 286 | 287 | loss.backward() 288 | gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) 289 | optimizer.step() 290 | 291 | return approx_kl, v_loss.detach(), pg_loss.detach(), entropy_loss.detach(), old_approx_kl, clipfrac, gn 292 | 293 | 294 | update = tensordict.nn.TensorDictModule( 295 | update, 296 | in_keys=["obs", "actions", "logprobs", "advantages", "returns", "vals"], 297 | out_keys=["approx_kl", "v_loss", "pg_loss", "entropy_loss", "old_approx_kl", "clipfrac", "gn"], 298 | ) 299 | 300 | if __name__ == "__main__": 301 | args = tyro.cli(Args) 302 | 303 | batch_size = int(args.num_envs * args.num_steps) 304 | args.minibatch_size = batch_size // args.num_minibatches 305 | args.batch_size = args.num_minibatches * args.minibatch_size 306 | args.num_iterations = args.total_timesteps // args.batch_size 307 | if args.exp_name is None: 308 | args.exp_name = os.path.basename(__file__)[: -len(".py")] 309 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" 310 | else: 311 | run_name = args.exp_name 312 | 313 | # TRY NOT TO MODIFY: seeding 314 | random.seed(args.seed) 315 | np.random.seed(args.seed) 316 | torch.manual_seed(args.seed) 317 | torch.backends.cudnn.deterministic = args.torch_deterministic 318 | 319 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 320 | 321 | ####### Environment setup ####### 322 | env_kwargs = dict(obs_mode="state", control_mode="pd_joint_delta_pos", render_mode="rgb_array", sim_backend="gpu") 323 | envs = gym.make(args.env_id, num_envs=args.num_envs if not args.evaluate else 1, reconfiguration_freq=args.reconfiguration_freq, **env_kwargs) 324 | eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, reconfiguration_freq=args.eval_reconfiguration_freq, human_render_camera_configs=dict(shader_pack="default"),**env_kwargs) 325 | if isinstance(envs.action_space, gym.spaces.Dict): 326 | envs = FlattenActionSpaceWrapper(envs) 327 | eval_envs = FlattenActionSpaceWrapper(eval_envs) 328 | if args.capture_video: 329 | eval_output_dir = f"runs/{run_name}/videos" 330 | if args.evaluate: 331 | eval_output_dir = f"{os.path.dirname(args.checkpoint)}/test_videos" 332 | print(f"Saving eval videos to {eval_output_dir}") 333 | if args.save_train_video_freq is not None: 334 | save_video_trigger = lambda x : (x // args.num_steps) % args.save_train_video_freq == 0 335 | envs = RecordEpisode(envs, output_dir=f"runs/{run_name}/train_videos", save_trajectory=False, save_video_trigger=save_video_trigger, max_steps_per_video=args.num_steps, video_fps=30) 336 | eval_envs = RecordEpisode(eval_envs, output_dir=eval_output_dir, save_trajectory=args.evaluate, trajectory_name="trajectory", max_steps_per_video=args.num_eval_steps, video_fps=30) 337 | envs = ManiSkillVectorEnv(envs, args.num_envs, ignore_terminations=not args.partial_reset, record_metrics=True) 338 | eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=True, record_metrics=True) 339 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 340 | 341 | max_episode_steps = gym_utils.find_max_episode_steps_value(envs._env) 342 | logger = None 343 | if not args.evaluate: 344 | print("Running training") 345 | if args.track: 346 | import wandb 347 | config = vars(args) 348 | config["env_cfg"] = dict(**env_kwargs, num_envs=args.num_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=args.partial_reset) 349 | config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=False) 350 | wandb.init( 351 | project=args.wandb_project_name, 352 | entity=args.wandb_entity, 353 | sync_tensorboard=False, 354 | config=config, 355 | name=run_name, 356 | save_code=True, 357 | group=args.wandb_group, 358 | tags=["ppo", "walltime_efficient", f"GPU:{torch.cuda.get_device_name()}"] 359 | ) 360 | writer = SummaryWriter(f"runs/{run_name}") 361 | writer.add_text( 362 | "hyperparameters", 363 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 364 | ) 365 | logger = Logger(log_wandb=args.track, tensorboard=writer) 366 | else: 367 | print("Running evaluation") 368 | n_act = math.prod(envs.single_action_space.shape) 369 | n_obs = math.prod(envs.single_observation_space.shape) 370 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" 371 | 372 | # Register step as a special op not to graph break 373 | # @torch.library.custom_op("mylib::step", mutates_args=()) 374 | def step_func(action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 375 | # NOTE (stao): change here for gpu env 376 | next_obs, reward, terminations, truncations, info = envs.step(action) 377 | next_done = torch.logical_or(terminations, truncations) 378 | return next_obs, reward, next_done, info 379 | 380 | ####### Agent ####### 381 | agent = Agent(n_obs, n_act, device=device) 382 | if args.checkpoint: 383 | agent.load_state_dict(torch.load(args.checkpoint)) 384 | # Make a version of agent with detached params 385 | agent_inference = Agent(n_obs, n_act, device=device) 386 | agent_inference_p = from_module(agent).data 387 | agent_inference_p.to_module(agent_inference) 388 | 389 | ####### Optimizer ####### 390 | optimizer = optim.Adam( 391 | agent.parameters(), 392 | lr=torch.tensor(args.learning_rate, device=device), 393 | eps=1e-5, 394 | capturable=args.cudagraphs and not args.compile, 395 | ) 396 | 397 | ####### Executables ####### 398 | # Define networks: wrapping the policy in a TensorDictModule allows us to use CudaGraphModule 399 | policy = agent_inference.get_action_and_value 400 | get_value = agent_inference.get_value 401 | 402 | # Compile policy 403 | if args.compile: 404 | policy = torch.compile(policy) 405 | gae = torch.compile(gae, fullgraph=True) 406 | update = torch.compile(update) 407 | 408 | if args.cudagraphs: 409 | policy = CudaGraphModule(policy) 410 | gae = CudaGraphModule(gae) 411 | update = CudaGraphModule(update) 412 | 413 | global_step = 0 414 | start_time = time.time() 415 | container_local = None 416 | next_obs = envs.reset()[0] 417 | next_done = torch.zeros(args.num_envs, device=device, dtype=torch.bool) 418 | pbar = tqdm.tqdm(range(1, args.num_iterations + 1)) 419 | 420 | cumulative_times = defaultdict(float) 421 | 422 | for iteration in pbar: 423 | agent.eval() 424 | if iteration % args.eval_freq == 1: 425 | stime = time.perf_counter() 426 | eval_obs, _ = eval_envs.reset() 427 | eval_metrics = defaultdict(list) 428 | num_episodes = 0 429 | for _ in range(args.num_eval_steps): 430 | with torch.no_grad(): 431 | eval_obs, eval_rew, eval_terminations, eval_truncations, eval_infos = eval_envs.step(agent.actor_mean(eval_obs)) 432 | if "final_info" in eval_infos: 433 | mask = eval_infos["_final_info"] 434 | num_episodes += mask.sum() 435 | for k, v in eval_infos["final_info"]["episode"].items(): 436 | eval_metrics[k].append(v) 437 | eval_metrics_mean = {} 438 | for k, v in eval_metrics.items(): 439 | mean = torch.stack(v).float().mean() 440 | eval_metrics_mean[k] = mean 441 | if logger is not None: 442 | logger.add_scalar(f"eval/{k}", mean, global_step) 443 | pbar.set_description( 444 | f"success_once: {eval_metrics_mean['success_once']:.2f}, " 445 | f"return: {eval_metrics_mean['return']:.2f}" 446 | ) 447 | if logger is not None: 448 | eval_time = time.perf_counter() - stime 449 | cumulative_times["eval_time"] += eval_time 450 | logger.add_scalar("time/eval_time", eval_time, global_step) 451 | if args.evaluate: 452 | break 453 | if args.save_model and iteration % args.eval_freq == 1: 454 | model_path = f"runs/{run_name}/ckpt_{iteration}.pt" 455 | torch.save(agent.state_dict(), model_path) 456 | print(f"model saved to {model_path}") 457 | # Annealing the rate if instructed to do so. 458 | if args.anneal_lr: 459 | frac = 1.0 - (iteration - 1.0) / args.num_iterations 460 | lrnow = frac * args.learning_rate 461 | optimizer.param_groups[0]["lr"].copy_(lrnow) 462 | 463 | torch.compiler.cudagraph_mark_step_begin() 464 | rollout_time = time.perf_counter() 465 | next_obs, next_done, container, final_values = rollout(next_obs, next_done) 466 | rollout_time = time.perf_counter() - rollout_time 467 | cumulative_times["rollout_time"] += rollout_time 468 | global_step += container.numel() 469 | 470 | update_time = time.perf_counter() 471 | container = gae(next_obs, next_done, container, final_values) 472 | container_flat = container.view(-1) 473 | 474 | # Optimizing the policy and value network 475 | clipfracs = [] 476 | for epoch in range(args.update_epochs): 477 | b_inds = torch.randperm(container_flat.shape[0], device=device).split(args.minibatch_size) 478 | for b in b_inds: 479 | container_local = container_flat[b] 480 | 481 | out = update(container_local, tensordict_out=tensordict.TensorDict()) 482 | clipfracs.append(out["clipfrac"]) 483 | if args.target_kl is not None and out["approx_kl"] > args.target_kl: 484 | break 485 | else: 486 | continue 487 | break 488 | update_time = time.perf_counter() - update_time 489 | cumulative_times["update_time"] += update_time 490 | 491 | logger.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) 492 | logger.add_scalar("losses/value_loss", out["v_loss"].item(), global_step) 493 | logger.add_scalar("losses/policy_loss", out["pg_loss"].item(), global_step) 494 | logger.add_scalar("losses/entropy", out["entropy_loss"].item(), global_step) 495 | logger.add_scalar("losses/old_approx_kl", out["old_approx_kl"].item(), global_step) 496 | logger.add_scalar("losses/approx_kl", out["approx_kl"].item(), global_step) 497 | logger.add_scalar("losses/clipfrac", torch.stack(clipfracs).mean().cpu().item(), global_step) 498 | # logger.add_scalar("losses/explained_variance", explained_var, global_step) 499 | logger.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) 500 | logger.add_scalar("time/step", global_step, global_step) 501 | logger.add_scalar("time/update_time", update_time, global_step) 502 | logger.add_scalar("time/rollout_time", rollout_time, global_step) 503 | logger.add_scalar("time/rollout_fps", args.num_envs * args.num_steps / rollout_time, global_step) 504 | for k, v in cumulative_times.items(): 505 | logger.add_scalar(f"time/total_{k}", v, global_step) 506 | logger.add_scalar("time/total_rollout+update_time", cumulative_times["rollout_time"] + cumulative_times["update_time"], global_step) 507 | 508 | envs.close() 509 | eval_envs.close() 510 | -------------------------------------------------------------------------------- /records/baseline/script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | seeds=(9351 4796 1788 2371 3462 4553 5644 6735 7826 8917) 3 | num_envs=4096 4 | update_epochs=12 5 | num_steps=4 6 | num_minibatches=32 7 | exp_name="baseline" 8 | for seed in ${seeds[@]} 9 | do 10 | python ppo.py --env_id="PickCube-v1" --seed=${seed} --total_timesteps=4_000_000 --eval-freq=10 \ 11 | --num_envs=${num_envs} --update_epochs=${update_epochs} --num_minibatches=${num_minibatches} \ 12 | --num-steps=${num_steps} \ 13 | --exp-name="ppo-PickCube-v1-state-${seed}-${exp_name}" \ 14 | --track --wandb_project_name="PPO-ManiSkill-GPU-SpeedRun" --wandb_group=${exp_name} --no-capture_video 15 | done -------------------------------------------------------------------------------- /rl_robotics_speedrun_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "collapsed_sections": [ 8 | "cj-pTw-Les0o" 9 | ], 10 | "machine_shape": "hm", 11 | "gpuType": "L4", 12 | "authorship_tag": "ABX9TyNsoFWzLLgUSaUtkeYRPI4F", 13 | "include_colab_link": true 14 | }, 15 | "kernelspec": { 16 | "name": "python3", 17 | "display_name": "Python 3" 18 | }, 19 | "language_info": { 20 | "name": "python" 21 | }, 22 | "accelerator": "GPU" 23 | }, 24 | "cells": [ 25 | { 26 | "cell_type": "markdown", 27 | "metadata": { 28 | "id": "view-in-github", 29 | "colab_type": "text" 30 | }, 31 | "source": [ 32 | "\"Open" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": { 38 | "id": "cj-pTw-Les0o" 39 | }, 40 | "source": [ 41 | "# Setup Code\n", 42 | "\n", 43 | "To begin, prepare the colab environment by switching to a GPU environment (Runtime -> Change Runtime Type)\n", 44 | "\n", 45 | "Then click the play button below. This will install all dependencies for the future code, namely ManiSkill, Torch, TorchRL, and Tensordict" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 1, 51 | "metadata": { 52 | "id": "pURa3mYwevk0", 53 | "outputId": "e4a70a8f-1d54-4b3a-848d-e987500b54c0", 54 | "colab": { 55 | "base_uri": "https://localhost:8080/" 56 | } 57 | }, 58 | "outputs": [ 59 | { 60 | "output_type": "stream", 61 | "name": "stdout", 62 | "text": [ 63 | "Reading package lists... Done\n", 64 | "Building dependency tree... Done\n", 65 | "Reading state information... Done\n", 66 | "The following additional packages will be installed:\n", 67 | " libvulkan1\n", 68 | "Recommended packages:\n", 69 | " mesa-vulkan-drivers | vulkan-icd\n", 70 | "The following NEW packages will be installed:\n", 71 | " libvulkan-dev libvulkan1\n", 72 | "0 upgraded, 2 newly installed, 0 to remove and 49 not upgraded.\n", 73 | "Need to get 1,020 kB of archives.\n", 74 | "After this operation, 17.2 MB of additional disk space will be used.\n", 75 | "Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 libvulkan1 amd64 1.3.204.1-2 [128 kB]\n", 76 | "Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 libvulkan-dev amd64 1.3.204.1-2 [892 kB]\n", 77 | "Fetched 1,020 kB in 1s (706 kB/s)\n", 78 | "Selecting previously unselected package libvulkan1:amd64.\n", 79 | "(Reading database ... 123623 files and directories currently installed.)\n", 80 | "Preparing to unpack .../libvulkan1_1.3.204.1-2_amd64.deb ...\n", 81 | "Unpacking libvulkan1:amd64 (1.3.204.1-2) ...\n", 82 | "Selecting previously unselected package libvulkan-dev:amd64.\n", 83 | "Preparing to unpack .../libvulkan-dev_1.3.204.1-2_amd64.deb ...\n", 84 | "Unpacking libvulkan-dev:amd64 (1.3.204.1-2) ...\n", 85 | "Setting up libvulkan1:amd64 (1.3.204.1-2) ...\n", 86 | "Setting up libvulkan-dev:amd64 (1.3.204.1-2) ...\n", 87 | "Processing triggers for libc-bin (2.35-0ubuntu3.4) ...\n", 88 | "/sbin/ldconfig.real: /usr/local/lib/libtcm.so.1 is not a symbolic link\n", 89 | "\n", 90 | "/sbin/ldconfig.real: /usr/local/lib/libur_adapter_opencl.so.0 is not a symbolic link\n", 91 | "\n", 92 | "/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link\n", 93 | "\n", 94 | "/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link\n", 95 | "\n", 96 | "/sbin/ldconfig.real: /usr/local/lib/libhwloc.so.15 is not a symbolic link\n", 97 | "\n", 98 | "/sbin/ldconfig.real: /usr/local/lib/libur_loader.so.0 is not a symbolic link\n", 99 | "\n", 100 | "/sbin/ldconfig.real: /usr/local/lib/libur_adapter_level_zero.so.0 is not a symbolic link\n", 101 | "\n", 102 | "/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link\n", 103 | "\n", 104 | "/sbin/ldconfig.real: /usr/local/lib/libumf.so.0 is not a symbolic link\n", 105 | "\n", 106 | "/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link\n", 107 | "\n", 108 | "/sbin/ldconfig.real: /usr/local/lib/libtcm_debug.so.1 is not a symbolic link\n", 109 | "\n", 110 | "/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link\n", 111 | "\n", 112 | "/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link\n", 113 | "\n", 114 | "Cloning into 'rl-robotics-speedrun'...\n", 115 | "remote: Enumerating objects: 37, done.\u001b[K\n", 116 | "remote: Counting objects: 100% (37/37), done.\u001b[K\n", 117 | "remote: Compressing objects: 100% (25/25), done.\u001b[K\n", 118 | "remote: Total 37 (delta 9), reused 33 (delta 9), pack-reused 0 (from 0)\u001b[K\n", 119 | "Receiving objects: 100% (37/37), 15.73 KiB | 236.00 KiB/s, done.\n", 120 | "Resolving deltas: 100% (9/9), done.\n", 121 | "Obtaining file:///content/rl-robotics-speedrun\n", 122 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 123 | "Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (from rl-robotics-speedrun==0.0.1) (0.18.5)\n", 124 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from rl-robotics-speedrun==0.0.1) (4.66.6)\n", 125 | "Collecting mani_skill (from rl-robotics-speedrun==0.0.1)\n", 126 | " Downloading mani_skill-3.0.0b12-py3-none-any.whl.metadata (3.2 kB)\n", 127 | "Requirement already satisfied: tensorboard in /usr/local/lib/python3.10/dist-packages (from rl-robotics-speedrun==0.0.1) (2.17.0)\n", 128 | "Collecting torchrl (from rl-robotics-speedrun==0.0.1)\n", 129 | " Downloading torchrl-0.6.0-cp310-cp310-manylinux1_x86_64.whl.metadata (39 kB)\n", 130 | "Collecting tensordict (from rl-robotics-speedrun==0.0.1)\n", 131 | " Downloading tensordict-0.6.0-cp310-cp310-manylinux1_x86_64.whl.metadata (8.9 kB)\n", 132 | "Requirement already satisfied: numpy<2.0.0,>=1.22 in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (1.26.4)\n", 133 | "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (1.13.1)\n", 134 | "Collecting dacite (from mani_skill->rl-robotics-speedrun==0.0.1)\n", 135 | " Downloading dacite-1.8.1-py3-none-any.whl.metadata (15 kB)\n", 136 | "Collecting gymnasium==0.29.1 (from mani_skill->rl-robotics-speedrun==0.0.1)\n", 137 | " Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)\n", 138 | "Collecting sapien==3.0.0.b1 (from mani_skill->rl-robotics-speedrun==0.0.1)\n", 139 | " Downloading sapien-3.0.0b1-cp310-cp310-manylinux2014_x86_64.whl.metadata (10 kB)\n", 140 | "Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (3.12.1)\n", 141 | "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (6.0.2)\n", 142 | "Requirement already satisfied: GitPython in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (3.1.43)\n", 143 | "Requirement already satisfied: tabulate in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (0.9.0)\n", 144 | "Collecting transforms3d (from mani_skill->rl-robotics-speedrun==0.0.1)\n", 145 | " Downloading transforms3d-0.4.2-py3-none-any.whl.metadata (2.8 kB)\n", 146 | "Collecting trimesh (from mani_skill->rl-robotics-speedrun==0.0.1)\n", 147 | " Downloading trimesh-4.5.1-py3-none-any.whl.metadata (18 kB)\n", 148 | "Requirement already satisfied: imageio in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (2.36.0)\n", 149 | "Requirement already satisfied: IPython in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (7.34.0)\n", 150 | "Collecting pytorch-kinematics==0.7.4 (from mani_skill->rl-robotics-speedrun==0.0.1)\n", 151 | " Downloading pytorch_kinematics-0.7.4-py3-none-any.whl.metadata (14 kB)\n", 152 | "Collecting pynvml (from mani_skill->rl-robotics-speedrun==0.0.1)\n", 153 | " Downloading pynvml-11.5.3-py3-none-any.whl.metadata (8.8 kB)\n", 154 | "Collecting tyro>=0.8.5 (from mani_skill->rl-robotics-speedrun==0.0.1)\n", 155 | " Downloading tyro-0.8.14-py3-none-any.whl.metadata (8.4 kB)\n", 156 | "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (0.24.7)\n", 157 | "Collecting mplib==0.1.1 (from mani_skill->rl-robotics-speedrun==0.0.1)\n", 158 | " Downloading mplib-0.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.7 kB)\n", 159 | "Collecting fast-kinematics==0.2.2 (from mani_skill->rl-robotics-speedrun==0.0.1)\n", 160 | " Downloading fast_kinematics-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)\n", 161 | "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium==0.29.1->mani_skill->rl-robotics-speedrun==0.0.1) (3.1.0)\n", 162 | "Requirement already satisfied: typing-extensions>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium==0.29.1->mani_skill->rl-robotics-speedrun==0.0.1) (4.12.2)\n", 163 | "Collecting farama-notifications>=0.0.1 (from gymnasium==0.29.1->mani_skill->rl-robotics-speedrun==0.0.1)\n", 164 | " Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)\n", 165 | "Collecting toppra>=0.4.0 (from mplib==0.1.1->mani_skill->rl-robotics-speedrun==0.0.1)\n", 166 | " Downloading toppra-0.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (3.4 kB)\n", 167 | "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (1.4.0)\n", 168 | "Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (5.3.0)\n", 169 | "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (2.5.0+cu121)\n", 170 | "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (3.8.0)\n", 171 | "Collecting pytorch-seed (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1)\n", 172 | " Downloading pytorch_seed-0.2.0-py3-none-any.whl.metadata (3.8 kB)\n", 173 | "Collecting arm-pytorch-utilities (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1)\n", 174 | " Downloading arm_pytorch_utilities-0.4.3-py3-none-any.whl.metadata (2.6 kB)\n", 175 | "Requirement already satisfied: requests>=2.22 in /usr/local/lib/python3.10/dist-packages (from sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (2.32.3)\n", 176 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (3.4.2)\n", 177 | "Requirement already satisfied: pyperclip in /usr/local/lib/python3.10/dist-packages (from sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (1.9.0)\n", 178 | "Requirement already satisfied: opencv-python>=4.0 in /usr/local/lib/python3.10/dist-packages (from sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (4.10.0.84)\n", 179 | "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (1.64.1)\n", 180 | "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (3.7)\n", 181 | "Requirement already satisfied: protobuf!=4.24.0,<5.0.0,>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (3.20.3)\n", 182 | "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (75.1.0)\n", 183 | "Requirement already satisfied: six>1.9 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (1.16.0)\n", 184 | "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (0.7.2)\n", 185 | "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (3.0.6)\n", 186 | "Requirement already satisfied: orjson in /usr/local/lib/python3.10/dist-packages (from tensordict->rl-robotics-speedrun==0.0.1) (3.10.10)\n", 187 | "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from torchrl->rl-robotics-speedrun==0.0.1) (24.1)\n", 188 | "Requirement already satisfied: click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (8.1.7)\n", 189 | "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (0.4.0)\n", 190 | "Requirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (4.3.6)\n", 191 | "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (5.9.5)\n", 192 | "Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (2.17.0)\n", 193 | "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (1.3.3)\n", 194 | "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython->mani_skill->rl-robotics-speedrun==0.0.1) (4.0.11)\n", 195 | "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.22->sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (3.4.0)\n", 196 | "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.22->sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (3.10)\n", 197 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.22->sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (2.2.3)\n", 198 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.22->sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (2024.8.30)\n", 199 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (3.16.1)\n", 200 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (3.1.4)\n", 201 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (2024.10.0)\n", 202 | "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (1.13.1)\n", 203 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (1.3.0)\n", 204 | "Requirement already satisfied: docstring-parser>=0.16 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.8.5->mani_skill->rl-robotics-speedrun==0.0.1) (0.16)\n", 205 | "Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.8.5->mani_skill->rl-robotics-speedrun==0.0.1) (13.9.3)\n", 206 | "Collecting shtab>=1.5.6 (from tyro>=0.8.5->mani_skill->rl-robotics-speedrun==0.0.1)\n", 207 | " Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)\n", 208 | "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard->rl-robotics-speedrun==0.0.1) (3.0.2)\n", 209 | "Requirement already satisfied: pillow>=8.3.2 in /usr/local/lib/python3.10/dist-packages (from imageio->mani_skill->rl-robotics-speedrun==0.0.1) (10.4.0)\n", 210 | "Requirement already satisfied: imageio-ffmpeg in /usr/local/lib/python3.10/dist-packages (from imageio[ffmpeg]->mani_skill->rl-robotics-speedrun==0.0.1) (0.5.1)\n", 211 | "Collecting jedi>=0.16 (from IPython->mani_skill->rl-robotics-speedrun==0.0.1)\n", 212 | " Downloading jedi-0.19.1-py2.py3-none-any.whl.metadata (22 kB)\n", 213 | "Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (4.4.2)\n", 214 | "Requirement already satisfied: pickleshare in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.7.5)\n", 215 | "Requirement already satisfied: traitlets>=4.2 in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (5.7.1)\n", 216 | "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (3.0.48)\n", 217 | "Requirement already satisfied: pygments in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (2.18.0)\n", 218 | "Requirement already satisfied: backcall in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.2.0)\n", 219 | "Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.1.7)\n", 220 | "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (4.9.0)\n", 221 | "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython->mani_skill->rl-robotics-speedrun==0.0.1) (5.0.1)\n", 222 | "Requirement already satisfied: parso<0.9.0,>=0.8.3 in /usr/local/lib/python3.10/dist-packages (from jedi>=0.16->IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.8.4)\n", 223 | "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.10/dist-packages (from pexpect>4.3->IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.7.0)\n", 224 | "Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.2.13)\n", 225 | "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.8.5->mani_skill->rl-robotics-speedrun==0.0.1) (3.0.0)\n", 226 | "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (1.3.0)\n", 227 | "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (0.12.1)\n", 228 | "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (4.54.1)\n", 229 | "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (1.4.7)\n", 230 | "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (3.2.0)\n", 231 | "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (2.8.2)\n", 232 | "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.8.5->mani_skill->rl-robotics-speedrun==0.0.1) (0.1.2)\n", 233 | "Downloading mani_skill-3.0.0b12-py3-none-any.whl (79.3 MB)\n", 234 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.3/79.3 MB\u001b[0m \u001b[31m29.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 235 | "\u001b[?25hDownloading fast_kinematics-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (623 kB)\n", 236 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m623.8/623.8 kB\u001b[0m \u001b[31m45.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 237 | "\u001b[?25hDownloading gymnasium-0.29.1-py3-none-any.whl (953 kB)\n", 238 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m953.9/953.9 kB\u001b[0m \u001b[31m60.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 239 | "\u001b[?25hDownloading mplib-0.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.0 MB)\n", 240 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.0/12.0 MB\u001b[0m \u001b[31m63.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 241 | "\u001b[?25hDownloading pytorch_kinematics-0.7.4-py3-none-any.whl (59 kB)\n", 242 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.6/59.6 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 243 | "\u001b[?25hDownloading sapien-3.0.0b1-cp310-cp310-manylinux2014_x86_64.whl (49.6 MB)\n", 244 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.6/49.6 MB\u001b[0m \u001b[31m45.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 245 | "\u001b[?25hDownloading tensordict-0.6.0-cp310-cp310-manylinux1_x86_64.whl (353 kB)\n", 246 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m353.9/353.9 kB\u001b[0m \u001b[31m33.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 247 | "\u001b[?25hDownloading torchrl-0.6.0-cp310-cp310-manylinux1_x86_64.whl (1.0 MB)\n", 248 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m65.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 249 | "\u001b[?25hDownloading transforms3d-0.4.2-py3-none-any.whl (1.4 MB)\n", 250 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m76.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 251 | "\u001b[?25hDownloading tyro-0.8.14-py3-none-any.whl (109 kB)\n", 252 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m109.8/109.8 kB\u001b[0m \u001b[31m13.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 253 | "\u001b[?25hDownloading dacite-1.8.1-py3-none-any.whl (14 kB)\n", 254 | "Downloading pynvml-11.5.3-py3-none-any.whl (53 kB)\n", 255 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 256 | "\u001b[?25hDownloading trimesh-4.5.1-py3-none-any.whl (703 kB)\n", 257 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m703.7/703.7 kB\u001b[0m \u001b[31m49.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 258 | "\u001b[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\n", 259 | "Downloading jedi-0.19.1-py2.py3-none-any.whl (1.6 MB)\n", 260 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m77.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 261 | "\u001b[?25hDownloading shtab-1.7.1-py3-none-any.whl (14 kB)\n", 262 | "Downloading toppra-0.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (638 kB)\n", 263 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m638.2/638.2 kB\u001b[0m \u001b[31m52.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 264 | "\u001b[?25hDownloading arm_pytorch_utilities-0.4.3-py3-none-any.whl (40 kB)\n", 265 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.0/41.0 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 266 | "\u001b[?25hDownloading pytorch_seed-0.2.0-py3-none-any.whl (4.2 kB)\n", 267 | "Installing collected packages: farama-notifications, trimesh, transforms3d, shtab, pynvml, jedi, gymnasium, fast-kinematics, dacite, sapien, tyro, toppra, tensordict, pytorch-seed, torchrl, mplib, arm-pytorch-utilities, pytorch-kinematics, mani_skill, rl-robotics-speedrun\n", 268 | " Running setup.py develop for rl-robotics-speedrun\n", 269 | "Successfully installed arm-pytorch-utilities-0.4.3 dacite-1.8.1 farama-notifications-0.0.4 fast-kinematics-0.2.2 gymnasium-0.29.1 jedi-0.19.1 mani_skill-3.0.0b12 mplib-0.1.1 pynvml-11.5.3 pytorch-kinematics-0.7.4 pytorch-seed-0.2.0 rl-robotics-speedrun-0.0.1 sapien-3.0.0b1 shtab-1.7.1 tensordict-0.6.0 toppra-0.6.0 torchrl-0.6.0 transforms3d-0.4.2 trimesh-4.5.1 tyro-0.8.14\n" 270 | ] 271 | } 272 | ], 273 | "source": [ 274 | "# setup vulkan\n", 275 | "!mkdir -p /usr/share/vulkan/icd.d\n", 276 | "!wget -q https://raw.githubusercontent.com/haosulab/ManiSkill/main/docker/nvidia_icd.json\n", 277 | "!wget -q https://raw.githubusercontent.com/haosulab/ManiSkill/main/docker/10_nvidia.json\n", 278 | "!mv nvidia_icd.json /usr/share/vulkan/icd.d\n", 279 | "!mv 10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json\n", 280 | "!apt-get install -y --no-install-recommends libvulkan-dev\n", 281 | "# dependencies\n", 282 | "!git clone https://github.com/StoneT2000/rl-robotics-speedrun\n", 283 | "!cd rl-robotics-speedrun && pip install -e ." 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 2, 289 | "metadata": { 290 | "id": "_LC6CX2LeJu6" 291 | }, 292 | "outputs": [], 293 | "source": [ 294 | "try:\n", 295 | " import google.colab\n", 296 | " IN_COLAB = True\n", 297 | "except:\n", 298 | " IN_COLAB = False\n", 299 | "\n", 300 | "if IN_COLAB:\n", 301 | " import site\n", 302 | " site.main() # run this so local pip installs are recognized" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "source": [ 308 | "# Benchmarking" 309 | ], 310 | "metadata": { 311 | "id": "VAOJ_Y9w9Lgf" 312 | } 313 | }, 314 | { 315 | "cell_type": "code", 316 | "source": [ 317 | "# login to upload / track aggregated results, you need to provide your wandb API key\n", 318 | "import wandb\n", 319 | "wandb.login()" 320 | ], 321 | "metadata": { 322 | "colab": { 323 | "base_uri": "https://localhost:8080/", 324 | "height": 86 325 | }, 326 | "id": "9fqoUMxR9N4z", 327 | "outputId": "81f123fd-3bc2-4e81-e1c9-ea894ef7e492" 328 | }, 329 | "execution_count": 3, 330 | "outputs": [ 331 | { 332 | "output_type": "stream", 333 | "name": "stderr", 334 | "text": [ 335 | "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n" 336 | ] 337 | }, 338 | { 339 | "output_type": "display_data", 340 | "data": { 341 | "text/plain": [ 342 | "" 343 | ], 344 | "application/javascript": [ 345 | "\n", 346 | " window._wandbApiKey = new Promise((resolve, reject) => {\n", 347 | " function loadScript(url) {\n", 348 | " return new Promise(function(resolve, reject) {\n", 349 | " let newScript = document.createElement(\"script\");\n", 350 | " newScript.onerror = reject;\n", 351 | " newScript.onload = resolve;\n", 352 | " document.body.appendChild(newScript);\n", 353 | " newScript.src = url;\n", 354 | " });\n", 355 | " }\n", 356 | " loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n", 357 | " const iframe = document.createElement('iframe')\n", 358 | " iframe.style.cssText = \"width:0;height:0;border:none\"\n", 359 | " document.body.appendChild(iframe)\n", 360 | " const handshake = new Postmate({\n", 361 | " container: iframe,\n", 362 | " url: 'https://wandb.ai/authorize'\n", 363 | " });\n", 364 | " const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n", 365 | " handshake.then(function(child) {\n", 366 | " child.on('authorize', data => {\n", 367 | " clearTimeout(timeout)\n", 368 | " resolve(data)\n", 369 | " });\n", 370 | " });\n", 371 | " })\n", 372 | " });\n", 373 | " " 374 | ] 375 | }, 376 | "metadata": {} 377 | }, 378 | { 379 | "output_type": "stream", 380 | "name": "stderr", 381 | "text": [ 382 | "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" 383 | ] 384 | }, 385 | { 386 | "output_type": "execute_result", 387 | "data": { 388 | "text/plain": [ 389 | "True" 390 | ] 391 | }, 392 | "metadata": {}, 393 | "execution_count": 3 394 | } 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "source": [ 400 | "# look at the training script, make modifications if needed\n", 401 | "!cat rl-robotics-speedrun/records/10312024_cudagraphs/script.sh\n", 402 | "# ppo code at rl-robotics-speedrun/records/10312024_cudagraphs/ppo.py" 403 | ], 404 | "metadata": { 405 | "colab": { 406 | "base_uri": "https://localhost:8080/" 407 | }, 408 | "id": "imJpNrbSF5UM", 409 | "outputId": "8935418d-f194-48be-9c78-8a9f8f6a21ab" 410 | }, 411 | "execution_count": 5, 412 | "outputs": [ 413 | { 414 | "output_type": "stream", 415 | "name": "stdout", 416 | "text": [ 417 | "#!/bin/bash\n", 418 | "seeds=(9351 4796 1788 2371 3462 4553 5644 6735 7826 8917)\n", 419 | "num_envs=4096\n", 420 | "update_epochs=12\n", 421 | "num_steps=4\n", 422 | "num_minibatches=32\n", 423 | "exp_name=\"10312024_cudagraphs\"\n", 424 | "for seed in ${seeds[@]}\n", 425 | "do\n", 426 | " python ppo.py --env_id=\"PickCube-v1\" --seed=${seed} --total_timesteps=4_000_000 --eval-freq=10 \\\n", 427 | " --num_envs=${num_envs} --update_epochs=${update_epochs} --num_minibatches=${num_minibatches} \\\n", 428 | " --num-steps=${num_steps} --cudagraphs \\\n", 429 | " --exp-name=\"ppo-PickCube-v1-state-${seed}-${exp_name}\" \\\n", 430 | " --track --wandb_project_name=\"PPO-ManiSkill-GPU-SpeedRun\" --wandb_group=${exp_name} --no-capture_video\n", 431 | "done" 432 | ] 433 | } 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "source": [ 439 | "!cd rl-robotics-speedrun/records/10312024_cudagraphs/ && bash script.sh" 440 | ], 441 | "metadata": { 442 | "colab": { 443 | "base_uri": "https://localhost:8080/" 444 | }, 445 | "id": "owjFU8HIClgW", 446 | "outputId": "a2904e53-7cc3-4386-dceb-f0989fc0b801" 447 | }, 448 | "execution_count": null, 449 | "outputs": [ 450 | { 451 | "output_type": "stream", 452 | "name": "stdout", 453 | "text": [ 454 | "2024-10-31 19:41:14.749861: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", 455 | "2024-10-31 19:41:14.765622: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", 456 | "2024-10-31 19:41:14.787173: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", 457 | "2024-10-31 19:41:14.793659: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", 458 | "2024-10-31 19:41:14.809284: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", 459 | "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 460 | "2024-10-31 19:41:16.055538: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", 461 | "Downloading PhysX GPU library to /root/.sapien/physx/105.1-physx-5.3.1.patch0 from Github. This can take several minutes. If it fails to download, please manually download fhttps://github.com/sapien-sim/physx-precompiled/releases/download/105.1-physx-5.3.1.patch0/linux-so.zip and unzip at /root/.sapien/physx/105.1-physx-5.3.1.patch0.\n", 462 | "Download complete.\n", 463 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.\u001b[0m\n", 464 | " logger.warn(\n", 465 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.\u001b[0m\n", 466 | " logger.warn(\n", 467 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.max_episode_steps to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.max_episode_steps` for environment variables or `env.get_wrapper_attr('max_episode_steps')` that will search the reminding wrappers.\u001b[0m\n", 468 | " logger.warn(\n", 469 | "Running training\n", 470 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mstonet2000\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", 471 | "\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.18.5\n", 472 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run data is saved locally in \u001b[35m\u001b[1m/content/rl-robotics-speedrun/records/10312024_cudagraphs/wandb/run-20241031_194142-tir5o14p\u001b[0m\n", 473 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run \u001b[1m`wandb offline`\u001b[0m to turn off syncing.\n", 474 | "\u001b[34m\u001b[1mwandb\u001b[0m: Syncing run \u001b[33mppo-PickCube-v1-state-9351-10312024_cudagraphs\u001b[0m\n", 475 | "\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at \u001b[34m\u001b[4mhttps://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun\u001b[0m\n", 476 | "\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run at \u001b[34m\u001b[4mhttps://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun/runs/tir5o14p\u001b[0m\n", 477 | "/usr/local/lib/python3.10/dist-packages/tensordict/nn/cudagraphs.py:194: UserWarning: Tensordict is registered in PyTree. This is incompatible with CudaGraphModule. Removing TDs from PyTree. To silence this warning, call tensordict.nn.functional_module._exclude_td_from_pytree().set() or set the environment variable `EXCLUDE_TD_FROM_PYTREE=1`. This operation is irreversible.\n", 478 | " warnings.warn(\n", 479 | "success_once: 0.88, return: 37.81: 100% 244/244 [03:58<00:00, 1.02it/s]\n", 480 | "2024-10-31 19:45:53.083376: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", 481 | "2024-10-31 19:45:53.101255: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", 482 | "2024-10-31 19:45:53.122746: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", 483 | "2024-10-31 19:45:53.129584: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", 484 | "2024-10-31 19:45:53.145547: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", 485 | "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 486 | "2024-10-31 19:45:54.411087: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", 487 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.\u001b[0m\n", 488 | " logger.warn(\n", 489 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.\u001b[0m\n", 490 | " logger.warn(\n", 491 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.max_episode_steps to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.max_episode_steps` for environment variables or `env.get_wrapper_attr('max_episode_steps')` that will search the reminding wrappers.\u001b[0m\n", 492 | " logger.warn(\n", 493 | "Running training\n", 494 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mstonet2000\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", 495 | "\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.18.5\n", 496 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run data is saved locally in \u001b[35m\u001b[1m/content/rl-robotics-speedrun/records/10312024_cudagraphs/wandb/run-20241031_194616-mf31p5np\u001b[0m\n", 497 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run \u001b[1m`wandb offline`\u001b[0m to turn off syncing.\n", 498 | "\u001b[34m\u001b[1mwandb\u001b[0m: Syncing run \u001b[33mppo-PickCube-v1-state-4796-10312024_cudagraphs\u001b[0m\n", 499 | "\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at \u001b[34m\u001b[4mhttps://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun\u001b[0m\n", 500 | "\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run at \u001b[34m\u001b[4mhttps://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun/runs/mf31p5np\u001b[0m\n", 501 | "/usr/local/lib/python3.10/dist-packages/tensordict/nn/cudagraphs.py:194: UserWarning: Tensordict is registered in PyTree. This is incompatible with CudaGraphModule. Removing TDs from PyTree. To silence this warning, call tensordict.nn.functional_module._exclude_td_from_pytree().set() or set the environment variable `EXCLUDE_TD_FROM_PYTREE=1`. This operation is irreversible.\n", 502 | " warnings.warn(\n", 503 | "success_once: 0.00, return: 19.96: 22% 53/244 [01:06<04:15, 1.34s/it]" 504 | ] 505 | } 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "source": [], 511 | "metadata": { 512 | "id": "pCmwy4__GIEz" 513 | }, 514 | "execution_count": null, 515 | "outputs": [] 516 | } 517 | ] 518 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="rl-robotics-speedrun", 5 | version="0.0.1", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "wandb", 9 | "tqdm", 10 | "mani_skill", 11 | "tensorboard", 12 | "torchrl", 13 | "tensordict" 14 | ], 15 | author="Stone Tao", 16 | description="Speed-running solving robot manipulation tasks", 17 | long_description=open("README.md").read(), 18 | long_description_content_type="text/markdown", 19 | url="https://github.com/StoneT2000/rl-robotics-speedrun", 20 | python_requires=">=3.10", 21 | ) 22 | --------------------------------------------------------------------------------