├── .gitignore ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── env.py ├── environment.yml ├── main.py ├── memory.py ├── models.py ├── planner.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | 118 | 119 | # Results 120 | results/ 121 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contributing Guidelines 2 | ======================= 3 | 4 | This project is an open source version of code that I use in my work, released (under this [license](LICENSE.md)) for the benefit of others. It is developed and maintained in my own time, and hence I cannot guarantee responses. While contributions are welcome, please keep the following points in mind: 5 | 6 | - Please be civil to myself and other contributors. 7 | - Do raise issues for bugs and other implementation problems/improvements. 8 | - If you have studied the repo and have a question, raise an issue (it could possibly be a bug). 9 | - Bug fixes and small improvements are very welcome. 10 | - Raise an issue before developing a large contribution to a) discuss b) see if I would be willing to merge it (you're obviously free to keep your own fork) c) prevent overlap with others. 11 | - All code contributions should adhere to the existing style (e.g., 2-space indent, no max line length, etc.) 12 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kai Arulkumaran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PlaNet 2 | ====== 3 | 4 | [![MIT License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE.md) 5 | 6 | PlaNet: A Deep Planning Network for Reinforcement Learning [[1]](#references). Supports symbolic/visual observation spaces. Supports some Gym environments (including classic control/non-MuJoCo environments, so DeepMind Control Suite/MuJoCo are optional dependencies). Hyperparameters have been taken from the original work and are tuned for DeepMind Control Suite, so would need tuning for any other domains (such as the Gym environments). 7 | 8 | Run with `python.main.py`. For best performance with DeepMind Control Suite, try setting environment variable `MUJOCO_GL=egl` (see instructions and details [here](https://github.com/deepmind/dm_control#rendering)). 9 | 10 | 11 | Results and pretrained models can be found in the [releases](https://github.com/Kaixhin/PlaNet/releases). 12 | 13 | Requirements 14 | ------------ 15 | 16 | - Python 3 17 | - [DeepMind Control Suite](https://github.com/deepmind/dm_control) (optional) 18 | - [Gym](https://gym.openai.com/) 19 | - [OpenCV Python](https://pypi.python.org/pypi/opencv-python) 20 | - [Plotly](https://plot.ly/) 21 | - [PyTorch](http://pytorch.org/) 22 | 23 | To install all dependencies with Anaconda run `conda env create -f environment.yml` and use `source activate planet` to activate the environment. 24 | 25 | Links 26 | ----- 27 | 28 | - [Introducing PlaNet: A Deep Planning Network for Reinforcement Learning](https://ai.googleblog.com/2019/02/introducing-planet-deep-planning.html) 29 | - [google-research/planet](https://github.com/google-research/planet) 30 | 31 | Acknowledgements 32 | ---------------- 33 | 34 | - [@danijar](https://github.com/danijar) for [google-research/planet](https://github.com/google-research/planet) and [help reproducing results](https://github.com/google-research/planet/issues/28) 35 | - [@sg2](https://github.com/sg2) for [running experiments](https://github.com/Kaixhin/PlaNet/issues/9) 36 | 37 | References 38 | ---------- 39 | 40 | [1] [Learning Latent Dynamics for Planning from Pixels](https://arxiv.org/abs/1811.04551) 41 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | 6 | GYM_ENVS = ['Pendulum-v0', 'MountainCarContinuous-v0', 'Ant-v2', 'HalfCheetah-v2', 'Hopper-v2', 'Humanoid-v2', 'HumanoidStandup-v2', 'InvertedDoublePendulum-v2', 'InvertedPendulum-v2', 'Reacher-v2', 'Swimmer-v2', 'Walker2d-v2'] 7 | CONTROL_SUITE_ENVS = ['cartpole-balance', 'cartpole-swingup', 'reacher-easy', 'finger-spin', 'cheetah-run', 'ball_in_cup-catch', 'walker-walk'] 8 | CONTROL_SUITE_ACTION_REPEATS = {'cartpole': 8, 'reacher': 4, 'finger': 2, 'cheetah': 4, 'ball_in_cup': 6, 'walker': 2} 9 | 10 | 11 | # Preprocesses an observation inplace (from float32 Tensor [0, 255] to [-0.5, 0.5]) 12 | def preprocess_observation_(observation, bit_depth): 13 | observation.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_(0.5) # Quantise to given bit depth and centre 14 | observation.add_(torch.rand_like(observation).div_(2 ** bit_depth)) # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images) 15 | 16 | 17 | # Postprocess an observation for storage (from float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255]) 18 | def postprocess_observation(observation, bit_depth): 19 | return np.clip(np.floor((observation + 0.5) * 2 ** bit_depth) * 2 ** (8 - bit_depth), 0, 2 ** 8 - 1).astype(np.uint8) 20 | 21 | 22 | def _images_to_observation(images, bit_depth): 23 | images = torch.tensor(cv2.resize(images, (64, 64), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), dtype=torch.float32) # Resize and put channel first 24 | preprocess_observation_(images, bit_depth) # Quantise, centre and dequantise inplace 25 | return images.unsqueeze(dim=0) # Add batch dimension 26 | 27 | 28 | class ControlSuiteEnv(): 29 | def __init__(self, env, symbolic, seed, max_episode_length, action_repeat, bit_depth): 30 | from dm_control import suite 31 | from dm_control.suite.wrappers import pixels 32 | domain, task = env.split('-') 33 | self.symbolic = symbolic 34 | self._env = suite.load(domain_name=domain, task_name=task, task_kwargs={'random': seed}) 35 | if not symbolic: 36 | self._env = pixels.Wrapper(self._env) 37 | self.max_episode_length = max_episode_length 38 | self.action_repeat = action_repeat 39 | if action_repeat != CONTROL_SUITE_ACTION_REPEATS[domain]: 40 | print('Using action repeat %d; recommended action repeat for domain is %d' % (action_repeat, CONTROL_SUITE_ACTION_REPEATS[domain])) 41 | self.bit_depth = bit_depth 42 | 43 | def reset(self): 44 | self.t = 0 # Reset internal timer 45 | state = self._env.reset() 46 | if self.symbolic: 47 | return torch.tensor(np.concatenate([np.asarray([obs]) if isinstance(obs, float) else obs for obs in state.observation.values()], axis=0), dtype=torch.float32).unsqueeze(dim=0) 48 | else: 49 | return _images_to_observation(self._env.physics.render(camera_id=0), self.bit_depth) 50 | 51 | def step(self, action): 52 | action = action.detach().numpy() 53 | reward = 0 54 | for k in range(self.action_repeat): 55 | state = self._env.step(action) 56 | reward += state.reward 57 | self.t += 1 # Increment internal timer 58 | done = state.last() or self.t == self.max_episode_length 59 | if done: 60 | break 61 | if self.symbolic: 62 | observation = torch.tensor(np.concatenate([np.asarray([obs]) if isinstance(obs, float) else obs for obs in state.observation.values()], axis=0), dtype=torch.float32).unsqueeze(dim=0) 63 | else: 64 | observation = _images_to_observation(self._env.physics.render(camera_id=0), self.bit_depth) 65 | return observation, reward, done 66 | 67 | def render(self): 68 | cv2.imshow('screen', self._env.physics.render(camera_id=0)[:, :, ::-1]) 69 | cv2.waitKey(1) 70 | 71 | def close(self): 72 | cv2.destroyAllWindows() 73 | self._env.close() 74 | 75 | @property 76 | def observation_size(self): 77 | return sum([(1 if len(obs.shape) == 0 else obs.shape[0]) for obs in self._env.observation_spec().values()]) if self.symbolic else (3, 64, 64) 78 | 79 | @property 80 | def action_size(self): 81 | return self._env.action_spec().shape[0] 82 | 83 | @property 84 | def action_range(self): 85 | return float(self._env.action_spec().minimum[0]), float(self._env.action_spec().maximum[0]) 86 | 87 | # Sample an action randomly from a uniform distribution over all valid actions 88 | def sample_random_action(self): 89 | spec = self._env.action_spec() 90 | return torch.from_numpy(np.random.uniform(spec.minimum, spec.maximum, spec.shape)) 91 | 92 | 93 | 94 | class GymEnv(): 95 | def __init__(self, env, symbolic, seed, max_episode_length, action_repeat, bit_depth): 96 | import logging 97 | import gym 98 | gym.logger.set_level(logging.ERROR) # Ignore warnings from Gym logger 99 | self.symbolic = symbolic 100 | self._env = gym.make(env) 101 | self._env.seed(seed) 102 | self.max_episode_length = max_episode_length 103 | self.action_repeat = action_repeat 104 | self.bit_depth = bit_depth 105 | 106 | def reset(self): 107 | self.t = 0 # Reset internal timer 108 | state = self._env.reset() 109 | if self.symbolic: 110 | return torch.tensor(state, dtype=torch.float32).unsqueeze(dim=0) 111 | else: 112 | return _images_to_observation(self._env.render(mode='rgb_array'), self.bit_depth) 113 | 114 | def step(self, action): 115 | action = action.detach().numpy() 116 | reward = 0 117 | for k in range(self.action_repeat): 118 | state, reward_k, done, _ = self._env.step(action) 119 | reward += reward_k 120 | self.t += 1 # Increment internal timer 121 | done = done or self.t == self.max_episode_length 122 | if done: 123 | break 124 | if self.symbolic: 125 | observation = torch.tensor(state, dtype=torch.float32).unsqueeze(dim=0) 126 | else: 127 | observation = _images_to_observation(self._env.render(mode='rgb_array'), self.bit_depth) 128 | return observation, reward, done 129 | 130 | def render(self): 131 | self._env.render() 132 | 133 | def close(self): 134 | self._env.close() 135 | 136 | @property 137 | def observation_size(self): 138 | return self._env.observation_space.shape[0] if self.symbolic else (3, 64, 64) 139 | 140 | @property 141 | def action_size(self): 142 | return self._env.action_space.shape[0] 143 | 144 | @property 145 | def action_range(self): 146 | return float(self._env.action_space.low[0]), float(self._env.action_space.high[0]) 147 | 148 | # Sample an action randomly from a uniform distribution over all valid actions 149 | def sample_random_action(self): 150 | return torch.from_numpy(self._env.action_space.sample()) 151 | 152 | 153 | def Env(env, symbolic, seed, max_episode_length, action_repeat, bit_depth): 154 | if env in GYM_ENVS: 155 | return GymEnv(env, symbolic, seed, max_episode_length, action_repeat, bit_depth) 156 | elif env in CONTROL_SUITE_ENVS: 157 | return ControlSuiteEnv(env, symbolic, seed, max_episode_length, action_repeat, bit_depth) 158 | 159 | 160 | # Wrapper for batching environments together 161 | class EnvBatcher(): 162 | def __init__(self, env_class, env_args, env_kwargs, n): 163 | self.n = n 164 | self.envs = [env_class(*env_args, **env_kwargs) for _ in range(n)] 165 | self.dones = [True] * n 166 | 167 | # Resets every environment and returns observation 168 | def reset(self): 169 | observations = [env.reset() for env in self.envs] 170 | self.dones = [False] * self.n 171 | return torch.cat(observations) 172 | 173 | # Steps/resets every environment and returns (observation, reward, done) 174 | def step(self, actions): 175 | done_mask = torch.nonzero(torch.tensor(self.dones))[:, 0] # Done mask to blank out observations and zero rewards for previously terminated environments 176 | observations, rewards, dones = zip(*[env.step(action) for env, action in zip(self.envs, actions)]) 177 | dones = [d or prev_d for d, prev_d in zip(dones, self.dones)] # Env should remain terminated if previously terminated 178 | self.dones = dones 179 | observations, rewards, dones = torch.cat(observations), torch.tensor(rewards, dtype=torch.float32), torch.tensor(dones, dtype=torch.uint8) 180 | observations[done_mask] = 0 181 | rewards[done_mask] = 0 182 | return observations, rewards, dones 183 | 184 | def close(self): 185 | [env.close() for env in self.envs] 186 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: planet 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - plotly 6 | - pytorch 7 | - torchvision 8 | - tqdm 9 | - pip 10 | - pip: 11 | - gym 12 | - opencv-python 13 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from math import inf 3 | import os 4 | import numpy as np 5 | import torch 6 | from torch import nn, optim 7 | from torch.distributions import Normal 8 | from torch.distributions.kl import kl_divergence 9 | from torch.nn import functional as F 10 | from torchvision.utils import make_grid, save_image 11 | from tqdm import tqdm 12 | from env import CONTROL_SUITE_ENVS, Env, GYM_ENVS, EnvBatcher 13 | from memory import ExperienceReplay 14 | from models import bottle, Encoder, ObservationModel, RewardModel, TransitionModel 15 | from planner import MPCPlanner 16 | from utils import lineplot, write_video 17 | 18 | 19 | # Hyperparameters 20 | parser = argparse.ArgumentParser(description='PlaNet') 21 | parser.add_argument('--id', type=str, default='default', help='Experiment ID') 22 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed') 23 | parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA') 24 | parser.add_argument('--env', type=str, default='Pendulum-v0', choices=GYM_ENVS + CONTROL_SUITE_ENVS, help='Gym/Control Suite environment') 25 | parser.add_argument('--symbolic-env', action='store_true', help='Symbolic features') 26 | parser.add_argument('--max-episode-length', type=int, default=1000, metavar='T', help='Max episode length') 27 | parser.add_argument('--experience-size', type=int, default=1000000, metavar='D', help='Experience replay size') # Original implementation has an unlimited buffer size, but 1 million is the max experience collected anyway 28 | parser.add_argument('--activation-function', type=str, default='relu', choices=dir(F), help='Model activation function') 29 | parser.add_argument('--embedding-size', type=int, default=1024, metavar='E', help='Observation embedding size') # Note that the default encoder for visual observations outputs a 1024D vector; for other embedding sizes an additional fully-connected layer is used 30 | parser.add_argument('--hidden-size', type=int, default=200, metavar='H', help='Hidden size') 31 | parser.add_argument('--belief-size', type=int, default=200, metavar='H', help='Belief/hidden size') 32 | parser.add_argument('--state-size', type=int, default=30, metavar='Z', help='State/latent size') 33 | parser.add_argument('--action-repeat', type=int, default=2, metavar='R', help='Action repeat') 34 | parser.add_argument('--action-noise', type=float, default=0.3, metavar='ε', help='Action noise') 35 | parser.add_argument('--episodes', type=int, default=1000, metavar='E', help='Total number of episodes') 36 | parser.add_argument('--seed-episodes', type=int, default=5, metavar='S', help='Seed episodes') 37 | parser.add_argument('--collect-interval', type=int, default=100, metavar='C', help='Collect interval') 38 | parser.add_argument('--batch-size', type=int, default=50, metavar='B', help='Batch size') 39 | parser.add_argument('--chunk-size', type=int, default=50, metavar='L', help='Chunk size') 40 | parser.add_argument('--overshooting-distance', type=int, default=50, metavar='D', help='Latent overshooting distance/latent overshooting weight for t = 1') 41 | parser.add_argument('--overshooting-kl-beta', type=float, default=0, metavar='β>1', help='Latent overshooting KL weight for t > 1 (0 to disable)') 42 | parser.add_argument('--overshooting-reward-scale', type=float, default=0, metavar='R>1', help='Latent overshooting reward prediction weight for t > 1 (0 to disable)') 43 | parser.add_argument('--global-kl-beta', type=float, default=0, metavar='βg', help='Global KL weight (0 to disable)') 44 | parser.add_argument('--free-nats', type=float, default=3, metavar='F', help='Free nats') 45 | parser.add_argument('--bit-depth', type=int, default=5, metavar='B', help='Image bit depth (quantisation)') 46 | parser.add_argument('--learning-rate', type=float, default=1e-3, metavar='α', help='Learning rate') 47 | parser.add_argument('--learning-rate-schedule', type=int, default=0, metavar='αS', help='Linear learning rate schedule (optimisation steps from 0 to final learning rate; 0 to disable)') 48 | parser.add_argument('--adam-epsilon', type=float, default=1e-4, metavar='ε', help='Adam optimiser epsilon value') 49 | # Note that original has a linear learning rate decay, but it seems unlikely that this makes a significant difference 50 | parser.add_argument('--grad-clip-norm', type=float, default=1000, metavar='C', help='Gradient clipping norm') 51 | parser.add_argument('--planning-horizon', type=int, default=12, metavar='H', help='Planning horizon distance') 52 | parser.add_argument('--optimisation-iters', type=int, default=10, metavar='I', help='Planning optimisation iterations') 53 | parser.add_argument('--candidates', type=int, default=1000, metavar='J', help='Candidate samples per iteration') 54 | parser.add_argument('--top-candidates', type=int, default=100, metavar='K', help='Number of top candidates to fit') 55 | parser.add_argument('--test', action='store_true', help='Test only') 56 | parser.add_argument('--test-interval', type=int, default=25, metavar='I', help='Test interval (episodes)') 57 | parser.add_argument('--test-episodes', type=int, default=10, metavar='E', help='Number of test episodes') 58 | parser.add_argument('--checkpoint-interval', type=int, default=50, metavar='I', help='Checkpoint interval (episodes)') 59 | parser.add_argument('--checkpoint-experience', action='store_true', help='Checkpoint experience replay') 60 | parser.add_argument('--models', type=str, default='', metavar='M', help='Load model checkpoint') 61 | parser.add_argument('--experience-replay', type=str, default='', metavar='ER', help='Load experience replay') 62 | parser.add_argument('--render', action='store_true', help='Render environment') 63 | args = parser.parse_args() 64 | args.overshooting_distance = min(args.chunk_size, args.overshooting_distance) # Overshooting distance cannot be greater than chunk size 65 | print(' ' * 26 + 'Options') 66 | for k, v in vars(args).items(): 67 | print(' ' * 26 + k + ': ' + str(v)) 68 | 69 | 70 | # Setup 71 | results_dir = os.path.join('results', args.id) 72 | os.makedirs(results_dir, exist_ok=True) 73 | np.random.seed(args.seed) 74 | torch.manual_seed(args.seed) 75 | if torch.cuda.is_available() and not args.disable_cuda: 76 | args.device = torch.device('cuda') 77 | torch.cuda.manual_seed(args.seed) 78 | else: 79 | args.device = torch.device('cpu') 80 | metrics = {'steps': [], 'episodes': [], 'train_rewards': [], 'test_episodes': [], 'test_rewards': [], 'observation_loss': [], 'reward_loss': [], 'kl_loss': []} 81 | 82 | 83 | # Initialise training environment and experience replay memory 84 | env = Env(args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth) 85 | if args.experience_replay is not '' and os.path.exists(args.experience_replay): 86 | D = torch.load(args.experience_replay) 87 | metrics['steps'], metrics['episodes'] = [D.steps] * D.episodes, list(range(1, D.episodes + 1)) 88 | elif not args.test: 89 | D = ExperienceReplay(args.experience_size, args.symbolic_env, env.observation_size, env.action_size, args.bit_depth, args.device) 90 | # Initialise dataset D with S random seed episodes 91 | for s in range(1, args.seed_episodes + 1): 92 | observation, done, t = env.reset(), False, 0 93 | while not done: 94 | action = env.sample_random_action() 95 | next_observation, reward, done = env.step(action) 96 | D.append(observation, action, reward, done) 97 | observation = next_observation 98 | t += 1 99 | metrics['steps'].append(t * args.action_repeat + (0 if len(metrics['steps']) == 0 else metrics['steps'][-1])) 100 | metrics['episodes'].append(s) 101 | 102 | 103 | # Initialise model parameters randomly 104 | transition_model = TransitionModel(args.belief_size, args.state_size, env.action_size, args.hidden_size, args.embedding_size, args.activation_function).to(device=args.device) 105 | observation_model = ObservationModel(args.symbolic_env, env.observation_size, args.belief_size, args.state_size, args.embedding_size, args.activation_function).to(device=args.device) 106 | reward_model = RewardModel(args.belief_size, args.state_size, args.hidden_size, args.activation_function).to(device=args.device) 107 | encoder = Encoder(args.symbolic_env, env.observation_size, args.embedding_size, args.activation_function).to(device=args.device) 108 | param_list = list(transition_model.parameters()) + list(observation_model.parameters()) + list(reward_model.parameters()) + list(encoder.parameters()) 109 | optimiser = optim.Adam(param_list, lr=0 if args.learning_rate_schedule != 0 else args.learning_rate, eps=args.adam_epsilon) 110 | if args.models is not '' and os.path.exists(args.models): 111 | model_dicts = torch.load(args.models) 112 | transition_model.load_state_dict(model_dicts['transition_model']) 113 | observation_model.load_state_dict(model_dicts['observation_model']) 114 | reward_model.load_state_dict(model_dicts['reward_model']) 115 | encoder.load_state_dict(model_dicts['encoder']) 116 | optimiser.load_state_dict(model_dicts['optimiser']) 117 | planner = MPCPlanner(env.action_size, args.planning_horizon, args.optimisation_iters, args.candidates, args.top_candidates, transition_model, reward_model, env.action_range[0], env.action_range[1]) 118 | global_prior = Normal(torch.zeros(args.batch_size, args.state_size, device=args.device), torch.ones(args.batch_size, args.state_size, device=args.device)) # Global prior N(0, I) 119 | free_nats = torch.full((1, ), args.free_nats, dtype=torch.float32, device=args.device) # Allowed deviation in KL divergence 120 | 121 | 122 | def update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation, min_action=-inf, max_action=inf, explore=False): 123 | # Infer belief over current state q(s_t|o≤t,a 0 177 | if args.overshooting_kl_beta != 0: 178 | overshooting_vars = [] # Collect variables for overshooting to process in batch 179 | for t in range(1, args.chunk_size - 1): 180 | d = min(t + args.overshooting_distance, args.chunk_size - 1) # Overshooting distance 181 | t_, d_ = t - 1, d - 1 # Use t_ and d_ to deal with different time indexing for latent states 182 | seq_pad = (0, 0, 0, 0, 0, t - d + args.overshooting_distance) # Calculate sequence padding so overshooting terms can be calculated in one batch 183 | # Store (0) actions, (1) nonterminals, (2) rewards, (3) beliefs, (4) posterior states, (5) posterior means, (6) posterior standard deviations and (7) sequence masks 184 | overshooting_vars.append((F.pad(actions[t:d], seq_pad), F.pad(nonterminals[t:d], seq_pad), F.pad(rewards[t:d], seq_pad[2:]), beliefs[t_], posterior_states[t_].detach(), F.pad(posterior_means[t_ + 1:d_ + 1].detach(), seq_pad), F.pad(posterior_std_devs[t_ + 1:d_ + 1].detach(), seq_pad, value=1), F.pad(torch.ones(d - t, args.batch_size, args.state_size, device=args.device), seq_pad))) # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences 185 | overshooting_vars = tuple(zip(*overshooting_vars)) 186 | # Update belief/state using prior from previous belief/state and previous action (over entire sequence at once) 187 | beliefs, prior_states, prior_means, prior_std_devs = transition_model(torch.cat(overshooting_vars[4], dim=0), torch.cat(overshooting_vars[0], dim=1), torch.cat(overshooting_vars[3], dim=0), None, torch.cat(overshooting_vars[1], dim=1)) 188 | seq_mask = torch.cat(overshooting_vars[7], dim=1) 189 | # Calculate overshooting KL loss with sequence mask 190 | kl_loss += (1 / args.overshooting_distance) * args.overshooting_kl_beta * torch.max((kl_divergence(Normal(torch.cat(overshooting_vars[5], dim=1), torch.cat(overshooting_vars[6], dim=1)), Normal(prior_means, prior_std_devs)) * seq_mask).sum(dim=2), free_nats).mean(dim=(0, 1)) * (args.chunk_size - 1) # Update KL loss (compensating for extra average over each overshooting/open loop sequence) 191 | # Calculate overshooting reward prediction loss with sequence mask 192 | if args.overshooting_reward_scale != 0: 193 | reward_loss += (1 / args.overshooting_distance) * args.overshooting_reward_scale * F.mse_loss(bottle(reward_model, (beliefs, prior_states)) * seq_mask[:, :, 0], torch.cat(overshooting_vars[2], dim=1), reduction='none').mean(dim=(0, 1)) * (args.chunk_size - 1) # Update reward loss (compensating for extra average over each overshooting/open loop sequence) 194 | 195 | # Apply linearly ramping learning rate schedule 196 | if args.learning_rate_schedule != 0: 197 | for group in optimiser.param_groups: 198 | group['lr'] = min(group['lr'] + args.learning_rate / args.learning_rate_schedule, args.learning_rate) 199 | # Update model parameters 200 | optimiser.zero_grad() 201 | (observation_loss + reward_loss + kl_loss).backward() 202 | nn.utils.clip_grad_norm_(param_list, args.grad_clip_norm, norm_type=2) 203 | optimiser.step() 204 | # Store (0) observation loss (1) reward loss (2) KL loss 205 | losses.append([observation_loss.item(), reward_loss.item(), kl_loss.item()]) 206 | 207 | # Update and plot loss metrics 208 | losses = tuple(zip(*losses)) 209 | metrics['observation_loss'].append(losses[0]) 210 | metrics['reward_loss'].append(losses[1]) 211 | metrics['kl_loss'].append(losses[2]) 212 | lineplot(metrics['episodes'][-len(metrics['observation_loss']):], metrics['observation_loss'], 'observation_loss', results_dir) 213 | lineplot(metrics['episodes'][-len(metrics['reward_loss']):], metrics['reward_loss'], 'reward_loss', results_dir) 214 | lineplot(metrics['episodes'][-len(metrics['kl_loss']):], metrics['kl_loss'], 'kl_loss', results_dir) 215 | 216 | 217 | # Data collection 218 | with torch.no_grad(): 219 | observation, total_reward = env.reset(), 0 220 | belief, posterior_state, action = torch.zeros(1, args.belief_size, device=args.device), torch.zeros(1, args.state_size, device=args.device), torch.zeros(1, env.action_size, device=args.device) 221 | pbar = tqdm(range(args.max_episode_length // args.action_repeat)) 222 | for t in pbar: 223 | belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), env.action_range[0], env.action_range[1], explore=True) 224 | D.append(observation, action.cpu(), reward, done) 225 | total_reward += reward 226 | observation = next_observation 227 | if args.render: 228 | env.render() 229 | if done: 230 | pbar.close() 231 | break 232 | 233 | # Update and plot train reward metrics 234 | metrics['steps'].append(t + metrics['steps'][-1]) 235 | metrics['episodes'].append(episode) 236 | metrics['train_rewards'].append(total_reward) 237 | lineplot(metrics['episodes'][-len(metrics['train_rewards']):], metrics['train_rewards'], 'train_rewards', results_dir) 238 | 239 | 240 | # Test model 241 | if episode % args.test_interval == 0: 242 | # Set models to eval mode 243 | transition_model.eval() 244 | observation_model.eval() 245 | reward_model.eval() 246 | encoder.eval() 247 | # Initialise parallelised test environments 248 | test_envs = EnvBatcher(Env, (args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth), {}, args.test_episodes) 249 | 250 | with torch.no_grad(): 251 | observation, total_rewards, video_frames = test_envs.reset(), np.zeros((args.test_episodes, )), [] 252 | belief, posterior_state, action = torch.zeros(args.test_episodes, args.belief_size, device=args.device), torch.zeros(args.test_episodes, args.state_size, device=args.device), torch.zeros(args.test_episodes, env.action_size, device=args.device) 253 | pbar = tqdm(range(args.max_episode_length // args.action_repeat)) 254 | for t in pbar: 255 | belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, test_envs, planner, transition_model, encoder, belief, posterior_state, action, observation.to(device=args.device), env.action_range[0], env.action_range[1]) 256 | total_rewards += reward.numpy() 257 | if not args.symbolic_env: # Collect real vs. predicted frames for video 258 | video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy()) # Decentre 259 | observation = next_observation 260 | if done.sum().item() == args.test_episodes: 261 | pbar.close() 262 | break 263 | 264 | # Update and plot reward metrics (and write video if applicable) and save metrics 265 | metrics['test_episodes'].append(episode) 266 | metrics['test_rewards'].append(total_rewards.tolist()) 267 | lineplot(metrics['test_episodes'], metrics['test_rewards'], 'test_rewards', results_dir) 268 | lineplot(np.asarray(metrics['steps'])[np.asarray(metrics['test_episodes']) - 1], metrics['test_rewards'], 'test_rewards_steps', results_dir, xaxis='step') 269 | if not args.symbolic_env: 270 | episode_str = str(episode).zfill(len(str(args.episodes))) 271 | write_video(video_frames, 'test_episode_%s' % episode_str, results_dir) # Lossy compression 272 | save_image(torch.as_tensor(video_frames[-1]), os.path.join(results_dir, 'test_episode_%s.png' % episode_str)) 273 | torch.save(metrics, os.path.join(results_dir, 'metrics.pth')) 274 | 275 | # Set models to train mode 276 | transition_model.train() 277 | observation_model.train() 278 | reward_model.train() 279 | encoder.train() 280 | # Close test environments 281 | test_envs.close() 282 | 283 | 284 | # Checkpoint models 285 | if episode % args.checkpoint_interval == 0: 286 | torch.save({'transition_model': transition_model.state_dict(), 'observation_model': observation_model.state_dict(), 'reward_model': reward_model.state_dict(), 'encoder': encoder.state_dict(), 'optimiser': optimiser.state_dict()}, os.path.join(results_dir, 'models_%d.pth' % episode)) 287 | if args.checkpoint_experience: 288 | torch.save(D, os.path.join(results_dir, 'experience.pth')) # Warning: will fail with MemoryError with large memory sizes 289 | 290 | 291 | # Close training environment 292 | env.close() 293 | -------------------------------------------------------------------------------- /memory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from env import postprocess_observation, preprocess_observation_ 4 | 5 | 6 | class ExperienceReplay(): 7 | def __init__(self, size, symbolic_env, observation_size, action_size, bit_depth, device): 8 | self.device = device 9 | self.symbolic_env = symbolic_env 10 | self.size = size 11 | self.observations = np.empty((size, observation_size) if symbolic_env else (size, 3, 64, 64), dtype=np.float32 if symbolic_env else np.uint8) 12 | self.actions = np.empty((size, action_size), dtype=np.float32) 13 | self.rewards = np.empty((size, ), dtype=np.float32) 14 | self.nonterminals = np.empty((size, 1), dtype=np.float32) 15 | self.idx = 0 16 | self.full = False # Tracks if memory has been filled/all slots are valid 17 | self.steps, self.episodes = 0, 0 # Tracks how much experience has been used in total 18 | self.bit_depth = bit_depth 19 | 20 | def append(self, observation, action, reward, done): 21 | if self.symbolic_env: 22 | self.observations[self.idx] = observation.numpy() 23 | else: 24 | self.observations[self.idx] = postprocess_observation(observation.numpy(), self.bit_depth) # Decentre and discretise visual observations (to save memory) 25 | self.actions[self.idx] = action.numpy() 26 | self.rewards[self.idx] = reward 27 | self.nonterminals[self.idx] = not done 28 | self.idx = (self.idx + 1) % self.size 29 | self.full = self.full or self.idx == 0 30 | self.steps, self.episodes = self.steps + 1, self.episodes + (1 if done else 0) 31 | 32 | # Returns an index for a valid single sequence chunk uniformly sampled from the memory 33 | def _sample_idx(self, L): 34 | valid_idx = False 35 | while not valid_idx: 36 | idx = np.random.randint(0, self.size if self.full else self.idx - L) 37 | idxs = np.arange(idx, idx + L) % self.size 38 | valid_idx = not self.idx in idxs[1:] # Make sure data does not cross the memory index 39 | return idxs 40 | 41 | def _retrieve_batch(self, idxs, n, L): 42 | vec_idxs = idxs.transpose().reshape(-1) # Unroll indices 43 | observations = torch.as_tensor(self.observations[vec_idxs].astype(np.float32)) 44 | if not self.symbolic_env: 45 | preprocess_observation_(observations, self.bit_depth) # Undo discretisation for visual observations 46 | return observations.reshape(L, n, *observations.shape[1:]), self.actions[vec_idxs].reshape(L, n, -1), self.rewards[vec_idxs].reshape(L, n), self.nonterminals[vec_idxs].reshape(L, n, 1) 47 | 48 | # Returns a batch of sequence chunks uniformly sampled from the memory 49 | def sample(self, n, L): 50 | batch = self._retrieve_batch(np.asarray([self._sample_idx(L) for _ in range(n)]), n, L) 51 | return [torch.as_tensor(item).to(device=self.device) for item in batch] 52 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | import torch 3 | from torch import jit, nn 4 | from torch.nn import functional as F 5 | 6 | 7 | # Wraps the input tuple for a function to process a time x batch x features sequence in batch x features (assumes one output) 8 | def bottle(f, x_tuple): 9 | x_sizes = tuple(map(lambda x: x.size(), x_tuple)) 10 | y = f(*map(lambda x: x[0].view(x[1][0] * x[1][1], *x[1][2:]), zip(x_tuple, x_sizes))) 11 | y_size = y.size() 12 | return y.view(x_sizes[0][0], x_sizes[0][1], *y_size[1:]) 13 | 14 | 15 | class TransitionModel(jit.ScriptModule): 16 | __constants__ = ['min_std_dev'] 17 | 18 | def __init__(self, belief_size, state_size, action_size, hidden_size, embedding_size, activation_function='relu', min_std_dev=0.1): 19 | super().__init__() 20 | self.act_fn = getattr(F, activation_function) 21 | self.min_std_dev = min_std_dev 22 | self.fc_embed_state_action = nn.Linear(state_size + action_size, belief_size) 23 | self.rnn = nn.GRUCell(belief_size, belief_size) 24 | self.fc_embed_belief_prior = nn.Linear(belief_size, hidden_size) 25 | self.fc_state_prior = nn.Linear(hidden_size, 2 * state_size) 26 | self.fc_embed_belief_posterior = nn.Linear(belief_size + embedding_size, hidden_size) 27 | self.fc_state_posterior = nn.Linear(hidden_size, 2 * state_size) 28 | 29 | # Operates over (previous) state, (previous) actions, (previous) belief, (previous) nonterminals (mask), and (current) observations 30 | # Diagram of expected inputs and outputs for T = 5 (-x- signifying beginning of output belief/state that gets sliced off): 31 | # t : 0 1 2 3 4 5 32 | # o : -X--X--X--X--X- 33 | # a : -X--X--X--X--X- 34 | # n : -X--X--X--X--X- 35 | # pb: -X- 36 | # ps: -X- 37 | # b : -x--X--X--X--X--X- 38 | # s : -x--X--X--X--X--X- 39 | @jit.script_method 40 | def forward(self, prev_state:torch.Tensor, actions:torch.Tensor, prev_belief:torch.Tensor, observations:Optional[torch.Tensor]=None, nonterminals:Optional[torch.Tensor]=None) -> List[torch.Tensor]: 41 | # Create lists for hidden states (cannot use single tensor as buffer because autograd won't work with inplace writes) 42 | T = actions.size(0) + 1 43 | beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T 44 | beliefs[0], prior_states[0], posterior_states[0] = prev_belief, prev_state, prev_state 45 | # Loop over time sequence 46 | for t in range(T - 1): 47 | _state = prior_states[t] if observations is None else posterior_states[t] # Select appropriate previous state 48 | _state = _state if nonterminals is None else _state * nonterminals[t] # Mask if previous transition was terminal 49 | # Compute belief (deterministic hidden state) 50 | hidden = self.act_fn(self.fc_embed_state_action(torch.cat([_state, actions[t]], dim=1))) 51 | beliefs[t + 1] = self.rnn(hidden, beliefs[t]) 52 | # Compute state prior by applying transition dynamics 53 | hidden = self.act_fn(self.fc_embed_belief_prior(beliefs[t + 1])) 54 | prior_means[t + 1], _prior_std_dev = torch.chunk(self.fc_state_prior(hidden), 2, dim=1) 55 | prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + self.min_std_dev 56 | prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1]) 57 | if observations is not None: 58 | # Compute state posterior by applying transition dynamics and using current observation 59 | t_ = t - 1 # Use t_ to deal with different time indexing for observations 60 | hidden = self.act_fn(self.fc_embed_belief_posterior(torch.cat([beliefs[t + 1], observations[t_ + 1]], dim=1))) 61 | posterior_means[t + 1], _posterior_std_dev = torch.chunk(self.fc_state_posterior(hidden), 2, dim=1) 62 | posterior_std_devs[t + 1] = F.softplus(_posterior_std_dev) + self.min_std_dev 63 | posterior_states[t + 1] = posterior_means[t + 1] + posterior_std_devs[t + 1] * torch.randn_like(posterior_means[t + 1]) 64 | # Return new hidden states 65 | hidden = [torch.stack(beliefs[1:], dim=0), torch.stack(prior_states[1:], dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0)] 66 | if observations is not None: 67 | hidden += [torch.stack(posterior_states[1:], dim=0), torch.stack(posterior_means[1:], dim=0), torch.stack(posterior_std_devs[1:], dim=0)] 68 | return hidden 69 | 70 | 71 | class SymbolicObservationModel(jit.ScriptModule): 72 | def __init__(self, observation_size, belief_size, state_size, embedding_size, activation_function='relu'): 73 | super().__init__() 74 | self.act_fn = getattr(F, activation_function) 75 | self.fc1 = nn.Linear(belief_size + state_size, embedding_size) 76 | self.fc2 = nn.Linear(embedding_size, embedding_size) 77 | self.fc3 = nn.Linear(embedding_size, observation_size) 78 | 79 | @jit.script_method 80 | def forward(self, belief, state): 81 | hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=1))) 82 | hidden = self.act_fn(self.fc2(hidden)) 83 | observation = self.fc3(hidden) 84 | return observation 85 | 86 | 87 | class VisualObservationModel(jit.ScriptModule): 88 | __constants__ = ['embedding_size'] 89 | 90 | def __init__(self, belief_size, state_size, embedding_size, activation_function='relu'): 91 | super().__init__() 92 | self.act_fn = getattr(F, activation_function) 93 | self.embedding_size = embedding_size 94 | self.fc1 = nn.Linear(belief_size + state_size, embedding_size) 95 | self.conv1 = nn.ConvTranspose2d(embedding_size, 128, 5, stride=2) 96 | self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2) 97 | self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2) 98 | self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2) 99 | 100 | @jit.script_method 101 | def forward(self, belief, state): 102 | hidden = self.fc1(torch.cat([belief, state], dim=1)) # No nonlinearity here 103 | hidden = hidden.view(-1, self.embedding_size, 1, 1) 104 | hidden = self.act_fn(self.conv1(hidden)) 105 | hidden = self.act_fn(self.conv2(hidden)) 106 | hidden = self.act_fn(self.conv3(hidden)) 107 | observation = self.conv4(hidden) 108 | return observation 109 | 110 | 111 | def ObservationModel(symbolic, observation_size, belief_size, state_size, embedding_size, activation_function='relu'): 112 | if symbolic: 113 | return SymbolicObservationModel(observation_size, belief_size, state_size, embedding_size, activation_function) 114 | else: 115 | return VisualObservationModel(belief_size, state_size, embedding_size, activation_function) 116 | 117 | 118 | class RewardModel(jit.ScriptModule): 119 | def __init__(self, belief_size, state_size, hidden_size, activation_function='relu'): 120 | super().__init__() 121 | self.act_fn = getattr(F, activation_function) 122 | self.fc1 = nn.Linear(belief_size + state_size, hidden_size) 123 | self.fc2 = nn.Linear(hidden_size, hidden_size) 124 | self.fc3 = nn.Linear(hidden_size, 1) 125 | 126 | @jit.script_method 127 | def forward(self, belief, state): 128 | hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=1))) 129 | hidden = self.act_fn(self.fc2(hidden)) 130 | reward = self.fc3(hidden).squeeze(dim=1) 131 | return reward 132 | 133 | 134 | class SymbolicEncoder(jit.ScriptModule): 135 | def __init__(self, observation_size, embedding_size, activation_function='relu'): 136 | super().__init__() 137 | self.act_fn = getattr(F, activation_function) 138 | self.fc1 = nn.Linear(observation_size, embedding_size) 139 | self.fc2 = nn.Linear(embedding_size, embedding_size) 140 | self.fc3 = nn.Linear(embedding_size, embedding_size) 141 | 142 | @jit.script_method 143 | def forward(self, observation): 144 | hidden = self.act_fn(self.fc1(observation)) 145 | hidden = self.act_fn(self.fc2(hidden)) 146 | hidden = self.fc3(hidden) 147 | return hidden 148 | 149 | 150 | class VisualEncoder(jit.ScriptModule): 151 | __constants__ = ['embedding_size'] 152 | 153 | def __init__(self, embedding_size, activation_function='relu'): 154 | super().__init__() 155 | self.act_fn = getattr(F, activation_function) 156 | self.embedding_size = embedding_size 157 | self.conv1 = nn.Conv2d(3, 32, 4, stride=2) 158 | self.conv2 = nn.Conv2d(32, 64, 4, stride=2) 159 | self.conv3 = nn.Conv2d(64, 128, 4, stride=2) 160 | self.conv4 = nn.Conv2d(128, 256, 4, stride=2) 161 | self.fc = nn.Identity() if embedding_size == 1024 else nn.Linear(1024, embedding_size) 162 | 163 | @jit.script_method 164 | def forward(self, observation): 165 | hidden = self.act_fn(self.conv1(observation)) 166 | hidden = self.act_fn(self.conv2(hidden)) 167 | hidden = self.act_fn(self.conv3(hidden)) 168 | hidden = self.act_fn(self.conv4(hidden)) 169 | hidden = hidden.view(-1, 1024) 170 | hidden = self.fc(hidden) # Identity if embedding size is 1024 else linear projection 171 | return hidden 172 | 173 | 174 | def Encoder(symbolic, observation_size, embedding_size, activation_function='relu'): 175 | if symbolic: 176 | return SymbolicEncoder(observation_size, embedding_size, activation_function) 177 | else: 178 | return VisualEncoder(embedding_size, activation_function) 179 | -------------------------------------------------------------------------------- /planner.py: -------------------------------------------------------------------------------- 1 | from math import inf 2 | import torch 3 | from torch import jit 4 | 5 | 6 | # Model-predictive control planner with cross-entropy method and learned transition model 7 | class MPCPlanner(jit.ScriptModule): 8 | __constants__ = ['action_size', 'planning_horizon', 'optimisation_iters', 'candidates', 'top_candidates', 'min_action', 'max_action'] 9 | 10 | def __init__(self, action_size, planning_horizon, optimisation_iters, candidates, top_candidates, transition_model, reward_model, min_action=-inf, max_action=inf): 11 | super().__init__() 12 | self.transition_model, self.reward_model = transition_model, reward_model 13 | self.action_size, self.min_action, self.max_action = action_size, min_action, max_action 14 | self.planning_horizon = planning_horizon 15 | self.optimisation_iters = optimisation_iters 16 | self.candidates, self.top_candidates = candidates, top_candidates 17 | 18 | @jit.script_method 19 | def forward(self, belief, state): 20 | B, H, Z = belief.size(0), belief.size(1), state.size(1) 21 | belief, state = belief.unsqueeze(dim=1).expand(B, self.candidates, H).reshape(-1, H), state.unsqueeze(dim=1).expand(B, self.candidates, Z).reshape(-1, Z) 22 | # Initialize factorized belief over action sequences q(a_t:t+H) ~ N(0, I) 23 | action_mean, action_std_dev = torch.zeros(self.planning_horizon, B, 1, self.action_size, device=belief.device), torch.ones(self.planning_horizon, B, 1, self.action_size, device=belief.device) 24 | for _ in range(self.optimisation_iters): 25 | # Evaluate J action sequences from the current belief (over entire sequence at once, batched over particles) 26 | actions = (action_mean + action_std_dev * torch.randn(self.planning_horizon, B, self.candidates, self.action_size, device=action_mean.device)).view(self.planning_horizon, B * self.candidates, self.action_size) # Sample actions (time x (batch x candidates) x actions) 27 | actions.clamp_(min=self.min_action, max=self.max_action) # Clip action range 28 | # Sample next states 29 | beliefs, states, _, _ = self.transition_model(state, actions, belief) 30 | # Calculate expected returns (technically sum of rewards over planning horizon) 31 | returns = self.reward_model(beliefs.view(-1, H), states.view(-1, Z)).view(self.planning_horizon, -1).sum(dim=0) 32 | # Re-fit belief to the K best action sequences 33 | _, topk = returns.reshape(B, self.candidates).topk(self.top_candidates, dim=1, largest=True, sorted=False) 34 | topk += self.candidates * torch.arange(0, B, dtype=torch.int64, device=topk.device).unsqueeze(dim=1) # Fix indices for unrolled actions 35 | best_actions = actions[:, topk.view(-1)].reshape(self.planning_horizon, B, self.top_candidates, self.action_size) 36 | # Update belief with new means and standard deviations 37 | action_mean, action_std_dev = best_actions.mean(dim=2, keepdim=True), best_actions.std(dim=2, unbiased=False, keepdim=True) 38 | # Return first action mean µ_t 39 | return action_mean[0].squeeze(dim=1) 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym 2 | opencv-python 3 | plotly 4 | torch 5 | tqdm 6 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import plotly 5 | from plotly.graph_objs import Scatter 6 | from plotly.graph_objs.scatter import Line 7 | 8 | 9 | # Plots min, max and mean + standard deviation bars of a population over time 10 | def lineplot(xs, ys_population, title, path='', xaxis='episode'): 11 | max_colour, mean_colour, std_colour, transparent = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)', 'rgba(0, 0, 0, 0)' 12 | 13 | if isinstance(ys_population[0], list) or isinstance(ys_population[0], tuple): 14 | ys = np.asarray(ys_population, dtype=np.float32) 15 | ys_min, ys_max, ys_mean, ys_std, ys_median = ys.min(1), ys.max(1), ys.mean(1), ys.std(1), np.median(ys, 1) 16 | ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std 17 | 18 | trace_max = Scatter(x=xs, y=ys_max, line=Line(color=max_colour, dash='dash'), name='Max') 19 | trace_upper = Scatter(x=xs, y=ys_upper, line=Line(color=transparent), name='+1 Std. Dev.', showlegend=False) 20 | trace_mean = Scatter(x=xs, y=ys_mean, fill='tonexty', fillcolor=std_colour, line=Line(color=mean_colour), name='Mean') 21 | trace_lower = Scatter(x=xs, y=ys_lower, fill='tonexty', fillcolor=std_colour, line=Line(color=transparent), name='-1 Std. Dev.', showlegend=False) 22 | trace_min = Scatter(x=xs, y=ys_min, line=Line(color=max_colour, dash='dash'), name='Min') 23 | trace_median = Scatter(x=xs, y=ys_median, line=Line(color=max_colour), name='Median') 24 | data = [trace_upper, trace_mean, trace_lower, trace_min, trace_max, trace_median] 25 | else: 26 | data = [Scatter(x=xs, y=ys_population, line=Line(color=mean_colour))] 27 | plotly.offline.plot({ 28 | 'data': data, 29 | 'layout': dict(title=title, xaxis={'title': xaxis}, yaxis={'title': title}) 30 | }, filename=os.path.join(path, title + '.html'), auto_open=False) 31 | 32 | 33 | def write_video(frames, title, path=''): 34 | frames = np.multiply(np.stack(frames, axis=0).transpose(0, 2, 3, 1), 255).clip(0, 255).astype(np.uint8)[:, :, :, ::-1] # VideoWrite expects H x W x C in BGR 35 | _, H, W, _ = frames.shape 36 | writer = cv2.VideoWriter(os.path.join(path, '%s.mp4' % title), cv2.VideoWriter_fourcc(*'mp4v'), 30., (W, H), True) 37 | for frame in frames: 38 | writer.write(frame) 39 | writer.release() 40 | --------------------------------------------------------------------------------