├── .gitignore
├── LICENSE
├── README.md
├── assets
├── PickCube-v1_eval_return_4090.png
└── PickCube-v1_eval_success_once_4090.png
├── records
├── 10312024_cudagraphs
│ ├── ppo.py
│ └── script.sh
└── baseline
│ ├── ppo.py
│ └── script.sh
├── rl_robotics_speedrun_colab.ipynb
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | runs/
2 | .vscode/
3 | wandb/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Stone Tao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RL Robotics Speedrun
2 |
3 | Speed-running solving robot manipulation tasks in [ManiSkill](https://github.com/haosulab/ManiSkill). Goal is to simply solve a list of tasks as fast as possible with RL + fixed dense rewards starting from scratch.
4 |
5 |
6 | Inspired by the [great speedrunning work done for LLMs by Jordan et. al](https://github.com/KellerJordan/modded-nanogpt)
7 |
8 | ## Getting Started
9 |
10 | You can run the speedrun/benchmark via your local machine or [google colab](https://colab.research.google.com/github/StoneT2000/rl-robotics-speedrun/blob/main/rl_robotics_speedrun_colab.ipynb). To set up locally, we recommend using conda/mamba to manage dependencies:
11 |
12 | ```bash
13 | conda create -n rl-robotics-speedrun python=3.11
14 | conda activate rl-robotics-speedrun
15 | git clone https://github.com/StoneT2000/rl-robotics-speedrun
16 | cd rl-robotics-speedrun
17 | pip install -e .
18 | ```
19 |
20 | ## Benchmarking
21 |
22 | To run the benchmark `cd` into one of the folders and run the script.sh file.
23 |
24 | This by default logs results to tensorboard and wandb. You will need to setup a [wandb account](https://wandb.ai/). Wandb helps better display aggregated results.
25 |
26 | ```bash
27 | cd records/baseline && bash script.sh
28 | ```
29 |
30 | The current standard is to run PPO initialized with random weights on 10 different seeds. By default this does not record evaluation videos. Remove `--no-capture_video` to record videos. Remove `--track` to not use wandb. Each of the 10 runs takes about 2 minutes to run on a 4090.
31 |
32 | Current best is `records/10312024_cudagraphs` which is standard PPO + GPU Simulation, very few steps per environment during rollouts, and cudagraphs enabled based on [leanrl](https://github.com/pytorch-labs/LeanRL/). Achieves >= 95% success rate after ~80s
33 |
34 |
35 | | Environment | Best Time | Wandb Results |
36 | |------------|-----------|---------------|
37 | | PickCube-v1 (state) | 80s | [Link](https://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun/workspace?nw=qgul0t4vstq) |
38 | | PickCube-v1 (visual) | ~ | ~ |
39 | | PushT-v1 (state) | ~ | ~ |
40 | | PushT-v1 (visual) | ~ | ~ |
41 |
42 | All results on 4090: [Wandb Link](https://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun/workspace?nw=qgul0t4vstq)
43 |
44 | All results on L4 (Google Colab GPU): [Wandb Link](https://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun/workspace?nw=i9kpqaqywjd)
45 |
46 | Figure of results on 4090 for PickCube-v1 across 10 seeds. Shaded area is the standard error:
47 | 
48 | 
--------------------------------------------------------------------------------
/assets/PickCube-v1_eval_return_4090.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StoneT2000/rl-robotics-speedrun/5a9cf05ec0ecec2cf39f4469200e2446b3cd5a6f/assets/PickCube-v1_eval_return_4090.png
--------------------------------------------------------------------------------
/assets/PickCube-v1_eval_success_once_4090.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StoneT2000/rl-robotics-speedrun/5a9cf05ec0ecec2cf39f4469200e2446b3cd5a6f/assets/PickCube-v1_eval_success_once_4090.png
--------------------------------------------------------------------------------
/records/10312024_cudagraphs/ppo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from mani_skill.utils import gym_utils
4 | from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
5 | from mani_skill.utils.wrappers.record import RecordEpisode
6 | from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
7 |
8 | os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"
9 |
10 | import math
11 | import os
12 | import random
13 | import time
14 | from collections import defaultdict, deque
15 | from dataclasses import dataclass
16 | from typing import Optional, Tuple
17 |
18 | import gymnasium as gym
19 | import numpy as np
20 | import tensordict
21 | import torch
22 | import torch.nn as nn
23 | import torch.optim as optim
24 | import tqdm
25 | import tyro
26 | from torch.utils.tensorboard import SummaryWriter
27 | import wandb
28 | from tensordict import from_module
29 | from tensordict.nn import CudaGraphModule
30 | from torch.distributions.normal import Normal
31 |
32 |
33 | @dataclass
34 | class Args:
35 | exp_name: Optional[str] = None
36 | """the name of this experiment"""
37 | seed: int = 1
38 | """seed of the experiment"""
39 | torch_deterministic: bool = True
40 | """if toggled, `torch.backends.cudnn.deterministic=False`"""
41 | cuda: bool = True
42 | """if toggled, cuda will be enabled by default"""
43 | track: bool = False
44 | """if toggled, this experiment will be tracked with Weights and Biases"""
45 | wandb_project_name: str = "ManiSkill"
46 | """the wandb's project name"""
47 | wandb_entity: Optional[str] = None
48 | """the entity (team) of wandb's project"""
49 | wandb_group: str = "PPO"
50 | """the group of the run for wandb"""
51 | capture_video: bool = True
52 | """whether to capture videos of the agent performances (check out `videos` folder)"""
53 | save_model: bool = False
54 | """whether to save model into the `runs/{run_name}` folder"""
55 | evaluate: bool = False
56 | """if toggled, only runs evaluation with the given model checkpoint and saves the evaluation trajectories"""
57 | checkpoint: Optional[str] = None
58 | """path to a pretrained checkpoint file to start evaluation/training from"""
59 |
60 | # Environment specific arguments
61 | env_id: str = "PickCube-v1"
62 | """the id of the environment"""
63 | env_vectorization: str = "gpu"
64 | """the type of environment vectorization to use"""
65 | num_envs: int = 512
66 | """the number of parallel environments"""
67 | num_eval_envs: int = 16
68 | """the number of parallel evaluation environments"""
69 | partial_reset: bool = True
70 | """whether to let parallel environments reset upon termination instead of truncation"""
71 | num_steps: int = 50
72 | """the number of steps to run in each environment per policy rollout"""
73 | num_eval_steps: int = 50
74 | """the number of steps to run in each evaluation environment during evaluation"""
75 | reconfiguration_freq: Optional[int] = None
76 | """how often to reconfigure the environment during training"""
77 | eval_reconfiguration_freq: Optional[int] = 1
78 | """for benchmarking purposes we want to reconfigure the eval environment each reset to ensure objects are randomized in some tasks"""
79 | eval_freq: int = 25
80 | """evaluation frequency in terms of iterations"""
81 | save_train_video_freq: Optional[int] = None
82 | """frequency to save training videos in terms of iterations"""
83 |
84 | # Algorithm specific arguments
85 | total_timesteps: int = 10000000
86 | """total timesteps of the experiments"""
87 | learning_rate: float = 3e-4
88 | """the learning rate of the optimizer"""
89 | anneal_lr: bool = False
90 | """Toggle learning rate annealing for policy and value networks"""
91 | gamma: float = 0.8
92 | """the discount factor gamma"""
93 | gae_lambda: float = 0.9
94 | """the lambda for the general advantage estimation"""
95 | num_minibatches: int = 32
96 | """the number of mini-batches"""
97 | update_epochs: int = 4
98 | """the K epochs to update the policy"""
99 | norm_adv: bool = True
100 | """Toggles advantages normalization"""
101 | clip_coef: float = 0.2
102 | """the surrogate clipping coefficient"""
103 | clip_vloss: bool = False
104 | """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
105 | ent_coef: float = 0.0
106 | """coefficient of the entropy"""
107 | vf_coef: float = 0.5
108 | """coefficient of the value function"""
109 | max_grad_norm: float = 0.5
110 | """the maximum norm for the gradient clipping"""
111 | target_kl: float = 0.1
112 | """the target KL divergence threshold"""
113 | reward_scale: float = 1.0
114 | """Scale the reward by this factor"""
115 | finite_horizon_gae: bool = False
116 |
117 | # to be filled in runtime
118 | batch_size: int = 0
119 | """the batch size (computed in runtime)"""
120 | minibatch_size: int = 0
121 | """the mini-batch size (computed in runtime)"""
122 | num_iterations: int = 0
123 | """the number of iterations (computed in runtime)"""
124 |
125 | # Torch optimizations
126 | compile: bool = False
127 | """whether to use torch.compile."""
128 | cudagraphs: bool = False
129 | """whether to use cudagraphs on top of compile."""
130 |
131 | def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
132 | torch.nn.init.orthogonal_(layer.weight, std)
133 | torch.nn.init.constant_(layer.bias, bias_const)
134 | return layer
135 |
136 |
137 | class Agent(nn.Module):
138 | def __init__(self, n_obs, n_act, device=None):
139 | super().__init__()
140 | self.critic = nn.Sequential(
141 | layer_init(nn.Linear(n_obs, 256, device=device)),
142 | nn.Tanh(),
143 | layer_init(nn.Linear(256, 256, device=device)),
144 | nn.Tanh(),
145 | layer_init(nn.Linear(256, 256, device=device)),
146 | nn.Tanh(),
147 | layer_init(nn.Linear(256, 1, device=device)),
148 | )
149 | self.actor_mean = nn.Sequential(
150 | layer_init(nn.Linear(n_obs, 256, device=device)),
151 | nn.Tanh(),
152 | layer_init(nn.Linear(256, 256, device=device)),
153 | nn.Tanh(),
154 | layer_init(nn.Linear(256, 256, device=device)),
155 | nn.Tanh(),
156 | layer_init(nn.Linear(256, n_act, device=device), std=0.01*np.sqrt(2)),
157 | )
158 | self.actor_logstd = nn.Parameter(torch.zeros(1, n_act, device=device))
159 |
160 | def get_value(self, x):
161 | return self.critic(x)
162 |
163 | def get_action_and_value(self, obs, action=None):
164 | action_mean = self.actor_mean(obs)
165 | action_logstd = self.actor_logstd.expand_as(action_mean)
166 | action_std = torch.exp(action_logstd)
167 | probs = Normal(action_mean, action_std)
168 | if action is None:
169 | action = action_mean + action_std * torch.randn_like(action_mean)
170 | return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(obs)
171 |
172 | class Logger:
173 | def __init__(self, log_wandb=False, tensorboard: SummaryWriter = None) -> None:
174 | self.writer = tensorboard
175 | self.log_wandb = log_wandb
176 | def add_scalar(self, tag, scalar_value, step):
177 | if self.log_wandb:
178 | wandb.log({tag: scalar_value}, step=step)
179 | self.writer.add_scalar(tag, scalar_value, step)
180 | def close(self):
181 | self.writer.close()
182 |
183 | def gae(next_obs, next_done, container, final_values):
184 | # bootstrap value if not done
185 | next_value = get_value(next_obs).reshape(-1)
186 | lastgaelam = 0
187 | nextnonterminals = (~container["dones"]).float().unbind(0)
188 | vals = container["vals"]
189 | vals_unbind = vals.unbind(0)
190 | rewards = container["rewards"].unbind(0)
191 |
192 | advantages = []
193 | nextnonterminal = (~next_done).float()
194 | nextvalues = next_value
195 | for t in range(args.num_steps - 1, -1, -1):
196 | cur_val = vals_unbind[t]
197 | # real_next_values = nextvalues * nextnonterminal
198 | real_next_values = nextnonterminal * nextvalues + final_values[t] # t instead of t+1
199 | delta = rewards[t] + args.gamma * real_next_values - cur_val
200 | advantages.append(delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam)
201 | lastgaelam = advantages[-1]
202 |
203 | nextnonterminal = nextnonterminals[t]
204 | nextvalues = cur_val
205 |
206 | advantages = container["advantages"] = torch.stack(list(reversed(advantages)))
207 | container["returns"] = advantages + vals
208 | return container
209 |
210 |
211 | def rollout(obs, done):
212 | ts = []
213 | final_values = torch.zeros((args.num_steps, args.num_envs), device=device)
214 | for step in range(args.num_steps):
215 | # ALGO LOGIC: action logic
216 | action, logprob, _, value = policy(obs=obs)
217 |
218 | # TRY NOT TO MODIFY: execute the game and log data.
219 | next_obs, reward, next_done, infos = step_func(action)
220 |
221 | if "final_info" in infos:
222 | final_info = infos["final_info"]
223 | done_mask = infos["_final_info"]
224 | for k, v in final_info["episode"].items():
225 | logger.add_scalar(f"train/{k}", v[done_mask].float().mean(), global_step)
226 | with torch.no_grad():
227 | final_values[step, torch.arange(args.num_envs, device=device)[done_mask]] = agent.get_value(infos["final_observation"][done_mask]).view(-1)
228 |
229 | ts.append(
230 | tensordict.TensorDict._new_unsafe(
231 | obs=obs,
232 | # cleanrl ppo examples associate the done with the previous obs (not the done resulting from action)
233 | dones=done,
234 | vals=value.flatten(),
235 | actions=action,
236 | logprobs=logprob,
237 | rewards=reward,
238 | batch_size=(args.num_envs,),
239 | )
240 | )
241 | # NOTE (stao): change here for gpu env
242 | obs = next_obs = next_obs
243 | done = next_done
244 | # NOTE (stao): need to do .to(device) i think? otherwise container.device is None, not sure if this affects anything
245 | container = torch.stack(ts, 0).to(device)
246 | return next_obs, done, container, final_values
247 |
248 |
249 | def update(obs, actions, logprobs, advantages, returns, vals):
250 | optimizer.zero_grad()
251 | _, newlogprob, entropy, newvalue = agent.get_action_and_value(obs, actions)
252 | logratio = newlogprob - logprobs
253 | ratio = logratio.exp()
254 |
255 | with torch.no_grad():
256 | # calculate approx_kl http://joschu.net/blog/kl-approx.html
257 | old_approx_kl = (-logratio).mean()
258 | approx_kl = ((ratio - 1) - logratio).mean()
259 | clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
260 |
261 | if args.norm_adv:
262 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
263 |
264 | # Policy loss
265 | pg_loss1 = -advantages * ratio
266 | pg_loss2 = -advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
267 | pg_loss = torch.max(pg_loss1, pg_loss2).mean()
268 |
269 | # Value loss
270 | newvalue = newvalue.view(-1)
271 | if args.clip_vloss:
272 | v_loss_unclipped = (newvalue - returns) ** 2
273 | v_clipped = vals + torch.clamp(
274 | newvalue - vals,
275 | -args.clip_coef,
276 | args.clip_coef,
277 | )
278 | v_loss_clipped = (v_clipped - returns) ** 2
279 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
280 | v_loss = 0.5 * v_loss_max.mean()
281 | else:
282 | v_loss = 0.5 * ((newvalue - returns) ** 2).mean()
283 |
284 | entropy_loss = entropy.mean()
285 | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
286 |
287 | loss.backward()
288 | gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
289 | optimizer.step()
290 |
291 | return approx_kl, v_loss.detach(), pg_loss.detach(), entropy_loss.detach(), old_approx_kl, clipfrac, gn
292 |
293 |
294 | update = tensordict.nn.TensorDictModule(
295 | update,
296 | in_keys=["obs", "actions", "logprobs", "advantages", "returns", "vals"],
297 | out_keys=["approx_kl", "v_loss", "pg_loss", "entropy_loss", "old_approx_kl", "clipfrac", "gn"],
298 | )
299 |
300 | if __name__ == "__main__":
301 | args = tyro.cli(Args)
302 |
303 | batch_size = int(args.num_envs * args.num_steps)
304 | args.minibatch_size = batch_size // args.num_minibatches
305 | args.batch_size = args.num_minibatches * args.minibatch_size
306 | args.num_iterations = args.total_timesteps // args.batch_size
307 | if args.exp_name is None:
308 | args.exp_name = os.path.basename(__file__)[: -len(".py")]
309 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
310 | else:
311 | run_name = args.exp_name
312 |
313 | # TRY NOT TO MODIFY: seeding
314 | random.seed(args.seed)
315 | np.random.seed(args.seed)
316 | torch.manual_seed(args.seed)
317 | torch.backends.cudnn.deterministic = args.torch_deterministic
318 |
319 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
320 |
321 | ####### Environment setup #######
322 | env_kwargs = dict(obs_mode="state", control_mode="pd_joint_delta_pos", render_mode="rgb_array", sim_backend="gpu")
323 | envs = gym.make(args.env_id, num_envs=args.num_envs if not args.evaluate else 1, reconfiguration_freq=args.reconfiguration_freq, **env_kwargs)
324 | eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, reconfiguration_freq=args.eval_reconfiguration_freq, human_render_camera_configs=dict(shader_pack="default"),**env_kwargs)
325 | if isinstance(envs.action_space, gym.spaces.Dict):
326 | envs = FlattenActionSpaceWrapper(envs)
327 | eval_envs = FlattenActionSpaceWrapper(eval_envs)
328 | if args.capture_video:
329 | eval_output_dir = f"runs/{run_name}/videos"
330 | if args.evaluate:
331 | eval_output_dir = f"{os.path.dirname(args.checkpoint)}/test_videos"
332 | print(f"Saving eval videos to {eval_output_dir}")
333 | if args.save_train_video_freq is not None:
334 | save_video_trigger = lambda x : (x // args.num_steps) % args.save_train_video_freq == 0
335 | envs = RecordEpisode(envs, output_dir=f"runs/{run_name}/train_videos", save_trajectory=False, save_video_trigger=save_video_trigger, max_steps_per_video=args.num_steps, video_fps=30)
336 | eval_envs = RecordEpisode(eval_envs, output_dir=eval_output_dir, save_trajectory=args.evaluate, trajectory_name="trajectory", max_steps_per_video=args.num_eval_steps, video_fps=30)
337 | envs = ManiSkillVectorEnv(envs, args.num_envs, ignore_terminations=not args.partial_reset, record_metrics=True)
338 | eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=True, record_metrics=True)
339 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
340 |
341 | max_episode_steps = gym_utils.find_max_episode_steps_value(envs._env)
342 | logger = None
343 | if not args.evaluate:
344 | print("Running training")
345 | if args.track:
346 | import wandb
347 | config = vars(args)
348 | config["env_cfg"] = dict(**env_kwargs, num_envs=args.num_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=args.partial_reset)
349 | config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=False)
350 | wandb.init(
351 | project=args.wandb_project_name,
352 | entity=args.wandb_entity,
353 | sync_tensorboard=False,
354 | config=config,
355 | name=run_name,
356 | save_code=True,
357 | group=args.wandb_group,
358 | tags=["ppo", "walltime_efficient", f"GPU:{torch.cuda.get_device_name()}"]
359 | )
360 | writer = SummaryWriter(f"runs/{run_name}")
361 | writer.add_text(
362 | "hyperparameters",
363 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
364 | )
365 | logger = Logger(log_wandb=args.track, tensorboard=writer)
366 | else:
367 | print("Running evaluation")
368 | n_act = math.prod(envs.single_action_space.shape)
369 | n_obs = math.prod(envs.single_observation_space.shape)
370 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
371 |
372 | # Register step as a special op not to graph break
373 | # @torch.library.custom_op("mylib::step", mutates_args=())
374 | def step_func(action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
375 | # NOTE (stao): change here for gpu env
376 | next_obs, reward, terminations, truncations, info = envs.step(action)
377 | next_done = torch.logical_or(terminations, truncations)
378 | return next_obs, reward, next_done, info
379 |
380 | ####### Agent #######
381 | agent = Agent(n_obs, n_act, device=device)
382 | if args.checkpoint:
383 | agent.load_state_dict(torch.load(args.checkpoint))
384 | # Make a version of agent with detached params
385 | agent_inference = Agent(n_obs, n_act, device=device)
386 | agent_inference_p = from_module(agent).data
387 | agent_inference_p.to_module(agent_inference)
388 |
389 | ####### Optimizer #######
390 | optimizer = optim.Adam(
391 | agent.parameters(),
392 | lr=torch.tensor(args.learning_rate, device=device),
393 | eps=1e-5,
394 | capturable=args.cudagraphs and not args.compile,
395 | )
396 |
397 | ####### Executables #######
398 | # Define networks: wrapping the policy in a TensorDictModule allows us to use CudaGraphModule
399 | policy = agent_inference.get_action_and_value
400 | get_value = agent_inference.get_value
401 |
402 | # Compile policy
403 | if args.compile:
404 | policy = torch.compile(policy)
405 | gae = torch.compile(gae, fullgraph=True)
406 | update = torch.compile(update)
407 |
408 | if args.cudagraphs:
409 | policy = CudaGraphModule(policy)
410 | gae = CudaGraphModule(gae)
411 | update = CudaGraphModule(update)
412 |
413 | global_step = 0
414 | start_time = time.time()
415 | container_local = None
416 | next_obs = envs.reset()[0]
417 | next_done = torch.zeros(args.num_envs, device=device, dtype=torch.bool)
418 | pbar = tqdm.tqdm(range(1, args.num_iterations + 1))
419 |
420 | cumulative_times = defaultdict(float)
421 |
422 | for iteration in pbar:
423 | agent.eval()
424 | if iteration % args.eval_freq == 1:
425 | stime = time.perf_counter()
426 | eval_obs, _ = eval_envs.reset()
427 | eval_metrics = defaultdict(list)
428 | num_episodes = 0
429 | for _ in range(args.num_eval_steps):
430 | with torch.no_grad():
431 | eval_obs, eval_rew, eval_terminations, eval_truncations, eval_infos = eval_envs.step(agent.actor_mean(eval_obs))
432 | if "final_info" in eval_infos:
433 | mask = eval_infos["_final_info"]
434 | num_episodes += mask.sum()
435 | for k, v in eval_infos["final_info"]["episode"].items():
436 | eval_metrics[k].append(v)
437 | eval_metrics_mean = {}
438 | for k, v in eval_metrics.items():
439 | mean = torch.stack(v).float().mean()
440 | eval_metrics_mean[k] = mean
441 | if logger is not None:
442 | logger.add_scalar(f"eval/{k}", mean, global_step)
443 | pbar.set_description(
444 | f"success_once: {eval_metrics_mean['success_once']:.2f}, "
445 | f"return: {eval_metrics_mean['return']:.2f}"
446 | )
447 | if logger is not None:
448 | eval_time = time.perf_counter() - stime
449 | cumulative_times["eval_time"] += eval_time
450 | logger.add_scalar("time/eval_time", eval_time, global_step)
451 | if args.evaluate:
452 | break
453 | if args.save_model and iteration % args.eval_freq == 1:
454 | model_path = f"runs/{run_name}/ckpt_{iteration}.pt"
455 | torch.save(agent.state_dict(), model_path)
456 | print(f"model saved to {model_path}")
457 | # Annealing the rate if instructed to do so.
458 | if args.anneal_lr:
459 | frac = 1.0 - (iteration - 1.0) / args.num_iterations
460 | lrnow = frac * args.learning_rate
461 | optimizer.param_groups[0]["lr"].copy_(lrnow)
462 |
463 | torch.compiler.cudagraph_mark_step_begin()
464 | rollout_time = time.perf_counter()
465 | next_obs, next_done, container, final_values = rollout(next_obs, next_done)
466 | rollout_time = time.perf_counter() - rollout_time
467 | cumulative_times["rollout_time"] += rollout_time
468 | global_step += container.numel()
469 |
470 | update_time = time.perf_counter()
471 | container = gae(next_obs, next_done, container, final_values)
472 | container_flat = container.view(-1)
473 |
474 | # Optimizing the policy and value network
475 | clipfracs = []
476 | for epoch in range(args.update_epochs):
477 | b_inds = torch.randperm(container_flat.shape[0], device=device).split(args.minibatch_size)
478 | for b in b_inds:
479 | container_local = container_flat[b]
480 |
481 | out = update(container_local, tensordict_out=tensordict.TensorDict())
482 | clipfracs.append(out["clipfrac"])
483 | if args.target_kl is not None and out["approx_kl"] > args.target_kl:
484 | break
485 | else:
486 | continue
487 | break
488 | update_time = time.perf_counter() - update_time
489 | cumulative_times["update_time"] += update_time
490 |
491 | logger.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
492 | logger.add_scalar("losses/value_loss", out["v_loss"].item(), global_step)
493 | logger.add_scalar("losses/policy_loss", out["pg_loss"].item(), global_step)
494 | logger.add_scalar("losses/entropy", out["entropy_loss"].item(), global_step)
495 | logger.add_scalar("losses/old_approx_kl", out["old_approx_kl"].item(), global_step)
496 | logger.add_scalar("losses/approx_kl", out["approx_kl"].item(), global_step)
497 | logger.add_scalar("losses/clipfrac", torch.stack(clipfracs).mean().cpu().item(), global_step)
498 | # logger.add_scalar("losses/explained_variance", explained_var, global_step)
499 | logger.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
500 | logger.add_scalar("time/step", global_step, global_step)
501 | logger.add_scalar("time/update_time", update_time, global_step)
502 | logger.add_scalar("time/rollout_time", rollout_time, global_step)
503 | logger.add_scalar("time/rollout_fps", args.num_envs * args.num_steps / rollout_time, global_step)
504 | for k, v in cumulative_times.items():
505 | logger.add_scalar(f"time/total_{k}", v, global_step)
506 | logger.add_scalar("time/total_rollout+update_time", cumulative_times["rollout_time"] + cumulative_times["update_time"], global_step)
507 |
508 | envs.close()
509 | eval_envs.close()
510 |
--------------------------------------------------------------------------------
/records/10312024_cudagraphs/script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | seeds=(9351 4796 1788 2371 3462 4553 5644 6735 7826 8917)
3 | num_envs=4096
4 | update_epochs=12
5 | num_steps=4
6 | num_minibatches=32
7 | exp_name="10312024_cudagraphs"
8 | for seed in ${seeds[@]}
9 | do
10 | python ppo.py --env_id="PickCube-v1" --seed=${seed} --total_timesteps=4_000_000 --eval-freq=10 \
11 | --num_envs=${num_envs} --update_epochs=${update_epochs} --num_minibatches=${num_minibatches} \
12 | --num-steps=${num_steps} --cudagraphs \
13 | --exp-name="ppo-PickCube-v1-state-${seed}-${exp_name}" \
14 | --track --wandb_project_name="PPO-ManiSkill-GPU-SpeedRun" --wandb_group=${exp_name} --no-capture_video
15 | done
--------------------------------------------------------------------------------
/records/baseline/ppo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from mani_skill.utils import gym_utils
4 | from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
5 | from mani_skill.utils.wrappers.record import RecordEpisode
6 | from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
7 |
8 | os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"
9 |
10 | import math
11 | import os
12 | import random
13 | import time
14 | from collections import defaultdict, deque
15 | from dataclasses import dataclass
16 | from typing import Optional, Tuple
17 |
18 | import gymnasium as gym
19 | import numpy as np
20 | import tensordict
21 | import torch
22 | import torch.nn as nn
23 | import torch.optim as optim
24 | import tqdm
25 | import tyro
26 | from torch.utils.tensorboard import SummaryWriter
27 | import wandb
28 | from tensordict import from_module
29 | from tensordict.nn import CudaGraphModule
30 | from torch.distributions.normal import Normal
31 |
32 |
33 | @dataclass
34 | class Args:
35 | exp_name: Optional[str] = None
36 | """the name of this experiment"""
37 | seed: int = 1
38 | """seed of the experiment"""
39 | torch_deterministic: bool = True
40 | """if toggled, `torch.backends.cudnn.deterministic=False`"""
41 | cuda: bool = True
42 | """if toggled, cuda will be enabled by default"""
43 | track: bool = False
44 | """if toggled, this experiment will be tracked with Weights and Biases"""
45 | wandb_project_name: str = "ManiSkill"
46 | """the wandb's project name"""
47 | wandb_entity: Optional[str] = None
48 | """the entity (team) of wandb's project"""
49 | wandb_group: str = "PPO"
50 | """the group of the run for wandb"""
51 | capture_video: bool = True
52 | """whether to capture videos of the agent performances (check out `videos` folder)"""
53 | save_model: bool = False
54 | """whether to save model into the `runs/{run_name}` folder"""
55 | evaluate: bool = False
56 | """if toggled, only runs evaluation with the given model checkpoint and saves the evaluation trajectories"""
57 | checkpoint: Optional[str] = None
58 | """path to a pretrained checkpoint file to start evaluation/training from"""
59 |
60 | # Environment specific arguments
61 | env_id: str = "PickCube-v1"
62 | """the id of the environment"""
63 | env_vectorization: str = "gpu"
64 | """the type of environment vectorization to use"""
65 | num_envs: int = 512
66 | """the number of parallel environments"""
67 | num_eval_envs: int = 16
68 | """the number of parallel evaluation environments"""
69 | partial_reset: bool = True
70 | """whether to let parallel environments reset upon termination instead of truncation"""
71 | num_steps: int = 50
72 | """the number of steps to run in each environment per policy rollout"""
73 | num_eval_steps: int = 50
74 | """the number of steps to run in each evaluation environment during evaluation"""
75 | reconfiguration_freq: Optional[int] = None
76 | """how often to reconfigure the environment during training"""
77 | eval_reconfiguration_freq: Optional[int] = 1
78 | """for benchmarking purposes we want to reconfigure the eval environment each reset to ensure objects are randomized in some tasks"""
79 | eval_freq: int = 25
80 | """evaluation frequency in terms of iterations"""
81 | save_train_video_freq: Optional[int] = None
82 | """frequency to save training videos in terms of iterations"""
83 |
84 | # Algorithm specific arguments
85 | total_timesteps: int = 10000000
86 | """total timesteps of the experiments"""
87 | learning_rate: float = 3e-4
88 | """the learning rate of the optimizer"""
89 | anneal_lr: bool = False
90 | """Toggle learning rate annealing for policy and value networks"""
91 | gamma: float = 0.8
92 | """the discount factor gamma"""
93 | gae_lambda: float = 0.9
94 | """the lambda for the general advantage estimation"""
95 | num_minibatches: int = 32
96 | """the number of mini-batches"""
97 | update_epochs: int = 4
98 | """the K epochs to update the policy"""
99 | norm_adv: bool = True
100 | """Toggles advantages normalization"""
101 | clip_coef: float = 0.2
102 | """the surrogate clipping coefficient"""
103 | clip_vloss: bool = False
104 | """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
105 | ent_coef: float = 0.0
106 | """coefficient of the entropy"""
107 | vf_coef: float = 0.5
108 | """coefficient of the value function"""
109 | max_grad_norm: float = 0.5
110 | """the maximum norm for the gradient clipping"""
111 | target_kl: float = 0.1
112 | """the target KL divergence threshold"""
113 | reward_scale: float = 1.0
114 | """Scale the reward by this factor"""
115 | finite_horizon_gae: bool = False
116 |
117 | # to be filled in runtime
118 | batch_size: int = 0
119 | """the batch size (computed in runtime)"""
120 | minibatch_size: int = 0
121 | """the mini-batch size (computed in runtime)"""
122 | num_iterations: int = 0
123 | """the number of iterations (computed in runtime)"""
124 |
125 | # Torch optimizations
126 | compile: bool = False
127 | """whether to use torch.compile."""
128 | cudagraphs: bool = False
129 | """whether to use cudagraphs on top of compile."""
130 |
131 | def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
132 | torch.nn.init.orthogonal_(layer.weight, std)
133 | torch.nn.init.constant_(layer.bias, bias_const)
134 | return layer
135 |
136 |
137 | class Agent(nn.Module):
138 | def __init__(self, n_obs, n_act, device=None):
139 | super().__init__()
140 | self.critic = nn.Sequential(
141 | layer_init(nn.Linear(n_obs, 256, device=device)),
142 | nn.Tanh(),
143 | layer_init(nn.Linear(256, 256, device=device)),
144 | nn.Tanh(),
145 | layer_init(nn.Linear(256, 256, device=device)),
146 | nn.Tanh(),
147 | layer_init(nn.Linear(256, 1, device=device)),
148 | )
149 | self.actor_mean = nn.Sequential(
150 | layer_init(nn.Linear(n_obs, 256, device=device)),
151 | nn.Tanh(),
152 | layer_init(nn.Linear(256, 256, device=device)),
153 | nn.Tanh(),
154 | layer_init(nn.Linear(256, 256, device=device)),
155 | nn.Tanh(),
156 | layer_init(nn.Linear(256, n_act, device=device), std=0.01*np.sqrt(2)),
157 | )
158 | self.actor_logstd = nn.Parameter(torch.zeros(1, n_act, device=device))
159 |
160 | def get_value(self, x):
161 | return self.critic(x)
162 |
163 | def get_action_and_value(self, obs, action=None):
164 | action_mean = self.actor_mean(obs)
165 | action_logstd = self.actor_logstd.expand_as(action_mean)
166 | action_std = torch.exp(action_logstd)
167 | probs = Normal(action_mean, action_std)
168 | if action is None:
169 | action = action_mean + action_std * torch.randn_like(action_mean)
170 | return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(obs)
171 |
172 | class Logger:
173 | def __init__(self, log_wandb=False, tensorboard: SummaryWriter = None) -> None:
174 | self.writer = tensorboard
175 | self.log_wandb = log_wandb
176 | def add_scalar(self, tag, scalar_value, step):
177 | if self.log_wandb:
178 | wandb.log({tag: scalar_value}, step=step)
179 | self.writer.add_scalar(tag, scalar_value, step)
180 | def close(self):
181 | self.writer.close()
182 |
183 | def gae(next_obs, next_done, container, final_values):
184 | # bootstrap value if not done
185 | next_value = get_value(next_obs).reshape(-1)
186 | lastgaelam = 0
187 | nextnonterminals = (~container["dones"]).float().unbind(0)
188 | vals = container["vals"]
189 | vals_unbind = vals.unbind(0)
190 | rewards = container["rewards"].unbind(0)
191 |
192 | advantages = []
193 | nextnonterminal = (~next_done).float()
194 | nextvalues = next_value
195 | for t in range(args.num_steps - 1, -1, -1):
196 | cur_val = vals_unbind[t]
197 | # real_next_values = nextvalues * nextnonterminal
198 | real_next_values = nextnonterminal * nextvalues + final_values[t] # t instead of t+1
199 | delta = rewards[t] + args.gamma * real_next_values - cur_val
200 | advantages.append(delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam)
201 | lastgaelam = advantages[-1]
202 |
203 | nextnonterminal = nextnonterminals[t]
204 | nextvalues = cur_val
205 |
206 | advantages = container["advantages"] = torch.stack(list(reversed(advantages)))
207 | container["returns"] = advantages + vals
208 | return container
209 |
210 |
211 | def rollout(obs, done):
212 | ts = []
213 | final_values = torch.zeros((args.num_steps, args.num_envs), device=device)
214 | for step in range(args.num_steps):
215 | # ALGO LOGIC: action logic
216 | action, logprob, _, value = policy(obs=obs)
217 |
218 | # TRY NOT TO MODIFY: execute the game and log data.
219 | next_obs, reward, next_done, infos = step_func(action)
220 |
221 | if "final_info" in infos:
222 | final_info = infos["final_info"]
223 | done_mask = infos["_final_info"]
224 | for k, v in final_info["episode"].items():
225 | logger.add_scalar(f"train/{k}", v[done_mask].float().mean(), global_step)
226 | with torch.no_grad():
227 | final_values[step, torch.arange(args.num_envs, device=device)[done_mask]] = agent.get_value(infos["final_observation"][done_mask]).view(-1)
228 |
229 | ts.append(
230 | tensordict.TensorDict._new_unsafe(
231 | obs=obs,
232 | # cleanrl ppo examples associate the done with the previous obs (not the done resulting from action)
233 | dones=done,
234 | vals=value.flatten(),
235 | actions=action,
236 | logprobs=logprob,
237 | rewards=reward,
238 | batch_size=(args.num_envs,),
239 | )
240 | )
241 | # NOTE (stao): change here for gpu env
242 | obs = next_obs = next_obs
243 | done = next_done
244 | # NOTE (stao): need to do .to(device) i think? otherwise container.device is None, not sure if this affects anything
245 | container = torch.stack(ts, 0).to(device)
246 | return next_obs, done, container, final_values
247 |
248 |
249 | def update(obs, actions, logprobs, advantages, returns, vals):
250 | optimizer.zero_grad()
251 | _, newlogprob, entropy, newvalue = agent.get_action_and_value(obs, actions)
252 | logratio = newlogprob - logprobs
253 | ratio = logratio.exp()
254 |
255 | with torch.no_grad():
256 | # calculate approx_kl http://joschu.net/blog/kl-approx.html
257 | old_approx_kl = (-logratio).mean()
258 | approx_kl = ((ratio - 1) - logratio).mean()
259 | clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
260 |
261 | if args.norm_adv:
262 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
263 |
264 | # Policy loss
265 | pg_loss1 = -advantages * ratio
266 | pg_loss2 = -advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
267 | pg_loss = torch.max(pg_loss1, pg_loss2).mean()
268 |
269 | # Value loss
270 | newvalue = newvalue.view(-1)
271 | if args.clip_vloss:
272 | v_loss_unclipped = (newvalue - returns) ** 2
273 | v_clipped = vals + torch.clamp(
274 | newvalue - vals,
275 | -args.clip_coef,
276 | args.clip_coef,
277 | )
278 | v_loss_clipped = (v_clipped - returns) ** 2
279 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
280 | v_loss = 0.5 * v_loss_max.mean()
281 | else:
282 | v_loss = 0.5 * ((newvalue - returns) ** 2).mean()
283 |
284 | entropy_loss = entropy.mean()
285 | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
286 |
287 | loss.backward()
288 | gn = nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
289 | optimizer.step()
290 |
291 | return approx_kl, v_loss.detach(), pg_loss.detach(), entropy_loss.detach(), old_approx_kl, clipfrac, gn
292 |
293 |
294 | update = tensordict.nn.TensorDictModule(
295 | update,
296 | in_keys=["obs", "actions", "logprobs", "advantages", "returns", "vals"],
297 | out_keys=["approx_kl", "v_loss", "pg_loss", "entropy_loss", "old_approx_kl", "clipfrac", "gn"],
298 | )
299 |
300 | if __name__ == "__main__":
301 | args = tyro.cli(Args)
302 |
303 | batch_size = int(args.num_envs * args.num_steps)
304 | args.minibatch_size = batch_size // args.num_minibatches
305 | args.batch_size = args.num_minibatches * args.minibatch_size
306 | args.num_iterations = args.total_timesteps // args.batch_size
307 | if args.exp_name is None:
308 | args.exp_name = os.path.basename(__file__)[: -len(".py")]
309 | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
310 | else:
311 | run_name = args.exp_name
312 |
313 | # TRY NOT TO MODIFY: seeding
314 | random.seed(args.seed)
315 | np.random.seed(args.seed)
316 | torch.manual_seed(args.seed)
317 | torch.backends.cudnn.deterministic = args.torch_deterministic
318 |
319 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
320 |
321 | ####### Environment setup #######
322 | env_kwargs = dict(obs_mode="state", control_mode="pd_joint_delta_pos", render_mode="rgb_array", sim_backend="gpu")
323 | envs = gym.make(args.env_id, num_envs=args.num_envs if not args.evaluate else 1, reconfiguration_freq=args.reconfiguration_freq, **env_kwargs)
324 | eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, reconfiguration_freq=args.eval_reconfiguration_freq, human_render_camera_configs=dict(shader_pack="default"),**env_kwargs)
325 | if isinstance(envs.action_space, gym.spaces.Dict):
326 | envs = FlattenActionSpaceWrapper(envs)
327 | eval_envs = FlattenActionSpaceWrapper(eval_envs)
328 | if args.capture_video:
329 | eval_output_dir = f"runs/{run_name}/videos"
330 | if args.evaluate:
331 | eval_output_dir = f"{os.path.dirname(args.checkpoint)}/test_videos"
332 | print(f"Saving eval videos to {eval_output_dir}")
333 | if args.save_train_video_freq is not None:
334 | save_video_trigger = lambda x : (x // args.num_steps) % args.save_train_video_freq == 0
335 | envs = RecordEpisode(envs, output_dir=f"runs/{run_name}/train_videos", save_trajectory=False, save_video_trigger=save_video_trigger, max_steps_per_video=args.num_steps, video_fps=30)
336 | eval_envs = RecordEpisode(eval_envs, output_dir=eval_output_dir, save_trajectory=args.evaluate, trajectory_name="trajectory", max_steps_per_video=args.num_eval_steps, video_fps=30)
337 | envs = ManiSkillVectorEnv(envs, args.num_envs, ignore_terminations=not args.partial_reset, record_metrics=True)
338 | eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=True, record_metrics=True)
339 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
340 |
341 | max_episode_steps = gym_utils.find_max_episode_steps_value(envs._env)
342 | logger = None
343 | if not args.evaluate:
344 | print("Running training")
345 | if args.track:
346 | import wandb
347 | config = vars(args)
348 | config["env_cfg"] = dict(**env_kwargs, num_envs=args.num_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=args.partial_reset)
349 | config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id, reward_mode="normalized_dense", env_horizon=max_episode_steps, partial_reset=False)
350 | wandb.init(
351 | project=args.wandb_project_name,
352 | entity=args.wandb_entity,
353 | sync_tensorboard=False,
354 | config=config,
355 | name=run_name,
356 | save_code=True,
357 | group=args.wandb_group,
358 | tags=["ppo", "walltime_efficient", f"GPU:{torch.cuda.get_device_name()}"]
359 | )
360 | writer = SummaryWriter(f"runs/{run_name}")
361 | writer.add_text(
362 | "hyperparameters",
363 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
364 | )
365 | logger = Logger(log_wandb=args.track, tensorboard=writer)
366 | else:
367 | print("Running evaluation")
368 | n_act = math.prod(envs.single_action_space.shape)
369 | n_obs = math.prod(envs.single_observation_space.shape)
370 | assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
371 |
372 | # Register step as a special op not to graph break
373 | # @torch.library.custom_op("mylib::step", mutates_args=())
374 | def step_func(action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
375 | # NOTE (stao): change here for gpu env
376 | next_obs, reward, terminations, truncations, info = envs.step(action)
377 | next_done = torch.logical_or(terminations, truncations)
378 | return next_obs, reward, next_done, info
379 |
380 | ####### Agent #######
381 | agent = Agent(n_obs, n_act, device=device)
382 | if args.checkpoint:
383 | agent.load_state_dict(torch.load(args.checkpoint))
384 | # Make a version of agent with detached params
385 | agent_inference = Agent(n_obs, n_act, device=device)
386 | agent_inference_p = from_module(agent).data
387 | agent_inference_p.to_module(agent_inference)
388 |
389 | ####### Optimizer #######
390 | optimizer = optim.Adam(
391 | agent.parameters(),
392 | lr=torch.tensor(args.learning_rate, device=device),
393 | eps=1e-5,
394 | capturable=args.cudagraphs and not args.compile,
395 | )
396 |
397 | ####### Executables #######
398 | # Define networks: wrapping the policy in a TensorDictModule allows us to use CudaGraphModule
399 | policy = agent_inference.get_action_and_value
400 | get_value = agent_inference.get_value
401 |
402 | # Compile policy
403 | if args.compile:
404 | policy = torch.compile(policy)
405 | gae = torch.compile(gae, fullgraph=True)
406 | update = torch.compile(update)
407 |
408 | if args.cudagraphs:
409 | policy = CudaGraphModule(policy)
410 | gae = CudaGraphModule(gae)
411 | update = CudaGraphModule(update)
412 |
413 | global_step = 0
414 | start_time = time.time()
415 | container_local = None
416 | next_obs = envs.reset()[0]
417 | next_done = torch.zeros(args.num_envs, device=device, dtype=torch.bool)
418 | pbar = tqdm.tqdm(range(1, args.num_iterations + 1))
419 |
420 | cumulative_times = defaultdict(float)
421 |
422 | for iteration in pbar:
423 | agent.eval()
424 | if iteration % args.eval_freq == 1:
425 | stime = time.perf_counter()
426 | eval_obs, _ = eval_envs.reset()
427 | eval_metrics = defaultdict(list)
428 | num_episodes = 0
429 | for _ in range(args.num_eval_steps):
430 | with torch.no_grad():
431 | eval_obs, eval_rew, eval_terminations, eval_truncations, eval_infos = eval_envs.step(agent.actor_mean(eval_obs))
432 | if "final_info" in eval_infos:
433 | mask = eval_infos["_final_info"]
434 | num_episodes += mask.sum()
435 | for k, v in eval_infos["final_info"]["episode"].items():
436 | eval_metrics[k].append(v)
437 | eval_metrics_mean = {}
438 | for k, v in eval_metrics.items():
439 | mean = torch.stack(v).float().mean()
440 | eval_metrics_mean[k] = mean
441 | if logger is not None:
442 | logger.add_scalar(f"eval/{k}", mean, global_step)
443 | pbar.set_description(
444 | f"success_once: {eval_metrics_mean['success_once']:.2f}, "
445 | f"return: {eval_metrics_mean['return']:.2f}"
446 | )
447 | if logger is not None:
448 | eval_time = time.perf_counter() - stime
449 | cumulative_times["eval_time"] += eval_time
450 | logger.add_scalar("time/eval_time", eval_time, global_step)
451 | if args.evaluate:
452 | break
453 | if args.save_model and iteration % args.eval_freq == 1:
454 | model_path = f"runs/{run_name}/ckpt_{iteration}.pt"
455 | torch.save(agent.state_dict(), model_path)
456 | print(f"model saved to {model_path}")
457 | # Annealing the rate if instructed to do so.
458 | if args.anneal_lr:
459 | frac = 1.0 - (iteration - 1.0) / args.num_iterations
460 | lrnow = frac * args.learning_rate
461 | optimizer.param_groups[0]["lr"].copy_(lrnow)
462 |
463 | torch.compiler.cudagraph_mark_step_begin()
464 | rollout_time = time.perf_counter()
465 | next_obs, next_done, container, final_values = rollout(next_obs, next_done)
466 | rollout_time = time.perf_counter() - rollout_time
467 | cumulative_times["rollout_time"] += rollout_time
468 | global_step += container.numel()
469 |
470 | update_time = time.perf_counter()
471 | container = gae(next_obs, next_done, container, final_values)
472 | container_flat = container.view(-1)
473 |
474 | # Optimizing the policy and value network
475 | clipfracs = []
476 | for epoch in range(args.update_epochs):
477 | b_inds = torch.randperm(container_flat.shape[0], device=device).split(args.minibatch_size)
478 | for b in b_inds:
479 | container_local = container_flat[b]
480 |
481 | out = update(container_local, tensordict_out=tensordict.TensorDict())
482 | clipfracs.append(out["clipfrac"])
483 | if args.target_kl is not None and out["approx_kl"] > args.target_kl:
484 | break
485 | else:
486 | continue
487 | break
488 | update_time = time.perf_counter() - update_time
489 | cumulative_times["update_time"] += update_time
490 |
491 | logger.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
492 | logger.add_scalar("losses/value_loss", out["v_loss"].item(), global_step)
493 | logger.add_scalar("losses/policy_loss", out["pg_loss"].item(), global_step)
494 | logger.add_scalar("losses/entropy", out["entropy_loss"].item(), global_step)
495 | logger.add_scalar("losses/old_approx_kl", out["old_approx_kl"].item(), global_step)
496 | logger.add_scalar("losses/approx_kl", out["approx_kl"].item(), global_step)
497 | logger.add_scalar("losses/clipfrac", torch.stack(clipfracs).mean().cpu().item(), global_step)
498 | # logger.add_scalar("losses/explained_variance", explained_var, global_step)
499 | logger.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
500 | logger.add_scalar("time/step", global_step, global_step)
501 | logger.add_scalar("time/update_time", update_time, global_step)
502 | logger.add_scalar("time/rollout_time", rollout_time, global_step)
503 | logger.add_scalar("time/rollout_fps", args.num_envs * args.num_steps / rollout_time, global_step)
504 | for k, v in cumulative_times.items():
505 | logger.add_scalar(f"time/total_{k}", v, global_step)
506 | logger.add_scalar("time/total_rollout+update_time", cumulative_times["rollout_time"] + cumulative_times["update_time"], global_step)
507 |
508 | envs.close()
509 | eval_envs.close()
510 |
--------------------------------------------------------------------------------
/records/baseline/script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | seeds=(9351 4796 1788 2371 3462 4553 5644 6735 7826 8917)
3 | num_envs=4096
4 | update_epochs=12
5 | num_steps=4
6 | num_minibatches=32
7 | exp_name="baseline"
8 | for seed in ${seeds[@]}
9 | do
10 | python ppo.py --env_id="PickCube-v1" --seed=${seed} --total_timesteps=4_000_000 --eval-freq=10 \
11 | --num_envs=${num_envs} --update_epochs=${update_epochs} --num_minibatches=${num_minibatches} \
12 | --num-steps=${num_steps} \
13 | --exp-name="ppo-PickCube-v1-state-${seed}-${exp_name}" \
14 | --track --wandb_project_name="PPO-ManiSkill-GPU-SpeedRun" --wandb_group=${exp_name} --no-capture_video
15 | done
--------------------------------------------------------------------------------
/rl_robotics_speedrun_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "collapsed_sections": [
8 | "cj-pTw-Les0o"
9 | ],
10 | "machine_shape": "hm",
11 | "gpuType": "L4",
12 | "authorship_tag": "ABX9TyNsoFWzLLgUSaUtkeYRPI4F",
13 | "include_colab_link": true
14 | },
15 | "kernelspec": {
16 | "name": "python3",
17 | "display_name": "Python 3"
18 | },
19 | "language_info": {
20 | "name": "python"
21 | },
22 | "accelerator": "GPU"
23 | },
24 | "cells": [
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {
28 | "id": "view-in-github",
29 | "colab_type": "text"
30 | },
31 | "source": [
32 | "
"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {
38 | "id": "cj-pTw-Les0o"
39 | },
40 | "source": [
41 | "# Setup Code\n",
42 | "\n",
43 | "To begin, prepare the colab environment by switching to a GPU environment (Runtime -> Change Runtime Type)\n",
44 | "\n",
45 | "Then click the play button below. This will install all dependencies for the future code, namely ManiSkill, Torch, TorchRL, and Tensordict"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 1,
51 | "metadata": {
52 | "id": "pURa3mYwevk0",
53 | "outputId": "e4a70a8f-1d54-4b3a-848d-e987500b54c0",
54 | "colab": {
55 | "base_uri": "https://localhost:8080/"
56 | }
57 | },
58 | "outputs": [
59 | {
60 | "output_type": "stream",
61 | "name": "stdout",
62 | "text": [
63 | "Reading package lists... Done\n",
64 | "Building dependency tree... Done\n",
65 | "Reading state information... Done\n",
66 | "The following additional packages will be installed:\n",
67 | " libvulkan1\n",
68 | "Recommended packages:\n",
69 | " mesa-vulkan-drivers | vulkan-icd\n",
70 | "The following NEW packages will be installed:\n",
71 | " libvulkan-dev libvulkan1\n",
72 | "0 upgraded, 2 newly installed, 0 to remove and 49 not upgraded.\n",
73 | "Need to get 1,020 kB of archives.\n",
74 | "After this operation, 17.2 MB of additional disk space will be used.\n",
75 | "Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 libvulkan1 amd64 1.3.204.1-2 [128 kB]\n",
76 | "Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 libvulkan-dev amd64 1.3.204.1-2 [892 kB]\n",
77 | "Fetched 1,020 kB in 1s (706 kB/s)\n",
78 | "Selecting previously unselected package libvulkan1:amd64.\n",
79 | "(Reading database ... 123623 files and directories currently installed.)\n",
80 | "Preparing to unpack .../libvulkan1_1.3.204.1-2_amd64.deb ...\n",
81 | "Unpacking libvulkan1:amd64 (1.3.204.1-2) ...\n",
82 | "Selecting previously unselected package libvulkan-dev:amd64.\n",
83 | "Preparing to unpack .../libvulkan-dev_1.3.204.1-2_amd64.deb ...\n",
84 | "Unpacking libvulkan-dev:amd64 (1.3.204.1-2) ...\n",
85 | "Setting up libvulkan1:amd64 (1.3.204.1-2) ...\n",
86 | "Setting up libvulkan-dev:amd64 (1.3.204.1-2) ...\n",
87 | "Processing triggers for libc-bin (2.35-0ubuntu3.4) ...\n",
88 | "/sbin/ldconfig.real: /usr/local/lib/libtcm.so.1 is not a symbolic link\n",
89 | "\n",
90 | "/sbin/ldconfig.real: /usr/local/lib/libur_adapter_opencl.so.0 is not a symbolic link\n",
91 | "\n",
92 | "/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link\n",
93 | "\n",
94 | "/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link\n",
95 | "\n",
96 | "/sbin/ldconfig.real: /usr/local/lib/libhwloc.so.15 is not a symbolic link\n",
97 | "\n",
98 | "/sbin/ldconfig.real: /usr/local/lib/libur_loader.so.0 is not a symbolic link\n",
99 | "\n",
100 | "/sbin/ldconfig.real: /usr/local/lib/libur_adapter_level_zero.so.0 is not a symbolic link\n",
101 | "\n",
102 | "/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link\n",
103 | "\n",
104 | "/sbin/ldconfig.real: /usr/local/lib/libumf.so.0 is not a symbolic link\n",
105 | "\n",
106 | "/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link\n",
107 | "\n",
108 | "/sbin/ldconfig.real: /usr/local/lib/libtcm_debug.so.1 is not a symbolic link\n",
109 | "\n",
110 | "/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link\n",
111 | "\n",
112 | "/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link\n",
113 | "\n",
114 | "Cloning into 'rl-robotics-speedrun'...\n",
115 | "remote: Enumerating objects: 37, done.\u001b[K\n",
116 | "remote: Counting objects: 100% (37/37), done.\u001b[K\n",
117 | "remote: Compressing objects: 100% (25/25), done.\u001b[K\n",
118 | "remote: Total 37 (delta 9), reused 33 (delta 9), pack-reused 0 (from 0)\u001b[K\n",
119 | "Receiving objects: 100% (37/37), 15.73 KiB | 236.00 KiB/s, done.\n",
120 | "Resolving deltas: 100% (9/9), done.\n",
121 | "Obtaining file:///content/rl-robotics-speedrun\n",
122 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
123 | "Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (from rl-robotics-speedrun==0.0.1) (0.18.5)\n",
124 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from rl-robotics-speedrun==0.0.1) (4.66.6)\n",
125 | "Collecting mani_skill (from rl-robotics-speedrun==0.0.1)\n",
126 | " Downloading mani_skill-3.0.0b12-py3-none-any.whl.metadata (3.2 kB)\n",
127 | "Requirement already satisfied: tensorboard in /usr/local/lib/python3.10/dist-packages (from rl-robotics-speedrun==0.0.1) (2.17.0)\n",
128 | "Collecting torchrl (from rl-robotics-speedrun==0.0.1)\n",
129 | " Downloading torchrl-0.6.0-cp310-cp310-manylinux1_x86_64.whl.metadata (39 kB)\n",
130 | "Collecting tensordict (from rl-robotics-speedrun==0.0.1)\n",
131 | " Downloading tensordict-0.6.0-cp310-cp310-manylinux1_x86_64.whl.metadata (8.9 kB)\n",
132 | "Requirement already satisfied: numpy<2.0.0,>=1.22 in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (1.26.4)\n",
133 | "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (1.13.1)\n",
134 | "Collecting dacite (from mani_skill->rl-robotics-speedrun==0.0.1)\n",
135 | " Downloading dacite-1.8.1-py3-none-any.whl.metadata (15 kB)\n",
136 | "Collecting gymnasium==0.29.1 (from mani_skill->rl-robotics-speedrun==0.0.1)\n",
137 | " Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)\n",
138 | "Collecting sapien==3.0.0.b1 (from mani_skill->rl-robotics-speedrun==0.0.1)\n",
139 | " Downloading sapien-3.0.0b1-cp310-cp310-manylinux2014_x86_64.whl.metadata (10 kB)\n",
140 | "Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (3.12.1)\n",
141 | "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (6.0.2)\n",
142 | "Requirement already satisfied: GitPython in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (3.1.43)\n",
143 | "Requirement already satisfied: tabulate in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (0.9.0)\n",
144 | "Collecting transforms3d (from mani_skill->rl-robotics-speedrun==0.0.1)\n",
145 | " Downloading transforms3d-0.4.2-py3-none-any.whl.metadata (2.8 kB)\n",
146 | "Collecting trimesh (from mani_skill->rl-robotics-speedrun==0.0.1)\n",
147 | " Downloading trimesh-4.5.1-py3-none-any.whl.metadata (18 kB)\n",
148 | "Requirement already satisfied: imageio in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (2.36.0)\n",
149 | "Requirement already satisfied: IPython in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (7.34.0)\n",
150 | "Collecting pytorch-kinematics==0.7.4 (from mani_skill->rl-robotics-speedrun==0.0.1)\n",
151 | " Downloading pytorch_kinematics-0.7.4-py3-none-any.whl.metadata (14 kB)\n",
152 | "Collecting pynvml (from mani_skill->rl-robotics-speedrun==0.0.1)\n",
153 | " Downloading pynvml-11.5.3-py3-none-any.whl.metadata (8.8 kB)\n",
154 | "Collecting tyro>=0.8.5 (from mani_skill->rl-robotics-speedrun==0.0.1)\n",
155 | " Downloading tyro-0.8.14-py3-none-any.whl.metadata (8.4 kB)\n",
156 | "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from mani_skill->rl-robotics-speedrun==0.0.1) (0.24.7)\n",
157 | "Collecting mplib==0.1.1 (from mani_skill->rl-robotics-speedrun==0.0.1)\n",
158 | " Downloading mplib-0.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.7 kB)\n",
159 | "Collecting fast-kinematics==0.2.2 (from mani_skill->rl-robotics-speedrun==0.0.1)\n",
160 | " Downloading fast_kinematics-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)\n",
161 | "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium==0.29.1->mani_skill->rl-robotics-speedrun==0.0.1) (3.1.0)\n",
162 | "Requirement already satisfied: typing-extensions>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium==0.29.1->mani_skill->rl-robotics-speedrun==0.0.1) (4.12.2)\n",
163 | "Collecting farama-notifications>=0.0.1 (from gymnasium==0.29.1->mani_skill->rl-robotics-speedrun==0.0.1)\n",
164 | " Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)\n",
165 | "Collecting toppra>=0.4.0 (from mplib==0.1.1->mani_skill->rl-robotics-speedrun==0.0.1)\n",
166 | " Downloading toppra-0.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (3.4 kB)\n",
167 | "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (1.4.0)\n",
168 | "Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (5.3.0)\n",
169 | "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (2.5.0+cu121)\n",
170 | "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (3.8.0)\n",
171 | "Collecting pytorch-seed (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1)\n",
172 | " Downloading pytorch_seed-0.2.0-py3-none-any.whl.metadata (3.8 kB)\n",
173 | "Collecting arm-pytorch-utilities (from pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1)\n",
174 | " Downloading arm_pytorch_utilities-0.4.3-py3-none-any.whl.metadata (2.6 kB)\n",
175 | "Requirement already satisfied: requests>=2.22 in /usr/local/lib/python3.10/dist-packages (from sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (2.32.3)\n",
176 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (3.4.2)\n",
177 | "Requirement already satisfied: pyperclip in /usr/local/lib/python3.10/dist-packages (from sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (1.9.0)\n",
178 | "Requirement already satisfied: opencv-python>=4.0 in /usr/local/lib/python3.10/dist-packages (from sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (4.10.0.84)\n",
179 | "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (1.64.1)\n",
180 | "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (3.7)\n",
181 | "Requirement already satisfied: protobuf!=4.24.0,<5.0.0,>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (3.20.3)\n",
182 | "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (75.1.0)\n",
183 | "Requirement already satisfied: six>1.9 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (1.16.0)\n",
184 | "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (0.7.2)\n",
185 | "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard->rl-robotics-speedrun==0.0.1) (3.0.6)\n",
186 | "Requirement already satisfied: orjson in /usr/local/lib/python3.10/dist-packages (from tensordict->rl-robotics-speedrun==0.0.1) (3.10.10)\n",
187 | "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from torchrl->rl-robotics-speedrun==0.0.1) (24.1)\n",
188 | "Requirement already satisfied: click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (8.1.7)\n",
189 | "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (0.4.0)\n",
190 | "Requirement already satisfied: platformdirs in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (4.3.6)\n",
191 | "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (5.9.5)\n",
192 | "Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (2.17.0)\n",
193 | "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb->rl-robotics-speedrun==0.0.1) (1.3.3)\n",
194 | "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython->mani_skill->rl-robotics-speedrun==0.0.1) (4.0.11)\n",
195 | "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.22->sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (3.4.0)\n",
196 | "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.22->sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (3.10)\n",
197 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.22->sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (2.2.3)\n",
198 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.22->sapien==3.0.0.b1->mani_skill->rl-robotics-speedrun==0.0.1) (2024.8.30)\n",
199 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (3.16.1)\n",
200 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (3.1.4)\n",
201 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (2024.10.0)\n",
202 | "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (1.13.1)\n",
203 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (1.3.0)\n",
204 | "Requirement already satisfied: docstring-parser>=0.16 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.8.5->mani_skill->rl-robotics-speedrun==0.0.1) (0.16)\n",
205 | "Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.8.5->mani_skill->rl-robotics-speedrun==0.0.1) (13.9.3)\n",
206 | "Collecting shtab>=1.5.6 (from tyro>=0.8.5->mani_skill->rl-robotics-speedrun==0.0.1)\n",
207 | " Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)\n",
208 | "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard->rl-robotics-speedrun==0.0.1) (3.0.2)\n",
209 | "Requirement already satisfied: pillow>=8.3.2 in /usr/local/lib/python3.10/dist-packages (from imageio->mani_skill->rl-robotics-speedrun==0.0.1) (10.4.0)\n",
210 | "Requirement already satisfied: imageio-ffmpeg in /usr/local/lib/python3.10/dist-packages (from imageio[ffmpeg]->mani_skill->rl-robotics-speedrun==0.0.1) (0.5.1)\n",
211 | "Collecting jedi>=0.16 (from IPython->mani_skill->rl-robotics-speedrun==0.0.1)\n",
212 | " Downloading jedi-0.19.1-py2.py3-none-any.whl.metadata (22 kB)\n",
213 | "Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (4.4.2)\n",
214 | "Requirement already satisfied: pickleshare in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.7.5)\n",
215 | "Requirement already satisfied: traitlets>=4.2 in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (5.7.1)\n",
216 | "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (3.0.48)\n",
217 | "Requirement already satisfied: pygments in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (2.18.0)\n",
218 | "Requirement already satisfied: backcall in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.2.0)\n",
219 | "Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.1.7)\n",
220 | "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.10/dist-packages (from IPython->mani_skill->rl-robotics-speedrun==0.0.1) (4.9.0)\n",
221 | "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython->mani_skill->rl-robotics-speedrun==0.0.1) (5.0.1)\n",
222 | "Requirement already satisfied: parso<0.9.0,>=0.8.3 in /usr/local/lib/python3.10/dist-packages (from jedi>=0.16->IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.8.4)\n",
223 | "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.10/dist-packages (from pexpect>4.3->IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.7.0)\n",
224 | "Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->IPython->mani_skill->rl-robotics-speedrun==0.0.1) (0.2.13)\n",
225 | "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.8.5->mani_skill->rl-robotics-speedrun==0.0.1) (3.0.0)\n",
226 | "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (1.3.0)\n",
227 | "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (0.12.1)\n",
228 | "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (4.54.1)\n",
229 | "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (1.4.7)\n",
230 | "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (3.2.0)\n",
231 | "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->pytorch-kinematics==0.7.4->mani_skill->rl-robotics-speedrun==0.0.1) (2.8.2)\n",
232 | "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.8.5->mani_skill->rl-robotics-speedrun==0.0.1) (0.1.2)\n",
233 | "Downloading mani_skill-3.0.0b12-py3-none-any.whl (79.3 MB)\n",
234 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.3/79.3 MB\u001b[0m \u001b[31m29.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
235 | "\u001b[?25hDownloading fast_kinematics-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (623 kB)\n",
236 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m623.8/623.8 kB\u001b[0m \u001b[31m45.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
237 | "\u001b[?25hDownloading gymnasium-0.29.1-py3-none-any.whl (953 kB)\n",
238 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m953.9/953.9 kB\u001b[0m \u001b[31m60.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
239 | "\u001b[?25hDownloading mplib-0.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.0 MB)\n",
240 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.0/12.0 MB\u001b[0m \u001b[31m63.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
241 | "\u001b[?25hDownloading pytorch_kinematics-0.7.4-py3-none-any.whl (59 kB)\n",
242 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.6/59.6 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
243 | "\u001b[?25hDownloading sapien-3.0.0b1-cp310-cp310-manylinux2014_x86_64.whl (49.6 MB)\n",
244 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.6/49.6 MB\u001b[0m \u001b[31m45.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
245 | "\u001b[?25hDownloading tensordict-0.6.0-cp310-cp310-manylinux1_x86_64.whl (353 kB)\n",
246 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m353.9/353.9 kB\u001b[0m \u001b[31m33.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
247 | "\u001b[?25hDownloading torchrl-0.6.0-cp310-cp310-manylinux1_x86_64.whl (1.0 MB)\n",
248 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m65.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
249 | "\u001b[?25hDownloading transforms3d-0.4.2-py3-none-any.whl (1.4 MB)\n",
250 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m76.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
251 | "\u001b[?25hDownloading tyro-0.8.14-py3-none-any.whl (109 kB)\n",
252 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m109.8/109.8 kB\u001b[0m \u001b[31m13.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
253 | "\u001b[?25hDownloading dacite-1.8.1-py3-none-any.whl (14 kB)\n",
254 | "Downloading pynvml-11.5.3-py3-none-any.whl (53 kB)\n",
255 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
256 | "\u001b[?25hDownloading trimesh-4.5.1-py3-none-any.whl (703 kB)\n",
257 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m703.7/703.7 kB\u001b[0m \u001b[31m49.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
258 | "\u001b[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\n",
259 | "Downloading jedi-0.19.1-py2.py3-none-any.whl (1.6 MB)\n",
260 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m77.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
261 | "\u001b[?25hDownloading shtab-1.7.1-py3-none-any.whl (14 kB)\n",
262 | "Downloading toppra-0.6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (638 kB)\n",
263 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m638.2/638.2 kB\u001b[0m \u001b[31m52.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
264 | "\u001b[?25hDownloading arm_pytorch_utilities-0.4.3-py3-none-any.whl (40 kB)\n",
265 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.0/41.0 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
266 | "\u001b[?25hDownloading pytorch_seed-0.2.0-py3-none-any.whl (4.2 kB)\n",
267 | "Installing collected packages: farama-notifications, trimesh, transforms3d, shtab, pynvml, jedi, gymnasium, fast-kinematics, dacite, sapien, tyro, toppra, tensordict, pytorch-seed, torchrl, mplib, arm-pytorch-utilities, pytorch-kinematics, mani_skill, rl-robotics-speedrun\n",
268 | " Running setup.py develop for rl-robotics-speedrun\n",
269 | "Successfully installed arm-pytorch-utilities-0.4.3 dacite-1.8.1 farama-notifications-0.0.4 fast-kinematics-0.2.2 gymnasium-0.29.1 jedi-0.19.1 mani_skill-3.0.0b12 mplib-0.1.1 pynvml-11.5.3 pytorch-kinematics-0.7.4 pytorch-seed-0.2.0 rl-robotics-speedrun-0.0.1 sapien-3.0.0b1 shtab-1.7.1 tensordict-0.6.0 toppra-0.6.0 torchrl-0.6.0 transforms3d-0.4.2 trimesh-4.5.1 tyro-0.8.14\n"
270 | ]
271 | }
272 | ],
273 | "source": [
274 | "# setup vulkan\n",
275 | "!mkdir -p /usr/share/vulkan/icd.d\n",
276 | "!wget -q https://raw.githubusercontent.com/haosulab/ManiSkill/main/docker/nvidia_icd.json\n",
277 | "!wget -q https://raw.githubusercontent.com/haosulab/ManiSkill/main/docker/10_nvidia.json\n",
278 | "!mv nvidia_icd.json /usr/share/vulkan/icd.d\n",
279 | "!mv 10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json\n",
280 | "!apt-get install -y --no-install-recommends libvulkan-dev\n",
281 | "# dependencies\n",
282 | "!git clone https://github.com/StoneT2000/rl-robotics-speedrun\n",
283 | "!cd rl-robotics-speedrun && pip install -e ."
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": 2,
289 | "metadata": {
290 | "id": "_LC6CX2LeJu6"
291 | },
292 | "outputs": [],
293 | "source": [
294 | "try:\n",
295 | " import google.colab\n",
296 | " IN_COLAB = True\n",
297 | "except:\n",
298 | " IN_COLAB = False\n",
299 | "\n",
300 | "if IN_COLAB:\n",
301 | " import site\n",
302 | " site.main() # run this so local pip installs are recognized"
303 | ]
304 | },
305 | {
306 | "cell_type": "markdown",
307 | "source": [
308 | "# Benchmarking"
309 | ],
310 | "metadata": {
311 | "id": "VAOJ_Y9w9Lgf"
312 | }
313 | },
314 | {
315 | "cell_type": "code",
316 | "source": [
317 | "# login to upload / track aggregated results, you need to provide your wandb API key\n",
318 | "import wandb\n",
319 | "wandb.login()"
320 | ],
321 | "metadata": {
322 | "colab": {
323 | "base_uri": "https://localhost:8080/",
324 | "height": 86
325 | },
326 | "id": "9fqoUMxR9N4z",
327 | "outputId": "81f123fd-3bc2-4e81-e1c9-ea894ef7e492"
328 | },
329 | "execution_count": 3,
330 | "outputs": [
331 | {
332 | "output_type": "stream",
333 | "name": "stderr",
334 | "text": [
335 | "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n"
336 | ]
337 | },
338 | {
339 | "output_type": "display_data",
340 | "data": {
341 | "text/plain": [
342 | ""
343 | ],
344 | "application/javascript": [
345 | "\n",
346 | " window._wandbApiKey = new Promise((resolve, reject) => {\n",
347 | " function loadScript(url) {\n",
348 | " return new Promise(function(resolve, reject) {\n",
349 | " let newScript = document.createElement(\"script\");\n",
350 | " newScript.onerror = reject;\n",
351 | " newScript.onload = resolve;\n",
352 | " document.body.appendChild(newScript);\n",
353 | " newScript.src = url;\n",
354 | " });\n",
355 | " }\n",
356 | " loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n",
357 | " const iframe = document.createElement('iframe')\n",
358 | " iframe.style.cssText = \"width:0;height:0;border:none\"\n",
359 | " document.body.appendChild(iframe)\n",
360 | " const handshake = new Postmate({\n",
361 | " container: iframe,\n",
362 | " url: 'https://wandb.ai/authorize'\n",
363 | " });\n",
364 | " const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n",
365 | " handshake.then(function(child) {\n",
366 | " child.on('authorize', data => {\n",
367 | " clearTimeout(timeout)\n",
368 | " resolve(data)\n",
369 | " });\n",
370 | " });\n",
371 | " })\n",
372 | " });\n",
373 | " "
374 | ]
375 | },
376 | "metadata": {}
377 | },
378 | {
379 | "output_type": "stream",
380 | "name": "stderr",
381 | "text": [
382 | "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
383 | ]
384 | },
385 | {
386 | "output_type": "execute_result",
387 | "data": {
388 | "text/plain": [
389 | "True"
390 | ]
391 | },
392 | "metadata": {},
393 | "execution_count": 3
394 | }
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "source": [
400 | "# look at the training script, make modifications if needed\n",
401 | "!cat rl-robotics-speedrun/records/10312024_cudagraphs/script.sh\n",
402 | "# ppo code at rl-robotics-speedrun/records/10312024_cudagraphs/ppo.py"
403 | ],
404 | "metadata": {
405 | "colab": {
406 | "base_uri": "https://localhost:8080/"
407 | },
408 | "id": "imJpNrbSF5UM",
409 | "outputId": "8935418d-f194-48be-9c78-8a9f8f6a21ab"
410 | },
411 | "execution_count": 5,
412 | "outputs": [
413 | {
414 | "output_type": "stream",
415 | "name": "stdout",
416 | "text": [
417 | "#!/bin/bash\n",
418 | "seeds=(9351 4796 1788 2371 3462 4553 5644 6735 7826 8917)\n",
419 | "num_envs=4096\n",
420 | "update_epochs=12\n",
421 | "num_steps=4\n",
422 | "num_minibatches=32\n",
423 | "exp_name=\"10312024_cudagraphs\"\n",
424 | "for seed in ${seeds[@]}\n",
425 | "do\n",
426 | " python ppo.py --env_id=\"PickCube-v1\" --seed=${seed} --total_timesteps=4_000_000 --eval-freq=10 \\\n",
427 | " --num_envs=${num_envs} --update_epochs=${update_epochs} --num_minibatches=${num_minibatches} \\\n",
428 | " --num-steps=${num_steps} --cudagraphs \\\n",
429 | " --exp-name=\"ppo-PickCube-v1-state-${seed}-${exp_name}\" \\\n",
430 | " --track --wandb_project_name=\"PPO-ManiSkill-GPU-SpeedRun\" --wandb_group=${exp_name} --no-capture_video\n",
431 | "done"
432 | ]
433 | }
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "source": [
439 | "!cd rl-robotics-speedrun/records/10312024_cudagraphs/ && bash script.sh"
440 | ],
441 | "metadata": {
442 | "colab": {
443 | "base_uri": "https://localhost:8080/"
444 | },
445 | "id": "owjFU8HIClgW",
446 | "outputId": "a2904e53-7cc3-4386-dceb-f0989fc0b801"
447 | },
448 | "execution_count": null,
449 | "outputs": [
450 | {
451 | "output_type": "stream",
452 | "name": "stdout",
453 | "text": [
454 | "2024-10-31 19:41:14.749861: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
455 | "2024-10-31 19:41:14.765622: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
456 | "2024-10-31 19:41:14.787173: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
457 | "2024-10-31 19:41:14.793659: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
458 | "2024-10-31 19:41:14.809284: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
459 | "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
460 | "2024-10-31 19:41:16.055538: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
461 | "Downloading PhysX GPU library to /root/.sapien/physx/105.1-physx-5.3.1.patch0 from Github. This can take several minutes. If it fails to download, please manually download fhttps://github.com/sapien-sim/physx-precompiled/releases/download/105.1-physx-5.3.1.patch0/linux-so.zip and unzip at /root/.sapien/physx/105.1-physx-5.3.1.patch0.\n",
462 | "Download complete.\n",
463 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.\u001b[0m\n",
464 | " logger.warn(\n",
465 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.\u001b[0m\n",
466 | " logger.warn(\n",
467 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.max_episode_steps to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.max_episode_steps` for environment variables or `env.get_wrapper_attr('max_episode_steps')` that will search the reminding wrappers.\u001b[0m\n",
468 | " logger.warn(\n",
469 | "Running training\n",
470 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mstonet2000\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
471 | "\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.18.5\n",
472 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run data is saved locally in \u001b[35m\u001b[1m/content/rl-robotics-speedrun/records/10312024_cudagraphs/wandb/run-20241031_194142-tir5o14p\u001b[0m\n",
473 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run \u001b[1m`wandb offline`\u001b[0m to turn off syncing.\n",
474 | "\u001b[34m\u001b[1mwandb\u001b[0m: Syncing run \u001b[33mppo-PickCube-v1-state-9351-10312024_cudagraphs\u001b[0m\n",
475 | "\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at \u001b[34m\u001b[4mhttps://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun\u001b[0m\n",
476 | "\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run at \u001b[34m\u001b[4mhttps://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun/runs/tir5o14p\u001b[0m\n",
477 | "/usr/local/lib/python3.10/dist-packages/tensordict/nn/cudagraphs.py:194: UserWarning: Tensordict is registered in PyTree. This is incompatible with CudaGraphModule. Removing TDs from PyTree. To silence this warning, call tensordict.nn.functional_module._exclude_td_from_pytree().set() or set the environment variable `EXCLUDE_TD_FROM_PYTREE=1`. This operation is irreversible.\n",
478 | " warnings.warn(\n",
479 | "success_once: 0.88, return: 37.81: 100% 244/244 [03:58<00:00, 1.02it/s]\n",
480 | "2024-10-31 19:45:53.083376: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
481 | "2024-10-31 19:45:53.101255: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
482 | "2024-10-31 19:45:53.122746: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
483 | "2024-10-31 19:45:53.129584: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
484 | "2024-10-31 19:45:53.145547: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
485 | "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
486 | "2024-10-31 19:45:54.411087: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
487 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.\u001b[0m\n",
488 | " logger.warn(\n",
489 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.\u001b[0m\n",
490 | " logger.warn(\n",
491 | "/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: \u001b[33mWARN: env.max_episode_steps to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.max_episode_steps` for environment variables or `env.get_wrapper_attr('max_episode_steps')` that will search the reminding wrappers.\u001b[0m\n",
492 | " logger.warn(\n",
493 | "Running training\n",
494 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mstonet2000\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
495 | "\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.18.5\n",
496 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run data is saved locally in \u001b[35m\u001b[1m/content/rl-robotics-speedrun/records/10312024_cudagraphs/wandb/run-20241031_194616-mf31p5np\u001b[0m\n",
497 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run \u001b[1m`wandb offline`\u001b[0m to turn off syncing.\n",
498 | "\u001b[34m\u001b[1mwandb\u001b[0m: Syncing run \u001b[33mppo-PickCube-v1-state-4796-10312024_cudagraphs\u001b[0m\n",
499 | "\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at \u001b[34m\u001b[4mhttps://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun\u001b[0m\n",
500 | "\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run at \u001b[34m\u001b[4mhttps://wandb.ai/stonet2000/PPO-ManiSkill-GPU-SpeedRun/runs/mf31p5np\u001b[0m\n",
501 | "/usr/local/lib/python3.10/dist-packages/tensordict/nn/cudagraphs.py:194: UserWarning: Tensordict is registered in PyTree. This is incompatible with CudaGraphModule. Removing TDs from PyTree. To silence this warning, call tensordict.nn.functional_module._exclude_td_from_pytree().set() or set the environment variable `EXCLUDE_TD_FROM_PYTREE=1`. This operation is irreversible.\n",
502 | " warnings.warn(\n",
503 | "success_once: 0.00, return: 19.96: 22% 53/244 [01:06<04:15, 1.34s/it]"
504 | ]
505 | }
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "source": [],
511 | "metadata": {
512 | "id": "pCmwy4__GIEz"
513 | },
514 | "execution_count": null,
515 | "outputs": []
516 | }
517 | ]
518 | }
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="rl-robotics-speedrun",
5 | version="0.0.1",
6 | packages=find_packages(),
7 | install_requires=[
8 | "wandb",
9 | "tqdm",
10 | "mani_skill",
11 | "tensorboard",
12 | "torchrl",
13 | "tensordict"
14 | ],
15 | author="Stone Tao",
16 | description="Speed-running solving robot manipulation tasks",
17 | long_description=open("README.md").read(),
18 | long_description_content_type="text/markdown",
19 | url="https://github.com/StoneT2000/rl-robotics-speedrun",
20 | python_requires=">=3.10",
21 | )
22 |
--------------------------------------------------------------------------------