├── .gitignore ├── .gitmodules ├── .style.yapf ├── LICENSE ├── README.md ├── fabfile.py ├── heuristics ├── __init__.py ├── multi_walker.py ├── pursuit.py └── waterworld.py ├── lessons ├── multiant │ └── env.yaml └── multiwalker │ └── env.yaml ├── madrl_environments ├── __init__.py ├── box_carrying.py ├── hostage.py ├── mujoco │ ├── __init__.py │ └── ant │ │ ├── __init__.py │ │ ├── ant_og.xml │ │ ├── multi_ant.py │ │ └── multi_ant.xml ├── pursuit │ ├── __init__.py │ ├── eval_scripts │ │ ├── policy_eval.py │ │ ├── speed_pursuit.py │ │ └── stationary_eval.py │ ├── pursuit_evade.py │ ├── test_pursuit.py │ ├── utils │ │ ├── AgentLayer.py │ │ ├── Controllers.py │ │ ├── DiscreteAgent.py │ │ ├── TwoDMaps.py │ │ ├── __init__.py │ │ └── agent_utils.py │ ├── vis_policy.py │ └── waterworld.py └── walker │ ├── __init__.py │ ├── multi_walker.py │ ├── test_walker.py │ ├── train_multi_walker.py │ └── train_single_walker.py ├── maps └── map_pool16.npy ├── pipelines ├── __init__.py ├── cont_pipeline.py ├── disc_pipeline.py ├── host_pipeline.py ├── pipeline.py ├── run_pipeline.py └── waterworld.py ├── pursuit_policy.py ├── rllabwrapper ├── __init__.py └── rllab_gru_test.py ├── runners ├── __init__.py ├── archs.py ├── curriculum.py ├── old │ ├── rllab │ │ ├── pursuit.sh │ │ ├── pursuit_cnn.sh │ │ ├── pursuit_test.sh │ │ ├── run_hostage.py │ │ ├── run_pursuit.py │ │ ├── run_pursuit_theano.py │ │ ├── run_walker.py │ │ └── run_waterworld.py │ └── rltools │ │ ├── __init__.py │ │ ├── pursuit.sh │ │ ├── run_con_hostage.py │ │ ├── run_con_waterworld.py │ │ ├── run_hostage.py │ │ ├── run_pursuit.py │ │ └── run_waterworld.py ├── run_hostage.py ├── run_multiant.py ├── run_multiwalker.py ├── run_pursuit.py ├── run_waterworld.py ├── rurllab.py └── rurltools.py ├── sample_spec.yaml └── vis ├── __init__.py ├── bar_plot.py ├── max_bar_plot.py ├── rllab ├── showlog.py ├── vis_pursuit.py └── vis_waterworld.py ├── rltools ├── vis_hostage.py ├── vis_pursuit.py └── vis_waterworld.py ├── vis_multiant.py ├── vis_multiwalker.py ├── vis_pursuit.py ├── vis_waterworld.py ├── wilco2.py └── wilcoxon.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "rllab"] 2 | path = rllab 3 | url = git@github.com:rejuvyesh/rllab.git 4 | branch = multiagent 5 | [submodule "rltools"] 6 | path = rltools 7 | url = git@github.com:sisl/rltools.git 8 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | split_before_named_assigns = False 3 | based_on_style = chromium 4 | indent_width = 4 5 | column_limit = 100 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Stanford Intelligent Systems Laboratory 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MADRL 2 | 3 | **Note**: 4 | **A maintained version of the first three environments with various fixes is included with PettingZoo (https://github.com/PettingZoo-Team/PettingZoo, https://pettingzoo.farama.org/environments/sisl/)** 5 | 6 | This package provides implementations of the following multi-agent reinforcement learning environemnts: 7 | 8 | - [Pursuit Evastion](https://github.com/sisl/MADRL/blob/master/madrl_environments/pursuit/pursuit_evade.py) 9 | - [Waterworld](https://github.com/sisl/MADRL/blob/master/madrl_environments/pursuit/waterworld.py) 10 | - [Multi-Agent Walker](https://github.com/sisl/MADRL/blob/master/madrl_environments/walker/multi_walker.py) 11 | - [Multi-Ant](https://github.com/sisl/MADRL/blob/master/madrl_environments/mujoco/ant/multi_ant.py) 12 | 13 | 14 | ## Requirements 15 | 16 | This package requires both [OpenAI Gym](https://github.com/openai/gym) and a forked version of [rllab](https://github.com/rejuvyesh/rllab/tree/multiagent) (the multiagent branch). There are a number of other requirements which can be found 17 | in `rllab/environment.yml` file if using `anaconda` distribution. 18 | 19 | ## Setup 20 | 21 | The easiest way to install MADRL and its dependencies is to perform a recursive clone of this repository. 22 | ```bash 23 | git clone --recursive git@github.com:sisl/MADRL.git 24 | ``` 25 | 26 | Then, add directories to `PYTHONPATH` 27 | ```bash 28 | export PYTHONPATH=$(pwd):$(pwd)/rltools:$(pwd)/rllab:$PYTHONPATH 29 | ``` 30 | 31 | Install the required dependencies. Good idea is to look into `rllab/environment.yml` file if using `anaconda` distribution. 32 | 33 | ## Usage 34 | 35 | Example run with curriculum: 36 | 37 | ```bash 38 | python3 runners/run_multiwalker.py rllab \ # Use rllab for training 39 | --control decentralized \ # Decentralized training protocol 40 | --policy_hidden 100,50,25 \ # Set MLP policy hidden layer sizes 41 | --n_iter 200 \ # Number of iterations 42 | --n_walkers 2 \ # Starting number of walkers 43 | --batch_size 24000 \ # Number of rollout waypoints 44 | --curriculum lessons/multiwalker/env.yaml 45 | ``` 46 | 47 | ## Details 48 | 49 | Policy definitions exist in `rllab/sandbox/rocky/tf/policies`. 50 | 51 | ## Citation 52 | 53 | Please cite the accompanied paper, if you find this useful: 54 | 55 | ``` 56 | @inproceedings{gupta2017cooperative, 57 | title={Cooperative multi-agent control using deep reinforcement learning}, 58 | author={Gupta, Jayesh K and Egorov, Maxim and Kochenderfer, Mykel}, 59 | booktitle={International Conference on Autonomous Agents and Multiagent Systems}, 60 | pages={66--83}, 61 | year={2017}, 62 | organization={Springer} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /fabfile.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: fabfile.py 4 | # 5 | # Created: Wednesday, August 24 2016 by rejuvyesh 6 | # License: GNU GPL 3 7 | # 8 | from fabric.api import cd, put, path, task, shell_env, run, env, local, settings 9 | from fabric.contrib.project import rsync_project 10 | import os.path 11 | from time import sleep 12 | 13 | env.use_ssh_config = True 14 | 15 | RLTOOLS_LOC = '/home/{}/src/python/rltools'.format(env.user) 16 | MADRL_LOC = '/home/{}/src/python/MADRL'.format(env.user) 17 | 18 | 19 | class Tmux(object): 20 | 21 | def __init__(self, name): 22 | self._name = name 23 | with settings(warn_only=True): 24 | test = run('tmux has-session -t {}'.format(self._name)) 25 | if test.failed: 26 | run('tmux new-session -d -s {}'.format(self._name)) 27 | 28 | def run(self, cmd, window=0): 29 | run('tmux send -t {}.{} "{}" ENTER'.format(self._name, window, cmd)) 30 | 31 | 32 | @task 33 | def githash(): 34 | git_hash = local('git rev-parse HEAD', capture=True) 35 | return git_hash 36 | 37 | @task 38 | def sync(): 39 | rsync_project(remote_dir=os.path.split(MADRL_LOC)[0], exclude=['*.h5']) 40 | 41 | @task(alias='pipe') 42 | def runpipeline(script, fname): 43 | git_hash = githash() 44 | sync() 45 | pipetm = Tmux('pipeline') 46 | pipetm.run('export PYTHONPATH={}:{}'.format(RLTOOLS_LOC, MADRL_LOC)) 47 | pipetm.run('cd {}'.format(MADRL_LOC)) 48 | pipetm.run('python {} {} {}'.format(script, fname, git_hash)) 49 | sleep(0.5) 50 | pipetm.run('y') 51 | -------------------------------------------------------------------------------- /heuristics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sisl/MADRL/9ea39a0fe8695b391008a4eb7bda9fe4438a96de/heuristics/__init__.py -------------------------------------------------------------------------------- /heuristics/multi_walker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rltools.policy import Policy 4 | 5 | STAY_ON_ONE_LEG, PUT_OTHER_DOWN, PUSH_OFF = 1, 2, 3 6 | SPEED = 0.29 # Will fall forward on higher speed 7 | SUPPORT_KNEE_ANGLE = +0.1 8 | 9 | 10 | class MultiWalkerHeuristicPolicy(Policy): 11 | 12 | def __init__(self, observation_space, action_space): 13 | super(MultiWalkerHeuristicPolicy, self).__init__(observation_space, action_space) 14 | 15 | def sample_actions(self, obs_B_Do, deterministic=True): 16 | 17 | n_agents = obs_B_Do.shape[0] 18 | actions = np.zeros((n_agents, 4)) 19 | 20 | for i in xrange(n_agents): 21 | a = np.zeros(4) 22 | s = obs_B_Do[i] 23 | state = STAY_ON_ONE_LEG 24 | moving_leg = 0 25 | supporting_leg = 1 - moving_leg 26 | supporting_knee_angle = SUPPORT_KNEE_ANGLE 27 | 28 | contact0 = s[8] 29 | contact1 = s[13] 30 | moving_s_base = 4 + 5 * moving_leg 31 | supporting_s_base = 4 + 5 * supporting_leg 32 | 33 | hip_targ = [None, None] # -0.8 .. +1.1 34 | knee_targ = [None, None] # -0.6 .. +0.9 35 | hip_todo = [0.0, 0.0] 36 | knee_todo = [0.0, 0.0] 37 | 38 | if state == STAY_ON_ONE_LEG: 39 | hip_targ[moving_leg] = 1.1 40 | knee_targ[moving_leg] = -0.6 41 | supporting_knee_angle += 0.03 42 | if s[2] > SPEED: 43 | supporting_knee_angle += 0.03 44 | supporting_knee_angle = min(supporting_knee_angle, SUPPORT_KNEE_ANGLE) 45 | knee_targ[supporting_leg] = supporting_knee_angle 46 | if s[supporting_s_base + 0] < 0.10: # supporting leg is behind 47 | state = PUT_OTHER_DOWN 48 | if state == PUT_OTHER_DOWN: 49 | hip_targ[moving_leg] = +0.1 50 | knee_targ[moving_leg] = SUPPORT_KNEE_ANGLE 51 | knee_targ[supporting_leg] = supporting_knee_angle 52 | if s[moving_s_base + 4]: 53 | state = PUSH_OFF 54 | supporting_knee_angle = min(s[moving_s_base + 2], SUPPORT_KNEE_ANGLE) 55 | if state == PUSH_OFF: 56 | knee_targ[moving_leg] = supporting_knee_angle 57 | knee_targ[supporting_leg] = +1.0 58 | if s[supporting_s_base + 2] > 0.88 or s[2] > 1.2 * SPEED: 59 | state = STAY_ON_ONE_LEG 60 | moving_leg = 1 - moving_leg 61 | supporting_leg = 1 - moving_leg 62 | 63 | if hip_targ[0]: 64 | hip_todo[0] = 0.9 * (hip_targ[0] - s[4]) - 0.25 * s[5] 65 | if hip_targ[1]: 66 | hip_todo[1] = 0.9 * (hip_targ[1] - s[9]) - 0.25 * s[10] 67 | if knee_targ[0]: 68 | knee_todo[0] = 4.0 * (knee_targ[0] - s[6]) - 0.25 * s[7] 69 | if knee_targ[1]: 70 | knee_todo[1] = 4.0 * (knee_targ[1] - s[11]) - 0.25 * s[12] 71 | 72 | hip_todo[0] -= 0.9 * (0 - s[0]) - 1.5 * s[1] # PID to keep head strait 73 | hip_todo[1] -= 0.9 * (0 - s[0]) - 1.5 * s[1] 74 | knee_todo[0] -= 15.0 * s[3] # vertical speed, to damp oscillations 75 | knee_todo[1] -= 15.0 * s[3] 76 | 77 | a[0] = hip_todo[0] 78 | a[1] = knee_todo[0] 79 | a[2] = hip_todo[1] 80 | a[3] = knee_todo[1] 81 | a = np.clip(0.5 * a, -1.0, 1.0) 82 | 83 | actions[i, :] = a 84 | 85 | fake_actiondist = np.concatenate([np.zeros((n_agents, 4)), np.ones((n_agents, 4))]) 86 | return actions, fake_actiondist 87 | 88 | 89 | if __name__ == '__main__': 90 | from madrl_environments.walker.multi_walker import MultiWalkerEnv 91 | from vis import Visualizer 92 | import pprint 93 | env = MultiWalkerEnv(n_walkers=3) 94 | train_args = {'discount': 0.99, 'control': 'decentralized'} 95 | vis = Visualizer(env, train_args, 500, 1, True, 'heuristic') 96 | 97 | rew, info = vis(None, hpolicy=MultiWalkerHeuristicPolicy(env.agents[0].observation_space, 98 | env.agents[0].observation_space)) 99 | 100 | pprint.pprint(rew) 101 | pprint.pprint(info) 102 | -------------------------------------------------------------------------------- /heuristics/pursuit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math as m 3 | 4 | from rltools.policy import Policy 5 | 6 | LEFT = 0 7 | RIGHT = 1 8 | UP = 2 9 | DOWN = 3 10 | STAY = 4 11 | 12 | 13 | class PursuitHeuristicPolicy(Policy): 14 | 15 | def __init__(self, observation_space, action_space): 16 | super(PursuitHeuristicPolicy, self).__init__(observation_space, action_space) 17 | 18 | def sample_actions(self, obs_B_Do, deterministic=True): 19 | n_ev = np.sum(obs_B_Do[..., 2]) 20 | n_pr = np.sum(obs_B_Do[..., 1]) 21 | xs, ys = obs_B_Do.shape[0], obs_B_Do.shape[1] 22 | 23 | x, y = xs / 2, ys / 2 24 | 25 | # if see evader move to it 26 | if n_ev > 0: 27 | xev, yev = np.nonzero(obs_B_Do[..., 2]) 28 | d = np.sqrt((xev - x)**2 + (yev - y)**2) 29 | midx = np.argmin(d) 30 | xc, yc = xev[midx], yev[midx] 31 | if xc == x and yc == y: 32 | return STAY, None 33 | ang = m.atan2(yc - y, xc - x) 34 | ang = (ang + np.pi) % (2 * np.pi) - np.pi 35 | # Right 36 | if -np.pi / 4 <= ang < np.pi / 4: 37 | return RIGHT, None 38 | # Up 39 | elif np.pi / 4 <= ang < 3 / 4. * np.pi: 40 | return UP, None 41 | # Left 42 | elif ang >= 3 / 4. * np.pi or ang < -3 / 4. * np.pi: 43 | return LEFT, None 44 | # Down 45 | elif -3 / 4. * np.pi <= ang < -np.pi / 4: 46 | return DOWN, None 47 | else: 48 | return self.action_space.sample(), None 49 | 50 | return self.action_space.sample(), None 51 | 52 | def get_state(self): 53 | return [] 54 | 55 | def set_state(self, *args): 56 | pass 57 | 58 | 59 | if __name__ == '__main__': 60 | from madrl_environments.pursuit import PursuitEvade 61 | from madrl_environments.pursuit.utils import * 62 | from vis import Visualizer 63 | import pprint 64 | map_mat = TwoDMaps.rectangle_map(16, 16) 65 | 66 | env = PursuitEvade([map_mat], n_evaders=30, n_pursuers=8, obs_range=7, n_catch=4, 67 | surround=False, flatten=False) 68 | 69 | policy = PursuitHeuristicPolicy(env.agents[0].observation_space, env.agents[0].action_space) 70 | 71 | for i in range(20): 72 | rew = 0.0 73 | obs = env.reset() 74 | infolist = [] 75 | for _ in xrange(500): 76 | # env.render() 77 | act_list = [] 78 | for o in obs: 79 | a, _ = policy.sample_actions(o) 80 | act_list.append(a) 81 | obs, r, done, info = env.step(act_list) 82 | rew += np.mean(r) 83 | infolist.append(info) 84 | if done: 85 | break 86 | 87 | pprint.pprint(rew / 20) 88 | pprint.pprint(info) 89 | -------------------------------------------------------------------------------- /heuristics/waterworld.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rltools.policy import Policy 4 | 5 | 6 | class WaterworldHeuristicPolicy(Policy): 7 | 8 | def __init__(self, observation_space, action_space): 9 | super(WaterworldHeuristicPolicy, self).__init__(observation_space, action_space) 10 | 11 | def sample_actions(self, obs_B_Do, deterministic=True): 12 | # Obs space 13 | # 0:K -> obdist 14 | # K:2K -> evdist 15 | # 2K:3K -> evspeed 16 | # 3K:4K -> podist 17 | # 4K:5K -> pospeed 18 | # 5K:6K -> pudist (allies) 19 | # 6K:7K -> puspeed (allies) 20 | # 7K -> colliding with ev 21 | # 7K+1 -> colliding with po 22 | # [7K+2] -> id 23 | 24 | # Act space 25 | # 0 -> x, 1-> y acceleration 26 | B = obs_B_Do.shape[0] 27 | K = obs_B_Do.shape[1] // 7 28 | angles_K = np.linspace(0., 2. * np.pi, K + 1)[:-1] 29 | 30 | vecs_K = np.c_[np.cos(angles_K), np.sin(angles_K)] 31 | 32 | obs_avoidance_action_B_2 = -np.sum(obs_B_Do[:, 0:K][..., None] * np.expand_dims(vecs_K, 0), 33 | axis=1) 34 | 35 | ev_catch_action_B_2 = np.sum(obs_B_Do[:, K:2 * K][..., None] * np.expand_dims(vecs_K, 0), 36 | axis=1) 37 | 38 | po_avoidance_action_B_2 = -np.sum(obs_B_Do[:, 3 * K:4 * K][..., None] * 39 | np.expand_dims(vecs_K, 0), axis=1) 40 | 41 | pu_closer_action_B_2 = np.sum(obs_B_Do[:, 5 * K:6 * K][..., None] * 42 | np.expand_dims(vecs_K, 0), axis=1) / 2 43 | 44 | ev_catch_action_B_2[obs_B_Do[:, 7 * K] > 0] *= 1.5 45 | po_avoidance_action_B_2[obs_B_Do[:, 7 * K + 1] > 0] *= 1.5 46 | 47 | actions_B_2 = obs_avoidance_action_B_2 + ev_catch_action_B_2 + po_avoidance_action_B_2 + pu_closer_action_B_2 48 | norm = np.linalg.norm(actions_B_2) 49 | if norm > 0: 50 | actions_B_2 /= norm 51 | else: 52 | actions_B_2 = np.zeros((B, 2)) 53 | 54 | # actions_B_2 = np.random.randn(B, 2) 55 | fake_actiondist = np.concatenate([np.zeros((B, 2)), np.ones((B, 2))]) 56 | return actions_B_2, fake_actiondist 57 | 58 | def get_state(self): 59 | return [] 60 | 61 | def set_state(self, *args): 62 | pass 63 | 64 | 65 | if __name__ == '__main__': 66 | from madrl_environments.pursuit import MAWaterWorld 67 | from vis import Visualizer 68 | import pprint 69 | env = MAWaterWorld(n_evaders=10, n_pursuers=8, n_poison=10, n_coop=4, n_sensors=30, 70 | food_reward=10, poison_reward=-1, encounter_reward=0.01) 71 | train_args = {'discount': 0.99, 'control': 'decentralized'} 72 | 73 | vis = Visualizer(env, train_args, 500, 1, True, 'heuristic') 74 | 75 | rew, info = vis(None, hpolicy=WaterworldHeuristicPolicy(env.agents[0].observation_space, 76 | env.agents[0].action_space)) 77 | pprint.pprint(rew) 78 | pprint.pprint(info) 79 | -------------------------------------------------------------------------------- /lessons/multiant/env.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | l1: 3 | n_legs: 3 4 | 5 | l2: 6 | n_legs: 4 7 | 8 | l3: 9 | n_legs: 5 10 | 11 | l4: 12 | n_legs: 6 13 | 14 | l5: 15 | n_legs: 7 16 | 17 | l6: 18 | n_legs: 8 19 | 20 | l7: 21 | n_legs: 9 22 | 23 | l8: 24 | n_legs: 10 25 | 26 | thresholds: 27 | lesson: 200 28 | stop: 450 29 | 30 | n_trials: 1 31 | eval_trials: 20 32 | metric: ret 33 | -------------------------------------------------------------------------------- /lessons/multiwalker/env.yaml: -------------------------------------------------------------------------------- 1 | tasks: 2 | l1: 3 | n_walkers: 2 4 | 5 | l2: 6 | n_walkers: 3 7 | 8 | l3: 9 | n_walkers: 4 10 | 11 | l4: 12 | n_walkers: 5 13 | 14 | l5: 15 | n_walkers: 6 16 | 17 | l6: 18 | n_walkers: 7 19 | 20 | l7: 21 | n_walkers: 8 22 | 23 | l8: 24 | n_walkers: 9 25 | 26 | l9: 27 | n_walkers: 10 28 | 29 | thresholds: 30 | lesson: 10 31 | stop: 20 32 | 33 | n_trials: 20 34 | eval_trials: 20 35 | metric: ret 36 | -------------------------------------------------------------------------------- /madrl_environments/mujoco/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sisl/MADRL/9ea39a0fe8695b391008a4eb7bda9fe4438a96de/madrl_environments/mujoco/__init__.py -------------------------------------------------------------------------------- /madrl_environments/mujoco/ant/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sisl/MADRL/9ea39a0fe8695b391008a4eb7bda9fe4438a96de/madrl_environments/mujoco/ant/__init__.py -------------------------------------------------------------------------------- /madrl_environments/mujoco/ant/ant_og.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 81 | -------------------------------------------------------------------------------- /madrl_environments/mujoco/ant/multi_ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 81 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/__init__.py: -------------------------------------------------------------------------------- 1 | from .pursuit_evade import PursuitEvade 2 | from .utils import RandomPolicy, SingleActionPolicy, TwoDMaps 3 | from .waterworld import MAWaterWorld 4 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/eval_scripts/policy_eval.py: -------------------------------------------------------------------------------- 1 | # compare each policy (including fully obs) 2 | 3 | from madrl_environments import CentralizedPursuitEvade 4 | from madrl_environments.pursuit import TwoDMaps 5 | 6 | from os.path import join 7 | import matplotlib.pyplot as plt 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | 12 | from rltools.algs import TRPOSolver 13 | from rltools.utils import evaluate 14 | from rltools.utils import simulate 15 | from rltools.models import softmax_mlp 16 | 17 | ################################################################# 18 | ######################## PARAMS ################################# 19 | ################################################################# 20 | layers = [128,64,64] # or [128,128] 21 | model_path = "../data/obs_range_sweep_3layer" 22 | plot_title = "Evaluation of Three Layer Centralized Controller" 23 | save_path = "results/three_layer_centralized_eval.pdf" 24 | 25 | 26 | xs = 10 27 | ys = 10 28 | n_evaders = 5 29 | n_pursuers = 2 30 | 31 | map_mat = TwoDMaps.rectangle_map(xs, ys) 32 | 33 | runs = [1] 34 | ranges = [3, 5, 7, 9] 35 | oranges = np.array([ranges for r in runs]).flatten() 36 | model_paths = [join("../data/obs_range_sweep_2layer_random_evaders_urgency", "obs_range_"+str(o), "run"+str(r), "final_model.ckpt") for o in oranges for r in runs] 37 | 38 | n_traj = 100 39 | max_steps = 250 40 | 41 | results = np.zeros((len(ranges), len(runs), 2, 3)) 42 | 43 | 44 | best_stds_stoch = np.zeros(len(ranges)) 45 | best_means_stoch = np.zeros(len(ranges)) 46 | best_stds_det = np.zeros(len(ranges)) 47 | best_means_det = np.zeros(len(ranges)) 48 | pidx = 0 49 | for i, ran in enumerate(ranges): 50 | for j in xrange(len(runs)): 51 | print model_paths[pidx] 52 | 53 | env = CentralizedPursuitEvade(map_mat, n_evaders=n_evaders, n_pursuers=n_pursuers, obs_range=ran, n_catch=2) 54 | 55 | input_obs = tf.placeholder(tf.float32, shape=(None,) + env.observation_space.shape, name="obs") 56 | net = softmax_mlp(input_obs, env.action_space.n, layers=[128,128], activation=tf.nn.tanh) 57 | solver = TRPOSolver(env, policy_net=net, input_layer=input_obs) 58 | solver.load(model_paths[pidx]) 59 | 60 | solver.train = True # stochastic policy 61 | r = evaluate(env, solver, max_steps, n_traj) 62 | smean, sstd = r.mean(), r.std() 63 | results[i, j, 0] = [n_traj, r.mean(), r.std()] 64 | 65 | solver.train = False # deterministic policy 66 | r = evaluate(env, solver, max_steps, n_traj) 67 | dmean, dstd = r.mean(), r.std() 68 | results[i, j, 1] = [n_traj, r.mean(), r.std()] 69 | if smean > best_means_stoch[i]: 70 | best_means_stoch[i] = smean 71 | best_stds_stoch[i] = sstd 72 | best_means_det[i] = dmean 73 | best_stds_det[i] = dstd 74 | 75 | pidx += 1 76 | tf.reset_default_graph() 77 | 78 | 79 | plt.figure(1) 80 | 81 | plt.plot(ranges, best_means_stoch, 'k', color='#CC4F1B', label='Stochastic Policy') 82 | plt.fill_between(ranges, best_means_stoch-best_stds_stoch, best_means_stoch+best_stds_stoch, alpha=0.5, edgecolor='#CC4F1B', facecolor='#FF9848') 83 | 84 | plt.plot(ranges, best_means_det, 'k', color='#1B2ACC', label='Deterministic Policy') 85 | plt.fill_between(ranges, best_means_det-best_stds_det, best_means_det+best_stds_det, alpha=0.2, edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=4, linestyle='dashdot', antialiased=True) 86 | 87 | #plt.ylim([3,5.5]) 88 | 89 | plt.xlabel('Observation Ranges') 90 | plt.ylabel('Average Rewards') 91 | plt.title('Centralized Pursuit Policy Evaluation') 92 | plt.grid() 93 | plt.legend(loc=4) 94 | plt.savefig('results/policy_eval/two_layer_centralized_eval_random_evaders_urgency_best.pdf') 95 | 96 | """ 97 | stoch_mean = [] 98 | stoch_best = [] 99 | stoch_worst = [] 100 | stoch_std = [] 101 | det_mean = [] 102 | det_best = [] 103 | det_worst = [] 104 | det_std = [] 105 | for i in xrange(len(ranges)): 106 | stoch_mean.append(results[i,:,0,1].mean()) 107 | stoch_std.append(results[i,:,0,1].std()/np.sqrt(n_traj)) 108 | stoch_best.append(results[i,:,0,1].max()) 109 | stoch_worst.append(results[i,:,0,1].min()) 110 | det_mean.append(results[i,:,1,1].mean()) 111 | det_std.append(results[i,:,1,1].std()/np.sqrt(n_traj)) 112 | det_best.append(results[i,:,1,1].max()) 113 | det_worst.append(results[i,:,1,1].min()) 114 | 115 | stoch_mean = np.array(stoch_mean) 116 | stoch_std = np.array(stoch_std) 117 | det_mean = np.array(det_mean) 118 | det_std = np.array(det_std) 119 | 120 | plt.plot(ranges, stoch_mean, 'k', color='#CC4F1B', label='Stochastic Policy') 121 | plt.fill_between(ranges, stoch_mean-stoch_std, stoch_mean+stoch_std, alpha=0.5, edgecolor='#CC4F1B', facecolor='#FF9848') 122 | 123 | plt.plot(ranges, det_mean, 'k', color='#1B2ACC', label='Deterministic Policy') 124 | plt.fill_between(ranges, det_mean-det_std, det_mean+det_std, alpha=0.2, edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=4, linestyle='dashdot', antialiased=True) 125 | 126 | plt.ylim([3,5.5]) 127 | 128 | plt.xlabel('Observation Ranges') 129 | plt.ylabel('Average Rewards') 130 | plt.title('Evaluation of Three Layer Centralized Controller') 131 | plt.grid() 132 | plt.legend() 133 | 134 | plt.savefig('results/three_layer_centralized_eval.pdf') 135 | """ 136 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/eval_scripts/speed_pursuit.py: -------------------------------------------------------------------------------- 1 | # test average timesteps to catch one evader 2 | 3 | from madrl_environments import CentralizedPursuitEvade 4 | from madrl_environments.pursuit import TwoDMaps 5 | 6 | from os.path import join 7 | import matplotlib.pyplot as plt 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | 12 | from rltools.algs import TRPOSolver 13 | from rltools.utils import evaluate_time 14 | from rltools.utils import simulate 15 | from rltools.models import softmax_mlp 16 | 17 | xs = 10 18 | ys = 10 19 | n_evaders = 1 20 | n_pursuers = 2 21 | 22 | map_mat = TwoDMaps.rectangle_map(xs, ys) 23 | 24 | runs = [1, 2, 3, 4, 5] 25 | ranges = [3, 5, 7, 9] 26 | save_path = "results/speed_eval/two_layer_centralized_eval_random_evaders_urgency_best.pdf" 27 | oranges = np.array([ranges for r in runs]).flatten() 28 | model_paths = [join("../data/obs_range_sweep_2layer_random_evaders_urgency", "obs_range_"+str(o), "run"+str(r), "final_model.ckpt") for o in oranges for r in runs] 29 | 30 | n_traj = 100 31 | max_steps = 500 32 | 33 | results = np.zeros((len(ranges), len(runs), 2, 3)) 34 | 35 | 36 | pidx = 0 37 | best_stds_stoch = np.zeros(len(ranges)) + 1e6 38 | best_means_stoch = np.zeros(len(ranges)) + 1e6 39 | best_stds_det = np.zeros(len(ranges)) + 1e6 40 | best_means_det = np.zeros(len(ranges)) + 1e6 41 | 42 | 43 | for i, ran in enumerate(ranges): 44 | for j in xrange(len(runs)): 45 | print model_paths[pidx] 46 | 47 | env = CentralizedPursuitEvade(map_mat, n_evaders=n_evaders, n_pursuers=n_pursuers, obs_range=ran, n_catch=2) 48 | 49 | input_obs = tf.placeholder(tf.float32, shape=(None,) + env.observation_space.shape, name="obs") 50 | net = softmax_mlp(input_obs, env.action_space.n, layers=[128,128], activation=tf.nn.tanh) 51 | solver = TRPOSolver(env, policy_net=net, input_layer=input_obs) 52 | solver.load(model_paths[pidx]) 53 | 54 | solver.train = True # stochastic policy 55 | r, t = evaluate_time(env, solver, max_steps, n_traj) 56 | smean, sstd = t.mean(), t.std() 57 | results[i, j, 0] = [n_traj, t.mean(), t.std()] 58 | solver.train = False # deterministic policy 59 | r, t = evaluate_time(env, solver, max_steps, n_traj) 60 | dmean, dstd = t.mean(), t.std() 61 | results[i, j, 1] = [n_traj, t.mean(), t.std()] 62 | if smean < best_means_stoch[i]: 63 | best_means_stoch[i] = smean 64 | best_stds_stoch[i] = sstd 65 | best_means_det[i] = dmean 66 | best_stds_det[i] = dstd 67 | 68 | pidx += 1 69 | tf.reset_default_graph() 70 | 71 | plt.figure(1) 72 | 73 | plt.plot(ranges, best_means_stoch, 'k', color='#CC4F1B', label='Stochastic Policy') 74 | plt.fill_between(ranges, best_means_stoch-best_stds_stoch, best_means_stoch+best_stds_stoch, alpha=0.5, edgecolor='#CC4F1B', facecolor='#FF9848') 75 | 76 | plt.plot(ranges, best_means_det, 'k', color='#1B2ACC', label='Deterministic Policy') 77 | plt.fill_between(ranges, best_means_det-best_stds_det, best_means_det+best_stds_det, alpha=0.2, edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=4, linestyle='dashdot', antialiased=True) 78 | 79 | #plt.ylim([3,5.5]) 80 | 81 | plt.xlabel('Observation Ranges') 82 | plt.ylabel('Time to Catch Single Evader') 83 | plt.title('Centralized Pursuit Policy Speed Evaluation') 84 | plt.grid() 85 | plt.legend() 86 | plt.ylim([0,90]) 87 | plt.savefig(save_path) 88 | 89 | plt.clf() 90 | 91 | 92 | """ 93 | stoch_mean = [] 94 | stoch_best = [] 95 | stoch_worst = [] 96 | stoch_std = [] 97 | det_mean = [] 98 | det_best = [] 99 | det_worst = [] 100 | det_std = [] 101 | for i in xrange(len(ranges)): 102 | stoch_mean.append(results[i,:,0,1].mean()) 103 | stoch_std.append(results[i,:,0,1].std()/np.sqrt(n_traj)) 104 | stoch_best.append(results[i,:,0,1].max()) 105 | stoch_worst.append(results[i,:,0,1].min()) 106 | det_mean.append(results[i,:,1,1].mean()) 107 | det_std.append(results[i,:,1,1].std()/np.sqrt(n_traj)) 108 | det_best.append(results[i,:,1,1].max()) 109 | det_worst.append(results[i,:,1,1].min()) 110 | 111 | stoch_mean = np.array(stoch_mean) 112 | stoch_std = np.array(stoch_std) 113 | det_mean = np.array(det_mean) 114 | det_std = np.array(det_std) 115 | 116 | plt.figure(2) 117 | 118 | plt.plot(ranges, stoch_mean, 'k', color='#CC4F1B', label='Stochastic Policy') 119 | plt.fill_between(ranges, stoch_mean-stoch_std, stoch_mean+stoch_std, alpha=0.5, edgecolor='#CC4F1B', facecolor='#FF9848') 120 | 121 | plt.plot(ranges, det_mean, 'k', color='#1B2ACC', label='Deterministic Policy') 122 | plt.fill_between(ranges, det_mean-det_std, det_mean+det_std, alpha=0.2, edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=4, linestyle='dashdot', antialiased=True) 123 | 124 | #plt.ylim([3,5.5]) 125 | 126 | plt.xlabel('Observation Ranges') 127 | plt.ylabel('Time to Catch Single Evader') 128 | plt.title('Centralized Pursuit Policy Speed Evaluation') 129 | plt.grid() 130 | plt.legend() 131 | 132 | #plt.savefig('results/three_layer_centralized_eval.pdf') 133 | """ 134 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/eval_scripts/stationary_eval.py: -------------------------------------------------------------------------------- 1 | # evalaute against stationary evaders 2 | 3 | from madrl_environments import CentralizedPursuitEvade 4 | from madrl_environments.pursuit import TwoDMaps 5 | from madrl_environments.pursuit import SingleActionPolicy 6 | 7 | from os.path import join 8 | import matplotlib.pyplot as plt 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | 13 | from rltools.algs import TRPOSolver 14 | from rltools.utils import evaluate_time 15 | from rltools.utils import simulate 16 | from rltools.models import softmax_mlp 17 | 18 | xs = 10 19 | ys = 10 20 | n_evaders = 1 21 | n_pursuers = 2 22 | 23 | map_mat = TwoDMaps.rectangle_map(xs, ys) 24 | 25 | evader_policy = SingleActionPolicy(4) # stationary action 26 | 27 | runs = [4, 1, 3, 1] 28 | ranges = [3, 5, 7, 9] 29 | oranges = np.array([ranges for r in runs]).flatten() 30 | model_paths = [join("../data/obs_range_sweep_2layer", "obs_range_"+str(o), "run"+str(r), "final_model.ckpt") for o, r in zip(oranges, runs)] 31 | 32 | n_traj = 200 33 | max_steps = 500 34 | 35 | results = np.zeros((len(ranges), 2, 3)) 36 | 37 | count = 0 38 | for ran, path in zip(ranges, model_paths): 39 | print path 40 | 41 | env = CentralizedPursuitEvade(map_mat, n_evaders=n_evaders, n_pursuers=n_pursuers, obs_range=ran, n_catch=2, 42 | evader_controller=evader_policy) 43 | 44 | input_obs = tf.placeholder(tf.float32, shape=(None,) + env.observation_space.shape, name="obs") 45 | net = softmax_mlp(input_obs, env.action_space.n, layers=[128,128], activation=tf.nn.tanh) 46 | solver = TRPOSolver(env, policy_net=net, input_layer=input_obs) 47 | solver.load(path) 48 | 49 | solver.train = True # stochastic policy 50 | r, t = evaluate_time(env, solver, max_steps, n_traj) 51 | results[count, 0] = [n_traj, t.mean(), t.std()] 52 | solver.train = False # deterministic policy 53 | r, t = evaluate_time(env, solver, max_steps, n_traj) 54 | results[count, 1] = [n_traj, t.mean(), t.std()] 55 | 56 | count += 1 57 | tf.reset_default_graph() 58 | 59 | m = results[:,0,1] 60 | sig = results[:,0,2] / np.sqrt(n_traj) 61 | plt.plot(ranges, m, 'k', color='#CC4F1B', label='Stochastic Policy') 62 | plt.fill_between(ranges, m-sig, m+sig, alpha=0.5, edgecolor='#CC4F1B', facecolor='#FF9848') 63 | 64 | m = results[:,1,1] 65 | sig = results[:,1,2] / np.sqrt(n_traj) 66 | plt.plot(ranges, m, 'k', color='#1B2ACC', label='Deterministic Policy') 67 | plt.fill_between(ranges, m-sig, m+sig, alpha=0.2, edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=4, linestyle='dashdot', antialiased=True) 68 | 69 | #plt.ylim([3,5.5]) 70 | 71 | plt.xlabel('Observation Ranges') 72 | plt.ylabel('Time to Find Evader') 73 | plt.title('Centralized Pursuit Policy with Stationay Evader') 74 | 75 | plt.ylim([0,160]) 76 | plt.grid() 77 | plt.legend() 78 | plt.savefig('results/stationary_eval/two_layer_centralized_eval_best.pdf') 79 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/test_pursuit.py: -------------------------------------------------------------------------------- 1 | from pursuit_evade import PursuitEvade 2 | import gym 3 | 4 | from utils import * 5 | 6 | xs = 5 7 | ys = 5 8 | n_evaders = 1 9 | n_pursuers = 2 10 | 11 | map_mat = TwoDMaps.rectangle_map(xs, ys) 12 | 13 | # obs_range should be odd 3, 5, 7, etc 14 | env = PursuitEvade([map_mat], n_evaders=n_evaders, n_pursuers=n_pursuers, obs_range=3, n_catch=2, surround=False, 15 | reward_mech='local') 16 | 17 | o = env.reset() 18 | 19 | """ 20 | a = [4]*n_pursuers 21 | 22 | env.pursuer_layer.set_position(0, 7, 1) 23 | env.pursuer_layer.set_position(1, 8, 0) 24 | env.pursuer_layer.set_position(2, 9, 1) 25 | env.pursuer_layer.set_position(3, 8, 2) 26 | 27 | 28 | env.pursuer_layer.set_position(4, 0, 2) 29 | env.pursuer_layer.set_position(5, 0, 4) 30 | env.pursuer_layer.set_position(6, 1, 3) 31 | 32 | env.pursuer_layer.set_position(7, 3, 4) 33 | env.pursuer_layer.set_position(8, 2, 5) 34 | env.pursuer_layer.set_position(9, 3, 6) 35 | 36 | 37 | env.pursuer_layer.set_position(10, 9, 9) 38 | env.pursuer_layer.set_position(11, 9, 9) 39 | 40 | env.evader_layer.set_position(0, 8, 1) 41 | env.evader_layer.set_position(1, 8, 1) 42 | 43 | env.evader_layer.set_position(2, 0, 3) 44 | #env.evader_layer.set_position(3, 0, 3) 45 | 46 | env.evader_layer.set_position(3, 3, 5) 47 | #env.evader_layer.set_position(4, 3, 5) 48 | #env.evader_layer.set_position(5, 3, 5) 49 | 50 | 51 | env.evader_layer.set_position(4,9,9) 52 | 53 | r = env.reward() 54 | 55 | o, r, done, info = env.step(a) 56 | 57 | #map_mat = multi_scale_map(32, 32) 58 | map_mat = multi_scale_map(32, 32, scales=[(4, [0.2,0.3]), (10, [0.1,0.2])]) 59 | 60 | map_pool = [map_mat] 61 | map_pool = np.load('../../runners/maps/map_pool16.npy') 62 | 63 | env = PursuitEvade(map_pool, n_evaders=10, n_pursuers=10, obs_range=5, n_catch=2, surround=True, reward_mech='local', 64 | sample_maps=True, constraint_window=1.0) 65 | #env.reset() 66 | #env.render() 67 | <<<<<<< Updated upstream 68 | 69 | env = PursuitEvade(map_pool, n_evaders=10, n_pursuers=3, obs_range=9, n_catch=2, surround=False, 70 | reward_mech='local') 71 | ======= 72 | """ 73 | >>>>>>> Stashed changes 74 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/utils/AgentLayer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ################################################################# 4 | # Implements a Cooperating Agent Layer for 2D problems 5 | ################################################################# 6 | 7 | class AgentLayer(): 8 | 9 | # constructor 10 | def __init__(self, 11 | xs, # x size of map 12 | ys, # y size of map 13 | allies, # list of ally agents 14 | seed=1): # should we have a seeds array for each agent? 15 | """ 16 | Each ally agent must support: 17 | - move(action) 18 | - current_position() 19 | - nactions() 20 | - set_position(x, y) 21 | """ 22 | 23 | self.allies = allies 24 | self.nagents = len(allies) 25 | self.global_state = np.zeros((xs, ys), dtype=np.int32) 26 | 27 | def n_agents(self): 28 | return self.nagents 29 | 30 | def move_agent(self, agent_idx, action): 31 | return self.allies[agent_idx].step(action) 32 | 33 | def set_position(self, agent_idx, x, y): 34 | self.allies[agent_idx].set_position(x,y) 35 | 36 | def get_position(self, agent_idx): 37 | """ 38 | Returns the position of the given agent 39 | """ 40 | return self.allies[agent_idx].current_position() 41 | 42 | def get_nactions(self, agent_idx): 43 | return self.allies[agent_idx].nactions() 44 | 45 | def remove_agent(self, agent_idx): 46 | # idx is between zero and nagents 47 | self.allies.pop(agent_idx) 48 | self.nagents -= 1 49 | 50 | def get_state_matrix(self): 51 | """ 52 | Returns a matrix representing the positions of all allies 53 | Example: matrix contains the number of allies at give (x,y) position 54 | 0 0 0 1 0 0 0 55 | 0 2 0 2 0 0 0 56 | 0 0 0 0 0 0 1 57 | 1 0 0 0 0 0 5 58 | """ 59 | gs = self.global_state 60 | gs.fill(0) 61 | for ally in self.allies: 62 | x, y = ally.current_position() 63 | gs[x,y] += 1 64 | return gs 65 | 66 | def get_state(self): 67 | pos = np.zeros(2*len(self.allies)) 68 | idx = 0 69 | for ally in self.allies: 70 | pos[idx:(idx+2)] = ally.get_state() 71 | idx += 2 72 | return pos 73 | 74 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/utils/Controllers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ################################################################# 4 | # Implements multi-agent controllers 5 | ################################################################# 6 | 7 | 8 | class RandomPolicy(object): 9 | 10 | # constructor 11 | def __init__(self, n_actions, rng = np.random.RandomState()): 12 | self.rng = rng 13 | self.n_actions = n_actions 14 | 15 | def act(self, state): 16 | return self.rng.randint(self.n_actions) 17 | 18 | 19 | class SingleActionPolicy(object): 20 | 21 | def __init__(self, a): 22 | self.action = a 23 | 24 | 25 | def act(self, state): 26 | return self.action 27 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/utils/DiscreteAgent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import spaces 4 | from madrl_environments import Agent 5 | 6 | 7 | ################################################################# 8 | # Implements the Single 2D Agent Dynamics 9 | ################################################################# 10 | 11 | class DiscreteAgent(Agent): 12 | 13 | # constructor 14 | def __init__(self, 15 | xs, 16 | ys, 17 | map_matrix, # the map of the environemnt (-1 are buildings) 18 | obs_range=3, 19 | n_channels=3, # number of observation channels 20 | seed=1, 21 | flatten=False): 22 | 23 | self.random_state = np.random.RandomState(seed) 24 | 25 | self.xs = xs 26 | self.ys = ys 27 | 28 | self.eactions = [0, # move left 29 | 1, # move right 30 | 2, # move up 31 | 3, # move down 32 | 4] # stay 33 | 34 | self.motion_range = [[-1, 0], 35 | [1, 0], 36 | [0, 1], 37 | [0, -1], 38 | [0, 0]] 39 | 40 | self.current_pos = np.zeros(2, dtype=np.int32) # x and y position 41 | self.last_pos = np.zeros(2, dtype=np.int32) 42 | self.temp_pos = np.zeros(2, dtype=np.int32) 43 | 44 | self.map_matrix = map_matrix 45 | 46 | self.terminal = False 47 | 48 | self._obs_range = obs_range 49 | 50 | if flatten: 51 | self._obs_shape = (n_channels * obs_range**2 + 1,) 52 | else: 53 | self._obs_shape = (obs_range, obs_range, 4) 54 | #self._obs_shape = (4, obs_range, obs_range) 55 | 56 | 57 | @property 58 | def observation_space(self): 59 | return spaces.Box(low=-np.inf, high=np.inf, shape=self._obs_shape) 60 | 61 | @property 62 | def action_space(self): 63 | return spaces.Discrete(5) 64 | 65 | 66 | ################################################################# 67 | # Dynamics Functions 68 | ################################################################# 69 | def step(self, a): 70 | cpos = self.current_pos 71 | lpos = self.last_pos 72 | # if dead or reached goal dont move 73 | if self.terminal: 74 | return cpos 75 | # if in building, dead, and stay there 76 | if self.inbuilding(cpos[0], cpos[1]): 77 | self.terminal = True 78 | return cpos 79 | tpos = self.temp_pos 80 | tpos[0] = cpos[0] 81 | tpos[1] = cpos[1] 82 | # transition is deterministic 83 | tpos += self.motion_range[a] 84 | x = tpos[0] 85 | y = tpos[1] 86 | # check bounds 87 | if not self.inbounds(x, y): 88 | return cpos 89 | # if bumped into building, then stay 90 | if self.inbuilding(x, y): 91 | return cpos 92 | else: 93 | lpos[0] = cpos[0] 94 | lpos[1] = cpos[1] 95 | cpos[0] = x 96 | cpos[1] = y 97 | return cpos 98 | 99 | def get_state(self): 100 | return self.current_pos 101 | 102 | ################################################################# 103 | # Helper Functions 104 | ################################################################# 105 | def inbounds(self, x, y): 106 | if 0 <= x < self.xs and 0 <= y < self.ys: 107 | return True 108 | return False 109 | 110 | def inbuilding(self, x, y): 111 | if self.map_matrix[x,y] == -1: 112 | return True 113 | return False 114 | 115 | def nactions(self): 116 | return len(self.eactions) 117 | 118 | def set_position(self, xs, ys): 119 | self.current_pos[0] = xs 120 | self.current_pos[1] = ys 121 | 122 | def current_position(self): 123 | return self.current_pos 124 | 125 | def last_position(self): 126 | return self.last_pos 127 | 128 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/utils/TwoDMaps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from six.moves import xrange 4 | 5 | from scipy.ndimage import zoom 6 | 7 | 8 | def rectangle_map(xs, ys, xb=0.3, yb=0.2): 9 | """ 10 | Returns a 2D 'map' with a rectangle building centered in the middle 11 | Map is a 2D numpy array 12 | xb and yb are buffers for each dim representing the raio of the map to leave open on each side 13 | """ 14 | rmap = np.zeros((xs, ys), dtype=np.int32) 15 | for i in xrange(xs): 16 | for j in xrange(ys): 17 | # are we in the rectnagle in x dim? 18 | if (float(i) / xs) > xb and (float(i) / xs) < (1.0 - xb): 19 | # are we in the rectangle in y dim? 20 | if (float(j) / ys) > yb and (float(j) / ys) < (1.0 - yb): 21 | rmap[i, j] = -1 # -1 is building pixel flag 22 | return rmap 23 | 24 | 25 | def complex_map(xs, ys): 26 | """ 27 | Returns a 2D 'map' with a four different obstacles 28 | Map is a 2D numpy array 29 | """ 30 | cmap = np.zeros((xs, ys), dtype=np.int32) 31 | cmap = add_rectangle(cmap, xc=0.8, yc=0.5, xl=0.1, yl=0.8) 32 | cmap = add_rectangle(cmap, xc=0.4, yc=0.8, xl=0.5, yl=0.2) 33 | cmap = add_rectangle(cmap, xc=0.5, yc=0.5, xl=0.4, yl=0.2) 34 | cmap = add_rectangle(cmap, xc=0.3, yc=0.1, xl=0.5, yl=0.1) 35 | cmap = add_rectangle(cmap, xc=0.1, yc=0.3, xl=0.1, yl=0.5) 36 | return cmap 37 | 38 | 39 | def gen_map(xs, ys, n_obs, center_bounds=[0.0, 1.0], length_bounds=[0.1,0.5], gmap=None): 40 | cl, cu = center_bounds 41 | ll, lu = length_bounds 42 | if gmap is None: gmap = np.zeros((xs, ys), dtype=np.int32) 43 | for _ in xrange(n_obs): 44 | xc = np.random.uniform(cl, cu) 45 | yc = np.random.uniform(cl, cu) 46 | xl = np.random.uniform(ll, lu) 47 | yl = np.random.uniform(ll, lu) 48 | gmap = add_rectangle(gmap, xc=xc, yc=yc, xl=xl, yl=yl) 49 | return gmap 50 | 51 | def multi_scale_map(xs, ys, scales=[(3, [0.2, 0.3]), (10, [0.1, 0.2]), (30, [0.05, 0.1]), (150, [0.01, 0.05])]): 52 | gmap = np.zeros((xs, ys), dtype=np.int32) 53 | for scale in scales: 54 | n, lb = scale 55 | gmap = gen_map(xs, ys, n, length_bounds=lb, gmap=gmap) 56 | return gmap 57 | 58 | 59 | def add_rectangle(input_map, xc, yc, xl, yl): 60 | """ 61 | Add a rectangle to the input map 62 | centered a xc, yc with dimensions xl, yl. 63 | Input specs are normalized wrt the map. 64 | """ 65 | assert len(input_map.shape) == 2, "input_map must be a numpy matrix" 66 | 67 | xs, ys = input_map.shape 68 | xcc, ycc = int(round(xs * xc)), int(round(ys * yc)) 69 | xll, yll = int(round(xs * xl)), int(round(ys * yl)) 70 | if xll <= 1: 71 | x_lbound, x_upbound = xcc, xcc+1 72 | else: 73 | x_lbound, x_upbound = xcc - xll/2, xcc + xll/2 74 | if yll <= 1: 75 | y_lbound, y_upbound = ycc, ycc+1 76 | else: 77 | y_lbound, y_upbound = ycc - yll/2, ycc + yll/2 78 | 79 | #assert x_lbound >= 0 and x_upbound < xs, "Invalid rectangel config, x out of bounds" 80 | #assert y_lbound >= 0 and y_upbound < ys, "Invalid rectangel config, y out of bounds" 81 | 82 | x_lbound, x_upbound = np.clip([x_lbound, x_upbound], 0, xs) 83 | y_lbound, y_upbound = np.clip([y_lbound, y_upbound], 0, ys) 84 | 85 | for i in xrange(x_lbound, x_upbound): 86 | for j in xrange(y_lbound, y_upbound): 87 | input_map[j,i] = -1 88 | return input_map 89 | 90 | def resize(scale, old_mats): 91 | new_mats = [] 92 | for mat in old_mats: 93 | new_mats.append(zoom(mat, scale, order=0)) 94 | return np.array(new_mats) 95 | 96 | 97 | def simple_soccer_map(xs=6, ys=9): 98 | assert xs % 2 == 0, "xs must be even" 99 | smap = np.zeros((xs, ys), dtype=np.int32) 100 | smap[0:xs/2-1,0] = -1 101 | smap[xs/2+1:xs,0] = -1 102 | smap[0:xs/2-1,ys-1] = -1 103 | smap[xs/2+1:xs,ys-1] = -1 104 | return smap 105 | 106 | 107 | 108 | def cross_map(xs, ys): 109 | pass 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multi-agent utilities 3 | """ 4 | 5 | from .AgentLayer import AgentLayer 6 | from .Controllers import * 7 | from .DiscreteAgent import DiscreteAgent 8 | from .TwoDMaps import * 9 | from .agent_utils import * 10 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/utils/agent_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from six.moves import xrange 4 | 5 | from .DiscreteAgent import DiscreteAgent 6 | 7 | ################################################################# 8 | # Implements utility functions for multi-agent DRL 9 | ################################################################# 10 | 11 | 12 | def create_agents(nagents, map_matrix, obs_range, flatten=False, randinit=False, constraints=None): 13 | """ 14 | Initializes the agents on a map (map_matrix) 15 | -nagents: the number of agents to put on the map 16 | -randinit: if True will place agents in random, feasible locations 17 | if False will place all agents at 0 18 | """ 19 | xs, ys = map_matrix.shape 20 | agents = [] 21 | for i in xrange(nagents): 22 | xinit, yinit = (0, 0) 23 | if randinit: 24 | xinit, yinit = feasible_position(map_matrix, constraints=constraints) 25 | agent = DiscreteAgent(xs, ys, map_matrix, obs_range=obs_range, flatten=flatten) 26 | agent.set_position(xinit, yinit) 27 | agents.append(agent) 28 | return agents 29 | 30 | 31 | def feasible_position(map_matrix, constraints=None): 32 | """ 33 | Returns a feasible position on map (map_matrix) 34 | """ 35 | xs, ys = map_matrix.shape 36 | loop_count = 0 37 | while True: 38 | if constraints is None: 39 | x = np.random.randint(xs) 40 | y = np.random.randint(ys) 41 | else: 42 | xl, xu = constraints[0] 43 | yl, yu = constraints[1] 44 | x = np.random.randint(xl, xu) 45 | y = np.random.randint(yl, yu) 46 | if map_matrix[x, y] != -1: 47 | return (x, y) 48 | 49 | 50 | def set_agents(agent_matrix, map_matrix): 51 | # check input sizes 52 | if agent_matrix.shape != map_matrix.shape: 53 | raise ValueError("Agent configuration and map matrix have mis-matched sizes") 54 | 55 | agents = [] 56 | xs, ys = agent_matrix.shape 57 | for i in xrange(xs): 58 | for j in xrange(ys): 59 | n_agents = agent_matrix[i, j] 60 | if n_agents > 0: 61 | if map_matrix[i, j] == -1: 62 | raise ValueError( 63 | "Trying to place an agent into a building: check map matrix and agent configuration") 64 | agent = DiscreteAgent(xs, ys, map_matrix) 65 | agent.set_position(i, j) 66 | agents.append(agent) 67 | return agents 68 | -------------------------------------------------------------------------------- /madrl_environments/pursuit/vis_policy.py: -------------------------------------------------------------------------------- 1 | from centralized_pursuit_evade import CentralizedPursuitEvade 2 | import gym 3 | from os.path import join 4 | import matplotlib.pyplot as plt 5 | 6 | from utils import * 7 | 8 | import tensorflow as tf 9 | 10 | from rltools.algs import TRPOSolver 11 | from rltools.utils import simulate 12 | from rltools.models import softmax_mlp 13 | 14 | 15 | xs = 10 16 | ys = 10 17 | n_evaders = 1 18 | n_pursuers = 2 19 | 20 | map_mat = TwoDMaps.rectangle_map(xs, ys) 21 | 22 | # obs_range should be odd 3, 5, 7, etc 23 | #env = CentralizedPursuitEvade(map_mat, n_evaders=n_evaders, n_pursuers=n_pursuers, obs_range=9, n_catch=2) 24 | 25 | 26 | config = {} 27 | config["train_iterations"] = 1000 # number of trpo iterations 28 | config["max_pathlength"] = 250 # maximum length of an env trajecotry 29 | config["timesteps_per_batch"] = 1000 30 | config["eval_trajectories"] = 50 31 | config["eval_every"] = 50 32 | config["gamma"] = 0.95 # discount factor 33 | 34 | ran = 9 35 | 36 | env = CentralizedPursuitEvade(map_mat, n_evaders=n_evaders, n_pursuers=n_pursuers, obs_range=ran, n_catch=2) 37 | 38 | input_obs = tf.placeholder(tf.float32, shape=(None,) + env.observation_space.shape, name="obs" + str(ran)) 39 | net = softmax_mlp(input_obs, env.action_space.n, layers=[128,128], activation=tf.nn.tanh) 40 | 41 | solver = TRPOSolver(env, config=config, policy_net=net, input_layer=input_obs) 42 | 43 | #model_path = "data/obs_range_sweep_2layer/obs_range_9/run1/" 44 | model_path = "data/obs_range_sweep_2layer_two_evaders/obs_range_" + str(ran) + "/run2/" 45 | 46 | solver.load(model_path+"final_model.ckpt") 47 | d = solver.load_stats(model_path+"final_stats.txt") 48 | 49 | #ims = env.animate(solver, 100, "eval_scripts/results/animations/one_evader_two_pursuers_or_5.mp4", interval=500) 50 | solver.train = False # deterministic policy 51 | env.animate(solver, 100, "eval_scripts/results/animations/one_evader_two_pursuers_two_evader_policy_range_" + str(ran) + ".mp4", rate=1.5) 52 | #env.animate(solver, 100, "eval_scripts/results/animations/temp" + str(ran) + ".mp4", rate=1.5) 53 | 54 | #simulate(env, solver, 100, render=True) 55 | -------------------------------------------------------------------------------- /madrl_environments/walker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sisl/MADRL/9ea39a0fe8695b391008a4eb7bda9fe4438a96de/madrl_environments/walker/__init__.py -------------------------------------------------------------------------------- /madrl_environments/walker/test_walker.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | env = gym.make('BipedalWalker-v2') 4 | 5 | o = env.reset() 6 | 7 | for i in xrange(50): 8 | env.render() 9 | a = env.action_space.sample() 10 | o, r, done, _ = env.step(a) 11 | print o, r, done 12 | -------------------------------------------------------------------------------- /madrl_environments/walker/train_multi_walker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: simple_continuous.py 4 | # 5 | # Created: Thursday, July 14 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import sys 12 | import pprint 13 | sys.path.append('../rltools/') 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | import h5py 18 | 19 | import gym 20 | import rltools.algos.policyopt 21 | import rltools.log 22 | import rltools.util 23 | from rltools.samplers.serial import SimpleSampler, ImportanceWeightedSampler, DecSampler 24 | from rltools.samplers.parallel import ThreadedSampler 25 | from rltools.baselines.linear import LinearFeatureBaseline 26 | from rltools.baselines.mlp import MLPBaseline 27 | from rltools.baselines.zero import ZeroBaseline 28 | from rltools.policy.gaussian import GaussianMLPPolicy 29 | 30 | from multi_walker import MultiWalkerEnv 31 | 32 | 33 | GAE_ARCH = '''[ 34 | {"type": "fc", "n": 200}, 35 | {"type": "nonlin", "func": "tanh"}, 36 | {"type": "fc", "n": 100}, 37 | {"type": "nonlin", "func": "tanh"}, 38 | {"type": "fc", "n": 50}, 39 | {"type": "nonlin", "func": "tanh"} 40 | ] 41 | ''' 42 | 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--discount', type=float, default=0.95) 47 | parser.add_argument('--gae_lambda', type=float, default=0.99) 48 | 49 | parser.add_argument('--n_iter', type=int, default=250) 50 | parser.add_argument('--sampler', type=str, default='simple') 51 | parser.add_argument('--max_traj_len', type=int, default=500) 52 | parser.add_argument('--n_timesteps', type=int, default=8000) # number of traj in an iteration 53 | 54 | parser.add_argument('--n_workers', type=int, default=4) # number of parallel workers for sampling 55 | 56 | parser.add_argument('--policy_hidden_spec', type=str, default=GAE_ARCH) 57 | 58 | parser.add_argument('--baseline_type', type=str, default='mlp') 59 | parser.add_argument('--baseline_hidden_spec', type=str, default=GAE_ARCH) 60 | 61 | parser.add_argument('--max_kl', type=float, default=0.01) 62 | parser.add_argument('--vf_max_kl', type=float, default=0.01) 63 | parser.add_argument('--vf_cg_damping', type=float, default=0.01) 64 | 65 | parser.add_argument('--n_walkers', type=int, default=2) 66 | 67 | parser.add_argument('--save_freq', type=int, default=20) 68 | parser.add_argument('--log', type=str, required=False) 69 | parser.add_argument('--tblog', type=str, default='/tmp/madrl_tb') 70 | parser.add_argument('--debug', dest='debug', action='store_true') 71 | parser.add_argument('--no-debug', dest='debug', action='store_false') 72 | 73 | parser.add_argument('--load_checkpoint', type=str, default='none') 74 | 75 | parser.set_defaults(debug=True) 76 | 77 | args = parser.parse_args() 78 | 79 | env = MultiWalkerEnv(n_walkers=args.n_walkers) 80 | 81 | if args.load_checkpoint is not 'none': 82 | filename, file_key = rltools.util.split_h5_name(args.load_checkpoint) 83 | print('Loading parameters from {} in {}'.format(file_key, filename)) 84 | with h5py.File(filename, 'r') as f: 85 | train_args = json.loads(f.attrs['args']) 86 | dset = f[file_key] 87 | 88 | pprint.pprint(dict(dset.attrs)) 89 | policy = GaussianMLPPolicy(env.observation_space, env.action_space, 90 | hidden_spec=train_args['policy_hidden_spec'], enable_obsnorm=True, 91 | min_stdev=0., init_logstdev=0., tblog=train_args['tblog'], 92 | varscope_name='gaussmlp_policy') 93 | else: 94 | 95 | policy = GaussianMLPPolicy(env.observation_space, env.action_space, 96 | hidden_spec=args.policy_hidden_spec, enable_obsnorm=True, 97 | min_stdev=0., init_logstdev=0., tblog=args.tblog, 98 | varscope_name='gaussmlp_policy') 99 | if args.baseline_type == 'linear': 100 | baseline = LinearFeatureBaseline(env.observation_space, enable_obsnorm=True, 101 | varscope_name='pursuit_linear_baseline') 102 | elif args.baseline_type == 'mlp': 103 | baseline = MLPBaseline(env.observation_space, args.baseline_hidden_spec, True, True, 104 | max_kl=args.vf_max_kl, damping=args.vf_cg_damping, 105 | time_scale=1. / args.max_traj_len, 106 | varscope_name='pursuit_mlp_baseline') 107 | else: 108 | baseline = ZeroBaseline(env.observation_space) 109 | 110 | if args.sampler == 'simple': 111 | sampler_cls = DecSampler 112 | sampler_args = dict(max_traj_len=args.max_traj_len, n_timesteps=args.n_timesteps, 113 | n_timesteps_min=4000, n_timesteps_max=64000, timestep_rate=40, 114 | adaptive=False) 115 | elif args.sampler == 'parallel': 116 | sampler_cls = ParallelSampler 117 | sampler_args = dict(max_traj_len=args.max_traj_len, n_timesteps=args.n_timesteps, 118 | n_timesteps_min=4000, n_timesteps_max=64000, timestep_rate=40, 119 | adaptive=False, n_workers=args.sampler_workers) 120 | else: 121 | raise NotImplementedError() 122 | step_func = rltools.algos.policyopt.TRPO(max_kl=args.max_kl) 123 | popt = rltools.algos.policyopt.SamplingPolicyOptimizer(env=env, policy=policy, 124 | baseline=baseline, step_func=step_func, 125 | discount=args.discount, 126 | gae_lambda=args.gae_lambda, 127 | sampler_cls=sampler_cls, 128 | sampler_args=sampler_args, 129 | n_iter=args.n_iter) 130 | argstr = json.dumps(vars(args), separators=(',', ':'), indent=2) 131 | rltools.util.header(argstr) 132 | log_f = rltools.log.TrainingLog(args.log, [('args', argstr)], debug=args.debug) 133 | 134 | with tf.Session() as sess: 135 | sess.run(tf.initialize_all_variables()) 136 | if args.load_checkpoint is not 'none': 137 | policy.load_h5(sess, filename, file_key) 138 | popt.train(sess, log_f, args.save_freq) 139 | 140 | 141 | if __name__ == '__main__': 142 | main() 143 | -------------------------------------------------------------------------------- /madrl_environments/walker/train_single_walker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: simple_continuous.py 4 | # 5 | # Created: Thursday, July 14 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import sys 12 | sys.path.append('../rltools/') 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | import gym 18 | import rltools.algos.policyopt 19 | import rltools.log 20 | import rltools.util 21 | from rltools.samplers.serial import SimpleSampler, ImportanceWeightedSampler, DecSampler 22 | from rltools.samplers.parallel import ThreadedSampler 23 | from rltools.baselines.linear import LinearFeatureBaseline 24 | from rltools.baselines.mlp import MLPBaseline 25 | from rltools.baselines.zero import ZeroBaseline 26 | from rltools.policy.gaussian import GaussianMLPPolicy 27 | 28 | 29 | GAE_ARCH = '''[ 30 | {"type": "fc", "n": 100}, 31 | {"type": "nonlin", "func": "tanh"}, 32 | {"type": "fc", "n": 50}, 33 | {"type": "nonlin", "func": "tanh"}, 34 | {"type": "fc", "n": 25}, 35 | {"type": "nonlin", "func": "tanh"} 36 | ] 37 | ''' 38 | 39 | 40 | def main(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--discount', type=float, default=0.95) 43 | parser.add_argument('--gae_lambda', type=float, default=0.99) 44 | 45 | parser.add_argument('--n_iter', type=int, default=250) 46 | parser.add_argument('--sampler', type=str, default='simple') 47 | parser.add_argument('--max_traj_len', type=int, default=500) 48 | parser.add_argument('--n_timesteps', type=int, default=8000) # number of traj in an iteration 49 | 50 | parser.add_argument('--n_workers', type=int, default=4) # number of parallel workers for sampling 51 | 52 | parser.add_argument('--policy_hidden_spec', type=str, default=GAE_ARCH) 53 | 54 | parser.add_argument('--baseline_type', type=str, default='mlp') 55 | parser.add_argument('--baseline_hidden_spec', type=str, default=GAE_ARCH) 56 | 57 | parser.add_argument('--max_kl', type=float, default=0.01) 58 | parser.add_argument('--vf_max_kl', type=float, default=0.01) 59 | parser.add_argument('--vf_cg_damping', type=float, default=0.01) 60 | 61 | parser.add_argument('--save_freq', type=int, default=20) 62 | parser.add_argument('--log', type=str, required=False) 63 | parser.add_argument('--tblog', type=str, default='/tmp/madrl_tb') 64 | parser.add_argument('--debug', dest='debug', action='store_true') 65 | parser.add_argument('--no-debug', dest='debug', action='store_false') 66 | parser.set_defaults(debug=True) 67 | 68 | args = parser.parse_args() 69 | 70 | env = gym.make('BipedalWalker-v2') 71 | 72 | policy = GaussianMLPPolicy(env.observation_space, env.action_space, 73 | hidden_spec=args.policy_hidden_spec, enable_obsnorm=True, 74 | min_stdev=0., init_logstdev=0., tblog=args.tblog, 75 | varscope_name='gaussmlp_policy') 76 | if args.baseline_type == 'linear': 77 | baseline = LinearFeatureBaseline(env.observation_space, enable_obsnorm=True, 78 | varscope_name='pursuit_linear_baseline') 79 | elif args.baseline_type == 'mlp': 80 | baseline = MLPBaseline(env.observation_space, args.baseline_hidden_spec, True, True, 81 | max_kl=args.vf_max_kl, damping=args.vf_cg_damping, 82 | time_scale=1. / args.max_traj_len, 83 | varscope_name='pursuit_mlp_baseline') 84 | else: 85 | baseline = ZeroBaseline(env.observation_space) 86 | 87 | if args.sampler == 'simple': 88 | sampler_cls = SimpleSampler 89 | sampler_args = dict(max_traj_len=args.max_traj_len, n_timesteps=args.n_timesteps, 90 | n_timesteps_min=4000, n_timesteps_max=64000, timestep_rate=40, 91 | adaptive=False) 92 | elif args.sampler == 'parallel': 93 | sampler_cls = ParallelSampler 94 | sampler_args = dict(max_traj_len=args.max_traj_len, n_timesteps=args.n_timesteps, 95 | n_timesteps_min=4000, n_timesteps_max=64000, timestep_rate=40, 96 | adaptive=False, n_workers=args.sampler_workers) 97 | else: 98 | raise NotImplementedError() 99 | step_func = rltools.algos.policyopt.TRPO(max_kl=args.max_kl) 100 | popt = rltools.algos.policyopt.SamplingPolicyOptimizer(env=env, policy=policy, 101 | baseline=baseline, step_func=step_func, 102 | discount=args.discount, 103 | gae_lambda=args.gae_lambda, 104 | sampler_cls=sampler_cls, 105 | sampler_args=sampler_args, 106 | n_iter=args.n_iter) 107 | argstr = json.dumps(vars(args), separators=(',', ':'), indent=2) 108 | rltools.util.header(argstr) 109 | log_f = rltools.log.TrainingLog(args.log, [('args', argstr)], debug=args.debug) 110 | 111 | with tf.Session() as sess: 112 | sess.run(tf.initialize_all_variables()) 113 | popt.train(sess, log_f, args.save_freq) 114 | 115 | 116 | if __name__ == '__main__': 117 | main() 118 | -------------------------------------------------------------------------------- /maps/map_pool16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sisl/MADRL/9ea39a0fe8695b391008a4eb7bda9fe4438a96de/maps/map_pool16.npy -------------------------------------------------------------------------------- /pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sisl/MADRL/9ea39a0fe8695b391008a4eb7bda9fe4438a96de/pipelines/__init__.py -------------------------------------------------------------------------------- /pipelines/cont_pipeline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: cont_pipeline.py 4 | # 5 | # Created: Friday, July 15 2016 by rejuvyesh 6 | # 7 | 8 | import argparse 9 | import os 10 | import yaml 11 | import shutil 12 | import rltools 13 | from pipelines import pipeline 14 | 15 | # Fix python 2.x 16 | try: 17 | input = raw_input 18 | except NameError: 19 | pass 20 | 21 | 22 | def phase_train(spec, spec_file): 23 | rltools.util.header('=== Running {} ==='.format(spec_file)) 24 | 25 | # Make checkpoint dir. All outputs go here 26 | storagedir = spec['options']['storagedir'] 27 | n_workers = spec['options']['n_workers'] 28 | checkptdir = os.path.join(spec['options']['storagedir'], spec['options']['checkpt_subdir']) 29 | rltools.util.mkdir_p(checkptdir) 30 | assert not os.listdir(checkptdir), 'Checkpoint directory {} is not empty!'.format(checkptdir) 31 | 32 | cmd_templates, output_filenames, argdicts = [], [], [] 33 | for alg in spec['training']['algorithms']: 34 | for bline in spec['training']['baselines']: 35 | for n_ev in spec['n_evaders']: 36 | for n_pu in spec['n_pursuers']: 37 | for n_se in spec['n_sensors']: 38 | for n_co in spec['n_coop']: 39 | # Number of cooperating agents can't be greater than pursuers 40 | if n_co > n_pu: 41 | continue 42 | for f_rew in spec['food_reward']: 43 | for p_rew in spec['poison_reward']: 44 | for e_rew in spec['encounter_reward']: 45 | for disc in spec['discounts']: 46 | for gae in spec['gae_lambdas']: 47 | for run in range(spec['training']['runs']): 48 | strid = 'alg={},bline={},n_ev={},n_pu={},n_se={},n_co={},f_rew={},p_rew={},e_rew={},disc={},gae={},run={}'.format( 49 | alg['name'], bline, n_ev, n_pu, n_se, n_co, 50 | f_rew, p_rew, e_rew, disc, gae, run) 51 | cmd_templates.append(alg['cmd'].replace( 52 | '\n', ' ').strip()) 53 | output_filenames.append(strid + '.txt') 54 | argdicts.append({ 55 | 'baseline_type': bline, 56 | 'n_evaders': n_ev, 57 | 'n_pursuers': n_pu, 58 | 'n_sensors': n_se, 59 | 'n_coop': n_co, 60 | 'discount': disc, 61 | 'food_reward': f_rew, 62 | 'poison_reward': p_rew, 63 | 'encounter_reward': e_rew, 64 | 'gae_lambda': gae, 65 | 'log': os.path.join(checkptdir, 66 | strid + '.h5') 67 | }) 68 | 69 | rltools.util.ok('{} jobs to run...'.format(len(cmd_templates))) 70 | rltools.util.warn('Continue? y/n') 71 | if input() == 'y': 72 | pipeline.run_jobs(cmd_templates, output_filenames, argdicts, storagedir, 73 | n_workers=n_workers) 74 | else: 75 | rltools.util.failure('Canceled.') 76 | sys.exit(1) 77 | 78 | # Copy the pipeline yaml file to the output dir too 79 | shutil.copyfile(spec_file, os.path.join(checkptdir, 'pipeline.yaml')) 80 | # Keep git commit 81 | import subprocess 82 | git_hash = subprocess.check_output('git rev-parse HEAD', shell=True).strip() 83 | with open(os.path.join(checkptdir, 'git_hash.txt'), 'w') as f: 84 | f.write(git_hash + '\n') 85 | 86 | 87 | def main(): 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument('spec', type=str) 90 | args = parser.parse_args() 91 | 92 | with open(args.spec, 'r') as f: 93 | spec = yaml.load(f) 94 | 95 | phase_train(spec, args.spec) 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /pipelines/disc_pipeline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: disc_pipeline.py 4 | # 5 | # Created: Friday, July 15 2016 by rejuvyesh 6 | # 7 | import sys 8 | sys.path.append('../rltools/') 9 | import argparse 10 | import os 11 | import shutil 12 | import subprocess 13 | 14 | import yaml 15 | import rltools.util 16 | import pipeline 17 | 18 | # Fix python 2.x 19 | try: 20 | input = raw_input 21 | except NameError: 22 | pass 23 | 24 | def phase_train(spec, spec_file): 25 | rltools.util.header('=== Running {} ==='.format(spec_file)) 26 | 27 | # Make checkpoint dir. All outputs go here 28 | storagedir = spec['options']['storagedir'] 29 | n_workers = spec['options']['n_workers'] 30 | checkptdir = os.path.join(spec['options']['storagedir'], spec['options']['checkpt_subdir']) 31 | rltools.util.mkdir_p(checkptdir) 32 | assert not os.listdir(checkptdir), 'Checkpoint directory {} is not empty!'.format(checkptdir) 33 | 34 | cmd_templates, output_filenames, argdicts = [], [], [] 35 | for alg in spec['training']['algorithms']: 36 | for bline in spec['training']['baselines']: 37 | for rect in spec['rectangles']: 38 | for n_ev in spec['n_evaders']: 39 | for n_pu in spec['n_pursuers']: 40 | for orng in spec['obs_ranges']: 41 | # observation range can't be bigger than the board 42 | if orng > max(map(int, rect.split(','))): 43 | continue 44 | for n_ca in spec['n_catches']: 45 | # number of simulataneous catches can't be bigger than numer of pusuers 46 | if n_ca > n_pu: 47 | continue 48 | for disc in spec['discounts']: 49 | for gae in spec['gae_lambdas']: 50 | for run in range(spec['training']['runs']): 51 | strid = 'alg={},bline={},rect={},n_ev={},n_pu={},orng={},n_ca={},disc={},gae={},run={}'.format(alg['name'], bline, rect, n_ev, n_pu, orng, n_ca, disc, gae, run) 52 | cmd_templates.append(alg['cmd'].replace('\n', ' ').strip()) 53 | output_filenames.append(strid + '.txt') 54 | argdicts.append({ 55 | 'baseline_type': bline, 56 | 'rectangle': rect, 57 | 'n_evaders': n_ev, 58 | 'n_pursuers': n_pu, 59 | 'obs_range': orng, 60 | 'n_catch': n_ca, 61 | 'discount': disc, 62 | 'gae_lambda': gae, 63 | 'log': os.path.join(checkptdir, strid+'.h5') 64 | }) 65 | 66 | rltools.util.ok('{} jobs to run...'.format(len(cmd_templates))) 67 | rltools.util.warn('Continue? y/n') 68 | if input() == 'y': 69 | pipeline.run_jobs(cmd_templates, output_filenames, argdicts, storagedir, n_workers=n_workers) 70 | else: 71 | rltools.util.failure('Canceled.') 72 | sys.exit(1) 73 | 74 | # Copy the pipeline yaml file to the output dir too 75 | shutil.copyfile(spec_file, os.path.join(checkptdir, 'pipeline.yaml')) 76 | # Keep git commit 77 | git_hash = subprocess.check_output('git rev-parse HEAD', shell=True).strip() 78 | with open(os.path.join(checkptdir, 'git_hash.txt'), 'w') as f: 79 | f.write(git_hash + '\n') 80 | 81 | 82 | 83 | def main(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('spec', type=str) 86 | args = parser.parse_args() 87 | 88 | with open(args.spec, 'r') as f: 89 | spec = yaml.load(f) 90 | 91 | phase_train(spec, args.spec) 92 | 93 | if __name__ == '__main__': 94 | main() 95 | 96 | -------------------------------------------------------------------------------- /pipelines/host_pipeline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: host_pipeline.py 4 | # 5 | # Created: Monday, August 1 2016 by rejuvyesh 6 | # License: GNU GPL 3 7 | # 8 | import argparse 9 | import os 10 | import yaml 11 | import shutil 12 | import rltools 13 | from pipelines import pipeline 14 | 15 | # Fix python 2.x 16 | try: 17 | input = raw_input 18 | except NameError: 19 | pass 20 | 21 | 22 | def phase_train(spec, spec_file): 23 | rltools.util.header('=== Running {} ==='.format(spec_file)) 24 | # Make checkpoint dir. All outputs go here 25 | storagedir = spec['options']['storagedir'] 26 | n_workers = spec['options']['n_workers'] 27 | checkptdir = os.path.join(spec['options']['storagedir'], spec['options']['checkpt_subdir']) 28 | rltools.util.mkdir_p(checkptdir) 29 | assert not os.listdir(checkptdir), 'Checkpoint directory {} is not empty!'.format(checkptdir) 30 | 31 | cmd_templates, output_filenames, argdicts = [], [], [] 32 | for alg in spec['training']['algorithms']: 33 | for bline in spec['training']['baselines']: 34 | for n_g in spec['n_good']: 35 | for n_h in spec['n_hostages']: 36 | for n_b in spec['n_bad']: 37 | for n_cs in spec['n_coop_save']: 38 | if n_cs > n_g: 39 | continue 40 | for n_ca in spec['n_coop_avoid']: 41 | if n_ca > n_g: 42 | continue 43 | for n_se in spec['n_sensors']: 44 | for srange in spec['sensor_range']: 45 | for srew in spec['save_reward']: 46 | for hrew in spec['hit_reward']: 47 | for erew in spec['encounter_reward']: 48 | for borew in spec['bomb_reward']: 49 | for disc in spec['discounts']: 50 | for gae in spec['gae_lambdas']: 51 | for run in range(spec['training'][ 52 | 'runs']): 53 | strid = 'alg={},bline={},n_g={},n_h={},n_b={},n_cs={},n_ca={},n_se={},srange={},srew={},hrew={},erew={},borew={},disc={},gae={},run={}'.format( 54 | alg['name'], bline, n_g, 55 | n_h, n_b, n_cs, n_ca, n_se, 56 | srange, srew, hrew, erew, 57 | borew, disc, gae, run) 58 | cmd_templates.append(alg[ 59 | 'cmd'].replace('\n', 60 | ' ').strip()) 61 | output_filenames.append(strid + 62 | '.txt') 63 | argdicts.append({ 64 | 'baseline_type': bline, 65 | 'n_good': n_g, 66 | 'n_hostage': n_h, 67 | 'n_bad': n_b, 68 | 'n_coop_save': n_cs, 69 | 'n_coop_avoid': n_ca, 70 | 'n_sensors': n_se, 71 | 'sensor_range': srange, 72 | 'save_reward': srew, 73 | 'hit_reward': hrew, 74 | 'encounter_reward': erew, 75 | 'bomb_reward': borew, 76 | 'discount': disc, 77 | 'gae_lambda': gae, 78 | 'log': os.path.join( 79 | checkptdir, 80 | strid + '.h5') 81 | }) 82 | 83 | rltools.util.ok('{} jobs to run...'.format(len(cmd_templates))) 84 | rltools.util.warn('Continue? y/n') 85 | if input() == 'y': 86 | pipeline.run_jobs(cmd_templates, output_filenames, argdicts, storagedir, 87 | n_workers=n_workers) 88 | else: 89 | rltools.util.failure('Canceled.') 90 | sys.exit(1) 91 | 92 | # Copy the pipeline yaml file to the output dir too 93 | shutil.copyfile(spec_file, os.path.join(checkptdir, 'pipeline.yaml')) 94 | # Keep git commit 95 | import subprocess 96 | git_hash = subprocess.check_output('git rev-parse HEAD', shell=True).strip() 97 | with open(os.path.join(checkptdir, 'git_hash.txt'), 'w') as f: 98 | f.write(git_hash + '\n') 99 | 100 | 101 | def main(): 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('spec', type=str) 104 | args = parser.parse_args() 105 | 106 | with open(args.spec, 'r') as f: 107 | spec = yaml.load(f) 108 | 109 | phase_train(spec, args.spec) 110 | 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /pipelines/pipeline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: pipeline.py 4 | # 5 | # Created: Tuesday, July 12 2016 by rejuvyesh 6 | # 7 | import argparse 8 | import datetime 9 | import multiprocessing as mp 10 | import numpy as np 11 | 12 | import os 13 | import shutil 14 | import subprocess 15 | import sys 16 | import rltools.util 17 | 18 | 19 | def runcommand(cmd): 20 | try: 21 | return subprocess.check_output(cmd, shell=True).strip() 22 | except: 23 | return "Error executing command {}".format(cmd) 24 | 25 | 26 | class Worker(mp.Process): 27 | 28 | def __init__(self, work_queue, result_queue): 29 | # base class initialization 30 | mp.Process.__init__(self) 31 | self.work_queue = work_queue 32 | self.result_queue = result_queue 33 | self.kill_received = False 34 | 35 | def run(self): 36 | while (not (self.kill_received)) and (self.work_queue.empty() == False): 37 | try: 38 | job = self.work_queue.get_nowait() 39 | outfile = self.result_queue.get_nowait() 40 | except: 41 | break 42 | print('Starting job: {}'.format(job)) 43 | rtn_val = runcommand(job) 44 | with open(outfile, 'w') as f: 45 | f.write(rtn_val + '\n') 46 | 47 | 48 | def run_jobs(cmd_templates, output_filenames, argdicts, storage_dir, outputfile_dir=None, 49 | jobname=None, n_workers=4): 50 | assert len(cmd_templates) == len(output_filenames) == len(argdicts) 51 | num_cmds = len(cmd_templates) 52 | outputfile_dir = outputfile_dir if outputfile_dir is not None else 'logs_%s_%s' % ( 53 | jobname, datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')) 54 | rltools.util.mkdir_p(os.path.join(storage_dir, outputfile_dir)) 55 | 56 | cmds, outputfiles = [], [] 57 | for i in range(num_cmds): 58 | cmds.append(cmd_templates[i].format(**argdicts[i])) 59 | outputfiles.append( 60 | os.path.join(storage_dir, outputfile_dir, '{:04d}_{}'.format(i + 1, output_filenames[ 61 | i]))) 62 | 63 | work_queue = mp.Queue() 64 | res_queue = mp.Queue() 65 | for cmd, ofile in zip(cmds, outputfiles): 66 | work_queue.put(cmd) 67 | res_queue.put(ofile) 68 | 69 | worker = [] 70 | for i in range(n_workers): 71 | worker.append(Worker(work_queue, res_queue)) 72 | worker[i].start() 73 | 74 | 75 | def create_slurm_script(commands, outputfiles, jobname=None, nodes=4, cpus=6): 76 | assert len(commands) == len(outputfiles) 77 | template = '''#!/bin/bash 78 | # 79 | #all commands that start with SBATCH contain commands that are just used by SLURM for scheduling 80 | ################# 81 | #set a job name 82 | #SBATCH --job-name={jobname} 83 | ################# 84 | #time you think you need; default is one hour 85 | #in minutes in this case, hh:mm:ss 86 | #SBATCH --time=24:00:00 87 | ################# 88 | #quality of service; think of it as job priority 89 | #SBATCH --qos=normal 90 | ################# 91 | #number of nodes you are requesting 92 | #SBATCH --nodes={nodes} 93 | ################# 94 | #tasks to run per node; a "task" is usually mapped to a MPI processes. 95 | # for local parallelism (OpenMP or threads), use "--ntasks-per-node=1 --cpus-per-task=16" instead 96 | #SBATCH --ntasks-per-node=1 --cpus-per-task={cpus} 97 | ################# 98 | module load singularity 99 | export PYTHONPATH=/scratch/PI/mykel/src/python/rltools/:$PI_SCRATCH/src/python/rllab3:$PI_SCRATCH/src/python/MADRL:$PYTHONPATH 100 | 101 | read -r -d '' COMMANDS << END 102 | {cmds_str} 103 | END 104 | cmd=$(echo "$COMMANDS" | awk "NR == $SLURM_ARRAY_TASK_ID") 105 | echo $cmd 106 | 107 | read -r -d '' OUTPUTFILES << END 108 | {outputfiles_str} 109 | END 110 | outputfile=$SLURM_SUBMIT_DIR/$(echo "$OUTPUTFILES" | awk "NR == $SLURM_ARRAY_TASK_ID") 111 | echo $outputfile 112 | # Make sure output directory exists 113 | mkdir -p "`dirname \"$outputfile\"`" 2>/dev/null 114 | 115 | echo $cmd >$outputfile 116 | eval $cmd >>$outputfile 2>&1 117 | ''' 118 | return template.format(jobname=jobname, nodes=nodes, cpus=cpus, cmds_str='\n'.join(commands), 119 | outputfiles_str='\n'.join(outputfiles)) 120 | 121 | 122 | def run_slurm(cmd_templates, output_filenames, argdicts, storage_dir, outputfile_dir=None, 123 | jobname=None, n_workers=4, slurm_script_copy=None): 124 | assert len(cmd_templates) == len(output_filenames) == len(argdicts) 125 | num_cmds = len(cmd_templates) 126 | 127 | outputfile_dir = outputfile_dir if outputfile_dir is not None else 'logs_%s_%s' % ( 128 | jobname, datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')) 129 | 130 | cmds, outputfiles = [], [] 131 | for i in range(num_cmds): 132 | cmds.append(cmd_templates[i].format(**argdicts[i])) 133 | # outputfile_name = outputfile_prefixes[i] + ','.join('{}={}'.format(k,v) for k,v in sorted(argdicts[i].items())) + outputfile_suffix 134 | outputfiles.append( 135 | os.path.join(outputfile_dir, '{:04d}_{}'.format(i + 1, output_filenames[i]))) 136 | 137 | script = create_slurm_script(cmds, outputfiles, jobname, nodes=n_workers) 138 | print(script) 139 | 140 | import tempfile 141 | with tempfile.NamedTemporaryFile(mode='w', suffix='.sh') as f: 142 | f.write(script) 143 | f.flush() 144 | 145 | cmd = 'sbatch --array %d-%d %s' % (1, len(cmds), f.name) 146 | print('Running command: {}'.format(cmd)) 147 | print('ok ({} jobs)? y/n'.format(num_cmds)) 148 | if input() == 'y': 149 | # Write a copy of the script 150 | if slurm_script_copy is not None: 151 | assert not os.path.exists(slurm_script_copy) 152 | with open(slurm_script_copy, 'w') as fcopy: 153 | fcopy.write(script) 154 | print('slurm script written to {}'.format(slurm_script_copy)) 155 | # Run slurm 156 | subprocess.check_call(cmd, shell=True) 157 | else: 158 | raise RuntimeError('Canceled.') 159 | 160 | 161 | from vis import Evaluator 162 | import json 163 | 164 | 165 | def envname2env(envname, args): 166 | from madrl_environments.pursuit import PursuitEvade 167 | from madrl_environments.pursuit import MAWaterWorld 168 | from madrl_environments.walker.multi_walker import MultiWalkerEnv 169 | 170 | # XXX 171 | # Will generalize later 172 | if envname == 'multiwalker': 173 | env = MultiWalkerEnv( 174 | args['n_walkers'], 175 | args['position_noise'], 176 | args['angle_noise'], 177 | reward_mech='global',) 178 | 179 | elif envname == 'waterworld': 180 | env = MAWaterWorld( 181 | args['n_pursuers'], 182 | args['n_evaders'], 183 | args['n_coop'], 184 | args['n_poison'], 185 | n_sensors=args['n_sensors'], 186 | food_reward=args['food_reward'], 187 | poison_reward=args['poison_reward'], 188 | encounter_reward=args['encounter_reward'], 189 | reward_mech='global',) 190 | 191 | elif envname == 'pursuit': 192 | env = PursuitEvade() 193 | else: 194 | raise NotImplementedError() 195 | 196 | return env 197 | 198 | 199 | def eval_snapshot(envname, checkptfile, last_snapshot_idx, n_trajs, mode): 200 | import tensorflow as tf 201 | if mode == 'rltools': 202 | import h5py 203 | with h5py.File(checkptfile, 'r') as f: 204 | args = json.loads(f.attrs['args']) 205 | elif mode == 'rllab': 206 | params_file = os.path.join(checkptfile, 'params.json') 207 | with open(params_file, 'r') as df: 208 | args = json.load(df) 209 | 210 | env = envname2env(envname, args) 211 | bestidx = 0 212 | bestret = -np.inf 213 | bestevr = {} 214 | for idx in range((last_snapshot_idx - 10), (last_snapshot_idx + 1)): 215 | tf.reset_default_graph() 216 | minion = Evaluator(env, args, args['max_traj_len'] if mode == 'rltools' else 217 | args['max_path_length'], n_trajs, False, mode) 218 | if mode == 'rltools': 219 | evr = minion(checkptfile, file_key='snapshots/iter%07d' % idx) 220 | elif mode == 'rllab': 221 | evr = minion(os.path.join(checkptfile, 'itr_{}.pkl'.format(idx))) 222 | 223 | if np.mean(evr['ret']) > bestret: 224 | bestret = np.mean(evr['ret']) 225 | bestevr = evr 226 | bestidx = idx 227 | return bestevr, bestidx 228 | 229 | 230 | def eval_heuristic_for_snapshot(envname, checkptfile, last_snapshot_idx, n_trajs): 231 | import h5py 232 | with h5py.File(checkptfile, 'r') as f: 233 | args = json.loads(f.attrs['args']) 234 | 235 | env = envname2env(envname, args) 236 | minion = Evaluator(env, args, args['max_traj_len'], n_trajs, False, 'heuristic') 237 | evr = minion(checkptfile, file_key='iter%07d' % last_snapshot_idx) 238 | return evr 239 | -------------------------------------------------------------------------------- /pipelines/waterworld.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: waterworld.py 4 | # 5 | # Created: Wednesday, August 24 2016 by rejuvyesh 6 | # 7 | import argparse 8 | import os 9 | import yaml 10 | import shutil 11 | import sys 12 | from rltools import util 13 | from pipelines import pipeline 14 | 15 | # Fix python 2.x 16 | try: 17 | input = raw_input 18 | except NameError: 19 | pass 20 | 21 | 22 | def phase_train(spec, spec_file, git_hash): 23 | util.header('=== Running {} ==='.format(spec_file)) 24 | 25 | # Make checkpoint dir. All outputs go here 26 | storagedir = spec['options']['storagedir'] 27 | n_workers = spec['options']['n_workers'] 28 | checkptdir = os.path.join(spec['options']['storagedir'], spec['options']['checkpt_subdir']) 29 | util.mkdir_p(checkptdir) 30 | assert not os.listdir(checkptdir), 'Checkpoint directory {} is not empty!'.format(checkptdir) 31 | 32 | cmd_templates, output_filenames, argdicts = [], [], [] 33 | train_spec = spec['training'] 34 | arg_spec = spec['arguments'] 35 | for alg in train_spec['algorithms']: 36 | for bline in train_spec['baselines']: 37 | for parch in train_spec['policy_archs']: 38 | for barch in train_spec['baseline_archs']: 39 | for rad in arg_spec['radius']: 40 | for n_se in arg_spec['n_sensors']: 41 | for srange in arg_spec['sensor_ranges']: 42 | for n_ev in arg_spec['n_evaders']: 43 | for n_pu in arg_spec['n_pursuers']: 44 | for n_co in arg_spec['n_coop']: 45 | if n_co > n_pu: 46 | continue 47 | for n_po in arg_spec['n_poison']: 48 | for f_rew in arg_spec['food_reward']: 49 | for p_rew in arg_spec['poison_reward']: 50 | for e_rew in arg_spec['encounter_reward']: 51 | for disc in arg_spec['discounts']: 52 | for gae in arg_spec['gae_lambdas']: 53 | for run in range(train_spec[ 54 | 'runs']): 55 | strid = ( 56 | 'alg={},bline={},parch={},barch={},'. 57 | format(alg['name'], 58 | bline, parch, 59 | barch) + 60 | 'rad={},n_se={},srange={},n_ev={},n_pu={},n_co={},n_po={},'. 61 | format(rad, n_se, 62 | srange, n_ev, 63 | n_pu, n_co, n_po) 64 | + 65 | 'f_rew={},p_rew={},e_rew={},'. 66 | format(f_rew, p_rew, 67 | e_rew) + 68 | 'disc={},gae={},run={}'. 69 | format(disc, gae, run)) 70 | cmd_templates.append(alg[ 71 | 'cmd'].replace( 72 | '\n', ' ').strip()) 73 | output_filenames.append( 74 | strid + '.txt') 75 | argdicts.append({ 76 | 'baseline_type': bline, 77 | 'radius': rad, 78 | 'sensor_range': srange, 79 | 'n_sensors': n_se, 80 | 'n_pursuers': n_pu, 81 | 'n_evaders': n_ev, 82 | 'n_coop': n_co, 83 | 'n_poison': n_po, 84 | 'discount': disc, 85 | 'food_reward': f_rew, 86 | 'poison_reward': p_rew, 87 | 'encounter_reward': 88 | e_rew, 89 | 'gae_lambda': gae, 90 | 'policy_arch': parch, 91 | 'baseline_arch': barch, 92 | 'log': os.path.join( 93 | checkptdir, 94 | strid + '.h5') 95 | }) 96 | 97 | util.ok('{} jobs to run...'.format(len(cmd_templates))) 98 | util.warn('Continue? y/n') 99 | if input() == 'y': 100 | pipeline.run_jobs(cmd_templates, output_filenames, argdicts, storagedir, 101 | jobname=os.path.split(spec_file)[-1], n_workers=n_workers) 102 | sys.exit(0) 103 | else: 104 | util.failure('Canceled.') 105 | sys.exit(1) 106 | 107 | # Copy the pipeline yaml file to the output dir too 108 | shutil.copyfile(spec_file, os.path.join(checkptdir, 'pipeline.yaml')) 109 | with open(os.path.join(checkptdir, 'git_hash.txt'), 'w') as f: 110 | f.write(git_hash + '\n') 111 | 112 | 113 | def main(): 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('spec', type=str) 116 | parser.add_argument('git_hash', type=str) 117 | args = parser.parse_args() 118 | 119 | with open(args.spec, 'r') as f: 120 | spec = yaml.load(f) 121 | 122 | if args.git_hash is None: 123 | args.git_hash = '000000' 124 | phase_train(spec, args.spec, args.git_hash) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /pursuit_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from rltools import nn, tfutil 5 | from rltools.distributions import Distribution 6 | from rltools.policy.stochastic import StochasticPolicy 7 | 8 | 9 | class FactoredCategorical(Distribution): 10 | 11 | def __init__(self, dim): 12 | self._dim = dim 13 | 14 | @property 15 | def dim(self): 16 | self._dim 17 | 18 | def entropy(self, probs_N_H_K): 19 | tmp = -probs_N_H_K * np.log(probs_N_H_K) 20 | tmp[~np.isfinite(tmp)] = 0 21 | return tmp.sum(axis=2) 22 | 23 | def sample(self, probs_N_H_K): 24 | """Sample from N factored categorical distributions""" 25 | N, H, K = probs_N_H_K.shape 26 | return np.array( 27 | [[np.random.choice(K, p=probs_N_H_K[i, j, :]) for j in range(H)] for i in xrange(N)]) 28 | 29 | def kl_expr(self, logprobs1_B_N_A, logprobs2_B_N_A, name=None): 30 | """KL divergence between facotored categorical distributions""" 31 | with tf.op_scope([logprobs1_B_N_A, logprobs2_B_N_A], name, 'fac_categorical_kl') as scope: 32 | kl_B = tf.reduce_sum( 33 | tf.reduce_sum( 34 | tf.exp(logprobs1_B_N_A) * (logprobs1_B_N_A - logprobs2_B_N_A), 2), 1, 35 | name=scope) 36 | return kl_B 37 | 38 | 39 | class PursuitCentralMLPPolicy(StochasticPolicy): 40 | 41 | def __init__(self, obsfeat_space, action_space, n_agents, hidden_spec, enable_obsnorm, tblog, 42 | varscope_name): 43 | self.hidden_spec = hidden_spec 44 | self._n_agents = n_agents 45 | self._dist = FactoredCategorical(action_space.n) 46 | super(PursuitCentralMLPPolicy, self).__init__(obsfeat_space, action_space, action_space.n, 47 | enable_obsnorm, tblog, varscope_name) 48 | 49 | @property 50 | def distribution(self): 51 | return self._dist 52 | 53 | def _make_actiondist_ops(self, obsfeat_B_Df): 54 | with tf.variable_scope('hidden'): 55 | net = nn.FeedforwardNet(obsfeat_B_Df, self.obsfeat_space.shape, self.hidden_spec) 56 | with tf.variable_scope('out'): 57 | out_layer = nn.AffineLayer(net.output, net.output_shape, (self.action_space.n,), 58 | initializer=tf.zeros_initializer) 59 | 60 | scores_B_NPa = out_layer.output 61 | scores_B_N_Pa = tf.reshape(scores_B_NPa, 62 | (-1, self._n_agents, self.action_space.n / self._n_agents)) 63 | actiondist_B_N_Pa = scores_B_N_Pa - tfutil.logsumexp(scores_B_N_Pa, axis=2) 64 | actiondist_B_NPa = tf.reshape(actiondist_B_N_Pa, (-1, self.action_space.n)) 65 | return actiondist_B_NPa 66 | 67 | def _make_actiondist_logprobs_ops(self, actiondist_B_NPa, input_actions_B_N): 68 | actiondist_B_N_Pa = tf.reshape(actiondist_B_NPa, 69 | (-1, self._n_agents, self.action_space.n / self._n_agents)) 70 | logprob_B_N = tfutil.lookup_last_idx(actiondist_B_N_Pa, input_actions_B_N) 71 | return tf.reduce_sum(logprob_B_N, 1) # Product of probabilities 72 | 73 | def _make_actiondist_kl_ops(self, proposal_actiondist_B_NPa, actiondist_B_NPa): 74 | proposal_actiondist_B_N_Pa = tf.reshape(proposal_actiondist_B_NPa, 75 | (-1, self._n_agents, 76 | self.action_space.n / self._n_agents)) 77 | actiondist_B_N_Pa = tf.reshape(actiondist_B_NPa, 78 | (-1, self._n_agents, self.action_space.n / self._n_agents)) 79 | return self.distribution.kl_expr(proposal_actiondist_B_N_Pa, actiondist_B_N_Pa) 80 | 81 | def _sample_from_actiondist(self, actiondist_B_NPa, deterministic=False): 82 | actiondist_B_N_Pa = np.reshape(actiondist_B_NPa, 83 | (-1, self._n_agents, self.action_space.n / self._n_agents)) 84 | probs_B_N_A = np.exp(actiondist_B_N_Pa) 85 | assert probs_B_N_A.ndim == 3 86 | assert probs_B_N_A.shape[2] == self.action_space.n / self._n_agents 87 | if deterministic: 88 | action_B_N = np.argmax(probs_B_N_A, axis=2) 89 | else: 90 | action_B_N = self.distribution.sample(probs_B_N_A) 91 | assert action_B_N.ndim == 2 and action_B_N.shape[-1] == self._n_agents 92 | 93 | return action_B_N 94 | 95 | def _compute_actiondist_entropy(self, actiondist_B_NPa): 96 | actiondist_B_N_Pa = actiondist_B_NPa.reshape( 97 | (-1, self._n_agents, self.action_space.n / self._n_agents)) 98 | return self.distribution.entropy(np.exp(actiondist_B_N_Pa)) 99 | -------------------------------------------------------------------------------- /rllabwrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | import gym 5 | import gym.envs 6 | import gym.spaces 7 | 8 | from rllab.envs.base import Env, Step 9 | from rllab.core.serializable import Serializable 10 | from rllab.spaces.box import Box 11 | from rllab.spaces.discrete import Discrete 12 | 13 | import numpy as np 14 | 15 | 16 | def convert_gym_space(space, n_agents=1): 17 | if isinstance(space, gym.spaces.Box) or isinstance(space, Box): 18 | if len(space.shape) > 1: 19 | assert n_agents == 1, "multi-dimensional inputs for centralized agents not supported" 20 | return Box(low=np.min(space.low), high=np.max(space.high), shape=space.shape) 21 | else: 22 | return Box(low=np.min(space.low), high=np.max(space.high), 23 | shape=(space.shape[0] * n_agents,)) 24 | elif isinstance(space, gym.spaces.Discrete) or isinstance(space, Discrete): 25 | return Discrete(n=space.n**n_agents) 26 | else: 27 | raise NotImplementedError 28 | 29 | 30 | class RLLabEnv(Env, Serializable): 31 | 32 | def __init__(self, env, ma_mode): 33 | Serializable.quick_init(self, locals()) 34 | 35 | self.env = env 36 | if hasattr(env, 'id'): 37 | self.env_id = env.id 38 | else: 39 | self.env_id = 'MA-Wrapper-v0' 40 | 41 | if ma_mode == 'centralized': 42 | obsfeat_space = convert_gym_space(env.agents[0].observation_space, 43 | n_agents=len(env.agents)) 44 | action_space = convert_gym_space(env.agents[0].action_space, n_agents=len(env.agents)) 45 | elif ma_mode in ['decentralized', 'concurrent']: 46 | obsfeat_space = convert_gym_space(env.agents[0].observation_space, n_agents=1) 47 | action_space = convert_gym_space(env.agents[0].action_space, n_agents=1) 48 | 49 | else: 50 | raise NotImplementedError 51 | 52 | self._observation_space = obsfeat_space 53 | self._action_space = action_space 54 | if hasattr(env, 'timestep_limit'): 55 | self._horizon = env.timestep_limit 56 | else: 57 | self._horizon = 250 58 | 59 | @property 60 | def agents(self): 61 | return self.env.agents 62 | 63 | @property 64 | def observation_space(self): 65 | return self._observation_space 66 | 67 | @property 68 | def action_space(self): 69 | return self._action_space 70 | 71 | @property 72 | def horizon(self): 73 | return self._horizon 74 | 75 | def reset(self): 76 | return self.env.reset() 77 | 78 | def step(self, action): 79 | next_obs, reward, done, info = self.env.step(action) 80 | if info is None: 81 | info = dict() 82 | return Step(next_obs, reward, done, **info) 83 | 84 | def render(self): 85 | self.env.render() 86 | 87 | def set_param_values(self, *args, **kwargs): 88 | self.env.set_param_values(*args, **kwargs) 89 | 90 | def get_param_values(self, *args, **kwargs): 91 | self.env.get_param_values(*args, **kwargs) 92 | -------------------------------------------------------------------------------- /rllabwrapper/rllab_gru_test.py: -------------------------------------------------------------------------------- 1 | from madrl_environments import StandardizedEnv 2 | from madrl_environments.pursuit import MAWaterWorld 3 | from rllabwrapper import RLLabEnv 4 | from rllab.sampler import parallel_sampler 5 | from sandbox.rocky.tf.algos.ma_trpo import MATRPO 6 | from sandbox.rocky.tf.envs.base import MATfEnv 7 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 8 | 9 | from sandbox.rocky.tf.policies.gaussian_gru_policy import GaussianGRUPolicy 10 | from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer, FiniteDifferenceHvp 11 | 12 | parallel_sampler.initialize(n_parallel=2) 13 | env = StandardizedEnv(MAWaterWorld(3, 10, 2, 5)) 14 | env = MATfEnv(RLLabEnv(env, ma_mode='decentralized')) 15 | 16 | policy = GaussianGRUPolicy(env_spec=env.spec, name='policy') 17 | 18 | baseline = LinearFeatureBaseline(env_spec=env.spec) 19 | 20 | algo = MATRPO(env=env, policy_or_policies=policy, baseline_or_baselines=baseline, batch_size=8000, 21 | max_path_length=200, n_itr=500, discount=0.99, step_size=0.01, 22 | optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)), 23 | ma_mode='decentralized') 24 | # policies = [GaussianGRUPolicy(env_spec=env.spec, name='policy_{}'.format(i)) for i in range(3)] 25 | # baselines = [LinearFeatureBaseline(env_spec=env.spec) for _ in range(3)] 26 | # algo = MATRPO(env=env, policy_or_policies=policies, baseline_or_baselines=baselines, 27 | # batch_size=8000, max_path_length=200, n_itr=500, discount=0.99, step_size=0.01, 28 | # optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)), 29 | # ma_mode='concurrent') 30 | 31 | algo.train() 32 | -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import argparse 4 | import sys 5 | import datetime 6 | import dateutil 7 | import dateutil.tz 8 | import uuid 9 | import ast 10 | 11 | from . import archs 12 | 13 | 14 | def tonamedtuple(dictionary): 15 | for key, value in dictionary.items(): 16 | if isinstance(value, dict): 17 | dictionary[key] = tonamedtuple(value) 18 | return namedtuple('GenericDict', dictionary.keys())(**dictionary) 19 | 20 | 21 | def get_arch(name): 22 | constructor = getattr(archs, name) 23 | return constructor 24 | 25 | 26 | def comma_sep_ints(s): 27 | if s: 28 | return list(map(int, s.split(","))) 29 | else: 30 | return [] 31 | 32 | 33 | class RunnerParser(object): 34 | 35 | DEFAULT_OPTS = [ 36 | ('discount', float, 0.95, ''), 37 | ('gae_lambda', float, 0.99, ''), 38 | ('n_iter', int, 500, ''), 39 | ] 40 | 41 | DEFAULT_POLICY_OPTS = [ 42 | ('control', str, 'decentralized', ''), 43 | ('recurrent', str, None, ''), 44 | ('baseline_type', str, 'linear', ''), 45 | ] 46 | 47 | def __init__(self, env_options, **kwargs): 48 | self._env_options = env_options 49 | parser = argparse.ArgumentParser(description='Runner') 50 | 51 | parser.add_argument('mode', help='rllab or rltools') 52 | args = parser.parse_args(sys.argv[1:2]) 53 | if not hasattr(self, args.mode): 54 | print('Unrecognized command') 55 | parser.print_help() 56 | exit(1) 57 | 58 | self._mode = args.mode 59 | getattr(self, args.mode)(self._env_options, **kwargs) 60 | 61 | def update_argument_parser(self, parser, options, **kwargs): 62 | kwargs = kwargs.copy() 63 | for (name, typ, default, desc) in options: 64 | flag = "--" + name 65 | if flag in parser._option_string_actions.keys(): #pylint: disable=W0212 66 | print("warning: already have option %s. skipping" % name) 67 | else: 68 | parser.add_argument(flag, type=typ, default=kwargs.pop(name, default), help=desc or 69 | " ") 70 | if kwargs: 71 | raise ValueError("options %s ignored" % kwargs) 72 | 73 | def rllab(self, env_options, **kwargs): 74 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 75 | rand_id = str(uuid.uuid4())[:5] 76 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') 77 | default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) 78 | 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('--exp_name', type=str, default=default_exp_name) 81 | self.update_argument_parser(parser, self.DEFAULT_OPTS) 82 | self.update_argument_parser(parser, self.DEFAULT_POLICY_OPTS) 83 | 84 | parser.add_argument( 85 | '--algo', type=str, default='tftrpo', 86 | help='Add tf or th to the algo name to run tensorflow or theano version') 87 | 88 | parser.add_argument('--max_path_length', type=int, default=500) 89 | parser.add_argument('--batch_size', type=int, default=12000) 90 | parser.add_argument('--n_parallel', type=int, default=1) 91 | parser.add_argument('--resume_from', type=str, default=None, 92 | help='Name of the pickle file to resume experiment from.') 93 | 94 | parser.add_argument('--epoch_length', type=int, default=1000) 95 | parser.add_argument('--min_pool_size', type=int, default=10000) 96 | parser.add_argument('--replay_pool_size', type=int, default=500000) 97 | parser.add_argument('--eval_samples', type=int, default=50000) 98 | parser.add_argument('--qfunc_lr', type=float, default=1e-3) 99 | parser.add_argument('--policy_lr', type=float, default=1e-4) 100 | 101 | parser.add_argument('--feature_net', type=str, default=None) 102 | parser.add_argument('--feature_output', type=int, default=16) 103 | parser.add_argument('--feature_hidden', type=comma_sep_ints, default='128,64,32') 104 | parser.add_argument('--policy_hidden', type=comma_sep_ints, default='32') 105 | parser.add_argument('--conv', type=str, default='') 106 | parser.add_argument('--conv_filters', type=comma_sep_ints, default='3,3') 107 | parser.add_argument('--conv_channels', type=comma_sep_ints, default='4,8') 108 | parser.add_argument('--conv_strides', type=comma_sep_ints, default='1,1') 109 | parser.add_argument('--min_std', type=float, default=1e-6) 110 | parser.add_argument('--exp_strategy', type=str, default='ou') 111 | parser.add_argument('--exp_noise', type=float, default=0.3) 112 | 113 | parser.add_argument('--step_size', type=float, default=0.01, help='max kl wall limit') 114 | 115 | parser.add_argument('--log_dir', type=str, required=False) 116 | parser.add_argument('--tabular_log_file', type=str, default='progress.csv', 117 | help='Name of the tabular log file (in csv).') 118 | parser.add_argument('--text_log_file', type=str, default='debug.log', 119 | help='Name of the text log file (in pure text).') 120 | parser.add_argument('--params_log_file', type=str, default='params.json', 121 | help='Name of the parameter log file (in json).') 122 | parser.add_argument('--seed', type=int, help='Random seed for numpy') 123 | parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') 124 | parser.add_argument('--snapshot_mode', type=str, default='all', 125 | help='Mode to save the snapshot. Can be either "all" ' 126 | '(all iterations will be saved), "last" (only ' 127 | 'the last iteration will be saved), or "none" ' 128 | '(do not save snapshots)') 129 | parser.add_argument( 130 | '--log_tabular_only', type=ast.literal_eval, default=False, 131 | help='Whether to only print the tabular log information (in a horizontal format)') 132 | 133 | self.update_argument_parser(parser, env_options, **kwargs) 134 | self.args = parser.parse_known_args( 135 | [arg for arg in sys.argv[2:] if arg not in ('-h', '--help')])[0] 136 | 137 | def rltools(self, env_options, **kwargs): 138 | parser = argparse.ArgumentParser() 139 | self.update_argument_parser(parser, self.DEFAULT_OPTS) 140 | self.update_argument_parser(parser, self.DEFAULT_POLICY_OPTS) 141 | 142 | parser.add_argument('--sampler', type=str, default='simple') 143 | parser.add_argument('--sampler_workers', type=int, default=1) 144 | parser.add_argument('--max_traj_len', type=int, default=500) 145 | parser.add_argument('--n_timesteps', type=int, default=12000) 146 | 147 | parser.add_argument('--adaptive_batch', action='store_true', default=False) 148 | parser.add_argument('--n_timesteps_min', type=int, default=4000) 149 | parser.add_argument('--n_timesteps_max', type=int, default=64000) 150 | parser.add_argument('--timestep_rate', type=int, default=20) 151 | 152 | parser.add_argument('--policy_hidden_spec', type=get_arch, default='GAE_ARCH') 153 | parser.add_argument('--baseline_hidden_spec', type=get_arch, default='GAE_ARCH') 154 | parser.add_argument('--min_std', type=float, default=1e-6) 155 | parser.add_argument('--max_kl', type=float, default=0.01) 156 | parser.add_argument('--vf_max_kl', type=float, default=0.01) 157 | parser.add_argument('--vf_cg_damping', type=float, default=0.01) 158 | parser.add_argument('--enable_obsnorm', action='store_true') 159 | parser.add_argument('--enable_rewnorm', action='store_true') 160 | parser.add_argument('--enable_vnorm', action='store_true') 161 | 162 | parser.add_argument('--interp_alpha', type=float, default=0.1) 163 | parser.add_argument('--blend_freq', type=int, default=0) 164 | parser.add_argument('--blend_eval_trajs', type=int, default=50) 165 | parser.add_argument('--keep_kmax', type=int, default=0) 166 | 167 | parser.add_argument('--save_freq', type=int, default=10) 168 | parser.add_argument('--log', type=str, required=False) 169 | parser.add_argument('--tblog', type=str, default='/tmp/madrl_tb_{}'.format(uuid.uuid4())) 170 | parser.add_argument('--no-debug', dest='debug', action='store_false') 171 | parser.set_defaults(debug=True) 172 | self.update_argument_parser(parser, env_options, **kwargs) 173 | self.args = parser.parse_known_args( 174 | [arg for arg in sys.argv[2:] if arg not in ('-h', '--help')])[0] 175 | -------------------------------------------------------------------------------- /runners/archs.py: -------------------------------------------------------------------------------- 1 | SIMPLE_POLICY_ARCH = '''[ 2 | {"type": "fc", "n": 128}, 3 | {"type": "nonlin", "func": "tanh"}, 4 | {"type": "fc", "n": 128}, 5 | {"type": "nonlin", "func": "tanh"} 6 | ] 7 | ''' 8 | 9 | TINY_VAL_ARCH = '''[ 10 | {"type": "fc", "n": 32}, 11 | {"type": "nonlin", "func": "relu"}, 12 | {"type": "fc", "n": 32}, 13 | {"type": "nonlin", "func": "relu"} 14 | ] 15 | ''' 16 | 17 | SIMPLE_VAL_ARCH = '''[ 18 | {"type": "fc", "n": 128}, 19 | {"type": "nonlin", "func": "tanh"}, 20 | {"type": "fc", "n": 128}, 21 | {"type": "nonlin", "func": "tanh"} 22 | ] 23 | ''' 24 | 25 | GAE_TYPE_VAL_ARCH = '''[ 26 | {"type": "fc", "n": 128}, 27 | {"type": "nonlin", "func": "tanh"}, 28 | {"type": "fc", "n": 64}, 29 | {"type": "nonlin", "func": "tanh"}, 30 | {"type": "fc", "n": 32}, 31 | {"type": "nonlin", "func": "tanh"} 32 | ] 33 | ''' 34 | 35 | GAE_ARCH = '''[ 36 | {"type": "fc", "n": 100}, 37 | {"type": "nonlin", "func": "tanh"}, 38 | {"type": "fc", "n": 50}, 39 | {"type": "nonlin", "func": "tanh"}, 40 | {"type": "fc", "n": 25}, 41 | {"type": "nonlin", "func": "tanh"} 42 | ] 43 | ''' 44 | 45 | MED_POLICY_ARCH = '''[ 46 | {"type": "fc", "n": 256}, 47 | {"type": "nonlin", "func": "tanh"}, 48 | {"type": "fc", "n": 128}, 49 | {"type": "nonlin", "func": "tanh"}, 50 | {"type": "fc", "n": 64}, 51 | {"type": "nonlin", "func": "tanh"} 52 | ] 53 | ''' 54 | 55 | LARGE_POLICY_ARCH = '''[ 56 | {"type": "fc", "n": 512}, 57 | {"type": "nonlin", "func": "tanh"}, 58 | {"type": "fc", "n": 256}, 59 | {"type": "nonlin", "func": "tanh"}, 60 | {"type": "fc", "n": 128}, 61 | {"type": "nonlin", "func": "tanh"} 62 | ] 63 | ''' 64 | 65 | LARGE_VAL_ARCH = '''[ 66 | {"type": "fc", "n": 512}, 67 | {"type": "nonlin", "func": "tanh"}, 68 | {"type": "fc", "n": 256}, 69 | {"type": "nonlin", "func": "tanh"}, 70 | {"type": "fc", "n": 128}, 71 | {"type": "nonlin", "func": "tanh"} 72 | ] 73 | ''' 74 | 75 | HUGE_POLICY_ARCH = '''[ 76 | {"type": "fc", "n": 1028}, 77 | {"type": "nonlin", "func": "tanh"}, 78 | {"type": "fc", "n": 512}, 79 | {"type": "nonlin", "func": "tanh"}, 80 | {"type": "fc", "n": 256}, 81 | {"type": "nonlin", "func": "tanh"}, 82 | {"type": "fc", "n": 128}, 83 | {"type": "nonlin", "func": "tanh"} 84 | ] 85 | ''' 86 | 87 | HUGE_VAL_ARCH = '''[ 88 | {"type": "fc", "n": 1028}, 89 | {"type": "nonlin", "func": "tanh"}, 90 | {"type": "fc", "n": 512}, 91 | {"type": "nonlin", "func": "tanh"}, 92 | {"type": "fc", "n": 256}, 93 | {"type": "nonlin", "func": "tanh"}, 94 | {"type": "fc", "n": 128}, 95 | {"type": "nonlin", "func": "tanh"} 96 | ] 97 | ''' 98 | 99 | SIMPLE_CONV_ARCH = '''[ 100 | {"type": "conv", "chanout": 16, "filtsize": 3, "stride": 1, "padding": "VALID"}, 101 | {"type": "nonlin", "func": "relu"}, 102 | {"type": "conv", "chanout": 8, "filtsize": 3, "stride": 1, "padding": "VALID"}, 103 | {"type": "nonlin", "func": "relu"}, 104 | {"type": "flatten"} 105 | ] 106 | ''' 107 | 108 | SIMPLE_GRU_ARCH = '''{"gru_hidden_dim": 32, "gru_hidden_nonlin": "tanh", "gru_hidden_init_trainable": false}''' 109 | 110 | SIMPLE_GAE_FEAT_GRU_ARCH = '''{ 111 | "feature_network": [ 112 | {"type": "flatten"}, 113 | {"type": "fc", "n": 100}, 114 | {"type": "nonlin", "func": "tanh"}, 115 | {"type": "fc", "n": 50}, 116 | {"type": "nonlin", "func": "tanh"}, 117 | {"type": "fc", "n": 25}, 118 | {"type": "nonlin", "func": "tanh"} 119 | ], 120 | "gru_hidden_dim": 32, "gru_hidden_nonlin": "tanh", "gru_hidden_init_trainable": false 121 | }''' 122 | -------------------------------------------------------------------------------- /runners/curriculum.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import yaml.constructor 3 | from collections import OrderedDict 4 | 5 | 6 | class OrderedDictYAMLLoader(yaml.Loader): 7 | """ 8 | A YAML loader that loads mappings into ordered dictionaries. 9 | """ 10 | 11 | def __init__(self, *args, **kwargs): 12 | yaml.Loader.__init__(self, *args, **kwargs) 13 | 14 | self.add_constructor(u'tag:yaml.org,2002:map', type(self).construct_yaml_map) 15 | self.add_constructor(u'tag:yaml.org,2002:omap', type(self).construct_yaml_map) 16 | 17 | def construct_yaml_map(self, node): 18 | data = OrderedDict() 19 | yield data 20 | value = self.construct_mapping(node) 21 | data.update(value) 22 | 23 | def construct_mapping(self, node, deep=False): 24 | if isinstance(node, yaml.MappingNode): 25 | self.flatten_mapping(node) 26 | else: 27 | raise yaml.constructor.ConstructorError(None, None, 28 | 'expected a mapping node, but found %s' % 29 | node.id, node.start_mark) 30 | 31 | mapping = OrderedDict() 32 | for key_node, value_node in node.value: 33 | key = self.construct_object(key_node, deep=deep) 34 | try: 35 | hash(key) 36 | except TypeError as exc: 37 | raise yaml.constructor.ConstructorError('while constructing a mapping', 38 | node.start_mark, 39 | 'found unacceptable key (%s)' % exc, 40 | key_node.start_mark) 41 | value = self.construct_object(value_node, deep=deep) 42 | mapping[key] = value 43 | return mapping 44 | 45 | 46 | class Task(object): 47 | 48 | def __init__(self, name, prop): 49 | self.name = name 50 | self.prop = prop 51 | 52 | def __hash__(self): 53 | return hash(self.name) 54 | 55 | 56 | class Curriculum(object): 57 | 58 | def __init__(self, config): 59 | with open(config, 'r') as f: 60 | self.config = yaml.load(f, OrderedDictYAMLLoader) 61 | 62 | self._tasks = list([Task(k, v) for k, v in self.config['tasks'].items()]) 63 | self._lesson_threshold = self.config['thresholds']['lesson'] 64 | self._stop_threshold = self.config['thresholds']['stop'] 65 | self._n_trials = self.config['n_trials'] 66 | self._metric = self.config['metric'] 67 | self._eval_trials = self.config['eval_trials'] 68 | 69 | @property 70 | def tasks(self): 71 | return self._tasks 72 | 73 | @property 74 | def lesson_threshold(self): 75 | return self._lesson_threshold 76 | 77 | @property 78 | def stop_threshold(self): 79 | return self._stop_threshold 80 | 81 | @property 82 | def n_trials(self): 83 | return self._n_trials 84 | 85 | @property 86 | def metric(self): 87 | return self._metric 88 | 89 | @property 90 | def eval_trials(self): 91 | return self._eval_trials 92 | -------------------------------------------------------------------------------- /runners/old/rllab/pursuit.sh: -------------------------------------------------------------------------------- 1 | python run_pursuit.py --log_dir pursuit_mlp_cont --n_iter 500 --n_evaders 50 --n_pursuers 30 --control decentralized --n_timesteps 60000 --obs_range 11 --sample_maps --map_file ../../maps/map_pool32.npy --policy_hidden_sizes 256,128,64 --flatten --sampler_workers 1 --surround --discount 0.99 --max_kl 0.05 --reward_mech local --max_traj_len 750 --checkpoint pursuit_mlp/itr_65.pkl 2 | -------------------------------------------------------------------------------- /runners/old/rllab/pursuit_cnn.sh: -------------------------------------------------------------------------------- 1 | python run_pursuit.py --log_dir pursuit_cnn_huge --n_iter 500 --n_evaders 300 --n_pursuers 100 --control decentralized --n_timesteps 40000 --obs_range 21 --sample_maps --map_file ../../maps/map_pool128.npy --sampler_workers 1 --surround --discount 0.99 --max_kl 0.1 --reward_mech local --conv --max_traj_len 500 --baseline_type zero 2 | -------------------------------------------------------------------------------- /runners/old/rllab/pursuit_test.sh: -------------------------------------------------------------------------------- 1 | python run_pursuit.py --log_dir gru_test --n_iter 500 --n_evaders 5 --n_pursuers 4 --control decentralized --n_timesteps 10000 --obs_range 5 --map_file ../maps/map_pool32.npy --policy_hidden_sizes 32 --flatten --recurrent --sampler_workers 1 --surround --discount 0.99 --max_kl 0.1 2 | -------------------------------------------------------------------------------- /runners/old/rllab/run_hostage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: run_pursuit.py 4 | # 5 | # Created: Wednesday, July 6 2016 by megorov 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import uuid 12 | import datetime 13 | import dateutil.tz 14 | import os.path as osp 15 | import ast 16 | 17 | import gym 18 | import numpy as np 19 | import tensorflow as tf 20 | from gym import spaces 21 | 22 | from madrl_environments.pursuit import MAWaterWorld 23 | from madrl_environments import StandardizedEnv 24 | from rllabwrapper import RLLabEnv 25 | 26 | from rllab.algos.trpo import TRPO 27 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 28 | from rllab.envs.normalized_env import normalize 29 | from rllab.policies.gaussian_mlp_policy import GaussianMLPPolicy 30 | from rllab.policies.gaussian_gru_policy import GaussianGRUPolicy 31 | from rllab.sampler import parallel_sampler 32 | import rllab.misc.logger as logger 33 | from rllab import config 34 | 35 | def main(): 36 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 37 | rand_id = str(uuid.uuid4())[:5] 38 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') 39 | default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument( 43 | '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') 44 | 45 | parser.add_argument('--discount', type=float, default=0.95) 46 | parser.add_argument('--gae_lambda', type=float, default=0.99) 47 | 48 | parser.add_argument('--n_iter', type=int, default=250) 49 | parser.add_argument('--sampler_workers', type=int, default=1) 50 | parser.add_argument('--max_traj_len', type=int, default=250) 51 | parser.add_argument('--update_curriculum', action='store_true', default=False) 52 | parser.add_argument('--n_timesteps', type=int, default=8000) 53 | parser.add_argument('--control', type=str, default='centralized') 54 | 55 | parser.add_argument('--control', type=str, default='centralized') 56 | parser.add_argument('--buffer_size', type=int, default=1) 57 | parser.add_argument('--n_good', type=int, default=3) 58 | parser.add_argument('--n_hostage', type=int, default=5) 59 | parser.add_argument('--n_bad', type=int, default=5) 60 | parser.add_argument('--n_coop_save', type=int, default=2) 61 | parser.add_argument('--n_coop_avoid', type=int, default=2) 62 | parser.add_argument('--n_sensors', type=int, default=20) 63 | parser.add_argument('--sensor_range', type=float, default=0.2) 64 | parser.add_argument('--save_reward', type=float, default=3) 65 | parser.add_argument('--hit_reward', type=float, default=-1) 66 | parser.add_argument('--encounter_reward', type=float, default=0.01) 67 | parser.add_argument('--bomb_reward', type=float, default=-10.) 68 | 69 | parser.add_argument('--recurrent', action='store_true', default=False) 70 | parser.add_argument('--baseline_type', type=str, default='linear') 71 | parser.add_argument('--policy_hidden_sizes', type=str, default='128,128') 72 | parser.add_argument('--baselin_hidden_sizes', type=str, default='128,128') 73 | 74 | parser.add_argument('--max_kl', type=float, default=0.01) 75 | 76 | parser.add_argument('--log_dir', type=str, required=False) 77 | parser.add_argument('--tabular_log_file', type=str, default='progress.csv', 78 | help='Name of the tabular log file (in csv).') 79 | parser.add_argument('--text_log_file', type=str, default='debug.log', 80 | help='Name of the text log file (in pure text).') 81 | parser.add_argument('--params_log_file', type=str, default='params.json', 82 | help='Name of the parameter log file (in json).') 83 | parser.add_argument('--seed', type=int, 84 | help='Random seed for numpy') 85 | parser.add_argument('--args_data', type=str, 86 | help='Pickled data for stub objects') 87 | parser.add_argument('--snapshot_mode', type=str, default='all', 88 | help='Mode to save the snapshot. Can be either "all" ' 89 | '(all iterations will be saved), "last" (only ' 90 | 'the last iteration will be saved), or "none" ' 91 | '(do not save snapshots)') 92 | parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False, 93 | help='Whether to only print the tabular log information (in a horizontal format)') 94 | 95 | 96 | args = parser.parse_args() 97 | 98 | parallel_sampler.initialize(n_parallel=args.sampler_workers) 99 | 100 | if args.seed is not None: 101 | set_seed(args.seed) 102 | parallel_sampler.set_seed(args.seed) 103 | 104 | args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(','))) 105 | 106 | centralized = True if args.control == 'centralized' else False 107 | 108 | sensor_range = np.array(map(float, args.sensor_range.split(','))) 109 | assert sensor_range.shape == (args.n_pursuers,) 110 | 111 | env = ContinuousHostageWorld(args.n_good, args.n_hostage, args.n_bad, args.n_coop_save, 112 | args.n_coop_avoid, n_sensors=args.n_sensors, 113 | sensor_range=args.sensor_range, save_reward=args.save_reward, 114 | hit_reward=args.hit_reward, encounter_reward=args.encounter_reward, 115 | bomb_reward=args.bomb_reward) 116 | 117 | env = RLLabEnv(StandardizedEnv(env), mode=args.control) 118 | 119 | if args.buffer_size > 1: 120 | env = ObservationBuffer(env, args.buffer_size) 121 | 122 | if args.recurrent: 123 | policy = GaussianGRUPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes) 124 | else: 125 | policy = GaussianMLPPolicy(env_spec=env.spec, hidden_sizes=args.hidden_sizes) 126 | 127 | if args.baseline_type == 'linear': 128 | baseline = LinearFeatureBaseline(env_spec=env.spec) 129 | else: 130 | baseline = ZeroBaseline(obsfeat_space) 131 | 132 | # logger 133 | default_log_dir = config.LOG_DIR 134 | if args.log_dir is None: 135 | log_dir = osp.join(default_log_dir, args.exp_name) 136 | else: 137 | log_dir = args.log_dir 138 | tabular_log_file = osp.join(log_dir, args.tabular_log_file) 139 | text_log_file = osp.join(log_dir, args.text_log_file) 140 | params_log_file = osp.join(log_dir, args.params_log_file) 141 | 142 | logger.log_parameters_lite(params_log_file, args) 143 | logger.add_text_output(text_log_file) 144 | logger.add_tabular_output(tabular_log_file) 145 | prev_snapshot_dir = logger.get_snapshot_dir() 146 | prev_mode = logger.get_snapshot_mode() 147 | logger.set_snapshot_dir(log_dir) 148 | logger.set_snapshot_mode(args.snapshot_mode) 149 | logger.set_log_tabular_only(args.log_tabular_only) 150 | logger.push_prefix("[%s] " % args.exp_name) 151 | 152 | algo = TRPO(env=env, 153 | policy=policy, 154 | baseline=baseline, 155 | batch_size=args.n_timesteps, 156 | max_path_length=args.max_traj_len, 157 | n_itr=args.n_iter, 158 | discount=args.discount, 159 | step_size=args.max_kl, 160 | mode=args.control,) 161 | 162 | algo.train() 163 | 164 | 165 | 166 | if __name__ == '__main__': 167 | main() 168 | -------------------------------------------------------------------------------- /runners/old/rllab/run_walker.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | 3 | import argparse 4 | import json 5 | import uuid 6 | import datetime 7 | import dateutil.tz 8 | import os.path as osp 9 | import ast 10 | import joblib 11 | 12 | import gym 13 | import numpy as np 14 | import tensorflow as tf 15 | from gym import spaces 16 | 17 | from madrl_environments.walker.multi_walker import MultiWalkerEnv 18 | from madrl_environments import StandardizedEnv, ObservationBuffer 19 | from rllabwrapper import RLLabEnv 20 | 21 | from sandbox.rocky.tf.algos.trpo import TRPO 22 | from sandbox.rocky.tf.envs.base import TfEnv 23 | from sandbox.rocky.tf.core.network import MLP 24 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy 25 | from sandbox.rocky.tf.policies.gaussian_gru_policy import GaussianGRUPolicy 26 | from sandbox.rocky.tf.policies.gaussian_lstm_policy import GaussianLSTMPolicy 27 | from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer, FiniteDifferenceHvp 28 | 29 | # from rllab.algos.trpo import TRPO 30 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 31 | # from rllab.baselines.gaussian_mlp_baseline import GaussianMLPBaseline 32 | from rllab.baselines.zero_baseline import ZeroBaseline 33 | # from rllab.envs.normalized_env import normalize 34 | # from rllab.policies.gaussian_mlp_policy import GaussianMLPPolicy 35 | # from rllab.policies.gaussian_gru_policy import GaussianGRUPolicy 36 | from rllab.sampler import parallel_sampler 37 | import rllab.misc.logger as logger 38 | from rllab.misc.ext import set_seed 39 | from rllab import config 40 | 41 | 42 | def main(): 43 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 44 | rand_id = str(uuid.uuid4())[:5] 45 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') 46 | default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) 47 | 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--exp_name', type=str, default=default_exp_name, 50 | help='Name of the experiment.') 51 | 52 | parser.add_argument('--discount', type=float, default=0.99) 53 | parser.add_argument('--gae_lambda', type=float, default=1.0) 54 | parser.add_argument('--reward_scale', type=float, default=1.0) 55 | 56 | parser.add_argument('--n_iter', type=int, default=250) 57 | parser.add_argument('--sampler_workers', type=int, default=1) 58 | parser.add_argument('--max_traj_len', type=int, default=500) 59 | parser.add_argument('--update_curriculum', action='store_true', default=False) 60 | 61 | parser.add_argument('--n_timesteps', type=int, default=10000) 62 | 63 | parser.add_argument('--control', type=str, default='centralized') 64 | 65 | parser.add_argument('--n_walkers', type=int, default=2) 66 | parser.add_argument('--reward_mech', type=str, default='local') 67 | 68 | parser.add_argument('--recurrent', type=str, default=None) 69 | parser.add_argument('--baseline_type', type=str, default='linear') 70 | parser.add_argument('--policy_hidden_sizes', type=str, default='128,128') 71 | parser.add_argument('--baseline_hidden_sizes', type=str, default='128,128') 72 | 73 | parser.add_argument('--max_kl', type=float, default=0.01) 74 | 75 | parser.add_argument('--log_dir', type=str, required=False) 76 | parser.add_argument('--tabular_log_file', type=str, default='progress.csv', 77 | help='Name of the tabular log file (in csv).') 78 | parser.add_argument('--text_log_file', type=str, default='debug.log', 79 | help='Name of the text log file (in pure text).') 80 | parser.add_argument('--params_log_file', type=str, default='params.json', 81 | help='Name of the parameter log file (in json).') 82 | parser.add_argument('--seed', type=int, help='Random seed for numpy') 83 | parser.add_argument('--args_data', type=str, help='Pickled data for stub objects') 84 | parser.add_argument('--snapshot_mode', type=str, default='all', 85 | help='Mode to save the snapshot. Can be either "all" ' 86 | '(all iterations will be saved), "last" (only ' 87 | 'the last iteration will be saved), or "none" ' 88 | '(do not save snapshots)') 89 | parser.add_argument( 90 | '--log_tabular_only', type=ast.literal_eval, default=False, 91 | help='Whether to only print the tabular log information (in a horizontal format)') 92 | parser.add_argument('--checkpoint', type=str, default=None) 93 | 94 | args = parser.parse_args() 95 | 96 | parallel_sampler.initialize(n_parallel=args.sampler_workers) 97 | 98 | if args.seed is not None: 99 | set_seed(args.seed) 100 | parallel_sampler.set_seed(args.seed) 101 | 102 | args.hidden_sizes = tuple(map(int, args.policy_hidden_sizes.split(','))) 103 | 104 | centralized = True if args.control == 'centralized' else False 105 | 106 | with tf.Session() as sess: 107 | if args.checkpoint: 108 | data = joblib.load(args.checkpoint) 109 | policy = data['policy'] 110 | env = data['env'] 111 | 112 | import IPython 113 | IPython.embed() 114 | else: 115 | env = MultiWalkerEnv(n_walkers=args.n_walkers) 116 | 117 | env = TfEnv( 118 | RLLabEnv( 119 | StandardizedEnv(env, scale_reward=args.reward_scale, enable_obsnorm=True), 120 | mode=args.control)) 121 | 122 | if args.recurrent: 123 | feature_network = MLP( 124 | name='feature_net', 125 | input_shape=(env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,), 126 | output_dim=4, hidden_sizes=(128,64,32), hidden_nonlinearity=tf.nn.tanh, 127 | output_nonlinearity=None) 128 | if args.recurrent == 'gru': 129 | policy = GaussianGRUPolicy(env_spec=env.spec, feature_network=feature_network, 130 | hidden_dim=int(args.policy_hidden_sizes), name='policy') 131 | elif args.recurrent == 'lstm': 132 | policy = GaussianLSTMPolicy(env_spec=env.spec, feature_network=feature_network, 133 | hidden_dim=int(args.policy_hidden_sizes), name='policy') 134 | else: 135 | policy = GaussianMLPPolicy(name='policy', 136 | env_spec=env.spec, hidden_sizes=tuple(map(int, args.policy_hidden_sizes.split(',')))) 137 | 138 | if args.baseline_type == 'linear': 139 | baseline = LinearFeatureBaseline(env_spec=env.spec) 140 | elif args.baseline_type == 'mlp': 141 | raise NotImplementedError() 142 | # baseline = GaussianMLPBaseline( 143 | # env_spec=env.spec, hidden_sizes=tuple(map(int, args.baseline_hidden_sizes.split(',')))) 144 | else: 145 | baseline = ZeroBaseline(env_spec=env.spec) 146 | 147 | # logger 148 | default_log_dir = config.LOG_DIR 149 | if args.log_dir is None: 150 | log_dir = osp.join(default_log_dir, args.exp_name) 151 | else: 152 | log_dir = args.log_dir 153 | tabular_log_file = osp.join(log_dir, args.tabular_log_file) 154 | text_log_file = osp.join(log_dir, args.text_log_file) 155 | params_log_file = osp.join(log_dir, args.params_log_file) 156 | 157 | logger.log_parameters_lite(params_log_file, args) 158 | logger.add_text_output(text_log_file) 159 | logger.add_tabular_output(tabular_log_file) 160 | prev_snapshot_dir = logger.get_snapshot_dir() 161 | prev_mode = logger.get_snapshot_mode() 162 | logger.set_snapshot_dir(log_dir) 163 | logger.set_snapshot_mode(args.snapshot_mode) 164 | logger.set_log_tabular_only(args.log_tabular_only) 165 | logger.push_prefix("[%s] " % args.exp_name) 166 | 167 | algo = TRPO( 168 | env=env, 169 | policy=policy, 170 | baseline=baseline, 171 | batch_size=args.n_timesteps, 172 | max_path_length=args.max_traj_len, 173 | n_itr=args.n_iter, 174 | discount=args.discount, 175 | gae_lambda=args.gae_lambda, 176 | step_size=args.max_kl, 177 | optimizer=ConjugateGradientOptimizer(hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if 178 | args.recurrent else None, 179 | mode=args.control,) 180 | 181 | algo.train() 182 | 183 | 184 | if __name__ == '__main__': 185 | main() 186 | -------------------------------------------------------------------------------- /runners/old/rltools/__init__.py: -------------------------------------------------------------------------------- 1 | import runners.archs 2 | 3 | 4 | def get_arch(name): 5 | constructor = getattr(runners.archs, name) 6 | return constructor 7 | -------------------------------------------------------------------------------- /runners/old/rltools/pursuit.sh: -------------------------------------------------------------------------------- 1 | python run_pursuit.py --log pursuit_med_32_local_catchr0.h5 --n_iter 500 --n_evaders 30 --n_pursuers 30 --control decentralized --n_timesteps 10000 --obs_range 11 --sample_maps --sampler parallel --map_file ../maps/map_pool32.npy --policy_hidden_spec MED_POLICY_ARCH --baseline_hidden_spec MED_POLICY_ARCH --sampler_workers 6 --flatten --reward_mech local --surround --cur_remove 15 --cur_shaping 250 --update_curriculum --load_checkpoint pursuit_med_32_local_30_catchr0.h5/snapshots/iter0000025 --catchr 0.0 --term_pursuit 5.0 --max_traj_len 1000 2 | -------------------------------------------------------------------------------- /runners/old/rltools/run_con_hostage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: run_con_hostage.py 4 | # 5 | # Created: Sunday, August 14 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | 15 | from gym import spaces 16 | import rltools.algos.policyopt 17 | import rltools.log 18 | import rltools.util 19 | from rltools.samplers.serial import SimpleSampler, ImportanceWeightedSampler, DecSampler 20 | from rltools.samplers.parallel import ThreadedSampler, ParallelSampler 21 | from madrl_environments import ObservationBuffer 22 | from madrl_environments.hostage import ContinuousHostageWorld 23 | from rltools.baselines.linear import LinearFeatureBaseline 24 | from rltools.baselines.mlp import MLPBaseline 25 | from rltools.baselines.zero import ZeroBaseline 26 | from rltools.policy.gaussian import GaussianMLPPolicy 27 | 28 | from runners.archs import * 29 | 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--discount', type=float, default=0.95) 34 | parser.add_argument('--gae_lambda', type=float, default=0.99) 35 | 36 | parser.add_argument('--interp_alpha', type=float, default=0.5) 37 | parser.add_argument('--policy_avg_weights', type=str, default='0.3333333,0.3333333,0.3333333') 38 | 39 | parser.add_argument('--n_iter', type=int, default=250) 40 | parser.add_argument('--sampler', type=str, default='simple') 41 | parser.add_argument('--sampler_workers', type=int, default=4) 42 | parser.add_argument('--max_traj_len', type=int, default=500) 43 | parser.add_argument('--adaptive_batch', action='store_true', default=False) 44 | 45 | parser.add_argument('--n_timesteps', type=int, default=8000) 46 | parser.add_argument('--n_timesteps_min', type=int, default=1000) 47 | parser.add_argument('--n_timesteps_max', type=int, default=64000) 48 | parser.add_argument('--timestep_rate', type=int, default=20) 49 | 50 | parser.add_argument('--is_n_backtrack', type=int, default=1) 51 | parser.add_argument('--is_randomize_draw', action='store_true', default=False) 52 | parser.add_argument('--is_n_pretrain', type=int, default=0) 53 | parser.add_argument('--is_skip_is', action='store_true', default=False) 54 | parser.add_argument('--is_max_is_ratio', type=float, default=0) 55 | 56 | parser.add_argument('--buffer_size', type=int, default=1) 57 | parser.add_argument('--n_good', type=int, default=3) 58 | parser.add_argument('--n_hostage', type=int, default=5) 59 | parser.add_argument('--n_bad', type=int, default=5) 60 | parser.add_argument('--n_coop_save', type=int, default=2) 61 | parser.add_argument('--n_coop_avoid', type=int, default=2) 62 | parser.add_argument('--n_sensors', type=int, default=20) 63 | parser.add_argument('--sensor_range', type=float, default=0.2) 64 | parser.add_argument('--save_reward', type=float, default=3) 65 | parser.add_argument('--hit_reward', type=float, default=-1) 66 | parser.add_argument('--encounter_reward', type=float, default=0.01) 67 | parser.add_argument('--bomb_reward', type=float, default=-10.) 68 | 69 | parser.add_argument('--policy_hidden_spec', type=str, default=GAE_ARCH) 70 | parser.add_argument('--min_std', type=float, default=0) 71 | parser.add_argument('--blend_freq', type=int, default=20) 72 | 73 | parser.add_argument('--baseline_type', type=str, default='mlp') 74 | parser.add_argument('--baseline_hidden_spec', type=str, default=GAE_ARCH) 75 | 76 | parser.add_argument('--max_kl', type=float, default=0.01) 77 | parser.add_argument('--vf_max_kl', type=float, default=0.01) 78 | parser.add_argument('--vf_cg_damping', type=float, default=0.01) 79 | 80 | parser.add_argument('--save_freq', type=int, default=20) 81 | parser.add_argument('--log', type=str, required=False) 82 | parser.add_argument('--tblog', type=str, default='/tmp/madrl_tb') 83 | parser.add_argument('--debug', dest='debug', action='store_true') 84 | parser.add_argument('--no-debug', dest='debug', action='store_false') 85 | parser.set_defaults(debug=True) 86 | 87 | args = parser.parse_args() 88 | 89 | policy_avg_weights = np.array(map(float, args.policy_avg_weights.split(','))) 90 | assert len(policy_avg_weights) == args.n_good 91 | 92 | env = ContinuousHostageWorld(args.n_good, args.n_hostage, args.n_bad, args.n_coop_save, 93 | args.n_coop_avoid, n_sensors=args.n_sensors, 94 | sensor_range=args.sensor_range, save_reward=args.save_reward, 95 | hit_reward=args.hit_reward, encounter_reward=args.encounter_reward, 96 | bomb_reward=args.bomb_reward) 97 | 98 | if args.buffer_size > 1: 99 | env = ObservationBuffer(env, args.buffer_size) 100 | 101 | policies = [GaussianMLPPolicy(agent.observation_space, agent.action_space, 102 | hidden_spec=args.policy_hidden_spec, enable_obsnorm=True, 103 | min_stdev=args.min_std, init_logstdev=0., tblog=args.tblog, 104 | varscope_name='gaussmlp_policy_{}'.format(agid)) 105 | for agid, agent in enumerate(env.agents)] 106 | if args.blend_freq: 107 | assert all( 108 | [agent.observation_space == env.agents[0].observation_space for agent in env.agents]) 109 | target_policy = GaussianMLPPolicy(env.agents[0].observation_space, 110 | env.agents[0].action_space, 111 | hidden_spec=args.policy_hidden_spec, enable_obsnorm=True, 112 | min_stdev=0., init_logstdev=0., tblog=args.tblog, 113 | varscope_name='targetgaussmlp_policy') 114 | else: 115 | target_policy = None 116 | 117 | if args.baseline_type == 'linear': 118 | baselines = [LinearFeatureBaseline(agent.observation_space, enable_obsnorm=True, 119 | varscope_name='linear_baseline_{}'.format(agid)) 120 | for agid, agent in enumerate(env.agents)] 121 | elif args.baseline_type == 'mlp': 122 | baselines = [MLPBaseline(agent.observation_space, args.baseline_hidden_spec, 123 | enable_obsnorm=True, enable_vnorm=True, max_kl=args.vf_max_kl, 124 | damping=args.vf_cg_damping, time_scale=1. / args.max_traj_len, 125 | varscope_name='mlp_baseline_{}'.format(agid)) 126 | for agid, agent in enumerate(env.agents)] 127 | else: 128 | baselines = [ZeroBaseline(agent.observation_space) for agent in env.agents] 129 | 130 | if args.sampler == 'parallel': 131 | sampler_cls = ParallelSampler 132 | sampler_args = dict(max_traj_len=args.max_traj_len, n_timesteps=args.n_timesteps, 133 | n_timesteps_min=args.n_timesteps_min, 134 | n_timesteps_max=args.n_timesteps_max, timestep_rate=args.timestep_rate, 135 | adaptive=args.adaptive_batch, enable_rewnorm=True, 136 | n_workers=args.sampler_workers, mode='concurrent') 137 | 138 | else: 139 | raise NotImplementedError() 140 | 141 | step_func = rltools.algos.policyopt.TRPO(max_kl=args.max_kl) 142 | 143 | popt = rltools.algos.policyopt.ConcurrentPolicyOptimizer( 144 | env=env, policies=policies, baselines=baselines, step_func=step_func, 145 | discount=args.discount, gae_lambda=args.gae_lambda, sampler_cls=sampler_cls, 146 | sampler_args=sampler_args, n_iter=args.n_iter, target_policy=target_policy, 147 | weights=policy_avg_weights, interp_alpha=args.interp_alpha) 148 | 149 | argstr = json.dumps(vars(args), separators=(',', ':'), indent=2) 150 | rltools.util.header(argstr) 151 | log_f = rltools.log.TrainingLog(args.log, [('args', argstr)], debug=args.debug) 152 | 153 | with tf.Session() as sess: 154 | sess.run(tf.initialize_all_variables()) 155 | popt.train(sess, log_f, args.blend_freq, args.save_freq) 156 | 157 | 158 | if __name__ == '__main__': 159 | main() 160 | -------------------------------------------------------------------------------- /runners/old/rltools/run_con_waterworld.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: run_con_waterworld.py 4 | # 5 | # Created: Friday, August 12 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | 15 | from gym import spaces 16 | import rltools.algos.policyopt 17 | import rltools.log 18 | import rltools.util 19 | from rltools.samplers.serial import SimpleSampler, ImportanceWeightedSampler, DecSampler 20 | from rltools.samplers.parallel import ThreadedSampler, ParallelSampler 21 | from madrl_environments import ObservationBuffer 22 | from madrl_environments.pursuit import MAWaterWorld 23 | from rltools.baselines.linear import LinearFeatureBaseline 24 | from rltools.baselines.mlp import MLPBaseline 25 | from rltools.baselines.zero import ZeroBaseline 26 | from rltools.policy.gaussian import GaussianMLPPolicy 27 | 28 | from runners.archs import * 29 | 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--discount', type=float, default=0.95) 34 | parser.add_argument('--gae_lambda', type=float, default=0.99) 35 | 36 | parser.add_argument('--interp_alpha', type=float, default=0.5) 37 | parser.add_argument('--policy_avg_weights', type=str, default='0.33,0.33,0.33') 38 | 39 | parser.add_argument('--n_iter', type=int, default=250) 40 | parser.add_argument('--sampler', type=str, default='simple') 41 | parser.add_argument('--sampler_workers', type=int, default=4) 42 | parser.add_argument('--max_traj_len', type=int, default=500) 43 | parser.add_argument('--adaptive_batch', action='store_true', default=False) 44 | 45 | parser.add_argument('--n_timesteps', type=int, default=8000) 46 | parser.add_argument('--n_timesteps_min', type=int, default=1000) 47 | parser.add_argument('--n_timesteps_max', type=int, default=64000) 48 | parser.add_argument('--timestep_rate', type=int, default=20) 49 | 50 | parser.add_argument('--is_n_backtrack', type=int, default=1) 51 | parser.add_argument('--is_randomize_draw', action='store_true', default=False) 52 | parser.add_argument('--is_n_pretrain', type=int, default=0) 53 | parser.add_argument('--is_skip_is', action='store_true', default=False) 54 | parser.add_argument('--is_max_is_ratio', type=float, default=0) 55 | 56 | parser.add_argument('--buffer_size', type=int, default=1) 57 | parser.add_argument('--n_evaders', type=int, default=5) 58 | parser.add_argument('--n_pursuers', type=int, default=3) 59 | parser.add_argument('--n_poison', type=int, default=10) 60 | parser.add_argument('--n_coop', type=int, default=2) 61 | parser.add_argument('--n_sensors', type=int, default=30) 62 | parser.add_argument('--sensor_range', type=str, default='0.2,0.2,0.2') 63 | parser.add_argument('--food_reward', type=float, default=3) 64 | parser.add_argument('--poison_reward', type=float, default=-1) 65 | parser.add_argument('--encounter_reward', type=float, default=0.05) 66 | 67 | parser.add_argument('--policy_hidden_spec', type=str, default=GAE_ARCH) 68 | parser.add_argument('--blend_freq', type=int, default=20) 69 | 70 | parser.add_argument('--baseline_type', type=str, default='mlp') 71 | parser.add_argument('--baseline_hidden_spec', type=str, default=GAE_ARCH) 72 | 73 | parser.add_argument('--max_kl', type=float, default=0.01) 74 | parser.add_argument('--vf_max_kl', type=float, default=0.01) 75 | parser.add_argument('--vf_cg_damping', type=float, default=0.01) 76 | 77 | parser.add_argument('--save_freq', type=int, default=20) 78 | parser.add_argument('--log', type=str, required=False) 79 | parser.add_argument('--tblog', type=str, default='/tmp/madrl_tb') 80 | parser.add_argument('--debug', dest='debug', action='store_true') 81 | parser.add_argument('--no-debug', dest='debug', action='store_false') 82 | parser.set_defaults(debug=True) 83 | 84 | args = parser.parse_args() 85 | 86 | sensor_range = np.array(map(float, args.sensor_range.split(','))) 87 | assert sensor_range.shape == (args.n_pursuers,) 88 | 89 | policy_avg_weights = np.array(map(float, args.policy_avg_weights.split(','))) 90 | assert len(policy_avg_weights) == args.n_pursuers 91 | 92 | env = MAWaterWorld(args.n_pursuers, args.n_evaders, args.n_coop, args.n_poison, 93 | n_sensors=args.n_sensors, food_reward=args.food_reward, 94 | poison_reward=args.poison_reward, encounter_reward=args.encounter_reward, 95 | sensor_range=sensor_range, obstacle_loc=None) 96 | 97 | if args.buffer_size > 1: 98 | env = ObservationBuffer(env, args.buffer_size) 99 | 100 | policies = [GaussianMLPPolicy(agent.observation_space, agent.action_space, 101 | hidden_spec=args.policy_hidden_spec, enable_obsnorm=True, 102 | min_stdev=0., init_logstdev=0., tblog=args.tblog, 103 | varscope_name='gaussmlp_policy_{}'.format(agid)) 104 | for agid, agent in enumerate(env.agents)] 105 | 106 | if args.blend_freq: 107 | assert all( 108 | [agent.observation_space == env.agents[0].observation_space for agent in env.agents]) 109 | target_policy = GaussianMLPPolicy(env.agents[0].observation_space, 110 | env.agents[0].action_space, 111 | hidden_spec=args.policy_hidden_spec, enable_obsnorm=True, 112 | min_stdev=0., init_logstdev=0., tblog=args.tblog, 113 | varscope_name='targetgaussmlp_policy') 114 | else: 115 | target_policy = None 116 | 117 | if args.baseline_type == 'linear': 118 | baselines = [LinearFeatureBaseline(agent.observation_space, enable_obsnorm=True, 119 | varscope_name='linear_baseline_{}'.format(agid)) 120 | for agid, agent in enumerate(env.agents)] 121 | elif args.baseline_type == 'mlp': 122 | baselines = [MLPBaseline(agent.observation_space, args.baseline_hidden_spec, 123 | enable_obsnorm=True, enable_vnorm=True, max_kl=args.vf_max_kl, 124 | damping=args.vf_cg_damping, time_scale=1. / args.max_traj_len, 125 | varscope_name='mlp_baseline_{}'.format(agid)) 126 | for agid, agent in enumerate(env.agents)] 127 | else: 128 | baselines = [ZeroBaseline(agent.observation_space) for agent in env.agents] 129 | 130 | if args.sampler == 'parallel': 131 | sampler_cls = ParallelSampler 132 | sampler_args = dict(max_traj_len=args.max_traj_len, n_timesteps=args.n_timesteps, 133 | n_timesteps_min=args.n_timesteps_min, 134 | n_timesteps_max=args.n_timesteps_max, timestep_rate=args.timestep_rate, 135 | adaptive=args.adaptive_batch, enable_rewnorm=True, 136 | n_workers=args.sampler_workers, mode='concurrent') 137 | 138 | else: 139 | raise NotImplementedError() 140 | 141 | step_func = rltools.algos.policyopt.TRPO(max_kl=args.max_kl) 142 | 143 | popt = rltools.algos.policyopt.ConcurrentPolicyOptimizer( 144 | env=env, policies=policies, baselines=baselines, step_func=step_func, 145 | discount=args.discount, gae_lambda=args.gae_lambda, sampler_cls=sampler_cls, 146 | sampler_args=sampler_args, n_iter=args.n_iter, target_policy=target_policy, 147 | weights=policy_avg_weights, interp_alpha=args.interp_alpha) 148 | 149 | argstr = json.dumps(vars(args), separators=(',', ':'), indent=2) 150 | rltools.util.header(argstr) 151 | log_f = rltools.log.TrainingLog(args.log, [('args', argstr)], debug=args.debug) 152 | 153 | with tf.Session() as sess: 154 | sess.run(tf.initialize_all_variables()) 155 | popt.train(sess, log_f, args.blend_freq, args.save_freq) 156 | 157 | 158 | if __name__ == '__main__': 159 | main() 160 | -------------------------------------------------------------------------------- /runners/run_hostage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: run_hostage.py 4 | # 5 | # Created: Friday, September 2 2016 by rejuvyesh 6 | # 7 | from runners import RunnerParser, comma_sep_ints 8 | 9 | from madrl_environments.hostage import ContinuousHostageWorld 10 | from madrl_environments import StandardizedEnv, ObservationBuffer 11 | 12 | # yapf: disable 13 | ENV_OPTIONS = [ 14 | ('n_good', int, 8, ''), 15 | ('n_hostages', int, 16, ''), 16 | ('n_bad', int, 16, ''), 17 | ('n_coop_save', int, 4, ''), 18 | ('n_coop_avoid', int, 2, ''), 19 | ('radius', float, 0.015, ''), 20 | ('bomb_radius', float, 0.03, ''), 21 | ('key_loc', comma_sep_ints, None, ''), 22 | ('bad_speed', float, 0.01, ''), 23 | ('n_sensors', int, 30, ''), 24 | ('sensor_range', float, 0.2, ''), 25 | ('save_reward', float, 10, ''), 26 | ('hit_reward', float, -1, ''), 27 | ('encounter_reward', float, 0.01, ''), 28 | ('not_saved_reward', float, -3, ''), 29 | ('bomb_reward', float, -20, ''), 30 | ('control_penalty', float, -0.1, ''), 31 | ('reward_mech', str, 'local', ''), 32 | ('buffer_size', int, 1, ''), 33 | ] 34 | # yapf: enable 35 | 36 | def main(parser): 37 | mode = parser._mode 38 | args = parser.args 39 | 40 | env = ContinuousHostageWorld(n_good=args.n_good, 41 | n_hostages=args.n_hostages, 42 | n_bad=args.n_bad, 43 | n_coop_save=args.n_coop_save, 44 | n_coop_avoid=args.n_coop_avoid, 45 | radius=args.radius, 46 | key_loc=args.key_loc, 47 | bad_speed=args.bad_speed, 48 | n_sensors=args.n_sensors, 49 | sensor_range=args.sensor_range, 50 | save_reward=args.save_reward, 51 | hit_reward=args.hit_reward, 52 | encounter_reward=args.encounter_reward, 53 | bomb_reward=args.bomb_reward, 54 | bomb_radius=args.bomb_radius, 55 | control_penalty=args.control_penalty, 56 | reward_mech=args.reward_mech,) 57 | 58 | if args.buffer_size > 1: 59 | env = ObservationBuffer(env, args.buffer_size) 60 | 61 | if mode == 'rllab': 62 | from runners.rurllab import RLLabRunner 63 | run = RLLabRunner(env, args) 64 | elif mode == 'rltools': 65 | from runners.rurltools import RLToolsRunner 66 | run = RLToolsRunner(env, args) 67 | else: 68 | raise NotImplementedError() 69 | 70 | run() 71 | 72 | 73 | if __name__ == '__main__': 74 | main(RunnerParser(ENV_OPTIONS)) 75 | -------------------------------------------------------------------------------- /runners/run_multiant.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: run_multiwalker.py 4 | # 5 | # Created: Friday, September 2 2016 by rejuvyesh 6 | # 7 | from runners import RunnerParser 8 | from runners.curriculum import Curriculum 9 | 10 | from madrl_environments.walker.multi_walker import MultiWalkerEnv 11 | from madrl_environments.mujoco.ant.multi_ant import MultiAnt 12 | from madrl_environments import StandardizedEnv, ObservationBuffer 13 | 14 | # yapf: disable 15 | ENV_OPTIONS = [ 16 | ('n_legs', int, 4, ''), 17 | ('ts', float, 0.02, ''), 18 | ('integrator', str, 'RK4', ''), 19 | ('leg_length', float, 0.282, ''), 20 | ('out_file', str, 'multi_ant.xml', ''), 21 | ('base_file', str, 'ant_og.xml', ''), 22 | ('reward_mech', str, 'local', ''), 23 | ('buffer_size', int, 1, ''), 24 | ('curriculum', str, None, ''), 25 | ] 26 | # yapf: enable 27 | 28 | def main(parser): 29 | mode = parser._mode 30 | args = parser.args 31 | 32 | env = MultiAnt(n_legs=args.n_legs, ts=args.ts, integrator=args.integrator, 33 | leg_length=args.leg_length, out_file=args.out_file, 34 | base_file=args.base_file, reward_mech=args.reward_mech) 35 | 36 | if args.buffer_size > 1: 37 | env = ObservationBuffer(env, args.buffer_size) 38 | 39 | if mode == 'rllab': 40 | from runners.rurllab import RLLabRunner 41 | run = RLLabRunner(env, args) 42 | elif mode == 'rltools': 43 | from runners.rurltools import RLToolsRunner 44 | run = RLToolsRunner(env, args) 45 | else: 46 | raise NotImplementedError() 47 | 48 | if args.curriculum: 49 | curr = Curriculum(args.curriculum) 50 | run(curr) 51 | else: 52 | run() 53 | 54 | if __name__ == '__main__': 55 | main(RunnerParser(ENV_OPTIONS)) 56 | -------------------------------------------------------------------------------- /runners/run_multiwalker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: run_multiwalker.py 4 | # 5 | # Created: Friday, September 2 2016 by rejuvyesh 6 | # 7 | from runners import RunnerParser 8 | from runners.curriculum import Curriculum 9 | 10 | from madrl_environments.walker.multi_walker import MultiWalkerEnv 11 | from madrl_environments import StandardizedEnv, ObservationBuffer 12 | 13 | # yapf: disable 14 | ENV_OPTIONS = [ 15 | ('n_walkers', int, 2, ''), 16 | ('position_noise', float, 1e-3, ''), 17 | ('angle_noise', float, 1e-3, ''), 18 | ('reward_mech', str, 'local', ''), 19 | ('forward_reward', float, 1.0, ''), 20 | ('fall_reward', float, -100.0, ''), 21 | ('drop_reward', float, -100.0, ''), 22 | ('terminate_on_fall', int, 1, ''), 23 | ('buffer_size', int, 1, ''), 24 | ('one_hot', int, 0, ''), 25 | ('curriculum', str, None, ''), 26 | ] 27 | # yapf: enable 28 | 29 | 30 | def main(parser): 31 | mode = parser._mode 32 | args = parser.args 33 | env_config = dict(n_walkers=args.n_walkers, position_noise=args.position_noise, 34 | angle_noise=args.angle_noise, reward_mech=args.reward_mech, 35 | forward_reward=args.forward_reward, fall_reward=args.fall_reward, 36 | drop_reward=args.drop_reward, terminate_on_fall=bool(args.terminate_on_fall), 37 | one_hot=bool(args.one_hot)) 38 | env = MultiWalkerEnv(**env_config) 39 | if args.buffer_size > 1: 40 | env = ObservationBuffer(env, args.buffer_size) 41 | 42 | if mode == 'rllab': 43 | from runners.rurllab import RLLabRunner 44 | run = RLLabRunner(env, args) 45 | elif mode == 'rltools': 46 | from runners.rurltools import RLToolsRunner 47 | run = RLToolsRunner(env, args) 48 | else: 49 | raise NotImplementedError() 50 | 51 | if args.curriculum: 52 | curr = Curriculum(args.curriculum) 53 | run(curr) 54 | else: 55 | run() 56 | 57 | 58 | if __name__ == '__main__': 59 | main(RunnerParser(ENV_OPTIONS)) 60 | -------------------------------------------------------------------------------- /runners/run_pursuit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: run_multiwalker.py 4 | # 5 | # Created: Friday, September 2 2016 by rejuvyesh 6 | # 7 | import numpy as np 8 | from runners import RunnerParser 9 | 10 | from madrl_environments.pursuit import PursuitEvade 11 | from madrl_environments.pursuit.utils import TwoDMaps 12 | from madrl_environments import StandardizedEnv, ObservationBuffer 13 | 14 | # yapf: disable 15 | ENV_OPTIONS = [ 16 | ('n_evaders', int, 2, ''), 17 | ('n_pursuers', int, 2, ''), 18 | ('obs_range', int, 3, ''), 19 | ('map_size', str, '10,10', ''), 20 | ('map_type', str, 'rectangle', ''), 21 | ('n_catch', int, 2, ''), 22 | ('urgency', float, 0.0, ''), 23 | ('surround', int, 1, ''), 24 | ('map_file', str, None, ''), 25 | ('sample_maps', int, 0, ''), 26 | ('flatten', int, 1, ''), 27 | ('reward_mech', str, 'local', ''), 28 | ('catchr', float, 0.1, ''), 29 | ('term_pursuit', float, 5.0, ''), 30 | ('buffer_size', int, 1, ''), 31 | ('noid', str, None, ''), 32 | ] 33 | # yapf: enable 34 | 35 | 36 | def main(parser): 37 | mode = parser._mode 38 | args = parser.args 39 | 40 | if args.map_file: 41 | map_pool = np.load(args.map_file) 42 | else: 43 | if args.map_type == 'rectangle': 44 | env_map = TwoDMaps.rectangle_map(*map(int, args.map_size.split(','))) 45 | elif args.map_type == 'complex': 46 | env_map = TwoDMaps.complex_map(*map(int, args.map_size.split(','))) 47 | else: 48 | raise NotImplementedError() 49 | map_pool = [env_map] 50 | 51 | env = PursuitEvade(map_pool, n_evaders=args.n_evaders, n_pursuers=args.n_pursuers, 52 | obs_range=args.obs_range, n_catch=args.n_catch, urgency_reward=args.urgency, 53 | surround=bool(args.surround), sample_maps=bool(args.sample_maps), 54 | flatten=bool(args.flatten), reward_mech=args.reward_mech, catchr=args.catchr, 55 | term_pursuit=args.term_pursuit, include_id=not bool(args.noid)) 56 | 57 | if args.buffer_size > 1: 58 | env = ObservationBuffer(env, args.buffer_size) 59 | 60 | if mode == 'rllab': 61 | from runners.rurllab import RLLabRunner 62 | run = RLLabRunner(env, args) 63 | elif mode == 'rltools': 64 | from runners.rurltools import RLToolsRunner 65 | run = RLToolsRunner(env, args) 66 | else: 67 | raise NotImplementedError() 68 | 69 | run() 70 | 71 | 72 | if __name__ == '__main__': 73 | main(RunnerParser(ENV_OPTIONS)) 74 | -------------------------------------------------------------------------------- /runners/run_waterworld.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: run_waterworld.py 4 | # 5 | # Created: Wednesday, August 31 2016 by rejuvyesh 6 | # 7 | from runners import RunnerParser 8 | from runners.curriculum import Curriculum 9 | 10 | from madrl_environments.pursuit import MAWaterWorld 11 | from madrl_environments import StandardizedEnv, ObservationBuffer 12 | 13 | # yapf: disable 14 | ENV_OPTIONS = [ 15 | ('radius', float, 0.015, 'Radius of agents'), 16 | ('n_evaders', int, 10, ''), 17 | ('n_pursuers', int, 8, ''), 18 | ('n_poison', int, 10, ''), 19 | ('n_coop', int, 4, ''), 20 | ('n_sensors', int, 30, ''), 21 | ('sensor_range', int, 0.2, ''), 22 | ('food_reward', float, 10, ''), 23 | ('poison_reward', float, -1, ''), 24 | ('encounter_reward', float, 0.01, ''), 25 | ('reward_mech', str, 'local', ''), 26 | ('noid', str, None, ''), 27 | ('speed_features', int, 1, ''), 28 | ('buffer_size', int, 1, ''), 29 | ('curriculum', str, None, ''), 30 | ] 31 | # yapf: enable 32 | 33 | 34 | def main(parser): 35 | mode = parser._mode 36 | args = parser.args 37 | env = MAWaterWorld(args.n_pursuers, args.n_evaders, args.n_coop, args.n_poison, 38 | radius=args.radius, n_sensors=args.n_sensors, food_reward=args.food_reward, 39 | poison_reward=args.poison_reward, encounter_reward=args.encounter_reward, 40 | reward_mech=args.reward_mech, sensor_range=args.sensor_range, 41 | obstacle_loc=None, addid=True if not args.noid else False, 42 | speed_features=bool(args.speed_features)) 43 | 44 | if args.buffer_size > 1: 45 | env = ObservationBuffer(env, args.buffer_size) 46 | 47 | if mode == 'rllab': 48 | from runners.rurllab import RLLabRunner 49 | run = RLLabRunner(env, args) 50 | elif mode == 'rltools': 51 | from runners.rurltools import RLToolsRunner 52 | run = RLToolsRunner(env, args) 53 | else: 54 | raise NotImplementedError() 55 | 56 | if args.curriculum: 57 | curr = Curriculum(args.curriculum) 58 | run(curr) 59 | else: 60 | run() 61 | 62 | 63 | if __name__ == '__main__': 64 | main(RunnerParser(ENV_OPTIONS)) 65 | -------------------------------------------------------------------------------- /sample_spec.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | runs: 1 3 | baselines: ['linear', 'mlp'] 4 | algorithms: 5 | - name: centralized 6 | cmd: > 7 | python simple.py --baseline_type {baseline_type} 8 | --discount {discount} 9 | --gae_lambda {gae_lambda} 10 | --rectangle {rectangle} 11 | --n_evaders {n_evaders} 12 | --n_pursuers {n_pursuers} 13 | --obs_range {obs_range} 14 | --n_catch {n_catch} 15 | --log {log} 16 | --no-debug 17 | - name: decentralized 18 | cmd: > 19 | python simple.py --baseline_type {baseline_type} 20 | --control decentralized 21 | --discount {discount} 22 | --gae_lambda {gae_lambda} 23 | --rectangle {rectangle} 24 | --n_evaders {n_evaders} 25 | --n_pursuers {n_pursuers} 26 | --obs_range {obs_range} 27 | --n_catch {n_catch} 28 | --log {log} 29 | --no-debug 30 | 31 | rectangles: ['10,10', '16,16', '32,32'] 32 | n_evaders: [5, 10, 15] 33 | n_pursuers: [2, 5, 10] 34 | obs_ranges: [3, 5, 7, 9] 35 | n_catches: [2, 3] 36 | discounts: [0.99, 0.95] 37 | gae_lambdas: [1, 0.99, 0.97] 38 | 39 | 40 | options: 41 | storagedir: /tmp/spec/ 42 | checkpt_subdir: checkpoints 43 | n_workers: 4 44 | -------------------------------------------------------------------------------- /vis/bar_plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: bar_plot.py 4 | # 5 | import argparse 6 | import os 7 | import pickle 8 | import matplotlib 9 | 10 | params = { 11 | 'axes.labelsize': 12, 12 | 'font.size': 16, 13 | 'font.family': 'serif', 14 | 'legend.fontsize': 12, 15 | 'xtick.labelsize': 12, 16 | 'ytick.labelsize': 12, 17 | #'text.usetex': True, 18 | 'figure.figsize': [4.5, 4.5] 19 | } 20 | matplotlib.rcParams.update(params) 21 | 22 | import matplotlib.pyplot as plt 23 | plt.style.use('seaborn-colorblind') 24 | 25 | import pandas as pd 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('dir', type=str) 31 | parser.add_argument('--random', type=float, default=None) 32 | args = parser.parse_args() 33 | 34 | if args.random: 35 | random = args.random 36 | else: 37 | random = 0 38 | with open(os.path.join(args.dir, 'results.pkl'), 'rb') as f: 39 | res = pickle.load(f)['retlist'] 40 | control_params = ['decentralized', 'concurrent', 'centralized'] 41 | cp2name = {'decentralized': 'PS', 'concurrent': 'Conc.', 'centralized': 'Cent.'} 42 | nn_params = ['mlp', 'gru', 'heuristic'] 43 | 44 | header = ['training'] 45 | for nnp in nn_params: 46 | header.append(nnp) 47 | header.append(nnp + '_error') 48 | 49 | mat = [] 50 | for cp in control_params: 51 | row = [cp2name[cp]] 52 | for nnp in nn_params: 53 | key = cp + '-' + nnp 54 | if key in res: 55 | if res[key]: 56 | row.append(res[key]['mean'] - random) 57 | row.append(res[key]['std']) 58 | else: 59 | row.append(None) 60 | row.append(None) 61 | elif 'heuristic' in key: 62 | row.append(res['heuristic']['mean'] - random) 63 | row.append(res['heuristic']['std']) 64 | mat.append(row) 65 | 66 | dat = pd.DataFrame(mat, columns=header) 67 | print(dat.to_csv(index=False, float_format='%.3f', na_rep='nan')) 68 | 69 | csv_errors = dat[['mlp_error', 'gru_error']].rename( 70 | columns={'mlp_error': 'mlp', 71 | 'gru_error': 'gru'}) 72 | ax = dat[['mlp', 'gru']].plot(kind='bar', title='', legend=True, yerr=csv_errors, alpha=0.7) 73 | ax.plot(dat['heuristic'], linewidth=1.2, linestyle='--', alpha=0.7) 74 | leg = ax.legend(['Heuristic', 'MLP', 'GRU']) 75 | leg.get_frame().set_alpha(0.7) 76 | ax.set_xticklabels(dat.training, rotation=0) 77 | ax.set_ylabel('Normalized Returns') 78 | plt.savefig(os.path.join(args.dir, 'bar.pdf'), bbox_inches='tight') 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /vis/max_bar_plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: bar_plot.py 4 | # 5 | import argparse 6 | import os 7 | import pickle 8 | import matplotlib 9 | import numpy as np 10 | 11 | params = { 12 | 'axes.labelsize': 12, 13 | 'font.size': 16, 14 | 'font.family': 'serif', 15 | 'legend.fontsize': 12, 16 | 'xtick.labelsize': 12, 17 | 'ytick.labelsize': 12, 18 | #'text.usetex': True, 19 | 'figure.figsize': [4.5, 4.5] 20 | } 21 | matplotlib.rcParams.update(params) 22 | 23 | import matplotlib.pyplot as plt 24 | plt.style.use('seaborn-colorblind') 25 | 26 | import pandas as pd 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('dir', type=str) 32 | parser.add_argument('--random', type=float, default=None) 33 | args = parser.parse_args() 34 | 35 | if args.random: 36 | random = args.random 37 | else: 38 | random = 0 39 | with open(os.path.join(args.dir, 'results.pkl'), 'rb') as f: 40 | res = pickle.load(f) 41 | control_params = ['decentralized', 'concurrent', 'centralized'] 42 | nn_params = ['mlp', 'gru', 'heuristic'] 43 | 44 | header = ['training'] 45 | for nnp in nn_params: 46 | header.append(nnp) 47 | header.append(nnp + '_error') 48 | 49 | mat = [] 50 | for cp in control_params: 51 | row = [cp] 52 | for nnp in nn_params: 53 | key = cp + '-' + nnp 54 | if key in res: 55 | if res[key]: 56 | row.append(res[key]['retlist'].mean() - random) 57 | row.append(res[key]['retlist'].std()) 58 | else: 59 | row.append(None) 60 | row.append(None) 61 | elif 'heuristic' in key: 62 | row.append(res['heuristic']['retlist'].mean() - random) 63 | row.append(res['heuristic']['retlist'].std()) 64 | mat.append(row) 65 | 66 | dat = pd.DataFrame(mat, columns=header) 67 | print(dat.to_csv(index=False, float_format='%.3f', na_rep='nan')) 68 | 69 | csv_errors = dat[['mlp_error', 'gru_error']].rename( 70 | columns={'mlp_error': 'mlp', 71 | 'gru_error': 'gru'}) 72 | ax = dat[['mlp', 'gru']].plot(kind='bar', title='', legend=True, yerr=csv_errors, alpha=0.7) 73 | ax.plot(dat['heuristic'], linewidth=1.2, linestyle='--', alpha=0.7) 74 | leg = ax.legend(['Heuristic', 'MLP', 'GRU']) 75 | leg.get_frame().set_alpha(0.7) 76 | ax.set_xticklabels(dat.training, rotation=0) 77 | ax.set_ylabel('Normalized Returns') 78 | plt.savefig(os.path.join(args.dir, 'bar.pdf'), bbox_inches='tight') 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /vis/rllab/showlog.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # File: showlog.py 4 | # 5 | import argparse 6 | import pandas as pd 7 | 8 | 9 | def runplot(args): 10 | assert len(set(args.logfiles)) == len(args.logfiles), 'Log files must be unique' 11 | 12 | fname2log = {} 13 | for fname in args.logfiles: 14 | df = pd.read_csv(fname) 15 | if 'Iteration' in df.keys(): 16 | df.set_index('Iteration', inplace=True) 17 | elif 'Epoch' in df.keys(): 18 | df.set_index('Epoch', inplace=True) 19 | else: 20 | raise NotImplementedError() 21 | if not args.fields == 'all': 22 | df = df.loc[:, args.fields.split(',')] 23 | fname2log[fname] = df 24 | 25 | if not args.noplot: 26 | import matplotlib 27 | if args.plotfile is not None: 28 | matplotlib.use('Agg') 29 | 30 | import matplotlib.pyplot as plt 31 | plt.style.use('seaborn-colorblind') 32 | 33 | ax = None 34 | for fname, df in fname2log.items(): 35 | with pd.option_context('display.max_rows', 9999): 36 | print(fname) 37 | print(df[-1:]) 38 | 39 | if not args.noplot: 40 | if ax is None: 41 | ax = df.plot(subplots=True, title=','.join(args.logfiles)) 42 | else: 43 | df.plot(subplots=True, title=','.join(args.logfiles), ax=ax, legend=False) 44 | 45 | if args.plotfile is not None: 46 | plt.savefig(args.plotfile, transparent=True, bbox_inches='tight', dpi=300) 47 | else: 48 | plt.show() 49 | 50 | 51 | if __name__ == '__main__': 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('logfiles', type=str, nargs='+') 54 | parser.add_argument('--noplot', action='store_true') 55 | parser.add_argument('--fields', type=str, default='all') 56 | parser.add_argument('--plotfile', type=str, default=None) 57 | args = parser.parse_args() 58 | runplot(args) 59 | -------------------------------------------------------------------------------- /vis/rllab/vis_pursuit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: vis_traj.py 4 | # 5 | # Created: Wednesday, July 13 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import pprint 12 | import sys 13 | import joblib 14 | import os.path as osp 15 | import uuid 16 | import os 17 | 18 | from gym import spaces 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from rllab.sampler.utils import rollout, decrollout 23 | 24 | from madrl_environments.pursuit import PursuitEvade 25 | from madrl_environments.pursuit.utils import TwoDMaps 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('policy_file', type=str) 31 | parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') 32 | parser.add_argument('--verbose', action='store_true', default=False) 33 | parser.add_argument('--n_steps', type=int, default=200) 34 | parser.add_argument('--map_file', type=str, default='') 35 | args = parser.parse_args() 36 | 37 | policy_dir = osp.dirname(args.policy_file) 38 | params_file = osp.join(policy_dir, 'params.json') 39 | 40 | # Load file 41 | with open(params_file) as data_file: 42 | train_args = json.load(data_file) 43 | print('Loading parameters from {} in {}'.format(policy_dir, 'params.json')) 44 | 45 | 46 | with tf.Session() as sess: 47 | 48 | data = joblib.load(args.policy_file) 49 | 50 | policy = data['policy'] 51 | env = data['env'] 52 | 53 | if train_args['control'] == 'centralized': 54 | paths = rollout(env, policy, max_path_length=args.n_steps, animated=True) 55 | elif train_args['control'] == 'decentralized': 56 | paths = decrollout(env, policy, max_path_length=args.n_steps, animated=True) 57 | 58 | 59 | """ 60 | if train_args['control'] == 'centralized': 61 | act_fn = lambda o: policy.get_action(o)[0] 62 | elif train_args['control'] == 'decentralized': 63 | def act_fn(o): 64 | action_list = [] 65 | for agent_obs in o: 66 | a, adist = policy.get_action(agent_obs) 67 | action_list.append(a) 68 | return action_list 69 | env.animate(act_fn=act_fn, nsteps=args.n_steps, file_name=args.vid, verbose=args.verbose) 70 | """ 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /vis/rllab/vis_waterworld.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: vis_traj.py 4 | # 5 | # Created: Wednesday, July 13 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import pprint 12 | import sys 13 | import joblib 14 | import os.path as osp 15 | import uuid 16 | import os 17 | 18 | from gym import spaces 19 | import numpy as np 20 | import tensorflow as tf 21 | from rllab.sampler.utils import rollout, decrollout 22 | 23 | from madrl_environments import ObservationBuffer 24 | from madrl_environments.pursuit import MAWaterWorld 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('policy_file', type=str) 30 | parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') 31 | parser.add_argument('--verbose', action='store_true', default=False) 32 | parser.add_argument('--n_steps', type=int, default=200) 33 | parser.add_argument('--map_file', type=str, default='') 34 | args = parser.parse_args() 35 | 36 | policy_dir = osp.dirname(args.policy_file) 37 | params_file = osp.join(policy_dir, 'params.json') 38 | 39 | # Load file 40 | with open(params_file) as data_file: 41 | train_args = json.load(data_file) 42 | print('Loading parameters from {} in {}'.format(policy_dir, 'params.json')) 43 | with tf.Session() as sess: 44 | data = joblib.load(args.policy_file) 45 | 46 | policy = data['policy'] 47 | env = data['env'] 48 | 49 | if train_args['control'] == 'centralized': 50 | paths = rollout(env, policy, max_path_length=args.n_steps, animated=True) 51 | elif train_args['control'] == 'decentralized': 52 | paths = decrollout(env, policy, max_path_length=args.n_steps, animated=True) 53 | """ 54 | if train_args['control'] == 'centralized': 55 | act_fn = lambda o: policy.get_action(o)[0] 56 | elif train_args['control'] == 'decentralized': 57 | def act_fn(o): 58 | action_list = [] 59 | for agent_obs in o: 60 | a, adist = policy.get_action(agent_obs) 61 | action_list.append(a) 62 | return action_list 63 | env.animate(act_fn=act_fn, nsteps=args.n_steps, file_name=args.vid, verbose=args.verbose) 64 | """ 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /vis/rltools/vis_hostage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: vis_hostage.py 4 | # 5 | # Created: Monday, August 1 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import pprint 12 | 13 | from gym import spaces 14 | import h5py 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | import rltools.algos 19 | import rltools.log 20 | import rltools.util 21 | from madrl_environments import ObservationBuffer 22 | from madrl_environments.hostage import ContinuousHostageWorld 23 | from rltools.baselines.linear import LinearFeatureBaseline 24 | from rltools.baselines.mlp import MLPBaseline 25 | from rltools.baselines.zero import ZeroBaseline 26 | from rltools.policy.gaussian import GaussianMLPPolicy 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('filename', type=str) # defaultIS.h5/snapshots/iter0000480 32 | parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') 33 | parser.add_argument('--deterministic', action='store_true', default=False) 34 | parser.add_argument('--n_steps', type=int, default=1000) 35 | args = parser.parse_args() 36 | 37 | # Load file 38 | filename, file_key = rltools.util.split_h5_name(args.filename) 39 | print('Loading parameters from {} in {}'.format(file_key, filename)) 40 | with h5py.File(filename, 'r') as f: 41 | train_args = json.loads(f.attrs['args']) 42 | dset = f[file_key] 43 | 44 | pprint.pprint(dict(dset.attrs)) 45 | 46 | centralized = True if train_args['control'] == 'centralized' else False 47 | env = ContinuousHostageWorld(train_args['n_good'], 48 | train_args['n_hostage'], 49 | train_args['n_bad'], 50 | train_args['n_coop_save'], 51 | train_args['n_coop_avoid'], 52 | n_sensors=train_args['n_sensors'], 53 | sensor_range=train_args['sensor_range'], 54 | save_reward=train_args['save_reward'], 55 | hit_reward=train_args['hit_reward'], 56 | encounter_reward=train_args['encounter_reward'], 57 | bomb_reward=train_args['bomb_reward'],) 58 | 59 | if train_args['buffer_size'] > 1: 60 | env = ObservationBuffer(env, train_args['buffer_size']) 61 | 62 | if centralized: 63 | obsfeat_space = spaces.Box(low=env.agents[0].observation_space.low[0], 64 | high=env.agents[0].observation_space.high[0], 65 | shape=(env.agents[0].observation_space.shape[0] * 66 | len(env.agents),)) # XXX 67 | action_space = spaces.Box(low=env.agents[0].action_space.low[0], 68 | high=env.agents[0].action_space.high[0], 69 | shape=(env.agents[0].action_space.shape[0] * 70 | len(env.agents),)) # XXX 71 | else: 72 | obsfeat_space = env.agents[0].observation_space 73 | action_space = env.agents[0].action_space 74 | 75 | policy = GaussianMLPPolicy(obsfeat_space, action_space, 76 | hidden_spec=train_args['policy_hidden_spec'], enable_obsnorm=True, 77 | min_stdev=0., init_logstdev=0., tblog=train_args['tblog'], 78 | varscope_name='gaussmlp_policy') 79 | 80 | with tf.Session() as sess: 81 | sess.run(tf.initialize_all_variables()) 82 | policy.load_h5(sess, filename, file_key) 83 | 84 | rew = env.animate( 85 | act_fn=lambda o: policy.sample_actions(sess, o[None, ...], deterministic=args.deterministic), 86 | nsteps=args.n_steps, file_name=args.vid) 87 | print(rew) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /vis/rltools/vis_pursuit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: vis_traj.py 4 | # 5 | # Created: Wednesday, July 13 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import pprint 12 | import sys 13 | sys.path.append('../rltools/') 14 | 15 | import gym 16 | from gym import spaces 17 | import h5py 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | import rltools.algos 22 | import rltools.log 23 | import rltools.util 24 | 25 | from madrl_environments.pursuit import PursuitEvade 26 | from madrl_environments.pursuit.utils import TwoDMaps 27 | 28 | from rltools.policy.categorical import CategoricalMLPPolicy 29 | 30 | from pursuit_policy import PursuitCentralMLPPolicy 31 | 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('filename', type=str) # defaultIS.h5/snapshots/iter0000480 36 | parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') 37 | parser.add_argument('--deterministic', action='store_true', default=False) 38 | parser.add_argument('--verbose', action='store_true', default=False) 39 | parser.add_argument('--n_steps', type=int, default=200) 40 | parser.add_argument('--map_file', type=str, default='') 41 | args = parser.parse_args() 42 | 43 | # Load file 44 | filename, file_key = rltools.util.split_h5_name(args.filename) 45 | print('Loading parameters from {} in {}'.format(file_key, filename)) 46 | with h5py.File(filename, 'r') as f: 47 | train_args = json.loads(f.attrs['args']) 48 | dset = f[file_key] 49 | 50 | pprint.pprint(dict(dset.attrs)) 51 | 52 | pprint.pprint(train_args) 53 | if train_args['sample_maps']: 54 | map_pool = np.load(args.map_file) 55 | else: 56 | if train_args['map_type'] == 'rectangle': 57 | env_map = TwoDMaps.rectangle_map(*map(int, train_args['rectangle'].split(','))) 58 | elif train_args['map_type'] == 'complex': 59 | env_map = TwoDMaps.complex_map(*map(int, train_args['rectangle'].split(','))) 60 | else: 61 | raise NotImplementedError() 62 | map_pool = [env_map] 63 | 64 | env = PursuitEvade(map_pool, 65 | #n_evaders=train_args['n_evaders'], 66 | #n_pursuers=train_args['n_pursuers'], 67 | n_evaders=50, 68 | n_pursuers=50, 69 | obs_range=train_args['obs_range'], 70 | n_catch=train_args['n_catch'], 71 | urgency_reward=train_args['urgency'], 72 | surround=train_args['surround'], 73 | sample_maps=train_args['sample_maps'], 74 | flatten=train_args['flatten'], 75 | reward_mech=train_args['reward_mech'] 76 | ) 77 | 78 | if train_args['control'] == 'decentralized': 79 | obsfeat_space = env.agents[0].observation_space 80 | action_space = env.agents[0].action_space 81 | elif train_args['control'] == 'centralized': 82 | obsfeat_space = spaces.Box(low=env.agents[0].observation_space.low[0], 83 | high=env.agents[0].observation_space.high[0], 84 | shape=(env.agents[0].observation_space.shape[0] * 85 | len(env.agents),)) # XXX 86 | action_space = spaces.Discrete(env.agents[0].action_space.n * len(env.agents)) 87 | 88 | else: 89 | raise NotImplementedError() 90 | 91 | 92 | policy = CategoricalMLPPolicy(obsfeat_space, action_space, 93 | hidden_spec=train_args['policy_hidden_spec'], 94 | enable_obsnorm=True, 95 | tblog=train_args['tblog'], varscope_name='pursuit_catmlp_policy') 96 | 97 | 98 | with tf.Session() as sess: 99 | sess.run(tf.initialize_all_variables()) 100 | policy.load_h5(sess, filename, file_key) 101 | if train_args['control'] == 'centralized': 102 | act_fn = lambda o: policy.sample_actions(np.expand_dims(np.array(o).flatten(),0), deterministic=args.deterministic)[0][0,0] 103 | elif train_args['control'] == 'decentralized': 104 | def act_fn(o): 105 | action_list = [] 106 | for agent_obs in o: 107 | a, adist = policy.sample_actions(np.expand_dims(agent_obs,0), deterministic=args.deterministic) 108 | action_list.append(a[0, 0]) 109 | return action_list 110 | #import IPython 111 | #IPython.embed() 112 | env.animate(act_fn=act_fn, nsteps=args.n_steps, file_name=args.vid, verbose=args.verbose) 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /vis/rltools/vis_waterworld.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: vis_waterworld.py 4 | # 5 | # Created: Thursday, July 14 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import pprint 12 | 13 | from gym import spaces 14 | import h5py 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | import rltools.algos 19 | import rltools.log 20 | import rltools.util 21 | import rltools.samplers 22 | from madrl_environments import ObservationBuffer 23 | from madrl_environments.pursuit import MAWaterWorld 24 | from rltools.baselines.linear import LinearFeatureBaseline 25 | from rltools.baselines.mlp import MLPBaseline 26 | from rltools.baselines.zero import ZeroBaseline 27 | from rltools.policy.gaussian import GaussianMLPPolicy 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('filename', type=str) # defaultIS.h5/snapshots/iter0000480 33 | parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') 34 | parser.add_argument('--deterministic', action='store_true', default=False) 35 | parser.add_argument('--evaluate', action='store_true', default=False) 36 | parser.add_argument('--n_steps', type=int, default=500) 37 | args = parser.parse_args() 38 | 39 | # Load file 40 | filename, file_key = rltools.util.split_h5_name(args.filename) 41 | print('Loading parameters from {} in {}'.format(file_key, filename)) 42 | with h5py.File(filename, 'r') as f: 43 | train_args = json.loads(f.attrs['args']) 44 | dset = f[file_key] 45 | 46 | pprint.pprint(dict(dset.attrs)) 47 | 48 | centralized = True if train_args['control'] == 'centralized' else False 49 | env = MAWaterWorld(train_args['n_pursuers'], train_args['n_evaders'], train_args['n_coop'], 50 | train_args['n_poison'], n_sensors=train_args['n_sensors'], 51 | food_reward=train_args['food_reward'], 52 | poison_reward=train_args['poison_reward'], obstacle_location=None, 53 | encounter_reward=train_args['encounter_reward'], addid=False if 54 | not centralized else True, speed_features=bool(train_args['speed_features'])) 55 | 56 | if train_args['buffer_size'] > 1: 57 | env = ObservationBuffer(env, train_args['buffer_size']) 58 | 59 | if centralized: 60 | obs_space = spaces.Box(low=env.agents[0].observation_space.low[0], 61 | high=env.agents[0].observation_space.high[0], 62 | shape=(env.agents[0].observation_space.shape[0] * 63 | len(env.agents),)) # XXX 64 | action_space = spaces.Box(low=env.agents[0].action_space.low[0], 65 | high=env.agents[0].action_space.high[0], 66 | shape=(env.agents[0].action_space.shape[0] * 67 | len(env.agents),)) # XXX 68 | else: 69 | obsfeat_space = env.agents[0].observation_space 70 | action_space = env.agents[0].action_space 71 | 72 | policy = GaussianMLPPolicy(obsfeat_space, action_space, 73 | hidden_spec=train_args['policy_hidden_spec'], 74 | enable_obsnorm=train_args['enable_obsnorm'], min_stdev=0., 75 | init_logstdev=0., tblog=train_args['tblog'], varscope_name='policy') 76 | 77 | with tf.Session() as sess: 78 | sess.run(tf.initialize_all_variables()) 79 | policy.load_h5(sess, filename, file_key) 80 | # Evaluate 81 | if args.evaluate: 82 | rltools.util.ok("Evaluating...") 83 | evr = rltools.util.evaluate_policy(env, policy, deterministic=args.deterministic, 84 | disc=train_args['discount'], 85 | mode=train_args['control'], 86 | max_traj_len=args.n_steps, n_trajs=100) 87 | from tabulate import tabulate 88 | print(tabulate(evr, headers='keys')) 89 | else: 90 | rew, trajinfo = env.animate( 91 | act_fn=lambda o: policy.sample_actions(o[None, ...], deterministic=args.deterministic)[0], 92 | nsteps=args.n_steps) 93 | info = {key: np.sum(value) for key, value in trajinfo.items()} 94 | print(rew, info) 95 | 96 | 97 | if __name__ == '__main__': 98 | main() 99 | -------------------------------------------------------------------------------- /vis/vis_multiant.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: vis_multiwalker.py 4 | # 5 | # Created: Wednesday, September 7 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import pprint 12 | import os 13 | import os.path 14 | import pickle 15 | 16 | from gym import spaces 17 | import h5py 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | import rltools.algos 22 | import rltools.log 23 | import rltools.util 24 | import rltools.samplers 25 | from madrl_environments import ObservationBuffer 26 | from madrl_environments.mujoco.ant.multi_ant import MultiAnt 27 | from rltools.baselines.linear import LinearFeatureBaseline 28 | from rltools.baselines.mlp import MLPBaseline 29 | from rltools.baselines.zero import ZeroBaseline 30 | from rltools.policy.gaussian import GaussianMLPPolicy 31 | 32 | from vis import Evaluator, Visualizer, FileHandler 33 | 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('filename', type=str) # defaultIS.h5/snapshots/iter0000480 38 | parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') 39 | parser.add_argument('--deterministic', action='store_true', default=False) 40 | parser.add_argument('--heuristic', action='store_true', default=False) 41 | parser.add_argument('--evaluate', action='store_true', default=False) 42 | parser.add_argument('--save_file', type=str, default=None) 43 | parser.add_argument('--n_trajs', type=int, default=20) 44 | parser.add_argument('--n_steps', type=int, default=500) 45 | parser.add_argument('--same_con_pol', action='store_true') 46 | args = parser.parse_args() 47 | 48 | fh = FileHandler(args.filename) 49 | 50 | env = MultiAnt(n_legs=fh.train_args['n_legs'], 51 | ts=fh.train_args['ts'], 52 | integrator=fh.train_args['integrator'], 53 | out_file=fh.train_args['out_file'], 54 | base_file=fh.train_args['base_file'], 55 | reward_mech=fh.train_args['reward_mech']) 56 | 57 | if fh.train_args['buffer_size'] > 1: 58 | env = ObservationBuffer(env, fh.train_args['buffer_size']) 59 | 60 | hpolicy = None 61 | if args.evaluate: 62 | minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, 63 | 'heuristic' if args.heuristic else fh.mode) 64 | evr = minion(fh.filename, file_key=fh.file_key, same_con_pol=args.same_con_pol, 65 | hpolicy=hpolicy) 66 | if args.save_file: 67 | pickle.dump(evr, open(args.save_file, "wb")) 68 | from tabulate import tabulate 69 | #print(tabulate(evr, headers='keys')) 70 | 71 | 72 | else: 73 | minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, 74 | fh.mode) 75 | rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid) 76 | pprint.pprint(rew) 77 | pprint.pprint(info) 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /vis/vis_multiwalker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: vis_multiwalker.py 4 | # 5 | # Created: Wednesday, September 7 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import pprint 12 | import os 13 | import os.path 14 | 15 | from gym import spaces 16 | import h5py 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | import rltools.algos 21 | import rltools.log 22 | import rltools.util 23 | import rltools.samplers 24 | from madrl_environments import ObservationBuffer 25 | from madrl_environments.walker.multi_walker import MultiWalkerEnv 26 | from rltools.baselines.linear import LinearFeatureBaseline 27 | from rltools.baselines.mlp import MLPBaseline 28 | from rltools.baselines.zero import ZeroBaseline 29 | from rltools.policy.gaussian import GaussianMLPPolicy 30 | 31 | from vis import Evaluator, Visualizer, FileHandler 32 | 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('filename', type=str) # defaultIS.h5/snapshots/iter0000480 37 | parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') 38 | parser.add_argument('--deterministic', action='store_true', default=False) 39 | parser.add_argument('--heuristic', action='store_true', default=False) 40 | parser.add_argument('--evaluate', action='store_true', default=False) 41 | parser.add_argument('--n_trajs', type=int, default=20) 42 | parser.add_argument('--n_steps', type=int, default=500) 43 | parser.add_argument('--same_con_pol', action='store_true') 44 | args = parser.parse_args() 45 | 46 | fh = FileHandler(args.filename) 47 | 48 | env = MultiWalkerEnv(fh.train_args['n_walkers'], fh.train_args['position_noise'], 49 | fh.train_args['angle_noise'], 50 | reward_mech='global') #fh.train_args['reward_mech']) 51 | 52 | if fh.train_args['buffer_size'] > 1: 53 | env = ObservationBuffer(env, fh.train_args['buffer_size']) 54 | 55 | hpolicy = None 56 | if args.heuristic: 57 | from heuristics.multiwalker import MultiwalkerHeuristicPolicy 58 | hpolicy = MultiwalkerHeuristicPolicy(env.agents[0].observation_space, 59 | env.agents[0].action_space) 60 | 61 | if args.evaluate: 62 | minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, 63 | 'heuristic' if args.heuristic else fh.mode) 64 | evr = minion(fh.filename, file_key=fh.file_key, same_con_pol=args.same_con_pol, 65 | hpolicy=hpolicy) 66 | from tabulate import tabulate 67 | print(tabulate(evr, headers='keys')) 68 | else: 69 | minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, 70 | fh.mode) 71 | rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid) 72 | pprint.pprint(rew) 73 | pprint.pprint(info) 74 | 75 | 76 | if __name__ == '__main__': 77 | main() 78 | -------------------------------------------------------------------------------- /vis/vis_pursuit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: vis_pursuit.py 4 | # 5 | # Created: Wednesday, September 14 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import pprint 12 | import os 13 | import os.path 14 | import numpy as np 15 | 16 | from madrl_environments.pursuit import PursuitEvade 17 | from madrl_environments.pursuit.utils import TwoDMaps 18 | from madrl_environments import StandardizedEnv, ObservationBuffer 19 | 20 | from vis import Evaluator, Visualizer, FileHandler 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('filename', type=str) # defaultIS.h5/snapshots/iter0000480 26 | parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') 27 | parser.add_argument('--deterministic', action='store_true', default=False) 28 | parser.add_argument('--heuristic', action='store_true', default=False) 29 | parser.add_argument('--evaluate', action='store_true', default=False) 30 | parser.add_argument('--n_trajs', type=int, default=20) 31 | parser.add_argument('--n_steps', type=int, default=500) 32 | parser.add_argument('--same_con_pol', action='store_true') 33 | args = parser.parse_args() 34 | 35 | fh = FileHandler(args.filename) 36 | 37 | map_pool = np.load( 38 | os.path.join('/scratch/megorov/deeprl/MADRL/runners/maps/', os.path.basename(fh.train_args[ 39 | 'map_file']))) 40 | env = PursuitEvade(map_pool, n_evaders=fh.train_args['n_evaders'], 41 | n_pursuers=fh.train_args['n_pursuers'], obs_range=fh.train_args['obs_range'], 42 | n_catch=fh.train_args['n_catch'], urgency_reward=fh.train_args['urgency'], 43 | surround=bool(fh.train_args['surround']), 44 | sample_maps=bool(fh.train_args['sample_maps']), 45 | flatten=bool(fh.train_args['flatten']), reward_mech='global', 46 | catchr=fh.train_args['catchr'], term_pursuit=fh.train_args['term_pursuit']) 47 | 48 | if fh.train_args['buffer_size'] > 1: 49 | env = ObservationBuffer(env, fh.train_args['buffer_size']) 50 | 51 | hpolicy = None 52 | if args.heuristic: 53 | from heuristics.pursuit import PursuitHeuristicPolicy 54 | hpolicy = PursuitHeuristicPolicy(env.agents[0].observation_space, 55 | env.agents[0].action_space) 56 | 57 | if args.evaluate: 58 | minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, 59 | 'heuristic' if args.heuristic else fh.mode) 60 | evr = minion(fh.filename, file_key=fh.file_key, same_con_pol=args.same_con_pol, 61 | hpolicy=hpolicy) 62 | from tabulate import tabulate 63 | print(evr) 64 | print(tabulate(evr, headers='keys')) 65 | else: 66 | minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, 67 | fh.mode) 68 | rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid) 69 | pprint.pprint(rew) 70 | pprint.pprint(info) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /vis/vis_waterworld.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: vis_waterworld.py 4 | # 5 | # Created: Thursday, July 14 2016 by rejuvyesh 6 | # 7 | from __future__ import absolute_import, print_function 8 | 9 | import argparse 10 | import json 11 | import pprint 12 | import os 13 | import os.path 14 | 15 | from gym import spaces 16 | import h5py 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | import rltools.algos 21 | import rltools.log 22 | import rltools.util 23 | import rltools.samplers 24 | from madrl_environments import ObservationBuffer 25 | from madrl_environments.pursuit import MAWaterWorld 26 | from rltools.baselines.linear import LinearFeatureBaseline 27 | from rltools.baselines.mlp import MLPBaseline 28 | from rltools.baselines.zero import ZeroBaseline 29 | from rltools.policy.gaussian import GaussianMLPPolicy 30 | 31 | from vis import Evaluator, Visualizer, FileHandler 32 | 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('filename', type=str) # defaultIS.h5/snapshots/iter0000480 37 | parser.add_argument('--vid', type=str, default='/tmp/madrl.mp4') 38 | parser.add_argument('--deterministic', action='store_true', default=False) 39 | parser.add_argument('--heuristic', action='store_true', default=False) 40 | parser.add_argument('--evaluate', action='store_true', default=False) 41 | parser.add_argument('--n_trajs', type=int, default=10) 42 | parser.add_argument('--n_steps', type=int, default=500) 43 | parser.add_argument('--same_con_pol', action='store_true') 44 | args = parser.parse_args() 45 | 46 | fh = FileHandler(args.filename) 47 | 48 | env = MAWaterWorld(fh.train_args['n_pursuers'], 49 | fh.train_args['n_evaders'], 50 | fh.train_args['n_coop'], 51 | fh.train_args['n_poison'], 52 | n_sensors=fh.train_args['n_sensors'], 53 | food_reward=fh.train_args['food_reward'], 54 | poison_reward=fh.train_args['poison_reward'], 55 | reward_mech='global', 56 | encounter_reward=0, #fh.train_args['encounter_reward'], 57 | addid=True,) 58 | 59 | if fh.train_args['buffer_size'] > 1: 60 | env = ObservationBuffer(env, fh.train_args['buffer_size']) 61 | 62 | hpolicy = None 63 | if args.heuristic: 64 | from heuristics.waterworld import WaterworldHeuristicPolicy 65 | hpolicy = WaterworldHeuristicPolicy(env.agents[0].observation_space, 66 | env.agents[0].action_space) 67 | 68 | if args.evaluate: 69 | minion = Evaluator(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, 70 | 'heuristic' if args.heuristic else fh.mode) 71 | evr = minion(fh.filename, file_key=fh.file_key, same_con_pol=args.same_con_pol, 72 | hpolicy=hpolicy) 73 | from tabulate import tabulate 74 | print(tabulate(evr, headers='keys')) 75 | else: 76 | minion = Visualizer(env, fh.train_args, args.n_steps, args.n_trajs, args.deterministic, 77 | 'heuristic' if args.heuristic else fh.mode) 78 | rew, info = minion(fh.filename, file_key=fh.file_key, vid=args.vid, hpolicy=hpolicy) 79 | pprint.pprint(rew) 80 | pprint.pprint(info) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /vis/wilco2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: wilcoxon.py 4 | # 5 | import argparse 6 | import pickle 7 | import os.path 8 | from pprint import pprint 9 | from itertools import combinations 10 | import numpy as np 11 | from scipy.stats import wilcoxon 12 | 13 | 14 | def filter_outlier(data): 15 | new_dat = [] 16 | for d in data: 17 | if d < -50: 18 | d = -50 19 | new_dat.append(d) 20 | return new_dat 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('dir', type=str) 26 | parser.add_argument('--filter', action='store_true') 27 | parser.add_argument('--all', action='store_true') 28 | args = parser.parse_args() 29 | 30 | control_params = ['centralized', 'decentralized', 'concurrent'] 31 | nn_params = ['gru', 'mlp'] 32 | total_params = ['{}-{}'.format(cp, np) for cp in control_params 33 | for np in nn_params] + ['heuristic'] 34 | 35 | retlist = dict.fromkeys(total_params) 36 | for tp in total_params: 37 | pkl_file = os.path.join(args.dir, tp + '.pkl') 38 | try: 39 | with open(pkl_file, 'rb') as f: 40 | retl = pickle.load(f, encoding='latin1')['retlist'][:50] 41 | 42 | if args.filter: 43 | retl = filter_outlier(retl) 44 | 45 | retlist[tp] = { 46 | 'retl': retl, 47 | 'mean': np.mean(retl), 48 | 'std': np.std(retl) / np.sqrt(50), 49 | } 50 | except Exception as e: 51 | print(e) 52 | 53 | pvals = {} 54 | for p1, p2 in combinations(total_params, 2): 55 | if retlist[p1] and retlist[p2]: 56 | _, p_val = wilcoxon(retlist[p1]['retl'], retlist[p2]['retl']) 57 | pvals[p1, p2] = p_val 58 | 59 | with open(os.path.join(args.dir, 'results.pkl'), 'wb') as f: 60 | pickle.dump({'retlist': retlist, 'pvals': pvals}, f) 61 | 62 | for ret in retlist: 63 | if retlist[ret]: 64 | print('{}: {}, {}'.format(ret, retlist[ret]['mean'], retlist[ret]['std'])) 65 | if args.all: 66 | pprint(retlist[ret]['retl']) 67 | print('######') 68 | pprint(pvals) 69 | 70 | index = [] 71 | comp = [] 72 | for row in total_params: 73 | index.append(row) 74 | rlist = [] 75 | for col in total_params: 76 | if (row, col) in pvals: 77 | rlist.append(pvals[row, col]) 78 | else: 79 | rlist.append(None) 80 | comp.append(rlist) 81 | 82 | import pandas as pd 83 | pd.set_option('expand_frame_repr', False) 84 | print(pd.DataFrame(comp, index=index, columns=index)) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /vis/wilcoxon.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File: wilcoxon.py 4 | # 5 | import argparse 6 | import pickle 7 | from pprint import pprint 8 | import numpy as np 9 | from scipy.stats import wilcoxon 10 | 11 | 12 | def filter_outlier(data): 13 | new_dat = [] 14 | for d in data: 15 | if d < -50: 16 | d = -50 17 | new_dat.append(d) 18 | return new_dat 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('file1', type=str) 24 | parser.add_argument('file2', type=str) 25 | parser.add_argument('--filter', action='store_true') 26 | args = parser.parse_args() 27 | 28 | with open(args.file1, 'rb') as f: 29 | retlist1 = pickle.load(f, encoding='latin1')['retlist'][:50] 30 | if args.filter: 31 | retlist1 = filter_outlier(retlist1) 32 | 33 | with open(args.file2, 'rb') as f: 34 | retlist2 = pickle.load(f, encoding='latin1')['retlist'][:50] 35 | if args.filter: 36 | retlist2 = filter_outlier(retlist2) 37 | 38 | mean1 = np.mean(retlist1) 39 | std1 = np.std(retlist1) / np.sqrt(50) 40 | mean2 = np.mean(retlist2) 41 | std2 = np.std(retlist2) / np.sqrt(50) 42 | z_stat, p_val = wilcoxon(retlist1, retlist2) 43 | 44 | pprint({ 45 | args.file1: { 46 | 'mean': mean1, 47 | 'std': std1 48 | }, 49 | args.file2: { 50 | 'mean': mean2, 51 | 'std': std2 52 | }, 53 | 'p_val': p_val 54 | }) 55 | import ipdb 56 | ipdb.set_trace() 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | --------------------------------------------------------------------------------