├── video_prediction.gif ├── figures ├── cheetah_run.png ├── finger_spin.png ├── walker_walk.png ├── reacher_easy.png ├── cartpole_swingup.png └── ball_in_cup_catch.png ├── requirements.txt ├── LICENSE ├── viewer.py ├── README.md ├── .gitignore ├── wrappers.py ├── test.py ├── utils.py ├── agent.py ├── video_prediction.py ├── model.py └── train.py /video_prediction.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cross32768/PlaNet_PyTorch/HEAD/video_prediction.gif -------------------------------------------------------------------------------- /figures/cheetah_run.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cross32768/PlaNet_PyTorch/HEAD/figures/cheetah_run.png -------------------------------------------------------------------------------- /figures/finger_spin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cross32768/PlaNet_PyTorch/HEAD/figures/finger_spin.png -------------------------------------------------------------------------------- /figures/walker_walk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cross32768/PlaNet_PyTorch/HEAD/figures/walker_walk.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dm_control 2 | gym 3 | opencv-python 4 | matplotlib 5 | numpy 6 | tensorboard 7 | torch 8 | -------------------------------------------------------------------------------- /figures/reacher_easy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cross32768/PlaNet_PyTorch/HEAD/figures/reacher_easy.png -------------------------------------------------------------------------------- /figures/cartpole_swingup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cross32768/PlaNet_PyTorch/HEAD/figures/cartpole_swingup.png -------------------------------------------------------------------------------- /figures/ball_in_cup_catch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cross32768/PlaNet_PyTorch/HEAD/figures/ball_in_cup_catch.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kaito Suzuki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /viewer.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple OpenCV based viewer for rendering 3 | from https://github.com/zuoxingdong/dm2gym 4 | """ 5 | import uuid 6 | import cv2 7 | 8 | 9 | class OpenCVImageViewer: 10 | """ 11 | A simple OpenCV highgui based dm_control image viewer 12 | This class is meant to be a drop-in replacement for 13 | `gym.envs.classic_control.rendering.SimpleImageViewer` 14 | """ 15 | def __init__(self, *, escape_to_exit=False): 16 | """ 17 | Construct the viewing window 18 | """ 19 | self._escape_to_exit = escape_to_exit 20 | self._window_name = str(uuid.uuid4()) 21 | cv2.namedWindow(self._window_name, cv2.WINDOW_AUTOSIZE) 22 | self._isopen = True 23 | 24 | def __del__(self): 25 | """ 26 | Close the window 27 | """ 28 | cv2.destroyWindow(self._window_name) 29 | self._isopen = False 30 | 31 | def imshow(self, img): 32 | """ 33 | Show an image 34 | """ 35 | cv2.imshow(self._window_name, img[:, :, [2, 1, 0]]) 36 | if cv2.waitKey(1) in [27] and self._escape_to_exit: 37 | exit() 38 | 39 | @property 40 | def isopen(self): 41 | """ 42 | Is the window open? 43 | """ 44 | return self._isopen 45 | 46 | def close(self): 47 | pass 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PlaNet_PyTorch 2 | Unofficial re-implementation of "Learning Latent Dynamics for Planning from Pixels" (https://arxiv.org/abs/1811.04551 ) 3 | 4 | ## Instructions 5 | For training, install the requirements (see below) and run (default environment is cheetah run) 6 | ```python 7 | python3 train.py 8 | ``` 9 | 10 | To test learned model, run 11 | ```python 12 | python3 test.py dir 13 | ``` 14 | 15 | To predict video with learned model, run 16 | ```python 17 | python3 video_prediction.py dir 18 | ``` 19 | dir should be log_dir of train.py and you need to specify environment corresponding to the log by arguments. 20 | 21 | 22 | 23 | ## Requirements 24 | * Python3 25 | * Mujoco (for DeepMind Control Suite) 26 | 27 | and see requirements.txt for required python library 28 | 29 | ## Qualitative tesult 30 | Example of predicted video frame by learned model 31 | ![](https://github.com/cross32768/PlaNet_PyTorch/blob/master/video_prediction.gif) 32 | 33 | ## Quantitative result 34 | ### cartpole swingup 35 | ![](https://github.com/cross32768/PlaNet_PyTorch/blob/master/figures/cartpole_swingup.png) 36 | 37 | ### reacher easy 38 | ![](https://github.com/cross32768/PlaNet_PyTorch/blob/master/figures/reacher_easy.png) 39 | 40 | ### cheetah run 41 | ![](https://github.com/cross32768/PlaNet_PyTorch/blob/master/figures/cheetah_run.png) 42 | 43 | ### finger spin 44 | ![](https://github.com/cross32768/PlaNet_PyTorch/blob/master/figures/finger_spin.png) 45 | 46 | ### ball_in_cup catch 47 | ![](https://github.com/cross32768/PlaNet_PyTorch/blob/master/figures/ball_in_cup_catch.png) 48 | 49 | ### walker walk 50 | ![](https://github.com/cross32768/PlaNet_PyTorch/blob/master/figures/walker_walk.png) 51 | 52 | Work in progress. 53 | 54 | I'm going to add result of experiments at least three times for each environment in the original paper. 55 | 56 | All results are test score (without exploration noise), acquired at every 10 episodes. 57 | 58 | And I applied moving average with window size=5 59 | 60 | ## References 61 | * [Learning Latent Dynamics for Planning from Pixels](https://arxiv.org/abs/1811.04551) 62 | * [Official Implementation](https://github.com/google-research/planet) 63 | 64 | 65 | ## TODO 66 | * speed up training 67 | * Add more qualitative results (at least 3 experiments for each envifonment with different random seed) 68 | * Generalize code for other environments 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /wrappers.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from viewer import OpenCVImageViewer 4 | 5 | 6 | class GymWrapper(object): 7 | """ 8 | Gym interface wrapper for dm_control env wrapped by pixels.Wrapper 9 | """ 10 | metadata = {'render.modes': ['human', 'rgb_array']} 11 | reward_range = (-np.inf, np.inf) 12 | 13 | def __init__(self, env): 14 | self._env = env 15 | self._viewer = None 16 | 17 | def __getattr(self, name): 18 | return getattr(self._env, name) 19 | 20 | @property 21 | def observation_space(self): 22 | obs_spec = self._env.observation_spec() 23 | return gym.spaces.Box(0, 255, obs_spec['pixels'].shape, dtype=np.uint8) 24 | 25 | @property 26 | def action_space(self): 27 | action_spec = self._env.action_spec() 28 | return gym.spaces.Box(action_spec.minimum, action_spec.maximum, dtype=np.float32) 29 | 30 | def step(self, action): 31 | time_step = self._env.step(action) 32 | obs = time_step.observation['pixels'] 33 | reward = time_step.reward or 0 34 | done = time_step.last() 35 | info = {'discount': time_step.discount} 36 | return obs, reward, done, info 37 | 38 | def reset(self): 39 | time_step = self._env.reset() 40 | obs = time_step.observation['pixels'] 41 | return obs 42 | 43 | def render(self, mode='human', **kwargs): 44 | if not kwargs: 45 | kwargs = self._env._render_kwargs 46 | 47 | img = self._env.physics.render(**kwargs) 48 | if mode == 'rgb_array': 49 | return img 50 | elif mode == 'human': 51 | if self._viewer is None: 52 | self._viewer = OpenCVImageViewer() 53 | self._viewer.imshow(img) 54 | return self._viewer.isopen 55 | else: 56 | raise NotImplementedError 57 | 58 | 59 | class RepeatAction(gym.Wrapper): 60 | """ 61 | Action repeat wrapper to act same action repeatedly 62 | """ 63 | def __init__(self, env, skip=4): 64 | gym.Wrapper.__init__(self, env) 65 | self._skip = skip 66 | 67 | def reset(self): 68 | return self.env.reset() 69 | 70 | def step(self, action): 71 | total_reward = 0.0 72 | for _ in range(self._skip): 73 | obs, reward, done, info = self.env.step(action) 74 | total_reward += reward 75 | if done: 76 | break 77 | return obs, total_reward, done, info 78 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import torch 5 | from dm_control import suite 6 | from dm_control.suite.wrappers import pixels 7 | from agent import CEMAgent 8 | from model import Encoder, RecurrentStateSpaceModel, RewardModel 9 | from wrappers import GymWrapper, RepeatAction 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser(description='Test learned model') 14 | parser.add_argument('dir', type=str, help='log directory to load learned model') 15 | parser.add_argument('--render', action='store_true') 16 | parser.add_argument('--domain-name', type=str, default='cheetah') 17 | parser.add_argument('--task-name', type=str, default='run') 18 | parser.add_argument('-R', '--action-repeat', type=int, default=4) 19 | parser.add_argument('--episodes', type=int, default=1) 20 | parser.add_argument('-H', '--horizon', type=int, default=12) 21 | parser.add_argument('-I', '--N-iterations', type=int, default=10) 22 | parser.add_argument('-J', '--N-candidates', type=int, default=1000) 23 | parser.add_argument('-K', '--N-top-candidates', type=int, default=100) 24 | args = parser.parse_args() 25 | 26 | # define environment and apply wrapper 27 | env = suite.load(args.domain_name, args.task_name) 28 | env = pixels.Wrapper(env, render_kwargs={'height': 64, 29 | 'width': 64, 30 | 'camera_id': 0}) 31 | env = GymWrapper(env) 32 | env = RepeatAction(env, skip=args.action_repeat) 33 | 34 | # define models 35 | with open(os.path.join(args.dir, 'args.json'), 'r') as f: 36 | train_args = json.load(f) 37 | 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | encoder = Encoder().to(device) 40 | rssm = RecurrentStateSpaceModel(train_args['state_dim'], 41 | env.action_space.shape[0], 42 | train_args['rnn_hidden_dim']).to(device) 43 | reward_model = RewardModel(train_args['state_dim'], 44 | train_args['rnn_hidden_dim']).to(device) 45 | 46 | # load learned parameters 47 | encoder.load_state_dict(torch.load(os.path.join(args.dir, 'encoder.pth'))) 48 | rssm.load_state_dict(torch.load(os.path.join(args.dir, 'rssm.pth'))) 49 | reward_model.load_state_dict(torch.load(os.path.join(args.dir, 'reward_model.pth'))) 50 | 51 | # define agent 52 | cem_agent = CEMAgent(encoder, rssm, reward_model, 53 | args.horizon, args.N_iterations, 54 | args.N_candidates, args.N_top_candidates) 55 | 56 | # test learnged model in the environment 57 | for episode in range(args.episodes): 58 | cem_agent.reset() 59 | obs = env.reset() 60 | done = False 61 | total_reward = 0 62 | while not done: 63 | action = cem_agent(obs) 64 | obs, reward, done, _ = env.step(action) 65 | total_reward += reward 66 | if args.render: 67 | env.render(height=256, width=256, camera_id=0) 68 | 69 | print('Total test reward at episode [%4d/%4d] is %f' % 70 | (episode+1, args.episodes, total_reward)) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ReplayBuffer(object): 5 | """ 6 | Replay buffer for training with RNN 7 | """ 8 | def __init__(self, capacity, observation_shape, action_dim): 9 | self.capacity = capacity 10 | 11 | self.observations = np.zeros((capacity, *observation_shape), dtype=np.uint8) 12 | self.actions = np.zeros((capacity, action_dim), dtype=np.float32) 13 | self.rewards = np.zeros((capacity, 1), dtype=np.float32) 14 | self.done = np.zeros((capacity, 1), dtype=np.bool) 15 | 16 | self.index = 0 17 | self.is_filled = False 18 | 19 | def push(self, observation, action, reward, done): 20 | """ 21 | Add experience to replay buffer 22 | NOTE: observation should be transformed to np.uint8 before push 23 | """ 24 | self.observations[self.index] = observation 25 | self.actions[self.index] = action 26 | self.rewards[self.index] = reward 27 | self.done[self.index] = done 28 | 29 | if self.index == self.capacity - 1: 30 | self.is_filled = True 31 | self.index = (self.index + 1) % self.capacity 32 | 33 | def sample(self, batch_size, chunk_length): 34 | """ 35 | Sample experiences from replay buffer (almost) uniformly 36 | The resulting array will be of the form (batch_size, chunk_length) 37 | and each batch is consecutive sequence 38 | NOTE: too large chunk_length for the length of episode will cause problems 39 | """ 40 | episode_borders = np.where(self.done)[0] 41 | sampled_indexes = [] 42 | for _ in range(batch_size): 43 | cross_border = True 44 | while cross_border: 45 | initial_index = np.random.randint(len(self) - chunk_length + 1) 46 | final_index = initial_index + chunk_length - 1 47 | cross_border = np.logical_and(initial_index <= episode_borders, 48 | episode_borders < final_index).any() 49 | sampled_indexes += list(range(initial_index, final_index + 1)) 50 | 51 | sampled_observations = self.observations[sampled_indexes].reshape( 52 | batch_size, chunk_length, *self.observations.shape[1:]) 53 | sampled_actions = self.actions[sampled_indexes].reshape( 54 | batch_size, chunk_length, self.actions.shape[1]) 55 | sampled_rewards = self.rewards[sampled_indexes].reshape( 56 | batch_size, chunk_length, 1) 57 | sampled_done = self.done[sampled_indexes].reshape( 58 | batch_size, chunk_length, 1) 59 | return sampled_observations, sampled_actions, sampled_rewards, sampled_done 60 | 61 | def __len__(self): 62 | return self.capacity if self.is_filled else self.index 63 | 64 | 65 | def preprocess_obs(obs, bit_depth=5): 66 | """ 67 | Reduces the bit depth of image for the ease of training 68 | and convert to [-0.5, 0.5] 69 | In addition, add uniform random noise same as original implementation 70 | """ 71 | obs = obs.astype(np.float32) 72 | reduced_obs = np.floor(obs / 2 ** (8 - bit_depth)) 73 | normalized_obs = reduced_obs / 2**bit_depth - 0.5 74 | normalized_obs += np.random.uniform(0.0, 1.0 / 2**bit_depth, normalized_obs.shape) 75 | return normalized_obs 76 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Normal 3 | from utils import preprocess_obs 4 | 5 | 6 | class CEMAgent: 7 | """ 8 | Action planning by Cross Entropy Method (CEM) in learned RSSM Model 9 | """ 10 | def __init__(self, encoder, rssm, reward_model, 11 | horizon, N_iterations, N_candidates, N_top_candidates): 12 | self.encoder = encoder 13 | self.rssm = rssm 14 | self.reward_model = reward_model 15 | 16 | self.horizon = horizon 17 | self.N_iterations = N_iterations 18 | self.N_candidates = N_candidates 19 | self.N_top_candidates = N_top_candidates 20 | 21 | self.device = next(self.reward_model.parameters()).device 22 | self.rnn_hidden = torch.zeros(1, rssm.rnn_hidden_dim, device=self.device) 23 | 24 | def __call__(self, obs): 25 | # Preprocess observation and transpose for torch style (channel-first) 26 | obs = preprocess_obs(obs) 27 | obs = torch.as_tensor(obs, device=self.device) 28 | obs = obs.transpose(1, 2).transpose(0, 1).unsqueeze(0) 29 | 30 | with torch.no_grad(): 31 | # Compute starting state for planning 32 | # while taking information from current observation (posterior) 33 | embedded_obs = self.encoder(obs) 34 | state_posterior = self.rssm.posterior(self.rnn_hidden, embedded_obs) 35 | 36 | # Initialize action distribution 37 | action_dist = Normal( 38 | torch.zeros((self.horizon, self.rssm.action_dim), device=self.device), 39 | torch.ones((self.horizon, self.rssm.action_dim), device=self.device) 40 | ) 41 | 42 | # Iteratively improve action distribution with CEM 43 | for itr in range(self.N_iterations): 44 | # Sample action candidates and transpose to 45 | # (self.horizon, self.N_candidates, action_dim) for parallel exploration 46 | action_candidates = \ 47 | action_dist.sample([self.N_candidates]).transpose(0, 1) 48 | 49 | # Initialize reward, state, and rnn hidden state 50 | # The size of state is (self.N_acndidates, state_dim) 51 | # The size of rnn hidden is (self.N_candidates, rnn_hidden_dim) 52 | # These are for parallel exploration 53 | total_predicted_reward = torch.zeros(self.N_candidates, device=self.device) 54 | state = state_posterior.sample([self.N_candidates]).squeeze() 55 | rnn_hidden = self.rnn_hidden.repeat([self.N_candidates, 1]) 56 | 57 | # Compute total predicted reward by open-loop prediction using prior 58 | for t in range(self.horizon): 59 | next_state_prior, rnn_hidden = \ 60 | self.rssm.prior(state, action_candidates[t], rnn_hidden) 61 | state = next_state_prior.sample() 62 | total_predicted_reward += self.reward_model(state, rnn_hidden).squeeze() 63 | 64 | # update action distribution using top-k samples 65 | top_indexes = \ 66 | total_predicted_reward.argsort(descending=True)[: self.N_top_candidates] 67 | top_action_candidates = action_candidates[:, top_indexes, :] 68 | mean = top_action_candidates.mean(dim=1) 69 | stddev = (top_action_candidates - mean.unsqueeze(1) 70 | ).abs().sum(dim=1) / (self.N_top_candidates - 1) 71 | action_dist = Normal(mean, stddev) 72 | 73 | # Return only first action (replan each state based on new observation) 74 | action = mean[0] 75 | 76 | # update rnn hidden state for next step planning 77 | with torch.no_grad(): 78 | _, self.rnn_hidden = self.rssm.prior(state_posterior.sample(), 79 | action.unsqueeze(0), 80 | self.rnn_hidden) 81 | return action.cpu().numpy() 82 | 83 | def reset(self): 84 | self.rnn_hidden = torch.zeros(1, self.rssm.rnn_hidden_dim, device=self.device) 85 | -------------------------------------------------------------------------------- /video_prediction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import matplotlib.pyplot as plt 5 | from matplotlib import animation 6 | import numpy as np 7 | import torch 8 | from dm_control import suite 9 | from dm_control.suite.wrappers import pixels 10 | from agent import CEMAgent 11 | from model import Encoder, RecurrentStateSpaceModel, ObservationModel, RewardModel 12 | from utils import preprocess_obs 13 | from wrappers import GymWrapper, RepeatAction 14 | 15 | 16 | def save_video_as_gif(frames): 17 | """ 18 | make video with given frames and save as "video_prediction.gif" 19 | """ 20 | plt.figure() 21 | patch = plt.imshow(frames[0]) 22 | plt.axis('off') 23 | 24 | def animate(i): 25 | patch.set_data(frames[i]) 26 | plt.title('Left: GT frame' + ' '*20 + 'Right: predicted frame \n Step %d' % (i)) 27 | 28 | anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=150) 29 | anim.save('video_prediction.gif', writer='imagemagick') 30 | 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser(description='Open-loop video prediction with learned model') 34 | parser.add_argument('dir', type=str, help='log directory to load learned model') 35 | parser.add_argument('--length', type=int, default=50, 36 | help='the length of video prediction') 37 | parser.add_argument('--domain-name', type=str, default='cheetah') 38 | parser.add_argument('--task-name', type=str, default='run') 39 | parser.add_argument('-R', '--action-repeat', type=int, default=4) 40 | parser.add_argument('-H', '--horizon', type=int, default=12) 41 | parser.add_argument('-I', '--N-iterations', type=int, default=10) 42 | parser.add_argument('-J', '--N-candidates', type=int, default=1000) 43 | parser.add_argument('-K', '--N-top-candidates', type=int, default=100) 44 | args = parser.parse_args() 45 | 46 | # define environment and apply wrapper 47 | env = suite.load(args.domain_name, args.task_name) 48 | env = pixels.Wrapper(env, render_kwargs={'height': 64, 49 | 'width': 64, 50 | 'camera_id': 0}) 51 | env = GymWrapper(env) 52 | env = RepeatAction(env, skip=args.action_repeat) 53 | 54 | # define models 55 | with open(os.path.join(args.dir, 'args.json'), 'r') as f: 56 | train_args = json.load(f) 57 | 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | encoder = Encoder().to(device) 60 | rssm = RecurrentStateSpaceModel(train_args['state_dim'], 61 | env.action_space.shape[0], 62 | train_args['rnn_hidden_dim']).to(device) 63 | obs_model = ObservationModel(train_args['state_dim'], 64 | train_args['rnn_hidden_dim']).to(device) 65 | reward_model = RewardModel(train_args['state_dim'], 66 | train_args['rnn_hidden_dim']).to(device) 67 | 68 | # load learned parameters 69 | encoder.load_state_dict(torch.load(os.path.join(args.dir, 'encoder.pth'))) 70 | rssm.load_state_dict(torch.load(os.path.join(args.dir, 'rssm.pth'))) 71 | obs_model.load_state_dict(torch.load(os.path.join(args.dir, 'obs_model.pth'))) 72 | reward_model.load_state_dict(torch.load(os.path.join(args.dir, 'reward_model.pth'))) 73 | 74 | # define agent 75 | cem_agent = CEMAgent(encoder, rssm, reward_model, 76 | args.horizon, args.N_iterations, 77 | args.N_candidates, args.N_top_candidates) 78 | 79 | # open-loop video prediction 80 | # select starting point of open-loop prediction randomly 81 | starting_point = torch.randint(1000 // args.action_repeat - args.length, (1,)).item() 82 | # interact in environment until starting point and charge context in cem_agent.rnn_hidden 83 | obs = env.reset() 84 | for _ in range(starting_point): 85 | action = cem_agent(obs) 86 | obs, _, _, _ = env.step(action) 87 | 88 | # preprocess observatin and embed by encoder 89 | preprocessed_obs = preprocess_obs(obs) 90 | preprocessed_obs = torch.as_tensor(preprocessed_obs, device=device) 91 | preprocessed_obs = preprocessed_obs.transpose(1, 2).transpose(0, 1).unsqueeze(0) 92 | with torch.no_grad(): 93 | embedded_obs = encoder(preprocessed_obs) 94 | 95 | # compute state using embedded observation 96 | # NOTE: after this, state is updated only using prior, 97 | # it means model doesn't see observation 98 | rnn_hidden = cem_agent.rnn_hidden 99 | state = rssm.posterior(rnn_hidden, embedded_obs).sample() 100 | frame = np.zeros((64, 128, 3)) 101 | frames = [] 102 | for _ in range(args.length): 103 | # action is selected same as training time (closed-loop) 104 | action = cem_agent(obs) 105 | obs, _, _, _ = env.step(action) 106 | 107 | # update state and reconstruct observation with same action 108 | action = torch.as_tensor(action, device=device).unsqueeze(0) 109 | with torch.no_grad(): 110 | state_prior, rnn_hidden = rssm.prior(state, action, rnn_hidden) 111 | state = state_prior.sample() 112 | predicted_obs = obs_model(state, rnn_hidden) 113 | 114 | # arrange GT frame and predicted frame in parallel 115 | frame[:, :64, :] = preprocess_obs(obs) 116 | frame[:, 64:, :] = predicted_obs.squeeze().transpose(0, 1).transpose(1, 2).cpu().numpy() 117 | frames.append((frame + 0.5).clip(0.0, 1.0)) 118 | 119 | save_video_as_gif(frames) 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.distributions import Normal 5 | 6 | 7 | class Encoder(nn.Module): 8 | """ 9 | Encoder to embed image observation (3, 64, 64) to vector (1024,) 10 | """ 11 | def __init__(self): 12 | super(Encoder, self).__init__() 13 | self.cv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2) 14 | self.cv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 15 | self.cv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2) 16 | self.cv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2) 17 | 18 | def forward(self, obs): 19 | hidden = F.relu(self.cv1(obs)) 20 | hidden = F.relu(self.cv2(hidden)) 21 | hidden = F.relu(self.cv3(hidden)) 22 | embedded_obs = F.relu(self.cv4(hidden)).reshape(hidden.size(0), -1) 23 | return embedded_obs 24 | 25 | 26 | class RecurrentStateSpaceModel(nn.Module): 27 | """ 28 | This class includes multiple components 29 | Deterministic state model: h_t+1 = f(h_t, s_t, a_t) 30 | Stochastic state model (prior): p(s_t+1 | h_t+1) 31 | State posterior: q(s_t | h_t, o_t) 32 | NOTE: actually, this class takes embedded observation by Encoder class 33 | min_stddev is added to stddev same as original implementation 34 | Activation function for this class is F.relu same as original implementation 35 | """ 36 | def __init__(self, state_dim, action_dim, rnn_hidden_dim, 37 | hidden_dim=200, min_stddev=0.1, act=F.relu): 38 | super(RecurrentStateSpaceModel, self).__init__() 39 | self.state_dim = state_dim 40 | self.action_dim = action_dim 41 | self.rnn_hidden_dim = rnn_hidden_dim 42 | self.fc_state_action = nn.Linear(state_dim + action_dim, hidden_dim) 43 | self.fc_rnn_hidden = nn.Linear(rnn_hidden_dim, hidden_dim) 44 | self.fc_state_mean_prior = nn.Linear(hidden_dim, state_dim) 45 | self.fc_state_stddev_prior = nn.Linear(hidden_dim, state_dim) 46 | self.fc_rnn_hidden_embedded_obs = nn.Linear(rnn_hidden_dim + 1024, hidden_dim) 47 | self.fc_state_mean_posterior = nn.Linear(hidden_dim, state_dim) 48 | self.fc_state_stddev_posterior = nn.Linear(hidden_dim, state_dim) 49 | self.rnn = nn.GRUCell(hidden_dim, rnn_hidden_dim) 50 | self._min_stddev = min_stddev 51 | self.act = act 52 | 53 | def forward(self, state, action, rnn_hidden, embedded_next_obs): 54 | """ 55 | h_t+1 = f(h_t, s_t, a_t) 56 | Return prior p(s_t+1 | h_t+1) and posterior p(s_t+1 | h_t+1, o_t+1) 57 | for model training 58 | """ 59 | next_state_prior, rnn_hidden = self.prior(state, action, rnn_hidden) 60 | next_state_posterior = self.posterior(rnn_hidden, embedded_next_obs) 61 | return next_state_prior, next_state_posterior, rnn_hidden 62 | 63 | def prior(self, state, action, rnn_hidden): 64 | """ 65 | h_t+1 = f(h_t, s_t, a_t) 66 | Compute prior p(s_t+1 | h_t+1) 67 | """ 68 | hidden = self.act(self.fc_state_action(torch.cat([state, action], dim=1))) 69 | rnn_hidden = self.rnn(hidden, rnn_hidden) 70 | hidden = self.act(self.fc_rnn_hidden(rnn_hidden)) 71 | 72 | mean = self.fc_state_mean_prior(hidden) 73 | stddev = F.softplus(self.fc_state_stddev_prior(hidden)) + self._min_stddev 74 | return Normal(mean, stddev), rnn_hidden 75 | 76 | def posterior(self, rnn_hidden, embedded_obs): 77 | """ 78 | Compute posterior q(s_t | h_t, o_t) 79 | """ 80 | hidden = self.act(self.fc_rnn_hidden_embedded_obs( 81 | torch.cat([rnn_hidden, embedded_obs], dim=1))) 82 | mean = self.fc_state_mean_posterior(hidden) 83 | stddev = F.softplus(self.fc_state_stddev_posterior(hidden)) + self._min_stddev 84 | return Normal(mean, stddev) 85 | 86 | 87 | class ObservationModel(nn.Module): 88 | """ 89 | p(o_t | s_t, h_t) 90 | Observation model to reconstruct image observation (3, 64, 64) 91 | from state and rnn hidden state 92 | """ 93 | def __init__(self, state_dim, rnn_hidden_dim): 94 | super(ObservationModel, self).__init__() 95 | self.fc = nn.Linear(state_dim + rnn_hidden_dim, 1024) 96 | self.dc1 = nn.ConvTranspose2d(1024, 128, kernel_size=5, stride=2) 97 | self.dc2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2) 98 | self.dc3 = nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2) 99 | self.dc4 = nn.ConvTranspose2d(32, 3, kernel_size=6, stride=2) 100 | 101 | def forward(self, state, rnn_hidden): 102 | hidden = self.fc(torch.cat([state, rnn_hidden], dim=1)) 103 | hidden = hidden.view(hidden.size(0), 1024, 1, 1) 104 | hidden = F.relu(self.dc1(hidden)) 105 | hidden = F.relu(self.dc2(hidden)) 106 | hidden = F.relu(self.dc3(hidden)) 107 | obs = self.dc4(hidden) 108 | return obs 109 | 110 | 111 | class RewardModel(nn.Module): 112 | """ 113 | p(r_t | s_t, h_t) 114 | Reward model to predict reward from state and rnn hidden state 115 | """ 116 | def __init__(self, state_dim, rnn_hidden_dim, hidden_dim=300, act=F.relu): 117 | super(RewardModel, self).__init__() 118 | self.fc1 = nn.Linear(state_dim + rnn_hidden_dim, hidden_dim) 119 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 120 | self.fc3 = nn.Linear(hidden_dim, hidden_dim) 121 | self.fc4 = nn.Linear(hidden_dim, 1) 122 | self.act = act 123 | 124 | def forward(self, state, rnn_hidden): 125 | hidden = self.act(self.fc1(torch.cat([state, rnn_hidden], dim=1))) 126 | hidden = self.act(self.fc2(hidden)) 127 | hidden = self.act(self.fc3(hidden)) 128 | reward = self.fc4(hidden) 129 | return reward 130 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | import json 4 | import os 5 | from pprint import pprint 6 | import time 7 | import numpy as np 8 | import torch 9 | from torch.distributions.kl import kl_divergence 10 | from torch.nn.functional import mse_loss 11 | from torch.nn.utils import clip_grad_norm_ 12 | from torch.optim import Adam 13 | from torch.utils.tensorboard import SummaryWriter 14 | from dm_control import suite 15 | from dm_control.suite.wrappers import pixels 16 | from agent import CEMAgent 17 | from model import Encoder, RecurrentStateSpaceModel, ObservationModel, RewardModel 18 | from utils import ReplayBuffer, preprocess_obs 19 | from wrappers import GymWrapper, RepeatAction 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser(description='PlaNet for DM control') 24 | parser.add_argument('--seed', type=int, default=0) 25 | parser.add_argument('--log-dir', type=str, default='log') 26 | parser.add_argument('--test-interval', type=int, default=10) 27 | parser.add_argument('--domain-name', type=str, default='cheetah') 28 | parser.add_argument('--task-name', type=str, default='run') 29 | parser.add_argument('-R', '--action-repeat', type=int, default=4) 30 | parser.add_argument('--state-dim', type=int, default=30) 31 | parser.add_argument('--rnn-hidden-dim', type=int, default=200) 32 | parser.add_argument('--buffer-capacity', type=int, default=1000000) 33 | parser.add_argument('--all-episodes', type=int, default=1000) 34 | parser.add_argument('-S', '--seed-episodes', type=int, default=5) 35 | parser.add_argument('-C', '--collect-interval', type=int, default=100) 36 | parser.add_argument('-B', '--batch-size', type=int, default=50) 37 | parser.add_argument('-L', '--chunk-length', type=int, default=50) 38 | parser.add_argument('--lr', type=float, default=1e-3) 39 | parser.add_argument('--eps', type=float, default=1e-4) 40 | parser.add_argument('--clip-grad-norm', type=int, default=1000) 41 | parser.add_argument('--free-nats', type=int, default=3) 42 | parser.add_argument('-H', '--horizon', type=int, default=12) 43 | parser.add_argument('-I', '--N-iterations', type=int, default=10) 44 | parser.add_argument('-J', '--N-candidates', type=int, default=1000) 45 | parser.add_argument('-K', '--N-top-candidates', type=int, default=100) 46 | parser.add_argument('--action-noise-var', type=float, default=0.3) 47 | args = parser.parse_args() 48 | 49 | # Prepare logging 50 | log_dir = os.path.join(args.log_dir, args.domain_name + '_' + args.task_name) 51 | log_dir = os.path.join(log_dir, datetime.now().strftime('%Y%m%d_%H%M')) 52 | os.makedirs(log_dir) 53 | with open(os.path.join(log_dir, 'args.json'), 'w') as f: 54 | json.dump(vars(args), f) 55 | pprint(vars(args)) 56 | writer = SummaryWriter(log_dir=log_dir) 57 | 58 | # set seed (NOTE: some randomness is still remaining (e.g. cuDNN's randomness)) 59 | np.random.seed(args.seed) 60 | torch.manual_seed(args.seed) 61 | if torch.cuda.is_available(): 62 | torch.cuda.manual_seed(args.seed) 63 | 64 | # define env and apply wrappers 65 | env = suite.load(args.domain_name, args.task_name, task_kwargs={'random': args.seed}) 66 | env = pixels.Wrapper(env, render_kwargs={'height': 64, 67 | 'width': 64, 68 | 'camera_id': 0}) 69 | env = GymWrapper(env) 70 | env = RepeatAction(env, skip=args.action_repeat) 71 | 72 | # define replay buffer 73 | replay_buffer = ReplayBuffer(capacity=args.buffer_capacity, 74 | observation_shape=env.observation_space.shape, 75 | action_dim=env.action_space.shape[0]) 76 | 77 | # define models and optimizer 78 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 79 | encoder = Encoder().to(device) 80 | rssm = RecurrentStateSpaceModel(args.state_dim, 81 | env.action_space.shape[0], 82 | args.rnn_hidden_dim).to(device) 83 | obs_model = ObservationModel(args.state_dim, args.rnn_hidden_dim).to(device) 84 | reward_model = RewardModel(args.state_dim, args.rnn_hidden_dim).to(device) 85 | all_params = (list(encoder.parameters()) + 86 | list(rssm.parameters()) + 87 | list(obs_model.parameters()) + 88 | list(reward_model.parameters())) 89 | optimizer = Adam(all_params, lr=args.lr, eps=args.eps) 90 | 91 | # collect initial experience with random action 92 | for episode in range(args.seed_episodes): 93 | obs = env.reset() 94 | done = False 95 | while not done: 96 | action = env.action_space.sample() 97 | next_obs, reward, done, _ = env.step(action) 98 | replay_buffer.push(obs, action, reward, done) 99 | obs = next_obs 100 | 101 | # main training loop 102 | for episode in range(args.seed_episodes, args.all_episodes): 103 | # collect experiences 104 | start = time.time() 105 | cem_agent = CEMAgent(encoder, rssm, reward_model, 106 | args.horizon, args.N_iterations, 107 | args.N_candidates, args.N_top_candidates) 108 | 109 | obs = env.reset() 110 | done = False 111 | total_reward = 0 112 | while not done: 113 | action = cem_agent(obs) 114 | action += np.random.normal(0, np.sqrt(args.action_noise_var), 115 | env.action_space.shape[0]) 116 | next_obs, reward, done, _ = env.step(action) 117 | replay_buffer.push(obs, action, reward, done) 118 | obs = next_obs 119 | total_reward += reward 120 | 121 | writer.add_scalar('total reward at train', total_reward, episode) 122 | print('episode [%4d/%4d] is collected. Total reward is %f' % 123 | (episode+1, args.all_episodes, total_reward)) 124 | print('elasped time for interaction: %.2fs' % (time.time() - start)) 125 | 126 | # update model parameters 127 | start = time.time() 128 | for update_step in range(args.collect_interval): 129 | observations, actions, rewards, _ = \ 130 | replay_buffer.sample(args.batch_size, args.chunk_length) 131 | 132 | # preprocess observations and transpose tensor for RNN training 133 | observations = preprocess_obs(observations) 134 | observations = torch.as_tensor(observations, device=device) 135 | observations = observations.transpose(3, 4).transpose(2, 3) 136 | observations = observations.transpose(0, 1) 137 | actions = torch.as_tensor(actions, device=device).transpose(0, 1) 138 | rewards = torch.as_tensor(rewards, device=device).transpose(0, 1) 139 | 140 | # embed observations with CNN 141 | embedded_observations = encoder( 142 | observations.reshape(-1, 3, 64, 64)).view(args.chunk_length, args.batch_size, -1) 143 | 144 | # prepare Tensor to maintain states sequence and rnn hidden states sequence 145 | states = torch.zeros( 146 | args.chunk_length, args.batch_size, args.state_dim, device=device) 147 | rnn_hiddens = torch.zeros( 148 | args.chunk_length, args.batch_size, args.rnn_hidden_dim, device=device) 149 | 150 | # initialize state and rnn hidden state with 0 vector 151 | state = torch.zeros(args.batch_size, args.state_dim, device=device) 152 | rnn_hidden = torch.zeros(args.batch_size, args.rnn_hidden_dim, device=device) 153 | 154 | # compute state and rnn hidden sequences and kl loss 155 | kl_loss = 0 156 | for l in range(args.chunk_length-1): 157 | next_state_prior, next_state_posterior, rnn_hidden = \ 158 | rssm(state, actions[l], rnn_hidden, embedded_observations[l+1]) 159 | state = next_state_posterior.rsample() 160 | states[l+1] = state 161 | rnn_hiddens[l+1] = rnn_hidden 162 | kl = kl_divergence(next_state_prior, next_state_posterior).sum(dim=1) 163 | kl_loss += kl.clamp(min=args.free_nats).mean() 164 | kl_loss /= (args.chunk_length - 1) 165 | 166 | # compute reconstructed observations and predicted rewards 167 | flatten_states = states.view(-1, args.state_dim) 168 | flatten_rnn_hiddens = rnn_hiddens.view(-1, args.rnn_hidden_dim) 169 | recon_observations = obs_model(flatten_states, flatten_rnn_hiddens).view( 170 | args.chunk_length, args.batch_size, 3, 64, 64) 171 | predicted_rewards = reward_model(flatten_states, flatten_rnn_hiddens).view( 172 | args.chunk_length, args.batch_size, 1) 173 | 174 | # compute loss for observation and reward 175 | obs_loss = 0.5 * mse_loss( 176 | recon_observations[1:], observations[1:], reduction='none').mean([0, 1]).sum() 177 | reward_loss = 0.5 * mse_loss(predicted_rewards[1:], rewards[:-1]) 178 | 179 | # add all losses and update model parameters with gradient descent 180 | loss = kl_loss + obs_loss + reward_loss 181 | optimizer.zero_grad() 182 | loss.backward() 183 | clip_grad_norm_(all_params, args.clip_grad_norm) 184 | optimizer.step() 185 | 186 | # print losses and add tensorboard 187 | print('update_step: %3d loss: %.5f, kl_loss: %.5f, obs_loss: %.5f, reward_loss: % .5f' 188 | % (update_step+1, 189 | loss.item(), kl_loss.item(), obs_loss.item(), reward_loss.item())) 190 | total_update_step = episode * args.collect_interval + update_step 191 | writer.add_scalar('overall loss', loss.item(), total_update_step) 192 | writer.add_scalar('kl loss', kl_loss.item(), total_update_step) 193 | writer.add_scalar('obs loss', obs_loss.item(), total_update_step) 194 | writer.add_scalar('reward loss', reward_loss.item(), total_update_step) 195 | 196 | print('elasped time for update: %.2fs' % (time.time() - start)) 197 | 198 | # test to get score without exploration noise 199 | if (episode + 1) % args.test_interval == 0: 200 | start = time.time() 201 | cem_agent = CEMAgent(encoder, rssm, reward_model, 202 | args.horizon, args.N_iterations, 203 | args.N_candidates, args.N_top_candidates) 204 | obs = env.reset() 205 | done = False 206 | total_reward = 0 207 | while not done: 208 | action = cem_agent(obs) 209 | obs, reward, done, _ = env.step(action) 210 | total_reward += reward 211 | 212 | writer.add_scalar('total reward at test', total_reward, episode) 213 | print('Total test reward at episode [%4d/%4d] is %f' % 214 | (episode+1, args.all_episodes, total_reward)) 215 | print('elasped time for test: %.2fs' % (time.time() - start)) 216 | 217 | # save learned model parameters 218 | torch.save(encoder.state_dict(), os.path.join(log_dir, 'encoder.pth')) 219 | torch.save(rssm.state_dict(), os.path.join(log_dir, 'rssm.pth')) 220 | torch.save(obs_model.state_dict(), os.path.join(log_dir, 'obs_model.pth')) 221 | torch.save(reward_model.state_dict(), os.path.join(log_dir, 'reward_model.pth')) 222 | writer.close() 223 | 224 | if __name__ == '__main__': 225 | main() 226 | --------------------------------------------------------------------------------