├── cap-planet ├── safety_envs.py ├── utils.py ├── memory.py ├── planner.py ├── env.py ├── models.py └── run_cap_planet.py ├── README.md ├── cap-pets ├── utils.py ├── run_cap_pets.py ├── models.py └── ccem.py └── environment.yml /cap-planet/safety_envs.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Any 2 | import numpy as np 3 | 4 | import gym 5 | import dm_env 6 | from dm_env import StepType 7 | from dm_control import suite 8 | 9 | CONSTRAINED_GYM_ENVS = ['CarRacingSkiddingConstrained-v0'] 10 | CONSTRAINED_CONTROL_SUITE_ENVS = ['cartpole-swingup-constrained'] 11 | 12 | class TimeStepWithCost(NamedTuple): 13 | step_type: Any 14 | reward: Any 15 | discount: Any 16 | observation: Any 17 | cost: Any 18 | 19 | def first(self) -> bool: 20 | return self.step_type == StepType.FIRST 21 | 22 | def mid(self) -> bool: 23 | return self.step_type == StepType.MID 24 | 25 | def last(self) -> bool: 26 | return self.step_type == StepType.LAST 27 | 28 | class CartpoleConstrainedWrapper(dm_env.Environment): 29 | def __init__(self, env): 30 | self._env = env 31 | 32 | def step(self, action): 33 | state = self._env.step(action) 34 | pos = state.observation["position"][1:] 35 | angle = np.degrees(np.arctan2(*pos)) 36 | cost = 1 if 20 < angle and angle < 50 else 0 37 | return TimeStepWithCost(*state, cost) 38 | 39 | def reset(self): 40 | return self._env.reset() 41 | 42 | def observation_spec(self): 43 | return self._env.observation_spec() 44 | 45 | def action_spec(self): 46 | return self._env.action_spec() 47 | 48 | def __getattr__(self, name): 49 | return getattr(self._env, name) 50 | 51 | 52 | def load_suite_env(env_name, seed): 53 | spec = env_name.split('-') 54 | domain, task = spec[:2] 55 | is_safety_constrained = len(spec) > 2 and spec[2] == "constrained" 56 | 57 | env = suite.load(domain_name=domain, task_name=task, task_kwargs={'random': seed}) 58 | 59 | if is_safety_constrained and domain == "cartpole": 60 | env = CartpoleConstrainedWrapper(env) 61 | 62 | return env 63 | 64 | class SkiddingConstrainedCarRacing(gym.Wrapper): 65 | def __init__(self): 66 | env = gym.make("CarRacing-v0") 67 | super(SkiddingConstrainedCarRacing, self).__init__(env) 68 | 69 | def step(self, action): 70 | obs, rew, done, info = self.env.step(action) 71 | cost = 0 72 | car = self.env.car 73 | for wheel in car.wheels: 74 | if wheel.skid_start is not None or wheel.skid_particle is not None: 75 | cost = 1 76 | info["cost"] = cost 77 | return obs, rew, done, info 78 | 79 | gym.register("CarRacingSkiddingConstrained-v0", entry_point="safety_envs:SkiddingConstrainedCarRacing") -------------------------------------------------------------------------------- /cap-planet/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import plotly 5 | from plotly.graph_objs import Scatter, Violin 6 | from plotly.graph_objs.scatter import Line 7 | 8 | 9 | # Plots min, max and mean + standard deviation bars of a population over time 10 | def lineplot(xs, ys_population, title, path='', xaxis='episode'): 11 | max_colour, mean_colour, std_colour, transparent = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)', 'rgba(0, 0, 0, 0)' 12 | 13 | if isinstance(ys_population[0], list) or isinstance(ys_population[0], tuple) or isinstance(ys_population[0], np.ndarray): 14 | ys = np.asarray(ys_population, dtype=np.float32) 15 | ys_min, ys_max, ys_mean, ys_std, ys_median = ys.min(1), ys.max(1), ys.mean(1), ys.std(1), np.median(ys, 1) 16 | ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std 17 | 18 | trace_max = Scatter(x=xs, y=ys_max, line=Line(color=max_colour, dash='dash'), name='Max') 19 | trace_upper = Scatter(x=xs, y=ys_upper, line=Line(color=transparent), name='+1 Std. Dev.', showlegend=False) 20 | trace_mean = Scatter(x=xs, y=ys_mean, fill='tonexty', fillcolor=std_colour, line=Line(color=mean_colour), name='Mean') 21 | trace_lower = Scatter(x=xs, y=ys_lower, fill='tonexty', fillcolor=std_colour, line=Line(color=transparent), name='-1 Std. Dev.', showlegend=False) 22 | trace_min = Scatter(x=xs, y=ys_min, line=Line(color=max_colour, dash='dash'), name='Min') 23 | trace_median = Scatter(x=xs, y=ys_median, line=Line(color=max_colour), name='Median') 24 | data = [trace_upper, trace_mean, trace_lower, trace_min, trace_max, trace_median] 25 | else: 26 | data = [Scatter(x=xs, y=ys_population, line=Line(color=mean_colour))] 27 | plotly.offline.plot({ 28 | 'data': data, 29 | 'layout': dict(title=title, xaxis={'title': xaxis}, yaxis={'title': title}) 30 | }, filename=os.path.join(path, title + '.html'), auto_open=False) 31 | 32 | 33 | def violinplot(xs, ys, title, path='', xaxis='cost'): 34 | x = np.concatenate(xs) 35 | y = np.concatenate(ys) 36 | population = Violin(x=x, y=y, points="all") 37 | data = [population] 38 | plotly.offline.plot({ 39 | 'data': data, 40 | 'layout': dict(title=title, xaxis={'title': xaxis}, yaxis={'title': title}) 41 | }, filename=os.path.join(path, title + '.html'), auto_open=False) 42 | 43 | 44 | def write_video(frames, title, path=''): 45 | frames = np.multiply(np.stack(frames, axis=0).transpose(0, 2, 3, 1), 255).clip(0, 255).astype(np.uint8)[:, :, :, ::-1] # VideoWrite expects H x W x C in BGR 46 | _, H, W, _ = frames.shape 47 | writer = cv2.VideoWriter(os.path.join(path, '%s.mp4' % title), cv2.VideoWriter_fourcc(*'mp4v'), 30., (W, H), True) 48 | for frame in frames: 49 | writer.write(frame) 50 | writer.release() 51 | -------------------------------------------------------------------------------- /cap-planet/memory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from env import postprocess_observation, preprocess_observation_ 4 | 5 | 6 | class ExperienceReplay(): 7 | def __init__(self, size, symbolic_env, observation_size, action_size, bit_depth, device): 8 | self.device = device 9 | self.symbolic_env = symbolic_env 10 | self.size = size 11 | self.observations = np.empty((size, observation_size) if symbolic_env else (size, 3, 64, 64), dtype=np.float32 if symbolic_env else np.uint8) 12 | self.actions = np.empty((size, action_size), dtype=np.float32) 13 | self.rewards = np.empty((size, ), dtype=np.float32) 14 | self.costs = np.empty((size, ), dtype=np.float32) 15 | self.nonterminals = np.empty((size, 1), dtype=np.float32) 16 | self.idx = 0 17 | self.full = False # Tracks if memory has been filled/all slots are valid 18 | self.steps, self.episodes = 0, 0 # Tracks how much experience has been used in total 19 | self.bit_depth = bit_depth 20 | 21 | def append(self, observation, action, reward, cost, done): 22 | if self.symbolic_env: 23 | self.observations[self.idx] = observation.numpy() 24 | else: 25 | self.observations[self.idx] = postprocess_observation(observation.numpy(), self.bit_depth) # Decentre and discretise visual observations (to save memory) 26 | self.actions[self.idx] = action.numpy() 27 | self.rewards[self.idx] = reward 28 | self.costs[self.idx] = cost 29 | self.nonterminals[self.idx] = not done 30 | self.idx = (self.idx + 1) % self.size 31 | self.full = self.full or self.idx == 0 32 | self.steps, self.episodes = self.steps + 1, self.episodes + (1 if done else 0) 33 | 34 | # Returns an index for a valid single sequence chunk uniformly sampled from the memory 35 | def _sample_idx(self, L): 36 | valid_idx = False 37 | while not valid_idx: 38 | idx = np.random.randint(0, self.size if self.full else self.idx - L) 39 | idxs = np.arange(idx, idx + L) % self.size 40 | valid_idx = not self.idx in idxs[1:] # Make sure data does not cross the memory index 41 | return idxs 42 | 43 | def _retrieve_batch(self, idxs, n, L): 44 | vec_idxs = idxs.transpose().reshape(-1) # Unroll indices 45 | observations = torch.as_tensor(self.observations[vec_idxs].astype(np.float32)) 46 | if not self.symbolic_env: 47 | preprocess_observation_(observations, self.bit_depth) # Undo discretisation for visual observations 48 | return observations.reshape(L, n, *observations.shape[1:]), self.actions[vec_idxs].reshape(L, n, -1), self.rewards[vec_idxs].reshape(L, n), self.costs[vec_idxs].reshape(L, n), self.nonterminals[vec_idxs].reshape(L, n, 1) 49 | 50 | # Returns a batch of sequence chunks uniformly sampled from the memory 51 | def sample(self, n, L): 52 | batch = self._retrieve_batch(np.asarray([self._sample_idx(L) for _ in range(n)]), n, L) 53 | return [torch.as_tensor(item).to(device=self.device) for item in batch] 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning 2 | 3 | This is the official repository for Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning. 4 | We provide the commands to run the PETS and PlaNet experiments included in the paper. This repository is made minimal for ease of experimentation. 5 | 6 | ## Installations 7 | This repository requires Python (3.6), Pytorch (version 1.3 or above) 8 | run the following command to create a conda environment (tested using CUDA10.2): 9 | ``` 10 | conda env create -f environment.yml 11 | ``` 12 | ## Experiments 13 | 14 | ### To run the PETS experiments on the HalfCheetah environment used in our ablation study, run: 15 | 16 | ``` 17 | cd cap-pets 18 | ``` 19 | 20 | **CAP** 21 | 22 | ``` 23 | python cap-pets/run_cap_pets.py --algo cem --env HalfCheetah-v3 --cost_lim 152 \ 24 | --cost_constrained --penalize_uncertainty --learn_kappa --seed 1 25 | ``` 26 | 27 | **CAP with fixed kappa** 28 | 29 | ``` 30 | python cap-pets/run_cap_pets.py --algo cem --env HalfCheetah-v3 --cost_lim 152 \ 31 | --cost_constrained --penalize_uncertainty --kappa 1.0 --seed 1 32 | ``` 33 | 34 | **CCEM** 35 | 36 | ``` 37 | python cap-pets/run_cap_pets.py --algo cem --env HalfCheetah-v3 --cost_lim 152 \ 38 | --cost_constrained --seed 1 39 | ``` 40 | 41 | **CEM** 42 | 43 | ``` 44 | python cap-pets/run_cap_pets.py --algo cem --env HalfCheetah-v3 --cost_lim 152 \ 45 | --seed 1 46 | ``` 47 | 48 | ### The commands for the PlaNet experiment on the CarRacing environment are: 49 | 50 | **CAP** 51 | 52 | ``` 53 | python cap-planet/run_cap_planet.py --env CarRacingSkiddingConstrained-v0 \ 54 | --cost-limit 0 --binary-cost \ 55 | --cost-constrained --penalize-uncertainty \ 56 | --learn-kappa --penalty-kappa 0.1 \ 57 | --id CarRacing-cap --seed 1 58 | ``` 59 | 60 | **CAP with fixed kappa** 61 | 62 | ``` 63 | python cap-planet/run_cap_planet.py --env CarRacingSkiddingConstrained-v0 \ 64 | --cost-limit 0 --binary-cost \ 65 | --cost-constrained --penalize-uncertainty \ 66 | --penalty-kappa 1.0 \ 67 | --id CarRacing-kappa1 --seed 1 68 | ``` 69 | 70 | **CCEM** 71 | 72 | ``` 73 | python cap-planet/run_cap_planet.py --env CarRacingSkiddingConstrained-v0 \ 74 | --cost-limit 0 --binary-cost \ 75 | --cost-constrained \ 76 | --id CarRacing-ccem --seed 1 77 | ``` 78 | 79 | **CEM** 80 | 81 | ``` 82 | python cap-planet/run_cap_planet.py --env CarRacingSkiddingConstrained-v0 \ 83 | --cost-limit 0 --binary-cost \ 84 | --id CarRacing-cem --seed 1 85 | ``` 86 | 87 | ## Contact 88 | If you have any questions regarding the code or paper, feel free to contact jasonyma@seas.upenn.edu or open an issue on this repository. 89 | 90 | ## Acknowledgement 91 | This repository contains code adapted from the 92 | following repositories: [PETS](https://github.com/quanvuong/handful-of-trials-pytorch) and 93 | [PlaNet](https://github.com/Kaixhin/PlaNet). We thank the 94 | authors and contributors for open-sourcing their code. 95 | -------------------------------------------------------------------------------- /cap-pets/utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import random 3 | import numpy as np 4 | from operator import itemgetter 5 | 6 | class ReplayMemory(Dataset): 7 | def __init__(self, capacity): 8 | self.capacity = capacity 9 | self.buffer = [] 10 | self.position = 0 11 | 12 | def __len__(self): 13 | return len(self.buffer) 14 | 15 | def __getitem__(self, idx): 16 | return self.buffer[idx] 17 | 18 | def push(self, state, action, reward, next_state, done): 19 | if len(self.buffer) < self.capacity: 20 | self.buffer.append(None) 21 | self.buffer[self.position] = (state, action, reward, next_state, done) 22 | self.position = (self.position + 1) % self.capacity 23 | 24 | def push_batch(self, batch): 25 | if len(self.buffer) < self.capacity: 26 | append_len = min(self.capacity - len(self.buffer), len(batch)) 27 | self.buffer.extend([None] * append_len) 28 | 29 | if self.position + len(batch) < self.capacity: 30 | self.buffer[self.position : self.position + len(batch)] = batch 31 | self.position += len(batch) 32 | else: 33 | self.buffer[self.position : len(self.buffer)] = batch[:len(self.buffer) - self.position] 34 | self.buffer[:len(batch) - len(self.buffer) + self.position] = batch[len(self.buffer) - self.position:] 35 | self.position = len(batch) - len(self.buffer) + self.position 36 | 37 | def sample(self, batch_size): 38 | if batch_size > len(self.buffer): 39 | batch_size = len(self.buffer) 40 | batch = random.sample(self.buffer, int(batch_size)) 41 | state, action, reward, next_state, done = map(np.stack, zip(*batch)) 42 | return state, action, reward, next_state, done 43 | 44 | def sample_all_batch(self, batch_size): 45 | idxes = np.random.randint(0, len(self.buffer), batch_size) 46 | batch = list(itemgetter(*idxes)(self.buffer)) 47 | state, action, reward, next_state, done = map(np.stack, zip(*batch)) 48 | return state, action, reward, next_state, done 49 | 50 | def return_all(self): 51 | state, action, reward, next_state, done = map(np.stack, zip(*self.buffer)) 52 | return state, action, reward, next_state, done 53 | 54 | def __len__(self): 55 | return len(self.buffer) 56 | 57 | class EnvSampler(): 58 | def __init__(self, env, max_path_length=1000): 59 | self.env = env 60 | 61 | self.path_length = 0 62 | self.current_state = None 63 | self.max_path_length = max_path_length 64 | self.path_rewards = [] 65 | self.sum_reward = 0 66 | 67 | def sample(self, agent, eval_t=False, random_explore=False): 68 | if self.current_state is None: 69 | self.current_state = self.env.reset() 70 | 71 | cur_state = self.current_state 72 | if not random_explore: 73 | action = agent.select_action(self.current_state, eval_t) 74 | else: 75 | action = self.env.action_space.sample() 76 | action = action.astype(float) 77 | 78 | next_state, reward, terminal, info = self.env.step(action) 79 | self.path_length += 1 80 | self.sum_reward += reward 81 | 82 | # add the cost 83 | if "cost" in info: 84 | cost = info["cost"] 85 | elif "x_velocity" in info: 86 | if "y_velocity" in info: 87 | cost = np.sqrt(info["y_velocity"] ** 2 + info["x_velocity"] ** 2) 88 | else: 89 | cost = np.abs(info["x_velocity"]) 90 | else: 91 | cost = 0 92 | reward = np.array([reward, cost]) 93 | 94 | # TODO: Save the path to the env_pool 95 | if terminal or self.path_length >= self.max_path_length: 96 | self.current_state = None 97 | self.path_length = 0 98 | self.path_rewards.append(self.sum_reward) 99 | self.sum_reward = 0 100 | else: 101 | self.current_state = next_state 102 | 103 | return cur_state, action, next_state, reward, terminal, info -------------------------------------------------------------------------------- /cap-planet/planner.py: -------------------------------------------------------------------------------- 1 | from math import inf 2 | import torch 3 | from torch import jit 4 | import numpy as np 5 | import math 6 | 7 | DEFAULT_UNCERTAINTY_MULTIPLIER = 1000 8 | DEFAULT_UNCERTAINTY_MULTIPLIER_BINARY = 10000 9 | 10 | # Model-predictive control planner with cross-entropy method and learned transition model 11 | # class MPCPlanner(jit.ScriptModule): 12 | class MPCPlanner(torch.nn.Module): 13 | __constants__ = ['action_size', 'planning_horizon', 'optimisation_iters', 'candidates', 'top_candidates', 'min_action', 'max_action', 'cost_constrained', 'cost_limit_per_step'] 14 | 15 | def __init__(self, action_size, planning_horizon, optimisation_iters, candidates, top_candidates, transition_model, reward_model, cost_model, one_step_ensemble, 16 | min_action=-inf, max_action=inf, cost_constrained=False, penalize_uncertainty=True, 17 | cost_limit=0, action_repeat=2, max_length=1000, binary_cost=False, cost_discount=0.99, penalty_kappa=0, lr=0.01): 18 | super().__init__() 19 | self.transition_model, self.reward_model, self.cost_model, self.one_step_ensemble = transition_model, reward_model, cost_model, one_step_ensemble 20 | self.action_size, self.min_action, self.max_action = action_size, min_action, max_action 21 | self.planning_horizon = planning_horizon 22 | self.optimisation_iters = optimisation_iters 23 | self.candidates, self.top_candidates = candidates, top_candidates 24 | self.cost_constrained = cost_constrained 25 | self.penalize_uncertainty = penalize_uncertainty 26 | self.binary_cost = binary_cost 27 | self.set_cost_limit(cost_limit, cost_discount, action_repeat, max_length) 28 | self.penalty_kappa = torch.tensor([float(penalty_kappa)], requires_grad=True) 29 | self.kappa_optim = torch.optim.Adam([self.penalty_kappa], lr=lr) 30 | self._fixed_cost_penalty = 0 31 | if self.binary_cost: 32 | self.uncertainty_multiplier = DEFAULT_UNCERTAINTY_MULTIPLIER_BINARY 33 | else: 34 | self.uncertainty_multiplier = DEFAULT_UNCERTAINTY_MULTIPLIER 35 | 36 | def set_cost_limit(self, cost_limit, cost_discount, action_repeat, max_length): 37 | self.cost_limit = cost_limit 38 | steps = max_length / action_repeat 39 | if cost_discount == 1: 40 | self.cost_limit_per_step = cost_limit / steps 41 | else: 42 | self.cost_limit_per_step = cost_limit * (1 - cost_discount ** action_repeat) / (1 - cost_discount ** max_length) 43 | 44 | def optimize_penalty_kappa(self, episode_cost, cost_limit=None): 45 | if cost_limit is None: 46 | cost_limit = self.cost_limit 47 | kappa_loss = -(self.penalty_kappa * (episode_cost - cost_limit)) 48 | 49 | self.kappa_optim.zero_grad() 50 | kappa_loss.backward() 51 | self.kappa_optim.step() 52 | 53 | # @jit.script_method 54 | def forward(self, belief, state): 55 | B, H, Z = belief.size(0), belief.size(1), state.size(1) 56 | belief, state = belief.unsqueeze(dim=1).expand(B, self.candidates, H).reshape(-1, H), state.unsqueeze(dim=1).expand(B, self.candidates, Z).reshape(-1, Z) 57 | # Initialize factorized belief over action sequences q(a_t:t+H) ~ N(0, I) 58 | action_mean, action_std_dev = torch.zeros(self.planning_horizon, B, 1, self.action_size, device=belief.device), torch.ones(self.planning_horizon, B, 1, self.action_size, device=belief.device) 59 | for _ in range(self.optimisation_iters): 60 | # Evaluate J action sequences from the current belief (over entire sequence at once, batched over particles) 61 | actions = (action_mean + action_std_dev * torch.randn(self.planning_horizon, B, self.candidates, self.action_size, device=action_mean.device)).view(self.planning_horizon, B * self.candidates, self.action_size) # Sample actions (time x (batch x candidates) x actions) 62 | actions.clamp_(min=self.min_action, max=self.max_action) # Clip action range 63 | # Sample next states 64 | beliefs, states, _, states_std = self.transition_model(state, actions, belief) 65 | # Calculate expected returns (technically sum of rewards over planning horizon) 66 | returns = self.reward_model(beliefs.view(-1, H), states.view(-1, Z)).view(self.planning_horizon, -1).mean(dim=0) 67 | objective = returns 68 | if self.cost_constrained: 69 | costs = self.cost_model(beliefs.view(-1, H), states.view(-1, Z)).view(self.planning_horizon, -1) 70 | costs += self._fixed_cost_penalty 71 | uncertainty = self.one_step_ensemble.compute_uncertainty(beliefs, actions) 72 | if self.penalize_uncertainty: 73 | penalty_kappa = self.penalty_kappa.detach().to(costs.device) 74 | costs += penalty_kappa * self.uncertainty_multiplier * uncertainty 75 | if self.binary_cost: 76 | logits = costs 77 | costs = (torch.sigmoid(costs) > 0.5).float() 78 | avg_costs = costs.mean(dim=0) 79 | feasible_samples = (avg_costs <= self.cost_limit_per_step) 80 | objective[~feasible_samples] = - np.inf 81 | if feasible_samples.sum() < self.top_candidates: 82 | objective = - avg_costs 83 | # Re-fit belief to the K best action sequences 84 | _, topk = objective.reshape(B, self.candidates).topk(self.top_candidates, dim=1, largest=True, sorted=False) 85 | topk += self.candidates * torch.arange(0, B, dtype=torch.int64, device=topk.device).unsqueeze(dim=1) # Fix indices for unrolled actions 86 | best_actions = actions[:, topk.view(-1)].reshape(self.planning_horizon, B, self.top_candidates, self.action_size) 87 | # Update belief with new means and standard deviations 88 | action_mean, action_std_dev = best_actions.mean(dim=2, keepdim=True), best_actions.std(dim=2, unbiased=False, keepdim=True) 89 | # Return first action mean µ_t 90 | if self.penalize_uncertainty: 91 | self.uncertainty_last_step = uncertainty[:, topk.view(-1)].reshape(self.planning_horizon, B, self.top_candidates).mean(2).cpu() 92 | return action_mean[0].squeeze(dim=1) 93 | -------------------------------------------------------------------------------- /cap-pets/run_cap_pets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | import gym 8 | from gym.wrappers import Monitor 9 | 10 | from utils import EnvSampler, ReplayMemory 11 | from ccem import ConstrainedCEM 12 | from models import ProbEnsemble 13 | 14 | 15 | def readParser(): 16 | parser = argparse.ArgumentParser(description='CAP') 17 | 18 | parser.add_argument('--env', default="HalfCheetah-v3") 19 | parser.add_argument('--algo', default="cem") 20 | parser.add_argument('--monitor_gym', default=False, action='store_true') 21 | parser.add_argument('--penalize_uncertainty', action='store_true') 22 | parser.add_argument('--kappa', type=float, default=1.0) 23 | parser.add_argument('--binary_cost', action='store_true') 24 | parser.add_argument('--learn_kappa', action='store_true') 25 | parser.add_argument('--gamma', default=0.99) 26 | parser.add_argument('--cost_constrained', dest='cost_constrained', action='store_true') 27 | parser.add_argument('--cost_limit', type=float, default=0., 28 | help='constraint threshold') 29 | parser.add_argument('--permissible_cost', type=float, default=0., 30 | help='constraint threshold') 31 | parser.add_argument('--plan_hor', type=int, default=30) 32 | parser.add_argument('--seed', type=int, default=0, metavar='N', 33 | help='random seed (default: 0)') 34 | parser.add_argument('--model_retain_epochs', type=int, default=20, metavar='A', 35 | help='retain epochs') 36 | parser.add_argument('--model_train_freq', type=int, default=1000, metavar='A', 37 | help='frequency of training') 38 | parser.add_argument('--epoch_length', type=int, default=1000, metavar='A', 39 | help='steps per epoch') 40 | 41 | parser.add_argument('--num_epoch', type=int, default=100, metavar='A', 42 | help='total number of epochs') 43 | parser.add_argument('--eval_n_episodes', type=int, default=10, metavar='A', 44 | help='number of evaluation episodes') 45 | parser.add_argument('--policy_train_batch_size', type=int, default=256, metavar='A', 46 | help='batch size for training policy') 47 | 48 | parser.add_argument('--hidden_size', type=int, default=200, metavar='A', 49 | help='ensemble model hidden dimension') 50 | parser.add_argument('--cuda', default=True, action="store_true", 51 | help='run on CUDA (default: True)') 52 | 53 | args = parser.parse_args() 54 | if args.permissible_cost < args.cost_limit: 55 | args.permissible_cost = args.cost_limit 56 | args.learn_cost = True 57 | if args.binary_cost: 58 | args.c_gamma = 1 59 | else: 60 | args.c_gamma = args.gamma 61 | if not torch.cuda.is_available(): 62 | args.cuda = False 63 | return args 64 | 65 | 66 | def train_env_model(args, env_pool, model): 67 | state, action, reward, next_state, done = env_pool.return_all() 68 | 69 | delta_state = next_state - state 70 | inputs = np.concatenate((state, action), axis=-1) 71 | 72 | reward = np.reshape(reward, (reward.shape[0], -1)) 73 | labels = { 74 | "state": delta_state, 75 | "reward": reward[:, :1], 76 | "cost": reward[:, 1:2], 77 | } 78 | 79 | model.train(inputs, labels, batch_size=256) 80 | 81 | # Save trained dynamics model 82 | if args.learn_cost: 83 | model_path = f'saved_models/{args.env}-ensemble-h{args.hidden_size}.pt' 84 | else: 85 | model_path = f'saved_models/{args.env}-ensemble-nocost-h{args.hidden_size}.pt' 86 | os.makedirs('saved_models', exist_ok=True) 87 | torch.save(model.state_dict(), model_path) 88 | 89 | def train(args, env_sampler, env_model, cem_agent, env_pool): 90 | reward_sum = 0 91 | total_violation = 0 92 | environment_step = 0 93 | learner_update_step = 0 94 | eps_idx = 0 95 | env = env_sampler.env 96 | 97 | for epoch_step in tqdm(range(args.num_epoch)): 98 | # Record agent behaviour 99 | if args.monitor_gym: 100 | monitor = Monitor(env, f"videos/{args.run_name}", force=True) 101 | if epoch_step % 10 == 0: 102 | env.reset() 103 | env_sampler.env = monitor 104 | env_sampler.current_state = None 105 | monitor.render() 106 | 107 | epoch_rewards = [0] 108 | epoch_costs = [0] 109 | epoch_lens = [0] 110 | 111 | for i in range(args.epoch_length): 112 | cur_state, action, next_state, reward, done, info = env_sampler.sample(cem_agent) 113 | epoch_rewards[-1] += reward[0] 114 | epoch_costs[-1] += args.c_gamma ** i * reward[1] 115 | epoch_lens[-1] += 1 116 | 117 | env_pool.push(cur_state, action, reward, next_state, done) 118 | 119 | environment_step += 1 120 | 121 | if done and i != args.epoch_length - 1: 122 | epoch_rewards.append(0) 123 | epoch_costs.append(0) 124 | epoch_lens.append(0) 125 | eps_idx += 1 126 | 127 | if (i + 1) % args.model_train_freq == 0: 128 | train_env_model(args, env_pool, env_model) 129 | if args.algo != "random": 130 | cem_agent.set_model(env_model) 131 | 132 | epoch_reward = np.mean(epoch_rewards) 133 | epoch_cost = np.mean(epoch_costs) 134 | epoch_len = np.mean(epoch_lens) 135 | 136 | if args.monitor_gym: 137 | monitor.close() 138 | env_sampler.env = env 139 | 140 | # Track total number of violations 141 | if epoch_cost > args.cost_limit: 142 | total_violation += 1 143 | 144 | print("") 145 | print(f'Epoch {epoch_step} Reward {epoch_reward:.2f} Cost {epoch_cost:.2f} Total_Violations {total_violation}') 146 | 147 | if args.learn_kappa: 148 | cem_agent.optimize_kappa(epoch_cost, args.permissible_cost) 149 | 150 | def main(): 151 | args = readParser() 152 | spec = [] 153 | if not args.cost_constrained: 154 | spec.append('NoConstraint') 155 | else: 156 | if args.penalize_uncertainty: 157 | spec.extend([f'P{args.kappa}', f'T{args.learn_kappa}']) 158 | if args.learn_kappa: 159 | spec.append('CAP') 160 | spec.append(f'C{args.cost_limit}') 161 | 162 | spec = '-'.join(spec) 163 | 164 | run_name = f"{args.algo}-{spec}-{args.seed}" 165 | args.run_name = run_name 166 | 167 | print(f"Starting run {run_name}") 168 | 169 | if args.learn_kappa: 170 | args.penalize_uncertainty = True 171 | 172 | env = gym.make(args.env) 173 | state_size = np.prod(env.observation_space.shape) 174 | action_size = np.prod(env.action_space.shape) 175 | 176 | # Set random seed 177 | torch.manual_seed(args.seed) 178 | np.random.seed(args.seed) 179 | env.seed(args.seed) 180 | 181 | # Ensemble Dynamics Model 182 | env_model = ProbEnsemble(state_size, action_size, network_size=5, cuda=args.cuda, 183 | cost=args.learn_cost, binary_cost=args.binary_cost, hidden_size=args.hidden_size) 184 | if args.cuda: 185 | env_model.to('cuda') 186 | 187 | # CEM Agent 188 | cem_agent = ConstrainedCEM(env, 189 | plan_hor=args.plan_hor, 190 | gamma=args.gamma, 191 | cost_limit=args.cost_limit, 192 | cost_constrained=args.cost_constrained, 193 | penalize_uncertainty=args.penalize_uncertainty, 194 | learn_kappa=args.learn_kappa, 195 | kappa=args.kappa, 196 | binary_cost=args.binary_cost, 197 | cuda=args.cuda, 198 | ) 199 | 200 | # Sampler Environment 201 | env_sampler = EnvSampler(env, max_path_length=args.epoch_length) 202 | 203 | # Experience Buffer 204 | env_pool = ReplayMemory(args.epoch_length * args.num_epoch) 205 | 206 | # Train 207 | train(args, env_sampler, env_model, cem_agent, env_pool) 208 | 209 | 210 | if __name__ == '__main__': 211 | main() -------------------------------------------------------------------------------- /cap-planet/env.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from gym import Wrapper 5 | from safety_envs import load_suite_env, CONSTRAINED_GYM_ENVS, CONSTRAINED_CONTROL_SUITE_ENVS 6 | 7 | GYM_ENVS = ['Pendulum-v0', 'MountainCarContinuous-v0', 'Ant-v2', 'HalfCheetah-v2', 'Hopper-v2', 'Humanoid-v2', 'HumanoidStandup-v2', 'InvertedDoublePendulum-v2', 'InvertedPendulum-v2', 'Reacher-v2', 'Swimmer-v2', 'Walker2d-v2', 'CarRacing-v0', *CONSTRAINED_GYM_ENVS] 8 | CONTROL_SUITE_ENVS = ['cartpole-balance', 'cartpole-swingup', 'reacher-easy', 'finger-spin', 'cheetah-run', 'ball_in_cup-catch', 'walker-walk', *CONSTRAINED_CONTROL_SUITE_ENVS] 9 | CONTROL_SUITE_ACTION_REPEATS = {'cartpole': 8, 'reacher': 4, 'finger': 2, 'cheetah': 4, 'ball_in_cup': 6, 'walker': 2} 10 | 11 | 12 | # Preprocesses an observation inplace (from float32 Tensor [0, 255] to [-0.5, 0.5]) 13 | def preprocess_observation_(observation, bit_depth): 14 | observation.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_(0.5) # Quantise to given bit depth and centre 15 | observation.add_(torch.rand_like(observation).div_(2 ** bit_depth)) # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images) 16 | 17 | 18 | # Postprocess an observation for storage (from float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255]) 19 | def postprocess_observation(observation, bit_depth): 20 | return np.clip(np.floor((observation + 0.5) * 2 ** bit_depth) * 2 ** (8 - bit_depth), 0, 2 ** 8 - 1).astype(np.uint8) 21 | 22 | 23 | def _images_to_observation(images, bit_depth): 24 | images = torch.tensor(cv2.resize(images, (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32) # Resize and put channel first 25 | preprocess_observation_(images, bit_depth) # Quantise, centre and dequantise inplace 26 | return images.unsqueeze(dim=0) # Add batch dimension 27 | 28 | 29 | class ControlSuiteEnv(): 30 | def __init__(self, env, symbolic, seed, max_episode_length, action_repeat, bit_depth, cost_reduce="max"): 31 | from dm_control.suite.wrappers import pixels 32 | domain = env.split('-')[0] 33 | self._env = load_suite_env(env, seed) 34 | self.symbolic = symbolic 35 | if not symbolic: 36 | self._env = pixels.Wrapper(self._env) 37 | self.max_episode_length = max_episode_length 38 | self.action_repeat = action_repeat 39 | if action_repeat != CONTROL_SUITE_ACTION_REPEATS[domain]: 40 | print('Using action repeat %d; recommended action repeat for domain is %d' % (action_repeat, CONTROL_SUITE_ACTION_REPEATS[domain])) 41 | self.bit_depth = bit_depth 42 | if cost_reduce == "sum": 43 | self.cost_reduce = sum 44 | if cost_reduce == "max": 45 | self.cost_reduce = max 46 | 47 | def reset(self): 48 | self.t = 0 # Reset internal timer 49 | state = self._env.reset() 50 | if self.symbolic: 51 | return torch.tensor(np.concatenate([np.asarray([obs]) if isinstance(obs, float) else obs for obs in state.observation.values()], axis=0), dtype=torch.float32).unsqueeze(dim=0) 52 | else: 53 | return _images_to_observation(self._env.physics.render(camera_id=0), self.bit_depth) 54 | 55 | def step(self, action): 56 | action = action.detach().numpy() 57 | reward = 0 58 | cost = 0 59 | for k in range(self.action_repeat): 60 | state = self._env.step(action) 61 | reward += state.reward 62 | cost = self.cost_reduce([getattr(state, 'cost', 0), cost]) 63 | self.t += 1 # Increment internal timer 64 | done = state.last() or self.t == self.max_episode_length 65 | if done: 66 | break 67 | if self.symbolic: 68 | observation = torch.tensor(np.concatenate([np.asarray([obs]) if isinstance(obs, float) else obs for obs in state.observation.values()], axis=0), dtype=torch.float32).unsqueeze(dim=0) 69 | else: 70 | observation = _images_to_observation(self._env.physics.render(camera_id=0), self.bit_depth) 71 | return observation, reward, cost, done 72 | 73 | def render(self): 74 | cv2.imshow('screen', self._env.physics.render(camera_id=0)[:, :, ::-1]) 75 | cv2.waitKey(1) 76 | 77 | def close(self): 78 | cv2.destroyAllWindows() 79 | self._env.close() 80 | 81 | @property 82 | def observation_size(self): 83 | return sum([(1 if len(obs.shape) == 0 else obs.shape[0]) for obs in self._env.observation_spec().values()]) if self.symbolic else (3, 64, 64) 84 | 85 | @property 86 | def action_size(self): 87 | return self._env.action_spec().shape[0] 88 | 89 | @property 90 | def action_range(self): 91 | return float(self._env.action_spec().minimum[0]), float(self._env.action_spec().maximum[0]) 92 | 93 | # Sample an action randomly from a uniform distribution over all valid actions 94 | def sample_random_action(self): 95 | spec = self._env.action_spec() 96 | return torch.from_numpy(np.random.uniform(spec.minimum, spec.maximum, spec.shape)) 97 | 98 | 99 | 100 | class GymEnv(): 101 | def __init__(self, env, symbolic, seed, max_episode_length, action_repeat, bit_depth, cost_reduce="max"): 102 | import logging 103 | import gym 104 | gym.logger.set_level(logging.ERROR) # Ignore warnings from Gym logger 105 | self.symbolic = symbolic 106 | self._env = gym.make(env) 107 | if seed is not None: 108 | self._env.seed(seed) 109 | self.max_episode_length = max_episode_length 110 | self.action_repeat = action_repeat 111 | self.bit_depth = bit_depth 112 | if cost_reduce == "sum": 113 | self.cost_reduce = sum 114 | if cost_reduce == "max": 115 | self.cost_reduce = max 116 | 117 | def reset(self): 118 | self.t = 0 # Reset internal timer 119 | state = self._env.reset() 120 | if self.symbolic: 121 | return torch.tensor(state, dtype=torch.float32).unsqueeze(dim=0) 122 | else: 123 | return _images_to_observation(self._env.render(mode='rgb_array'), self.bit_depth) 124 | 125 | def step(self, action): 126 | action = action.detach().numpy() 127 | reward = 0 128 | cost = 0 129 | for k in range(self.action_repeat): 130 | state, reward_k, done, info = self._env.step(action) 131 | reward += reward_k 132 | cost = self.cost_reduce([info.get('cost', 0), cost]) 133 | self.t += 1 # Increment internal timer 134 | done = done or self.t == self.max_episode_length 135 | if done: 136 | break 137 | if self.symbolic: 138 | observation = torch.tensor(state, dtype=torch.float32).unsqueeze(dim=0) 139 | else: 140 | observation = _images_to_observation(self._env.render(mode='rgb_array'), self.bit_depth) 141 | return observation, reward, cost, done 142 | 143 | def render(self): 144 | self._env.render() 145 | 146 | def close(self): 147 | self._env.close() 148 | 149 | @property 150 | def observation_size(self): 151 | return self._env.observation_space.shape[0] if self.symbolic else (3, 64, 64) 152 | 153 | @property 154 | def action_size(self): 155 | return self._env.action_space.shape[0] 156 | 157 | @property 158 | def action_range(self): 159 | return float(self._env.action_space.low[0]), float(self._env.action_space.high[0]) 160 | 161 | # Sample an action randomly from a uniform distribution over all valid actions 162 | def sample_random_action(self): 163 | return torch.from_numpy(self._env.action_space.sample()) 164 | 165 | 166 | class ClipActionWrapper(Wrapper): 167 | def __init__(self, env): 168 | super(ClipActionWrapper, self).__init__(env) 169 | 170 | def step(self, action): 171 | act_space = self.env.action_space 172 | action = np.clip(action, act_space.low, act_space.high) 173 | return self.env.step(action) 174 | 175 | 176 | def Env(env_name, symbolic, seed, max_episode_length, action_repeat, bit_depth, cost_reduce="max"): 177 | if env_name in GYM_ENVS: 178 | env = GymEnv(env_name, symbolic, seed, max_episode_length, action_repeat, bit_depth, cost_reduce=cost_reduce) 179 | elif env_name in CONTROL_SUITE_ENVS: 180 | env = ControlSuiteEnv(env_name, symbolic, seed, max_episode_length, action_repeat, bit_depth, cost_reduce=cost_reduce) 181 | else: 182 | env = None 183 | 184 | return env 185 | 186 | 187 | # Wrapper for batching environments together 188 | class EnvBatcher(): 189 | def __init__(self, env_class, env_args, env_kwargs, n): 190 | self.n = n 191 | self.envs = [env_class(*env_args, **env_kwargs) for _ in range(n)] 192 | self.dones = [True] * n 193 | 194 | # Resets every environment and returns observation 195 | def reset(self): 196 | observations = [env.reset() for env in self.envs] 197 | self.dones = [False] * self.n 198 | return torch.cat(observations) 199 | 200 | # Steps/resets every environment and returns (observation, reward, done) 201 | def step(self, actions): 202 | done_mask = torch.nonzero(torch.tensor(self.dones))[:, 0] # Done mask to blank out observations and zero rewards for previously terminated environments 203 | observations, rewards, costs, dones = zip(*[env.step(action) for env, action in zip(self.envs, actions)]) 204 | dones = [d or prev_d for d, prev_d in zip(dones, self.dones)] # Env should remain terminated if previously terminated 205 | self.dones = dones 206 | observations, rewards, costs, dones = torch.cat(observations), torch.tensor(rewards, dtype=torch.float32), torch.tensor(costs, dtype=torch.float32), torch.tensor(dones, dtype=torch.uint8) 207 | observations[done_mask] = 0 208 | rewards[done_mask] = 0 209 | costs[done_mask] = 0 210 | return observations, rewards, costs, dones 211 | 212 | def close(self): 213 | [env.close() for env in self.envs] 214 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: cap 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _pytorch_select=0.2=gpu_0 8 | - _tflow_select=2.3.0=mkl 9 | - absl-py=0.13.0=py36h06a4308_0 10 | - aiohttp=2.3.9=py36_0 11 | - anyio=2.2.0=py36h06a4308_1 12 | - argon2-cffi=20.1.0=py36h27cfd23_1 13 | - astor=0.8.1=py36h06a4308_0 14 | - astunparse=1.6.3=py_0 15 | - async-timeout=3.0.1=py36h06a4308_0 16 | - async_generator=1.10=py36h28b3542_0 17 | - attrs=20.3.0=pyhd3eb1b0_0 18 | - babel=2.9.0=pyhd3eb1b0_0 19 | - backcall=0.2.0=pyhd3eb1b0_0 20 | - blas=1.0=mkl 21 | - bleach=3.3.0=pyhd3eb1b0_0 22 | - blinker=1.4=py36h06a4308_0 23 | - blosc=1.21.0=h8c45485_0 24 | - brotli=1.0.9=he6710b0_2 25 | - brotlipy=0.7.0=py36h27cfd23_1003 26 | - bzip2=1.0.8=h7b6447c_0 27 | - c-ares=1.17.1=h27cfd23_0 28 | - ca-certificates=2021.7.5=h06a4308_1 29 | - certifi=2021.5.30=py36h06a4308_0 30 | - cffi=1.14.5=py36h261ae71_0 31 | - chardet=4.0.0=py36h06a4308_1003 32 | - charls=2.1.0=he6710b0_2 33 | - click=7.1.2=pyhd3eb1b0_0 34 | - colorcet=2.0.6=pyhd3eb1b0_0 35 | - contextvars=2.4=py_0 36 | - coverage=5.5=py36h27cfd23_2 37 | - cryptography=3.4.7=py36hd23ed53_0 38 | - cudatoolkit=10.0.130=0 39 | - cudnn=7.6.5=cuda10.0_0 40 | - cycler=0.10.0=py36_0 41 | - cytoolz=0.11.0=py36h7b6447c_0 42 | - dask-core=2021.3.0=pyhd3eb1b0_0 43 | - dataclasses=0.8=pyh4f3eec9_6 44 | - datashape=0.5.4=py36h06a4308_1 45 | - dbus=1.13.18=hb2f20db_0 46 | - defusedxml=0.7.1=pyhd3eb1b0_0 47 | - dill=0.3.3=pyhd3eb1b0_0 48 | - distributed=2021.3.0=py36h06a4308_0 49 | - entrypoints=0.3=py36_0 50 | - expat=2.3.0=h2531618_2 51 | - fontconfig=2.13.1=h6c09931_0 52 | - freetype=2.10.4=h5ab3b9f_0 53 | - fsspec=0.9.0=pyhd3eb1b0_0 54 | - gast=0.3.3=py_0 55 | - giflib=5.1.4=h14c3975_1 56 | - glib=2.68.1=h36276a3_0 57 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 58 | - google-pasta=0.2.0=py_0 59 | - grpcio=1.36.1=py36h2157cd5_1 60 | - gst-plugins-base=1.14.0=h8213a91_2 61 | - gstreamer=1.14.0=h28cd5cc_2 62 | - hdf5=1.10.6=hb1b8bf9_0 63 | - heapdict=1.0.1=py_0 64 | - icu=58.2=he6710b0_3 65 | - idna=2.10=pyhd3eb1b0_0 66 | - imagecodecs=2020.5.30=py36hfa7d478_2 67 | - imageio=2.9.0=pyhd3eb1b0_0 68 | - immutables=0.15=py36h27cfd23_0 69 | - importlib-metadata=3.10.0=py36h06a4308_0 70 | - importlib_metadata=3.10.0=hd3eb1b0_0 71 | - intel-openmp=2020.2=254 72 | - ipython=7.16.1=py36h5ca1d4c_0 73 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 74 | - ipywidgets=7.6.3=pyhd3deb0d_0 75 | - jedi=0.17.0=py36_0 76 | - jinja2=2.11.3=pyhd3eb1b0_0 77 | - jpeg=9b=h024ee3a_2 78 | - json5=0.9.5=py_0 79 | - jsonschema=3.2.0=py_2 80 | - jupyter-packaging=0.7.12=pyhd3eb1b0_0 81 | - jupyter_client=6.1.12=pyhd3eb1b0_0 82 | - jupyter_core=4.7.1=py36h06a4308_0 83 | - jupyter_server=1.4.1=py36h06a4308_0 84 | - jupyterlab_pygments=0.1.2=py_0 85 | - jupyterlab_server=2.4.0=pyhd3eb1b0_0 86 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 87 | - jxrlib=1.1=h7b6447c_2 88 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 89 | - kiwisolver=1.3.1=py36h2531618_0 90 | - lcms2=2.12=h3be6417_0 91 | - ld_impl_linux-64=2.33.1=h53a641e_7 92 | - libaec=1.0.4=he6710b0_1 93 | - libblas=3.8.0=21_mkl 94 | - libcblas=3.8.0=21_mkl 95 | - libffi=3.3=he6710b0_2 96 | - libgcc-ng=9.1.0=hdf63c60_0 97 | - libgfortran-ng=7.3.0=hdf63c60_0 98 | - liblapack=3.8.0=21_mkl 99 | - libllvm10=10.0.1=hbcb73fb_5 100 | - libpng=1.6.37=hbc83047_0 101 | - libprotobuf=3.14.0=h8c45485_0 102 | - libsodium=1.0.18=h7b6447c_0 103 | - libstdcxx-ng=9.1.0=hdf63c60_0 104 | - libtiff=4.1.0=h2733197_1 105 | - libuuid=1.0.3=h1bed415_2 106 | - libwebp=1.0.1=h8e7db2f_0 107 | - libxcb=1.14=h7b6447c_0 108 | - libxml2=2.9.10=hb55368b_3 109 | - libzopfli=1.0.3=he6710b0_0 110 | - llvmlite=0.36.0=py36h612dafd_4 111 | - locket=0.2.1=py36h06a4308_1 112 | - lz4-c=1.9.3=h2531618_0 113 | - markdown=3.3.4=py36h06a4308_0 114 | - markupsafe=1.1.1=py36h7b6447c_0 115 | - matplotlib=3.3.4=py36h06a4308_0 116 | - matplotlib-base=3.3.4=py36h62a2d02_0 117 | - mistune=0.8.4=py36h7b6447c_0 118 | - mkl=2020.2=256 119 | - mkl-service=2.3.0=py36he8ac12f_0 120 | - mkl_fft=1.3.0=py36h54f3939_0 121 | - mkl_random=1.1.1=py36h0573a6f_0 122 | - msgpack-python=1.0.2=py36hff7bd54_1 123 | - multidict=5.1.0=py36h27cfd23_2 124 | - multipledispatch=0.6.0=py36_0 125 | - nbclassic=0.2.6=pyhd3eb1b0_0 126 | - nbclient=0.5.3=pyhd3eb1b0_0 127 | - nbconvert=6.0.7=py36_0 128 | - nbformat=5.1.3=pyhd3eb1b0_0 129 | - ncurses=6.2=he6710b0_1 130 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 131 | - networkx=2.5=py_0 132 | - ninja=1.10.2=hff7bd54_1 133 | - notebook=6.3.0=py36h06a4308_0 134 | - numba=0.53.1=py36ha9443f7_0 135 | - oauthlib=3.1.1=pyhd3eb1b0_0 136 | - olefile=0.46=py36_0 137 | - openjpeg=2.3.0=h05c96fa_1 138 | - openssl=1.1.1l=h7f8727e_0 139 | - opt_einsum=3.3.0=pyhd3eb1b0_1 140 | - packaging=20.9=pyhd3eb1b0_0 141 | - pandoc=2.12=h06a4308_0 142 | - pandocfilters=1.4.3=py36h06a4308_1 143 | - param=1.10.1=pyhd3eb1b0_0 144 | - parso=0.8.2=pyhd3eb1b0_0 145 | - partd=1.2.0=pyhd3eb1b0_0 146 | - pcre=8.44=he6710b0_0 147 | - pexpect=4.8.0=pyhd3eb1b0_3 148 | - pickleshare=0.7.5=pyhd3eb1b0_1003 149 | - pip=21.0.1=py36h06a4308_0 150 | - prometheus_client=0.10.1=pyhd3eb1b0_0 151 | - prompt-toolkit=3.0.17=pyh06a4308_0 152 | - psutil=5.8.0=py36h27cfd23_1 153 | - ptyprocess=0.7.0=pyhd3eb1b0_2 154 | - pyasn1=0.4.8=py_0 155 | - pycparser=2.20=py_2 156 | - pyct=0.4.8=py36_0 157 | - pygments=2.8.1=pyhd3eb1b0_0 158 | - pyjwt=2.1.0=py36h06a4308_0 159 | - pynndescent=0.5.2=pyhd3eb1b0_0 160 | - pyopenssl=20.0.1=pyhd3eb1b0_1 161 | - pyparsing=2.4.7=pyhd3eb1b0_0 162 | - pyqt=5.9.2=py36h05f1152_2 163 | - pyrsistent=0.17.3=py36h7b6447c_0 164 | - pysocks=1.7.1=py36h06a4308_0 165 | - python=3.6.13=hdb3f193_0 166 | - python-dateutil=2.8.1=pyhd3eb1b0_0 167 | - python_abi=3.6=1_cp36m 168 | - pytorch=1.3.1=cuda100py36h53c1284_0 169 | - pytz=2021.1=pyhd3eb1b0_0 170 | - pyviz_comms=2.0.1=pyhd3eb1b0_0 171 | - pywavelets=1.1.1=py36h7b6447c_2 172 | - pyyaml=5.4.1=py36h27cfd23_1 173 | - pyzmq=20.0.0=py36h2531618_1 174 | - qt=5.9.7=h5867ecd_1 175 | - readline=8.1=h27cfd23_0 176 | - requests=2.25.1=pyhd3eb1b0_0 177 | - requests-oauthlib=1.3.0=py_0 178 | - rope=0.18.0=py_0 179 | - rsa=4.7.2=pyhd3eb1b0_1 180 | - scikit-image=0.17.2=py36hdf5156a_0 181 | - scikit-learn=0.24.1=py36ha9443f7_0 182 | - seaborn=0.11.2=pyhd3eb1b0_0 183 | - send2trash=1.5.0=pyhd3eb1b0_1 184 | - setuptools=52.0.0=py36h06a4308_0 185 | - sip=4.19.8=py36hf484d3e_0 186 | - six=1.15.0=py36h06a4308_0 187 | - snappy=1.1.8=he6710b0_0 188 | - sniffio=1.2.0=py36h06a4308_1 189 | - sortedcontainers=2.3.0=pyhd3eb1b0_0 190 | - sqlite=3.35.4=hdfb4753_0 191 | - tbb=2020.3=hfd86e86_0 192 | - tblib=1.7.0=py_0 193 | - tensorboard=2.4.0=pyhc547734_0 194 | - tensorboard-plugin-wit=1.6.0=py_0 195 | - tensorflow-base=2.2.0=mkl_py36hd506778_0 196 | - tensorflow-estimator=2.5.0=pyh7b7c402_0 197 | - termcolor=1.1.0=py36h06a4308_1 198 | - terminado=0.9.4=py36h06a4308_0 199 | - testpath=0.4.4=pyhd3eb1b0_0 200 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 201 | - tifffile=2021.3.17=pyhd3eb1b0_1 202 | - tk=8.6.10=hbc83047_0 203 | - toolz=0.11.1=pyhd3eb1b0_0 204 | - torchvision=0.4.2=cuda100py36hecfc37a_0 205 | - tornado=6.1=py36h27cfd23_0 206 | - tqdm=4.59.0=pyhd3eb1b0_1 207 | - traitlets=4.3.3=py36_0 208 | - typing_extensions=3.7.4.3=pyha847dfd_0 209 | - umap-learn=0.5.1=py36h5fab9bb_0 210 | - urllib3=1.26.4=pyhd3eb1b0_0 211 | - wcwidth=0.2.5=py_0 212 | - webencodings=0.5.1=py36_1 213 | - werkzeug=1.0.1=pyhd3eb1b0_0 214 | - wheel=0.36.2=pyhd3eb1b0_0 215 | - widgetsnbextension=3.5.1=py36_0 216 | - wrapt=1.12.1=py36h7b6447c_1 217 | - xarray=0.17.0=pyhd3eb1b0_0 218 | - xz=5.2.5=h7b6447c_0 219 | - yaml=0.2.5=h7b6447c_0 220 | - yarl=1.6.3=py36h27cfd23_0 221 | - zeromq=4.3.4=h2531618_0 222 | - zict=2.0.0=pyhd3eb1b0_0 223 | - zipp=3.4.1=pyhd3eb1b0_0 224 | - zlib=1.2.11=h7b6447c_3 225 | - zstd=1.4.5=h9ceee32_0 226 | - pip: 227 | - atari-py==0.2.6 228 | - blessings==1.7 229 | - box2d==2.3.10 230 | - cached-property==1.5.2 231 | - cachetools==4.2.1 232 | - cloudpickle==1.3.0 233 | - configparser==5.0.2 234 | - cython==0.29.23 235 | - decorator==4.4.2 236 | - dm-control==0.0.364896371 237 | - dm-env==1.4 238 | - dm-tree==0.1.5 239 | - docker-pycreds==0.4.0 240 | - fasteners==0.16 241 | - future==0.18.2 242 | - gitdb==4.0.5 243 | - gitpython==3.1.14 244 | - glfw==2.1.0 245 | - google-api-core==1.26.3 246 | - google-api-python-client==2.2.0 247 | - google-auth==1.28.1 248 | - google-auth-httplib2==0.1.0 249 | - googleapis-common-protos==1.53.0 250 | - gpustat==0.6.0 251 | - gym==0.17.0 252 | - h5py==3.1.0 253 | - httplib2==0.19.1 254 | - imageio-ffmpeg==0.4.4 255 | - imutils==0.5.1 256 | - ipykernel==5.5.3 257 | - joblib==0.14.1 258 | - labmaze==1.0.3 259 | - lxml==4.6.2 260 | - moviepy==1.0.3 261 | - mujoco-py==2.0.2.7 262 | - numpy==1.17.5 263 | - nvidia-ml-py3==7.352.0 264 | - opencv-python==4.5.1.48 265 | - pandas==1.1.5 266 | - pathtools==0.1.2 267 | - pillow==7.2.0 268 | - plotly==5.1.0 269 | - proglog==0.1.9 270 | - promise==2.3 271 | - protobuf==3.15.6 272 | - pyasn1-modules==0.2.8 273 | - pybullet==3.1.0 274 | - pygame==2.0.1 275 | - pyglet==1.5.0 276 | - pyopengl==3.1.5 277 | - scipy==1.5.4 278 | - sentry-sdk==1.0.0 279 | - shortuuid==1.0.1 280 | - smmap==3.0.5 281 | - stable-baselines3==1.0 282 | - subprocess32==3.5.4 283 | - tenacity==8.0.1 284 | - torch==1.10.0 285 | - uritemplate==3.0.1 286 | - wandb==0.10.28 287 | - xmltodict==0.12.0 288 | -------------------------------------------------------------------------------- /cap-pets/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import itertools 4 | 5 | import torch 6 | from torch import nn as nn 7 | from torch.nn import functional as F 8 | from torch.utils.data import TensorDataset, DataLoader 9 | 10 | 11 | def swish(x): 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | def truncated_normal(size, std): 16 | val = torch.fmod(torch.randn(size),2) * std 17 | return torch.tensor(val, dtype=torch.float32) 18 | 19 | 20 | def get_affine_params(ensemble_size, in_features, out_features): 21 | 22 | w = truncated_normal(size=(ensemble_size, in_features, out_features), 23 | std=1.0 / (2.0 * np.sqrt(in_features))) 24 | w = nn.Parameter(w) 25 | 26 | b = nn.Parameter(torch.zeros(ensemble_size, 1, out_features, dtype=torch.float32)) 27 | 28 | return w, b 29 | 30 | class EnsembleLayer(nn.Module): 31 | def __init__(self, network_size, in_size, out_size): 32 | super().__init__() 33 | self.w, self.b = get_affine_params(network_size, in_size, out_size) 34 | 35 | def forward(self, inputs): 36 | inputs = inputs.matmul(self.w) + self.b 37 | inputs = swish(inputs) 38 | return inputs 39 | 40 | def decays(self): 41 | return (self.w ** 2).sum() / 2.0 42 | 43 | class GaussianEnsembleLayer(nn.Module): 44 | def __init__(self, network_size, in_size, out_size): 45 | super().__init__() 46 | self.w, self.b = get_affine_params(network_size, in_size, out_size * 2) 47 | self.out_size = out_size 48 | self.max_logvar = nn.Parameter(torch.ones(1, self.out_size, dtype=torch.float32) / 2.0) 49 | self.min_logvar = nn.Parameter(-torch.ones(1, self.out_size, dtype=torch.float32) * 10.0) 50 | 51 | def forward(self, inputs, sample=True): 52 | outputs = inputs.matmul(self.w) + self.b 53 | mean = outputs[:, :, :self.out_size] 54 | logvar = outputs[:, :, self.out_size:] 55 | logvar = self.max_logvar - F.softplus(self.max_logvar - logvar) 56 | logvar = self.min_logvar + F.softplus(logvar - self.min_logvar) 57 | var = torch.exp(logvar) 58 | if sample: 59 | noise = torch.randn_like(mean, device=inputs.device) * var.sqrt() 60 | return mean + noise * var.sqrt(), var 61 | else: 62 | return mean, var 63 | 64 | def decays(self): 65 | return (self.w ** 2).sum() / 2.0 66 | 67 | def compute_loss(self, output, target): 68 | train_loss = 0.01 * (self.max_logvar.sum() - self.min_logvar.sum()) 69 | mean, var = output 70 | logvar = torch.log(var) 71 | 72 | inv_var = torch.pow(var, -1) 73 | train_losses = ((mean - target) ** 2) * inv_var + logvar 74 | train_losses = train_losses.mean(-1).mean(-1).sum() 75 | train_loss += train_losses 76 | return train_loss 77 | 78 | class LogisticEnsembleLayer(nn.Module): 79 | def __init__(self, network_size, in_size, out_size): 80 | super().__init__() 81 | self.w, self.b = get_affine_params(network_size, in_size, out_size) 82 | 83 | def forward(self, inputs, sample=True): 84 | logits = inputs.matmul(self.w) + self.b 85 | return logits, None 86 | 87 | def decays(self): 88 | return (self.w ** 2).sum() / 2.0 89 | 90 | def compute_loss(self, output, target): 91 | logits, _ = output 92 | mean = torch.sigmoid(logits) 93 | 94 | train_loss = F.binary_cross_entropy(mean, target, reduce=False) 95 | train_loss = train_loss.mean(-1).mean(-1).sum() 96 | return train_loss 97 | 98 | class ProbEnsemble(nn.Module): 99 | 100 | def __init__(self, state_size, action_size, 101 | network_size=7, elite_size=5, cuda=True, 102 | cost=False, binary_cost=False, hidden_size=200, lr=0.001): 103 | super().__init__() 104 | self.network_size = network_size 105 | self.num_nets = network_size 106 | self.state_size = state_size 107 | self.action_size = action_size 108 | self.binary_cost = binary_cost 109 | self.elite_size = elite_size 110 | self.elite_model_idxes = [] 111 | 112 | self.in_features = state_size + action_size 113 | 114 | self.layer0 = EnsembleLayer(network_size, self.in_features, hidden_size) 115 | self.layer1 = EnsembleLayer(network_size, hidden_size, hidden_size) 116 | self.layer2 = EnsembleLayer(network_size, hidden_size, hidden_size) 117 | self.layer3 = EnsembleLayer(network_size, hidden_size, hidden_size) 118 | 119 | self.state_model = GaussianEnsembleLayer(network_size, hidden_size, state_size) 120 | self.reward_model = GaussianEnsembleLayer(network_size, hidden_size, 1) 121 | if binary_cost: 122 | self.cost_model = LogisticEnsembleLayer(network_size, hidden_size, 1) 123 | else: 124 | self.cost_model = GaussianEnsembleLayer(network_size, hidden_size, 1) 125 | 126 | self.inputs_mu = nn.Parameter(torch.zeros(1, self.in_features), requires_grad=False) 127 | self.inputs_sigma = nn.Parameter(torch.zeros(1, self.in_features), requires_grad=False) 128 | 129 | self.fit_input = False 130 | 131 | self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) 132 | self.grad_update = 0 133 | 134 | self.device = 'cuda' if cuda else 'cpu' 135 | 136 | def compute_decays(self): 137 | 138 | lin0_decays = 0.000025 * self.layer0.decays() 139 | lin1_decays = 0.00005 * self.layer1.decays() 140 | lin2_decays = 0.000075 * self.layer2.decays() 141 | lin3_decays = 0.000075 * self.layer3.decays() 142 | lin4_decays = 0.0001 * (self.state_model.decays() + self.reward_model.decays() + self.cost_model.decays()) 143 | 144 | return lin0_decays + lin1_decays + lin2_decays + lin3_decays + lin4_decays 145 | 146 | def fit_input_stats(self, data): 147 | self.fit_input = True 148 | mu = np.mean(data, axis=0, keepdims=True) 149 | sigma = np.std(data, axis=0, keepdims=True) 150 | sigma[sigma < 1e-12] = 1.0 151 | 152 | self.inputs_mu.data = torch.from_numpy(mu).to(self.device).float() 153 | self.inputs_sigma.data = torch.from_numpy(sigma).to(self.device).float() 154 | 155 | def forward(self, inputs, sample=False, return_state_variance=False): 156 | # Transform inputs 157 | if self.fit_input: 158 | inputs = (inputs - self.inputs_mu) / self.inputs_sigma 159 | 160 | hidden = self.layer0(inputs) 161 | hidden = self.layer1(hidden) 162 | hidden = self.layer2(hidden) 163 | hidden = self.layer3(hidden) 164 | 165 | output = { 166 | "state": self.state_model(hidden, sample=sample), 167 | "reward": self.reward_model(hidden, sample=sample), 168 | "cost": self.cost_model(hidden, sample=sample), 169 | } 170 | return output 171 | 172 | def compute_loss(self, input, target): 173 | output = self(input) 174 | train_loss = self.state_model.compute_loss(output["state"], target["state"]) 175 | train_loss += self.reward_model.compute_loss(output["reward"], target["reward"]) 176 | train_loss += self.cost_model.compute_loss(output["cost"], target["cost"]) 177 | train_loss += self.compute_decays() 178 | return train_loss 179 | 180 | def _save_best(self, epoch, holdout_losses): 181 | updated = False 182 | updated_count = 0 183 | for i in range(len(holdout_losses)): 184 | current = holdout_losses[i] 185 | _, best = self._snapshots[i] 186 | improvement = (best - current) / abs(best) 187 | if improvement > 0.01: 188 | self._snapshots[i] = (epoch, current) 189 | updated = True 190 | updated_count += 1 191 | improvement = (best - current) / best 192 | 193 | if updated: 194 | self._epochs_since_update = 0 195 | else: 196 | self._epochs_since_update += 1 197 | 198 | if self._epochs_since_update > self._max_epochs_since_update: 199 | return True 200 | else: 201 | return False 202 | 203 | def train(self, inputs, targets, batch_size=256, max_epochs_since_update=5, max_epochs=5): 204 | self._max_epochs_since_update = max_epochs_since_update 205 | self._snapshots = {i: (None, 1e10) for i in range(self.num_nets)} 206 | self._epochs_since_update = 0 207 | 208 | def shuffle_rows(arr): 209 | idxs = np.argsort(np.random.uniform(size=arr.shape), axis=-1) 210 | return arr[np.arange(arr.shape[0])[:, None], idxs] 211 | 212 | self.fit_input_stats(inputs) 213 | 214 | idxs = np.random.randint(inputs.shape[0], size=[self.num_nets, inputs.shape[0]]) 215 | 216 | if max_epochs is not None: 217 | epoch_iter = range(max_epochs) 218 | else: 219 | epoch_iter = itertools.count() 220 | 221 | for epoch in epoch_iter: 222 | for batch_num in range(int(np.ceil(idxs.shape[-1] / batch_size))): 223 | batch_idxs = idxs[:, batch_num * batch_size:(batch_num + 1) * batch_size] 224 | 225 | input = torch.from_numpy(inputs[batch_idxs]).float().to(self.device) 226 | target = { key: torch.from_numpy(value[batch_idxs]).float().to(self.device) 227 | for key, value in targets.items() } 228 | train_loss = self.compute_loss(input, target) 229 | 230 | self.optimizer.zero_grad() 231 | train_loss.backward() 232 | self.optimizer.step() 233 | 234 | self.grad_update += 1 235 | 236 | idxs = shuffle_rows(idxs) 237 | 238 | def predict(self, state, action, variance=False): 239 | input = torch.cat([state, action], dim=-1) 240 | with torch.no_grad(): 241 | output = self(input, sample=True) 242 | output["state"] = (output["state"][0] + state, output["state"][1]) 243 | return output 244 | -------------------------------------------------------------------------------- /cap-pets/ccem.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch.optim import Adam 5 | 6 | import scipy.stats as stats 7 | 8 | STATE_MAX = 100 9 | 10 | class ConstrainedCEM: 11 | def __init__(self, 12 | env, 13 | epoch_length=1000, 14 | plan_hor=30, 15 | gamma=0.99, 16 | c_gamma=0.99, 17 | kappa=0, 18 | binary_cost=False, 19 | cost_limit=0, 20 | learn_kappa=False, 21 | cost_constrained=True, 22 | penalize_uncertainty=True, 23 | cuda=True, 24 | ): 25 | self.dO, self.dU = env.observation_space.shape[0], env.action_space.shape[0] 26 | self.ac_ub, self.ac_lb = env.action_space.high, env.action_space.low 27 | self.gamma = gamma 28 | self.c_gamma = c_gamma 29 | 30 | self.learn_kappa = learn_kappa 31 | if not self.learn_kappa: 32 | self.kappa = torch.tensor([kappa], requires_grad=False) 33 | else: 34 | self.kappa = torch.tensor([float(kappa)], requires_grad=True) 35 | self.kappa_optim = Adam([self.kappa], lr=0.1) 36 | self.binary_cost = binary_cost 37 | 38 | self.cost_limit = cost_limit 39 | self.epoch_length = epoch_length 40 | self.cost_constrained = cost_constrained 41 | self.penalize_uncertainty = penalize_uncertainty 42 | self.device = 'cuda' if cuda else 'cpu' 43 | 44 | # CEM parameters 45 | self.per = 1 46 | self.npart = 20 47 | self.plan_hor = plan_hor 48 | self.popsize = 500 49 | self.num_elites = 50 50 | self.max_iters = 5 51 | self.alpha = 0.1 52 | self.epsilon = 0.001 53 | self.lb = np.tile(self.ac_lb, [self.plan_hor]) 54 | self.ub = np.tile(self.ac_ub, [self.plan_hor]) 55 | self.decay = 1.25 56 | self.elite_fraction = 0.3 57 | self.elites = None 58 | 59 | self.ac_buf = np.array([]).reshape(0, self.dU) 60 | self.prev_sol = np.tile((self.ac_lb + self.ac_ub) / 2, [self.plan_hor]) 61 | self.init_var = np.tile(np.square(self.ac_ub - self.ac_lb) / 16, [self.plan_hor]) 62 | 63 | self.model = None 64 | self.step = 0 65 | 66 | def set_model(self, model): 67 | self.model = model 68 | 69 | def select_action(self, obs, eval_t=False): 70 | if self.model is None: 71 | return np.random.uniform(self.ac_lb, self.ac_ub, self.ac_lb.shape) 72 | if self.ac_buf.shape[0] > 0: 73 | action, self.ac_buf = self.ac_buf[0], self.ac_buf[1:] 74 | return action 75 | 76 | soln = self.obtain_solution(obs, self.prev_sol, self.init_var) 77 | self.prev_sol = np.concatenate([np.copy(soln)[self.per * self.dU:], np.zeros(self.per * self.dU)]) 78 | self.ac_buf = soln[:self.per * self.dU].reshape(-1, self.dU) 79 | 80 | return self.select_action(obs) 81 | 82 | def obtain_solution(self, obs, init_mean, init_var): 83 | mean, var, t = init_mean, init_var, 0 84 | X = stats.truncnorm(-2, 2, loc=np.zeros_like(mean), scale=np.ones_like(var)) 85 | 86 | while (t < self.max_iters) and np.max(var) > self.epsilon: 87 | lb_dist, ub_dist = mean - self.lb, self.ub - mean 88 | constrained_var = np.minimum(np.minimum(np.square(lb_dist / 2), np.square(ub_dist / 2)), var) 89 | 90 | noise = X.rvs(size=[self.popsize, self.plan_hor * self.dU]) 91 | 92 | samples = noise * np.sqrt(constrained_var) + mean 93 | samples = samples.astype(np.float32) 94 | 95 | rewards, costs, eps_lens = self.rollout(obs, samples) 96 | epoch_ratio = np.ones_like(eps_lens) * self.epoch_length / self.plan_hor 97 | terminated = eps_lens != self.plan_hor 98 | if self.c_gamma == 1: 99 | c_gamma_discount = epoch_ratio 100 | else: 101 | c_gamma_discount = (1 - self.c_gamma ** (epoch_ratio * self.plan_hor)) / (1 - self.c_gamma) / self.plan_hor 102 | rewards = rewards * epoch_ratio 103 | costs = costs * c_gamma_discount 104 | 105 | feasible_ids = ((costs <= self.cost_limit) & (~terminated)).nonzero()[0] 106 | if self.cost_constrained: 107 | if feasible_ids.shape[0] >= self.num_elites: 108 | elite_ids = feasible_ids[np.argsort(-rewards[feasible_ids])][:self.num_elites] 109 | else: 110 | elite_ids = np.argsort(costs)[:self.num_elites] 111 | else: 112 | elite_ids = np.argsort(-rewards)[:self.num_elites] 113 | self.elites = samples[elite_ids] 114 | new_mean = np.mean(self.elites, axis=0) 115 | new_var = np.var(self.elites, axis=0) 116 | 117 | mean = self.alpha * mean + (1 - self.alpha) * new_mean 118 | var = self.alpha * var + (1 - self.alpha) * new_var 119 | 120 | average_reward = rewards.mean().item() 121 | average_cost = costs.mean().item() 122 | average_len = eps_lens.mean().item() 123 | average_elite_reward = rewards[elite_ids].mean().item() 124 | average_elite_cost = costs[elite_ids].mean().item() 125 | average_elite_len = eps_lens[elite_ids].mean().item() 126 | if t == 0: 127 | start_reward = average_reward 128 | start_cost = average_cost 129 | t += 1 130 | 131 | self.step += 1 132 | return mean 133 | 134 | @torch.no_grad() 135 | def rollout(self, obs, ac_seqs): 136 | nopt = ac_seqs.shape[0] 137 | 138 | ac_seqs = torch.from_numpy(ac_seqs).float().to(self.device) 139 | 140 | # Reshape ac_seqs so that it's amenable to parallel compute 141 | # Before, ac seqs has dimension (400, 25) which are pop size and sol dim coming from CEM 142 | ac_seqs = ac_seqs.view(-1, self.plan_hor, self.dU) 143 | # After, ac seqs has dimension (400, 25, 1) 144 | 145 | transposed = ac_seqs.transpose(0, 1) 146 | # Then, (25, 400, 1) 147 | 148 | expanded = transposed[:, :, None] 149 | # Then, (25, 400, 1, 1) 150 | 151 | tiled = expanded.expand(-1, -1, self.npart, -1) 152 | # Then, (25, 400, 20, 1) 153 | 154 | ac_seqs = tiled.contiguous().view(self.plan_hor, -1, self.dU) 155 | # Then, (25, 8000, 1) 156 | 157 | # Expand current observation 158 | cur_obs = torch.from_numpy(obs).float().to(self.device) 159 | cur_obs = cur_obs[None] 160 | cur_obs = cur_obs.expand(nopt * self.npart, -1) 161 | 162 | rewards = torch.zeros(nopt, self.npart, device=self.device) 163 | costs = torch.zeros(nopt, self.npart, device=self.device) 164 | length = torch.zeros(nopt, self.npart, device=self.device) 165 | 166 | for t in range(self.plan_hor): 167 | cur_acs = ac_seqs[t] 168 | 169 | cur_obs, reward, cost = self._predict_next(cur_obs, cur_acs) 170 | # Clip state value 171 | cur_obs = torch.clamp(cur_obs, -STATE_MAX, STATE_MAX) 172 | reward = reward.view(-1, self.npart) 173 | cost = cost.view(-1, self.npart) 174 | 175 | rewards += reward 176 | costs += cost 177 | length += 1 178 | 179 | if t == 0: 180 | start_reward = reward 181 | start_cost = cost 182 | 183 | # Replace nan with high cost 184 | rewards[rewards != rewards] = -1e6 185 | costs[costs != costs] = 1e6 186 | 187 | return rewards.mean(dim=1).detach().cpu().numpy(), costs.mean(dim=1).detach().cpu().numpy(), length.mean(dim=1).detach().cpu().numpy() 188 | 189 | def optimize_kappa(self, episode_cost, permissible_cost=None): 190 | if permissible_cost is None: 191 | permissible_cost = self.cost_limit 192 | kappa_loss = -(self.kappa * (episode_cost - permissible_cost)) 193 | 194 | self.kappa_optim.zero_grad() 195 | kappa_loss.backward() 196 | self.kappa_optim.step() 197 | 198 | def _predict_next(self, obs, acs): 199 | proc_obs = self._expand_to_ts_format(obs) 200 | proc_acs = self._expand_to_ts_format(acs) 201 | 202 | output = self.model.predict(proc_obs, proc_acs) 203 | next_obs, var = output["state"] 204 | reward, _ = output["reward"] 205 | cost, _ = output["cost"] 206 | 207 | next_obs = self._flatten_to_matrix(next_obs) 208 | reward = self._flatten_to_matrix(reward) 209 | cost = self._flatten_to_matrix(cost) 210 | 211 | obs = obs.detach().cpu().numpy() 212 | acs = acs.detach().cpu().numpy() 213 | 214 | if self.cost_constrained and self.penalize_uncertainty: 215 | cost_penalty = var.sqrt().norm(dim=2).max(0)[0] 216 | cost_penalty = cost_penalty.repeat_interleave(self.model.num_nets).view(cost.shape) 217 | cost += self.kappa.to(cost_penalty.device) * cost_penalty 218 | if self.binary_cost: 219 | cost = (torch.sigmoid(cost) > 0.5).float() 220 | 221 | return next_obs, reward, cost 222 | 223 | def _expand_to_ts_format(self, mat): 224 | dim = mat.shape[-1] 225 | 226 | # Before, [10, 5] in case of proc_obs 227 | reshaped = mat.view(-1, self.model.num_nets, self.npart // self.model.num_nets, dim) 228 | # After, [2, 5, 1, 5] 229 | 230 | transposed = reshaped.transpose(0, 1) 231 | # After, [5, 2, 1, 5] 232 | 233 | reshaped = transposed.contiguous().view(self.model.num_nets, -1, dim) 234 | # After. [5, 2, 5] 235 | 236 | return reshaped 237 | 238 | def _flatten_to_matrix(self, ts_fmt_arr): 239 | dim = ts_fmt_arr.shape[-1] 240 | 241 | reshaped = ts_fmt_arr.view(self.model.num_nets, -1, self.npart // self.model.num_nets, dim) 242 | 243 | transposed = reshaped.transpose(0, 1) 244 | 245 | reshaped = transposed.contiguous().view(-1, dim) 246 | 247 | return reshaped 248 | 249 | -------------------------------------------------------------------------------- /cap-planet/models.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | import torch 3 | from torch import jit, nn 4 | from torch.nn import functional as F 5 | 6 | 7 | # Wraps the input tuple for a function to process a time x batch x features sequence in batch x features (assumes one output) 8 | def bottle(f, x_tuple): 9 | x_sizes = tuple(map(lambda x: x.size(), x_tuple)) 10 | y = f(*map(lambda x: x[0].view(x[1][0] * x[1][1], *x[1][2:]), zip(x_tuple, x_sizes))) 11 | y_size = y.size() 12 | return y.view(x_sizes[0][0], x_sizes[0][1], *y_size[1:]) 13 | 14 | 15 | class TransitionModel(jit.ScriptModule): 16 | __constants__ = ['min_std_dev'] 17 | 18 | def __init__(self, belief_size, state_size, action_size, hidden_size, embedding_size, activation_function='relu', min_std_dev=0.1): 19 | super().__init__() 20 | self.act_fn = getattr(F, activation_function) 21 | self.min_std_dev = min_std_dev 22 | self.fc_embed_state_action = nn.Linear(state_size + action_size, belief_size) 23 | self.rnn = nn.GRUCell(belief_size, belief_size) 24 | self.fc_embed_belief_prior = nn.Linear(belief_size, hidden_size) 25 | self.fc_state_prior = nn.Linear(hidden_size, 2 * state_size) 26 | self.fc_embed_belief_posterior = nn.Linear(belief_size + embedding_size, hidden_size) 27 | self.fc_state_posterior = nn.Linear(hidden_size, 2 * state_size) 28 | 29 | # Operates over (previous) state, (previous) actions, (previous) belief, (previous) nonterminals (mask), and (current) observations 30 | # Diagram of expected inputs and outputs for T = 5 (-x- signifying beginning of output belief/state that gets sliced off): 31 | # t : 0 1 2 3 4 5 32 | # o : -X--X--X--X--X- 33 | # a : -X--X--X--X--X- 34 | # n : -X--X--X--X--X- 35 | # pb: -X- 36 | # ps: -X- 37 | # b : -x--X--X--X--X--X- 38 | # s : -x--X--X--X--X--X- 39 | @jit.script_method 40 | def forward(self, prev_state:torch.Tensor, actions:torch.Tensor, prev_belief:torch.Tensor, observations:Optional[torch.Tensor]=None, nonterminals:Optional[torch.Tensor]=None) -> List[torch.Tensor]: 41 | # Create lists for hidden states (cannot use single tensor as buffer because autograd won't work with inplace writes) 42 | T = actions.size(0) + 1 43 | beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T 44 | beliefs[0], prior_states[0], posterior_states[0] = prev_belief, prev_state, prev_state 45 | # Loop over time sequence 46 | for t in range(T - 1): 47 | _state = prior_states[t] if observations is None else posterior_states[t] # Select appropriate previous state 48 | _state = _state if nonterminals is None else _state * nonterminals[t] # Mask if previous transition was terminal 49 | # Compute belief (deterministic hidden state) 50 | hidden = self.act_fn(self.fc_embed_state_action(torch.cat([_state, actions[t]], dim=1))) 51 | beliefs[t + 1] = self.rnn(hidden, beliefs[t]) 52 | # Compute state prior by applying transition dynamics 53 | hidden = self.act_fn(self.fc_embed_belief_prior(beliefs[t + 1])) 54 | prior_means[t + 1], _prior_std_dev = torch.chunk(self.fc_state_prior(hidden), 2, dim=1) 55 | prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + self.min_std_dev 56 | prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1]) 57 | if observations is not None: 58 | # Compute state posterior by applying transition dynamics and using current observation 59 | t_ = t - 1 # Use t_ to deal with different time indexing for observations 60 | hidden = self.act_fn(self.fc_embed_belief_posterior(torch.cat([beliefs[t + 1], observations[t_ + 1]], dim=1))) 61 | posterior_means[t + 1], _posterior_std_dev = torch.chunk(self.fc_state_posterior(hidden), 2, dim=1) 62 | posterior_std_devs[t + 1] = F.softplus(_posterior_std_dev) + self.min_std_dev 63 | posterior_states[t + 1] = posterior_means[t + 1] + posterior_std_devs[t + 1] * torch.randn_like(posterior_means[t + 1]) 64 | # Return new hidden states 65 | hidden = [torch.stack(beliefs[1:], dim=0), torch.stack(prior_states[1:], dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0)] 66 | if observations is not None: 67 | hidden += [torch.stack(posterior_states[1:], dim=0), torch.stack(posterior_means[1:], dim=0), torch.stack(posterior_std_devs[1:], dim=0)] 68 | return hidden 69 | 70 | class OneStepModel(jit.ScriptModule): 71 | def __init__(self, belief_size, action_size, output_size): 72 | super(OneStepModel, self).__init__() 73 | self.layers = nn.ModuleList([ 74 | nn.Linear(belief_size + action_size, output_size), 75 | nn.Linear(output_size + action_size, output_size), 76 | ]) 77 | self.mean_layer = nn.Linear(output_size + action_size, output_size) 78 | 79 | @jit.script_method 80 | def forward(self, prev_belief:torch.Tensor, action:torch.Tensor) -> torch.Tensor: 81 | hidden = torch.cat([prev_belief, action], dim=-1) 82 | for layer in self.layers: 83 | hidden = F.relu(layer(hidden)) 84 | hidden = torch.cat([hidden, action], dim=-1) 85 | mean = self.mean_layer(hidden) 86 | return mean 87 | 88 | class OneStepEnsemble(nn.Module): 89 | def __init__(self, belief_size, action_size, embedding_size, num_models=5): 90 | super().__init__() 91 | self.num_models = num_models 92 | self.embedding_size = embedding_size 93 | self.models = nn.ModuleList([ 94 | OneStepModel(belief_size, action_size, embedding_size) for _ in range(num_models) 95 | ]) 96 | self.last_sample = None 97 | 98 | def sample_with_replacement(self, batch_size:int) -> torch.Tensor: 99 | sample = torch.randint(batch_size, [self.num_models, batch_size]) 100 | self.last_sample = sample 101 | return sample 102 | 103 | def loss(self, preds:List[torch.Tensor], target:torch.Tensor) -> torch.Tensor: 104 | target = target.detach() 105 | losses = [] 106 | for i in range(self.num_models): 107 | losses.append(F.mse_loss(preds[i], target[:, self.last_sample[i]])) 108 | return sum(losses) 109 | 110 | def forward(self, prev_beliefs:torch.Tensor, actions:torch.Tensor) -> List[torch.Tensor]: 111 | prev_beliefs, actions = prev_beliefs.detach(), actions.detach() 112 | batch_size = actions.shape[1] 113 | outputs = [] 114 | sample = self.sample_with_replacement(batch_size) 115 | for i in range(self.num_models): 116 | output = self.models[i](prev_beliefs[:, sample[i]], actions[:, sample[i]]) 117 | outputs.append(output) 118 | return outputs 119 | 120 | def compute_uncertainty(self, prev_beliefs:torch.Tensor, actions:torch.Tensor) -> torch.Tensor: 121 | batch_shape = actions.shape[:-1] 122 | with torch.no_grad(): 123 | outputs = torch.zeros([self.num_models, *batch_shape, self.embedding_size]).to(actions.device) 124 | for i in range(self.num_models): 125 | outputs[i] = self.models[i](prev_beliefs, actions) 126 | # Calculate variance in place 127 | n = outputs.shape[0] 128 | outputs = outputs - outputs.mean(0) 129 | outputs.pow_(2) 130 | var = outputs.sum(0) / (n - 1) 131 | uncertainty = var.mean(-1) 132 | return uncertainty 133 | 134 | class SymbolicObservationModel(jit.ScriptModule): 135 | def __init__(self, observation_size, belief_size, state_size, embedding_size, activation_function='relu'): 136 | super().__init__() 137 | self.act_fn = getattr(F, activation_function) 138 | self.fc1 = nn.Linear(belief_size + state_size, embedding_size) 139 | self.fc2 = nn.Linear(embedding_size, embedding_size) 140 | self.fc3 = nn.Linear(embedding_size, observation_size) 141 | 142 | @jit.script_method 143 | def forward(self, belief, state): 144 | hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=1))) 145 | hidden = self.act_fn(self.fc2(hidden)) 146 | observation = self.fc3(hidden) 147 | return observation 148 | 149 | 150 | class VisualObservationModel(jit.ScriptModule): 151 | __constants__ = ['embedding_size'] 152 | 153 | def __init__(self, belief_size, state_size, embedding_size, activation_function='relu'): 154 | super().__init__() 155 | self.act_fn = getattr(F, activation_function) 156 | self.embedding_size = embedding_size 157 | self.fc1 = nn.Linear(belief_size + state_size, embedding_size) 158 | self.conv1 = nn.ConvTranspose2d(embedding_size, 128, 5, stride=2) 159 | self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2) 160 | self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2) 161 | self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2) 162 | 163 | @jit.script_method 164 | def forward(self, belief, state): 165 | hidden = self.fc1(torch.cat([belief, state], dim=1)) # No nonlinearity here 166 | hidden = hidden.view(-1, self.embedding_size, 1, 1) 167 | hidden = self.act_fn(self.conv1(hidden)) 168 | hidden = self.act_fn(self.conv2(hidden)) 169 | hidden = self.act_fn(self.conv3(hidden)) 170 | observation = self.conv4(hidden) 171 | return observation 172 | 173 | 174 | def ObservationModel(symbolic, observation_size, belief_size, state_size, embedding_size, activation_function='relu'): 175 | if symbolic: 176 | return SymbolicObservationModel(observation_size, belief_size, state_size, embedding_size, activation_function) 177 | else: 178 | return VisualObservationModel(belief_size, state_size, embedding_size, activation_function) 179 | 180 | 181 | class RewardModel(jit.ScriptModule): 182 | def __init__(self, belief_size, state_size, hidden_size, activation_function='relu'): 183 | super().__init__() 184 | self.act_fn = getattr(F, activation_function) 185 | self.fc1 = nn.Linear(belief_size + state_size, hidden_size) 186 | self.fc2 = nn.Linear(hidden_size, hidden_size) 187 | self.fc3 = nn.Linear(hidden_size, 1) 188 | 189 | @jit.script_method 190 | def forward(self, belief, state): 191 | hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=1))) 192 | hidden = self.act_fn(self.fc2(hidden)) 193 | reward = self.fc3(hidden).squeeze(dim=1) 194 | return reward 195 | 196 | class CostModel(jit.ScriptModule): 197 | def __init__(self, belief_size, state_size, hidden_size, activation_function='relu'): 198 | super().__init__() 199 | self.act_fn = getattr(F, activation_function) 200 | self.fc1 = nn.Linear(belief_size + state_size, hidden_size) 201 | self.fc2 = nn.Linear(hidden_size, hidden_size) 202 | self.fc3 = nn.Linear(hidden_size, 1) 203 | self.loss = F.mse_loss 204 | 205 | @jit.script_method 206 | def forward(self, belief, state): 207 | hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=1))) 208 | hidden = self.act_fn(self.fc2(hidden)) 209 | cost = self.fc3(hidden).squeeze(dim=1) 210 | return cost 211 | 212 | class LogisticCostModel(jit.ScriptModule): 213 | def __init__(self, belief_size, state_size, hidden_size, activation_function='relu'): 214 | super().__init__() 215 | self.act_fn = getattr(F, activation_function) 216 | self.fc1 = nn.Linear(belief_size + state_size, hidden_size) 217 | self.fc2 = nn.Linear(hidden_size, hidden_size) 218 | self.fc3 = nn.Linear(hidden_size, 1) 219 | self.loss = F.binary_cross_entropy_with_logits 220 | 221 | @jit.script_method 222 | def forward(self, belief, state): 223 | hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=1))) 224 | hidden = self.act_fn(self.fc2(hidden)) 225 | logit = self.fc3(hidden).squeeze(dim=1) 226 | return logit 227 | 228 | 229 | class SymbolicEncoder(jit.ScriptModule): 230 | def __init__(self, observation_size, embedding_size, activation_function='relu'): 231 | super().__init__() 232 | self.act_fn = getattr(F, activation_function) 233 | self.fc1 = nn.Linear(observation_size, embedding_size) 234 | self.fc2 = nn.Linear(embedding_size, embedding_size) 235 | self.fc3 = nn.Linear(embedding_size, embedding_size) 236 | 237 | @jit.script_method 238 | def forward(self, observation): 239 | hidden = self.act_fn(self.fc1(observation)) 240 | hidden = self.act_fn(self.fc2(hidden)) 241 | hidden = self.fc3(hidden) 242 | return hidden 243 | 244 | 245 | class VisualEncoder(jit.ScriptModule): 246 | __constants__ = ['embedding_size'] 247 | 248 | def __init__(self, embedding_size, activation_function='relu'): 249 | super().__init__() 250 | self.act_fn = getattr(F, activation_function) 251 | self.embedding_size = embedding_size 252 | self.conv1 = nn.Conv2d(3, 32, 4, stride=2) 253 | self.conv2 = nn.Conv2d(32, 64, 4, stride=2) 254 | self.conv3 = nn.Conv2d(64, 128, 4, stride=2) 255 | self.conv4 = nn.Conv2d(128, 256, 4, stride=2) 256 | self.fc = nn.Identity() if embedding_size == 1024 else nn.Linear(1024, embedding_size) 257 | 258 | @jit.script_method 259 | def forward(self, observation): 260 | hidden = self.act_fn(self.conv1(observation)) 261 | hidden = self.act_fn(self.conv2(hidden)) 262 | hidden = self.act_fn(self.conv3(hidden)) 263 | hidden = self.act_fn(self.conv4(hidden)) 264 | hidden = hidden.view(-1, 1024) 265 | hidden = self.fc(hidden) # Identity if embedding size is 1024 else linear projection 266 | return hidden 267 | 268 | 269 | def Encoder(symbolic, observation_size, embedding_size, activation_function='relu'): 270 | if symbolic: 271 | return SymbolicEncoder(observation_size, embedding_size, activation_function) 272 | else: 273 | return VisualEncoder(embedding_size, activation_function) 274 | -------------------------------------------------------------------------------- /cap-planet/run_cap_planet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from math import inf 3 | import os 4 | import json 5 | import numpy as np 6 | import torch 7 | from torch import nn, optim 8 | from torch.distributions import Normal 9 | from torch.distributions.kl import kl_divergence 10 | from torch.nn import functional as F 11 | from torchvision.utils import make_grid, save_image 12 | from tqdm import tqdm 13 | from env import CONTROL_SUITE_ENVS, Env, GYM_ENVS, EnvBatcher 14 | from memory import ExperienceReplay 15 | from models import bottle, Encoder, ObservationModel, RewardModel, CostModel, LogisticCostModel, TransitionModel, OneStepEnsemble 16 | from planner import MPCPlanner 17 | from utils import lineplot, violinplot, write_video 18 | 19 | 20 | # Hyperparameters 21 | parser = argparse.ArgumentParser(description='PlaNet') 22 | parser.add_argument('--id', type=str, default='default', help='Experiment ID') 23 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed') 24 | parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA') 25 | parser.add_argument('--env', type=str, default='Pendulum-v0', choices=GYM_ENVS + CONTROL_SUITE_ENVS, help='Gym/Control Suite environment') 26 | parser.add_argument('--cost-constrained', action='store_true', help='Follow cost constraint') 27 | parser.add_argument('--penalize-uncertainty', action='store_true', help='Penalize model uncertainty as an additional cost') 28 | parser.add_argument('--penalty-kappa', type=float, default=0, metavar='λ', help='The penalty kappa') 29 | parser.add_argument('--learn-kappa', action='store_true', help='Learn the penalty kappa') 30 | parser.add_argument('--binary-cost', action='store_true') 31 | parser.add_argument('--cost-limit', type=float, default=0, metavar='CL', help='Discounted cost limit') 32 | parser.add_argument('--permissible-cost', type=float, default=None, metavar='CL_', help='Cost allowed for exploration') 33 | parser.add_argument('--symbolic-env', action='store_true', help='Symbolic features') 34 | parser.add_argument('--discount', type=float, default=0.99, metavar='γ', help='Reward discount factor') 35 | parser.add_argument('--cost-discount', type=float, default=0.99, metavar='γ', help='Cost discount factor') 36 | parser.add_argument('--max-episode-length', type=int, default=1000, metavar='T', help='Max episode length') 37 | parser.add_argument('--experience-size', type=int, default=1000000, metavar='D', help='Experience replay size') # Original implementation has an unlimited buffer size, but 1 million is the max experience collected anyway 38 | parser.add_argument('--activation-function', type=str, default='relu', choices=dir(F), help='Model activation function') 39 | parser.add_argument('--embedding-size', type=int, default=1024, metavar='E', help='Observation embedding size') # Note that the default encoder for visual observations outputs a 1024D vector; for other embedding sizes an additional fully-connected layer is used 40 | parser.add_argument('--hidden-size', type=int, default=200, metavar='H', help='Hidden size') 41 | parser.add_argument('--belief-size', type=int, default=200, metavar='H', help='Belief/hidden size') 42 | parser.add_argument('--state-size', type=int, default=30, metavar='Z', help='State/latent size') 43 | parser.add_argument('--action-repeat', type=int, default=2, metavar='R', help='Action repeat') 44 | parser.add_argument('--action-noise', type=float, default=0.0, metavar='ε', help='Action noise') 45 | parser.add_argument('--episodes', type=int, default=1000, metavar='E', help='Total number of episodes') 46 | parser.add_argument('--seed-episodes', type=int, default=5, metavar='S', help='Seed episodes') 47 | parser.add_argument('--collect-interval', type=int, default=100, metavar='C', help='Collect interval') 48 | parser.add_argument('--batch-size', type=int, default=50, metavar='B', help='Batch size') 49 | parser.add_argument('--chunk-size', type=int, default=50, metavar='L', help='Chunk size') 50 | parser.add_argument('--overshooting-distance', type=int, default=50, metavar='D', help='Latent overshooting distance/latent overshooting weight for t = 1') 51 | parser.add_argument('--overshooting-kl-beta', type=float, default=0, metavar='β>1', help='Latent overshooting KL weight for t > 1 (0 to disable)') 52 | parser.add_argument('--overshooting-reward-scale', type=float, default=0, metavar='R>1', help='Latent overshooting reward prediction weight for t > 1 (0 to disable)') 53 | parser.add_argument('--global-kl-beta', type=float, default=0, metavar='βg', help='Global KL weight (0 to disable)') 54 | parser.add_argument('--free-nats', type=float, default=3, metavar='F', help='Free nats') 55 | parser.add_argument('--bit-depth', type=int, default=5, metavar='B', help='Image bit depth (quantisation)') 56 | parser.add_argument('--learning-rate', type=float, default=1e-3, metavar='α', help='Learning rate') 57 | parser.add_argument('--learning-rate-schedule', type=int, default=0, metavar='αS', help='Linear learning rate schedule (optimisation steps from 0 to final learning rate; 0 to disable)') 58 | parser.add_argument('--adam-epsilon', type=float, default=1e-4, metavar='ε', help='Adam optimiser epsilon value') 59 | # Note that original has a linear learning rate decay, but it seems unlikely that this makes a significant difference 60 | parser.add_argument('--grad-clip-norm', type=float, default=1000, metavar='C', help='Gradient clipping norm') 61 | parser.add_argument('--planning-horizon', type=int, default=12, metavar='H', help='Planning horizon distance') 62 | parser.add_argument('--optimisation-iters', type=int, default=10, metavar='I', help='Planning optimisation iterations') 63 | parser.add_argument('--candidates', type=int, default=1000, metavar='J', help='Candidate samples per iteration') 64 | parser.add_argument('--top-candidates', type=int, default=100, metavar='K', help='Number of top candidates to fit') 65 | parser.add_argument('--test', action='store_true', help='Test only') 66 | parser.add_argument('--test-interval', type=int, default=25, metavar='I', help='Test interval (episodes)') 67 | parser.add_argument('--test-episodes', type=int, default=10, metavar='E', help='Number of test episodes') 68 | parser.add_argument('--checkpoint-interval', type=int, default=50, metavar='I', help='Checkpoint interval (episodes)') 69 | parser.add_argument('--checkpoint-experience', action='store_true', help='Checkpoint experience replay') 70 | parser.add_argument('--models', type=str, default='', metavar='M', help='Load model checkpoint') 71 | parser.add_argument('--experience-replay', type=str, default='', metavar='ER', help='Load experience replay') 72 | parser.add_argument('--render', action='store_true', help='Render environment') 73 | args = parser.parse_args() 74 | args.overshooting_distance = min(args.chunk_size, args.overshooting_distance) # Overshooting distance cannot be greater than chunk size 75 | if args.permissible_cost is None: 76 | args.permissible_cost = args.cost_limit 77 | if args.binary_cost: 78 | print("Using cost discount of 1 for binary cost") 79 | args.cost_discount = 1 80 | print(' ' * 26 + 'Options') 81 | for k, v in vars(args).items(): 82 | print(' ' * 26 + k + ': ' + str(v)) 83 | 84 | 85 | # Setup 86 | results_dir = os.path.join('results', args.id) 87 | os.makedirs(results_dir, exist_ok=True) 88 | with open(os.path.join(results_dir, 'config.txt'), 'w') as config_file: 89 | json.dump(args.__dict__, config_file, indent=2) 90 | 91 | np.random.seed(args.seed) 92 | torch.manual_seed(args.seed) 93 | if torch.cuda.is_available() and not args.disable_cuda: 94 | args.device = torch.device('cuda') 95 | torch.cuda.manual_seed(args.seed) 96 | else: 97 | args.device = torch.device('cpu') 98 | metrics = {'steps': [], 'episodes': [], 'train_rewards': [], 'train_costs': [], 'test_episodes': [], 'test_rewards': [], 99 | 'test_costs': [], 'observation_loss': [], 'reward_loss': [], 'cost_loss': [], 'kl_loss': [], 'one_step_loss': [], 100 | 'penalty_kappa': [], 'test_eps_costs': [], 'test_eps_uncertainty': [], 'cost_violations': [], 'test_cost_violations': []} 101 | 102 | 103 | # Initialise training environment and experience replay memory 104 | env = Env(args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth, cost_reduce="max" if args.binary_cost else "sum") 105 | test_envs = EnvBatcher(Env, (args.env, args.symbolic_env, None, args.max_episode_length, args.action_repeat, args.bit_depth, "max" if args.binary_cost else "sum"), {}, args.test_episodes) 106 | if args.experience_replay is not '' and os.path.exists(args.experience_replay): 107 | D = torch.load(args.experience_replay) 108 | metrics['steps'], metrics['episodes'] = [D.steps] * D.episodes, list(range(1, D.episodes + 1)) 109 | elif not args.test: 110 | D = ExperienceReplay(args.experience_size, args.symbolic_env, env.observation_size, env.action_size, args.bit_depth, args.device) 111 | # Initialise dataset D with S random seed episodes 112 | for s in range(1, args.seed_episodes + 1): 113 | observation, done, t = env.reset(), False, 0 114 | total_cost = 0 115 | while not done: 116 | action = env.sample_random_action() 117 | next_observation, reward, cost, done = env.step(action) 118 | total_cost += cost * args.cost_discount ** (args.action_repeat * (t + 0.5)) 119 | D.append(observation, action, reward, cost, done) 120 | observation = next_observation 121 | t += 1 122 | metrics['steps'].append(t * args.action_repeat + (0 if len(metrics['steps']) == 0 else metrics['steps'][-1])) 123 | metrics['episodes'].append(s) 124 | metrics['cost_violations'].append(total_cost > args.cost_limit) 125 | 126 | 127 | # Initialise model parameters randomly 128 | transition_model = TransitionModel(args.belief_size, args.state_size, env.action_size, args.hidden_size, args.embedding_size, args.activation_function).to(device=args.device) 129 | observation_model = ObservationModel(args.symbolic_env, env.observation_size, args.belief_size, args.state_size, args.embedding_size, args.activation_function).to(device=args.device) 130 | reward_model = RewardModel(args.belief_size, args.state_size, args.hidden_size, args.activation_function).to(device=args.device) 131 | if args.binary_cost: 132 | cost_model = LogisticCostModel(args.belief_size, args.state_size, args.hidden_size, args.activation_function).to(device=args.device) 133 | else: 134 | cost_model = CostModel(args.belief_size, args.state_size, args.hidden_size, args.activation_function).to(device=args.device) 135 | encoder = Encoder(args.symbolic_env, env.observation_size, args.embedding_size, args.activation_function).to(device=args.device) 136 | one_step_ensemble = OneStepEnsemble(args.belief_size, env.action_size, args.state_size) 137 | param_list = list(transition_model.parameters()) + list(observation_model.parameters()) + list(reward_model.parameters()) + list(cost_model.parameters()) + list(encoder.parameters()) + list(one_step_ensemble.parameters()) 138 | optimiser = optim.Adam(param_list, lr=0 if args.learning_rate_schedule != 0 else args.learning_rate, eps=args.adam_epsilon) 139 | planner = MPCPlanner(env.action_size, args.planning_horizon, args.optimisation_iters, args.candidates, args.top_candidates, transition_model, reward_model, cost_model, one_step_ensemble, 140 | env.action_range[0], env.action_range[1], 141 | cost_constrained=args.cost_constrained, penalize_uncertainty=args.penalize_uncertainty, 142 | cost_limit=args.cost_limit, cost_discount=args.cost_discount, action_repeat=args.action_repeat, max_length=args.max_episode_length, 143 | penalty_kappa=args.penalty_kappa, binary_cost=args.binary_cost).to(device=args.device) 144 | if args.models is not '' and os.path.exists(args.models): 145 | model_dicts = torch.load(args.models) 146 | transition_model.load_state_dict(model_dicts['transition_model']) 147 | observation_model.load_state_dict(model_dicts['observation_model']) 148 | reward_model.load_state_dict(model_dicts['reward_model']) 149 | cost_model.load_state_dict(model_dicts['cost_model']) 150 | encoder.load_state_dict(model_dicts['encoder']) 151 | one_step_ensemble.load_state_dict(model_dicts['one_step_ensemble']) 152 | optimiser.load_state_dict(model_dicts['optimiser']) 153 | planner_state = planner.state_dict() 154 | planner_state.update(model_dicts['planner']) 155 | planner.load_state_dict(planner_state) 156 | metrics = torch.load(os.path.join(os.path.dirname(args.models), 'metrics.pth')) 157 | global_prior = Normal(torch.zeros(args.batch_size, args.state_size, device=args.device), torch.ones(args.batch_size, args.state_size, device=args.device)) # Global prior N(0, I) 158 | free_nats = torch.full((1, ), args.free_nats, dtype=torch.float32, device=args.device) # Allowed deviation in KL divergence 159 | 160 | 161 | def update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation, min_action=-inf, max_action=inf, explore=False): 162 | # Infer belief over current state q(s_t|o≤t,a 0 224 | if args.overshooting_kl_beta != 0: 225 | overshooting_vars = [] # Collect variables for overshooting to process in batch 226 | for t in range(1, args.chunk_size - 1): 227 | d = min(t + args.overshooting_distance, args.chunk_size - 1) # Overshooting distance 228 | t_, d_ = t - 1, d - 1 # Use t_ and d_ to deal with different time indexing for latent states 229 | seq_pad = (0, 0, 0, 0, 0, t - d + args.overshooting_distance) # Calculate sequence padding so overshooting terms can be calculated in one batch 230 | # Store (0) actions, (1) nonterminals, (2) rewards, (3) beliefs, (4) prior states, (5) posterior means, (6) posterior standard deviations and (7) sequence masks 231 | overshooting_vars.append((F.pad(actions[t:d], seq_pad), F.pad(nonterminals[t:d], seq_pad), F.pad(rewards[t:d], seq_pad[2:]), beliefs[t_], prior_states[t_], F.pad(posterior_means[t_ + 1:d_ + 1].detach(), seq_pad), F.pad(posterior_std_devs[t_ + 1:d_ + 1].detach(), seq_pad, value=1), F.pad(torch.ones(d - t, args.batch_size, args.state_size, device=args.device), seq_pad))) # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences 232 | overshooting_vars = tuple(zip(*overshooting_vars)) 233 | # Update belief/state using prior from previous belief/state and previous action (over entire sequence at once) 234 | beliefs, prior_states, prior_means, prior_std_devs = transition_model(torch.cat(overshooting_vars[4], dim=0), torch.cat(overshooting_vars[0], dim=1), torch.cat(overshooting_vars[3], dim=0), None, torch.cat(overshooting_vars[1], dim=1)) 235 | seq_mask = torch.cat(overshooting_vars[7], dim=1) 236 | # Calculate overshooting KL loss with sequence mask 237 | kl_loss += (1 / args.overshooting_distance) * args.overshooting_kl_beta * torch.max((kl_divergence(Normal(torch.cat(overshooting_vars[5], dim=1), torch.cat(overshooting_vars[6], dim=1)), Normal(prior_means, prior_std_devs)) * seq_mask).sum(dim=2), free_nats).mean(dim=(0, 1)) * (args.chunk_size - 1) # Update KL loss (compensating for extra average over each overshooting/open loop sequence) 238 | # Calculate overshooting reward prediction loss with sequence mask 239 | if args.overshooting_reward_scale != 0: 240 | reward_loss += (1 / args.overshooting_distance) * args.overshooting_reward_scale * F.mse_loss(bottle(reward_model, (beliefs, prior_states)) * seq_mask[:, :, 0], torch.cat(overshooting_vars[2], dim=1), reduction='none').mean(dim=(0, 1)) * (args.chunk_size - 1) # Update reward loss (compensating for extra average over each overshooting/open loop sequence) 241 | 242 | # Apply linearly ramping learning rate schedule 243 | if args.learning_rate_schedule != 0: 244 | for group in optimiser.param_groups: 245 | group['lr'] = min(group['lr'] + args.learning_rate / args.learning_rate_schedule, args.learning_rate) 246 | # Update model parameters 247 | optimiser.zero_grad() 248 | (observation_loss + reward_loss + cost_loss + one_step_loss + kl_loss).backward() 249 | nn.utils.clip_grad_norm_(param_list, args.grad_clip_norm, norm_type=2) 250 | optimiser.step() 251 | # Store (0) observation loss (1) reward loss (2) cost loss (3) one step ensemble loss (4) KL loss 252 | losses.append([observation_loss.item(), reward_loss.item(), cost_loss.item(), one_step_loss.item(), kl_loss.item()]) 253 | 254 | # Update and plot loss metrics 255 | losses = tuple(zip(*losses)) 256 | metrics['observation_loss'].append(losses[0]) 257 | metrics['reward_loss'].append(losses[1]) 258 | metrics['cost_loss'].append(losses[2]) 259 | metrics['one_step_loss'].append(losses[3]) 260 | metrics['kl_loss'].append(losses[4]) 261 | lineplot(metrics['episodes'][-len(metrics['observation_loss']):], metrics['observation_loss'], 'observation_loss', results_dir) 262 | lineplot(metrics['episodes'][-len(metrics['reward_loss']):], metrics['reward_loss'], 'reward_loss', results_dir) 263 | lineplot(metrics['episodes'][-len(metrics['cost_loss']):], metrics['cost_loss'], 'cost_loss', results_dir) 264 | lineplot(metrics['episodes'][-len(metrics['one_step_loss']):], metrics['one_step_loss'], 'one_step_loss', results_dir) 265 | lineplot(metrics['episodes'][-len(metrics['kl_loss']):], metrics['kl_loss'], 'kl_loss', results_dir) 266 | 267 | 268 | # Data collection 269 | with torch.no_grad(): 270 | observation, total_reward, total_cost = env.reset(), 0, 0 271 | belief, posterior_state, action = torch.zeros(1, args.belief_size, device=args.device), torch.zeros(1, args.state_size, device=args.device), torch.zeros(1, env.action_size, device=args.device) 272 | pbar = tqdm(range(args.max_episode_length // args.action_repeat)) 273 | for t in pbar: 274 | belief, posterior_state, action, next_observation, reward, cost, done = update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), env.action_range[0], env.action_range[1], explore=True) 275 | D.append(observation, action.cpu(), reward, cost, done) 276 | total_reward += reward * args.discount ** (args.action_repeat * (t + 0.5)) 277 | total_cost += cost * args.cost_discount ** (args.action_repeat * (t + 0.5)) 278 | observation = next_observation 279 | if args.render: 280 | env.render() 281 | if done: 282 | pbar.close() 283 | break 284 | 285 | 286 | if args.learn_kappa: 287 | planner.optimize_penalty_kappa(total_cost, args.permissible_cost) 288 | # Update and plot train reward metrics 289 | metrics['steps'].append(t * args.action_repeat + metrics['steps'][-1]) 290 | metrics['episodes'].append(episode) 291 | metrics['train_rewards'].append(total_reward) 292 | metrics['train_costs'].append(total_cost) 293 | metrics['penalty_kappa'].append(planner.penalty_kappa.item()) 294 | metrics['cost_violations'].append(total_cost > args.cost_limit) 295 | 296 | lineplot(metrics['episodes'][-len(metrics['train_rewards']):], metrics['train_rewards'], 'train_rewards', results_dir) 297 | lineplot(metrics['episodes'][-len(metrics['train_costs']):], metrics['train_costs'], 'train_costs', results_dir) 298 | lineplot(metrics['episodes'][-len(metrics['penalty_kappa']):], metrics['penalty_kappa'], 'penalty_kappa', results_dir) 299 | 300 | 301 | # Test model 302 | if episode % args.test_interval == 0: 303 | # Set models to eval mode 304 | transition_model.eval() 305 | observation_model.eval() 306 | reward_model.eval() 307 | cost_model.eval() 308 | encoder.eval() 309 | one_step_ensemble.eval() 310 | planner.eval() 311 | # Initialise parallelised test environments 312 | 313 | with torch.no_grad(): 314 | observation, total_rewards, total_costs, video_frames = test_envs.reset(), np.zeros((args.test_episodes, )), np.zeros((args.test_episodes, )), [] 315 | belief, posterior_state, action = torch.zeros(args.test_episodes, args.belief_size, device=args.device), torch.zeros(args.test_episodes, args.state_size, device=args.device), torch.zeros(args.test_episodes, env.action_size, device=args.device) 316 | pbar = tqdm(range(args.max_episode_length // args.action_repeat)) 317 | for t in pbar: 318 | belief, posterior_state, action, next_observation, reward, cost, done = update_belief_and_act(args, test_envs, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), env.action_range[0], env.action_range[1]) 319 | total_rewards += reward.numpy() * args.discount ** (args.action_repeat * (t + 0.5)) 320 | total_costs += cost.numpy() * args.cost_discount ** (args.action_repeat * (t + 0.5)) 321 | if not args.symbolic_env: # Collect real vs. predicted frames for video 322 | video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy()) # Decentre 323 | metrics['test_eps_costs'].append(cost.numpy()) 324 | if args.penalize_uncertainty: 325 | uncertainty = planner.uncertainty_last_step[0] 326 | metrics['test_eps_uncertainty'].append(uncertainty.numpy()) 327 | observation = next_observation 328 | if done.sum().item() == args.test_episodes: 329 | test_envs.reset() 330 | t += 1 331 | pbar.close() 332 | break 333 | 334 | # Update and plot reward metrics (and write video if applicable) and save metrics 335 | metrics['test_episodes'].append(episode) 336 | metrics['test_rewards'].append(total_rewards.tolist()) 337 | metrics['test_costs'].append(total_costs.tolist()) 338 | metrics['test_cost_violations'].append((total_costs > args.cost_limit).tolist()) 339 | lineplot(metrics['test_episodes'], metrics['test_rewards'], 'test_rewards', results_dir) 340 | lineplot(metrics['test_episodes'], metrics['test_costs'], 'test_costs', results_dir) 341 | lineplot(np.asarray(metrics['steps'])[np.asarray(metrics['test_episodes']) - 1], metrics['test_rewards'], 'test_rewards_steps', results_dir, xaxis='step') 342 | lineplot(np.asarray(metrics['steps'])[np.asarray(metrics['test_episodes']) - 1], metrics['test_costs'], 'test_costs_steps', results_dir, xaxis='step') 343 | episode_str = str(episode).zfill(len(str(args.episodes))) 344 | if args.penalize_uncertainty: 345 | violinplot(metrics['test_eps_costs'][-t:], metrics['test_eps_uncertainty'][-t:], f'test_episode_{episode_str}_cost_uncertainty', results_dir, xaxis='cost') 346 | if not args.symbolic_env: 347 | write_video(video_frames, 'test_episode_%s' % episode_str, results_dir) # Lossy compression 348 | save_image(torch.as_tensor(video_frames[-1]), os.path.join(results_dir, 'test_episode_%s.png' % episode_str)) 349 | torch.save(metrics, os.path.join(results_dir, 'metrics.pth')) 350 | 351 | # Set models to train mode 352 | transition_model.train() 353 | observation_model.train() 354 | reward_model.train() 355 | cost_model.train() 356 | encoder.train() 357 | one_step_ensemble.train() 358 | planner.train() 359 | 360 | 361 | # Checkpoint models 362 | if episode % args.checkpoint_interval == 0: 363 | planner_state = {k: v for k, v in planner.state_dict().items() if 'model' not in k} 364 | torch.save({'transition_model': transition_model.state_dict(), 'observation_model': observation_model.state_dict(), 365 | 'reward_model': reward_model.state_dict(), 'cost_model': cost_model.state_dict(), 'encoder': encoder.state_dict(), 366 | 'one_step_ensemble': one_step_ensemble.state_dict(), 'optimiser': optimiser.state_dict(), 'planner': planner_state}, 367 | os.path.join(results_dir, 'models_%d.pth' % episode)) 368 | if args.checkpoint_experience: 369 | torch.save(D, os.path.join(results_dir, 'experience.pth')) # Warning: will fail with MemoryError with large memory sizes 370 | 371 | 372 | # Close training environment 373 | env.close() 374 | --------------------------------------------------------------------------------