├── .gitignore ├── requirements.txt ├── gif ├── Pendulum-v0-2level.gif ├── MountainCarContinuous-v0.gif └── MountainCarContinuous-v0-3level.gif ├── preTrained ├── Pendulum-h-v1 │ └── 2level │ │ ├── HAC_Pendulum-h-v1_level_0_actor.pth │ │ ├── HAC_Pendulum-h-v1_level_0_crtic.pth │ │ ├── HAC_Pendulum-h-v1_level_1_actor.pth │ │ ├── HAC_Pendulum-h-v1_level_1_crtic.pth │ │ ├── HAC_Pendulum-h-v1_solved_level_0_actor.pth │ │ ├── HAC_Pendulum-h-v1_solved_level_0_crtic.pth │ │ ├── HAC_Pendulum-h-v1_solved_level_1_actor.pth │ │ └── HAC_Pendulum-h-v1_solved_level_1_crtic.pth └── MountainCarContinuous-h-v1 │ ├── 2level │ ├── HAC_MountainCarContinuous-h-v1_level_0_actor.pth │ ├── HAC_MountainCarContinuous-h-v1_level_0_crtic.pth │ ├── HAC_MountainCarContinuous-h-v1_level_1_actor.pth │ ├── HAC_MountainCarContinuous-h-v1_level_1_crtic.pth │ ├── HAC_MountainCarContinuous-h-v1_solved_level_0_actor.pth │ ├── HAC_MountainCarContinuous-h-v1_solved_level_0_crtic.pth │ ├── HAC_MountainCarContinuous-h-v1_solved_level_1_actor.pth │ └── HAC_MountainCarContinuous-h-v1_solved_level_1_crtic.pth │ └── 3level │ ├── HAC_MountainCarContinuous-h-v1_level_0_actor.pth │ ├── HAC_MountainCarContinuous-h-v1_level_0_crtic.pth │ ├── HAC_MountainCarContinuous-h-v1_level_1_actor.pth │ ├── HAC_MountainCarContinuous-h-v1_level_1_crtic.pth │ ├── HAC_MountainCarContinuous-h-v1_level_2_actor.pth │ ├── HAC_MountainCarContinuous-h-v1_level_2_crtic.pth │ ├── HAC_MountainCarContinuous-h-v1_solved_level_0_actor.pth │ ├── HAC_MountainCarContinuous-h-v1_solved_level_0_crtic.pth │ ├── HAC_MountainCarContinuous-h-v1_solved_level_1_actor.pth │ ├── HAC_MountainCarContinuous-h-v1_solved_level_1_crtic.pth │ ├── HAC_MountainCarContinuous-h-v1_solved_level_2_actor.pth │ └── HAC_MountainCarContinuous-h-v1_solved_level_2_crtic.pth ├── asset ├── __init__.py ├── pendulum.py ├── rendering.py └── continuous_mountain_car.py ├── LICENSE ├── utils.py ├── README.md ├── test.py ├── train.py ├── DDPG.py └── HAC.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym 2 | torch 3 | pyglet 4 | six -------------------------------------------------------------------------------- /gif/Pendulum-v0-2level.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/gif/Pendulum-v0-2level.gif -------------------------------------------------------------------------------- /gif/MountainCarContinuous-v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/gif/MountainCarContinuous-v0.gif -------------------------------------------------------------------------------- /gif/MountainCarContinuous-v0-3level.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/gif/MountainCarContinuous-v0-3level.gif -------------------------------------------------------------------------------- /preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_level_0_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_level_0_actor.pth -------------------------------------------------------------------------------- /preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_level_0_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_level_0_crtic.pth -------------------------------------------------------------------------------- /preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_level_1_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_level_1_actor.pth -------------------------------------------------------------------------------- /preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_level_1_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_level_1_crtic.pth -------------------------------------------------------------------------------- /preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_solved_level_0_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_solved_level_0_actor.pth -------------------------------------------------------------------------------- /preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_solved_level_0_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_solved_level_0_crtic.pth -------------------------------------------------------------------------------- /preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_solved_level_1_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_solved_level_1_actor.pth -------------------------------------------------------------------------------- /preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_solved_level_1_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/Pendulum-h-v1/2level/HAC_Pendulum-h-v1_solved_level_1_crtic.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_level_0_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_level_0_actor.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_level_0_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_level_0_crtic.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_level_1_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_level_1_actor.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_level_1_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_level_1_crtic.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_0_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_0_actor.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_0_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_0_crtic.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_1_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_1_actor.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_1_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_1_crtic.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_2_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_2_actor.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_2_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_level_2_crtic.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_solved_level_0_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_solved_level_0_actor.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_solved_level_0_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_solved_level_0_crtic.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_solved_level_1_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_solved_level_1_actor.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_solved_level_1_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/2level/HAC_MountainCarContinuous-h-v1_solved_level_1_crtic.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_0_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_0_actor.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_0_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_0_crtic.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_1_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_1_actor.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_1_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_1_crtic.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_2_actor.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_2_actor.pth -------------------------------------------------------------------------------- /preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_2_crtic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/HEAD/preTrained/MountainCarContinuous-h-v1/3level/HAC_MountainCarContinuous-h-v1_solved_level_2_crtic.pth -------------------------------------------------------------------------------- /asset/__init__.py: -------------------------------------------------------------------------------- 1 | from asset.continuous_mountain_car import Continuous_MountainCarEnv 2 | from asset.pendulum import PendulumEnv 3 | 4 | from gym.envs.registration import register 5 | 6 | register( 7 | id="MountainCarContinuous-h-v1", 8 | entry_point="asset:Continuous_MountainCarEnv", 9 | ) 10 | 11 | register( 12 | id="Pendulum-h-v1", 13 | entry_point="asset:PendulumEnv", 14 | ) 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Nikhil Barhate 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class ReplayBuffer: 4 | def __init__(self, max_size=5e5): 5 | self.buffer = [] 6 | self.max_size = int(max_size) 7 | self.size = 0 8 | 9 | def add(self, transition): 10 | assert len(transition) == 7, "transition must have length = 7" 11 | 12 | # transiton is tuple of (state, action, reward, next_state, goal, gamma, done) 13 | self.buffer.append(transition) 14 | self.size +=1 15 | 16 | def sample(self, batch_size): 17 | # delete 1/5th of the buffer when full 18 | if self.size > self.max_size: 19 | del self.buffer[0:int(self.size/5)] 20 | self.size = len(self.buffer) 21 | 22 | indexes = np.random.randint(0, len(self.buffer), size=batch_size) 23 | states, actions, rewards, next_states, goals, gamma, dones = [], [], [], [], [], [], [] 24 | 25 | for i in indexes: 26 | states.append(np.array(self.buffer[i][0], copy=False)) 27 | actions.append(np.array(self.buffer[i][1], copy=False)) 28 | rewards.append(np.array(self.buffer[i][2], copy=False)) 29 | next_states.append(np.array(self.buffer[i][3], copy=False)) 30 | goals.append(np.array(self.buffer[i][4], copy=False)) 31 | gamma.append(np.array(self.buffer[i][5], copy=False)) 32 | dones.append(np.array(self.buffer[i][6], copy=False)) 33 | 34 | return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(goals), np.array(gamma), np.array(dones) 35 | 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical-Actor-Critic-HAC-PyTorch 2 | 3 | This is an implementation of the Hierarchical Actor Critic (HAC) algorithm described in the paper, [Learning Multi-Level Hierarchies with Hindsight](https://arxiv.org/abs/1712.00948) (ICLR 2019), in PyTorch for OpenAI gym environments. The algorithm learns to reach a goal state by dividing the task into short horizon intermediate goals (subgoals). 4 | 5 | 6 | 7 | ## Usage 8 | - All the hyperparameters are contained in the `train.py` file. 9 | - To train a new network run `train.py` 10 | - To test a preTrained network run `test.py` 11 | - For a detailed explanation of offsets and bounds, refer to [issue #2](https://github.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/issues/2) 12 | - For hyperparameters used for preTraining the pendulum policy refer to [issue #3](https://github.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/issues/3) 13 | 14 | ## Implementation Details 15 | 16 | - The code is implemented as described in the appendix section of the paper and the Official repository, i.e. without target networks and with bounded Q-values. 17 | - The Actor and Critic networks have 2 hidded layers of size 64. 18 | 19 | ## Citing 20 | 21 | Please use this bibtex if you want to cite this repository in your publications : 22 | 23 | @misc{pytorch_hac, 24 | author = {Barhate, Nikhil}, 25 | title = {PyTorch Implementation of Hierarchical Actor-Critic}, 26 | year = {2021}, 27 | publisher = {GitHub}, 28 | journal = {GitHub repository}, 29 | howpublished = {\url{https://github.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch}}, 30 | } 31 | 32 | ## Requirements 33 | 34 | - Python 3.6 35 | - [PyTorch](https://pytorch.org/) 36 | - [OpenAI gym](https://gym.openai.com/) 37 | 38 | 39 | 40 | ## Results 41 | 42 | ### MountainCarContinuous-v0 43 | (2 levels, H = 20, 200 episodes) | (3 levels, H = 5, 200 episodes) | 44 | :-----------------------------------:|:-----------------------------------:| 45 | ![](https://github.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/blob/master/gif/MountainCarContinuous-v0.gif) | ![](https://github.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/blob/master/gif/MountainCarContinuous-v0-3level.gif) | 46 | 47 | (2 levels, H = 20, 200 episodes) | 48 | :---------------------------------:| 49 | ![](https://github.com/nikhilbarhate99/Hierarchical-Actor-Critic-HAC-PyTorch/blob/master/gif/Pendulum-v0-2level.gif) | 50 | 51 | 52 | ## References 53 | 54 | - Official [Paper](https://arxiv.org/abs/1712.00948) and [Code (TensorFlow)](https://github.com/andrew-j-levy/Hierarchical-Actor-Critc-HAC-) 55 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gym 3 | import asset 4 | import numpy as np 5 | from HAC import HAC 6 | 7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | 9 | def test(): 10 | 11 | #################### Hyperparameters #################### 12 | env_name = "MountainCarContinuous-h-v1"#"MountainCarContinuous-v0" 13 | save_episode = 10 # keep saving every n episodes 14 | max_episodes = 5 # max num of training episodes 15 | random_seed = 0 16 | render = False 17 | 18 | env = gym.make(env_name) 19 | state_dim = env.observation_space.shape[0] 20 | action_dim = env.action_space.shape[0] 21 | 22 | """ 23 | Actions (both primitive and subgoal) are implemented as follows: 24 | action = ( network output (Tanh) * bounds ) + offset 25 | clip_high and clip_low bound the exploration noise 26 | """ 27 | 28 | # primitive action bounds and offset 29 | action_bounds = env.action_space.high[0] 30 | action_offset = np.array([0.0]) 31 | action_offset = torch.FloatTensor(action_offset.reshape(1, -1)).to(device) 32 | action_clip_low = np.array([-1.0 * action_bounds]) 33 | action_clip_high = np.array([action_bounds]) 34 | 35 | # state bounds and offset 36 | state_bounds_np = np.array([0.9, 0.07]) 37 | state_bounds = torch.FloatTensor(state_bounds_np.reshape(1, -1)).to(device) 38 | state_offset = np.array([-0.3, 0.0]) 39 | state_offset = torch.FloatTensor(state_offset.reshape(1, -1)).to(device) 40 | state_clip_low = np.array([-1.2, -0.07]) 41 | state_clip_high = np.array([0.6, 0.07]) 42 | 43 | # exploration noise std for primitive action and subgoals 44 | exploration_action_noise = np.array([0.1]) 45 | exploration_state_noise = np.array([0.02, 0.01]) 46 | 47 | goal_state = np.array([0.48, 0.04]) # final goal state to be achived 48 | threshold = np.array([0.01, 0.02]) # threshold value to check if goal state is achieved 49 | 50 | # HAC parameters: 51 | k_level = 2 # num of levels in hierarchy 52 | H = 20 # time horizon to achieve subgoal 53 | lamda = 0.3 # subgoal testing parameter 54 | 55 | # DDPG parameters: 56 | gamma = 0.95 # discount factor for future rewards 57 | n_iter = 100 # update policy n_iter times in one DDPG update 58 | batch_size = 100 # num of transitions sampled from replay buffer 59 | lr = 0.001 60 | 61 | # save trained models 62 | directory = "./preTrained/{}/{}level/".format(env_name, k_level) 63 | filename = "HAC_{}".format(env_name) 64 | ######################################################### 65 | 66 | if random_seed: 67 | print("Random Seed: {}".format(random_seed)) 68 | env.seed(random_seed) 69 | torch.manual_seed(random_seed) 70 | np.random.seed(random_seed) 71 | 72 | # creating HAC agent and setting parameters 73 | agent = HAC(k_level, H, state_dim, action_dim, render, threshold, 74 | action_bounds, action_offset, state_bounds, state_offset, lr) 75 | 76 | agent.set_parameters(lamda, gamma, action_clip_low, action_clip_high, 77 | state_clip_low, state_clip_high, exploration_action_noise, exploration_state_noise) 78 | 79 | # load agent 80 | agent.load(directory, filename) 81 | 82 | # Evaluation 83 | for i_episode in range(1, max_episodes+1): 84 | 85 | agent.reward = 0 86 | agent.timestep = 0 87 | 88 | state = env.reset() 89 | agent.run_HAC(env, k_level-1, state, goal_state, True) 90 | 91 | print("Episode: {}\t Reward: {}\t len: {}".format(i_episode, agent.reward, agent.timestep)) 92 | 93 | env.close() 94 | 95 | 96 | if __name__ == '__main__': 97 | test() 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gym 3 | import asset 4 | import numpy as np 5 | from HAC import HAC 6 | 7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | 9 | def train(): 10 | #################### Hyperparameters #################### 11 | env_name = "MountainCarContinuous-h-v1" 12 | save_episode = 10 # keep saving every n episodes 13 | max_episodes = 1000 # max num of training episodes 14 | random_seed = 0 15 | render = False 16 | 17 | env = gym.make(env_name) 18 | state_dim = env.observation_space.shape[0] 19 | action_dim = env.action_space.shape[0] 20 | 21 | """ 22 | Actions (both primitive and subgoal) are implemented as follows: 23 | action = ( network output (Tanh) * bounds ) + offset 24 | clip_high and clip_low bound the exploration noise 25 | """ 26 | 27 | # primitive action bounds and offset 28 | action_bounds = env.action_space.high[0] 29 | action_offset = np.array([0.0]) 30 | action_offset = torch.FloatTensor(action_offset.reshape(1, -1)).to(device) 31 | action_clip_low = np.array([-1.0 * action_bounds]) 32 | action_clip_high = np.array([action_bounds]) 33 | 34 | # state bounds and offset 35 | state_bounds_np = np.array([0.9, 0.07]) 36 | state_bounds = torch.FloatTensor(state_bounds_np.reshape(1, -1)).to(device) 37 | state_offset = np.array([-0.3, 0.0]) 38 | state_offset = torch.FloatTensor(state_offset.reshape(1, -1)).to(device) 39 | state_clip_low = np.array([-1.2, -0.07]) 40 | state_clip_high = np.array([0.6, 0.07]) 41 | 42 | # exploration noise std for primitive action and subgoals 43 | exploration_action_noise = np.array([0.1]) 44 | exploration_state_noise = np.array([0.02, 0.01]) 45 | 46 | goal_state = np.array([0.48, 0.04]) # final goal state to be achived 47 | threshold = np.array([0.01, 0.02]) # threshold value to check if goal state is achieved 48 | 49 | # HAC parameters: 50 | k_level = 2 # num of levels in hierarchy 51 | H = 20 # time horizon to achieve subgoal 52 | lamda = 0.3 # subgoal testing parameter 53 | 54 | # DDPG parameters: 55 | gamma = 0.95 # discount factor for future rewards 56 | n_iter = 100 # update policy n_iter times in one DDPG update 57 | batch_size = 100 # num of transitions sampled from replay buffer 58 | lr = 0.001 59 | 60 | # save trained models 61 | directory = "./preTrained/{}/{}level/".format(env_name, k_level) 62 | filename = "HAC_{}".format(env_name) 63 | ######################################################### 64 | 65 | 66 | if random_seed: 67 | print("Random Seed: {}".format(random_seed)) 68 | env.seed(random_seed) 69 | torch.manual_seed(random_seed) 70 | np.random.seed(random_seed) 71 | 72 | # creating HAC agent and setting parameters 73 | agent = HAC(k_level, H, state_dim, action_dim, render, threshold, 74 | action_bounds, action_offset, state_bounds, state_offset, lr) 75 | 76 | agent.set_parameters(lamda, gamma, action_clip_low, action_clip_high, 77 | state_clip_low, state_clip_high, exploration_action_noise, exploration_state_noise) 78 | 79 | # logging file: 80 | log_f = open("log.txt","w+") 81 | 82 | # training procedure 83 | for i_episode in range(1, max_episodes+1): 84 | agent.reward = 0 85 | agent.timestep = 0 86 | 87 | state = env.reset() 88 | # collecting experience in environment 89 | last_state, done = agent.run_HAC(env, k_level-1, state, goal_state, False) 90 | 91 | if agent.check_goal(last_state, goal_state, threshold): 92 | print("################ Solved! ################ ") 93 | name = filename + '_solved' 94 | agent.save(directory, name) 95 | 96 | # update all levels 97 | agent.update(n_iter, batch_size) 98 | 99 | # logging updates: 100 | log_f.write('{},{}\n'.format(i_episode, agent.reward)) 101 | log_f.flush() 102 | 103 | if i_episode % save_episode == 0: 104 | agent.save(directory, filename) 105 | 106 | print("Episode: {}\t Reward: {}".format(i_episode, agent.reward)) 107 | 108 | 109 | if __name__ == '__main__': 110 | train() 111 | 112 | -------------------------------------------------------------------------------- /DDPG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 6 | 7 | class Actor(nn.Module): 8 | def __init__(self, state_dim, action_dim, action_bounds, offset): 9 | super(Actor, self).__init__() 10 | # actor 11 | self.actor = nn.Sequential( 12 | nn.Linear(state_dim + state_dim, 64), 13 | nn.ReLU(), 14 | nn.Linear(64, 64), 15 | nn.ReLU(), 16 | nn.Linear(64, action_dim), 17 | nn.Tanh() 18 | ) 19 | # max value of actions 20 | self.action_bounds = action_bounds 21 | self.offset = offset 22 | 23 | def forward(self, state, goal): 24 | return (self.actor(torch.cat([state, goal], 1)) * self.action_bounds) + self.offset 25 | 26 | class Critic(nn.Module): 27 | def __init__(self, state_dim, action_dim, H): 28 | super(Critic, self).__init__() 29 | # UVFA critic 30 | self.critic = nn.Sequential( 31 | nn.Linear(state_dim + action_dim + state_dim, 64), 32 | nn.ReLU(), 33 | nn.Linear(64, 64), 34 | nn.ReLU(), 35 | nn.Linear(64, 1), 36 | nn.Sigmoid() 37 | ) 38 | self.H = H 39 | 40 | def forward(self, state, action, goal): 41 | # rewards are in range [-H, 0] 42 | return -self.critic(torch.cat([state, action, goal], 1)) * self.H 43 | 44 | class DDPG: 45 | def __init__(self, state_dim, action_dim, action_bounds, offset, lr, H): 46 | 47 | self.actor = Actor(state_dim, action_dim, action_bounds, offset).to(device) 48 | self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr) 49 | 50 | self.critic = Critic(state_dim, action_dim, H).to(device) 51 | self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr) 52 | 53 | self.mseLoss = torch.nn.MSELoss() 54 | 55 | def select_action(self, state, goal): 56 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 57 | goal = torch.FloatTensor(goal.reshape(1, -1)).to(device) 58 | return self.actor(state, goal).detach().cpu().data.numpy().flatten() 59 | 60 | def update(self, buffer, n_iter, batch_size): 61 | 62 | for i in range(n_iter): 63 | # Sample a batch of transitions from replay buffer: 64 | state, action, reward, next_state, goal, gamma, done = buffer.sample(batch_size) 65 | 66 | # convert np arrays into tensors 67 | state = torch.FloatTensor(state).to(device) 68 | action = torch.FloatTensor(action).to(device) 69 | reward = torch.FloatTensor(reward).reshape((batch_size,1)).to(device) 70 | next_state = torch.FloatTensor(next_state).to(device) 71 | goal = torch.FloatTensor(goal).to(device) 72 | gamma = torch.FloatTensor(gamma).reshape((batch_size,1)).to(device) 73 | done = torch.FloatTensor(done).reshape((batch_size,1)).to(device) 74 | 75 | # select next action 76 | next_action = self.actor(next_state, goal).detach() 77 | 78 | # Compute target Q-value: 79 | target_Q = self.critic(next_state, next_action, goal).detach() 80 | target_Q = reward + ((1-done) * gamma * target_Q) 81 | 82 | # Optimize Critic: 83 | critic_loss = self.mseLoss(self.critic(state, action, goal), target_Q) 84 | self.critic_optimizer.zero_grad() 85 | critic_loss.backward() 86 | self.critic_optimizer.step() 87 | 88 | # Compute actor loss: 89 | actor_loss = -self.critic(state, self.actor(state, goal), goal).mean() 90 | 91 | # Optimize the actor 92 | self.actor_optimizer.zero_grad() 93 | actor_loss.backward() 94 | self.actor_optimizer.step() 95 | 96 | 97 | def save(self, directory, name): 98 | torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, name)) 99 | torch.save(self.critic.state_dict(), '%s/%s_crtic.pth' % (directory, name)) 100 | 101 | def load(self, directory, name): 102 | self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, name), map_location='cpu')) 103 | self.critic.load_state_dict(torch.load('%s/%s_crtic.pth' % (directory, name), map_location='cpu')) 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /HAC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from DDPG import DDPG 4 | from utils import ReplayBuffer 5 | 6 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 7 | 8 | class HAC: 9 | def __init__(self, k_level, H, state_dim, action_dim, render, threshold, 10 | action_bounds, action_offset, state_bounds, state_offset, lr): 11 | 12 | # adding lowest level 13 | self.HAC = [DDPG(state_dim, action_dim, action_bounds, action_offset, lr, H)] 14 | self.replay_buffer = [ReplayBuffer()] 15 | 16 | # adding remaining levels 17 | for _ in range(k_level-1): 18 | self.HAC.append(DDPG(state_dim, state_dim, state_bounds, state_offset, lr, H)) 19 | self.replay_buffer.append(ReplayBuffer()) 20 | 21 | # set some parameters 22 | self.k_level = k_level 23 | self.H = H 24 | self.action_dim = action_dim 25 | self.state_dim = state_dim 26 | self.threshold = threshold 27 | self.render = render 28 | 29 | # logging parameters 30 | self.goals = [None]*self.k_level 31 | self.reward = 0 32 | self.timestep = 0 33 | 34 | def set_parameters(self, lamda, gamma, action_clip_low, action_clip_high, 35 | state_clip_low, state_clip_high, exploration_action_noise, exploration_state_noise): 36 | 37 | self.lamda = lamda 38 | self.gamma = gamma 39 | self.action_clip_low = action_clip_low 40 | self.action_clip_high = action_clip_high 41 | self.state_clip_low = state_clip_low 42 | self.state_clip_high = state_clip_high 43 | self.exploration_action_noise = exploration_action_noise 44 | self.exploration_state_noise = exploration_state_noise 45 | 46 | 47 | def check_goal(self, state, goal, threshold): 48 | for i in range(self.state_dim): 49 | if abs(state[i]-goal[i]) > threshold[i]: 50 | return False 51 | return True 52 | 53 | 54 | def run_HAC(self, env, i_level, state, goal, is_subgoal_test): 55 | next_state = None 56 | done = None 57 | goal_transitions = [] 58 | 59 | # logging updates 60 | self.goals[i_level] = goal 61 | 62 | # H attempts 63 | for _ in range(self.H): 64 | # if this is a subgoal test, then next/lower level goal has to be a subgoal test 65 | is_next_subgoal_test = is_subgoal_test 66 | 67 | action = self.HAC[i_level].select_action(state, goal) 68 | 69 | # <================ high level policy ================> 70 | if i_level > 0: 71 | # add noise or take random action if not subgoal testing 72 | if not is_subgoal_test: 73 | if np.random.random_sample() > 0.2: 74 | action = action + np.random.normal(0, self.exploration_state_noise) 75 | action = action.clip(self.state_clip_low, self.state_clip_high) 76 | else: 77 | action = np.random.uniform(self.state_clip_low, self.state_clip_high) 78 | 79 | # Determine whether to test subgoal (action) 80 | if np.random.random_sample() < self.lamda: 81 | is_next_subgoal_test = True 82 | 83 | # Pass subgoal to lower level 84 | next_state, done = self.run_HAC(env, i_level-1, state, action, is_next_subgoal_test) 85 | 86 | # if subgoal was tested but not achieved, add subgoal testing transition 87 | if is_next_subgoal_test and not self.check_goal(action, next_state, self.threshold): 88 | self.replay_buffer[i_level].add((state, action, -self.H, next_state, goal, 0.0, float(done))) 89 | 90 | # for hindsight action transition 91 | action = next_state 92 | 93 | # <================ low level policy ================> 94 | else: 95 | # add noise or take random action if not subgoal testing 96 | if not is_subgoal_test: 97 | if np.random.random_sample() > 0.2: 98 | action = action + np.random.normal(0, self.exploration_action_noise) 99 | action = action.clip(self.action_clip_low, self.action_clip_high) 100 | else: 101 | action = np.random.uniform(self.action_clip_low, self.action_clip_high) 102 | 103 | # take primitive action 104 | next_state, rew, done, _ = env.step(action) 105 | 106 | if self.render: 107 | 108 | # env.render() ########## 109 | 110 | if self.k_level == 2: 111 | env.unwrapped.render_goal(self.goals[0], self.goals[1]) 112 | elif self.k_level == 3: 113 | env.unwrapped.render_goal_2(self.goals[0], self.goals[1], self.goals[2]) 114 | 115 | 116 | # this is for logging 117 | self.reward += rew 118 | self.timestep +=1 119 | 120 | # <================ finish one step/transition ================> 121 | 122 | # check if goal is achieved 123 | goal_achieved = self.check_goal(next_state, goal, self.threshold) 124 | 125 | # hindsight action transition 126 | if goal_achieved: 127 | self.replay_buffer[i_level].add((state, action, 0.0, next_state, goal, 0.0, float(done))) 128 | else: 129 | self.replay_buffer[i_level].add((state, action, -1.0, next_state, goal, self.gamma, float(done))) 130 | 131 | # copy for goal transition 132 | goal_transitions.append([state, action, -1.0, next_state, None, self.gamma, float(done)]) 133 | 134 | state = next_state 135 | 136 | if done or goal_achieved: 137 | break 138 | 139 | 140 | # <================ finish H attempts ================> 141 | 142 | # hindsight goal transition 143 | # last transition reward and discount is 0 144 | goal_transitions[-1][2] = 0.0 145 | goal_transitions[-1][5] = 0.0 146 | for transition in goal_transitions: 147 | # last state is goal for all transitions 148 | transition[4] = next_state 149 | self.replay_buffer[i_level].add(tuple(transition)) 150 | 151 | return next_state, done 152 | 153 | 154 | def update(self, n_iter, batch_size): 155 | for i in range(self.k_level): 156 | self.HAC[i].update(self.replay_buffer[i], n_iter, batch_size) 157 | 158 | 159 | def save(self, directory, name): 160 | for i in range(self.k_level): 161 | self.HAC[i].save(directory, name+'_level_{}'.format(i)) 162 | 163 | 164 | def load(self, directory, name): 165 | for i in range(self.k_level): 166 | self.HAC[i].load(directory, name+'_level_{}'.format(i)) 167 | 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /asset/pendulum.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | from gym.utils import seeding 4 | import numpy as np 5 | from os import path 6 | 7 | class PendulumEnv(gym.Env): 8 | metadata = { 9 | 'render.modes' : ['human', 'rgb_array'], 10 | 'video.frames_per_second' : 30 11 | } 12 | 13 | def __init__(self): 14 | self.max_speed=8 15 | self.max_torque=2. 16 | self.dt=.05 17 | self.viewer = None 18 | 19 | high = np.array([1., 1., self.max_speed]) 20 | self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32) 21 | self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32) 22 | 23 | self.seed() 24 | 25 | def seed(self, seed=None): 26 | self.np_random, seed = seeding.np_random(seed) 27 | return [seed] 28 | 29 | def step(self,u): 30 | th, thdot = self.state # th := theta 31 | 32 | g = 10. 33 | m = 1. 34 | l = 1. 35 | dt = self.dt 36 | 37 | u = np.clip(u, -self.max_torque, self.max_torque)[0] 38 | self.last_u = u # for rendering 39 | costs = angle_normalize(th)**2 + .1*thdot**2 + .001*(u**2) 40 | 41 | newthdot = thdot + (-3*g/(2*l) * np.sin(th + np.pi) + 3./(m*l**2)*u) * dt 42 | 43 | newth = angle_normalize(th + newthdot*dt) ##### 44 | 45 | newthdot = np.clip(newthdot, -self.max_speed, self.max_speed) #pylint: disable=E1111 46 | 47 | self.state = np.array([newth, newthdot]) 48 | # return self._get_obs(), -costs, False, {} 49 | return self.state, -costs, False, {} ##### 50 | 51 | def reset(self): 52 | high = np.array([np.pi, 1]) 53 | self.state = self.np_random.uniform(low=-high, high=high) 54 | self.last_u = None 55 | # return self._get_obs() 56 | return self.state ##### 57 | 58 | def _get_obs(self): 59 | theta, thetadot = self.state 60 | return np.array([np.cos(theta), np.sin(theta), thetadot]) 61 | 62 | def render(self, mode='human'): 63 | 64 | if self.viewer is None: 65 | from asset import rendering 66 | self.viewer = rendering.Viewer(500,500) 67 | self.viewer.set_bounds(-2.2,2.2,-2.2,2.2) 68 | rod = rendering.make_capsule(1, .2) 69 | rod.set_color(.8, .3, .3) 70 | self.pole_transform = rendering.Transform() 71 | rod.add_attr(self.pole_transform) 72 | self.viewer.add_geom(rod) 73 | axle = rendering.make_circle(.05) 74 | axle.set_color(0,0,0) 75 | self.viewer.add_geom(axle) 76 | fname = path.join(path.dirname(__file__), "assets/clockwise.png") 77 | self.img = rendering.Image(fname, 1., 1.) 78 | self.imgtrans = rendering.Transform() 79 | self.img.add_attr(self.imgtrans) 80 | 81 | self.viewer.add_onetime(self.img) 82 | self.pole_transform.set_rotation(self.state[0] + np.pi/2) 83 | if self.last_u: 84 | self.imgtrans.scale = (-self.last_u/2, np.abs(self.last_u)/2) 85 | 86 | return self.viewer.render(return_rgb_array = mode=='rgb_array') 87 | 88 | 89 | def render_goal(self, goal, end_goal, mode='human'): 90 | 91 | if self.viewer is None: 92 | from asset import rendering 93 | self.viewer = rendering.Viewer(500,500) 94 | self.viewer.set_bounds(-2.2,2.2,-2.2,2.2) 95 | 96 | rod = rendering.make_capsule(1, .2) 97 | rod.set_color(.8, .3, .3) 98 | self.pole_transform = rendering.Transform() 99 | rod.add_attr(self.pole_transform) 100 | self.viewer.add_geom(rod) 101 | 102 | ################ goal ################ 103 | rod1 = rendering.make_goal_circ(1, .1) 104 | rod1.set_color(.8, .8, .3) 105 | self.pole_transform1 = rendering.Transform() 106 | rod1.add_attr(self.pole_transform1) 107 | self.viewer.add_geom(rod1) 108 | ###################################### 109 | 110 | ############## End Goal ############## 111 | rod2 = rendering.make_goal_circ(1, .1) 112 | rod2.set_color(.3, .3, .8) 113 | self.pole_transform2 = rendering.Transform() 114 | rod2.add_attr(self.pole_transform2) 115 | self.viewer.add_geom(rod2) 116 | ###################################### 117 | 118 | axle = rendering.make_circle(.05) 119 | axle.set_color(0,0,0) 120 | self.viewer.add_geom(axle) 121 | fname = path.join(path.dirname(__file__), "assets/clockwise.png") 122 | self.img = rendering.Image(fname, 1., 1.) 123 | self.imgtrans = rendering.Transform() 124 | self.img.add_attr(self.imgtrans) 125 | 126 | # self.viewer.add_onetime(self.img) 127 | self.pole_transform.set_rotation(self.state[0] + np.pi/2) 128 | 129 | self.pole_transform1.set_rotation(goal[0] + np.pi/2) 130 | 131 | self.pole_transform2.set_rotation(end_goal[0] + np.pi/2) 132 | 133 | if self.last_u: 134 | self.imgtrans.scale = (-self.last_u/2, np.abs(self.last_u)/2) 135 | 136 | return self.viewer.render(return_rgb_array = mode=='rgb_array') 137 | 138 | def render_goal_2(self, goal1, goal2, end_goal, mode='human'): 139 | 140 | if self.viewer is None: 141 | from asset import rendering 142 | self.viewer = rendering.Viewer(500,500) 143 | self.viewer.set_bounds(-2.2,2.2,-2.2,2.2) 144 | 145 | rod = rendering.make_capsule(1, .2) 146 | rod.set_color(.8, .3, .3) 147 | self.pole_transform = rendering.Transform() 148 | rod.add_attr(self.pole_transform) 149 | self.viewer.add_geom(rod) 150 | 151 | 152 | ################ goal 1 ################ 153 | rod1 = rendering.make_goal_circ(1, .1) 154 | rod1.set_color(.8, .8, .3) 155 | self.pole_transform1 = rendering.Transform() 156 | rod1.add_attr(self.pole_transform1) 157 | self.viewer.add_geom(rod1) 158 | ######################################## 159 | 160 | 161 | ################ goal 2 ################ 162 | rod2 = rendering.make_goal_circ(1, .1) 163 | rod2.set_color(.3, .8, .3) 164 | self.pole_transform2 = rendering.Transform() 165 | rod2.add_attr(self.pole_transform2) 166 | self.viewer.add_geom(rod2) 167 | ######################################## 168 | 169 | 170 | ############### End Goal ############### 171 | rod3 = rendering.make_goal_circ(1, .1) 172 | rod3.set_color(.3, .3, .8) 173 | self.pole_transform3 = rendering.Transform() 174 | rod3.add_attr(self.pole_transform3) 175 | self.viewer.add_geom(rod3) 176 | ######################################## 177 | 178 | axle = rendering.make_circle(.05) 179 | axle.set_color(0,0,0) 180 | self.viewer.add_geom(axle) 181 | fname = path.join(path.dirname(__file__), "assets/clockwise.png") 182 | self.img = rendering.Image(fname, 1., 1.) 183 | self.imgtrans = rendering.Transform() 184 | self.img.add_attr(self.imgtrans) 185 | 186 | # self.viewer.add_onetime(self.img) 187 | 188 | self.pole_transform.set_rotation(self.state[0] + np.pi/2) 189 | 190 | self.pole_transform1.set_rotation(goal1[0] + np.pi/2) 191 | 192 | self.pole_transform2.set_rotation(goal2[0] + np.pi/2) 193 | 194 | self.pole_transform3.set_rotation(end_goal[0] + np.pi/2) 195 | 196 | if self.last_u: 197 | self.imgtrans.scale = (-self.last_u/2, np.abs(self.last_u)/2) 198 | 199 | return self.viewer.render(return_rgb_array = mode=='rgb_array') 200 | 201 | 202 | def close(self): 203 | if self.viewer: 204 | self.viewer.close() 205 | self.viewer = None 206 | 207 | 208 | def angle_normalize(x): 209 | return (((x+np.pi) % (2*np.pi)) - np.pi) 210 | -------------------------------------------------------------------------------- /asset/rendering.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2D rendering framework 3 | """ 4 | from __future__ import division 5 | import os 6 | import six 7 | import sys 8 | 9 | if "Apple" in sys.version: 10 | if 'DYLD_FALLBACK_LIBRARY_PATH' in os.environ: 11 | os.environ['DYLD_FALLBACK_LIBRARY_PATH'] += ':/usr/lib' 12 | # (JDS 2016/04/15): avoid bug on Anaconda 2.3.0 / Yosemite 13 | 14 | from gym import error 15 | 16 | try: 17 | import pyglet 18 | except ImportError as e: 19 | raise ImportError(''' 20 | Cannot import pyglet. 21 | HINT: you can install pyglet directly via 'pip install pyglet'. 22 | But if you really just want to install all Gym dependencies and not have to think about it, 23 | 'pip install -e .[all]' or 'pip install gym[all]' will do it. 24 | ''') 25 | 26 | try: 27 | from pyglet.gl import * 28 | except ImportError as e: 29 | raise ImportError(''' 30 | Error occured while running `from pyglet.gl import *` 31 | HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get install python-opengl'. 32 | If you're running on a server, you may need a virtual frame buffer; something like this should work: 33 | 'xvfb-run -s \"-screen 0 1400x900x24\" python ' 34 | ''') 35 | 36 | import math 37 | import numpy as np 38 | 39 | RAD2DEG = 57.29577951308232 40 | 41 | def get_display(spec): 42 | """Convert a display specification (such as :0) into an actual Display 43 | object. 44 | 45 | Pyglet only supports multiple Displays on Linux. 46 | """ 47 | if spec is None: 48 | return None 49 | elif isinstance(spec, six.string_types): 50 | return pyglet.canvas.Display(spec) 51 | else: 52 | raise error.Error('Invalid display specification: {}. (Must be a string like :0 or None.)'.format(spec)) 53 | 54 | class Viewer(object): 55 | def __init__(self, width, height, display=None): 56 | display = get_display(display) 57 | 58 | self.width = width 59 | self.height = height 60 | self.window = pyglet.window.Window(width=width, height=height, display=display) 61 | self.window.on_close = self.window_closed_by_user 62 | self.isopen = True 63 | self.geoms = [] 64 | self.onetime_geoms = [] 65 | self.transform = Transform() 66 | 67 | glEnable(GL_BLEND) 68 | glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) 69 | 70 | def close(self): 71 | self.window.close() 72 | 73 | def window_closed_by_user(self): 74 | self.isopen = False 75 | 76 | def set_bounds(self, left, right, bottom, top): 77 | assert right > left and top > bottom 78 | scalex = self.width/(right-left) 79 | scaley = self.height/(top-bottom) 80 | self.transform = Transform( 81 | translation=(-left*scalex, -bottom*scaley), 82 | scale=(scalex, scaley)) 83 | 84 | def add_geom(self, geom): 85 | self.geoms.append(geom) 86 | 87 | def add_onetime(self, geom): 88 | self.onetime_geoms.append(geom) 89 | 90 | def render(self, return_rgb_array=False): 91 | glClearColor(1,1,1,1) 92 | self.window.clear() 93 | self.window.switch_to() 94 | self.window.dispatch_events() 95 | self.transform.enable() 96 | for geom in self.geoms: 97 | geom.render() 98 | for geom in self.onetime_geoms: 99 | geom.render() 100 | self.transform.disable() 101 | arr = None 102 | if return_rgb_array: 103 | buffer = pyglet.image.get_buffer_manager().get_color_buffer() 104 | image_data = buffer.get_image_data() 105 | arr = np.frombuffer(image_data.data, dtype=np.uint8) 106 | # In https://github.com/openai/gym-http-api/issues/2, we 107 | # discovered that someone using Xmonad on Arch was having 108 | # a window of size 598 x 398, though a 600 x 400 window 109 | # was requested. (Guess Xmonad was preserving a pixel for 110 | # the boundary.) So we use the buffer height/width rather 111 | # than the requested one. 112 | arr = arr.reshape(buffer.height, buffer.width, 4) 113 | arr = arr[::-1,:,0:3] 114 | self.window.flip() 115 | self.onetime_geoms = [] 116 | return arr if return_rgb_array else self.isopen 117 | 118 | # Convenience 119 | def draw_circle(self, radius=10, res=30, filled=True, **attrs): 120 | geom = make_circle(radius=radius, res=res, filled=filled) 121 | _add_attrs(geom, attrs) 122 | self.add_onetime(geom) 123 | return geom 124 | 125 | def draw_polygon(self, v, filled=True, **attrs): 126 | geom = make_polygon(v=v, filled=filled) 127 | _add_attrs(geom, attrs) 128 | self.add_onetime(geom) 129 | return geom 130 | 131 | def draw_polyline(self, v, **attrs): 132 | geom = make_polyline(v=v) 133 | _add_attrs(geom, attrs) 134 | self.add_onetime(geom) 135 | return geom 136 | 137 | def draw_line(self, start, end, **attrs): 138 | geom = Line(start, end) 139 | _add_attrs(geom, attrs) 140 | self.add_onetime(geom) 141 | return geom 142 | 143 | def get_array(self): 144 | self.window.flip() 145 | image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data() 146 | self.window.flip() 147 | arr = np.fromstring(image_data.data, dtype=np.uint8, sep='') 148 | arr = arr.reshape(self.height, self.width, 4) 149 | return arr[::-1,:,0:3] 150 | 151 | def __del__(self): 152 | self.close() 153 | 154 | def _add_attrs(geom, attrs): 155 | if "color" in attrs: 156 | geom.set_color(*attrs["color"]) 157 | if "linewidth" in attrs: 158 | geom.set_linewidth(attrs["linewidth"]) 159 | 160 | class Geom(object): 161 | def __init__(self): 162 | self._color=Color((0, 0, 0, 1.0)) 163 | self.attrs = [self._color] 164 | def render(self): 165 | for attr in reversed(self.attrs): 166 | attr.enable() 167 | self.render1() 168 | for attr in self.attrs: 169 | attr.disable() 170 | def render1(self): 171 | raise NotImplementedError 172 | def add_attr(self, attr): 173 | self.attrs.append(attr) 174 | def set_color(self, r, g, b): 175 | self._color.vec4 = (r, g, b, 1) 176 | 177 | class Attr(object): 178 | def enable(self): 179 | raise NotImplementedError 180 | def disable(self): 181 | pass 182 | 183 | class Transform(Attr): 184 | def __init__(self, translation=(0.0, 0.0), rotation=0.0, scale=(1,1)): 185 | self.set_translation(*translation) 186 | self.set_rotation(rotation) 187 | self.set_scale(*scale) 188 | def enable(self): 189 | glPushMatrix() 190 | glTranslatef(self.translation[0], self.translation[1], 0) # translate to GL loc ppint 191 | glRotatef(RAD2DEG * self.rotation, 0, 0, 1.0) 192 | glScalef(self.scale[0], self.scale[1], 1) 193 | def disable(self): 194 | glPopMatrix() 195 | def set_translation(self, newx, newy): 196 | self.translation = (float(newx), float(newy)) 197 | def set_rotation(self, new): 198 | self.rotation = float(new) 199 | def set_scale(self, newx, newy): 200 | self.scale = (float(newx), float(newy)) 201 | 202 | class Color(Attr): 203 | def __init__(self, vec4): 204 | self.vec4 = vec4 205 | def enable(self): 206 | glColor4f(*self.vec4) 207 | 208 | class LineStyle(Attr): 209 | def __init__(self, style): 210 | self.style = style 211 | def enable(self): 212 | glEnable(GL_LINE_STIPPLE) 213 | glLineStipple(1, self.style) 214 | def disable(self): 215 | glDisable(GL_LINE_STIPPLE) 216 | 217 | class LineWidth(Attr): 218 | def __init__(self, stroke): 219 | self.stroke = stroke 220 | def enable(self): 221 | glLineWidth(self.stroke) 222 | 223 | class Point(Geom): 224 | def __init__(self): 225 | Geom.__init__(self) 226 | def render1(self): 227 | glBegin(GL_POINTS) # draw point 228 | glVertex3f(0.0, 0.0, 0.0) 229 | glEnd() 230 | 231 | class FilledPolygon(Geom): 232 | def __init__(self, v): 233 | Geom.__init__(self) 234 | self.v = v 235 | def render1(self): 236 | if len(self.v) == 4 : glBegin(GL_QUADS) 237 | elif len(self.v) > 4 : glBegin(GL_POLYGON) 238 | else: glBegin(GL_TRIANGLES) 239 | for p in self.v: 240 | glVertex3f(p[0], p[1],0) # draw each vertex 241 | glEnd() 242 | 243 | def make_circle(radius=10, res=30, filled=True): 244 | points = [] 245 | for i in range(res): 246 | ang = 2*math.pi*i / res 247 | points.append((math.cos(ang)*radius, math.sin(ang)*radius)) 248 | if filled: 249 | return FilledPolygon(points) 250 | else: 251 | return PolyLine(points, True) 252 | 253 | def make_polygon(v, filled=True): 254 | if filled: return FilledPolygon(v) 255 | else: return PolyLine(v, True) 256 | 257 | def make_polyline(v): 258 | return PolyLine(v, False) 259 | 260 | def make_capsule(length, width): 261 | l, r, t, b = 0, length, width/2, -width/2 262 | box = make_polygon([(l,b), (l,t), (r,t), (r,b)]) 263 | circ0 = make_circle(width/2) 264 | circ1 = make_circle(width/2) 265 | circ1.add_attr(Transform(translation=(length, 0))) 266 | geom = Compound([box, circ0, circ1]) 267 | return geom 268 | 269 | def make_goal_circ(length, radius, width=0.00001): 270 | l, r, t, b = 0, length, width/2, -width/2 271 | box = make_polygon([(l,b), (l,t), (r,t), (r,b)]) 272 | circ0 = make_circle(width/2) 273 | circ1 = make_circle(radius) 274 | circ1.add_attr(Transform(translation=(length, 0))) 275 | geom = Compound([box, circ0, circ1]) 276 | return geom 277 | 278 | class Compound(Geom): 279 | def __init__(self, gs): 280 | Geom.__init__(self) 281 | self.gs = gs 282 | for g in self.gs: 283 | g.attrs = [a for a in g.attrs if not isinstance(a, Color)] 284 | def render1(self): 285 | for g in self.gs: 286 | g.render() 287 | 288 | class PolyLine(Geom): 289 | def __init__(self, v, close): 290 | Geom.__init__(self) 291 | self.v = v 292 | self.close = close 293 | self.linewidth = LineWidth(1) 294 | self.add_attr(self.linewidth) 295 | def render1(self): 296 | glBegin(GL_LINE_LOOP if self.close else GL_LINE_STRIP) 297 | for p in self.v: 298 | glVertex3f(p[0], p[1],0) # draw each vertex 299 | glEnd() 300 | def set_linewidth(self, x): 301 | self.linewidth.stroke = x 302 | 303 | class Line(Geom): 304 | def __init__(self, start=(0.0, 0.0), end=(0.0, 0.0)): 305 | Geom.__init__(self) 306 | self.start = start 307 | self.end = end 308 | self.linewidth = LineWidth(1) 309 | self.add_attr(self.linewidth) 310 | 311 | def render1(self): 312 | glBegin(GL_LINES) 313 | glVertex2f(*self.start) 314 | glVertex2f(*self.end) 315 | glEnd() 316 | 317 | class Image(Geom): 318 | def __init__(self, fname, width, height): 319 | Geom.__init__(self) 320 | self.width = width 321 | self.height = height 322 | img = pyglet.image.load(fname) 323 | self.img = img 324 | self.flip = False 325 | def render1(self): 326 | self.img.blit(-self.width/2, -self.height/2, width=self.width, height=self.height) 327 | 328 | # ================================================================ 329 | 330 | class SimpleImageViewer(object): 331 | def __init__(self, display=None, maxwidth=500): 332 | self.window = None 333 | self.isopen = False 334 | self.display = display 335 | self.maxwidth = maxwidth 336 | def imshow(self, arr): 337 | if self.window is None: 338 | height, width, _channels = arr.shape 339 | if width > self.maxwidth: 340 | scale = self.maxwidth / width 341 | width = int(scale * width) 342 | height = int(scale * height) 343 | self.window = pyglet.window.Window(width=width, height=height, 344 | display=self.display, vsync=False, resizable=True) 345 | self.width = width 346 | self.height = height 347 | self.isopen = True 348 | 349 | @self.window.event 350 | def on_resize(width, height): 351 | self.width = width 352 | self.height = height 353 | 354 | @self.window.event 355 | def on_close(): 356 | self.isopen = False 357 | 358 | assert len(arr.shape) == 3, "You passed in an image with the wrong number shape" 359 | image = pyglet.image.ImageData(arr.shape[1], arr.shape[0], 360 | 'RGB', arr.tobytes(), pitch=arr.shape[1]*-3) 361 | gl.glTexParameteri(gl.GL_TEXTURE_2D, 362 | gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST) 363 | texture = image.get_texture() 364 | texture.width = self.width 365 | texture.height = self.height 366 | self.window.clear() 367 | self.window.switch_to() 368 | self.window.dispatch_events() 369 | texture.blit(0, 0) # draw 370 | self.window.flip() 371 | def close(self): 372 | if self.isopen and sys.meta_path: 373 | # ^^^ check sys.meta_path to avoid 'ImportError: sys.meta_path is None, Python is likely shutting down' 374 | self.window.close() 375 | self.isopen = False 376 | 377 | def __del__(self): 378 | self.close() 379 | -------------------------------------------------------------------------------- /asset/continuous_mountain_car.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: Olivier Sigaud 4 | 5 | A merge between two sources: 6 | 7 | * Adaptation of the MountainCar Environment from the "FAReinforcement" library 8 | of Jose Antonio Martin H. (version 1.0), adapted by 'Tom Schaul, tom@idsia.ch' 9 | and then modified by Arnaud de Broissia 10 | 11 | * the OpenAI/gym MountainCar environment 12 | itself from 13 | http://incompleteideas.net/sutton/MountainCar/MountainCar1.cp 14 | permalink: https://perma.cc/6Z2N-PFWC 15 | """ 16 | 17 | import math 18 | 19 | import numpy as np 20 | 21 | import gym 22 | from gym import spaces 23 | from gym.utils import seeding 24 | 25 | class Continuous_MountainCarEnv(gym.Env): 26 | metadata = { 27 | 'render.modes': ['human', 'rgb_array'], 28 | 'video.frames_per_second': 30 29 | } 30 | 31 | def __init__(self): 32 | self.min_action = -1.0 33 | self.max_action = 1.0 34 | self.min_position = -1.2 35 | self.max_position = 0.6 36 | self.max_speed = 0.07 37 | self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version 38 | self.power = 0.0015 39 | 40 | self.low_state = np.array([self.min_position, -self.max_speed]) 41 | self.high_state = np.array([self.max_position, self.max_speed]) 42 | 43 | self.viewer = None 44 | 45 | self.action_space = spaces.Box(low=self.min_action, high=self.max_action, 46 | shape=(1,), dtype=np.float32) 47 | self.observation_space = spaces.Box(low=self.low_state, high=self.high_state, 48 | dtype=np.float32) 49 | 50 | self.seed() 51 | self.reset() 52 | 53 | def seed(self, seed=None): 54 | self.np_random, seed = seeding.np_random(seed) 55 | return [seed] 56 | 57 | def step(self, action): 58 | 59 | position = self.state[0] 60 | velocity = self.state[1] 61 | force = min(max(action[0], -1.0), 1.0) 62 | 63 | velocity += force*self.power -0.0025 * math.cos(3*position) 64 | if (velocity > self.max_speed): velocity = self.max_speed 65 | if (velocity < -self.max_speed): velocity = -self.max_speed 66 | position += velocity 67 | if (position > self.max_position): position = self.max_position 68 | if (position < self.min_position): position = self.min_position 69 | if (position==self.min_position and velocity<0): velocity = 0 70 | 71 | done = bool(position >= self.goal_position) 72 | 73 | reward = 0 74 | if done: 75 | reward = 100.0 76 | reward-= math.pow(action[0],2)*0.1 77 | 78 | self.state = np.array([position, velocity]) 79 | return self.state, reward, done, {} 80 | 81 | def reset(self): 82 | self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0]) 83 | return np.array(self.state) 84 | 85 | # def get_state(self): 86 | # return self.state 87 | 88 | def _height(self, xs): 89 | return np.sin(3 * xs)*.45+.55 90 | 91 | def render(self, mode='human'): 92 | screen_width = 600 93 | screen_height = 400 94 | 95 | world_width = self.max_position - self.min_position 96 | scale = screen_width/world_width 97 | carwidth=40 98 | carheight=20 99 | 100 | if self.viewer is None: 101 | from asset import rendering 102 | self.viewer = rendering.Viewer(screen_width, screen_height) 103 | xs = np.linspace(self.min_position, self.max_position, 100) 104 | ys = self._height(xs) 105 | xys = list(zip((xs-self.min_position)*scale, ys*scale)) 106 | 107 | self.track = rendering.make_polyline(xys) 108 | self.track.set_linewidth(4) 109 | self.viewer.add_geom(self.track) 110 | 111 | clearance = 10 112 | 113 | l,r,t,b = -carwidth/2, carwidth/2, carheight, 0 114 | car = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) 115 | car.add_attr(rendering.Transform(translation=(0, clearance))) 116 | self.cartrans = rendering.Transform() 117 | car.add_attr(self.cartrans) 118 | self.viewer.add_geom(car) 119 | frontwheel = rendering.make_circle(carheight/2.5) 120 | frontwheel.set_color(.5, .5, .5) 121 | frontwheel.add_attr(rendering.Transform(translation=(carwidth/4,clearance))) 122 | frontwheel.add_attr(self.cartrans) 123 | self.viewer.add_geom(frontwheel) 124 | backwheel = rendering.make_circle(carheight/2.5) 125 | backwheel.add_attr(rendering.Transform(translation=(-carwidth/4,clearance))) 126 | backwheel.add_attr(self.cartrans) 127 | backwheel.set_color(.5, .5, .5) 128 | self.viewer.add_geom(backwheel) 129 | flagx = (self.goal_position-self.min_position)*scale 130 | flagy1 = self._height(self.goal_position)*scale 131 | flagy2 = flagy1 + 50 132 | flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2)) 133 | self.viewer.add_geom(flagpole) 134 | flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2-10), (flagx+25, flagy2-5)]) 135 | flag.set_color(.8,.8,0) 136 | self.viewer.add_geom(flag) 137 | 138 | pos = self.state[0] 139 | self.cartrans.set_translation((pos-self.min_position)*scale, self._height(pos)*scale) 140 | self.cartrans.set_rotation(math.cos(3 * pos)) 141 | 142 | return self.viewer.render(return_rgb_array = mode=='rgb_array') 143 | 144 | 145 | def render_goal(self, goal, end_goal, mode='human'): 146 | screen_width = 600 147 | screen_height = 400 148 | 149 | world_width = self.max_position - self.min_position 150 | scale = screen_width/world_width 151 | carwidth=40 152 | carheight=20 153 | 154 | if self.viewer is None: 155 | from asset import rendering 156 | self.viewer = rendering.Viewer(screen_width, screen_height) 157 | xs = np.linspace(self.min_position, self.max_position, 100) 158 | ys = self._height(xs) 159 | xys = list(zip((xs-self.min_position)*scale, ys*scale)) 160 | 161 | self.track = rendering.make_polyline(xys) 162 | self.track.set_linewidth(4) 163 | self.viewer.add_geom(self.track) 164 | 165 | clearance = 10 166 | 167 | l,r,t,b = -carwidth/2, carwidth/2, carheight, 0 168 | car = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) 169 | car.add_attr(rendering.Transform(translation=(0, clearance))) 170 | self.cartrans = rendering.Transform() 171 | car.add_attr(self.cartrans) 172 | self.viewer.add_geom(car) 173 | frontwheel = rendering.make_circle(carheight/2.5) 174 | frontwheel.set_color(.5, .5, .5) 175 | frontwheel.add_attr(rendering.Transform(translation=(carwidth/4,clearance))) 176 | frontwheel.add_attr(self.cartrans) 177 | self.viewer.add_geom(frontwheel) 178 | backwheel = rendering.make_circle(carheight/2.5) 179 | backwheel.add_attr(rendering.Transform(translation=(-carwidth/4,clearance))) 180 | backwheel.add_attr(self.cartrans) 181 | backwheel.set_color(.5, .5, .5) 182 | self.viewer.add_geom(backwheel) 183 | 184 | ################ Goal ################ 185 | car1 = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) 186 | car1.set_color(1, 0.2, 0.2) 187 | car1.add_attr(rendering.Transform(translation=(0, clearance))) 188 | self.cartrans1 = rendering.Transform() 189 | car1.add_attr(self.cartrans1) 190 | self.viewer.add_geom(car1) 191 | ###################################### 192 | 193 | ############## End Goal ############## 194 | car2 = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) 195 | car2.set_color(0.2, 0.2, 1) 196 | car2.add_attr(rendering.Transform(translation=(0, clearance))) 197 | self.cartrans2 = rendering.Transform() 198 | car2.add_attr(self.cartrans2) 199 | self.viewer.add_geom(car2) 200 | ###################################### 201 | 202 | flagx = (self.goal_position-self.min_position)*scale 203 | flagy1 = self._height(self.goal_position)*scale 204 | flagy2 = flagy1 + 50 205 | flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2)) 206 | self.viewer.add_geom(flagpole) 207 | flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2-10), (flagx+25, flagy2-5)]) 208 | flag.set_color(.8,.8,0) 209 | self.viewer.add_geom(flag) 210 | 211 | pos = self.state[0] 212 | self.cartrans.set_translation((pos-self.min_position)*scale, self._height(pos)*scale) 213 | self.cartrans.set_rotation(math.cos(3 * pos)) 214 | 215 | pos1 = goal[0] 216 | self.cartrans1.set_translation((pos1-self.min_position)*scale, self._height(pos1)*scale) 217 | self.cartrans1.set_rotation(math.cos(3 * pos1)) 218 | 219 | pos2 = end_goal[0] 220 | self.cartrans2.set_translation((pos2-self.min_position)*scale, self._height(pos2)*scale) 221 | self.cartrans2.set_rotation(math.cos(3 * pos2)) 222 | 223 | return self.viewer.render(return_rgb_array = mode=='rgb_array') 224 | 225 | 226 | 227 | def render_goal_2(self, goal1, goal2, end_goal, mode='human'): 228 | screen_width = 600 229 | screen_height = 400 230 | 231 | world_width = self.max_position - self.min_position 232 | scale = screen_width/world_width 233 | carwidth=40 234 | carheight=20 235 | 236 | 237 | if self.viewer is None: 238 | from asset import rendering 239 | self.viewer = rendering.Viewer(screen_width, screen_height) 240 | xs = np.linspace(self.min_position, self.max_position, 100) 241 | ys = self._height(xs) 242 | xys = list(zip((xs-self.min_position)*scale, ys*scale)) 243 | 244 | self.track = rendering.make_polyline(xys) 245 | self.track.set_linewidth(4) 246 | self.viewer.add_geom(self.track) 247 | 248 | clearance = 10 249 | 250 | l,r,t,b = -carwidth/2, carwidth/2, carheight, 0 251 | car = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) 252 | car.add_attr(rendering.Transform(translation=(0, clearance))) 253 | self.cartrans = rendering.Transform() 254 | car.add_attr(self.cartrans) 255 | self.viewer.add_geom(car) 256 | frontwheel = rendering.make_circle(carheight/2.5) 257 | frontwheel.set_color(.5, .5, .5) 258 | frontwheel.add_attr(rendering.Transform(translation=(carwidth/4,clearance))) 259 | frontwheel.add_attr(self.cartrans) 260 | self.viewer.add_geom(frontwheel) 261 | backwheel = rendering.make_circle(carheight/2.5) 262 | backwheel.add_attr(rendering.Transform(translation=(-carwidth/4,clearance))) 263 | backwheel.add_attr(self.cartrans) 264 | backwheel.set_color(.5, .5, .5) 265 | self.viewer.add_geom(backwheel) 266 | 267 | ############### Goal 1 ############### 268 | car1 = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) 269 | car1.set_color(1, 0.2, 0.2) 270 | car1.add_attr(rendering.Transform(translation=(0, clearance))) 271 | self.cartrans1 = rendering.Transform() 272 | car1.add_attr(self.cartrans1) 273 | self.viewer.add_geom(car1) 274 | ###################################### 275 | 276 | 277 | ############### Goal 2 ############### 278 | car2 = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) 279 | car2.set_color(0.2, 1, 0.2) 280 | car2.add_attr(rendering.Transform(translation=(0, clearance))) 281 | self.cartrans2 = rendering.Transform() 282 | car2.add_attr(self.cartrans2) 283 | self.viewer.add_geom(car2) 284 | ###################################### 285 | 286 | 287 | ############## End Goal ############## 288 | car3 = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) 289 | car3.set_color(0.2, 0.2, 1) 290 | car3.add_attr(rendering.Transform(translation=(0, clearance))) 291 | self.cartrans3 = rendering.Transform() 292 | car3.add_attr(self.cartrans3) 293 | self.viewer.add_geom(car3) 294 | ###################################### 295 | 296 | 297 | flagx = (self.goal_position-self.min_position)*scale 298 | flagy1 = self._height(self.goal_position)*scale 299 | flagy2 = flagy1 + 50 300 | flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2)) 301 | self.viewer.add_geom(flagpole) 302 | flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2-10), (flagx+25, flagy2-5)]) 303 | flag.set_color(.8,.8,0) 304 | self.viewer.add_geom(flag) 305 | 306 | pos = self.state[0] 307 | self.cartrans.set_translation((pos-self.min_position)*scale, self._height(pos)*scale) 308 | self.cartrans.set_rotation(math.cos(3 * pos)) 309 | 310 | pos1 = goal1[0] 311 | self.cartrans1.set_translation((pos1-self.min_position)*scale, self._height(pos1)*scale) 312 | self.cartrans1.set_rotation(math.cos(3 * pos1)) 313 | 314 | pos2 = goal2[0] 315 | self.cartrans2.set_translation((pos2-self.min_position)*scale, self._height(pos2)*scale) 316 | self.cartrans2.set_rotation(math.cos(3 * pos2)) 317 | 318 | pos3 = end_goal[0] 319 | self.cartrans3.set_translation((pos3-self.min_position)*scale, self._height(pos3)*scale) 320 | self.cartrans3.set_rotation(math.cos(3 * pos3)) 321 | 322 | return self.viewer.render(return_rgb_array = mode=='rgb_array') 323 | 324 | 325 | 326 | 327 | def close(self): 328 | if self.viewer: 329 | self.viewer.close() 330 | self.viewer = None 331 | --------------------------------------------------------------------------------