├── README.md ├── algos ├── extract.py ├── pca.py └── ppo.py ├── cassie ├── __init__.py ├── cassie.py ├── cassiemujoco │ ├── __init__.py │ ├── cassie-stl-meshes │ │ ├── achilles-rod.stl │ │ ├── foot-crank.stl │ │ ├── foot.stl │ │ ├── heel-spring.stl │ │ ├── hip-pitch.stl │ │ ├── hip-roll.stl │ │ ├── hip-yaw.stl │ │ ├── knee-spring.stl │ │ ├── knee.stl │ │ ├── pelvis.stl │ │ ├── plantar-rod.stl │ │ ├── shin.stl │ │ └── tarsus.stl │ ├── cassie.xml │ ├── cassieUDP.py │ ├── cassie_soft.xml │ ├── cassie_stiff.xml │ ├── cassiemujoco.py │ ├── cassiemujoco_ctypes.py │ ├── include │ │ ├── CassieCoreSim.h │ │ ├── PdInput.h │ │ ├── StateOutput.h │ │ ├── cassie_in_t.h │ │ ├── cassie_out_t.h │ │ ├── cassie_user_in_t.h │ │ ├── cassiemujoco.h │ │ ├── pd_in_t.h │ │ ├── state_out_t.h │ │ └── udp.h │ └── libcassiemujoco.so ├── trajectory │ ├── __init__.py │ ├── stepdata.bin │ └── trajectory.py └── udp.py ├── main.py ├── nn ├── actor.py ├── base.py ├── critic.py └── fit.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Learning Memory-Based Control for Human-Scale Bipedal Locomotion 2 | 3 | ## Purpose 4 | 5 | This repo is intended to serve as a foundation with which you can reproduce the results of the experiments detailed in our RSS 2020 paper, [Learning Memory-Based Control for Human-Scale Bipedal Locomotion](https://arxiv.org/abs/2006.02402). 6 | 7 | ## First-time setup 8 | This repo requires [MuJoCo 2.0](http://www.mujoco.org/). We recommend that you use Ubuntu 18.04. 9 | 10 | You will probably need to install the following packages: 11 | ```bash 12 | pip3 install --user torch numpy ray tensorboard 13 | sudo apt-get install -y curl git libgl1-mesa-dev libgl1-mesa-glx libglew-dev libosmesa6-dev net-tools unzip vim wget xpra xserver-xorg-dev patchelf 14 | ``` 15 | 16 | If you don't already have it, you will need to install MuJoCo. You will also need to obtain a license key `mjkey.txt` from the [official website](https://www.roboti.us/license.html). You can get a free 30-day trial if necessary. 17 | 18 | ```bash 19 | wget https://www.roboti.us/download/mujoco200_linux.zip 20 | unzip mujoco200_linux.zip 21 | mkdir ~/.mujoco 22 | mv mujoco200_linux ~/.mujoco/mujoco200 23 | cp [YOUR KEY FILE] ~/.mujoco/mjkey.txt 24 | ``` 25 | 26 | You will need to create an environment variable `LD_LIBRARY_PATH` to allow mujoco-py to find your mujoco directory. You can add it to your `~/.bashrc` or just enter it into the terminal every time you wish to use mujoco. 27 | ```bash 28 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin 29 | ``` 30 | 31 | ## Reproducing experiments 32 | 33 | ### Basics 34 | 35 | To train a policy in accordance with the hyperparameters used in the paper, execute this command: 36 | 37 | ```bash 38 | python3 main.py ppo --batch_size 64 --sample 50000 --epochs 8 --traj_len 300 --timesteps 60000000 --discount 0.95 --workers 56 --recurrent --randomize --layers 128,128 --std 0.13 --logdir LOG_DIRECTORY 39 | ``` 40 | 41 | To train a FF policy, simply remove the `--recurrent` argument. To train without dynamics randomization, remove the `--randomize` argument. 42 | 43 | 44 | ### Logging details / Monitoring live training progress 45 | Tensorboard logging is enabled by default. After initiating an experiment, your directory structure would look like this: 46 | 47 | ``` 48 | logs/ 49 | ├── [algo] 50 | │ └── [New Experiment Logdir] 51 | ``` 52 | 53 | To see live training progress, run ```$ tensorboard --logdir=logs``` then navigate to ```http://localhost:6006/``` in your browser 54 | -------------------------------------------------------------------------------- /algos/extract.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import numpy as np 6 | import locale, os, time 7 | 8 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 9 | from copy import deepcopy 10 | from nn.fit import Model 11 | from util import env_factory 12 | 13 | def get_hiddens(policy): 14 | """ 15 | A helper function for flattening the memory of a recurrent 16 | policy into a vector. 17 | """ 18 | hiddens = [] 19 | if hasattr(policy, 'hidden'): 20 | hiddens += [h.data for h in policy.hidden] 21 | 22 | if hasattr(policy, 'cells'): 23 | hiddens += [c.data for c in policy.cells] 24 | 25 | if hasattr(policy, 'latent'): 26 | hiddens += [l for l in policy.latent] 27 | 28 | return torch.cat([layer.view(-1) for layer in hiddens]).numpy() 29 | 30 | def collect_point(policy, max_traj_len): 31 | """ 32 | A helper function which collects a single memory-dynamics parameter pair 33 | from a trajectory. 34 | """ 35 | env = env_factory(True)() 36 | 37 | chosen_timestep = np.random.randint(15, max_traj_len) 38 | timesteps = 0 39 | done = False 40 | 41 | if hasattr(policy, 'init_hidden_state'): 42 | policy.init_hidden_state() 43 | 44 | state = env.reset() 45 | while not done and timesteps < chosen_timestep: 46 | 47 | action = policy(state).numpy() 48 | state, _, done, _ = env.step(action) 49 | timesteps += 1 50 | 51 | return get_hiddens(policy), env.get_damping(), env.get_mass(), env.get_ipos() 52 | 53 | @ray.remote 54 | def collect_data(policy, max_traj_len=45, points=500): 55 | """ 56 | A ray remote function which collects a series of memory-dynamics pairs 57 | and returns a dataset. 58 | """ 59 | policy = deepcopy(policy) 60 | torch.set_num_threads(1) 61 | with torch.no_grad(): 62 | done = True 63 | 64 | damps = [] 65 | masses = [] 66 | ipos = [] 67 | 68 | latent = [] 69 | ts = [] 70 | 71 | last = time.time() 72 | while len(latent) < points: 73 | x, d, m, q = collect_point(policy, max_traj_len) 74 | damps += [d] 75 | masses += [m] 76 | ipos += [q] 77 | latent += [x] 78 | return damps, masses, ipos, latent 79 | 80 | def concat(datalist): 81 | """ 82 | Concatenates several datasets into one larger 83 | dataset. 84 | """ 85 | damps = [] 86 | masses = [] 87 | ipos = [] 88 | latents = [] 89 | for l in datalist: 90 | damp, mass, quat, latent = l 91 | damps += damp 92 | masses += mass 93 | ipos += quat 94 | 95 | latents += latent 96 | damps = torch.tensor(damps).float() 97 | masses = torch.tensor(masses).float() 98 | ipos = torch.tensor(ipos).float() 99 | latents = torch.tensor(latents).float() 100 | return damps, masses, ipos, latents 101 | 102 | def run_experiment(args): 103 | """ 104 | The entry point for the dynamics extraction algorithm. 105 | """ 106 | from util import create_logger 107 | 108 | locale.setlocale(locale.LC_ALL, '') 109 | 110 | policy = torch.load(args.policy) 111 | 112 | env_fn = env_factory(True) 113 | 114 | layers = [int(x) for x in args.layers.split(',')] 115 | 116 | env = env_fn() 117 | policy.init_hidden_state() 118 | policy(torch.tensor(env.reset()).float()) 119 | latent_dim = get_hiddens(policy).shape[0] 120 | 121 | models = [] 122 | opts = [] 123 | for fn in [env.get_damping, env.get_mass, env.get_ipos]: 124 | output_dim = fn().shape[0] 125 | model = Model(latent_dim, output_dim, layers=layers) 126 | 127 | models += [model] 128 | opts += [optim.Adam(model.parameters(), lr=args.lr, eps=1e-5)] 129 | 130 | model.policy_path = args.policy 131 | 132 | logger = create_logger(args) 133 | 134 | best_loss = None 135 | actor_dir = os.path.split(args.policy)[0] 136 | create_new = True 137 | if os.path.exists(os.path.join(logger.dir, 'test_latents.pt')): 138 | x = torch.load(os.path.join(logger.dir, 'train_latents.pt')) 139 | test_x = torch.load(os.path.join(logger.dir, 'test_latents.pt')) 140 | 141 | damps = torch.load(os.path.join(logger.dir, 'train_damps.pt')) 142 | test_damps = torch.load(os.path.join(logger.dir, 'test_damps.pt')) 143 | 144 | masses = torch.load(os.path.join(logger.dir, 'train_masses.pt')) 145 | test_masses = torch.load(os.path.join(logger.dir, 'test_masses.pt')) 146 | 147 | ipos = torch.load(os.path.join(logger.dir, 'train_ipos.pt')) 148 | test_ipos = torch.load(os.path.join(logger.dir, 'test_ipos.pt')) 149 | 150 | if args.points > len(x) + len(test_x): 151 | create_new = True 152 | else: 153 | create_new = False 154 | 155 | if create_new: 156 | if not ray.is_initialized(): 157 | ray.init(num_cpus=args.workers) 158 | 159 | print("Collecting {:4d} timesteps of data.".format(args.points)) 160 | points_per_worker = max(args.points // args.workers, 1) 161 | start = time.time() 162 | 163 | damps, masses, ipos, x = concat(ray.get([collect_data.remote(policy, points=points_per_worker) for _ in range(args.workers)])) 164 | 165 | split = int(0.8 * len(x)) 166 | 167 | test_x = x[split:] 168 | x = x[:split] 169 | 170 | test_damps = damps[split:] 171 | damps = damps[:split] 172 | 173 | test_masses = masses[split:] 174 | masses = masses[:split] 175 | 176 | test_ipos = ipos[split:] 177 | ipos = ipos[:split] 178 | 179 | print("{:3.2f} to collect {} timesteps. Training set is {}, test set is {}".format(time.time() - start, len(x)+len(test_x), len(x), len(test_x))) 180 | torch.save(x, os.path.join(logger.dir, 'train_latents.pt')) 181 | torch.save(test_x, os.path.join(logger.dir, 'test_latents.pt')) 182 | 183 | torch.save(damps, os.path.join(logger.dir, 'train_damps.pt')) 184 | torch.save(test_damps, os.path.join(logger.dir, 'test_damps.pt')) 185 | 186 | torch.save(masses, os.path.join(logger.dir, 'train_masses.pt')) 187 | torch.save(test_masses, os.path.join(logger.dir, 'test_masses.pt')) 188 | 189 | torch.save(ipos, os.path.join(logger.dir, 'train_ipos.pt')) 190 | torch.save(test_ipos, os.path.join(logger.dir, 'test_ipos.pt')) 191 | 192 | for epoch in range(args.epochs): 193 | 194 | random_indices = SubsetRandomSampler(range(len(x)-1)) 195 | sampler = BatchSampler(random_indices, args.batch_size, drop_last=False) 196 | 197 | for j, batch_idx in enumerate(sampler): 198 | batch_x = x[batch_idx]#.float() 199 | batch = [damps[batch_idx], masses[batch_idx], ipos[batch_idx]] 200 | 201 | losses = [] 202 | for model, batch_y, opt in zip(models, batch, opts): 203 | loss = 0.5 * (batch_y - model(batch_x)).pow(2).mean() 204 | 205 | opt.zero_grad() 206 | loss.backward() 207 | opt.step() 208 | 209 | losses.append(loss.item()) 210 | 211 | print("Epoch {:3d} batch {:4d}/{:4d} ".format(epoch, j, len(sampler)-1), end='\r') 212 | 213 | train_y = [damps, masses, ipos] 214 | test_y = [test_damps, test_masses, test_ipos] 215 | order = ['damping', 'mass', 'com'] 216 | 217 | with torch.no_grad(): 218 | print("\nEpoch {:3d} losses:".format(epoch)) 219 | for model, y_tr, y_te, name in zip(models, train_y, test_y, order): 220 | loss_total = 0.5 * (y_tr - model(x)).pow(2).mean().item() 221 | 222 | preds = model(test_x) 223 | test_loss = 0.5 * (y_te - preds).pow(2).mean().item() 224 | pce = torch.mean(torch.abs((y_te - preds) / (y_te + 1e-5))) 225 | err = torch.mean(torch.abs((y_te - preds))) 226 | 227 | logger.add_scalar(logger.arg_hash + '/' + name + '_loss', test_loss, epoch) 228 | logger.add_scalar(logger.arg_hash + '/' + name + '_percenterr', pce, epoch) 229 | logger.add_scalar(logger.arg_hash + '/' + name + '_abserr', err, epoch) 230 | model.dyn_parameter = name 231 | torch.save(model, os.path.join(logger.dir, name + '_extractor.pt')) 232 | print("\t{:16s}: train loss {:7.6f} test loss {:7.6f}, err {:5.4f}, percent err {:3.2f}".format(name, loss_total, test_loss, err, pce)) 233 | 234 | -------------------------------------------------------------------------------- /algos/pca.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import locale, os, time, sys 6 | from pathlib import Path 7 | 8 | from util import env_factory 9 | import matplotlib.pyplot as plt 10 | 11 | def get_hiddens(policy): 12 | hiddens = [] 13 | if hasattr(policy, 'hidden'): 14 | hiddens += [h.data for h in policy.hidden] 15 | 16 | if hasattr(policy, 'cells'): 17 | hiddens += [c.data for c in policy.cells] 18 | 19 | if hasattr(policy, 'latent'): 20 | hiddens += [l for l in policy.latent] 21 | 22 | return torch.cat([layer.view(-1) for layer in hiddens]).numpy() 23 | 24 | def run_pca(policy): 25 | max_traj_len = 1000 26 | from sklearn.decomposition import PCA 27 | with torch.no_grad(): 28 | env = env_factory(False)() 29 | state = env.reset() 30 | 31 | done = False 32 | timesteps = 0 33 | eval_reward = 0 34 | 35 | if hasattr(policy, 'init_hidden_state'): 36 | policy.init_hidden_state() 37 | 38 | mems = [] 39 | while not done and timesteps < max_traj_len: 40 | 41 | env.speed = 0.5 42 | 43 | action = policy.forward(torch.Tensor(state)).numpy() 44 | state, reward, done, _ = env.step(action) 45 | env.render() 46 | eval_reward += reward 47 | timesteps += 1 48 | 49 | memory = get_hiddens(policy) 50 | mems.append(memory) 51 | 52 | data = np.vstack(mems) 53 | 54 | pca = PCA(n_components=2) 55 | 56 | fig = plt.figure() 57 | plt.axis('off') 58 | base = (0.05, 0.05, 0.05) 59 | 60 | components = pca.fit_transform(data) 61 | 62 | x = components[:,0] 63 | y = components[:,1] 64 | c = [] 65 | for i in range(len(x)): 66 | c.append(np.hstack([base, (len(x) - i/2) / len(x)])) 67 | 68 | plt.scatter(x, y, color=c, s=0.8) 69 | plt.show() 70 | plt.close() 71 | -------------------------------------------------------------------------------- /algos/ppo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains the code implementing Proximal Policy Optimization (PPO), 3 | with objective clipping and early termination if a KL threshold 4 | is exceeded. 5 | """ 6 | 7 | import os 8 | import ray 9 | import torch 10 | import numpy as np 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 15 | from torch.distributions import kl_divergence 16 | from torch.nn.utils.rnn import pad_sequence 17 | from copy import deepcopy 18 | from time import time 19 | 20 | class Buffer: 21 | """ 22 | A class representing the replay buffer used in PPO. Used 23 | to states, actions, and reward/return, and then calculate 24 | advantage from discounted sum of returns. 25 | """ 26 | def __init__(self, discount=0.99): 27 | self.discount = discount 28 | self.states = [] 29 | self.actions = [] 30 | self.rewards = [] 31 | self.values = [] 32 | self.returns = [] 33 | self.advantages = [] 34 | 35 | self.size = 0 36 | 37 | self.traj_idx = [0] 38 | self.buffer_ready = False 39 | 40 | def __len__(self): 41 | return len(self.states) 42 | 43 | def push(self, state, action, reward, value, done=False): 44 | self.states += [state] 45 | self.actions += [action] 46 | self.rewards += [reward] 47 | self.values += [value] 48 | 49 | self.size += 1 50 | 51 | def end_trajectory(self, terminal_value=0): 52 | self.traj_idx += [self.size] 53 | rewards = self.rewards[self.traj_idx[-2]:self.traj_idx[-1]] 54 | 55 | returns = [] 56 | 57 | R = terminal_value 58 | for reward in reversed(rewards): 59 | R = self.discount * R + reward 60 | returns.insert(0, R) 61 | 62 | self.returns += returns 63 | 64 | def _finish_buffer(self): 65 | self.states = torch.Tensor(self.states) 66 | self.actions = torch.Tensor(self.actions) 67 | self.rewards = torch.Tensor(self.rewards) 68 | self.returns = torch.Tensor(self.returns) 69 | self.values = torch.Tensor(self.values) 70 | 71 | a = self.returns - self.values 72 | a = (a - a.mean()) / (a.std() + 1e-4) 73 | self.advantages = a 74 | self.buffer_ready = True 75 | 76 | def sample(self, batch_size=64, recurrent=False): 77 | if not self.buffer_ready: 78 | self._finish_buffer() 79 | 80 | if recurrent: 81 | """ 82 | If we are returning a sample for a recurrent network, we should 83 | return a zero-padded tensor of size [traj_len, batch_size, dim], 84 | or a trajectory of batched states/actions/returns. 85 | """ 86 | random_indices = SubsetRandomSampler(range(len(self.traj_idx)-1)) 87 | sampler = BatchSampler(random_indices, batch_size, drop_last=True) 88 | 89 | for traj_indices in sampler: 90 | states = [self.states[self.traj_idx[i]:self.traj_idx[i+1]] for i in traj_indices] 91 | actions = [self.actions[self.traj_idx[i]:self.traj_idx[i+1]] for i in traj_indices] 92 | returns = [self.returns[self.traj_idx[i]:self.traj_idx[i+1]] for i in traj_indices] 93 | advantages = [self.advantages[self.traj_idx[i]:self.traj_idx[i+1]] for i in traj_indices] 94 | 95 | traj_mask = [torch.ones_like(r) for r in returns] 96 | 97 | states = pad_sequence(states, batch_first=False) 98 | actions = pad_sequence(actions, batch_first=False) 99 | returns = pad_sequence(returns, batch_first=False) 100 | advantages = pad_sequence(advantages, batch_first=False) 101 | traj_mask = pad_sequence(traj_mask, batch_first=False) 102 | 103 | yield states, actions, returns, advantages, traj_mask 104 | 105 | else: 106 | """ 107 | If we are returning a sample for a conventional network, we should 108 | return a tensor of size [batch_size, dim], or a batch of timesteps. 109 | """ 110 | random_indices = SubsetRandomSampler(range(self.size)) 111 | sampler = BatchSampler(random_indices, batch_size, drop_last=True) 112 | 113 | for i, idxs in enumerate(sampler): 114 | states = self.states[idxs] 115 | actions = self.actions[idxs] 116 | returns = self.returns[idxs] 117 | advantages = self.advantages[idxs] 118 | 119 | yield states, actions, returns, advantages, 1 120 | 121 | @ray.remote 122 | class PPO_Worker: 123 | """ 124 | A class representing a parallel worker used to explore the 125 | environment. 126 | """ 127 | def __init__(self, actor, critic, env_fn, gamma): 128 | torch.set_num_threads(1) 129 | self.gamma = gamma 130 | self.actor = deepcopy(actor) 131 | self.critic = deepcopy(critic) 132 | self.env = env_fn() 133 | 134 | def sync_policy(self, new_actor_params, new_critic_params, input_norm=None): 135 | for p, new_p in zip(self.actor.parameters(), new_actor_params): 136 | p.data.copy_(new_p) 137 | 138 | for p, new_p in zip(self.critic.parameters(), new_critic_params): 139 | p.data.copy_(new_p) 140 | 141 | if input_norm is not None: 142 | self.actor.state_mean, self.actor.state_mean_diff, self.actor.state_n = input_norm 143 | 144 | def collect_experience(self, max_traj_len, min_steps): 145 | with torch.no_grad(): 146 | start = time() 147 | 148 | num_steps = 0 149 | memory = Buffer(self.gamma) 150 | actor = self.actor 151 | critic = self.critic 152 | 153 | while num_steps < min_steps: 154 | state = torch.Tensor(self.env.reset()) 155 | 156 | done = False 157 | value = 0 158 | traj_len = 0 159 | 160 | if hasattr(actor, 'init_hidden_state'): 161 | actor.init_hidden_state() 162 | 163 | if hasattr(critic, 'init_hidden_state'): 164 | critic.init_hidden_state() 165 | 166 | while not done and traj_len < max_traj_len: 167 | state = torch.Tensor(state) 168 | action = actor(state, False) 169 | value = critic(state) 170 | 171 | next_state, reward, done, _ = self.env.step(action.numpy()) 172 | 173 | reward = np.array([reward]) 174 | 175 | memory.push(state.numpy(), action.numpy(), reward, value.numpy()) 176 | 177 | state = next_state 178 | 179 | traj_len += 1 180 | num_steps += 1 181 | 182 | value = (not done) * critic(torch.Tensor(state)).numpy() 183 | memory.end_trajectory(terminal_value=value) 184 | 185 | return memory 186 | 187 | class PPO: 188 | def __init__(self, actor, critic, env_fn, args): 189 | 190 | self.actor = actor 191 | self.old_actor = deepcopy(actor) 192 | self.critic = critic 193 | 194 | if actor.is_recurrent or critic.is_recurrent: 195 | self.recurrent = True 196 | else: 197 | self.recurrent = False 198 | 199 | self.actor_optim = optim.Adam(self.actor.parameters(), lr=args.a_lr, eps=args.eps) 200 | self.critic_optim = optim.Adam(self.critic.parameters(), lr=args.c_lr, eps=args.eps) 201 | self.env_fn = env_fn 202 | self.discount = args.discount 203 | self.grad_clip = args.grad_clip 204 | 205 | if not ray.is_initialized(): 206 | ray.init(num_cpus=args.workers) 207 | 208 | self.workers = [PPO_Worker.remote(actor, critic, env_fn, args.discount) for _ in range(args.workers)] 209 | 210 | def sync_policy(self, states, actions, returns, advantages, mask): 211 | with torch.no_grad(): 212 | old_pdf = self.old_actor.pdf(states) 213 | old_log_probs = old_pdf.log_prob(actions).sum(-1, keepdim=True) 214 | 215 | pdf = self.actor.pdf(states) 216 | log_probs = pdf.log_prob(actions).sum(-1, keepdim=True) 217 | 218 | ratio = ((log_probs - old_log_probs)).exp() 219 | cpi_loss = ratio * advantages * mask 220 | clip_loss = ratio.clamp(0.8, 1.2) * advantages * mask 221 | actor_loss = -torch.min(cpi_loss, clip_loss).mean() 222 | 223 | critic_loss = 0.5 * ((returns - self.critic(states)) * mask).pow(2).mean() 224 | 225 | self.actor_optim.zero_grad() 226 | self.critic_optim.zero_grad() 227 | 228 | actor_loss.backward() 229 | critic_loss.backward() 230 | 231 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.grad_clip) 232 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.grad_clip) 233 | self.actor_optim.step() 234 | self.critic_optim.step() 235 | 236 | with torch.no_grad(): 237 | return kl_divergence(pdf, old_pdf).mean().numpy(), ((actor_loss).item(), critic_loss.item()) 238 | 239 | def merge_buffers(self, buffers): 240 | memory = Buffer() 241 | 242 | for b in buffers: 243 | offset = len(memory) 244 | 245 | memory.states += b.states 246 | memory.actions += b.actions 247 | memory.rewards += b.rewards 248 | memory.values += b.values 249 | memory.returns += b.returns 250 | 251 | memory.traj_idx += [offset + i for i in b.traj_idx[1:]] 252 | memory.size += b.size 253 | return memory 254 | 255 | def do_iteration(self, num_steps, max_traj_len, epochs, kl_thresh=0.02, verbose=True, batch_size=64): 256 | self.old_actor.load_state_dict(self.actor.state_dict()) 257 | 258 | start = time() 259 | actor_param_id = ray.put(list(self.actor.parameters())) 260 | critic_param_id = ray.put(list(self.critic.parameters())) 261 | norm_id = ray.put([self.actor.state_mean, self.actor.state_mean_diff, self.actor.state_n]) 262 | 263 | steps = max(num_steps // len(self.workers), max_traj_len) 264 | 265 | for w in self.workers: 266 | w.sync_policy.remote(actor_param_id, critic_param_id, input_norm=norm_id) 267 | 268 | if verbose: 269 | print("\t{:5.4f}s to copy policy params to workers.".format(time() - start)) 270 | 271 | start = time() 272 | buffers = ray.get([w.collect_experience.remote(max_traj_len, steps) for w in self.workers]) 273 | memory = self.merge_buffers(buffers) 274 | 275 | total_steps = len(memory) 276 | elapsed = time() - start 277 | if verbose: 278 | print("\t{:3.2f}s to collect {:6n} timesteps | {:3.2}k/s.".format(elapsed, total_steps, (total_steps/1000)/elapsed)) 279 | 280 | start = time() 281 | kls = [] 282 | done = False 283 | for epoch in range(epochs): 284 | a_loss = [] 285 | c_loss = [] 286 | for batch in memory.sample(batch_size=batch_size, recurrent=self.recurrent): 287 | states, actions, returns, advantages, mask = batch 288 | 289 | kl, losses = self.sync_policy(states, actions, returns, advantages, mask) 290 | kls += [kl] 291 | a_loss += [losses[0]] 292 | c_loss += [losses[1]] 293 | 294 | if max(kls) > kl_thresh: 295 | done = True 296 | print("\t\tbatch had kl of {} (threshold {}), stopping optimization early.".format(max(kls), kl_thresh)) 297 | break 298 | 299 | if verbose: 300 | print("\t\tepoch {:2d} kl {:4.3f}, actor loss {:6.3f}, critic loss {:6.3f}".format(epoch+1, np.mean(kls), np.mean(a_loss), np.mean(c_loss))) 301 | 302 | if done: 303 | break 304 | 305 | if verbose: 306 | print("\t{:3.2f}s to update policy.".format(time() - start)) 307 | return np.mean(kls), np.mean(a_loss), np.mean(c_loss), len(memory) 308 | 309 | def run_experiment(args): 310 | torch.set_num_threads(1) 311 | 312 | from util import create_logger, env_factory, eval_policy, train_normalizer 313 | 314 | from nn.critic import FF_V, LSTM_V 315 | from nn.actor import FF_Stochastic_Actor, LSTM_Stochastic_Actor 316 | 317 | import locale, os 318 | locale.setlocale(locale.LC_ALL, '') 319 | 320 | # wrapper function for creating parallelized envs 321 | env_fn = env_factory(args.randomize) 322 | obs_dim = env_fn().observation_space.shape[0] 323 | action_dim = env_fn().action_space.shape[0] 324 | layers = [int(x) for x in args.layers.split(',')] 325 | 326 | # Set seeds 327 | torch.manual_seed(args.seed) 328 | np.random.seed(args.seed) 329 | 330 | if args.recurrent: 331 | policy = LSTM_Stochastic_Actor(obs_dim, action_dim,\ 332 | layers=layers, 333 | dynamics_randomization=args.randomize, 334 | fixed_std=torch.ones(action_dim)*args.std) 335 | critic = LSTM_V(obs_dim, layers=layers) 336 | else: 337 | policy = FF_Stochastic_Actor(obs_dim, action_dim,\ 338 | layers=layers, 339 | dynamics_randomization=args.randomize, 340 | fixed_std=torch.ones(action_dim)*args.std) 341 | critic = FF_V(obs_dim, layers=layers) 342 | 343 | env = env_fn() 344 | 345 | policy.train(0) 346 | critic.train(0) 347 | 348 | print("Collecting normalization statistics with {} states...".format(args.prenormalize_steps)) 349 | train_normalizer(policy, args.prenormalize_steps, max_traj_len=args.traj_len, noise=1) 350 | critic.copy_normalizer_stats(policy) 351 | 352 | algo = PPO(policy, critic, env_fn, args) 353 | 354 | # create a tensorboard logging object 355 | if not args.nolog: 356 | logger = create_logger(args) 357 | else: 358 | logger = None 359 | 360 | if args.save_actor is None and logger is not None: 361 | args.save_actor = os.path.join(logger.dir, 'actor.pt') 362 | 363 | if args.save_critic is None and logger is not None: 364 | args.save_critic = os.path.join(logger.dir, 'critic.pt') 365 | 366 | print() 367 | print("Proximal Policy Optimization:") 368 | print("\tseed: {}".format(args.seed)) 369 | print("\ttimesteps: {:n}".format(int(args.timesteps))) 370 | print("\titeration steps: {:n}".format(int(args.sample))) 371 | print("\tprenormalize steps: {}".format(int(args.prenormalize_steps))) 372 | print("\ttraj_len: {}".format(args.traj_len)) 373 | print("\tdiscount: {}".format(args.discount)) 374 | print("\tactor_lr: {}".format(args.a_lr)) 375 | print("\tcritic_lr: {}".format(args.c_lr)) 376 | print("\tgrad clip: {}".format(args.grad_clip)) 377 | print("\tbatch size: {}".format(args.batch_size)) 378 | print("\tepochs: {}".format(args.epochs)) 379 | print("\trecurrent: {}".format(args.recurrent)) 380 | print("\tdynamics rand: {}".format(args.randomize)) 381 | print("\tworkers: {}".format(args.workers)) 382 | print() 383 | 384 | itr = 0 385 | timesteps = 0 386 | best_reward = None 387 | while timesteps < args.timesteps: 388 | kl, a_loss, c_loss, steps = algo.do_iteration(args.sample, args.traj_len, args.epochs, batch_size=args.batch_size, kl_thresh=args.kl) 389 | eval_reward = eval_policy(algo.actor, env, episodes=5, max_traj_len=args.traj_len, verbose=False, visualize=False) 390 | 391 | timesteps += steps 392 | print("iter {:4d} | return: {:5.2f} | KL {:5.4f} | timesteps {:n}".format(itr, eval_reward, kl, timesteps)) 393 | 394 | if best_reward is None or eval_reward > best_reward: 395 | print("\t(best policy so far! saving to {})".format(args.save_actor)) 396 | best_reward = eval_reward 397 | if args.save_actor is not None: 398 | torch.save(algo.actor, args.save_actor) 399 | 400 | if args.save_critic is not None: 401 | torch.save(algo.critic, args.save_critic) 402 | 403 | if logger is not None: 404 | logger.add_scalar('cassie/kl', kl, itr) 405 | logger.add_scalar('cassie/return', eval_reward, itr) 406 | logger.add_scalar('cassie/actor_loss', a_loss, itr) 407 | logger.add_scalar('cassie/critic_loss', c_loss, itr) 408 | itr += 1 409 | print("Finished ({} of {}).".format(timesteps, args.timesteps)) 410 | -------------------------------------------------------------------------------- /cassie/__init__.py: -------------------------------------------------------------------------------- 1 | from .cassie import CassieEnv 2 | from .cassiemujoco import * 3 | -------------------------------------------------------------------------------- /cassie/cassie.py: -------------------------------------------------------------------------------- 1 | from .cassiemujoco import pd_in_t, state_out_t, CassieSim, CassieVis 2 | 3 | from .trajectory import CassieTrajectory 4 | 5 | from math import floor 6 | 7 | import numpy as np 8 | import os 9 | import random 10 | 11 | import pickle 12 | 13 | class CassieEnv: 14 | def __init__(self, dynamics_randomization=False): 15 | self.sim = CassieSim("./cassie/cassiemujoco/cassie.xml") 16 | self.vis = None 17 | 18 | self.dynamics_randomization = dynamics_randomization 19 | 20 | state_est_size = 46 21 | clock_size = 2 22 | speed_size = 1 23 | 24 | self.observation_space = np.zeros(state_est_size + clock_size + speed_size) 25 | self.action_space = np.zeros(10) 26 | 27 | dirname = os.path.dirname(__file__) 28 | 29 | traj_path = os.path.join(dirname, "trajectory", "stepdata.bin") 30 | 31 | self.trajectory = CassieTrajectory(traj_path) 32 | 33 | self.P = np.array([100, 100, 88, 96, 50]) 34 | self.D = np.array([10.0, 10.0, 8.0, 9.6, 5.0]) 35 | 36 | self.u = pd_in_t() 37 | 38 | self.cassie_state = state_out_t() 39 | 40 | self.simrate = 60 # simulate X mujoco steps with same pd target 41 | # 60 brings simulation from 2000Hz to roughly 30Hz 42 | self.time = 0 # number of time steps in current episode 43 | self.phase = 0 # portion of the phase the robot is in 44 | self.counter = 0 # number of phase cycles completed in episode 45 | 46 | # NOTE: a reference trajectory represents ONE phase cycle 47 | self.phaselen = floor(len(self.trajectory) / self.simrate) - 1 48 | 49 | # see include/cassiemujoco.h for meaning of these indices 50 | self.pos_idx = [7, 8, 9, 14, 20, 21, 22, 23, 28, 34] 51 | self.vel_idx = [6, 7, 8, 12, 18, 19, 20, 21, 25, 31] 52 | self.offset = np.array([0.0045, 0.0, 0.4973, -1.1997, -1.5968, 0.0045, 0.0, 0.4973, -1.1997, -1.5968]) 53 | self.speed = 0 54 | self.phase_add = 1 55 | 56 | # Record default dynamics parameters 57 | self.default_damping = self.sim.get_dof_damping() 58 | self.default_mass = self.sim.get_body_mass() 59 | self.default_ipos = self.sim.get_body_ipos() 60 | 61 | def step_simulation(self, action): 62 | 63 | target = action + self.offset 64 | 65 | self.u = pd_in_t() 66 | for i in range(5): 67 | # TODO: move setting gains out of the loop? 68 | # maybe write a wrapper for pd_in_t ? 69 | self.u.leftLeg.motorPd.pGain[i] = self.P[i] 70 | self.u.rightLeg.motorPd.pGain[i] = self.P[i] 71 | 72 | self.u.leftLeg.motorPd.dGain[i] = self.D[i] 73 | self.u.rightLeg.motorPd.dGain[i] = self.D[i] 74 | 75 | self.u.leftLeg.motorPd.torque[i] = 0 # Feedforward torque 76 | self.u.rightLeg.motorPd.torque[i] = 0 77 | 78 | self.u.leftLeg.motorPd.pTarget[i] = target[i] 79 | self.u.rightLeg.motorPd.pTarget[i] = target[i + 5] 80 | 81 | self.u.leftLeg.motorPd.dTarget[i] = 0 82 | self.u.rightLeg.motorPd.dTarget[i] = 0 83 | 84 | self.cassie_state = self.sim.step_pd(self.u) 85 | 86 | def step(self, action): 87 | for _ in range(self.simrate): 88 | self.step_simulation(action) 89 | 90 | height = self.sim.qpos()[2] 91 | 92 | self.time += 1 93 | self.phase += self.phase_add 94 | 95 | if self.phase > self.phaselen: 96 | self.phase = 0 97 | self.counter += 1 98 | 99 | # Early termination 100 | done = not(height > 0.4 and height < 3.0) 101 | 102 | reward = self.compute_reward() 103 | 104 | if reward < 0.3: 105 | done = True 106 | 107 | return self.get_full_state(), reward, done, {} 108 | 109 | def reset(self): 110 | self.phase = random.randint(0, self.phaselen) 111 | self.time = 0 112 | self.counter = 0 113 | 114 | qpos, qvel = self.get_ref_state(self.phase) 115 | 116 | self.sim.set_qpos(qpos) 117 | self.sim.set_qvel(qvel) 118 | 119 | # Randomize dynamics: 120 | if self.dynamics_randomization: 121 | damp = self.default_damping 122 | weak_factor = 0.5 123 | strong_factor = 1.5 124 | pelvis_damp_range = [[damp[0], damp[0]], 125 | [damp[1], damp[1]], 126 | [damp[2], damp[2]], 127 | [damp[3], damp[3]], 128 | [damp[4], damp[4]], 129 | [damp[5], damp[5]]] # 0->5 130 | 131 | hip_damp_range = [[damp[6]*weak_factor, damp[6]*strong_factor], 132 | [damp[7]*weak_factor, damp[7]*strong_factor], 133 | [damp[8]*weak_factor, damp[8]*strong_factor]] # 6->8 and 19->21 134 | 135 | achilles_damp_range = [[damp[9]*weak_factor, damp[9]*strong_factor], 136 | [damp[10]*weak_factor, damp[10]*strong_factor], 137 | [damp[11]*weak_factor, damp[11]*strong_factor]] # 9->11 and 22->24 138 | 139 | knee_damp_range = [[damp[12]*weak_factor, damp[12]*strong_factor]] # 12 and 25 140 | shin_damp_range = [[damp[13]*weak_factor, damp[13]*strong_factor]] # 13 and 26 141 | tarsus_damp_range = [[damp[14], damp[14]]] # 14 and 27 142 | heel_damp_range = [[damp[15], damp[15]]] # 15 and 28 143 | fcrank_damp_range = [[damp[16]*weak_factor, damp[16]*strong_factor]] # 16 and 29 144 | prod_damp_range = [[damp[17], damp[17]]] # 17 and 30 145 | foot_damp_range = [[damp[18]*weak_factor, damp[18]*strong_factor]] # 18 and 31 146 | 147 | side_damp = hip_damp_range + achilles_damp_range + knee_damp_range + shin_damp_range + tarsus_damp_range + heel_damp_range + fcrank_damp_range + prod_damp_range + foot_damp_range 148 | damp_range = pelvis_damp_range + side_damp + side_damp 149 | damp_noise = [np.random.uniform(a, b) for a, b in damp_range] 150 | 151 | hi = 1.3 152 | lo = 0.7 153 | m = self.default_mass 154 | pelvis_mass_range = [[lo*m[1], hi*m[1]]] # 1 155 | hip_mass_range = [[lo*m[2], hi*m[2]], # 2->4 and 14->16 156 | [lo*m[3], hi*m[3]], 157 | [lo*m[4], hi*m[4]]] 158 | 159 | achilles_mass_range = [[lo*m[5], hi*m[5]]] # 5 and 17 160 | knee_mass_range = [[lo*m[6], hi*m[6]]] # 6 and 18 161 | knee_spring_mass_range = [[lo*m[7], hi*m[7]]] # 7 and 19 162 | shin_mass_range = [[lo*m[8], hi*m[8]]] # 8 and 20 163 | tarsus_mass_range = [[lo*m[9], hi*m[9]]] # 9 and 21 164 | heel_spring_mass_range = [[lo*m[10], hi*m[10]]] # 10 and 22 165 | fcrank_mass_range = [[lo*m[11], hi*m[11]]] # 11 and 23 166 | prod_mass_range = [[lo*m[12], hi*m[12]]] # 12 and 24 167 | foot_mass_range = [[lo*m[13], hi*m[13]]] # 13 and 25 168 | 169 | side_mass = hip_mass_range + achilles_mass_range \ 170 | + knee_mass_range + knee_spring_mass_range \ 171 | + shin_mass_range + tarsus_mass_range \ 172 | + heel_spring_mass_range + fcrank_mass_range \ 173 | + prod_mass_range + foot_mass_range 174 | 175 | mass_range = [[0, 0]] + pelvis_mass_range + side_mass + side_mass 176 | mass_noise = [np.random.uniform(a, b) for a, b in mass_range] 177 | 178 | delta_y_min, delta_y_max = self.default_ipos[4] - 0.07, self.default_ipos[4] + 0.07 179 | delta_z_min, delta_z_max = self.default_ipos[5] - 0.04, self.default_ipos[5] + 0.04 180 | com_noise = [0, 0, 0] + [np.random.uniform(-0.25, 0.06)] + [np.random.uniform(delta_y_min, delta_y_max)] + [np.random.uniform(delta_z_min, delta_z_max)] + list(self.default_ipos[6:]) 181 | 182 | self.sim.set_dof_damping(np.clip(damp_noise, 0, None)) 183 | self.sim.set_body_mass(np.clip(mass_noise, 0, None)) 184 | self.sim.set_body_ipos(com_noise) 185 | else: 186 | self.sim.set_dof_damping(self.default_damping) 187 | self.sim.set_body_mass(self.default_mass) 188 | self.sim.set_body_ipos(self.default_ipos) 189 | 190 | self.sim.set_const() 191 | 192 | self.cassie_state = self.sim.step_pd(self.u) 193 | self.speed = np.random.uniform(-0.15, 0.8) 194 | 195 | return self.get_full_state() 196 | 197 | def compute_reward(self): 198 | qpos = np.copy(self.sim.qpos()) 199 | qvel = np.copy(self.sim.qvel()) 200 | 201 | ref_pos, _ = self.get_ref_state(self.phase) 202 | 203 | joint_error = 0 204 | com_error = 0 205 | orientation_error = 0 206 | spring_error = 0 207 | 208 | # each joint pos 209 | weight = [0.15, 0.15, 0.1, 0.05, 0.05, 0.15, 0.15, 0.1, 0.05, 0.05] 210 | for i, j in enumerate(self.pos_idx): 211 | target = ref_pos[j] 212 | actual = qpos[j] 213 | 214 | joint_error += 30 * weight[i] * (target - actual) ** 2 215 | 216 | forward_diff = np.abs(qvel[0] - self.speed) 217 | if forward_diff < 0.05: 218 | forward_diff = 0 219 | 220 | y_vel = np.abs(qvel[1]) 221 | if y_vel < 0.03: 222 | y_vel = 0 223 | 224 | straight_diff = np.abs(qpos[1]) 225 | if straight_diff < 0.05: 226 | straight_diff = 0 227 | 228 | actual_q = qpos[3:7] 229 | target_q = [1, 0, 0, 0] 230 | orientation_error = 5 * (1 - np.inner(actual_q, target_q) ** 2) 231 | 232 | # left and right shin springs 233 | for i in [15, 29]: 234 | target = ref_pos[i] 235 | actual = qpos[i] 236 | 237 | spring_error += 1000 * (target - actual) ** 2 238 | 239 | reward = 0.000 + \ 240 | 0.300 * np.exp(-orientation_error) + \ 241 | 0.200 * np.exp(-joint_error) + \ 242 | 0.200 * np.exp(-forward_diff) + \ 243 | 0.200 * np.exp(-y_vel) + \ 244 | 0.050 * np.exp(-straight_diff) + \ 245 | 0.050 * np.exp(-spring_error) 246 | 247 | return reward 248 | 249 | def get_damping(self): 250 | return np.array(self.sim.get_dof_damping()) 251 | 252 | def get_mass(self): 253 | return np.array(self.sim.get_body_mass()) 254 | 255 | def get_ipos(self): 256 | return np.array(self.sim.get_body_ipos()[3:6]) 257 | 258 | # get the corresponding state from the reference trajectory for the current phase 259 | def get_ref_state(self, phase=None): 260 | if phase is None: 261 | phase = self.phase 262 | 263 | if phase > self.phaselen: 264 | phase = 0 265 | 266 | pos = np.copy(self.trajectory.qpos[phase * self.simrate]) 267 | 268 | ###### Setting variable speed ######### 269 | pos[0] *= self.speed 270 | pos[0] += (self.trajectory.qpos[-1, 0] - self.trajectory.qpos[0, 0]) * self.counter * self.speed 271 | ###### ######## 272 | 273 | # setting lateral distance target to 0 274 | # regardless of reference trajectory 275 | pos[1] = 0 276 | 277 | vel = np.copy(self.trajectory.qvel[phase * self.simrate]) 278 | vel[0] *= self.speed 279 | 280 | return pos, vel 281 | 282 | def get_full_state(self): 283 | qpos = np.copy(self.sim.qpos()) 284 | qvel = np.copy(self.sim.qvel()) 285 | 286 | ref_pos, _ = self.get_ref_state(self.phase + self.phase_add) 287 | 288 | clock = [np.sin(2 * np.pi * self.phase / self.phaselen), 289 | np.cos(2 * np.pi * self.phase / self.phaselen)] 290 | 291 | ext_state = np.concatenate((clock, [self.speed])) 292 | 293 | # Use state estimator 294 | robot_state = np.concatenate([ 295 | [self.cassie_state.pelvis.position[2] - self.cassie_state.terrain.height], # pelvis height 296 | self.cassie_state.pelvis.orientation[:], # pelvis orientation 297 | self.cassie_state.motor.position[:], # actuated joint positions 298 | self.cassie_state.pelvis.translationalVelocity[:], # pelvis translational velocity 299 | self.cassie_state.pelvis.rotationalVelocity[:], # pelvis rotational velocity 300 | self.cassie_state.motor.velocity[:], # actuated joint velocities 301 | self.cassie_state.pelvis.translationalAcceleration[:], # pelvis translational acceleration 302 | self.cassie_state.joint.position[:], # unactuated joint positions 303 | self.cassie_state.joint.velocity[:] # unactuated joint velocities 304 | ]) 305 | 306 | return np.concatenate([robot_state, ext_state]) 307 | 308 | def render(self): 309 | if self.vis is None: 310 | self.vis = CassieVis(self.sim, "./cassie/cassiemujoco/cassie.xml") 311 | 312 | return self.vis.draw(self.sim) 313 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/__init__.py: -------------------------------------------------------------------------------- 1 | from .cassiemujoco import * -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/achilles-rod.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/achilles-rod.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/foot-crank.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/foot-crank.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/foot.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/foot.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/heel-spring.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/heel-spring.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/hip-pitch.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/hip-pitch.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/hip-roll.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/hip-roll.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/hip-yaw.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/hip-yaw.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/knee-spring.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/knee-spring.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/knee.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/knee.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/pelvis.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/pelvis.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/plantar-rod.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/plantar-rod.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/shin.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/shin.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie-stl-meshes/tarsus.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/cassie-stl-meshes/tarsus.stl -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 275 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassieUDP.py: -------------------------------------------------------------------------------- 1 | from .cassiemujoco_ctypes import * 2 | import os 3 | import ctypes 4 | import numpy as np 5 | 6 | class CassieUdp: 7 | def __init__(self, remote_addr='127.0.0.1', remote_port='25000', 8 | local_addr='0.0.0.0', local_port='25001'): 9 | self.sock = udp_init_client(str.encode(remote_addr), 10 | str.encode(remote_port), 11 | str.encode(local_addr), 12 | str.encode(local_port)) 13 | self.packet_header_info = packet_header_info_t() 14 | self.recvlen = 2 + 697 15 | self.sendlen = 2 + 58 16 | self.recvlen_pd = 2 + 493 17 | self.sendlen_pd = 2 + 476 18 | self.recvbuf = (ctypes.c_ubyte * max(self.recvlen, self.recvlen_pd))() 19 | self.sendbuf = (ctypes.c_ubyte * max(self.sendlen, self.sendlen_pd))() 20 | self.inbuf = ctypes.cast(ctypes.byref(self.recvbuf, 2), 21 | ctypes.POINTER(ctypes.c_ubyte)) 22 | self.outbuf = ctypes.cast(ctypes.byref(self.sendbuf, 2), 23 | ctypes.POINTER(ctypes.c_ubyte)) 24 | 25 | def send(self, u): 26 | pack_cassie_user_in_t(u, self.outbuf) 27 | send_packet(self.sock, self.sendbuf, self.sendlen, None, 0) 28 | 29 | def send_pd(self, u): 30 | pack_pd_in_t(u, self.outbuf) 31 | send_packet(self.sock, self.sendbuf, self.sendlen_pd, None, 0) 32 | 33 | def recv_wait(self): 34 | nbytes = -1 35 | while nbytes != self.recvlen: 36 | nbytes = get_newest_packet(self.sock, self.recvbuf, self.recvlen, 37 | None, None) 38 | process_packet_header(self.packet_header_info, 39 | self.recvbuf, self.sendbuf) 40 | cassie_out = cassie_out_t() 41 | unpack_cassie_out_t(self.inbuf, cassie_out) 42 | return cassie_out 43 | 44 | def recv_wait_pd(self): 45 | nbytes = -1 46 | while nbytes != self.recvlen_pd: 47 | nbytes = get_newest_packet(self.sock, self.recvbuf, self.recvlen_pd, 48 | None, None) 49 | process_packet_header(self.packet_header_info, 50 | self.recvbuf, self.sendbuf) 51 | state_out = state_out_t() 52 | unpack_state_out_t(self.inbuf, state_out) 53 | return state_out 54 | 55 | def recv_newest(self): 56 | nbytes = get_newest_packet(self.sock, self.recvbuf, self.recvlen, 57 | None, None) 58 | if nbytes != self.recvlen: 59 | return None 60 | process_packet_header(self.packet_header_info, 61 | self.recvbuf, self.sendbuf) 62 | cassie_out = cassie_out_t() 63 | unpack_cassie_out_t(self.inbuf, cassie_out) 64 | return cassie_out 65 | 66 | def recv_newest_pd(self): 67 | nbytes = get_newest_packet(self.sock, self.recvbuf, self.recvlen_pd, 68 | None, None) 69 | if nbytes != self.recvlen_pd: 70 | return None 71 | process_packet_header(self.packet_header_info, 72 | self.recvbuf, self.sendbuf) 73 | state_out = state_out_t() 74 | unpack_state_out_t(self.inbuf, state_out) 75 | return state_out 76 | 77 | def delay(self): 78 | return ord(self.packet_header_info.delay) 79 | 80 | def seq_num_in_diff(self): 81 | return ord(self.packet_header_info.seq_num_in_diff) 82 | 83 | def __del__(self): 84 | udp_close(self.sock) -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie_soft.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 271 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassie_stiff.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 287 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassiemujoco.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Dynamic Robotics Laboratory 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from .cassiemujoco_ctypes import * 16 | import os 17 | import ctypes 18 | import numpy as np 19 | 20 | # Get base directory 21 | _dir_path = os.path.dirname(os.path.realpath(__file__)) 22 | 23 | # Initialize libcassiesim 24 | cassie_mujoco_init(str.encode(_dir_path+"/cassie.xml")) 25 | 26 | # Interface classes 27 | class CassieSim: 28 | def __init__(self, modelfile): 29 | self.c = cassie_sim_init(modelfile.encode('utf-8')) 30 | self.nv = 32 31 | self.nbody = 26 32 | self.nq = 35 33 | self.ngeom = 35 34 | 35 | def step(self, u): 36 | y = cassie_out_t() 37 | cassie_sim_step(self.c, y, u) 38 | return y 39 | 40 | def step_pd(self, u): 41 | y = state_out_t() 42 | cassie_sim_step_pd(self.c, y, u) 43 | return y 44 | 45 | def get_state(self): 46 | s = CassieState() 47 | cassie_get_state(self.c, s.s) 48 | return s 49 | 50 | def set_state(self, s): 51 | cassie_set_state(self.c, s.s) 52 | 53 | def time(self): 54 | timep = cassie_sim_time(self.c) 55 | return timep[0] 56 | 57 | def qpos(self): 58 | qposp = cassie_sim_qpos(self.c) 59 | return qposp[:self.nq] 60 | 61 | def qvel(self): 62 | qvelp = cassie_sim_qvel(self.c) 63 | return qvelp[:self.nv] 64 | 65 | def set_time(self, time): 66 | timep = cassie_sim_time(self.c) 67 | timep[0] = time 68 | 69 | def set_qpos(self, qpos): 70 | qposp = cassie_sim_qpos(self.c) 71 | for i in range(min(len(qpos), self.nq)): 72 | qposp[i] = qpos[i] 73 | 74 | def set_qvel(self, qvel): 75 | qvelp = cassie_sim_qvel(self.c) 76 | for i in range(min(len(qvel), self.nv)): 77 | qvelp[i] = qvel[i] 78 | 79 | def hold(self): 80 | cassie_sim_hold(self.c) 81 | 82 | def release(self): 83 | cassie_sim_release(self.c) 84 | 85 | def apply_force(self, xfrc, body=1): 86 | xfrc_array = (ctypes.c_double * 6)() 87 | for i in range(len(xfrc)): 88 | xfrc_array[i] = xfrc[i] 89 | cassie_sim_apply_force(self.c, xfrc_array, body) 90 | 91 | def foot_force(self, force): 92 | frc_array = (ctypes.c_double * 12)() 93 | cassie_sim_foot_forces(self.c, frc_array) 94 | for i in range(12): 95 | force[i] = frc_array[i] 96 | #print(force) 97 | 98 | def foot_pos(self, pos): 99 | pos_array = (ctypes.c_double * 6)() 100 | cassie_sim_foot_positions(self.c, pos_array) 101 | for i in range(6): 102 | pos[i] = pos_array[i] 103 | 104 | def clear_forces(self): 105 | cassie_sim_clear_forces(self.c) 106 | 107 | def get_foot_forces(self): 108 | y = state_out_t() 109 | force = np.zeros(12) 110 | self.foot_force(force) 111 | return force[[2, 8]] 112 | 113 | def get_dof_damping(self): 114 | ptr = cassie_sim_dof_damping(self.c) 115 | ret = np.zeros(self.nv) 116 | for i in range(self.nv): 117 | ret[i] = ptr[i] 118 | return ret 119 | 120 | def get_body_mass(self): 121 | ptr = cassie_sim_body_mass(self.c) 122 | ret = np.zeros(self.nbody) 123 | for i in range(self.nbody): 124 | ret[i] = ptr[i] 125 | return ret 126 | 127 | def get_body_ipos(self): 128 | nbody = self.nbody * 3 129 | ptr = cassie_sim_body_ipos(self.c) 130 | ret = np.zeros(nbody) 131 | for i in range(nbody): 132 | ret[i] = ptr[i] 133 | return ret 134 | 135 | def get_ground_friction(self): 136 | ptr = cassie_sim_ground_friction(self.c) 137 | ret = np.zeros(3) 138 | for i in range(3): 139 | ret[i] = ptr[i] 140 | return ret 141 | 142 | def get_geom_rgba(self): 143 | ptr = cassie_sim_geom_rgba(self.c) 144 | ret = np.zeros(self.ngeom * 4) 145 | for i in range(self.ngeom * 4): 146 | ret[i] = ptr[i] 147 | return ret 148 | 149 | def set_dof_damping(self, data): 150 | c_arr = (ctypes.c_double * self.nv)() 151 | 152 | if len(data) != self.nv: 153 | print("SIZE MISMATCH SET_DOF_DAMPING()") 154 | exit(1) 155 | 156 | for i in range(self.nv): 157 | c_arr[i] = data[i] 158 | 159 | cassie_sim_set_dof_damping(self.c, c_arr) 160 | 161 | def set_body_mass(self, data): 162 | c_arr = (ctypes.c_double * self.nbody)() 163 | 164 | if len(data) != self.nbody: 165 | print("SIZE MISMATCH SET_BODY_MASS()") 166 | exit(1) 167 | 168 | for i in range(self.nbody): 169 | c_arr[i] = data[i] 170 | 171 | cassie_sim_set_body_mass(self.c, c_arr) 172 | 173 | def set_body_ipos(self, data): 174 | nbody = self.nbody * 3 175 | c_arr = (ctypes.c_double * nbody)() 176 | 177 | if len(data) != nbody: 178 | print("SIZE MISMATCH SET_BODY_IPOS()") 179 | exit(1) 180 | 181 | for i in range(nbody): 182 | c_arr[i] = data[i] 183 | 184 | cassie_sim_set_body_ipos(self.c, c_arr) 185 | 186 | def set_ground_friction(self, data): 187 | c_arr = (ctypes.c_double * 3)() 188 | 189 | if len(data) != 3: 190 | print("SIZE MISMATCH SET_GROUND_FRICTION()") 191 | exit(1) 192 | 193 | for i in range(3): 194 | c_arr[i] = data[i] 195 | 196 | cassie_sim_set_ground_friction(self.c, c_arr) 197 | 198 | def set_geom_rgba(self, data): 199 | ngeom = self.ngeom * 4 200 | 201 | if len(data) != ngeom: 202 | print("SIZE MISMATCH SET_GEOM_RGBA()") 203 | exit(1) 204 | 205 | c_arr = (ctypes.c_float * ngeom)() 206 | 207 | for i in range(ngeom): 208 | c_arr[i] = data[i] 209 | 210 | cassie_sim_set_geom_rgba(self.c, c_arr) 211 | 212 | def set_const(self): 213 | cassie_sim_set_const(self.c) 214 | 215 | def __del__(self): 216 | cassie_sim_free(self.c) 217 | 218 | class CassieVis: 219 | def __init__(self, c, modelfile): 220 | self.v = cassie_vis_init(c.c, modelfile.encode('utf-8')) 221 | 222 | def draw(self, c): 223 | state = cassie_vis_draw(self.v, c.c) 224 | # print("vis draw state:", state) 225 | return state 226 | 227 | def valid(self): 228 | return cassie_vis_valid(self.v) 229 | 230 | def ispaused(self): 231 | return cassie_vis_paused(self.v) 232 | 233 | def __del__(self): 234 | cassie_vis_free(self.v) 235 | 236 | class CassieState: 237 | def __init__(self): 238 | self.s = cassie_state_alloc() 239 | 240 | def time(self): 241 | timep = cassie_state_time(self.s) 242 | return timep[0] 243 | 244 | def qpos(self): 245 | qposp = cassie_state_qpos(self.s) 246 | return qposp[:35] 247 | 248 | def qvel(self): 249 | qvelp = cassie_state_qvel(self.s) 250 | return qvelp[:32] 251 | 252 | def set_time(self, time): 253 | timep = cassie_state_time(self.s) 254 | timep[0] = time 255 | 256 | def set_qpos(self, qpos): 257 | qposp = cassie_state_qpos(self.s) 258 | for i in range(min(len(qpos), 35)): 259 | qposp[i] = qpos[i] 260 | 261 | def set_qvel(self, qvel): 262 | qvelp = cassie_state_qvel(self.s) 263 | for i in range(min(len(qvel), 32)): 264 | qvelp[i] = qvel[i] 265 | 266 | def __del__(self): 267 | cassie_state_free(self.s) 268 | 269 | class CassieUdp: 270 | def __init__(self, remote_addr='127.0.0.1', remote_port='25000', 271 | local_addr='0.0.0.0', local_port='25001'): 272 | self.sock = udp_init_client(str.encode(remote_addr), 273 | str.encode(remote_port), 274 | str.encode(local_addr), 275 | str.encode(local_port)) 276 | self.packet_header_info = packet_header_info_t() 277 | self.recvlen = 2 + 697 278 | self.sendlen = 2 + 58 279 | self.recvlen_pd = 2 + 493 280 | self.sendlen_pd = 2 + 476 281 | self.recvbuf = (ctypes.c_ubyte * max(self.recvlen, self.recvlen_pd))() 282 | self.sendbuf = (ctypes.c_ubyte * max(self.sendlen, self.sendlen_pd))() 283 | self.inbuf = ctypes.cast(ctypes.byref(self.recvbuf, 2), 284 | ctypes.POINTER(ctypes.c_ubyte)) 285 | self.outbuf = ctypes.cast(ctypes.byref(self.sendbuf, 2), 286 | ctypes.POINTER(ctypes.c_ubyte)) 287 | 288 | def send(self, u): 289 | pack_cassie_user_in_t(u, self.outbuf) 290 | send_packet(self.sock, self.sendbuf, self.sendlen, None, 0) 291 | 292 | def send_pd(self, u): 293 | pack_pd_in_t(u, self.outbuf) 294 | send_packet(self.sock, self.sendbuf, self.sendlen_pd, None, 0) 295 | 296 | def recv_wait(self): 297 | nbytes = -1 298 | while nbytes != self.recvlen: 299 | nbytes = get_newest_packet(self.sock, self.recvbuf, self.recvlen, 300 | None, None) 301 | process_packet_header(self.packet_header_info, 302 | self.recvbuf, self.sendbuf) 303 | cassie_out = cassie_out_t() 304 | unpack_cassie_out_t(self.inbuf, cassie_out) 305 | return cassie_out 306 | 307 | def recv_wait_pd(self): 308 | nbytes = -1 309 | while nbytes != self.recvlen_pd: 310 | nbytes = get_newest_packet(self.sock, self.recvbuf, self.recvlen_pd, 311 | None, None) 312 | process_packet_header(self.packet_header_info, 313 | self.recvbuf, self.sendbuf) 314 | state_out = state_out_t() 315 | unpack_state_out_t(self.inbuf, state_out) 316 | return state_out 317 | 318 | def recv_newest(self): 319 | nbytes = get_newest_packet(self.sock, self.recvbuf, self.recvlen, 320 | None, None) 321 | if nbytes != self.recvlen: 322 | return None 323 | process_packet_header(self.packet_header_info, 324 | self.recvbuf, self.sendbuf) 325 | cassie_out = cassie_out_t() 326 | unpack_cassie_out_t(self.inbuf, cassie_out) 327 | return cassie_out 328 | 329 | def recv_newest_pd(self): 330 | nbytes = get_newest_packet(self.sock, self.recvbuf, self.recvlen_pd, 331 | None, None) 332 | if nbytes != self.recvlen_pd: 333 | return None 334 | process_packet_header(self.packet_header_info, 335 | self.recvbuf, self.sendbuf) 336 | state_out = state_out_t() 337 | unpack_state_out_t(self.inbuf, state_out) 338 | return state_out 339 | 340 | def delay(self): 341 | return ord(self.packet_header_info.delay) 342 | 343 | def seq_num_in_diff(self): 344 | return ord(self.packet_header_info.seq_num_in_diff) 345 | 346 | def __del__(self): 347 | udp_close(self.sock) 348 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/cassiemujoco_ctypes.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # TARGET arch is: ['-I/usr/include/clang/6.0/include', '-Iinclude'] 4 | # WORD_SIZE is: 8 5 | # POINTER_SIZE is: 8 6 | # LONGDOUBLE_SIZE is: 16 7 | # 8 | import ctypes 9 | import os 10 | _dir_path = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | _libraries = {} 14 | _libraries['./libcassiemujoco.so'] = ctypes.CDLL(_dir_path + '/libcassiemujoco.so') 15 | # if local wordsize is same as target, keep ctypes pointer function. 16 | if ctypes.sizeof(ctypes.c_void_p) == 8: 17 | POINTER_T = ctypes.POINTER 18 | else: 19 | # required to access _ctypes 20 | import _ctypes 21 | # Emulate a pointer class using the approriate c_int32/c_int64 type 22 | # The new class should have : 23 | # ['__module__', 'from_param', '_type_', '__dict__', '__weakref__', '__doc__'] 24 | # but the class should be submitted to a unique instance for each base type 25 | # to that if A == B, POINTER_T(A) == POINTER_T(B) 26 | ctypes._pointer_t_type_cache = {} 27 | def POINTER_T(pointee): 28 | # a pointer should have the same length as LONG 29 | fake_ptr_base_type = ctypes.c_uint64 30 | # specific case for c_void_p 31 | if pointee is None: # VOID pointer type. c_void_p. 32 | pointee = type(None) # ctypes.c_void_p # ctypes.c_ulong 33 | clsname = 'c_void' 34 | else: 35 | clsname = pointee.__name__ 36 | if clsname in ctypes._pointer_t_type_cache: 37 | return ctypes._pointer_t_type_cache[clsname] 38 | # make template 39 | class _T(_ctypes._SimpleCData,): 40 | _type_ = 'L' 41 | _subtype_ = pointee 42 | def _sub_addr_(self): 43 | return self.value 44 | def __repr__(self): 45 | return '%s(%d)'%(clsname, self.value) 46 | def contents(self): 47 | raise TypeError('This is not a ctypes pointer.') 48 | def __init__(self, **args): 49 | raise TypeError('This is not a ctypes pointer. It is not instanciable.') 50 | _class = type('LP_%d_%s'%(8, clsname), (_T,),{}) 51 | ctypes._pointer_t_type_cache[clsname] = _class 52 | return _class 53 | 54 | c_int128 = ctypes.c_ubyte*16 55 | c_uint128 = c_int128 56 | void = None 57 | if ctypes.sizeof(ctypes.c_longdouble) == 16: 58 | c_long_double_t = ctypes.c_longdouble 59 | else: 60 | c_long_double_t = ctypes.c_ubyte*16 61 | 62 | 63 | 64 | size_t = ctypes.c_uint64 65 | socklen_t = ctypes.c_uint32 66 | class struct_sockaddr(ctypes.Structure): 67 | _pack_ = True # source:False 68 | _fields_ = [ 69 | ('sa_family', ctypes.c_uint16), 70 | ('sa_data', ctypes.c_char * 14), 71 | ] 72 | 73 | ssize_t = ctypes.c_int64 74 | class struct_CassieCoreSim(ctypes.Structure): 75 | pass 76 | 77 | cassie_core_sim_t = struct_CassieCoreSim 78 | cassie_core_sim_alloc = _libraries['./libcassiemujoco.so'].cassie_core_sim_alloc 79 | cassie_core_sim_alloc.restype = POINTER_T(struct_CassieCoreSim) 80 | cassie_core_sim_alloc.argtypes = [] 81 | cassie_core_sim_copy = _libraries['./libcassiemujoco.so'].cassie_core_sim_copy 82 | cassie_core_sim_copy.restype = None 83 | cassie_core_sim_copy.argtypes = [POINTER_T(struct_CassieCoreSim), POINTER_T(struct_CassieCoreSim)] 84 | cassie_core_sim_free = _libraries['./libcassiemujoco.so'].cassie_core_sim_free 85 | cassie_core_sim_free.restype = None 86 | cassie_core_sim_free.argtypes = [POINTER_T(struct_CassieCoreSim)] 87 | cassie_core_sim_setup = _libraries['./libcassiemujoco.so'].cassie_core_sim_setup 88 | cassie_core_sim_setup.restype = None 89 | cassie_core_sim_setup.argtypes = [POINTER_T(struct_CassieCoreSim)] 90 | class struct_c__SA_cassie_user_in_t(ctypes.Structure): 91 | pass 92 | 93 | class struct_c__SA_cassie_out_t(ctypes.Structure): 94 | pass 95 | 96 | class struct_c__SA_cassie_in_t(ctypes.Structure): 97 | pass 98 | 99 | cassie_core_sim_step = _libraries['./libcassiemujoco.so'].cassie_core_sim_step 100 | cassie_core_sim_step.restype = None 101 | cassie_core_sim_step.argtypes = [POINTER_T(struct_CassieCoreSim), POINTER_T(struct_c__SA_cassie_user_in_t), POINTER_T(struct_c__SA_cassie_out_t), POINTER_T(struct_c__SA_cassie_in_t)] 102 | class struct_c__SA_elmo_in_t(ctypes.Structure): 103 | _pack_ = True # source:False 104 | _fields_ = [ 105 | ('controlWord', ctypes.c_uint16), 106 | ('PADDING_0', ctypes.c_ubyte * 6), 107 | ('torque', ctypes.c_double), 108 | ] 109 | 110 | elmo_in_t = struct_c__SA_elmo_in_t 111 | class struct_c__SA_cassie_leg_in_t(ctypes.Structure): 112 | _pack_ = True # source:False 113 | _fields_ = [ 114 | ('hipRollDrive', elmo_in_t), 115 | ('hipYawDrive', elmo_in_t), 116 | ('hipPitchDrive', elmo_in_t), 117 | ('kneeDrive', elmo_in_t), 118 | ('footDrive', elmo_in_t), 119 | ] 120 | 121 | cassie_leg_in_t = struct_c__SA_cassie_leg_in_t 122 | class struct_c__SA_radio_in_t(ctypes.Structure): 123 | _pack_ = True # source:False 124 | _fields_ = [ 125 | ('channel', ctypes.c_int16 * 14), 126 | ] 127 | 128 | radio_in_t = struct_c__SA_radio_in_t 129 | class struct_c__SA_cassie_pelvis_in_t(ctypes.Structure): 130 | _pack_ = True # source:False 131 | _fields_ = [ 132 | ('radio', radio_in_t), 133 | ('sto', ctypes.c_bool), 134 | ('piezoState', ctypes.c_bool), 135 | ('piezoTone', ctypes.c_ubyte), 136 | ('PADDING_0', ctypes.c_ubyte), 137 | ] 138 | 139 | cassie_pelvis_in_t = struct_c__SA_cassie_pelvis_in_t 140 | struct_c__SA_cassie_in_t._pack_ = True # source:False 141 | struct_c__SA_cassie_in_t._fields_ = [ 142 | ('pelvis', cassie_pelvis_in_t), 143 | ('leftLeg', cassie_leg_in_t), 144 | ('rightLeg', cassie_leg_in_t), 145 | ] 146 | 147 | cassie_in_t = struct_c__SA_cassie_in_t 148 | pack_cassie_in_t = _libraries['./libcassiemujoco.so'].pack_cassie_in_t 149 | pack_cassie_in_t.restype = None 150 | pack_cassie_in_t.argtypes = [POINTER_T(struct_c__SA_cassie_in_t), POINTER_T(ctypes.c_ubyte)] 151 | unpack_cassie_in_t = _libraries['./libcassiemujoco.so'].unpack_cassie_in_t 152 | unpack_cassie_in_t.restype = None 153 | unpack_cassie_in_t.argtypes = [POINTER_T(ctypes.c_ubyte), POINTER_T(struct_c__SA_cassie_in_t)] 154 | DiagnosticCodes = ctypes.c_int16 155 | class struct_c__SA_battery_out_t(ctypes.Structure): 156 | _pack_ = True # source:False 157 | _fields_ = [ 158 | ('dataGood', ctypes.c_bool), 159 | ('PADDING_0', ctypes.c_ubyte * 7), 160 | ('stateOfCharge', ctypes.c_double), 161 | ('voltage', ctypes.c_double * 12), 162 | ('current', ctypes.c_double), 163 | ('temperature', ctypes.c_double * 4), 164 | ] 165 | 166 | battery_out_t = struct_c__SA_battery_out_t 167 | class struct_c__SA_cassie_joint_out_t(ctypes.Structure): 168 | _pack_ = True # source:False 169 | _fields_ = [ 170 | ('position', ctypes.c_double), 171 | ('velocity', ctypes.c_double), 172 | ] 173 | 174 | cassie_joint_out_t = struct_c__SA_cassie_joint_out_t 175 | class struct_c__SA_elmo_out_t(ctypes.Structure): 176 | _pack_ = True # source:False 177 | _fields_ = [ 178 | ('statusWord', ctypes.c_uint16), 179 | ('PADDING_0', ctypes.c_ubyte * 6), 180 | ('position', ctypes.c_double), 181 | ('velocity', ctypes.c_double), 182 | ('torque', ctypes.c_double), 183 | ('driveTemperature', ctypes.c_double), 184 | ('dcLinkVoltage', ctypes.c_double), 185 | ('torqueLimit', ctypes.c_double), 186 | ('gearRatio', ctypes.c_double), 187 | ] 188 | 189 | elmo_out_t = struct_c__SA_elmo_out_t 190 | class struct_c__SA_cassie_leg_out_t(ctypes.Structure): 191 | _pack_ = True # source:False 192 | _fields_ = [ 193 | ('hipRollDrive', elmo_out_t), 194 | ('hipYawDrive', elmo_out_t), 195 | ('hipPitchDrive', elmo_out_t), 196 | ('kneeDrive', elmo_out_t), 197 | ('footDrive', elmo_out_t), 198 | ('shinJoint', cassie_joint_out_t), 199 | ('tarsusJoint', cassie_joint_out_t), 200 | ('footJoint', cassie_joint_out_t), 201 | ('medullaCounter', ctypes.c_ubyte), 202 | ('PADDING_0', ctypes.c_ubyte), 203 | ('medullaCpuLoad', ctypes.c_uint16), 204 | ('reedSwitchState', ctypes.c_bool), 205 | ('PADDING_1', ctypes.c_ubyte * 3), 206 | ] 207 | 208 | cassie_leg_out_t = struct_c__SA_cassie_leg_out_t 209 | class struct_c__SA_radio_out_t(ctypes.Structure): 210 | _pack_ = True # source:False 211 | _fields_ = [ 212 | ('radioReceiverSignalGood', ctypes.c_bool), 213 | ('receiverMedullaSignalGood', ctypes.c_bool), 214 | ('PADDING_0', ctypes.c_ubyte * 6), 215 | ('channel', ctypes.c_double * 16), 216 | ] 217 | 218 | radio_out_t = struct_c__SA_radio_out_t 219 | class struct_c__SA_target_pc_out_t(ctypes.Structure): 220 | _pack_ = True # source:False 221 | _fields_ = [ 222 | ('etherCatStatus', ctypes.c_int32 * 6), 223 | ('etherCatNotifications', ctypes.c_int32 * 21), 224 | ('PADDING_0', ctypes.c_ubyte * 4), 225 | ('taskExecutionTime', ctypes.c_double), 226 | ('overloadCounter', ctypes.c_uint32), 227 | ('PADDING_1', ctypes.c_ubyte * 4), 228 | ('cpuTemperature', ctypes.c_double), 229 | ] 230 | 231 | target_pc_out_t = struct_c__SA_target_pc_out_t 232 | class struct_c__SA_vectornav_out_t(ctypes.Structure): 233 | _pack_ = True # source:False 234 | _fields_ = [ 235 | ('dataGood', ctypes.c_bool), 236 | ('PADDING_0', ctypes.c_ubyte), 237 | ('vpeStatus', ctypes.c_uint16), 238 | ('PADDING_1', ctypes.c_ubyte * 4), 239 | ('pressure', ctypes.c_double), 240 | ('temperature', ctypes.c_double), 241 | ('magneticField', ctypes.c_double * 3), 242 | ('angularVelocity', ctypes.c_double * 3), 243 | ('linearAcceleration', ctypes.c_double * 3), 244 | ('orientation', ctypes.c_double * 4), 245 | ] 246 | 247 | vectornav_out_t = struct_c__SA_vectornav_out_t 248 | class struct_c__SA_cassie_pelvis_out_t(ctypes.Structure): 249 | _pack_ = True # source:False 250 | _fields_ = [ 251 | ('targetPc', target_pc_out_t), 252 | ('battery', battery_out_t), 253 | ('radio', radio_out_t), 254 | ('vectorNav', vectornav_out_t), 255 | ('medullaCounter', ctypes.c_ubyte), 256 | ('PADDING_0', ctypes.c_ubyte), 257 | ('medullaCpuLoad', ctypes.c_uint16), 258 | ('bleederState', ctypes.c_bool), 259 | ('leftReedSwitchState', ctypes.c_bool), 260 | ('rightReedSwitchState', ctypes.c_bool), 261 | ('PADDING_1', ctypes.c_ubyte), 262 | ('vtmTemperature', ctypes.c_double), 263 | ] 264 | 265 | cassie_pelvis_out_t = struct_c__SA_cassie_pelvis_out_t 266 | struct_c__SA_cassie_out_t._pack_ = True # source:False 267 | struct_c__SA_cassie_out_t._fields_ = [ 268 | ('pelvis', cassie_pelvis_out_t), 269 | ('leftLeg', cassie_leg_out_t), 270 | ('rightLeg', cassie_leg_out_t), 271 | ('isCalibrated', ctypes.c_bool), 272 | ('PADDING_0', ctypes.c_ubyte), 273 | ('messages', ctypes.c_int16 * 4), 274 | ('PADDING_1', ctypes.c_ubyte * 6), 275 | ] 276 | 277 | cassie_out_t = struct_c__SA_cassie_out_t 278 | pack_cassie_out_t = _libraries['./libcassiemujoco.so'].pack_cassie_out_t 279 | pack_cassie_out_t.restype = None 280 | pack_cassie_out_t.argtypes = [POINTER_T(struct_c__SA_cassie_out_t), POINTER_T(ctypes.c_ubyte)] 281 | unpack_cassie_out_t = _libraries['./libcassiemujoco.so'].unpack_cassie_out_t 282 | unpack_cassie_out_t.restype = None 283 | unpack_cassie_out_t.argtypes = [POINTER_T(ctypes.c_ubyte), POINTER_T(struct_c__SA_cassie_out_t)] 284 | struct_c__SA_cassie_user_in_t._pack_ = True # source:False 285 | struct_c__SA_cassie_user_in_t._fields_ = [ 286 | ('torque', ctypes.c_double * 10), 287 | ('telemetry', ctypes.c_int16 * 9), 288 | ('PADDING_0', ctypes.c_ubyte * 6), 289 | ] 290 | 291 | cassie_user_in_t = struct_c__SA_cassie_user_in_t 292 | pack_cassie_user_in_t = _libraries['./libcassiemujoco.so'].pack_cassie_user_in_t 293 | pack_cassie_user_in_t.restype = None 294 | pack_cassie_user_in_t.argtypes = [POINTER_T(struct_c__SA_cassie_user_in_t), POINTER_T(ctypes.c_ubyte)] 295 | unpack_cassie_user_in_t = _libraries['./libcassiemujoco.so'].unpack_cassie_user_in_t 296 | unpack_cassie_user_in_t.restype = None 297 | unpack_cassie_user_in_t.argtypes = [POINTER_T(ctypes.c_ubyte), POINTER_T(struct_c__SA_cassie_user_in_t)] 298 | class struct_cassie_sim(ctypes.Structure): 299 | pass 300 | 301 | cassie_sim_t = struct_cassie_sim 302 | class struct_cassie_vis(ctypes.Structure): 303 | pass 304 | 305 | cassie_vis_t = struct_cassie_vis 306 | class struct_cassie_state(ctypes.Structure): 307 | pass 308 | 309 | cassie_state_t = struct_cassie_state 310 | cassie_mujoco_init = _libraries['./libcassiemujoco.so'].cassie_mujoco_init 311 | cassie_mujoco_init.restype = ctypes.c_bool 312 | cassie_mujoco_init.argtypes = [POINTER_T(ctypes.c_char)] 313 | cassie_cleanup = _libraries['./libcassiemujoco.so'].cassie_cleanup 314 | cassie_cleanup.restype = None 315 | cassie_cleanup.argtypes = [] 316 | cassie_sim_init = _libraries['./libcassiemujoco.so'].cassie_sim_init 317 | cassie_sim_init.restype = POINTER_T(struct_cassie_sim) 318 | cassie_sim_init.argtypes = [ctypes.c_char_p] 319 | cassie_sim_duplicate = _libraries['./libcassiemujoco.so'].cassie_sim_duplicate 320 | cassie_sim_duplicate.restype = POINTER_T(struct_cassie_sim) 321 | cassie_sim_duplicate.argtypes = [POINTER_T(struct_cassie_sim)] 322 | cassie_sim_copy = _libraries['./libcassiemujoco.so'].cassie_sim_copy 323 | cassie_sim_copy.restype = None 324 | cassie_sim_copy.argtypes = [POINTER_T(struct_cassie_sim), POINTER_T(struct_cassie_sim)] 325 | cassie_sim_free = _libraries['./libcassiemujoco.so'].cassie_sim_free 326 | cassie_sim_free.restype = None 327 | cassie_sim_free.argtypes = [POINTER_T(struct_cassie_sim)] 328 | cassie_sim_step_ethercat = _libraries['./libcassiemujoco.so'].cassie_sim_step_ethercat 329 | cassie_sim_step_ethercat.restype = None 330 | cassie_sim_step_ethercat.argtypes = [POINTER_T(struct_cassie_sim), POINTER_T(struct_c__SA_cassie_out_t), POINTER_T(struct_c__SA_cassie_in_t)] 331 | cassie_sim_step = _libraries['./libcassiemujoco.so'].cassie_sim_step 332 | cassie_sim_step.restype = None 333 | cassie_sim_step.argtypes = [POINTER_T(struct_cassie_sim), POINTER_T(struct_c__SA_cassie_out_t), POINTER_T(struct_c__SA_cassie_user_in_t)] 334 | class struct_c__SA_state_out_t(ctypes.Structure): 335 | pass 336 | 337 | class struct_c__SA_pd_in_t(ctypes.Structure): 338 | pass 339 | 340 | cassie_sim_step_pd = _libraries['./libcassiemujoco.so'].cassie_sim_step_pd 341 | cassie_sim_step_pd.restype = None 342 | cassie_sim_step_pd.argtypes = [POINTER_T(struct_cassie_sim), POINTER_T(struct_c__SA_state_out_t), POINTER_T(struct_c__SA_pd_in_t)] 343 | cassie_sim_time = _libraries['./libcassiemujoco.so'].cassie_sim_time 344 | cassie_sim_time.restype = POINTER_T(ctypes.c_double) 345 | cassie_sim_time.argtypes = [POINTER_T(struct_cassie_sim)] 346 | cassie_sim_qpos = _libraries['./libcassiemujoco.so'].cassie_sim_qpos 347 | cassie_sim_qpos.restype = POINTER_T(ctypes.c_double) 348 | cassie_sim_qpos.argtypes = [POINTER_T(struct_cassie_sim)] 349 | cassie_sim_qvel = _libraries['./libcassiemujoco.so'].cassie_sim_qvel 350 | cassie_sim_qvel.restype = POINTER_T(ctypes.c_double) 351 | cassie_sim_qvel.argtypes = [POINTER_T(struct_cassie_sim)] 352 | cassie_sim_mjmodel = _libraries['./libcassiemujoco.so'].cassie_sim_mjmodel 353 | cassie_sim_mjmodel.restype = POINTER_T(None) 354 | cassie_sim_mjmodel.argtypes = [POINTER_T(struct_cassie_sim)] 355 | cassie_sim_mjdata = _libraries['./libcassiemujoco.so'].cassie_sim_mjdata 356 | cassie_sim_mjdata.restype = POINTER_T(None) 357 | cassie_sim_mjdata.argtypes = [POINTER_T(struct_cassie_sim)] 358 | cassie_sim_check_obstacle_collision = _libraries['./libcassiemujoco.so'].cassie_sim_check_obstacle_collision 359 | cassie_sim_check_obstacle_collision.restype = ctypes.c_bool 360 | cassie_sim_check_obstacle_collision.argtypes = [POINTER_T(struct_cassie_sim)] 361 | cassie_sim_check_self_collision = _libraries['./libcassiemujoco.so'].cassie_sim_check_self_collision 362 | cassie_sim_check_self_collision.restype = ctypes.c_bool 363 | cassie_sim_check_self_collision.argtypes = [POINTER_T(struct_cassie_sim)] 364 | cassie_sim_foot_forces = _libraries['./libcassiemujoco.so'].cassie_sim_foot_forces 365 | cassie_sim_foot_forces.restype = None 366 | cassie_sim_foot_forces.argtypes = [POINTER_T(struct_cassie_sim), ctypes.c_double * 12] 367 | cassie_sim_foot_positions = _libraries['./libcassiemujoco.so'].cassie_sim_foot_positions 368 | cassie_sim_foot_positions.restype = None 369 | cassie_sim_foot_positions.argtypes = [POINTER_T(struct_cassie_sim), ctypes.c_double * 6] 370 | cassie_sim_apply_force = _libraries['./libcassiemujoco.so'].cassie_sim_apply_force 371 | cassie_sim_apply_force.restype = None 372 | cassie_sim_apply_force.argtypes = [POINTER_T(struct_cassie_sim), ctypes.c_double * 6, ctypes.c_int32] 373 | cassie_sim_clear_forces = _libraries['./libcassiemujoco.so'].cassie_sim_clear_forces 374 | cassie_sim_clear_forces.restype = None 375 | cassie_sim_clear_forces.argtypes = [POINTER_T(struct_cassie_sim)] 376 | cassie_sim_hold = _libraries['./libcassiemujoco.so'].cassie_sim_hold 377 | cassie_sim_hold.restype = None 378 | cassie_sim_hold.argtypes = [POINTER_T(struct_cassie_sim)] 379 | cassie_sim_release = _libraries['./libcassiemujoco.so'].cassie_sim_release 380 | cassie_sim_release.restype = None 381 | cassie_sim_release.argtypes = [POINTER_T(struct_cassie_sim)] 382 | cassie_sim_radio = _libraries['./libcassiemujoco.so'].cassie_sim_radio 383 | cassie_sim_radio.restype = None 384 | cassie_sim_radio.argtypes = [POINTER_T(struct_cassie_sim), ctypes.c_double * 16] 385 | cassie_vis_init = _libraries['./libcassiemujoco.so'].cassie_vis_init 386 | cassie_vis_init.restype = POINTER_T(struct_cassie_vis) 387 | cassie_vis_init.argtypes = [POINTER_T(struct_cassie_sim), ctypes.c_char_p] 388 | cassie_vis_close = _libraries['./libcassiemujoco.so'].cassie_vis_close 389 | cassie_vis_close.restype = None 390 | cassie_vis_close.argtypes = [POINTER_T(struct_cassie_vis)] 391 | cassie_vis_free = _libraries['./libcassiemujoco.so'].cassie_vis_free 392 | cassie_vis_free.restype = None 393 | cassie_vis_free.argtypes = [POINTER_T(struct_cassie_vis)] 394 | cassie_vis_draw = _libraries['./libcassiemujoco.so'].cassie_vis_draw 395 | cassie_vis_draw.restype = ctypes.c_bool 396 | cassie_vis_draw.argtypes = [POINTER_T(struct_cassie_vis), POINTER_T(struct_cassie_sim)] 397 | cassie_vis_valid = _libraries['./libcassiemujoco.so'].cassie_vis_valid 398 | cassie_vis_valid.restype = ctypes.c_bool 399 | cassie_vis_valid.argtypes = [POINTER_T(struct_cassie_vis)] 400 | cassie_vis_paused = _libraries['./libcassiemujoco.so'].cassie_vis_paused 401 | cassie_vis_paused.restype = ctypes.c_bool 402 | cassie_vis_paused.argtypes = [POINTER_T(struct_cassie_vis)] 403 | cassie_state_alloc = _libraries['./libcassiemujoco.so'].cassie_state_alloc 404 | cassie_state_alloc.restype = POINTER_T(struct_cassie_state) 405 | cassie_state_alloc.argtypes = [] 406 | cassie_state_duplicate = _libraries['./libcassiemujoco.so'].cassie_state_duplicate 407 | cassie_state_duplicate.restype = POINTER_T(struct_cassie_state) 408 | cassie_state_duplicate.argtypes = [POINTER_T(struct_cassie_state)] 409 | cassie_state_copy = _libraries['./libcassiemujoco.so'].cassie_state_copy 410 | cassie_state_copy.restype = None 411 | cassie_state_copy.argtypes = [POINTER_T(struct_cassie_state), POINTER_T(struct_cassie_state)] 412 | cassie_state_free = _libraries['./libcassiemujoco.so'].cassie_state_free 413 | cassie_state_free.restype = None 414 | cassie_state_free.argtypes = [POINTER_T(struct_cassie_state)] 415 | cassie_state_time = _libraries['./libcassiemujoco.so'].cassie_state_time 416 | cassie_state_time.restype = POINTER_T(ctypes.c_double) 417 | cassie_state_time.argtypes = [POINTER_T(struct_cassie_state)] 418 | cassie_state_qpos = _libraries['./libcassiemujoco.so'].cassie_state_qpos 419 | cassie_state_qpos.restype = POINTER_T(ctypes.c_double) 420 | cassie_state_qpos.argtypes = [POINTER_T(struct_cassie_state)] 421 | cassie_state_qvel = _libraries['./libcassiemujoco.so'].cassie_state_qvel 422 | cassie_state_qvel.restype = POINTER_T(ctypes.c_double) 423 | cassie_state_qvel.argtypes = [POINTER_T(struct_cassie_state)] 424 | cassie_get_state = _libraries['./libcassiemujoco.so'].cassie_get_state 425 | cassie_get_state.restype = None 426 | cassie_get_state.argtypes = [POINTER_T(struct_cassie_sim), POINTER_T(struct_cassie_state)] 427 | cassie_set_state = _libraries['./libcassiemujoco.so'].cassie_set_state 428 | cassie_set_state.restype = None 429 | cassie_set_state.argtypes = [POINTER_T(struct_cassie_sim), POINTER_T(struct_cassie_state)] 430 | 431 | #cassie_sim_foot_positions.argtypes = [POINTER_T(struct_cassie_sim), ctypes.c_double * 6] 432 | 433 | cassie_sim_dof_damping = _libraries['./libcassiemujoco.so'].cassie_sim_dof_damping 434 | cassie_sim_dof_damping.restype = POINTER_T(ctypes.c_double) 435 | cassie_sim_dof_damping.argtypes = [POINTER_T(struct_cassie_sim)] 436 | 437 | cassie_sim_set_dof_damping = _libraries['./libcassiemujoco.so'].cassie_sim_set_dof_damping 438 | cassie_sim_set_dof_damping.restype = None 439 | cassie_sim_set_dof_damping.argtypes = [POINTER_T(struct_cassie_sim), POINTER_T(ctypes.c_double)] 440 | 441 | cassie_sim_body_mass = _libraries['./libcassiemujoco.so'].cassie_sim_body_mass 442 | cassie_sim_body_mass.restype = POINTER_T(ctypes.c_double) 443 | cassie_sim_body_mass.argtypes = [POINTER_T(struct_cassie_sim)] 444 | 445 | cassie_sim_set_body_mass = _libraries['./libcassiemujoco.so'].cassie_sim_set_body_mass 446 | cassie_sim_set_body_mass.restype = None 447 | cassie_sim_set_body_mass.argtypes = [POINTER_T(struct_cassie_sim), POINTER_T(ctypes.c_double)] 448 | 449 | cassie_sim_body_ipos = _libraries['./libcassiemujoco.so'].cassie_sim_body_ipos 450 | cassie_sim_body_ipos.restype = POINTER_T(ctypes.c_double) 451 | cassie_sim_body_ipos.argtypes = [POINTER_T(struct_cassie_sim)] 452 | 453 | cassie_sim_set_body_ipos = _libraries['./libcassiemujoco.so'].cassie_sim_set_body_ipos 454 | cassie_sim_set_body_ipos.restype = None 455 | cassie_sim_set_body_ipos.argtypes = [POINTER_T(struct_cassie_sim), POINTER_T(ctypes.c_double)] 456 | 457 | cassie_sim_ground_friction = _libraries['./libcassiemujoco.so'].cassie_sim_ground_friction 458 | cassie_sim_ground_friction.restype = POINTER_T(ctypes.c_double) 459 | cassie_sim_ground_friction.argtypes = [POINTER_T(struct_cassie_sim)] 460 | 461 | cassie_sim_set_ground_friction = _libraries['./libcassiemujoco.so'].cassie_sim_set_ground_friction 462 | cassie_sim_set_ground_friction.restype = None 463 | cassie_sim_set_ground_friction.argtypes = [POINTER_T(struct_cassie_sim), POINTER_T(ctypes.c_double)] 464 | 465 | cassie_sim_geom_rgba = _libraries['./libcassiemujoco.so'].cassie_sim_geom_rgba 466 | cassie_sim_geom_rgba.restype = POINTER_T(ctypes.c_float) 467 | cassie_sim_geom_rgba.argtypes = [POINTER_T(struct_cassie_sim)] 468 | 469 | cassie_sim_set_geom_rgba = _libraries['./libcassiemujoco.so'].cassie_sim_set_geom_rgba 470 | cassie_sim_set_geom_rgba.restype = None 471 | cassie_sim_set_geom_rgba.argtypes = [POINTER_T(struct_cassie_sim), POINTER_T(ctypes.c_float)] 472 | 473 | cassie_sim_set_const = _libraries['./libcassiemujoco.so'].cassie_sim_set_const 474 | cassie_sim_set_const.restype = None 475 | cassie_sim_set_const.argtypes = [POINTER_T(struct_cassie_sim)] 476 | 477 | class struct_c__SA_pd_motor_in_t(ctypes.Structure): 478 | _pack_ = True # source:False 479 | _fields_ = [ 480 | ('torque', ctypes.c_double * 5), 481 | ('pTarget', ctypes.c_double * 5), 482 | ('dTarget', ctypes.c_double * 5), 483 | ('pGain', ctypes.c_double * 5), 484 | ('dGain', ctypes.c_double * 5), 485 | ] 486 | 487 | pd_motor_in_t = struct_c__SA_pd_motor_in_t 488 | class struct_c__SA_pd_task_in_t(ctypes.Structure): 489 | _pack_ = True # source:False 490 | _fields_ = [ 491 | ('torque', ctypes.c_double * 6), 492 | ('pTarget', ctypes.c_double * 6), 493 | ('dTarget', ctypes.c_double * 6), 494 | ('pGain', ctypes.c_double * 6), 495 | ('dGain', ctypes.c_double * 6), 496 | ] 497 | 498 | pd_task_in_t = struct_c__SA_pd_task_in_t 499 | class struct_c__SA_pd_leg_in_t(ctypes.Structure): 500 | _pack_ = True # source:False 501 | _fields_ = [ 502 | ('taskPd', pd_task_in_t), 503 | ('motorPd', pd_motor_in_t), 504 | ] 505 | 506 | pd_leg_in_t = struct_c__SA_pd_leg_in_t 507 | struct_c__SA_pd_in_t._pack_ = True # source:False 508 | struct_c__SA_pd_in_t._fields_ = [ 509 | ('leftLeg', pd_leg_in_t), 510 | ('rightLeg', pd_leg_in_t), 511 | ('telemetry', ctypes.c_double * 9), 512 | ] 513 | 514 | pd_in_t = struct_c__SA_pd_in_t 515 | pack_pd_in_t = _libraries['./libcassiemujoco.so'].pack_pd_in_t 516 | pack_pd_in_t.restype = None 517 | pack_pd_in_t.argtypes = [POINTER_T(struct_c__SA_pd_in_t), POINTER_T(ctypes.c_ubyte)] 518 | unpack_pd_in_t = _libraries['./libcassiemujoco.so'].unpack_pd_in_t 519 | unpack_pd_in_t.restype = None 520 | unpack_pd_in_t.argtypes = [POINTER_T(ctypes.c_ubyte), POINTER_T(struct_c__SA_pd_in_t)] 521 | class struct_PdInput(ctypes.Structure): 522 | pass 523 | 524 | pd_input_t = struct_PdInput 525 | pd_input_alloc = _libraries['./libcassiemujoco.so'].pd_input_alloc 526 | pd_input_alloc.restype = POINTER_T(struct_PdInput) 527 | pd_input_alloc.argtypes = [] 528 | pd_input_copy = _libraries['./libcassiemujoco.so'].pd_input_copy 529 | pd_input_copy.restype = None 530 | pd_input_copy.argtypes = [POINTER_T(struct_PdInput), POINTER_T(struct_PdInput)] 531 | pd_input_free = _libraries['./libcassiemujoco.so'].pd_input_free 532 | pd_input_free.restype = None 533 | pd_input_free.argtypes = [POINTER_T(struct_PdInput)] 534 | pd_input_setup = _libraries['./libcassiemujoco.so'].pd_input_setup 535 | pd_input_setup.restype = None 536 | pd_input_setup.argtypes = [POINTER_T(struct_PdInput)] 537 | pd_input_step = _libraries['./libcassiemujoco.so'].pd_input_step 538 | pd_input_step.restype = None 539 | pd_input_step.argtypes = [POINTER_T(struct_PdInput), POINTER_T(struct_c__SA_pd_in_t), POINTER_T(struct_c__SA_cassie_out_t), POINTER_T(struct_c__SA_cassie_user_in_t)] 540 | class struct_c__SA_state_battery_out_t(ctypes.Structure): 541 | _pack_ = True # source:False 542 | _fields_ = [ 543 | ('stateOfCharge', ctypes.c_double), 544 | ('current', ctypes.c_double), 545 | ] 546 | 547 | state_battery_out_t = struct_c__SA_state_battery_out_t 548 | class struct_c__SA_state_foot_out_t(ctypes.Structure): 549 | _pack_ = True # source:False 550 | _fields_ = [ 551 | ('position', ctypes.c_double * 3), 552 | ('orientation', ctypes.c_double * 4), 553 | ('footRotationalVelocity', ctypes.c_double * 3), 554 | ('footTranslationalVelocity', ctypes.c_double * 3), 555 | ('toeForce', ctypes.c_double * 3), 556 | ('heelForce', ctypes.c_double * 3), 557 | ] 558 | 559 | state_foot_out_t = struct_c__SA_state_foot_out_t 560 | class struct_c__SA_state_joint_out_t(ctypes.Structure): 561 | _pack_ = True # source:False 562 | _fields_ = [ 563 | ('position', ctypes.c_double * 6), 564 | ('velocity', ctypes.c_double * 6), 565 | ] 566 | 567 | state_joint_out_t = struct_c__SA_state_joint_out_t 568 | class struct_c__SA_state_motor_out_t(ctypes.Structure): 569 | _pack_ = True # source:False 570 | _fields_ = [ 571 | ('position', ctypes.c_double * 10), 572 | ('velocity', ctypes.c_double * 10), 573 | ('torque', ctypes.c_double * 10), 574 | ] 575 | 576 | state_motor_out_t = struct_c__SA_state_motor_out_t 577 | class struct_c__SA_state_pelvis_out_t(ctypes.Structure): 578 | _pack_ = True # source:False 579 | _fields_ = [ 580 | ('position', ctypes.c_double * 3), 581 | ('orientation', ctypes.c_double * 4), 582 | ('rotationalVelocity', ctypes.c_double * 3), 583 | ('translationalVelocity', ctypes.c_double * 3), 584 | ('translationalAcceleration', ctypes.c_double * 3), 585 | ('externalMoment', ctypes.c_double * 3), 586 | ('externalForce', ctypes.c_double * 3), 587 | ] 588 | 589 | state_pelvis_out_t = struct_c__SA_state_pelvis_out_t 590 | class struct_c__SA_state_radio_out_t(ctypes.Structure): 591 | _pack_ = True # source:False 592 | _fields_ = [ 593 | ('channel', ctypes.c_double * 16), 594 | ('signalGood', ctypes.c_bool), 595 | ('PADDING_0', ctypes.c_ubyte * 7), 596 | ] 597 | 598 | state_radio_out_t = struct_c__SA_state_radio_out_t 599 | class struct_c__SA_state_terrain_out_t(ctypes.Structure): 600 | _pack_ = True # source:False 601 | _fields_ = [ 602 | ('height', ctypes.c_double), 603 | ('slope', ctypes.c_double * 2), 604 | ] 605 | 606 | state_terrain_out_t = struct_c__SA_state_terrain_out_t 607 | struct_c__SA_state_out_t._pack_ = True # source:False 608 | struct_c__SA_state_out_t._fields_ = [ 609 | ('pelvis', state_pelvis_out_t), 610 | ('leftFoot', state_foot_out_t), 611 | ('rightFoot', state_foot_out_t), 612 | ('terrain', state_terrain_out_t), 613 | ('motor', state_motor_out_t), 614 | ('joint', state_joint_out_t), 615 | ('radio', state_radio_out_t), 616 | ('battery', state_battery_out_t), 617 | ] 618 | 619 | state_out_t = struct_c__SA_state_out_t 620 | pack_state_out_t = _libraries['./libcassiemujoco.so'].pack_state_out_t 621 | pack_state_out_t.restype = None 622 | pack_state_out_t.argtypes = [POINTER_T(struct_c__SA_state_out_t), POINTER_T(ctypes.c_ubyte)] 623 | unpack_state_out_t = _libraries['./libcassiemujoco.so'].unpack_state_out_t 624 | unpack_state_out_t.restype = None 625 | unpack_state_out_t.argtypes = [POINTER_T(ctypes.c_ubyte), POINTER_T(struct_c__SA_state_out_t)] 626 | class struct_StateOutput(ctypes.Structure): 627 | pass 628 | 629 | state_output_t = struct_StateOutput 630 | state_output_alloc = _libraries['./libcassiemujoco.so'].state_output_alloc 631 | state_output_alloc.restype = POINTER_T(struct_StateOutput) 632 | state_output_alloc.argtypes = [] 633 | state_output_copy = _libraries['./libcassiemujoco.so'].state_output_copy 634 | state_output_copy.restype = None 635 | state_output_copy.argtypes = [POINTER_T(struct_StateOutput), POINTER_T(struct_StateOutput)] 636 | state_output_free = _libraries['./libcassiemujoco.so'].state_output_free 637 | state_output_free.restype = None 638 | state_output_free.argtypes = [POINTER_T(struct_StateOutput)] 639 | state_output_setup = _libraries['./libcassiemujoco.so'].state_output_setup 640 | state_output_setup.restype = None 641 | state_output_setup.argtypes = [POINTER_T(struct_StateOutput)] 642 | state_output_step = _libraries['./libcassiemujoco.so'].state_output_step 643 | state_output_step.restype = None 644 | state_output_step.argtypes = [POINTER_T(struct_StateOutput), POINTER_T(struct_c__SA_cassie_out_t), POINTER_T(struct_c__SA_state_out_t)] 645 | class struct_c__SA_packet_header_info_t(ctypes.Structure): 646 | _pack_ = True # source:False 647 | _fields_ = [ 648 | ('seq_num_out', ctypes.c_char), 649 | ('seq_num_in_last', ctypes.c_char), 650 | ('delay', ctypes.c_char), 651 | ('seq_num_in_diff', ctypes.c_char), 652 | ] 653 | 654 | packet_header_info_t = struct_c__SA_packet_header_info_t 655 | process_packet_header = _libraries['./libcassiemujoco.so'].process_packet_header 656 | process_packet_header.restype = None 657 | process_packet_header.argtypes = [POINTER_T(struct_c__SA_packet_header_info_t), POINTER_T(ctypes.c_ubyte), POINTER_T(ctypes.c_ubyte)] 658 | udp_init_host = _libraries['./libcassiemujoco.so'].udp_init_host 659 | udp_init_host.restype = ctypes.c_int32 660 | udp_init_host.argtypes = [POINTER_T(ctypes.c_char), POINTER_T(ctypes.c_char)] 661 | udp_init_client = _libraries['./libcassiemujoco.so'].udp_init_client 662 | udp_init_client.restype = ctypes.c_int32 663 | udp_init_client.argtypes = [POINTER_T(ctypes.c_char), POINTER_T(ctypes.c_char), POINTER_T(ctypes.c_char), POINTER_T(ctypes.c_char)] 664 | udp_close = _libraries['./libcassiemujoco.so'].udp_close 665 | udp_close.restype = None 666 | udp_close.argtypes = [ctypes.c_int32] 667 | get_newest_packet = _libraries['./libcassiemujoco.so'].get_newest_packet 668 | get_newest_packet.restype = ssize_t 669 | get_newest_packet.argtypes = [ctypes.c_int32, POINTER_T(None), size_t, POINTER_T(struct_sockaddr), POINTER_T(ctypes.c_uint32)] 670 | wait_for_packet = _libraries['./libcassiemujoco.so'].wait_for_packet 671 | wait_for_packet.restype = ssize_t 672 | wait_for_packet.argtypes = [ctypes.c_int32, POINTER_T(None), size_t, POINTER_T(struct_sockaddr), POINTER_T(ctypes.c_uint32)] 673 | send_packet = _libraries['./libcassiemujoco.so'].send_packet 674 | send_packet.restype = ssize_t 675 | send_packet.argtypes = [ctypes.c_int32, POINTER_T(None), size_t, POINTER_T(struct_sockaddr), socklen_t] 676 | __all__ = \ 677 | ['cassie_pelvis_in_t', 'struct_StateOutput', 'cassie_state_t', 678 | 'cassie_sim_check_self_collision', 'cassie_vis_free', 679 | 'cassie_in_t', 'state_terrain_out_t', 'struct_c__SA_pd_leg_in_t', 680 | 'cassie_state_free', 'struct_c__SA_state_battery_out_t', 681 | 'elmo_in_t', 'state_joint_out_t', 'send_packet', 682 | 'cassie_pelvis_out_t', 'cassie_cleanup', 683 | 'struct_c__SA_state_radio_out_t', 'cassie_vis_valid', 684 | 'pd_input_setup', 'pd_leg_in_t', 'cassie_mujoco_init', 685 | 'cassie_state_copy', 'cassie_core_sim_setup', 'battery_out_t', 686 | 'cassie_sim_hold', 'struct_CassieCoreSim', 'cassie_core_sim_step', 687 | 'pack_cassie_out_t', 'cassie_out_t', 'radio_in_t', 688 | 'unpack_cassie_out_t', 'struct_c__SA_pd_task_in_t', 689 | 'struct_PdInput', 'udp_init_client', 'pd_motor_in_t', 690 | 'cassie_sim_t', 'cassie_core_sim_alloc', 'get_newest_packet', 691 | 'size_t', 'struct_c__SA_vectornav_out_t', 692 | 'struct_c__SA_pd_motor_in_t', 'cassie_get_state', 693 | 'state_battery_out_t', 'struct_c__SA_state_pelvis_out_t', 694 | 'cassie_state_qpos', 'cassie_state_qvel', 'state_radio_out_t', 695 | 'struct_c__SA_pd_in_t', 'udp_close', 'state_output_free', 696 | 'cassie_core_sim_free', 'pd_task_in_t', 'packet_header_info_t', 697 | 'pd_in_t', 'struct_cassie_vis', 'struct_c__SA_elmo_out_t', 698 | 'pack_pd_in_t', 'struct_c__SA_radio_out_t', 'pd_input_alloc', 699 | 'DiagnosticCodes', 'unpack_state_out_t', 'target_pc_out_t', 700 | 'cassie_sim_duplicate', 'cassie_state_alloc', 'cassie_sim_init', 701 | 'struct_c__SA_cassie_user_in_t', 'struct_c__SA_radio_in_t', 702 | 'socklen_t', 'cassie_vis_init', 'state_out_t', 703 | 'struct_c__SA_cassie_in_t', 'pd_input_free', 'state_output_alloc', 704 | 'struct_c__SA_cassie_leg_out_t', 705 | 'struct_c__SA_cassie_pelvis_in_t', 'unpack_pd_in_t', 706 | 'cassie_user_in_t', 'cassie_sim_clear_forces', 'cassie_vis_t', 707 | 'struct_c__SA_target_pc_out_t', 'pd_input_step', 708 | 'cassie_set_state', 'struct_c__SA_battery_out_t', 709 | 'vectornav_out_t', 'struct_c__SA_packet_header_info_t', 710 | 'cassie_sim_step_pd', 'struct_sockaddr', 'cassie_vis_draw', 711 | 'cassie_core_sim_copy', 'unpack_cassie_in_t', 'struct_cassie_sim', 712 | 'unpack_cassie_user_in_t', 'cassie_sim_step', 'udp_init_host', 713 | 'state_motor_out_t', 'cassie_core_sim_t', 'pack_state_out_t', 714 | 'cassie_sim_mjdata', 'state_output_setup', 'cassie_sim_mjmodel', 715 | 'state_foot_out_t', 'state_output_t', 'cassie_sim_time', 716 | 'cassie_sim_step_ethercat', 'cassie_sim_check_obstacle_collision', 717 | 'elmo_out_t', 'pack_cassie_in_t', 'cassie_sim_apply_force', 718 | 'cassie_leg_out_t', 'wait_for_packet', 719 | 'struct_c__SA_cassie_leg_in_t', 'struct_c__SA_state_joint_out_t', 720 | 'process_packet_header', 'cassie_sim_release', 'cassie_sim_foot_forces', 721 | 'cassie_sim_foot_positions', 'struct_c__SA_state_foot_out_t', 722 | 'pd_input_t', 'pack_cassie_user_in_t', 'cassie_state_duplicate', 723 | 'state_pelvis_out_t', 'struct_c__SA_state_terrain_out_t', 724 | 'cassie_sim_free', 'ssize_t', 'state_output_copy', 725 | 'cassie_sim_radio', 'cassie_vis_close', 'cassie_vis_paused', 'radio_out_t', 726 | 'state_output_step', 'struct_c__SA_state_motor_out_t', 727 | 'struct_cassie_state', 'cassie_state_time', 'cassie_sim_qvel', 728 | 'cassie_sim_qpos', 'struct_c__SA_elmo_in_t', 'cassie_joint_out_t', 729 | 'cassie_leg_in_t', 'struct_c__SA_cassie_joint_out_t', 730 | 'struct_c__SA_state_out_t', 'struct_c__SA_cassie_pelvis_out_t', 731 | 'pd_input_copy', 'cassie_sim_copy', 'struct_c__SA_cassie_out_t', 732 | 'cassie_sim_dof_damping', 'cassie_sim_set_dof_damping', 733 | 'cassie_sim_body_mass', 'cassie_sim_set_body_mass', 734 | 'cassie_sim_body_ipos', 'cassie_sim_set_body_ipos', 735 | 'cassie_sim_ground_friction', 'cassie_sim_set_ground_friction', 736 | 'cassie_sim_set_const', 'cassie_sim_geom_rgba', 'cassie_sim_set_geom_rgba'] 737 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/include/CassieCoreSim.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Agility Robotics 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | */ 16 | 17 | #ifndef CASSIECORESIM_H 18 | #define CASSIECORESIM_H 19 | 20 | #include "cassie_user_in_t.h" 21 | #include "cassie_out_t.h" 22 | #include "cassie_in_t.h" 23 | 24 | typedef struct CassieCoreSim CassieCoreSim; 25 | 26 | #ifdef __cplusplus 27 | extern "C" { 28 | #endif 29 | 30 | CassieCoreSim* CassieCoreSim_alloc(void); 31 | void CassieCoreSim_copy(CassieCoreSim *dst, const CassieCoreSim *src); 32 | void CassieCoreSim_free(CassieCoreSim *sys); 33 | void CassieCoreSim_setup(CassieCoreSim *sys); 34 | void CassieCoreSim_step(CassieCoreSim *sys, const cassie_user_in_t *in1, 35 | const cassie_out_t *in2, cassie_in_t *out1); 36 | 37 | #ifdef __cplusplus 38 | } 39 | #endif 40 | #endif // CASSIECORESIM_H 41 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/include/PdInput.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Agility Robotics 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | */ 16 | 17 | #ifndef PDINPUT_H 18 | #define PDINPUT_H 19 | 20 | #include "pd_in_t.h" 21 | #include "cassie_out_t.h" 22 | #include "cassie_user_in_t.h" 23 | 24 | typedef struct PdInput PdInput; 25 | 26 | #ifdef __cplusplus 27 | extern "C" { 28 | #endif 29 | 30 | PdInput* PdInput_alloc(void); 31 | void PdInput_copy(PdInput *dst, const PdInput *src); 32 | void PdInput_free(PdInput *sys); 33 | void PdInput_setup(PdInput *sys); 34 | void PdInput_step(PdInput *sys, const pd_in_t *in1, const cassie_out_t 35 | *in2, cassie_user_in_t *out1); 36 | 37 | #ifdef __cplusplus 38 | } 39 | #endif 40 | #endif // PDINPUT_H 41 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/include/StateOutput.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Agility Robotics 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | */ 16 | 17 | #ifndef STATEOUTPUT_H 18 | #define STATEOUTPUT_H 19 | 20 | #include "cassie_out_t.h" 21 | #include "state_out_t.h" 22 | 23 | typedef struct StateOutput StateOutput; 24 | 25 | #ifdef __cplusplus 26 | extern "C" { 27 | #endif 28 | 29 | StateOutput* StateOutput_alloc(void); 30 | void StateOutput_copy(StateOutput *dst, const StateOutput *src); 31 | void StateOutput_free(StateOutput *sys); 32 | void StateOutput_setup(StateOutput *sys); 33 | void StateOutput_step(StateOutput *sys, const cassie_out_t *in1, 34 | state_out_t *out1); 35 | 36 | #ifdef __cplusplus 37 | } 38 | #endif 39 | #endif // STATEOUTPUT_H 40 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/include/cassie_in_t.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Agility Robotics 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | */ 16 | 17 | #ifndef CASSIE_IN_T_H 18 | #define CASSIE_IN_T_H 19 | 20 | #define CASSIE_IN_T_PACKED_LEN 91 21 | 22 | #include 23 | 24 | typedef struct { 25 | unsigned short controlWord; 26 | double torque; 27 | } elmo_in_t; 28 | 29 | typedef struct { 30 | elmo_in_t hipRollDrive; 31 | elmo_in_t hipYawDrive; 32 | elmo_in_t hipPitchDrive; 33 | elmo_in_t kneeDrive; 34 | elmo_in_t footDrive; 35 | } cassie_leg_in_t; 36 | 37 | typedef struct { 38 | short channel[14]; 39 | } radio_in_t; 40 | 41 | typedef struct { 42 | radio_in_t radio; 43 | bool sto; 44 | bool piezoState; 45 | unsigned char piezoTone; 46 | } cassie_pelvis_in_t; 47 | 48 | typedef struct { 49 | cassie_pelvis_in_t pelvis; 50 | cassie_leg_in_t leftLeg; 51 | cassie_leg_in_t rightLeg; 52 | } cassie_in_t; 53 | 54 | 55 | #ifdef __cplusplus 56 | extern "C" { 57 | #endif 58 | 59 | void pack_cassie_in_t(const cassie_in_t *bus, unsigned char *bytes); 60 | void unpack_cassie_in_t(const unsigned char *bytes, cassie_in_t *bus); 61 | 62 | #ifdef __cplusplus 63 | } 64 | #endif 65 | #endif // CASSIE_IN_T_H 66 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/include/cassie_out_t.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Agility Robotics 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | */ 16 | 17 | #ifndef CASSIE_OUT_T_H 18 | #define CASSIE_OUT_T_H 19 | 20 | #define CASSIE_OUT_T_PACKED_LEN 697 21 | 22 | #include 23 | 24 | typedef short DiagnosticCodes; 25 | 26 | 27 | typedef struct { 28 | bool dataGood; 29 | double stateOfCharge; 30 | double voltage[12]; 31 | double current; 32 | double temperature[4]; 33 | } battery_out_t; 34 | 35 | typedef struct { 36 | double position; 37 | double velocity; 38 | } cassie_joint_out_t; 39 | 40 | typedef struct { 41 | unsigned short statusWord; 42 | double position; 43 | double velocity; 44 | double torque; 45 | double driveTemperature; 46 | double dcLinkVoltage; 47 | double torqueLimit; 48 | double gearRatio; 49 | } elmo_out_t; 50 | 51 | typedef struct { 52 | elmo_out_t hipRollDrive; 53 | elmo_out_t hipYawDrive; 54 | elmo_out_t hipPitchDrive; 55 | elmo_out_t kneeDrive; 56 | elmo_out_t footDrive; 57 | cassie_joint_out_t shinJoint; 58 | cassie_joint_out_t tarsusJoint; 59 | cassie_joint_out_t footJoint; 60 | unsigned char medullaCounter; 61 | unsigned short medullaCpuLoad; 62 | bool reedSwitchState; 63 | } cassie_leg_out_t; 64 | 65 | typedef struct { 66 | bool radioReceiverSignalGood; 67 | bool receiverMedullaSignalGood; 68 | double channel[16]; 69 | } radio_out_t; 70 | 71 | typedef struct { 72 | int etherCatStatus[6]; 73 | int etherCatNotifications[21]; 74 | double taskExecutionTime; 75 | unsigned int overloadCounter; 76 | double cpuTemperature; 77 | } target_pc_out_t; 78 | 79 | typedef struct { 80 | bool dataGood; 81 | unsigned short vpeStatus; 82 | double pressure; 83 | double temperature; 84 | double magneticField[3]; 85 | double angularVelocity[3]; 86 | double linearAcceleration[3]; 87 | double orientation[4]; 88 | } vectornav_out_t; 89 | 90 | typedef struct { 91 | target_pc_out_t targetPc; 92 | battery_out_t battery; 93 | radio_out_t radio; 94 | vectornav_out_t vectorNav; 95 | unsigned char medullaCounter; 96 | unsigned short medullaCpuLoad; 97 | bool bleederState; 98 | bool leftReedSwitchState; 99 | bool rightReedSwitchState; 100 | double vtmTemperature; 101 | } cassie_pelvis_out_t; 102 | 103 | typedef struct { 104 | cassie_pelvis_out_t pelvis; 105 | cassie_leg_out_t leftLeg; 106 | cassie_leg_out_t rightLeg; 107 | bool isCalibrated; 108 | DiagnosticCodes messages[4]; 109 | } cassie_out_t; 110 | 111 | #define EMPTY ((DiagnosticCodes)0) 112 | #define LEFT_HIP_NOT_CALIB ((DiagnosticCodes)5) 113 | #define LEFT_KNEE_NOT_CALIB ((DiagnosticCodes)6) 114 | #define RIGHT_HIP_NOT_CALIB ((DiagnosticCodes)7) 115 | #define RIGHT_KNEE_NOT_CALIB ((DiagnosticCodes)8) 116 | #define LOW_BATTERY_CHARGE ((DiagnosticCodes)200) 117 | #define HIGH_CPU_TEMP ((DiagnosticCodes)205) 118 | #define HIGH_VTM_TEMP ((DiagnosticCodes)210) 119 | #define HIGH_ELMO_DRIVE_TEMP ((DiagnosticCodes)215) 120 | #define HIGH_STATOR_TEMP ((DiagnosticCodes)220) 121 | #define LOW_ELMO_LINK_VOLTAGE ((DiagnosticCodes)221) 122 | #define HIGH_BATTERY_TEMP ((DiagnosticCodes)225) 123 | #define RADIO_DATA_BAD ((DiagnosticCodes)230) 124 | #define RADIO_SIGNAL_BAD ((DiagnosticCodes)231) 125 | #define BMS_DATA_BAD ((DiagnosticCodes)235) 126 | #define VECTORNAV_DATA_BAD ((DiagnosticCodes)236) 127 | #define VPE_GYRO_SATURATION ((DiagnosticCodes)240) 128 | #define VPE_MAG_SATURATION ((DiagnosticCodes)241) 129 | #define VPE_ACC_SATURATION ((DiagnosticCodes)242) 130 | #define VPE_ATTITUDE_BAD ((DiagnosticCodes)245) 131 | #define VPE_ATTITUDE_NOT_TRACKING ((DiagnosticCodes)246) 132 | #define ETHERCAT_DC_ERROR ((DiagnosticCodes)400) 133 | #define ETHERCAT_ERROR ((DiagnosticCodes)410) 134 | #define LOAD_CALIB_DATA_ERROR ((DiagnosticCodes)590) 135 | #define CRITICAL_BATTERY_CHARGE ((DiagnosticCodes)600) 136 | #define CRITICAL_CPU_TEMP ((DiagnosticCodes)605) 137 | #define CRITICAL_VTM_TEMP ((DiagnosticCodes)610) 138 | #define CRITICAL_ELMO_DRIVE_TEMP ((DiagnosticCodes)615) 139 | #define CRITICAL_STATOR_TEMP ((DiagnosticCodes)620) 140 | #define CRITICAL_BATTERY_TEMP ((DiagnosticCodes)625) 141 | #define TORQUE_LIMIT_REACHED ((DiagnosticCodes)630) 142 | #define JOINT_LIMIT_REACHED ((DiagnosticCodes)635) 143 | #define ENCODER_FAILURE ((DiagnosticCodes)640) 144 | #define SPRING_FAILURE ((DiagnosticCodes)645) 145 | #define LEFT_LEG_MEDULLA_HANG ((DiagnosticCodes)700) 146 | #define RIGHT_LEG_MEDULLA_HANG ((DiagnosticCodes)701) 147 | #define PELVIS_MEDULLA_HANG ((DiagnosticCodes)703) 148 | #define CPU_OVERLOAD ((DiagnosticCodes)704) 149 | 150 | #ifdef __cplusplus 151 | extern "C" { 152 | #endif 153 | 154 | void pack_cassie_out_t(const cassie_out_t *bus, unsigned char *bytes); 155 | void unpack_cassie_out_t(const unsigned char *bytes, cassie_out_t *bus); 156 | 157 | #ifdef __cplusplus 158 | } 159 | #endif 160 | #endif // CASSIE_OUT_T_H 161 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/include/cassie_user_in_t.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Agility Robotics 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | */ 16 | 17 | #ifndef CASSIE_USER_IN_T_H 18 | #define CASSIE_USER_IN_T_H 19 | 20 | #define CASSIE_USER_IN_T_PACKED_LEN 58 21 | 22 | #include 23 | 24 | typedef struct { 25 | double torque[10]; 26 | short telemetry[9]; 27 | } cassie_user_in_t; 28 | 29 | 30 | #ifdef __cplusplus 31 | extern "C" { 32 | #endif 33 | 34 | void pack_cassie_user_in_t(const cassie_user_in_t *bus, unsigned char *bytes); 35 | void unpack_cassie_user_in_t(const unsigned char *bytes, cassie_user_in_t *bus); 36 | 37 | #ifdef __cplusplus 38 | } 39 | #endif 40 | #endif // CASSIE_USER_IN_T_H 41 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/include/cassiemujoco.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Dynamic Robotics Laboratory 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | */ 16 | 17 | #ifndef CASSIEMUJOCO_H 18 | #define CASSIEMUJOCO_H 19 | 20 | #include 21 | #include "cassie_out_t.h" 22 | #include "cassie_in_t.h" 23 | #include "cassie_user_in_t.h" 24 | #include "state_out_t.h" 25 | #include "pd_in_t.h" 26 | 27 | 28 | typedef struct cassie_sim cassie_sim_t; 29 | typedef struct cassie_vis cassie_vis_t; 30 | typedef struct cassie_state cassie_state_t; 31 | 32 | 33 | #ifdef __cplusplus 34 | extern "C" { 35 | #endif 36 | 37 | // Pass a null-terminated string containing the path to the directory 38 | // containing cassie.xml, mjpro150/, mjkey.txt, etc. If NULL is 39 | // passed, the directory containing the current executable is used 40 | // instead. Returns true if loading was successful, false otherwise. 41 | bool cassie_mujoco_init(const char *basedir); 42 | 43 | // Unloads the MuJoCo library and Cassie model. After calling this 44 | // function, cassie_mujoco_init can be called again. 45 | void cassie_cleanup(void); 46 | 47 | 48 | /******************************************************************************* 49 | * Cassie simulator functions 50 | ******************************************************************************/ 51 | 52 | // Creates an instance of the Cassie simulator. If called before 53 | // cassie_mujoco_init, cassie_mujoco_init is called with the parameter 54 | // NULL. 55 | cassie_sim_t *cassie_sim_init(void); 56 | 57 | // Creates an instance of the Cassie simulator with the same state as 58 | // an existing instance. 59 | cassie_sim_t *cassie_sim_duplicate(const cassie_sim_t *sim); 60 | 61 | // Copies the state of one Cassie simulator to another. 62 | void cassie_sim_copy(cassie_sim_t *dst, const cassie_sim_t *src); 63 | 64 | // Destroys an instance of the Cassie simulator. 65 | void cassie_sim_free(cassie_sim_t *sim); 66 | 67 | // Simulates one step of the Cassie simulator at the lowest level of 68 | // input and output. Only one cassie_sim_step_* function should be 69 | // called on a given Cassie simulator instance. 70 | void cassie_sim_step_ethercat(cassie_sim_t *sim, cassie_out_t *y, const cassie_in_t *u); 71 | 72 | // Simulates one step of the Cassie simulator including software 73 | // safeties. Only one cassie_sim_step_* function should be called on a 74 | // given Cassie simulator instance. 75 | void cassie_sim_step(cassie_sim_t *sim, cassie_out_t *y, const cassie_user_in_t *u); 76 | 77 | // Simulates one step of the Cassie simulator with PD input and state 78 | // estimator output. Only one cassie_sim_step_* function should be 79 | // called on a given Cassie simulator instance. 80 | void cassie_sim_step_pd(cassie_sim_t *sim, state_out_t *y, const pd_in_t *u); 81 | 82 | // Returns a read-write pointer to the simulator time. 83 | double *cassie_sim_time(cassie_sim_t *sim); 84 | 85 | // Returns a read-write pointer to the simulator joint positions. 86 | // The order of the values are as follows: 87 | // [ 0] Pelvis x 88 | // [ 1] Pelvis y 89 | // [ 2] Pelvis z 90 | // [ 3] Pelvis orientation qw 91 | // [ 4] Pelvis orientation qx 92 | // [ 5] Pelvis orientation qy 93 | // [ 6] Pelvis orientation qz 94 | // [ 7] Left hip roll (Motor [0]) 95 | // [ 8] Left hip yaw (Motor [1]) 96 | // [ 9] Left hip pitch (Motor [2]) 97 | // [10] Left achilles rod qw 98 | // [11] Left achilles rod qx 99 | // [12] Left achilles rod qy 100 | // [13] Left achilles rod qz 101 | // [14] Left knee (Motor [3]) 102 | // [15] Left shin (Joint [0]) 103 | // [16] Left tarsus (Joint [1]) 104 | // [17] Left heel spring 105 | // [18] Left foot crank 106 | // [19] Left plantar rod 107 | // [20] Left foot (Motor [4], Joint [2]) 108 | // [21] Right hip roll (Motor [5]) 109 | // [22] Right hip yaw (Motor [6]) 110 | // [23] Right hip pitch (Motor [7]) 111 | // [24] Right achilles rod qw 112 | // [25] Right achilles rod qx 113 | // [26] Right achilles rod qy 114 | // [27] Right achilles rod qz 115 | // [28] Right knee (Motor [8]) 116 | // [29] Right shin (Joint [3]) 117 | // [30] Right tarsus (Joint [4]) 118 | // [31] Right heel spring 119 | // [32] Right foot crank 120 | // [33] Right plantar rod 121 | // [34] Right foot (Motor [9], Joint [5]) 122 | double *cassie_sim_qpos(cassie_sim_t *sim); 123 | 124 | // Returns a read-write pointer to the simulator joint velocities. 125 | // The order of the values are as follows: 126 | // [ 0] Pelvis x 127 | // [ 1] Pelvis y 128 | // [ 2] Pelvis z 129 | // [ 3] Pelvis orientation wx 130 | // [ 4] Pelvis orientation wy 131 | // [ 5] Pelvis orientation wz 132 | // [ 6] Left hip roll (Motor [0]) 133 | // [ 7] Left hip yaw (Motor [1]) 134 | // [ 8] Left hip pitch (Motor [2]) 135 | // [ 9] Left achilles rod wx 136 | // [10] Left achilles rod wy 137 | // [11] Left achilles rod wz 138 | // [12] Left knee (Motor [3]) 139 | // [13] Left shin (Joint [0]) 140 | // [14] Left tarsus (Joint [1]) 141 | // [15] Left heel spring 142 | // [16] Left foot crank 143 | // [17] Left plantar rod 144 | // [18] Left foot (Motor [4], Joint [2]) 145 | // [19] Right hip roll (Motor [5]) 146 | // [20] Right hip yaw (Motor [6]) 147 | // [21] Right hip pitch (Motor [7]) 148 | // [22] Right achilles rod wx 149 | // [23] Right achilles rod wy 150 | // [24] Right achilles rod wz 151 | // [25] Right knee (Motor [8]) 152 | // [26] Right shin (Joint [3]) 153 | // [27] Right tarsus (Joint [4]) 154 | // [28] Right heel spring 155 | // [29] Right foot crank 156 | // [30] Right plantar rod 157 | // [31] Right foot (Motor [9], Joint [5]) 158 | double *cassie_sim_qvel(cassie_sim_t *sim); 159 | 160 | // Returns the mjModel* used by the simulator 161 | void *cassie_sim_mjmodel(cassie_sim_t *sim); 162 | 163 | // Returns the mjData* used by the simulator 164 | void *cassie_sim_mjdata(cassie_sim_t *sim); 165 | 166 | // Returns true if any of the collision bodies in Cassie are in 167 | // contact with an object with the obstacle class. 168 | bool cassie_sim_check_obstacle_collision(const cassie_sim_t *sim); 169 | 170 | // Returns true if any of the collision bodies in Cassie are in 171 | // contact with each other (i.e. right and left leg collide). 172 | bool cassie_sim_check_self_collision(const cassie_sim_t *sim); 173 | 174 | // Returns the contact forces on the left and right feet 175 | // cfrc[0-2]: Contact force acting on the left foot, in world coordinates 176 | // cfrc[3-5]: Currently zero, reserved for torque acting on the left foot 177 | // cfrc[6-8]: Contact force acting on the left foot, in world coordinates 178 | // cfrc[9-11]: Currently zero, reserved for torque acting on the right foot 179 | void cassie_sim_foot_forces(const cassie_sim_t *c, double cfrc[12]); 180 | 181 | // Applies an external force to a specified body. 182 | void cassie_sim_apply_force(cassie_sim_t *sim, double xfrc[6], int body); 183 | 184 | // Sets all external forces to zero. 185 | void cassie_sim_clear_forces(cassie_sim_t *sim); 186 | 187 | // Holds the pelvis stationary in the current position. 188 | void cassie_sim_hold(cassie_sim_t *sim); 189 | 190 | // Releases a held pelvis. 191 | void cassie_sim_release(cassie_sim_t *sim); 192 | 193 | // Sets the values reported by the radio receiver in Cassie, which 194 | // should be doubles in the range [-1, 1]. Channel 8 must be set to 1 195 | // to enable the motors, which is the default state. 196 | void cassie_sim_radio(cassie_sim_t *sim, double channels[16]); 197 | 198 | 199 | /******************************************************************************* 200 | * Cassie visualizer functions 201 | ******************************************************************************/ 202 | 203 | // Creates an instance of the Cassie simulation visualizer. If called 204 | // before cassie_mujoco_init, cassie_mujoco_init is called with the 205 | // parameter NULL. 206 | cassie_vis_t *cassie_vis_init(void); 207 | 208 | // Closes the visualization window without freeing the instance. After 209 | // calling this, cassie_vis_draw can still be called, but the 210 | // visualizer will remain closed. 211 | void cassie_vis_close(cassie_vis_t *vis); 212 | 213 | // Closes and frees the visualization window. 214 | void cassie_vis_free(cassie_vis_t *vis); 215 | 216 | // Visualizes the state of the given Cassie simulator. 217 | bool cassie_vis_draw(cassie_vis_t *vis, cassie_sim_t *sim); 218 | 219 | // Returns true if the visualizer has been closed but not freed. 220 | bool cassie_vis_valid(cassie_vis_t *vis); 221 | 222 | 223 | /******************************************************************************* 224 | * Cassie simulation state functions 225 | ******************************************************************************/ 226 | 227 | // Allocates storage for a Cassie simulation state object. This allows 228 | // the state of a simulator to be recorded and restored without 229 | // duplicating the entire simulator. A simulation state can only be 230 | // restored to the exact simulator instance it was recorded from. 231 | cassie_state_t *cassie_state_alloc(void); 232 | 233 | // Creates an instance of a simulation state object with the same 234 | // state as an existing instance. 235 | cassie_state_t *cassie_state_duplicate(const cassie_state_t *src); 236 | 237 | // Copies the state of one simulation state object into another. 238 | void cassie_state_copy(cassie_state_t *dst, const cassie_state_t *src); 239 | 240 | // Destroys a Cassie simulation state object 241 | void cassie_state_free(cassie_state_t *state); 242 | 243 | // Returns a read/write pointer to the simulation state time. 244 | double *cassie_state_time(cassie_state_t *state); 245 | 246 | // Returns a read/write pointer to the simulation state joint positions. 247 | double *cassie_state_qpos(cassie_state_t *state); 248 | 249 | // Returns a read/write pointer to the simulation state joint velocities. 250 | double *cassie_state_qvel(cassie_state_t *state); 251 | 252 | // Copies the state of a Cassie simulator into a simulation state object. 253 | void cassie_get_state(const cassie_sim_t *sim, cassie_state_t *state); 254 | 255 | // Copies the state of a simulation state object into a Cassie simulator. 256 | void cassie_set_state(cassie_sim_t *sim, const cassie_state_t *state); 257 | 258 | 259 | #ifdef __cplusplus 260 | } 261 | #endif 262 | 263 | #endif // CASSIEMUJOCO_H 264 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/include/pd_in_t.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Agility Robotics 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | */ 16 | 17 | #ifndef PD_IN_T_H 18 | #define PD_IN_T_H 19 | 20 | #define PD_IN_T_PACKED_LEN 476 21 | 22 | #include 23 | 24 | typedef struct { 25 | double torque[5]; 26 | double pTarget[5]; 27 | double dTarget[5]; 28 | double pGain[5]; 29 | double dGain[5]; 30 | } pd_motor_in_t; 31 | 32 | typedef struct { 33 | double torque[6]; 34 | double pTarget[6]; 35 | double dTarget[6]; 36 | double pGain[6]; 37 | double dGain[6]; 38 | } pd_task_in_t; 39 | 40 | typedef struct { 41 | pd_task_in_t taskPd; 42 | pd_motor_in_t motorPd; 43 | } pd_leg_in_t; 44 | 45 | typedef struct { 46 | pd_leg_in_t leftLeg; 47 | pd_leg_in_t rightLeg; 48 | double telemetry[9]; 49 | } pd_in_t; 50 | 51 | 52 | #ifdef __cplusplus 53 | extern "C" { 54 | #endif 55 | 56 | void pack_pd_in_t(const pd_in_t *bus, unsigned char *bytes); 57 | void unpack_pd_in_t(const unsigned char *bytes, pd_in_t *bus); 58 | 59 | #ifdef __cplusplus 60 | } 61 | #endif 62 | #endif // PD_IN_T_H 63 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/include/state_out_t.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Agility Robotics 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | */ 16 | 17 | #ifndef STATE_OUT_T_H 18 | #define STATE_OUT_T_H 19 | 20 | #define STATE_OUT_T_PACKED_LEN 493 21 | 22 | #include 23 | 24 | typedef struct { 25 | double stateOfCharge; 26 | double current; 27 | } state_battery_out_t; 28 | 29 | typedef struct { 30 | double position[3]; 31 | double orientation[4]; 32 | double footRotationalVelocity[3]; 33 | double footTranslationalVelocity[3]; 34 | double toeForce[3]; 35 | double heelForce[3]; 36 | } state_foot_out_t; 37 | 38 | typedef struct { 39 | double position[6]; 40 | double velocity[6]; 41 | } state_joint_out_t; 42 | 43 | typedef struct { 44 | double position[10]; 45 | double velocity[10]; 46 | double torque[10]; 47 | } state_motor_out_t; 48 | 49 | typedef struct { 50 | double position[3]; 51 | double orientation[4]; 52 | double rotationalVelocity[3]; 53 | double translationalVelocity[3]; 54 | double translationalAcceleration[3]; 55 | double externalMoment[3]; 56 | double externalForce[3]; 57 | } state_pelvis_out_t; 58 | 59 | typedef struct { 60 | double channel[16]; 61 | bool signalGood; 62 | } state_radio_out_t; 63 | 64 | typedef struct { 65 | double height; 66 | double slope[2]; 67 | } state_terrain_out_t; 68 | 69 | typedef struct { 70 | state_pelvis_out_t pelvis; 71 | state_foot_out_t leftFoot; 72 | state_foot_out_t rightFoot; 73 | state_terrain_out_t terrain; 74 | state_motor_out_t motor; 75 | state_joint_out_t joint; 76 | state_radio_out_t radio; 77 | state_battery_out_t battery; 78 | } state_out_t; 79 | 80 | 81 | #ifdef __cplusplus 82 | extern "C" { 83 | #endif 84 | 85 | void pack_state_out_t(const state_out_t *bus, unsigned char *bytes); 86 | void unpack_state_out_t(const unsigned char *bytes, state_out_t *bus); 87 | 88 | #ifdef __cplusplus 89 | } 90 | #endif 91 | #endif // STATE_OUT_T_H 92 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/include/udp.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Agility Robotics 3 | * 4 | * Permission to use, copy, modify, and distribute this software for any 5 | * purpose with or without fee is hereby granted, provided that the above 6 | * copyright notice and this permission notice appear in all copies. 7 | * 8 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 9 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 11 | * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 | * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 14 | * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 | */ 16 | 17 | #ifndef UDP_H 18 | #define UDP_H 19 | 20 | #define PACKET_HEADER_LEN 2 21 | 22 | // Data and results for processing packet header 23 | typedef struct { 24 | char seq_num_out; 25 | char seq_num_in_last; 26 | char delay; 27 | char seq_num_in_diff; 28 | } packet_header_info_t; 29 | 30 | 31 | // Process packet header used to measure delay and skipped packets 32 | void process_packet_header(packet_header_info_t *info, 33 | const unsigned char *header_in, 34 | unsigned char *header_out); 35 | 36 | #ifndef _WIN32 37 | #include 38 | 39 | // Create a UDP socket listening at a specific address/port 40 | int udp_init_host(const char *addr_str, const char *port_str); 41 | 42 | // Create a UDP socket connected and listening to specific addresses/ports 43 | int udp_init_client(const char *remote_addr_str, const char *remote_port_str, 44 | const char *local_addr_str, const char *local_port_str); 45 | 46 | // Close a UDP socket 47 | void udp_close(int sock); 48 | 49 | // Get newest valid packet in RX buffer 50 | ssize_t get_newest_packet(int sock, void *recvbuf, size_t recvlen, 51 | struct sockaddr *src_addr, socklen_t *addrlen); 52 | 53 | // Wait for a new valid packet 54 | ssize_t wait_for_packet(int sock, void *recvbuf, size_t recvlen, 55 | struct sockaddr *src_addr, socklen_t *addrlen); 56 | 57 | // Send a packet 58 | ssize_t send_packet(int sock, void *sendbuf, size_t sendlen, 59 | struct sockaddr *dst_addr, socklen_t addrlen); 60 | 61 | #endif // _WIN32 62 | #endif // UDP_H 63 | -------------------------------------------------------------------------------- /cassie/cassiemujoco/libcassiemujoco.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/cassiemujoco/libcassiemujoco.so -------------------------------------------------------------------------------- /cassie/trajectory/__init__.py: -------------------------------------------------------------------------------- 1 | from .trajectory import * -------------------------------------------------------------------------------- /cassie/trajectory/stepdata.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osudrl/RSS-2020-learning-memory-based-control/99a133c7d53c7caa8187b1c7058dfc5cd9a81507/cassie/trajectory/stepdata.bin -------------------------------------------------------------------------------- /cassie/trajectory/trajectory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | 5 | class CassieTrajectory: 6 | def __init__(self, filepath): 7 | n = 1 + 35 + 32 + 10 + 10 + 10 8 | data = np.fromfile(filepath, dtype=np.double).reshape((-1, n)) 9 | 10 | # states 11 | self.time = data[:, 0] 12 | self.qpos = data[:, 1:36] 13 | self.qvel = data[:, 36:68] 14 | 15 | # actions 16 | self.torque = data[:, 68:78] 17 | self.mpos = data[:, 78:88] 18 | self.mvel = data[:, 88:98] 19 | 20 | def __len__(self): 21 | return len(self.time) 22 | -------------------------------------------------------------------------------- /cassie/udp.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import pickle 4 | import platform 5 | 6 | import sys 7 | import datetime 8 | 9 | import select, termios, tty 10 | 11 | from cassie.cassiemujoco.cassieUDP import * 12 | from cassie.cassiemujoco.cassiemujoco_ctypes import * 13 | 14 | import numpy as np 15 | 16 | import math 17 | import numpy as np 18 | 19 | def inverse_quaternion(quaternion): 20 | result = np.copy(quaternion) 21 | result[1:4] = -result[1:4] 22 | return result 23 | 24 | def quaternion_product(q1, q2): 25 | result = np.zeros(4) 26 | result[0] = q1[0]*q2[0]-q1[1]*q2[1]-q1[2]*q2[2]-q1[3]*q2[3] 27 | result[1] = q1[0]*q2[1]+q2[0]*q1[1]+q1[2]*q2[3]-q1[3]*q2[2] 28 | result[2] = q1[0]*q2[2]-q1[1]*q2[3]+q1[2]*q2[0]+q1[3]*q2[1] 29 | result[3] = q1[0]*q2[3]+q1[1]*q2[2]-q1[2]*q2[1]+q1[3]*q2[0] 30 | return result 31 | 32 | def rotate_by_quaternion(vector, quaternion): 33 | q1 = np.copy(quaternion) 34 | q2 = np.zeros(4) 35 | q2[1:4] = np.copy(vector) 36 | q3 = inverse_quaternion(quaternion) 37 | q = quaternion_product(q2, q3) 38 | q = quaternion_product(q1, q) 39 | result = q[1:4] 40 | return result 41 | 42 | def quaternion2euler(quaternion): 43 | w = quaternion[0] 44 | x = quaternion[1] 45 | y = quaternion[2] 46 | z = quaternion[3] 47 | ysqr = y * y 48 | 49 | t0 = +2.0 * (w * x + y * z) 50 | t1 = +1.0 - 2.0 * (x * x + ysqr) 51 | X = math.degrees(math.atan2(t0, t1)) 52 | 53 | t2 = +2.0 * (w * y - z * x) 54 | t2 = +1.0 if t2 > +1.0 else t2 55 | t2 = -1.0 if t2 < -1.0 else t2 56 | Y = math.degrees(math.asin(t2)) 57 | 58 | t3 = +2.0 * (w * z + x * y) 59 | t4 = +1.0 - 2.0 * (ysqr + z * z) 60 | Z = math.degrees(math.atan2(t3, t4)) 61 | 62 | result = np.zeros(3) 63 | result[0] = X * np.pi / 180 64 | result[1] = Y * np.pi / 180 65 | result[2] = Z * np.pi / 180 66 | 67 | return result 68 | 69 | def euler2quat(z=0, y=0, x=0): 70 | 71 | z = z/2.0 72 | y = y/2.0 73 | x = x/2.0 74 | cz = math.cos(z) 75 | sz = math.sin(z) 76 | cy = math.cos(y) 77 | sy = math.sin(y) 78 | cx = math.cos(x) 79 | sx = math.sin(x) 80 | result = np.array([ 81 | cx*cy*cz - sx*sy*sz, 82 | cx*sy*sz + cy*cz*sx, 83 | cx*cz*sy - sx*cy*sz, 84 | cx*cy*sz + sx*cz*sy]) 85 | if result[0] < 0: 86 | result = -result 87 | return result 88 | 89 | def check_stdin(): 90 | return select.select([sys.stdin], [], [], 0) == ([sys.stdin], [], []) 91 | 92 | def run_udp(args): 93 | from util.env import env_factory 94 | 95 | policy = torch.load(args.policy) 96 | #policy.eval() 97 | 98 | env = env_factory(policy.env_name)() 99 | if not env.state_est: 100 | print("This policy was not trained with state estimation and cannot be run on the robot.") 101 | raise RuntimeError 102 | 103 | print("This policy is: {}".format(policy.__class__.__name__)) 104 | time.sleep(1) 105 | 106 | time_log = [] # time stamp 107 | input_log = [] # network inputs 108 | output_log = [] # network outputs 109 | state_log = [] # cassie state 110 | target_log = [] #PD target log 111 | 112 | clock_based = env.clock 113 | no_delta = env.no_delta 114 | 115 | u = pd_in_t() 116 | for i in range(5): 117 | u.leftLeg.motorPd.pGain[i] = env.P[i] 118 | u.leftLeg.motorPd.dGain[i] = env.D[i] 119 | u.rightLeg.motorPd.pGain[i] = env.P[i] 120 | u.rightLeg.motorPd.dGain[i] = env.D[i] 121 | 122 | if platform.node() == 'cassie': 123 | cassie = CassieUdp(remote_addr='10.10.10.3', remote_port='25010', 124 | local_addr='10.10.10.100', local_port='25011') 125 | else: 126 | cassie = CassieUdp() # local testing 127 | 128 | print('Connecting...') 129 | y = None 130 | while y is None: 131 | cassie.send_pd(pd_in_t()) 132 | time.sleep(0.001) 133 | y = cassie.recv_newest_pd() 134 | 135 | received_data = True 136 | t = time.monotonic() 137 | t0 = t 138 | 139 | print('Connected!\n') 140 | 141 | action = 0 142 | # Whether or not STO has been TOGGLED (i.e. it does not count the initial STO condition) 143 | # STO = True means that STO is ON (i.e. robot is not running) and STO = False means that STO is 144 | # OFF (i.e. robot *is* running) 145 | sto = True 146 | sto_count = 0 147 | 148 | orient_add = 0 149 | 150 | # We have multiple modes of operation 151 | # 0: Normal operation, walking with policy 152 | # 1: Start up, Standing Pose with variable height (no balance) 153 | # 2: Stop Drop and hopefully not roll, Damping Mode with no P gain 154 | operation_mode = 0 155 | standing_height = 0.7 156 | MAX_HEIGHT = 0.8 157 | MIN_HEIGHT = 0.4 158 | D_mult = 1 # Reaaaaaally bad stability problems if this is pushed higher as a multiplier 159 | # Might be worth tuning by joint but something else if probably needed 160 | phase = 0 161 | counter = 0 162 | phase_add = 1 163 | speed = 0 164 | 165 | max_speed = 2 166 | min_speed = -1 167 | max_y_speed = 0.0 168 | min_y_speed = 0.0 169 | 170 | old_settings = termios.tcgetattr(sys.stdin) 171 | 172 | try: 173 | tty.setcbreak(sys.stdin.fileno()) 174 | 175 | while True: 176 | t = time.monotonic() 177 | 178 | tt = time.monotonic() - t0 179 | 180 | # Get newest state 181 | state = cassie.recv_newest_pd() 182 | 183 | if state is None: 184 | print('Missed a cycle! ') 185 | continue 186 | 187 | if platform.node() == 'cassie': 188 | 189 | # Radio control 190 | orient_add -= state.radio.channel[3] / 60.0 191 | 192 | # Reset orientation on STO 193 | if state.radio.channel[8] < 0: 194 | orient_add = quaternion2euler(state.pelvis.orientation[:])[2] 195 | 196 | # Save log files after STO toggle (skipping first STO) 197 | if sto is False: 198 | #log(sto_count) 199 | sto_count += 1 200 | sto = True 201 | # Clear out logs 202 | time_log = [] # time stamp 203 | input_log = [] # network inputs 204 | output_log = [] # network outputs 205 | state_log = [] # cassie state 206 | target_log = [] #PD target log 207 | 208 | if hasattr(policy, 'init_hidden_state'): 209 | print("RESETTING HIDDEN STATES TO ZERO!") 210 | policy.init_hidden_state() 211 | 212 | else: 213 | sto = False 214 | 215 | if state.radio.channel[15] < 0 and hasattr(policy, 'init_hidden_state'): 216 | print("(TOGGLE SWITCH) RESETTING HIDDEN STATES TO ZERO!") 217 | policy.init_hidden_state() 218 | 219 | # Switch the operation mode based on the toggle next to STO 220 | if state.radio.channel[9] < -0.5: # towards operator means damping shutdown mode 221 | operation_mode = 2 222 | elif state.radio.channel[9] > 0.5: # away from the operator means reset states 223 | operation_mode = 1 224 | standing_height = MIN_HEIGHT + (MAX_HEIGHT - MIN_HEIGHT)*0.5*(state.radio.channel[6] + 1) 225 | else: # Middle means normal walking 226 | operation_mode = 0 227 | 228 | curr_max = max_speed / 2 229 | speed_add = (max_speed / 2) * state.radio.channel[4] 230 | speed = max(min_speed, state.radio.channel[0] * curr_max + speed_add) 231 | speed = min(max_speed, state.radio.channel[0] * curr_max + speed_add) 232 | 233 | print('\tCH5: ' + str(state.radio.channel[5])) 234 | phase_add = 1 # + state.radio.channel[5] 235 | else: 236 | # Automatically change orientation and speed 237 | tt = time.monotonic() - t0 238 | 239 | if check_stdin(): 240 | c = sys.stdin.read(1) 241 | if c == 'w': 242 | speed += 0.1 243 | if c == 's': 244 | speed -= 0.1 245 | if c == 'a': 246 | orient_add -= 0.1 247 | if c == 'd': 248 | orient_add += 0.1 249 | if c == 'r': 250 | speed = 0.5 251 | orient_add = 0 252 | 253 | 254 | speed = max(min_speed, speed) 255 | speed = min(max_speed, speed) 256 | 257 | #------------------------------- Normal Walking --------------------------- 258 | if operation_mode == 0: 259 | #print("speed: {:3.2f} | orientation {:3.2f}".format(speed, orient_add), end='\r') 260 | print("\tspeed: {:3.2f} | orientation {:3.2f}".format(speed, orient_add)) 261 | 262 | # Reassign because it might have been changed by the damping mode 263 | for i in range(5): 264 | u.leftLeg.motorPd.pGain[i] = env.P[i] 265 | u.leftLeg.motorPd.dGain[i] = env.D[i] 266 | u.rightLeg.motorPd.pGain[i] = env.P[i] 267 | u.rightLeg.motorPd.dGain[i] = env.D[i] 268 | 269 | clock = [np.sin(2 * np.pi * phase / 27), np.cos(2 * np.pi * phase / 27)] 270 | quaternion = euler2quat(z=orient_add, y=0, x=0) 271 | iquaternion = inverse_quaternion(quaternion) 272 | new_orient = quaternion_product(iquaternion, state.pelvis.orientation[:]) 273 | if new_orient[0] < 0: 274 | new_orient = -new_orient 275 | new_translationalVelocity = rotate_by_quaternion(state.pelvis.translationalVelocity[:], iquaternion) 276 | 277 | ext_state = np.concatenate((clock, [speed])) 278 | robot_state = np.concatenate([ 279 | [state.pelvis.position[2] - state.terrain.height], # pelvis height 280 | new_orient, # pelvis orientation 281 | state.motor.position[:], # actuated joint positions 282 | 283 | new_translationalVelocity[:], # pelvis translational velocity 284 | state.pelvis.rotationalVelocity[:], # pelvis rotational velocity 285 | state.motor.velocity[:], # actuated joint velocities 286 | 287 | state.pelvis.translationalAcceleration[:], # pelvis translational acceleration 288 | 289 | state.joint.position[:], # unactuated joint positions 290 | state.joint.velocity[:] # unactuated joint velocities 291 | ]) 292 | RL_state = np.concatenate([robot_state, ext_state]) 293 | 294 | #pretending the height is always 1.0 295 | #RL_state[0] = 1.0 296 | 297 | # Construct input vector 298 | torch_state = torch.Tensor(RL_state) 299 | torch_state = policy.normalize_state(torch_state, update=False) 300 | 301 | if no_delta: 302 | offset = env.offset 303 | else: 304 | offset = env.get_ref_state(phase=phase) 305 | 306 | action = policy(torch_state) 307 | env_action = action.data.numpy() 308 | target = env_action + offset 309 | 310 | # Send action 311 | for i in range(5): 312 | u.leftLeg.motorPd.pTarget[i] = target[i] 313 | u.rightLeg.motorPd.pTarget[i] = target[i+5] 314 | cassie.send_pd(u) 315 | 316 | # Logging 317 | if sto == False: 318 | time_log.append(time.time()) 319 | state_log.append(state) 320 | input_log.append(RL_state) 321 | output_log.append(env_action) 322 | target_log.append(target) 323 | #------------------------------- Start Up Standing --------------------------- 324 | elif operation_mode == 1: 325 | print('Startup Standing. Height = ' + str(standing_height)) 326 | #Do nothing 327 | # Reassign with new multiplier on damping 328 | for i in range(5): 329 | u.leftLeg.motorPd.pGain[i] = 0.0 330 | u.leftLeg.motorPd.dGain[i] = 0.0 331 | u.rightLeg.motorPd.pGain[i] = 0.0 332 | u.rightLeg.motorPd.dGain[i] = 0.0 333 | 334 | # Send action 335 | for i in range(5): 336 | u.leftLeg.motorPd.pTarget[i] = 0.0 337 | u.rightLeg.motorPd.pTarget[i] = 0.0 338 | cassie.send_pd(u) 339 | 340 | #------------------------------- Shutdown Damping --------------------------- 341 | elif operation_mode == 2: 342 | 343 | print('Shutdown Damping. Multiplier = ' + str(D_mult)) 344 | # Reassign with new multiplier on damping 345 | for i in range(5): 346 | u.leftLeg.motorPd.pGain[i] = 0.0 347 | u.leftLeg.motorPd.dGain[i] = D_mult*env.D[i] 348 | u.rightLeg.motorPd.pGain[i] = 0.0 349 | u.rightLeg.motorPd.dGain[i] = D_mult*env.D[i] 350 | 351 | # Send action 352 | for i in range(5): 353 | u.leftLeg.motorPd.pTarget[i] = 0.0 354 | u.rightLeg.motorPd.pTarget[i] = 0.0 355 | cassie.send_pd(u) 356 | 357 | #---------------------------- Other, should not happen ----------------------- 358 | else: 359 | print('Error, In bad operation_mode with value: ' + str(operation_mode)) 360 | 361 | # Measure delay 362 | # Wait until next cycle time 363 | while time.monotonic() - t < 60/2000: 364 | time.sleep(0.001) 365 | print('\tdelay: {:6.1f} ms'.format((time.monotonic() - t) * 1000)) 366 | 367 | # Track phase 368 | phase += phase_add 369 | if phase >= 28: 370 | phase = 0 371 | counter += 1 372 | finally: 373 | termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) 374 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys, argparse, time, os 3 | parser = argparse.ArgumentParser() 4 | print("RSS 2020: Learning Memory-Based Control for Human-Scale Bipedal Locomotion") 5 | print("\tJonah Siekmann, Srikar Valluri, Jeremy Dao, Lorenzo Bermillo, Helei Duan, Alan Fern, Jonathan Hurst") 6 | 7 | if len(sys.argv) < 2: 8 | print("Usage: python main.py [option]", sys.argv) 9 | print("\t potential options are: 'ppo', 'extract', 'eval', 'cassie'") 10 | exit(1) 11 | 12 | option = sys.argv[1] 13 | sys.argv.remove(sys.argv[1]) 14 | 15 | if option == 'eval': 16 | from util import eval_policy 17 | import torch 18 | 19 | model = sys.argv[1] 20 | sys.argv.remove(sys.argv[1]) 21 | 22 | parser.add_argument("--traj_len", default=300, type=int) 23 | args = parser.parse_args() 24 | 25 | model = torch.load(model) 26 | 27 | eval_policy(model, max_traj_len=args.traj_len, visualize=True, verbose=True) 28 | exit() 29 | 30 | if option == 'cassie': 31 | from cassie.udp import run_udp 32 | 33 | policies = sys.argv[1:] 34 | 35 | run_udp(policies) 36 | exit() 37 | 38 | if option == 'extract': 39 | from algos.extract import run_experiment 40 | 41 | parser.add_argument("--policy", "-p", default=None, type=str) 42 | parser.add_argument("--layers", default="256,256", type=str) 43 | parser.add_argument("--logdir", default='logs/extract', type=str) 44 | 45 | parser.add_argument("--workers", default=4, type=int) 46 | parser.add_argument("--points", default=5000, type=int) 47 | parser.add_argument("--batch_size", default=16, type=int) 48 | parser.add_argument("--epochs", default=500, type=int) 49 | 50 | parser.add_argument("--lr", default=1e-5, type=float) 51 | args = parser.parse_args() 52 | if args.policy is None: 53 | print("Please provide a --policy argument.") 54 | exit(1) 55 | run_experiment(args) 56 | exit() 57 | 58 | 59 | # Options common to all RL algorithms. 60 | elif option == 'ppo': 61 | """ 62 | Utility for running Proximal Policy Optimization. 63 | 64 | """ 65 | from algos.ppo import run_experiment 66 | parser.add_argument("--timesteps", default=1e6, type=float) # timesteps to run experiment for 67 | parser.add_argument('--discount', default=0.99, type=float) # the discount factor 68 | parser.add_argument('--std', default=0.13, type=float) # the fixed exploration std 69 | parser.add_argument("--a_lr", default=1e-4, type=float) # adam learning rate for actor 70 | parser.add_argument("--c_lr", default=1e-4, type=float) # adam learning rate for critic 71 | parser.add_argument("--eps", default=1e-6, type=float) # adam eps 72 | parser.add_argument("--kl", default=0.02, type=float) # kl abort threshold 73 | parser.add_argument("--grad_clip", default=0.05, type=float) # gradient norm clip 74 | 75 | parser.add_argument("--batch_size", default=64, type=int) # batch size for policy update 76 | parser.add_argument("--epochs", default=3, type=int) # number of updates per iter 77 | parser.add_argument("--workers", default=4, type=int) # how many workers to use for exploring in parallel 78 | parser.add_argument("--seed", default=0, type=int) # random seed for reproducibility 79 | parser.add_argument("--traj_len", default=1000, type=int) # max trajectory length for environment 80 | parser.add_argument("--prenormalize_steps", default=10000, type=int) # number of samples to get normalization stats 81 | parser.add_argument("--sample", default=5000, type=int) # how many samples to do every iteration 82 | 83 | parser.add_argument("--layers", default="128,128", type=str) # hidden layer sizes in policy 84 | parser.add_argument("--save_actor", default=None, type=str) # where to save the actor (default=logdir) 85 | parser.add_argument("--save_critic", default=None, type=str) # where to save the critic (default=logdir) 86 | parser.add_argument("--logdir", default="./logs/ppo/", type=str) # where to store log information 87 | parser.add_argument("--nolog", action='store_true') # store log data or not. 88 | parser.add_argument("--recurrent", action='store_true') # recurrent policy or not 89 | parser.add_argument("--randomize", action='store_true') # randomize dynamics or not 90 | args = parser.parse_args() 91 | 92 | run_experiment(args) 93 | 94 | elif option == 'pca': 95 | from algos.pca import run_pca 96 | import torch 97 | model = sys.argv[1] 98 | sys.argv.remove(sys.argv[1]) 99 | 100 | 101 | model = torch.load(model) 102 | 103 | run_pca(model) 104 | exit() 105 | 106 | else: 107 | print("Invalid option '{}'".format(option)) 108 | -------------------------------------------------------------------------------- /nn/actor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from nn.base import FF_Base, LSTM_Base 6 | 7 | class Stochastic_Actor: 8 | """ 9 | The base class for stochastic actors. 10 | """ 11 | def __init__(self, latent, action_dim, dynamics_randomization, fixed_std): 12 | 13 | self.action_dim = action_dim 14 | self.dynamics_randomization = dynamics_randomization 15 | self.means = nn.Linear(latent, action_dim) 16 | self.nn_type = 'policy' 17 | 18 | self.fixed_std = fixed_std 19 | 20 | def _get_dist_params(self, state, update=False): 21 | state = self.normalize_state(state, update=update) 22 | x = self._base_forward(state) 23 | 24 | mu = self.means(x) 25 | 26 | std = self.fixed_std 27 | 28 | return mu, std 29 | 30 | def stochastic_forward(self, state, deterministic=True, update=False, log_probs=False): 31 | mu, sd = self._get_dist_params(state, update=update) 32 | 33 | if not deterministic or log_probs: 34 | dist = torch.distributions.Normal(mu, sd) 35 | sample = dist.rsample() 36 | 37 | action = mu if deterministic else sample 38 | 39 | return action 40 | 41 | def pdf(self, state): 42 | mu, sd = self._get_dist_params(state) 43 | return torch.distributions.Normal(mu, sd) 44 | 45 | 46 | class FF_Stochastic_Actor(FF_Base, Stochastic_Actor): 47 | """ 48 | A class inheriting from FF_Base and Stochastic_Actor 49 | which implements a feedforward stochastic policy. 50 | """ 51 | def __init__(self, input_dim, action_dim, layers=(256, 256), dynamics_randomization=False, nonlinearity=torch.tanh, fixed_std=None): 52 | 53 | FF_Base.__init__(self, input_dim, layers, nonlinearity) 54 | Stochastic_Actor.__init__(self, layers[-1], action_dim, dynamics_randomization, fixed_std) 55 | 56 | def forward(self, x, deterministic=True, update_norm=False, return_log_probs=False): 57 | return self.stochastic_forward(x, deterministic=deterministic, update=update_norm, log_probs=return_log_probs) 58 | 59 | 60 | class LSTM_Stochastic_Actor(LSTM_Base, Stochastic_Actor): 61 | """ 62 | A class inheriting from LSTM_Base and Stochastic_Actor 63 | which implements a recurrent stochastic policy. 64 | """ 65 | def __init__(self, input_dim, action_dim, layers=(128, 128), dynamics_randomization=False, fixed_std=None): 66 | 67 | LSTM_Base.__init__(self, input_dim, layers) 68 | Stochastic_Actor.__init__(self, layers[-1], action_dim, dynamics_randomization, fixed_std) 69 | 70 | self.is_recurrent = True 71 | self.init_hidden_state() 72 | 73 | def forward(self, x, deterministic=True, update_norm=False, return_log_probs=False): 74 | return self.stochastic_forward(x, deterministic=deterministic, update=update_norm, log_probs=return_log_probs) 75 | -------------------------------------------------------------------------------- /nn/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from torch import sqrt 7 | 8 | def create_layers(layer_fn, input_dim, layer_sizes): 9 | """ 10 | This function creates a pytorch modulelist and appends 11 | pytorch modules like nn.Linear or nn.LSTMCell passed 12 | in through the layer_fn argument, using the sizes 13 | specified in the layer_sizes list. 14 | """ 15 | ret = nn.ModuleList() 16 | ret += [layer_fn(input_dim, layer_sizes[0])] 17 | for i in range(len(layer_sizes)-1): 18 | ret += [layer_fn(layer_sizes[i], layer_sizes[i+1])] 19 | return ret 20 | 21 | class Net(nn.Module): 22 | """ 23 | The base class which all policy networks inherit from. It includes methods 24 | for normalizing states. 25 | """ 26 | def __init__(self): 27 | super(Net, self).__init__() 28 | #nn.Module.__init__(self) 29 | self.is_recurrent = False 30 | 31 | self.state_mean = torch.zeros(1) 32 | self.state_mean_diff = torch.ones(1) 33 | self.state_n = 1 34 | 35 | self.env_name = None 36 | 37 | self.calculate_norm = False 38 | 39 | def normalize_state(self, state, update=True): 40 | """ 41 | Use Welford's algorithm to normalize a state, and optionally update the statistics 42 | for normalizing states using the new state, online. 43 | """ 44 | state = torch.Tensor(state) 45 | 46 | if self.state_n == 1: 47 | self.state_mean = torch.zeros(state.size(-1)) 48 | self.state_mean_diff = torch.ones(state.size(-1)) 49 | 50 | if update: 51 | if len(state.size()) == 1: # if we get a single state vector 52 | state_old = self.state_mean 53 | self.state_mean += (state - state_old) / self.state_n 54 | self.state_mean_diff += (state - state_old) * (state - state_old) 55 | self.state_n += 1 56 | else: 57 | raise RuntimeError # this really should not happen 58 | return (state - self.state_mean) / sqrt(self.state_mean_diff / self.state_n) 59 | 60 | def copy_normalizer_stats(self, net): 61 | self.state_mean = net.state_mean 62 | self.state_mean_diff = net.state_mean_diff 63 | self.state_n = net.state_n 64 | 65 | class FF_Base(Net): 66 | """ 67 | The base class for feedforward networks. 68 | """ 69 | def __init__(self, in_dim, layers, nonlinearity): 70 | super(FF_Base, self).__init__() 71 | self.layers = create_layers(nn.Linear, in_dim, layers) 72 | self.nonlinearity = nonlinearity 73 | 74 | def _base_forward(self, x): 75 | for idx, layer in enumerate(self.layers): 76 | x = self.nonlinearity(layer(x)) 77 | return x 78 | 79 | class LSTM_Base(Net): 80 | """ 81 | The base class for LSTM networks. 82 | """ 83 | def __init__(self, in_dim, layers): 84 | super(LSTM_Base, self).__init__() 85 | self.layers = create_layers(nn.LSTMCell, in_dim, layers) 86 | 87 | def init_hidden_state(self, batch_size=1): 88 | self.hidden = [torch.zeros(batch_size, l.hidden_size) for l in self.layers] 89 | self.cells = [torch.zeros(batch_size, l.hidden_size) for l in self.layers] 90 | 91 | def _base_forward(self, x): 92 | dims = len(x.size()) 93 | 94 | if dims == 3: # if we get a batch of trajectories 95 | self.init_hidden_state(batch_size=x.size(1)) 96 | 97 | if self.calculate_norm: 98 | self.latent_norm = 0 99 | 100 | y = [] 101 | for t, x_t in enumerate(x): 102 | for idx, layer in enumerate(self.layers): 103 | c, h = self.cells[idx], self.hidden[idx] 104 | self.hidden[idx], self.cells[idx] = layer(x_t, (h, c)) 105 | x_t = self.hidden[idx] 106 | 107 | if self.calculate_norm: 108 | self.latent_norm += (torch.mean(torch.abs(x_t)) + torch.mean(torch.abs(self.cells[idx]))) 109 | 110 | y.append(x_t) 111 | x = torch.stack([x_t for x_t in y]) 112 | 113 | if self.calculate_norm: 114 | self.latent_norm /= len(x) * len(self.layers) 115 | 116 | else: 117 | if dims == 1: # if we get a single timestep (if not, assume we got a batch of single timesteps) 118 | x = x.view(1, -1) 119 | 120 | for idx, layer in enumerate(self.layers): 121 | h, c = self.hidden[idx], self.cells[idx] 122 | self.hidden[idx], self.cells[idx] = layer(x, (h, c)) 123 | x = self.hidden[idx] 124 | 125 | if dims == 1: 126 | x = x.view(-1) 127 | return x 128 | -------------------------------------------------------------------------------- /nn/critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from nn.base import FF_Base, LSTM_Base 6 | 7 | class V: 8 | """ 9 | The base class for Value functions. 10 | """ 11 | def __init__(self, latent, env_name): 12 | self.env_name = env_name 13 | self.network_out = nn.Linear(latent, 1) 14 | 15 | def v_forward(self, state, update=False): 16 | state = self.normalize_state(state, update=update) 17 | x = self._base_forward(state) 18 | return self.network_out(x) 19 | 20 | 21 | class FF_V(FF_Base, V): 22 | """ 23 | A class inheriting from FF_Base and V 24 | which implements a feedforward value function. 25 | """ 26 | def __init__(self, input_dim, layers=(256, 256), env_name=None): 27 | FF_Base.__init__(self, input_dim, layers, F.relu) 28 | V.__init__(self, layers[-1], env_name) 29 | 30 | def forward(self, state): 31 | return self.v_forward(state) 32 | 33 | 34 | class LSTM_V(LSTM_Base, V): 35 | """ 36 | A class inheriting from LSTM_Base and V 37 | which implements a recurrent value function. 38 | """ 39 | def __init__(self, input_dim, layers=(128, 128), env_name=None): 40 | LSTM_Base.__init__(self, input_dim, layers) 41 | V.__init__(self, layers[-1], env_name) 42 | 43 | self.is_recurrent = True 44 | self.init_hidden_state() 45 | 46 | def forward(self, state): 47 | return self.v_forward(state) 48 | 49 | -------------------------------------------------------------------------------- /nn/fit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from nn.base import FF_Base 6 | 7 | class Model(FF_Base): 8 | """ 9 | A very simple feedforward network to be used for 10 | vanilla supervised learning problems. 11 | """ 12 | def __init__(self, state_dim, output_dim, layers=(512,256), nonlinearity=torch.tanh): 13 | super(Model, self).__init__(state_dim, layers, nonlinearity) 14 | 15 | self.network_out = nn.Linear(layers[-1], output_dim) 16 | self.output_dim = output_dim 17 | self.nonlinearity = nonlinearity 18 | self.nn_type = 'extractor' 19 | 20 | def forward(self, x): 21 | x = self._base_forward(x) 22 | return self.network_out(x) 23 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hashlib 3 | import os 4 | import numpy as np 5 | from collections import OrderedDict 6 | 7 | def create_logger(args): 8 | from torch.utils.tensorboard import SummaryWriter 9 | """Use hyperparms to set a directory to output diagnostic files.""" 10 | 11 | arg_dict = args.__dict__ 12 | assert "logdir" in arg_dict, \ 13 | "You must provide a 'logdir' key in your command line arguments." 14 | 15 | # sort the keys so the same hyperparameters will always have the same hash 16 | arg_dict = OrderedDict(sorted(arg_dict.items(), key=lambda t: t[0])) 17 | 18 | # remove seed so it doesn't get hashed, store value for filename 19 | # same for logging directory 20 | if 'seed' in arg_dict: 21 | seed = str(arg_dict.pop("seed")) 22 | else: 23 | seed = None 24 | 25 | logdir = str(arg_dict.pop('logdir')) 26 | 27 | # get a unique hash for the hyperparameter settings, truncated at 10 chars 28 | if seed is None: 29 | arg_hash = hashlib.md5(str(arg_dict).encode('ascii')).hexdigest()[0:6] 30 | else: 31 | arg_hash = hashlib.md5(str(arg_dict).encode('ascii')).hexdigest()[0:6] + '-seed' + seed 32 | 33 | output_dir = os.path.join(logdir, arg_hash) 34 | 35 | # create a directory with the hyperparm hash as its name, if it doesn't 36 | # already exist. 37 | os.makedirs(output_dir, exist_ok=True) 38 | 39 | # Create a file with all the hyperparam settings in plaintext 40 | info_path = os.path.join(output_dir, "experiment.info") 41 | file = open(info_path, 'w') 42 | for key, val in arg_dict.items(): 43 | file.write("%s: %s" % (key, val)) 44 | file.write('\n') 45 | 46 | logger = SummaryWriter(output_dir, flush_secs=0.1) 47 | logger.dir = output_dir 48 | logger.arg_hash = arg_hash 49 | return logger 50 | 51 | def train_normalizer(policy, min_timesteps, max_traj_len=1000, noise=0.5): 52 | with torch.no_grad(): 53 | env = env_factory(policy.env_name)() 54 | env.dynamics_randomization = False 55 | 56 | total_t = 0 57 | while total_t < min_timesteps: 58 | state = env.reset() 59 | done = False 60 | timesteps = 0 61 | 62 | if hasattr(policy, 'init_hidden_state'): 63 | policy.init_hidden_state() 64 | 65 | while not done and timesteps < max_traj_len: 66 | action = policy.forward(state, update_norm=True).numpy() + np.random.normal(0, noise, size=policy.action_dim) 67 | state, _, done, _ = env.step(action) 68 | timesteps += 1 69 | total_t += 1 70 | 71 | def eval_policy(model, env=None, episodes=5, max_traj_len=400, verbose=True, visualize=False): 72 | if env is None: 73 | env = env_factory(False)() 74 | 75 | if model.nn_type == 'policy': 76 | policy = model 77 | elif model.nn_type == 'extractor': 78 | policy = torch.load(model.policy_path) 79 | 80 | with torch.no_grad(): 81 | steps = 0 82 | ep_returns = [] 83 | for _ in range(episodes): 84 | env.dynamics_randomization = False 85 | state = torch.Tensor(env.reset()) 86 | done = False 87 | traj_len = 0 88 | ep_return = 0 89 | 90 | if hasattr(policy, 'init_hidden_state'): 91 | policy.init_hidden_state() 92 | 93 | while not done and traj_len < max_traj_len: 94 | action = policy(state) 95 | env.speed = 1 96 | next_state, reward, done, _ = env.step(action.numpy()) 97 | if visualize: 98 | env.render() 99 | state = torch.Tensor(next_state) 100 | ep_return += reward 101 | traj_len += 1 102 | steps += 1 103 | 104 | if model.nn_type == 'extractor': 105 | pass 106 | 107 | ep_returns += [ep_return] 108 | if verbose: 109 | print('Return: {:6.2f}'.format(ep_return)) 110 | 111 | return np.mean(ep_returns) 112 | 113 | def env_factory(dynamics_randomization, verbose=False, **kwargs): 114 | from functools import partial 115 | 116 | """ 117 | Returns an *uninstantiated* environment constructor. 118 | 119 | Since environments containing cpointers (e.g. Mujoco envs) can't be serialized, 120 | this allows us to pass their constructors to Ray remote functions instead 121 | 122 | """ 123 | from cassie.cassie import CassieEnv 124 | 125 | if verbose: 126 | print("Created cassie env with arguments:") 127 | print("\tdynamics randomization: {}".format(dynamics_randomization)) 128 | return partial(CassieEnv, dynamics_randomization=dynamics_randomization) 129 | --------------------------------------------------------------------------------