├── .gitignore ├── README.md ├── algos ├── __init__.py ├── offrl │ ├── .DS_Store │ ├── bcq │ │ ├── __init__.py │ │ ├── bcq.py │ │ ├── module.py │ │ └── storage.py │ ├── iql │ │ ├── __init__.py │ │ ├── iql.py │ │ ├── module.py │ │ └── storage.py │ ├── ppo_collect │ │ ├── __init__.py │ │ ├── module.py │ │ ├── ppo_collect.py │ │ └── storage.py │ └── td3_bc │ │ ├── __init__.py │ │ ├── module.py │ │ ├── storage.py │ │ └── td3_bc.py ├── planner │ ├── approach.py │ └── base.py ├── rl │ └── ppo │ │ ├── __init__.py │ │ ├── module.py │ │ ├── ppo.py │ │ └── storage.py └── utils │ ├── act.py │ ├── cnn.py │ ├── distributions.py │ ├── mlp.py │ ├── rnn.py │ └── util.py ├── assets └── readme │ └── teaser.png ├── cfgs ├── algo │ └── ppo │ │ ├── config.yaml │ │ └── manipulation.yaml ├── plan │ └── config.yaml ├── repre │ └── ag2manip │ │ └── config.yaml └── task │ ├── frankakitchen │ ├── close_hingecabinet.yaml │ ├── close_microwave.yaml │ ├── close_slidecabinet.yaml │ ├── move_kettle.yaml │ ├── open_hingecabinet.yaml │ ├── open_microwave.yaml │ ├── open_slidecabinet.yaml │ ├── pickup_kettle.yaml │ ├── turnoff_switch.yaml │ └── turnon_switch.yaml │ ├── maniskill │ ├── close_door.yaml │ ├── insert_peg.yaml │ ├── open_door.yaml │ ├── pickup_clutterycb.yaml │ ├── pickup_cube.yaml │ ├── stack_cube.yaml │ ├── turn_leftfaucet.yaml │ └── turn_rightfaucet.yaml │ └── partmanip │ ├── close_dishwasher.yaml │ ├── lift_lid.yaml │ ├── open_dishwasher.yaml │ ├── press_button.yaml │ ├── pull_wooddrawer.yaml │ └── push_wooddrawer.yaml ├── plan.py ├── repre_trainer ├── cfgs │ ├── default.yaml │ ├── model │ │ ├── ag2manip.yaml │ │ ├── r3m.yaml │ │ └── vip.yaml │ └── task │ │ └── epic_kitchen.yaml ├── datasets │ ├── __init__.py │ ├── base.py │ ├── epic_kitchen.py │ └── misc.py ├── models │ ├── __init__.py │ ├── base.py │ ├── evaluator.py │ └── model │ │ ├── ag2manip.py │ │ ├── r3m.py │ │ └── vip.py ├── train.py ├── train_ddp.py └── utils │ ├── io.py │ ├── misc.py │ ├── plot.py │ └── registry.py ├── repres ├── ag2manip.py ├── base │ ├── __init__.py │ └── base_repre.py ├── r3m.py └── vip.py ├── requirements.txt ├── tasks ├── base │ ├── __init__.py │ ├── base_task.py │ └── vec_task.py ├── frankakitchen │ ├── CloseHingecabinet.py │ ├── CloseMicrowave.py │ ├── CloseSlidecabinet.py │ ├── MoveKettle.py │ ├── OpenHingecabinet.py │ ├── OpenMicrowave.py │ ├── OpenSlidecabinet.py │ ├── PickupKettle.py │ ├── TurnoffSwitch.py │ ├── TurnonSwitch.py │ ├── goals_image │ │ ├── CloseHingecabinet@default@woa.png │ │ ├── CloseHingecabinet@left@woa.png │ │ ├── CloseHingecabinet@right@woa.png │ │ ├── CloseMicrowave@default@woa.png │ │ ├── CloseMicrowave@left@woa.png │ │ ├── CloseMicrowave@right@woa.png │ │ ├── CloseSlidecabinet@default@woa.png │ │ ├── CloseSlidecabinet@left@woa.png │ │ ├── CloseSlidecabinet@right@woa.png │ │ ├── MoveKettle@default@woa.png │ │ ├── MoveKettle@left@woa.png │ │ ├── MoveKettle@right@woa.png │ │ ├── OpenHingecabinet@default@woa.png │ │ ├── OpenHingecabinet@left@woa.png │ │ ├── OpenHingecabinet@right@woa.png │ │ ├── OpenMicrowave@default@woa.png │ │ ├── OpenMicrowave@left@woa.png │ │ ├── OpenMicrowave@right@woa.png │ │ ├── OpenSlidecabinet@default@woa.png │ │ ├── OpenSlidecabinet@left@woa.png │ │ ├── OpenSlidecabinet@right@woa.png │ │ ├── PickupKettle@default@woa.png │ │ ├── PickupKettle@left@woa.png │ │ ├── PickupKettle@right@woa.png │ │ ├── TurnoffSwitch@default@woa.png │ │ ├── TurnoffSwitch@left@woa.png │ │ ├── TurnoffSwitch@right@woa.png │ │ ├── TurnonSwitch@default@woa.png │ │ ├── TurnonSwitch@left@woa.png │ │ └── TurnonSwitch@right@woa.png │ └── goals_image_wa │ │ ├── CloseHingecabinet@default@wa.png │ │ ├── CloseHingecabinet@left@wa.png │ │ ├── CloseHingecabinet@right@wa.png │ │ ├── CloseMicrowave@default@wa.png │ │ ├── CloseMicrowave@left@wa.png │ │ ├── CloseMicrowave@right@wa.png │ │ ├── CloseSlidecabinet@default@wa.png │ │ ├── CloseSlidecabinet@left@wa.png │ │ ├── CloseSlidecabinet@right@wa.png │ │ ├── MoveKettle@default@wa.png │ │ ├── MoveKettle@left@wa.png │ │ ├── MoveKettle@right@wa.png │ │ ├── OpenHingecabinet@default@wa.png │ │ ├── OpenHingecabinet@left@wa.png │ │ ├── OpenHingecabinet@right@wa.png │ │ ├── OpenMicrowave@default@wa.png │ │ ├── OpenMicrowave@left@wa.png │ │ ├── OpenMicrowave@right@wa.png │ │ ├── OpenSlidecabinet@default@wa.png │ │ ├── OpenSlidecabinet@left@wa.png │ │ ├── OpenSlidecabinet@right@wa.png │ │ ├── PickupKettle@default@wa.png │ │ ├── PickupKettle@left@wa.png │ │ ├── PickupKettle@right@wa.png │ │ ├── TurnoffSwitch@default@wa.png │ │ ├── TurnoffSwitch@left@wa.png │ │ ├── TurnoffSwitch@right@wa.png │ │ ├── TurnonSwitch@default@wa.png │ │ ├── TurnonSwitch@left@wa.png │ │ └── TurnonSwitch@right@wa.png ├── maniskill │ ├── CloseDoor.py │ ├── InsertPeg.py │ ├── OpenDoor.py │ ├── PickupClutterycb.py │ ├── PickupCube.py │ ├── StackCube.py │ ├── TurnLeftfaucet.py │ ├── TurnRightfaucet.py │ ├── goals_image │ │ ├── CloseDoor@default@woa.png │ │ ├── CloseDoor@left@woa.png │ │ ├── CloseDoor@right@woa.png │ │ ├── InsertPeg@default@woa.png │ │ ├── InsertPeg@left@woa.png │ │ ├── InsertPeg@right@woa.png │ │ ├── OpenDoor@default@woa.png │ │ ├── OpenDoor@left@woa.png │ │ ├── OpenDoor@right@woa.png │ │ ├── PickupClutterycb@default@woa.png │ │ ├── PickupClutterycb@left@woa.png │ │ ├── PickupClutterycb@right@woa.png │ │ ├── PickupCube@default@woa.png │ │ ├── PickupCube@left@woa.png │ │ ├── PickupCube@right@woa.png │ │ ├── StackCube@default@woa.png │ │ ├── StackCube@left@woa.png │ │ ├── StackCube@right@woa.png │ │ ├── TurnLeftfaucet@default@woa.png │ │ ├── TurnLeftfaucet@left@woa.png │ │ ├── TurnLeftfaucet@right@woa.png │ │ ├── TurnRightfaucet@default@woa.png │ │ ├── TurnRightfaucet@left@woa.png │ │ └── TurnRightfaucet@right@woa.png │ └── goals_image_wa │ │ ├── CloseDoor@default@wa.png │ │ ├── CloseDoor@left@wa.png │ │ ├── CloseDoor@right@wa.png │ │ ├── InsertPeg@default@wa.png │ │ ├── InsertPeg@left@wa.png │ │ ├── InsertPeg@right@wa.png │ │ ├── OpenDoor@default@wa.png │ │ ├── OpenDoor@left@wa.png │ │ ├── OpenDoor@right@wa.png │ │ ├── PickCube@right@wa.png │ │ ├── PickupClutterycb@default@wa.png │ │ ├── PickupClutterycb@left@wa.png │ │ ├── PickupClutterycb@right@wa.png │ │ ├── PickupCube@default@wa.png │ │ ├── PickupCube@left@wa.png │ │ ├── StackCube@default@wa.png │ │ ├── StackCube@left@wa.png │ │ ├── StackCube@right@wa.png │ │ ├── TurnLeftfaucet@default@wa.png │ │ ├── TurnLeftfaucet@left@wa.png │ │ ├── TurnLeftfaucet@right@wa.png │ │ ├── TurnRightfaucet@default@wa.png │ │ ├── TurnRightfaucet@left@wa.png │ │ └── TurnRightfaucet@right@wa.png └── partmanip │ ├── CloseDishwasher.py │ ├── LiftLid.py │ ├── OpenDishwasher.py │ ├── PressButton.py │ ├── PullWooddrawer.py │ ├── PushWooddrawer.py │ ├── goals_image │ ├── CloseDishwasher@default@woa.png │ ├── CloseDishwasher@left@woa.png │ ├── CloseDishwasher@right@woa.png │ ├── LiftLid@default@woa.png │ ├── LiftLid@left@woa.png │ ├── LiftLid@right@woa.png │ ├── OpenDishwasher@default@woa.png │ ├── OpenDishwasher@left@woa.png │ ├── OpenDishwasher@right@woa.png │ ├── PressButton@default@woa.png │ ├── PressButton@left@woa.png │ ├── PressButton@right@woa.png │ ├── PullWooddrawer@default@woa.png │ ├── PullWooddrawer@left@woa.png │ ├── PullWooddrawer@right@woa.png │ ├── PushWooddrawer@default@woa.png │ ├── PushWooddrawer@left@woa.png │ └── PushWooddrawer@right@woa.png │ └── goals_image_wa │ ├── CloseDishwasher@default@wa.png │ ├── CloseDishwasher@left@wa.png │ ├── CloseDishwasher@right@wa.png │ ├── LiftLid@default@wa.png │ ├── LiftLid@left@wa.png │ ├── LiftLid@right@wa.png │ ├── OpenDishwasher@default@wa.png │ ├── OpenDishwasher@left@wa.png │ ├── OpenDishwasher@right@wa.png │ ├── PressButton@default@wa.png │ ├── PressButton@left@wa.png │ ├── PressButton@right@wa.png │ ├── PullWooddrawer@default@wa.png │ ├── PullWooddrawer@left@wa.png │ ├── PullWooddrawer@right@wa.png │ ├── PushWooddrawer@default@wa.png │ ├── PushWooddrawer@left@wa.png │ └── PushWooddrawer@right@wa.png ├── train.py └── utils ├── __init__.py ├── config.py ├── logger ├── plotter.py └── tools.py ├── o3dviewer.py ├── package_utils.py ├── parse_task.py ├── parse_task_plan.py ├── process_offrl.py ├── process_sarl.py ├── torch_jit_utils.py └── util.py /algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/algos/__init__.py -------------------------------------------------------------------------------- /algos/offrl/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/algos/offrl/.DS_Store -------------------------------------------------------------------------------- /algos/offrl/bcq/__init__.py: -------------------------------------------------------------------------------- 1 | from .storage import ReplayBuffer 2 | from .module import BCQ_Model 3 | from .bcq import BCQ -------------------------------------------------------------------------------- /algos/offrl/bcq/bcq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | from bidexhands.algorithms.offrl.bcq import BCQ_Model 10 | from bidexhands.algorithms.offrl.bcq import ReplayBuffer 11 | 12 | class BCQ: 13 | 14 | def __init__(self, 15 | vec_env, 16 | device='cpu', 17 | discount = 0.99, 18 | tau = 0.005, 19 | lmbda = 0.75, 20 | phi = 0.05, 21 | batch_size = 100, 22 | max_timesteps = 1000000, 23 | iterations = 10000, 24 | log_dir = '', 25 | datatype = 'expert', 26 | algo = 'bcq'): 27 | 28 | self.observation_space = vec_env.observation_space 29 | self.action_space = vec_env.action_space 30 | self.vec_env = vec_env 31 | self.device = device 32 | self.discount = discount 33 | self.tau = tau 34 | self.lmbda = lmbda 35 | self.phi = phi 36 | self.batch_size = batch_size 37 | self.max_timesteps = max_timesteps 38 | self.iterations = iterations 39 | self.log_dir = log_dir 40 | self.datatype = datatype 41 | self.algo = algo 42 | self.log_dir = self.log_dir.split(self.algo)[0]+self.algo+'/'+self.datatype+'/' 43 | self.data_dir = self.log_dir.split(self.algo)[0].split('logs') 44 | self.data_dir = self.data_dir[0]+'data'+self.data_dir[1]+self.datatype+'/' 45 | self.test_step = 40000 46 | 47 | if not os.path.exists(self.log_dir): 48 | os.makedirs(self.log_dir) 49 | time.sleep(np.random.rand()*2) 50 | order = len(os.listdir(self.log_dir)) 51 | self.reward_log = open(self.log_dir+str(order)+'.log','w') 52 | 53 | def run(self, num_learning_iterations, log_interval=1): 54 | 55 | current_obs = self.vec_env.reset() 56 | state_dim = self.observation_space.shape[0] 57 | action_dim = self.action_space.shape[0] 58 | max_action = float(self.action_space.high[0]) 59 | 60 | policy = BCQ_Model(state_dim, action_dim, max_action, self.device, self.discount, self.tau, self.lmbda, self.phi) 61 | 62 | replay_buffer = ReplayBuffer(state_dim, action_dim, self.device) 63 | replay_buffer.convert(self.data_dir) 64 | 65 | for t in range(int(self.max_timesteps/self.iterations)+1): 66 | 67 | policy.train(replay_buffer, self.iterations, self.batch_size) 68 | 69 | reward_sum = [] 70 | cur_reward_sum = torch.zeros(self.vec_env.num_envs, dtype=torch.float, device=self.device) 71 | current_obs = self.vec_env.reset() 72 | for _ in range(int(self.test_step/self.vec_env.num_envs)): 73 | actions = policy.select_action(current_obs) 74 | next_obs, rews, dones, infos = self.vec_env.step(actions) 75 | current_obs.copy_(next_obs) 76 | cur_reward_sum[:] += rews 77 | new_ids = (dones > 0).nonzero(as_tuple=False) 78 | reward_sum.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) 79 | cur_reward_sum[new_ids] = 0 80 | 81 | self.reward_log.write(str(sum(reward_sum)/len(reward_sum))+'\n') 82 | self.reward_log.flush() 83 | 84 | 85 | -------------------------------------------------------------------------------- /algos/offrl/bcq/module.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class Actor(nn.Module): 8 | def __init__(self, state_dim, action_dim, max_action, phi=0.05): 9 | super(Actor, self).__init__() 10 | self.l1 = nn.Linear(state_dim + action_dim, 400) 11 | self.l2 = nn.Linear(400, 300) 12 | self.l3 = nn.Linear(300, action_dim) 13 | 14 | self.max_action = max_action 15 | self.phi = phi 16 | 17 | def forward(self, state, action): 18 | a = F.relu(self.l1(torch.cat([state, action], 1))) 19 | a = F.relu(self.l2(a)) 20 | a = self.phi * self.max_action * torch.tanh(self.l3(a)) 21 | return (a + action).clamp(-self.max_action, self.max_action) 22 | 23 | class Critic(nn.Module): 24 | def __init__(self, state_dim, action_dim): 25 | super(Critic, self).__init__() 26 | self.l1 = nn.Linear(state_dim + action_dim, 400) 27 | self.l2 = nn.Linear(400, 300) 28 | self.l3 = nn.Linear(300, 1) 29 | 30 | self.l4 = nn.Linear(state_dim + action_dim, 400) 31 | self.l5 = nn.Linear(400, 300) 32 | self.l6 = nn.Linear(300, 1) 33 | 34 | def forward(self, state, action): 35 | q1 = F.relu(self.l1(torch.cat([state, action], 1))) 36 | q1 = F.relu(self.l2(q1)) 37 | q1 = self.l3(q1) 38 | 39 | q2 = F.relu(self.l4(torch.cat([state, action], 1))) 40 | q2 = F.relu(self.l5(q2)) 41 | q2 = self.l6(q2) 42 | return q1, q2 43 | 44 | def q1(self, state, action): 45 | q1 = F.relu(self.l1(torch.cat([state, action], 1))) 46 | q1 = F.relu(self.l2(q1)) 47 | q1 = self.l3(q1) 48 | return q1 49 | 50 | class VAE(nn.Module): 51 | def __init__(self, state_dim, action_dim, latent_dim, max_action, device): 52 | super(VAE, self).__init__() 53 | self.e1 = nn.Linear(state_dim + action_dim, 750) 54 | self.e2 = nn.Linear(750, 750) 55 | 56 | self.mean = nn.Linear(750, latent_dim) 57 | self.log_std = nn.Linear(750, latent_dim) 58 | 59 | self.d1 = nn.Linear(state_dim + latent_dim, 750) 60 | self.d2 = nn.Linear(750, 750) 61 | self.d3 = nn.Linear(750, action_dim) 62 | 63 | self.max_action = max_action 64 | self.latent_dim = latent_dim 65 | self.device = device 66 | 67 | def forward(self, state, action): 68 | z = F.relu(self.e1(torch.cat([state, action], 1))) 69 | z = F.relu(self.e2(z)) 70 | 71 | mean = self.mean(z) 72 | log_std = self.log_std(z).clamp(-4, 15) 73 | std = torch.exp(log_std) 74 | z = mean + std * torch.randn_like(std) 75 | u = self.decode(state, z) 76 | 77 | return u, mean, std 78 | 79 | def decode(self, state, z=None): 80 | if z is None: 81 | z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-0.5,0.5) 82 | 83 | a = F.relu(self.d1(torch.cat([state, z], 1))) 84 | a = F.relu(self.d2(a)) 85 | return self.max_action * torch.tanh(self.d3(a)) 86 | 87 | class BCQ_Model(object): 88 | def __init__(self, state_dim, action_dim, max_action, device, discount=0.99, tau=0.005, lmbda=0.75, phi=0.05): 89 | latent_dim = action_dim * 2 90 | 91 | self.actor = Actor(state_dim, action_dim, max_action, phi).to(device) 92 | self.actor_target = copy.deepcopy(self.actor) 93 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-3) 94 | 95 | self.critic = Critic(state_dim, action_dim).to(device) 96 | self.critic_target = copy.deepcopy(self.critic) 97 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-3) 98 | 99 | self.vae = VAE(state_dim, action_dim, latent_dim, max_action, device).to(device) 100 | self.vae_optimizer = torch.optim.Adam(self.vae.parameters()) 101 | 102 | self.max_action = max_action 103 | self.action_dim = action_dim 104 | self.discount = discount 105 | self.tau = tau 106 | self.lmbda = lmbda 107 | self.device = device 108 | 109 | 110 | def select_action(self, state): 111 | with torch.no_grad(): 112 | lenth = state.shape[0] 113 | state = state.unsqueeze(1).repeat(1,100, 1).reshape(-1,state.shape[-1]) 114 | action = self.actor(state, self.vae.decode(state)) 115 | q1 = self.critic.q1(state, action).reshape(lenth,100,1) 116 | ind = q1.argmax(1) 117 | action = action.reshape(lenth,100,action.shape[-1]) 118 | action = torch.stack([action[i][ind[i]] for i in range(lenth)]).squeeze(1) 119 | return action 120 | 121 | 122 | def train(self, replay_buffer, iterations, batch_size=100): 123 | 124 | for it in range(iterations): 125 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 126 | 127 | recon, mean, std = self.vae(state, action) 128 | recon_loss = F.mse_loss(recon, action) 129 | KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() 130 | vae_loss = recon_loss + 0.5 * KL_loss 131 | 132 | self.vae_optimizer.zero_grad() 133 | vae_loss.backward() 134 | self.vae_optimizer.step() 135 | 136 | with torch.no_grad(): 137 | next_state = torch.repeat_interleave(next_state, 10, 0) 138 | 139 | target_Q1, target_Q2 = self.critic_target(next_state, self.actor_target(next_state, self.vae.decode(next_state))) 140 | target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1. - self.lmbda) * torch.max(target_Q1, target_Q2) 141 | target_Q = target_Q.reshape(batch_size, -1).max(1)[0].reshape(-1, 1) 142 | target_Q = reward + not_done * self.discount * target_Q 143 | 144 | current_Q1, current_Q2 = self.critic(state, action) 145 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 146 | 147 | self.critic_optimizer.zero_grad() 148 | critic_loss.backward() 149 | self.critic_optimizer.step() 150 | 151 | sampled_actions = self.vae.decode(state) 152 | perturbed_actions = self.actor(state, sampled_actions) 153 | 154 | actor_loss = -self.critic.q1(state, perturbed_actions).mean() 155 | 156 | self.actor_optimizer.zero_grad() 157 | actor_loss.backward() 158 | self.actor_optimizer.step() 159 | 160 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 161 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 162 | 163 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 164 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) -------------------------------------------------------------------------------- /algos/offrl/bcq/storage.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, state_dim, action_dim, device, max_size=int(1e6)): 7 | self.max_size = max_size 8 | self.size = 0 9 | 10 | self.state = np.zeros((max_size, state_dim)) 11 | self.action = np.zeros((max_size, action_dim)) 12 | self.next_state = np.zeros((max_size, state_dim)) 13 | self.reward = np.zeros((max_size, 1)) 14 | self.not_done = np.zeros((max_size, 1)) 15 | 16 | self.device = device 17 | 18 | 19 | def sample(self, batch_size): 20 | ind = np.random.randint(0, self.size, size=batch_size) 21 | 22 | return ( 23 | torch.FloatTensor(self.state[ind]).to(self.device), 24 | torch.FloatTensor(self.action[ind]).to(self.device), 25 | torch.FloatTensor(self.next_state[ind]).to(self.device), 26 | torch.FloatTensor(self.reward[ind]).to(self.device), 27 | torch.FloatTensor(self.not_done[ind]).to(self.device) 28 | ) 29 | 30 | 31 | def convert(self, data_dir): 32 | self.state = np.load(data_dir+'states.npy') 33 | self.action = np.load(data_dir+'actions.npy') 34 | self.next_state = np.load(data_dir+'next_states.npy') 35 | self.reward = np.load(data_dir+'rewards.npy') 36 | self.not_done = 1. - np.load(data_dir+'dones.npy') 37 | self.size = self.state.shape[0] 38 | 39 | -------------------------------------------------------------------------------- /algos/offrl/iql/__init__.py: -------------------------------------------------------------------------------- 1 | from .storage import ReplayBuffer 2 | from .module import IQL_Model 3 | from .iql import IQL -------------------------------------------------------------------------------- /algos/offrl/iql/iql.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | from bidexhands.algorithms.offrl.iql import IQL_Model 10 | from bidexhands.algorithms.offrl.iql import ReplayBuffer 11 | 12 | class IQL: 13 | 14 | def __init__(self, 15 | vec_env, 16 | device='cpu', 17 | discount = 0.99, 18 | tau = 0.005, 19 | expectile = 0.7, 20 | batch_size = 100, 21 | max_timesteps = 1000000, 22 | iterations = 10000, 23 | log_dir = '', 24 | datatype = 'expert', 25 | algo = 'iql'): 26 | 27 | self.observation_space = vec_env.observation_space 28 | self.action_space = vec_env.action_space 29 | self.vec_env = vec_env 30 | self.device = device 31 | self.discount = discount 32 | self.tau = tau 33 | self.expectile = expectile 34 | self.batch_size = batch_size 35 | self.max_timesteps = max_timesteps 36 | self.iterations = iterations 37 | self.beta=3.0 38 | self.log_dir = log_dir 39 | self.datatype = datatype 40 | self.algo = algo 41 | self.log_dir = self.log_dir.split(self.algo)[0]+self.algo+'/'+self.datatype+'/' 42 | self.data_dir = self.log_dir.split(self.algo)[0].split('logs') 43 | self.data_dir = self.data_dir[0]+'data'+self.data_dir[1]+self.datatype+'/' 44 | self.test_step = 40000 45 | 46 | if not os.path.exists(self.log_dir): 47 | os.makedirs(self.log_dir) 48 | time.sleep(np.random.rand()*2) 49 | order = len(os.listdir(self.log_dir)) 50 | self.reward_log = open(self.log_dir+str(order)+'.log','w') 51 | 52 | def run(self, num_learning_iterations, log_interval=1): 53 | 54 | current_obs = self.vec_env.reset() 55 | state_dim = self.observation_space.shape[0] 56 | action_dim = self.action_space.shape[0] 57 | max_action = float(self.action_space.high[0]) 58 | 59 | replay_buffer = ReplayBuffer(state_dim, action_dim, self.device) 60 | replay_buffer.convert(self.data_dir) 61 | 62 | policy = IQL_Model(state_dim, action_dim, max_action, self.device, self.discount, self.tau, self.expectile, self.beta) 63 | 64 | for t in range(int(self.max_timesteps/self.iterations)+1): 65 | policy.train(replay_buffer, self.iterations, self.batch_size) 66 | 67 | reward_sum = [] 68 | cur_reward_sum = torch.zeros(self.vec_env.num_envs, dtype=torch.float, device=self.device) 69 | current_obs = self.vec_env.reset() 70 | for _ in range(int(self.test_step/self.vec_env.num_envs)): 71 | actions = policy.select_action(current_obs) 72 | next_obs, rews, dones, infos = self.vec_env.step(actions) 73 | current_obs.copy_(next_obs) 74 | cur_reward_sum[:] += rews 75 | new_ids = (dones > 0).nonzero(as_tuple=False) 76 | reward_sum.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) 77 | cur_reward_sum[new_ids] = 0 78 | 79 | self.reward_log.write(str(sum(reward_sum)/len(reward_sum))+'\n') 80 | self.reward_log.flush() -------------------------------------------------------------------------------- /algos/offrl/iql/storage.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, state_dim, action_dim, device, max_size=int(1e6)): 7 | self.max_size = max_size 8 | self.size = 0 9 | 10 | self.state = np.zeros((max_size, state_dim)) 11 | self.action = np.zeros((max_size, action_dim)) 12 | self.next_state = np.zeros((max_size, state_dim)) 13 | self.reward = np.zeros((max_size, 1)) 14 | self.not_done = np.zeros((max_size, 1)) 15 | 16 | self.device = device 17 | 18 | 19 | def sample(self, batch_size): 20 | ind = np.random.randint(0, self.size, size=batch_size) 21 | 22 | return ( 23 | torch.FloatTensor(self.state[ind]).to(self.device), 24 | torch.FloatTensor(self.action[ind]).to(self.device), 25 | torch.FloatTensor(self.next_state[ind]).to(self.device), 26 | torch.FloatTensor(self.reward[ind]).to(self.device), 27 | torch.FloatTensor(self.not_done[ind]).to(self.device) 28 | ) 29 | 30 | 31 | def convert(self, data_dir): 32 | self.state = np.load(data_dir+'states.npy') 33 | self.action = np.load(data_dir+'actions.npy') 34 | self.next_state = np.load(data_dir+'next_states.npy') 35 | self.reward = np.load(data_dir+'rewards.npy') 36 | self.not_done = 1. - np.load(data_dir+'dones.npy') 37 | self.size = self.state.shape[0] 38 | 39 | -------------------------------------------------------------------------------- /algos/offrl/ppo_collect/__init__.py: -------------------------------------------------------------------------------- 1 | from .storage import RolloutStorage 2 | from .module import ActorCritic 3 | from .ppo_collect import PPO 4 | -------------------------------------------------------------------------------- /algos/offrl/ppo_collect/module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.distributions import MultivariateNormal 6 | 7 | 8 | class ActorCritic(nn.Module): 9 | 10 | def __init__(self, obs_shape, states_shape, actions_shape, initial_std, model_cfg, asymmetric=False): 11 | super(ActorCritic, self).__init__() 12 | 13 | self.asymmetric = asymmetric 14 | 15 | if model_cfg is None: 16 | actor_hidden_dim = [256, 256, 256] 17 | critic_hidden_dim = [256, 256, 256] 18 | activation = get_activation("selu") 19 | else: 20 | actor_hidden_dim = model_cfg['pi_hid_sizes'] 21 | critic_hidden_dim = model_cfg['vf_hid_sizes'] 22 | activation = get_activation(model_cfg['activation']) 23 | 24 | # Policy 25 | actor_layers = [] 26 | actor_layers.append(nn.Linear(*obs_shape, actor_hidden_dim[0])) 27 | actor_layers.append(activation) 28 | for l in range(len(actor_hidden_dim)): 29 | if l == len(actor_hidden_dim) - 1: 30 | actor_layers.append(nn.Linear(actor_hidden_dim[l], *actions_shape)) 31 | else: 32 | actor_layers.append(nn.Linear(actor_hidden_dim[l], actor_hidden_dim[l + 1])) 33 | actor_layers.append(activation) 34 | self.actor = nn.Sequential(*actor_layers) 35 | 36 | # Value function 37 | critic_layers = [] 38 | if self.asymmetric: 39 | critic_layers.append(nn.Linear(*states_shape, critic_hidden_dim[0])) 40 | else: 41 | critic_layers.append(nn.Linear(*obs_shape, critic_hidden_dim[0])) 42 | critic_layers.append(activation) 43 | for l in range(len(critic_hidden_dim)): 44 | if l == len(critic_hidden_dim) - 1: 45 | critic_layers.append(nn.Linear(critic_hidden_dim[l], 1)) 46 | else: 47 | critic_layers.append(nn.Linear(critic_hidden_dim[l], critic_hidden_dim[l + 1])) 48 | critic_layers.append(activation) 49 | self.critic = nn.Sequential(*critic_layers) 50 | 51 | print(self.actor) 52 | print(self.critic) 53 | 54 | # Action noise 55 | self.log_std = nn.Parameter(np.log(initial_std) * torch.ones(*actions_shape)) 56 | 57 | # Initialize the weights like in stable baselines 58 | actor_weights = [np.sqrt(2)] * len(actor_hidden_dim) 59 | actor_weights.append(0.01) 60 | critic_weights = [np.sqrt(2)] * len(critic_hidden_dim) 61 | critic_weights.append(1.0) 62 | self.init_weights(self.actor, actor_weights) 63 | self.init_weights(self.critic, critic_weights) 64 | 65 | @staticmethod 66 | def init_weights(sequential, scales): 67 | [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in 68 | enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))] 69 | 70 | def forward(self): 71 | raise NotImplementedError 72 | 73 | def act(self, observations, states): 74 | actions_mean = self.actor(observations) 75 | 76 | covariance = torch.diag(self.log_std.exp() * self.log_std.exp()) 77 | distribution = MultivariateNormal(actions_mean, scale_tril=covariance) 78 | 79 | actions = distribution.sample() 80 | actions_log_prob = distribution.log_prob(actions) 81 | 82 | if self.asymmetric: 83 | value = self.critic(states) 84 | else: 85 | value = self.critic(observations) 86 | 87 | return actions.detach(), actions_log_prob.detach(), value.detach(), actions_mean.detach(), self.log_std.repeat(actions_mean.shape[0], 1).detach() 88 | 89 | def act_inference(self, observations): 90 | actions_mean = self.actor(observations) 91 | return actions_mean 92 | 93 | def evaluate(self, observations, states, actions): 94 | actions_mean = self.actor(observations) 95 | 96 | covariance = torch.diag(self.log_std.exp() * self.log_std.exp()) 97 | distribution = MultivariateNormal(actions_mean, scale_tril=covariance) 98 | 99 | actions_log_prob = distribution.log_prob(actions) 100 | entropy = distribution.entropy() 101 | 102 | if self.asymmetric: 103 | value = self.critic(states) 104 | else: 105 | value = self.critic(observations) 106 | 107 | return actions_log_prob, entropy, value, actions_mean, self.log_std.repeat(actions_mean.shape[0], 1) 108 | 109 | 110 | def get_activation(act_name): 111 | if act_name == "elu": 112 | return nn.ELU() 113 | elif act_name == "selu": 114 | return nn.SELU() 115 | elif act_name == "relu": 116 | return nn.ReLU() 117 | elif act_name == "crelu": 118 | return nn.ReLU() 119 | elif act_name == "lrelu": 120 | return nn.LeakyReLU() 121 | elif act_name == "tanh": 122 | return nn.Tanh() 123 | elif act_name == "sigmoid": 124 | return nn.Sigmoid() 125 | else: 126 | print("invalid activation function!") 127 | return None 128 | -------------------------------------------------------------------------------- /algos/offrl/ppo_collect/storage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import BatchSampler, SequentialSampler, SubsetRandomSampler 3 | 4 | 5 | class RolloutStorage: 6 | 7 | def __init__(self, num_envs, num_transitions_per_env, obs_shape, states_shape, actions_shape, device='cpu', sampler='sequential'): 8 | 9 | self.device = device 10 | self.sampler = sampler 11 | 12 | # Core 13 | self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device) 14 | self.states = torch.zeros(num_transitions_per_env, num_envs, *states_shape, device=self.device) 15 | self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 16 | self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 17 | self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte() 18 | 19 | # For PPO 20 | self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 21 | self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 22 | self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 23 | self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 24 | self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 25 | self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 26 | 27 | self.num_transitions_per_env = num_transitions_per_env 28 | self.num_envs = num_envs 29 | 30 | self.step = 0 31 | 32 | def add_transitions(self, observations, states, actions, rewards, dones, values, actions_log_prob, mu, sigma): 33 | if self.step >= self.num_transitions_per_env: 34 | raise AssertionError("Rollout buffer overflow") 35 | 36 | self.observations[self.step].copy_(observations) 37 | self.states[self.step].copy_(states) 38 | self.actions[self.step].copy_(actions) 39 | self.rewards[self.step].copy_(rewards.view(-1, 1)) 40 | self.dones[self.step].copy_(dones.view(-1, 1)) 41 | self.values[self.step].copy_(values) 42 | self.actions_log_prob[self.step].copy_(actions_log_prob.view(-1, 1)) 43 | self.mu[self.step].copy_(mu) 44 | self.sigma[self.step].copy_(sigma) 45 | 46 | self.step += 1 47 | 48 | def clear(self): 49 | self.step = 0 50 | 51 | def compute_returns(self, last_values, gamma, lam): 52 | advantage = 0 53 | for step in reversed(range(self.num_transitions_per_env)): 54 | if step == self.num_transitions_per_env - 1: 55 | next_values = last_values 56 | else: 57 | next_values = self.values[step + 1] 58 | next_is_not_terminal = 1.0 - self.dones[step].float() 59 | delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step] 60 | advantage = delta + next_is_not_terminal * gamma * lam * advantage 61 | self.returns[step] = advantage + self.values[step] 62 | 63 | # Compute and normalize the advantages 64 | self.advantages = self.returns - self.values 65 | self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8) 66 | 67 | def get_statistics(self): 68 | done = self.dones.cpu() 69 | done[-1] = 1 70 | flat_dones = done.permute(1, 0, 2).reshape(-1, 1) 71 | done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0])) 72 | trajectory_lengths = (done_indices[1:] - done_indices[:-1]) 73 | return trajectory_lengths.float().mean(), self.rewards.mean() 74 | 75 | def mini_batch_generator(self, num_mini_batches): 76 | batch_size = self.num_envs * self.num_transitions_per_env 77 | mini_batch_size = batch_size // num_mini_batches 78 | 79 | if self.sampler == "sequential": 80 | # For physics-based RL, each environment is already randomized. There is no value to doing random sampling 81 | # but a lot of CPU overhead during the PPO process. So, we can just switch to a sequential sampler instead 82 | subset = SequentialSampler(range(batch_size)) 83 | elif self.sampler == "random": 84 | subset = SubsetRandomSampler(range(batch_size)) 85 | 86 | batch = BatchSampler(subset, mini_batch_size, drop_last=True) 87 | return batch 88 | -------------------------------------------------------------------------------- /algos/offrl/td3_bc/__init__.py: -------------------------------------------------------------------------------- 1 | from .storage import ReplayBuffer 2 | from .module import TD3_BC_Model 3 | from .td3_bc import TD3_BC -------------------------------------------------------------------------------- /algos/offrl/td3_bc/module.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class Actor(nn.Module): 8 | def __init__(self, state_dim, action_dim, max_action): 9 | super(Actor, self).__init__() 10 | 11 | self.l1 = nn.Linear(state_dim, 256) 12 | self.l2 = nn.Linear(256, 256) 13 | self.l3 = nn.Linear(256, action_dim) 14 | 15 | self.max_action = max_action 16 | 17 | 18 | def forward(self, state): 19 | a = F.relu(self.l1(state)) 20 | a = F.relu(self.l2(a)) 21 | return self.max_action * torch.tanh(self.l3(a)) 22 | 23 | 24 | class Critic(nn.Module): 25 | def __init__(self, state_dim, action_dim): 26 | super(Critic, self).__init__() 27 | 28 | self.l1 = nn.Linear(state_dim + action_dim, 256) 29 | self.l2 = nn.Linear(256, 256) 30 | self.l3 = nn.Linear(256, 1) 31 | 32 | self.l4 = nn.Linear(state_dim + action_dim, 256) 33 | self.l5 = nn.Linear(256, 256) 34 | self.l6 = nn.Linear(256, 1) 35 | 36 | 37 | def forward(self, state, action): 38 | sa = torch.cat([state, action], 1) 39 | 40 | q1 = F.relu(self.l1(sa)) 41 | q1 = F.relu(self.l2(q1)) 42 | q1 = self.l3(q1) 43 | 44 | q2 = F.relu(self.l4(sa)) 45 | q2 = F.relu(self.l5(q2)) 46 | q2 = self.l6(q2) 47 | return q1, q2 48 | 49 | 50 | def Q1(self, state, action): 51 | sa = torch.cat([state, action], 1) 52 | 53 | q1 = F.relu(self.l1(sa)) 54 | q1 = F.relu(self.l2(q1)) 55 | q1 = self.l3(q1) 56 | return q1 57 | 58 | 59 | class TD3_BC_Model(object): 60 | def __init__( 61 | self, 62 | state_dim, 63 | action_dim, 64 | max_action, 65 | device, 66 | discount=0.99, 67 | tau=0.005, 68 | policy_noise=0.2, 69 | noise_clip=0.5, 70 | policy_freq=2, 71 | alpha=2.5, 72 | ): 73 | 74 | self.actor = Actor(state_dim, action_dim, max_action).to(device) 75 | self.actor_target = copy.deepcopy(self.actor) 76 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) 77 | 78 | self.critic = Critic(state_dim, action_dim).to(device) 79 | self.critic_target = copy.deepcopy(self.critic) 80 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 81 | 82 | self.max_action = max_action 83 | self.discount = discount 84 | self.tau = tau 85 | self.policy_noise = policy_noise 86 | self.noise_clip = noise_clip 87 | self.policy_freq = policy_freq 88 | self.alpha = alpha 89 | 90 | self.total_it = 0 91 | 92 | def select_action(self, state): 93 | with torch.no_grad(): 94 | action = self.actor(state) 95 | return action 96 | 97 | def train(self, replay_buffer, interaction, batch_size=256): 98 | 99 | for _ in range(interaction): 100 | self.total_it += 1 101 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 102 | 103 | with torch.no_grad(): 104 | 105 | noise = ( 106 | torch.randn_like(action) * self.policy_noise 107 | ).clamp(-self.noise_clip, self.noise_clip) 108 | 109 | next_action = ( 110 | self.actor_target(next_state) + noise 111 | ).clamp(-self.max_action, self.max_action) 112 | 113 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 114 | target_Q = torch.min(target_Q1, target_Q2) 115 | target_Q = reward + not_done * self.discount * target_Q 116 | 117 | current_Q1, current_Q2 = self.critic(state, action) 118 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 119 | self.critic_optimizer.zero_grad() 120 | critic_loss.backward() 121 | self.critic_optimizer.step() 122 | 123 | if self.total_it % self.policy_freq == 0: 124 | 125 | pi = self.actor(state) 126 | Q = self.critic.Q1(state, pi) 127 | lmbda = self.alpha/Q.abs().mean().detach() 128 | 129 | actor_loss = -lmbda * Q.mean() + F.mse_loss(pi, action) 130 | self.actor_optimizer.zero_grad() 131 | actor_loss.backward() 132 | self.actor_optimizer.step() 133 | 134 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 135 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 136 | 137 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 138 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 139 | 140 | -------------------------------------------------------------------------------- /algos/offrl/td3_bc/storage.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, state_dim, action_dim, device, max_size=int(1e6)): 7 | self.max_size = max_size 8 | self.size = 0 9 | 10 | self.state = np.zeros((max_size, state_dim)) 11 | self.action = np.zeros((max_size, action_dim)) 12 | self.next_state = np.zeros((max_size, state_dim)) 13 | self.reward = np.zeros((max_size, 1)) 14 | self.not_done = np.zeros((max_size, 1)) 15 | 16 | self.device = device 17 | 18 | 19 | def sample(self, batch_size): 20 | ind = np.random.randint(0, self.size, size=batch_size) 21 | 22 | return ( 23 | torch.FloatTensor(self.state[ind]).to(self.device), 24 | torch.FloatTensor(self.action[ind]).to(self.device), 25 | torch.FloatTensor(self.next_state[ind]).to(self.device), 26 | torch.FloatTensor(self.reward[ind]).to(self.device), 27 | torch.FloatTensor(self.not_done[ind]).to(self.device) 28 | ) 29 | 30 | 31 | def convert(self, data_dir): 32 | self.state = np.load(data_dir+'states.npy') 33 | self.action = np.load(data_dir+'actions.npy') 34 | self.next_state = np.load(data_dir+'next_states.npy') 35 | self.reward = np.load(data_dir+'rewards.npy') 36 | self.not_done = 1. - np.load(data_dir+'dones.npy') 37 | self.size = self.state.shape[0] 38 | 39 | -------------------------------------------------------------------------------- /algos/offrl/td3_bc/td3_bc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | from bidexhands.algorithms.offrl.td3_bc import TD3_BC_Model 10 | from bidexhands.algorithms.offrl.td3_bc import ReplayBuffer 11 | 12 | class TD3_BC: 13 | 14 | def __init__(self, 15 | vec_env, 16 | device='cpu', 17 | discount = 0.99, 18 | tau = 0.005, 19 | alpha = 2.5, 20 | policy_freq = 2, 21 | batch_size = 250, 22 | max_timesteps = 1000000, 23 | iterations = 10000, 24 | log_dir = '', 25 | datatype = 'expert', 26 | algo = 'td3_bc'): 27 | 28 | self.observation_space = vec_env.observation_space 29 | self.action_space = vec_env.action_space 30 | self.vec_env = vec_env 31 | self.device = device 32 | self.discount = discount 33 | self.tau = tau 34 | self.policy_freq = policy_freq 35 | self.alpha = alpha 36 | self.batch_size = batch_size 37 | self.max_timesteps = max_timesteps 38 | self.iterations = iterations 39 | self.log_dir = log_dir 40 | self.datatype = datatype 41 | self.algo = algo 42 | self.log_dir = self.log_dir.split(self.algo)[0]+self.algo+'/'+self.datatype+'/' 43 | self.data_dir = self.log_dir.split(self.algo)[0].split('logs') 44 | self.data_dir = self.data_dir[0]+'data'+self.data_dir[1]+self.datatype+'/' 45 | self.test_step = 40000 46 | 47 | if not os.path.exists(self.log_dir): 48 | os.makedirs(self.log_dir) 49 | time.sleep(np.random.rand()*2) 50 | order = len(os.listdir(self.log_dir)) 51 | self.reward_log = open(self.log_dir+str(order)+'.log','w') 52 | 53 | 54 | def run(self, num_learning_iterations, log_interval=1): 55 | 56 | current_obs = self.vec_env.reset() 57 | state_dim = self.observation_space.shape[0] 58 | action_dim = self.action_space.shape[0] 59 | max_action = float(self.action_space.high[0]) 60 | policy_noise = 0.2 * max_action 61 | noise_clip = 0.5 * max_action 62 | 63 | policy = TD3_BC_Model(state_dim, action_dim, max_action, self.device, self.discount, self.tau, policy_noise, noise_clip, self.policy_freq, self.alpha) 64 | 65 | replay_buffer = ReplayBuffer(state_dim, action_dim,self.device) 66 | replay_buffer.convert(self.data_dir) 67 | 68 | for t in range(int(self.max_timesteps/self.iterations)+1): 69 | policy.train(replay_buffer, self.iterations, self.batch_size) 70 | 71 | reward_sum = [] 72 | cur_reward_sum = torch.zeros(self.vec_env.num_envs, dtype=torch.float, device=self.device) 73 | current_obs = self.vec_env.reset() 74 | for _ in range(int(self.test_step/self.vec_env.num_envs)): 75 | actions = policy.select_action(current_obs) 76 | next_obs, rews, dones, infos = self.vec_env.step(actions) 77 | current_obs.copy_(next_obs) 78 | cur_reward_sum[:] += rews 79 | new_ids = (dones > 0).nonzero(as_tuple=False) 80 | reward_sum.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) 81 | cur_reward_sum[new_ids] = 0 82 | 83 | self.reward_log.write(str(sum(reward_sum)/len(reward_sum))+'\n') 84 | self.reward_log.flush() 85 | 86 | 87 | -------------------------------------------------------------------------------- /algos/planner/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from isaacgym import gymapi 5 | from scipy.spatial.transform import Rotation 6 | from scipy.spatial.transform import Slerp 7 | 8 | class BASE: 9 | def __init__(self, vec_env, cfg, ): 10 | self.vec_env = vec_env 11 | self.env_name = cfg['name'].rsplit('@', 1)[0] 12 | self.traj_path = cfg['traj_path'] 13 | self.traj_basedir = os.path.dirname(self.traj_path) 14 | self.traj_name = os.path.basename(self.traj_path).split('.')[0][7:] 15 | self.dummy_traj = pickle.load(open(self.traj_path, 'rb'))['trajectory'] 16 | 17 | def transfer(self, env_id=0): 18 | if self.env_name in ['frankakitchen_v1@gripper_open_hingecabinet_left', 'frankakitchen_v1@gripper_close_hingecabinet_left', 19 | 'frankakitchen_v1@gripper_open_microwave', 'frankakitchen_v1@gripper_close_microwave', ]: 20 | quat_default = np.array([0., 1., 0., 0.]) 21 | elif self.env_name in ['frankakitchen_v1@gripper_open_slidecabinet', 'frankakitchen_v1@gripper_close_slidecabinet', 22 | 'frankakitchen_v1@gripper_push_kettle', 'frankakitchen_v1@gripper_pickup_kettle', 23 | 'partmanip@gripper_open_wooddrawer_middle', 'partmanip@gripper_close_wooddrawer_middle', 24 | 'maniskill@gripper_turn_leftfaucet', 'maniskill@gripper_turn_rightfaucet']: 25 | quat_default = np.array([0.707, 0.707, 0., 0.]) 26 | else: 27 | raise NotImplementedError 28 | is_attached = False 29 | related_tm = None 30 | self.execute_traj = [] ## pandahand-atrtactor pose, gripper effort (positive leads to open) 31 | for i_step, act in enumerate(self.dummy_traj): # print(f"slidecabinet: {act['rigid_bodies'][env_id, act['attached_body_handle'][env_id], :7]}") 32 | if act['attached_info_indices'] == -1: 33 | pos = act['panda_hand'][env_id, :3] 34 | # quat = act['panda_hand'][env_id, 3:7] 35 | quat = quat_default 36 | attractor_pose = gymapi.Transform() 37 | attractor_pose.p = gymapi.Vec3(*pos) 38 | 39 | # rot_mat = Rotation.from_quat(quat).as_matrix() @ Rotation.from_euler('xyz', [3.14, -1.57, 3.14]).as_matrix() 40 | rot_mat = Rotation.from_quat(quat).as_matrix() 41 | rot_quat = Rotation.from_matrix(rot_mat).as_quat() 42 | attractor_pose.r = gymapi.Quat(*rot_quat) 43 | gripper_effort = 100. 44 | else: 45 | attached_info = act['attach_info'][act['attached_info_indices'][env_id]] 46 | actor_trans = act['attach_info'][act['attached_info_indices'][env_id]]['object_trans'] 47 | actor_rot = act['attach_info'][act['attached_info_indices'][env_id]]['object_rot'] 48 | attach_trans = attached_info['attach_info']['translation'] 49 | attach_rot = attached_info['attach_info']['rotation_matrix'] @ Rotation.from_euler('xyz', [0., 1.57, 0.]).as_matrix() 50 | 51 | attach_trans = actor_rot @ attach_trans + actor_trans 52 | attach_rot = actor_rot @ attach_rot 53 | attach_trans = attach_trans - attach_rot @ np.array([0., 0., 0.058]) 54 | attach_tm = np.eye(4) 55 | attach_tm[:3, :3] = attach_rot 56 | attach_tm[:3, 3] = attach_trans 57 | attach_rot = Rotation.from_matrix(attach_rot).as_quat() 58 | if not is_attached: 59 | is_attached = True 60 | body_trans = act['rigid_bodies'][env_id, act['attached_body_handle'][env_id], :3] 61 | body_rot = act['rigid_bodies'][env_id, act['attached_body_handle'][env_id], 3:7] 62 | body_rot = Rotation.from_quat(body_rot).as_matrix() 63 | body_tm = np.eye(4) 64 | body_tm[:3, :3] = body_rot 65 | body_tm[:3, 3] = body_trans 66 | related_tm = attach_tm @ np.linalg.inv(body_tm) 67 | else: 68 | body_trans = act['rigid_bodies'][env_id, act['attached_body_handle'][env_id], :3] 69 | body_rot = act['rigid_bodies'][env_id, act['attached_body_handle'][env_id], 3:7] 70 | body_rot = Rotation.from_quat(body_rot).as_matrix() 71 | # attach_trans = body_trans + related_rot @ related_trans 72 | # attach_rot = Rotation.from_matrix(related_rot @ body_rot).as_quat() 73 | body_tm_t = np.eye(4) 74 | body_tm_t[:3, :3] = body_rot 75 | body_tm_t[:3, 3] = body_trans 76 | body_tm_t = body_tm_t @ np.linalg.inv(body_tm) 77 | attach_tm = body_tm_t @ related_tm @ body_tm 78 | attach_trans = attach_tm[:3, 3] 79 | attach_rot = Rotation.from_matrix(attach_tm[:3, :3]).as_quat() 80 | 81 | attractor_pose = gymapi.Transform() 82 | attractor_pose.p = gymapi.Vec3(*attach_trans) 83 | attractor_pose.r = gymapi.Quat(*attach_rot) 84 | gripper_effort = -500. 85 | self.execute_traj.append([attractor_pose, gripper_effort]) 86 | 87 | # smooth the excute_traj on attractor_pose 88 | SMOOTH_STEPS = 10 89 | transition_flag = False 90 | execute_traj_smooth = [] 91 | for i_step in range(len(self.execute_traj) - 1): 92 | lower_pos = np.array([self.execute_traj[i_step][0].p.x, self.execute_traj[i_step][0].p.y, self.execute_traj[i_step][0].p.z]) 93 | upper_pos = np.array([self.execute_traj[i_step + 1][0].p.x, self.execute_traj[i_step + 1][0].p.y, self.execute_traj[i_step + 1][0].p.z]) 94 | lower_quat = np.array([self.execute_traj[i_step][0].r.x, self.execute_traj[i_step][0].r.y, self.execute_traj[i_step][0].r.z, self.execute_traj[i_step][0].r.w]) 95 | upper_quat = np.array([self.execute_traj[i_step + 1][0].r.x, self.execute_traj[i_step + 1][0].r.y, self.execute_traj[i_step + 1][0].r.z, self.execute_traj[i_step + 1][0].r.w]) 96 | interp_rot = Slerp([0, 1], Rotation.from_quat([lower_quat, upper_quat])) 97 | smooth_steps = SMOOTH_STEPS 98 | if self.dummy_traj[i_step + 1]['attached_info_indices'] != -1 and not transition_flag: 99 | smooth_steps = 100 100 | transition_flag = True 101 | for i_smooth in range(smooth_steps): 102 | i_smooth_pos = lower_pos + (upper_pos - lower_pos) * i_smooth / smooth_steps 103 | i_smooth_quat = interp_rot(i_smooth / smooth_steps).as_quat() 104 | attractor_pose = gymapi.Transform() 105 | attractor_pose.p = gymapi.Vec3(*i_smooth_pos) 106 | attractor_pose.r = gymapi.Quat(*i_smooth_quat) 107 | gripper_effort = self.execute_traj[i_step][1] 108 | execute_traj_smooth.append([attractor_pose, gripper_effort]) 109 | self.execute_traj = execute_traj_smooth 110 | 111 | def run(self): 112 | # transfer to get the excute_traj 113 | self.transfer() 114 | scores = [] 115 | for i_step, act in enumerate(self.execute_traj): 116 | infos = self.vec_env.task.step_plan(act) 117 | print(f'Step: {i_step} | Score: {infos["success_scores"].cpu().item()}') 118 | scores.append(infos["success_scores"].cpu().item()) 119 | execres = { 120 | 'scores': scores, 121 | } 122 | execres_path = os.path.join(self.traj_basedir, f'execres_{self.traj_name}.pkl') 123 | pickle.dump(execres, open(execres_path, 'wb')) 124 | print(f'Save execres to {execres_path}') -------------------------------------------------------------------------------- /algos/rl/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from .storage import RolloutStorage 2 | from .module import ActorCritic 3 | from .ppo import PPO 4 | -------------------------------------------------------------------------------- /algos/rl/ppo/module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.distributions import MultivariateNormal 6 | 7 | 8 | class ActorCritic(nn.Module): 9 | 10 | def __init__(self, obs_shape, states_shape, actions_shape, initial_std, model_cfg, asymmetric=False): 11 | super(ActorCritic, self).__init__() 12 | 13 | self.asymmetric = asymmetric 14 | 15 | if model_cfg is None: 16 | actor_hidden_dim = [256, 256, 256] 17 | critic_hidden_dim = [256, 256, 256] 18 | activation = get_activation("selu") 19 | else: 20 | actor_hidden_dim = model_cfg['pi_hid_sizes'] 21 | critic_hidden_dim = model_cfg['vf_hid_sizes'] 22 | activation = get_activation(model_cfg['activation']) 23 | 24 | # Policy 25 | actor_layers = [] 26 | actor_layers.append(nn.Linear(*obs_shape, actor_hidden_dim[0])) 27 | actor_layers.append(activation) 28 | for l in range(len(actor_hidden_dim)): 29 | if l == len(actor_hidden_dim) - 1: 30 | actor_layers.append(nn.Linear(actor_hidden_dim[l], *actions_shape)) 31 | else: 32 | actor_layers.append(nn.Linear(actor_hidden_dim[l], actor_hidden_dim[l + 1])) 33 | actor_layers.append(activation) 34 | self.actor = nn.Sequential(*actor_layers) 35 | 36 | # Value function 37 | critic_layers = [] 38 | if self.asymmetric: 39 | critic_layers.append(nn.Linear(*states_shape, critic_hidden_dim[0])) 40 | else: 41 | critic_layers.append(nn.Linear(*obs_shape, critic_hidden_dim[0])) 42 | critic_layers.append(activation) 43 | for l in range(len(critic_hidden_dim)): 44 | if l == len(critic_hidden_dim) - 1: 45 | critic_layers.append(nn.Linear(critic_hidden_dim[l], 1)) 46 | else: 47 | critic_layers.append(nn.Linear(critic_hidden_dim[l], critic_hidden_dim[l + 1])) 48 | critic_layers.append(activation) 49 | self.critic = nn.Sequential(*critic_layers) 50 | 51 | print(self.actor) 52 | print(self.critic) 53 | 54 | # Action noise 55 | self.log_std = nn.Parameter(np.log(initial_std) * torch.ones(*actions_shape)) 56 | 57 | # Initialize the weights like in stable baselines 58 | actor_weights = [np.sqrt(2)] * len(actor_hidden_dim) 59 | actor_weights.append(0.01) 60 | critic_weights = [np.sqrt(2)] * len(critic_hidden_dim) 61 | critic_weights.append(1.0) 62 | self.init_weights(self.actor, actor_weights) 63 | self.init_weights(self.critic, critic_weights) 64 | 65 | @staticmethod 66 | def init_weights(sequential, scales): 67 | [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in 68 | enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))] 69 | 70 | def forward(self): 71 | raise NotImplementedError 72 | 73 | def act(self, observations, states): 74 | actions_mean = self.actor(observations) 75 | 76 | covariance = torch.diag(self.log_std.exp() * self.log_std.exp()) 77 | distribution = MultivariateNormal(actions_mean, scale_tril=covariance) 78 | 79 | actions = distribution.sample() 80 | actions_log_prob = distribution.log_prob(actions) 81 | 82 | if self.asymmetric: 83 | value = self.critic(states) 84 | else: 85 | value = self.critic(observations) 86 | 87 | return actions.detach(), actions_log_prob.detach(), value.detach(), actions_mean.detach(), self.log_std.repeat(actions_mean.shape[0], 1).detach() 88 | 89 | def act_inference(self, observations): 90 | actions_mean = self.actor(observations) 91 | return actions_mean 92 | 93 | def evaluate(self, observations, states, actions): 94 | actions_mean = self.actor(observations) 95 | 96 | covariance = torch.diag(self.log_std.exp() * self.log_std.exp()) 97 | distribution = MultivariateNormal(actions_mean, scale_tril=covariance) 98 | 99 | actions_log_prob = distribution.log_prob(actions) 100 | entropy = distribution.entropy() 101 | 102 | if self.asymmetric: 103 | value = self.critic(states) 104 | else: 105 | value = self.critic(observations) 106 | 107 | return actions_log_prob, entropy, value, actions_mean, self.log_std.repeat(actions_mean.shape[0], 1) 108 | 109 | 110 | def get_activation(act_name): 111 | if act_name == "elu": 112 | return nn.ELU() 113 | elif act_name == "selu": 114 | return nn.SELU() 115 | elif act_name == "relu": 116 | return nn.ReLU() 117 | elif act_name == "crelu": 118 | return nn.ReLU() 119 | elif act_name == "lrelu": 120 | return nn.LeakyReLU() 121 | elif act_name == "tanh": 122 | return nn.Tanh() 123 | elif act_name == "sigmoid": 124 | return nn.Sigmoid() 125 | else: 126 | print("invalid activation function!") 127 | return None 128 | -------------------------------------------------------------------------------- /algos/rl/ppo/storage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import BatchSampler, SequentialSampler, SubsetRandomSampler 3 | 4 | 5 | class RolloutStorage: 6 | 7 | def __init__(self, num_envs, num_transitions_per_env, obs_shape, states_shape, actions_shape, device='cpu', sampler='sequential'): 8 | 9 | self.device = device 10 | self.sampler = sampler 11 | 12 | # Core 13 | self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device) 14 | self.states = torch.zeros(num_transitions_per_env, num_envs, *states_shape, device=self.device) 15 | self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 16 | self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 17 | self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte() 18 | 19 | # For PPO 20 | self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 21 | self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 22 | self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 23 | self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) 24 | self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 25 | self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) 26 | 27 | self.num_transitions_per_env = num_transitions_per_env 28 | self.num_envs = num_envs 29 | 30 | self.step = 0 31 | 32 | def add_transitions(self, observations, states, actions, rewards, dones, values, actions_log_prob, mu, sigma): 33 | if self.step >= self.num_transitions_per_env: 34 | raise AssertionError("Rollout buffer overflow") 35 | 36 | self.observations[self.step].copy_(observations) 37 | self.states[self.step].copy_(states) 38 | self.actions[self.step].copy_(actions) 39 | self.rewards[self.step].copy_(rewards.view(-1, 1)) 40 | self.dones[self.step].copy_(dones.view(-1, 1)) 41 | self.values[self.step].copy_(values) 42 | self.actions_log_prob[self.step].copy_(actions_log_prob.view(-1, 1)) 43 | self.mu[self.step].copy_(mu) 44 | self.sigma[self.step].copy_(sigma) 45 | 46 | self.step += 1 47 | 48 | def clear(self): 49 | self.step = 0 50 | 51 | def compute_returns(self, last_values, gamma, lam): 52 | advantage = 0 53 | for step in reversed(range(self.num_transitions_per_env)): 54 | if step == self.num_transitions_per_env - 1: 55 | next_values = last_values 56 | else: 57 | next_values = self.values[step + 1] 58 | next_is_not_terminal = 1.0 - self.dones[step].float() 59 | delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step] 60 | advantage = delta + next_is_not_terminal * gamma * lam * advantage 61 | self.returns[step] = advantage + self.values[step] 62 | 63 | # Compute and normalize the advantages 64 | self.advantages = self.returns - self.values 65 | self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8) 66 | 67 | def get_statistics(self): 68 | done = self.dones.cpu() 69 | done[-1] = 1 70 | flat_dones = done.permute(1, 0, 2).reshape(-1, 1) 71 | done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0])) 72 | trajectory_lengths = (done_indices[1:] - done_indices[:-1]) 73 | return trajectory_lengths.float().mean(), self.rewards.mean() 74 | 75 | def mini_batch_generator(self, num_mini_batches): 76 | batch_size = self.num_envs * self.num_transitions_per_env 77 | mini_batch_size = batch_size // num_mini_batches 78 | 79 | if self.sampler == "sequential": 80 | # For physics-based RL, each environment is already randomized. There is no value to doing random sampling 81 | # but a lot of CPU overhead during the PPO process. So, we can just switch to a sequential sampler instead 82 | subset = SequentialSampler(range(batch_size)) 83 | elif self.sampler == "random": 84 | subset = SubsetRandomSampler(range(batch_size)) 85 | 86 | batch = BatchSampler(subset, mini_batch_size, drop_last=True) 87 | return batch 88 | -------------------------------------------------------------------------------- /algos/utils/cnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .util import init 3 | 4 | """CNN Modules and utils.""" 5 | 6 | class Flatten(nn.Module): 7 | def forward(self, x): 8 | return x.view(x.size(0), -1) 9 | 10 | 11 | class CNNLayer(nn.Module): 12 | def __init__(self, obs_shape, hidden_size, use_orthogonal, use_ReLU, kernel_size=3, stride=1): 13 | super(CNNLayer, self).__init__() 14 | 15 | active_func = [nn.Tanh(), nn.ReLU()][use_ReLU] 16 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 17 | gain = nn.init.calculate_gain(['tanh', 'relu'][use_ReLU]) 18 | 19 | def init_(m): 20 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain) 21 | 22 | input_channel = obs_shape[0] 23 | input_width = obs_shape[1] 24 | input_height = obs_shape[2] 25 | 26 | self.cnn = nn.Sequential( 27 | init_(nn.Conv2d(in_channels=input_channel, 28 | out_channels=hidden_size // 2, 29 | kernel_size=kernel_size, 30 | stride=stride) 31 | ), 32 | active_func, 33 | Flatten(), 34 | init_(nn.Linear(hidden_size // 2 * (input_width - kernel_size + stride) * (input_height - kernel_size + stride), 35 | hidden_size) 36 | ), 37 | active_func, 38 | init_(nn.Linear(hidden_size, hidden_size)), active_func) 39 | 40 | def forward(self, x): 41 | x = x / 255.0 42 | x = self.cnn(x) 43 | return x 44 | 45 | 46 | class CNNBase(nn.Module): 47 | def __init__(self, args, obs_shape): 48 | super(CNNBase, self).__init__() 49 | 50 | self._use_orthogonal = args.use_orthogonal 51 | self._use_ReLU = args.use_ReLU 52 | self.hidden_size = args.hidden_size 53 | 54 | self.cnn = CNNLayer(obs_shape, self.hidden_size, self._use_orthogonal, self._use_ReLU) 55 | 56 | def forward(self, x): 57 | x = self.cnn(x) 58 | return x 59 | -------------------------------------------------------------------------------- /algos/utils/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .util import init 4 | 5 | """ 6 | Modify standard PyTorch distributions so they to make compatible with this codebase. 7 | """ 8 | 9 | # 10 | # Standardize distribution interfaces 11 | # 12 | 13 | # Categorical 14 | class FixedCategorical(torch.distributions.Categorical): 15 | def sample(self): 16 | return super().sample().unsqueeze(-1) 17 | 18 | def log_probs(self, actions): 19 | return ( 20 | super() 21 | .log_prob(actions.squeeze(-1)) 22 | .view(actions.size(0), -1) 23 | .sum(-1) 24 | .unsqueeze(-1) 25 | ) 26 | 27 | def mode(self): 28 | return self.probs.argmax(dim=-1, keepdim=True) 29 | 30 | 31 | # Normal 32 | class FixedNormal(torch.distributions.Normal): 33 | def log_probs(self, actions): 34 | return super().log_prob(actions) 35 | # return super().log_prob(actions).sum(-1, keepdim=True) 36 | 37 | def entrop(self): 38 | return super.entropy().sum(-1) 39 | 40 | def mode(self): 41 | return self.mean 42 | 43 | 44 | # Bernoulli 45 | class FixedBernoulli(torch.distributions.Bernoulli): 46 | def log_probs(self, actions): 47 | return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 48 | 49 | def entropy(self): 50 | return super().entropy().sum(-1) 51 | 52 | def mode(self): 53 | return torch.gt(self.probs, 0.5).float() 54 | 55 | 56 | class Categorical(nn.Module): 57 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 58 | super(Categorical, self).__init__() 59 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 60 | def init_(m): 61 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 62 | 63 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 64 | 65 | def forward(self, x, available_actions=None): 66 | x = self.linear(x) 67 | if available_actions is not None: 68 | x[available_actions == 0] = -1e10 69 | return FixedCategorical(logits=x) 70 | 71 | 72 | # class DiagGaussian(nn.Module): 73 | # def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 74 | # super(DiagGaussian, self).__init__() 75 | # 76 | # init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 77 | # def init_(m): 78 | # return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 79 | # 80 | # self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 81 | # self.logstd = AddBias(torch.zeros(num_outputs)) 82 | # 83 | # def forward(self, x, available_actions=None): 84 | # action_mean = self.fc_mean(x) 85 | # 86 | # # An ugly hack for my KFAC implementation. 87 | # zeros = torch.zeros(action_mean.size()) 88 | # if x.is_cuda: 89 | # zeros = zeros.cuda() 90 | # 91 | # action_logstd = self.logstd(zeros) 92 | # return FixedNormal(action_mean, action_logstd.exp()) 93 | 94 | class DiagGaussian(nn.Module): 95 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01, config=None): 96 | super(DiagGaussian, self).__init__() 97 | gain = config["actor_gain"] 98 | 99 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 100 | 101 | def init_(m): 102 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 103 | 104 | if config is not None: 105 | self.std_x_coef = config["std_x_coef"] 106 | self.std_y_coef = config["std_y_coef"] 107 | else: 108 | self.std_x_coef = 1. 109 | self.std_y_coef = 0.5 110 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 111 | log_std = torch.ones(num_outputs) * self.std_x_coef 112 | self.log_std = torch.nn.Parameter(log_std) 113 | 114 | def forward(self, x, available_actions=None): 115 | action_mean = self.fc_mean(x) 116 | action_std = torch.sigmoid(self.log_std / self.std_x_coef) * self.std_y_coef 117 | return FixedNormal(action_mean, action_std) 118 | 119 | class Bernoulli(nn.Module): 120 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 121 | super(Bernoulli, self).__init__() 122 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 123 | def init_(m): 124 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 125 | 126 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 127 | 128 | def forward(self, x): 129 | x = self.linear(x) 130 | return FixedBernoulli(logits=x) 131 | 132 | class AddBias(nn.Module): 133 | def __init__(self, bias): 134 | super(AddBias, self).__init__() 135 | self._bias = nn.Parameter(bias.unsqueeze(1)) 136 | 137 | def forward(self, x): 138 | if x.dim() == 2: 139 | bias = self._bias.t().view(1, -1) 140 | else: 141 | bias = self._bias.t().view(1, -1, 1, 1) 142 | 143 | return x + bias 144 | -------------------------------------------------------------------------------- /algos/utils/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .util import init, get_clones 3 | 4 | """MLP modules.""" 5 | 6 | class MLPLayer(nn.Module): 7 | def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, use_ReLU): 8 | super(MLPLayer, self).__init__() 9 | self._layer_N = layer_N 10 | 11 | # active_func = [nn.Tanh(), nn.ReLU()][use_ReLU] 12 | active_func = [nn.ELU(), nn.ELU()][use_ReLU] 13 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 14 | gain = nn.init.calculate_gain(['tanh', 'relu'][use_ReLU]) 15 | 16 | def init_(m): 17 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain) 18 | 19 | self.fc1 = nn.Sequential( 20 | init_(nn.Linear(input_dim, hidden_size)), active_func, nn.LayerNorm(hidden_size)) 21 | # self.fc1 = nn.Sequential( 22 | # init_(nn.Linear(input_dim, hidden_size)), active_func) 23 | # self.fc_h = nn.Sequential(init_( 24 | # nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size)) 25 | # self.fc2 = get_clones(self.fc_h, self._layer_N) 26 | self.fc2 = nn.ModuleList([nn.Sequential(init_( 27 | nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size)) for i in range(self._layer_N)]) 28 | # self.fc2 = nn.ModuleList([nn.Sequential(init_( 29 | # nn.Linear(hidden_size, hidden_size)), active_func) for i in range(self._layer_N)]) 30 | 31 | def forward(self, x): 32 | x = self.fc1(x) 33 | for i in range(self._layer_N): 34 | x = self.fc2[i](x) 35 | return x 36 | 37 | 38 | class MLPBase(nn.Module): 39 | def __init__(self, config, obs_shape, cat_self=True, attn_internal=False): 40 | super(MLPBase, self).__init__() 41 | 42 | self._use_feature_normalization = config["use_feature_normalization"] 43 | self._use_orthogonal = config["use_orthogonal"] 44 | self._use_ReLU = config["use_ReLU"] 45 | self._stacked_frames = config["stacked_frames"] 46 | self._layer_N = config["layer_N"] 47 | self.hidden_size = config["hidden_size"] 48 | 49 | obs_dim = obs_shape[0] 50 | 51 | if self._use_feature_normalization: 52 | self.feature_norm = nn.LayerNorm(obs_dim) 53 | 54 | self.mlp = MLPLayer(obs_dim, self.hidden_size, 55 | self._layer_N, self._use_orthogonal, self._use_ReLU) 56 | # self.mlp_middle_layer = MLPLayer(self.hidden_size, self.hidden_size, 57 | # self._layer_N, self._use_orthogonal, self._use_ReLU) 58 | 59 | def forward(self, x): 60 | if self._use_feature_normalization: 61 | x = self.feature_norm(x) 62 | 63 | x = self.mlp(x) 64 | # x = self.mlp_middle_layer(x) 65 | 66 | return x -------------------------------------------------------------------------------- /algos/utils/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | """RNN modules.""" 5 | 6 | 7 | class RNNLayer(nn.Module): 8 | def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal): 9 | super(RNNLayer, self).__init__() 10 | self._recurrent_N = recurrent_N 11 | self._use_orthogonal = use_orthogonal 12 | 13 | self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N) 14 | for name, param in self.rnn.named_parameters(): 15 | if 'bias' in name: 16 | nn.init.constant_(param, 0) 17 | elif 'weight' in name: 18 | if self._use_orthogonal: 19 | nn.init.orthogonal_(param) 20 | else: 21 | nn.init.xavier_uniform_(param) 22 | self.norm = nn.LayerNorm(outputs_dim) 23 | 24 | def forward(self, x, hxs, masks): 25 | if x.size(0) == hxs.size(0): 26 | x, hxs = self.rnn(x.unsqueeze(0), 27 | (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous()) 28 | x = x.squeeze(0) 29 | hxs = hxs.transpose(0, 1) 30 | else: 31 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 32 | N = hxs.size(0) 33 | T = int(x.size(0) / N) 34 | 35 | # unflatten 36 | x = x.view(T, N, x.size(1)) 37 | 38 | # Same deal with masks 39 | masks = masks.view(T, N) 40 | 41 | # Let's figure out which steps in the sequence have a zero for any agent 42 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 43 | has_zeros = ((masks[1:] == 0.0) 44 | .any(dim=-1) 45 | .nonzero() 46 | .squeeze() 47 | .cpu()) 48 | 49 | # +1 to correct the masks[1:] 50 | if has_zeros.dim() == 0: 51 | # Deal with scalar 52 | has_zeros = [has_zeros.item() + 1] 53 | else: 54 | has_zeros = (has_zeros + 1).numpy().tolist() 55 | 56 | # add t=0 and t=T to the list 57 | has_zeros = [0] + has_zeros + [T] 58 | 59 | hxs = hxs.transpose(0, 1) 60 | 61 | outputs = [] 62 | for i in range(len(has_zeros) - 1): 63 | # We can now process steps that don't have any zeros in masks together! 64 | # This is much faster 65 | start_idx = has_zeros[i] 66 | end_idx = has_zeros[i + 1] 67 | temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous() 68 | rnn_scores, hxs = self.rnn(x[start_idx:end_idx], temp) 69 | outputs.append(rnn_scores) 70 | 71 | # assert len(outputs) == T 72 | # x is a (T, N, -1) tensor 73 | x = torch.cat(outputs, dim=0) 74 | 75 | # flatten 76 | x = x.reshape(T * N, -1) 77 | hxs = hxs.transpose(0, 1) 78 | 79 | x = self.norm(x) 80 | return x, hxs 81 | -------------------------------------------------------------------------------- /algos/utils/util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | def init(module, weight_init, bias_init, gain=1): 8 | weight_init(module.weight.data, gain=gain) 9 | bias_init(module.bias.data) 10 | return module 11 | 12 | def get_clones(module, N): 13 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 14 | 15 | def check(input): 16 | output = torch.from_numpy(input) if type(input) == np.ndarray else input 17 | return output 18 | -------------------------------------------------------------------------------- /assets/readme/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/assets/readme/teaser.png -------------------------------------------------------------------------------- /cfgs/algo/ppo/config.yaml: -------------------------------------------------------------------------------- 1 | seed: -1 2 | 3 | clip_observations: 5.0 4 | clip_actions: 1.0 5 | 6 | policy: # only works for MlpPolicy right now 7 | pi_hid_sizes: [1024, 1024, 512] 8 | vf_hid_sizes: [1024, 1024, 512] 9 | activation: elu # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid 10 | learn: 11 | agent_name: shadow_hand 12 | test: False 13 | resume: 0 14 | save_interval: 1000 # check for potential saves every this many iterations 15 | print_log: True 16 | 17 | # rollout params 18 | max_iterations: 6500 19 | 20 | # training params 21 | cliprange: 0.2 22 | ent_coef: 0 23 | nsteps: 8 24 | noptepochs: 5 25 | nminibatches: 4 # this is per agent 26 | max_grad_norm: 1 27 | optim_stepsize: 3.e-4 # 3e-4 is default for single agent training with constant schedule 28 | schedule: adaptive # could be adaptive or linear or fixed 29 | desired_kl: 0.016 30 | gamma: 0.96 31 | lam: 0.95 32 | init_noise_std: 0.8 33 | 34 | log_interval: 1 35 | asymmetric: False -------------------------------------------------------------------------------- /cfgs/algo/ppo/manipulation.yaml: -------------------------------------------------------------------------------- 1 | seed: -1 2 | 3 | clip_observations: 5.0 4 | clip_actions: 0.2 5 | 6 | policy: # only works for MlpPolicy right now 7 | pi_hid_sizes: [128, 128, 128] 8 | vf_hid_sizes: [128, 128, 128] 9 | activation: elu # can be elu, relu, selu, crelu, lrelu, tanh, sigmoid 10 | learn: 11 | agent_name: franka 12 | test: False 13 | resume: 0 14 | save_interval: 10 # check for potential saves every this many iterations 15 | print_log: True 16 | 17 | # rollout params 18 | max_iterations: 200 19 | 20 | # training params 21 | cliprange: 0.2 22 | total_loss_coef: 1.e-3 23 | ent_coef: 0 24 | nsteps: 75 25 | noptepochs: 20 26 | nminibatches: 32 # this is per agent 27 | max_grad_norm: 1 28 | optim_stepsize: 3.e-4 # 3e-4 is default for single agent training with constant schedule 29 | schedule: fixed # could be adaptive or linear or fixed 30 | desired_kl: 0.016 31 | gamma: 0.998 32 | lam: 0.95 33 | init_noise_std: 0.8 34 | 35 | log_interval: 1 36 | asymmetric: False -------------------------------------------------------------------------------- /cfgs/plan/config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/cfgs/plan/config.yaml -------------------------------------------------------------------------------- /cfgs/repre/ag2manip/config.yaml: -------------------------------------------------------------------------------- 1 | type: "ag2manip" 2 | desc: "use R3M checkpoint pre-trained with epic-kitchen agent-agnostic dataset" 3 | 4 | model: AG2MANIP 5 | d_emb: 1024 6 | batchsize: 128 7 | 8 | backbone_type: resnet50 9 | similarity_type: l2 10 | 11 | ckpt_dir: repre_trainer/logs/ag2manip/ckpts -------------------------------------------------------------------------------- /cfgs/task/partmanip/pull_wooddrawer.yaml: -------------------------------------------------------------------------------- 1 | # if given, will override the device setting in gym. 2 | env: 3 | env_name: "partmanip@pull_wooddrawer" 4 | numEnvs: 128 #! default is 128 5 | envSpacing: 0.0 6 | episodeLength: 75 7 | enableDebugVis: False 8 | cameraDebug: True 9 | pointCloudDebug: True 10 | aggregateMode: 1 11 | 12 | stiffnessScale: 1.0 13 | forceLimitScale: 1.0 14 | useRelativeControl: True #! default is True 15 | dofSpeedScale: 20.0 16 | attractorSpeedScale: 20.0 17 | actionsMovingAverage: 1.0 18 | controlFrequencyInv: 1 # 60 Hz 19 | 20 | startPositionNoise: 0.0 21 | startRotationNoise: 0.0 22 | 23 | resetPositionNoise: 0.0 24 | resetRotationNoise: 0.0 25 | resetDofPosRandomInterval: 0.0 26 | resetDofVelRandomInterval: 0.0 27 | 28 | distRewardScale: 50 29 | rotRewardScale: 1.0 30 | rotEps: 0.1 31 | actionPenaltyScale: -0.0002 32 | reachGoalBonus: 250 33 | fallDistance: 0.4 34 | fallPenalty: 0.0 35 | 36 | observationType: "full_state" # full_state or robot_state 37 | asymmetric_observations: False 38 | successTolerance: 0.1 39 | printNumSuccesses: False 40 | maxConsecutiveSuccesses: 0 41 | 42 | actionType: "dummy_interaction_sphere" 43 | asset: 44 | agent: 45 | assetRoot: "./assets/agent" 46 | panda-wovis: "panda_gripper/robots/panda-wovis.urdf" 47 | panda-wvis: "panda_gripper/robots/panda.urdf" 48 | sphere-wovis: "sphere/sphere-wovis.urdf" 49 | sphere-wvis: "sphere/sphere.urdf" 50 | franka: "franka_urdf/robots/franka_panda.urdf" 51 | franka-wovis: "franka_urdf/robots/franka_panda-wovis.urdf" 52 | object: 53 | assetRoot: "./assets/env/partmanip" 54 | floor: "floor.xml" 55 | room_0: "room/room.urdf" 56 | room_1: "room/room.urdf" 57 | wood_drawer: "wood_drawer/mobility.urdf" 58 | placement: 59 | floor: 60 | pos: [0, 0, -0.01] 61 | rot: [0, 0, 0] ## rot as euler_xyz 62 | room_0: 63 | pos: [-2.0, 1.0, 0.0] 64 | rot: [0, 0, 0] 65 | room_1: 66 | pos: [2.0, -3.0, 0.0] 67 | rot: [0.0, 0.0, 3.14] 68 | wood_drawer: 69 | pos: [0.0, 0.0, 1.1] 70 | rot: [0.0, 0.0, 1.57] 71 | 72 | 73 | task: 74 | randomize: False 75 | randomization_params: 76 | frequency: 600 # Define how many simulation steps between generating new randomizations 77 | observations: 78 | range: [0., 0.] # range for the white noise 79 | range_correlated: [0., 0.] # range for correlated noise, refreshed with freq `frequency` 80 | operation: "additive" 81 | distribution: "gaussian" 82 | schedule: "linear" # "constant" is to turn on noise after `schedule_steps` num steps 83 | schedule_steps: 40000 84 | actions: 85 | range: [0., .05] 86 | range_correlated: [0, .015] # range for correlated noise, refreshed with freq `frequency` 87 | operation: "additive" 88 | distribution: "gaussian" 89 | schedule: "linear" # "linear" will linearly interpolate between no rand and max rand 90 | schedule_steps: 40000 91 | sim_params: 92 | gravity: 93 | range: [0, 0.4] 94 | operation: "additive" 95 | distribution: "gaussian" 96 | schedule: "linear" # "linear" will linearly interpolate between no rand and max rand 97 | schedule_steps: 40000 98 | actor_params: 99 | hand: 100 | color: True 101 | tendon_properties: 102 | damping: 103 | range: [0.3, 3.0] 104 | operation: "scaling" 105 | distribution: "loguniform" 106 | schedule: "linear" # "linear" will scale the current random sample by `min(current num steps, schedule_steps) / schedule_steps` 107 | schedule_steps: 30000 108 | stiffness: 109 | range: [0.75, 1.5] 110 | operation: "scaling" 111 | distribution: "loguniform" 112 | schedule: "linear" # "linear" will scale the current random sample by `min(current num steps, schedule_steps) / schedule_steps` 113 | schedule_steps: 30000 114 | dof_properties: 115 | damping: 116 | range: [0.3, 3.0] 117 | operation: "scaling" 118 | distribution: "loguniform" 119 | schedule: "linear" # "linear" will scale the current random sample by `min(current num steps, schedule_steps) / schedule_steps` 120 | schedule_steps: 30000 121 | stiffness: 122 | range: [0.75, 1.5] 123 | operation: "scaling" 124 | distribution: "loguniform" 125 | schedule: "linear" # "linear" will scale the current random sample by `min(current num steps, schedule_steps) / schedule_steps` 126 | schedule_steps: 30000 127 | lower: 128 | range: [0, 0.01] 129 | operation: "additive" 130 | distribution: "gaussian" 131 | schedule: "linear" # "linear" will scale the current random sample by `min(current num steps, schedule_steps) / schedule_steps` 132 | schedule_steps: 30000 133 | upper: 134 | range: [0, 0.01] 135 | operation: "additive" 136 | distribution: "gaussian" 137 | schedule: "linear" # "linear" will scale the current random sample by `min(current num steps, schedule_steps) / schedule_steps` 138 | schedule_steps: 30000 139 | rigid_body_properties: 140 | mass: 141 | range: [0.5, 1.5] 142 | operation: "scaling" 143 | distribution: "uniform" 144 | schedule: "linear" # "linear" will scale the current random sample by `min(current num steps, schedule_steps) / schedule_steps` 145 | schedule_steps: 30000 146 | rigid_shape_properties: 147 | friction: 148 | num_buckets: 250 149 | range: [0.7, 1.3] 150 | operation: "scaling" 151 | distribution: "uniform" 152 | schedule: "linear" # "linear" will scale the current random sample by `min(current num steps, schedule_steps) / schedule_steps` 153 | schedule_steps: 30000 154 | object: 155 | scale: 156 | range: [0.95, 1.05] 157 | operation: "scaling" 158 | distribution: "uniform" 159 | schedule: "linear" # "linear" will scale the current random sample by ``min(current num steps, schedule_steps) / schedule_steps` 160 | schedule_steps: 30000 161 | rigid_body_properties: 162 | mass: 163 | range: [0.5, 1.5] 164 | operation: "scaling" 165 | distribution: "uniform" 166 | schedule: "linear" # "linear" will scale the current random sample by ``min(current num steps, schedule_steps) / schedule_steps` 167 | schedule_steps: 30000 168 | rigid_shape_properties: 169 | friction: 170 | num_buckets: 250 171 | range: [0.7, 1.3] 172 | operation: "scaling" 173 | distribution: "uniform" 174 | schedule: "linear" # "linear" will scale the current random sample by `min(current num steps, schedule_steps) / schedule_steps` 175 | schedule_steps: 30000 176 | 177 | repre: None 178 | # repre: 179 | # handcrafted: 180 | # model: None # representation model name 181 | # checkpoints: "./checkpoints" # path to the checkpoints 182 | 183 | 184 | sim: 185 | substeps: 2 186 | physx: 187 | num_threads: 4 188 | solver_type: 1 # 0: pgs, 1: tgs 189 | num_position_iterations: 8 190 | num_velocity_iterations: 0 191 | contact_offset: 0.002 192 | rest_offset: 0.0 193 | bounce_threshold_velocity: 0.2 194 | max_depenetration_velocity: 1000.0 195 | default_buffer_size_multiplier: 5.0 196 | flex: 197 | num_outer_iterations: 5 198 | num_inner_iterations: 20 199 | warm_start: 0.8 200 | relaxation: 0.75 201 | -------------------------------------------------------------------------------- /plan.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import pickle 3 | import numpy as np 4 | 5 | from utils.config import set_np_formatting, set_seed, get_args, parse_sim_params, load_plan_cfg 6 | from utils.parse_task_plan import parse_task_plan 7 | from algos.planner.base import BASE 8 | from algos.planner.approach import APPROACH 9 | 10 | 11 | def plan(): 12 | print(f'Planner: Default') 13 | 14 | task, env = parse_task_plan(args, cfg, sim_params) 15 | planner = APPROACH(env, cfg, save_goal=args.save_goal, save_video=args.save_video) 16 | planner.run() 17 | 18 | 19 | if __name__ == '__main__': 20 | set_np_formatting() 21 | args = get_args() 22 | cfg, logdir = load_plan_cfg(args) 23 | sim_params = parse_sim_params(args, cfg, None) 24 | set_seed(42, True) 25 | plan() -------------------------------------------------------------------------------- /repre_trainer/cfgs/default.yaml: -------------------------------------------------------------------------------- 1 | # config/default.yaml 2 | hydra: 3 | run: 4 | dir: ${exp_dir} 5 | output_subdir: null 6 | 7 | defaults: 8 | - _self_ 9 | - model: null 10 | - task: null 11 | - eval: null 12 | # - optimizer: null 13 | # - planner: null 14 | 15 | ckpt: null 16 | output_dir: logs 17 | exp_name: default 18 | exp_dir: ${output_dir}/${now:%Y-%m-%d_%H-%M-%S}_${exp_name} 19 | tb_dir: ${exp_dir}/tb_logs 20 | vis_dir: ${exp_dir}/visual 21 | ckpt_dir: ${exp_dir}/ckpts 22 | 23 | slurm: false 24 | gpu: 0 25 | 26 | ## for saving model, interval for epoch loop 27 | save_model_interval: 1 28 | save_model_seperately: false 29 | save_scene_model: false # save scene model or not, important!!! 30 | -------------------------------------------------------------------------------- /repre_trainer/cfgs/model/ag2manip.yaml: -------------------------------------------------------------------------------- 1 | name: AG2MANIP 2 | data_type: rgb #! optional list ['rgb', 'agentago'], which is important 3 | 4 | d_emb: 1024 5 | backbone_type: resnet50 6 | similarity_type: l2 # optional list: [l2, cosine] 7 | num_negatives: 3 8 | 9 | learning_rate: 1e-4 10 | 11 | loss_weight: 12 | tcn: 1.0 13 | l1norm: 0.00001 14 | l2norm: 0.00001 15 | -------------------------------------------------------------------------------- /repre_trainer/cfgs/model/r3m.yaml: -------------------------------------------------------------------------------- 1 | name: R3M 2 | data_type: rgb #! optional list ['rgb', 'agentago'], which is important 3 | 4 | d_emb: 1024 5 | backbone_type: resnet50 6 | similarity_type: l2 # optional list: [l2, cosine] 7 | num_negatives: 3 8 | 9 | learning_rate: 1e-4 10 | 11 | loss_weight: 12 | tcn: 1.0 13 | l1norm: 0.00001 14 | l2norm: 0.00001 15 | -------------------------------------------------------------------------------- /repre_trainer/cfgs/model/vip.yaml: -------------------------------------------------------------------------------- 1 | name: VIP 2 | data_type: rgb #! optional list ['rgb', 'agentago'], which is important 3 | 4 | d_emb: 1024 5 | backbone_type: resnet50 6 | reward_type: sparse # optional list: [sparse, dense] 7 | similarity_type: l2 # optional list: [l2, cosine] 8 | num_negatives: 3 # assert: num_negatives < ${task.dataset.bacth_size} 9 | 10 | learning_rate: 1e-4 11 | 12 | loss_weight: 13 | gamma: 0.98 14 | l1norm: 0.00001 15 | l2norm: 0.00001 16 | -------------------------------------------------------------------------------- /repre_trainer/cfgs/task/epic_kitchen.yaml: -------------------------------------------------------------------------------- 1 | # task: pose generation 2 | name: epic_kitchen 3 | lr: ${model.learning_rate} 4 | clip_grad: 0.0 # 0.0 means no clip 5 | eval_interval: 1 6 | eval_visualize: 100 7 | 8 | train: 9 | batch_size: 64 10 | num_workers: 8 11 | num_epochs: 10000000 12 | log_step: 100 ## orignal 100 13 | 14 | test: 15 | epoch: null 16 | batch_size: 8 17 | num_workers: 0 18 | 19 | dataset: 20 | name: EpicKitchen 21 | desc: '[Epic Kitchen]' 22 | model_type: ${model.name} 23 | resolution_height: 256 24 | resolution_width: 256 25 | aug_window_size: 0.6 26 | # data_type: rgb 27 | data_type: ${model.data_type} 28 | item_type: ${model.name} 29 | device: cuda 30 | data_dir_local: ${path_to_your_dataset_folder} 31 | data_dir_slurm: null 32 | # train_transforms: ['NumpyToTensor'] 33 | # test_transforms: ['NumpyToTensor'] 34 | # transform_cfg: {} 35 | 36 | visualizer: 37 | visualize: false 38 | -------------------------------------------------------------------------------- /repre_trainer/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .epic_kitchen import EpicKitchen -------------------------------------------------------------------------------- /repre_trainer/datasets/base.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from torch.utils.data import Dataset 3 | from utils.registry import Registry 4 | DATASET = Registry('Dataset') 5 | 6 | def create_dataset(cfg: dict, phase: str, slurm: bool, **kwargs: Dict) -> Dataset: 7 | """ Create a `torch.utils.data.Dataset` object from configuration. 8 | 9 | Args: 10 | cfg: configuration object, dataset configuration 11 | phase: phase string, can be 'train' and 'test' 12 | slurm: on slurm platform or not. This field is used to specify the data path 13 | 14 | Return: 15 | A Dataset object that has loaded the designated dataset. 16 | """ 17 | return DATASET.get(cfg.name)(cfg, phase, slurm, **kwargs) 18 | -------------------------------------------------------------------------------- /repre_trainer/datasets/epic_kitchen.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, Dict 2 | import os 3 | import json 4 | import glob 5 | from tqdm import tqdm 6 | import pickle 7 | import trimesh 8 | 9 | import torch 10 | import torchvision 11 | import pandas as pd 12 | import numpy as np 13 | from PIL import Image 14 | from torch.utils.data import Dataset, DataLoader 15 | from omegaconf import DictConfig 16 | 17 | from datasets.base import DATASET 18 | 19 | @DATASET.register() 20 | class EpicKitchen(Dataset): 21 | """ Dataset for representation learning, 22 | training with EPIC-KITCHEN agent-aware/agent-agonostic dataset 23 | """ 24 | _STR_FRAME_LENGTH = 10 25 | _train_split = [] 26 | _test_split = [] 27 | _all_split = [] 28 | 29 | def __init__(self, cfg: DictConfig, phase: str, slurm: bool, **kwargs: Dict) -> None: 30 | super(EpicKitchen, self).__init__() 31 | self.phase = phase 32 | self.slurm = slurm 33 | if self.phase == 'train': 34 | self.split = self._train_split 35 | elif self.phase == 'test': 36 | self.split = self._test_split 37 | elif self.phase == 'all': 38 | self.split = self._all_split 39 | else: 40 | raise Exception(f"Unsupported phase: {self.phase}") 41 | self.data_type = cfg.data_type 42 | self.model_type = cfg.model_type 43 | self.data_dir = cfg.data_dir_slurm if self.slurm else cfg.data_dir_local 44 | self.resolution = (cfg.resolution_height, cfg.resolution_width) 45 | self.aug_sidewindow_size = (1 - cfg.aug_window_size) / 2 46 | self.to_tensor = torchvision.transforms.ToTensor() 47 | self.preprocess = torch.nn.Sequential( 48 | torchvision.transforms.Resize(self.resolution, antialias=True),) 49 | 50 | #* for specify getitem func. 51 | self.item_type = cfg.item_type.lower() 52 | #* load data 53 | self._pre_load_data() 54 | 55 | def _pre_load_data(self) -> None: 56 | """ Load metadata from json files 57 | """ 58 | # self.indices = [] 59 | self.info = json.load(open(os.path.join(self.data_dir, 'info.json'), 'r')) 60 | self.metadata = pd.read_csv(os.path.join(self.data_dir, 'EPIC100_annotations.csv')) 61 | 62 | def __len__(self): 63 | return len(self.metadata) 64 | 65 | def __getitem__(self, index: Tuple) -> Tuple: 66 | if self.item_type.lower() in ['vip']: 67 | # return self._getitem_vip(index) 68 | #! use r3m sample way 69 | return self._getitem_vip(index) 70 | elif self.item_type.lower() in ['r3m', 'ag2manip']: 71 | return self._getitem_r3m(index) 72 | else: 73 | raise NotImplementedError 74 | 75 | def _getitem_r3m(self, index: Any) -> Tuple: 76 | mdata = self.metadata.iloc[index] 77 | start_frame = mdata['start_frame'] 78 | stop_frame = mdata['stop_frame'] 79 | part_id = mdata['participant_id'] 80 | video_id = mdata['video_id'] 81 | 82 | #* do augmentation and observation sampling 83 | # clip_length = stop_frame - start_frame + 1 84 | # start_ind = np.random.randint(start_frame, 85 | # start_frame + int(clip_length * self.aug_sidewindow_size) + 1) 86 | # stop_ind = np.random.randint(stop_frame - int(clip_length * self.aug_sidewindow_size), 87 | # stop_frame + 1) 88 | sample_indices = np.random.permutation(np.arange(start_frame, stop_frame + 1))[:3] 89 | s0_ind_r3m, s1_ind_r3m, s2_ind_r3m = np.sort(sample_indices) 90 | 91 | img_s0 = self._load_frame(part_id, video_id, s0_ind_r3m) 92 | img_s1 = self._load_frame(part_id, video_id, s1_ind_r3m) 93 | img_s2 = self._load_frame(part_id, video_id, s2_ind_r3m) 94 | imgs = torch.stack([img_s0, img_s1, img_s2], dim=0) 95 | imgs = self.preprocess(imgs) 96 | 97 | #* dict a data sample 98 | data = { 99 | 'imgs': imgs, 100 | 's0_ind': s0_ind_r3m, 101 | 's1_ind': s1_ind_r3m, 102 | 's2_ind': s2_ind_r3m 103 | } 104 | return data 105 | 106 | def _getitem_vip(self, index: Any) -> Tuple: 107 | mdata = self.metadata.iloc[index] 108 | #? do random crop??? 109 | start_frame = mdata['start_frame'] 110 | stop_frame = mdata['stop_frame'] 111 | part_id = mdata['participant_id'] 112 | video_id = mdata['video_id'] 113 | 114 | #* do augmentation and observation sampling 115 | clip_length = stop_frame - start_frame + 1 116 | start_ind = np.random.randint(start_frame, 117 | start_frame + int(clip_length * self.aug_sidewindow_size) + 1) 118 | stop_ind = np.random.randint(stop_frame - int(clip_length * self.aug_sidewindow_size), 119 | stop_frame + 1) 120 | s0_ind_vip = np.random.randint(start_ind, stop_ind) 121 | s1_ind_vip = np.random.randint(s0_ind_vip + 1, stop_ind + 1) 122 | 123 | #* load images 124 | #! should be start_ind and stop_ind 125 | img_start = self._load_frame(part_id, video_id, start_ind) 126 | img_goal = self._load_frame(part_id, video_id, stop_ind) 127 | img_s0 = self._load_frame(part_id, video_id, s0_ind_vip) 128 | img_s1 = self._load_frame(part_id, video_id, s1_ind_vip) 129 | imgs = torch.stack([img_start, img_goal, img_s0, img_s1], dim=0) 130 | imgs = self.preprocess(imgs) 131 | 132 | #* dict a data sample 133 | data = { 134 | 'imgs': imgs, 135 | 'start_ind': start_ind, 136 | 'stop_ind': stop_ind, 137 | 's0_ind': s0_ind_vip, 138 | 's1_ind': s1_ind_vip,} 139 | return data 140 | 141 | def _load_frame(self, part_id: str, video_id: str, frame_id: int) -> torch.Tensor: 142 | if self.data_type == 'rgb': 143 | vid = os.path.join(self.data_dir, part_id, 'rgb_frames', video_id, f"frame_{frame_id:010d}.jpg") 144 | elif self.data_type == 'agentago': 145 | vid = os.path.join(self.data_dir, part_id, 'agentago_frames', video_id, f"frame_{frame_id:010d}.jpg") 146 | else: 147 | raise NotImplementedError 148 | return self.to_tensor(Image.open(vid).convert('RGB')) 149 | 150 | def get_dataloader(self, **kwargs): 151 | return DataLoader(self, **kwargs) 152 | -------------------------------------------------------------------------------- /repre_trainer/datasets/misc.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import torch 3 | from einops import rearrange 4 | 5 | def collate_fn_general(batch: List) -> Dict: 6 | """ General collate function used for dataloader. 7 | """ 8 | batch_data = {key: [d[key] for d in batch] for key in batch[0]} 9 | 10 | # for key in batch_data: 11 | # if torch.is_tensor(batch_data[key][0]): 12 | # batch_data[key] = torch.stack(batch_data[key]) 13 | return batch_data 14 | 15 | def collate_fn_epic_vip(batch: List) -> Dict: 16 | """ Collate function used for EPIC-KITCHENS dataset. 17 | """ 18 | batch_data = {key: [d[key] for d in batch] for key in batch[0]} 19 | batch_data['imgs'] = torch.stack(batch_data['imgs']) 20 | batch_data['start_ind'] = torch.tensor(batch_data['start_ind'], dtype=torch.long) 21 | batch_data['stop_ind'] = torch.tensor(batch_data['stop_ind'], dtype=torch.long) 22 | batch_data['s0_ind'] = torch.tensor(batch_data['s0_ind'], dtype=torch.long) 23 | batch_data['s1_ind'] = torch.tensor(batch_data['s1_ind'], dtype=torch.long) 24 | 25 | return batch_data 26 | 27 | def collate_fn_epic_r3m(batch: List) -> Dict: 28 | """ Collate function used for EPIC-KITCHENS dataset. 29 | """ 30 | batch_data = {key: [d[key] for d in batch] for key in batch[0]} 31 | batch_data['imgs'] = torch.stack(batch_data['imgs']) 32 | batch_data['s0_ind'] = torch.tensor(batch_data['s0_ind'], dtype=torch.long) 33 | batch_data['s1_ind'] = torch.tensor(batch_data['s1_ind'], dtype=torch.long) 34 | batch_data['s2_ind'] = torch.tensor(batch_data['s2_ind'], dtype=torch.long) 35 | 36 | return batch_data 37 | 38 | def collate_fn_epic_clip(batch: List) -> Dict: 39 | """ Collate function used for EPIC-KITCHEN Clips dataset. 40 | """ 41 | assert len(batch) == 1, "batch size must be exactly 1" 42 | batch_data = batch[0] 43 | 44 | return batch_data 45 | 46 | def collate_fn_arnold_clip(batch: List) -> Dict: 47 | """ Collate function used for EPIC-KITCHEN Clips dataset. 48 | """ 49 | assert len(batch) == 1, "batch size must be exactly 1" 50 | batch_data = batch[0] 51 | 52 | return batch_data 53 | -------------------------------------------------------------------------------- /repre_trainer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model.r3m import R3M 2 | from .model.vip import VIP 3 | from .model.ag2manip import AG2MANIP -------------------------------------------------------------------------------- /repre_trainer/models/base.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import torch.nn as nn 3 | from omegaconf import DictConfig 4 | 5 | from utils.registry import Registry 6 | 7 | MODEL = Registry('Model') 8 | 9 | def create_model(cfg: DictConfig, *args: List, **kwargs: Dict) -> nn.Module: 10 | """ Create a `torch.nn.Module` object from configuration. 11 | 12 | Args: 13 | cfg: configuration object, model configuration 14 | args: arguments to initialize the model 15 | kwargs: keyword arguments to initialize the model 16 | 17 | Return: 18 | A Module object that has loaded the designated model. 19 | """ 20 | return MODEL.get(cfg.model.name)(cfg.model, *args, **kwargs) 21 | -------------------------------------------------------------------------------- /repre_trainer/models/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import torch.nn as nn 5 | import numpy as np 6 | from scipy.stats import spearmanr 7 | from collections import defaultdict 8 | from omegaconf import DictConfig 9 | 10 | from utils.registry import Registry 11 | from loguru import logger 12 | 13 | EVALUATOR = Registry('Evaluator') 14 | 15 | def create_evaluator(cfg: DictConfig) -> nn.Module: 16 | """ Create a evaluator for quantitative evaluation 17 | Args: 18 | cfg: configuration object 19 | 20 | Return: 21 | A evaluator 22 | """ 23 | return EVALUATOR.get(cfg.name)(cfg) 24 | 25 | @EVALUATOR.register() 26 | class RewardRanker(): 27 | def __init__(self, cfg: DictConfig) -> None: 28 | """ Evaluator class for reward ranking test 29 | 30 | Args: 31 | cfg: evaluator configuration 32 | """ 33 | self.log_step = cfg.log_step 34 | 35 | -------------------------------------------------------------------------------- /repre_trainer/models/model/ag2manip.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from einops import rearrange 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision import transforms 8 | from omegaconf import DictConfig 9 | 10 | from models.base import MODEL 11 | 12 | @MODEL.register() 13 | class AG2MANIP(nn.Module): 14 | # a copy for r3m model architecture 15 | def __init__(self, cfg: DictConfig, *args, **kwargs) -> None: 16 | super(AG2MANIP, self).__init__() 17 | self.d_emb = cfg.d_emb 18 | self.backbone_type = cfg.backbone_type 19 | self.similarity_type = cfg.similarity_type 20 | self.num_negatives = cfg.num_negatives 21 | self.loss_weight = cfg.loss_weight 22 | 23 | self.normlayer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 24 | if self.backbone_type == 'resnet50': 25 | self.backbone = torchvision.models.resnet50(pretrained=False) 26 | self.backbone.fc = nn.Linear(2048, self.d_emb) 27 | else: 28 | raise NotImplementedError 29 | 30 | def forward(self, data: Dict) -> torch.Tensor: 31 | """ Forward 32 | Args: 33 | data: input data dict 34 | { 35 | 'imgs': imgs, [B, T, C, H, W] (e.g., [32, 3, 3, 256, 256]), must be float32 36 | 's0_ind': s0_ind, 37 | 's1_ind': s1_ind, 38 | 's2_ind': s2_ind} 39 | return: 40 | dict { 41 | 'loss': full_loss, 42 | 'metrics': metrics for logs} 43 | """ 44 | imgs = data['imgs'] 45 | s0_ind = data['s0_ind'] 46 | s1_ind = data['s1_ind'] 47 | s2_ind = data['s2_ind'] 48 | 49 | if imgs.shape[2:] != (3, 256, 256): 50 | preprocess = nn.Sequential( 51 | transforms.Resize(256, antialias=True), 52 | self.normlayer, 53 | ) 54 | else: 55 | preprocess = nn.Sequential( 56 | self.normlayer, 57 | ) 58 | imgs = preprocess(imgs) 59 | B, T = imgs.shape[:2] 60 | imgs = imgs.reshape(B*T, *imgs.shape[2:]) 61 | embs = self.backbone(imgs) 62 | embs = embs.reshape(B, T, *embs.shape[1:]) 63 | emb_s0 = embs[:, 0] 64 | emb_s1 = embs[:, 1] 65 | emb_s2 = embs[:, 2] 66 | 67 | #* compute metrics and full loss 68 | full_loss = 0 69 | metrics = dict() 70 | 71 | #* 1. Embdedding Norm loss 72 | loss_l1 = torch.linalg.norm(embs, ord=1, dim=-1).mean() 73 | loss_l2 = torch.linalg.norm(embs, ord=2, dim=-1).mean() 74 | full_loss += self.loss_weight.l1norm * loss_l1 75 | full_loss += self.loss_weight.l2norm * loss_l2 76 | metrics['loss_l1'] = loss_l1.item() 77 | metrics['loss_l2'] = loss_l2.item() 78 | 79 | #* 2. TCN Loss 80 | sim_0_1 = self.similarity(emb_s0, emb_s1) 81 | sim_1_2 = self.similarity(emb_s1, emb_s2) 82 | sim_0_2 = self.similarity(emb_s0, emb_s2) 83 | 84 | # negative samples 85 | sim_s0_neg = [] 86 | sim_s2_neg = [] 87 | perm = [i for i in range(B)] 88 | for _ in range(self.num_negatives): 89 | perm = [(i_perm + 1) % B for i_perm in perm] 90 | emb_s0_shuf = emb_s0[perm] 91 | emb_s2_shuf = emb_s2[perm] 92 | sim_s0_neg.append(self.similarity(emb_s0_shuf, emb_s0)) 93 | sim_s2_neg.append(self.similarity(emb_s2_shuf, emb_s2)) 94 | sim_s0_neg = torch.stack(sim_s0_neg, dim=-1) 95 | sim_s2_neg = torch.stack(sim_s2_neg, dim=-1) 96 | 97 | tcn_loss_1 = -torch.log(1e-6 + (torch.exp(sim_1_2) / (1e-6 + torch.exp(sim_0_2) + torch.exp(sim_1_2) + torch.exp(sim_s2_neg).sum(-1)))) 98 | tcn_loss_2 = -torch.log(1e-6 + (torch.exp(sim_0_1) / (1e-6 + torch.exp(sim_0_1) + torch.exp(sim_0_2) + torch.exp(sim_s0_neg).sum(-1)))) 99 | 100 | tcn_loss = ((tcn_loss_1 + tcn_loss_2) / 2.0).mean() 101 | metrics['loss_tcn'] = tcn_loss.item() 102 | metrics['alignment'] = (1.0 * (sim_0_2 < sim_1_2) * (sim_0_1 > sim_0_2)).float().mean().item() 103 | 104 | #* compute full loss 105 | full_loss += self.loss_weight.tcn * tcn_loss 106 | metrics['full_loss'] = full_loss.item() 107 | 108 | return {'loss': full_loss, 'metrics': metrics} 109 | 110 | def embedding(self, imgs: torch.Tensor) -> torch.Tensor: 111 | """ Embedding function 112 | """ 113 | if imgs.shape[1:] != (3, 256, 256): 114 | preprocess = nn.Sequential( 115 | transforms.Resize(256, antialias=True), 116 | self.normlayer, 117 | ) 118 | else: 119 | preprocess = nn.Sequential( 120 | self.normlayer, 121 | ) 122 | imgs = preprocess(imgs) 123 | embs = self.backbone(imgs) 124 | return embs 125 | 126 | def similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 127 | """ Similarity function 128 | """ 129 | if self.similarity_type == 'l2': 130 | d = -torch.linalg.norm(x - y, dim=-1) 131 | return d 132 | elif self.similarity_type == 'cosine': 133 | x = F.normalize(x, dim=-1) 134 | y = F.normalize(y, dim=-1) 135 | d = torch.einsum('...i,...i->...', x, y) 136 | return d 137 | else: 138 | raise NotImplementedError 139 | -------------------------------------------------------------------------------- /repre_trainer/models/model/r3m.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from einops import rearrange 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision import transforms 8 | from omegaconf import DictConfig 9 | 10 | from models.base import MODEL 11 | 12 | @MODEL.register() 13 | class R3M(nn.Module): 14 | def __init__(self, cfg: DictConfig, *args, **kwargs) -> None: 15 | super(R3M, self).__init__() 16 | self.d_emb = cfg.d_emb 17 | self.backbone_type = cfg.backbone_type 18 | self.similarity_type = cfg.similarity_type 19 | self.num_negatives = cfg.num_negatives 20 | self.loss_weight = cfg.loss_weight 21 | 22 | self.normlayer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 23 | if self.backbone_type == 'resnet50': 24 | self.backbone = torchvision.models.resnet50(pretrained=False) 25 | self.backbone.fc = nn.Linear(2048, self.d_emb) 26 | else: 27 | raise NotImplementedError 28 | 29 | def forward(self, data: Dict) -> torch.Tensor: 30 | """ Forward 31 | Args: 32 | data: input data dict 33 | { 34 | 'imgs': imgs, [B, T, C, H, W] (e.g., [32, 3, 3, 256, 256]), must be float32 35 | 's0_ind': s0_ind, 36 | 's1_ind': s1_ind, 37 | 's2_ind': s2_ind} 38 | return: 39 | dict { 40 | 'loss': full_loss, 41 | 'metrics': metrics for logs} 42 | """ 43 | imgs = data['imgs'] 44 | s0_ind = data['s0_ind'] 45 | s1_ind = data['s1_ind'] 46 | s2_ind = data['s2_ind'] 47 | 48 | if imgs.shape[2:] != (3, 256, 256): 49 | preprocess = nn.Sequential( 50 | transforms.Resize(256, antialias=True), 51 | self.normlayer, 52 | ) 53 | else: 54 | preprocess = nn.Sequential( 55 | self.normlayer, 56 | ) 57 | imgs = preprocess(imgs) 58 | B, T = imgs.shape[:2] 59 | imgs = imgs.reshape(B*T, *imgs.shape[2:]) 60 | embs = self.backbone(imgs) 61 | embs = embs.reshape(B, T, *embs.shape[1:]) 62 | emb_s0 = embs[:, 0] 63 | emb_s1 = embs[:, 1] 64 | emb_s2 = embs[:, 2] 65 | 66 | #* compute metrics and full loss 67 | full_loss = 0 68 | metrics = dict() 69 | 70 | #* 1. Embdedding Norm loss 71 | loss_l1 = torch.linalg.norm(embs, ord=1, dim=-1).mean() 72 | loss_l2 = torch.linalg.norm(embs, ord=2, dim=-1).mean() 73 | full_loss += self.loss_weight.l1norm * loss_l1 74 | full_loss += self.loss_weight.l2norm * loss_l2 75 | metrics['loss_l1'] = loss_l1.item() 76 | metrics['loss_l2'] = loss_l2.item() 77 | 78 | #* 2. TCN Loss 79 | sim_0_1 = self.similarity(emb_s0, emb_s1) 80 | sim_1_2 = self.similarity(emb_s1, emb_s2) 81 | sim_0_2 = self.similarity(emb_s0, emb_s2) 82 | 83 | # negative samples 84 | sim_s0_neg = [] 85 | sim_s2_neg = [] 86 | perm = [i for i in range(B)] 87 | for _ in range(self.num_negatives): 88 | perm = [(i_perm + 1) % B for i_perm in perm] 89 | emb_s0_shuf = emb_s0[perm] 90 | emb_s2_shuf = emb_s2[perm] 91 | sim_s0_neg.append(self.similarity(emb_s0_shuf, emb_s0)) 92 | sim_s2_neg.append(self.similarity(emb_s2_shuf, emb_s2)) 93 | sim_s0_neg = torch.stack(sim_s0_neg, dim=-1) 94 | sim_s2_neg = torch.stack(sim_s2_neg, dim=-1) 95 | 96 | tcn_loss_1 = -torch.log(1e-6 + (torch.exp(sim_1_2) / (1e-6 + torch.exp(sim_0_2) + torch.exp(sim_1_2) + torch.exp(sim_s2_neg).sum(-1)))) 97 | tcn_loss_2 = -torch.log(1e-6 + (torch.exp(sim_0_1) / (1e-6 + torch.exp(sim_0_1) + torch.exp(sim_0_2) + torch.exp(sim_s0_neg).sum(-1)))) 98 | 99 | tcn_loss = ((tcn_loss_1 + tcn_loss_2) / 2.0).mean() 100 | metrics['loss_tcn'] = tcn_loss.item() 101 | metrics['alignment'] = (1.0 * (sim_0_2 < sim_1_2) * (sim_0_1 > sim_0_2)).float().mean().item() 102 | 103 | #* compute full loss 104 | full_loss += self.loss_weight.tcn * tcn_loss 105 | metrics['full_loss'] = full_loss.item() 106 | 107 | return {'loss': full_loss, 'metrics': metrics} 108 | 109 | def embedding(self, imgs: torch.Tensor) -> torch.Tensor: 110 | """ Embedding function 111 | """ 112 | if imgs.shape[1:] != (3, 256, 256): 113 | preprocess = nn.Sequential( 114 | transforms.Resize(256, antialias=True), 115 | self.normlayer, 116 | ) 117 | else: 118 | preprocess = nn.Sequential( 119 | self.normlayer, 120 | ) 121 | imgs = preprocess(imgs) 122 | embs = self.backbone(imgs) 123 | return embs 124 | 125 | def similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 126 | """ Similarity function 127 | """ 128 | if self.similarity_type == 'l2': 129 | d = -torch.linalg.norm(x - y, dim=-1) 130 | return d 131 | elif self.similarity_type == 'cosine': 132 | x = F.normalize(x, dim=-1) 133 | y = F.normalize(y, dim=-1) 134 | d = torch.einsum('...i,...i->...', x, y) 135 | return d 136 | else: 137 | raise NotImplementedError 138 | -------------------------------------------------------------------------------- /repre_trainer/models/model/vip.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from einops import rearrange 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision import transforms 8 | from omegaconf import DictConfig 9 | 10 | from models.base import MODEL 11 | 12 | @MODEL.register() 13 | class VIP(nn.Module): 14 | def __init__(self, cfg: DictConfig, *args, **kwargs) -> None: 15 | super(VIP, self).__init__() 16 | self.d_emb = cfg.d_emb 17 | self.backbone_type = cfg.backbone_type 18 | self.reward_type = cfg.reward_type 19 | self.similarity_type = cfg.similarity_type 20 | self.num_negatives = cfg.num_negatives 21 | self.loss_weight = cfg.loss_weight 22 | 23 | self.normlayer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 24 | if self.backbone_type == 'resnet50': 25 | self.backbone = torchvision.models.resnet50(pretrained=False) 26 | self.backbone.fc = nn.Linear(2048, self.d_emb) 27 | else: 28 | raise NotImplementedError 29 | 30 | def forward(self, data: Dict) -> torch.Tensor: 31 | """ Forward 32 | Args: 33 | data: input data dict 34 | { 35 | 'imgs': imgs, [B, T, C, H, W] (e.g., [32, 4, 3, 256, 256]), must be float32 36 | 'start_ind': start_ind, 37 | 'stop_ind': stop_ind, 38 | 's0_ind': s0_ind, 39 | 's1_ind': s1_ind} 40 | return: 41 | dict { 42 | 'loss': full_loss, 43 | 'metrics': metrics for logs} 44 | """ 45 | #* forward process 46 | imgs = data['imgs'] 47 | start_ind = data['start_ind'] 48 | stop_ind = data['stop_ind'] 49 | s0_ind = data['s0_ind'] 50 | s1_ind = data['s1_ind'] 51 | 52 | if imgs.shape[2:] != (3, 256, 256): 53 | preprocess = nn.Sequential( 54 | transforms.Resize(256, antialias=True), 55 | self.normlayer, 56 | ) 57 | else: 58 | preprocess = nn.Sequential( 59 | self.normlayer, 60 | ) 61 | imgs = preprocess(imgs) 62 | B, T = imgs.shape[:2] 63 | imgs = imgs.reshape(B*T, *imgs.shape[2:]) 64 | embs = self.backbone(imgs) 65 | embs = embs.reshape(B, T, embs.shape[-1]) # [B, T, d_emb] 66 | emb_start = embs[:, 0] 67 | emb_goal = embs[:, 1] 68 | emb_s0 = embs[:, 2] 69 | emb_s1 = embs[:, 3] 70 | 71 | #* compute metrics and full_loss 72 | full_loss = 0 73 | metrics = dict() 74 | 75 | #* 1. Embedding Norm Loss 76 | loss_l1 = torch.linalg.norm(embs, ord=1, dim=-1).mean() 77 | loss_l2 = torch.linalg.norm(embs, ord=2, dim=-1).mean() 78 | full_loss += self.loss_weight.l1norm * loss_l1 79 | full_loss += self.loss_weight.l2norm * loss_l2 80 | metrics['loss_l1'] = loss_l1.item() 81 | metrics['loss_l2'] = loss_l2.item() 82 | 83 | #* 2. VIP Loss 84 | v_o = self.similarity(emb_start, emb_goal) 85 | v_s0 = self.similarity(emb_s0, emb_goal) 86 | v_s1 = self.similarity(emb_s1, emb_goal) 87 | # compute reward (sparse version) 88 | reward = self.reward(start_ind, stop_ind, s0_ind, s1_ind) 89 | loss_vip = (1 - self.loss_weight.gamma) * ( - v_o.mean()) \ 90 | + torch.log(1e-6 + torch.mean(torch.exp( - (reward + self.loss_weight.gamma * v_s1 - v_s0)))) 91 | 92 | #* 3. Additional negative observations 93 | v_s0_neg = [] 94 | v_s1_neg = [] 95 | perm = [i for i in range(B)] 96 | for _ in range(self.num_negatives): 97 | perm = [(i_perm + 1) % B for i_perm in perm] 98 | emb_s0_shuf = emb_s0[perm] 99 | emb_s1_shuf = emb_s1[perm] 100 | v_s0_neg.append(self.similarity(emb_s0_shuf, emb_goal)) 101 | v_s1_neg.append(self.similarity(emb_s1_shuf, emb_goal)) 102 | if self.num_negatives > 0: 103 | v_s0_neg = torch.cat(v_s0_neg) 104 | v_s1_neg = torch.cat(v_s1_neg) 105 | reward_neg = - torch.ones_like(v_s0_neg, device=v_s0_neg.device) 106 | loss_vip += torch.log(1e-6 + torch.mean(torch.exp( - (reward_neg + self.loss_weight.gamma * v_s1_neg - v_s0_neg)))) 107 | metrics['loss_vip'] = loss_vip.item() 108 | # metrics['alignment'] = (1.0 * (v_s0 > v_o) * (v_s1 > v_s0)).float().mean().item() 109 | metrics['alignment'] = (0.5 * (v_s0 > v_o) + 0.5 * (v_s1 > v_o)).float().mean().item() 110 | 111 | #* compute full loss 112 | full_loss += loss_vip 113 | metrics['full_loss'] = full_loss.item() 114 | 115 | return {'loss': full_loss, 'metrics': metrics} 116 | 117 | def embedding(self, imgs: torch.Tensor) -> torch.Tensor: 118 | """ Embedding function 119 | Args: 120 | imgs: input tensor [B, C, H, W] 121 | """ 122 | if imgs.shape[1:] != (3, 256, 256): 123 | preprocess = nn.Sequential( 124 | transforms.Resize(256, antialias=True), 125 | self.normlayer, 126 | ) 127 | else: 128 | preprocess = nn.Sequential( 129 | self.normlayer, 130 | ) 131 | imgs = preprocess(imgs) 132 | embs = self.backbone(imgs) 133 | return embs 134 | 135 | def similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 136 | """ Similarity function 137 | """ 138 | if self.similarity_type == 'l2': 139 | d = -torch.linalg.norm(x - y, dim=-1) 140 | return d 141 | elif self.similarity_type == 'cosine': 142 | x = F.normalize(x, dim=-1) 143 | y = F.normalize(y, dim=-1) 144 | d = torch.einsum('...i,...i->...', x, y) 145 | return d 146 | else: 147 | raise NotImplementedError 148 | 149 | def reward(self, start_ind, stop_ind, s0_ind, s1_ind) -> torch.Tensor: 150 | """ Reward function 151 | """ 152 | if self.reward_type == 'sparse': 153 | reward = (s0_ind == stop_ind).float() - 1 154 | return reward 155 | elif self.reward_type == 'dense': 156 | raise NotImplementedError 157 | else: 158 | raise NotImplementedError 159 | -------------------------------------------------------------------------------- /repre_trainer/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def mkdir_if_not_exists(dir_name: str, recursive: bool=False) -> None: 4 | """ Make directory with the given dir_name 5 | Args: 6 | dir_name: input directory name that can be a path 7 | recursive: recursive directory creation 8 | """ 9 | if os.path.exists(dir_name): 10 | return 11 | 12 | if recursive: 13 | os.makedirs(dir_name) 14 | else: 15 | print(f'current path: {os.getcwd()}') 16 | os.mkdir(dir_name) -------------------------------------------------------------------------------- /repre_trainer/utils/misc.py: -------------------------------------------------------------------------------- 1 | import string 2 | import random 3 | from datetime import datetime 4 | from omegaconf import DictConfig 5 | 6 | def timestamp_str() -> str: 7 | """ Get current time stamp string 8 | """ 9 | now = datetime.now() 10 | return now.strftime("%Y-%m-%d_%H-%M-%S") 11 | 12 | def random_str(length: int=4) -> str: 13 | """ Generate random string with given length 14 | """ 15 | return ''.join(random.choices(string.ascii_letters + string.digits, k=4)) 16 | 17 | -------------------------------------------------------------------------------- /repre_trainer/utils/plot.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | def singleton(cls): 4 | _instance = {} 5 | 6 | def inner(*args, **kwargs): 7 | if cls not in _instance: 8 | _instance[cls] = cls(*args, **kwargs) 9 | return _instance[cls] 10 | return inner 11 | 12 | @singleton 13 | class _Writer(): 14 | """ A singleton class that can hold the SummaryWriter Object. 15 | So we can initialize it once and use it everywhere. 16 | """ 17 | def __init__(self) -> None: 18 | self.writer = None 19 | 20 | def write(self, write_dict: dict) -> None: 21 | """ Write the input dict data into writer object. 22 | 23 | Args: 24 | write_dict: a dict object containing data that need to be plotted. 25 | Format is ```{key1: {'plot': bool, 'value': float, 'step': long}}```. 26 | `plot` means this value corresponding to this key needs to be plotted or not. 27 | `value` is the specific value. `step` is the training step. 28 | """ 29 | if self.writer is None: 30 | raise Exception('[ERR-CFG] Writer is None!') 31 | 32 | for key in write_dict.keys(): 33 | if write_dict[key]['plot']: 34 | self.writer.add_scalar(key, write_dict[key]['value'], write_dict[key]['step']) 35 | 36 | def setWriter(self, writer: SummaryWriter) -> None: 37 | self.writer = writer 38 | 39 | class Ploter(): 40 | """ Ploter class for providing static methods to write data into SummaryWriter. 41 | """ 42 | def __init__(self) -> None: 43 | pass 44 | 45 | @staticmethod 46 | def setWriter(writer: SummaryWriter) -> None: 47 | w = _Writer() 48 | w.setWriter(writer) 49 | 50 | @staticmethod 51 | def write(write_dict: dict) -> None: 52 | """ Plot input dict data. 53 | 54 | Args: 55 | write_dict: a dict object containing data that need to be plotted. 56 | Format is ```{key1: {'plot': bool, 'value': float, 'step': long}}```. 57 | `plot` means this value corresponding to this key needs to be plotted or not. 58 | `value` is the specific value. `step` is the training step. 59 | """ 60 | w = _Writer() 61 | w.write(write_dict) -------------------------------------------------------------------------------- /repre_trainer/utils/registry.py: -------------------------------------------------------------------------------- 1 | ## Acknowledgment: This code is modified from https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/registry.py 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # pyre-ignore-all-errors[2,3] 5 | from typing import Any, Dict, Iterable, Iterator, Tuple 6 | 7 | from tabulate import tabulate 8 | 9 | 10 | class Registry(Iterable[Tuple[str, Any]]): 11 | """ 12 | The registry that provides name -> object mapping, to support third-party 13 | users' custom modules. 14 | 15 | To create a registry (e.g. a backbone registry): 16 | 17 | .. code-block:: python 18 | 19 | BACKBONE_REGISTRY = Registry('BACKBONE') 20 | 21 | To register an object: 22 | 23 | .. code-block:: python 24 | 25 | @BACKBONE_REGISTRY.register() 26 | class MyBackbone(): 27 | ... 28 | 29 | Or: 30 | 31 | .. code-block:: python 32 | 33 | BACKBONE_REGISTRY.register(MyBackbone) 34 | """ 35 | 36 | def __init__(self, name: str) -> None: 37 | """ 38 | Args: 39 | name (str): the name of this registry 40 | """ 41 | self._name: str = name 42 | self._obj_map: Dict[str, Any] = {} 43 | 44 | def _do_register(self, name: str, obj: Any) -> None: 45 | assert ( 46 | name not in self._obj_map 47 | ), "An object named '{}' was already registered in '{}' registry!".format( 48 | name, self._name 49 | ) 50 | self._obj_map[name] = obj 51 | 52 | def register(self, obj: Any = None) -> Any: 53 | """ 54 | Register the given object under the the name `obj.__name__`. 55 | Can be used as either a decorator or not. See docstring of this class for usage. 56 | """ 57 | if obj is None: 58 | # used as a decorator 59 | def deco(func_or_class: Any) -> Any: 60 | name = func_or_class.__name__ 61 | self._do_register(name, func_or_class) 62 | return func_or_class 63 | 64 | return deco 65 | 66 | # used as a function call 67 | name = obj.__name__ 68 | self._do_register(name, obj) 69 | 70 | def get(self, name: str) -> Any: 71 | ret = self._obj_map.get(name) 72 | if ret is None: 73 | raise KeyError( 74 | "No object named '{}' found in '{}' registry!".format(name, self._name) 75 | ) 76 | return ret 77 | 78 | def __contains__(self, name: str) -> bool: 79 | return name in self._obj_map 80 | 81 | def __repr__(self) -> str: 82 | table_headers = ["Names", "Objects"] 83 | table = tabulate( 84 | self._obj_map.items(), headers=table_headers, tablefmt="fancy_grid" 85 | ) 86 | return "Registry of {}:\n".format(self._name) + table 87 | 88 | def __iter__(self) -> Iterator[Tuple[str, Any]]: 89 | return iter(self._obj_map.items()) 90 | 91 | # pyre-fixme[4]: Attribute must be annotated. 92 | __str__ = __repr__ -------------------------------------------------------------------------------- /repres/ag2manip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision import transforms 8 | 9 | from repres.base.base_repre import BaseRepre 10 | 11 | class AG2MANIP(BaseRepre): 12 | 13 | def __init__(self, cfg_repre) -> None: 14 | super(AG2MANIP, self).__init__() 15 | self.goal_image = cfg_repre["goal_image"] 16 | self.device = cfg_repre["device"] 17 | self.batchsize = cfg_repre["batchsize"] 18 | 19 | self.d_emb = cfg_repre["d_emb"] 20 | self.backbone_type = cfg_repre['backbone_type'] 21 | self.similarity_type = cfg_repre['similarity_type'] 22 | 23 | if self.goal_image.dtype != torch.float32: 24 | raise TypeError("cfg_repre.goal_image.dtype must be torch.float32") 25 | self.goal_image = torch.tensor(self.goal_image, dtype=torch.float32) 26 | 27 | self.normlayer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 28 | if self.backbone_type == 'resnet50': 29 | self.backbone = torchvision.models.resnet50(pretrained=False) 30 | self.backbone.fc = nn.Linear(2048, self.d_emb) 31 | else: 32 | raise NotImplementedError 33 | 34 | #* load pre-trained ckpts 35 | if cfg_repre['ckpt_dir']: 36 | print(f'Require a pre-trained ckpt dir for representation model {self.__class__.__name__}') 37 | self.ckpt_dir = cfg_repre['ckpt_dir'] 38 | print(f'Loading ckpt from {self.ckpt_dir}') 39 | checkpoint = torch.load(os.path.join(self.ckpt_dir, 'model.pth'))['model'] 40 | self.load_state_dict(checkpoint) 41 | self.to(self.device) 42 | self.eval() 43 | 44 | #* compute goal image embedding 45 | self.goal_image = self.goal_image.to(self.device) 46 | self.goal_emb = self.embedding(self.goal_image.unsqueeze(0).permute(0, 3, 1, 2)) # (1, 1000) 47 | 48 | @torch.no_grad() 49 | def forward(self, x): 50 | """ 51 | x: [to torch.float32] (batch_size, 256, 256, 3) 52 | """ 53 | x = x.to(self.device) 54 | if x.dtype != torch.float32: 55 | raise TypeError("x.dtype must be torch.float32") 56 | 57 | x = x.permute(0, 3, 1, 2) # (batch_size, 3, 256, 256) 58 | embs = [] 59 | for i in range(0, x.shape[0], self.batchsize): 60 | embs.append(self.embedding(x[i:i+self.batchsize])) 61 | embs = torch.cat(embs, dim=0) 62 | value = self.similarity(embs, self.goal_emb) 63 | 64 | return value 65 | 66 | @torch.no_grad() 67 | def embedding(self, imgs: torch.Tensor) -> torch.Tensor: 68 | """ Embedding function 69 | """ 70 | if imgs.shape[1:] != (3, 256, 256): 71 | preprocess = nn.Sequential( 72 | transforms.Resize(256, antialias=True), 73 | self.normlayer, 74 | ) 75 | else: 76 | preprocess = nn.Sequential( 77 | self.normlayer, 78 | ) 79 | imgs = preprocess(imgs) 80 | embs = self.backbone(imgs) 81 | return embs 82 | 83 | def similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 84 | """ Similarity function #! nagative similarity 85 | """ 86 | if self.similarity_type == 'l2': 87 | d = -torch.linalg.norm(x - y, dim=-1) 88 | return -d 89 | elif self.similarity_type == 'cosine': 90 | x = F.normalize(x, dim=-1) 91 | y = F.normalize(y, dim=-1) 92 | d = torch.einsum('...i,...i->...', x, y) 93 | return -d 94 | else: 95 | raise NotImplementedError -------------------------------------------------------------------------------- /repres/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/repres/base/__init__.py -------------------------------------------------------------------------------- /repres/base/base_repre.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BaseRepre(nn.Module): 7 | 8 | def __init__(self) -> None: 9 | super(BaseRepre, self).__init__() 10 | 11 | def forward(self, x): 12 | raise NotImplementedError 13 | -------------------------------------------------------------------------------- /repres/r3m.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision import transforms 8 | 9 | from repres.base.base_repre import BaseRepre 10 | 11 | class R3M(BaseRepre): 12 | 13 | def __init__(self, cfg_repre) -> None: 14 | super(R3M, self).__init__() 15 | self.goal_image = cfg_repre["goal_image"] 16 | self.device = cfg_repre["device"] 17 | self.batchsize = cfg_repre["batchsize"] 18 | 19 | self.d_emb = cfg_repre["d_emb"] 20 | self.backbone_type = cfg_repre['backbone_type'] 21 | self.similarity_type = cfg_repre['similarity_type'] 22 | 23 | if self.goal_image.dtype != torch.float32: 24 | raise TypeError("cfg_repre.goal_image.dtype must be torch.float32") 25 | self.goal_image = torch.tensor(self.goal_image, dtype=torch.float32) 26 | 27 | self.normlayer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 28 | if self.backbone_type == 'resnet50': 29 | self.backbone = torchvision.models.resnet50(pretrained=False) 30 | self.backbone.fc = nn.Linear(2048, self.d_emb) 31 | else: 32 | raise NotImplementedError 33 | 34 | #* load pre-trained ckpts 35 | if cfg_repre['ckpt_dir']: 36 | print(f'Require a pre-trained ckpt dir for representation model {self.__class__.__name__}') 37 | self.ckpt_dir = os.path.join('ckpts', cfg_repre['ckpt_dir'], 'ckpts') 38 | print(f'Loading ckpt from {self.ckpt_dir}') 39 | checkpoint = torch.load(os.path.join(self.ckpt_dir, 'model.pth'))['model'] 40 | self.load_state_dict(checkpoint) 41 | self.to(self.device) 42 | self.eval() 43 | 44 | #* compute goal image embedding 45 | self.goal_image = self.goal_image.to(self.device) 46 | self.goal_emb = self.embedding(self.goal_image.unsqueeze(0).permute(0, 3, 1, 2)) # (1, 1000) 47 | 48 | @torch.no_grad() 49 | def forward(self, x): 50 | """ 51 | x: [to torch.float32] (batch_size, 256, 256, 3) 52 | """ 53 | x = x.to(self.device) 54 | if x.dtype != torch.float32: 55 | raise TypeError("x.dtype must be torch.float32") 56 | 57 | x = x.permute(0, 3, 1, 2) # (batch_size, 3, 256, 256) 58 | embs = [] 59 | for i in range(0, x.shape[0], self.batchsize): 60 | embs.append(self.embedding(x[i:i+self.batchsize])) 61 | embs = torch.cat(embs, dim=0) 62 | value = self.similarity(embs, self.goal_emb) 63 | 64 | return value 65 | 66 | @torch.no_grad() 67 | def embedding(self, imgs: torch.Tensor) -> torch.Tensor: 68 | """ Embedding function 69 | """ 70 | if imgs.shape[1:] != (3, 256, 256): 71 | preprocess = nn.Sequential( 72 | transforms.Resize(256, antialias=True), 73 | self.normlayer, 74 | ) 75 | else: 76 | preprocess = nn.Sequential( 77 | self.normlayer, 78 | ) 79 | imgs = preprocess(imgs) 80 | embs = self.backbone(imgs) 81 | return embs 82 | 83 | def similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 84 | """ Similarity function #! nagative similarity 85 | """ 86 | if self.similarity_type == 'l2': 87 | d = -torch.linalg.norm(x - y, dim=-1) 88 | return -d 89 | elif self.similarity_type == 'cosine': 90 | x = F.normalize(x, dim=-1) 91 | y = F.normalize(y, dim=-1) 92 | d = torch.einsum('...i,...i->...', x, y) 93 | return -d 94 | else: 95 | raise NotImplementedError -------------------------------------------------------------------------------- /repres/vip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision import transforms 8 | 9 | from repres.base.base_repre import BaseRepre 10 | 11 | class VIP(BaseRepre): 12 | 13 | def __init__(self, cfg_repre) -> None: 14 | super(VIP, self).__init__() 15 | self.goal_image = cfg_repre["goal_image"] 16 | self.device = cfg_repre["device"] 17 | self.batchsize = cfg_repre["batchsize"] 18 | 19 | self.d_emb = cfg_repre["d_emb"] 20 | self.backbone_type = cfg_repre['backbone_type'] 21 | self.similarity_type = cfg_repre['similarity_type'] 22 | 23 | if self.goal_image.dtype != torch.float32: 24 | raise TypeError("cfg_repre.goal_image.dtype must be torch.float32") 25 | self.goal_image = torch.tensor(self.goal_image, dtype=torch.float32) 26 | 27 | self.normlayer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 28 | if self.backbone_type == 'resnet50': 29 | self.backbone = torchvision.models.resnet50(pretrained=False) 30 | self.backbone.fc = nn.Linear(2048, self.d_emb) 31 | else: 32 | raise NotImplementedError 33 | 34 | #* load pre-trained ckpts 35 | if cfg_repre['ckpt_dir']: 36 | print(f'Require a pre-trained ckpt dir for representation model {self.__class__.__name__}') 37 | self.ckpt_dir = os.path.join('ckpts', cfg_repre['ckpt_dir'], 'ckpts') 38 | print(f'Loading ckpt from {self.ckpt_dir}') 39 | checkpoint = torch.load(os.path.join(self.ckpt_dir, 'model.pth'))['model'] 40 | self.load_state_dict(checkpoint) 41 | self.to(self.device) 42 | self.eval() 43 | 44 | #* compute goal image embedding 45 | self.goal_image = self.goal_image.to(self.device) 46 | self.goal_emb = self.embedding(self.goal_image.unsqueeze(0).permute(0, 3, 1, 2)) # (1, 1000) 47 | 48 | @torch.no_grad() 49 | def forward(self, x): 50 | """ 51 | x: [to torch.float32] (batch_size, 256, 256, 3) 52 | """ 53 | x = x.to(self.device) 54 | if x.dtype != torch.float32: 55 | raise TypeError("x.dtype must be torch.float32") 56 | 57 | x = x.permute(0, 3, 1, 2) # (batch_size, 3, 256, 256) 58 | embs = [] 59 | for i in range(0, x.shape[0], self.batchsize): 60 | embs.append(self.embedding(x[i:i+self.batchsize])) 61 | embs = torch.cat(embs, dim=0) 62 | value = self.similarity(embs, self.goal_emb) 63 | 64 | return value 65 | 66 | @torch.no_grad() 67 | def embedding(self, imgs: torch.Tensor) -> torch.Tensor: 68 | """ Embedding function 69 | """ 70 | if imgs.shape[1:] != (3, 256, 256): 71 | preprocess = nn.Sequential( 72 | transforms.Resize(256, antialias=True), 73 | self.normlayer, 74 | ) 75 | else: 76 | preprocess = nn.Sequential( 77 | self.normlayer, 78 | ) 79 | imgs = preprocess(imgs) 80 | embs = self.backbone(imgs) 81 | return embs 82 | 83 | def similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 84 | """ Similarity function #! nagative similarity 85 | """ 86 | if self.similarity_type == 'l2': 87 | d = -torch.linalg.norm(x - y, dim=-1) 88 | return -d 89 | elif self.similarity_type == 'cosine': 90 | x = F.normalize(x, dim=-1) 91 | y = F.normalize(y, dim=-1) 92 | d = torch.einsum('...i,...i->...', x, y) 93 | return -d 94 | else: 95 | raise NotImplementedError -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiosignal==1.3.1 3 | annotated-types==0.6.0 4 | antlr4-python3-runtime==4.9.3 5 | anyio==4.2.0 6 | appdirs==1.4.4 7 | argcomplete==2.0.0 8 | attrs==19.2.0 9 | cachetools==5.3.2 10 | certifi==2024.2.2 11 | charset-normalizer==3.3.2 12 | click==8.1.7 13 | cloudpickle==2.2.1 14 | cmake==3.28.1 15 | distlib==0.3.8 16 | distro==1.9.0 17 | docker-pycreds==0.4.0 18 | einops==0.3.0 19 | exceptiongroup==1.2.0 20 | filelock==3.13.1 21 | flake8-blind-except==0.2.1 22 | flake8-builtins==2.0.0 23 | flake8-class-newline==1.6.0 24 | flake8-comprehensions==3.10.0 25 | flake8-deprecated==2.0.1 26 | flake8-docstrings==1.6.0 27 | flake8-import-order==0.18.1 28 | flake8-quotes==3.3.1 29 | frozenlist==1.4.1 30 | gitdb==4.0.11 31 | GitPython==3.1.41 32 | google-auth==2.27.0 33 | google-auth-oauthlib==1.0.0 34 | grpcio==1.62.1 35 | gym==0.23.1 36 | gym-notices==0.0.8 37 | h11==0.14.0 38 | httpcore==1.0.2 39 | httpx==0.26.0 40 | hydra-core==1.3.2 41 | idna==3.6 42 | imageio==2.34.0 43 | importlib-metadata==7.0.1 44 | importlib-resources==6.1.1 45 | iniconfig==1.1.1 46 | # Editable install with no version control (isaacgym==1.0rc4) 47 | -e /home/puhao/dev/site-packages/isaacgym/python 48 | isaacgym-stubs==1.0rc4 49 | lit==17.0.6 50 | loguru==0.7.0 51 | Markdown==3.5.2 52 | MarkupSafe==2.1.5 53 | mpmath==1.3.0 54 | msgpack==1.0.7 55 | networkx==3.1 56 | ninja==1.11.1.1 57 | numpy==1.22.0 58 | nvidia-cublas-cu11==11.10.3.66 59 | nvidia-cuda-cupti-cu11==11.7.101 60 | nvidia-cuda-nvrtc-cu11==11.7.99 61 | nvidia-cuda-runtime-cu11==11.7.99 62 | nvidia-cudnn-cu11==8.5.0.96 63 | nvidia-cufft-cu11==10.9.0.58 64 | nvidia-curand-cu11==10.2.10.91 65 | nvidia-cusolver-cu11==11.4.0.1 66 | nvidia-cusparse-cu11==11.7.4.91 67 | nvidia-nccl-cu11==2.14.3 68 | nvidia-nvtx-cu11==11.7.91 69 | oauthlib==3.2.2 70 | omegaconf==2.3.0 71 | opencv-python==4.9.0.80 72 | packaging==24.0 73 | pandas==2.0.3 74 | pathtools==0.1.2 75 | pillow==10.3.0 76 | platformdirs==4.2.0 77 | pluggy==1.4.0 78 | promise==2.3 79 | protobuf==4.25.3 80 | psutil==5.9.8 81 | py==1.11.0 82 | pyasn1==0.6.0 83 | pyasn1_modules==0.4.0 84 | pydantic==2.6.0 85 | pydantic_core==2.16.1 86 | pytest==7.1.3 87 | pytest-repeat==0.9.1 88 | pytest-rerunfailures==10.2 89 | python-dateutil==2.9.0.post0 90 | pytz==2024.1 91 | PyVirtualDisplay==3.0 92 | PyYAML==6.0.1 93 | requests==2.31.0 94 | requests-oauthlib==1.3.1 95 | rsa==4.9 96 | scipy==1.10.1 97 | sentry-sdk==1.40.0 98 | setproctitle==1.3.3 99 | shortuuid==1.0.11 100 | six==1.16.0 101 | smmap==5.0.1 102 | sniffio==1.3.0 103 | sympy==1.12 104 | tabulate==0.9.0 105 | tensorboard==2.14.0 106 | tensorboard-data-server==0.7.2 107 | tensorboardX==2.6.2.2 108 | termcolor==2.3.0 109 | tomli==2.0.1 110 | torch==1.13.1+cu117 111 | torchvision==0.14.1+cu117 112 | tqdm==4.66.1 113 | trimesh==3.22.0 114 | triton==2.0.0 115 | typing_extensions==4.10.0 116 | tzdata==2024.1 117 | urllib3==2.2.0 118 | virtualenv==20.25.0 119 | wandb==0.16.5 120 | Werkzeug==3.0.1 121 | zipp==3.18.1 122 | -------------------------------------------------------------------------------- /tasks/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # NVIDIA CORPORATION and its licensors retain all intellectual property 3 | # and proprietary rights in and to this software, related documentation 4 | # and any modifications thereto. Any use, reproduction, disclosure or 5 | # distribution of this software and related documentation without an express 6 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 7 | -------------------------------------------------------------------------------- /tasks/base/vec_task.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # NVIDIA CORPORATION and its licensors retain all intellectual property 3 | # and proprietary rights in and to this software, related documentation 4 | # and any modifications thereto. Any use, reproduction, disclosure or 5 | # distribution of this software and related documentation without an express 6 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 7 | 8 | from gym import spaces 9 | 10 | from isaacgym import gymtorch 11 | from isaacgym.torch_utils import to_torch 12 | import torch 13 | import numpy as np 14 | 15 | 16 | # VecEnv Wrapper for RL training 17 | class VecTask(): 18 | def __init__(self, task, rl_device, clip_observations=5.0, clip_actions=1.0): 19 | self.task = task 20 | 21 | self.num_environments = task.num_envs 22 | self.num_agents = 1 # used for multi-agent environments 23 | self.num_observations = task.num_obs 24 | self.num_states = task.num_states 25 | self.num_actions = task.num_actions 26 | 27 | self.obs_space = spaces.Box(np.ones(self.num_obs) * -np.Inf, np.ones(self.num_obs) * np.Inf) 28 | self.state_space = spaces.Box(np.ones(self.num_states) * -np.Inf, np.ones(self.num_states) * np.Inf) 29 | self.act_space = spaces.Box(np.ones(self.num_actions) * -1., np.ones(self.num_actions) * 1.) 30 | 31 | self.clip_obs = clip_observations 32 | self.clip_actions = clip_actions 33 | self.rl_device = rl_device 34 | 35 | print("RL device: ", rl_device) 36 | 37 | def step(self, actions): 38 | raise NotImplementedError 39 | 40 | def reset(self): 41 | raise NotImplementedError 42 | 43 | def get_number_of_agents(self): 44 | return self.num_agents 45 | 46 | @property 47 | def observation_space(self): 48 | return self.obs_space 49 | 50 | @property 51 | def action_space(self): 52 | return self.act_space 53 | 54 | @property 55 | def num_envs(self): 56 | return self.num_environments 57 | 58 | @property 59 | def num_acts(self): 60 | return self.num_actions 61 | 62 | @property 63 | def num_obs(self): 64 | return self.num_observations 65 | 66 | 67 | # C++ CPU Class 68 | class VecTaskCPU(VecTask): 69 | def __init__(self, task, rl_device, sync_frame_time=False, clip_observations=5.0, clip_actions=1.0): 70 | super().__init__(task, rl_device, clip_observations=clip_observations, clip_actions=clip_actions) 71 | self.sync_frame_time = sync_frame_time 72 | 73 | def step(self, actions): 74 | actions = actions.cpu().numpy() 75 | self.task.render(self.sync_frame_time) 76 | 77 | obs, rewards, resets, extras = self.task.step(np.clip(actions, -self.clip_actions, self.clip_actions)) 78 | 79 | return (to_torch(np.clip(obs, -self.clip_obs, self.clip_obs), dtype=torch.float, device=self.rl_device), 80 | to_torch(rewards, dtype=torch.float, device=self.rl_device), 81 | to_torch(resets, dtype=torch.uint8, device=self.rl_device), []) 82 | 83 | def reset(self): 84 | actions = 0.01 * (1 - 2 * np.random.rand(self.num_envs, self.num_actions)).astype('f') 85 | 86 | # step the simulator 87 | obs, rewards, resets, extras = self.task.step(actions) 88 | 89 | return to_torch(np.clip(obs, -self.clip_obs, self.clip_obs), dtype=torch.float, device=self.rl_device) 90 | 91 | 92 | # C++ GPU Class 93 | class VecTaskGPU(VecTask): 94 | def __init__(self, task, rl_device, clip_observations=5.0, clip_actions=1.0): 95 | super().__init__(task, rl_device, clip_observations=clip_observations, clip_actions=clip_actions) 96 | 97 | self.obs_tensor = gymtorch.wrap_tensor(self.task.obs_tensor, counts=(self.task.num_envs, self.task.num_obs)) 98 | self.rewards_tensor = gymtorch.wrap_tensor(self.task.rewards_tensor, counts=(self.task.num_envs,)) 99 | self.resets_tensor = gymtorch.wrap_tensor(self.task.resets_tensor, counts=(self.task.num_envs,)) 100 | 101 | def step(self, actions): 102 | self.task.render(False) 103 | actions_clipped = torch.clamp(actions, -self.clip_actions, self.clip_actions) 104 | actions_tensor = gymtorch.unwrap_tensor(actions_clipped) 105 | 106 | self.task.step(actions_tensor) 107 | 108 | return torch.clamp(self.obs_tensor, -self.clip_obs, self.clip_obs), self.rewards_tensor, self.resets_tensor, [] 109 | 110 | def reset(self): 111 | actions = 0.01 * (1 - 2 * torch.rand([self.task.num_envs, self.task.num_actions], dtype=torch.float32, device=self.rl_device)) 112 | actions_tensor = gymtorch.unwrap_tensor(actions) 113 | 114 | # step the simulator 115 | self.task.step(actions_tensor) 116 | 117 | return torch.clamp(self.obs_tensor, -self.clip_obs, self.clip_obs) 118 | 119 | 120 | # Python CPU/GPU Class 121 | class VecTaskPython(VecTask): 122 | 123 | def get_state(self): 124 | return torch.clamp(self.task.states_buf, -self.clip_obs, self.clip_obs).to(self.rl_device) 125 | 126 | def step(self, actions): 127 | actions_tensor = torch.clamp(actions, -self.clip_actions, self.clip_actions) 128 | 129 | self.task.step(actions_tensor) 130 | 131 | return torch.clamp(self.task.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device), self.task.rew_buf.to(self.rl_device), self.task.reset_buf.to(self.rl_device), self.task.extras 132 | 133 | def reset(self): 134 | actions = 0.01 * (1 - 2 * torch.rand([self.task.num_envs, self.task.num_actions], dtype=torch.float32, device=self.rl_device)) 135 | 136 | # step the simulator 137 | self.task.step(actions) 138 | 139 | return torch.clamp(self.task.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device) 140 | 141 | class VecTaskPythonArm(VecTask) : 142 | 143 | def get_state(self): 144 | return torch.clamp(self.task.states_buf, -self.clip_obs, self.clip_obs).to(self.rl_device) 145 | 146 | def step(self, actions): 147 | actions_tensor = torch.clamp(actions, -self.clip_actions, self.clip_actions) 148 | 149 | self.task.step(actions_tensor) 150 | 151 | return torch.clamp(self.task.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device), self.task.rew_buf.to(self.rl_device), self.task.reset_buf.to(self.rl_device), self.task.extras 152 | 153 | def reset(self): 154 | actions = 0.01 * (1 - 2 * torch.rand([self.task.num_envs, self.task.num_actions], dtype=torch.float32, device=self.rl_device)) 155 | 156 | # step the simulator 157 | self.task.reset() 158 | self.task.step(actions) 159 | 160 | return torch.clamp(self.task.obs_buf, -self.clip_obs, self.clip_obs).to(self.rl_device) -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/CloseHingecabinet@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/CloseHingecabinet@default@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/CloseHingecabinet@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/CloseHingecabinet@left@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/CloseHingecabinet@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/CloseHingecabinet@right@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/CloseMicrowave@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/CloseMicrowave@default@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/CloseMicrowave@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/CloseMicrowave@left@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/CloseMicrowave@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/CloseMicrowave@right@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/CloseSlidecabinet@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/CloseSlidecabinet@default@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/CloseSlidecabinet@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/CloseSlidecabinet@left@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/CloseSlidecabinet@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/CloseSlidecabinet@right@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/MoveKettle@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/MoveKettle@default@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/MoveKettle@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/MoveKettle@left@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/MoveKettle@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/MoveKettle@right@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/OpenHingecabinet@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/OpenHingecabinet@default@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/OpenHingecabinet@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/OpenHingecabinet@left@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/OpenHingecabinet@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/OpenHingecabinet@right@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/OpenMicrowave@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/OpenMicrowave@default@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/OpenMicrowave@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/OpenMicrowave@left@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/OpenMicrowave@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/OpenMicrowave@right@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/OpenSlidecabinet@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/OpenSlidecabinet@default@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/OpenSlidecabinet@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/OpenSlidecabinet@left@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/OpenSlidecabinet@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/OpenSlidecabinet@right@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/PickupKettle@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/PickupKettle@default@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/PickupKettle@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/PickupKettle@left@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/PickupKettle@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/PickupKettle@right@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/TurnoffSwitch@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/TurnoffSwitch@default@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/TurnoffSwitch@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/TurnoffSwitch@left@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/TurnoffSwitch@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/TurnoffSwitch@right@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/TurnonSwitch@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/TurnonSwitch@default@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/TurnonSwitch@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/TurnonSwitch@left@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image/TurnonSwitch@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image/TurnonSwitch@right@woa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/CloseHingecabinet@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/CloseHingecabinet@default@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/CloseHingecabinet@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/CloseHingecabinet@left@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/CloseHingecabinet@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/CloseHingecabinet@right@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/CloseMicrowave@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/CloseMicrowave@default@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/CloseMicrowave@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/CloseMicrowave@left@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/CloseMicrowave@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/CloseMicrowave@right@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/CloseSlidecabinet@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/CloseSlidecabinet@default@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/CloseSlidecabinet@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/CloseSlidecabinet@left@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/CloseSlidecabinet@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/CloseSlidecabinet@right@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/MoveKettle@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/MoveKettle@default@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/MoveKettle@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/MoveKettle@left@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/MoveKettle@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/MoveKettle@right@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/OpenHingecabinet@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/OpenHingecabinet@default@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/OpenHingecabinet@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/OpenHingecabinet@left@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/OpenHingecabinet@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/OpenHingecabinet@right@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/OpenMicrowave@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/OpenMicrowave@default@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/OpenMicrowave@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/OpenMicrowave@left@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/OpenMicrowave@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/OpenMicrowave@right@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/OpenSlidecabinet@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/OpenSlidecabinet@default@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/OpenSlidecabinet@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/OpenSlidecabinet@left@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/OpenSlidecabinet@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/OpenSlidecabinet@right@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/PickupKettle@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/PickupKettle@default@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/PickupKettle@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/PickupKettle@left@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/PickupKettle@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/PickupKettle@right@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/TurnoffSwitch@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/TurnoffSwitch@default@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/TurnoffSwitch@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/TurnoffSwitch@left@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/TurnoffSwitch@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/TurnoffSwitch@right@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/TurnonSwitch@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/TurnonSwitch@default@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/TurnonSwitch@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/TurnonSwitch@left@wa.png -------------------------------------------------------------------------------- /tasks/frankakitchen/goals_image_wa/TurnonSwitch@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/frankakitchen/goals_image_wa/TurnonSwitch@right@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/CloseDoor@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/CloseDoor@default@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/CloseDoor@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/CloseDoor@left@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/CloseDoor@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/CloseDoor@right@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/InsertPeg@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/InsertPeg@default@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/InsertPeg@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/InsertPeg@left@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/InsertPeg@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/InsertPeg@right@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/OpenDoor@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/OpenDoor@default@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/OpenDoor@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/OpenDoor@left@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/OpenDoor@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/OpenDoor@right@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/PickupClutterycb@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/PickupClutterycb@default@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/PickupClutterycb@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/PickupClutterycb@left@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/PickupClutterycb@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/PickupClutterycb@right@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/PickupCube@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/PickupCube@default@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/PickupCube@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/PickupCube@left@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/PickupCube@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/PickupCube@right@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/StackCube@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/StackCube@default@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/StackCube@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/StackCube@left@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/StackCube@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/StackCube@right@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/TurnLeftfaucet@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/TurnLeftfaucet@default@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/TurnLeftfaucet@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/TurnLeftfaucet@left@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/TurnLeftfaucet@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/TurnLeftfaucet@right@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/TurnRightfaucet@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/TurnRightfaucet@default@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/TurnRightfaucet@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/TurnRightfaucet@left@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image/TurnRightfaucet@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image/TurnRightfaucet@right@woa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/CloseDoor@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/CloseDoor@default@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/CloseDoor@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/CloseDoor@left@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/CloseDoor@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/CloseDoor@right@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/InsertPeg@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/InsertPeg@default@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/InsertPeg@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/InsertPeg@left@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/InsertPeg@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/InsertPeg@right@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/OpenDoor@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/OpenDoor@default@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/OpenDoor@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/OpenDoor@left@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/OpenDoor@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/OpenDoor@right@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/PickCube@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/PickCube@right@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/PickupClutterycb@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/PickupClutterycb@default@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/PickupClutterycb@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/PickupClutterycb@left@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/PickupClutterycb@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/PickupClutterycb@right@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/PickupCube@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/PickupCube@default@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/PickupCube@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/PickupCube@left@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/StackCube@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/StackCube@default@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/StackCube@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/StackCube@left@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/StackCube@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/StackCube@right@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/TurnLeftfaucet@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/TurnLeftfaucet@default@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/TurnLeftfaucet@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/TurnLeftfaucet@left@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/TurnLeftfaucet@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/TurnLeftfaucet@right@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/TurnRightfaucet@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/TurnRightfaucet@default@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/TurnRightfaucet@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/TurnRightfaucet@left@wa.png -------------------------------------------------------------------------------- /tasks/maniskill/goals_image_wa/TurnRightfaucet@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/maniskill/goals_image_wa/TurnRightfaucet@right@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/CloseDishwasher@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/CloseDishwasher@default@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/CloseDishwasher@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/CloseDishwasher@left@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/CloseDishwasher@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/CloseDishwasher@right@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/LiftLid@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/LiftLid@default@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/LiftLid@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/LiftLid@left@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/LiftLid@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/LiftLid@right@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/OpenDishwasher@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/OpenDishwasher@default@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/OpenDishwasher@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/OpenDishwasher@left@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/OpenDishwasher@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/OpenDishwasher@right@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/PressButton@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/PressButton@default@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/PressButton@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/PressButton@left@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/PressButton@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/PressButton@right@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/PullWooddrawer@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/PullWooddrawer@default@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/PullWooddrawer@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/PullWooddrawer@left@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/PullWooddrawer@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/PullWooddrawer@right@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/PushWooddrawer@default@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/PushWooddrawer@default@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/PushWooddrawer@left@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/PushWooddrawer@left@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image/PushWooddrawer@right@woa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image/PushWooddrawer@right@woa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/CloseDishwasher@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/CloseDishwasher@default@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/CloseDishwasher@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/CloseDishwasher@left@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/CloseDishwasher@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/CloseDishwasher@right@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/LiftLid@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/LiftLid@default@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/LiftLid@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/LiftLid@left@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/LiftLid@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/LiftLid@right@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/OpenDishwasher@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/OpenDishwasher@default@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/OpenDishwasher@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/OpenDishwasher@left@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/OpenDishwasher@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/OpenDishwasher@right@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/PressButton@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/PressButton@default@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/PressButton@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/PressButton@left@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/PressButton@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/PressButton@right@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/PullWooddrawer@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/PullWooddrawer@default@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/PullWooddrawer@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/PullWooddrawer@left@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/PullWooddrawer@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/PullWooddrawer@right@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/PushWooddrawer@default@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/PushWooddrawer@default@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/PushWooddrawer@left@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/PushWooddrawer@left@wa.png -------------------------------------------------------------------------------- /tasks/partmanip/goals_image_wa/PushWooddrawer@right@wa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/tasks/partmanip/goals_image_wa/PushWooddrawer@right@wa.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | 4 | from utils.config import set_np_formatting, set_seed, get_args, parse_sim_params, load_cfg 5 | from utils.parse_task import parse_task 6 | from utils.process_sarl import process_sarl 7 | from utils.process_offrl import * 8 | 9 | 10 | def train(): 11 | print(f"Algorithm: {args.algo}") 12 | 13 | if args.algo in ['ppo', 'ddpg', 'sac', 'td3', 'trpo']: 14 | if args.save_traj: 15 | cfg['env']['numEnvs'] = 1 16 | task, env = parse_task(args, cfg, cfg_train, sim_params, agent_index=None) 17 | 18 | cfg_train['save_traj'] = args.save_traj 19 | sarl = eval('process_sarl')(args, env, cfg_train, logdir) 20 | 21 | iterations = cfg_train["learn"]["max_iterations"] 22 | if args.max_iterations > 0: 23 | iterations = args.max_iterations 24 | 25 | ## initialize wandb 26 | if not args.disable_wandb and not args.test: 27 | task_env, task_name, repre_name = args.task.split("@") 28 | camera_name = args.camera 29 | wandb.init( 30 | project=f'ag2manip', 31 | name=f'{camera_name}@{repre_name}.seed{env.task.cfg["seed"]}', 32 | config={ 33 | 'cfg': cfg, 34 | 'cfg_train': cfg_train, 35 | 'cfg_repre': cfg_repre, 36 | 'args': args 37 | } 38 | ) 39 | 40 | sarl.run(num_learning_iterations=iterations, log_interval=cfg_train["learn"]["save_interval"]) 41 | 42 | elif args.algo in ["td3_bc", "bcq", "iql", "ppo_collect"]: 43 | raise NotImplementedError 44 | 45 | else: 46 | raise NotImplementedError 47 | 48 | 49 | if __name__ == '__main__': 50 | set_np_formatting() 51 | args = get_args() 52 | cfg, cfg_train, logdir, cfg_repre = load_cfg(args) 53 | sim_params = parse_sim_params(args, cfg, cfg_train) 54 | set_seed(cfg_train.get("seed", -1), cfg_train.get("torch_deterministic", False)) 55 | train() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoyao-Li/Ag2Manip/ac16340ff1923b91da72329c018fd51d1b8b37e8/utils/__init__.py -------------------------------------------------------------------------------- /utils/logger/tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import csv 5 | import os 6 | import re 7 | from collections import defaultdict 8 | 9 | import numpy as np 10 | import tqdm 11 | from tensorboard.backend.event_processing import event_accumulator 12 | 13 | 14 | def find_all_files(root_dir, pattern): 15 | """Find all files under root_dir according to relative pattern.""" 16 | file_list = [] 17 | for dirname, _, files in os.walk(root_dir): 18 | for f in files: 19 | absolute_path = os.path.join(dirname, f) 20 | if re.match(pattern, absolute_path): 21 | file_list.append(absolute_path) 22 | return file_list 23 | 24 | 25 | def group_files(file_list, pattern): 26 | res = defaultdict(list) 27 | for f in file_list: 28 | match = re.search(pattern, f) 29 | key = match.group() if match else '' 30 | res[key].append(f) 31 | return res 32 | 33 | 34 | def csv2numpy(csv_file): 35 | csv_dict = defaultdict(list) 36 | reader = csv.DictReader(open(csv_file)) 37 | for row in reader: 38 | for k, v in row.items(): 39 | csv_dict[k].append(eval(v)) 40 | return {k: np.array(v) for k, v in csv_dict.items()} 41 | 42 | 43 | def convert_tfevents_to_csv(root_dir, alg_type, env_num, env_step, refresh=False): 44 | """Recursively convert test/rew from all tfevent file under root_dir to csv. 45 | 46 | This function assumes that there is at most one tfevents file in each directory 47 | and will add suffix to that directory. 48 | 49 | :param bool refresh: re-create csv file under any condition. 50 | """ 51 | if alg_type == 'sarl': 52 | tfevent_files = find_all_files(root_dir, re.compile(r"^.*tfevents.*$")) 53 | elif alg_type == 'marl': 54 | tfevent_files = find_all_files(root_dir, re.compile(r"^.*tfevents.*.13$")) 55 | else: 56 | print("wrong alg_type!") 57 | 58 | 59 | print(f"Converting {len(tfevent_files)} tfevents files under {root_dir} ...") 60 | result = {} 61 | with tqdm.tqdm(tfevent_files) as t: 62 | for tfevent_file in t: 63 | t.set_postfix(file=tfevent_file) 64 | output_file = os.path.join(os.path.split(tfevent_file)[0], "test_rew.csv") 65 | 66 | if os.path.exists(output_file) and not refresh: 67 | content = list(csv.reader(open(output_file, "r"))) 68 | if content[0] == ["env_step", "rew", "time"]: 69 | for i in range(1, len(content)): 70 | content[i] = list(map(eval, content[i])) 71 | result[output_file] = content 72 | continue 73 | 74 | 75 | ea = event_accumulator.EventAccumulator(tfevent_file) 76 | ea.Reload() 77 | initial_time = ea._first_event_timestamp 78 | content = [["env_step", "rew", "time"]] 79 | 80 | 81 | if alg_type == "sarl": 82 | for i, test_rew in enumerate(ea.scalars.Items("Train/mean_reward")): 83 | content.append( 84 | [ 85 | test_rew.step * env_step * env_num, # if env is to lift a pot, change it as test_rew.step * 20 * 2048 86 | round(test_rew.value, 4), 87 | round(test_rew.wall_time - initial_time, 4), 88 | ] 89 | ) 90 | 91 | elif alg_type == 'marl': 92 | for i, test_rew in enumerate(ea.scalars.Items("train_episode_rewards")): 93 | content.append( 94 | [ 95 | test_rew.step, 96 | round(test_rew.value, 4), 97 | round(test_rew.wall_time - initial_time, 4), 98 | ] 99 | ) 100 | 101 | csv.writer(open(output_file, 'w')).writerows(content) 102 | result[output_file] = content 103 | return result 104 | 105 | 106 | def merge_csv(csv_files, root_dir, remove_zero=False): 107 | """Merge result in csv_files into a single csv file.""" 108 | assert len(csv_files) > 0 109 | if remove_zero: 110 | for v in csv_files.values(): 111 | if v[1][0] == 0: 112 | v.pop(1) 113 | 114 | sorted_keys = sorted(csv_files.keys()) 115 | sorted_values = [csv_files[k][1:] for k in sorted_keys] 116 | content = [ 117 | ["env_step", "rew", "rew:shaded"] + 118 | list(map(lambda f: "rew:" + os.path.relpath(f, root_dir), sorted_keys)) 119 | ] 120 | 121 | for rows in zip(*sorted_values): 122 | array = np.array(rows) 123 | # assert len(set(array[:, 0])) == 1, (set(array[:, 0]), array[:, 0]) 124 | # line = [rows[0][0], round(array[:, 1].mean(), 4), round(array[:, 1].std(), 4)] 125 | line = [round(array[:, 0].mean(), 4), round(array[:, 1].mean(), 4), round(array[:, 1].std(), 4)] 126 | line += array[:, 1].tolist() 127 | content.append(line) 128 | output_path = os.path.join(root_dir, f"test_rew_{len(csv_files)}seeds.csv") 129 | print(f"Output merged csv file to {output_path} with {len(content[1:])} lines.") 130 | csv.writer(open(output_path, "w")).writerows(content) 131 | 132 | 133 | if __name__ == "__main__": 134 | parser = argparse.ArgumentParser() 135 | 136 | parser.add_argument( 137 | '--alg-name', 138 | type=str, 139 | default='happo' 140 | ) 141 | parser.add_argument( 142 | '--alg-type', 143 | type=str, 144 | default='marl', 145 | help="single-agent: sarl; multi-agent: marl" 146 | ) 147 | parser.add_argument( 148 | '--env-num', 149 | type=int, 150 | default=2048, 151 | help="the number of parallel simulations" 152 | ) 153 | parser.add_argument( 154 | '--env-step', 155 | type=int, 156 | default=8, 157 | help="the environment lifting a pot : 20; other environments: 8" 158 | ) 159 | parser.add_argument( 160 | '--refresh', 161 | action="store_true", 162 | help="Re-generate all csv files instead of using existing one." 163 | ) 164 | parser.add_argument( 165 | '--remove-zero', 166 | action="store_true", 167 | help="Remove the data point of env_step == 0." 168 | ) 169 | parser.add_argument('--root-dir', type=str) 170 | args = parser.parse_args() 171 | 172 | args.root_dir = '{}/{}'.format(args.root_dir,args.alg_name) 173 | 174 | csv_files = convert_tfevents_to_csv(args.root_dir, args.alg_type, args.env_num, args.env_step, args.refresh) 175 | merge_csv(csv_files, args.root_dir, args.remove_zero) 176 | -------------------------------------------------------------------------------- /utils/o3dviewer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | 4 | class PointcloudVisualizer() : 5 | 6 | def __init__(self) -> None: 7 | self.vis = o3d.visualization.VisualizerWithKeyCallback() 8 | self.vis.create_window() 9 | # self.vis.register_key_callback(key, your_update_function) 10 | 11 | def add_geometry(self, cloud) : 12 | self.vis.add_geometry(cloud) 13 | 14 | def update(self, cloud): 15 | #Your update routine 16 | self.vis.update_geometry(cloud) 17 | self.vis.update_renderer() 18 | self.vis.poll_events() 19 | 20 | if __name__ == "__main__" : 21 | 22 | visualizer = PointcloudVisualizer() 23 | cloud = o3d.io.read_point_cloud("../../../assets/dataset/one_door_cabinet/46145_link_0/point_sample/full_PC.ply") 24 | visualizer.add_geometry(cloud) 25 | while True : 26 | print("update") 27 | visualizer.update(cloud) 28 | xyz = np.asarray(cloud.points) 29 | xyz *= 1.001 -------------------------------------------------------------------------------- /utils/package_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # NVIDIA CORPORATION and its licensors retain all intellectual property 3 | # and proprietary rights in and to this software, related documentation 4 | # and any modifications thereto. Any use, reproduction, disclosure or 5 | # distribution of this software and related documentation without an express 6 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 7 | 8 | from ast import arg 9 | import numpy as np 10 | import random 11 | 12 | from bidexhands.utils.config import set_np_formatting, set_seed, get_args, parse_sim_params, load_cfg 13 | from bidexhands.utils.parse_task import parse_task 14 | from bidexhands.utils.process_sarl import process_sarl 15 | from bidexhands.utils.process_marl import process_MultiAgentRL, get_AgentIndex 16 | from bidexhands.utils.process_mtrl import * 17 | from bidexhands.utils.process_metarl import * 18 | from bidexhands.utils.process_offrl import * 19 | 20 | def make(task_name, algo): 21 | set_np_formatting() 22 | args = get_args(task_name=task_name, algo=algo) 23 | cfg, cfg_train, logdir = load_cfg(args) 24 | sim_params = parse_sim_params(args, cfg, cfg_train) 25 | set_seed(cfg_train.get("seed", -1), cfg_train.get("torch_deterministic", False)) 26 | 27 | print("Algorithm: ", args.algo) 28 | agent_index = get_AgentIndex(cfg) 29 | 30 | if args.algo in ["mappo", "happo", "hatrpo","maddpg","ippo"]: 31 | # maddpg exists a bug now 32 | args.task_type = "MultiAgent" 33 | 34 | task, env = parse_task(args, cfg, cfg_train, sim_params, agent_index) 35 | 36 | elif args.algo in ["ppo","ddpg","sac","td3","trpo"]: 37 | task, env = parse_task(args, cfg, cfg_train, sim_params, agent_index) 38 | 39 | 40 | elif args.algo in ["mtppo", "random"]: 41 | args.task_type = "MultiTask" 42 | 43 | task, env = parse_task(args, cfg, cfg_train, sim_params, agent_index) 44 | 45 | elif args.algo in ["mamlppo"]: 46 | args.task_type = "Meta" 47 | task, env = parse_task(args, cfg, cfg_train, sim_params, agent_index) 48 | 49 | elif args.algo in ["td3_bc", "bcq", "iql", "ppo_collect"]: 50 | task, env = parse_task(args, cfg, cfg_train, sim_params, agent_index) 51 | 52 | else: 53 | print("Unrecognized algorithm!\nAlgorithm should be one of: [happo, hatrpo, mappo,ippo,maddpg,sac,td3,trpo,ppo,ddpg, mtppo, random, mamlppo, td3_bc, bcq, iql, ppo_collect]") 54 | 55 | 56 | return env 57 | 58 | if __name__ == '__main__': 59 | set_np_formatting() 60 | args = get_args() 61 | cfg, cfg_train, logdir = load_cfg(args) 62 | sim_params = parse_sim_params(args, cfg, cfg_train) 63 | set_seed(cfg_train.get("seed", -1), cfg_train.get("torch_deterministic", False)) 64 | 65 | -------------------------------------------------------------------------------- /utils/parse_task.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # NVIDIA CORPORATION and its licensors retain all intellectual property 3 | # and proprietary rights in and to this software, related documentation 4 | # and any modifications thereto. Any use, reproduction, disclosure or 5 | # distribution of this software and related documentation without an express 6 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 7 | 8 | # from bidexhands.tasks.shadow_hand_over import ShadowHandOver 9 | # from bidexhands.tasks.shadow_hand_catch_underarm import ShadowHandCatchUnderarm 10 | # from bidexhands.tasks.shadow_hand_two_catch_underarm import ShadowHandTwoCatchUnderarm 11 | # from bidexhands.tasks.shadow_hand_catch_abreast import ShadowHandCatchAbreast 12 | # from bidexhands.tasks.shadow_hand_lift_underarm import ShadowHandLiftUnderarm 13 | # from bidexhands.tasks.shadow_hand_catch_over2underarm import ShadowHandCatchOver2Underarm 14 | # from bidexhands.tasks.shadow_hand_door_close_inward import ShadowHandDoorCloseInward 15 | # from bidexhands.tasks.shadow_hand_door_close_outward import ShadowHandDoorCloseOutward 16 | # from bidexhands.tasks.shadow_hand_door_open_inward import ShadowHandDoorOpenInward 17 | # from bidexhands.tasks.shadow_hand_door_open_outward import ShadowHandDoorOpenOutward 18 | # from bidexhands.tasks.shadow_hand_bottle_cap import ShadowHandBottleCap 19 | # from bidexhands.tasks.shadow_hand_push_block import ShadowHandPushBlock 20 | # from bidexhands.tasks.shadow_hand_swing_cup import ShadowHandSwingCup 21 | # from bidexhands.tasks.shadow_hand_grasp_and_place import ShadowHandGraspAndPlace 22 | # from bidexhands.tasks.shadow_hand_scissors import ShadowHandScissors 23 | # from bidexhands.tasks.shadow_hand_switch import ShadowHandSwitch 24 | # from bidexhands.tasks.shadow_hand_pen import ShadowHandPen 25 | # from bidexhands.tasks.shadow_hand_re_orientation import ShadowHandReOrientation 26 | # from bidexhands.tasks.shadow_hand_kettle import ShadowHandKettle 27 | # from bidexhands.tasks.shadow_hand_block_stack import ShadowHandBlockStack 28 | 29 | # # Allegro hand 30 | # from bidexhands.tasks.allegro_hand_over import AllegroHandOver 31 | # from bidexhands.tasks.allegro_hand_catch_underarm import AllegroHandCatchUnderarm 32 | 33 | # # Meta 34 | # from bidexhands.tasks.shadow_hand_meta.shadow_hand_meta_mt1 import ShadowHandMetaMT1 35 | # from bidexhands.tasks.shadow_hand_meta.shadow_hand_meta_ml1 import ShadowHandMetaML1 36 | # from bidexhands.tasks.shadow_hand_meta.shadow_hand_meta_mt4 import ShadowHandMetaMT4 37 | 38 | # from bidexhands.tasks.hand_base.vec_task import VecTaskCPU, VecTaskGPU, VecTaskPython, VecTaskPythonArm 39 | # from bidexhands.tasks.hand_base.multi_vec_task import MultiVecTaskPython, SingleVecTaskPythonArm 40 | # from bidexhands.tasks.hand_base.multi_task_vec_task import MultiTaskVecTaskPython 41 | # from bidexhands.tasks.hand_base.meta_vec_task import MetaVecTaskPython 42 | # from bidexhands.tasks.hand_base.vec_task_rlgames import RLgamesVecTaskPython 43 | 44 | from tasks.base.vec_task import VecTaskCPU, VecTaskGPU, VecTaskPython, VecTaskPythonArm 45 | 46 | from importlib import import_module 47 | from utils.config import warn_task_name 48 | import json 49 | 50 | 51 | def parse_task(args, cfg, cfg_train, sim_params, agent_index): 52 | 53 | # create native task and pass custom config 54 | device_id = args.device_id 55 | rl_device = args.rl_device 56 | 57 | cfg["seed"] = cfg_train.get("seed", -1) 58 | cfg_task = cfg["env"] 59 | cfg_task["seed"] = cfg["seed"] 60 | 61 | if args.task_type == "C++": 62 | raise NotImplementedError("C++ task is not supported yet") 63 | 64 | elif args.task_type == "Python": 65 | print("Python") 66 | 67 | task_env, task_name, task_repre = args.task.split("@") 68 | camera = args.camera 69 | 70 | try: 71 | task_name_voc = task_name.split("_") 72 | task_name_voc = [word.capitalize() for word in task_name_voc] 73 | task_name = "".join(task_name_voc) 74 | Module = import_module(f"tasks.{task_env}.{task_name}") 75 | Task = getattr(Module, task_name) 76 | 77 | task = Task( 78 | cfg=cfg, 79 | sim_params=sim_params, 80 | physics_engine=args.physics_engine, 81 | device_type=args.device, 82 | device_id=device_id, 83 | camera=camera, 84 | headless=args.headless,) 85 | except NameError as e: 86 | print(e) 87 | warn_task_name() 88 | 89 | env = VecTaskPython(task, rl_device) 90 | # if args.task == "OneFrankaCabinet" : 91 | # env = VecTaskPythonArm(task, rl_device) 92 | # else : 93 | # env = VecTaskPython(task, rl_device) 94 | 95 | elif args.task_type == "MultiAgent": 96 | raise NotImplementedError("MultiAgent task is not supported yet") 97 | 98 | elif args.task_type == "MultiTask": 99 | raise NotImplementedError("MultiTask task is not supported yet") 100 | 101 | elif args.task_type == "Meta": 102 | raise NotImplementedError("Meta task is not supported yet") 103 | 104 | elif args.task_type == "RLgames": 105 | raise NotImplementedError("RLgames task is not supported yet") 106 | 107 | return task, env 108 | 109 | -------------------------------------------------------------------------------- /utils/parse_task_plan.py: -------------------------------------------------------------------------------- 1 | from tasks.base.vec_task import VecTaskCPU, VecTaskGPU, VecTaskPython, VecTaskPythonArm 2 | 3 | from importlib import import_module 4 | from utils.config import warn_task_name 5 | import json 6 | 7 | def parse_task_plan(args, cfg, sim_params): 8 | 9 | # create native task and pass custom config 10 | device_id = args.device_id 11 | 12 | if args.task_type == "C++": 13 | raise NotImplementedError("C++ task is not supported yet") 14 | 15 | elif args.task_type == "Python": 16 | print("Python") 17 | 18 | task_env, task_name, task_repre = args.task.split("@") 19 | camera = args.camera 20 | 21 | try: 22 | task_name_voc = task_name.split("_") 23 | task_name_voc = [word.capitalize() for word in task_name_voc] 24 | task_name = "".join(task_name_voc) 25 | Module = import_module(f"tasks.{task_env}.{task_name}") 26 | Task = getattr(Module, task_name) 27 | 28 | task = Task( 29 | cfg=cfg, 30 | sim_params=sim_params, 31 | physics_engine=args.physics_engine, 32 | device_type=args.device, 33 | device_id=device_id, 34 | camera=camera, 35 | headless=args.headless,) 36 | except NameError as e: 37 | print(e) 38 | warn_task_name() 39 | 40 | env = VecTaskPython(task, device_id) 41 | # if args.task == "OneFrankaCabinet" : 42 | # env = VecTaskPythonArm(task, rl_device) 43 | # else : 44 | # env = VecTaskPython(task, rl_device) 45 | 46 | elif args.task_type == "MultiAgent": 47 | raise NotImplementedError("MultiAgent task is not supported yet") 48 | 49 | elif args.task_type == "MultiTask": 50 | raise NotImplementedError("MultiTask task is not supported yet") 51 | 52 | elif args.task_type == "Meta": 53 | raise NotImplementedError("Meta task is not supported yet") 54 | 55 | elif args.task_type == "RLgames": 56 | raise NotImplementedError("RLgames task is not supported yet") 57 | 58 | return task, env 59 | 60 | -------------------------------------------------------------------------------- /utils/process_offrl.py: -------------------------------------------------------------------------------- 1 | 2 | def process_td3_bc(args, env, cfg_train, logdir): 3 | from bidexhands.algorithms.offrl.td3_bc import TD3_BC 4 | learn_cfg = cfg_train["learn"] 5 | is_testing = learn_cfg["test"] 6 | # is_testing = True 7 | # Override resume and testing flags if they are passed as parameters. 8 | if args.model_dir != "": 9 | is_testing = True 10 | chkpt_path = args.model_dir 11 | 12 | """Set up the PPO system for training or inferencing.""" 13 | td3_bc = TD3_BC(vec_env=env, 14 | device=env.rl_device, 15 | discount = learn_cfg["discount"], 16 | tau = learn_cfg["tau"], 17 | policy_freq = learn_cfg["policy_freq"], 18 | alpha = learn_cfg["alpha"], 19 | batch_size = learn_cfg["batch_size"], 20 | max_timesteps = learn_cfg["max_timesteps"], 21 | iterations = learn_cfg["iterations"], 22 | log_dir=logdir, 23 | datatype = args.datatype) 24 | 25 | if is_testing and args.model_dir != "": 26 | print("Loading model from {}".format(chkpt_path)) 27 | td3_bc.test(chkpt_path) 28 | elif args.model_dir != "": 29 | print("Loading model from {}".format(chkpt_path)) 30 | td3_bc.load(chkpt_path) 31 | 32 | return td3_bc 33 | 34 | def process_bcq(args, env, cfg_train, logdir): 35 | from bidexhands.algorithms.offrl.bcq import BCQ 36 | learn_cfg = cfg_train["learn"] 37 | is_testing = learn_cfg["test"] 38 | # is_testing = True 39 | # Override resume and testing flags if they are passed as parameters. 40 | if args.model_dir != "": 41 | is_testing = True 42 | chkpt_path = args.model_dir 43 | 44 | """Set up the PPO system for training or inferencing.""" 45 | bcq = BCQ(vec_env=env, 46 | device=env.rl_device, 47 | discount = learn_cfg["discount"], 48 | tau = learn_cfg["tau"], 49 | lmbda = learn_cfg["lmbda"], 50 | phi = learn_cfg["phi"], 51 | batch_size = learn_cfg["batch_size"], 52 | max_timesteps = learn_cfg["max_timesteps"], 53 | iterations = learn_cfg["iterations"], 54 | log_dir=logdir, 55 | datatype = args.datatype) 56 | 57 | if is_testing and args.model_dir != "": 58 | print("Loading model from {}".format(chkpt_path)) 59 | bcq.test(chkpt_path) 60 | elif args.model_dir != "": 61 | print("Loading model from {}".format(chkpt_path)) 62 | bcq.load(chkpt_path) 63 | 64 | return bcq 65 | 66 | def process_iql(args, env, cfg_train, logdir): 67 | from bidexhands.algorithms.offrl.iql import IQL 68 | learn_cfg = cfg_train["learn"] 69 | is_testing = learn_cfg["test"] 70 | # is_testing = True 71 | # Override resume and testing flags if they are passed as parameters. 72 | if args.model_dir != "": 73 | is_testing = True 74 | chkpt_path = args.model_dir 75 | 76 | """Set up the PPO system for training or inferencing.""" 77 | iql = IQL(vec_env=env, 78 | device=env.rl_device, 79 | discount = learn_cfg["discount"], 80 | tau = learn_cfg["tau"], 81 | expectile = learn_cfg["expectile"], 82 | beta = learn_cfg["beta"], 83 | scale = learn_cfg["scale"], 84 | batch_size = learn_cfg["batch_size"], 85 | max_timesteps = learn_cfg["max_timesteps"], 86 | iterations = learn_cfg["iterations"], 87 | log_dir=logdir, 88 | datatype = args.datatype) 89 | 90 | if is_testing and args.model_dir != "": 91 | print("Loading model from {}".format(chkpt_path)) 92 | iql.test(chkpt_path) 93 | elif args.model_dir != "": 94 | print("Loading model from {}".format(chkpt_path)) 95 | iql.load(chkpt_path) 96 | 97 | return iql 98 | 99 | def process_ppo_collect(args, env, cfg_train, logdir): 100 | from bidexhands.algorithms.offrl.ppo_collect import PPO, ActorCritic 101 | learn_cfg = cfg_train["learn"] 102 | is_testing = learn_cfg["test"] 103 | # is_testing = True 104 | # Override resume and testing flags if they are passed as parameters. 105 | if args.model_dir != "": 106 | is_testing = True 107 | chkpt_path = args.model_dir 108 | 109 | logdir = logdir + "_seed{}".format(env.task.cfg["seed"]) 110 | 111 | """Set up the PPO system for training or inferencing.""" 112 | ppo_collect = PPO(vec_env=env, 113 | actor_critic_class=ActorCritic, 114 | num_transitions_per_env=learn_cfg["nsteps"], 115 | num_learning_epochs=learn_cfg["noptepochs"], 116 | num_mini_batches=learn_cfg["nminibatches"], 117 | clip_param=learn_cfg["cliprange"], 118 | gamma=learn_cfg["gamma"], 119 | lam=learn_cfg["lam"], 120 | init_noise_std=learn_cfg.get("init_noise_std", 0.3), 121 | value_loss_coef=learn_cfg.get("value_loss_coef", 2.0), 122 | entropy_coef=learn_cfg["ent_coef"], 123 | learning_rate=learn_cfg["optim_stepsize"], 124 | max_grad_norm=learn_cfg.get("max_grad_norm", 2.0), 125 | use_clipped_value_loss=learn_cfg.get("use_clipped_value_loss", False), 126 | schedule=learn_cfg.get("schedule", "fixed"), 127 | desired_kl=learn_cfg.get("desired_kl", None), 128 | model_cfg=cfg_train["policy"], 129 | device=env.rl_device, 130 | sampler=learn_cfg.get("sampler", 'sequential'), 131 | log_dir=logdir, 132 | is_testing=is_testing, 133 | print_log=learn_cfg["print_log"], 134 | apply_reset=False, 135 | asymmetric=(env.num_states > 0), 136 | data_size=learn_cfg["data_size"] 137 | ) 138 | 139 | # ppo.test("/home/hp-3070/bi-dexhands/bi-dexhands/logs/shadow_hand_lift_underarm2/ppo/ppo_seed2/model_40000.pt") 140 | if is_testing and args.model_dir != "": 141 | print("Loading model from {}".format(chkpt_path)) 142 | ppo_collect.test(chkpt_path) 143 | elif args.model_dir != "": 144 | print("Loading model from {}".format(chkpt_path)) 145 | ppo_collect.load(chkpt_path) 146 | 147 | return ppo_collect -------------------------------------------------------------------------------- /utils/process_sarl.py: -------------------------------------------------------------------------------- 1 | from algos.rl.ppo import PPO 2 | # from algos.rl.sac import SAC 3 | # from algos.rl.td3 import TD3 4 | # from algos.rl.ddpg import DDPG 5 | # from algos.rl.trpo import TRPO 6 | 7 | def process_sarl(args, env, cfg_train, logdir): 8 | learn_cfg = cfg_train["learn"] 9 | is_testing = learn_cfg["test"] 10 | # is_testing = True 11 | # Override resume and testing flags if they are passed as parameters. 12 | if args.model_dir != "": 13 | is_testing = True 14 | chkpt_path = args.model_dir 15 | 16 | if args.max_iterations != -1: 17 | cfg_train["learn"]["max_iterations"] = args.max_iterations 18 | 19 | logdir = logdir + ".{}".format(env.task.cfg["seed"]) 20 | 21 | """Set up the algo system for training or inferencing.""" 22 | model = eval(args.algo.upper())(vec_env=env, 23 | cfg_train = cfg_train, 24 | device=env.rl_device, 25 | sampler=learn_cfg.get("sampler", 'sequential'), 26 | log_dir=logdir, 27 | is_testing=is_testing, 28 | print_log=learn_cfg["print_log"], 29 | apply_reset=False, 30 | asymmetric=(env.num_states > 0) 31 | ) 32 | 33 | # ppo.test("/home/hp-3070/logs/demo/scissors/ppo_seed0/model_6000.pt") 34 | if is_testing and args.model_dir != "": 35 | print("Loading model from {}".format(chkpt_path)) 36 | model.test(chkpt_path) 37 | elif args.model_dir != "": 38 | print("Loading model from {}".format(chkpt_path)) 39 | model.load(chkpt_path) 40 | 41 | return model -------------------------------------------------------------------------------- /utils/torch_jit_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # NVIDIA CORPORATION and its licensors retain all intellectual property 3 | # and proprietary rights in and to this software, related documentation 4 | # and any modifications thereto. Any use, reproduction, disclosure or 5 | # distribution of this software and related documentation without an express 6 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 7 | 8 | import numpy as np 9 | from isaacgym import gymtorch 10 | from isaacgym import gymapi 11 | from isaacgym.torch_utils import * 12 | import torch 13 | 14 | 15 | @torch.jit.script 16 | def compute_heading_and_up( 17 | torso_rotation, inv_start_rot, to_target, vec0, vec1, up_idx 18 | ): 19 | # type: (Tensor, Tensor, Tensor, Tensor, Tensor, int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor] 20 | num_envs = torso_rotation.shape[0] 21 | target_dirs = normalize(to_target) 22 | 23 | torso_quat = quat_mul(torso_rotation, inv_start_rot) 24 | up_vec = get_basis_vector(torso_quat, vec1).view(num_envs, 3) 25 | heading_vec = get_basis_vector(torso_quat, vec0).view(num_envs, 3) 26 | up_proj = up_vec[:, up_idx] 27 | heading_proj = torch.bmm(heading_vec.view( 28 | num_envs, 1, 3), target_dirs.view(num_envs, 3, 1)).view(num_envs) 29 | 30 | return torso_quat, up_proj, heading_proj, up_vec, heading_vec 31 | 32 | 33 | @torch.jit.script 34 | def compute_rot(torso_quat, velocity, ang_velocity, targets, torso_positions): 35 | vel_loc = quat_rotate_inverse(torso_quat, velocity) 36 | angvel_loc = quat_rotate_inverse(torso_quat, ang_velocity) 37 | 38 | roll, pitch, yaw = get_euler_xyz(torso_quat) 39 | 40 | walk_target_angle = torch.atan2(targets[:, 2] - torso_positions[:, 2], 41 | targets[:, 0] - torso_positions[:, 0]) 42 | angle_to_target = walk_target_angle - yaw 43 | 44 | return vel_loc, angvel_loc, roll, pitch, yaw, angle_to_target 45 | 46 | 47 | @torch.jit.script 48 | def quat_axis(q, axis=0): 49 | # type: (Tensor, int) -> Tensor 50 | basis_vec = torch.zeros(q.shape[0], 3, device=q.device) 51 | basis_vec[:, axis] = 1 52 | return quat_rotate(q, basis_vec) 53 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | 5 | def check(input): 6 | if type(input) == np.ndarray: 7 | return torch.from_numpy(input) 8 | 9 | def get_gard_norm(it): 10 | sum_grad = 0 11 | for x in it: 12 | if x.grad is None: 13 | continue 14 | sum_grad += x.grad.norm() ** 2 15 | return math.sqrt(sum_grad) 16 | 17 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 18 | """Decreases the learning rate linearly""" 19 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 20 | for param_group in optimizer.param_groups: 21 | param_group['lr'] = lr 22 | 23 | def huber_loss(e, d): 24 | a = (abs(e) <= d).float() 25 | b = (e > d).float() 26 | return a*e**2/2 + b*d*(abs(e)-d/2) 27 | 28 | def mse_loss(e): 29 | return e**2/2 30 | 31 | def get_shape_from_obs_space(obs_space): 32 | if obs_space.__class__.__name__ == 'Box': 33 | obs_shape = obs_space.shape 34 | elif obs_space.__class__.__name__ == 'list': 35 | obs_shape = obs_space 36 | else: 37 | raise NotImplementedError 38 | return obs_shape 39 | 40 | def get_shape_from_act_space(act_space): 41 | if act_space.__class__.__name__ == 'Discrete': 42 | act_shape = 1 43 | elif act_space.__class__.__name__ == "MultiDiscrete": 44 | act_shape = act_space.shape 45 | elif act_space.__class__.__name__ == "Box": 46 | act_shape = act_space.shape[0] 47 | elif act_space.__class__.__name__ == "MultiBinary": 48 | act_shape = act_space.shape[0] 49 | else: # agar 50 | act_shape = act_space[0].shape[0] + 1 51 | return act_shape 52 | 53 | 54 | def tile_images(img_nhwc): 55 | """ 56 | Tile N images into one big PxQ image 57 | (P,Q) are chosen to be as close as possible, and if N 58 | is square, then P=Q. 59 | input: img_nhwc, list or array of images, ndim=4 once turned into array 60 | n = batch index, h = height, w = width, c = channel 61 | returns: 62 | bigim_HWc, ndarray with ndim=3 63 | """ 64 | img_nhwc = np.asarray(img_nhwc) 65 | N, h, w, c = img_nhwc.shape 66 | H = int(np.ceil(np.sqrt(N))) 67 | W = int(np.ceil(float(N)/H)) 68 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 69 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 70 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 71 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 72 | return img_Hh_Ww_c --------------------------------------------------------------------------------