├── .gitignore ├── LICENSE.txt ├── README.md ├── experiments └── train.py ├── maddpg ├── __init__.py ├── common │ ├── distributions.py │ └── tf_util.py └── trainer │ ├── maddpg.py │ └── replay_buffer.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | .static_storage/ 57 | .media/ 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | # Multi-Agent Deep Deterministic Policy Gradient (MADDPG) 4 | 5 | This is the code for implementing the MADDPG algorithm presented in the paper: 6 | [Multi-Agent Actor-Critic for Mixed Cooperative-Competitive Environments](https://arxiv.org/pdf/1706.02275.pdf). 7 | It is configured to be run in conjunction with environments from the 8 | [Multi-Agent Particle Environments (MPE)](https://github.com/openai/multiagent-particle-envs). 9 | Note: this codebase has been restructured since the original paper, and the results may 10 | vary from those reported in the paper. 11 | 12 | **Update:** the original implementation for policy ensemble and policy estimation can be found [here](https://www.dropbox.com/s/jlc6dtxo580lpl2/maddpg_ensemble_and_approx_code.zip?dl=0). The code is provided as-is. 13 | 14 | ## Installation 15 | 16 | - To install, `cd` into the root directory and type `pip install -e .` 17 | 18 | - Known dependencies: Python (3.5.4), OpenAI gym (0.10.5), tensorflow (1.8.0), numpy (1.14.5) 19 | 20 | ## Case study: Multi-Agent Particle Environments 21 | 22 | We demonstrate here how the code can be used in conjunction with the 23 | [Multi-Agent Particle Environments (MPE)](https://github.com/openai/multiagent-particle-envs). 24 | 25 | - Download and install the MPE code [here](https://github.com/openai/multiagent-particle-envs) 26 | by following the `README`. 27 | 28 | - Ensure that `multiagent-particle-envs` has been added to your `PYTHONPATH` (e.g. in `~/.bashrc` or `~/.bash_profile`). 29 | 30 | - To run the code, `cd` into the `experiments` directory and run `train.py`: 31 | 32 | ``python train.py --scenario simple`` 33 | 34 | - You can replace `simple` with any environment in the MPE you'd like to run. 35 | 36 | ## Command-line options 37 | 38 | ### Environment options 39 | 40 | - `--scenario`: defines which environment in the MPE is to be used (default: `"simple"`) 41 | 42 | - `--max-episode-len` maximum length of each episode for the environment (default: `25`) 43 | 44 | - `--num-episodes` total number of training episodes (default: `60000`) 45 | 46 | - `--num-adversaries`: number of adversaries in the environment (default: `0`) 47 | 48 | - `--good-policy`: algorithm used for the 'good' (non adversary) policies in the environment 49 | (default: `"maddpg"`; options: {`"maddpg"`, `"ddpg"`}) 50 | 51 | - `--adv-policy`: algorithm used for the adversary policies in the environment 52 | (default: `"maddpg"`; options: {`"maddpg"`, `"ddpg"`}) 53 | 54 | ### Core training parameters 55 | 56 | - `--lr`: learning rate (default: `1e-2`) 57 | 58 | - `--gamma`: discount factor (default: `0.95`) 59 | 60 | - `--batch-size`: batch size (default: `1024`) 61 | 62 | - `--num-units`: number of units in the MLP (default: `64`) 63 | 64 | ### Checkpointing 65 | 66 | - `--exp-name`: name of the experiment, used as the file name to save all results (default: `None`) 67 | 68 | - `--save-dir`: directory where intermediate training results and model will be saved (default: `"/tmp/policy/"`) 69 | 70 | - `--save-rate`: model is saved every time this number of episodes has been completed (default: `1000`) 71 | 72 | - `--load-dir`: directory where training state and model are loaded from (default: `""`) 73 | 74 | ### Evaluation 75 | 76 | - `--restore`: restores previous training state stored in `load-dir` (or in `save-dir` if no `load-dir` 77 | has been provided), and continues training (default: `False`) 78 | 79 | - `--display`: displays to the screen the trained policy stored in `load-dir` (or in `save-dir` if no `load-dir` 80 | has been provided), but does not continue training (default: `False`) 81 | 82 | - `--benchmark`: runs benchmarking evaluations on saved policy, saves results to `benchmark-dir` folder (default: `False`) 83 | 84 | - `--benchmark-iters`: number of iterations to run benchmarking for (default: `100000`) 85 | 86 | - `--benchmark-dir`: directory where benchmarking data is saved (default: `"./benchmark_files/"`) 87 | 88 | - `--plots-dir`: directory where training curves are saved (default: `"./learning_curves/"`) 89 | 90 | ## Code structure 91 | 92 | - `./experiments/train.py`: contains code for training MADDPG on the MPE 93 | 94 | - `./maddpg/trainer/maddpg.py`: core code for the MADDPG algorithm 95 | 96 | - `./maddpg/trainer/replay_buffer.py`: replay buffer code for MADDPG 97 | 98 | - `./maddpg/common/distributions.py`: useful distributions used in `maddpg.py` 99 | 100 | - `./maddpg/common/tf_util.py`: useful tensorflow functions used in `maddpg.py` 101 | 102 | 103 | 104 | ## Paper citation 105 | 106 | If you used this code for your experiments or found it helpful, consider citing the following paper: 107 | 108 |
109 | @article{lowe2017multi,
110 |   title={Multi-Agent Actor-Critic for Mixed Cooperative-Competitive Environments},
111 |   author={Lowe, Ryan and Wu, Yi and Tamar, Aviv and Harb, Jean and Abbeel, Pieter and Mordatch, Igor},
112 |   journal={Neural Information Processing Systems (NIPS)},
113 |   year={2017}
114 | }
115 | 
116 | -------------------------------------------------------------------------------- /experiments/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tensorflow as tf 4 | import time 5 | import pickle 6 | 7 | import maddpg.common.tf_util as U 8 | from maddpg.trainer.maddpg import MADDPGAgentTrainer 9 | import tensorflow.contrib.layers as layers 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser("Reinforcement Learning experiments for multiagent environments") 13 | # Environment 14 | parser.add_argument("--scenario", type=str, default="simple", help="name of the scenario script") 15 | parser.add_argument("--max-episode-len", type=int, default=25, help="maximum episode length") 16 | parser.add_argument("--num-episodes", type=int, default=60000, help="number of episodes") 17 | parser.add_argument("--num-adversaries", type=int, default=0, help="number of adversaries") 18 | parser.add_argument("--good-policy", type=str, default="maddpg", help="policy for good agents") 19 | parser.add_argument("--adv-policy", type=str, default="maddpg", help="policy of adversaries") 20 | # Core training parameters 21 | parser.add_argument("--lr", type=float, default=1e-2, help="learning rate for Adam optimizer") 22 | parser.add_argument("--gamma", type=float, default=0.95, help="discount factor") 23 | parser.add_argument("--batch-size", type=int, default=1024, help="number of episodes to optimize at the same time") 24 | parser.add_argument("--num-units", type=int, default=64, help="number of units in the mlp") 25 | # Checkpointing 26 | parser.add_argument("--exp-name", type=str, default=None, help="name of the experiment") 27 | parser.add_argument("--save-dir", type=str, default="/tmp/policy/", help="directory in which training state and model should be saved") 28 | parser.add_argument("--save-rate", type=int, default=1000, help="save model once every time this many episodes are completed") 29 | parser.add_argument("--load-dir", type=str, default="", help="directory in which training state and model are loaded") 30 | # Evaluation 31 | parser.add_argument("--restore", action="store_true", default=False) 32 | parser.add_argument("--display", action="store_true", default=False) 33 | parser.add_argument("--benchmark", action="store_true", default=False) 34 | parser.add_argument("--benchmark-iters", type=int, default=100000, help="number of iterations run for benchmarking") 35 | parser.add_argument("--benchmark-dir", type=str, default="./benchmark_files/", help="directory where benchmark data is saved") 36 | parser.add_argument("--plots-dir", type=str, default="./learning_curves/", help="directory where plot data is saved") 37 | return parser.parse_args() 38 | 39 | def mlp_model(input, num_outputs, scope, reuse=False, num_units=64, rnn_cell=None): 40 | # This model takes as input an observation and returns values of all actions 41 | with tf.variable_scope(scope, reuse=reuse): 42 | out = input 43 | out = layers.fully_connected(out, num_outputs=num_units, activation_fn=tf.nn.relu) 44 | out = layers.fully_connected(out, num_outputs=num_units, activation_fn=tf.nn.relu) 45 | out = layers.fully_connected(out, num_outputs=num_outputs, activation_fn=None) 46 | return out 47 | 48 | def make_env(scenario_name, arglist, benchmark=False): 49 | from multiagent.environment import MultiAgentEnv 50 | import multiagent.scenarios as scenarios 51 | 52 | # load scenario from script 53 | scenario = scenarios.load(scenario_name + ".py").Scenario() 54 | # create world 55 | world = scenario.make_world() 56 | # create multiagent environment 57 | if benchmark: 58 | env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation, scenario.benchmark_data) 59 | else: 60 | env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation) 61 | return env 62 | 63 | def get_trainers(env, num_adversaries, obs_shape_n, arglist): 64 | trainers = [] 65 | model = mlp_model 66 | trainer = MADDPGAgentTrainer 67 | for i in range(num_adversaries): 68 | trainers.append(trainer( 69 | "agent_%d" % i, model, obs_shape_n, env.action_space, i, arglist, 70 | local_q_func=(arglist.adv_policy=='ddpg'))) 71 | for i in range(num_adversaries, env.n): 72 | trainers.append(trainer( 73 | "agent_%d" % i, model, obs_shape_n, env.action_space, i, arglist, 74 | local_q_func=(arglist.good_policy=='ddpg'))) 75 | return trainers 76 | 77 | 78 | def train(arglist): 79 | with U.single_threaded_session(): 80 | # Create environment 81 | env = make_env(arglist.scenario, arglist, arglist.benchmark) 82 | # Create agent trainers 83 | obs_shape_n = [env.observation_space[i].shape for i in range(env.n)] 84 | num_adversaries = min(env.n, arglist.num_adversaries) 85 | trainers = get_trainers(env, num_adversaries, obs_shape_n, arglist) 86 | print('Using good policy {} and adv policy {}'.format(arglist.good_policy, arglist.adv_policy)) 87 | 88 | # Initialize 89 | U.initialize() 90 | 91 | # Load previous results, if necessary 92 | if arglist.load_dir == "": 93 | arglist.load_dir = arglist.save_dir 94 | if arglist.display or arglist.restore or arglist.benchmark: 95 | print('Loading previous state...') 96 | U.load_state(arglist.load_dir) 97 | 98 | episode_rewards = [0.0] # sum of rewards for all agents 99 | agent_rewards = [[0.0] for _ in range(env.n)] # individual agent reward 100 | final_ep_rewards = [] # sum of rewards for training curve 101 | final_ep_ag_rewards = [] # agent rewards for training curve 102 | agent_info = [[[]]] # placeholder for benchmarking info 103 | saver = tf.train.Saver() 104 | obs_n = env.reset() 105 | episode_step = 0 106 | train_step = 0 107 | t_start = time.time() 108 | 109 | print('Starting iterations...') 110 | while True: 111 | # get action 112 | action_n = [agent.action(obs) for agent, obs in zip(trainers,obs_n)] 113 | # environment step 114 | new_obs_n, rew_n, done_n, info_n = env.step(action_n) 115 | episode_step += 1 116 | done = all(done_n) 117 | terminal = (episode_step >= arglist.max_episode_len) 118 | # collect experience 119 | for i, agent in enumerate(trainers): 120 | agent.experience(obs_n[i], action_n[i], rew_n[i], new_obs_n[i], done_n[i], terminal) 121 | obs_n = new_obs_n 122 | 123 | for i, rew in enumerate(rew_n): 124 | episode_rewards[-1] += rew 125 | agent_rewards[i][-1] += rew 126 | 127 | if done or terminal: 128 | obs_n = env.reset() 129 | episode_step = 0 130 | episode_rewards.append(0) 131 | for a in agent_rewards: 132 | a.append(0) 133 | agent_info.append([[]]) 134 | 135 | # increment global step counter 136 | train_step += 1 137 | 138 | # for benchmarking learned policies 139 | if arglist.benchmark: 140 | for i, info in enumerate(info_n): 141 | agent_info[-1][i].append(info_n['n']) 142 | if train_step > arglist.benchmark_iters and (done or terminal): 143 | file_name = arglist.benchmark_dir + arglist.exp_name + '.pkl' 144 | print('Finished benchmarking, now saving...') 145 | with open(file_name, 'wb') as fp: 146 | pickle.dump(agent_info[:-1], fp) 147 | break 148 | continue 149 | 150 | # for displaying learned policies 151 | if arglist.display: 152 | time.sleep(0.1) 153 | env.render() 154 | continue 155 | 156 | # update all trainers, if not in display or benchmark mode 157 | loss = None 158 | for agent in trainers: 159 | agent.preupdate() 160 | for agent in trainers: 161 | loss = agent.update(trainers, train_step) 162 | 163 | # save model, display training output 164 | if terminal and (len(episode_rewards) % arglist.save_rate == 0): 165 | U.save_state(arglist.save_dir, saver=saver) 166 | # print statement depends on whether or not there are adversaries 167 | if num_adversaries == 0: 168 | print("steps: {}, episodes: {}, mean episode reward: {}, time: {}".format( 169 | train_step, len(episode_rewards), np.mean(episode_rewards[-arglist.save_rate:]), round(time.time()-t_start, 3))) 170 | else: 171 | print("steps: {}, episodes: {}, mean episode reward: {}, agent episode reward: {}, time: {}".format( 172 | train_step, len(episode_rewards), np.mean(episode_rewards[-arglist.save_rate:]), 173 | [np.mean(rew[-arglist.save_rate:]) for rew in agent_rewards], round(time.time()-t_start, 3))) 174 | t_start = time.time() 175 | # Keep track of final episode reward 176 | final_ep_rewards.append(np.mean(episode_rewards[-arglist.save_rate:])) 177 | for rew in agent_rewards: 178 | final_ep_ag_rewards.append(np.mean(rew[-arglist.save_rate:])) 179 | 180 | # saves final episode reward for plotting training curve later 181 | if len(episode_rewards) > arglist.num_episodes: 182 | rew_file_name = arglist.plots_dir + arglist.exp_name + '_rewards.pkl' 183 | with open(rew_file_name, 'wb') as fp: 184 | pickle.dump(final_ep_rewards, fp) 185 | agrew_file_name = arglist.plots_dir + arglist.exp_name + '_agrewards.pkl' 186 | with open(agrew_file_name, 'wb') as fp: 187 | pickle.dump(final_ep_ag_rewards, fp) 188 | print('...Finished total of {} episodes.'.format(len(episode_rewards))) 189 | break 190 | 191 | if __name__ == '__main__': 192 | arglist = parse_args() 193 | train(arglist) 194 | -------------------------------------------------------------------------------- /maddpg/__init__.py: -------------------------------------------------------------------------------- 1 | class AgentTrainer(object): 2 | def __init__(self, name, model, obs_shape, act_space, args): 3 | raise NotImplemented() 4 | 5 | def action(self, obs): 6 | raise NotImplemented() 7 | 8 | def process_experience(self, obs, act, rew, new_obs, done, terminal): 9 | raise NotImplemented() 10 | 11 | def preupdate(self): 12 | raise NotImplemented() 13 | 14 | def update(self, agents): 15 | raise NotImplemented() -------------------------------------------------------------------------------- /maddpg/common/distributions.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import maddpg.common.tf_util as U 4 | from tensorflow.python.ops import math_ops 5 | from multiagent.multi_discrete import MultiDiscrete 6 | from tensorflow.python.ops import nn 7 | 8 | class Pd(object): 9 | """ 10 | A particular probability distribution 11 | """ 12 | def flatparam(self): 13 | raise NotImplementedError 14 | def mode(self): 15 | raise NotImplementedError 16 | def logp(self, x): 17 | raise NotImplementedError 18 | def kl(self, other): 19 | raise NotImplementedError 20 | def entropy(self): 21 | raise NotImplementedError 22 | def sample(self): 23 | raise NotImplementedError 24 | 25 | class PdType(object): 26 | """ 27 | Parametrized family of probability distributions 28 | """ 29 | def pdclass(self): 30 | raise NotImplementedError 31 | def pdfromflat(self, flat): 32 | return self.pdclass()(flat) 33 | def param_shape(self): 34 | raise NotImplementedError 35 | def sample_shape(self): 36 | raise NotImplementedError 37 | def sample_dtype(self): 38 | raise NotImplementedError 39 | 40 | def param_placeholder(self, prepend_shape, name=None): 41 | return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name) 42 | def sample_placeholder(self, prepend_shape, name=None): 43 | return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name) 44 | 45 | class CategoricalPdType(PdType): 46 | def __init__(self, ncat): 47 | self.ncat = ncat 48 | def pdclass(self): 49 | return CategoricalPd 50 | def param_shape(self): 51 | return [self.ncat] 52 | def sample_shape(self): 53 | return [] 54 | def sample_dtype(self): 55 | return tf.int32 56 | 57 | class SoftCategoricalPdType(PdType): 58 | def __init__(self, ncat): 59 | self.ncat = ncat 60 | def pdclass(self): 61 | return SoftCategoricalPd 62 | def param_shape(self): 63 | return [self.ncat] 64 | def sample_shape(self): 65 | return [self.ncat] 66 | def sample_dtype(self): 67 | return tf.float32 68 | 69 | class MultiCategoricalPdType(PdType): 70 | def __init__(self, low, high): 71 | self.low = low 72 | self.high = high 73 | self.ncats = high - low + 1 74 | def pdclass(self): 75 | return MultiCategoricalPd 76 | def pdfromflat(self, flat): 77 | return MultiCategoricalPd(self.low, self.high, flat) 78 | def param_shape(self): 79 | return [sum(self.ncats)] 80 | def sample_shape(self): 81 | return [len(self.ncats)] 82 | def sample_dtype(self): 83 | return tf.int32 84 | 85 | class SoftMultiCategoricalPdType(PdType): 86 | def __init__(self, low, high): 87 | self.low = low 88 | self.high = high 89 | self.ncats = high - low + 1 90 | def pdclass(self): 91 | return SoftMultiCategoricalPd 92 | def pdfromflat(self, flat): 93 | return SoftMultiCategoricalPd(self.low, self.high, flat) 94 | def param_shape(self): 95 | return [sum(self.ncats)] 96 | def sample_shape(self): 97 | return [sum(self.ncats)] 98 | def sample_dtype(self): 99 | return tf.float32 100 | 101 | class DiagGaussianPdType(PdType): 102 | def __init__(self, size): 103 | self.size = size 104 | def pdclass(self): 105 | return DiagGaussianPd 106 | def param_shape(self): 107 | return [2*self.size] 108 | def sample_shape(self): 109 | return [self.size] 110 | def sample_dtype(self): 111 | return tf.float32 112 | 113 | class BernoulliPdType(PdType): 114 | def __init__(self, size): 115 | self.size = size 116 | def pdclass(self): 117 | return BernoulliPd 118 | def param_shape(self): 119 | return [self.size] 120 | def sample_shape(self): 121 | return [self.size] 122 | def sample_dtype(self): 123 | return tf.int32 124 | 125 | # WRONG SECOND DERIVATIVES 126 | # class CategoricalPd(Pd): 127 | # def __init__(self, logits): 128 | # self.logits = logits 129 | # self.ps = tf.nn.softmax(logits) 130 | # @classmethod 131 | # def fromflat(cls, flat): 132 | # return cls(flat) 133 | # def flatparam(self): 134 | # return self.logits 135 | # def mode(self): 136 | # return U.argmax(self.logits, axis=1) 137 | # def logp(self, x): 138 | # return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x) 139 | # def kl(self, other): 140 | # return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \ 141 | # - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps) 142 | # def entropy(self): 143 | # return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps) 144 | # def sample(self): 145 | # u = tf.random_uniform(tf.shape(self.logits)) 146 | # return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1) 147 | 148 | class CategoricalPd(Pd): 149 | def __init__(self, logits): 150 | self.logits = logits 151 | def flatparam(self): 152 | return self.logits 153 | def mode(self): 154 | return U.argmax(self.logits, axis=1) 155 | def logp(self, x): 156 | return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x) 157 | def kl(self, other): 158 | a0 = self.logits - U.max(self.logits, axis=1, keepdims=True) 159 | a1 = other.logits - U.max(other.logits, axis=1, keepdims=True) 160 | ea0 = tf.exp(a0) 161 | ea1 = tf.exp(a1) 162 | z0 = U.sum(ea0, axis=1, keepdims=True) 163 | z1 = U.sum(ea1, axis=1, keepdims=True) 164 | p0 = ea0 / z0 165 | return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1) 166 | def entropy(self): 167 | a0 = self.logits - U.max(self.logits, axis=1, keepdims=True) 168 | ea0 = tf.exp(a0) 169 | z0 = U.sum(ea0, axis=1, keepdims=True) 170 | p0 = ea0 / z0 171 | return U.sum(p0 * (tf.log(z0) - a0), axis=1) 172 | def sample(self): 173 | u = tf.random_uniform(tf.shape(self.logits)) 174 | return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1) 175 | @classmethod 176 | def fromflat(cls, flat): 177 | return cls(flat) 178 | 179 | class SoftCategoricalPd(Pd): 180 | def __init__(self, logits): 181 | self.logits = logits 182 | def flatparam(self): 183 | return self.logits 184 | def mode(self): 185 | return U.softmax(self.logits, axis=-1) 186 | def logp(self, x): 187 | return -tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=x) 188 | def kl(self, other): 189 | a0 = self.logits - U.max(self.logits, axis=1, keepdims=True) 190 | a1 = other.logits - U.max(other.logits, axis=1, keepdims=True) 191 | ea0 = tf.exp(a0) 192 | ea1 = tf.exp(a1) 193 | z0 = U.sum(ea0, axis=1, keepdims=True) 194 | z1 = U.sum(ea1, axis=1, keepdims=True) 195 | p0 = ea0 / z0 196 | return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1) 197 | def entropy(self): 198 | a0 = self.logits - U.max(self.logits, axis=1, keepdims=True) 199 | ea0 = tf.exp(a0) 200 | z0 = U.sum(ea0, axis=1, keepdims=True) 201 | p0 = ea0 / z0 202 | return U.sum(p0 * (tf.log(z0) - a0), axis=1) 203 | def sample(self): 204 | u = tf.random_uniform(tf.shape(self.logits)) 205 | return U.softmax(self.logits - tf.log(-tf.log(u)), axis=-1) 206 | @classmethod 207 | def fromflat(cls, flat): 208 | return cls(flat) 209 | 210 | class MultiCategoricalPd(Pd): 211 | def __init__(self, low, high, flat): 212 | self.flat = flat 213 | self.low = tf.constant(low, dtype=tf.int32) 214 | self.categoricals = list(map(CategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1))) 215 | def flatparam(self): 216 | return self.flat 217 | def mode(self): 218 | return self.low + tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32) 219 | def logp(self, x): 220 | return tf.add_n([p.logp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))]) 221 | def kl(self, other): 222 | return tf.add_n([ 223 | p.kl(q) for p, q in zip(self.categoricals, other.categoricals) 224 | ]) 225 | def entropy(self): 226 | return tf.add_n([p.entropy() for p in self.categoricals]) 227 | def sample(self): 228 | return self.low + tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32) 229 | @classmethod 230 | def fromflat(cls, flat): 231 | return cls(flat) 232 | 233 | class SoftMultiCategoricalPd(Pd): # doesn't work yet 234 | def __init__(self, low, high, flat): 235 | self.flat = flat 236 | self.low = tf.constant(low, dtype=tf.float32) 237 | self.categoricals = list(map(SoftCategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1))) 238 | def flatparam(self): 239 | return self.flat 240 | def mode(self): 241 | x = [] 242 | for i in range(len(self.categoricals)): 243 | x.append(self.low[i] + self.categoricals[i].mode()) 244 | return tf.concat(x, axis=-1) 245 | def logp(self, x): 246 | return tf.add_n([p.logp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))]) 247 | def kl(self, other): 248 | return tf.add_n([ 249 | p.kl(q) for p, q in zip(self.categoricals, other.categoricals) 250 | ]) 251 | def entropy(self): 252 | return tf.add_n([p.entropy() for p in self.categoricals]) 253 | def sample(self): 254 | x = [] 255 | for i in range(len(self.categoricals)): 256 | x.append(self.low[i] + self.categoricals[i].sample()) 257 | return tf.concat(x, axis=-1) 258 | @classmethod 259 | def fromflat(cls, flat): 260 | return cls(flat) 261 | 262 | class DiagGaussianPd(Pd): 263 | def __init__(self, flat): 264 | self.flat = flat 265 | mean, logstd = tf.split(axis=1, num_or_size_splits=2, value=flat) 266 | self.mean = mean 267 | self.logstd = logstd 268 | self.std = tf.exp(logstd) 269 | def flatparam(self): 270 | return self.flat 271 | def mode(self): 272 | return self.mean 273 | def logp(self, x): 274 | return - 0.5 * U.sum(tf.square((x - self.mean) / self.std), axis=1) \ 275 | - 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) \ 276 | - U.sum(self.logstd, axis=1) 277 | def kl(self, other): 278 | assert isinstance(other, DiagGaussianPd) 279 | return U.sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=1) 280 | def entropy(self): 281 | return U.sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), 1) 282 | def sample(self): 283 | return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) 284 | @classmethod 285 | def fromflat(cls, flat): 286 | return cls(flat) 287 | 288 | class BernoulliPd(Pd): 289 | def __init__(self, logits): 290 | self.logits = logits 291 | self.ps = tf.sigmoid(logits) 292 | def flatparam(self): 293 | return self.logits 294 | def mode(self): 295 | return tf.round(self.ps) 296 | def logp(self, x): 297 | return - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=1) 298 | def kl(self, other): 299 | return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=1) - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1) 300 | def entropy(self): 301 | return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1) 302 | def sample(self): 303 | p = tf.sigmoid(self.logits) 304 | u = tf.random_uniform(tf.shape(p)) 305 | return tf.to_float(math_ops.less(u, p)) 306 | @classmethod 307 | def fromflat(cls, flat): 308 | return cls(flat) 309 | 310 | def make_pdtype(ac_space): 311 | from gym import spaces 312 | if isinstance(ac_space, spaces.Box): 313 | assert len(ac_space.shape) == 1 314 | return DiagGaussianPdType(ac_space.shape[0]) 315 | elif isinstance(ac_space, spaces.Discrete): 316 | # return CategoricalPdType(ac_space.n) 317 | return SoftCategoricalPdType(ac_space.n) 318 | elif isinstance(ac_space, MultiDiscrete): 319 | #return MultiCategoricalPdType(ac_space.low, ac_space.high) 320 | return SoftMultiCategoricalPdType(ac_space.low, ac_space.high) 321 | elif isinstance(ac_space, spaces.MultiBinary): 322 | return BernoulliPdType(ac_space.n) 323 | else: 324 | raise NotImplementedError 325 | 326 | def shape_el(v, i): 327 | maybe = v.get_shape()[i] 328 | if maybe is not None: 329 | return maybe 330 | else: 331 | return tf.shape(v)[i] 332 | -------------------------------------------------------------------------------- /maddpg/common/tf_util.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import os 4 | import tensorflow as tf 5 | 6 | def sum(x, axis=None, keepdims=False): 7 | return tf.reduce_sum(x, axis=None if axis is None else [axis], keep_dims = keepdims) 8 | def mean(x, axis=None, keepdims=False): 9 | return tf.reduce_mean(x, axis=None if axis is None else [axis], keep_dims = keepdims) 10 | def var(x, axis=None, keepdims=False): 11 | meanx = mean(x, axis=axis, keepdims=keepdims) 12 | return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims) 13 | def std(x, axis=None, keepdims=False): 14 | return tf.sqrt(var(x, axis=axis, keepdims=keepdims)) 15 | def max(x, axis=None, keepdims=False): 16 | return tf.reduce_max(x, axis=None if axis is None else [axis], keep_dims = keepdims) 17 | def min(x, axis=None, keepdims=False): 18 | return tf.reduce_min(x, axis=None if axis is None else [axis], keep_dims = keepdims) 19 | def concatenate(arrs, axis=0): 20 | return tf.concat(axis=axis, values=arrs) 21 | def argmax(x, axis=None): 22 | return tf.argmax(x, axis=axis) 23 | def softmax(x, axis=None): 24 | return tf.nn.softmax(x, axis=axis) 25 | 26 | # ================================================================ 27 | # Misc 28 | # ================================================================ 29 | 30 | 31 | def is_placeholder(x): 32 | return type(x) is tf.Tensor and len(x.op.inputs) == 0 33 | 34 | # ================================================================ 35 | # Inputs 36 | # ================================================================ 37 | 38 | 39 | class TfInput(object): 40 | def __init__(self, name="(unnamed)"): 41 | """Generalized Tensorflow placeholder. The main differences are: 42 | - possibly uses multiple placeholders internally and returns multiple values 43 | - can apply light postprocessing to the value feed to placeholder. 44 | """ 45 | self.name = name 46 | 47 | def get(self): 48 | """Return the tf variable(s) representing the possibly postprocessed value 49 | of placeholder(s). 50 | """ 51 | raise NotImplemented() 52 | 53 | def make_feed_dict(data): 54 | """Given data input it to the placeholder(s).""" 55 | raise NotImplemented() 56 | 57 | 58 | class PlacholderTfInput(TfInput): 59 | def __init__(self, placeholder): 60 | """Wrapper for regular tensorflow placeholder.""" 61 | super().__init__(placeholder.name) 62 | self._placeholder = placeholder 63 | 64 | def get(self): 65 | return self._placeholder 66 | 67 | def make_feed_dict(self, data): 68 | return {self._placeholder: data} 69 | 70 | 71 | class BatchInput(PlacholderTfInput): 72 | def __init__(self, shape, dtype=tf.float32, name=None): 73 | """Creates a placeholder for a batch of tensors of a given shape and dtype 74 | 75 | Parameters 76 | ---------- 77 | shape: [int] 78 | shape of a single elemenet of the batch 79 | dtype: tf.dtype 80 | number representation used for tensor contents 81 | name: str 82 | name of the underlying placeholder 83 | """ 84 | super().__init__(tf.placeholder(dtype, [None] + list(shape), name=name)) 85 | 86 | 87 | class Uint8Input(PlacholderTfInput): 88 | def __init__(self, shape, name=None): 89 | """Takes input in uint8 format which is cast to float32 and divided by 255 90 | before passing it to the model. 91 | 92 | On GPU this ensures lower data transfer times. 93 | 94 | Parameters 95 | ---------- 96 | shape: [int] 97 | shape of the tensor. 98 | name: str 99 | name of the underlying placeholder 100 | """ 101 | 102 | super().__init__(tf.placeholder(tf.uint8, [None] + list(shape), name=name)) 103 | self._shape = shape 104 | self._output = tf.cast(super().get(), tf.float32) / 255.0 105 | 106 | def get(self): 107 | return self._output 108 | 109 | 110 | def ensure_tf_input(thing): 111 | """Takes either tf.placeholder of TfInput and outputs equivalent TfInput""" 112 | if isinstance(thing, TfInput): 113 | return thing 114 | elif is_placeholder(thing): 115 | return PlacholderTfInput(thing) 116 | else: 117 | raise ValueError("Must be a placeholder or TfInput") 118 | 119 | # ================================================================ 120 | # Mathematical utils 121 | # ================================================================ 122 | 123 | 124 | def huber_loss(x, delta=1.0): 125 | """Reference: https://en.wikipedia.org/wiki/Huber_loss""" 126 | return tf.where( 127 | tf.abs(x) < delta, 128 | tf.square(x) * 0.5, 129 | delta * (tf.abs(x) - 0.5 * delta) 130 | ) 131 | 132 | # ================================================================ 133 | # Optimizer utils 134 | # ================================================================ 135 | 136 | 137 | def minimize_and_clip(optimizer, objective, var_list, clip_val=10): 138 | """Minimized `objective` using `optimizer` w.r.t. variables in 139 | `var_list` while ensure the norm of the gradients for each 140 | variable is clipped to `clip_val` 141 | """ 142 | if clip_val is None: 143 | return optimizer.minimize(objective, var_list=var_list) 144 | else: 145 | gradients = optimizer.compute_gradients(objective, var_list=var_list) 146 | for i, (grad, var) in enumerate(gradients): 147 | if grad is not None: 148 | gradients[i] = (tf.clip_by_norm(grad, clip_val), var) 149 | return optimizer.apply_gradients(gradients) 150 | 151 | 152 | # ================================================================ 153 | # Global session 154 | # ================================================================ 155 | 156 | def get_session(): 157 | """Returns recently made Tensorflow session""" 158 | return tf.get_default_session() 159 | 160 | 161 | def make_session(num_cpu): 162 | """Returns a session that will use CPU's only""" 163 | tf_config = tf.ConfigProto( 164 | inter_op_parallelism_threads=num_cpu, 165 | intra_op_parallelism_threads=num_cpu) 166 | return tf.Session(config=tf_config) 167 | 168 | 169 | def single_threaded_session(): 170 | """Returns a session which will only use a single CPU""" 171 | return make_session(1) 172 | 173 | 174 | ALREADY_INITIALIZED = set() 175 | 176 | 177 | def initialize(): 178 | """Initialize all the uninitialized variables in the global scope.""" 179 | new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED 180 | get_session().run(tf.variables_initializer(new_variables)) 181 | ALREADY_INITIALIZED.update(new_variables) 182 | 183 | 184 | # ================================================================ 185 | # Scopes 186 | # ================================================================ 187 | 188 | 189 | def scope_vars(scope, trainable_only=False): 190 | """ 191 | Get variables inside a scope 192 | The scope can be specified as a string 193 | 194 | Parameters 195 | ---------- 196 | scope: str or VariableScope 197 | scope in which the variables reside. 198 | trainable_only: bool 199 | whether or not to return only the variables that were marked as trainable. 200 | 201 | Returns 202 | ------- 203 | vars: [tf.Variable] 204 | list of variables in `scope`. 205 | """ 206 | return tf.get_collection( 207 | tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.GLOBAL_VARIABLES, 208 | scope=scope if isinstance(scope, str) else scope.name 209 | ) 210 | 211 | 212 | def scope_name(): 213 | """Returns the name of current scope as a string, e.g. deepq/q_func""" 214 | return tf.get_variable_scope().name 215 | 216 | 217 | def absolute_scope_name(relative_scope_name): 218 | """Appends parent scope name to `relative_scope_name`""" 219 | return scope_name() + "/" + relative_scope_name 220 | 221 | # ================================================================ 222 | # Saving variables 223 | # ================================================================ 224 | 225 | 226 | def load_state(fname, saver=None): 227 | """Load all the variables to the current session from the location """ 228 | if saver is None: 229 | saver = tf.train.Saver() 230 | saver.restore(get_session(), fname) 231 | return saver 232 | 233 | 234 | def save_state(fname, saver=None): 235 | """Save all the variables in the current session to the location """ 236 | os.makedirs(os.path.dirname(fname), exist_ok=True) 237 | if saver is None: 238 | saver = tf.train.Saver() 239 | saver.save(get_session(), fname) 240 | return saver 241 | 242 | # ================================================================ 243 | # Theano-like Function 244 | # ================================================================ 245 | 246 | 247 | def function(inputs, outputs, updates=None, givens=None): 248 | """Just like Theano function. Take a bunch of tensorflow placeholders and expersions 249 | computed based on those placeholders and produces f(inputs) -> outputs. Function f takes 250 | values to be feed to the inputs placeholders and produces the values of the experessions 251 | in outputs. 252 | 253 | Input values can be passed in the same order as inputs or can be provided as kwargs based 254 | on placeholder name (passed to constructor or accessible via placeholder.op.name). 255 | 256 | Example: 257 | x = tf.placeholder(tf.int32, (), name="x") 258 | y = tf.placeholder(tf.int32, (), name="y") 259 | z = 3 * x + 2 * y 260 | lin = function([x, y], z, givens={y: 0}) 261 | 262 | with single_threaded_session(): 263 | initialize() 264 | 265 | assert lin(2) == 6 266 | assert lin(x=3) == 9 267 | assert lin(2, 2) == 10 268 | assert lin(x=2, y=3) == 12 269 | 270 | Parameters 271 | ---------- 272 | inputs: [tf.placeholder or TfInput] 273 | list of input arguments 274 | outputs: [tf.Variable] or tf.Variable 275 | list of outputs or a single output to be returned from function. Returned 276 | value will also have the same shape. 277 | """ 278 | if isinstance(outputs, list): 279 | return _Function(inputs, outputs, updates, givens=givens) 280 | elif isinstance(outputs, (dict, collections.OrderedDict)): 281 | f = _Function(inputs, outputs.values(), updates, givens=givens) 282 | return lambda *args, **kwargs: type(outputs)(zip(outputs.keys(), f(*args, **kwargs))) 283 | else: 284 | f = _Function(inputs, [outputs], updates, givens=givens) 285 | return lambda *args, **kwargs: f(*args, **kwargs)[0] 286 | 287 | 288 | class _Function(object): 289 | def __init__(self, inputs, outputs, updates, givens, check_nan=False): 290 | for inpt in inputs: 291 | if not issubclass(type(inpt), TfInput): 292 | assert len(inpt.op.inputs) == 0, "inputs should all be placeholders of rl_algs.common.TfInput" 293 | self.inputs = inputs 294 | updates = updates or [] 295 | self.update_group = tf.group(*updates) 296 | self.outputs_update = list(outputs) + [self.update_group] 297 | self.givens = {} if givens is None else givens 298 | self.check_nan = check_nan 299 | 300 | def _feed_input(self, feed_dict, inpt, value): 301 | if issubclass(type(inpt), TfInput): 302 | feed_dict.update(inpt.make_feed_dict(value)) 303 | elif is_placeholder(inpt): 304 | feed_dict[inpt] = value 305 | 306 | def __call__(self, *args, **kwargs): 307 | assert len(args) <= len(self.inputs), "Too many arguments provided" 308 | feed_dict = {} 309 | # Update the args 310 | for inpt, value in zip(self.inputs, args): 311 | self._feed_input(feed_dict, inpt, value) 312 | # Update the kwargs 313 | kwargs_passed_inpt_names = set() 314 | for inpt in self.inputs[len(args):]: 315 | inpt_name = inpt.name.split(':')[0] 316 | inpt_name = inpt_name.split('/')[-1] 317 | assert inpt_name not in kwargs_passed_inpt_names, \ 318 | "this function has two arguments with the same name \"{}\", so kwargs cannot be used.".format(inpt_name) 319 | if inpt_name in kwargs: 320 | kwargs_passed_inpt_names.add(inpt_name) 321 | self._feed_input(feed_dict, inpt, kwargs.pop(inpt_name)) 322 | else: 323 | assert inpt in self.givens, "Missing argument " + inpt_name 324 | assert len(kwargs) == 0, "Function got extra arguments " + str(list(kwargs.keys())) 325 | # Update feed dict with givens. 326 | for inpt in self.givens: 327 | feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt]) 328 | results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1] 329 | if self.check_nan: 330 | if any(np.isnan(r).any() for r in results): 331 | raise RuntimeError("Nan detected") 332 | return results 333 | -------------------------------------------------------------------------------- /maddpg/trainer/maddpg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import tensorflow as tf 4 | import maddpg.common.tf_util as U 5 | 6 | from maddpg.common.distributions import make_pdtype 7 | from maddpg import AgentTrainer 8 | from maddpg.trainer.replay_buffer import ReplayBuffer 9 | 10 | 11 | def discount_with_dones(rewards, dones, gamma): 12 | discounted = [] 13 | r = 0 14 | for reward, done in zip(rewards[::-1], dones[::-1]): 15 | r = reward + gamma*r 16 | r = r*(1.-done) 17 | discounted.append(r) 18 | return discounted[::-1] 19 | 20 | def make_update_exp(vals, target_vals): 21 | polyak = 1.0 - 1e-2 22 | expression = [] 23 | for var, var_target in zip(sorted(vals, key=lambda v: v.name), sorted(target_vals, key=lambda v: v.name)): 24 | expression.append(var_target.assign(polyak * var_target + (1.0-polyak) * var)) 25 | expression = tf.group(*expression) 26 | return U.function([], [], updates=[expression]) 27 | 28 | def p_train(make_obs_ph_n, act_space_n, p_index, p_func, q_func, optimizer, grad_norm_clipping=None, local_q_func=False, num_units=64, scope="trainer", reuse=None): 29 | with tf.variable_scope(scope, reuse=reuse): 30 | # create distribtuions 31 | act_pdtype_n = [make_pdtype(act_space) for act_space in act_space_n] 32 | 33 | # set up placeholders 34 | obs_ph_n = make_obs_ph_n 35 | act_ph_n = [act_pdtype_n[i].sample_placeholder([None], name="action"+str(i)) for i in range(len(act_space_n))] 36 | 37 | p_input = obs_ph_n[p_index] 38 | 39 | p = p_func(p_input, int(act_pdtype_n[p_index].param_shape()[0]), scope="p_func", num_units=num_units) 40 | p_func_vars = U.scope_vars(U.absolute_scope_name("p_func")) 41 | 42 | # wrap parameters in distribution 43 | act_pd = act_pdtype_n[p_index].pdfromflat(p) 44 | 45 | act_sample = act_pd.sample() 46 | p_reg = tf.reduce_mean(tf.square(act_pd.flatparam())) 47 | 48 | act_input_n = act_ph_n + [] 49 | act_input_n[p_index] = act_pd.sample() 50 | q_input = tf.concat(obs_ph_n + act_input_n, 1) 51 | if local_q_func: 52 | q_input = tf.concat([obs_ph_n[p_index], act_input_n[p_index]], 1) 53 | q = q_func(q_input, 1, scope="q_func", reuse=True, num_units=num_units)[:,0] 54 | pg_loss = -tf.reduce_mean(q) 55 | 56 | loss = pg_loss + p_reg * 1e-3 57 | 58 | optimize_expr = U.minimize_and_clip(optimizer, loss, p_func_vars, grad_norm_clipping) 59 | 60 | # Create callable functions 61 | train = U.function(inputs=obs_ph_n + act_ph_n, outputs=loss, updates=[optimize_expr]) 62 | act = U.function(inputs=[obs_ph_n[p_index]], outputs=act_sample) 63 | p_values = U.function([obs_ph_n[p_index]], p) 64 | 65 | # target network 66 | target_p = p_func(p_input, int(act_pdtype_n[p_index].param_shape()[0]), scope="target_p_func", num_units=num_units) 67 | target_p_func_vars = U.scope_vars(U.absolute_scope_name("target_p_func")) 68 | update_target_p = make_update_exp(p_func_vars, target_p_func_vars) 69 | 70 | target_act_sample = act_pdtype_n[p_index].pdfromflat(target_p).sample() 71 | target_act = U.function(inputs=[obs_ph_n[p_index]], outputs=target_act_sample) 72 | 73 | return act, train, update_target_p, {'p_values': p_values, 'target_act': target_act} 74 | 75 | def q_train(make_obs_ph_n, act_space_n, q_index, q_func, optimizer, grad_norm_clipping=None, local_q_func=False, scope="trainer", reuse=None, num_units=64): 76 | with tf.variable_scope(scope, reuse=reuse): 77 | # create distribtuions 78 | act_pdtype_n = [make_pdtype(act_space) for act_space in act_space_n] 79 | 80 | # set up placeholders 81 | obs_ph_n = make_obs_ph_n 82 | act_ph_n = [act_pdtype_n[i].sample_placeholder([None], name="action"+str(i)) for i in range(len(act_space_n))] 83 | target_ph = tf.placeholder(tf.float32, [None], name="target") 84 | 85 | q_input = tf.concat(obs_ph_n + act_ph_n, 1) 86 | if local_q_func: 87 | q_input = tf.concat([obs_ph_n[q_index], act_ph_n[q_index]], 1) 88 | q = q_func(q_input, 1, scope="q_func", num_units=num_units)[:,0] 89 | q_func_vars = U.scope_vars(U.absolute_scope_name("q_func")) 90 | 91 | q_loss = tf.reduce_mean(tf.square(q - target_ph)) 92 | 93 | # viscosity solution to Bellman differential equation in place of an initial condition 94 | q_reg = tf.reduce_mean(tf.square(q)) 95 | loss = q_loss #+ 1e-3 * q_reg 96 | 97 | optimize_expr = U.minimize_and_clip(optimizer, loss, q_func_vars, grad_norm_clipping) 98 | 99 | # Create callable functions 100 | train = U.function(inputs=obs_ph_n + act_ph_n + [target_ph], outputs=loss, updates=[optimize_expr]) 101 | q_values = U.function(obs_ph_n + act_ph_n, q) 102 | 103 | # target network 104 | target_q = q_func(q_input, 1, scope="target_q_func", num_units=num_units)[:,0] 105 | target_q_func_vars = U.scope_vars(U.absolute_scope_name("target_q_func")) 106 | update_target_q = make_update_exp(q_func_vars, target_q_func_vars) 107 | 108 | target_q_values = U.function(obs_ph_n + act_ph_n, target_q) 109 | 110 | return train, update_target_q, {'q_values': q_values, 'target_q_values': target_q_values} 111 | 112 | class MADDPGAgentTrainer(AgentTrainer): 113 | def __init__(self, name, model, obs_shape_n, act_space_n, agent_index, args, local_q_func=False): 114 | self.name = name 115 | self.n = len(obs_shape_n) 116 | self.agent_index = agent_index 117 | self.args = args 118 | obs_ph_n = [] 119 | for i in range(self.n): 120 | obs_ph_n.append(U.BatchInput(obs_shape_n[i], name="observation"+str(i)).get()) 121 | 122 | # Create all the functions necessary to train the model 123 | self.q_train, self.q_update, self.q_debug = q_train( 124 | scope=self.name, 125 | make_obs_ph_n=obs_ph_n, 126 | act_space_n=act_space_n, 127 | q_index=agent_index, 128 | q_func=model, 129 | optimizer=tf.train.AdamOptimizer(learning_rate=args.lr), 130 | grad_norm_clipping=0.5, 131 | local_q_func=local_q_func, 132 | num_units=args.num_units 133 | ) 134 | self.act, self.p_train, self.p_update, self.p_debug = p_train( 135 | scope=self.name, 136 | make_obs_ph_n=obs_ph_n, 137 | act_space_n=act_space_n, 138 | p_index=agent_index, 139 | p_func=model, 140 | q_func=model, 141 | optimizer=tf.train.AdamOptimizer(learning_rate=args.lr), 142 | grad_norm_clipping=0.5, 143 | local_q_func=local_q_func, 144 | num_units=args.num_units 145 | ) 146 | # Create experience buffer 147 | self.replay_buffer = ReplayBuffer(1e6) 148 | self.max_replay_buffer_len = args.batch_size * args.max_episode_len 149 | self.replay_sample_index = None 150 | 151 | def action(self, obs): 152 | return self.act(obs[None])[0] 153 | 154 | def experience(self, obs, act, rew, new_obs, done, terminal): 155 | # Store transition in the replay buffer. 156 | self.replay_buffer.add(obs, act, rew, new_obs, float(done)) 157 | 158 | def preupdate(self): 159 | self.replay_sample_index = None 160 | 161 | def update(self, agents, t): 162 | if len(self.replay_buffer) < self.max_replay_buffer_len: # replay buffer is not large enough 163 | return 164 | if not t % 100 == 0: # only update every 100 steps 165 | return 166 | 167 | self.replay_sample_index = self.replay_buffer.make_index(self.args.batch_size) 168 | # collect replay sample from all agents 169 | obs_n = [] 170 | obs_next_n = [] 171 | act_n = [] 172 | index = self.replay_sample_index 173 | for i in range(self.n): 174 | obs, act, rew, obs_next, done = agents[i].replay_buffer.sample_index(index) 175 | obs_n.append(obs) 176 | obs_next_n.append(obs_next) 177 | act_n.append(act) 178 | obs, act, rew, obs_next, done = self.replay_buffer.sample_index(index) 179 | 180 | # train q network 181 | num_sample = 1 182 | target_q = 0.0 183 | for i in range(num_sample): 184 | target_act_next_n = [agents[i].p_debug['target_act'](obs_next_n[i]) for i in range(self.n)] 185 | target_q_next = self.q_debug['target_q_values'](*(obs_next_n + target_act_next_n)) 186 | target_q += rew + self.args.gamma * (1.0 - done) * target_q_next 187 | target_q /= num_sample 188 | q_loss = self.q_train(*(obs_n + act_n + [target_q])) 189 | 190 | # train p network 191 | p_loss = self.p_train(*(obs_n + act_n)) 192 | 193 | self.p_update() 194 | self.q_update() 195 | 196 | return [q_loss, p_loss, np.mean(target_q), np.mean(rew), np.mean(target_q_next), np.std(target_q)] 197 | -------------------------------------------------------------------------------- /maddpg/trainer/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | class ReplayBuffer(object): 5 | def __init__(self, size): 6 | """Create Prioritized Replay buffer. 7 | 8 | Parameters 9 | ---------- 10 | size: int 11 | Max number of transitions to store in the buffer. When the buffer 12 | overflows the old memories are dropped. 13 | """ 14 | self._storage = [] 15 | self._maxsize = int(size) 16 | self._next_idx = 0 17 | 18 | def __len__(self): 19 | return len(self._storage) 20 | 21 | def clear(self): 22 | self._storage = [] 23 | self._next_idx = 0 24 | 25 | def add(self, obs_t, action, reward, obs_tp1, done): 26 | data = (obs_t, action, reward, obs_tp1, done) 27 | 28 | if self._next_idx >= len(self._storage): 29 | self._storage.append(data) 30 | else: 31 | self._storage[self._next_idx] = data 32 | self._next_idx = (self._next_idx + 1) % self._maxsize 33 | 34 | def _encode_sample(self, idxes): 35 | obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] 36 | for i in idxes: 37 | data = self._storage[i] 38 | obs_t, action, reward, obs_tp1, done = data 39 | obses_t.append(np.array(obs_t, copy=False)) 40 | actions.append(np.array(action, copy=False)) 41 | rewards.append(reward) 42 | obses_tp1.append(np.array(obs_tp1, copy=False)) 43 | dones.append(done) 44 | return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones) 45 | 46 | def make_index(self, batch_size): 47 | return [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 48 | 49 | def make_latest_index(self, batch_size): 50 | idx = [(self._next_idx - 1 - i) % self._maxsize for i in range(batch_size)] 51 | np.random.shuffle(idx) 52 | return idx 53 | 54 | def sample_index(self, idxes): 55 | return self._encode_sample(idxes) 56 | 57 | def sample(self, batch_size): 58 | """Sample a batch of experiences. 59 | 60 | Parameters 61 | ---------- 62 | batch_size: int 63 | How many transitions to sample. 64 | 65 | Returns 66 | ------- 67 | obs_batch: np.array 68 | batch of observations 69 | act_batch: np.array 70 | batch of actions executed given obs_batch 71 | rew_batch: np.array 72 | rewards received as results of executing act_batch 73 | next_obs_batch: np.array 74 | next set of observations seen after executing act_batch 75 | done_mask: np.array 76 | done_mask[i] = 1 if executing act_batch[i] resulted in 77 | the end of an episode and 0 otherwise. 78 | """ 79 | if batch_size > 0: 80 | idxes = self.make_index(batch_size) 81 | else: 82 | idxes = range(0, len(self._storage)) 83 | return self._encode_sample(idxes) 84 | 85 | def collect(self): 86 | return self.sample(-1) 87 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='maddpg', 4 | version='0.0.1', 5 | description='Multi-Agent Deep Deterministic Policy Gradient', 6 | url='https://github.com/openai/maddpg', 7 | author='Igor Mordatch', 8 | author_email='mordatch@openai.com', 9 | packages=find_packages(), 10 | include_package_data=True, 11 | zip_safe=False, 12 | install_requires=['gym', 'numpy-stl'] 13 | ) 14 | --------------------------------------------------------------------------------