├── constants.py ├── problem_gen ├── cartpole │ ├── A.npy │ ├── B.npy │ ├── C.npy │ ├── D.npy │ └── G.npy └── quadrotor │ ├── A.npy │ ├── C.npy │ ├── G.npy │ └── linearize_dynamics.py ├── rl ├── arguments.py ├── utils.py ├── distributions.py ├── ppo.py ├── trainer.py ├── rarl_ppo.py ├── model.py └── storage.py ├── sqrtm.py ├── envs ├── random_hinf_env.py ├── random_pldi_env.py ├── random_nldi_env.py ├── ode_env.py ├── rl_wrapper.py ├── microgrid.py ├── quadrotor_env.py └── cartpole.py ├── .gitignore ├── README.md ├── robust_mpc.py ├── disturb_models.py ├── LICENSE ├── plots.py ├── policy_models.py └── main.py /constants.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | NUMPY_DTYPE = np.float64 5 | TORCH_DTYPE = torch.float64 6 | -------------------------------------------------------------------------------- /problem_gen/cartpole/A.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/robust-nn-control/HEAD/problem_gen/cartpole/A.npy -------------------------------------------------------------------------------- /problem_gen/cartpole/B.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/robust-nn-control/HEAD/problem_gen/cartpole/B.npy -------------------------------------------------------------------------------- /problem_gen/cartpole/C.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/robust-nn-control/HEAD/problem_gen/cartpole/C.npy -------------------------------------------------------------------------------- /problem_gen/cartpole/D.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/robust-nn-control/HEAD/problem_gen/cartpole/D.npy -------------------------------------------------------------------------------- /problem_gen/cartpole/G.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/robust-nn-control/HEAD/problem_gen/cartpole/G.npy -------------------------------------------------------------------------------- /problem_gen/quadrotor/A.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/robust-nn-control/HEAD/problem_gen/quadrotor/A.npy -------------------------------------------------------------------------------- /problem_gen/quadrotor/C.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/robust-nn-control/HEAD/problem_gen/quadrotor/C.npy -------------------------------------------------------------------------------- /problem_gen/quadrotor/G.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/robust-nn-control/HEAD/problem_gen/quadrotor/G.npy -------------------------------------------------------------------------------- /rl/arguments.py: -------------------------------------------------------------------------------- 1 | 2 | def get_args(): 3 | args = type('', (), {})() 4 | 5 | args.num_env_steps = 1e7 6 | args.num_processes = 8 7 | args.gamma = 0.99 8 | 9 | args.clip_param = 0.1 10 | args.ppo_epoch = 4 11 | args.num_mini_batch = 4 12 | args.value_loss_coef = 0.5 13 | args.entropy_coef = 0.01 14 | args.lr = 2.5e-4 15 | args.rms_prop_eps = 1e-5 16 | args.max_grad_norm = 0.5 17 | 18 | args.use_linear_lr_decay = True 19 | args.use_gae = True 20 | args.gae_lambda = 0.95 21 | args.use_proper_time_limits = False 22 | args.save_interval = 100 23 | args.log_interval = 10 24 | args.eval_interval = 100 25 | 26 | return args 27 | -------------------------------------------------------------------------------- /rl/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | # Necessary for my KFAC implementation. 9 | class AddBias(nn.Module): 10 | def __init__(self, bias): 11 | super(AddBias, self).__init__() 12 | self._bias = nn.Parameter(bias.unsqueeze(1)) 13 | 14 | def forward(self, x): 15 | if x.dim() == 2: 16 | bias = self._bias.t().view(1, -1) 17 | else: 18 | bias = self._bias.t().view(1, -1, 1, 1) 19 | 20 | return x + bias 21 | 22 | 23 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 24 | """Decreases the learning rate linearly""" 25 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 26 | for param_group in optimizer.param_groups: 27 | param_group['lr'] = lr 28 | 29 | 30 | def init(module, weight_init, bias_init, gain=1): 31 | weight_init(module.weight.data, gain=gain) 32 | bias_init(module.bias.data) 33 | return module 34 | 35 | 36 | def cleanup_log_dir(log_dir): 37 | try: 38 | os.makedirs(log_dir) 39 | except OSError: 40 | files = glob.glob(os.path.join(log_dir, '*.monitor.csv')) 41 | for f in files: 42 | os.remove(f) 43 | -------------------------------------------------------------------------------- /sqrtm.py: -------------------------------------------------------------------------------- 1 | ##### From https://github.com/steveli/pytorch-sqrtm 2 | 3 | import torch 4 | from torch.autograd import Function 5 | import numpy as np 6 | import scipy.linalg 7 | 8 | 9 | class MatrixSquareRoot(Function): 10 | """Square root of a positive definite matrix. 11 | 12 | NOTE: matrix square root is not differentiable for matrices with 13 | zero eigenvalues. 14 | """ 15 | @staticmethod 16 | def forward(ctx, input): 17 | m = input.detach().cpu().numpy().astype(np.float_) 18 | sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m).real).to(input) 19 | ctx.save_for_backward(sqrtm) 20 | return sqrtm 21 | 22 | @staticmethod 23 | def backward(ctx, grad_output): 24 | grad_input = None 25 | if ctx.needs_input_grad[0]: 26 | sqrtm, = ctx.saved_tensors 27 | sqrtm = sqrtm.data.cpu().numpy().astype(np.float_) 28 | gm = grad_output.data.cpu().numpy().astype(np.float_) 29 | 30 | # Given a positive semi-definite matrix X, 31 | # since X = X^{1/2}X^{1/2}, we can compute the gradient of the 32 | # matrix square root dX^{1/2} by solving the Sylvester equation: 33 | # dX = (d(X^{1/2})X^{1/2} + X^{1/2}(dX^{1/2}). 34 | grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm) 35 | 36 | grad_input = torch.from_numpy(grad_sqrtm).to(grad_output) 37 | return grad_input 38 | 39 | 40 | sqrtm = MatrixSquareRoot.apply 41 | 42 | 43 | def main(): 44 | from torch.autograd import gradcheck 45 | k = torch.randn(20, 10).double() 46 | # Create a positive definite matrix 47 | pd_mat = (k.t().matmul(k)).requires_grad_() 48 | test = gradcheck(sqrtm, (pd_mat,)) 49 | print(test) 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /envs/random_hinf_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from envs import ode_env 6 | import disturb_models as dm 7 | from constants import * 8 | 9 | 10 | class RandomHinfEnv(ode_env.HinfEnv): 11 | 12 | def __init__(self, n=5, m=3, wp=2, T=2, random_seed=None, device=None): 13 | if random_seed is not None: 14 | np.random.seed(random_seed) 15 | torch.manual_seed(random_seed+1) 16 | 17 | self.n, self.m, self.wp = n, m, wp 18 | 19 | self.A = torch.tensor(np.random.randn(n, n), dtype=TORCH_DTYPE, device=device) 20 | self.B = torch.tensor(np.random.randn(n, m), dtype=TORCH_DTYPE, device=device) 21 | self.G = torch.tensor(1.5*np.random.randn(n, wp), dtype=TORCH_DTYPE, device=device) 22 | 23 | Q = np.random.randn(n, n) 24 | Q = Q.T @ Q 25 | self.Q = torch.tensor(Q, dtype=TORCH_DTYPE, device=device) 26 | 27 | R = np.random.randn(m, m) 28 | R = R.T @ R 29 | self.R = torch.tensor(R, dtype=TORCH_DTYPE, device=device) 30 | 31 | self.disturb_f = dm.HinfDisturbModel(n, m, wp, T) 32 | self.adversarial_disturb_f = None 33 | 34 | if device is not None: 35 | self.disturb_f.to(device=device, dtype=TORCH_DTYPE) 36 | 37 | def xdot_f(self, x, u, t): 38 | w = self.disturb_f(x, u, t) 39 | return x @ self.A.T + u @ self.B.T + w @ self.G.T 40 | 41 | def xdot_adversarial_f(self, x, u, t): 42 | if self.adversarial_disturb_f is None: 43 | raise ValueError('You must initialize adversarial_disturb_f before running in adversarial mode') 44 | w = self.adversarial_disturb_f(x, u, t) 45 | return x @ self.A.T + u @ self.B.T + w @ self.G.T 46 | 47 | def cost_f(self, x, u, t): 48 | return ((x @ self.Q) * x).sum(-1) + ((u @ self.R) * u).sum(-1) 49 | 50 | def get_hinf_linearization(self): 51 | return self.A, self.B, self.G, self.Q, self.R 52 | 53 | def gen_states(self, num_states, device=None): 54 | return torch.randn((num_states, self.n), device=device, dtype=TORCH_DTYPE) 55 | 56 | def __copy__(self): 57 | new_env = RandomHinfEnv.__new__(RandomHinfEnv) 58 | new_env.__dict__.update(self.__dict__) 59 | return new_env 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python 3 | # Edit at https://www.gitignore.io/?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # Jupyter notebooks 12 | .ipynb_checkpoints 13 | 14 | # Mac 15 | .DS_Store 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # pipenv 81 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 82 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 83 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 84 | # install all needed dependencies. 85 | #Pipfile.lock 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # Mr Developer 101 | .mr.developer.cfg 102 | .project 103 | .pydevproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre/ 115 | 116 | # End of https://www.gitignore.io/api/python 117 | -------------------------------------------------------------------------------- /envs/random_pldi_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from envs import ode_env 6 | import disturb_models as dm 7 | from constants import * 8 | 9 | 10 | class RandomPLDIEnv(ode_env.PLDIEnv): 11 | 12 | def __init__(self, n=5, m=3, L=3, random_seed=None, device=None): 13 | if random_seed is not None: 14 | np.random.seed(random_seed) 15 | torch.manual_seed(random_seed+1) 16 | 17 | self.n, self.m, self.L = n, m, L 18 | 19 | self.A = 3 * torch.tensor(np.random.randn(1, n, n) + 0.5 * np.random.randn(L, n, n), dtype=TORCH_DTYPE, device=device) 20 | self.B = 3 * torch.tensor(np.random.randn(1, n, m) + 0.5 * np.random.randn(L, n, m), dtype=TORCH_DTYPE, device=device) 21 | 22 | Q = np.random.randn(n, n) 23 | Q = Q.T @ Q 24 | self.Q = torch.tensor(Q, dtype=TORCH_DTYPE, device=device) 25 | 26 | R = np.random.randn(m, m) 27 | R = R.T @ R 28 | self.R = torch.tensor(R, dtype=TORCH_DTYPE, device=device) 29 | 30 | self.disturb_f = dm.PLDIDisturbModel(n, m, L) 31 | self.adversarial_disturb_f = None 32 | 33 | if device is not None: 34 | self.disturb_f.to(device=device, dtype=TORCH_DTYPE) 35 | 36 | def xdot_f(self, x, u, t): 37 | a = self.disturb_f(x, u, t) 38 | A = (self.A.unsqueeze(0) * a[:, :, None, None]).sum(1) 39 | B = (self.B.unsqueeze(0) * a[:, :, None, None]).sum(1) 40 | return (A @ x.unsqueeze(2) + B @ u.unsqueeze(2)).squeeze() 41 | 42 | def xdot_adversarial_f(self, x, u, t): 43 | if self.adversarial_disturb_f is None: 44 | raise ValueError('You must initialize adversarial_disturb_f before running in adversarial mode') 45 | a = self.adversarial_disturb_f(x, u, t) 46 | A = (self.A.unsqueeze(0) * a[:, :, None, None]).sum(1) 47 | B = (self.B.unsqueeze(0) * a[:, :, None, None]).sum(1) 48 | return (A @ x.unsqueeze(2) + B @ u.unsqueeze(2)).squeeze() 49 | 50 | def cost_f(self, x, u, t): 51 | return ((x @ self.Q) * x).sum(-1) + ((u @ self.R) * u).sum(-1) 52 | 53 | def get_pldi_linearization(self): 54 | return self.A, self.B, self.Q, self.R 55 | 56 | def gen_states(self, num_states, device=None): 57 | return torch.tensor(np.random.rand(num_states, self.n), device=device, dtype=TORCH_DTYPE) 58 | 59 | def __copy__(self): 60 | new_env = RandomPLDIEnv.__new__(RandomPLDIEnv) 61 | new_env.__dict__.update(self.__dict__) 62 | return new_env 63 | -------------------------------------------------------------------------------- /envs/random_nldi_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from envs import ode_env 6 | import disturb_models as dm 7 | from constants import * 8 | 9 | 10 | class RandomNLDIEnv(ode_env.NLDIEnv): 11 | 12 | def __init__(self, n=5, m=3, wp=2, wq=2, isD0=False, random_seed=None, device=None): 13 | if random_seed is not None: 14 | np.random.seed(random_seed) 15 | torch.manual_seed(random_seed+1) 16 | 17 | self.n, self.m, self.wp, self.wq = n, m, wp, wq 18 | self.isD0 = isD0 19 | 20 | self.A = torch.tensor(np.random.randn(n, n), dtype=TORCH_DTYPE, device=device) 21 | self.B = torch.tensor(np.random.randn(n, m), dtype=TORCH_DTYPE, device=device) 22 | self.G = torch.tensor(1.5*np.random.randn(n, wp), dtype=TORCH_DTYPE, device=device) 23 | self.C = torch.tensor(np.random.randn(wq, n), dtype=TORCH_DTYPE, device=device) 24 | if isD0: 25 | self.D = torch.zeros(wq, m, dtype=TORCH_DTYPE, device=device) 26 | else: 27 | self.D = torch.tensor(0.01*np.random.randn(wq, m), dtype=TORCH_DTYPE, device=device) 28 | 29 | Q = np.random.randn(n, n) 30 | Q = Q.T @ Q 31 | self.Q = torch.tensor(Q, dtype=TORCH_DTYPE, device=device) 32 | 33 | R = np.random.randn(m, m) 34 | R = R.T @ R 35 | self.R = torch.tensor(R, dtype=TORCH_DTYPE, device=device) 36 | 37 | self.disturb_f = dm.NLDIDisturbModel(self.C, self.D, n, m, wp) 38 | self.adversarial_disturb_f = None 39 | 40 | if device is not None: 41 | self.disturb_f.to(device=device, dtype=TORCH_DTYPE) 42 | 43 | def xdot_f(self, x, u, t): 44 | p = self.disturb_f(x, u, t) 45 | return x @ self.A.T + u @ self.B.T + p @ self.G.T 46 | 47 | def xdot_adversarial_f(self, x, u, t): 48 | if self.adversarial_disturb_f is None: 49 | raise ValueError('You must initialize adversarial_disturb_f before running in adversarial mode') 50 | p = self.adversarial_disturb_f(x, u, t) 51 | return x @ self.A.T + u @ self.B.T + p @ self.G.T 52 | 53 | def cost_f(self, x, u, t): 54 | return ((x @ self.Q) * x).sum(-1) + ((u @ self.R) * u).sum(-1) 55 | 56 | def get_nldi_linearization(self): 57 | return self.A, self.B, self.G, self.C, self.D, self.Q, self.R 58 | 59 | def gen_states(self, num_states, device=None): 60 | return torch.tensor(np.random.rand(num_states, self.n), device=device, dtype=TORCH_DTYPE) 61 | 62 | def __copy__(self): 63 | new_env = RandomNLDIEnv.__new__(RandomNLDIEnv) 64 | new_env.__dict__.update(self.__dict__) 65 | return new_env 66 | 67 | -------------------------------------------------------------------------------- /envs/ode_env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | import numpy as np 3 | import torch 4 | from scipy import integrate 5 | 6 | from constants import * 7 | 8 | 9 | class ODEEnv(ABC): 10 | def xdot_f(self, x, u, t): 11 | raise NotImplementedError 12 | 13 | def xdot_adversarial_f(self, x, u, t): 14 | raise NotImplementedError 15 | 16 | def cost_f(self, x, u, t): 17 | raise NotImplementedError 18 | 19 | def gen_states(self, num_states, device=None): 20 | raise NotImplementedError 21 | 22 | def step(self, x, u, t, dt, step_type, adversarial=False): 23 | xdot_f = self.xdot_adversarial_f if adversarial else self.xdot_f 24 | 25 | if step_type == 'euler': 26 | x_dot = xdot_f(x, u, t) 27 | x_next = x + dt * x_dot 28 | cost = dt * self.cost_f(x, u, t) 29 | elif step_type == 'RK4': 30 | # k1 = time_step * xdot_f(x,u) 31 | # k2 = time_step * xdot_f(x + k1/2, model(x + k1/2)) 32 | # k3 = time_step * xdot_f(x + k2/2, model(x + k2/2)) 33 | # k4 = time_step * xdot_f(x + k3, model(x + k3)) 34 | # x_next = x + (k1 + 2*k2 + 2*k3 + k4)/6 35 | k1 = xdot_f(x, u, t) 36 | k2 = xdot_f(x + (dt * k1 / 2), u, t + dt/2) 37 | k3 = xdot_f(x + (dt * k2 / 2), u, t + dt/2) 38 | k4 = xdot_f(x + (dt * k3), u, t + dt) 39 | x_next = x + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6 40 | 41 | cost_k1 = self.cost_f(x, u, t) 42 | cost_k2 = self.cost_f(x + (dt * k1 / 2), u, t + dt/2) 43 | cost_k3 = self.cost_f(x + (dt * k2 / 2), u, t + dt/2) 44 | cost_k4 = self.cost_f(x + (dt * k3), u, t + dt) 45 | cost = dt * (cost_k1 + 2 * cost_k2 + 2 * cost_k3 + cost_k4) / 6 46 | elif step_type == 'scipy': 47 | x_next = torch.zeros_like(x, dtype=TORCH_DTYPE, device=x.device) 48 | num_x, x_dim = x_next.shape 49 | cost = torch.zeros([num_x], dtype=TORCH_DTYPE, device=x.device) 50 | for i in range(num_x): 51 | u_i = u[i, :].unsqueeze(0) 52 | f = lambda y, t: torch.cat((xdot_f(torch.Tensor([y[0:x_dim]]), u_i).squeeze(), 53 | self.cost_f(torch.Tensor([y[0:x_dim]]), u_i)), 0).numpy() 54 | y0 = torch.cat((x[i, :], torch.Tensor([0])), 0).numpy() 55 | output = integrate.odeint(f, y0, [0, dt]) 56 | x_next[i, :] = torch.Tensor(output[1, :x_dim]) 57 | cost[i] = torch.Tensor(output[1, x_dim:]) 58 | else: 59 | raise NotImplementedError('Unsupported step type.') 60 | 61 | return x_next, cost 62 | 63 | 64 | class NLDIEnv(ODEEnv, ABC): 65 | def get_nldi_linearization(self): 66 | raise NotImplementedError 67 | 68 | 69 | class PLDIEnv(ODEEnv, ABC): 70 | def get_pldi_linearization(self): 71 | raise NotImplementedError 72 | 73 | 74 | class HinfEnv(ODEEnv, ABC): 75 | def get_hinf_linearization(self): 76 | raise NotImplementedError 77 | -------------------------------------------------------------------------------- /rl/distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from rl.utils import AddBias, init 7 | 8 | """ 9 | Modify standard PyTorch distributions so they are compatible with this code. 10 | """ 11 | 12 | # 13 | # Standardize distribution interfaces 14 | # 15 | 16 | # Categorical 17 | class FixedCategorical(torch.distributions.Categorical): 18 | def sample(self): 19 | return super().sample().unsqueeze(-1) 20 | 21 | def log_probs(self, actions): 22 | return ( 23 | super() 24 | .log_prob(actions.squeeze(-1)) 25 | .view(actions.size(0), -1) 26 | .sum(-1) 27 | .unsqueeze(-1) 28 | ) 29 | 30 | def mode(self): 31 | return self.probs.argmax(dim=-1, keepdim=True) 32 | 33 | 34 | # Normal 35 | class FixedNormal(torch.distributions.Normal): 36 | def log_probs(self, actions): 37 | return super().log_prob(actions).sum(-1, keepdim=True) 38 | 39 | def entrop(self): 40 | return super.entropy().sum(-1) 41 | 42 | def mode(self): 43 | return self.mean 44 | 45 | 46 | # Bernoulli 47 | class FixedBernoulli(torch.distributions.Bernoulli): 48 | def log_probs(self, actions): 49 | return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 50 | 51 | def entropy(self): 52 | return super().entropy().sum(-1) 53 | 54 | def mode(self): 55 | return torch.gt(self.probs, 0.5).float() 56 | 57 | 58 | class Categorical(nn.Module): 59 | def __init__(self, num_inputs, num_outputs): 60 | super(Categorical, self).__init__() 61 | 62 | init_ = lambda m: init( 63 | m, 64 | nn.init.orthogonal_, 65 | lambda x: nn.init.constant_(x, 0), 66 | gain=0.01) 67 | 68 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 69 | 70 | def forward(self, x): 71 | x = self.linear(x) 72 | return FixedCategorical(logits=x) 73 | 74 | 75 | class DiagGaussian(nn.Module): 76 | def __init__(self, num_inputs, num_outputs): 77 | super(DiagGaussian, self).__init__() 78 | 79 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 80 | constant_(x, 0)) 81 | 82 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 83 | self.logstd = AddBias(torch.zeros(num_outputs)) 84 | 85 | def forward(self, x): 86 | action_mean = self.fc_mean(x) 87 | 88 | # An ugly hack for my KFAC implementation. 89 | zeros = torch.zeros(action_mean.size()) 90 | if x.is_cuda: 91 | zeros = zeros.to(device=x.device) 92 | 93 | action_logstd = self.logstd(zeros) 94 | return FixedNormal(action_mean, action_logstd.exp()) 95 | 96 | 97 | class Bernoulli(nn.Module): 98 | def __init__(self, num_inputs, num_outputs): 99 | super(Bernoulli, self).__init__() 100 | 101 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 102 | constant_(x, 0)) 103 | 104 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 105 | 106 | def forward(self, x): 107 | x = self.linear(x) 108 | return FixedBernoulli(logits=x) 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Enforcing robust control guarantees within neural network policies 2 | 3 | This repository is by 4 | [Priya L. Donti](https://www.priyadonti.com), 5 | [Melrose Roderick](https://melroderick.github.io/), 6 | [Mahyar Fazlyab](https://scholar.google.com/citations?user=Y3bmjJwAAAAJ&hl=en), 7 | and [J. Zico Kolter](http://zicokolter.com), 8 | and contains the [PyTorch](https://pytorch.org) source code to 9 | reproduce the experiments in our paper 10 | "[Enforcing robust control guarantees within neural network policies](https://arxiv.org/abs/2011.08105)." 11 | 12 | If you find this repository helpful in your publications, 13 | please consider citing our paper. 14 | 15 | ``` 16 | @inproceedings{donti2021enforcing, 17 | title={Enforcing robust control guarantees within neural network policies}, 18 | author={Donti, Priya and Roderick, Melrose and Fazlyab, Mahyar and Kolter, J Zico}, 19 | booktitle={International Conference on Learning Representations}, 20 | year={2021} 21 | } 22 | ``` 23 | 24 | 25 | ## Introduction 26 | 27 | When designing controllers for safety-critical systems, practitioners often face a challenging tradeoff between robustness and performance. While robust control methods provide rigorous guarantees on system stability under certain worst-case disturbances, they often result in simple controllers that perform poorly in the average (non-worst) case. In contrast, nonlinear control methods trained using deep learning have achieved state-of-the-art performance on many control tasks, but 28 | often lack robustness guarantees. We propose a technique that combines the strengths of these two approaches: a generic nonlinear control policy class, parameterized by neural networks, that nonetheless enforces the same provable robustness criteria as robust control. Specifically, we show that by integrating custom convex-optimization-based projection layers into a nonlinear policy, we can construct a provably robust neural network policy class that outperforms robust control methods in the average (non-adversarial) setting. We demonstrate the power of this approach on several domains, improving in performance over existing robust control methods and in stability over (non-robust) RL methods. 29 | 30 | ## Dependencies 31 | 32 | + Python 3.x/numpy/scipy/[cvxpy](http://www.cvxpy.org/en/latest/) 33 | + [PyTorch](https://pytorch.org) 1.5 34 | + OpenAI [Gym](https://gym.openai.com/) 0.15: *A toolkit for reinforcement learning* 35 | + [qpth](https://github.com/locuslab/qpth): 36 | *A fast differentiable QP solver for PyTorch* 37 | + [block](https://github.com/bamos/block): 38 | *A block matrix library for numpy and PyTorch* 39 | + [argparse](https://docs.python.org/3/library/argparse.html): *Input argument parsing* 40 | + [setproctitle](https://pypi.org/project/setproctitle/): *Library to set process titles* 41 | + [tqdm](https://tqdm.github.io/): *A library for smart progress bars* 42 | 43 | 44 | ## Instructions 45 | 46 | ### Running experiments 47 | 48 | Experiments can be run the following commands for each environment (with the additional optional flag `--gpu [gpunum]` to enable GPU support). To reproduce the results in our paper, append the flag `--envRandomSeed 10` to the commands below. 49 | 50 | Synthetic NLDI (D=0): 51 | 52 | ``` 53 | python main.py --env random_nldi-d0 54 | ``` 55 | 56 | Synthetic NLDI (D ≠ 0): 57 | 58 | ``` 59 | python main.py --env random_nldi-dnonzero 60 | ``` 61 | 62 | Cart-pole: 63 | 64 | ``` 65 | python main.py --env cartpole --T 10 --dt 0.05 66 | ``` 67 | 68 | Planar quadrotor: 69 | 70 | ``` 71 | python main.py --env quadrotor --T 4 --dt 0.02 72 | ``` 73 | 74 | Microgrid: 75 | 76 | ``` 77 | python main.py --env microgrid 78 | ``` 79 | 80 | Synthetic PLDI: 81 | 82 | ``` 83 | python main.py --env random_pldi_env 84 | ``` 85 | 86 | Synthetic H: 87 | 88 | ``` 89 | python main.py --env random_hinf_env 90 | ``` 91 | 92 | ### Generating plots 93 | 94 | After running the experiments above, plots and tables can then be generated by running: 95 | 96 | ``` 97 | python plots.py 98 | ``` 99 | -------------------------------------------------------------------------------- /robust_mpc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cvxpy as cp 3 | import scipy.linalg as la 4 | from constants import * 5 | 6 | import warnings 7 | 8 | class RobustMPC(): 9 | def __init__(self): 10 | pass 11 | 12 | def get_action(self, x): 13 | Ks_new = np.apply_along_axis(self.controller_gain_fn, 1, x.cpu().detach().numpy()) 14 | nan_mask = np.isnan(Ks_new).sum(axis=1).sum(axis=1) > 0 15 | 16 | # Accommodate any solver errors by falling back on robust LQR 17 | Ks_t = torch.tensor(Ks_new, device=self.device, dtype=TORCH_DTYPE) 18 | Ks_t[nan_mask] = self.K_init 19 | 20 | u = Ks_t.bmm(x.unsqueeze(-1)).squeeze(-1) 21 | return u 22 | 23 | class RobustNLDIMPC(RobustMPC): 24 | def __init__(self, A, B, G, C, D, Q, R, K_init, device='cpu'): 25 | super().__init__() 26 | Q_sqrt = la.sqrtm(Q.cpu().detach().numpy()) 27 | R_sqrt = la.sqrtm(R.cpu().detach().numpy()) 28 | self.controller_gain_fn = lambda x: get_nldi_controller_gain( 29 | x, *[v.cpu().detach().numpy() for v in (A, B, G, C, D)], Q_sqrt, R_sqrt) 30 | 31 | self.K_init = K_init 32 | self.device = device 33 | 34 | class RobustPLDIMPC(RobustMPC): 35 | def __init__(self, As, Bs, Q, R, K_init, device='cpu'): 36 | super().__init__() 37 | Q_sqrt = la.sqrtm(Q.cpu().detach().numpy()) 38 | R_sqrt = la.sqrtm(R.cpu().detach().numpy()) 39 | self.controller_gain_fn = lambda x: get_pldi_controller_gain( 40 | x, As.cpu().detach().numpy(), Bs.cpu().detach().numpy(), Q_sqrt, R_sqrt) 41 | 42 | self.K_init = K_init 43 | self.device = device 44 | 45 | def get_nldi_controller_gain(x_in, A, B, G, C, D, Q_sqrt, R_sqrt): 46 | n = A.shape[1] 47 | m = B.shape[1] 48 | w = G.shape[1] 49 | wq = C.shape[0] 50 | assert (w <= wq), "wp must equal wq to use this method" 51 | if w < wq: 52 | G = np.concatenate([G, np.zeros([G.shape[0], wq - w])], axis=1) 53 | w = wq 54 | x = np.expand_dims(x_in, 1) 55 | 56 | S = cp.Variable((n,n), symmetric=True) 57 | Y = cp.Variable((m,n)) 58 | gam = cp.Variable() 59 | lam = cp.Variable(w) 60 | 61 | m1 = cp.bmat(( 62 | (np.expand_dims(np.array([1]),0), x.T), 63 | (x, S) 64 | )) 65 | 66 | m2 = cp.bmat(( 67 | (S, Y.T@R_sqrt, S@Q_sqrt, S@C.T + Y.T@D.T, S@A.T + Y.T@B.T), 68 | (R_sqrt@Y, gam*np.eye(m), np.zeros((m,n)), np.zeros((m,w)), np.zeros((m,n))), 69 | (Q_sqrt@S, np.zeros((n,m)), gam*np.eye(n), np.zeros((n,w)), np.zeros((n,n))), 70 | (C@S + D@Y, np.zeros((w,m)), np.zeros((w,n)), cp.diag(lam), np.zeros((w,n))), 71 | (A@S + B@Y, np.zeros((n,m)), np.zeros((n,n)), np.zeros((n,w)), S - G@cp.diag(lam)@G.T) 72 | )) 73 | 74 | cons = [S >> 0, lam >= 1e-2, m1 >> 0, m2 >> 0] 75 | 76 | try: 77 | prob = cp.Problem(cp.Minimize(gam), cons) 78 | prob.solve(solver=cp.MOSEK) 79 | if prob.status in ["infeasible", "unbounded"]: 80 | warnings.warn('Infeasible or unbounded SDP for some x. Falling back to K_init for that x.') 81 | K = np.nan * np.ones((m,n)) 82 | else: 83 | K = np.linalg.solve(S.value, Y.value.T).T 84 | except cp.SolverError: 85 | warnings.warn('Solver error for some x. Falling back to K_init for that x.') 86 | K = np.nan * np.ones((m,n)) 87 | 88 | return K 89 | 90 | 91 | def get_pldi_controller_gain(x_in, As, Bs, Q_sqrt, R_sqrt): 92 | n = As.shape[2] 93 | m = Bs.shape[2] 94 | L = As.shape[0] 95 | 96 | x = np.expand_dims(x_in, 1) 97 | 98 | S = cp.Variable((n,n), symmetric=True) 99 | Y = cp.Variable((m,n)) 100 | gam = cp.Variable() 101 | 102 | m1 = cp.bmat(( 103 | (np.expand_dims(np.array([1]),0), x.T), 104 | (x, S) 105 | )) 106 | 107 | m2s = [cp.bmat(( 108 | (S, S@As[i].T + Y.T@Bs[i].T, S@Q_sqrt, Y.T@R_sqrt), 109 | (As[i]@S + Bs[i]@Y, S, np.zeros((n,n)), np.zeros((n,m))), 110 | (Q_sqrt@S, np.zeros((n,n)), gam*np.eye(n), np.zeros((n,m))), 111 | (R_sqrt@Y, np.zeros((m,n)), np.zeros((m,n)), gam*np.eye(m)) 112 | )) for i in range(L)] 113 | 114 | cons = [S >> 0, m1 >> 0] + [m2 >> 0 for m2 in m2s] 115 | 116 | try: 117 | prob = cp.Problem(cp.Minimize(gam), cons) 118 | prob.solve(solver=cp.MOSEK) 119 | if prob.status in ["infeasible", "unbounded"]: 120 | warnings.warn('Infeasible or unbounded SDP for some x. Falling back to K_init for that x.') 121 | K = np.nan * np.ones((m,n)) 122 | else: 123 | K = np.linalg.solve(S.value, Y.value.T).T 124 | except cp.SolverError: 125 | warnings.warn('Solver error for some x. Falling back to K_init for that x.') 126 | K = np.nan * np.ones((m,n)) 127 | 128 | return K -------------------------------------------------------------------------------- /envs/rl_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import gym 5 | from gym import spaces 6 | 7 | from constants import * 8 | 9 | 10 | class RLWrapper(gym.Env): 11 | 12 | def __init__(self, env, state_dim, action_dim, 13 | rmax=100., gamma=None, 14 | dt=0.05, step_type='euler', action_transform=None, 15 | num_envs=1, device=None, rarl=False, hinf_loss=False): 16 | self.env = env 17 | self.rmax = rmax 18 | self.dt = dt 19 | self.step_type = step_type 20 | self.action_transform = action_transform 21 | self.num_envs = num_envs 22 | self.device = device 23 | self.gamma = gamma 24 | self.epsilon = 1e-8 25 | self.cliprew = 10 26 | self.hinf_loss = hinf_loss 27 | 28 | self.state_dim = state_dim 29 | self.action_dim = action_dim 30 | 31 | self.max_action_val = 10. 32 | self.action_space = spaces.Box(low=-self.max_action_val, high=self.max_action_val, 33 | shape=(self.action_dim,), dtype=NUMPY_DTYPE) 34 | self.max_state_val = 100. 35 | self.observation_space = spaces.Box(low=-self.max_state_val, high=self.max_state_val, 36 | shape=(self.state_dim,), dtype=NUMPY_DTYPE) 37 | self.observation_space_low = torch.tensor(self.observation_space.low, device=device) 38 | self.observation_space_high = torch.tensor(self.observation_space.high, device=device) 39 | 40 | self.rarl = rarl 41 | self.disturb_space = spaces.Box(low=0, high=1, 42 | shape=(self.env.disturb_f.disturb_size,), dtype=NUMPY_DTYPE) 43 | 44 | self.x = None 45 | self.episode_reward = None 46 | self.episode_cost = None 47 | self.episode_disturb_norm = None 48 | self.episode_t = None 49 | self.reset() 50 | 51 | def step(self, u, adversarial=False): 52 | if adversarial: 53 | self.env.adversarial_disturb_f.update(self.x) 54 | 55 | if self.rarl: 56 | disturb = u[:, self.action_dim:] 57 | u = u[:, :self.action_dim] 58 | self.env.disturb_f.disturbance = disturb 59 | else: 60 | self.env.disturb_f.disturbance = None 61 | 62 | if self.action_transform is not None: 63 | u = self.action_transform(u, self.x) 64 | 65 | self.x, cost = self.env.step(self.x, u, self.episode_t, self.dt, self.step_type, adversarial=adversarial) 66 | self.x = self.x.detach() 67 | cost = cost.detach() 68 | if self.num_envs == 1: 69 | self.x.flatten() 70 | r = torch.clamp(self.rmax - (cost / self.dt), 0, self.rmax) / self.rmax 71 | self.episode_reward += r 72 | self.episode_cost += cost 73 | if self.hinf_loss: 74 | self.episode_disturb_norm += torch.norm(self.env.disturb, p=2, dim=1) 75 | self.episode_t += 1 76 | 77 | done = torch.max(self.x <= self.observation_space_low, self.x >= self.observation_space_high) 78 | if self.num_envs > 1: 79 | done = done.any(dim=1) 80 | else: 81 | done = done.any() 82 | 83 | episode_cost = self.episode_cost / self.episode_disturb_norm if self.hinf_loss else self.episode_cost 84 | return self.x, r, done, {'episode_reward': self.episode_reward, 'episode_cost': episode_cost} 85 | 86 | def reset(self, x0=None, index=None): 87 | if x0 is None: 88 | if self.num_envs == 1: 89 | self.x = self.env.gen_states(1, device=self.device)[0, :] 90 | elif index: 91 | self.x[index, :] = self.env.gen_states(1, device=self.device)[0, :] 92 | else: 93 | self.x = self.env.gen_states(self.num_envs, device=self.device) 94 | else: 95 | self.x = x0 96 | 97 | if index: 98 | self.episode_reward[index] = 0 99 | self.episode_cost[index] = 0 100 | self.episode_disturb_norm[index] = 0 101 | self.episode_t[index] = 0 102 | else: 103 | self.episode_reward = torch.zeros(self.num_envs, dtype=TORCH_DTYPE, device=self.device) 104 | self.episode_cost = torch.zeros(self.num_envs, dtype=TORCH_DTYPE, device=self.device) 105 | self.episode_disturb_norm = torch.zeros(self.num_envs, dtype=TORCH_DTYPE, device=self.device) 106 | self.episode_t = torch.zeros(self.num_envs, dtype=torch.int64, device=self.device) 107 | 108 | return self.x 109 | 110 | def render(self, state_i=0, mode='human'): 111 | if hasattr(self.env, 'render') and callable(getattr(self.env, 'render')): 112 | return self.env.render(self.x[state_i, :]) 113 | else: 114 | raise NotImplementedError 115 | 116 | def close(self): 117 | if hasattr(self.env, 'close') and callable(getattr(self.env, 'close')): 118 | return self.env.close() 119 | else: 120 | raise NotImplementedError 121 | 122 | -------------------------------------------------------------------------------- /envs/microgrid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from envs import ode_env 5 | import disturb_models as dm 6 | from constants import * 7 | 8 | ''' 9 | Adapted from: 10 | Quang Linh Lam, Antoneta Iuliana Bratcu, Delphine Riu. Frequency robust control in stand-alone 11 | microgrids with PV sources : design and sensitivity analysis. Symposium de Génie Electrique, Jun 12 | 2016, Grenoble, France. ffhal-01361556 13 | ''' 14 | 15 | class MicrogridEnv(ode_env.NLDIEnv): 16 | 17 | def __init__(self, random_seed=None, device=None): 18 | if random_seed is not None: 19 | np.random.seed(random_seed) 20 | torch.manual_seed(random_seed+1) 21 | 22 | self.n = 3 23 | self.m = 2 24 | self.wp = 1 25 | self.wq = 1 26 | self.isD0 = True # TODO? 27 | 28 | # for per-unit normalization 29 | v_base = 585 30 | p_base = 1.0015 31 | i_base = p_base / v_base 32 | r_base = v_base / i_base 33 | 34 | omega_b = 3.5 # between 2.61 and 5.22 given in paper 35 | R_dc = 50 / r_base # based on 5-200 ohms range at https://doc.ingeniamc.com/venus/product-manual/installation-and-configuration/motor-output-wiring/shunt-braking-resistor 36 | C_dc = 500e-6 * r_base # based on 200-800 micro-farads given here: https://www.elmomc.com/wp-content/uploads/2019/08/Simple-Capacitor-white-paper.pdf 37 | # and also approx here: https://doc.ingeniamc.com/titan/manuals/titan-go-product-manual/wiring-and-connections/dc-bus-bulk-capacitance 38 | T_diesel = 0.01 # based on plausible value from https://core.ac.uk/download/pdf/52115363.pdf (TODO switch to p.u.?) 39 | s_diesel = 0.04 # arbitrary, based on plausible value from Wikipedia: https://en.wikipedia.org/wiki/Droop_speed_control 40 | H = 0.7305 # based on plausible value from https://core.ac.uk/download/pdf/52115363.pdf (TODO switch to p.u.?) 41 | D_load = 0.9 # arbitrary 42 | alpha_ce = 0.585 43 | beta_de = 0.4 44 | v_sce = 585 / v_base 45 | R_sc = 30e-3 / r_base # based on plausible internal supercapacitor resistence of 30 milli-ohms from https://en.wikipedia.org/wiki/Supercapacitor#Internal_resistance 46 | i_se = -2.5 / i_base 47 | 48 | self.A = 0.001 * torch.tensor(( 49 | (-omega_b/(R_dc*C_dc), 0, 0), 50 | (0, -1/T_diesel, -1/(T_diesel * s_diesel)), 51 | (0, 1/(2*H), -D_load/(2*H)) 52 | ), dtype=TORCH_DTYPE, device=device) 53 | 54 | self.B = torch.tensor(( 55 | (omega_b * alpha_ce / C_dc, -omega_b * beta_de / C_dc), 56 | (0, 0), 57 | ( (v_sce - 2*R_sc*i_se)/(2*H), 0) 58 | ), dtype=TORCH_DTYPE, device=device) 59 | 60 | self.G = torch.tensor(( 61 | (0, ), 62 | (0, ), 63 | (-1/(2*H), ) 64 | ), dtype=TORCH_DTYPE, device=device) 65 | 66 | # Capture some (arbitrary) dependence between voltage/freq variation and disturb 67 | # Note: Doesn't always solve, but solves for random seed 10 68 | # self.C = torch.tensor((-0.05, 0, 0.05), dtype=TORCH_DTYPE, device=device).unsqueeze(0) 69 | self.C = 5.0 * torch.tensor(np.random.randn(3), dtype=TORCH_DTYPE, device=device).unsqueeze(0) 70 | self.D = torch.zeros(self.wq, self.m, dtype=TORCH_DTYPE, device=device) 71 | 72 | # Objective: Assign weight 1 to entries in output vector y, 73 | # and small weight to other values 74 | self.Q = torch.tensor(( 75 | (1, 0, 0), 76 | (0, 0.1, 0), 77 | (0, 0, 1) 78 | ), dtype=TORCH_DTYPE, device=device) 79 | 80 | self.R = torch.tensor(( 81 | (0.1, 0), 82 | (0, 0.1) 83 | ), dtype=TORCH_DTYPE, device=device) 84 | 85 | self.disturb_f = dm.NLDIDisturbModel(self.C, self.D, self.n, self.m, self.wp) 86 | if device is not None: 87 | self.disturb_f.to(device=device, dtype=TORCH_DTYPE) 88 | 89 | self.adversarial_disturb_f = None 90 | 91 | def xdot_f(self, x, u, t): 92 | w = self.disturb_f(x, u, t) 93 | return x @ self.A.T + u @ self.B.T + w @ self.G.T 94 | 95 | def xdot_adversarial_f(self, x, u, t): 96 | if self.adversarial_disturb_f is None: 97 | raise ValueError('You must initialize adversarial_disturb_f before running in adversarial mode') 98 | w = self.adversarial_disturb_f(x, u, t) 99 | return x @ self.A.T + u @ self.B.T + w @ self.G.T 100 | 101 | def cost_f(self, x, u, t): 102 | return ((x @ self.Q) * x).sum(-1) + ((u @ self.R) * u).sum(-1) 103 | 104 | def get_nldi_linearization(self): 105 | return self.A, self.B, self.G, self.C, self.D, self.Q, self.R 106 | 107 | def gen_states(self, num_states, device=None): 108 | return torch.tensor(np.random.rand(num_states, self.n), device=device, dtype=TORCH_DTYPE) 109 | 110 | def __copy__(self): 111 | new_env = MicrogridEnv.__new__(MicrogridEnv) 112 | new_env.__dict__.update(self.__dict__) 113 | return new_env 114 | -------------------------------------------------------------------------------- /rl/ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import os 5 | 6 | from rl import utils 7 | 8 | 9 | class PPO(): 10 | def __init__(self, 11 | actor_critic, 12 | clip_param, 13 | ppo_epoch, 14 | num_mini_batch, 15 | value_loss_coef, 16 | entropy_coef, 17 | lr=None, 18 | eps=None, 19 | max_grad_norm=None, 20 | use_clipped_value_loss=True, 21 | use_linear_lr_decay=False): 22 | 23 | self.actor_critic = actor_critic 24 | 25 | self.clip_param = clip_param 26 | self.ppo_epoch = ppo_epoch 27 | self.num_mini_batch = num_mini_batch 28 | 29 | self.value_loss_coef = value_loss_coef 30 | self.entropy_coef = entropy_coef 31 | 32 | self.max_grad_norm = max_grad_norm 33 | self.use_clipped_value_loss = use_clipped_value_loss 34 | 35 | self.lr = lr 36 | self.use_linear_lr_decay = use_linear_lr_decay 37 | self.optimizer = optim.Adam(actor_critic.parameters(), lr=lr, eps=eps) 38 | 39 | def act(self, inputs, rnn_hxs, masks): 40 | with torch.no_grad(): 41 | _, action, _, _ = self.train_act(inputs, rnn_hxs, masks, deterministic=True) 42 | return action 43 | 44 | def train_act(self, inputs, rnn_hxs, masks, deterministic=False): 45 | return self.actor_critic.act(inputs, rnn_hxs, masks, deterministic=deterministic) 46 | 47 | def get_value(self, inputs, rnn_hxs, masks): 48 | return self.actor_critic.get_value(inputs, rnn_hxs, masks) 49 | 50 | def update(self, rollouts, step, total_steps): 51 | if self.use_linear_lr_decay: 52 | # decrease learning rate linearly 53 | utils.update_linear_schedule(self.optimizer, step, total_steps, self.lr) 54 | 55 | advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1] 56 | advantages = (advantages - advantages.mean()) / ( 57 | advantages.std() + 1e-5) 58 | 59 | value_loss_epoch = 0 60 | action_loss_epoch = 0 61 | dist_entropy_epoch = 0 62 | 63 | for e in range(self.ppo_epoch): 64 | if self.actor_critic.is_recurrent: 65 | data_generator = rollouts.recurrent_generator( 66 | advantages, self.num_mini_batch) 67 | else: 68 | data_generator = rollouts.feed_forward_generator( 69 | advantages, self.num_mini_batch) 70 | 71 | for sample in data_generator: 72 | obs_batch, recurrent_hidden_states_batch, actions_batch, \ 73 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, \ 74 | adv_targ = sample 75 | 76 | # Reshape to do in a single forward pass for all steps 77 | values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions( 78 | obs_batch, recurrent_hidden_states_batch, masks_batch, 79 | actions_batch) 80 | 81 | ratio = torch.exp(action_log_probs - 82 | old_action_log_probs_batch) 83 | surr1 = ratio * adv_targ 84 | surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 85 | 1.0 + self.clip_param) * adv_targ 86 | action_loss = -torch.min(surr1, surr2).mean() 87 | 88 | if self.use_clipped_value_loss: 89 | value_pred_clipped = value_preds_batch + \ 90 | (values - value_preds_batch).clamp(-self.clip_param, self.clip_param) 91 | value_losses = (values - return_batch).pow(2) 92 | value_losses_clipped = ( 93 | value_pred_clipped - return_batch).pow(2) 94 | value_loss = 0.5 * torch.max(value_losses, 95 | value_losses_clipped).mean() 96 | else: 97 | value_loss = 0.5 * (return_batch - values).pow(2).mean() 98 | 99 | self.optimizer.zero_grad() 100 | (value_loss * self.value_loss_coef + action_loss - 101 | dist_entropy * self.entropy_coef).backward() 102 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 103 | self.max_grad_norm) 104 | self.optimizer.step() 105 | 106 | value_loss_epoch += value_loss.item() 107 | action_loss_epoch += action_loss.item() 108 | dist_entropy_epoch += dist_entropy.item() 109 | 110 | num_updates = self.ppo_epoch * self.num_mini_batch 111 | 112 | value_loss_epoch /= num_updates 113 | action_loss_epoch /= num_updates 114 | dist_entropy_epoch /= num_updates 115 | 116 | return value_loss_epoch, action_loss_epoch, dist_entropy_epoch 117 | 118 | def save(self, save_dir): 119 | torch.save(self.actor_critic.state_dict(), os.path.join(save_dir, 'ppo.pt')) 120 | 121 | def load(self, save_dir): 122 | self.actor_critic.load_state_dict(torch.load(os.path.join(save_dir, 'ppo.pt'))) 123 | -------------------------------------------------------------------------------- /envs/quadrotor_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from envs import ode_env 5 | import disturb_models as dm 6 | from constants import * 7 | 8 | import os 9 | 10 | 11 | class QuadrotorEnv(ode_env.NLDIEnv): 12 | 13 | def __init__(self, mass=0.8, moment_arm=0.01, inertia_roll=15.67e-3, random_seed=None, device=None): 14 | if random_seed is not None: 15 | np.random.seed(random_seed) 16 | torch.manual_seed(random_seed+1) 17 | 18 | self.device = device 19 | 20 | # guessing parameters: 21 | # mass 1 kg, moment of inertia for roll axis 0.0093 kg * m^2 22 | # http://www.diva-portal.org/smash/get/diva2:1020192/FULLTEXT02.pdf 23 | # OR mass 0.8 kg, moment of inertia for roll axis 15.67e-3 kg*m^2, arm length of vehicle=0.3m 24 | # https://www.researchgate.net/publication/283241371_Feedback_control_strategies_for_quadrotor-type_aerial_robots_a_survey 25 | 26 | self.g = 9.81 # gravitational acceleration in m/s^2 27 | 28 | self.mass = mass 29 | self.moment_arm = moment_arm 30 | self.inertia_roll = inertia_roll 31 | 32 | # Max and min values for the state: [x, z, roll, xdot, zdot, rolldot] 33 | self.x_max = torch.tensor([1.1, 1.1, 0.06, 0.5, 1.0, 0.8], dtype=TORCH_DTYPE, device=device) 34 | self.x_min = torch.tensor([-1.1, -1.1, -0.06, -0.5, -1.0, -0.8], dtype=TORCH_DTYPE, device=device) 35 | 36 | self.x_0_max = torch.tensor([1.0, 1.0, 0.05, 0.0, 0.0, 0.0], dtype=TORCH_DTYPE, device=device) 37 | self.x_0_min = torch.tensor([-1.0, -1.0, -0.05, -0.0, -0.0, -0.0], dtype=TORCH_DTYPE, device=device) 38 | 39 | self.B = torch.tensor([ 40 | [0, 0], 41 | [0, 0], 42 | [0, 0], 43 | [0, 0], 44 | [1/self.mass, 1/self.mass], 45 | [self.moment_arm/self.inertia_roll, -self.moment_arm/self.inertia_roll] 46 | ], dtype=TORCH_DTYPE, device=device) 47 | 48 | # TODO: hacky, assumes call from main.py in top level directory 49 | array_path = os.path.join('problem_gen', 'quadrotor') 50 | self.A = torch.tensor(np.load(os.path.join(array_path, 'A.npy')), dtype=TORCH_DTYPE, device=device) 51 | self.n, self.m = self.A.shape[0], self.B.shape[1] 52 | 53 | self.G_lin = torch.tensor(np.load(os.path.join(array_path, 'G.npy')), dtype=TORCH_DTYPE, device=device) 54 | self.C_lin = torch.tensor(np.load(os.path.join(array_path, 'C.npy')), dtype=TORCH_DTYPE, device=device) 55 | self.D_lin = torch.zeros(self.C_lin.shape[0], self.m, dtype=TORCH_DTYPE, device=device) 56 | 57 | disturb_n = 2 58 | self.G_disturb = 0.1 * torch.tensor(np.random.randn(self.n, disturb_n), dtype=TORCH_DTYPE, device=device) 59 | self.C_disturb = torch.tensor(0.1 * np.random.randn(disturb_n, self.n), dtype=TORCH_DTYPE, device=device) 60 | self.D_disturb = torch.zeros(disturb_n, self.m, dtype=TORCH_DTYPE, device=device) 61 | 62 | self.G = torch.cat([self.G_lin, self.G_disturb], dim=1) 63 | self.C = torch.cat([self.C_lin, self.C_disturb], dim=0) 64 | self.D = torch.cat([self.D_lin, self.D_disturb], dim=0) 65 | 66 | self.wp, self.wq = self.G.shape[1], self.C.shape[0] 67 | 68 | # TODO: have reasonable objective? 69 | Q = np.random.randn(self.n, self.n) 70 | Q = Q.T @ Q 71 | # Q = np.eye(self.n) 72 | self.Q = torch.tensor(Q, dtype=TORCH_DTYPE, device=device) 73 | 74 | R = np.random.randn(self.m, self.m) 75 | R = R.T @ R 76 | # R = np.eye(self.m) 77 | self.R = torch.tensor(R, dtype=TORCH_DTYPE, device=device) 78 | 79 | self.disturb_f = dm.NLDIDisturbModel(self.C_disturb, self.D_disturb, self.n, self.m, self.G_disturb.shape[1]) 80 | if device is not None: 81 | self.disturb_f.to(device=device, dtype=TORCH_DTYPE) 82 | 83 | self.adversarial_disturb_f = None 84 | 85 | def xdot_f(self, x, u, t): 86 | px, pz, phi, vx, vz, phidot = [x[:,i] for i in range(x.shape[1])] 87 | 88 | x_part = torch.stack([ 89 | vx*torch.cos(phi) - vz*torch.sin(phi), 90 | vx*torch.sin(phi) + vz*torch.cos(phi), 91 | phidot, 92 | vz*phidot - self.g*torch.sin(phi), 93 | -vx*phidot - self.g*torch.cos(phi) + self.g, # note: + g = center dynamics by assuming nominal policy [gm/2, gm/2] always applied 94 | torch.zeros(x.shape[0], device=self.device, dtype=TORCH_DTYPE) 95 | ]).T 96 | 97 | p_disturb = self.disturb_f(x, u, t) 98 | return x_part + u@self.B.T + p_disturb @ self.G_disturb.T 99 | 100 | def xdot_adversarial_f(self, x, u, t): 101 | if self.adversarial_disturb_f is None: 102 | raise ValueError('You must initialize adversarial_disturb_f before running in adversarial mode') 103 | p = self.adversarial_disturb_f(x, u, t) 104 | return x @ self.A.T + u @ self.B.T + p @ self.G.T 105 | 106 | def cost_f(self, x, u, t): 107 | return ((x @ self.Q) * x).sum(-1) + ((u @ self.R) * u).sum(-1) 108 | 109 | def get_nldi_linearization(self): 110 | return self.A, self.B, self.G, self.C, self.D, self.Q, self.R 111 | 112 | def gen_states(self, num_states, device=None): 113 | prop = torch.tensor(np.random.rand(num_states, self.n), device=device, dtype=TORCH_DTYPE) 114 | return self.x_0_max.detach()*prop + self.x_0_min.detach()*(1-prop) 115 | 116 | def __copy__(self): 117 | new_env = QuadrotorEnv.__new__(QuadrotorEnv) 118 | new_env.__dict__.update(self.__dict__) 119 | return new_env 120 | 121 | 122 | -------------------------------------------------------------------------------- /rl/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from collections import deque 4 | 5 | from rl import utils 6 | from constants import * 7 | 8 | 9 | def train(agent, envs, rollouts, device, args, save_dir, 10 | eval_envs=None, x_hold=None, x_test=None, save_extension=None, 11 | num_episode_steps=100): 12 | obs = envs.reset() 13 | rollouts.obs[0].copy_(obs) 14 | rollouts.to(device) 15 | 16 | episode_rewards = deque(maxlen=100) 17 | 18 | hold_costs = [] 19 | test_costs = [] 20 | adv_test_costs = [] 21 | 22 | min_cost = np.inf 23 | if save_extension is not None: 24 | save_dir = os.path.join(save_dir, save_extension) 25 | try: 26 | os.makedirs(save_dir) 27 | except OSError: 28 | pass 29 | agent.save(save_dir) 30 | 31 | start = time.time() 32 | num_updates = int(args.num_env_steps) // num_episode_steps // args.num_processes 33 | for j in range(num_updates): 34 | obs = envs.reset() 35 | 36 | masks = torch.zeros((envs.num_envs, 1), dtype=TORCH_DTYPE, device=obs.device) 37 | bad_masks = torch.ones((envs.num_envs, 1), dtype=TORCH_DTYPE, device=obs.device) 38 | recurrent_hidden_states = torch.ones((envs.num_envs, 1), dtype=TORCH_DTYPE, device=obs.device) 39 | rollouts.reset(obs, recurrent_hidden_states, masks, bad_masks) 40 | 41 | for step in range(num_episode_steps): 42 | # Sample actions 43 | with torch.no_grad(): 44 | value, action, action_log_prob, recurrent_hidden_states = agent.train_act( 45 | rollouts.obs[step], rollouts.recurrent_hidden_states[step], 46 | rollouts.masks[step]) 47 | 48 | # Obser reward and next obs 49 | obs, reward, done, infos = envs.step(action) 50 | 51 | for i, d in enumerate(done): 52 | if d: 53 | episode_rewards.append(infos['episode_reward'][i].cpu().numpy()) 54 | obs = envs.reset(index=i) 55 | 56 | # If done then clean the history of observations. 57 | masks = torch.tensor([[0.0] if done_ else [1.0] for done_ in done], dtype=TORCH_DTYPE) 58 | bad_masks = torch.ones((envs.num_envs, 1), dtype=TORCH_DTYPE) 59 | rollouts.insert(obs, recurrent_hidden_states, action, 60 | action_log_prob, value, reward.unsqueeze(1), masks, bad_masks) 61 | 62 | with torch.no_grad(): 63 | next_value = agent.get_value( 64 | rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], 65 | rollouts.masks[-1]).detach() 66 | 67 | rollouts.compute_returns(next_value, args.use_gae, args.gamma, 68 | args.gae_lambda, args.use_proper_time_limits) 69 | 70 | value_loss, action_loss, dist_entropy = agent.update(rollouts, j, num_updates) 71 | 72 | rollouts.after_update() 73 | 74 | for i in range(envs.num_envs): 75 | episode_rewards.append(infos['episode_reward'][i].cpu().numpy()) 76 | 77 | if j % args.log_interval == 0 and len(episode_rewards) > 1: 78 | total_num_steps = (j + 1) * args.num_processes * num_episode_steps 79 | end = time.time() 80 | print( 81 | "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" 82 | .format(j, total_num_steps, 83 | int(total_num_steps / (end - start)), 84 | len(episode_rewards), np.mean(episode_rewards), 85 | np.median(episode_rewards), np.min(episode_rewards), 86 | np.max(episode_rewards), dist_entropy, value_loss, 87 | action_loss)) 88 | 89 | if eval_envs is not None \ 90 | and (args.eval_interval is not None and j % args.eval_interval == 0): 91 | hold_cost = evaluate(agent, eval_envs, device, num_episode_steps, x_test=x_hold, set_name='x_hold') 92 | hold_costs.append(hold_cost.cpu().numpy()) 93 | 94 | test_cost = evaluate(agent, eval_envs, device, num_episode_steps, x_test=x_test, set_name='x_test') 95 | test_costs.append(test_cost.cpu().numpy()) 96 | 97 | eval_envs.env.adversarial_disturb_f.reset() 98 | adv_test_cost = evaluate(agent, eval_envs, device, num_episode_steps, x_test=x_test, set_name='adv_x_test', adversarial=True) 99 | adv_test_costs.append(adv_test_cost.detach().cpu().numpy()) 100 | 101 | if hold_cost < min_cost: 102 | min_cost = hold_cost 103 | agent.save(save_dir) 104 | 105 | agent.load(save_dir) 106 | return hold_costs, test_costs, adv_test_costs 107 | 108 | 109 | def evaluate(agent, eval_env, device, max_num_steps, x_test=None, set_name='', adversarial=False): 110 | if x_test is not None: 111 | obs = eval_env.reset(x0=x_test) 112 | else: 113 | obs = eval_env.reset() 114 | 115 | for step in range(max_num_steps): 116 | action = agent.act(obs, None, None) 117 | 118 | # Obser reward and next obs 119 | obs, _, done, info = eval_env.step(action, adversarial=adversarial) 120 | 121 | episode_rewards = info['episode_reward'] 122 | episode_costs = info['episode_cost'] 123 | 124 | print(" Evaluating {} using {} episodes: mean/median cost {:.3f}/{:.3f}, min/max cost {:.3f}/{:.3f} mean/median reward {:.3f}/{:.3f}, min/max reward {:.3f}/{:.3f}\n".format( 125 | set_name, episode_costs.shape[0], 126 | torch.mean(episode_costs), torch.median(episode_costs), 127 | torch.min(episode_costs), torch.max(episode_costs), 128 | torch.mean(episode_rewards), torch.median(episode_rewards), 129 | torch.min(episode_rewards), torch.max(episode_rewards))) 130 | 131 | return torch.mean(episode_costs) 132 | -------------------------------------------------------------------------------- /problem_gen/quadrotor/linearize_dynamics.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import cvxpy as cp 4 | import scipy.optimize as sco 5 | import scipy.linalg as sla 6 | import itertools 7 | import torch 8 | 9 | 10 | def quadrotor_jacobian(x): 11 | px, pz, phi, vx, vz, phidot = x 12 | g = 9.81 13 | 14 | jac = np.array([ 15 | [0, 0, -vx*np.sin(phi)-vz*np.cos(phi), np.cos(phi), -np.sin(phi), 0], 16 | [0, 0, vx*np.cos(phi)-vz*np.sin(phi), np.sin(phi), np.cos(phi), 0], 17 | [0, 0, 0, 0, 0, 1], 18 | [0, 0, -g*np.cos(phi), 0, phidot, vz], 19 | [0, 0, g*np.sin(phi), -phidot, 0, -vx], 20 | [0, 0, 0, 0, 0, 0] 21 | ]) 22 | 23 | return jac 24 | 25 | def calc_max_jac(x_min, x_max): 26 | px_max, pz_max, phi_max, vx_max, vz_max, phidot_max = x_max 27 | px_min, pz_min, phi_min, vx_min, vz_min, phidot_min = x_min 28 | (sinphi_max, cosphi_max), (sinphi_min, cosphi_min) = get_sinusoid_extrema(phi_min, phi_max) 29 | 30 | g = 9.81 31 | 32 | jac = np.array([ 33 | [0, 0, get_max_value(lambda x: -x[3]*np.sin(x[2])-x[4]*np.cos(x[2]), x_min, x_max), cosphi_max, -sinphi_min, 0], 34 | [0, 0, get_max_value(lambda x: x[3]*np.cos(x[2])-x[4]*np.sin(x[2]), x_min, x_max), sinphi_max, cosphi_max, 0], 35 | [0, 0, 0, 0, 0, 1], 36 | [0, 0, -g*cosphi_min, 0, phidot_max, vz_max], 37 | [0, 0, g*sinphi_max, -phidot_min, 0, -vx_min], 38 | [0, 0, 0, 0, 0, 0] 39 | ]) 40 | 41 | return jac 42 | 43 | def calc_min_jac(x_min, x_max): 44 | px_max, pz_max, phi_max, vx_max, vz_max, phidot_max = x_max 45 | px_min, pz_min, phi_min, vx_min, vz_min, phidot_min = x_min 46 | (sinphi_max, cosphi_max), (sinphi_min, cosphi_min) = get_sinusoid_extrema(phi_min, phi_max) 47 | 48 | g = 9.81 49 | 50 | jac = np.array([ 51 | [0, 0, get_min_value(lambda x: -x[3]*np.sin(x[2])-x[4]*np.cos(x[2]), x_min, x_max), cosphi_min, -sinphi_max, 0], 52 | [0, 0, get_min_value(lambda x: x[3]*np.cos(x[2])-x[4]*np.sin(x[2]), x_min, x_max), sinphi_min, cosphi_min, 0], 53 | [0, 0, 0, 0, 0, 1], 54 | [0, 0, -g*cosphi_max, 0, phidot_min, vz_min], 55 | [0, 0, g*sinphi_min, -phidot_max, 0, -vx_max], 56 | [0, 0, 0, 0, 0, 0] 57 | ]) 58 | 59 | return jac 60 | 61 | 62 | def is_in(val, v_min, v_max): 63 | return (val >= v_min) and (val <= v_max) 64 | 65 | def get_max_value(fun, x_min, x_max): 66 | res = sco.minimize(lambda x: -fun(x), (x_max + x_min)/2, 67 | constraints=( 68 | {'type': 'ineq', 'fun': lambda x: x - x_min}, 69 | {'type': 'ineq', 'fun': lambda x: -x + x_max})) 70 | return -res['fun'] 71 | 72 | def get_min_value(fun, x_min, x_max): 73 | res = sco.minimize(fun, (x_max + x_min)/2, 74 | constraints=( 75 | {'type': 'ineq', 'fun': lambda x: x - x_min}, 76 | {'type': 'ineq', 'fun': lambda x: -x + x_max})) 77 | return res['fun'] 78 | 79 | def get_sinusoid_extrema(phi_min, phi_max): 80 | cosphi_max = 1 if is_in(0, phi_min, phi_max) \ 81 | else max(np.cos(phi_min), np.cos(phi_max)) 82 | cosphi_min = -1 if is_in(np.pi, phi_min, phi_max) or is_in(-np.pi, phi_min, phi_max) \ 83 | else min(np.cos(phi_min), np.cos(phi_max)) 84 | sinphi_max = 1 if is_in(np.pi/2, phi_min, phi_max) \ 85 | else max(np.sin(phi_min), np.sin(phi_max)) 86 | sinphi_min = -1 if is_in(-np.pi/2, phi_min, phi_max) \ 87 | else min(np.sin(phi_min), np.sin(phi_max)) 88 | 89 | return (sinphi_max, cosphi_max), (sinphi_min, cosphi_min) 90 | 91 | 92 | def xdot_uncontrolled(x): 93 | px, pz, phi, vx, vz, phidot = [x[:,i] for i in range(x.shape[1])] 94 | g = 9.81 95 | 96 | x_part = torch.stack([ 97 | vx*torch.cos(phi) - vz*torch.sin(phi), 98 | vx*torch.sin(phi) + vz*torch.cos(phi), 99 | phidot, 100 | vz*phidot - g*torch.sin(phi), 101 | -vx*phidot - g*torch.cos(phi) + g, 102 | torch.zeros(x.shape[0]) 103 | ]).T 104 | 105 | return x_part.numpy() 106 | 107 | 108 | def main(): 109 | n = 6 110 | x_max = np.array([6, 6, np.pi/16, 0.25, 0.25, np.pi/32]) 111 | x_min = np.array([-6, -6, -np.pi/16, -0.25, -0.25, -np.pi/32]) 112 | 113 | max_jac = calc_max_jac(x_min, x_max) 114 | min_jac = calc_min_jac(x_min, x_max) 115 | 116 | print('constructing polytope') 117 | # construct polytope 118 | non_const = (max_jac != min_jac) 119 | # Aks = iter([np.array(p) for p in itertools.product(*zip(max_jac[non_const],min_jac[non_const]))]) 120 | Aks_nonconst = [np.array(p) for p in itertools.product(*zip(max_jac[non_const],min_jac[non_const]))] 121 | Aks = [max_jac.copy() for i in range(len(Aks_nonconst))] 122 | for i in range(len(Aks)): 123 | np.putmask(Aks[i], non_const, Aks_nonconst[i]) 124 | 125 | print('constructing problem') 126 | V = cp.Variable((n,n), symmetric=True) 127 | W = cp.Variable((n,n), symmetric=True) 128 | A = quadrotor_jacobian(np.zeros(n)) 129 | 130 | obj = cp.trace(V) + cp.trace(W) 131 | cons = [cp.bmat([[V, (Ak-A).T], 132 | [Ak-A, W]]) >> 0 \ 133 | for Ak in Aks] 134 | cons += [W >> 0] 135 | 136 | prob = cp.Problem(cp.Minimize(obj), cons) 137 | 138 | print('solving SDP') 139 | prob.solve(solver=cp.MOSEK, verbose=True) 140 | 141 | # TODO: figure out best way to do this 142 | C = np.linalg.cholesky((V.value).T) 143 | G = np.linalg.cholesky(W.value) 144 | # C = sla.sqrtm(V.value) 145 | # G = sla.sqrtm(W.value) 146 | 147 | # Check correctness 148 | prop = np.random.random((200, n)) 149 | rand_xs = x_max*prop + x_min*(1-prop) 150 | fx = xdot_uncontrolled(torch.Tensor(rand_xs)) 151 | # print(np.linalg.norm((fx - rand_xs@A.T)@np.linalg.inv(G).T, axis=1) <= np.linalg.norm(rand_xs@C.T, axis=1)) 152 | print((np.linalg.norm((fx - rand_xs@A.T)@np.linalg.inv(G).T, axis=1) <= np.linalg.norm(rand_xs@C.T, axis=1)).all()) 153 | 154 | ratio = np.linalg.norm(rand_xs@C.T, axis=1)/np.linalg.norm((fx - rand_xs@A.T)@np.linalg.inv(G).T, axis=1) 155 | print(ratio.max()) 156 | print(ratio.mean()) 157 | print(np.median(ratio)) 158 | 159 | # Save matrices 160 | np.save('A.npy', A) 161 | np.save('G.npy', G) 162 | np.save('C.npy', C) 163 | 164 | if __name__ == '__main__': 165 | main() 166 | 167 | -------------------------------------------------------------------------------- /rl/rarl_ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import os 5 | 6 | from rl import utils 7 | 8 | 9 | class RARLPPO(): 10 | def __init__(self, 11 | protagonist_ac, 12 | adversarial_ac, 13 | clip_param, 14 | ppo_epoch, 15 | num_mini_batch, 16 | value_loss_coef, 17 | entropy_coef, 18 | lr=None, 19 | eps=None, 20 | max_grad_norm=None, 21 | use_clipped_value_loss=True, 22 | use_linear_lr_decay=False): 23 | 24 | self.protagonist_ac = protagonist_ac 25 | self.adversarial_ac = adversarial_ac 26 | 27 | self.clip_param = clip_param 28 | self.ppo_epoch = ppo_epoch 29 | self.num_mini_batch = num_mini_batch 30 | 31 | self.value_loss_coef = value_loss_coef 32 | self.entropy_coef = entropy_coef 33 | 34 | self.max_grad_norm = max_grad_norm 35 | self.use_clipped_value_loss = use_clipped_value_loss 36 | 37 | self.lr = lr 38 | self.use_linear_lr_decay = use_linear_lr_decay 39 | self.protagonist_optimizer = optim.Adam(protagonist_ac.parameters(), lr=lr, eps=eps) 40 | self.adversarial_optimizer = optim.Adam(adversarial_ac.parameters(), lr=lr, eps=eps) 41 | 42 | self.num_updates = 0 43 | 44 | def act(self, inputs, rnn_hxs, masks): 45 | with torch.no_grad(): 46 | _, action, _, _ = self.protagonist_ac.act(inputs, rnn_hxs, masks, deterministic=True) 47 | return action 48 | 49 | def train_act(self, inputs, rnn_hxs, masks, deterministic=False): 50 | value, action, action_log_probs, rnn_hxs =\ 51 | self.protagonist_ac.act(inputs, rnn_hxs, masks, deterministic=deterministic) 52 | adv_value, adv_action, adv_action_log_probs, adv_rnn_hxs =\ 53 | self.adversarial_ac.act(inputs, rnn_hxs, masks, deterministic=deterministic) 54 | 55 | return torch.cat([value, adv_value], dim=-1), \ 56 | torch.cat([action, adv_action], dim=-1), \ 57 | torch.cat([action_log_probs, adv_action_log_probs], dim=-1), \ 58 | None 59 | 60 | def get_value(self, inputs, rnn_hxs, masks): 61 | value = self.protagonist_ac.get_value(inputs, rnn_hxs, masks) 62 | adv_value = self.protagonist_ac.get_value(inputs, rnn_hxs, masks) 63 | return torch.cat([value, adv_value], dim=-1) 64 | 65 | def update(self, rollouts, step, total_steps): 66 | if self.use_linear_lr_decay: 67 | # decrease learning rate linearly 68 | utils.update_linear_schedule(self.protagonist_optimizer, step, total_steps, self.lr) 69 | utils.update_linear_schedule(self.adversarial_optimizer, step, total_steps, self.lr) 70 | 71 | advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1] 72 | advantages = (advantages - advantages.mean(dim=[0, 1], keepdims=True)) / ( 73 | advantages.std(dim=[0, 1], keepdims=True) + 1e-5) 74 | 75 | value_loss_epoch = 0 76 | action_loss_epoch = 0 77 | dist_entropy_epoch = 0 78 | 79 | switch_freq = 10 80 | adversarial_update = self.num_updates % (2 * switch_freq) >= switch_freq 81 | 82 | for e in range(self.ppo_epoch): 83 | data_generator = rollouts.feed_forward_generator(advantages, self.num_mini_batch) 84 | 85 | for sample in data_generator: 86 | obs_batch, recurrent_hidden_states_batch, full_actions_batch, \ 87 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, \ 88 | adv_targ = sample 89 | 90 | actions_batch = full_actions_batch[:, :self.protagonist_ac.num_outputs] 91 | adv_actions_batch = full_actions_batch[:, self.protagonist_ac.num_outputs:] 92 | 93 | if adversarial_update: 94 | values, action_log_probs, dist_entropy, _ = self.adversarial_ac.evaluate_actions( 95 | obs_batch, recurrent_hidden_states_batch[:, self.protagonist_ac.num_outputs:], masks_batch, 96 | adv_actions_batch) 97 | 98 | ratio = torch.exp(action_log_probs - old_action_log_probs_batch[:, 1].unsqueeze(-1)) 99 | value_preds_batch = value_preds_batch[:, 1].unsqueeze(-1) 100 | return_batch = return_batch[:, 1].unsqueeze(-1) 101 | adv_targ = adv_targ[:, 1].unsqueeze(-1) 102 | else: 103 | values, action_log_probs, dist_entropy, _ = self.protagonist_ac.evaluate_actions( 104 | obs_batch, recurrent_hidden_states_batch[:, :self.protagonist_ac.num_outputs], masks_batch, 105 | actions_batch) 106 | 107 | ratio = torch.exp(action_log_probs - old_action_log_probs_batch[:, 0].unsqueeze(-1)) 108 | value_preds_batch = value_preds_batch[:, 0].unsqueeze(-1) 109 | return_batch = return_batch[:, 0].unsqueeze(-1) 110 | adv_targ = adv_targ[:, 0].unsqueeze(-1) 111 | 112 | surr1 = ratio * adv_targ 113 | surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 114 | 1.0 + self.clip_param) * adv_targ 115 | action_loss = -torch.min(surr1, surr2).mean() 116 | if adversarial_update: 117 | action_loss = -action_loss 118 | 119 | if self.use_clipped_value_loss: 120 | value_pred_clipped = value_preds_batch + \ 121 | (values - value_preds_batch).clamp(-self.clip_param, self.clip_param) 122 | value_losses = (values - return_batch).pow(2) 123 | value_losses_clipped = ( 124 | value_pred_clipped - return_batch).pow(2) 125 | value_loss = 0.5 * torch.max(value_losses, 126 | value_losses_clipped).mean() 127 | else: 128 | value_loss = 0.5 * (return_batch - values).pow(2).mean() 129 | 130 | optimizer = self.adversarial_optimizer if adversarial_update else self.protagonist_optimizer 131 | params = self.adversarial_ac.parameters() if adversarial_update else self.protagonist_ac.parameters() 132 | optimizer.zero_grad() 133 | (value_loss * self.value_loss_coef + action_loss - 134 | dist_entropy * self.entropy_coef).backward() 135 | nn.utils.clip_grad_norm_(params, self.max_grad_norm) 136 | optimizer.step() 137 | 138 | value_loss_epoch += value_loss.item() 139 | action_loss_epoch += action_loss.item() 140 | dist_entropy_epoch += dist_entropy.item() 141 | 142 | num_updates = self.ppo_epoch * self.num_mini_batch 143 | 144 | value_loss_epoch /= num_updates 145 | action_loss_epoch /= num_updates 146 | dist_entropy_epoch /= num_updates 147 | 148 | self.num_updates += 1 149 | 150 | return value_loss_epoch, action_loss_epoch, dist_entropy_epoch 151 | 152 | def save(self, save_dir): 153 | torch.save(self.protagonist_ac.state_dict(), os.path.join(save_dir, 'rarl_ppo.pt')) 154 | torch.save(self.adversarial_ac.state_dict(), os.path.join(save_dir, 'rarl_ppo_adversary.pt')) 155 | 156 | def load(self, save_dir): 157 | self.protagonist_ac.load_state_dict(torch.load(os.path.join(save_dir, 'rarl_ppo.pt'))) 158 | self.adversarial_ac.load_state_dict(torch.load(os.path.join(save_dir, 'rarl_ppo_adversary.pt'))) 159 | -------------------------------------------------------------------------------- /envs/cartpole.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | 5 | from envs import ode_env 6 | import disturb_models as dm 7 | from constants import * 8 | 9 | 10 | class CartPoleEnv(ode_env.NLDIEnv): 11 | 12 | def __init__(self, l=1, m_cart=1, m_pole=1, g=9.81, Q=None, R=None, random_seed=None, device=None): 13 | if random_seed is not None: 14 | np.random.seed(random_seed) 15 | torch.manual_seed(random_seed+1) 16 | 17 | self.l = l 18 | self.m_cart = m_cart 19 | self.m_pole = m_pole 20 | self.g = g 21 | 22 | self.n, self.m, = 4, 1 23 | 24 | # TODO: have reasonable objective? 25 | self.Q = Q 26 | self.R = R 27 | if Q is None: 28 | Q = np.random.randn(self.n, self.n) 29 | Q = Q.T @ Q 30 | # Q = np.eye(self.n) 31 | self.Q = torch.tensor(Q, dtype=TORCH_DTYPE, device=device) 32 | if R is None: 33 | R = np.random.randn(self.m, self.m) 34 | R = R.T @ R 35 | # R = np.eye(self.m) 36 | self.R = torch.tensor(R, dtype=TORCH_DTYPE, device=device) 37 | 38 | # TODO: hacky, assumes call from main.py in top level directory 39 | array_path = os.path.join('problem_gen', 'cartpole') 40 | self.A = torch.tensor(np.load(os.path.join(array_path, 'A.npy')), dtype=TORCH_DTYPE, device=device) 41 | self.B = torch.tensor(np.load(os.path.join(array_path, 'B.npy')), dtype=TORCH_DTYPE, device=device) 42 | self.G_lin = torch.tensor(np.load(os.path.join(array_path, 'G.npy')), dtype=TORCH_DTYPE, device=device) 43 | self.C_lin = torch.tensor(np.load(os.path.join(array_path, 'C.npy')), dtype=TORCH_DTYPE, device=device) 44 | self.D_lin = torch.tensor(np.load(os.path.join(array_path, 'D.npy')), dtype=TORCH_DTYPE, device=device) 45 | 46 | disturb_n = 2 47 | self.G_disturb = torch.tensor(np.random.randn(self.n, disturb_n), dtype=TORCH_DTYPE, device=device) 48 | self.C_disturb = torch.tensor(0.1 * np.random.randn(disturb_n, self.n), dtype=TORCH_DTYPE, device=device) 49 | self.D_disturb = torch.tensor(0.001 * np.random.randn(disturb_n, self.m), dtype=TORCH_DTYPE, device=device) 50 | 51 | self.G = torch.cat([self.G_lin, self.G_disturb], dim=1) 52 | self.C = torch.cat([self.C_lin, self.C_disturb], dim=0) 53 | self.D = torch.cat([self.D_lin, self.D_disturb], dim=0) 54 | 55 | self.wp, self.wq = self.G.shape[1], self.C.shape[0] 56 | 57 | self.disturb_f = dm.NLDIDisturbModel(self.C_disturb, self.D_disturb, self.n, self.m, self.G_disturb.shape[1]) 58 | if device is not None: 59 | self.disturb_f.to(device=device, dtype=TORCH_DTYPE) 60 | 61 | self.adversarial_disturb_f = None 62 | 63 | # Max and min values for state and action: [x, xdot, theta, thetadot, u] 64 | self.yumax = torch.tensor([1.2, 1.0, 0.1, 1.0, 10], dtype=TORCH_DTYPE, device=device) 65 | self.yumin = torch.tensor([-1.2, -1.0, -0.1, -1.0, -10], dtype=TORCH_DTYPE, device=device) 66 | 67 | self.y_0_max = torch.tensor([1.0, 0.0, 0.1, 0.0], dtype=TORCH_DTYPE, device=device) 68 | self.y_0_min = torch.tensor([-1.0, -0.0, -0.1, -0.0], dtype=TORCH_DTYPE, device=device) 69 | 70 | self.viewer = None 71 | 72 | # Keeping external interface, but renaming internally 73 | def xdot_f(self, state, u_in, t): 74 | # x = state[:, 0] 75 | x_dot = state[:, 1] 76 | theta = state[:, 2] 77 | theta_dot = state[:, 3] 78 | 79 | # limit action magnitude 80 | if self.m == 1: 81 | u = torch.clamp(u_in, self.yumin[-1], self.yumax[-1]).squeeze(1) 82 | else: 83 | raise NotImplementedError() 84 | 85 | sin_theta = torch.sin(theta) 86 | cos_theta = torch.cos(theta) 87 | temp = 1/(self.m_cart + self.m_pole * (sin_theta * sin_theta)) 88 | x_ddot = temp * (u + self.m_pole * sin_theta * (self.l * (theta_dot**2) 89 | - self.g * cos_theta)) 90 | theta_ddot = -(1/self.l) * temp * (u * cos_theta 91 | + self.m_pole * self.l * (theta_dot**2) * cos_theta * sin_theta 92 | - (self.m_cart + self.m_pole) * self.g * sin_theta) 93 | 94 | return torch.stack([x_dot, x_ddot, theta_dot, theta_ddot]).T 95 | 96 | def xdot_adversarial_f(self, x, u, t): 97 | if self.adversarial_disturb_f is None: 98 | raise ValueError('You must initialize adversarial_disturb_f before running in adversarial mode') 99 | 100 | # # limit action magnitude 101 | # if self.m == 1: 102 | # u = torch.clamp(u_in, self.yumin[-1], self.yumax[-1]).squeeze(1) 103 | # else: 104 | # raise NotImplementedError() 105 | 106 | p = self.adversarial_disturb_f(x, u, t) 107 | return x @ self.A.T + u @ self.B.T + p @ self.G.T 108 | 109 | def cost_f(self, x, u, t): 110 | return ((x @ self.Q) * x).sum(-1) + ((u @ self.R) * u).sum(-1) 111 | 112 | def get_nldi_linearization(self): 113 | return self.A, self.B, self.G, self.C, self.D, self.Q, self.R 114 | 115 | def gen_states(self, num_states, device=None): 116 | prop = torch.tensor(np.random.rand(num_states, self.n), device=device, dtype=TORCH_DTYPE) 117 | return self.y_0_max[:self.n].detach()*prop + self.y_0_min[:self.n].detach()*(1-prop) 118 | 119 | def __copy__(self): 120 | new_env = CartPoleEnv.__new__(CartPoleEnv) 121 | new_env.__dict__.update(self.__dict__) 122 | return new_env 123 | 124 | # Copied from Open AI gym: https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py 125 | def render(self, state, mode='human'): 126 | screen_width = 600 127 | screen_height = 400 128 | 129 | world_width = 10 130 | scale = screen_width / world_width 131 | carty = 100 # TOP OF CART 132 | polewidth = 10.0 133 | polelen = scale * (2 * self.l) 134 | cartwidth = 50.0 135 | cartheight = 30.0 136 | 137 | if self.viewer is None: 138 | from gym.envs.classic_control import rendering 139 | self.viewer = rendering.Viewer(screen_width, screen_height) 140 | l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 141 | axleoffset = cartheight / 4.0 142 | cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 143 | self.carttrans = rendering.Transform() 144 | cart.add_attr(self.carttrans) 145 | self.viewer.add_geom(cart) 146 | l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2 147 | pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 148 | pole.set_color(.8, .6, .4) 149 | self.poletrans = rendering.Transform(translation=(0, axleoffset)) 150 | pole.add_attr(self.poletrans) 151 | pole.add_attr(self.carttrans) 152 | self.viewer.add_geom(pole) 153 | self.axle = rendering.make_circle(polewidth / 2) 154 | self.axle.add_attr(self.poletrans) 155 | self.axle.add_attr(self.carttrans) 156 | self.axle.set_color(.5, .5, .8) 157 | self.viewer.add_geom(self.axle) 158 | self.track = rendering.Line((0, carty), (screen_width, carty)) 159 | self.track.set_color(0, 0, 0) 160 | self.viewer.add_geom(self.track) 161 | 162 | self._pole_geom = pole 163 | 164 | # Edit the pole polygon vertex 165 | pole = self._pole_geom 166 | l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2 167 | pole.v = [(l, b), (l, t), (r, t), (r, b)] 168 | 169 | cartx = state[0] * scale + screen_width / 2.0 # MIDDLE OF CART 170 | self.carttrans.set_translation(cartx, carty) 171 | self.poletrans.set_rotation(-state[2]) 172 | 173 | return self.viewer.render(return_rgb_array=mode == 'rgb_array') 174 | 175 | def close(self): 176 | if self.viewer: 177 | self.viewer.close() 178 | self.viewer = None 179 | -------------------------------------------------------------------------------- /disturb_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from scipy.stats import norm 5 | 6 | from constants import * 7 | 8 | 9 | class NLDIDisturbModel(nn.Module): 10 | def __init__(self, C, D, n, m, wp): 11 | super().__init__() 12 | self.C = C 13 | self.D = D 14 | self.net = nn.Sequential(nn.Linear(n + m, 50), nn.Sigmoid(), 15 | nn.Linear(50, 50), nn.Sigmoid()) 16 | self.disturb_layer = nn.Linear(50, wp) 17 | self.magnitude_layer = nn.Sequential(nn.Linear(50, 1), nn.Tanh()) 18 | list(self.magnitude_layer.parameters())[-1].data *= 10 19 | 20 | self.disturb_size = wp 21 | self.disturbance = None 22 | 23 | def forward(self, x, u, t): 24 | if self.disturbance is None: 25 | y = self.net(torch.cat((x, u), dim=1)) 26 | disturb = self.disturb_layer(y) 27 | magnitude = self.magnitude_layer(y) 28 | else: 29 | disturb = self.disturbance 30 | magnitude = 1 31 | disturb_norm = torch.norm(disturb, dim=1) 32 | max_norm = torch.norm(x @ self.C.T + u @ self.D.T, dim=1) 33 | p = (disturb / disturb_norm.unsqueeze(1)) * max_norm.unsqueeze(1) * magnitude 34 | return p 35 | 36 | 37 | class MultiNLDIDisturbModel(nn.Module): 38 | def __init__(self, bs, C, D, n, m, wp): 39 | super().__init__() 40 | self.C = C 41 | self.D = D 42 | self.bs = bs 43 | self.net = nn.Sequential(nn.Linear(self.bs * (n + m), 50), nn.Sigmoid(), 44 | nn.Linear(50, 50), nn.Sigmoid(), 45 | nn.Linear(50, self.bs * wp)) 46 | 47 | def forward(self, x, u, t): 48 | disturb = self.net(torch.cat((x, u), dim=1).reshape([1, -1])).reshape([self.bs, -1]) 49 | disturb_norm = torch.norm(disturb, dim=1) 50 | max_norm = torch.norm(x @ self.C.T + u @ self.D.T, dim=1) 51 | p = (disturb / disturb_norm.unsqueeze(1)) * max_norm.unsqueeze(1) 52 | return p 53 | 54 | def reset(self): 55 | def weight_reset(m): 56 | if isinstance(m, nn.Linear): 57 | m.reset_parameters() 58 | 59 | self.net.apply(weight_reset) 60 | 61 | 62 | class PLDIDisturbModel(nn.Module): 63 | def __init__(self, n, m, L): 64 | super().__init__() 65 | self.net = nn.Sequential(nn.Linear(n + m, 50), nn.ReLU(), 66 | nn.Linear(50, 50), nn.ReLU(), 67 | nn.Linear(50, L), nn.Softmax(1)) 68 | 69 | self.disturb_size = L 70 | self.disturbance = None 71 | 72 | def forward(self, x, u, t): 73 | if self.disturbance is None: 74 | disturb = self.net(torch.cat((x, u), dim=1)) 75 | else: 76 | disturb = nn.Softmax(1)(self.disturbance) 77 | return disturb 78 | 79 | 80 | class MultiPLDIDisturbModel(nn.Module): 81 | def __init__(self, bs, n, m, L): 82 | super().__init__() 83 | self.bs = bs 84 | self.net = nn.Sequential(nn.Linear(self.bs * (n + m), 50), nn.Sigmoid(), 85 | nn.Linear(50, 50), nn.Sigmoid(), 86 | nn.Linear(50, self.bs * L)) 87 | self.softmax = nn.Softmax(1) 88 | 89 | def forward(self, x, u, t): 90 | return self.softmax(self.net(torch.cat((x, u), dim=1).reshape([1, -1])).reshape([self.bs, -1])) 91 | 92 | def reset(self): 93 | def weight_reset(m): 94 | if isinstance(m, nn.Linear): 95 | m.reset_parameters() 96 | self.net.apply(weight_reset) 97 | 98 | 99 | class HinfDisturbModel(nn.Module): 100 | def __init__(self, n, m, wp, T): 101 | super().__init__() 102 | self.net = nn.Sequential(nn.Linear(n + m, 50), nn.Sigmoid(), 103 | nn.Linear(50, 50), nn.Sigmoid()) 104 | self.disturb_layer = nn.Linear(50, wp) 105 | self.magnitude_layer = nn.Sequential(nn.Linear(50, 1), nn.Tanh()) 106 | list(self.magnitude_layer.parameters())[-1].data *= 10 107 | self.T = T 108 | 109 | self.disturb_size = wp 110 | self.disturbance = None 111 | 112 | def forward(self, x, u, t): 113 | if self.disturbance is None: 114 | y = self.net(torch.cat((x, u), dim=1)) 115 | disturb = self.disturb_layer(y) 116 | magnitude = self.magnitude_layer(y) 117 | else: 118 | disturb = self.disturbance 119 | magnitude = 1 120 | 121 | disturb_norm = torch.norm(disturb, dim=1) 122 | if type(t) == torch.Tensor: 123 | t = t.detach().cpu().numpy() 124 | max_norm = torch.tensor(20 * norm.pdf(2 * t/self.T), device=x.device).reshape((-1, 1)) 125 | p = (disturb / disturb_norm.unsqueeze(1)) * max_norm * magnitude 126 | return p 127 | 128 | 129 | class MultiHinfDisturbModel(nn.Module): 130 | def __init__(self, bs, n, m, wp, T): 131 | super().__init__() 132 | self.bs = bs 133 | self.net = nn.Sequential(nn.Linear(self.bs * (n + m), 50), nn.Sigmoid(), 134 | nn.Linear(50, 50), nn.Sigmoid(), 135 | nn.Linear(50, self.bs * wp)) 136 | self.T = T 137 | 138 | def forward(self, x, u, t): 139 | disturb = self.net(torch.cat((x, u), dim=1).reshape([1, -1])).reshape([self.bs, -1]) 140 | disturb_norm = torch.norm(disturb, dim=1) 141 | if type(t) == torch.Tensor: 142 | t = t.detach().cpu().numpy() 143 | max_norm = torch.tensor(20 * norm.pdf(2 * t/self.T), device=x.device).reshape((-1, 1)) 144 | p = (disturb / disturb_norm.unsqueeze(1)) * max_norm 145 | return p 146 | 147 | def reset(self): 148 | def weight_reset(m): 149 | if isinstance(m, nn.Linear): 150 | m.reset_parameters() 151 | 152 | self.net.apply(weight_reset) 153 | 154 | 155 | class MBAdvDisturbModel(nn.Module): 156 | def __init__(self, env, pi, disturb_model, dt, 157 | step_type='euler', lr=0.0025, horizon=100, num_iters=100, change_thresh=0.001, update_freq=100, hinf_loss=False): 158 | super().__init__() 159 | self.dt = dt 160 | self.step_type = step_type 161 | self.lr = lr 162 | self.horizon = horizon 163 | self.num_iters = num_iters 164 | self.change_thresh = change_thresh 165 | self.update_freq = update_freq 166 | self.hinf_loss = hinf_loss 167 | 168 | self.env = env.__copy__() 169 | self.pi = pi 170 | 171 | self.disturb_model = disturb_model 172 | self.num_steps = 0 173 | 174 | def update(self, x_in): 175 | if self.num_steps % self.update_freq == 0: 176 | self.env.adversarial_disturb_f = self.disturb_model 177 | 178 | opt = optim.Adam(self.disturb_model.net.parameters(), lr=self.lr) 179 | x_in = x_in.detach() 180 | 181 | # print('') 182 | # print('Optimizing...') 183 | prev_total_cost = np.inf 184 | for i in range(self.num_iters): 185 | opt.zero_grad() 186 | 187 | x = x_in 188 | total_cost = 0 189 | disturb_norm = 0 190 | for t in range(self.horizon): 191 | u = self.pi(x) 192 | x, cost = self.env.step(x, u, t, self.dt, self.step_type, adversarial=True) 193 | total_cost += cost 194 | 195 | if self.hinf_loss: 196 | disturb_norm += torch.norm(self.env.disturb, p=2, dim=1) 197 | 198 | if self.hinf_loss: 199 | total_cost = (total_cost / disturb_norm).mean() 200 | else: 201 | total_cost = total_cost.mean() 202 | if torch.isnan(total_cost) or torch.abs(prev_total_cost - total_cost)/total_cost < self.change_thresh: 203 | break 204 | prev_total_cost = total_cost 205 | 206 | (-total_cost).backward(retain_graph=True) 207 | opt.step() 208 | 209 | self.num_steps += 1 210 | 211 | def forward(self, x_in, u_in, t): 212 | return self.disturb_model(x_in, u_in, t) 213 | 214 | def set_policy(self, policy): 215 | del self.pi 216 | self.pi = policy 217 | self.reset() 218 | 219 | def reset(self): 220 | self.disturb_model.reset() 221 | self.num_steps = 0 222 | -------------------------------------------------------------------------------- /rl/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import operator 5 | from functools import reduce 6 | 7 | from rl.distributions import Bernoulli, Categorical, DiagGaussian 8 | from rl.utils import init 9 | 10 | 11 | class Flatten(nn.Module): 12 | def forward(self, x): 13 | return x.view(x.size(0), -1) 14 | 15 | 16 | class Policy(nn.Module): 17 | def __init__(self, obs_shape, action_space, base=None, base_kwargs=None): 18 | super(Policy, self).__init__() 19 | if base_kwargs is None: 20 | base_kwargs = {} 21 | if base is None: 22 | if len(obs_shape) == 3: 23 | base = CNNBase 24 | elif len(obs_shape) == 1: 25 | base = MLPBase 26 | else: 27 | raise NotImplementedError 28 | 29 | self.base = base(obs_shape[0], **base_kwargs) 30 | 31 | if action_space.__class__.__name__ == "Discrete": 32 | num_outputs = action_space.n 33 | self.dist = Categorical(self.base.output_size, num_outputs) 34 | elif action_space.__class__.__name__ == "Box": 35 | num_outputs = action_space.shape[0] 36 | self.dist = DiagGaussian(self.base.output_size, num_outputs) 37 | elif action_space.__class__.__name__ == "MultiBinary": 38 | num_outputs = action_space.shape[0] 39 | self.dist = Bernoulli(self.base.output_size, num_outputs) 40 | else: 41 | raise NotImplementedError 42 | 43 | self.num_outputs = num_outputs 44 | 45 | @property 46 | def is_recurrent(self): 47 | return self.base.is_recurrent 48 | 49 | @property 50 | def recurrent_hidden_state_size(self): 51 | """Size of rnn_hx.""" 52 | return self.base.recurrent_hidden_state_size 53 | 54 | def forward(self, inputs, rnn_hxs, masks): 55 | raise NotImplementedError 56 | 57 | def act(self, inputs, rnn_hxs, masks, deterministic=False): 58 | # TODO: integrate the normalization at this point? 59 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 60 | dist = self.dist(actor_features) 61 | 62 | if deterministic: 63 | action = dist.mode() 64 | else: 65 | action = dist.sample() 66 | 67 | action_log_probs = dist.log_probs(action) 68 | dist_entropy = dist.entropy().mean() 69 | 70 | return value, action, action_log_probs, rnn_hxs 71 | 72 | def get_value(self, inputs, rnn_hxs, masks): 73 | value, _, _ = self.base(inputs, rnn_hxs, masks) 74 | return value 75 | 76 | def evaluate_actions(self, inputs, rnn_hxs, masks, action): 77 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 78 | dist = self.dist(actor_features) 79 | 80 | action_log_probs = dist.log_probs(action) 81 | dist_entropy = dist.entropy().mean() 82 | 83 | return value, action_log_probs, dist_entropy, rnn_hxs 84 | 85 | 86 | class NNBase(nn.Module): 87 | def __init__(self, recurrent, recurrent_input_size, hidden_size): 88 | super(NNBase, self).__init__() 89 | 90 | self._hidden_size = hidden_size 91 | self._recurrent = recurrent 92 | 93 | if recurrent: 94 | self.gru = nn.GRU(recurrent_input_size, hidden_size) 95 | for name, param in self.gru.named_parameters(): 96 | if 'bias' in name: 97 | nn.init.constant_(param, 0) 98 | elif 'weight' in name: 99 | nn.init.orthogonal_(param) 100 | 101 | @property 102 | def is_recurrent(self): 103 | return self._recurrent 104 | 105 | @property 106 | def recurrent_hidden_state_size(self): 107 | if self._recurrent: 108 | return self._hidden_size 109 | return 1 110 | 111 | @property 112 | def output_size(self): 113 | return self._hidden_size 114 | 115 | def _forward_gru(self, x, hxs, masks): 116 | if x.size(0) == hxs.size(0): 117 | x, hxs = self.gru(x.unsqueeze(0), (hxs * masks).unsqueeze(0)) 118 | x = x.squeeze(0) 119 | hxs = hxs.squeeze(0) 120 | else: 121 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 122 | N = hxs.size(0) 123 | T = int(x.size(0) / N) 124 | 125 | # unflatten 126 | x = x.view(T, N, x.size(1)) 127 | 128 | # Same deal with masks 129 | masks = masks.view(T, N) 130 | 131 | # Let's figure out which steps in the sequence have a zero for any agent 132 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 133 | has_zeros = ((masks[1:] == 0.0) \ 134 | .any(dim=-1) 135 | .nonzero() 136 | .squeeze() 137 | .cpu()) 138 | 139 | # +1 to correct the masks[1:] 140 | if has_zeros.dim() == 0: 141 | # Deal with scalar 142 | has_zeros = [has_zeros.item() + 1] 143 | else: 144 | has_zeros = (has_zeros + 1).numpy().tolist() 145 | 146 | # add t=0 and t=T to the list 147 | has_zeros = [0] + has_zeros + [T] 148 | 149 | hxs = hxs.unsqueeze(0) 150 | outputs = [] 151 | for i in range(len(has_zeros) - 1): 152 | # We can now process steps that don't have any zeros in masks together! 153 | # This is much faster 154 | start_idx = has_zeros[i] 155 | end_idx = has_zeros[i + 1] 156 | 157 | rnn_scores, hxs = self.gru( 158 | x[start_idx:end_idx], 159 | hxs * masks[start_idx].view(1, -1, 1)) 160 | 161 | outputs.append(rnn_scores) 162 | 163 | # assert len(outputs) == T 164 | # x is a (T, N, -1) tensor 165 | x = torch.cat(outputs, dim=0) 166 | # flatten 167 | x = x.view(T * N, -1) 168 | hxs = hxs.squeeze(0) 169 | 170 | return x, hxs 171 | 172 | 173 | class CNNBase(NNBase): 174 | def __init__(self, num_inputs, recurrent=False, hidden_size=512): 175 | super(CNNBase, self).__init__(recurrent, hidden_size, hidden_size) 176 | 177 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 178 | constant_(x, 0), nn.init.calculate_gain('relu')) 179 | 180 | self.main = nn.Sequential( 181 | init_(nn.Conv2d(num_inputs, 32, 8, stride=4)), nn.ReLU(), 182 | init_(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(), 183 | init_(nn.Conv2d(64, 32, 3, stride=1)), nn.ReLU(), Flatten(), 184 | init_(nn.Linear(32 * 7 * 7, hidden_size)), nn.ReLU()) 185 | 186 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 187 | constant_(x, 0)) 188 | 189 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 190 | 191 | self.train() 192 | 193 | def forward(self, inputs, rnn_hxs, masks): 194 | x = self.main(inputs / 255.0) 195 | 196 | if self.is_recurrent: 197 | x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) 198 | 199 | return self.critic_linear(x), x, rnn_hxs 200 | 201 | 202 | class MLPBase(NNBase): 203 | def __init__(self, num_inputs, recurrent=False, hidden_size=64): 204 | super(MLPBase, self).__init__(recurrent, num_inputs, hidden_size) 205 | 206 | if recurrent: 207 | num_inputs = hidden_size 208 | 209 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 210 | constant_(x, 0), np.sqrt(2)) 211 | 212 | layer_sizes = [num_inputs, hidden_size, hidden_size] 213 | 214 | layers = reduce(operator.add, 215 | [[nn.Linear(a, b), nn.ReLU()] 216 | for a, b in zip(layer_sizes[0:-1], layer_sizes[1:])]) 217 | self.actor = nn.Sequential(*layers) 218 | 219 | layers = reduce(operator.add, 220 | [[nn.Linear(a, b), nn.ReLU()] 221 | for a, b in zip(layer_sizes[0:-1], layer_sizes[1:])]) 222 | self.critic = nn.Sequential(*layers) 223 | 224 | # self.actor = nn.Sequential( 225 | # init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(), 226 | # init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 227 | # 228 | # self.critic = nn.Sequential( 229 | # init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(), 230 | # init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 231 | 232 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 233 | 234 | self.train() 235 | 236 | def forward(self, inputs, rnn_hxs, masks): 237 | x = inputs 238 | 239 | if self.is_recurrent: 240 | x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) 241 | 242 | hidden_critic = self.critic(x) 243 | hidden_actor = self.actor(x) 244 | 245 | return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs 246 | -------------------------------------------------------------------------------- /rl/storage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 3 | 4 | from constants import * 5 | 6 | 7 | def _flatten_helper(T, N, _tensor): 8 | return _tensor.view(T * N, *_tensor.size()[2:]) 9 | 10 | 11 | class RolloutStorage(object): 12 | def __init__(self, num_steps, num_processes, obs_shape, action_space, recurrent_hidden_state_size, rarl=False): 13 | self.rarl = rarl 14 | self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape, dtype=TORCH_DTYPE) 15 | self.recurrent_hidden_states = torch.zeros( 16 | num_steps + 1, num_processes, recurrent_hidden_state_size, dtype=TORCH_DTYPE) 17 | self.rewards = torch.zeros(num_steps, num_processes, 1, dtype=TORCH_DTYPE) 18 | self.value_preds = torch.zeros(num_steps + 1, num_processes, 2 if rarl else 1, dtype=TORCH_DTYPE) 19 | self.returns = torch.zeros(num_steps + 1, num_processes, 2 if rarl else 1, dtype=TORCH_DTYPE) 20 | self.action_log_probs = torch.zeros(num_steps, num_processes, 2 if rarl else 1, dtype=TORCH_DTYPE) 21 | if action_space.__class__.__name__ == 'Discrete': 22 | action_shape = 1 23 | else: 24 | action_shape = action_space.shape[0] 25 | self.actions = torch.zeros(num_steps, num_processes, action_shape) 26 | if action_space.__class__.__name__ == 'Discrete': 27 | self.actions = self.actions.long() 28 | self.masks = torch.ones(num_steps + 1, num_processes, 1, dtype=TORCH_DTYPE) 29 | 30 | # Masks that indicate whether it's a true terminal state 31 | # or time limit end state 32 | self.bad_masks = torch.ones(num_steps + 1, num_processes, 1, dtype=TORCH_DTYPE) 33 | 34 | self.num_steps = num_steps 35 | self.step = 0 36 | 37 | def to(self, device): 38 | self.obs = self.obs.to(device) 39 | self.recurrent_hidden_states = self.recurrent_hidden_states.to(device) 40 | self.rewards = self.rewards.to(device) 41 | self.value_preds = self.value_preds.to(device) 42 | self.returns = self.returns.to(device) 43 | self.action_log_probs = self.action_log_probs.to(device) 44 | self.actions = self.actions.to(device) 45 | self.masks = self.masks.to(device) 46 | self.bad_masks = self.bad_masks.to(device) 47 | 48 | def insert(self, obs, recurrent_hidden_states, actions, action_log_probs, 49 | value_preds, rewards, masks, bad_masks): 50 | self.obs[self.step + 1].copy_(obs) 51 | # self.recurrent_hidden_states[self.step + 1].copy_(recurrent_hidden_states) 52 | self.actions[self.step].copy_(actions) 53 | self.action_log_probs[self.step].copy_(action_log_probs) 54 | self.value_preds[self.step].copy_(value_preds) 55 | self.rewards[self.step].copy_(rewards) 56 | self.masks[self.step + 1].copy_(masks) 57 | self.bad_masks[self.step + 1].copy_(bad_masks) 58 | 59 | self.step += 1 60 | 61 | def reset(self, obs, recurrent_hidden_states, masks, bad_masks): 62 | self.obs[0].copy_(obs) 63 | self.recurrent_hidden_states[0].copy_(recurrent_hidden_states) 64 | self.masks[0].copy_(masks) 65 | self.bad_masks[0].copy_(bad_masks) 66 | 67 | self.step = 0 68 | 69 | def after_update(self): 70 | self.obs[0].copy_(self.obs[-1]) 71 | self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1]) 72 | self.masks[0].copy_(self.masks[-1]) 73 | self.bad_masks[0].copy_(self.bad_masks[-1]) 74 | 75 | self.step = 0 76 | 77 | def compute_returns(self, 78 | next_value, 79 | use_gae, 80 | gamma, 81 | gae_lambda, 82 | use_proper_time_limits=True): 83 | if use_proper_time_limits: 84 | if use_gae: 85 | self.value_preds[-1] = next_value 86 | gae = 0 87 | for step in reversed(range(self.rewards.size(0))): 88 | delta = self.rewards[step] + gamma * self.value_preds[ 89 | step + 1] * self.masks[step + 90 | 1] - self.value_preds[step] 91 | gae = delta + gamma * gae_lambda * self.masks[step + 92 | 1] * gae 93 | gae = gae * self.bad_masks[step + 1] 94 | self.returns[step] = gae + self.value_preds[step] 95 | else: 96 | self.returns[-1] = next_value 97 | for step in reversed(range(self.rewards.size(0))): 98 | self.returns[step] = (self.returns[step + 1] * \ 99 | gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \ 100 | + (1 - self.bad_masks[step + 1]) * self.value_preds[step] 101 | else: 102 | if use_gae: 103 | self.value_preds[-1] = next_value 104 | gae = 0 105 | for step in reversed(range(self.rewards.size(0))): 106 | delta = self.rewards[step] + gamma * self.value_preds[ 107 | step + 1] * self.masks[step + 108 | 1] - self.value_preds[step] 109 | gae = delta + gamma * gae_lambda * self.masks[step + 110 | 1] * gae 111 | self.returns[step] = gae + self.value_preds[step] 112 | else: 113 | self.returns[-1] = next_value 114 | for step in reversed(range(self.rewards.size(0))): 115 | self.returns[step] = self.returns[step + 1] * \ 116 | gamma * self.masks[step + 1] + self.rewards[step] 117 | 118 | def feed_forward_generator(self, 119 | advantages, 120 | num_mini_batch=None, 121 | mini_batch_size=None): 122 | num_steps, num_processes = self.rewards.size()[0:2] 123 | batch_size = num_processes * num_steps 124 | 125 | if mini_batch_size is None: 126 | assert batch_size >= num_mini_batch, ( 127 | "PPO requires the number of processes ({}) " 128 | "* number of steps ({}) = {} " 129 | "to be greater than or equal to the number of PPO mini batches ({})." 130 | "".format(num_processes, num_steps, num_processes * num_steps, 131 | num_mini_batch)) 132 | mini_batch_size = batch_size // num_mini_batch 133 | sampler = BatchSampler( 134 | SubsetRandomSampler(range(batch_size)), 135 | mini_batch_size, 136 | drop_last=True) 137 | for indices in sampler: 138 | obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices] 139 | recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view( 140 | -1, self.recurrent_hidden_states.size(-1))[indices] 141 | actions_batch = self.actions.view(-1, 142 | self.actions.size(-1))[indices] 143 | value_preds_batch = self.value_preds[:-1].view(-1, 2 if self.rarl else 1)[indices] 144 | return_batch = self.returns[:-1].view(-1, 2 if self.rarl else 1)[indices] 145 | masks_batch = self.masks[:-1].view(-1, 1)[indices] 146 | old_action_log_probs_batch = self.action_log_probs.view(-1, 2 if self.rarl else 1)[indices] 147 | if advantages is None: 148 | adv_targ = None 149 | else: 150 | adv_targ = advantages.view(-1, 2 if self.rarl else 1)[indices] 151 | 152 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, \ 153 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ 154 | 155 | def recurrent_generator(self, advantages, num_mini_batch): 156 | num_processes = self.rewards.size(1) 157 | assert num_processes >= num_mini_batch, ( 158 | "PPO requires the number of processes ({}) " 159 | "to be greater than or equal to the number of " 160 | "PPO mini batches ({}).".format(num_processes, num_mini_batch)) 161 | num_envs_per_batch = num_processes // num_mini_batch 162 | perm = torch.randperm(num_processes) 163 | for start_ind in range(0, num_processes, num_envs_per_batch): 164 | obs_batch = [] 165 | recurrent_hidden_states_batch = [] 166 | actions_batch = [] 167 | value_preds_batch = [] 168 | return_batch = [] 169 | masks_batch = [] 170 | old_action_log_probs_batch = [] 171 | adv_targ = [] 172 | 173 | for offset in range(num_envs_per_batch): 174 | ind = perm[start_ind + offset] 175 | obs_batch.append(self.obs[:-1, ind]) 176 | recurrent_hidden_states_batch.append( 177 | self.recurrent_hidden_states[0:1, ind]) 178 | actions_batch.append(self.actions[:, ind]) 179 | value_preds_batch.append(self.value_preds[:-1, ind]) 180 | return_batch.append(self.returns[:-1, ind]) 181 | masks_batch.append(self.masks[:-1, ind]) 182 | old_action_log_probs_batch.append( 183 | self.action_log_probs[:, ind]) 184 | adv_targ.append(advantages[:, ind]) 185 | 186 | T, N = self.num_steps, num_envs_per_batch 187 | # These are all tensors of size (T, N, -1) 188 | obs_batch = torch.stack(obs_batch, 1) 189 | actions_batch = torch.stack(actions_batch, 1) 190 | value_preds_batch = torch.stack(value_preds_batch, 1) 191 | return_batch = torch.stack(return_batch, 1) 192 | masks_batch = torch.stack(masks_batch, 1) 193 | old_action_log_probs_batch = torch.stack( 194 | old_action_log_probs_batch, 1) 195 | adv_targ = torch.stack(adv_targ, 1) 196 | 197 | # States is just a (N, -1) tensor 198 | recurrent_hidden_states_batch = torch.stack( 199 | recurrent_hidden_states_batch, 1).view(N, -1) 200 | 201 | # Flatten the (T, N, ...) tensors to (T * N, ...) 202 | obs_batch = _flatten_helper(T, N, obs_batch) 203 | actions_batch = _flatten_helper(T, N, actions_batch) 204 | value_preds_batch = _flatten_helper(T, N, value_preds_batch) 205 | return_batch = _flatten_helper(T, N, return_batch) 206 | masks_batch = _flatten_helper(T, N, masks_batch) 207 | old_action_log_probs_batch = _flatten_helper(T, N, \ 208 | old_action_log_probs_batch) 209 | adv_targ = _flatten_helper(T, N, adv_targ) 210 | 211 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, \ 212 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ 213 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.patches as patches 3 | from matplotlib.lines import Line2D 4 | from matplotlib import ticker 5 | 6 | from constants import * 7 | 8 | import os 9 | 10 | tableau10 = [(31, 119, 180), (255, 127, 14), 11 | (44, 160, 44), (214, 39, 40), 12 | (148, 103, 189), (140, 86, 75), 13 | (227, 119, 194), (127, 127, 127), 14 | (188, 189, 34), (23, 190, 207)] 15 | 16 | # Scale the RGB values to the [0, 1] range, which is the format matplotlib accepts. 17 | for i in range(len(tableau10)): 18 | r, g, b = tableau10[i] 19 | tableau10[i] = (r / 255., g / 255., b / 255.) 20 | 21 | def plot(save_subs, env_names, file_name, maxes, mins): 22 | n = len(save_subs) 23 | plt.rcParams.update({'font.size': 16}) 24 | 25 | plt.yscale('log') 26 | fig, axes = plt.subplots(n, 2, figsize=(9, 1.4*n + 1)) 27 | axes = axes.reshape((n, 2)) 28 | 29 | for j, title in enumerate(['Non-robust Methods', 'Robust Methods']): 30 | axes[0, j].set_title(title, y=1.05, fontweight='bold', fontsize=17) 31 | axes[-1, j].set_xlabel('Training epochs') 32 | 33 | extra_artists = [] 34 | for i, env_name in enumerate(env_names): 35 | if n > 1: 36 | ax = axes[i, 0] 37 | text = ax.text(-0.32, 0.5, env_name, horizontalalignment='center', 38 | verticalalignment='center', transform=ax.transAxes, 39 | rotation=90, fontweight='bold', fontsize=16) 40 | extra_artists.append(text) 41 | 42 | axes[i, 0].set_ylabel('Loss') 43 | 44 | colors = dict(zip(['LQR', 'MBP', 'PPO', 'Robust LQR', 'RARL', 'Robust MBP$^*$', 'Robust PPO$^*$', 'Robust MPC'], tableau10)) 45 | renaming_dict = dict(zip(['mbp', 'robust_mbp', 'ppo', 'robust_ppo', 'rarl_ppo', 'lqr', 'robust_lqr', 'robust_mpc'], 46 | ['MBP', 'Robust MBP$^*$', 'PPO', 'Robust PPO$^*$', 'RARL', 'LQR', 'Robust LQR', 'Robust MPC'])) 47 | 48 | for i, (save_sub, ax_1, ax_2) in enumerate(zip(save_subs, axes[:, 0], axes[:, 1])): 49 | save = os.path.join('results', save_sub) 50 | 51 | with open(os.path.join(save, 'results.txt'), 'r') as f: 52 | lines = f.readlines() 53 | 54 | test_performances = dict([x.strip().split(': ') for x in lines]) 55 | for key in test_performances.keys(): 56 | test_performances[key] = float(test_performances[key].replace('[', '').replace(']', '')) 57 | 58 | lqr_perf = test_performances['LQR'] 59 | lqr_adv_perf = test_performances['LQR-adv'] 60 | robust_lqr_perf = test_performances['Robust LQR'] 61 | robust_lqr_adv_perf = test_performances['Robust LQR-adv'] 62 | nn_perf = test_performances['MBP'] 63 | nn_adv_perf = test_performances['MBP-adv'] 64 | robust_nn_perf = test_performances['Robust MBP'] 65 | robust_nn_adv_perf = test_performances['Robust MBP-adv'] 66 | ppo_perf = test_performances['PPO'] 67 | ppo_adv_perf = test_performances['PPO-adv'] 68 | robust_ppo_perf = test_performances['Robust PPO'] 69 | robust_ppo_adv_perf = test_performances['Robust PPO-adv'] 70 | rarl_perf = test_performances['RARL PPO'] 71 | rarl_adv_perf = test_performances['RARL PPO-adv'] 72 | mpc_perf = test_performances.get('Robust MPC', None) 73 | mpc_adv_perf = test_performances.get('Robust MPC-adv', None) 74 | 75 | print('Results for %s' % env_names[i]) 76 | 77 | if mpc_perf is not None: 78 | print('& O & %.4g & %.4g & %.4g & %.4g & %.4g & %.4g & %.4g & %.4g \\\\' % ( 79 | lqr_perf, nn_perf, ppo_perf, robust_lqr_perf, mpc_perf, rarl_perf, robust_nn_perf, robust_ppo_perf)) 80 | print('& A & %.5g & %.5g & %.5g & %.5g & %.5g & %.5g & %.5g & %.5g \\\\' % ( 81 | lqr_adv_perf, nn_adv_perf, ppo_adv_perf, robust_lqr_adv_perf, mpc_adv_perf, rarl_adv_perf, robust_nn_adv_perf, robust_ppo_adv_perf)) 82 | else: 83 | print('& O & %.4g & %.4g & %.4g & %.4g & N/A & %.4g & %.4g & %.4g \\\\' % ( 84 | lqr_perf, nn_perf, ppo_perf, robust_lqr_perf, rarl_perf, robust_nn_perf, robust_ppo_perf)) 85 | print('& A & %.5g & %.5g & %.5g & %.5g & N/A & %.5g & %.5g & %.5g \\\\' % ( 86 | lqr_adv_perf, nn_adv_perf, ppo_adv_perf, robust_lqr_adv_perf, rarl_adv_perf, robust_nn_adv_perf, robust_ppo_adv_perf)) 87 | 88 | # Separate out performances in nominal and adversarial cases 89 | test_perfs = {} 90 | test_perfs_adv = {} 91 | nan_to_inf = lambda x: np.where(np.isnan(x), np.inf, x) 92 | for key in test_performances.keys(): 93 | red_name = key.replace('-adv', '').replace(' ', '_').lower() 94 | renaming = renaming_dict[red_name] 95 | if 'adv' in key: 96 | test_perfs_adv[renaming] = nan_to_inf(test_performances[key]) 97 | else: 98 | test_perfs[renaming] = nan_to_inf(test_performances[key]) 99 | 100 | # Load losses in nominal and adversarial cases 101 | test_losses = {} 102 | test_losses_adv = {} 103 | for sub_dir in ['mbp', 'robust_mbp', 'ppo', 'robust_ppo', 'rarl_ppo']: 104 | _, _, losses, losses_adv = load_results(save, sub_dir) 105 | renaming = renaming_dict[sub_dir] 106 | test_losses[renaming] = nan_to_inf(losses) 107 | test_losses_adv[renaming] = nan_to_inf(losses_adv) 108 | # Truncate MBP testing curve to appropriate length for plots 109 | num_test_points = len(test_losses['Robust MBP$^*$']) 110 | test_losses['MBP'] = test_losses['MBP'][:num_test_points] 111 | test_losses_adv['MBP'] = test_losses_adv['MBP'][:num_test_points] 112 | 113 | perf_labels = ['LQR'] 114 | loss_labels = ['MBP', 'PPO'] 115 | loss_markers = ['s', 'D'] 116 | test_freqs = [20, 16] 117 | subplot(num_test_points, 118 | perf_labels, test_perfs, test_perfs_adv, 119 | loss_labels, test_losses, test_losses_adv, 120 | maxes[i], mins[i], ax_1, colors, test_frequencies=test_freqs, markers=loss_markers) 121 | 122 | perf_labels = ['Robust LQR', 'Robust MPC'] 123 | loss_labels = ['RARL', 'Robust MBP$^*$', 'Robust PPO$^*$'] 124 | loss_markers = ['s', '+', 'D'] 125 | test_freqs = [16, 20, 16] 126 | subplot(num_test_points, 127 | perf_labels, test_perfs, test_perfs_adv, 128 | loss_labels, test_losses, test_losses_adv, 129 | maxes[i], mins[i], ax_2, colors, test_frequencies=test_freqs, markers=loss_markers) 130 | 131 | ax_1.label_outer() 132 | ax_2.label_outer() 133 | 134 | handles_1, labels_1 = ax_1.get_legend_handles_labels() 135 | handles_2, labels_2 = ax_2.get_legend_handles_labels() 136 | 137 | fig.tight_layout(pad=0.0, w_pad=0.2, h_pad=0.2) 138 | 139 | # legend = axes[-1, 0].legend(handles, labels, loc='lower center', mode='expand', bbox_to_anchor=(-0.25, -0.6, 2.6, 0.5), frameon=True, ncol=4, ) 140 | # legend = fig.legend([handles_1[0], handles_2[0], handles_1[1], handles_2[1], handles_1[2], handles_2[2]], 141 | # [labels_1[0], labels_2[0], labels_1[1], labels_2[1], labels_1[2], labels_2[2]], 142 | # loc='lower center', mode='expand', bbox_to_anchor=(0.05, 0.0, 1.2, 0.0), ncol=6, 143 | # handletextpad=0.15) 144 | height = fig.bbox_inches.y1 145 | legend = fig.legend(handles_1 + handles_2, 146 | labels_1 + labels_2, 147 | loc='lower center', bbox_to_anchor=(-0.015, 0.32 / height, 1.15, 0.0), ncol=5, 148 | handletextpad=0.4, fontsize=15, frameon=False) 149 | text_label = fig.text(0.25, -0.80 / height, 'Setting:', fontweight='bold', fontsize=15) 150 | legend2 = fig.legend([Line2D([0], [0], color='black', lw=1), Line2D([0], [0], color='black', lw=1, linestyle='--')], 151 | ['Original', 'Adversarial'], 152 | loc='lower center', mode='expand', bbox_to_anchor=(0.40, -0.08 / height, 0.45, 0.0), ncol=2, 153 | handletextpad=0.4, fontsize=15, frameon=False) 154 | extra_artists += [text_label, legend, legend2] 155 | 156 | # # text_offset = -0.25 157 | # text_offset = -0.08 158 | # text = fig.text(0.0, text_offset, ' ', fontsize=14, verticalalignment='top') 159 | # extra_artists.append(text) 160 | 161 | # fig.show() 162 | fig.savefig('%s.pdf' % file_name, bbox_inches='tight', bbox_extra_artists=extra_artists) 163 | 164 | return fig 165 | 166 | 167 | def subplot(n, perf_labels, test_perfs, test_perfs_adv, 168 | loss_labels, all_losses, all_adv_losses, ma, mi, ax, 169 | colors, test_frequencies, markers, plot_perfs=False): 170 | # blowup = 1e6 171 | # ma = max(np.max(nn_test_losses) if np.max(nn_test_losses) < blowup else 0, 172 | # np.max(robust_nn_test_losses) if np.max(robust_nn_test_losses) < blowup else 0, 173 | # np.max(ppo_test_losses) if np.max(ppo_test_losses) < blowup else 0, 174 | # np.max(robust_ppo_test_losses) if np.max(robust_ppo_test_losses) < blowup else 0, 175 | # lqr_perf if lqr_perf < blowup else 0, 176 | # robust_lqr_perf if robust_lqr_perf < blowup else 0) 177 | # mi = min(np.min(nn_test_losses), np.min(robust_nn_test_losses), lqr_perf, robust_lqr_perf) 178 | # top = 1.05 * ma 179 | # bottom = mi - 0.05 * ma 180 | top = 1.15 * ma 181 | bottom = mi * 0.85 182 | # top = ma 183 | # bottom = mi 184 | linewidth = 1 185 | 186 | if plot_perfs: 187 | for label in perf_labels: 188 | if label in test_perfs: 189 | plot_line(test_perfs[label], ma, ax, n, linewidth=1, label=label, linestyle='-', color=colors[label]) 190 | plot_line(test_perfs_adv[label], ma, ax, n, linewidth=1, linestyle='--', color=colors[label]) 191 | 192 | for label, test_frequency, marker in zip(loss_labels, test_frequencies, markers): 193 | losses = all_losses[label] 194 | adv_losses = all_adv_losses[label] 195 | if losses.shape[0] > 80: 196 | losses = losses[::2] 197 | adv_losses = adv_losses[::2] 198 | plot_losses(losses, ma, ax, 199 | test_frequency=test_frequency, linewidth=linewidth, label=label, linestyle='-', color=colors[label], marker=marker) 200 | plot_losses(adv_losses, ma, ax, 201 | test_frequency=test_frequency, linewidth=linewidth, linestyle='--', color=colors[label], marker=marker) 202 | 203 | ax.set_yscale('log') 204 | ax.set_ylim(bottom=bottom, top=top) 205 | # ax.set_yticks([x for x in [1, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6] if mi <= x <= ma]) 206 | ax.set_yticks([mi, ma]) 207 | locmin = ticker.LogLocator(base=10.0, numticks=12) 208 | ax.yaxis.set_minor_locator(locmin) 209 | ax.yaxis.set_minor_formatter(ticker.NullFormatter()) 210 | # ax.ticklabel_format(scilimits=(-2, 4)) 211 | 212 | 213 | def plot_line(loss, ma, ax, n, linewidth=3, color=None, label=None, linestyle='-', marker=None): 214 | ax.axhline(y=loss, linestyle=linestyle, label=label, linewidth=linewidth, color=color, marker=marker) 215 | if loss > ma: 216 | ax.scatter(-1, ma, marker='x', s=150, color=color, linewidth=3) 217 | 218 | def plot_losses(losses, ma, ax, n=None, test_frequency=1, linewidth=3, color=None, label=None, linestyle='-', marker=None): 219 | i = n if n is not None else losses.shape[0] 220 | if (losses[:i] > ma).any(): 221 | i = np.argmax(losses[:n] > ma) + 1 222 | ax.scatter((i - 1) * test_frequency, ma, marker='x', s=150, color=color, linewidth=3) 223 | losses = np.minimum(losses, ma) 224 | ax.plot(range(0, i * test_frequency, test_frequency), losses[:i], label=label, linestyle=linestyle, 225 | linewidth=linewidth, color=color, marker=marker, markersize=2.25, markevery=2) 226 | 227 | def load_results(save_dir, model_name): 228 | model_save_dir = os.path.join(save_dir, model_name) 229 | 230 | plt.close() 231 | # train_losses = np.load(os.path.join(model_save_dir, 'train_losses.npy')) 232 | hold_losses = np.load(os.path.join(model_save_dir, 'hold_losses.npy')) 233 | test_losses = np.load(os.path.join(model_save_dir, 'test_losses.npy')) 234 | test_losses_adv = np.load(os.path.join(model_save_dir, 'test_losses_adv.npy')) 235 | 236 | return None, hold_losses, test_losses, test_losses_adv 237 | 238 | 239 | if __name__ == '__main__': 240 | 241 | # Main plot 242 | save_subs = [ 243 | 'random_nldi-d0+alpha0.001+gamma20+testSz50+holdSz50+trainBatch20+baselr0.001+robustlr0.0001+T2+stepTypeRK4+testStepTypeRK4+seed10+dt0.01', 244 | 'random_nldi-dnonzero+alpha0.001+gamma20+testSz50+holdSz50+trainBatch20+baselr0.001+robustlr0.0001+T2+stepTypeRK4+testStepTypeRK4+seed10+dt0.01', 245 | 'cartpole+alpha0.001+gamma20+testSz50+holdSz50+trainBatch20+baselr0.001+robustlr0.0001+T10.0+stepTypeRK4+testStepTypeRK4+seed10+dt0.05', 246 | 'quadrotor+alpha0.001+gamma20+testSz50+holdSz50+trainBatch20+baselr0.001+robustlr0.0001+T4.0+stepTypeRK4+testStepTypeRK4+seed10+dt0.02', 247 | 'microgrid+alpha0.001+gamma20+testSz50+holdSz50+trainBatch20+baselr0.001+robustlr0.0001+T2+stepTypeRK4+testStepTypeRK4+seed10+dt0.01', 248 | ] 249 | env_names = ['NLDI\n(D = 0)', 'NLDI\n(D ≠ 0)', 'Cartpole', 'Quadrotor', 'Microgrid'] 250 | maxes = [100000, 100000, 1000, 1000, 100] 251 | mins = [10, 10, 1, 1, 0.1] 252 | plot(save_subs, env_names, 'main_results', maxes, mins) 253 | 254 | # Appendix plot 255 | save_subs = [ 256 | 'random_pldi_env+alpha0.001+gamma20+testSz50+holdSz50+trainBatch20+baselr0.001+robustlr0.0001+T2+stepTypeRK4+testStepTypeRK4+seed10+dt0.01', 257 | 'random_hinf_env+alpha0.001+gamma20+testSz50+holdSz50+trainBatch20+baselr0.001+robustlr0.0001+T2+stepTypeRK4+testStepTypeRK4+seed10+dt0.01', 258 | ] 259 | env_names = ['PLDI', 'H$_\mathbf{\infty}$'] 260 | maxes = [1500, 1500] 261 | mins = [1, 10] 262 | plot(save_subs, env_names, 'appendix_results', maxes, mins) 263 | -------------------------------------------------------------------------------- /policy_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import operator 5 | from functools import reduce 6 | import cvxpy as cp 7 | from qpth.qp import QPFunction 8 | from cvxpylayers.torch import CvxpyLayer 9 | from sqrtm import sqrtm 10 | import warnings 11 | 12 | from constants import * 13 | 14 | 15 | class MBPPolicy(nn.Module): 16 | def __init__(self, K, n, m): 17 | super().__init__() 18 | self.K = K 19 | 20 | layer_sizes = [n, 200, 200] 21 | layers = reduce(operator.add, 22 | [[nn.Linear(a, b), nn.ReLU()] 23 | for a, b in zip(layer_sizes[0:-1], layer_sizes[1:])]) 24 | layers += [nn.Linear(layer_sizes[-1], m)] 25 | self.net = nn.Sequential(*layers) 26 | 27 | def forward(self, x): 28 | return x @ self.K.T + self.net(x) 29 | 30 | 31 | class StableNLDIProjection: 32 | def __init__(self, P, A, B, G, C, D, alpha, isD0=False): 33 | self.P = P 34 | self.A = A 35 | self.B = B 36 | self.G = G 37 | self.C = C 38 | self.D = D 39 | self.alpha = alpha 40 | self.isD0 = isD0 41 | self.proj_layer = None 42 | 43 | if self.isD0: 44 | epsilon = 10e-5 45 | self.proj_layer = lambda u, g, h: u - nn.ReLU()(torch.div((u * g).sum(-1) - h, (g * g).sum(-1))).unsqueeze( 46 | 1) * g - epsilon * torch.sign(g) 47 | else: 48 | self.proj_layer = SOCProjFast() 49 | 50 | def project_action(self, u, x): 51 | if self.isD0: 52 | Px = x @ self.P 53 | g = 2 * Px @ self.B 54 | neg_h = self.alpha * (Px * x).sum(-1) + \ 55 | 2 * torch.norm(Px @ self.G, dim=1) * torch.norm(x @ self.C.T, dim=1) + \ 56 | 2 * (Px @ self.A * x).sum(-1) 57 | u = self.proj_layer(u, g, -neg_h) 58 | else: 59 | Px = x @ self.P 60 | const = torch.norm(x @ self.P @ self.G, dim=1) 61 | A = self.D.expand(x.shape[0], self.D.shape[0], self.D.shape[1]) 62 | b = x @ self.C.T 63 | c = (1 / const).unsqueeze(1) * (-Px @ self.B) 64 | d = -((2 * Px @ self.A + self.alpha * Px) * x).sum(-1) / (2 * const) 65 | 66 | u = self.proj_layer(u, A, b, c, d) 67 | 68 | return u 69 | 70 | def __getstate__(self): 71 | state = [self.P, self.A, self.B, self.G, self.C, self.D, self.alpha, self.isD0] 72 | return state 73 | 74 | def __setstate__(self, state): 75 | self.__init__(*state) 76 | 77 | 78 | class StablePLDIProjection: 79 | def __init__(self, P, A, B): 80 | self.P = P 81 | self.A = A 82 | self.B = B 83 | 84 | self.e = torch.DoubleTensor().to(device=A.device) 85 | 86 | def project_action(self, u, x): 87 | Px = x @ self.P 88 | G = 2 * Px.expand(self.B.shape[0], Px.shape[0], Px.shape[1]).bmm(self.B).transpose(0, 1) 89 | h = (Px * x).sum(-1).unsqueeze(1) + \ 90 | 2 * Px.expand(self.B.shape[0], Px.shape[0], Px.shape[1]).bmm(self.A).transpose(0, 1).matmul( 91 | x.unsqueeze(2)).squeeze(2) 92 | 93 | Q = torch.eye(u.shape[-1], device=x.device).unsqueeze(0).expand(u.shape[0], u.shape[-1], u.shape[-1]) 94 | res = QPFunction(verbose=-1)(Q.double(), -u.double(), G.double(), -h.double(), self.e, self.e) 95 | return res.type(TORCH_DTYPE) 96 | 97 | def __getstate__(self): 98 | state = [self.P, self.A, self.B] 99 | return state 100 | 101 | def __setstate__(self, state): 102 | self.__init__(*state) 103 | 104 | 105 | class StableHinfProjection: 106 | def __init__(self, P, A, B, G, Q, R, alpha, gamma, sigma): 107 | self.P = P 108 | self.A = A 109 | self.B = B 110 | self.G = G 111 | self.Q = Q 112 | self.R = R 113 | self.alpha = alpha 114 | self.gamma = gamma 115 | self.sigma = sigma 116 | 117 | def project_action(self, u, xin): 118 | x = xin.unsqueeze(-1) 119 | 120 | Atilde = sqrtm(self.sigma*self.R)/torch.sqrt(x.transpose(1,2)@( 121 | self.P@self.B@torch.inverse(self.R)@self.B.T@self.P/self.sigma - \ 122 | self.P@self.A - self.A.T@self.P - self.alpha*self.P - self.sigma * self.Q - \ 123 | self.P@self.G@self.G.T@self.P/(self.sigma*(self.gamma**2)))@x) 124 | btilde = Atilde@torch.inverse(self.R)@self.B.T@self.P@x/self.sigma 125 | 126 | ctilde = torch.zeros(Atilde.shape[0], Atilde.shape[2], device=xin.device, dtype=TORCH_DTYPE) 127 | dtilde = torch.ones(btilde.shape[0], device=xin.device, dtype=TORCH_DTYPE) 128 | 129 | u = SOCProjFast(momentum=False)(u, Atilde, btilde.squeeze(-1), ctilde, dtilde) 130 | 131 | return u 132 | 133 | def __getstate__(self): 134 | state = [self.P, self.A, self.B, self.G, self.Q, self.R, self.alpha, self.gamma] 135 | return state 136 | 137 | def __setstate__(self, state): 138 | self.__init__(*state) 139 | 140 | 141 | class StablePolicy(nn.Module): 142 | def __init__(self, pi, stable_projection): 143 | super().__init__() 144 | self.pi = pi 145 | self.stable_projection = stable_projection 146 | 147 | def forward(self, x): 148 | u = self.pi(x) 149 | u = self.stable_projection.project_action(u, x) 150 | return u 151 | 152 | 153 | # From https://github.com/locuslab/qpth/blob/master/qpth/util.py 154 | def bger(x, y): 155 | """Batch outer product""" 156 | return x.unsqueeze(2).bmm(y.unsqueeze(1)) 157 | 158 | 159 | def SOCProj(tol=1e-5, max_iters=1000000, rho=10): 160 | """Projection onto a second order cone constraint""" 161 | 162 | class SOCProjFn(Function): 163 | 164 | @staticmethod 165 | def forward(ctx, pi, A, b, c, d): 166 | G = torch.cat([A, c.unsqueeze(1)], dim=1) 167 | h = torch.cat([b, d.unsqueeze(-1)], dim=1) 168 | 169 | xkm1 = pi 170 | zkm1 = G.bmm(xkm1.unsqueeze(-1)).squeeze() + h 171 | ukm1 = torch.zeros_like(zkm1, device=zkm1.device, dtype=TORCH_DTYPE) 172 | 173 | # precompute inversion matrix for x update 174 | inv_mat = torch.inverse( 175 | torch.eye(pi.shape[-1], device=zkm1.device, dtype=TORCH_DTYPE).unsqueeze(0).expand(pi.shape[0], pi.shape[-1], pi.shape[-1]) + \ 176 | rho * G.transpose(1, 2).bmm(G)) 177 | 178 | for i in range(max_iters): 179 | xk = inv_mat.bmm( 180 | (pi.unsqueeze(-1) - G.transpose(1, 2).bmm((ukm1 - rho * zkm1 + rho * h).unsqueeze(-1)))).squeeze(-1) 181 | zk = SOCProjFn.proj_normcone(G.bmm(xk.unsqueeze(-1)).squeeze(-1) + h + ukm1 / rho) 182 | uk = ukm1 + rho * (G.bmm(xk.unsqueeze(-1)).squeeze(-1) - zk + h) 183 | 184 | if i % 10 == 0 and (torch.norm(xkm1 - xk, dim=1) < tol).all() and \ 185 | (torch.norm(zkm1 - zk, dim=1) < tol).all() and (torch.norm(ukm1 - uk, dim=1) < tol).all(): 186 | ctx.save_for_backward(xk, zk, uk, G, h) 187 | print(i) 188 | return xk 189 | 190 | xkm1 = xk 191 | zkm1 = zk 192 | ukm1 = uk 193 | 194 | warnings.warn('Max iterations reached') 195 | ctx.save_for_backward(xk, zk, uk, G, h) 196 | return xk 197 | 198 | @staticmethod 199 | def backward(ctx, dl_dx): 200 | x, z, u, G, h = ctx.saved_tensors 201 | m = x.shape[-1] 202 | w = z.shape[-1] # also equals u.shape[-1] 203 | loss_vec = torch.cat([dl_dx, 204 | torch.zeros(dl_dx.shape[0], w, device=dl_dx.device, dtype=TORCH_DTYPE), 205 | torch.zeros(dl_dx.shape[0], w, device=dl_dx.device, dtype=TORCH_DTYPE)], 206 | dim=1) 207 | 208 | dsoc = SOCProjFn.dproj_normcone(u / rho + G.bmm(x.unsqueeze(-1)).squeeze(-1) + h) 209 | mat = torch.cat([ 210 | torch.cat([ 211 | torch.eye(m).unsqueeze(0) + rho * G.transpose(1, 2).bmm(G), 212 | -rho * G.transpose(1, 2), 213 | G.transpose(1, 2)], dim=2), 214 | torch.cat([ 215 | -dsoc.bmm(G), 216 | torch.eye(w, device=dl_dx.device, dtype=TORCH_DTYPE).unsqueeze(0).expand(x.shape[0], w, w), 217 | -dsoc / rho], dim=2), 218 | torch.cat([ 219 | G, 220 | -torch.eye(w, device=dl_dx.device, dtype=TORCH_DTYPE).unsqueeze(0).expand(x.shape[0], w, w), 221 | torch.zeros(x.shape[0], w, w, device=dl_dx.device, dtype=TORCH_DTYPE)], dim=2)], 222 | dim=1) 223 | res = torch.inverse(mat.transpose(1, 2)).bmm(loss_vec.unsqueeze(-1)).squeeze(-1) 224 | d_x = res[:, :m] 225 | d_z = res[:, m:m + w] 226 | d_u = res[:, -w:] 227 | 228 | dldy = d_x 229 | dldh = -rho * G.bmm(d_x.unsqueeze(-1)).squeeze(-1) + dsoc.bmm(d_z.unsqueeze(-1)).squeeze(-1) - d_u 230 | dldG = bger(-rho * G.bmm(x.unsqueeze(-1)).squeeze(-1) - u + rho * z - rho * h, d_x) - \ 231 | bger(rho * G.bmm(d_x.unsqueeze(-1)).squeeze(-1), x) + \ 232 | bger(dsoc.bmm(d_z.unsqueeze(-1)).squeeze(-1), x) - \ 233 | bger(d_u, x) 234 | 235 | dldA = dldG[:, :-1, :] 236 | dldb = dldh[:, :-1] 237 | dldc = dldG[:, -1, :] 238 | dldd = dldh[:, -1] 239 | 240 | return dldy, dldA, dldb, dldc, dldd 241 | 242 | @staticmethod 243 | def proj_normcone(z_in): 244 | '''Deals with 3 cases of projections: in cone (case 1), in "negative" cone (case 2), other (case 3)''' 245 | z = z_in[:, :-1] 246 | t = z_in[:, -1] 247 | z_norm = torch.norm(z, dim=1) 248 | case1m = (z_norm <= t) 249 | case2m = (z_norm <= -t) 250 | case3v = (z_norm + t).unsqueeze(-1) / 2 * \ 251 | torch.cat([z / z_norm.unsqueeze(-1), torch.ones(t.shape[0], 1, device=z_in.device, dtype=TORCH_DTYPE)], dim=1) 252 | return case1m.unsqueeze(-1).expand_as(z_in) * z_in + \ 253 | ~(case1m | case2m).unsqueeze(-1).expand_as(z_in) * case3v 254 | 255 | @staticmethod 256 | def dproj_normcone(z_in): 257 | '''Deals with 3 cases of projections: in cone (case 1), in "negative" cone (case 2), other (case 3)''' 258 | z = z_in[:, :-1] 259 | t = z_in[:, -1] 260 | 261 | z_norm = torch.norm(z, dim=1) 262 | d1dz = (bger(z, z) + \ 263 | (z_norm + t).unsqueeze(1).unsqueeze(2) * ( 264 | z_norm.unsqueeze(1).unsqueeze(2) * torch.eye(z.shape[1], device=z_in.device, dtype=TORCH_DTYPE).unsqueeze(0).expand( 265 | z.shape[0], z.shape[1], z.shape[1]) 266 | - bger(z, z) / z_norm.unsqueeze(1).unsqueeze(2))) / ( 267 | 2 * z_norm.unsqueeze(1).unsqueeze(2) ** 2) 268 | d1dr = (z.T / (2 * z_norm)).T 269 | case3v = torch.cat([ 270 | torch.cat([d1dz, d1dr.unsqueeze(2)], dim=2), 271 | torch.cat([d1dr.unsqueeze(1), 0.5 * torch.ones(d1dr.shape[0], 1, 1, device=z_in.device, dtype=TORCH_DTYPE)], dim=2)], 272 | dim=1) 273 | 274 | case1m = (z_norm <= t) 275 | case2m = (z_norm <= -t) 276 | 277 | return case1m.unsqueeze(1).unsqueeze(2).expand_as(case3v) * torch.eye(z_in.shape[1], device=z_in.device, dtype=TORCH_DTYPE) + \ 278 | ~(case1m | case2m).unsqueeze(-1).unsqueeze(2).expand_as(case3v) * case3v 279 | 280 | return SOCProjFn.apply 281 | 282 | 283 | def SOCProjFast(tol=1e-5, max_iters=10000, momentum=True): 284 | """Projection onto a second order cone constraint""" 285 | 286 | class SOCProjFastFn(Function): 287 | 288 | @staticmethod 289 | def forward(ctx, pi, A, b, c, d): 290 | G = torch.cat([A, c.unsqueeze(1)], dim=1) 291 | h = torch.cat([b, d.unsqueeze(-1)], dim=1) 292 | 293 | H = G.bmm(G.transpose(1,2)) 294 | eig_H = torch.symeig(H,eigenvectors=False).eigenvalues 295 | 296 | mh = torch.min(eig_H,1)[0] 297 | Lh = torch.max(eig_H,1)[0] 298 | 299 | ## to avoid extremely small but negative mh 300 | threshold = 1e-5 301 | mh = (mh>threshold)*mh 302 | 303 | momentum_param = lambda iter: (momentum)*((mh>0)*((torch.sqrt(Lh)-torch.sqrt(mh))/(torch.sqrt(Lh)+torch.sqrt(mh))) + (mh==0)*((iter)/(iter+3))).unsqueeze(-1) 304 | 305 | step_size = (1/Lh).unsqueeze(-1) 306 | 307 | 308 | ## initial condition 309 | lamk = torch.zeros_like(h) 310 | lamkm1 = lamk 311 | xkm1 = pi 312 | 313 | for i in range(max_iters): 314 | 315 | vk = lamk + momentum_param(i) * (lamk-lamkm1) 316 | lamkp1 = SOCProjFastFn.proj_normcone(vk - step_size * (H.bmm(vk.unsqueeze(-1)) + G.bmm(pi.unsqueeze(-1))+h.unsqueeze(-1)).squeeze(-1)) 317 | 318 | lamkm1 = lamk 319 | lamk = lamkp1 320 | # print(torch.max(rd)) 321 | 322 | xk = pi + G.transpose(1, 2).bmm(lamk.unsqueeze(-1)).squeeze(-1) 323 | if torch.norm(xkm1 - xk, dim=1).max() < tol: 324 | ctx.save_for_backward(xk, -lamk, G, h) 325 | return xk 326 | xkm1 = xk 327 | 328 | warnings.warn('Max iterations reached') 329 | xk = pi + G.transpose(1,2).bmm(lamk.unsqueeze(-1)).squeeze(-1) 330 | ctx.save_for_backward(xk, -lamk, G, h) 331 | return xk 332 | 333 | @staticmethod 334 | def backward(ctx, dl_dx): 335 | x, u, G, h = ctx.saved_tensors 336 | z = G.bmm(x.unsqueeze(-1)).squeeze(-1) + h 337 | m = x.shape[-1] 338 | w = z.shape[-1] # also equals u.shape[-1] 339 | loss_vec = torch.cat([dl_dx, 340 | torch.zeros(dl_dx.shape[0], w, device=dl_dx.device, dtype=TORCH_DTYPE), 341 | torch.zeros(dl_dx.shape[0], w, device=dl_dx.device, dtype=TORCH_DTYPE)], 342 | dim=1) 343 | 344 | dsoc = SOCProjFastFn.dproj_normcone(u + G.bmm(x.unsqueeze(-1)).squeeze(-1) + h) 345 | mat = torch.cat([ 346 | torch.cat([ 347 | torch.eye(m, device=dl_dx.device, dtype=TORCH_DTYPE).unsqueeze(0) + G.transpose(1, 2).bmm(G), 348 | -G.transpose(1, 2), 349 | G.transpose(1, 2)], dim=2), 350 | torch.cat([ 351 | -dsoc.bmm(G), 352 | torch.eye(w, device=dl_dx.device, dtype=TORCH_DTYPE).unsqueeze(0).expand(x.shape[0], w, w), 353 | -dsoc], dim=2), 354 | torch.cat([ 355 | G, 356 | -torch.eye(w, device=dl_dx.device, dtype=TORCH_DTYPE).unsqueeze(0).expand(x.shape[0], w, w), 357 | torch.zeros(x.shape[0], w, w, device=dl_dx.device, dtype=TORCH_DTYPE)], dim=2)], 358 | dim=1) 359 | 360 | try: 361 | res = torch.inverse(mat.transpose(1, 2)).bmm(loss_vec.unsqueeze(-1)).squeeze(-1) 362 | except RuntimeError: 363 | # if (torch.det(mat.transpose(1, 2)) == 0).any(): 364 | # print(dsoc[i, :, :]) 365 | warnings.warn('Singular matrix in backwards pass') 366 | res = torch.zeros(mat.shape[0], loss_vec.shape[-1], device=dl_dx.device, dtype=TORCH_DTYPE) # could also do 1e-5 * torch.ones 367 | i = np.argwhere((torch.det(mat.transpose(1, 2)) != 0).cpu()).reshape(-1).to(device=dl_dx.device) 368 | if i.shape[0] > 0: 369 | res0 = torch.inverse(mat[i].transpose(1, 2)).bmm(loss_vec[i].unsqueeze(-1)).squeeze(-1) 370 | res[i] = res0 371 | 372 | d_x = res[:, :m] 373 | d_z = res[:, m:m + w] 374 | d_u = res[:, -w:] 375 | 376 | dldy = d_x 377 | dldh = -G.bmm(d_x.unsqueeze(-1)).squeeze(-1) + dsoc.bmm(d_z.unsqueeze(-1)).squeeze(-1) - d_u 378 | dldG = bger(-G.bmm(x.unsqueeze(-1)).squeeze(-1) - u + z - h, d_x) - \ 379 | bger(G.bmm(d_x.unsqueeze(-1)).squeeze(-1), x) + \ 380 | bger(dsoc.bmm(d_z.unsqueeze(-1)).squeeze(-1), x) - \ 381 | bger(d_u, x) 382 | 383 | dldA = dldG[:, :-1, :] 384 | dldb = dldh[:, :-1] 385 | dldc = dldG[:, -1, :] 386 | dldd = dldh[:, -1] 387 | 388 | return dldy, dldA, dldb, dldc, dldd 389 | 390 | @staticmethod 391 | def proj_normcone(z_in): 392 | '''Deals with 3 cases of projections: in cone (case 1), in "negative" cone (case 2), other (case 3)''' 393 | z = z_in[:, :-1] 394 | t = z_in[:, -1] 395 | z_norm = torch.norm(z, dim=1) 396 | case1m = (z_norm <= t) 397 | case2m = (z_norm <= -t) 398 | case3v = (z_norm + t).unsqueeze(-1) / 2 * \ 399 | torch.cat([z / z_norm.unsqueeze(-1), torch.ones(t.shape[0], 1, device=z_in.device, dtype=TORCH_DTYPE)], dim=1) 400 | return case1m.unsqueeze(-1).expand_as(z_in) * z_in + \ 401 | ~(case1m | case2m).unsqueeze(-1).expand_as(z_in) * case3v 402 | 403 | @staticmethod 404 | def dproj_normcone(z_in): 405 | '''Deals with 3 cases of projections: in cone (case 1), in "negative" cone (case 2), other (case 3)''' 406 | z = z_in[:, :-1] 407 | t = z_in[:, -1] 408 | 409 | z_norm = torch.norm(z, dim=1) 410 | d1dz = (bger(z, z) + 411 | (z_norm + t).unsqueeze(1).unsqueeze(2) * ( 412 | z_norm.unsqueeze(1).unsqueeze(2) * torch.eye(z.shape[1], device=z_in.device, dtype=TORCH_DTYPE).unsqueeze(0).expand( 413 | z.shape[0], z.shape[1], z.shape[1]) 414 | - bger(z, z) / z_norm.unsqueeze(1).unsqueeze(2))) / ( 415 | 2 * z_norm.unsqueeze(1).unsqueeze(2) ** 2) 416 | d1dr = (z.T / (2 * z_norm)).T 417 | case3v = torch.cat([ 418 | torch.cat([d1dz, d1dr.unsqueeze(2)], dim=2), 419 | torch.cat([d1dr.unsqueeze(1), 0.5 * torch.ones(d1dr.shape[0], 1, 1, device=z_in.device, dtype=TORCH_DTYPE)], dim=2)], 420 | dim=1) 421 | 422 | case1m = (z_norm <= t) 423 | case2m = (z_norm <= -t) 424 | 425 | return case1m.unsqueeze(1).unsqueeze(2).expand_as(case3v) * torch.eye(z_in.shape[1], device=z_in.device, dtype=TORCH_DTYPE) + \ 426 | ~(case1m | case2m).unsqueeze(-1).unsqueeze(2).expand_as(case3v) * case3v 427 | 428 | return SOCProjFastFn.apply -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg as la 3 | import cvxpy as cp 4 | import torch 5 | import torch.optim as optim 6 | import argparse 7 | import setproctitle 8 | import os 9 | from gym import spaces 10 | import tqdm 11 | 12 | import policy_models as pm 13 | import disturb_models as dm 14 | import robust_mpc as rmpc 15 | 16 | from envs.random_nldi_env import RandomNLDIEnv 17 | from envs.cartpole import CartPoleEnv 18 | from envs.quadrotor_env import QuadrotorEnv 19 | from envs.random_pldi_env import RandomPLDIEnv 20 | from envs.random_hinf_env import RandomHinfEnv 21 | from envs.microgrid import MicrogridEnv 22 | 23 | from constants import * 24 | 25 | from rl.ppo import PPO 26 | from rl.rarl_ppo import RARLPPO 27 | from rl.model import Policy 28 | from rl.storage import RolloutStorage 29 | from rl import trainer 30 | from rl import arguments 31 | from envs.rl_wrapper import RLWrapper 32 | 33 | # import ipdb 34 | # import sys 35 | # from IPython.core import ultratb 36 | # sys.excepthook = ultratb.FormattedTB(mode='Verbose', 37 | # color_scheme='Linux', call_pdb=1) 38 | 39 | 40 | def main(): 41 | parser = argparse.ArgumentParser( 42 | description='Run robust control experiments.') 43 | parser.add_argument('--baseLR', type=float, default=1e-3, 44 | help='learning rate for non-projected DPS') 45 | parser.add_argument('--robustLR', type=float, default=1e-4, 46 | help='learning rate for projected DPS') 47 | parser.add_argument('--alpha', type=float, default=0.001, 48 | help='exponential stability coefficient') 49 | parser.add_argument('--gamma', type=float, default=20, 50 | help='bound on L2 gain of disturbance-to-output map (for H_inf control)') 51 | parser.add_argument('--epochs', type=int, default=1000, 52 | help='max epochs') 53 | parser.add_argument('--test_frequency', type=int, default=20, 54 | help='frequency of testing during training') 55 | parser.add_argument('--T', type=float, default=2, 56 | help='time horizon in seconds') 57 | parser.add_argument('--dt', type=float, default=0.01, 58 | help='time increment') 59 | parser.add_argument('--testSetSz', type=int, default=50, 60 | help='size of test set') 61 | parser.add_argument('--holdSetSz', type=int, default=50, 62 | help='size of holdout set') 63 | parser.add_argument('--trainBatchSz', type=int, default=20, 64 | help='batch size for training') 65 | parser.add_argument('--stepType', type=str, 66 | choices=['euler', 'RK4', 'scipy'], default='RK4', 67 | help='method for taking steps during training') 68 | parser.add_argument('--testStepType', type=str, 69 | choices=['euler', 'RK4', 'scipy'], default='RK4', 70 | help='method for taking steps during testing') 71 | parser.add_argument('--env', type=str, 72 | choices=['random_nldi-d0', 'random_nldi-dnonzero', 'random_pldi_env', 73 | 'random_hinf_env', 'cartpole', 'quadrotor', 'microgrid'], 74 | default='random_nldi-d0', 75 | help='environment') 76 | parser.add_argument('--envRandomSeed', type=int, default=10, 77 | help='random seed used to construct the environment') 78 | parser.add_argument('--save', type=str, 79 | help='prefix to add to save path') 80 | parser.add_argument('--gpu', type=int, default=0, 81 | help='prefix to add to save path') 82 | parser.add_argument('--evaluate', type=str, 83 | help='instead of training, evaluate the models from a given directory' 84 | ' (remember to use the same random seed)') 85 | args = parser.parse_args() 86 | 87 | dt = args.dt 88 | save_sub = '{}+alpha{}+gamma{}+testSz{}+holdSz{}+trainBatch{}+baselr{}+robustlr{}+T{}+stepType{}+testStepType{}+seed{}+dt{}'.format( 89 | args.env, args.alpha, args.gamma, args.testSetSz, args.holdSetSz, 90 | args.trainBatchSz, args.baseLR, args.robustLR, args.T, 91 | args.stepType, args.testStepType, args.envRandomSeed, dt) 92 | if args.save is not None: 93 | save = os.path.join('results', '{}+{}'.format(args.save, save_sub)) 94 | else: 95 | save = os.path.join('results', save_sub) 96 | trained_model_dir = os.path.join(save, 'trained_models') 97 | if not os.path.exists(trained_model_dir): 98 | os.makedirs(trained_model_dir) 99 | setproctitle.setproctitle(save_sub) 100 | 101 | device = torch.device('cuda:%d' % args.gpu if torch.cuda.is_available() else 'cpu') 102 | if torch.cuda.is_available(): 103 | torch.cuda.set_device(args.gpu) 104 | 105 | # Setup 106 | isD0 = (args.env == 'random_nldi-d0') or (args.env == 'quadrotor') # no u dependence in disturbance bound 107 | problem_type = 'nldi' 108 | if 'random_nldi' in args.env: 109 | env = RandomNLDIEnv(isD0=isD0, random_seed=args.envRandomSeed, device=device) 110 | elif args.env == 'random_pldi_env': 111 | env = RandomPLDIEnv(random_seed=args.envRandomSeed, device=device) 112 | problem_type = 'pldi' 113 | elif args.env == 'random_hinf_env': 114 | env = RandomHinfEnv(T=args.T, random_seed=args.envRandomSeed, device=device) 115 | problem_type = 'hinf' 116 | elif args.env == 'cartpole': 117 | env = CartPoleEnv(random_seed=args.envRandomSeed, device=device) 118 | elif args.env == 'quadrotor': 119 | env = QuadrotorEnv(random_seed=args.envRandomSeed, device=device) 120 | elif args.env == 'microgrid': 121 | env = MicrogridEnv(random_seed=args.envRandomSeed, device=device) 122 | else: 123 | raise ValueError('No environment named %s' % args.env) 124 | evaluate_dir = args.evaluate 125 | evaluate = evaluate_dir is not None 126 | 127 | # Test and holdout set of states 128 | torch.manual_seed(17) 129 | x_test = env.gen_states(num_states=args.testSetSz, device=device) 130 | x_hold = env.gen_states(num_states=args.holdSetSz, device=device) 131 | num_episode_steps = int(args.T / dt) 132 | 133 | if problem_type == 'nldi': 134 | A, B, G, C, D, Q, R = env.get_nldi_linearization() 135 | state_dim = A.shape[0] 136 | action_dim = B.shape[1] 137 | 138 | # Get LQR solutions 139 | Kct, Pct = get_lqr_tensors(A, B, Q, R, args.alpha, device) 140 | 141 | Kr, Sr = get_robust_lqr_sol(*(v.cpu().numpy() for v in (A, B, G, C, D, Q, R)), args.alpha) 142 | Krt = torch.tensor(Kr, device=device, dtype=TORCH_DTYPE) 143 | Prt = torch.tensor(np.linalg.inv(Sr), device=device, dtype=TORCH_DTYPE) 144 | stable_projection = pm.StableNLDIProjection(Prt, A, B, G, C, D, args.alpha, isD0) 145 | 146 | disturb_model = dm.MultiNLDIDisturbModel(x_test.shape[0], C, D, state_dim, action_dim, env.wp) 147 | disturb_model.to(device=device, dtype=TORCH_DTYPE) 148 | 149 | elif problem_type == 'pldi': 150 | A, B, Q, R = env.get_pldi_linearization() 151 | state_dim = A.shape[1] 152 | action_dim = B.shape[2] 153 | 154 | # Get LQR solutions 155 | Kct, Pct = get_lqr_tensors(A.mean(0), B.mean(0), Q, R, args.alpha, device) 156 | 157 | Kr, Sr = get_robust_pldi_policy(*(v.cpu().numpy() for v in (A, B, Q, R)), args.alpha) 158 | Krt = torch.tensor(Kr, device=device, dtype=TORCH_DTYPE) 159 | Prt = torch.tensor(np.linalg.inv(Sr), device=device, dtype=TORCH_DTYPE) 160 | stable_projection = pm.StablePLDIProjection(Prt, A, B) 161 | 162 | disturb_model = dm.MultiPLDIDisturbModel(x_test.shape[0], state_dim, action_dim, env.L) 163 | disturb_model.to(device=device, dtype=TORCH_DTYPE) 164 | 165 | elif problem_type == 'hinf': 166 | A, B, G, Q, R = env.get_hinf_linearization() 167 | state_dim = A.shape[0] 168 | action_dim = B.shape[1] 169 | 170 | # Get LQR solutions 171 | Kct, Pct = get_lqr_tensors(A, B, Q, R, args.alpha, device) 172 | 173 | Kr, Sr, mu = get_robust_hinf_policy(*(v.cpu().numpy() for v in (A, B, G, Q, R)), args.alpha, args.gamma) 174 | Krt = torch.tensor(Kr, device=device, dtype=TORCH_DTYPE) 175 | Prt = torch.tensor(np.linalg.inv(Sr), device=device, dtype=TORCH_DTYPE) 176 | stable_projection = pm.StableHinfProjection(Prt, A, B, G, Q, R, args.alpha, args.gamma, 1/mu) 177 | 178 | disturb_model = dm.MultiHinfDisturbModel(x_test.shape[0], state_dim, action_dim, env.wp, args.T) 179 | disturb_model.to(device=device, dtype=TORCH_DTYPE) 180 | 181 | else: 182 | raise ValueError('No problem type named %s' % problem_type) 183 | 184 | adv_disturb_model = dm.MBAdvDisturbModel(env, None, disturb_model, dt, horizon=num_episode_steps//5, update_freq=num_episode_steps//20) 185 | env.adversarial_disturb_f = adv_disturb_model 186 | 187 | ########################################################### 188 | # LQR baselines 189 | ########################################################### 190 | 191 | ### Vanilla LQR (i.e., non-robust, exponentially stable) 192 | pi_custom_lqr = lambda x: x @ Kct.T 193 | adv_disturb_model.set_policy(pi_custom_lqr) 194 | 195 | custom_lqr_perf = eval_model(x_test, pi_custom_lqr, env, 196 | step_type=args.testStepType, T=args.T, dt=dt) 197 | write_results(custom_lqr_perf, 'LQR', save) 198 | custom_lqr_perf = eval_model(x_test, pi_custom_lqr, env, 199 | step_type=args.testStepType, T=args.T, dt=dt, adversarial=True) 200 | write_results(custom_lqr_perf, 'LQR-adv', save) 201 | 202 | ### Robust LQR 203 | pi_robust_lqr = lambda x: x @ Krt.T 204 | adv_disturb_model.set_policy(pi_robust_lqr) 205 | 206 | robust_lqr_perf = eval_model(x_test, pi_robust_lqr, env, 207 | step_type=args.testStepType, T=args.T, dt=dt) 208 | write_results(robust_lqr_perf, 'Robust LQR', save) 209 | robust_lqr_perf = eval_model(x_test, pi_robust_lqr, env, 210 | step_type=args.testStepType, T=args.T, dt=dt, adversarial=True) 211 | write_results(robust_lqr_perf, 'Robust LQR-adv', save) 212 | 213 | 214 | ########################################################### 215 | # Model-based planning methods 216 | ########################################################### 217 | 218 | ### Non-robust MBP (starting with robust LQR solution) 219 | pi_mbp = pm.MBPPolicy(Krt, state_dim, action_dim) 220 | pi_mbp.to(device=device, dtype=TORCH_DTYPE) 221 | adv_disturb_model.set_policy(pi_mbp) 222 | 223 | if evaluate: 224 | pi_mbp.load_state_dict(torch.load(os.path.join(evaluate_dir, 'mbp.pt'))) 225 | else: 226 | pi_mbp_dict, train_losses, hold_losses, test_losses, test_losses_adv, stop_epoch = \ 227 | train(pi_mbp, x_test, x_hold, env, 228 | lr=args.baseLR, batch_size=args.trainBatchSz, epochs=args.epochs, T=args.T, dt=dt, step_type=args.stepType, 229 | test_frequency=args.test_frequency, save_dir=save, model_name='mbp', device=device) 230 | save_results(train_losses, hold_losses, test_losses, test_losses_adv, save, 'mbp', pi_mbp_dict, epoch=stop_epoch, 231 | is_final=True) 232 | torch.save(pi_mbp_dict, os.path.join(trained_model_dir, 'mbp.pt')) 233 | 234 | pi_mbp_perf = eval_model(x_test, pi_mbp, env, 235 | step_type=args.testStepType, T=args.T, dt=dt) 236 | write_results(pi_mbp_perf, 'MBP', save) 237 | pi_mbp_perf = eval_model(x_test, pi_mbp, env, 238 | step_type=args.testStepType, T=args.T, dt=dt, adversarial=True) 239 | write_results(pi_mbp_perf, 'MBP-adv', save) 240 | 241 | 242 | ### Robust MBP (starting with robust LQR solution) 243 | pi_robust_mbp = pm.StablePolicy(pm.MBPPolicy(Krt, state_dim, action_dim), stable_projection) 244 | pi_robust_mbp.to(device=device, dtype=TORCH_DTYPE) 245 | adv_disturb_model.set_policy(pi_robust_mbp) 246 | 247 | if evaluate: 248 | pi_robust_mbp.load_state_dict(torch.load(os.path.join(evaluate_dir, 'robust_mbp.pt'))) 249 | else: 250 | pi_robust_mbp_dict, train_losses, hold_losses, test_losses, test_losses_adv, stop_epoch = \ 251 | train(pi_robust_mbp, x_test, x_hold, env, 252 | lr=args.robustLR, batch_size=args.trainBatchSz, epochs=args.epochs, T=args.T, dt=dt, step_type=args.stepType, 253 | test_frequency=args.test_frequency, save_dir=save, model_name='robust_mbp', device=device) 254 | save_results(train_losses, hold_losses, test_losses, test_losses_adv, save, 'robust_mbp', pi_robust_mbp_dict, epoch=stop_epoch, 255 | is_final=True) 256 | torch.save(pi_robust_mbp_dict, os.path.join(trained_model_dir, 'robust_mbp.pt')) 257 | 258 | pi_robust_mbp_perf = eval_model(x_test, pi_robust_mbp, env, 259 | step_type=args.testStepType, T=args.T, dt=dt) 260 | write_results(pi_robust_mbp_perf, 'Robust MBP', save) 261 | pi_robust_mbp_perf = eval_model(x_test, pi_robust_mbp, env, 262 | step_type=args.testStepType, T=args.T, dt=dt, adversarial=True) 263 | write_results(pi_robust_mbp_perf, 'Robust MBP-adv', save) 264 | 265 | 266 | ########################################################### 267 | # RL methods 268 | ########################################################### 269 | 270 | if 'random_nldi' in args.env: 271 | if isD0: 272 | rmax = 1000 273 | else: 274 | rmax = 1000 275 | elif args.env == 'random_pldi_env': 276 | rmax = 10 277 | elif args.env == 'random_hinf_env': 278 | rmax = 1000 279 | elif args.env == 'cartpole': 280 | rmax = 10 281 | elif args.env == 'quadrotor': 282 | rmax = 1000 283 | elif args.env == 'microgrid': 284 | rmax = 10 285 | else: 286 | raise ValueError('No environment named %s' % args.env) 287 | 288 | rl_args = arguments.get_args() 289 | linear_controller_K = Krt 290 | linear_controller_P = Prt 291 | linear_transform = lambda u, x: u + x @ linear_controller_K.T 292 | 293 | 294 | ### Vanilla and robust PPO 295 | base_ppo_perfs = [] 296 | base_ppo_adv_perfs = [] 297 | robust_ppo_perfs = [] 298 | robust_ppo_adv_perfs = [] 299 | for seed in range(1): 300 | for robust in [False, True]: 301 | torch.manual_seed(seed) 302 | 303 | if robust: 304 | # stable_projection = pm.StableNLDIProjection(linear_controller_P, A, B, G, C, D, args.alpha, isD0=isD0) 305 | action_transform = lambda u, x: stable_projection.project_action(linear_transform(u, x), x) 306 | else: 307 | action_transform = linear_transform 308 | 309 | envs = RLWrapper(env, state_dim, action_dim, gamma=rl_args.gamma, 310 | dt=dt, rmax=rmax, step_type='RK4', action_transform=action_transform, 311 | num_envs=rl_args.num_processes, device=device) 312 | eval_envs = RLWrapper(env, state_dim, action_dim, gamma=rl_args.gamma, 313 | dt=dt, rmax=rmax, step_type='RK4', action_transform=action_transform, 314 | num_envs=args.testSetSz, device=device) 315 | 316 | actor_critic = Policy( 317 | envs.observation_space.shape, 318 | envs.action_space, 319 | base_kwargs={'recurrent': False}) 320 | actor_critic.to(device=device, dtype=TORCH_DTYPE) 321 | agent = PPO( 322 | actor_critic, 323 | rl_args.clip_param, 324 | rl_args.ppo_epoch, 325 | rl_args.num_mini_batch, 326 | rl_args.value_loss_coef, 327 | rl_args.entropy_coef, 328 | lr=rl_args.lr, 329 | eps=rl_args.rms_prop_eps, 330 | max_grad_norm=rl_args.max_grad_norm, 331 | use_linear_lr_decay=rl_args.use_linear_lr_decay) 332 | rollouts = RolloutStorage(num_episode_steps, rl_args.num_processes, 333 | envs.observation_space.shape, envs.action_space, 334 | actor_critic.recurrent_hidden_state_size) 335 | 336 | ppo_pi = lambda x: action_transform(actor_critic.act(x, None, None, deterministic=True)[1], x) 337 | adv_disturb_model.set_policy(ppo_pi) 338 | 339 | if evaluate: 340 | actor_critic.load_state_dict(torch.load(os.path.join(evaluate_dir, 341 | 'robust_ppo.pt' if robust else 'ppo.pt'))) 342 | else: 343 | hold_costs, test_costs, adv_test_costs =\ 344 | trainer.train(agent, envs, rollouts, device, rl_args, 345 | eval_envs=eval_envs, x_hold=x_hold, x_test=x_test, num_episode_steps=num_episode_steps, 346 | save_dir=os.path.join(save, 'robust_ppo' if robust else 'ppo'), 347 | save_extension='%d' % seed) 348 | save_results(np.zeros_like(hold_costs), hold_costs, test_costs, adv_test_costs, save, 349 | 'robust_ppo' if robust else 'ppo', actor_critic.state_dict(), 350 | epoch=rl_args.num_env_steps, is_final=True) 351 | torch.save(actor_critic.state_dict(), os.path.join(trained_model_dir, 352 | 'robust_ppo.pt' if robust else 'ppo.pt')) 353 | 354 | ppo_perf = eval_model(x_test, ppo_pi, env, 355 | step_type=args.testStepType, T=args.T, dt=dt) 356 | ppo_adv_perf = eval_model(x_test, ppo_pi, env, 357 | step_type=args.testStepType, T=args.T, dt=dt, adversarial=True) 358 | 359 | if robust: 360 | robust_ppo_perfs.append(ppo_perf.item()) 361 | robust_ppo_adv_perfs.append(ppo_adv_perf.item()) 362 | else: 363 | base_ppo_perfs.append(ppo_perf.item()) 364 | base_ppo_adv_perfs.append(ppo_adv_perf.item()) 365 | 366 | write_results(base_ppo_perfs, 'PPO', save) 367 | write_results(robust_ppo_perfs, 'Robust PPO', save) 368 | write_results(base_ppo_adv_perfs, 'PPO-adv', save) 369 | write_results(robust_ppo_adv_perfs, 'Robust PPO-adv', save) 370 | 371 | 372 | # RARL PPO baseline 373 | adv_ppo_perfs = [] 374 | adv_ppo_adv_perfs = [] 375 | seed = 0 376 | torch.manual_seed(seed) 377 | 378 | action_transform = linear_transform 379 | 380 | envs = RLWrapper(env, state_dim, action_dim, gamma=rl_args.gamma, 381 | dt=dt, rmax=rmax, step_type='RK4', action_transform=action_transform, 382 | num_envs=rl_args.num_processes, device=device, rarl=True) 383 | eval_envs = RLWrapper(env, state_dim, action_dim, gamma=rl_args.gamma, 384 | dt=dt, rmax=rmax, step_type='RK4', action_transform=action_transform, 385 | num_envs=args.testSetSz, device=device) 386 | 387 | protagornist_ac = Policy( 388 | envs.observation_space.shape, 389 | envs.action_space, 390 | base_kwargs={'recurrent': False}) 391 | protagornist_ac.to(device=device, dtype=TORCH_DTYPE) 392 | adversary_ac = Policy( 393 | envs.observation_space.shape, 394 | envs.disturb_space, 395 | base_kwargs={'recurrent': False}) 396 | adversary_ac.to(device=device, dtype=TORCH_DTYPE) 397 | agent = RARLPPO( 398 | protagornist_ac, 399 | adversary_ac, 400 | rl_args.clip_param, 401 | rl_args.ppo_epoch, 402 | rl_args.num_mini_batch, 403 | rl_args.value_loss_coef, 404 | rl_args.entropy_coef, 405 | lr=rl_args.lr, 406 | eps=rl_args.rms_prop_eps, 407 | max_grad_norm=rl_args.max_grad_norm, 408 | use_linear_lr_decay=rl_args.use_linear_lr_decay) 409 | action_space = spaces.Box(low=0, high=1, 410 | shape=(envs.action_space.shape[0]+envs.disturb_space.shape[0],), dtype=NUMPY_DTYPE) 411 | rollouts = RolloutStorage(num_episode_steps, rl_args.num_processes, 412 | envs.observation_space.shape, action_space, 413 | protagornist_ac.recurrent_hidden_state_size + adversary_ac.recurrent_hidden_state_size, 414 | rarl=True) 415 | 416 | ppo_pi = lambda x: action_transform(protagornist_ac.act(x, None, None, deterministic=True)[1], x) 417 | adv_disturb_model.set_policy(ppo_pi) 418 | 419 | if evaluate: 420 | agent.load(evaluate_dir) 421 | else: 422 | hold_costs, test_costs, adv_test_costs = \ 423 | trainer.train(agent, envs, rollouts, device, rl_args, 424 | eval_envs=eval_envs, x_hold=x_hold, x_test=x_test, 425 | num_episode_steps=num_episode_steps, 426 | save_dir=os.path.join(save, 'rarl_ppo'), 427 | save_extension='%d' % seed) 428 | save_results(np.zeros_like(hold_costs), hold_costs, test_costs, adv_test_costs, save, 429 | 'rarl_ppo', protagornist_ac.state_dict(), 430 | epoch=rl_args.num_env_steps, is_final=True) 431 | agent.save(trained_model_dir) 432 | env.disturb_f.disturbance = None 433 | 434 | ppo_perf = eval_model(x_test, ppo_pi, env, 435 | step_type=args.testStepType, T=args.T, dt=dt) 436 | ppo_adv_perf = eval_model(x_test, ppo_pi, env, 437 | step_type=args.testStepType, T=args.T, dt=dt, adversarial=True) 438 | 439 | adv_ppo_perfs.append(ppo_perf.item()) 440 | adv_ppo_adv_perfs.append(ppo_adv_perf.item()) 441 | 442 | write_results(adv_ppo_perfs, 'RARL PPO', save) 443 | write_results(adv_ppo_adv_perfs, 'RARL PPO-adv', save) 444 | 445 | 446 | ########################################################### 447 | # MPC baselines 448 | ########################################################### 449 | 450 | ### Robust MPC (not implemented for H_infinity settings) 451 | if problem_type != 'hinf': 452 | if problem_type == 'nldi': 453 | robust_mpc_model = rmpc.RobustNLDIMPC(A, B, G, C, D, Q, R, Krt, device) 454 | else: 455 | robust_mpc_model = rmpc.RobustPLDIMPC(A, B, Q, R, Krt, device) 456 | 457 | pi_robust_mpc = robust_mpc_model.get_action 458 | adv_disturb_model.set_policy(pi_robust_mpc) 459 | 460 | robust_mpc_perf = eval_model(x_test, pi_robust_mpc, env, 461 | step_type=args.testStepType, T=args.T, dt=dt, adversarial=True) 462 | write_results(robust_mpc_perf, 'Robust MPC-adv', save) 463 | 464 | 465 | 466 | def get_lqr_tensors(At, Bt, Qt, Rt, alpha, device): 467 | K, S = get_custom_lqr_sol(*(v.cpu().numpy() for v in (At, Bt, Qt, Rt)), alpha) 468 | Kt = torch.tensor(K, device=device, dtype=TORCH_DTYPE) 469 | Pt = torch.tensor(np.linalg.inv(S), device=device, dtype=TORCH_DTYPE) 470 | 471 | return Kt, Pt 472 | 473 | 474 | def get_custom_lqr_sol(A, B, Q, R, alpha): 475 | n, m = B.shape 476 | S = cp.Variable((n, n), symmetric=True) 477 | Y = cp.Variable((m, n)) 478 | 479 | R_sqrt = la.sqrtm(R) 480 | f = cp.trace(S @ Q) + cp.matrix_frac(Y.T @ R_sqrt, S) 481 | 482 | # Exponential stability constraints from LMI book 483 | cons = [S >> np.eye(n)] # make LMI non-homogeneous 484 | cons += [A @ S + S @ A.T + B @ Y + Y.T @ B.T << -alpha * S] 485 | 486 | cp.Problem(cp.Minimize(f), cons).solve() 487 | K = np.linalg.solve(S.value, Y.value.T).T 488 | S = S.value 489 | 490 | return np.array(K), np.array(S) 491 | 492 | 493 | def get_robust_lqr_sol(A, B, G, C, D, Q, R, alpha): 494 | n, m = B.shape 495 | wq = C.shape[0] 496 | 497 | S = cp.Variable((n, n), symmetric=True) 498 | Y = cp.Variable((m, n)) 499 | mu = cp.Variable() 500 | 501 | R_sqrt = la.sqrtm(R) 502 | f = cp.trace(S @ Q) + cp.matrix_frac(Y.T @ R_sqrt, S) 503 | 504 | cons_mat = cp.bmat(( 505 | (A @ S + S @ A.T + cp.multiply(mu, G @ G.T) + B @ Y + Y.T @ B.T + alpha * S, S @ C.T + Y.T @ D.T), 506 | (C @ S + D @ Y, -cp.multiply(mu, np.eye(wq))) 507 | )) 508 | cons = [S >> 0, mu >= 1e-2] + [cons_mat << 0] 509 | 510 | try: 511 | prob = cp.Problem(cp.Minimize(f), cons) 512 | prob.solve(solver=cp.SCS) 513 | except cp.error.SolverError as e: 514 | raise ValueError('Solver failed with error: %s \n Try another environment seed' % e) 515 | K = np.linalg.solve(S.value, Y.value.T).T 516 | 517 | return K, S.value 518 | 519 | 520 | def get_robust_pldi_policy(A, B, Q, R, alpha): 521 | L, n, m = B.shape 522 | S = cp.Variable((n, n), symmetric=True) 523 | Y = cp.Variable((m, n)) 524 | 525 | R_sqrt = la.sqrtm(R) 526 | 527 | f = cp.trace(S @ Q) + cp.matrix_frac(Y.T @ R_sqrt, S) 528 | cons = [S >> np.eye(n)] + [A[i, :, :] @ S + B[i, :, :] @ Y + S @ A[i, :, :].T + Y.T @ B[i, :, :].T << -alpha * S for i in range(A.shape[0])] 529 | prob = cp.Problem(cp.Minimize(f), cons) 530 | prob.solve(solver=cp.MOSEK) 531 | K = np.linalg.solve(S.value, Y.value.T).T 532 | return K, S.value 533 | 534 | 535 | def get_robust_hinf_policy(A, B, G, Q, R, alpha, gamma): 536 | n, m = B.shape 537 | wq = G.shape[1] 538 | 539 | S = cp.Variable((n, n), symmetric=True) 540 | Y = cp.Variable((m, n)) 541 | mu = cp.Variable() 542 | 543 | Q_sqrt = la.sqrtm(Q) 544 | R_sqrt = la.sqrtm(R) 545 | f = cp.trace(S @ Q) + cp.matrix_frac(Y.T @ R_sqrt, S) 546 | 547 | cons_mat = cp.bmat([[S @ A.T + A @ S + Y.T @ B.T + B @ Y + alpha * S + (mu / gamma ** 2) * G @ G.T, 548 | cp.bmat([[S @ Q_sqrt, Y.T @ R_sqrt]])], 549 | [cp.bmat([[Q_sqrt @ S], [R_sqrt @ Y]]), -mu * np.eye(m + n)]]) 550 | cons = [S >> np.eye(n), mu >= 0] + [cons_mat << -1e-3 * np.eye(n+m+n)] 551 | 552 | try: 553 | prob = cp.Problem(cp.Minimize(f), cons) 554 | prob.solve(solver=cp.SCS) #cp.MOSEK) 555 | except cp.error.SolverError as e: 556 | raise ValueError('Solver failed with error: %s \n Try another environment seed' % e) 557 | K = np.linalg.solve(S.value, Y.value.T).T 558 | 559 | assert np.all(np.linalg.eigvals(S.value) > 0) 560 | assert np.all(mu.value > 0) 561 | assert np.all(np.linalg.eigvals(cons_mat.value) <= 0) 562 | 563 | return K, S.value, mu.value 564 | 565 | 566 | def eval_model(x, pi, env, step_type='euler', T=10, dt=0.05, adversarial=False): 567 | if adversarial: 568 | env.adversarial_disturb_f.reset() 569 | loss = 0 570 | # maxes = torch.ones(6, dtype=TORCH_DTYPE) * -np.inf 571 | # mins = torch.ones(6, dtype=TORCH_DTYPE) * np.inf 572 | for t in tqdm.tqdm(range(int(T / dt)), desc='Evaluating agent%s' % (' adversarial' if adversarial else '')): 573 | u = pi(x) 574 | if adversarial: 575 | env.adversarial_disturb_f.update(x) 576 | x, cost = env.step(x, u, t, step_type=step_type, dt=dt, adversarial=adversarial) 577 | loss += cost 578 | 579 | # maxes = torch.max(maxes, torch.max(x, dim=0)[0]) 580 | # mins = torch.min(mins, torch.min(x, dim=0)[0]) 581 | return loss.mean() 582 | 583 | 584 | def train(model, x_test, x_hold, env, batch_size=20, epochs=1000, test_frequency=10, lr=1e-4, T=1, 585 | dt=0.05, step_type='euler', save_dir=None, model_name=None, device=None, hinf_loss=False): 586 | opt = optim.Adam(model.parameters(), lr=lr) 587 | losses = [] 588 | hold_losses = [] 589 | test_losses = [] 590 | test_losses_adv = [] 591 | num_episode_steps = int(T / dt) 592 | 593 | for i in range(epochs+1): 594 | opt.zero_grad() 595 | x = env.gen_states(batch_size, device=device) 596 | loss = 0 597 | for t in range(num_episode_steps): 598 | # train 599 | model.train() 600 | u = model(x) 601 | x, cost = env.step(x, u, t, dt=dt, step_type=step_type) 602 | loss += cost 603 | 604 | losses.append(loss.mean().item()) 605 | print('Epoch {}. Loss: mean/median {:.3f}/{:.3f}, min/max {:.3f}/{:.3f}' 606 | .format(i, torch.mean(loss), torch.median(loss), torch.min(loss), torch.max(loss))) 607 | 608 | loss.mean().backward() 609 | opt.step() 610 | 611 | if i % test_frequency == 0: 612 | print('Testing...') 613 | env.adversarial_disturb_f.reset() 614 | xh = x_hold.detach() 615 | xt = x_test.detach() 616 | xta = x_test.detach() 617 | hold_loss = 0 618 | test_loss = 0 619 | test_loss_adv = 0 620 | hold_disturb_norm = 0 621 | test_disturb_norm = 0 622 | test_disturb_norm_adv = 0 623 | for t in range(num_episode_steps): 624 | # holdout 625 | model.eval() 626 | u_hold = model(xh) 627 | xh, cost_h = env.step(xh, u_hold, t, dt=dt, step_type=step_type) 628 | hold_loss += cost_h 629 | if hinf_loss: 630 | hold_disturb_norm += torch.norm(env.disturb, p=2, dim=1) 631 | 632 | # test 633 | model.eval() 634 | u_test = model(xt) 635 | xt, cost_t = env.step(xt, u_test, t, dt=dt, step_type=step_type) 636 | test_loss += cost_t 637 | if hinf_loss: 638 | test_disturb_norm += torch.norm(env.disturb, p=2, dim=1) 639 | 640 | # test adversarial 641 | env.adversarial_disturb_f.update(xta) 642 | model.eval() 643 | u_test_adv = model(xta) 644 | xta, cost_ta = env.step(xta, u_test_adv, t, dt=dt, step_type=step_type, adversarial=True) 645 | test_loss_adv += cost_ta 646 | if hinf_loss: 647 | test_disturb_norm_adv += torch.norm(env.disturb, p=2, dim=1) 648 | 649 | hold_losses.append(hold_loss.mean().item()) 650 | test_losses.append(test_loss.mean().item()) 651 | test_losses_adv.append(test_loss_adv.mean().item()) 652 | 653 | print('Hold Loss: mean/median {:.3f}/{:.3f}, min/max {:.3f}/{:.3f}'.format( 654 | torch.mean(hold_loss), torch.median(hold_loss), 655 | torch.min(hold_loss), torch.max(hold_loss))) 656 | print('Test Loss: mean/median {:.3f}/{:.3f}, min/max {:.3f}/{:.3f}'.format( 657 | torch.mean(test_loss), torch.median(test_loss), 658 | torch.min(test_loss), torch.max(test_loss))) 659 | print('Test Loss Adv: mean/median {:.3f}/{:.3f}, min/max {:.3f}/{:.3f}'.format( 660 | torch.mean(test_loss_adv), torch.median(test_loss_adv), 661 | torch.min(test_loss_adv), torch.max(test_loss_adv))) 662 | print('') 663 | 664 | # Save intermediate results 665 | if i % 100 == 0: 666 | save_results(np.array(losses), np.array(hold_losses), np.array(test_losses), np.array(test_losses_adv), 667 | save_dir, model_name, model.state_dict(), epoch=i) 668 | 669 | return model.state_dict(), losses, hold_losses, test_losses, test_losses_adv, i 670 | 671 | 672 | def save_results(train_losses, hold_losses, test_losses, test_losses_adv, 673 | save_dir, model_name, model_dict, epoch, is_final=False): 674 | model_save_dir = os.path.join(save_dir, model_name) 675 | if not os.path.exists(model_save_dir): 676 | os.makedirs(model_save_dir) 677 | 678 | np.save(os.path.join(model_save_dir, 'train_losses.npy'), np.array(train_losses)) 679 | np.save(os.path.join(model_save_dir, 'hold_losses.npy'), np.array(hold_losses)) 680 | np.save(os.path.join(model_save_dir, 'test_losses.npy'), np.array(test_losses)) 681 | np.save(os.path.join(model_save_dir, 'test_losses_adv.npy'), np.array(test_losses_adv)) 682 | torch.save(model_dict, os.path.join(model_save_dir, 'model-{}.pt'.format(epoch))) 683 | if is_final: 684 | torch.save(model_dict, os.path.join(model_save_dir, 'model.pt')) 685 | 686 | 687 | def write_results(test_loss, model_name, save_dir): 688 | if not os.path.exists(save_dir): os.makedirs(save_dir) 689 | result_str = '{}: {}\n'.format(model_name, test_loss) 690 | print(result_str) 691 | with open(os.path.join(save_dir, 'results.txt'), 'a') as f: 692 | f.write(result_str) 693 | 694 | 695 | if __name__ == '__main__': 696 | main() 697 | --------------------------------------------------------------------------------