├── .gitignore
├── README.md
├── envs
├── README.md
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── ant.cpython-36.pyc
│ ├── ant.cpython-37.pyc
│ ├── ant_maze_env.cpython-36.pyc
│ ├── ant_maze_env.cpython-37.pyc
│ ├── create_maze_env.cpython-36.pyc
│ ├── create_maze_env.cpython-37.pyc
│ ├── maze_env.cpython-36.pyc
│ ├── maze_env.cpython-37.pyc
│ ├── maze_env_utils.cpython-36.pyc
│ └── maze_env_utils.cpython-37.pyc
├── ant.py
├── ant_maze_env.py
├── assets
│ └── ant.xml
├── create_maze_env.py
├── maze_env.py
└── maze_env_utils.py
├── hiro
├── __pycache__
│ ├── models.cpython-36.pyc
│ ├── models.cpython-37.pyc
│ ├── nn_rl.cpython-36.pyc
│ └── nn_rl.cpython-37.pyc
├── hiro_utils.py
├── models.py
└── utils.py
├── main.py
├── media
├── Success_Rate.svg
├── demo.gif
├── loss_actor_loss_high.svg
├── loss_actor_loss_low.svg
├── loss_actor_loss_td3.svg
├── loss_critic_loss_high.svg
├── loss_critic_loss_low.svg
├── loss_critic_loss_td3.svg
├── reward_Intrinsic_Reward.svg
└── reward_Reward.svg
├── requirements.txt
└── test
├── test_agent.py
├── test_env.py
└── test_models.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.h5
2 | *.pyc
3 | .vscode
4 | __pycach__/
5 | test/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Overview
2 | An implementation of [Data-Efficient Hierarchical Reinforcement Learning](https://arxiv.org/pdf/1805.08296.pdf) (HIRO) in PyTorch.
3 | 
4 |
5 | # Installation
6 | 1. Follow installation of [OpenAI Gym Mujoco Installation](https://github.com/openai/mujoco-py)
7 | ```
8 | 1. Obtain a 30-day free trial on the MuJoCo website or free license if you are a student. The license key will arrive in an email with your username and password.
9 | 2. Download the MuJoCo version 2.0 binaries for Linux or OSX.
10 | 3. Unzip the downloaded mujoco200 directory into ~/.mujoco/mujoco200, and place your license key (the mjkey.txt file from your email) at ~/.mujoco/mjkey.txt.
11 | ```
12 | 2. Install Dependencies
13 | ```
14 | pip install -r requirements.txt
15 | ```
16 |
17 | # Run
18 | For `HIRO`,
19 | ```
20 | python main.py --train
21 | ```
22 |
23 | For `TD3`,
24 | ```
25 | python main.py --train --td3
26 | ```
27 | # Evaluate Trained Model
28 | Passing `--eval` argument will read the most updated model parameters and start playing. The goal is to get to the position (0, 16), which is top left corner.
29 |
30 | For `HIRO`,
31 | ```
32 | python main.py --eval
33 | ```
34 |
35 | For `TD3`,
36 | ```
37 | python main.py --eval --td3
38 | ```
39 |
40 |
41 | # Trainining result
42 | Blue is HIRO and orange is TD3
43 |
44 | ## Succss Rate
45 |
46 |
47 | ## Reward
48 |
49 |
50 | ## Intrinsic Reward
51 |
52 |
53 | ## Losses
54 | Higher Controller Actor
55 |
56 |
57 | Higher Controller Critic
58 |
59 |
60 | Lower Controller Actor
61 |
62 |
63 | Lower Controller Critic
64 |
65 |
66 | TD3 Controller Actor
67 |
68 |
69 | TD3 Controller Critic
70 |
71 |
72 |
--------------------------------------------------------------------------------
/envs/README.md:
--------------------------------------------------------------------------------
1 | # Ant Environments for RL
2 |
3 | Taken almost entirely from the [Tensorflow Models](https://github.com/tensorflow/models/tree/master/research/efficient-hrl/environments) repository, which is itself inspired by the work done by [RlLab](https://github.com/rll/rllab/blob/master/rllab/envs/mujoco/). Minor edit: visual observations, which we needed for some our own work.
4 |
5 |
--------------------------------------------------------------------------------
/envs/__init__.py:
--------------------------------------------------------------------------------
1 | """Random policy on an environment."""
2 |
3 | import numpy as np
4 | import argparse
5 |
6 | import envs.create_maze_env
7 |
8 |
9 | def get_goal_sample_fn(env_name, evaluate):
10 | if env_name == 'AntMaze':
11 | # NOTE: When evaluating (i.e. the metrics shown in the paper,
12 | # we use the commented out goal sampling function. The uncommented
13 | # one is only used for training.
14 | if evaluate:
15 | return lambda: np.array([0., 16.])
16 | else:
17 | return lambda: np.random.uniform((-4, -4), (20, 20))
18 | elif env_name == 'AntPush':
19 | return lambda: np.array([0., 19.])
20 | elif env_name == 'AntFall':
21 | return lambda: np.array([0., 27., 4.5])
22 | else:
23 | assert False, 'Unknown env'
24 |
25 |
26 | def get_reward_fn(env_name):
27 | if env_name == 'AntMaze':
28 | return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
29 | elif env_name == 'AntPush':
30 | return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
31 | elif env_name == 'AntFall':
32 | return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
33 | else:
34 | assert False, 'Unknown env'
35 |
36 |
37 | def success_fn(last_reward):
38 | return last_reward > -5.0
39 |
40 |
41 | class EnvWithGoal(object):
42 | def __init__(self, base_env, env_name):
43 | self.base_env = base_env
44 | self.env_name = env_name
45 | self.evaluate = False
46 | self.reward_fn = get_reward_fn(env_name)
47 | self.goal = None
48 | self.distance_threshold = 5
49 | self.count = 0
50 | self.state_dim = self.base_env.observation_space.shape[0] + 1
51 | self.action_dim = self.base_env.action_space.shape[0]
52 |
53 | def seed(self, seed):
54 | self.base_env.seed(seed)
55 |
56 | def reset(self):
57 | # self.viewer_setup()
58 | self.goal_sample_fn = get_goal_sample_fn(self.env_name, self.evaluate)
59 | obs = self.base_env.reset()
60 | self.count = 0
61 | self.goal = self.goal_sample_fn()
62 | return {
63 | # add timestep
64 | 'observation': np.r_[obs.copy(), self.count],
65 | 'achieved_goal': obs[:2],
66 | 'desired_goal': self.goal,
67 | }
68 |
69 | def step(self, a):
70 | obs, _, done, info = self.base_env.step(a)
71 | reward = self.reward_fn(obs, self.goal)
72 | self.count += 1
73 | next_obs = {
74 | # add timestep
75 | 'observation': np.r_[obs.copy(), self.count],
76 | 'achieved_goal': obs[:2],
77 | 'desired_goal': self.goal,
78 | }
79 | return next_obs, reward, done or self.count >= 500, info
80 |
81 | def render(self):
82 | self.base_env.render()
83 |
84 | def get_image(self):
85 | self.render()
86 | data = self.base_env.viewer.get_image()
87 |
88 | img_data = data[0]
89 | width = data[1]
90 | height = data[2]
91 |
92 | tmp = np.fromstring(img_data, dtype=np.uint8)
93 | image_obs = np.reshape(tmp, [height, width, 3])
94 | image_obs = np.flipud(image_obs)
95 |
96 | return image_obs
97 |
98 | @property
99 | def action_space(self):
100 | return self.base_env.action_space
101 |
102 | @property
103 | def observation_space(self):
104 | return self.base_env.observation_space
105 |
106 | def run_environment(env_name, episode_length, num_episodes):
107 | env = EnvWithGoal(
108 | create_maze_env.create_maze_env(env_name),
109 | env_name)
110 |
111 | def action_fn(obs):
112 | action_space = env.action_space
113 | action_space_mean = (action_space.low + action_space.high) / 2.0
114 | action_space_magn = (action_space.high - action_space.low) / 2.0
115 | random_action = (action_space_mean +
116 | action_space_magn *
117 | np.random.uniform(low=-1.0, high=1.0,
118 | size=action_space.shape))
119 |
120 | return random_action
121 |
122 | rewards = []
123 | successes = []
124 | for ep in range(num_episodes):
125 | rewards.append(0.0)
126 | successes.append(False)
127 | obs = env.reset()
128 | for _ in range(episode_length):
129 | env.render()
130 | print(env.get_image().shape)
131 | obs, reward, done, _ = env.step(action_fn(obs))
132 | rewards[-1] += reward
133 | successes[-1] = success_fn(reward)
134 | if done:
135 | break
136 |
137 | print('Episode {} reward: {}, Success: {}'.format(ep + 1, rewards[-1], successes[-1]))
138 |
139 | print('Average Reward over {} episodes: {}'.format(num_episodes, np.mean(rewards)))
140 | print('Average Success over {} episodes: {}'.format(num_episodes, np.mean(successes)))
141 |
142 |
143 | if __name__ == '__main__':
144 | parser = argparse.ArgumentParser()
145 | parser.add_argument("--env_name", default="AntEnv", type=str)
146 | parser.add_argument("--episode_length", default=500, type=int)
147 | parser.add_argument("--num_episodes", default=100, type=int)
148 |
149 | args = parser.parse_args()
150 | run_environment(args.env_name, args.episode_length, args.num_episodes)
--------------------------------------------------------------------------------
/envs/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/envs/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/envs/__pycache__/ant.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/ant.cpython-36.pyc
--------------------------------------------------------------------------------
/envs/__pycache__/ant.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/ant.cpython-37.pyc
--------------------------------------------------------------------------------
/envs/__pycache__/ant_maze_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/ant_maze_env.cpython-36.pyc
--------------------------------------------------------------------------------
/envs/__pycache__/ant_maze_env.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/ant_maze_env.cpython-37.pyc
--------------------------------------------------------------------------------
/envs/__pycache__/create_maze_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/create_maze_env.cpython-36.pyc
--------------------------------------------------------------------------------
/envs/__pycache__/create_maze_env.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/create_maze_env.cpython-37.pyc
--------------------------------------------------------------------------------
/envs/__pycache__/maze_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/maze_env.cpython-36.pyc
--------------------------------------------------------------------------------
/envs/__pycache__/maze_env.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/maze_env.cpython-37.pyc
--------------------------------------------------------------------------------
/envs/__pycache__/maze_env_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/maze_env_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/envs/__pycache__/maze_env_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/envs/__pycache__/maze_env_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/envs/ant.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Wrapper for creating the ant environment in gym_mujoco."""
17 |
18 | import math
19 | import numpy as np
20 | from gym import utils
21 | from gym.envs.mujoco import mujoco_env
22 |
23 |
24 | class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
25 | FILE = "ant.xml"
26 |
27 | def __init__(self, file_path=None, expose_all_qpos=True,
28 | expose_body_coms=None, expose_body_comvels=None):
29 | self._expose_all_qpos = expose_all_qpos
30 | self._expose_body_coms = expose_body_coms
31 | self._expose_body_comvels = expose_body_comvels
32 | self._body_com_indices = {}
33 | self._body_comvel_indices = {}
34 |
35 | mujoco_env.MujocoEnv.__init__(self, file_path, 5)
36 | utils.EzPickle.__init__(self)
37 |
38 | @property
39 | def physics(self):
40 | return self.model
41 |
42 | def _step(self, a):
43 | return self.step(a)
44 |
45 | def step(self, a):
46 | xposbefore = self.get_body_com("torso")[0]
47 | self.do_simulation(a, self.frame_skip)
48 | xposafter = self.get_body_com("torso")[0]
49 | forward_reward = (xposafter - xposbefore) / self.dt
50 | ctrl_cost = .5 * np.square(a).sum()
51 | survive_reward = 1.0
52 | reward = forward_reward - ctrl_cost + survive_reward
53 | state = self.state_vector()
54 | done = False
55 | ob = self._get_obs()
56 | return ob, reward, done, dict(
57 | reward_forward=forward_reward,
58 | reward_ctrl=-ctrl_cost,
59 | reward_survive=survive_reward)
60 |
61 | def _get_obs(self):
62 | # No cfrc observation
63 | if self._expose_all_qpos:
64 | obs = np.concatenate([
65 | self.data.qpos.flat[:15], # Ensures only ant obs.
66 | self.data.qvel.flat[:14],
67 | ])
68 | else:
69 | obs = np.concatenate([
70 | self.data.qpos.flat[2:15],
71 | self.data.qvel.flat[:14],
72 | ])
73 |
74 | if self._expose_body_coms is not None:
75 | for name in self._expose_body_coms:
76 | com = self.get_body_com(name)
77 | if name not in self._body_com_indices:
78 | indices = range(len(obs), len(obs) + len(com))
79 | self._body_com_indices[name] = indices
80 | obs = np.concatenate([obs, com])
81 |
82 | if self._expose_body_comvels is not None:
83 | for name in self._expose_body_comvels:
84 | comvel = self.get_body_comvel(name)
85 | if name not in self._body_comvel_indices:
86 | indices = range(len(obs), len(obs) + len(comvel))
87 | self._body_comvel_indices[name] = indices
88 | obs = np.concatenate([obs, comvel])
89 | return obs
90 |
91 | def reset_model(self):
92 | qpos = self.init_qpos + self.np_random.uniform(
93 | size=self.model.nq, low=-.1, high=.1)
94 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
95 |
96 | # Set everything other than ant to original position and 0 velocity.
97 | qpos[15:] = self.init_qpos[15:]
98 | qvel[14:] = 0.
99 | self.set_state(qpos, qvel)
100 | return self._get_obs()
101 |
102 | def viewer_setup(self):
103 | self.viewer.cam.trackbodyid = -1
104 | self.viewer.cam.distance = 50
105 | self.viewer.cam.elevation = -90
--------------------------------------------------------------------------------
/envs/ant_maze_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from .maze_env import MazeEnv
17 | from .ant import AntEnv
18 |
19 |
20 | class AntMazeEnv(MazeEnv):
21 | MODEL_CLASS = AntEnv
22 |
--------------------------------------------------------------------------------
/envs/assets/ant.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/envs/create_maze_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from .ant_maze_env import AntMazeEnv
17 |
18 |
19 | def create_maze_env(env_name=None):
20 | maze_id = None
21 | if env_name.startswith('AntMaze'):
22 | maze_id = 'Maze'
23 | elif env_name.startswith('AntPush'):
24 | maze_id = 'Push'
25 | elif env_name.startswith('AntFall'):
26 | maze_id = 'Fall'
27 | else:
28 | raise ValueError('Unknown maze environment %s' % env_name)
29 |
30 | return AntMazeEnv(maze_id=maze_id)
31 |
--------------------------------------------------------------------------------
/envs/maze_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Adapted from rllab maze_env.py."""
17 |
18 | import os
19 | import tempfile
20 | import xml.etree.ElementTree as ET
21 | import math
22 | import numpy as np
23 | import gym
24 |
25 | from envs import maze_env_utils
26 |
27 | # Directory that contains mujoco xml files.
28 | MODEL_DIR = 'assets'
29 |
30 |
31 | class MazeEnv(gym.Env):
32 | MODEL_CLASS = None
33 |
34 | MAZE_HEIGHT = None
35 | MAZE_SIZE_SCALING = None
36 |
37 | def __init__(
38 | self,
39 | maze_id=None,
40 | maze_height=0.5,
41 | maze_size_scaling=8,
42 | *args,
43 | **kwargs):
44 | self._maze_id = maze_id
45 | self.t = 0
46 |
47 | model_cls = self.__class__.MODEL_CLASS
48 | if model_cls is None:
49 | raise "MODEL_CLASS unspecified!"
50 | xml_path = os.path.join("envs", MODEL_DIR, model_cls.FILE)
51 | tree = ET.parse(xml_path)
52 | worldbody = tree.find(".//worldbody")
53 |
54 | self.MAZE_HEIGHT = height = maze_height
55 | self.MAZE_SIZE_SCALING = size_scaling = maze_size_scaling
56 | self.MAZE_STRUCTURE = structure = maze_env_utils.construct_maze(maze_id=self._maze_id)
57 | self.elevated = any(-1 in row for row in structure) # Elevate the maze to allow for falling.
58 | self.blocks = any(
59 | any(maze_env_utils.can_move(r) for r in row)
60 | for row in structure) # Are there any movable blocks?
61 |
62 | torso_x, torso_y = self._find_robot()
63 | self._init_torso_x = torso_x
64 | self._init_torso_y = torso_y
65 |
66 | height_offset = 0.
67 | if self.elevated:
68 | # Increase initial z-pos of ant.
69 | height_offset = height * size_scaling
70 | torso = tree.find(".//body[@name='torso']")
71 | torso.set('pos', '0 0 %.2f' % (0.75 + height_offset))
72 | if self.blocks:
73 | # If there are movable blocks, change simulation settings to perform
74 | # better contact detection.
75 | default = tree.find(".//default")
76 | default.find('.//geom').set('solimp', '.995 .995 .01')
77 |
78 | for i in range(len(structure)):
79 | for j in range(len(structure[0])):
80 | if self.elevated and structure[i][j] not in [-1]:
81 | # Create elevated platform.
82 | ET.SubElement(
83 | worldbody, "geom",
84 | name="elevated_%d_%d" % (i, j),
85 | pos="%f %f %f" % (j * size_scaling - torso_x,
86 | i * size_scaling - torso_y,
87 | height / 2 * size_scaling),
88 | size="%f %f %f" % (0.5 * size_scaling,
89 | 0.5 * size_scaling,
90 | height / 2 * size_scaling),
91 | type="box",
92 | material="",
93 | contype="1",
94 | conaffinity="1",
95 | rgba="0.9 0.9 0.9 1",
96 | )
97 | if structure[i][j] == 1: # Unmovable block.
98 | # Offset all coordinates so that robot starts at the origin.
99 | ET.SubElement(
100 | worldbody, "geom",
101 | name="block_%d_%d" % (i, j),
102 | pos="%f %f %f" % (j * size_scaling - torso_x,
103 | i * size_scaling - torso_y,
104 | height_offset +
105 | height / 2 * size_scaling),
106 | size="%f %f %f" % (0.5 * size_scaling,
107 | 0.5 * size_scaling,
108 | height / 2 * size_scaling),
109 | type="box",
110 | material="",
111 | contype="1",
112 | conaffinity="1",
113 | rgba="0.4 0.4 0.4 1",
114 | )
115 | elif maze_env_utils.can_move(structure[i][j]): # Movable block.
116 | # The "falling" blocks are shrunk slightly and increased in mass to
117 | # ensure that it can fall easily through a gap in the platform blocks.
118 | falling = maze_env_utils.can_move_z(structure[i][j])
119 | shrink = 0.99 if falling else 1.0
120 | moveable_body = ET.SubElement(
121 | worldbody, "body",
122 | name="moveable_%d_%d" % (i, j),
123 | pos="%f %f %f" % (j * size_scaling - torso_x,
124 | i * size_scaling - torso_y,
125 | height_offset +
126 | height / 2 * size_scaling),
127 | )
128 | ET.SubElement(
129 | moveable_body, "geom",
130 | name="block_%d_%d" % (i, j),
131 | pos="0 0 0",
132 | size="%f %f %f" % (0.5 * size_scaling * shrink,
133 | 0.5 * size_scaling * shrink,
134 | height / 2 * size_scaling),
135 | type="box",
136 | material="",
137 | mass="0.001" if falling else "0.0002",
138 | contype="1",
139 | conaffinity="1",
140 | rgba="0.9 0.1 0.1 1"
141 | )
142 | if maze_env_utils.can_move_x(structure[i][j]):
143 | ET.SubElement(
144 | moveable_body, "joint",
145 | armature="0",
146 | axis="1 0 0",
147 | damping="0.0",
148 | limited="true" if falling else "false",
149 | range="%f %f" % (-size_scaling, size_scaling),
150 | margin="0.01",
151 | name="moveable_x_%d_%d" % (i, j),
152 | pos="0 0 0",
153 | type="slide"
154 | )
155 | if maze_env_utils.can_move_y(structure[i][j]):
156 | ET.SubElement(
157 | moveable_body, "joint",
158 | armature="0",
159 | axis="0 1 0",
160 | damping="0.0",
161 | limited="true" if falling else "false",
162 | range="%f %f" % (-size_scaling, size_scaling),
163 | margin="0.01",
164 | name="moveable_y_%d_%d" % (i, j),
165 | pos="0 0 0",
166 | type="slide"
167 | )
168 | if maze_env_utils.can_move_z(structure[i][j]):
169 | ET.SubElement(
170 | moveable_body, "joint",
171 | armature="0",
172 | axis="0 0 1",
173 | damping="0.0",
174 | limited="true",
175 | range="%f 0" % (-height_offset),
176 | margin="0.01",
177 | name="moveable_z_%d_%d" % (i, j),
178 | pos="0 0 0",
179 | type="slide"
180 | )
181 |
182 | torso = tree.find(".//body[@name='torso']")
183 | geoms = torso.findall(".//geom")
184 | for geom in geoms:
185 | if 'name' not in geom.attrib:
186 | raise Exception("Every geom of the torso must have a name "
187 | "defined")
188 |
189 | _, file_path = tempfile.mkstemp(text=True, suffix=".xml")
190 | tree.write(file_path)
191 |
192 | self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
193 |
194 | def _get_obs(self):
195 | return np.concatenate([self.wrapped_env._get_obs(),
196 | [self.t * 0.001]])
197 |
198 | def reset(self):
199 | self.t = 0
200 | self.wrapped_env.reset()
201 | return self._get_obs()
202 |
203 | @property
204 | def viewer(self):
205 | return self.wrapped_env.viewer
206 |
207 | def render(self, *args, **kwargs):
208 | return self.wrapped_env.render(*args, **kwargs)
209 |
210 | @property
211 | def observation_space(self):
212 | shape = self._get_obs().shape
213 | high = np.inf * np.ones(shape)
214 | low = -high
215 | return gym.spaces.Box(low, high)
216 |
217 | @property
218 | def action_space(self):
219 | return self.wrapped_env.action_space
220 |
221 | def _find_robot(self):
222 | structure = self.MAZE_STRUCTURE
223 | size_scaling = self.MAZE_SIZE_SCALING
224 | for i in range(len(structure)):
225 | for j in range(len(structure[0])):
226 | if structure[i][j] == 'r':
227 | return j * size_scaling, i * size_scaling
228 | assert False, 'No robot in maze specification.'
229 |
230 | def step(self, action):
231 | self.t += 1
232 | inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action)
233 | next_obs = self._get_obs()
234 | done = False
235 | return next_obs, inner_reward, done, info
236 |
--------------------------------------------------------------------------------
/envs/maze_env_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Adapted from rllab maze_env_utils.py."""
17 | import numpy as np
18 | import math
19 |
20 |
21 | class Move(object):
22 | X = 11
23 | Y = 12
24 | Z = 13
25 | XY = 14
26 | XZ = 15
27 | YZ = 16
28 | XYZ = 17
29 |
30 |
31 | def can_move_x(movable):
32 | return movable in [Move.X, Move.XY, Move.XZ, Move.XYZ]
33 |
34 |
35 | def can_move_y(movable):
36 | return movable in [Move.Y, Move.XY, Move.YZ, Move.XYZ]
37 |
38 |
39 | def can_move_z(movable):
40 | return movable in [Move.Z, Move.XZ, Move.YZ, Move.XYZ]
41 |
42 |
43 | def can_move(movable):
44 | return can_move_x(movable) or can_move_y(movable) or can_move_z(movable)
45 |
46 |
47 | def construct_maze(maze_id='Maze'):
48 | if maze_id == 'Maze':
49 | structure = [
50 | [1, 1, 1, 1, 1],
51 | [1, 'r', 0, 0, 1],
52 | [1, 1, 1, 0, 1],
53 | [1, 0, 0, 0, 1],
54 | [1, 1, 1, 1, 1],
55 | ]
56 | elif maze_id == 'Push':
57 | structure = [
58 | [1, 1, 1, 1, 1],
59 | [1, 0, 'r', 1, 1],
60 | [1, 0, Move.XY, 0, 1],
61 | [1, 1, 0, 1, 1],
62 | [1, 1, 1, 1, 1],
63 | ]
64 | elif maze_id == 'Fall':
65 | structure = [
66 | [1, 1, 1, 1],
67 | [1, 'r', 0, 1],
68 | [1, 0, Move.YZ, 1],
69 | [1, -1, -1, 1],
70 | [1, 0, 0, 1],
71 | [1, 1, 1, 1],
72 | ]
73 | else:
74 | raise NotImplementedError('The provided MazeId %s is not recognized' % maze_id)
75 |
76 | return structure
77 |
--------------------------------------------------------------------------------
/hiro/__pycache__/models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/hiro/__pycache__/models.cpython-36.pyc
--------------------------------------------------------------------------------
/hiro/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/hiro/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/hiro/__pycache__/nn_rl.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/hiro/__pycache__/nn_rl.cpython-36.pyc
--------------------------------------------------------------------------------
/hiro/__pycache__/nn_rl.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/hiro/__pycache__/nn_rl.cpython-37.pyc
--------------------------------------------------------------------------------
/hiro/hiro_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5 |
6 |
7 | class ReplayBuffer():
8 | def __init__(self, state_dim, goal_dim, action_dim, buffer_size, batch_size):
9 | self.buffer_size = buffer_size
10 | self.batch_size = batch_size
11 | self.ptr = 0
12 | self.size = 0
13 | self.state = np.zeros((buffer_size, state_dim))
14 | self.goal = np.zeros((buffer_size, goal_dim))
15 | self.action = np.zeros((buffer_size, action_dim))
16 | self.n_state = np.zeros((buffer_size, state_dim))
17 | self.reward = np.zeros((buffer_size, 1))
18 | self.not_done = np.zeros((buffer_size, 1))
19 |
20 | self.device = device
21 |
22 | def append(self, state, goal, action, n_state, reward, done):
23 | self.state[self.ptr] = state
24 | self.goal[self.ptr] = goal
25 | self.action[self.ptr] = action
26 | self.n_state[self.ptr] = n_state
27 | self.reward[self.ptr] = reward
28 | self.not_done[self.ptr] = 1. - done
29 |
30 | self.ptr = (self.ptr+1) % self.buffer_size
31 | self.size = min(self.size+1, self.buffer_size)
32 |
33 | def sample(self):
34 | ind = np.random.randint(0, self.size, size=self.batch_size)
35 |
36 | return (
37 | torch.FloatTensor(self.state[ind]).to(self.device),
38 | torch.FloatTensor(self.goal[ind]).to(self.device),
39 | torch.FloatTensor(self.action[ind]).to(self.device),
40 | torch.FloatTensor(self.n_state[ind]).to(self.device),
41 | torch.FloatTensor(self.reward[ind]).to(self.device),
42 | torch.FloatTensor(self.not_done[ind]).to(self.device),
43 | )
44 |
45 | class LowReplayBuffer(ReplayBuffer):
46 | def __init__(self, state_dim, goal_dim, action_dim, buffer_size, batch_size):
47 | super(LowReplayBuffer, self).__init__(state_dim, goal_dim, action_dim, buffer_size, batch_size)
48 | self.n_goal = np.zeros((buffer_size, goal_dim))
49 |
50 | def append(self, state, goal, action, n_state, n_goal, reward, done):
51 | self.state[self.ptr] = state
52 | self.goal[self.ptr] = goal
53 | self.action[self.ptr] = action
54 | self.n_state[self.ptr] = n_state
55 | self.n_goal[self.ptr] = n_goal
56 | self.reward[self.ptr] = reward
57 | self.not_done[self.ptr] = 1. - done
58 |
59 | self.ptr = (self.ptr+1) % self.buffer_size
60 | self.size = min(self.size+1, self.buffer_size)
61 |
62 | def sample(self):
63 | ind = np.random.randint(0, self.size, size=self.batch_size)
64 |
65 | return (
66 | torch.FloatTensor(self.state[ind]).to(self.device),
67 | torch.FloatTensor(self.goal[ind]).to(self.device),
68 | torch.FloatTensor(self.action[ind]).to(self.device),
69 | torch.FloatTensor(self.n_state[ind]).to(self.device),
70 | torch.FloatTensor(self.n_goal[ind]).to(self.device),
71 | torch.FloatTensor(self.reward[ind]).to(self.device),
72 | torch.FloatTensor(self.not_done[ind]).to(self.device),
73 | )
74 |
75 | class HighReplayBuffer(ReplayBuffer):
76 | def __init__(self, state_dim, goal_dim, subgoal_dim, action_dim, buffer_size, batch_size, freq):
77 | super(HighReplayBuffer, self).__init__(state_dim, goal_dim, action_dim, buffer_size, batch_size)
78 | self.action = np.zeros((buffer_size, subgoal_dim))
79 | self.state_arr = np.zeros((buffer_size, freq, state_dim))
80 | self.action_arr = np.zeros((buffer_size, freq, action_dim))
81 |
82 | def append(self, state, goal, action, n_state, reward, done, state_arr, action_arr):
83 | self.state[self.ptr] = state
84 | self.goal[self.ptr] = goal
85 | self.action[self.ptr] = action
86 | self.n_state[self.ptr] = n_state
87 | self.reward[self.ptr] = reward
88 | self.not_done[self.ptr] = 1. - done
89 | self.state_arr[self.ptr,:,:] = state_arr
90 | self.action_arr[self.ptr,:,:] = action_arr
91 |
92 | self.ptr = (self.ptr+1) % self.buffer_size
93 | self.size = min(self.size+1, self.buffer_size)
94 |
95 | def sample(self):
96 | ind = np.random.randint(0, self.size, size=self.batch_size)
97 |
98 | return (
99 | torch.FloatTensor(self.state[ind]).to(self.device),
100 | torch.FloatTensor(self.goal[ind]).to(self.device),
101 | torch.FloatTensor(self.action[ind]).to(self.device),
102 | torch.FloatTensor(self.n_state[ind]).to(self.device),
103 | torch.FloatTensor(self.reward[ind]).to(self.device),
104 | torch.FloatTensor(self.not_done[ind]).to(self.device),
105 | torch.FloatTensor(self.state_arr[ind]).to(self.device),
106 | torch.FloatTensor(self.action_arr[ind]).to(self.device)
107 | )
108 |
109 | class SubgoalActionSpace(object):
110 | def __init__(self, dim):
111 | limits = np.array([-10, -10, -0.5, -1, -1, -1, -1,
112 | -0.5, -0.3, -0.5, -0.3, -0.5, -0.3, -0.5, -0.3])
113 | self.shape = (dim,1)
114 | self.low = limits[:dim]
115 | self.high = -self.low
116 |
117 | def sample(self):
118 | return (self.high - self.low) * np.random.sample(self.high.shape) + self.low
119 |
120 | class Subgoal(object):
121 | def __init__(self, dim=15):
122 | self.action_space = SubgoalActionSpace(dim)
123 | self.action_dim = self.action_space.shape[0]
124 |
--------------------------------------------------------------------------------
/hiro/models.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # @copyright Kandai Watanabe
3 | # @email kandai.wata@gmail.com
4 | # @institute University of Colorado Boulder
5 | #
6 | # NN Models for HIRO
7 | # (Data-Efficient Hierarchical Reinforcement Learning)
8 | # Parameters can be find in the original paper
9 | import os
10 | import copy
11 | import time
12 | import glob
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from .utils import get_tensor
18 | from hiro.hiro_utils import LowReplayBuffer, HighReplayBuffer, ReplayBuffer, Subgoal
19 | from hiro.utils import _is_update
20 |
21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22 |
23 | class TD3Actor(nn.Module):
24 | def __init__(self, state_dim, goal_dim, action_dim, scale=None):
25 | super(TD3Actor, self).__init__()
26 | if scale is None:
27 | scale = torch.ones(state_dim)
28 | else:
29 | scale = get_tensor(scale)
30 | self.scale = nn.Parameter(scale.clone().detach().float(), requires_grad=False)
31 |
32 | self.l1 = nn.Linear(state_dim + goal_dim, 300)
33 | self.l2 = nn.Linear(300, 300)
34 | self.l3 = nn.Linear(300, action_dim)
35 |
36 | def forward(self, state, goal):
37 | a = F.relu(self.l1(torch.cat([state, goal], 1)))
38 | a = F.relu(self.l2(a))
39 | return self.scale * torch.tanh(self.l3(a))
40 |
41 | class TD3Critic(nn.Module):
42 | def __init__(self, state_dim, goal_dim, action_dim):
43 | super(TD3Critic, self).__init__()
44 | # Q1
45 | self.l1 = nn.Linear(state_dim + goal_dim + action_dim, 300)
46 | self.l2 = nn.Linear(300, 300)
47 | self.l3 = nn.Linear(300, 1)
48 | # Q2
49 | self.l4 = nn.Linear(state_dim + goal_dim + action_dim, 300)
50 | self.l5 = nn.Linear(300, 300)
51 | self.l6 = nn.Linear(300, 1)
52 |
53 | def forward(self, state, goal, action):
54 | sa = torch.cat([state, goal, action], 1)
55 |
56 | q = F.relu(self.l1(sa))
57 | q = F.relu(self.l2(q))
58 | q = self.l3(q)
59 |
60 | return q
61 |
62 | class TD3Controller(object):
63 | def __init__(
64 | self,
65 | state_dim,
66 | goal_dim,
67 | action_dim,
68 | scale,
69 | model_path,
70 | actor_lr=0.0001,
71 | critic_lr=0.001,
72 | expl_noise=0.1,
73 | policy_noise=0.2,
74 | noise_clip=0.5,
75 | gamma=0.99,
76 | policy_freq=2,
77 | tau=0.005):
78 | self.name = 'td3'
79 | self.scale = scale
80 | self.model_path = model_path
81 |
82 | # parameters
83 | self.expl_noise = expl_noise
84 | self.policy_noise = policy_noise
85 | self.noise_clip = noise_clip
86 | self.gamma = gamma
87 | self.policy_freq = policy_freq
88 | self.tau = tau
89 |
90 | self.actor = TD3Actor(state_dim, goal_dim, action_dim, scale=scale).to(device)
91 | self.actor_target = TD3Actor(state_dim, goal_dim, action_dim, scale=scale).to(device)
92 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
93 |
94 | self.critic1 = TD3Critic(state_dim, goal_dim, action_dim).to(device)
95 | self.critic2 = TD3Critic(state_dim, goal_dim, action_dim).to(device)
96 | self.critic1_target = TD3Critic(state_dim, goal_dim, action_dim).to(device)
97 | self.critic2_target = TD3Critic(state_dim, goal_dim, action_dim).to(device)
98 |
99 | self.critic1_optimizer = torch.optim.Adam(self.critic1.parameters(), lr=critic_lr)
100 | self.critic2_optimizer = torch.optim.Adam(self.critic2.parameters(), lr=critic_lr)
101 | self._initialize_target_networks()
102 |
103 | self._initialized = False
104 | self.total_it = 0
105 |
106 | def _initialize_target_networks(self):
107 | self._update_target_network(self.critic1_target, self.critic1, 1.0)
108 | self._update_target_network(self.critic2_target, self.critic2, 1.0)
109 | self._update_target_network(self.actor_target, self.actor, 1.0)
110 | self._initialized = True
111 |
112 | def _update_target_network(self, target, origin, tau):
113 | for target_param, origin_param in zip(target.parameters(), origin.parameters()):
114 | target_param.data.copy_(tau * origin_param.data + (1.0 - tau) * target_param.data)
115 |
116 | def save(self, episode):
117 | # create episode directory. (e.g. model/2000)
118 | model_path = os.path.join(self.model_path, str(episode))
119 | if not os.path.exists(model_path):
120 | os.makedirs(model_path)
121 |
122 | # save file (e.g. model/2000/high_actor.h)
123 | torch.save(
124 | self.actor.state_dict(),
125 | os.path.join(model_path, self.name+"_actor.h5")
126 | )
127 | torch.save(
128 | self.critic1.state_dict(),
129 | os.path.join(model_path, self.name+"_critic1.h5")
130 | )
131 | torch.save(
132 | self.critic2.state_dict(),
133 | os.path.join(model_path, self.name+"_critic2.h5")
134 | )
135 |
136 | def load(self, episode):
137 | # episode is -1, then read most updated
138 | if episode<0:
139 | episode_list = map(int, os.listdir(self.model_path))
140 | episode = max(episode_list)
141 |
142 | model_path = os.path.join(self.model_path, str(episode))
143 |
144 | self.actor.load_state_dict(torch.load(
145 | os.path.join(model_path, self.name+"_actor.h5"))
146 | )
147 | self.critic1.load_state_dict(torch.load(
148 | os.path.join(model_path, self.name+"_critic1.h5"))
149 | )
150 | self.critic2.load_state_dict(torch.load(
151 | os.path.join(model_path, self.name+"_critic2.h5"))
152 | )
153 |
154 | def _train(self, states, goals, actions, rewards, n_states, n_goals, not_done):
155 | self.total_it += 1
156 | with torch.no_grad():
157 | noise = (
158 | torch.randn_like(actions) * self.policy_noise
159 | ).clamp(-self.noise_clip, self.noise_clip)
160 |
161 | n_actions = self.actor_target(n_states, n_goals) + noise
162 | n_actions = torch.min(n_actions, self.actor.scale)
163 | n_actions = torch.max(n_actions, -self.actor.scale)
164 |
165 | target_Q1 = self.critic1_target(n_states, n_goals, n_actions)
166 | target_Q2 = self.critic2_target(n_states, n_goals, n_actions)
167 | target_Q = torch.min(target_Q1, target_Q2)
168 | target_Q_detached = (rewards + not_done * self.gamma * target_Q).detach()
169 |
170 | current_Q1 = self.critic1(states, goals, actions)
171 | current_Q2 = self.critic2(states, goals, actions)
172 |
173 | critic1_loss = F.smooth_l1_loss(current_Q1, target_Q_detached)
174 | critic2_loss = F.smooth_l1_loss(current_Q2, target_Q_detached)
175 | critic_loss = critic1_loss + critic2_loss
176 |
177 | td_error = (target_Q_detached - current_Q1).mean().cpu().data.numpy()
178 |
179 | self.critic1_optimizer.zero_grad()
180 | self.critic2_optimizer.zero_grad()
181 | critic_loss.backward()
182 | self.critic1_optimizer.step()
183 | self.critic2_optimizer.step()
184 |
185 | if self.total_it % self.policy_freq == 0:
186 | a = self.actor(states, goals)
187 | Q1 = self.critic1(states, goals, a)
188 | actor_loss = -Q1.mean() # multiply by neg becuz gradient ascent
189 |
190 | self.actor_optimizer.zero_grad()
191 | actor_loss.backward()
192 | self.actor_optimizer.step()
193 |
194 | self._update_target_network(self.critic1_target, self.critic1, self.tau)
195 | self._update_target_network(self.critic2_target, self.critic2, self.tau)
196 | self._update_target_network(self.actor_target, self.actor, self.tau)
197 |
198 | return {'actor_loss_'+self.name: actor_loss, 'critic_loss_'+self.name: critic_loss}, \
199 | {'td_error_'+self.name: td_error}
200 |
201 | return {'critic_loss_'+self.name: critic_loss}, \
202 | {'td_error_'+self.name: td_error}
203 |
204 | def train(self, replay_buffer, iterations=1):
205 | states, goals, actions, n_states, rewards, not_done = replay_buffer.sample()
206 | return self._train(states, goals, actions, rewards, n_states, goals, not_done)
207 |
208 | def policy(self, state, goal, to_numpy=True):
209 | state = get_tensor(state)
210 | goal = get_tensor(goal)
211 | action = self.actor(state, goal)
212 |
213 | if to_numpy:
214 | return action.cpu().data.numpy().squeeze()
215 |
216 | return action.squeeze()
217 |
218 | def policy_with_noise(self, state, goal, to_numpy=True):
219 | state = get_tensor(state)
220 | goal = get_tensor(goal)
221 | action = self.actor(state, goal)
222 |
223 | action = action + self._sample_exploration_noise(action)
224 | action = torch.min(action, self.actor.scale)
225 | action = torch.max(action, -self.actor.scale)
226 |
227 | if to_numpy:
228 | return action.cpu().data.numpy().squeeze()
229 |
230 | return action.squeeze()
231 |
232 | def _sample_exploration_noise(self, actions):
233 | mean = torch.zeros(actions.size()).to(device)
234 | var = torch.ones(actions.size()).to(device)
235 | #expl_noise = self.expl_noise - (self.expl_noise/1200) * (self.total_it//10000)
236 | return torch.normal(mean, self.expl_noise*var)
237 |
238 | class HigherController(TD3Controller):
239 | def __init__(
240 | self,
241 | state_dim,
242 | goal_dim,
243 | action_dim,
244 | scale,
245 | model_path,
246 | actor_lr=0.0001,
247 | critic_lr=0.001,
248 | expl_noise=1.0,
249 | policy_noise=0.2,
250 | noise_clip=0.5,
251 | gamma=0.99,
252 | policy_freq=2,
253 | tau=0.005):
254 | super(HigherController, self).__init__(
255 | state_dim, goal_dim, action_dim, scale, model_path,
256 | actor_lr, critic_lr, expl_noise, policy_noise,
257 | noise_clip, gamma, policy_freq, tau
258 | )
259 | self.name = 'high'
260 | self.action_dim = action_dim
261 |
262 | def off_policy_corrections(self, low_con, batch_size, sgoals, states, actions, candidate_goals=8):
263 | first_s = [s[0] for s in states] # First x
264 | last_s = [s[-1] for s in states] # Last x
265 |
266 | # Shape: (batch_size, 1, subgoal_dim)
267 | # diff = 1
268 | diff_goal = (np.array(last_s) -
269 | np.array(first_s))[:, np.newaxis, :self.action_dim]
270 |
271 | # Shape: (batch_size, 1, subgoal_dim)
272 | # original = 1
273 | # random = candidate_goals
274 | original_goal = np.array(sgoals)[:, np.newaxis, :]
275 | random_goals = np.random.normal(loc=diff_goal, scale=.5*self.scale[None, None, :],
276 | size=(batch_size, candidate_goals, original_goal.shape[-1]))
277 | random_goals = random_goals.clip(-self.scale, self.scale)
278 |
279 | # Shape: (batch_size, 10, subgoal_dim)
280 | candidates = np.concatenate([original_goal, diff_goal, random_goals], axis=1)
281 | #states = np.array(states)[:, :-1, :]
282 | actions = np.array(actions)
283 | seq_len = len(states[0])
284 |
285 | # For ease
286 | new_batch_sz = seq_len * batch_size
287 | action_dim = actions[0][0].shape
288 | obs_dim = states[0][0].shape
289 | ncands = candidates.shape[1]
290 |
291 | true_actions = actions.reshape((new_batch_sz,) + action_dim)
292 | observations = states.reshape((new_batch_sz,) + obs_dim)
293 | goal_shape = (new_batch_sz, self.action_dim)
294 | # observations = get_obs_tensor(observations, sg_corrections=True)
295 |
296 | # batched_candidates = np.tile(candidates, [seq_len, 1, 1])
297 | # batched_candidates = batched_candidates.transpose(1, 0, 2)
298 |
299 | policy_actions = np.zeros((ncands, new_batch_sz) + action_dim)
300 |
301 | for c in range(ncands):
302 | subgoal = candidates[:,c]
303 | candidate = (subgoal + states[:, 0, :self.action_dim])[:, None] - states[:, :, :self.action_dim]
304 | candidate = candidate.reshape(*goal_shape)
305 | policy_actions[c] = low_con.policy(observations, candidate)
306 |
307 | difference = (policy_actions - true_actions)
308 | difference = np.where(difference != -np.inf, difference, 0)
309 | difference = difference.reshape((ncands, batch_size, seq_len) + action_dim).transpose(1, 0, 2, 3)
310 |
311 | logprob = -0.5*np.sum(np.linalg.norm(difference, axis=-1)**2, axis=-1)
312 | max_indices = np.argmax(logprob, axis=-1)
313 |
314 | return candidates[np.arange(batch_size), max_indices]
315 |
316 | def train(self, replay_buffer, low_con):
317 | if not self._initialized:
318 | self._initialize_target_networks()
319 |
320 | states, goals, actions, n_states, rewards, not_done, states_arr, actions_arr = replay_buffer.sample()
321 |
322 | actions = self.off_policy_corrections(
323 | low_con,
324 | replay_buffer.batch_size,
325 | actions.cpu().data.numpy(),
326 | states_arr.cpu().data.numpy(),
327 | actions_arr.cpu().data.numpy())
328 |
329 | actions = get_tensor(actions)
330 | return self._train(states, goals, actions, rewards, n_states, goals, not_done)
331 |
332 | class LowerController(TD3Controller):
333 | def __init__(
334 | self,
335 | state_dim,
336 | goal_dim,
337 | action_dim,
338 | scale,
339 | model_path,
340 | actor_lr=0.0001,
341 | critic_lr=0.001,
342 | expl_noise=1.0,
343 | policy_noise=0.2,
344 | noise_clip=0.5,
345 | gamma=0.99,
346 | policy_freq=2,
347 | tau=0.005):
348 | super(LowerController, self).__init__(
349 | state_dim, goal_dim, action_dim, scale, model_path,
350 | actor_lr, critic_lr, expl_noise, policy_noise,
351 | noise_clip, gamma, policy_freq, tau
352 | )
353 | self.name = 'low'
354 |
355 | def train(self, replay_buffer):
356 | if not self._initialized:
357 | self._initialize_target_networks()
358 |
359 | states, sgoals, actions, n_states, n_sgoals, rewards, not_done = replay_buffer.sample()
360 |
361 | return self._train(states, sgoals, actions, rewards, n_states, n_sgoals, not_done)
362 |
363 | class Agent():
364 | def __init__(self):
365 | pass
366 |
367 | def set_final_goal(self, fg):
368 | self.fg = fg
369 |
370 | def step(self, s, env, step, global_step=0, explore=False):
371 | raise NotImplementedError
372 |
373 | def append(self, step, s, a, n_s, r, d):
374 | raise NotImplementedError
375 |
376 | def train(self, global_step):
377 | raise NotImplementedError
378 |
379 | def end_step(self):
380 | raise NotImplementedError
381 |
382 | def end_episode(self, episode, logger=None):
383 | raise NotImplementedError
384 |
385 | def evaluate_policy(self, env, eval_episodes=10, render=False, save_video=False, sleep=-1):
386 | if save_video:
387 | from OpenGL import GL
388 | env = gym.wrappers.Monitor(env, directory='video',
389 | write_upon_reset=True, force=True, resume=True, mode='evaluation')
390 | render = False
391 |
392 | success = 0
393 | rewards = []
394 | env.evaluate = True
395 | for e in range(eval_episodes):
396 | obs = env.reset()
397 | fg = obs['desired_goal']
398 | s = obs['observation']
399 | done = False
400 | reward_episode_sum = 0
401 | step = 0
402 |
403 | self.set_final_goal(fg)
404 |
405 | while not done:
406 | if render:
407 | env.render()
408 | if sleep>0:
409 | time.sleep(sleep)
410 |
411 | a, r, n_s, done = self.step(s, env, step)
412 | reward_episode_sum += r
413 |
414 | s = n_s
415 | step += 1
416 | self.end_step()
417 | else:
418 | error = np.sqrt(np.sum(np.square(fg-s[:2])))
419 | print('Goal, Curr: (%02.2f, %02.2f, %02.2f, %02.2f) Error:%.2f'%(fg[0], fg[1], s[0], s[1], error))
420 | rewards.append(reward_episode_sum)
421 | success += 1 if error <=5 else 0
422 | self.end_episode(e)
423 |
424 | env.evaluate = False
425 | return np.array(rewards), success/eval_episodes
426 |
427 | class TD3Agent(Agent):
428 | def __init__(
429 | self,
430 | state_dim,
431 | action_dim,
432 | goal_dim,
433 | scale,
434 | model_path,
435 | model_save_freq,
436 | buffer_size,
437 | batch_size,
438 | start_training_steps):
439 |
440 | self.con = TD3Controller(
441 | state_dim=state_dim,
442 | goal_dim=goal_dim,
443 | action_dim=action_dim,
444 | scale=scale,
445 | model_path=model_path
446 | )
447 |
448 | self.replay_buffer = ReplayBuffer(
449 | state_dim=state_dim,
450 | goal_dim=goal_dim,
451 | action_dim=action_dim,
452 | buffer_size=buffer_size,
453 | batch_size=batch_size
454 | )
455 | self.model_save_freq = model_save_freq
456 | self.start_training_steps = start_training_steps
457 |
458 | def step(self, s, env, step, global_step=0, explore=False):
459 | if explore:
460 | if global_step < self.start_training_steps:
461 | a = env.action_space.sample()
462 | else:
463 | a = self._choose_action_with_noise(s)
464 | else:
465 | a = self._choose_action(s)
466 |
467 | obs, r, done, _ = env.step(a)
468 | n_s = obs['observation']
469 |
470 | return a, r, n_s, done
471 |
472 | def append(self, step, s, a, n_s, r, d):
473 | self.replay_buffer.append(s, self.fg, a, n_s, r, d)
474 |
475 | def train(self, global_step):
476 | return self.con.train(self.replay_buffer)
477 |
478 | def _choose_action(self, s):
479 | return self.con.policy(s, self.fg)
480 |
481 | def _choose_action_with_noise(self, s):
482 | return self.con.policy_with_noise(s, self.fg)
483 |
484 | def end_step(self):
485 | pass
486 |
487 | def end_episode(self, episode, logger=None):
488 | if logger:
489 | if _is_update(episode, self.model_save_freq):
490 | self.save(episode=episode)
491 |
492 | def save(self, episode):
493 | self.con.save(episode)
494 |
495 | def load(self, episode):
496 | self.con.load(episode)
497 |
498 | class HiroAgent(Agent):
499 | def __init__(
500 | self,
501 | state_dim,
502 | action_dim,
503 | goal_dim,
504 | subgoal_dim,
505 | scale_low,
506 | start_training_steps,
507 | model_save_freq,
508 | model_path,
509 | buffer_size,
510 | batch_size,
511 | buffer_freq,
512 | train_freq,
513 | reward_scaling,
514 | policy_freq_high,
515 | policy_freq_low):
516 |
517 | self.subgoal = Subgoal(subgoal_dim)
518 | scale_high = self.subgoal.action_space.high * np.ones(subgoal_dim)
519 |
520 | self.model_save_freq = model_save_freq
521 |
522 | self.high_con = HigherController(
523 | state_dim=state_dim,
524 | goal_dim=goal_dim,
525 | action_dim=subgoal_dim,
526 | scale=scale_high,
527 | model_path=model_path,
528 | policy_freq=policy_freq_high
529 | )
530 |
531 | self.low_con = LowerController(
532 | state_dim=state_dim,
533 | goal_dim=subgoal_dim,
534 | action_dim=action_dim,
535 | scale=scale_low,
536 | model_path=model_path,
537 | policy_freq=policy_freq_low
538 | )
539 |
540 | self.replay_buffer_low = LowReplayBuffer(
541 | state_dim=state_dim,
542 | goal_dim=subgoal_dim,
543 | action_dim=action_dim,
544 | buffer_size=buffer_size,
545 | batch_size=batch_size
546 | )
547 |
548 | self.replay_buffer_high = HighReplayBuffer(
549 | state_dim=state_dim,
550 | goal_dim=goal_dim,
551 | subgoal_dim=subgoal_dim,
552 | action_dim=action_dim,
553 | buffer_size=buffer_size,
554 | batch_size=batch_size,
555 | freq=buffer_freq
556 | )
557 |
558 | self.buffer_freq = buffer_freq
559 | self.train_freq = train_freq
560 | self.reward_scaling = reward_scaling
561 | self.episode_subreward = 0
562 | self.sr = 0
563 |
564 | self.buf = [None, None, None, 0, None, None, [], []]
565 | self.fg = np.array([0,0])
566 | self.sg = self.subgoal.action_space.sample()
567 |
568 | self.start_training_steps = start_training_steps
569 |
570 | def step(self, s, env, step, global_step=0, explore=False):
571 | ## Lower Level Controller
572 | if explore:
573 | # Take random action for start_training_steps
574 | if global_step < self.start_training_steps:
575 | a = env.action_space.sample()
576 | else:
577 | a = self._choose_action_with_noise(s, self.sg)
578 | else:
579 | a = self._choose_action(s, self.sg)
580 |
581 | # Take action
582 | obs, r, done, _ = env.step(a)
583 | n_s = obs['observation']
584 |
585 | ## Higher Level Controller
586 | # Take random action for start_training steps
587 | if explore:
588 | if global_step < self.start_training_steps:
589 | n_sg = self.subgoal.action_space.sample()
590 | else:
591 | n_sg = self._choose_subgoal_with_noise(step, s, self.sg, n_s)
592 | else:
593 | n_sg = self._choose_subgoal(step, s, self.sg, n_s)
594 |
595 | self.n_sg = n_sg
596 |
597 | return a, r, n_s, done
598 |
599 | def append(self, step, s, a, n_s, r, d):
600 | self.sr = self.low_reward(s, self.sg, n_s)
601 |
602 | # Low Replay Buffer
603 | self.replay_buffer_low.append(
604 | s, self.sg, a, n_s, self.n_sg, self.sr, float(d))
605 |
606 | # High Replay Buffer
607 | if _is_update(step, self.buffer_freq, rem=1):
608 | if len(self.buf[6]) == self.buffer_freq:
609 | self.buf[4] = s
610 | self.buf[5] = float(d)
611 | self.replay_buffer_high.append(
612 | state=self.buf[0],
613 | goal=self.buf[1],
614 | action=self.buf[2],
615 | n_state=self.buf[4],
616 | reward=self.buf[3],
617 | done=self.buf[5],
618 | state_arr=np.array(self.buf[6]),
619 | action_arr=np.array(self.buf[7])
620 | )
621 | self.buf = [s, self.fg, self.sg, 0, None, None, [], []]
622 |
623 | self.buf[3] += self.reward_scaling * r
624 | self.buf[6].append(s)
625 | self.buf[7].append(a)
626 |
627 | def train(self, global_step):
628 | losses = {}
629 | td_errors = {}
630 |
631 | if global_step >= self.start_training_steps:
632 | loss, td_error = self.low_con.train(self.replay_buffer_low)
633 | losses.update(loss)
634 | td_errors.update(td_error)
635 |
636 | if global_step % self.train_freq == 0:
637 | loss, td_error = self.high_con.train(self.replay_buffer_high, self.low_con)
638 | losses.update(loss)
639 | td_errors.update(td_error)
640 |
641 | return losses, td_errors
642 |
643 | def _choose_action_with_noise(self, s, sg):
644 | return self.low_con.policy_with_noise(s, sg)
645 |
646 | def _choose_subgoal_with_noise(self, step, s, sg, n_s):
647 | if step % self.buffer_freq == 0: # Should be zero
648 | sg = self.high_con.policy_with_noise(s, self.fg)
649 | else:
650 | sg = self.subgoal_transition(s, sg, n_s)
651 |
652 | return sg
653 |
654 | def _choose_action(self, s, sg):
655 | return self.low_con.policy(s, sg)
656 |
657 | def _choose_subgoal(self, step, s, sg, n_s):
658 | if step % self.buffer_freq == 0:
659 | sg = self.high_con.policy(s, self.fg)
660 | else:
661 | sg = self.subgoal_transition(s, sg, n_s)
662 |
663 | return sg
664 |
665 | def subgoal_transition(self, s, sg, n_s):
666 | return s[:sg.shape[0]] + sg - n_s[:sg.shape[0]]
667 |
668 | def low_reward(self, s, sg, n_s):
669 | abs_s = s[:sg.shape[0]] + sg
670 | return -np.sqrt(np.sum((abs_s - n_s[:sg.shape[0]])**2))
671 |
672 | def end_step(self):
673 | self.episode_subreward += self.sr
674 | self.sg = self.n_sg
675 |
676 | def end_episode(self, episode, logger=None):
677 | if logger:
678 | # log
679 | logger.write('reward/Intrinsic Reward', self.episode_subreward, episode)
680 |
681 | # Save Model
682 | if _is_update(episode, self.model_save_freq):
683 | self.save(episode=episode)
684 |
685 | self.episode_subreward = 0
686 | self.sr = 0
687 | self.buf = [None, None, None, 0, None, None, [], []]
688 |
689 | def save(self, episode):
690 | self.low_con.save(episode)
691 | self.high_con.save(episode)
692 |
693 | def load(self, episode):
694 | self.low_con.load(episode)
695 | self.high_con.load(episode)
696 |
--------------------------------------------------------------------------------
/hiro/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import numpy as np
4 | import torch
5 | from torch.utils.tensorboard import SummaryWriter
6 |
7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8 |
9 | def var(tensor):
10 | if torch.cuda.is_available():
11 | return tensor.cuda()
12 | else:
13 | return tensor
14 |
15 | def get_tensor(z):
16 | if len(z.shape) == 1:
17 | return var(torch.FloatTensor(z.copy())).unsqueeze(0)
18 | else:
19 | return var(torch.FloatTensor(z.copy()))
20 |
21 | class Logger():
22 | def __init__(self, log_path):
23 | self.writer = SummaryWriter(log_path)
24 |
25 | def print(self, name, value, episode=-1, step=-1):
26 | string = "{} is {}".format(name, value)
27 | if episode > 0:
28 | print('Episode:{}, {}'.format(episode, string))
29 | if step > 0:
30 | print('Step:{}, {}'.format(step, string))
31 |
32 | def write(self, name, value, index):
33 | self.writer.add_scalar(name, value, index)
34 |
35 | def _is_update(episode, freq, ignore=0, rem=0):
36 | if episode!=ignore and episode%freq==rem:
37 | return True
38 | return False
39 |
40 |
41 | class ReplayBuffer():
42 | def __init__(self, state_dim, action_dim, buffer_size, batch_size):
43 | self.buffer_size = buffer_size
44 | self.batch_size = batch_size
45 | self.ptr = 0
46 | self.size = 0
47 | self.state = np.zeros((buffer_size, state_dim))
48 | self.action = np.zeros((buffer_size, action_dim))
49 | self.n_state = np.zeros((buffer_size, state_dim))
50 | self.reward = np.zeros((buffer_size, 1))
51 | self.not_done = np.zeros((buffer_size, 1))
52 |
53 | self.device = device
54 |
55 | def append(self, state, action, n_state, reward, done):
56 | self.state[self.ptr] = state
57 | self.action[self.ptr] = action
58 | self.n_state[self.ptr] = n_state
59 | self.reward[self.ptr] = reward
60 | self.not_done[self.ptr] = 1. - done
61 |
62 | self.ptr = (self.ptr+1) % self.buffer_size
63 | self.size = min(self.size+1, self.buffer_size)
64 |
65 | def sample(self):
66 | ind = np.random.randint(0, self.size, size=self.batch_size)
67 |
68 | return (
69 | torch.FloatTensor(self.state[ind]).to(self.device),
70 | torch.FloatTensor(self.action[ind]).to(self.device),
71 | torch.FloatTensor(self.n_state[ind]).to(self.device),
72 | torch.FloatTensor(self.reward[ind]).to(self.device),
73 | torch.FloatTensor(self.not_done[ind]).to(self.device),
74 | )
75 |
76 | def record_experience_to_csv(args, experiment_name, csv_name='experiments.csv'):
77 | # append DATE_TIME to dict
78 | d = vars(args)
79 | d['date'] = experiment_name
80 |
81 | if os.path.exists(csv_name):
82 | # Save Dictionary to a csv
83 | with open(csv_name, 'a') as f:
84 | w = csv.DictWriter(f, list(d.keys()))
85 | w.writerow(d)
86 | else:
87 | # Save Dictionary to a csv
88 | with open(csv_name, 'w') as f:
89 | w = csv.DictWriter(f, list(d.keys()))
90 | w.writeheader()
91 | w.writerow(d)
92 |
93 | def listdirs(directory):
94 | return [name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name))]
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import datetime
5 | import copy
6 | from envs import EnvWithGoal
7 | from envs.create_maze_env import create_maze_env
8 | from hiro.hiro_utils import Subgoal
9 | from hiro.utils import Logger, _is_update, record_experience_to_csv, listdirs
10 | from hiro.models import HiroAgent, TD3Agent
11 |
12 | def run_evaluation(args, env, agent):
13 | agent.load(args.load_episode)
14 |
15 | rewards, success_rate = agent.evaluate_policy(env, args.eval_episodes, args.render, args.save_video, args.sleep)
16 |
17 | print('mean:{mean:.2f}, \
18 | std:{std:.2f}, \
19 | median:{median:.2f}, \
20 | success:{success:.2f}'.format(
21 | mean=np.mean(rewards),
22 | std=np.std(rewards),
23 | median=np.median(rewards),
24 | success=success_rate))
25 |
26 | class Trainer():
27 | def __init__(self, args, env, agent, experiment_name):
28 | self.args = args
29 | self.env = env
30 | self.agent = agent
31 | log_path = os.path.join(args.log_path, experiment_name)
32 | self.logger = Logger(log_path=log_path)
33 |
34 | def train(self):
35 | global_step = 0
36 |
37 | for e in np.arange(self.args.num_episode)+1:
38 | obs = self.env.reset()
39 | fg = obs['desired_goal']
40 | s = obs['observation']
41 | done = False
42 |
43 | step = 0
44 | episode_reward = 0
45 |
46 | self.agent.set_final_goal(fg)
47 |
48 | while not done:
49 | # Take action
50 | a, r, n_s, done = self.agent.step(s, self.env, step, global_step, explore=True)
51 |
52 | # Append
53 | self.agent.append(step, s, a, n_s, r, done)
54 |
55 | # Train
56 | losses, td_errors = self.agent.train(global_step)
57 |
58 | # Log
59 | self.log(global_step, [losses, td_errors])
60 |
61 | # Updates
62 | s = n_s
63 | episode_reward += r
64 | step += 1
65 | global_step += 1
66 | self.agent.end_step()
67 |
68 | self.agent.end_episode(e, self.logger)
69 | self.logger.write('reward/Reward', episode_reward, e)
70 | self.evaluate(e)
71 |
72 | def log(self, global_step, data):
73 | losses, td_errors = data[0], data[1]
74 |
75 | # Logs
76 | if global_step >= self.args.start_training_steps and _is_update(global_step, args.writer_freq):
77 | for k, v in losses.items():
78 | self.logger.write('loss/%s'%(k), v, global_step)
79 |
80 | for k, v in td_errors.items():
81 | self.logger.write('td_error/%s'%(k), v, global_step)
82 |
83 | def evaluate(self, e):
84 | # Print
85 | if _is_update(e, args.print_freq):
86 | agent = copy.deepcopy(self.agent)
87 | rewards, success_rate = agent.evaluate_policy(self.env)
88 | #rewards, success_rate = self.agent.evaluate_policy(self.env)
89 | self.logger.write('Success Rate', success_rate, e)
90 |
91 | print('episode:{episode:05d}, mean:{mean:.2f}, std:{std:.2f}, median:{median:.2f}, success:{success:.2f}'.format(
92 | episode=e,
93 | mean=np.mean(rewards),
94 | std=np.std(rewards),
95 | median=np.median(rewards),
96 | success=success_rate))
97 |
98 | if __name__ == '__main__':
99 | parser = argparse.ArgumentParser()
100 |
101 | # Across All
102 | parser.add_argument('--train', action='store_true')
103 | parser.add_argument('--eval', action='store_true')
104 | parser.add_argument('--render', action='store_true')
105 | parser.add_argument('--save_video', action='store_true')
106 | parser.add_argument('--sleep', type=float, default=-1)
107 | parser.add_argument('--eval_episodes', type=float, default=5, help='Unit = Episode')
108 | parser.add_argument('--env', default='AntMaze', type=str)
109 | parser.add_argument('--td3', action='store_true')
110 |
111 | # Training
112 | parser.add_argument('--num_episode', default=25000, type=int)
113 | parser.add_argument('--start_training_steps', default=2500, type=int, help='Unit = Global Step')
114 | parser.add_argument('--writer_freq', default=25, type=int, help='Unit = Global Step')
115 | # Training (Model Saving)
116 | parser.add_argument('--subgoal_dim', default=15, type=int)
117 | parser.add_argument('--load_episode', default=-1, type=int)
118 | parser.add_argument('--model_save_freq', default=2000, type=int, help='Unit = Episodes')
119 | parser.add_argument('--print_freq', default=250, type=int, help='Unit = Episode')
120 | parser.add_argument('--exp_name', default=None, type=str)
121 | # Model
122 | parser.add_argument('--model_path', default='model', type=str)
123 | parser.add_argument('--log_path', default='log', type=str)
124 | parser.add_argument('--policy_freq_low', default=2, type=int)
125 | parser.add_argument('--policy_freq_high', default=2, type=int)
126 | # Replay Buffer
127 | parser.add_argument('--buffer_size', default=200000, type=int)
128 | parser.add_argument('--batch_size', default=100, type=int)
129 | parser.add_argument('--buffer_freq', default=10, type=int)
130 | parser.add_argument('--train_freq', default=10, type=int)
131 | parser.add_argument('--reward_scaling', default=0.1, type=float)
132 | args = parser.parse_args()
133 |
134 | # Select or Generate a name for this experiment
135 | if args.exp_name:
136 | experiment_name = args.exp_name
137 | else:
138 | if args.eval:
139 | # choose most updated experiment for evaluation
140 | dirs_str = listdirs(args.model_path)
141 | dirs = np.array(list(map(int, dirs_str)))
142 | experiment_name = dirs_str[np.argmax(dirs)]
143 | else:
144 | experiment_name = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
145 | print(experiment_name)
146 |
147 | # Environment and its attributes
148 | env = EnvWithGoal(create_maze_env(args.env), args.env)
149 | goal_dim = 2
150 | state_dim = env.state_dim
151 | action_dim = env.action_dim
152 | scale = env.action_space.high * np.ones(action_dim)
153 |
154 | # Spawn an agent
155 | if args.td3:
156 | agent = TD3Agent(
157 | state_dim=state_dim,
158 | action_dim=action_dim,
159 | goal_dim=goal_dim,
160 | scale=scale,
161 | model_save_freq=args.model_save_freq,
162 | model_path=os.path.join(args.model_path, experiment_name),
163 | buffer_size=args.buffer_size,
164 | batch_size=args.batch_size,
165 | start_training_steps=args.start_training_steps
166 | )
167 | else:
168 | agent = HiroAgent(
169 | state_dim=state_dim,
170 | action_dim=action_dim,
171 | goal_dim=goal_dim,
172 | subgoal_dim=args.subgoal_dim,
173 | scale_low=scale,
174 | start_training_steps=args.start_training_steps,
175 | model_path=os.path.join(args.model_path, experiment_name),
176 | model_save_freq=args.model_save_freq,
177 | buffer_size=args.buffer_size,
178 | batch_size=args.batch_size,
179 | buffer_freq=args.buffer_freq,
180 | train_freq=args.train_freq,
181 | reward_scaling=args.reward_scaling,
182 | policy_freq_high=args.policy_freq_high,
183 | policy_freq_low=args.policy_freq_low
184 | )
185 |
186 | # Run training or evaluation
187 | if args.train:
188 | # Record this experiment with arguments to a CSV file
189 | record_experience_to_csv(args, experiment_name)
190 | # Start training
191 | trainer = Trainer(args, env, agent, experiment_name)
192 | trainer.train()
193 | if args.eval:
194 | run_evaluation(args, env, agent)
--------------------------------------------------------------------------------
/media/Success_Rate.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/watakandai/hiro_pytorch/5a92a286a334eb8480a1a4c781abde22c3b67cdf/media/demo.gif
--------------------------------------------------------------------------------
/media/loss_actor_loss_td3.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/media/loss_critic_loss_high.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gym
2 | torch
3 | numpy
4 | mujoco_py
5 |
--------------------------------------------------------------------------------
/test/test_agent.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | import sys, os
4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5 | import torch
6 | from hiro.models import HiroAgent
7 | from hiro.hiro_utils import Subgoal, spawn_dims
8 |
9 | from envs import EnvWithGoal
10 | from envs.create_maze_env import create_maze_env
11 |
12 | ENV_NAME = 'AntMaze'
13 |
14 | class AgentTest(unittest.TestCase):
15 | def test_low_reward(self):
16 | env = EnvWithGoal(create_maze_env(ENV_NAME), ENV_NAME)
17 | subgoal = Subgoal()
18 |
19 | subgoal_dim = subgoal.action_dim
20 | state_dim, goal_dim, action_dim, scale_low = spawn_dims(env)
21 | scale_high = subgoal.action_space.high * np.ones(subgoal_dim)
22 |
23 | agent = HiroAgent(
24 | state_dim=state_dim,
25 | action_dim=action_dim,
26 | goal_dim=goal_dim,
27 | subgoal_dim=subgoal_dim,
28 | scale_low=scale_low,
29 | scale_high=scale_high)
30 |
31 | goal = np.array([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 , 0])
32 |
33 | state = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
34 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
35 | next_state = np.array([1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
36 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
37 | reward1 = agent.low_reward(state, goal, next_state)
38 |
39 | state = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
40 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
41 | next_state = np.array([1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
42 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
43 | reward2 = agent.low_reward(state, goal, next_state)
44 |
45 | self.assertTrue(reward2 > reward1)
46 |
47 | def test_low_reward_negative(self):
48 | env = EnvWithGoal(create_maze_env(ENV_NAME), ENV_NAME)
49 | subgoal = Subgoal()
50 |
51 | subgoal_dim = subgoal.action_dim
52 | state_dim, goal_dim, action_dim, scale_low = spawn_dims(env)
53 | scale_high = subgoal.action_space.high * np.ones(subgoal_dim)
54 |
55 | agent = HiroAgent(
56 | state_dim=state_dim,
57 | action_dim=action_dim,
58 | goal_dim=goal_dim,
59 | subgoal_dim=subgoal_dim,
60 | scale_low=scale_low,
61 | scale_high=scale_high)
62 |
63 | goal = np.array([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 , 0])
64 |
65 | state = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
66 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
67 | next_state = np.array([1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
68 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
69 | reward1 = agent.low_reward(state, goal, next_state)
70 |
71 | state = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
72 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
73 | next_state = np.array([-1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
74 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
75 | reward2 = agent.low_reward(state, goal, next_state)
76 |
77 | self.assertTrue(reward1 > reward2)
78 |
79 | def test_subgoal_transition(self):
80 | env = EnvWithGoal(create_maze_env(ENV_NAME), ENV_NAME)
81 | subgoal = Subgoal()
82 |
83 | subgoal_dim = subgoal.action_dim
84 | state_dim, goal_dim, action_dim, scale_low = spawn_dims(env)
85 | scale_high = subgoal.action_space.high * np.ones(subgoal_dim)
86 |
87 | agent = HiroAgent(
88 | state_dim=state_dim,
89 | action_dim=action_dim,
90 | goal_dim=goal_dim,
91 | subgoal_dim=subgoal_dim,
92 | scale_low=scale_low,
93 | scale_high=scale_high)
94 |
95 | goal = np.array([5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 , 0])
96 |
97 | state = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
98 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
99 | next_state = np.array([1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
100 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
101 | subgoal = agent.subgoal_transition(state, goal, next_state)
102 |
103 | # distance from current state to current goal should be maintained
104 | self.assertEqual(goal-state, subgoal-next_state)
105 |
106 |
107 | if __name__ == '__main__':
108 | unittest.main(verbosity=2)
109 |
--------------------------------------------------------------------------------
/test/test_env.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | import sys, os
4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5 | import torch
6 | from hiro.models import HiroAgent
7 | from hiro.hiro_utils import Subgoal, spawn_dims
8 |
9 | from envs import EnvWithGoal
10 | from envs.create_maze_env import create_maze_env
11 |
12 | ENV_NAME = 'AntMaze'
13 |
14 | class EnvTest(unittest.TestCase):
15 | def test_dimensions(self):
16 | env = EnvWithGoal(create_maze_env(ENV_NAME), ENV_NAME)
17 | subgoal = Subgoal()
18 | subgoal_dim = subgoal.action_dim
19 | state_dim, goal_dim, action_dim, _ = spawn_dims(env)
20 |
21 | # {xyz=3, orientation (quaternion)=4, limb angles=8} * {pos, vel}
22 | # = (3+4+8)*2 = 15*2 = 30
23 | # states + time (1)
24 | self.assertEqual(state_dim, 31)
25 | # num of limbs
26 | self.assertEqual(action_dim, 8)
27 | # {xyz=3, orientation (quaternion)=4, limb angles=8}
28 | # = 3+4+8 = 15
29 | self.assertEqual(subgoal_dim, 15)
30 | # x, y
31 | self.assertEqual(goal_dim, 2)
32 |
33 | def test_low_action_limit(self):
34 | env = EnvWithGoal(create_maze_env(ENV_NAME), ENV_NAME)
35 | subgoal = Subgoal()
36 |
37 | subgoal_dim = subgoal.action_dim
38 | _, _, _, action_lim = spawn_dims(env)
39 | action_lim_given = np.array([30]*15)
40 |
41 | self.assertTrue((action_lim == action_lim_given).all())
42 |
43 | def test_high_action_limit(self):
44 | subgoal = Subgoal()
45 | subgoal_dim = subgoal.action_dim
46 | action_lim = subgoal.action_space.high * np.ones(subgoal_dim)
47 |
48 | action_lim_given = np.array([
49 | 10, 10, 0.5, 0.5, 1, 1, 1, 1,
50 | 0.5, 0.3, 0.5, 0.3, 0.5, 0.3, 0.5, 0.3
51 | ])
52 |
53 | self.assertTrue((action_lim == action_lim_given).all())
54 |
55 | def test_goal_does_not_change(self):
56 | env = EnvWithGoal(create_maze_env(ENV_NAME), ENV_NAME)
57 |
58 | obs = env.reset()
59 | goal = obs['desired_goal']
60 |
61 | for i in range(100):
62 | a = np.random.rand(action_dim)
63 | obs, reward, done, info = env.step(a)
64 | g = obs['desired_goal']
65 |
66 | self.assertEqual(goal, g)
67 |
68 | def test_state_does_change(self):
69 | env = EnvWithGoal(create_maze_env(ENV_NAME), ENV_NAME)
70 | max_action = float(env.action_space.high[0])
71 |
72 | obs = env.reset()
73 | state = obs['observation']
74 |
75 | for i in range(100):
76 | a = np.random.rand(action_dim)
77 | a = np.clip(max_action*a, -max_action, max_action)
78 | obs, reward, done, info = env.step(a)
79 | s = obs['observation']
80 |
81 | self.assertNotEqual(state, s)
82 |
83 | def test_reward_equation(self):
84 | env = EnvWithGoal(create_maze_env(ENV_NAME), ENV_NAME)
85 |
86 | obs = env.reset()
87 | goal = obs['desired_goal']
88 | state = obs['observation']
89 |
90 | a = np.random.rand(action_dim)
91 | obs, reward, done, info = env.step(a)
92 |
93 | goal = obs['desired_goal']
94 | state = obs['observation']
95 |
96 | diff = state[:2] - goal
97 | squared = np.square(diff)
98 | sum_squared = np.sum(squared)
99 | mse = np.sqrt(sum_squared)
100 | hand_computed_reward = -mse
101 |
102 | self.assertEqual(reward, hand_computed_reward)
103 |
104 | def test_goal_range(self):
105 | env = EnvWithGoal(create_maze_env(ENV_NAME), ENV_NAME)
106 |
107 | obs = env.reset()
108 | goal = obs['desired_goal']
109 |
110 | goals = np.zeros((1000, goal.shape[0]))
111 |
112 | for i in range(1000):
113 | obs = env.reset()
114 | goal = obs['desired_goal']
115 | goals[i,:] = goal
116 |
117 | self.assertAlmostEqual(np.min(goal[:,0]), -4)
118 | self.assertAlmostEqual(np.min(goal[:,1]), -4)
119 | self.assertAlmostEqual(np.max(goal[:,0]), 20)
120 | self.assertAlmostEqual(np.max(goal[:,1]), 20)
121 |
122 | if __name__ == '__main__':
123 | unittest.main(verbosity=2)
124 |
--------------------------------------------------------------------------------
/test/test_models.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | import sys, os
4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5 | import torch
6 | from hiro.models import TD3Actor, TD3, get_tensor
7 |
8 | class ModelsTest(unittest.TestCase):
9 | def test_td3actor_output_size(self):
10 | max_action = get_tensor(np.random.randint(5, size=5))
11 | actor = TD3Actor(10, 3, 5, max_action)
12 | state = get_tensor(np.random.rand(10))
13 | goal = get_tensor(np.random.rand(3))
14 | y = actor(state, goal)
15 |
16 | self.assertEqual(y.shape[1], 5)
17 |
18 | def test_td3actor_output_type(self):
19 | max_action = get_tensor(np.random.randint(5, size=5))
20 | actor = TD3Actor(10, 3, 5, max_action)
21 | state = get_tensor(np.random.rand(10))
22 | goal = get_tensor(np.random.rand(3))
23 | y = actor(state, goal)
24 |
25 | self.assertEqual(type(y), torch.Tensor)
26 |
27 | def test_td3actor_output_minmax(self):
28 | random_value = 100
29 | sdim = 10
30 | gdim = 3
31 | adim = 5
32 | max_action = get_tensor(np.random.randint(random_value, size=adim))
33 | actor = TD3Actor(sdim, gdim, adim, max_action)
34 |
35 | x = np.inf * np.ones(adim)
36 | x = np.array([x, -x])
37 | x = torch.tensor(x)
38 | out = actor.scale * actor.max_action * torch.tanh(x)
39 |
40 | for i in range(adim):
41 | self.assertEqual(torch.max(out[0,i]), max_action[0,i])
42 | self.assertEqual(torch.min(out[1,i]), -max_action[0,i])
43 |
44 |
45 | if __name__ == '__main__':
46 | unittest.main(verbosity=2)
47 |
--------------------------------------------------------------------------------