├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── data ├── sample_planar.py └── sample_pole.py ├── datasets.py ├── ilqr.py ├── ilqr_config ├── balance.json ├── cartpole.json ├── plane.json ├── swing.json └── threepole.json ├── ilqr_utils.py ├── latent_map_pendulum.py ├── latent_map_planar.py ├── losses.py ├── mdp ├── cartpole_mdp.py ├── common.py ├── pendulum_mdp.py ├── plane_obstacles_mdp.py ├── pole_base.py └── three_pole_mdp.py ├── networks.py ├── pcc.yml ├── pcc_model.py ├── sample_results ├── cartpole_1.gif ├── cartpole_2.gif ├── latent_map_pend.png ├── latent_map_sample.png ├── pendulum_1.gif ├── pendulum_2.gif ├── planar_1.gif └── planar_2.gif ├── train_pcc.py └── true_map.png /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__ 3 | .vscode 4 | iwae_result/ 5 | *.pyc 6 | result 7 | logs 8 | planar 9 | pendulum 10 | .ipynb_checkpoints/ 11 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 VinAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prediction, Consistency and Curvature 2 | 3 | This is a pytorch implementation of the paper "[Prediction, Consistency, Curvature: Representation Learning for Locally-Linear Control](https://arxiv.org/abs/1909.01506)". The work was done during the residency at [VinAI Research](https://vinai.io), Hanoi, Vietnam. 4 | 5 | ## Installing 6 | 7 | First, clone the repository: 8 | 9 | ``` 10 | https://github.com/VinAIResearch/PCC-pytorch.git 11 | ``` 12 | 13 | Then install the dependencies as listed in `pcc.yml` and activate the environment: 14 | 15 | ``` 16 | conda env create -f pcc.yml 17 | 18 | conda activate pcc 19 | ``` 20 | 21 | ## Training 22 | 23 | The code currently supports training for ``planar``, ``pendulum``, ``cartpole`` and ``threepole`` environment. Run the ``train_pcc.py`` with your own settings. For example: 24 | 25 | ``` 26 | python train_pcc.py \ 27 | --env=planar \ 28 | --armotized=False \ 29 | --log_dir=planar_1 \ 30 | --seed=1 \ 31 | --data_size=5000 \ 32 | --noise=0 \ 33 | --batch_size=128 \ 34 | --lam_p=1.0 \ 35 | --lam_c=8.0 \ 36 | --lam_cur=8.0 \ 37 | --vae_coeff=0.01 \ 38 | --determ_coeff=0.3 \ 39 | --lr=0.0005 \ 40 | --decay=0.001 \ 41 | --num_iter=5000 \ 42 | --iter_save=1000 \ 43 | --save_map=True 44 | ``` 45 | 46 | First, data is sampled according to the given data size and noise level, then PCC model will be trained using the specified settings. 47 | 48 | If the argument save_map is set to True, the latent map will be drawn every 10 epoches (for planar only), then the gif file will be saved at the same directory as the trained model. 49 | 50 | You can also visualize the training process by running ``tensorboard --logdir={path_to_log_file}``, where ``path_to_log_file`` has the form ``logs/{env}/{log_dir}``. The trained model will be saved at ``result/{env}/{log_dir}``. 51 | 52 | ### Visualizing latent maps 53 | 54 | You can visualize the latent map for both planar and pendulum, to do that simply run: 55 | 56 | ``` 57 | python latent_map_planar.py --log_path={log_to_trained_model} --epoch={epoch} 58 | or 59 | python latent_map_pendulum.py --log_path={log_to_trained_model} --epoch={epoch} 60 | ``` 61 | 62 | ## Sampling data 63 | 64 | You can generate the training images for visualization by simply running: 65 | 66 | ``` 67 | cd data 68 | 69 | python sample_{env_name}_data.py --sample_size={sample_size} --noise={noise} 70 | ``` 71 | 72 | Currently the code supports simulating 3 environments: `planar`, `pendulum` and `cartpole`. 73 | 74 | The raw data (images) is saved in data/{env_name}/raw\_{noise}\_noise 75 | 76 | ## Running iLQR on latent space 77 | 78 | The configuration file for running iLQR for each task is in ``ilqr_config`` folder, you can modify with your own settings. Run ``python ilqr.py --task={task}``, where ``task`` is in ``{plane, swing, balance, cartpole}``. 79 | 80 | The code will run iLQR for all models trained for that specific task and compute some statistics. The result is saved in ``iLQR/result``. 81 | 82 | ## Result 83 | We evaluate the PCC model in 2 ways: quality of the latent map and the percentage of time the agent spent in the goal region. 84 | ### Planar system 85 | 86 | #### Latent map 87 | Below is a random latent map PCC produces. You can watch a video clip comparing how latent maps produced by E2C and PCC evolve at this link: https://www.youtube.com/watch?v=pBmzFvvE2bo. 88 | 89 | ![Latent space learned by PCC](sample_results/latent_map_sample.png) 90 | 91 | #### Control result 92 | We got around 48% on average and around 76% for the best model. Below are 2 sample trajectories of the agent. 93 | 94 | ![Sample planar trajectory 1](sample_results/planar_1.gif) 95 | 96 | ![Sample planar trajectory 2](sample_results/planar_2.gif) 97 | 98 | ### Inverted pendulum 99 | 100 | #### Latent map 101 | 102 | Below is a random latent map PCC produces. 103 | 104 | ![Latent space learned by PCC](sample_results/latent_map_pend.png) 105 | 106 | #### Control result 107 | 108 | We got around 60.7% on average and around 80.65% for the best model. Below are 2 sample trajectories of the inverted pendulum. 109 | 110 | ![Sample inverted pendulum trajectory 1](sample_results/pendulum_1.gif) 111 | 112 | ![Sample inverted pendulum trajectory 2](sample_results/pendulum_2.gif) 113 | ### Cartpole 114 | 115 | ![Sample cartpole trajectory 1](sample_results/cartpole_1.gif) 116 | 117 | ![Sample cartpole trajectory 2](sample_results/cartpole_2.gif) 118 | 119 | ### Acknowledgment 120 | 121 | Many thanks to Nir Levine and Yinlam Chow for their help in answering the questions related to the PCC paper. 122 | 123 | ### Citation 124 | 125 | If you find this implementation useful for your work, please consider starring this repository. 126 | -------------------------------------------------------------------------------- /data/sample_planar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from datetime import datetime 5 | from os import path 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | from mdp.plane_obstacles_mdp import PlanarObstaclesMDP 10 | from PIL import Image 11 | from tqdm import trange 12 | 13 | 14 | root_path = str(Path(os.path.dirname(os.path.abspath(__file__))).parent) 15 | np.random.seed(1) 16 | 17 | 18 | def get_all_pos(mdp): 19 | start = 0 20 | end = mdp.width 21 | state_list = [] 22 | for x in range(start, end): 23 | for y in range(start, end): 24 | state = np.array([x, y]) 25 | if mdp.is_valid_state(state): 26 | state_list.append(state) 27 | return state_list 28 | 29 | 30 | def sample(sample_size=5000, noise=0.0): 31 | """ 32 | return [(s, u, s_next)] 33 | """ 34 | mdp = PlanarObstaclesMDP(noise=noise) 35 | 36 | # place holder for data 37 | x_data = np.zeros((sample_size, mdp.width, mdp.height), dtype="float32") 38 | u_data = np.zeros((sample_size, mdp.action_dim), dtype="float32") 39 | x_next_data = np.zeros((sample_size, mdp.width, mdp.height), dtype="float32") 40 | state_data = np.zeros((sample_size, 2), dtype="float32") 41 | state_next_data = np.zeros((sample_size, 2), dtype="float32") 42 | 43 | # get all possible states (discretized on integer grid) 44 | state_list = get_all_pos(mdp) 45 | 46 | for j in trange(sample_size, desc="Sampling remaining data"): 47 | state_data[j] = state_list[j % len(state_list)] 48 | x_data[j] = mdp.render(state_data[j]) 49 | u_data[j] = mdp.sample_valid_random_action(state_data[j]) 50 | state_next_data[j] = mdp.transition_function(state_data[j], u_data[j]) 51 | x_next_data[j] = mdp.render(state_next_data[j]) 52 | return x_data, u_data, x_next_data, state_data, state_next_data 53 | 54 | 55 | def write_to_file(noise, sample_size): 56 | """ 57 | write [(x, u, x_next)] to output dir 58 | """ 59 | output_dir = root_path + "/data/planar/raw_{:d}_{:.0f}".format(sample_size, noise) 60 | if not path.exists(output_dir): 61 | os.makedirs(output_dir) 62 | 63 | x_data, u_data, x_next_data, state_data, state_next_data = sample(sample_size, noise) 64 | 65 | samples = [] 66 | 67 | for i, _ in enumerate(x_data): 68 | before_file = "before-{:05d}.png".format(i) 69 | Image.fromarray(x_data[i] * 255.0).convert("L").save(path.join(output_dir, before_file)) 70 | 71 | after_file = "after-{:05d}.png".format(i) 72 | Image.fromarray(x_next_data[i] * 255.0).convert("L").save(path.join(output_dir, after_file)) 73 | 74 | initial_state = state_data[i] 75 | after_state = state_next_data[i] 76 | u = u_data[i] 77 | 78 | samples.append( 79 | { 80 | "before_state": initial_state.tolist(), 81 | "after_state": after_state.tolist(), 82 | "before": before_file, 83 | "after": after_file, 84 | "control": u.tolist(), 85 | } 86 | ) 87 | 88 | with open(path.join(output_dir, "data.json"), "wt") as outfile: 89 | json.dump( 90 | { 91 | "metadata": {"num_samples": sample_size, "time_created": str(datetime.now()), "version": 1}, 92 | "samples": samples, 93 | }, 94 | outfile, 95 | indent=2, 96 | ) 97 | 98 | 99 | def main(args): 100 | sample_size = args.sample_size 101 | noise = args.noise 102 | write_to_file(noise, sample_size) 103 | 104 | 105 | if __name__ == "__main__": 106 | parser = argparse.ArgumentParser(description="sample planar data") 107 | 108 | parser.add_argument("--sample_size", required=True, type=int, help="the number of samples") 109 | parser.add_argument("--noise", default=0, type=int, help="level of noise") 110 | 111 | args = parser.parse_args() 112 | 113 | main(args) 114 | -------------------------------------------------------------------------------- /data/sample_pole.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import json 5 | import os 6 | import os.path as path 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | from mdp.cartpole_mdp import CartPoleMDP 12 | from mdp.pendulum_mdp import PendulumMDP 13 | from mdp.three_pole_mdp import ThreePoleMDP 14 | from PIL import Image 15 | from tqdm import trange 16 | 17 | 18 | root_path = str(Path(os.path.dirname(os.path.abspath(__file__))).parent) 19 | 20 | widths = {"pendulum": 48, "cartpole": 80, "threepole": 80} 21 | heights = {"pendulum": 48, "cartpole": 80, "threepole": 80} 22 | state_dims = {"pendulum": 2, "cartpole": 4, "threepole": 6} 23 | frequencies = {"pendulum": 50, "cartpole": 50, "threepole": 50} 24 | mdps = {"pendulum": PendulumMDP, "cartpole": CartPoleMDP, "threepole": ThreePoleMDP} 25 | 26 | 27 | def sample(env_name, sample_size, noise): 28 | """ 29 | return [(x, u, x_next, s, s_next)] 30 | """ 31 | width, height, frequency = widths[env_name], heights[env_name], frequencies[env_name] 32 | s_dim = state_dims[env_name] 33 | mdp = mdps[env_name](width=width, height=height, frequency=frequency, noise=noise) 34 | 35 | # Data buffers to fill. 36 | x_data = np.zeros((sample_size, width, height, 2), dtype="float32") 37 | u_data = np.zeros((sample_size, mdp.action_dim), dtype="float32") 38 | x_next_data = np.zeros((sample_size, width, height, 2), dtype="float32") 39 | state_data = np.zeros((sample_size, s_dim, 2), dtype="float32") 40 | state_next_data = np.zeros((sample_size, s_dim, 2), dtype="float32") 41 | 42 | # Generate interaction tuples (random states and actions). 43 | for sample in trange(sample_size, desc="Sampling " + env_name + " data"): 44 | s0 = mdp.sample_random_state() 45 | x0 = mdp.render(s0) 46 | a0 = mdp.sample_random_action() 47 | s1 = mdp.transition_function(s0, a0) 48 | 49 | x1 = mdp.render(s1) 50 | a1 = mdp.sample_random_action() 51 | s2 = mdp.transition_function(s1, a1) 52 | x2 = mdp.render(s2) 53 | # Store interaction tuple. 54 | # Current state (w/ history). 55 | x_data[sample, :, :, 0] = x0[:, :, 0] 56 | x_data[sample, :, :, 1] = x1[:, :, 0] 57 | state_data[sample, :, 0] = s0 58 | state_data[sample, :, 1] = s1 59 | # Action. 60 | u_data[sample] = a1 61 | # Next state (w/ history). 62 | x_next_data[sample, :, :, 0] = x1[:, :, 0] 63 | x_next_data[sample, :, :, 1] = x2[:, :, 0] 64 | state_next_data[sample, :, 0] = s1 65 | state_next_data[sample, :, 1] = s2 66 | 67 | return x_data, u_data, x_next_data, state_data, state_next_data 68 | 69 | 70 | def write_to_file(env_name, sample_size, noise): 71 | """ 72 | write [(x, u, x_next)] to output dir 73 | """ 74 | output_dir = root_path + "/data/" + env_name + "/raw_{:d}_{:.0f}".format(sample_size, noise) 75 | if not path.exists(output_dir): 76 | os.makedirs(output_dir) 77 | 78 | samples = [] 79 | data = sample(env_name=env_name, sample_size=sample_size, noise=noise) 80 | x_data, u_data, x_next_data, state_data, state_next_data = data 81 | 82 | for i in range(x_data.shape[0]): 83 | x_1 = x_data[i, :, :, 0] 84 | x_2 = x_data[i, :, :, 1] 85 | before = np.hstack((x_1, x_2)) 86 | before_file = "before-{:05d}.png".format(i) 87 | Image.fromarray(before * 255.0).convert("L").save(path.join(output_dir, before_file)) 88 | 89 | after_file = "after-{:05d}.png".format(i) 90 | x_next_1 = x_next_data[i, :, :, 0] 91 | x_next_2 = x_next_data[i, :, :, 1] 92 | after = np.hstack((x_next_1, x_next_2)) 93 | Image.fromarray(after * 255.0).convert("L").save(path.join(output_dir, after_file)) 94 | 95 | initial_state = state_data[i] 96 | after_state = state_next_data[i] 97 | 98 | samples.append( 99 | { 100 | "before_state": initial_state.tolist(), 101 | "after_state": after_state.tolist(), 102 | "before": before_file, 103 | "after": after_file, 104 | "control": u_data[i].tolist(), 105 | } 106 | ) 107 | 108 | with open(path.join(output_dir, "data.json"), "wt") as outfile: 109 | json.dump( 110 | { 111 | "metadata": {"num_samples": x_data.shape[0], "time_created": str(datetime.now()), "version": 1}, 112 | "samples": samples, 113 | }, 114 | outfile, 115 | indent=2, 116 | ) 117 | 118 | 119 | def main(args): 120 | sample_size = args.sample_size 121 | noise = args.noise 122 | env_name = args.env 123 | assert env_name in ["pendulum", "cartpole", "threepole"] 124 | write_to_file(env_name, sample_size, noise) 125 | 126 | 127 | if __name__ == "__main__": 128 | parser = argparse.ArgumentParser(description="sample pendulum data") 129 | 130 | parser.add_argument("--sample_size", required=True, type=int, help="the number of samples") 131 | parser.add_argument("--noise", default=0, type=int, help="level of noise") 132 | parser.add_argument("--env", required=True, type=str, help="pendulum or cartpole or threepole") 133 | 134 | args = parser.parse_args() 135 | 136 | main(args) 137 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import numpy as np 5 | import torch 6 | from data import sample_planar, sample_pole 7 | from torch.utils.data import Dataset 8 | 9 | 10 | torch.set_default_dtype(torch.float64) 11 | 12 | 13 | class BaseDataset(Dataset): 14 | def __init__(self, data_path, sample_size, noise): 15 | self.sample_size = sample_size 16 | self.noise = noise 17 | self.data_path = data_path 18 | if not os.path.exists(self.data_path): 19 | os.makedirs(self.data_path) 20 | self._process() 21 | self.data_x, self.data_u, self.data_x_next = torch.load( 22 | self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise) 23 | ) 24 | 25 | def __len__(self): 26 | return len(self.data_x) 27 | 28 | def __getitem__(self, index): 29 | return self.data_x[index], self.data_u[index], self.data_x_next[index] 30 | 31 | def _process_image(self, img): 32 | pass 33 | 34 | def check_exists(self): 35 | return path.exists(self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise)) 36 | 37 | def _process(self): 38 | pass 39 | 40 | 41 | class PlanarDataset(BaseDataset): 42 | width = 40 43 | height = 40 44 | action_dim = 2 45 | 46 | def __init__(self, sample_size, noise): 47 | data_path = "data/planar/" 48 | super(PlanarDataset, self).__init__(data_path, sample_size, noise) 49 | 50 | def _process_image(self, img): 51 | return torch.from_numpy(img.flatten()).unsqueeze(0) 52 | 53 | def _process(self): 54 | if self.check_exists(): 55 | return 56 | else: 57 | ( 58 | x_numpy_data, 59 | u_numpy_data, 60 | x_next_numpy_data, 61 | state_numpy_data, 62 | state_next_numpy_data, 63 | ) = sample_planar.sample(sample_size=self.sample_size, noise=self.noise) 64 | data_len = len(x_numpy_data) 65 | 66 | # place holder for data 67 | data_x = torch.zeros(data_len, self.width * self.height) 68 | data_u = torch.zeros(data_len, self.action_dim) 69 | data_x_next = torch.zeros(data_len, self.width * self.height) 70 | 71 | for i in range(data_len): 72 | data_x[i] = self._process_image(x_numpy_data[i]) 73 | data_u[i] = torch.from_numpy(u_numpy_data[i]) 74 | data_x_next[i] = self._process_image(x_next_numpy_data[i]) 75 | 76 | data_set = (data_x, data_u, data_x_next) 77 | 78 | with open(self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise), "wb") as f: 79 | torch.save(data_set, f) 80 | 81 | 82 | class PendulumDataset(BaseDataset): 83 | width = 48 84 | height = 48 * 2 85 | action_dim = 1 86 | 87 | def __init__(self, sample_size, noise): 88 | data_path = "data/pendulum/" 89 | super(PendulumDataset, self).__init__(data_path, sample_size, noise) 90 | 91 | def _process_image(self, img): 92 | x = np.vstack((img[:, :, 0], img[:, :, 1])).flatten() 93 | return torch.from_numpy(x).unsqueeze(0) 94 | 95 | def _process(self): 96 | if self.check_exists(): 97 | return 98 | else: 99 | ( 100 | x_numpy_data, 101 | u_numpy_data, 102 | x_next_numpy_data, 103 | state_numpy_data, 104 | state_next_numpy_data, 105 | ) = sample_pole.sample(env_name="pendulum", sample_size=self.sample_size, noise=self.noise) 106 | data_len = len(x_numpy_data) 107 | 108 | # place holder for data 109 | data_x = torch.zeros(data_len, self.width * self.height) 110 | data_u = torch.zeros(data_len, self.action_dim) 111 | data_x_next = torch.zeros(data_len, self.width * self.height) 112 | 113 | for i in range(data_len): 114 | data_x[i] = self._process_image(x_numpy_data[i]) 115 | data_u[i] = torch.from_numpy(u_numpy_data[i]) 116 | data_x_next[i] = self._process_image(x_next_numpy_data[i]) 117 | 118 | data_set = (data_x, data_u, data_x_next) 119 | 120 | with open(self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise), "wb") as f: 121 | torch.save(data_set, f) 122 | 123 | 124 | class CartPoleDataset(BaseDataset): 125 | width = 80 126 | height = 80 * 2 127 | action_dim = 1 128 | 129 | def __init__(self, sample_size, noise): 130 | data_path = "data/cartpole/" 131 | super(CartPoleDataset, self).__init__(data_path, sample_size, noise) 132 | 133 | def _process_image(self, img): 134 | x = torch.zeros(size=(2, self.width, self.width)) 135 | x[0, :, :] = torch.from_numpy(img[:, :, 0]) 136 | x[1, :, :] = torch.from_numpy(img[:, :, 1]) 137 | return x.unsqueeze(0) 138 | 139 | def _process(self): 140 | if self.check_exists(): 141 | return 142 | else: 143 | ( 144 | x_numpy_data, 145 | u_numpy_data, 146 | x_next_numpy_data, 147 | state_numpy_data, 148 | state_next_numpy_data, 149 | ) = sample_pole.sample(env_name="cartpole", sample_size=self.sample_size, noise=self.noise) 150 | data_len = len(x_numpy_data) 151 | 152 | # place holder for data 153 | data_x = torch.zeros(data_len, 2, self.width, self.width) 154 | data_u = torch.zeros(data_len, self.action_dim) 155 | data_x_next = torch.zeros(data_len, 2, self.width, self.width) 156 | 157 | for i in range(data_len): 158 | data_x[i] = self._process_image(x_numpy_data[i]) 159 | data_u[i] = torch.from_numpy(u_numpy_data[i]) 160 | data_x_next[i] = self._process_image(x_next_numpy_data[i]) 161 | 162 | data_set = (data_x, data_u, data_x_next) 163 | 164 | with open(self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise), "wb") as f: 165 | torch.save(data_set, f) 166 | 167 | 168 | class ThreePoleDataset(BaseDataset): 169 | width = 80 170 | height = 80 * 2 171 | action_dim = 3 172 | 173 | def __init__(self, sample_size, noise): 174 | data_path = "data/threepole/" 175 | super(ThreePoleDataset, self).__init__(data_path, sample_size, noise) 176 | 177 | def _process_image(self, img): 178 | x = torch.zeros(size=(2, self.width, self.width)) 179 | x[0, :, :] = torch.from_numpy(img[:, :, 0]) 180 | x[1, :, :] = torch.from_numpy(img[:, :, 1]) 181 | return x.unsqueeze(0) 182 | 183 | def _process(self): 184 | if self.check_exists(): 185 | return 186 | else: 187 | ( 188 | x_numpy_data, 189 | u_numpy_data, 190 | x_next_numpy_data, 191 | state_numpy_data, 192 | state_next_numpy_data, 193 | ) = sample_pole.sample(env_name="threepole", sample_size=self.sample_size, noise=self.noise) 194 | data_len = len(x_numpy_data) 195 | 196 | # place holder for data 197 | data_x = torch.zeros(data_len, 2, self.width, self.width) 198 | data_u = torch.zeros(data_len, self.action_dim) 199 | data_x_next = torch.zeros(data_len, 2, self.width, self.width) 200 | 201 | for i in range(data_len): 202 | data_x[i] = self._process_image(x_numpy_data[i]) 203 | data_u[i] = torch.from_numpy(u_numpy_data[i]) 204 | data_x_next[i] = self._process_image(x_next_numpy_data[i]) 205 | 206 | data_set = (data_x, data_u, data_x_next) 207 | 208 | with open(self.data_path + "{:d}_{:.0f}.pt".format(self.sample_size, self.noise), "wb") as f: 209 | torch.save(data_set, f) 210 | -------------------------------------------------------------------------------- /ilqr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | from ilqr_utils import ( 9 | backward, 10 | compute_latent_traj, 11 | forward, 12 | get_x_data, 13 | latent_cost, 14 | random_actions_trajs, 15 | refresh_actions_trajs, 16 | save_traj, 17 | seq_jacobian, 18 | update_horizon_start, 19 | ) 20 | from mdp.cartpole_mdp import CartPoleMDP 21 | from mdp.pendulum_mdp import PendulumMDP 22 | from mdp.plane_obstacles_mdp import PlanarObstaclesMDP 23 | from mdp.three_pole_mdp import ThreePoleMDP 24 | from pcc_model import PCC 25 | 26 | 27 | seed = 2020 28 | random.seed(seed) 29 | os.environ["PYTHONHASHSEED"] = str(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 34 | torch.backends.cudnn.benchmark = False 35 | torch.backends.cudnn.deterministic = True 36 | torch.set_default_dtype(torch.float64) 37 | 38 | config_path = { 39 | "plane": "ilqr_config/plane.json", 40 | "swing": "ilqr_config/swing.json", 41 | "balance": "ilqr_config/balance.json", 42 | "cartpole": "ilqr_config/cartpole.json", 43 | "threepole": "ilqr_config/threepole.json", 44 | } 45 | env_data_dim = { 46 | "planar": (1600, 2, 2), 47 | "pendulum": ((2, 48, 48), 3, 1), 48 | "cartpole": ((2, 80, 80), 8, 1), 49 | "threepole": ((2, 80, 80), 8, 3), 50 | } 51 | 52 | 53 | def main(args): 54 | task_name = args.task 55 | assert task_name in ["planar", "balance", "swing", "cartpole", "threepole", "pendulum_gym", "mountain_car"] 56 | env_name = "pendulum" if task_name in ["balance", "swing"] else task_name 57 | 58 | setting_path = args.setting_path 59 | setting = os.path.basename(os.path.normpath(setting_path)) 60 | noise = args.noise 61 | epoch = args.epoch 62 | x_dim, z_dim, u_dim = env_data_dim[env_name] 63 | if env_name in ["planar", "pendulum"]: 64 | x_dim = np.prod(x_dim) 65 | 66 | ilqr_result_path = "iLQR_result/" + "_".join([task_name, str(setting), str(noise), str(epoch)]) 67 | if not os.path.exists(ilqr_result_path): 68 | os.makedirs(ilqr_result_path) 69 | with open(ilqr_result_path + "/settings", "w") as f: 70 | json.dump(args.__dict__, f, indent=2) 71 | 72 | # each trained model will perform 10 random tasks 73 | all_task_configs = [] 74 | for task_counter in range(10): 75 | # config for this task 76 | with open(config_path[task_name]) as f: 77 | config = json.load(f) 78 | 79 | # sample random start and goal state 80 | s_start_min, s_start_max = config["start_min"], config["start_max"] 81 | config["s_start"] = np.random.uniform(low=s_start_min, high=s_start_max) 82 | s_goal = config["goal"][np.random.choice(len(config["goal"]))] 83 | config["s_goal"] = np.array(s_goal) 84 | 85 | all_task_configs.append(config) 86 | 87 | # the folder where all trained models are saved 88 | log_folders = [ 89 | os.path.join(setting_path, dI) 90 | for dI in os.listdir(setting_path) 91 | if os.path.isdir(os.path.join(setting_path, dI)) 92 | ] 93 | log_folders.sort() 94 | 95 | # statistics on all trained models 96 | avg_model_percent = 0.0 97 | best_model_percent = 0.0 98 | for log in log_folders: 99 | with open(log + "/settings", "r") as f: 100 | settings = json.load(f) 101 | armotized = settings["armotized"] 102 | 103 | log_base = os.path.basename(os.path.normpath(log)) 104 | model_path = ilqr_result_path + "/" + log_base 105 | if not os.path.exists(model_path): 106 | os.makedirs(model_path) 107 | print("iLQR for " + log_base) 108 | 109 | # load the trained model 110 | model = PCC(armotized, x_dim, z_dim, u_dim, env_name) 111 | model.load_state_dict(torch.load(log + "/model_" + str(epoch), map_location="cpu")) 112 | model.eval() 113 | dynamics = model.dynamics 114 | encoder = model.encoder 115 | 116 | # run the task with 10 different start and goal states for a particular model 117 | avg_percent = 0.0 118 | for task_counter, config in enumerate(all_task_configs): 119 | 120 | print("Performing task %d: " % (task_counter) + str(config["task"])) 121 | 122 | # environment specification 123 | horizon = config["horizon_prob"] 124 | plan_len = config["plan_len"] 125 | 126 | # ilqr specification 127 | R_z = config["q_weight"] * np.eye(z_dim) 128 | R_u = config["r_weight"] * np.eye(u_dim) 129 | num_uniform = config["uniform_trajs"] 130 | num_extreme = config["extreme_trajs"] 131 | ilqr_iters = config["ilqr_iters"] 132 | inv_regulator_init = config["pinv_init"] 133 | inv_regulator_multi = config["pinv_mult"] 134 | inv_regulator_max = config["pinv_max"] 135 | alpha_init = config["alpha_init"] 136 | alpha_mult = config["alpha_mult"] 137 | alpha_min = config["alpha_min"] 138 | 139 | s_start = config["s_start"] 140 | s_goal = config["s_goal"] 141 | 142 | # mdp 143 | if env_name == "planar": 144 | mdp = PlanarObstaclesMDP(goal=s_goal, goal_thres=config["distance_thresh"], noise=noise) 145 | elif env_name == "pendulum": 146 | mdp = PendulumMDP(frequency=config["frequency"], noise=noise, torque=config["torque"]) 147 | elif env_name == "cartpole": 148 | mdp = CartPoleMDP(frequency=config["frequency"], noise=noise) 149 | elif env_name == "threepole": 150 | mdp = ThreePoleMDP(frequency=config["frequency"], noise=noise, torque=config["torque"]) 151 | # get z_start and z_goal 152 | x_start = get_x_data(mdp, s_start, config) 153 | x_goal = get_x_data(mdp, s_goal, config) 154 | with torch.no_grad(): 155 | z_start = encoder(x_start).mean 156 | z_goal = encoder(x_goal).mean 157 | z_start = z_start.squeeze().numpy() 158 | z_goal = z_goal.squeeze().numpy() 159 | 160 | # initialize actions trajectories 161 | all_actions_trajs = random_actions_trajs(mdp, num_uniform, num_extreme, plan_len) 162 | 163 | # perform reciding horizon iLQR 164 | s_start_horizon = np.copy(s_start) # s_start and z_start is changed at each horizon 165 | z_start_horizon = np.copy(z_start) 166 | obs_traj = [mdp.render(s_start).squeeze()] 167 | goal_counter = 0.0 168 | for plan_iter in range(1, horizon + 1): 169 | latent_cost_list = [None] * len(all_actions_trajs) 170 | # iterate over all trajectories 171 | for traj_id in range(len(all_actions_trajs)): 172 | # initialize the inverse regulator 173 | inv_regulator = inv_regulator_init 174 | for iter in range(1, ilqr_iters + 1): 175 | u_seq = all_actions_trajs[traj_id] 176 | z_seq = compute_latent_traj(z_start_horizon, u_seq, dynamics) 177 | # compute the linearization matrices 178 | A_seq, B_seq = seq_jacobian(dynamics, z_seq, u_seq) 179 | # run backward 180 | k_small, K_big = backward(R_z, R_u, z_seq, u_seq, z_goal, A_seq, B_seq, inv_regulator) 181 | current_cost = latent_cost(R_z, R_u, z_seq, z_goal, u_seq) 182 | # forward using line search 183 | alpha = alpha_init 184 | accept = False # if any alpha is accepted 185 | while alpha > alpha_min: 186 | z_seq_cand, u_seq_cand = forward( 187 | z_seq, all_actions_trajs[traj_id], k_small, K_big, dynamics, alpha 188 | ) 189 | cost_cand = latent_cost(R_z, R_u, z_seq_cand, z_goal, u_seq_cand) 190 | if cost_cand < current_cost: # accept the trajectory candidate 191 | accept = True 192 | all_actions_trajs[traj_id] = u_seq_cand 193 | latent_cost_list[traj_id] = cost_cand 194 | break 195 | else: 196 | alpha *= alpha_mult 197 | if accept: 198 | inv_regulator = inv_regulator_init 199 | else: 200 | inv_regulator *= inv_regulator_multi 201 | if inv_regulator > inv_regulator_max: 202 | break 203 | 204 | for i in range(len(latent_cost_list)): 205 | if latent_cost_list[i] is None: 206 | latent_cost_list[i] = np.inf 207 | traj_opt_id = np.argmin(latent_cost_list) 208 | action_chosen = all_actions_trajs[traj_opt_id][0] 209 | s_start_horizon, z_start_horizon = update_horizon_start( 210 | mdp, s_start_horizon, action_chosen, encoder, config 211 | ) 212 | 213 | obs_traj.append(mdp.render(s_start_horizon).squeeze()) 214 | goal_counter += mdp.reward_function(s_start_horizon) 215 | 216 | all_actions_trajs = refresh_actions_trajs( 217 | all_actions_trajs, 218 | traj_opt_id, 219 | mdp, 220 | np.min([plan_len, horizon - plan_iter]), 221 | num_uniform, 222 | num_extreme, 223 | ) 224 | 225 | # compute the percentage close to goal 226 | success_rate = goal_counter / horizon 227 | print("Success rate: %.2f" % (success_rate)) 228 | percent = success_rate 229 | avg_percent += success_rate 230 | with open(model_path + "/result.txt", "a+") as f: 231 | f.write(config["task"] + ": " + str(percent) + "\n") 232 | 233 | # save trajectory as gif file 234 | gif_path = model_path + "/task_{:01d}.gif".format(task_counter + 1) 235 | save_traj(obs_traj, mdp.render(s_goal).squeeze(), gif_path, config["task"]) 236 | 237 | avg_percent = avg_percent / 10 238 | print("Average success rate: " + str(avg_percent)) 239 | print("====================================") 240 | avg_model_percent += avg_percent 241 | if avg_percent > best_model_percent: 242 | best_model = log_base 243 | best_model_percent = avg_percent 244 | with open(model_path + "/result.txt", "a+") as f: 245 | f.write("Average percentage: " + str(avg_percent)) 246 | 247 | avg_model_percent = avg_model_percent / len(log_folders) 248 | with open(ilqr_result_path + "/result.txt", "w") as f: 249 | f.write("Average percentage of all models: " + str(avg_model_percent) + "\n") 250 | f.write("Best model: " + best_model + ", best percentage: " + str(best_model_percent)) 251 | 252 | 253 | if __name__ == "__main__": 254 | parser = argparse.ArgumentParser(description="run iLQR") 255 | parser.add_argument("--task", required=True, type=str, help="task to perform") 256 | parser.add_argument("--setting_path", required=True, type=str, help="path to load trained models") 257 | parser.add_argument("--noise", type=float, default=0.0, help="noise level for mdp") 258 | parser.add_argument("--epoch", type=int, default=2000, help="number of epochs to load model") 259 | args = parser.parse_args() 260 | 261 | main(args) 262 | -------------------------------------------------------------------------------- /ilqr_config/balance.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "balance", 3 | 4 | "pinv_init": 1e-5, 5 | "pinv_mult": 2.0, 6 | "pinv_max": 10.0, 7 | "alpha_init": 1.0, 8 | "alpha_mult": 0.5, 9 | "alpha_min": 1e-4, 10 | 11 | "start_min": [0, 0], 12 | "start_max": [0, 0], 13 | "goal": [[0, 0]], 14 | 15 | "q_weight": 50, 16 | "r_weight": 1, 17 | 18 | "frequency": 50, 19 | "noise": 0.0, 20 | "torque": 0.65, 21 | "epoch": "5000", 22 | 23 | "ilqr_iters": 4, 24 | "horizon_prob": 100, 25 | "plan_len": 10, 26 | "uniform_trajs": 3, 27 | "extreme_trajs": 3, 28 | 29 | "obs_shape": [2, 48, 48], 30 | "action_dim": 1, 31 | "latent_dim": 3 32 | } 33 | -------------------------------------------------------------------------------- /ilqr_config/cartpole.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "cartpole", 3 | 4 | "pinv_init": 1e-5, 5 | "pinv_mult": 2.0, 6 | "pinv_max": 10.0, 7 | "alpha_init": 1.0, 8 | "alpha_mult": 0.5, 9 | "alpha_min": 1e-4, 10 | 11 | "start_min": [-0.39269908169872414, 0, 0, 0], 12 | "start_max": [0.39269908169872414, 0, 0, 0], 13 | "goal": [[0, 0, 0, 0]], 14 | 15 | "q_weight": 100, 16 | "r_weight": 1, 17 | 18 | "frequency": 50, 19 | "noise": 0.0, 20 | "epoch": "5000", 21 | 22 | "ilqr_iters": 4, 23 | "horizon_prob": 50, 24 | "plan_len": 5, 25 | "uniform_trajs": 3, 26 | "extreme_trajs": 3, 27 | 28 | "obs_shape": [2, 80, 80], 29 | "action_dim": 1, 30 | "latent_dim": 8 31 | } -------------------------------------------------------------------------------- /ilqr_config/plane.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "plane", 3 | 4 | "pinv_init": 1e-5, 5 | "pinv_mult": 2.0, 6 | "pinv_max": 10.0, 7 | "alpha_init": 1.0, 8 | "alpha_mult": 0.5, 9 | "alpha_min": 1e-4, 10 | 11 | "start_min": [3, 3], 12 | "start_max": [7, 7], 13 | "goal": [[37, 37], [37, 3], [3, 37]], 14 | "distance_thresh": 2, 15 | 16 | "q_weight": 10, 17 | "r_weight": 1, 18 | 19 | "noise": 0.0, 20 | "epoch": "5000", 21 | 22 | "ilqr_iters": 4, 23 | "horizon_prob": 40, 24 | "plan_len": 10, 25 | "uniform_trajs": 3, 26 | "extreme_trajs": 3, 27 | 28 | "obs_shape": [40, 40], 29 | "action_dim": 2, 30 | "latent_dim": 2 31 | } 32 | -------------------------------------------------------------------------------- /ilqr_config/swing.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "swing", 3 | 4 | "pinv_init": 1e-5, 5 | "pinv_mult": 2.0, 6 | "pinv_max": 10.0, 7 | "alpha_init": 1.0, 8 | "alpha_mult": 0.5, 9 | "alpha_min": 1e-4, 10 | 11 | "start_min": [2.827433388230814, 0], 12 | "start_max": [2.9845130209103035, 0], 13 | "goal": [[0, 0]], 14 | 15 | "q_weight": 50, 16 | "r_weight": 1, 17 | 18 | "frequency": 50, 19 | "noise": 0.0, 20 | "torque": 0.65, 21 | "epoch": "5000", 22 | 23 | "ilqr_iters": 4, 24 | "horizon_prob": 100, 25 | "plan_len": 10, 26 | "uniform_trajs": 3, 27 | "extreme_trajs": 3, 28 | 29 | "obs_shape": [2, 48, 48], 30 | "action_dim": 1, 31 | "latent_dim": 3 32 | } 33 | -------------------------------------------------------------------------------- /ilqr_config/threepole.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "threepole", 3 | 4 | "pinv_init": 1e-5, 5 | "pinv_mult": 2.0, 6 | "pinv_max": 10.0, 7 | "alpha_init": 1.0, 8 | "alpha_mult": 0.5, 9 | "alpha_min": 1e-4, 10 | 11 | "start_min": [2.9845130209103035, 0, 2.0420352248333655, 0, 1.0367255756846319, 0], 12 | "start_max": [2.9845130209103035, 0, 2.0420352248333655, 0, 1.0367255756846319, 0], 13 | "goal": [[0, 0, 0, 0, 0, 0]], 14 | 15 | "q_weight": 50, 16 | "r_weight": 1, 17 | 18 | "frequency": 50, 19 | "noise": 0.0, 20 | "torque": 1.0, 21 | 22 | "ilqr_iters": 4, 23 | "horizon_prob": 200, 24 | "plan_len": 20, 25 | "uniform_trajs": 3, 26 | "extreme_trajs": 3, 27 | 28 | "obs_shape": [2, 80, 80], 29 | "action_dim": 3, 30 | "latent_dim": 8 31 | } -------------------------------------------------------------------------------- /ilqr_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | from matplotlib.animation import FuncAnimation, writers 5 | 6 | 7 | np.random.seed(0) 8 | 9 | 10 | def cost_dz(R_z, z, z_goal): 11 | # compute the first-order deravative of latent cost w.r.t z 12 | z_diff = np.expand_dims(z - z_goal, axis=-1) 13 | return np.squeeze(2 * np.matmul(R_z, z_diff)) 14 | 15 | 16 | def cost_du(R_u, u): 17 | # compute the first-order deravative of latent cost w.r.t u 18 | return np.atleast_1d(np.squeeze(2 * np.matmul(R_u, np.expand_dims(u, axis=-1)))) 19 | 20 | 21 | def cost_dzz(R_z): 22 | # compute the second-order deravative of latent cost w.r.t z 23 | return 2 * R_z 24 | 25 | 26 | def cost_duu(R_u): 27 | # compute the second-order deravative of latent cost w.r.t u 28 | return 2 * R_u 29 | 30 | 31 | def cost_duz(z, u): 32 | # compute the second-order deravative of latent cost w.r.t uz 33 | return np.zeros((u.shape[-1], z.shape[-1])) 34 | 35 | 36 | def latent_cost(R_z, R_u, z_seq, z_goal, u_seq): 37 | z_diff = np.expand_dims(z_seq - z_goal, axis=-1) 38 | cost_z = np.squeeze(np.matmul(np.matmul(z_diff.transpose((0, 2, 1)), R_z), z_diff)) 39 | u_seq_reshaped = np.expand_dims(u_seq, axis=-1) 40 | cost_u = np.squeeze(np.matmul(np.matmul(u_seq_reshaped.transpose((0, 2, 1)), R_u), u_seq_reshaped)) 41 | return np.sum(cost_z) + np.sum(cost_u) 42 | 43 | 44 | def one_step_back(R_z, R_u, z, u, z_goal, A, B, V_prime_next_z, V_prime_next_zz, mu_inv_regulator): 45 | """ 46 | V_prime_next_z: first order derivative of the value function at time step t+1 47 | V_prime_next_zz: second order derivative of the value function at time tep t+1 48 | A: derivative of F(z, u) w.r.t z at z_bar_t, u_bar_t 49 | B: derivative of F(z, u) w.r.t u at z_bar_t, u_bar_t 50 | """ 51 | # compute Q_z, Q_u, Q_zz, Q_uu, Q_uz using cost function, A, B and V 52 | Q_z = cost_dz(R_z, z, z_goal) + np.matmul(A.transpose(), V_prime_next_z) 53 | Q_u = cost_du(R_u, u) + np.matmul(B.transpose(), V_prime_next_z) 54 | Q_zz = cost_dzz(R_z) + np.matmul(np.matmul(A.transpose(), V_prime_next_zz), A) 55 | Q_uz = cost_duz(z, u) + np.matmul(np.matmul(B.transpose(), V_prime_next_zz), A) 56 | Q_uu = cost_duu(R_u) + np.matmul(np.matmul(B.transpose(), V_prime_next_zz), B) 57 | 58 | # compute k and K matrix, add regularization to Q_uu 59 | Q_uu_regularized = Q_uu + mu_inv_regulator * np.eye(Q_uu.shape[0]) 60 | Q_uu_in = np.linalg.inv(Q_uu_regularized) 61 | k = -np.matmul(Q_uu_in, Q_u) 62 | K = -np.matmul(Q_uu_in, Q_uz) 63 | 64 | # compute V_z and V_zz using k and K 65 | V_prime_z = Q_z + np.matmul(Q_uz.transpose(), k) 66 | V_prime_zz = Q_zz + np.matmul(Q_uz.transpose(), K) 67 | return k, K, V_prime_z, V_prime_zz 68 | 69 | 70 | def backward(R_z, R_u, z_seq, u_seq, z_goal, A_seq, B_seq, inv_regulator): 71 | """ 72 | do the backward pass 73 | return a sequence of k and K matrices 74 | """ 75 | # first and second order derivative of the value function at the last time step 76 | V_prime_next_z = cost_dz(R_z, z_seq[-1], z_goal) 77 | V_prime_next_zz = cost_dzz(R_z) 78 | k, K = [], [] 79 | act_seq_len = len(u_seq) 80 | for t in reversed(range(act_seq_len)): 81 | k_t, K_t, V_prime_z, V_prime_zz = one_step_back( 82 | R_z, R_u, z_seq[t], u_seq[t], z_goal, A_seq[t], B_seq[t], V_prime_next_z, V_prime_next_zz, inv_regulator 83 | ) 84 | k.insert(0, k_t) 85 | K.insert(0, K_t) 86 | V_prime_next_z, V_prime_next_zz = V_prime_z, V_prime_zz 87 | return k, K 88 | 89 | 90 | def forward(z_seq, u_seq, k, K, dynamics, alpha): 91 | """ 92 | update the trajectory, given k and K 93 | !!!! update using the linearization matricies (A and B), not the learned dynamics 94 | """ 95 | z_seq_new = [] 96 | z_seq_new.append(z_seq[0]) 97 | u_seq_new = [] 98 | for i in range(0, len(u_seq)): 99 | u_new = u_seq[i] + alpha * k[i] + np.matmul(K[i], z_seq_new[i] - z_seq[i]) 100 | u_seq_new.append(u_new) 101 | with torch.no_grad(): 102 | z_new = dynamics(torch.from_numpy(z_seq_new[i]).unsqueeze(0), torch.from_numpy(u_new).unsqueeze(0))[0].mean 103 | z_seq_new.append(z_new.squeeze().numpy()) 104 | return np.array(z_seq_new), np.array(u_seq_new) 105 | 106 | 107 | # def forward(u_seq, k_seq, K_seq, A_seq, B_seq, alpha): 108 | # """ 109 | # update the trajectory, given k and K 110 | # !!!! update using the linearization matricies (A and B), not the learned dynamics 111 | # """ 112 | # u_new_seq = [] 113 | # plan_len = len(u_seq) 114 | # z_dim = K_seq[0].shape[1] 115 | # for i in range(0, plan_len): 116 | # if i == 0: 117 | # z_delta = np.zeros(z_dim) 118 | # else: 119 | # z_delta = np.matmul(A_seq[i-1], z_delta) + np.matmul(B_seq[i-1], u_delta) 120 | # u_delta = alpha * (k_seq[i] + np.matmul(K_seq[i], z_delta)) 121 | # u_new_seq.append(u_seq[i] + u_delta) 122 | # return np.array(u_new_seq) 123 | 124 | 125 | def get_x_data(mdp, state, config): 126 | image_data = mdp.render(state).squeeze() 127 | x_dim = config["obs_shape"] 128 | if config["task"] == "plane": 129 | x_dim = np.prod(x_dim) 130 | x_data = torch.from_numpy(image_data).double().view(x_dim).unsqueeze(0) 131 | elif config["task"] in ["swing", "balance"]: 132 | x_dim = np.prod(x_dim) 133 | x_data = np.vstack((image_data, image_data)) 134 | x_data = torch.from_numpy(x_data).double().view(x_dim).unsqueeze(0) 135 | elif config["task"] in ["cartpole", "threepole"]: 136 | x_data = torch.zeros(size=(2, 80, 80)) 137 | x_data[0, :, :] = torch.from_numpy(image_data) 138 | x_data[1, :, :] = torch.from_numpy(image_data) 139 | x_data = x_data.unsqueeze(0) 140 | return x_data 141 | 142 | 143 | def update_horizon_start(mdp, s, u, encoder, config): 144 | s_next = mdp.transition_function(s, u) 145 | if config["task"] == "plane": 146 | x_next = get_x_data(mdp, s_next, config) 147 | elif config["task"] in ["swing", "balance"]: 148 | obs = mdp.render(s).squeeze() 149 | obs_next = mdp.render(s_next).squeeze() 150 | obs_stacked = np.vstack((obs, obs_next)) 151 | x_dim = np.prod(config["obs_shape"]) 152 | x_next = torch.from_numpy(obs_stacked).view(x_dim).unsqueeze(0).double() 153 | elif config["task"] in ["cartpole", "threepole"]: 154 | obs = mdp.render(s).squeeze() 155 | obs_next = mdp.render(s_next).squeeze() 156 | x_next = torch.zeros(size=config["obs_shape"]) 157 | x_next[0, :, :] = torch.from_numpy(obs) 158 | x_next[1, :, :] = torch.from_numpy(obs_next) 159 | x_next = x_next.unsqueeze(0) 160 | with torch.no_grad(): 161 | z_next = encoder(x_next).mean 162 | return s_next, z_next.squeeze().numpy() 163 | 164 | 165 | def random_uniform_actions(mdp, plan_len): 166 | # create a trajectory of random actions 167 | random_actions = [] 168 | for i in range(plan_len): 169 | action = mdp.sample_random_action() 170 | random_actions.append(action) 171 | return np.array(random_actions) 172 | 173 | 174 | def random_extreme_actions(mdp, plan_len): 175 | # create a trajectory of extreme actions 176 | extreme_actions = [] 177 | for i in range(plan_len): 178 | action = mdp.sample_extreme_action() 179 | extreme_actions.append(action) 180 | return np.array(extreme_actions) 181 | 182 | 183 | def random_actions_trajs(mdp, num_uniform, num_extreme, plan_len): 184 | actions_trajs = [] 185 | for i in range(num_uniform): 186 | actions_trajs.append(random_uniform_actions(mdp, plan_len)) 187 | for j in range(num_extreme): 188 | actions_trajs.append(random_extreme_actions(mdp, plan_len)) 189 | return actions_trajs 190 | 191 | 192 | def refresh_actions_trajs(actions_trajs, traj_opt_id, mdp, length, num_uniform, num_extreme): 193 | for traj_id in range(len(actions_trajs)): 194 | if traj_id == traj_opt_id: 195 | actions_trajs[traj_id] = actions_trajs[traj_id][1:] 196 | if len(actions_trajs[traj_id]) < length: 197 | # Duplicate last action. 198 | actions_trajs[traj_id] = np.append( 199 | actions_trajs[traj_id], actions_trajs[traj_id][-1].reshape(1, -1), axis=0 200 | ) 201 | continue 202 | if traj_id < num_uniform: 203 | actions_trajs[traj_id] = random_uniform_actions(mdp, length) 204 | else: 205 | actions_trajs[traj_id] = random_extreme_actions(mdp, length) 206 | return actions_trajs 207 | 208 | 209 | def update_seq_act(z_seq, z_start, u_seq, k, K, dynamics): 210 | """ 211 | update the trajectory, given k and K 212 | """ 213 | z_new = z_start 214 | u_seq_new = [] 215 | for i in range(0, len(u_seq)): 216 | u_new = u_seq[i] + k[i] + np.matmul(K[i], (z_new - z_seq[i])) 217 | with torch.no_grad(): 218 | z_new = dynamics(torch.from_numpy(z_new).view(1, -1), torch.from_numpy(u_new).view(1, -1))[0].mean 219 | z_new = z_new.squeeze().numpy() 220 | u_seq_new.append(u_new) 221 | return np.array(u_seq_new) 222 | 223 | 224 | def compute_latent_traj(z_start, u_seq, dynamics): 225 | plan_len = len(u_seq) 226 | z_seq = [z_start] 227 | for i in range(plan_len): 228 | z = torch.from_numpy(z_seq[i]).view(1, -1).double() 229 | u = torch.from_numpy(u_seq[i]).view(1, -1).double() 230 | with torch.no_grad(): 231 | z_next = dynamics(z, u)[0].mean 232 | z_seq.append(z_next.squeeze().numpy()) 233 | return z_seq 234 | 235 | 236 | def jacobian(dynamics, z, u): 237 | """ 238 | compute the jacobian of F(z,u) w.r.t z, u 239 | """ 240 | z_dim = z.shape[0] 241 | u_dim = u.shape[0] 242 | z_tensor = torch.from_numpy(z).view(1, -1).double() 243 | u_tensor = torch.from_numpy(u).view(1, -1).double() 244 | if dynamics.armotized: 245 | _, A, B = dynamics(z_tensor, u_tensor) 246 | return A.squeeze().view(z_dim, z_dim).numpy(), B.squeeze().view(z_dim, u_dim).numpy() 247 | z_tensor, u_tensor = z_tensor.squeeze().repeat(z_dim, 1), u_tensor.squeeze().repeat(z_dim, 1) 248 | z_tensor = z_tensor.detach().requires_grad_(True) 249 | u_tensor = u_tensor.detach().requires_grad_(True) 250 | z_next = dynamics(z_tensor, u_tensor)[0].mean 251 | grad_inp = torch.eye(z_dim) 252 | A, B = torch.autograd.grad(z_next, [z_tensor, u_tensor], [grad_inp, grad_inp]) 253 | return A.numpy(), B.numpy() 254 | 255 | 256 | def seq_jacobian(dynamics, z_seq, u_seq): 257 | """ 258 | compute the jacobian w.r.t each pair in the trajectory 259 | """ 260 | A_seq, B_seq = [], [] 261 | horizon = len(u_seq) 262 | for i in range(horizon): 263 | z, u = z_seq[i], u_seq[i] 264 | A, B = jacobian(dynamics, z, u) 265 | A_seq.append(A) 266 | B_seq.append(B) 267 | return A_seq, B_seq 268 | 269 | 270 | def save_traj(images, image_goal, gif_path, task): 271 | # save trajectory as gif file 272 | fig, aa = plt.subplots(1, 2) 273 | m1 = aa[0].matshow(images[0], cmap=plt.cm.gray, vmin=0.0, vmax=1.0) 274 | aa[0].set_title("Time step 0") 275 | aa[0].set_yticklabels([]) 276 | aa[0].set_xticklabels([]) 277 | m2 = aa[1].matshow(image_goal, cmap=plt.cm.gray, vmin=0.0, vmax=1.0) 278 | aa[1].set_title("goal") 279 | aa[1].set_yticklabels([]) 280 | aa[1].set_xticklabels([]) 281 | fig.tight_layout() 282 | 283 | def updatemat2(t): 284 | m1.set_data(images[t]) 285 | aa[0].set_title("Time step " + str(t)) 286 | m2.set_data(image_goal) 287 | return m1, m2 288 | 289 | frames = len(images) 290 | if task == "plane": 291 | fps = 2 292 | else: 293 | fps = 20 294 | 295 | anim = FuncAnimation(fig, updatemat2, frames=frames, interval=200, blit=True, repeat=True) 296 | Writer = writers["imagemagick"] # animation.writers.avail 297 | writer = Writer(fps=fps, metadata=dict(artist="Me"), bitrate=1800) 298 | 299 | anim.save(gif_path, writer=writer) 300 | -------------------------------------------------------------------------------- /latent_map_pendulum.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from colour import Color 8 | from mdp.pendulum_mdp import PendulumMDP 9 | from pcc_model import PCC 10 | from torchvision.transforms import ToTensor 11 | 12 | 13 | red = Color("red") 14 | blue = Color("blue") 15 | num_angles = 100 16 | num_each_angle = 20 17 | 18 | np.random.seed(0) 19 | torch.manual_seed(0) 20 | 21 | 22 | def map_angle_color(num_angles, mdp): 23 | colors = list(red.range_to(blue, num_angles)) 24 | colors_rgb = [color.rgb for color in colors] 25 | all_angles = np.linspace(start=mdp.angle_range[0], stop=mdp.angle_range[1], num=num_angles) 26 | angle_color_map = dict(zip(all_angles, colors_rgb)) 27 | return angle_color_map, colors_rgb 28 | 29 | 30 | def assign_latent_color(model, angel, mdp): 31 | # the same angle corresponds to multiple states -> multiple latent vectors 32 | # map an angle to multiple latent vectors corresponding to that angle 33 | angle_vels = np.linspace( 34 | start=mdp.angular_velocity_range[0], stop=mdp.angular_velocity_range[1], num=num_each_angle 35 | ) 36 | all_z_for_angle = [] 37 | for i in range(num_each_angle): 38 | ang_velocity = angle_vels[i] 39 | s = np.array([angel, ang_velocity]) 40 | x = mdp.render(s).squeeze() 41 | # take a random action 42 | u = mdp.sample_random_action() 43 | s_next = mdp.transition_function(s, u) 44 | x_next = mdp.render(s_next).squeeze() 45 | # reverse order: the state we want to represent is x not x_next 46 | x_with_history = np.vstack((x_next, x)) 47 | x_with_history = ToTensor()(x_with_history).double() 48 | with torch.no_grad(): 49 | z = model.encode(x_with_history.view(-1, x_with_history.shape[-1] * x_with_history.shape[-2])).mean 50 | all_z_for_angle.append(z.detach().squeeze().numpy()) 51 | return all_z_for_angle 52 | 53 | 54 | def show_latent_map(model, mdp): 55 | angle_color_map, colors_rgb = map_angle_color(num_angles, mdp) 56 | colors_list = [] 57 | for color in colors_rgb: 58 | for i in range(num_each_angle): 59 | colors_list.append(list(color)) 60 | all_z = [] 61 | 62 | for angle in angle_color_map: 63 | all_z_for_angle = assign_latent_color(model, angle, mdp) 64 | all_z += all_z_for_angle 65 | all_z = np.array(all_z) 66 | 67 | z_min = np.min(all_z, axis=0) 68 | z_max = np.max(all_z, axis=0) 69 | all_z = 2 * (all_z - z_min) / (z_max - z_min) - 1.0 70 | all_z = all_z * 35 71 | 72 | ax = plt.axes(projection="3d") 73 | ax.set_xlim([-100, 100]) 74 | ax.set_ylim([-100, 100]) 75 | ax.set_zlim([-100, 100]) 76 | xdata = all_z[:, 0] 77 | ydata = all_z[:, 1] 78 | zdata = all_z[:, 2] 79 | 80 | ax.scatter(xdata, ydata, zdata, c=colors_list, marker="o", s=10) 81 | plt.show() 82 | 83 | 84 | def main(args): 85 | log_path = args.log_path 86 | epoch = args.epoch 87 | 88 | mdp = PendulumMDP() 89 | 90 | # load the specified model 91 | with open(log_path + "/settings", "r") as f: 92 | settings = json.load(f) 93 | armotized = settings["armotized"] 94 | model = PCC(armotized=armotized, x_dim=4608, z_dim=3, u_dim=1, env="pendulum") 95 | model.load_state_dict(torch.load(log_path + "/model_" + str(epoch), map_location="cpu")) 96 | model.eval() 97 | 98 | show_latent_map(model, mdp) 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser(description="train pcc model") 103 | 104 | parser.add_argument("--log_path", required=True, type=str, help="path to trained model") 105 | parser.add_argument("--epoch", required=True, type=int, help="load model corresponding to this epoch") 106 | args = parser.parse_args() 107 | 108 | main(args) 109 | -------------------------------------------------------------------------------- /latent_map_planar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from colour import Color 8 | from mdp.plane_obstacles_mdp import PlanarObstaclesMDP 9 | from pcc_model import PCC 10 | from PIL import Image, ImageDraw 11 | 12 | 13 | blue = Color("blue") 14 | colors = list(blue.range_to(Color("red"), 40)) 15 | colors_rgb = [color.rgb for color in colors] 16 | start, end = 0, 40 17 | width, height = 40, 40 18 | 19 | 20 | # states corresponding to obstacles' positions 21 | def get_invalid_state(mdp): 22 | invalid_pos = [] 23 | for x in range(start, end): 24 | for y in range(start, end): 25 | s = [x, y] 26 | if not mdp.is_valid_state(np.array(s)): 27 | invalid_pos.append(s) 28 | return invalid_pos 29 | 30 | 31 | def color_gradient(): 32 | img = Image.new("RGB", (width, height), "#FFFFFF") 33 | draw = ImageDraw.Draw(img) 34 | 35 | for i, color in zip(range(start, end), colors_rgb): 36 | r1, g1, b1 = color[0] * 255.0, color[1] * 255.0, color[2] * 255.0 37 | draw.line((i, start, i, end), fill=(int(r1), int(g1), int(b1))) 38 | 39 | return img 40 | 41 | 42 | def get_true_map(mdp): 43 | invalid_pos = get_invalid_state(mdp) 44 | color_gradient_img = color_gradient() 45 | img_scaled = Image.new("RGB", (width * 10, height * 10), "#FFFFFF") 46 | draw = ImageDraw.Draw(img_scaled) 47 | for y in range(start, end): 48 | for x in range(start, end): 49 | if [y, x] in invalid_pos: 50 | continue 51 | else: 52 | x_scaled, y_scaled = x * 10, y * 10 53 | draw.ellipse( 54 | (x_scaled - 2, y_scaled - 2, x_scaled + 2, y_scaled + 2), fill=color_gradient_img.getpixel((x, y)) 55 | ) 56 | img_arr_scaled = np.array(img_scaled) / 255.0 57 | return img_arr_scaled 58 | 59 | 60 | def draw_latent_map(model, mdp): 61 | invalid_pos = get_invalid_state(mdp) 62 | img = color_gradient() 63 | # compute latent z 64 | all_z = [] 65 | for x in range(start, end): 66 | for y in range(start, end): 67 | s = np.array([x, y]) 68 | if [x, y] in invalid_pos: 69 | all_z.append(np.zeros(2)) 70 | else: 71 | with torch.no_grad(): 72 | obs = torch.Tensor(mdp.render(s)).unsqueeze(0).view(-1, 1600).double() 73 | if next(model.parameters()).is_cuda: 74 | obs = obs.cuda() 75 | mu = model.encode(obs).mean 76 | z = mu.squeeze().cpu().numpy() 77 | all_z.append(np.copy(z)) 78 | all_z = np.array(all_z) 79 | 80 | avg_norm_2 = np.mean(np.sum(all_z ** 2, axis=1)) 81 | print("avg norm 2: " + str(avg_norm_2)) 82 | 83 | # normalize and scale to plot 84 | z_min = np.min(all_z, axis=0) 85 | all_z = np.round(20 * (all_z - z_min) + 30).astype(np.int) 86 | 87 | # plot 88 | latent_map = {} 89 | i = 0 90 | for x in range(start, end): 91 | for y in range(start, end): 92 | latent_map[(x, y)] = all_z[i] 93 | i += 1 94 | 95 | img_latent = Image.new("RGB", (mdp.width * 10, mdp.height * 10), "#FFFFFF") 96 | draw = ImageDraw.Draw(img_latent) 97 | for k in latent_map: 98 | x, y = k 99 | if [x, y] in invalid_pos: 100 | continue 101 | else: 102 | x_scaled, y_scaled = latent_map[k][1], latent_map[k][0] 103 | draw.ellipse((x_scaled - 2, y_scaled - 2, x_scaled + 2, y_scaled + 2), fill=img.getpixel((y, x))) 104 | return img_latent 105 | 106 | 107 | def show_latent_map(model, mdp): 108 | true_map = get_true_map(mdp) 109 | latent_map = draw_latent_map(model, mdp) 110 | latent_map = np.array(latent_map) / 255.0 111 | 112 | f, axarr = plt.subplots(1, 2, figsize=(15, 15)) 113 | axarr[0].imshow(true_map) 114 | axarr[1].imshow(latent_map) 115 | plt.show() 116 | 117 | 118 | def main(args): 119 | log_path = args.log_path 120 | epoch = args.epoch 121 | 122 | mdp = PlanarObstaclesMDP() 123 | 124 | # load the specified model 125 | with open(log_path + "/settings", "r") as f: 126 | settings = json.load(f) 127 | armotized = settings["armotized"] 128 | model = PCC(armotized, 1600, 2, 2, "planar") 129 | model.load_state_dict(torch.load(log_path + "/model_" + str(epoch), map_location="cpu")) 130 | model.eval() 131 | 132 | show_latent_map(model, mdp) 133 | 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser(description="train pcc model") 137 | 138 | parser.add_argument("--log_path", required=True, type=str, help="path to trained model") 139 | parser.add_argument("--epoch", required=True, type=int, help="load model corresponding to this epoch") 140 | args = parser.parse_args() 141 | 142 | main(args) 143 | 144 | # from mdp.plane_obstacles_mdp import PlanarObstaclesMDP 145 | # from pcc_model import PCC 146 | # mdp = PlanarObstaclesMDP() 147 | # start = 0 148 | # end = 39 149 | # invalid_pos = get_invalid_state(mdp, start, end) 150 | # img_arr, img = random_gradient(start, end, mdp.width, mdp.height, invalid_pos) 151 | # get_true_map(mdp, start, end, mdp.width, mdp.height, img) 152 | 153 | # mdp = PlanarObstaclesMDP() 154 | # model = PCC(armotized=False, x_dim=1600, z_dim=2, u_dim=2, env = 'planar').cuda() 155 | # model.load_state_dict(torch.load('./new_mdp_result/planar/log_10/model_5000')) 156 | # latent_map = draw_latent_map(model, mdp) 157 | # latent_map.show() 158 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from networks import MultivariateNormalDiag 3 | from torch.distributions.kl import kl_divergence 4 | 5 | 6 | torch.set_default_dtype(torch.float64) 7 | 8 | 9 | def bernoulli(x, p): 10 | p = p.probs 11 | log_p_x = torch.sum(x * torch.log(1e-15 + p) + (1 - x) * torch.log(1e-15 + 1 - p), dim=-1) 12 | log_p_x = torch.mean(log_p_x) 13 | return log_p_x 14 | 15 | 16 | def KL(normal_1, normal_2): 17 | kl = kl_divergence(normal_1, normal_2) 18 | kl = torch.mean(kl) 19 | return kl 20 | 21 | 22 | def entropy(p): 23 | H = p.entropy() 24 | H = torch.mean(H) 25 | return H 26 | 27 | 28 | def gaussian(z, p): 29 | log_p_z = p.log_prob(z) 30 | log_p_z = torch.mean(log_p_z) 31 | return log_p_z 32 | 33 | 34 | def vae_bound(x, p_x, p_z): 35 | recon_loss = -bernoulli(x, p_x) 36 | regularization_loss = KL(p_z, MultivariateNormalDiag(torch.zeros_like(p_z.mean), torch.ones_like(p_z.stddev))) 37 | return recon_loss + regularization_loss 38 | 39 | 40 | def ae_loss(x, p_x): 41 | recon_loss = -bernoulli(x, p_x) 42 | return recon_loss 43 | 44 | 45 | def curvature(model, z, u, delta, armotized): 46 | z_alias = z.detach().requires_grad_(True) 47 | u_alias = u.detach().requires_grad_(True) 48 | eps_z = torch.normal(mean=torch.zeros_like(z), std=torch.empty_like(z).fill_(delta)) 49 | eps_u = torch.normal(mean=torch.zeros_like(u), std=torch.empty_like(u).fill_(delta)) 50 | 51 | z_bar = z_alias + eps_z 52 | u_bar = u_alias + eps_u 53 | 54 | f_z_bar, A, B = model.transition(z_bar, u_bar) 55 | f_z_bar = f_z_bar.mean 56 | f_z, A, B = model.transition(z_alias, u_alias) 57 | f_z = f_z.mean 58 | 59 | z_dim, u_dim = z.size(1), u.size(1) 60 | if not armotized: 61 | _, B = get_jacobian(model.dynamics, z_alias, u_alias) 62 | else: 63 | A = A.view(-1, z_dim, z_dim) 64 | B = B.view(-1, z_dim, u_dim) 65 | 66 | (grad_z,) = torch.autograd.grad(f_z, z_alias, grad_outputs=eps_z, retain_graph=True, create_graph=True) 67 | grad_u = torch.bmm(B, eps_u.view(-1, u_dim, 1)).squeeze() 68 | taylor_error = f_z_bar - (grad_z + grad_u) - f_z 69 | cur_loss = torch.mean(torch.sum(taylor_error.pow(2), dim=1)) 70 | return cur_loss 71 | 72 | 73 | def get_jacobian(dynamics, batched_z, batched_u): 74 | """ 75 | compute the jacobian of F(z,u) w.r.t z, u 76 | """ 77 | batch_size = batched_z.size(0) 78 | z_dim = batched_z.size(-1) 79 | # u_dim = batched_u.size(-1) 80 | 81 | z, u = batched_z.unsqueeze(1), batched_u.unsqueeze(1) # batch_size, 1, input_dim 82 | z, u = z.repeat(1, z_dim, 1), u.repeat(1, z_dim, 1) # batch_size, output_dim, input_dim 83 | z_next = dynamics(z, u)[0].mean 84 | grad_inp = torch.eye(z_dim).reshape(1, z_dim, z_dim).repeat(batch_size, 1, 1).cuda() 85 | all_A, all_B = torch.autograd.grad(z_next, [z, u], [grad_inp, grad_inp], create_graph=True, retain_graph=True) 86 | return all_A, all_B 87 | -------------------------------------------------------------------------------- /mdp/cartpole_mdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mdp.common import StateIndex, wrap 3 | from mdp.pole_base import PoleBase 4 | from PIL import Image, ImageDraw 5 | from scipy.integrate import solve_ivp 6 | 7 | 8 | class CartPoleMDP(PoleBase): 9 | goal_range = [-np.pi / 10, np.pi / 10] 10 | 11 | # state range 12 | angle_range = [-np.pi, np.pi] 13 | angular_velocity_range = [-2 * np.pi, 2 * np.pi] 14 | position_range = [-2.4, 2.4] 15 | velocity_range = [-6.0, 6.0] 16 | 17 | # sampling range 18 | angle_samp_range = 2 * goal_range 19 | 20 | # action range 21 | action_dim = 1 22 | action_range = np.array([-10.0, 10.0]) 23 | 24 | def __init__(self, width=80, height=80, frequency=50, noise=0.0, render_width=6): 25 | """ 26 | Args: 27 | width: width of the rendered image. 28 | height: height of the rendered image. 29 | frequency: the simulator frequency, i.e., the number of steps in 1 second. 30 | noise: noise level 31 | render_width: width of the pole in the rendered image. 32 | """ 33 | self.width = width 34 | self.height = height 35 | self.time_interval = 1 / frequency 36 | self.noise = noise 37 | 38 | self.render_width = render_width 39 | self.render_length = height * 0.65 40 | self.cart_render_size = (width / 10.0, height / 20.0) 41 | 42 | super(CartPoleMDP, self).__init__() 43 | 44 | def take_step(self, s, u): 45 | # clip the action 46 | u = np.clip(u, self.action_range[0], self.action_range[1]) 47 | 48 | # concatenate s and u to pass through ds_dt 49 | s_aug = np.append(s, u) 50 | 51 | # solve the differientable equation to compute next state 52 | s_next = solve_ivp(self.ds_dt, (0.0, self.time_interval), s_aug).y[0:4, -1] # last index is the action applied 53 | 54 | # project state to the valid space. 55 | s_next[StateIndex.THETA] = wrap(s_next[StateIndex.THETA], self.angle_range[0], self.angle_range[1]) 56 | s_next[StateIndex.THETA_DOT] = np.clip( 57 | s_next[StateIndex.THETA_DOT], self.angular_velocity_range[0], self.angular_velocity_range[1] 58 | ) 59 | s_next[StateIndex.X_DOT] = np.clip(s_next[StateIndex.X_DOT], self.velocity_range[0], self.velocity_range[1]) 60 | 61 | return s_next 62 | 63 | def ds_dt(self, t, s_augmented): 64 | mass_combined = self.cart_mass + self.pend_mass 65 | 66 | theta = s_augmented[StateIndex.THETA] 67 | theta_dot = s_augmented[StateIndex.THETA_DOT] 68 | x_dot = s_augmented[StateIndex.X_DOT] 69 | force = s_augmented[StateIndex.CARTPOLE_ACTION] 70 | 71 | sin_theta = np.sin(theta) 72 | cos_theta = np.cos(theta) 73 | calc_help = force + self.pend_mass * self.length * theta_dot ** 2 * sin_theta 74 | 75 | # derivative of theta_dot 76 | theta_double_dot_num = mass_combined * self.earth_gravity * sin_theta - cos_theta * calc_help 77 | theta_double_dot_denum = 4.0 / 3 * mass_combined * self.length - self.pend_mass * self.length * cos_theta ** 2 78 | theta_double_dot = theta_double_dot_num / theta_double_dot_denum 79 | 80 | # derivative of x_dot 81 | x_double_dot_num = calc_help - self.pend_mass * self.length * theta_double_dot * cos_theta 82 | x_double_dot = x_double_dot_num / mass_combined 83 | 84 | return np.array([theta_dot, theta_double_dot, x_dot, x_double_dot, 0.0]) 85 | 86 | def render(self, s): 87 | # black background. 88 | im = Image.new("L", (self.width, self.height)) 89 | draw = ImageDraw.Draw(im) 90 | 91 | draw.rectangle((0, 0, self.width, self.height), fill=0) 92 | 93 | # cart location. 94 | x_center_image = im.size[0] / 2.0 95 | y_center_cart = im.size[1] - 2 * self.cart_render_size[1] - 2 96 | x_center_cart = x_center_image + (s[StateIndex.X] / self.position_range[1]) * ( 97 | self.width / 2.0 - 1.0 * self.cart_render_size[0] 98 | ) 99 | 100 | # pole location. 101 | x_pole_end = x_center_cart + np.sin([s[StateIndex.THETA]]) * self.render_length 102 | y_pole_end = y_center_cart - np.cos([s[StateIndex.THETA]]) * self.render_length 103 | 104 | # draw cart. 105 | draw.rectangle( 106 | ( 107 | x_center_cart - self.cart_render_size[0], 108 | y_center_cart - self.cart_render_size[1], 109 | x_center_cart + self.cart_render_size[0], 110 | y_center_cart + self.cart_render_size[1], 111 | ), 112 | fill=255, 113 | ) 114 | 115 | # draw pole. 116 | draw.line((x_center_cart, y_center_cart, x_pole_end, y_pole_end), width=self.render_width, fill=255) 117 | 118 | return np.expand_dims(np.asarray(im) / 255.0, axis=-1) 119 | 120 | def is_fail(self, s): 121 | """check if the current state is failed""" 122 | angle = s[StateIndex.THETA] 123 | position = s[StateIndex.X] 124 | return not ( 125 | (self.goal_range[0] < angle < self.goal_range[1]) 126 | and (self.position_range[0] < position < self.position_range[1]) 127 | ) 128 | 129 | def sample_random_state(self): 130 | angle = np.random.uniform(self.angle_samp_range[0], self.angle_samp_range[1]) 131 | angle_rate = np.random.uniform(self.angular_velocity_range[0], self.angular_velocity_range[1]) 132 | pos = np.random.uniform(self.position_range[0], self.position_range[1]) 133 | vel = np.random.uniform(self.velocity_range[0], self.velocity_range[1]) 134 | true_state = np.array([angle, angle_rate, pos, vel]) 135 | return true_state 136 | -------------------------------------------------------------------------------- /mdp/common.py: -------------------------------------------------------------------------------- 1 | """Common functions for the MDP package.""" 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | 6 | def wrap(x, low, high): 7 | """Wraps data between low and high boundaries.""" 8 | diff = high - low 9 | while x > high: 10 | x -= diff 11 | while x < low: 12 | x += diff 13 | return x 14 | 15 | 16 | class StateIndex(object): 17 | """Flexible way to index states in the CartPole Domain. 18 | 19 | This class enumerates the different indices used when indexing the state. 20 | e.g. ``s[StateIndex.THETA]`` is guaranteed to return the angle state. 21 | """ 22 | 23 | THETA, THETA_DOT = 0, 1 24 | X, X_DOT = 2, 3 25 | PEND_ACTION = 2 26 | CARTPOLE_ACTION = 4 27 | 28 | THETA_1, THETA_2, THETA_3 = 0, 2, 4 29 | THETA_1_DOT, THETA_2_DOT, THETA_3_DOT = 1, 3, 5 30 | TORQUE_3_1, TORQUE_3_2, TORQUE_3_3 = 6, 7, 8 31 | -------------------------------------------------------------------------------- /mdp/pendulum_mdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mdp.common import StateIndex, wrap 3 | from mdp.pole_base import PoleBase 4 | from PIL import Image, ImageDraw 5 | from scipy.integrate import solve_ivp 6 | 7 | 8 | class PendulumMDP(PoleBase): 9 | # goal range 10 | goal_range = [-np.pi / 6, np.pi / 6] 11 | 12 | # state range 13 | angle_range = [-np.pi, np.pi] 14 | angular_velocity_range = [-3 * np.pi, 3 * np.pi] 15 | 16 | action_dim = 1 17 | 18 | def __init__(self, width=48, height=48, frequency=50, noise=0.0, torque=1.0, render_width=4): 19 | """ 20 | Args: 21 | width: width of the rendered image. 22 | height: height of the rendered image. 23 | frequency: the simulator frequency, i.e., the number of steps in 1 second. 24 | noise: noise level 25 | torque: the maximal torque which can be applied 26 | render_width: width of the pendulum in the rendered image. 27 | """ 28 | self.width = width 29 | self.height = height 30 | self.time_interval = 1.0 / frequency 31 | self.noise = noise 32 | self.action_range = np.array([-torque, torque]) 33 | 34 | self.render_width = render_width 35 | self.render_length = (width / 2) - 2 36 | 37 | super(PendulumMDP, self).__init__() 38 | 39 | def take_step(self, s, u): 40 | # clip the action 41 | u = np.clip(u, self.action_range[0], self.action_range[1]) 42 | 43 | # concatenate s and u to pass through ds_dt 44 | s_aug = np.append(s, u) 45 | 46 | # solve the differientable equation to compute next state 47 | s_next = solve_ivp(self.ds_dt, (0.0, self.time_interval), s_aug).y[0:2, -1] # last index is the action applied 48 | 49 | # project state to the valid range 50 | s_next[StateIndex.THETA] = wrap(s_next[StateIndex.THETA], self.angle_range[0], self.angle_range[1]) 51 | s_next[StateIndex.THETA_DOT] = np.clip( 52 | s_next[StateIndex.THETA_DOT], self.angular_velocity_range[0], self.angular_velocity_range[1] 53 | ) 54 | 55 | return s_next 56 | 57 | def ds_dt(self, t, s_augmented): 58 | theta = s_augmented[StateIndex.THETA] 59 | theta_dot = s_augmented[StateIndex.THETA_DOT] 60 | torque = s_augmented[StateIndex.PEND_ACTION] 61 | 62 | # theta is w.r.t the upside vertical position, which is = pi - theta in tedrake's note 63 | sine = np.sin(np.pi - theta) 64 | theta_prime_num = self.pend_mass * self.earth_gravity * self.length * sine - torque 65 | theta_prime_denum = 1.0 / 3.0 * self.pend_mass * self.length ** 2 # the moment of inertia 66 | theta_double_dot = theta_prime_num / theta_prime_denum 67 | 68 | return np.array([theta_dot, theta_double_dot, 0.0]) 69 | 70 | def render(self, s): 71 | im = Image.new("L", (self.width, self.height)) 72 | draw = ImageDraw.Draw(im) 73 | # black background 74 | draw.rectangle((0, 0, self.width, self.height), fill=0) 75 | 76 | # pendulum location. 77 | x_center = im.size[0] / 2.0 78 | y_center = im.size[1] / 2.0 79 | x_end = x_center + np.sin(s[0]) * self.render_length 80 | y_end = y_center - np.cos(s[0]) * self.render_length 81 | 82 | # white pendulum 83 | draw.line((x_center, y_center, x_end, y_end), width=self.render_width, fill=255) 84 | 85 | return np.expand_dims(np.asarray(im) / 255.0, axis=-1) 86 | 87 | def is_fail(self, s): 88 | return False 89 | 90 | def sample_random_state(self): 91 | angle = np.random.uniform(self.angle_range[0], self.angle_range[1]) 92 | angle_rate = np.random.uniform(self.angular_velocity_range[0], self.angular_velocity_range[1]) 93 | true_state = np.array([angle, angle_rate]) 94 | return true_state 95 | -------------------------------------------------------------------------------- /mdp/plane_obstacles_mdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageDraw 3 | 4 | 5 | class PlanarObstaclesMDP(object): 6 | width = 40 7 | height = 40 8 | 9 | obstacles = np.array([[20.5, 5.5], [20.5, 13.5], [20.5, 27.5], [20.5, 35.5], [10.5, 20.5], [30.5, 20.5]]) 10 | obstacles_r = 2.5 # radius of the obstacles when rendering 11 | half_agent_size = 1.5 # robot half-width 12 | 13 | position_range = np.array([half_agent_size, width - half_agent_size]) 14 | 15 | action_dim = 2 16 | 17 | def __init__(self, rw_rendered=1, max_step=3, goal=[37, 37], goal_thres=2, noise=0): 18 | self.rw_rendered = rw_rendered 19 | self.max_step = max_step 20 | self.action_range = np.array([-max_step, max_step]) 21 | self.goal = goal 22 | self.goal_thres = goal_thres 23 | self.noise = noise 24 | super(PlanarObstaclesMDP, self).__init__() 25 | 26 | def is_valid_state(self, s): 27 | # check if the agent runs out of map 28 | if np.any(s < self.position_range[0]) or np.any(s > self.position_range[1]): 29 | return False 30 | 31 | # check if the agent crosses any obstacle (the obstacle is inside the agent) 32 | top, bot = s[0] - self.half_agent_size, s[0] + self.half_agent_size 33 | left, right = s[1] - self.half_agent_size, s[1] + self.half_agent_size 34 | for obs in self.obstacles: 35 | if top <= obs[0] <= bot and left <= obs[1] <= right: 36 | return False 37 | return True 38 | 39 | def take_step(self, s, u, anneal_ratio=0.9): # compute the next state given the current state and action 40 | u = np.clip(u, self.action_range[0], self.action_range[1]) 41 | 42 | s_next = np.clip(s + u, self.position_range[0], self.position_range[1]) 43 | if not self.is_valid_state(s_next): 44 | return s 45 | return s_next 46 | 47 | def transition_function(self, s, u): # compute next state and add noise 48 | s_next = self.take_step(s, u) 49 | # sample noise until get a valid next state 50 | sample_noise = self.noise * np.random.randn(*s_next.shape) 51 | return np.clip(s_next + sample_noise, self.position_range[0], self.position_range[1]) 52 | 53 | def render(self, s): 54 | top, bottom, left, right = self.get_pixel_location(s) 55 | x = self.generate_env() 56 | x[top:bottom, left:right] = 1.0 # robot is white on black background 57 | return x 58 | 59 | def get_pixel_location(self, s): 60 | # return the location of agent when rendered 61 | center_x, center_y = int(round(s[0])), int(round(s[1])) 62 | top = center_x - self.rw_rendered 63 | bottom = center_x + self.rw_rendered 64 | left = center_y - self.rw_rendered 65 | right = center_y + self.rw_rendered 66 | return top, bottom, left, right 67 | 68 | def generate_env(self): 69 | """ 70 | return the image with 6 obstacles 71 | """ 72 | img_arr = np.zeros(shape=(self.width, self.height)) 73 | 74 | img_env = Image.fromarray(img_arr) 75 | draw = ImageDraw.Draw(img_env) 76 | for y, x in self.obstacles: 77 | draw.ellipse( 78 | ( 79 | int(x) - int(self.obstacles_r), 80 | int(y) - int(self.obstacles_r), 81 | int(x) + int(self.obstacles_r), 82 | int(y) + int(self.obstacles_r), 83 | ), 84 | fill=255, 85 | ) 86 | img_env = img_env.convert("L") 87 | 88 | img_arr = np.array(img_env) / 255.0 89 | return img_arr 90 | 91 | def is_goal(self, s): 92 | return np.sqrt(np.sum((s - self.goal) ** 2)) <= self.goal_thres 93 | 94 | def is_fail(self, s): 95 | return False 96 | 97 | def reward_function(self, s): 98 | if self.is_goal(s): 99 | reward = 1 100 | else: 101 | reward = 0 102 | return reward 103 | 104 | def sample_random_state(self): 105 | while True: 106 | s = np.random.uniform(self.half_agent_size, self.width - self.half_agent_size, size=2) 107 | if self.is_valid_state(s): 108 | return s 109 | 110 | def is_low_error(self, u, epsilon=0.1): 111 | rounded_u = np.round(u) 112 | diff = np.abs(u - rounded_u) 113 | return np.all(diff <= epsilon) 114 | 115 | def is_valid_action(self, s, u): 116 | return self.is_low_error(u) and self.is_valid_state(s + u) 117 | 118 | def sample_valid_random_action(self, s): 119 | while True: 120 | u = np.random.uniform(self.action_range[0], self.action_range[1], size=self.action_dim) 121 | if self.is_valid_action(s, u): 122 | return u 123 | 124 | def sample_random_action(self): 125 | return np.random.uniform(self.action_range[0], self.action_range[1], size=self.action_dim) 126 | 127 | def sample_extreme_action(self): 128 | x_direction = np.random.choice([self.action_range[0], self.action_range[1]]) 129 | y_direction = np.random.choice([self.action_range[0], self.action_range[1]]) 130 | return np.array([x_direction, y_direction]) 131 | -------------------------------------------------------------------------------- /mdp/pole_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from mdp.common import StateIndex 6 | from numpy import pi 7 | 8 | 9 | root_path = str(Path(os.path.dirname(os.path.abspath(__file__))).parent) 10 | os.sys.path.append(root_path) 11 | 12 | 13 | class PoleBase(object): 14 | """ 15 | The base class to store common attributes for pendulum and cartpole 16 | Basically, the MDP works as follows: 17 | - Define a function ds_dt, which computes the derivatives of state w.r.t time step t (from Tedrake's textbook) 18 | - The time interval between 2 consecutive time steps is determined by frequency, time_interval = 1./frequency 19 | - Use solve_ivp package to solve the differiential equation and compute the next state 20 | """ 21 | 22 | # environment specifications 23 | earth_gravity = 9.81 24 | pend_mass = 0.1 25 | cart_mass = 1.0 26 | length = 0.5 27 | # reward if close to goal 28 | goal_reward = 1 29 | 30 | def __init__(self): 31 | assert np.all( 32 | 2 * pi / self.time_interval > np.abs(self.angular_velocity_range) 33 | ), """ 34 | WARNING: Your step size is too small or the angular rate limit is too large. 35 | This could lead to a situation in which the pole is at the same state in 2 36 | consecutive time step (the pole finishes a round). 37 | """ 38 | 39 | def take_step(self, s, u): # compute the next state given the current state and action 40 | pass 41 | 42 | def ds_dt(self, t, s): # derivative of s w.r.t t 43 | pass 44 | 45 | def transition_function(self, s, u): # compute next state and add noise 46 | s_next = self.take_step(s, u) 47 | # add noise 48 | s_next += self.noise * np.random.randn(*s_next.shape) 49 | return s_next 50 | 51 | def render(self, s): 52 | pass 53 | 54 | def is_goal(self, s): 55 | """Check if the pendulum is in goal region""" 56 | angle = s[StateIndex.THETA] 57 | return self.goal_range[0] <= angle <= self.goal_range[1] 58 | 59 | def is_fail(self, s): 60 | pass 61 | 62 | def reward_function(self, s): 63 | """Reward function.""" 64 | return int(self.is_goal(s)) * self.goal_reward 65 | 66 | def sample_random_state(self): 67 | pass 68 | 69 | def sample_random_action(self): 70 | """Sample a random action from action range.""" 71 | return np.array([np.random.uniform(self.action_range[0], self.action_range[1])]) 72 | 73 | def sample_extreme_action(self): 74 | """Sample a random extreme action from action range.""" 75 | return np.array([np.random.choice([self.action_range[0], self.action_range[1]])]) 76 | -------------------------------------------------------------------------------- /mdp/three_pole_mdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mdp.common import StateIndex, wrap 3 | from mdp.pole_base import PoleBase 4 | from PIL import Image, ImageDraw 5 | from scipy.integrate import solve_ivp 6 | 7 | 8 | class ThreePoleMDP(PoleBase): 9 | # goal_range 10 | goal_range = [-np.pi / 6, np.pi / 6] 11 | 12 | # state range 13 | angle_1_range = [-np.pi, np.pi] 14 | angle_2_range = [-2.0 * np.pi / 3, 2.0 * np.pi / 3] 15 | angle_3_range = [-np.pi / 3.0, np.pi / 3.0] 16 | angular_velocity_range = [-0.5 * np.pi, 0.5 * np.pi] 17 | 18 | mass_pend_1 = 0.1 19 | mass_pend_2 = 0.1 20 | mass_pend_3 = 0.1 21 | length_1 = 0.5 22 | length_2 = 0.5 23 | length_3 = 0.5 24 | 25 | action_dim = 3 26 | 27 | def __init__(self, width=80, height=80, frequency=50, noise=0.0, torque=1.0, line_width=5): 28 | """ 29 | Args: 30 | width: width of the rendered image. 31 | height: height of the rendered image. 32 | frequency: the simulator frequency, i.e., the number of steps in 1 second. 33 | noise: noise level 34 | torque: the maximal torque which can be applied 35 | render_width: width of the line in the rendered image. 36 | """ 37 | self.width = width 38 | self.height = height 39 | self.time_interval = 1.0 / frequency 40 | self.noise = noise 41 | self.action_range = np.array([-torque, torque]) 42 | 43 | self.line_width = line_width 44 | self.visual_length = (width / 6) - 2 45 | 46 | super(ThreePoleMDP, self).__init__() 47 | 48 | def take_step(self, s, a): 49 | """Computes the next state from the current state and the action.""" 50 | torque_1_action = a[0] 51 | # Clip the action to valid values. 52 | torque_1_action = np.clip(torque_1_action, self.action_range[0], self.action_range[1]) 53 | 54 | torque_2_action = a[1] 55 | # Clip the action to valid values. 56 | torque_2_action = np.clip(torque_2_action, self.action_range[0], self.action_range[1]) 57 | 58 | torque_3_action = a[2] 59 | # Clip the action to valid values. 60 | torque_3_action = np.clip(torque_3_action, self.action_range[0], self.action_range[1]) 61 | 62 | # Add the action to the state so it can be passed to _dsdt. 63 | s_aug = np.append(s, np.array([torque_1_action, torque_2_action, torque_3_action])) 64 | 65 | # Compute next state. 66 | # Type of integration and integration step. 67 | dt_in = self.time_interval 68 | if ( 69 | self.goal_range[0] < s[StateIndex.THETA_1] < self.goal_range[1] 70 | and self.goal_range[0] < s[StateIndex.THETA_1] + s[StateIndex.THETA_2] < self.goal_range[1] 71 | and self.goal_range[0] 72 | < s[StateIndex.THETA_1] + s[StateIndex.THETA_2] + s[StateIndex.THETA_3] 73 | < self.goal_range[1] 74 | ): 75 | dt_in = self.time_interval 76 | else: 77 | dt_in = self.time_interval * 2.5 78 | 79 | ns = solve_ivp(self.ds_dt, (0.0, dt_in), s_aug).y[0:6, -1] 80 | 81 | # Project variables to valid space. 82 | theta_1 = wrap(ns[StateIndex.THETA_1], self.angle_1_range[0], self.angle_1_range[1]) 83 | ns[StateIndex.THETA_1] = np.clip(theta_1, self.angle_1_range[0], self.angle_1_range[1]) 84 | ns[StateIndex.THETA_1_DOT] = np.clip( 85 | ns[StateIndex.THETA_1_DOT], self.angular_velocity_range[0], self.angular_velocity_range[1] 86 | ) 87 | 88 | theta_2 = wrap(ns[StateIndex.THETA_2], self.angle_2_range[0], self.angle_2_range[1]) 89 | ns[StateIndex.THETA_2] = np.clip(theta_2, self.angle_2_range[0], self.angle_2_range[1]) 90 | ns[StateIndex.THETA_2_DOT] = np.clip( 91 | ns[StateIndex.THETA_2_DOT], self.angular_velocity_range[0], self.angular_velocity_range[1] 92 | ) 93 | 94 | theta_3 = wrap(ns[StateIndex.THETA_3], self.angle_3_range[0], self.angle_3_range[1]) 95 | ns[StateIndex.THETA_3] = np.clip(theta_3, self.angle_3_range[0], self.angle_3_range[1]) 96 | ns[StateIndex.THETA_3_DOT] = np.clip( 97 | ns[StateIndex.THETA_3_DOT], self.angular_velocity_range[0], self.angular_velocity_range[1] 98 | ) 99 | 100 | return ns 101 | 102 | def ds_dt(self, t, s_augmented): 103 | """Calculates derivatives at a given state.""" 104 | # Unused. 105 | del t 106 | 107 | # Extracting current state and action. 108 | theta_1 = s_augmented[StateIndex.THETA_1] 109 | theta_1_dot = s_augmented[StateIndex.THETA_1_DOT] 110 | theta_2 = s_augmented[StateIndex.THETA_2] 111 | theta_2_dot = s_augmented[StateIndex.THETA_2_DOT] 112 | theta_3 = s_augmented[StateIndex.THETA_3] 113 | theta_3_dot = s_augmented[StateIndex.THETA_3_DOT] 114 | 115 | theta_dot = np.array([theta_1_dot, theta_2_dot, theta_3_dot]) 116 | 117 | torque_1 = s_augmented[StateIndex.TORQUE_3_1] 118 | torque_2 = s_augmented[StateIndex.TORQUE_3_2] 119 | torque_3 = s_augmented[StateIndex.TORQUE_3_3] 120 | torque = np.array([torque_1, torque_2, torque_3]) 121 | 122 | # Useful mid-calculation. 123 | # NOTE: the angle here is clock-wise 124 | # which is -\theta from tedrake's reference 125 | 126 | sine_1 = np.sin(-theta_1) 127 | sine_2 = np.sin(-theta_2) 128 | sine_3 = np.sin(-theta_3) 129 | sine_2_3 = np.sin(-(theta_2 + theta_3)) 130 | 131 | # cosine_1 = np.cos(np.pi - theta_1) 132 | cosine_2 = np.cos(-theta_2) 133 | cosine_3 = np.cos(-theta_3) 134 | cosine_2_3 = np.cos(-(theta_2 + theta_3)) 135 | 136 | sine_1_2 = np.sin(-(theta_1 + theta_2)) 137 | # cosine_1_2 = np.cos(np.pi - (theta_1 + theta_2)) 138 | sine_1_2_3 = np.sin(-(theta_1 + theta_2 + theta_3)) 139 | 140 | i_1 = 1.0 / 3.0 * self.mass_pend_1 * self.length_1 ** 2 141 | i_2 = 1.0 / 3.0 * self.mass_pend_2 * self.length_2 ** 2 142 | i_3 = 1.0 / 3.0 * self.mass_pend_3 * self.length_3 ** 2 143 | 144 | length_c1 = self.length_1 / 2.0 145 | length_c2 = self.length_2 / 2.0 146 | length_c3 = self.length_3 / 2.0 147 | 148 | # point mass version, not a rod, so no inertia and no center-of-mass 149 | alpha_1 = i_1 + (self.mass_pend_2 + self.mass_pend_3) * self.length_1 ** 2 150 | alpha_2 = i_2 + self.mass_pend_3 * self.length_2 ** 2 151 | alpha_3 = (self.mass_pend_2 * length_c2 + self.mass_pend_3 * self.length_2) * self.length_1 152 | alpha_4 = i_3 153 | alpha_5 = self.mass_pend_3 * self.length_1 * length_c3 154 | alpha_6 = self.mass_pend_3 * self.length_2 * length_c3 155 | 156 | h_11 = alpha_1 + alpha_2 + alpha_4 + 2 * alpha_5 * cosine_2_3 + 2 * alpha_3 * cosine_2 + 2 * alpha_6 * cosine_3 157 | h_12 = alpha_2 + alpha_4 + alpha_3 * cosine_2 + alpha_5 * cosine_2_3 + 2 * alpha_6 * cosine_3 158 | h_13 = alpha_4 + alpha_5 * cosine_2_3 + alpha_6 * cosine_3 159 | 160 | h_21 = h_12 161 | h_22 = alpha_2 + alpha_4 + 2 * alpha_6 * cosine_3 162 | h_23 = alpha_4 + alpha_6 * cosine_3 163 | 164 | h_31 = h_13 165 | h_32 = h_23 166 | h_33 = alpha_4 167 | h_mat = np.array([[h_11, h_12, h_13], [h_21, h_22, h_23], [h_31, h_32, h_33]]) 168 | 169 | beta_1 = ( 170 | self.mass_pend_1 * length_c1 + self.mass_pend_2 * self.length_1 + self.mass_pend_3 * self.length_1 171 | ) * self.earth_gravity 172 | beta_2 = (self.mass_pend_2 * length_c2 + self.mass_pend_3 * self.length_2) * self.earth_gravity 173 | beta_3 = self.mass_pend_3 * self.earth_gravity * length_c3 174 | 175 | c_11 = ( 176 | alpha_5 * (theta_2_dot + theta_3_dot) * sine_2_3 177 | + alpha_3 * theta_2_dot * sine_2 178 | + alpha_6 * theta_3_dot * sine_3 179 | ) 180 | c_12 = ( 181 | alpha_5 * (theta_1_dot + theta_2_dot + theta_3_dot) * sine_2_3 182 | + alpha_3 * (theta_1_dot + theta_2_dot) * sine_2 183 | + alpha_6 * theta_3_dot * sine_3 184 | ) 185 | c_13 = (theta_1_dot + theta_2_dot + theta_3_dot) * (alpha_5 * sine_2_3 + alpha_6 * sine_3) 186 | c_21 = -alpha_5 * theta_1_dot * sine_2_3 - alpha_3 * theta_1_dot * sine_2 + alpha_6 * theta_3_dot * sine_3 187 | c_22 = alpha_6 * theta_3_dot * sine_3 188 | c_23 = alpha_6 * (theta_1_dot + theta_2_dot + theta_3_dot) * sine_3 189 | c_31 = -alpha_5 * theta_1_dot * sine_2_3 - alpha_6 * (theta_1_dot + theta_2_dot) * sine_3 190 | c_32 = -alpha_6 * (theta_1_dot + theta_2_dot) * sine_3 191 | c_33 = 0.0 192 | c_mat = np.array([[c_11, c_12, c_13], [c_21, c_22, c_23], [c_31, c_32, c_33]]) 193 | 194 | g_1 = -beta_1 * sine_1 - beta_2 * sine_1_2 - beta_3 * sine_1_2_3 195 | g_2 = -beta_2 * sine_1_2 - beta_3 * sine_1_2_3 196 | g_3 = -beta_3 * sine_1_2_3 197 | g_mat = np.array([g_1, g_2, g_3]) 198 | 199 | b_mat = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) 200 | 201 | theta_double_dot = -np.linalg.pinv(h_mat + 1e-6 * np.eye(len(h_mat))).dot( 202 | b_mat.dot(torque) - c_mat.dot(-theta_dot) - g_mat 203 | ) 204 | 205 | theta_1_double_dot = theta_double_dot[0] 206 | theta_2_double_dot = theta_double_dot[1] 207 | theta_3_double_dot = theta_double_dot[2] 208 | 209 | # Derivatives. 210 | return np.array( 211 | [ 212 | theta_1_dot, 213 | theta_1_double_dot, 214 | theta_2_dot, 215 | theta_2_double_dot, 216 | theta_3_dot, 217 | theta_3_double_dot, 218 | 0.0, 219 | 0.0, 220 | 0.0, 221 | ] 222 | ) 223 | 224 | def render(self, s): 225 | im = Image.new("L", (self.width, self.height)) 226 | draw = ImageDraw.Draw(im) 227 | # Draw background. 228 | draw.rectangle((0, 0, self.width, self.height), fill=0) 229 | 230 | # Pole 1 location. 231 | xstart_1 = im.size[0] / 2.0 232 | ystart_1 = im.size[1] / 2.0 233 | xend_1 = xstart_1 + np.sin(s[0]) * self.visual_length 234 | yend_1 = ystart_1 - np.cos(s[0]) * self.visual_length 235 | 236 | # Draw pole 1. 237 | draw.line((xstart_1, ystart_1, xend_1, yend_1), width=self.line_width, fill=255) 238 | 239 | # Pole 2 location. 240 | xstart_2 = xend_1 241 | ystart_2 = yend_1 242 | xend_2 = xstart_2 + np.sin(s[0] + s[2]) * self.visual_length 243 | yend_2 = ystart_2 - np.cos(s[0] + s[2]) * self.visual_length 244 | 245 | # Draw pole 2. 246 | draw.line((xstart_2, ystart_2, xend_2, yend_2), width=self.line_width, fill=255) 247 | 248 | # Pole 2 location. 249 | xstart_3 = xend_2 250 | ystart_3 = yend_2 251 | xend_3 = xstart_3 + np.sin(s[0] + s[2] + s[4]) * self.visual_length 252 | yend_3 = ystart_3 - np.cos(s[0] + s[2] + s[4]) * self.visual_length 253 | 254 | # Draw pole 2. 255 | draw.line((xstart_3, ystart_3, xend_3, yend_3), width=self.line_width, fill=255) 256 | 257 | return np.expand_dims(np.asarray(im) / 255.0, axis=-1) 258 | 259 | def is_fail(self, s): 260 | """Indicates whether the state results in failure.""" 261 | # Unused. 262 | del s 263 | return False 264 | 265 | def is_goal(self, s): 266 | """Inidicates whether the state achieves the goal.""" 267 | angle_1 = s[StateIndex.THETA_1] 268 | angle_2 = s[StateIndex.THETA_2] 269 | angle_3 = s[StateIndex.THETA_3] 270 | if ( 271 | self.goal_range[0] < angle_1 < self.goal_range[1] 272 | and self.goal_range[0] < angle_1 + angle_2 < self.goal_range[1] 273 | and self.goal_range[0] < angle_1 + angle_2 + angle_3 < self.goal_range[1] 274 | ): 275 | return True 276 | else: 277 | return False 278 | 279 | def sample_random_action(self): 280 | """Sample a random action from available force.""" 281 | return np.atleast_1d(np.random.uniform(self.action_range[0], self.action_range[1], self.action_dim)) 282 | 283 | def sample_extreme_action(self): 284 | """Sample a random extreme action from available force.""" 285 | return np.atleast_1d(np.random.choice([self.action_range[0], self.action_range[1]], self.action_dim)) 286 | 287 | def sample_random_state(self): 288 | """Sample a random state.""" 289 | angle_1 = np.random.uniform(self.angle_1_range[0], self.angle_1_range[1]) 290 | angle_1_rate = np.random.uniform(self.angular_velocity_range[0], self.angular_velocity_range[1]) 291 | angle_2 = np.random.uniform(self.angle_2_range[0], self.angle_2_range[1]) 292 | angle_2_rate = np.random.uniform(self.angular_velocity_range[0], self.angular_velocity_range[1]) 293 | angle_3 = np.random.uniform(self.angle_3_range[0], self.angle_3_range[1]) 294 | angle_3_rate = np.random.uniform(self.angular_velocity_range[0], self.angular_velocity_range[1]) 295 | true_state = np.array([angle_1, angle_1_rate, angle_2, angle_2_rate, angle_3, angle_3_rate]) 296 | return true_state 297 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.distributions.bernoulli import Bernoulli 4 | from torch.distributions.independent import Independent 5 | from torch.distributions.normal import Normal 6 | 7 | 8 | torch.set_default_dtype(torch.float64) 9 | 10 | 11 | def MultivariateNormalDiag(loc, scale_diag): 12 | if loc.dim() < 1: 13 | raise ValueError("loc must be at least one-dimensional.") 14 | return Independent(Normal(loc, scale_diag), 1) 15 | 16 | 17 | class Encoder(nn.Module): 18 | # P(z_t | x_t) and Q(z^_t+1 | x_t+1) 19 | def __init__(self, net_hidden, net_mean, net_logstd, x_dim, z_dim): 20 | super(Encoder, self).__init__() 21 | self.net_hidden = net_hidden 22 | self.net_mean = net_mean 23 | self.net_logstd = net_logstd 24 | self.x_dim = x_dim 25 | self.z_dim = z_dim 26 | 27 | def forward(self, x): 28 | # mean and variance of p(z|x) 29 | hidden_neurons = self.net_hidden(x) 30 | mean = self.net_mean(hidden_neurons) 31 | logstd = self.net_logstd(hidden_neurons) 32 | return MultivariateNormalDiag(mean, torch.exp(logstd)) 33 | 34 | 35 | class Decoder(nn.Module): 36 | # P(x_t+1 | z^_t+1) 37 | def __init__(self, net_hidden, net_logits, z_dim, x_dim): 38 | super(Decoder, self).__init__() 39 | self.net_hidden = net_hidden 40 | self.net_logits = net_logits 41 | self.z_dim = z_dim 42 | self.x_dim = x_dim 43 | 44 | def forward(self, z): 45 | hidden_neurons = self.net_hidden(z) 46 | logits = self.net_logits(hidden_neurons) 47 | return Bernoulli(logits=logits) 48 | 49 | 50 | class Dynamics(nn.Module): 51 | # P(z^_t+1 | z_t, u_t) 52 | def __init__(self, net_hidden, net_mean, net_logstd, net_A, net_B, z_dim, u_dim, armotized): 53 | super(Dynamics, self).__init__() 54 | self.net_hidden = net_hidden 55 | self.net_mean = net_mean 56 | self.net_logstd = net_logstd 57 | self.net_A = net_A 58 | self.net_B = net_B 59 | self.z_dim = z_dim 60 | self.u_dim = u_dim 61 | self.armotized = armotized 62 | 63 | def forward(self, z_t, u_t): 64 | z_u_t = torch.cat((z_t, u_t), dim=-1) 65 | hidden_neurons = self.net_hidden(z_u_t) 66 | mean = self.net_mean(hidden_neurons) + z_t # skip connection 67 | logstd = self.net_logstd(hidden_neurons) 68 | if self.armotized: 69 | A = self.net_A(hidden_neurons) 70 | B = self.net_B(hidden_neurons) 71 | else: 72 | A, B = None, None 73 | return MultivariateNormalDiag(mean, torch.exp(logstd)), A, B 74 | 75 | 76 | class BackwardDynamics(nn.Module): 77 | # Q(z_t | z^_t+1, x_t, u_t) 78 | def __init__(self, net_z, net_u, net_x, net_joint_hidden, net_joint_mean, net_joint_logstd, z_dim, u_dim, x_dim): 79 | super(BackwardDynamics, self).__init__() 80 | self.net_z = net_z 81 | self.net_u = net_u 82 | self.net_x = net_x 83 | self.net_joint_hidden = net_joint_hidden 84 | self.net_joint_mean = net_joint_mean 85 | self.net_joint_logstd = net_joint_logstd 86 | self.z_dim = z_dim 87 | self.u_dim = u_dim 88 | self.x_dim = x_dim 89 | 90 | def forward(self, z_t, u_t, x_t): 91 | z_t_out = self.net_z(z_t) 92 | u_t_out = self.net_u(u_t) 93 | x_t_out = self.net_x(x_t) 94 | 95 | hidden_neurons = self.net_joint_hidden(torch.cat((z_t_out, u_t_out, x_t_out), dim=-1)) 96 | mean = self.net_joint_mean(hidden_neurons) 97 | logstd = self.net_joint_logstd(hidden_neurons) 98 | return MultivariateNormalDiag(mean, torch.exp(logstd)) 99 | 100 | 101 | class PlanarEncoder(Encoder): 102 | def __init__(self, x_dim=1600, z_dim=2): 103 | net_hidden = nn.Sequential( 104 | nn.Linear(x_dim, 300), 105 | nn.ReLU(), 106 | nn.Linear(300, 300), 107 | nn.ReLU(), 108 | ) 109 | net_mean = nn.Linear(300, z_dim) 110 | net_logstd = nn.Linear(300, z_dim) 111 | super(PlanarEncoder, self).__init__(net_hidden, net_mean, net_logstd, x_dim, z_dim) 112 | 113 | 114 | class PlanarDecoder(Decoder): 115 | def __init__(self, z_dim=2, x_dim=1600): 116 | net_hidden = nn.Sequential( 117 | nn.Linear(z_dim, 300), 118 | nn.ReLU(), 119 | nn.Linear(300, 300), 120 | nn.ReLU(), 121 | ) 122 | net_logits = nn.Linear(300, x_dim) 123 | super(PlanarDecoder, self).__init__(net_hidden, net_logits, z_dim, x_dim) 124 | 125 | 126 | class PlanarDynamics(Dynamics): 127 | def __init__(self, armotized, z_dim=2, u_dim=2): 128 | net_hidden = nn.Sequential(nn.Linear(z_dim + u_dim, 20), nn.ReLU(), nn.Linear(20, 20), nn.ReLU()) 129 | net_mean = nn.Linear(20, z_dim) 130 | net_logstd = nn.Linear(20, z_dim) 131 | if armotized: 132 | net_A = nn.Linear(20, z_dim ** 2) 133 | net_B = nn.Linear(20, u_dim * z_dim) 134 | else: 135 | net_A, net_B = None, None 136 | super(PlanarDynamics, self).__init__(net_hidden, net_mean, net_logstd, net_A, net_B, z_dim, u_dim, armotized) 137 | 138 | 139 | class PlanarBackwardDynamics(BackwardDynamics): 140 | def __init__(self, z_dim=2, u_dim=2, x_dim=1600): 141 | net_z = nn.Linear(z_dim, 5) 142 | net_u = nn.Linear(u_dim, 5) 143 | net_x = nn.Linear(x_dim, 100) 144 | net_joint_hidden = nn.Sequential( 145 | nn.Linear(5 + 5 + 100, 100), 146 | nn.ReLU(), 147 | ) 148 | net_joint_mean = nn.Linear(100, z_dim) 149 | net_joint_logstd = nn.Linear(100, z_dim) 150 | super(PlanarBackwardDynamics, self).__init__( 151 | net_z, net_u, net_x, net_joint_hidden, net_joint_mean, net_joint_logstd, z_dim, u_dim, x_dim 152 | ) 153 | 154 | 155 | class PendulumEncoder(Encoder): 156 | def __init__(self, x_dim=4608, z_dim=3): 157 | net_hidden = nn.Sequential( 158 | nn.Linear(x_dim, 500), 159 | nn.ReLU(), 160 | nn.Linear(500, 500), 161 | nn.ReLU(), 162 | ) 163 | net_mean = nn.Linear(500, z_dim) 164 | net_logstd = nn.Linear(500, z_dim) 165 | super(PendulumEncoder, self).__init__(net_hidden, net_mean, net_logstd, x_dim, z_dim) 166 | 167 | 168 | class PendulumDecoder(Decoder): 169 | def __init__(self, z_dim=3, x_dim=4608): 170 | net_hidden = nn.Sequential( 171 | nn.Linear(z_dim, 500), 172 | nn.ReLU(), 173 | nn.Linear(500, 500), 174 | nn.ReLU(), 175 | ) 176 | net_logits = nn.Linear(500, x_dim) 177 | super(PendulumDecoder, self).__init__(net_hidden, net_logits, z_dim, x_dim) 178 | 179 | 180 | class PendulumDynamics(Dynamics): 181 | def __init__(self, armotized, z_dim=3, u_dim=1): 182 | net_hidden = nn.Sequential(nn.Linear(z_dim + u_dim, 30), nn.ReLU(), nn.Linear(30, 30), nn.ReLU()) 183 | net_mean = nn.Linear(30, z_dim) 184 | net_logstd = nn.Linear(30, z_dim) 185 | if armotized: 186 | net_A = nn.Linear(30, z_dim * z_dim) 187 | net_B = nn.Linear(30, u_dim * z_dim) 188 | else: 189 | net_A, net_B = None, None 190 | super(PendulumDynamics, self).__init__(net_hidden, net_mean, net_logstd, net_A, net_B, z_dim, u_dim, armotized) 191 | 192 | 193 | class PendulumBackwardDynamics(BackwardDynamics): 194 | def __init__(self, z_dim=3, u_dim=1, x_dim=4608): 195 | net_z = nn.Linear(z_dim, 10) 196 | net_u = nn.Linear(u_dim, 10) 197 | net_x = nn.Linear(x_dim, 200) 198 | net_joint_hidden = nn.Sequential( 199 | nn.Linear(10 + 10 + 200, 200), 200 | nn.ReLU(), 201 | ) 202 | net_joint_mean = nn.Linear(200, z_dim) 203 | net_joint_logstd = nn.Linear(200, z_dim) 204 | super(PendulumBackwardDynamics, self).__init__( 205 | net_z, net_u, net_x, net_joint_hidden, net_joint_mean, net_joint_logstd, z_dim, u_dim, x_dim 206 | ) 207 | 208 | 209 | class Flatten(nn.Module): 210 | def __init__(self): 211 | super(Flatten, self).__init__() 212 | 213 | def forward(self, x): 214 | return x.view(x.size(0), -1) 215 | 216 | 217 | class View(nn.Module): 218 | def __init__(self, shape): 219 | super(View, self).__init__() 220 | self.shape = shape 221 | 222 | def forward(self, x): 223 | return x.view(*self.shape) 224 | 225 | 226 | class CartPoleEncoder(Encoder): 227 | def __init__(self, x_dim=(2, 80, 80), z_dim=8): 228 | x_channels = x_dim[0] 229 | net_hidden = nn.Sequential( 230 | nn.Conv2d(in_channels=x_channels, out_channels=32, kernel_size=5, stride=1, padding=2), 231 | nn.ReLU(), 232 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2), 233 | nn.ReLU(), 234 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2), 235 | nn.ReLU(), 236 | nn.Conv2d(in_channels=32, out_channels=10, kernel_size=5, stride=2, padding=2), 237 | nn.ReLU(), 238 | Flatten(), 239 | nn.Linear(10 * 10 * 10, 200), 240 | nn.ReLU(), 241 | ) 242 | net_mean = nn.Linear(200, z_dim) 243 | net_logstd = nn.Linear(200, z_dim) 244 | super(CartPoleEncoder, self).__init__(net_hidden, net_mean, net_logstd, x_dim, z_dim) 245 | 246 | 247 | class CartPoleDecoder(Decoder): 248 | def __init__(self, z_dim=8, x_dim=(2, 80, 80)): 249 | x_channels = x_dim[0] 250 | net_hidden = nn.Sequential( 251 | nn.Linear(z_dim, 200), 252 | nn.ReLU(), 253 | nn.Linear(200, 1000), 254 | nn.ReLU(), 255 | View((-1, 10, 10, 10)), 256 | nn.ConvTranspose2d(in_channels=10, out_channels=32, kernel_size=5, stride=1, padding=2), 257 | nn.Upsample(scale_factor=2), 258 | nn.ReLU(), 259 | nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2), 260 | nn.Upsample(scale_factor=2), 261 | nn.ReLU(), 262 | nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2), 263 | nn.Upsample(scale_factor=2), 264 | nn.ReLU(), 265 | ) 266 | net_logits = nn.Sequential( 267 | nn.ConvTranspose2d(in_channels=32, out_channels=x_channels, kernel_size=5, stride=1, padding=2), 268 | Flatten(), 269 | ) 270 | super(CartPoleDecoder, self).__init__(net_hidden, net_logits, z_dim, x_dim) 271 | 272 | 273 | class CartPoleDynamics(Dynamics): 274 | def __init__(self, armotized, z_dim=8, u_dim=1): 275 | net_hidden = nn.Sequential(nn.Linear(z_dim + u_dim, 40), nn.ReLU(), nn.Linear(40, 40), nn.ReLU()) 276 | net_mean = nn.Linear(40, z_dim) 277 | net_logstd = nn.Linear(40, z_dim) 278 | if armotized: 279 | net_A = nn.Linear(40, z_dim * z_dim) 280 | net_B = nn.Linear(40, u_dim * z_dim) 281 | else: 282 | net_A, net_B = None, None 283 | super(CartPoleDynamics, self).__init__(net_hidden, net_mean, net_logstd, net_A, net_B, z_dim, u_dim, armotized) 284 | 285 | 286 | class CartPoleBackwardDynamics(BackwardDynamics): 287 | def __init__(self, z_dim=8, u_dim=1, x_dim=(2, 80, 80)): 288 | net_z = nn.Linear(z_dim, 10) 289 | net_u = nn.Linear(u_dim, 10) 290 | net_x = nn.Sequential(Flatten(), nn.Linear(x_dim[0] * x_dim[1] * x_dim[2], 300)) 291 | 292 | net_joint_hidden = nn.Sequential( 293 | nn.Linear(10 + 10 + 300, 300), 294 | nn.ReLU(), 295 | ) 296 | net_joint_mean = nn.Linear(300, z_dim) 297 | net_joint_logstd = nn.Linear(300, z_dim) 298 | super(CartPoleBackwardDynamics, self).__init__( 299 | net_z, net_u, net_x, net_joint_hidden, net_joint_mean, net_joint_logstd, z_dim, u_dim, x_dim 300 | ) 301 | 302 | 303 | class ThreePoleEncoder(Encoder): 304 | def __init__(self, x_dim=(2, 80, 80), z_dim=8): 305 | x_channels = x_dim[0] 306 | net_hidden = nn.Sequential( 307 | nn.Conv2d(in_channels=x_channels, out_channels=32, kernel_size=5, stride=1, padding=2), 308 | nn.ReLU(), 309 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2), 310 | nn.ReLU(), 311 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2), 312 | nn.ReLU(), 313 | nn.Conv2d(in_channels=32, out_channels=10, kernel_size=5, stride=2, padding=2), 314 | nn.ReLU(), 315 | Flatten(), 316 | nn.Linear(10 * 10 * 10, 200), 317 | nn.ReLU(), 318 | ) 319 | net_mean = nn.Linear(200, z_dim) 320 | net_logstd = nn.Linear(200, z_dim) 321 | super(ThreePoleEncoder, self).__init__(net_hidden, net_mean, net_logstd, x_dim, z_dim) 322 | 323 | 324 | class ThreePoleDecoder(Decoder): 325 | def __init__(self, z_dim=8, x_dim=(2, 80, 80)): 326 | x_channels = x_dim[0] 327 | net_hidden = nn.Sequential( 328 | nn.Linear(z_dim, 200), 329 | nn.ReLU(), 330 | nn.Linear(200, 1000), 331 | nn.ReLU(), 332 | View((-1, 10, 10, 10)), 333 | nn.ConvTranspose2d(in_channels=10, out_channels=32, kernel_size=5, stride=1, padding=2), 334 | nn.Upsample(scale_factor=2), 335 | nn.ReLU(), 336 | nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2), 337 | nn.Upsample(scale_factor=2), 338 | nn.ReLU(), 339 | nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2), 340 | nn.Upsample(scale_factor=2), 341 | nn.ReLU(), 342 | ) 343 | net_logits = nn.Sequential( 344 | nn.ConvTranspose2d(in_channels=32, out_channels=x_channels, kernel_size=5, stride=1, padding=2), 345 | Flatten(), 346 | ) 347 | super(ThreePoleDecoder, self).__init__(net_hidden, net_logits, z_dim, x_dim) 348 | 349 | 350 | class ThreePoleDynamics(Dynamics): 351 | def __init__(self, armotized, z_dim=8, u_dim=1): 352 | net_hidden = nn.Sequential(nn.Linear(z_dim + u_dim, 40), nn.ReLU(), nn.Linear(40, 40), nn.ReLU()) 353 | net_mean = nn.Linear(40, z_dim) 354 | net_logstd = nn.Linear(40, z_dim) 355 | if armotized: 356 | net_A = nn.Linear(40, z_dim * z_dim) 357 | net_B = nn.Linear(40, u_dim * z_dim) 358 | else: 359 | net_A, net_B = None, None 360 | super(ThreePoleDynamics, self).__init__( 361 | net_hidden, net_mean, net_logstd, net_A, net_B, z_dim, u_dim, armotized 362 | ) 363 | 364 | 365 | class ThreePoleBackwardDynamics(BackwardDynamics): 366 | def __init__(self, z_dim=8, u_dim=1, x_dim=(2, 80, 80)): 367 | net_z = nn.Linear(z_dim, 10) 368 | net_u = nn.Linear(u_dim, 10) 369 | net_x = nn.Sequential(Flatten(), nn.Linear(x_dim[0] * x_dim[1] * x_dim[2], 300)) 370 | 371 | net_joint_hidden = nn.Sequential( 372 | nn.Linear(10 + 10 + 300, 300), 373 | nn.ReLU(), 374 | ) 375 | net_joint_mean = nn.Linear(300, z_dim) 376 | net_joint_logstd = nn.Linear(300, z_dim) 377 | super(ThreePoleBackwardDynamics, self).__init__( 378 | net_z, net_u, net_x, net_joint_hidden, net_joint_mean, net_joint_logstd, z_dim, u_dim, x_dim 379 | ) 380 | 381 | 382 | CONFIG = { 383 | "planar": (PlanarEncoder, PlanarDecoder, PlanarDynamics, PlanarBackwardDynamics), 384 | "pendulum": (PendulumEncoder, PendulumDecoder, PendulumDynamics, PendulumBackwardDynamics), 385 | "cartpole": (CartPoleEncoder, CartPoleDecoder, CartPoleDynamics, CartPoleBackwardDynamics), 386 | "threepole": (ThreePoleEncoder, ThreePoleDecoder, ThreePoleDynamics, ThreePoleBackwardDynamics), 387 | } 388 | 389 | 390 | def load_config(name): 391 | return CONFIG[name] 392 | 393 | 394 | __all__ = ["load_config"] 395 | 396 | # device = torch.device("cuda") 397 | # cartpole_encoder = CartPoleEncoder() 398 | # cartpole_encoder.to(device) 399 | # # cartpole_encoder.net[0].to(device) 400 | # print (next(cartpole_encoder.net[0].parameters()).is_cuda) 401 | -------------------------------------------------------------------------------- /pcc.yml: -------------------------------------------------------------------------------- 1 | name: pcc 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _tflow_select=2.1.0=gpu 10 | - absl-py=0.7.1=py36_0 11 | - asn1crypto=0.24.0=py36_0 12 | - astor=0.8.0=py36_0 13 | - attrs=19.3.0=py_0 14 | - backcall=0.1.0=py_0 15 | - blas=1.0=mkl 16 | - bleach=3.1.0=py_0 17 | - c-ares=1.15.0=h7b6447c_1001 18 | - ca-certificates=2019.9.11=hecc5488_0 19 | - certifi=2019.9.11=py36_0 20 | - cffi=1.12.3=py36h2e261b9_0 21 | - chardet=3.0.4=py36_1003 22 | - colour=0.1.5=py_0 23 | - cryptography=2.7=py36h1ba5d50_0 24 | - cudatoolkit=10.0.130=0 25 | - cudnn=7.6.0=cuda10.0_0 26 | - cupti=10.0.130=0 27 | - cycler=0.10.0=py_1 28 | - dbus=1.13.6=h746ee38_0 29 | - decorator=4.4.0=py_0 30 | - defusedxml=0.6.0=py_0 31 | - entrypoints=0.3=py36_1000 32 | - expat=2.2.6=he6710b0_0 33 | - fontconfig=2.13.1=he4413a7_1000 34 | - freetype=2.9.1=h8a8886c_1 35 | - gast=0.2.2=py36_0 36 | - gettext=0.19.8.1=hc5be6a0_1002 37 | - glib=2.56.2=had28632_1001 38 | - google-pasta=0.1.7=py_0 39 | - grpcio=1.16.1=py36hf8bcb03_1 40 | - gst-plugins-base=1.14.0=hbbd80ab_1 41 | - gstreamer=1.14.0=hb453b48_1 42 | - h5py=2.9.0=py36h7918eee_0 43 | - hdf5=1.10.4=hb1b8bf9_0 44 | - icu=58.2=hf484d3e_1000 45 | - idna=2.8=py36_0 46 | - imageio=2.6.1=py36_0 47 | - importlib_metadata=0.23=py36_0 48 | - intel-openmp=2019.4=243 49 | - ipykernel=5.1.3=py36h5ca1d4c_0 50 | - ipython=7.8.0=py36h5ca1d4c_0 51 | - ipython_genutils=0.2.0=py_1 52 | - jedi=0.15.1=py36_0 53 | - jinja2=2.10.3=py_0 54 | - jpeg=9b=h024ee3a_2 55 | - jsonschema=3.1.1=py36_0 56 | - jupyter_client=5.3.3=py36_1 57 | - jupyter_core=4.4.0=py_0 58 | - keras-applications=1.0.8=py_0 59 | - keras-preprocessing=1.1.0=py_1 60 | - kiwisolver=1.1.0=py36hc9558a2_0 61 | - libedit=3.1.20181209=hc058e9b_0 62 | - libffi=3.2.1=hd88cf55_4 63 | - libgcc-ng=9.1.0=hdf63c60_0 64 | - libgfortran-ng=7.3.0=hdf63c60_0 65 | - libiconv=1.15=h516909a_1005 66 | - libpng=1.6.37=hbc83047_0 67 | - libprotobuf=3.8.0=hd408876_0 68 | - libsodium=1.0.17=h516909a_0 69 | - libstdcxx-ng=9.1.0=hdf63c60_0 70 | - libtiff=4.0.10=h2733197_2 71 | - libuuid=2.32.1=h14c3975_1000 72 | - libxcb=1.13=h14c3975_1002 73 | - libxml2=2.9.9=h13577e0_2 74 | - markdown=3.1.1=py36_0 75 | - markupsafe=1.1.1=py36h14c3975_0 76 | - matplotlib=3.1.1=py36h5429711_0 77 | - mistune=0.8.4=py36h14c3975_1000 78 | - mkl=2019.4=243 79 | - mkl-service=2.3.0=py36he904b0f_0 80 | - mkl_fft=1.0.14=py36ha843d7b_0 81 | - mkl_random=1.0.2=py36hd81dba3_0 82 | - more-itertools=7.2.0=py_0 83 | - nbconvert=5.6.0=py36_1 84 | - nbformat=4.4.0=py_1 85 | - ncurses=6.1=he6710b0_1 86 | - ninja=1.9.0=py36hfd86e86_0 87 | - notebook=6.0.1=py36_0 88 | - numpy=1.16.5=py36h7e9f1db_0 89 | - numpy-base=1.16.5=py36hde5b4d6_0 90 | - olefile=0.46=py36_0 91 | - openssl=1.1.1=h7b6447c_0 92 | - pandoc=2.7.3=0 93 | - pandocfilters=1.4.2=py_1 94 | - parso=0.5.1=py_0 95 | - pcre=8.43=he6710b0_0 96 | - pexpect=4.7.0=py36_0 97 | - pickleshare=0.7.5=py36_1000 98 | - pillow=6.1.0=py36h34e0f95_0 99 | - pip=19.2.2=py36_0 100 | - pixman-cos6-x86_64=0.32.8=h7062e45_0 101 | - prometheus_client=0.7.1=py_0 102 | - prompt_toolkit=2.0.10=py_0 103 | - protobuf=3.8.0=py36he6710b0_0 104 | - pthread-stubs=0.4=h14c3975_1001 105 | - ptyprocess=0.6.0=py_1001 106 | - pycparser=2.19=py36_0 107 | - pyglet=1.3.2=py36_1000 108 | - pygments=2.4.2=py_0 109 | - pyopengl=3.1.1a1=py36_0 110 | - pyopenssl=19.0.0=py36_0 111 | - pyparsing=2.4.2=py_0 112 | - pyqt=5.9.2=py36hcca6a23_4 113 | - pyrsistent=0.15.4=py36h516909a_0 114 | - pysocks=1.7.1=py36_0 115 | - python=3.6.8=h0371630_0 116 | - python-dateutil=2.8.0=py_0 117 | - pytorch=1.2.0=py3.6_cuda10.0.130_cudnn7.6.2_0 118 | - pytz=2019.2=py_0 119 | - pyzmq=18.1.0=py36h1768529_0 120 | - qt=5.9.7=h5867ecd_1 121 | - readline=7.0=h7b6447c_5 122 | - requests=2.22.0=py36_0 123 | - scipy=1.3.1=py36h7c811a0_0 124 | - send2trash=1.5.0=py_0 125 | - setuptools=41.2.0=py36_0 126 | - sip=4.19.8=py36hf484d3e_1000 127 | - six=1.12.0=py36_0 128 | - sqlite=3.29.0=h7b6447c_0 129 | - tensorboard=1.14.0=py36hf484d3e_0 130 | - tensorboardx=1.8=py_0 131 | - tensorflow=1.14.0=gpu_py36h57aa796_0 132 | - tensorflow-base=1.14.0=gpu_py36h8d69cac_0 133 | - tensorflow-estimator=1.14.0=py_0 134 | - tensorflow-gpu=1.14.0=h0d30ee6_0 135 | - termcolor=1.1.0=py36_1 136 | - terminado=0.8.2=py36_0 137 | - testpath=0.4.2=py_1001 138 | - tk=8.6.8=hbc83047_0 139 | - torchvision=0.4.0=py36_cu100 140 | - tornado=6.0.3=py36h516909a_0 141 | - tqdm=4.36.1=py_0 142 | - traitlets=4.3.3=py36_0 143 | - urllib3=1.24.2=py36_0 144 | - wcwidth=0.1.7=py_1 145 | - webencodings=0.5.1=py_1 146 | - werkzeug=0.15.5=py_0 147 | - wheel=0.33.6=py36_0 148 | - wrapt=1.11.2=py36h7b6447c_0 149 | - xorg-libxau=1.0.9=h14c3975_0 150 | - xorg-libxdmcp=1.1.3=h516909a_0 151 | - xorg-x11-server-common-cos6-x86_64=1.17.4=he6f580c_0 152 | - xorg-x11-server-xvfb-cos6-x86_64=1.17.4=h5c27f9d_0 153 | - xz=5.2.4=h14c3975_4 154 | - zeromq=4.3.2=he1b5a44_2 155 | - zipp=0.6.0=py_0 156 | - zlib=1.2.11=h7b6447c_3 157 | - zstd=1.3.7=h0b5b093_0 158 | - pip: 159 | - future==0.17.1 160 | prefix: /home/tungnd13/miniconda3/envs/pcc 161 | 162 | -------------------------------------------------------------------------------- /pcc_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from networks import load_config 3 | from torch import nn 4 | 5 | 6 | torch.set_default_dtype(torch.float64) 7 | # torch.manual_seed(0) 8 | 9 | 10 | class PCC(nn.Module): 11 | def __init__(self, armotized, x_dim, z_dim, u_dim, env="planar"): 12 | super(PCC, self).__init__() 13 | enc, dec, dyn, back_dyn = load_config(env) 14 | 15 | self.x_dim = x_dim 16 | self.z_dim = z_dim 17 | self.u_dim = u_dim 18 | self.armotized = armotized 19 | 20 | self.encoder = enc(x_dim, z_dim) 21 | self.decoder = dec(z_dim, x_dim) 22 | self.dynamics = dyn(armotized, z_dim, u_dim) 23 | self.backward_dynamics = back_dyn(z_dim, u_dim, x_dim) 24 | 25 | def encode(self, x): 26 | return self.encoder(x) 27 | 28 | def decode(self, z): 29 | return self.decoder(z) 30 | 31 | def transition(self, z, u): 32 | return self.dynamics(z, u) 33 | 34 | def back_dynamics(self, z, u, x): 35 | return self.backward_dynamics(z, u, x) 36 | 37 | def reparam(self, mean, std): 38 | # sigma = (logvar / 2).exp() 39 | epsilon = torch.randn_like(std) 40 | return mean + torch.mul(epsilon, std) 41 | 42 | def forward(self, x, u, x_next): 43 | # prediction and consistency loss 44 | # 1st term and 3rd 45 | q_z_next = self.encode(x_next) # Q(z^_t+1 | x_t+1) 46 | z_next = self.reparam(q_z_next.mean, q_z_next.stddev) # sample z^_t+1 47 | p_x_next = self.decode(z_next) # P(x_t+1 | z^t_t+1) 48 | # 2nd term 49 | q_z_backward = self.back_dynamics(z_next, u, x) # Q(z_t | z^_t+1, u_t, x_t) 50 | p_z = self.encode(x) # P(z_t | x_t) 51 | 52 | # 4th term 53 | z_q = self.reparam(q_z_backward.mean, q_z_backward.stddev) # samples from Q(z_t | z^_t+1, u_t, x_t) 54 | p_z_next, _, _ = self.transition(z_q, u) # P(z^_t+1 | z_t, u _t) 55 | 56 | # additional VAE loss 57 | z_p = self.reparam(p_z_next.mean, p_z_next.stddev) # samples from P(z_t | x_t) 58 | p_x = self.decode(z_p) # for additional vae loss 59 | 60 | # additional deterministic loss 61 | mu_z_next_determ = self.transition(p_z.mean, u)[0].mean 62 | p_x_next_determ = self.decode(mu_z_next_determ) 63 | 64 | return p_x_next, q_z_backward, p_z, q_z_next, z_next, p_z_next, z_p, u, p_x, p_x_next_determ 65 | 66 | def predict(self, x, u): 67 | mu, logvar = self.encoder(x) 68 | z = self.reparam(mu, logvar) 69 | x_recon = self.decode(z) 70 | 71 | mu_next, logvar_next, A, B = self.transition(z, u) 72 | z_next = self.reparam(mu_next, logvar_next) 73 | x_next_pred = self.decode(z_next) 74 | return x_recon, x_next_pred 75 | 76 | 77 | # def reparam(mean, logvar): 78 | # sigma = (logvar / 2).exp() 79 | # epsilon = torch.randn_like(sigma) 80 | # return mean + torch.mul(epsilon, sigma) 81 | 82 | # def jacobian_1(dynamics, z, u): 83 | # """ 84 | # compute the jacobian of F(z,u) w.r.t z, u 85 | # """ 86 | # z_dim, u_dim = z.size(1), u.size(1) 87 | # z, u = z.squeeze().repeat(z_dim, 1), u.squeeze().repeat(z_dim, 1) 88 | # z = z.detach().requires_grad_(True) 89 | # u = u.detach().requires_grad_(True) 90 | # z_next, _, _, _ = dynamics(z, u) 91 | # grad_inp = torch.eye(z_dim) 92 | # A = torch.autograd.grad(z_next, z, grad_inp, retain_graph=True)[0] 93 | # B = torch.autograd.grad(z_next, u, grad_inp, retain_graph=True)[0] 94 | # return A, B 95 | 96 | # def jacobian_2(dynamics, z, u): 97 | # """ 98 | # compute the jacobian of F(z,u) w.r.t z, u 99 | # """ 100 | # z_dim, u_dim = z.size(1), u.size(1) 101 | # z = z.detach().requires_grad_(True) 102 | # u = u.detach().requires_grad_(True) 103 | # z_next, _, _, _ = dynamics(z, u) 104 | # A = torch.empty(size=(z_dim, z_dim)) 105 | # B = torch.empty(size=(z_dim, u_dim)) 106 | # for i in range(A.size(0)): # for each row 107 | # grad_inp = torch.zeros(size=(1, A.size(0))) 108 | # grad_inp[0][i] = 1 109 | # A[i] = torch.autograd.grad(z_next, z, grad_inp, retain_graph=True)[0] 110 | # for i in range(B.size(0)): # for each row 111 | # grad_inp = torch.zeros(size=(1, B.size(0))) 112 | # grad_inp[0][i] = 1 113 | # B[i] = torch.autograd.grad(z_next, u, grad_inp, retain_graph=True)[0] 114 | # return A, B 115 | 116 | # enc, dec, dyn, back_dyn = load_config('planar') 117 | # dynamics = dyn(armotized=False, z_dim=2, u_dim=2) 118 | # dynamics.eval() 119 | 120 | # import torch.optim as optim 121 | # optimizer = optim.Adam(dynamics.parameters(), betas=(0.9, 0.999), eps=1e-8, lr=0.001) 122 | 123 | # z = torch.randn(size=(1, 2)) 124 | # z.requires_grad = True 125 | # u = torch.randn(size=(1, 2)) 126 | # u.requires_grad = True 127 | 128 | # eps_z = torch.normal(0.0, 0.1, size=z.size()) 129 | # eps_u = torch.normal(0.0, 0.1, size=u.size()) 130 | 131 | # mean, logvar, _, _ = dynamics(z, u) 132 | # grad_z = torch.autograd.grad(mean, z, grad_outputs=eps_z, retain_graph=True, create_graph=True) 133 | # grad_u = torch.autograd.grad(mean, u, grad_outputs=eps_u, retain_graph=True, create_graph=True) 134 | # print ('AAA') 135 | # print (grad_z, grad_u) 136 | # A, B = jacobian_1(dynamics, z, u) 137 | # grad_z, grad_u = eps_z.mm(A), eps_u.mm(B) 138 | # print ('BBBB') 139 | # print (grad_z, grad_u) 140 | # A, B = jacobian_1(dynamics, z, u) 141 | # grad_z, grad_u = eps_z.mm(A), eps_u.mm(B) 142 | # print ('BBBB') 143 | # print (grad_z, grad_u) 144 | -------------------------------------------------------------------------------- /sample_results/cartpole_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PCC-pytorch/1127f8314cdd45faeca8bd429777c50fcfc192a5/sample_results/cartpole_1.gif -------------------------------------------------------------------------------- /sample_results/cartpole_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PCC-pytorch/1127f8314cdd45faeca8bd429777c50fcfc192a5/sample_results/cartpole_2.gif -------------------------------------------------------------------------------- /sample_results/latent_map_pend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PCC-pytorch/1127f8314cdd45faeca8bd429777c50fcfc192a5/sample_results/latent_map_pend.png -------------------------------------------------------------------------------- /sample_results/latent_map_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PCC-pytorch/1127f8314cdd45faeca8bd429777c50fcfc192a5/sample_results/latent_map_sample.png -------------------------------------------------------------------------------- /sample_results/pendulum_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PCC-pytorch/1127f8314cdd45faeca8bd429777c50fcfc192a5/sample_results/pendulum_1.gif -------------------------------------------------------------------------------- /sample_results/pendulum_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PCC-pytorch/1127f8314cdd45faeca8bd429777c50fcfc192a5/sample_results/pendulum_2.gif -------------------------------------------------------------------------------- /sample_results/planar_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PCC-pytorch/1127f8314cdd45faeca8bd429777c50fcfc192a5/sample_results/planar_1.gif -------------------------------------------------------------------------------- /sample_results/planar_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PCC-pytorch/1127f8314cdd45faeca8bd429777c50fcfc192a5/sample_results/planar_2.gif -------------------------------------------------------------------------------- /train_pcc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import time 6 | from os import path 7 | 8 | import numpy as np 9 | import torch 10 | import torch.optim as optim 11 | from datasets import CartPoleDataset, PendulumDataset, PlanarDataset, ThreePoleDataset 12 | from latent_map_planar import draw_latent_map 13 | from losses import KL, ae_loss, bernoulli, curvature, entropy, gaussian, vae_bound 14 | from mdp.plane_obstacles_mdp import PlanarObstaclesMDP 15 | from pcc_model import PCC 16 | from tensorboardX import SummaryWriter 17 | from torch.utils.data import DataLoader 18 | 19 | 20 | torch.set_default_dtype(torch.float64) 21 | 22 | device = torch.device("cuda") 23 | datasets = { 24 | "planar": PlanarDataset, 25 | "pendulum": PendulumDataset, 26 | "cartpole": CartPoleDataset, 27 | "threepole": ThreePoleDataset, 28 | } 29 | dims = { 30 | "planar": (1600, 2, 2), 31 | "pendulum": (4608, 3, 1), 32 | "cartpole": ((2, 80, 80), 8, 1), 33 | "threepole": ((2, 80, 80), 8, 3), 34 | } 35 | 36 | 37 | def seed_torch(seed): 38 | random.seed(seed) 39 | os.environ["PYTHONHASHSEED"] = str(seed) 40 | np.random.seed(seed) 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 44 | torch.backends.cudnn.benchmark = False 45 | torch.backends.cudnn.deterministic = True 46 | 47 | 48 | def compute_loss( 49 | model, 50 | armotized, 51 | x, 52 | u, 53 | x_next, 54 | p_x_next, 55 | q_z_backward, 56 | p_z, 57 | q_z_next, 58 | z_next, 59 | p_z_next, 60 | z, 61 | p_x, 62 | p_x_next_determ, 63 | lam=(1.0, 8.0, 8.0), 64 | delta=0.1, 65 | vae_coeff=0.01, 66 | determ_coeff=0.3, 67 | ): 68 | # prediction and consistency loss 69 | pred_loss = -bernoulli(x_next, p_x_next) + KL(q_z_backward, p_z) - entropy(q_z_next) - gaussian(z_next, p_z_next) 70 | 71 | consis_loss = -entropy(q_z_next) - gaussian(z_next, p_z_next) + KL(q_z_backward, p_z) 72 | 73 | # curvature loss 74 | cur_loss = curvature(model, z, u, delta, armotized) 75 | 76 | # additional vae loss 77 | vae_loss = vae_bound(x, p_x, p_z) 78 | 79 | # additional deterministic loss 80 | determ_loss = -bernoulli(x_next, p_x_next_determ) 81 | 82 | lam_p, lam_c, lam_cur = lam 83 | return ( 84 | pred_loss, 85 | consis_loss, 86 | cur_loss, 87 | lam_p * pred_loss 88 | + lam_c * consis_loss 89 | + lam_cur * cur_loss 90 | + vae_coeff * vae_loss 91 | + determ_coeff * determ_loss, 92 | ) 93 | 94 | 95 | def train(model, env_name, train_loader, lam, vae_coeff, determ_coeff, optimizer, armotized, epoch): 96 | avg_pred_loss = 0.0 97 | avg_consis_loss = 0.0 98 | avg_cur_loss = 0.0 99 | avg_loss = 0.0 100 | avg_ae_loss = 0.0 101 | 102 | num_batches = len(train_loader) 103 | model.train() 104 | 105 | start = time.time() 106 | for x, u, x_next in train_loader: 107 | x = x.to(device).double() 108 | u = u.to(device).double() 109 | x_next = x_next.to(device).double() 110 | optimizer.zero_grad() 111 | 112 | p_x_next, q_z_backward, p_z, q_z_next, z_next, p_z_next, z_p, u, p_x, p_x_next_determ = model(x, u, x_next) 113 | 114 | x = x.view(x.size(0), -1) 115 | x_next = x_next.view(x_next.size(0), -1) 116 | 117 | if env_name == "planar" and epoch < 100: # warm up using autoencoder 118 | pred_loss, consis_loss, cur_loss = torch.zeros(1), torch.zeros(1), torch.zeros(1) 119 | loss = ae_loss(x, p_x) 120 | avg_ae_loss += loss.item() 121 | else: 122 | pred_loss, consis_loss, cur_loss, loss = compute_loss( 123 | model, 124 | armotized, 125 | x, 126 | u, 127 | x_next, 128 | p_x_next, 129 | q_z_backward, 130 | p_z, 131 | q_z_next, 132 | z_next, 133 | p_z_next, 134 | z_p, 135 | p_x, 136 | p_x_next_determ, 137 | lam=lam, 138 | vae_coeff=vae_coeff, 139 | determ_coeff=determ_coeff, 140 | ) 141 | 142 | loss.backward() 143 | # clip_grad_norm_(model.parameters(), 1.0) 144 | optimizer.step() 145 | 146 | avg_pred_loss += pred_loss.item() 147 | avg_consis_loss += consis_loss.item() 148 | avg_cur_loss += cur_loss.item() 149 | avg_loss += loss 150 | 151 | avg_pred_loss /= num_batches 152 | avg_consis_loss /= num_batches 153 | avg_cur_loss /= num_batches 154 | avg_loss /= num_batches 155 | 156 | end = time.time() 157 | print("Training time: %f" % (end - start)) 158 | 159 | if (epoch + 1) % 1 == 0: 160 | if env_name == "planar" and epoch < 100: 161 | print("AE loss epoch %d: %f" % (epoch + 1, avg_ae_loss)) 162 | print("---------------------------") 163 | else: 164 | print("Epoch %d" % (epoch + 1)) 165 | print("Prediction loss: %f" % (avg_pred_loss)) 166 | print("Consistency loss: %f" % (avg_consis_loss)) 167 | print("Curvature loss: %f" % (avg_cur_loss)) 168 | print("Training loss: %f" % (avg_loss)) 169 | print("--------------------------------------") 170 | 171 | return avg_pred_loss, avg_consis_loss, avg_cur_loss, avg_loss 172 | 173 | 174 | def main(args): 175 | env_name = args.env 176 | assert env_name in ["planar", "pendulum", "cartpole", "threepole"] 177 | armotized = args.armotized 178 | log_dir = args.log_dir 179 | seed = args.seed 180 | data_size = args.data_size 181 | noise_level = args.noise 182 | batch_size = args.batch_size 183 | lam_p = args.lam_p 184 | lam_c = args.lam_c 185 | lam_cur = args.lam_cur 186 | lam = (lam_p, lam_c, lam_cur) 187 | vae_coeff = args.vae_coeff 188 | determ_coeff = args.determ_coeff 189 | lr = args.lr 190 | weight_decay = args.decay 191 | epoches = args.num_iter 192 | iter_save = args.iter_save 193 | save_map = args.save_map 194 | 195 | seed_torch(seed) 196 | 197 | def _init_fn(worker_id): 198 | np.random.seed(int(seed)) 199 | 200 | dataset = datasets[env_name] 201 | data = dataset(sample_size=data_size, noise=noise_level) 202 | data_loader = DataLoader( 203 | data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4, worker_init_fn=_init_fn 204 | ) 205 | 206 | x_dim, z_dim, u_dim = dims[env_name] 207 | model = PCC(armotized=armotized, x_dim=x_dim, z_dim=z_dim, u_dim=u_dim, env=env_name).to(device) 208 | 209 | if env_name == "planar" and save_map: 210 | mdp = PlanarObstaclesMDP(noise=noise_level) 211 | 212 | optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.999), eps=1e-8, lr=lr, weight_decay=weight_decay) 213 | 214 | log_path = "logs/" + env_name + "/" + log_dir 215 | if not path.exists(log_path): 216 | os.makedirs(log_path) 217 | writer = SummaryWriter(log_path) 218 | 219 | result_path = "result/" + env_name + "/" + log_dir 220 | if not path.exists(result_path): 221 | os.makedirs(result_path) 222 | with open(result_path + "/settings", "w") as f: 223 | json.dump(args.__dict__, f, indent=2) 224 | 225 | if env_name == "planar" and save_map: 226 | latent_maps = [draw_latent_map(model, mdp)] 227 | for i in range(epoches): 228 | avg_pred_loss, avg_consis_loss, avg_cur_loss, avg_loss = train( 229 | model, env_name, data_loader, lam, vae_coeff, determ_coeff, optimizer, armotized, i 230 | ) 231 | 232 | # ...log the running loss 233 | writer.add_scalar("prediction loss", avg_pred_loss, i) 234 | writer.add_scalar("consistency loss", avg_consis_loss, i) 235 | writer.add_scalar("curvature loss", avg_cur_loss, i) 236 | writer.add_scalar("training loss", avg_loss, i) 237 | if env_name == "planar" and save_map: 238 | if (i + 1) % 10 == 0: 239 | map_i = draw_latent_map(model, mdp) 240 | latent_maps.append(map_i) 241 | # save model 242 | if (i + 1) % iter_save == 0: 243 | print("Saving the model.............") 244 | 245 | torch.save(model.state_dict(), result_path + "/model_" + str(i + 1)) 246 | with open(result_path + "/loss_" + str(i + 1), "w") as f: 247 | f.write( 248 | "\n".join( 249 | [ 250 | "Prediction loss: " + str(avg_pred_loss), 251 | "Consistency loss: " + str(avg_consis_loss), 252 | "Curvature loss: " + str(avg_cur_loss), 253 | "Training loss: " + str(avg_loss), 254 | ] 255 | ) 256 | ) 257 | if env_name == "planar" and save_map: 258 | latent_maps[0].save( 259 | result_path + "/latent_map.gif", 260 | format="GIF", 261 | append_images=latent_maps[1:], 262 | save_all=True, 263 | duration=100, 264 | loop=0, 265 | ) 266 | writer.close() 267 | 268 | 269 | def str2bool(v): 270 | if isinstance(v, bool): 271 | return v 272 | if v.lower() in ("yes", "true", "t", "y", "1"): 273 | return True 274 | elif v.lower() in ("no", "false", "f", "n", "0"): 275 | return False 276 | else: 277 | raise argparse.ArgumentTypeError("Boolean value expected.") 278 | 279 | 280 | if __name__ == "__main__": 281 | parser = argparse.ArgumentParser(description="train pcc model") 282 | 283 | parser.add_argument("--env", required=True, type=str, help="environment used for training") 284 | parser.add_argument( 285 | "--armotized", 286 | required=True, 287 | type=str2bool, 288 | nargs="?", 289 | const=True, 290 | default=False, 291 | help="type of dynamics model", 292 | ) 293 | parser.add_argument("--log_dir", required=True, type=str, help="directory to save training log") 294 | parser.add_argument("--seed", required=True, type=int, help="seed number") 295 | parser.add_argument("--data_size", required=True, type=int, help="the bumber of data points used for training") 296 | parser.add_argument("--noise", default=0, type=float, help="the level of noise") 297 | parser.add_argument("--batch_size", default=128, type=int, help="batch size") 298 | parser.add_argument("--lam_p", default=1.0, type=float, help="weight of prediction loss") 299 | parser.add_argument("--lam_c", default=8.0, type=float, help="weight of consistency loss") 300 | parser.add_argument("--lam_cur", default=8.0, type=float, help="weight of curvature loss") 301 | parser.add_argument("--vae_coeff", default=0.01, type=float, help="coefficient of additional vae loss") 302 | parser.add_argument("--determ_coeff", default=0.3, type=float, help="coefficient of addtional deterministic loss") 303 | parser.add_argument("--lr", default=0.0005, type=float, help="learning rate") 304 | parser.add_argument("--decay", default=0.001, type=float, help="L2 regularization") 305 | parser.add_argument("--num_iter", default=2000, type=int, help="number of epoches") 306 | parser.add_argument( 307 | "--iter_save", default=1000, type=int, help="save model and result after this number of iterations" 308 | ) 309 | parser.add_argument("--save_map", default=False, type=str2bool, help="save the latent map during training or not") 310 | args = parser.parse_args() 311 | 312 | main(args) 313 | -------------------------------------------------------------------------------- /true_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VinAIResearch/PCC-pytorch/1127f8314cdd45faeca8bd429777c50fcfc192a5/true_map.png --------------------------------------------------------------------------------