├── .dockerignore ├── Dockerfile ├── LICENSE ├── build.sh ├── code ├── comm.lua ├── game │ ├── ColorDigit.lua │ └── Switch.lua ├── include │ ├── kwargs.lua │ ├── log.lua │ └── util.lua ├── model │ ├── ColorDigit.lua │ └── Switch.lua ├── module │ ├── Binarize.lua │ ├── GRU.lua │ ├── GaussianNoise.lua │ ├── LSTM.lua │ ├── LinearO.lua │ ├── Print.lua │ ├── rmsprop.lua │ └── rmspropm.lua ├── results │ └── .gitkeep ├── run_colordigit-dial.sh ├── run_colordigit-rial.sh ├── run_colordigit_many_steps-dial.sh ├── run_colordigit_many_steps-rial.sh ├── run_switch_3-dial.sh ├── run_switch_3-rial.sh ├── run_switch_4-dial.sh └── run_switch_4-rial.sh ├── readme.md └── run.sh /.dockerignore: -------------------------------------------------------------------------------- 1 | code -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:8.0-cudnn5-devel-ubuntu16.04 2 | # FROM ubuntu:16.04 3 | MAINTAINER Yannis Assael, Jakob Foerster 4 | 5 | # CUDA includes 6 | ENV CUDA_PATH /usr/local/cuda 7 | ENV CUDA_INCLUDE_PATH /usr/local/cuda/include 8 | ENV CUDA_LIBRARY_PATH /usr/local/cuda/lib64 9 | 10 | # Ubuntu Packages 11 | RUN apt-get update -y && apt-get install software-properties-common -y && \ 12 | add-apt-repository -y multiverse && apt-get update -y && apt-get upgrade -y && \ 13 | apt-get install -y apt-utils nano vim man build-essential wget sudo && \ 14 | rm -rf /var/lib/apt/lists/* 15 | 16 | # Install curl and other dependencies 17 | RUN apt-get update -y && apt-get install -y curl libssl-dev openssl libopenblas-dev \ 18 | libhdf5-dev hdf5-helpers hdf5-tools libhdf5-serial-dev libprotobuf-dev protobuf-compiler && \ 19 | curl -sk https://raw.githubusercontent.com/torch/distro/master/install-deps | bash && \ 20 | rm -rf /var/lib/apt/lists/* 21 | 22 | # Clone torch (and package) repos: 23 | RUN mkdir -p /opt && git clone https://github.com/torch/distro.git /opt/torch --recursive 24 | 25 | # Run installation script 26 | RUN cd /opt/torch && ./install.sh -b 27 | 28 | # Export environment variables manually 29 | ENV TORCH_DIR /opt/torch/pkg/torch/build/cmake-exports/ 30 | ENV LUA_PATH '/root/.luarocks/share/lua/5.1/?.lua;/root/.luarocks/share/lua/5.1/?/init.lua;/opt/torch/install/share/lua/5.1/?.lua;/opt/torch/install/share/lua/5.1/?/init.lua;./?.lua;/opt/torch/install/share/luajit-2.1.0-beta1/?.lua;/usr/local/share/lua/5.1/?.lua;/usr/local/share/lua/5.1/?/init.lua' 31 | ENV LUA_CPATH '/root/.luarocks/lib/lua/5.1/?.so;/opt/torch/install/lib/lua/5.1/?.so;./?.so;/usr/local/lib/lua/5.1/?.so;/usr/local/lib/lua/5.1/loadall.so' 32 | ENV PATH /opt/torch/install/bin:$PATH 33 | ENV LD_LIBRARY_PATH /opt/torch/install/lib:$LD_LIBRARY_PATH 34 | ENV DYLD_LIBRARY_PATH /opt/torch/install/lib:$DYLD_LIBRARY_PATH 35 | ENV LUA_CPATH '/opt/torch/install/lib/?.so;'$LUA_CPATH 36 | 37 | # Install torch packages 38 | RUN luarocks install totem && \ 39 | luarocks install https://raw.githubusercontent.com/deepmind/torch-hdf5/master/hdf5-0-0.rockspec && \ 40 | luarocks install unsup && \ 41 | luarocks install csvigo && \ 42 | luarocks install loadcaffe && \ 43 | luarocks install classic && \ 44 | luarocks install pprint && \ 45 | luarocks install class && \ 46 | luarocks install image && \ 47 | luarocks install mnist && \ 48 | luarocks install https://raw.githubusercontent.com/deepmind/torch-distributions/master/distributions-0-0.rockspec 49 | 50 | # Cleanup 51 | RUN rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 52 | 53 | WORKDIR /project 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | nvidia-docker build -t $USER/comm . 2 | -------------------------------------------------------------------------------- /code/comm.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Learning to Communicate with Deep Multi-Agent Reinforcement Learning 4 | 5 | @article{foerster2016learning, 6 | title={Learning to Communicate with Deep Multi-Agent Reinforcement Learning}, 7 | author={Foerster, Jakob N and Assael, Yannis M and de Freitas, Nando and Whiteson, Shimon}, 8 | journal={arXiv preprint arXiv:1605.06676}, 9 | year={2016} 10 | } 11 | 12 | ]] -- 13 | 14 | 15 | -- Configuration 16 | cmd = torch.CmdLine() 17 | cmd:text() 18 | cmd:text('Learning to Communicate with Deep Multi-Agent Reinforcement Learning') 19 | cmd:text() 20 | cmd:text('Options') 21 | 22 | -- general options: 23 | cmd:option('-seed', -1, 'initial random seed') 24 | cmd:option('-threads', 1, 'number of threads') 25 | 26 | -- gpu 27 | cmd:option('-cuda', 0, 'cuda') 28 | 29 | -- rl 30 | cmd:option('-gamma', 1, 'discount factor') 31 | cmd:option('-eps', 0.05, 'epsilon-greedy policy') 32 | 33 | -- model 34 | cmd:option('-model_rnn', 'gru', 'rnn type') 35 | cmd:option('-model_dial', 0, 'use dial connection or rial') 36 | cmd:option('-model_comm_narrow', 1, 'combines comm bits') 37 | cmd:option('-model_know_share', 1, 'knowledge sharing') 38 | cmd:option('-model_action_aware', 1, 'last action used as input') 39 | cmd:option('-model_rnn_size', 128, 'rnn size') 40 | cmd:option('-model_rnn_layers', 2, 'rnn layers') 41 | cmd:option('-model_dropout', 0, 'dropout') 42 | cmd:option('-model_bn', 1, 'batch normalisation') 43 | cmd:option('-model_target', 1, 'use a target network') 44 | cmd:option('-model_avg_q', 1, 'avearge q functions') 45 | 46 | -- training 47 | cmd:option('-bs', 32, 'batch size') 48 | cmd:option('-learningrate', 5e-4, 'learningrate') 49 | cmd:option('-nepisodes', 1e+6, 'number of episodes') 50 | cmd:option('-nsteps', 10, 'number of steps') 51 | 52 | cmd:option('-step', 1000, 'print every episodes') 53 | cmd:option('-step_test', 10, 'print every episodes') 54 | cmd:option('-step_target', 100, 'target network updates') 55 | 56 | cmd:option('-filename', '', '') 57 | 58 | -- games 59 | -- ColorDigit 60 | cmd:option('-game', 'ColorDigit', 'game name') 61 | cmd:option('-game_dim', 28, '') 62 | cmd:option('-game_bias', 0, '') 63 | cmd:option('-game_colors', 2, '') 64 | cmd:option('-game_use_mnist', 1, '') 65 | cmd:option('-game_use_digits', 0, '') 66 | cmd:option('-game_nagents', 2, '') 67 | cmd:option('-game_action_space', 2, '') 68 | cmd:option('-game_comm_limited', 0, '') 69 | cmd:option('-game_comm_bits', 1, '') 70 | cmd:option('-game_comm_sigma', 0, '') 71 | cmd:option('-game_coop', 1, '') 72 | cmd:option('-game_bottleneck', 10, '') 73 | cmd:option('-game_level', 'extra_hard', '') 74 | cmd:option('-game_vision_net', 'mlp', 'mlp or cnn') 75 | cmd:option('-nsteps', 2, 'number of steps') 76 | -- Switch 77 | cmd:option('-game', 'Switch', 'game name') 78 | cmd:option('-game_nagents', 3, '') 79 | cmd:option('-game_action_space', 2, '') 80 | cmd:option('-game_comm_limited', 1, '') 81 | cmd:option('-game_comm_bits', 2, '') 82 | cmd:option('-game_comm_sigma', 0, '') 83 | cmd:option('-nsteps', 6, 'number of steps') 84 | 85 | cmd:text() 86 | 87 | local opt = cmd:parse(arg) 88 | 89 | -- Custom options 90 | if opt.seed == -1 then opt.seed = torch.random(1000000) end 91 | opt.model_comm_narrow = opt.model_dial 92 | 93 | if opt.model_rnn == 'lstm' then 94 | opt.model_rnn_states = 2 * opt.model_rnn_layers 95 | elseif opt.model_rnn == 'gru' then 96 | opt.model_rnn_states = opt.model_rnn_layers 97 | end 98 | 99 | -- Requirements 100 | require 'nn' 101 | require 'optim' 102 | local kwargs = require 'include.kwargs' 103 | local log = require 'include.log' 104 | local util = require 'include.util' 105 | 106 | -- Set float as default type 107 | torch.manualSeed(opt.seed) 108 | torch.setnumthreads(opt.threads) 109 | torch.setdefaulttensortype('torch.FloatTensor') 110 | 111 | -- Cuda initialisation 112 | if opt.cuda == 1 then 113 | require 'cutorch' 114 | require 'cunn' 115 | cutorch.setDevice(1) 116 | opt.dtype = 'torch.CudaTensor' 117 | print(cutorch.getDeviceProperties(1)) 118 | else 119 | opt.dtype = 'torch.FloatTensor' 120 | end 121 | 122 | if opt.model_comm_narrow == 0 and opt.game_comm_bits > 0 then 123 | opt.game_comm_bits = 2 ^ opt.game_comm_bits 124 | end 125 | 126 | -- Initialise game 127 | local game = (require('game.' .. opt.game))(opt) 128 | 129 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 130 | -- Without dial we add the communication actions to the action space 131 | opt.game_action_space_total = opt.game_action_space + opt.game_comm_bits 132 | else 133 | opt.game_action_space_total = opt.game_action_space 134 | end 135 | 136 | -- Initialise models 137 | local model = (require('model.' .. opt.game))(opt) 138 | 139 | -- Print options 140 | util.sprint(opt) 141 | 142 | -- Model target evaluate 143 | model.evaluate(model.agent_target) 144 | 145 | -- Get parameters 146 | local params, gradParams, params_target, _ = model.getParameters() 147 | 148 | -- Optimisation function 149 | local optim_func, optim_config = model.optim() 150 | local optim_state = {} 151 | 152 | -- Initialise agents 153 | local agent = {} 154 | for i = 1, opt.game_nagents do 155 | agent[i] = {} 156 | 157 | agent[i].id = torch.Tensor():type(opt.dtype):resize(opt.bs):fill(i) 158 | 159 | -- Populate init state 160 | agent[i].input = {} 161 | agent[i].input_target = {} 162 | agent[i].state = {} 163 | agent[i].state_target = {} 164 | agent[i].state[0] = {} 165 | agent[i].state_target[0] = {} 166 | for j = 1, opt.model_rnn_states do 167 | agent[i].state[0][j] = torch.zeros(opt.bs, opt.model_rnn_size):type(opt.dtype) 168 | agent[i].state_target[0][j] = torch.zeros(opt.bs, opt.model_rnn_size):type(opt.dtype) 169 | end 170 | 171 | agent[i].d_state = {} 172 | agent[i].d_state[0] = {} 173 | for j = 1, opt.model_rnn_states do 174 | agent[i].d_state[0][j] = torch.zeros(opt.bs, opt.model_rnn_size):type(opt.dtype) 175 | end 176 | 177 | -- Store q values 178 | agent[i].q_next_max = {} 179 | agent[i].q_comm_next_max = {} 180 | end 181 | 182 | local episode = {} 183 | 184 | -- Initialise aux vectors 185 | local d_err = torch.Tensor(opt.bs, opt.game_action_space_total):type(opt.dtype) 186 | local td_err = torch.Tensor(opt.bs):type(opt.dtype) 187 | local td_comm_err = torch.Tensor(opt.bs):type(opt.dtype) 188 | local stats = { 189 | r_episode = torch.zeros(opt.nsteps), 190 | td_err = torch.zeros(opt.step), 191 | td_comm = torch.zeros(opt.step), 192 | train_r = torch.zeros(opt.step, opt.game_nagents), 193 | steps = torch.zeros(opt.step / opt.step_test), 194 | test_r = torch.zeros(opt.step / opt.step_test, opt.game_nagents), 195 | comm_per = torch.zeros(opt.step / opt.step_test), 196 | te = torch.zeros(opt.step) 197 | } 198 | 199 | local replay = {} 200 | 201 | -- Run episode 202 | local function run_episode(opt, game, model, agent, test_mode) 203 | 204 | -- Test mode 205 | test_mode = test_mode or false 206 | 207 | -- Reset game 208 | game:reset() 209 | 210 | -- Initialise episode 211 | local step = 1 212 | local episode = { 213 | comm_per = torch.zeros(opt.bs), 214 | r = torch.zeros(opt.bs, opt.game_nagents), 215 | steps = torch.zeros(opt.bs), 216 | ended = torch.zeros(opt.bs), 217 | comm_count = 0, 218 | non_comm_count = 0 219 | } 220 | episode[step] = { 221 | s_t = game:getState(), 222 | terminal = torch.zeros(opt.bs) 223 | } 224 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 225 | episode[step].comm = torch.zeros(opt.bs, opt.game_nagents, opt.game_comm_bits):type(opt.dtype) 226 | if opt.model_dial == 1 and opt.model_target == 1 then 227 | episode[step].comm_target = episode[step].comm:clone() 228 | end 229 | episode[step].d_comm = torch.zeros(opt.bs, opt.game_nagents, opt.game_comm_bits):type(opt.dtype) 230 | end 231 | 232 | 233 | -- Run for N steps 234 | local steps = test_mode and opt.nsteps or opt.nsteps + 1 235 | while step <= steps and episode.ended:sum() < opt.bs do 236 | 237 | -- Initialise next step 238 | episode[step + 1] = {} 239 | 240 | -- Initialise comm channel 241 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 242 | episode[step + 1].comm = torch.zeros(opt.bs, opt.game_nagents, opt.game_comm_bits):type(opt.dtype) 243 | episode[step + 1].d_comm = torch.zeros(opt.bs, opt.game_nagents, opt.game_comm_bits):type(opt.dtype) 244 | if opt.model_dial == 1 and opt.model_target == 1 then 245 | episode[step + 1].comm_target = torch.zeros(opt.bs, opt.game_nagents, opt.game_comm_bits):type(opt.dtype) 246 | end 247 | end 248 | 249 | -- Forward pass 250 | episode[step].a_t = torch.zeros(opt.bs, opt.game_nagents):type(opt.dtype) 251 | if opt.model_dial == 0 then 252 | episode[step].a_comm_t = torch.zeros(opt.bs, opt.game_nagents):type(opt.dtype) 253 | end 254 | 255 | -- Iterate agents 256 | for i = 1, opt.game_nagents do 257 | agent[i].input[step] = { 258 | episode[step].s_t[i]:type(opt.dtype), 259 | agent[i].id, 260 | agent[i].state[step - 1] 261 | } 262 | 263 | -- Communication enabled 264 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 265 | local comm_limited = game:getCommLimited(step, i) 266 | local comm = episode[step].comm:clone():type(opt.dtype) 267 | if comm_limited then 268 | -- Create limited communication channel nbits 269 | local comm_lim = torch.zeros(opt.bs, 1, opt.game_comm_bits):type(opt.dtype) 270 | for b = 1, opt.bs do 271 | if comm_limited[b] == 0 then 272 | comm_lim[{ { b } }]:zero() 273 | else 274 | comm_lim[{ { b } }] = comm[{ { b }, unpack(comm_limited[b]) }] 275 | end 276 | end 277 | table.insert(agent[i].input[step], comm_lim) 278 | else 279 | -- zero out own communication if not action aware 280 | comm[{ {}, { i } }]:zero() 281 | table.insert(agent[i].input[step], comm) 282 | end 283 | end 284 | 285 | -- Last action enabled 286 | if opt.model_action_aware == 1 then 287 | -- If comm always then use both action 288 | if opt.model_dial == 0 then 289 | local la = { torch.ones(opt.bs):type(opt.dtype), torch.ones(opt.bs):type(opt.dtype) } 290 | if step > 1 then 291 | for b = 1, opt.bs do 292 | -- Last action 293 | if episode[step - 1].a_t[b][i] > 0 then 294 | la[1][{ { b } }] = episode[step - 1].a_t[b][i] + 1 295 | end 296 | -- Last comm action 297 | if episode[step - 1].a_comm_t[b][i] > 0 then 298 | la[2][{ { b } }] = episode[step - 1].a_comm_t[b][i] - opt.game_action_space + 1 299 | end 300 | end 301 | end 302 | table.insert(agent[i].input[step], la) 303 | else 304 | -- Action aware for single a, comm action 305 | local la = torch.ones(opt.bs):type(opt.dtype) 306 | if step > 1 then 307 | for b = 1, opt.bs do 308 | if episode[step - 1].a_t[b][i] > 0 then 309 | la[{ { b } }] = episode[step - 1].a_t[b][i] + 1 310 | end 311 | end 312 | end 313 | table.insert(agent[i].input[step], la) 314 | end 315 | end 316 | 317 | -- Compute Q values 318 | local comm, state, q_t 319 | agent[i].state[step], q_t = unpack(model.agent[model.id(step, i)]:forward(agent[i].input[step])) 320 | 321 | 322 | -- If dial split out the comm values from q values 323 | if opt.model_dial == 1 then 324 | q_t, comm = DRU(q_t, test_mode) 325 | end 326 | 327 | -- Pick an action (epsilon-greedy) 328 | local action_range, action_range_comm 329 | local max_value, max_a, max_a_comm 330 | if opt.model_dial == 0 then 331 | action_range, action_range_comm = game:getActionRange(step, i) 332 | else 333 | action_range = game:getActionRange(step, i) 334 | end 335 | 336 | -- If Limited action range 337 | if action_range then 338 | agent[i].range = agent[i].range or torch.range(1, opt.game_action_space_total) 339 | max_value = torch.Tensor(opt.bs, 1) 340 | max_a = torch.zeros(opt.bs, 1) 341 | if opt.model_dial == 0 then 342 | max_a_comm = torch.zeros(opt.bs, 1) 343 | end 344 | for b = 1, opt.bs do 345 | -- If comm always fetch range for comm and actions 346 | if opt.model_dial == 0 then 347 | -- If action was taken 348 | if action_range[b][2][1] > 0 then 349 | local v, a = torch.max(q_t[action_range[b]], 2) 350 | max_value[b] = v:squeeze() 351 | max_a[b] = agent[i].range[{ action_range[b][2] }][a:squeeze()] 352 | end 353 | -- If comm action was taken 354 | if action_range_comm[b][2][1] > 0 then 355 | local v, a = torch.max(q_t[action_range_comm[b]], 2) 356 | max_a_comm[b] = agent[i].range[{ action_range_comm[b][2] }][a:squeeze()] 357 | end 358 | else 359 | local v, a = torch.max(q_t[action_range[b]], 2) 360 | max_a[b] = agent[i].range[{ action_range[b][2] }][a:squeeze()] 361 | end 362 | end 363 | else 364 | -- If comm always pick max_a and max_comm 365 | if opt.model_dial == 0 and opt.game_comm_bits > 0 then 366 | _, max_a = torch.max(q_t[{ {}, { 1, opt.game_action_space } }], 2) 367 | _, max_a_comm = torch.max(q_t[{ {}, { opt.game_action_space + 1, opt.game_action_space_total } }], 2) 368 | max_a_comm = max_a_comm + opt.game_action_space 369 | else 370 | _, max_a = torch.max(q_t, 2) 371 | end 372 | end 373 | 374 | -- Store actions 375 | episode[step].a_t[{ {}, { i } }] = max_a:type(opt.dtype) 376 | if opt.model_dial == 0 and opt.game_comm_bits > 0 then 377 | episode[step].a_comm_t[{ {}, { i } }] = max_a_comm:type(opt.dtype) 378 | end 379 | 380 | for b = 1, opt.bs do 381 | 382 | -- Epsilon-greedy action picking 383 | if not test_mode then 384 | if opt.model_dial == 0 then 385 | -- Random action 386 | if torch.uniform() < opt.eps then 387 | if action_range then 388 | if action_range[b][2][1] > 0 then 389 | local a_range = agent[i].range[{ action_range[b][2] }] 390 | local a_idx = torch.random(a_range:nElement()) 391 | episode[step].a_t[b][i] = agent[i].range[{ action_range[b][2] }][a_idx] 392 | end 393 | else 394 | episode[step].a_t[b][i] = torch.random(opt.game_action_space) 395 | end 396 | end 397 | 398 | -- Random communication 399 | if opt.game_comm_bits > 0 and torch.uniform() < opt.eps then 400 | if action_range then 401 | if action_range_comm[b][2][1] > 0 then 402 | local a_range = agent[i].range[{ action_range_comm[b][2] }] 403 | local a_idx = torch.random(a_range:nElement()) 404 | episode[step].a_comm_t[b][i] = agent[i].range[{ action_range_comm[b][2] }][a_idx] 405 | end 406 | else 407 | episode[step].a_comm_t[b][i] = torch.random(opt.game_action_space + 1, opt.game_action_space_total) 408 | end 409 | end 410 | 411 | else 412 | if torch.uniform() < opt.eps then 413 | if action_range then 414 | local a_range = agent[i].range[{ action_range[b][2] }] 415 | local a_idx = torch.random(a_range:nElement()) 416 | episode[step].a_t[b][i] = agent[i].range[{ action_range[b][2] }][a_idx] 417 | else 418 | episode[step].a_t[b][i] = torch.random(q_t[b]:size(1)) 419 | end 420 | end 421 | end 422 | end 423 | 424 | -- If communication action populate channel 425 | if step <= opt.nsteps then 426 | -- For dial we 'forward' the direct activation otherwise we shift the a_t into the 1-game_comm_bits range 427 | if opt.model_dial == 1 then 428 | episode[step + 1].comm[b][i] = comm[b] 429 | else 430 | local a_t = episode[step].a_comm_t[b][i] - opt.game_action_space 431 | if a_t > 0 then 432 | episode[step + 1].comm[b][{ { i }, { a_t } }] = 1 433 | end 434 | end 435 | 436 | if episode.ended[b] == 0 then 437 | episode.comm_per[{ { b } }]:add(1 / opt.game_nagents) 438 | end 439 | episode.comm_count = episode.comm_count + 1 440 | 441 | else 442 | episode.non_comm_count = episode.non_comm_count + 1 443 | end 444 | end 445 | end 446 | 447 | -- Compute reward for current state-action pair 448 | episode[step].r_t, episode[step].terminal = game:step(episode[step].a_t) 449 | 450 | 451 | -- Accumulate steps (not for +1 step) 452 | if step <= opt.nsteps then 453 | for b = 1, opt.bs do 454 | if episode.ended[b] == 0 then 455 | 456 | -- Keep steps and rewards 457 | episode.steps[{ { b } }]:add(1) 458 | episode.r[{ { b } }]:add(episode[step].r_t[b]) 459 | 460 | -- Check if terminal 461 | if episode[step].terminal[b] == 1 then 462 | episode.ended[{ { b } }] = 1 463 | end 464 | end 465 | end 466 | end 467 | 468 | 469 | -- Target Network, for look-ahead 470 | if opt.model_target == 1 and not test_mode then 471 | for i = 1, opt.game_nagents do 472 | local comm = agent[i].input[step][4] 473 | 474 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 and opt.model_dial == 1 then 475 | local comm_limited = game:getCommLimited(step, i) 476 | comm = episode[step].comm_target:clone():type(opt.dtype) 477 | 478 | -- Create limited communication channel nbits 479 | if comm_limited then 480 | local comm_lim = torch.zeros(opt.bs, 1, opt.game_comm_bits):type(opt.dtype) 481 | for b = 1, opt.bs do 482 | if comm_limited[b] == 0 then 483 | comm_lim[{ { b } }] = 0 484 | else 485 | comm_lim[{ { b } }] = comm[{ { b }, unpack(comm_limited[b]) }] 486 | end 487 | end 488 | comm = comm_lim 489 | else 490 | comm[{ {}, { i } }] = 0 491 | end 492 | end 493 | 494 | -- Target input 495 | agent[i].input_target[step] = { 496 | agent[i].input[step][1], 497 | agent[i].input[step][2], 498 | agent[i].state_target[step - 1], 499 | comm, 500 | agent[i].input[step][5], 501 | } 502 | 503 | -- Forward target 504 | local state, q_t_target = unpack(model.agent_target[model.id(step, i)]:forward(agent[i].input_target[step])) 505 | agent[i].state_target[step] = state 506 | if opt.model_dial == 1 then 507 | q_t_target, comm = DRU(q_t_target, test_mode) 508 | end 509 | 510 | -- Limit actions 511 | if opt.model_dial == 0 and opt.game_comm_bits > 0 then 512 | local action_range, action_range_comm = game:getActionRange(step, i) 513 | if action_range then 514 | agent[i].q_next_max[step] = torch.zeros(opt.bs):type(opt.dtype) 515 | agent[i].q_comm_next_max[step] = torch.zeros(opt.bs):type(opt.dtype) 516 | for b = 1, opt.bs do 517 | if action_range[b][2][1] > 0 then 518 | agent[i].q_next_max[step][b], _ = torch.max(q_t_target[action_range[b]], 2) 519 | else 520 | error('Not implemented') 521 | end 522 | 523 | -- If comm not available pick from None 524 | if action_range_comm[b][2][1] > 0 then 525 | agent[i].q_comm_next_max[step][b], _ = torch.max(q_t_target[action_range_comm[b]], 2) 526 | else 527 | agent[i].q_comm_next_max[step][b], _ = torch.max(q_t_target[action_range[b]], 2) 528 | end 529 | end 530 | else 531 | agent[i].q_next_max[step], _ = torch.max(q_t_target[{ {}, { 1, opt.game_action_space } }], 2) 532 | agent[i].q_comm_next_max[step], _ = torch.max(q_t_target[{ {}, { opt.game_action_space + 1, opt.game_action_space_total } }], 2) 533 | end 534 | else 535 | local action_range = game:getActionRange(step, i) 536 | if action_range then 537 | agent[i].q_next_max[step] = torch.zeros(opt.bs):type(opt.dtype) 538 | for b = 1, opt.bs do 539 | if action_range[b][2][1] > 0 then 540 | agent[i].q_next_max[step][b], _ = torch.max(q_t_target[action_range[b]], 2) 541 | end 542 | end 543 | else 544 | agent[i].q_next_max[step], _ = torch.max(q_t_target, 2) 545 | end 546 | end 547 | 548 | if opt.model_dial == 1 then 549 | for b = 1, opt.bs do 550 | episode[step + 1].comm_target[b][i] = comm[b] 551 | end 552 | end 553 | end 554 | end 555 | 556 | -- Forward next step 557 | step = step + 1 558 | if episode.ended:sum() < opt.bs then 559 | episode[step].s_t = game:getState() 560 | end 561 | end 562 | 563 | -- Update stats 564 | episode.nsteps = episode.steps:max() 565 | episode.comm_per:cdiv(episode.steps) 566 | 567 | return episode, agent 568 | end 569 | 570 | 571 | -- split out the communication bits and add noise. 572 | function DRU(q_t, test_mode) 573 | if opt.model_dial == 0 then error('Warning!! Should only be used in DIAL') end 574 | local bound = opt.game_action_space 575 | 576 | local q_t_n = q_t[{ {}, { 1, bound } }]:clone() 577 | local comm = q_t[{ {}, { bound + 1, opt.game_action_space_total } }]:clone() 578 | if test_mode then 579 | if opt.model_comm_narrow == 0 then 580 | local ind 581 | _, ind = torch.max(comm, 2) 582 | comm:zero() 583 | for b = 1, opt.bs do 584 | comm[b][ind[b][1]] = 20 585 | end 586 | else 587 | comm = comm:gt(0.5):type(opt.dtype):add(-0.5):mul(2 * 20) 588 | end 589 | end 590 | if opt.game_comm_sigma > 0 and not test_mode then 591 | local noise_vect = torch.randn(comm:size()):type(opt.dtype):mul(opt.game_comm_sigma) 592 | comm = comm + noise_vect 593 | end 594 | return q_t_n, comm 595 | end 596 | 597 | -- Start time 598 | local beginning_time = torch.tic() 599 | 600 | -- Iterate episodes 601 | for e = 1, opt.nepisodes do 602 | 603 | stats.e = e 604 | 605 | -- Initialise clock 606 | local time = sys.clock() 607 | 608 | -- Model training 609 | model.training(model.agent) 610 | 611 | -- Run episode 612 | episode, agent = run_episode(opt, game, model, agent) 613 | 614 | -- Rewards stats 615 | stats.train_r[(e - 1) % opt.step + 1] = episode.r:mean(1) 616 | 617 | -- Reset parameters 618 | if e == 1 then 619 | gradParams:zero() 620 | end 621 | 622 | -- Backwawrd pass 623 | local step_back = 1 624 | for step = episode.nsteps, 1, -1 do 625 | stats.td_err[(e - 1) % opt.step + 1] = 0 626 | stats.td_comm[(e - 1) % opt.step + 1] = 0 627 | 628 | -- Iterate agents 629 | for i = 1, opt.game_nagents do 630 | 631 | -- Compute Q values 632 | local state, q_t = unpack(model.agent[model.id(step, i)].output) 633 | 634 | -- Compute td error 635 | td_err:zero() 636 | td_comm_err:zero() 637 | d_err:zero() 638 | 639 | for b = 1, opt.bs do 640 | if step >= episode.steps[b] then 641 | -- if first backward init RNN 642 | for j = 1, opt.model_rnn_states do 643 | agent[i].d_state[step_back - 1][j][b]:zero() 644 | end 645 | end 646 | 647 | if step <= episode.steps[b] then 648 | 649 | -- if terminal state or end state => no future rewards 650 | if episode[step].a_t[b][i] > 0 then 651 | if episode[step].terminal[b] == 1 then 652 | td_err[b] = episode[step].r_t[b][i] - q_t[b][episode[step].a_t[b][i]] 653 | else 654 | local q_next_max 655 | if opt.model_avg_q == 1 and opt.model_dial == 0 and episode[step].a_comm_t[b][i] > 0 then 656 | q_next_max = (agent[i].q_next_max[step + 1]:squeeze() + agent[i].q_comm_next_max[step + 1]:squeeze()) / 2 657 | else 658 | q_next_max = agent[i].q_next_max[step + 1]:squeeze() 659 | end 660 | td_err[b] = episode[step].r_t[b][i] + opt.gamma * q_next_max[b] - q_t[b][episode[step].a_t[b][i]] 661 | end 662 | d_err[{ { b }, { episode[step].a_t[b][i] } }] = -td_err[b] 663 | 664 | else 665 | error('Error!') 666 | end 667 | 668 | -- Delta Q for communication 669 | if opt.model_dial == 0 then 670 | if episode[step].a_comm_t[b][i] > 0 then 671 | if episode[step].terminal[b] == 1 then 672 | td_comm_err[b] = episode[step].r_t[b][i] - q_t[b][episode[step].a_comm_t[b][i]] 673 | else 674 | local q_next_max 675 | if opt.model_avg_q == 1 and episode[step].a_t[b][i] > 0 then 676 | q_next_max = (agent[i].q_next_max[step + 1]:squeeze() + agent[i].q_comm_next_max[step + 1]:squeeze()) / 2 677 | else 678 | q_next_max = agent[i].q_comm_next_max[step + 1]:squeeze() 679 | end 680 | td_comm_err[b] = episode[step].r_t[b][i] + opt.gamma * q_next_max[b] - q_t[b][episode[step].a_comm_t[b][i]] 681 | end 682 | d_err[{ { b }, { episode[step].a_comm_t[b][i] } }] = -td_comm_err[b] 683 | end 684 | end 685 | 686 | -- If we use dial and the agent took the umbrella comm action and the messsage happened before last round, the we get incoming derivaties 687 | if opt.model_dial == 1 and step < episode.steps[b] then 688 | -- Derivatives with respect to agent_i's message are stored in d_comm[b][i] 689 | local bound = opt.game_action_space 690 | d_err[{ { b }, { bound + 1, opt.game_action_space_total } }]:add(episode[step + 1].d_comm[b][i]) 691 | end 692 | end 693 | end 694 | 695 | -- Track td-err 696 | stats.td_err[(e - 1) % opt.step + 1] = stats.td_err[(e - 1) % opt.step + 1] + 0.5 * td_err:clone():pow(2):mean() 697 | if opt.model_dial == 0 then 698 | stats.td_comm[(e - 1) % opt.step + 1] = stats.td_comm[(e - 1) % opt.step + 1] + 0.5 * td_comm_err:clone():pow(2):mean() 699 | end 700 | 701 | -- Track the amplitude of the dial-derivatives 702 | if opt.model_dial == 1 then 703 | local bound = opt.game_action_space 704 | stats.td_comm[(e - 1) % opt.step + 1] = stats.td_comm[(e - 1) % opt.step + 1] + 0.5 * d_err[{ {}, { bound + 1, opt.game_action_space_total } }]:clone():pow(2):mean() 705 | end 706 | 707 | -- Backward pass 708 | local grad = model.agent[model.id(step, i)]:backward(agent[i].input[step], { 709 | agent[i].d_state[step_back - 1], 710 | d_err 711 | }) 712 | 713 | --'state' is the 3rd input, so we can extract d_state 714 | agent[i].d_state[step_back] = grad[3] 715 | 716 | --For dial we need to write add the derivatives w/ respect to the incoming messages to the d_comm tracker 717 | if opt.model_dial == 1 then 718 | local comm_limited = game:getCommLimited(step, i) 719 | local comm_grad = grad[4] 720 | 721 | if comm_limited then 722 | for b = 1, opt.bs do 723 | -- Agent could only receive the message if they were active 724 | if comm_limited[b] ~= 0 then 725 | episode[step].d_comm[{ { b }, unpack(comm_limited[b]) }]:add(comm_grad[b]) 726 | end 727 | end 728 | else 729 | -- zero out own communication unless it's part of the switch riddle 730 | comm_grad[{ {}, { i } }]:zero() 731 | episode[step].d_comm:add(comm_grad) 732 | end 733 | end 734 | end 735 | 736 | -- Count backward steps 737 | step_back = step_back + 1 738 | end 739 | 740 | -- Update gradients 741 | local feval = function(x) 742 | 743 | -- Normalise Gradients 744 | gradParams:div(opt.game_nagents * opt.bs) 745 | 746 | -- Clip Gradients 747 | gradParams:clamp(-10, 10) 748 | 749 | return nil, gradParams 750 | end 751 | 752 | optim_func(feval, params, optim_config, optim_state) 753 | 754 | -- Gradient statistics 755 | if e % opt.step == 0 then 756 | stats.grad_norm = gradParams:norm() / gradParams:nElement() * 1000 757 | end 758 | 759 | -- Reset parameters 760 | gradParams:zero() 761 | 762 | -- Update target network 763 | if e % opt.step_target == 0 then 764 | params_target:copy(params) 765 | end 766 | 767 | -- Test 768 | if e % opt.step_test == 0 then 769 | local test_idx = (e / opt.step_test - 1) % (opt.step / opt.step_test) + 1 770 | 771 | local episode, _ = run_episode(opt, game, model, agent, true) 772 | stats.test_r[test_idx] = episode.r:mean(1) 773 | stats.steps[test_idx] = episode.steps:mean() 774 | stats.comm_per[test_idx] = episode.comm_count / (episode.comm_count + episode.non_comm_count) 775 | end 776 | 777 | -- Compute statistics 778 | stats.te[(e - 1) % opt.step + 1] = sys.clock() - time 779 | 780 | if e == opt.step then 781 | stats.td_err_avg = stats.td_err:mean() 782 | stats.td_comm_avg = stats.td_comm:mean() 783 | stats.train_r_avg = stats.train_r:mean(1) 784 | stats.test_r_avg = stats.test_r:mean(1) 785 | stats.steps_avg = stats.steps:mean() 786 | stats.comm_per_avg = stats.comm_per:mean() 787 | stats.te_avg = stats.te:mean() 788 | elseif e % opt.step == 0 then 789 | local coef = 0.9 790 | stats.td_err_avg = stats.td_err_avg * coef + stats.td_err:mean() * (1 - coef) 791 | stats.td_comm_avg = stats.td_comm_avg * coef + stats.td_comm:mean() * (1 - coef) 792 | stats.train_r_avg = stats.train_r_avg * coef + stats.train_r:mean(1) * (1 - coef) 793 | stats.test_r_avg = stats.test_r_avg * coef + stats.test_r:mean(1) * (1 - coef) 794 | stats.steps_avg = stats.steps_avg * coef + stats.steps:mean() * (1 - coef) 795 | stats.comm_per_avg = stats.comm_per_avg * coef + stats.comm_per:mean() * (1 - coef) 796 | stats.te_avg = stats.te_avg * coef + stats.te:mean() * (1 - coef) 797 | end 798 | 799 | -- Print statistics 800 | if e % opt.step == 0 then 801 | log.infof('e=%d, td_err=%.3f, td_err_avg=%.3f, td_comm=%.3f, td_comm_avg=%.3f, tr_r=%.2f, tr_r_avg=%.2f, te_r=%.2f, te_r_avg=%.2f, st=%.1f, comm=%.1f%%, grad=%.3f, t/s=%.2f s, t=%d m', 802 | stats.e, 803 | stats.td_err:mean(), 804 | stats.td_err_avg, 805 | stats.td_comm:mean(), 806 | stats.td_comm_avg, 807 | stats.train_r:mean(), 808 | stats.train_r_avg:mean(), 809 | stats.test_r:mean(), 810 | stats.test_r_avg:mean(), 811 | stats.steps_avg, 812 | stats.comm_per_avg * 100, 813 | stats.grad_norm, 814 | stats.te_avg * opt.step, 815 | torch.toc(beginning_time) / 60) 816 | 817 | collectgarbage() 818 | end 819 | 820 | -- run model specific statistics 821 | model.stats(opt, game, stats, e) 822 | 823 | -- run model specific statistics 824 | model.save(opt, stats, model) 825 | end 826 | -------------------------------------------------------------------------------- /code/game/ColorDigit.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | local class = require 'class' 3 | 4 | local log = require 'include.log' 5 | local kwargs = require 'include.kwargs' 6 | local util = require 'include.util' 7 | 8 | 9 | local ColorDigit = class('ColorDigit') 10 | 11 | function ColorDigit:__init(opt) 12 | self.opt = opt 13 | 14 | self.step_counter = 1 15 | -- Preprocess data 16 | local dataset = (require 'mnist').traindataset() 17 | local data = {} 18 | local lookup = {} 19 | for i = 1, dataset.size do 20 | -- Shift 0 class 21 | local y = dataset[i].y + 1 22 | -- Create array 23 | if not data[y] then 24 | data[y] = {} 25 | end 26 | -- Move data 27 | data[y][#data[y] + 1] = { 28 | x = dataset[i].x, 29 | y = y 30 | } 31 | lookup[i] = y 32 | end 33 | 34 | self.mnist = data 35 | 36 | self.mnist.lookup = lookup 37 | 38 | -- Rewards 39 | self.reward = torch.zeros(self.opt.bs) 40 | self.terminal = torch.zeros(self.opt.bs) 41 | 42 | -- Spawn new game 43 | self:reset() 44 | end 45 | 46 | function ColorDigit:loadDigit() 47 | -- Pick random digit and color 48 | local color_id = torch.zeros(self.opt.bs) 49 | local number = torch.zeros(self.opt.bs) 50 | local x = torch.zeros(self.opt.bs, self.opt.game_colors, self.opt.game_dim, self.opt.game_dim):type(self.opt.dtype) 51 | for b = 1, self.opt.bs do 52 | -- Pick number 53 | local num 54 | if self.opt.game_use_mnist == 1 then 55 | local index = torch.random(#self.mnist.lookup) 56 | num = self.mnist.lookup[index] 57 | elseif torch.uniform() < self.opt.game_bias then 58 | num = 1 59 | else 60 | num = torch.random(10) 61 | end 62 | 63 | number[b] = num 64 | 65 | -- Pick color 66 | color_id[b] = torch.random(self.opt.game_colors) 67 | 68 | -- Pick dataset id 69 | local id = torch.random(#self.mnist[num]) 70 | x[b][color_id[b]] = self.mnist[num][id].x 71 | end 72 | return { x, color_id, number } 73 | end 74 | 75 | 76 | function ColorDigit:reset() 77 | 78 | -- Load images 79 | self.state = { self:loadDigit(), self:loadDigit() } 80 | 81 | -- Reset rewards 82 | self.reward:zero() 83 | self.terminal:zero() 84 | 85 | -- Reset counter 86 | self.step_counter = 1 87 | 88 | return self 89 | end 90 | 91 | function ColorDigit:getActionRange() 92 | return nil 93 | end 94 | 95 | function ColorDigit:getCommLimited() 96 | return nil 97 | end 98 | 99 | function ColorDigit:getReward(a) 100 | 101 | local color_1 = self.state[1][2] 102 | local color_2 = self.state[2][2] 103 | local digit_1 = self.state[1][3] 104 | local digit_2 = self.state[2][3] 105 | 106 | local reward = torch.zeros(self.opt.bs, self.opt.game_nagents) 107 | 108 | for b = 1, self.opt.bs do 109 | if self.opt.game_level == "extra_hard_local" then 110 | if a[b][2] <= self.opt.game_action_space and self.step_counter > 1 then 111 | reward[b] = 2 * (-1) ^ (digit_1[b] + a[b][2] + color_2[b]) + (-1) ^ (digit_2[b] + a[b][2] + color_1[b]) 112 | end 113 | if a[b][1] <= self.opt.game_action_space and self.step_counter > 1 then 114 | reward[b] = reward[b] + 2 * (-1) ^ (digit_2[b] + a[b][1] + color_1[b]) + (-1) ^ (digit_1[b] + a[b][1] + color_2[b]) 115 | end 116 | elseif self.opt.game_level == "many_bits" then 117 | if a[b][1] <= self.opt.game_action_space and self.step_counter == self.opt.nsteps then 118 | if digit_2[b] == a[b][1] then 119 | reward[b] = reward[b] + 0.5 120 | end 121 | end 122 | 123 | if a[b][2] <= self.opt.game_action_space and self.step_counter == self.opt.nsteps then 124 | if digit_1[b] == a[b][2] then 125 | reward[b] = reward[b] + 0.5 126 | end 127 | end 128 | else 129 | error("[ColorDigit] wrong level") 130 | end 131 | end 132 | 133 | local reward_coop = torch.zeros(self.opt.bs, self.opt.game_nagents) 134 | reward_coop[{ {}, { 2 } }] = (reward[{ {}, { 2 } }] + reward[{ {}, { 1 } }] * self.opt.game_coop) / (1 + self.opt.game_coop) 135 | reward_coop[{ {}, { 1 } }] = (reward[{ {}, { 1 } }] + reward[{ {}, { 2 } }] * self.opt.game_coop) / (1 + self.opt.game_coop) 136 | 137 | return reward_coop 138 | end 139 | 140 | function ColorDigit:step(a) 141 | local reward, terminal 142 | 143 | reward = self:getReward(a) 144 | 145 | if self.step_counter == self.opt.nsteps then 146 | self.terminal:fill(1) 147 | end 148 | 149 | self.step_counter = self.step_counter + 1 150 | 151 | return reward, self.terminal:clone() 152 | end 153 | 154 | 155 | function ColorDigit:getState() 156 | if self.opt.game_use_digits == 1 then 157 | return { self.state[1][3], self.state[2][3] } 158 | else 159 | return { self.state[1][1], self.state[2][1] } 160 | end 161 | end 162 | 163 | return ColorDigit 164 | 165 | -------------------------------------------------------------------------------- /code/game/Switch.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | local class = require 'class' 3 | 4 | local log = require 'include.log' 5 | local kwargs = require 'include.kwargs' 6 | local util = require 'include.util' 7 | 8 | 9 | local Switch = class('Switch') 10 | 11 | -- Actions 12 | -- 1 = on 13 | -- 2 = off 14 | -- 3 = tell 15 | -- 4* = none 16 | 17 | function Switch:__init(opt) 18 | local opt_game = kwargs(_, { 19 | { 'game_action_space', type = 'int-pos', default = 2 }, 20 | { 'game_reward_shift', type = 'int', default = 0 }, 21 | { 'game_comm_bits', type = 'int', default = 0 }, 22 | { 'game_comm_sigma', type = 'number', default = 2 }, 23 | }) 24 | 25 | -- Steps max override 26 | opt.nsteps = 4 * opt.game_nagents - 6 27 | 28 | for k, v in pairs(opt_game) do 29 | if not opt[k] then 30 | opt[k] = v 31 | end 32 | end 33 | self.opt = opt 34 | 35 | -- Rewards 36 | self.reward_all_live = 1 + self.opt.game_reward_shift 37 | self.reward_all_die = -1 + self.opt.game_reward_shift 38 | 39 | -- Spawn new game 40 | self:reset() 41 | end 42 | 43 | function Switch:reset() 44 | 45 | -- Reset rewards 46 | self.reward = torch.zeros(self.opt.bs, self.opt.game_nagents) 47 | 48 | -- Has been in 49 | self.has_been = torch.zeros(self.opt.bs, self.opt.nsteps, self.opt.game_nagents) 50 | 51 | -- Reached end 52 | self.terminal = torch.zeros(self.opt.bs) 53 | 54 | -- Step counter 55 | self.step_counter = 1 56 | 57 | -- Who is in 58 | self.active_agent = torch.zeros(self.opt.bs, self.opt.nsteps) 59 | for b = 1, self.opt.bs do 60 | for step = 1, self.opt.nsteps do 61 | local id = torch.random(self.opt.game_nagents) 62 | self.active_agent[{ { b }, { step } }] = id 63 | self.has_been[{ { b }, { step }, { id } }] = 1 64 | end 65 | end 66 | 67 | return self 68 | end 69 | 70 | function Switch:getActionRange(step, agent) 71 | local range = {} 72 | if self.opt.model_dial == 1 then 73 | local bound = self.opt.game_action_space 74 | 75 | for i = 1, self.opt.bs do 76 | if self.active_agent[i][step] == agent then 77 | range[i] = { { i }, { 1, bound } } 78 | else 79 | range[i] = { { i }, { 1 } } 80 | end 81 | end 82 | return range 83 | else 84 | local comm_range = {} 85 | for i = 1, self.opt.bs do 86 | if self.active_agent[i][step] == agent then 87 | range[i] = { { i }, { 1, self.opt.game_action_space } } 88 | comm_range[i] = { { i }, { self.opt.game_action_space + 1, self.opt.game_action_space_total } } 89 | else 90 | range[i] = { { i }, { 1 } } 91 | comm_range[i] = { { i }, { 0, 0 } } 92 | end 93 | end 94 | return range, comm_range 95 | end 96 | end 97 | 98 | 99 | function Switch:getCommLimited(step, i) 100 | if self.opt.game_comm_limited then 101 | 102 | local range = {} 103 | 104 | -- Get range per batch 105 | for b = 1, self.opt.bs do 106 | -- if agent is active read from field of previous agent 107 | if step > 1 and i == self.active_agent[b][step] then 108 | range[b] = { self.active_agent[b][step - 1], {} } 109 | else 110 | range[b] = 0 111 | end 112 | end 113 | return range 114 | else 115 | return nil 116 | end 117 | end 118 | 119 | function Switch:getReward(a_t) 120 | 121 | for b = 1, self.opt.bs do 122 | local active_agent = self.active_agent[b][self.step_counter] 123 | if (a_t[b][active_agent] == 2 and self.terminal[b] == 0) then 124 | local has_been = self.has_been[{ { b }, { 1, self.step_counter }, {} }]:sum(2):squeeze(2):gt(0):float():sum() 125 | if has_been == self.opt.game_nagents then 126 | self.reward[b] = self.reward_all_live 127 | else 128 | self.reward[b] = self.reward_all_die 129 | end 130 | self.terminal[b] = 1 131 | elseif self.step_counter == self.opt.nsteps and self.terminal[b] == 0 then 132 | self.terminal[b] = 1 133 | end 134 | end 135 | 136 | return self.reward:clone(), self.terminal:clone() 137 | end 138 | 139 | function Switch:step(a_t) 140 | 141 | -- Get rewards 142 | local reward, terminal = self:getReward(a_t) 143 | 144 | -- Make step 145 | self.step_counter = self.step_counter + 1 146 | 147 | return reward, terminal 148 | end 149 | 150 | 151 | function Switch:getState() 152 | local state = {} 153 | 154 | for agent = 1, self.opt.game_nagents do 155 | state[agent] = torch.Tensor(self.opt.bs) 156 | 157 | for b = 1, self.opt.bs do 158 | if self.active_agent[b][self.step_counter] == agent then 159 | state[agent][{ { b } }] = 1 160 | else 161 | state[agent][{ { b } }] = 2 162 | end 163 | end 164 | end 165 | 166 | return state 167 | end 168 | 169 | return Switch 170 | 171 | -------------------------------------------------------------------------------- /code/include/kwargs.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Based on the code of Brendan Shillingford [bitbucket.org/bshillingford/nnob](https://bitbucket.org/bshillingford/nnob). 4 | 5 | Argument type checker for keyword arguments, i.e. arguments 6 | specified as a key-value table to constructors/functions. 7 | 8 | ## Valid typespecs: 9 | - `"number"` 10 | - `"string"` 11 | - `"boolean"` 12 | - `"function"` 13 | - `"table"` 14 | - `"tensor"` 15 | - torch class (not a string) like nn.Module 16 | - `"class:x"` 17 | - e.g. x=`nn.Module`; can be comma-separated to OR them 18 | - specific `torch.*Tensor` 19 | - specific `torch.*Storage` 20 | - `"int"`: number that is integer 21 | - `"int-pos"`: integer, and `> 0` 22 | - `"int-nonneg"`: integer, and `>= 0` 23 | --]] 24 | require 'torch' 25 | local math = require 'math' 26 | 27 | local function assert_type(val, typespec, argname) 28 | local typename = torch.typename(val) or type(val) 29 | -- handles number, boolean, string, table; but passes through if needed 30 | if typespec == typename then return true end 31 | 32 | -- try to parse, nil if no match 33 | local classnames = type(typespec) == string and string.match(typespec, 'class: *(.*)') 34 | 35 | -- isTypeOf for table-typed specs (see below for class:x version) 36 | if type(typespec) == 'table' and typespec.__typename then 37 | if torch.isTypeOf(val, typespec) then 38 | return true 39 | else 40 | error(string.format('argument %s should be instance of %s, but is type %s', 41 | argname, typespec.__typename, typename)) 42 | end 43 | elseif typespec == 'tensor' then 44 | if torch.isTensor(val) then return true end 45 | elseif classnames then 46 | for _, classname in pairs(string.split(classnames, ' *, *')) do 47 | if torch.isTypeOf(val, classname) then return true end 48 | end 49 | elseif typespec == 'int' or typespec == 'integer' then 50 | if math.floor(val) == val then return true end 51 | elseif typespec == 'int-pos' then 52 | if val > 0 and math.floor(val) == val then return true end 53 | elseif typespec == 'int-nonneg' then 54 | if val >= 0 and math.floor(val) == val then return true end 55 | else 56 | error('invalid type spec (' .. tostring(typespec) .. ') for arg ' .. argname) 57 | end 58 | error(string.format('argument %s must be of type %s, given type %s', 59 | argname, typespec, typename)) 60 | end 61 | 62 | return function(args, settings) 63 | local result = {} 64 | local unprocessed = {} 65 | 66 | if not args then 67 | args = {} 68 | end 69 | 70 | if type(args) ~= 'table' then 71 | error('args must be non-nil and must be a table') 72 | end 73 | 74 | for k, _ in pairs(args) do 75 | unprocessed[k] = true 76 | end 77 | 78 | -- Use ipairs, so we skip named settings 79 | for _, setting in ipairs(settings) do 80 | -- allow name to either be the only non-named element 81 | -- e.g. {'name', type='...'}, or named 82 | local name = setting.name or setting[1] 83 | 84 | -- get value or default 85 | local val 86 | if args[name] ~= nil then 87 | val = args[name] 88 | elseif setting.default ~= nil then 89 | val = setting.default 90 | elseif not setting.optional then 91 | error('required argument: ' .. name) 92 | end 93 | -- check types 94 | if val ~= nil and not setting.optional and setting.type ~= nil then 95 | assert_type(val, setting.type, name) 96 | end 97 | 98 | result[name] = val 99 | unprocessed[name] = nil 100 | end 101 | 102 | if settings.ignore_extras then 103 | for _, name in pairs(unprocessed) do 104 | result[name] = args[name] 105 | end 106 | elseif #unprocessed > 0 then 107 | error('extra unprocessed arguments: ' 108 | .. table.concat(unprocessed, ', ')) 109 | end 110 | return result 111 | end 112 | -------------------------------------------------------------------------------- /code/include/log.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Based on the code of Brendan Shillingford [bitbucket.org/bshillingford/nnob](https://bitbucket.org/bshillingford/nnob). 4 | 5 | Based on [github.com/rxi/log.lua](https://github.com/rxi/log.lua/commit/93cfbe0c91bced6d3d061a58e7129979441eb200). 6 | 7 | Usage: 8 | ``` 9 | local log = require 'nnob.log' 10 | 11 | -- ... 12 | log.info('hello world!') 13 | log.infof('some number: %3.3f', 123) 14 | ``` 15 | 16 | Pasted from README.md of `github.com/rxi/log.lua`: 17 | # log.lua 18 | A tiny logging module for Lua. 19 | 20 | ![screenshot from 2014-07-04 19 55 55](https://cloud.githubusercontent.com/assets/3920290/3484524/2ea2a9c6-03ad-11e4-9ed5-a9744c6fd75d.png) 21 | 22 | 23 | ## Usage 24 | log.lua provides 6 functions, each function takes all its arguments, 25 | concatenates them into a string then outputs the string to the console and -- 26 | if one is set -- the log file: 27 | 28 | * **log.trace(...)** 29 | * **log.debug(...)** 30 | * **log.info(...)** 31 | * **log.warn(...)** 32 | * **log.error(...)** 33 | * **log.fatal(...)** 34 | 35 | 36 | ### Additional options 37 | log.lua provides variables for setting additional options: 38 | 39 | #### log.usecolor 40 | Whether colors should be used when outputting to the console, this is `true` by 41 | default. If you're using a console which does not support ANSI color escape 42 | codes then this should be disabled. 43 | 44 | #### log.outfile 45 | The name of the file where the log should be written, log files do not contain 46 | ANSI colors and always use the full date rather than just the time. By default 47 | `log.outfile` is `nil` (no log file is used). If a file which does not exist is 48 | set as the `log.outfile` then it is created on the first message logged. If the 49 | file already exists it is appended to. 50 | 51 | #### log.level 52 | The minimum level to log, any logging function called with a lower level than 53 | the `log.level` is ignored and no text is outputted or written. By default this 54 | value is set to `"trace"`, the lowest log level, such that no log messages are 55 | ignored. 56 | 57 | The level of each log mode, starting with the lowest log level is as follows: 58 | `"trace"` `"debug"` `"info"` `"warn"` `"error"` `"fatal"` 59 | 60 | 61 | ## License 62 | This library is free software; you can redistribute it and/or modify it under 63 | the terms of the MIT license. See [LICENSE](LICENSE) for details. 64 | --]] 65 | 66 | 67 | 68 | -- 69 | -- log.lua 70 | -- 71 | -- Copyright (c) 2016 rxi 72 | -- 73 | -- This library is free software; you can redistribute it and/or modify it 74 | -- under the terms of the MIT license. See LICENSE for details. 75 | -- 76 | 77 | local log = { _version = "0.1.0" } 78 | 79 | log.usecolor = true 80 | log.outfile = nil 81 | log.level = "trace" 82 | 83 | 84 | local modes = { 85 | { name = "trace", color = "\27[34m", }, 86 | { name = "debug", color = "\27[36m", }, 87 | { name = "info", color = "\27[32m", }, 88 | { name = "warn", color = "\27[33m", }, 89 | { name = "error", color = "\27[31m", }, 90 | { name = "fatal", color = "\27[35m", }, 91 | } 92 | 93 | 94 | local levels = {} 95 | for i, v in ipairs(modes) do 96 | levels[v.name] = i 97 | end 98 | 99 | 100 | local round = function(x, increment) 101 | increment = increment or 1 102 | x = x / increment 103 | return (x > 0 and math.floor(x + .5) or math.ceil(x - .5)) * increment 104 | end 105 | 106 | 107 | local _tostring = tostring 108 | 109 | local tostring = function(...) 110 | local t = {} 111 | for i = 1, select('#', ...) do 112 | local x = select(i, ...) 113 | if type(x) == "number" then 114 | x = round(x, .01) 115 | end 116 | t[#t + 1] = _tostring(x) 117 | end 118 | return table.concat(t, " ") 119 | end 120 | 121 | 122 | for i, x in ipairs(modes) do 123 | local nameupper = x.name:upper() 124 | log[x.name] = function(...) 125 | 126 | -- Return early if we're below the log level 127 | if i < levels[log.level] then 128 | return 129 | end 130 | 131 | local msg = tostring(...) 132 | local info = debug.getinfo(2, "Sl") 133 | local lineinfo = info.short_src .. ":" .. info.currentline 134 | 135 | -- Output to console 136 | print(string.format("%s[%-6s%s]%s %s: %s", 137 | log.usecolor and x.color or "", 138 | nameupper, 139 | os.date("%H:%M:%S"), 140 | log.usecolor and "\27[0m" or "", 141 | lineinfo, 142 | msg)) 143 | 144 | -- Output to log file 145 | if log.outfile then 146 | local fp = io.open(log.outfile, "a") 147 | local str = string.format("[%-6s%s] %s: %s\n", 148 | nameupper, os.date(), lineinfo, msg) 149 | fp:write(str) 150 | fp:close() 151 | end 152 | end 153 | 154 | -- bshillingford: add formatted versions, 155 | -- e.g. log.infof(...) as alias for log.info(string.format(...) 156 | log[x.name .. 'f'] = function(...) 157 | -- Return early if we're below the log level 158 | if i < levels[log.level] then 159 | return 160 | end 161 | 162 | local fmt = ... -- i.e. first arg; note: select(2, ...) gets everything after 163 | 164 | local info = debug.getinfo(2, "Sl") 165 | local lineinfo = info.short_src .. ":" .. info.currentline 166 | 167 | -- Output to console 168 | print(string.format("%s[%-6s%s]%s %s: " .. fmt, 169 | log.usecolor and x.color or "", 170 | nameupper, 171 | os.date("%H:%M:%S"), 172 | log.usecolor and "\27[0m" or "", 173 | lineinfo, 174 | select(2, ...))) 175 | 176 | -- Output to log file 177 | if log.outfile then 178 | local fp = io.open(log.outfile, "a") 179 | local str = string.format("[%-6s%s] %s: " .. fmt .. "\n", 180 | nameupper, os.date(), lineinfo, select(2, ...)) 181 | fp:write(str) 182 | fp:close() 183 | end 184 | end 185 | end 186 | 187 | 188 | return log 189 | -------------------------------------------------------------------------------- /code/include/util.lua: -------------------------------------------------------------------------------- 1 | local util = {} 2 | local log = require 'include.log' 3 | 4 | function util.euclidean_dist(p, q) 5 | return math.sqrt(util.euclidean_dist2(p, q)) 6 | end 7 | 8 | function util.euclidean_dist2(p, q) 9 | assert(#p == #q, 'vectors must have the same length') 10 | local sum = 0 11 | for i in ipairs(p) do 12 | sum = sum + (p[i] - q[i]) ^ 2 13 | end 14 | return sum 15 | end 16 | 17 | function util.resetParams(cur_module) 18 | if cur_module.modules then 19 | for i, module in ipairs(cur_module.modules) do 20 | util.resetParams(module) 21 | end 22 | else 23 | cur_module:reset() 24 | end 25 | 26 | return cur_module 27 | end 28 | 29 | function util.dc(orig) 30 | local orig_type = torch.type(orig) 31 | local copy 32 | if orig_type == 'table' then 33 | copy = {} 34 | for orig_key, orig_value in next, orig, nil do 35 | copy[util.dc(orig_key)] = util.dc(orig_value) 36 | end 37 | setmetatable(copy, util.dc(getmetatable(orig))) 38 | elseif orig_type == 'torch.FloatTensor' or orig_type == 'torch.DoubleTensor' or orig_type == 'torch.CudaTensor' then 39 | -- Torch tensor 40 | copy = orig:clone() 41 | else 42 | -- number, string, boolean, etc 43 | copy = orig 44 | end 45 | 46 | return copy 47 | end 48 | 49 | function util.copyManyTimes(net, n) 50 | local nets = {} 51 | 52 | for i = 1, n do 53 | nets[#nets + 1] = util.resetParams(net:clone()) 54 | end 55 | 56 | return nets 57 | end 58 | 59 | function util.cloneManyTimes(net, T) 60 | local clones = {} 61 | local params, gradParams = net:parameters() 62 | if params == nil then 63 | params = {} 64 | end 65 | local paramsNoGrad 66 | if net.parametersNoGrad then 67 | paramsNoGrad = net:parametersNoGrad() 68 | end 69 | local mem = torch.MemoryFile("w"):binary() 70 | mem:writeObject(net) 71 | for t = 1, T do 72 | -- We need to use a new reader for each clone. 73 | -- We don't want to use the pointers to already read objects. 74 | local reader = torch.MemoryFile(mem:storage(), "r"):binary() 75 | local clone = reader:readObject() 76 | reader:close() 77 | local cloneParams, cloneGradParams = clone:parameters() 78 | local cloneParamsNoGrad 79 | for i = 1, #params do 80 | cloneParams[i]:set(params[i]) 81 | cloneGradParams[i]:set(gradParams[i]) 82 | end 83 | if paramsNoGrad then 84 | cloneParamsNoGrad = clone:parametersNoGrad() 85 | for i = 1, #paramsNoGrad do 86 | cloneParamsNoGrad[i]:set(paramsNoGrad[i]) 87 | end 88 | end 89 | clones[t] = clone 90 | collectgarbage() 91 | end 92 | mem:close() 93 | return clones 94 | end 95 | 96 | function util.spairs(t, order) 97 | -- collect the keys 98 | local keys = {} 99 | for k in pairs(t) do keys[#keys + 1] = k end 100 | 101 | -- if order function given, sort by it by passing the table and keys a, b, 102 | -- otherwise just sort the keys 103 | if order then 104 | table.sort(keys, function(a, b) return order(t, a, b) end) 105 | else 106 | table.sort(keys) 107 | end 108 | 109 | -- return the iterator function 110 | local i = 0 111 | return function() 112 | i = i + 1 113 | if keys[i] then 114 | return keys[i], t[keys[i]] 115 | end 116 | end 117 | end 118 | 119 | function util.sprint(t) 120 | for k, v in util.spairs(t) do 121 | log.debugf('opt[\'%s\'] = %s', k, v) 122 | end 123 | end 124 | 125 | function util.f2(f) 126 | return string.format("%.2f", f) 127 | end 128 | 129 | function util.f3(f) 130 | return string.format("%.3f", f) 131 | end 132 | 133 | function util.f4(f) 134 | return string.format("%.4f", f) 135 | end 136 | 137 | function util.f5(f) 138 | return string.format("%.5f", f) 139 | end 140 | 141 | function util.f6(f) 142 | return string.format("%.6f", f) 143 | end 144 | 145 | function util.d(f) 146 | return string.format("%d", torch.round(f)) 147 | end 148 | 149 | 150 | 151 | return util -------------------------------------------------------------------------------- /code/model/ColorDigit.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | require 'optim' 4 | require 'csvigo' 5 | 6 | local kwargs = require 'include.kwargs' 7 | local log = require 'include.log' 8 | local util = require 'include.util' 9 | local LSTM = require 'module.LSTM' 10 | local GRU = require 'module.GRU' 11 | require 'module.LinearO' 12 | require 'module.Binarize' 13 | require 'module.Print' 14 | require 'module.GaussianNoise' 15 | require 'module.rmsprop' 16 | 17 | return function(opt) 18 | 19 | local exp = {} 20 | 21 | function exp.optim(iter) 22 | -- iter can be used for learning rate decay 23 | -- local optimfunc = optim.adam 24 | local optimfunc = optim.rmsprop 25 | local optimconfig = { learningRate = opt.learningrate } 26 | return optimfunc, optimconfig 27 | end 28 | 29 | function exp.save(opt, stats, model) 30 | if stats.e % opt.step == 0 then 31 | if opt.filename == '' then 32 | exp.save_path = exp.save_path or paths.concat('results', opt.game .. '_' .. opt.game_nagents .. 33 | (opt.model_dial == 1 and '_dial' or '') .. '_' .. string.upper(string.format("%x", opt.seed))) 34 | else 35 | exp.save_path = exp.save_path or paths.concat('results', opt.game .. '_' .. opt.game_nagents .. 36 | (opt.model_dial == 1 and '_dial' or '') .. '_' .. opt.filename .. '_' .. string.upper(string.format("%x", opt.seed))) 37 | end 38 | 39 | 40 | -- Save opt 41 | if stats.e == opt.step then 42 | os.execute('mkdir -p ' .. exp.save_path) 43 | local opt_csv = {} 44 | for k, v in util.spairs(opt) do 45 | table.insert(opt_csv, { k, v }) 46 | end 47 | 48 | csvigo.save({ 49 | path = paths.concat(exp.save_path, 'opt.csv'), 50 | data = opt_csv, 51 | verbose = false 52 | }) 53 | end 54 | 55 | -- keep stats 56 | stats.history = stats.history or { { 'e', 'td_err', 'td_comm', 'train_r', 'test_r', 'steps', 'comm_per', 'te' } } 57 | table.insert(stats.history, { 58 | stats.e, 59 | stats.td_err:mean(), 60 | stats.td_comm:mean(), 61 | stats.train_r:mean(), 62 | stats.test_r:mean(), 63 | stats.steps:mean(), 64 | stats.comm_per:mean(), 65 | stats.te:mean() 66 | }) 67 | 68 | -- Save stats csv 69 | csvigo.save({ 70 | path = paths.concat(exp.save_path, 'stats.csv'), 71 | data = stats.history, 72 | verbose = false 73 | }) 74 | 75 | -- Save action histogram 76 | if opt.hist_action == 1 then 77 | -- Append to memory 78 | stats.history_hist_action = stats.history_hist_action or {} 79 | table.insert(stats.history_hist_action, 80 | stats.hist_action_avg:totable()[1]) 81 | 82 | -- save csv 83 | csvigo.save({ 84 | path = paths.concat(exp.save_path, 'hist_action.csv'), 85 | data = stats.history_hist_action, 86 | verbose = false 87 | }) 88 | end 89 | 90 | -- Save action histogram 91 | if opt.hist_comm == 1 then 92 | -- Append to memory 93 | stats.history_hist_comm = stats.history_hist_comm or {} 94 | table.insert(stats.history_hist_comm, 95 | stats.hist_comm_avg:totable()[1]) 96 | 97 | -- save csv 98 | csvigo.save({ 99 | path = paths.concat(exp.save_path, 'hist_comm.csv'), 100 | data = stats.history_hist_comm, 101 | verbose = false 102 | }) 103 | end 104 | 105 | 106 | -- save model 107 | if stats.e % (opt.step * 10) == 0 then 108 | log.debug('Saving model') 109 | 110 | -- clear state 111 | -- exp.clearState(model.agent) 112 | 113 | -- save model 114 | local filename = paths.concat(exp.save_path, 'exp.t7') 115 | torch.save(filename, { opt, stats, model.agent }) 116 | end 117 | end 118 | end 119 | 120 | function exp.load() 121 | end 122 | 123 | function exp.clearState(model) 124 | for i = 1, #model do 125 | model[i]:clearState() 126 | end 127 | end 128 | 129 | function exp.training(model) 130 | for i = 1, #model do 131 | model[i]:training() 132 | end 133 | end 134 | 135 | function exp.evaluate(model) 136 | for i = 1, #model do 137 | model[i]:evaluate() 138 | end 139 | end 140 | 141 | function exp.getParameters() 142 | -- Get model params 143 | local a = nn.Container() 144 | for i = 1, #exp.agent do 145 | a:add(exp.agent[i]) 146 | end 147 | local params, gradParams = a:getParameters() 148 | 149 | log.infof('Creating model(s), params=%d', params:nElement()) 150 | 151 | -- Get target model params 152 | local a = nn.Container() 153 | for i = 1, #exp.agent_target do 154 | a:add(exp.agent_target[i]) 155 | end 156 | local params_target, gradParams_target = a:getParameters() 157 | 158 | log.infof('Creating target model(s), params=%d', params_target:nElement()) 159 | 160 | return params, gradParams, params_target, gradParams_target 161 | end 162 | 163 | function exp.id(step_i, agent_i) 164 | return (step_i - 1) * opt.game_nagents + agent_i 165 | end 166 | 167 | function exp.stats(opt, game, stats, e) 168 | if e % opt.step == 0 then 169 | end 170 | end 171 | 172 | local function create_agent() 173 | 174 | -- Sizes 175 | local comm_size = 0 176 | 177 | if (opt.game_comm_bits > 0) and (opt.game_nagents > 1) then 178 | comm_size = opt.game_comm_bits * opt.game_nagents 179 | end 180 | 181 | local action_aware_size = 0 182 | if opt.model_action_aware == 1 then 183 | action_aware_size = opt.game_action_space_total 184 | end 185 | 186 | -- Networks 187 | local model_mnist = nn.Sequential() 188 | 189 | if opt.game_use_digits == 1 then 190 | model_mnist:add(nn.LookupTable(10, opt.model_rnn_size)) 191 | else 192 | if opt.game_vision_net == 'mlp' then 193 | model_mnist:add(nn.View(-1, opt.game_colors * 28 * 28)) 194 | if opt.model_bn == 1 then model_mnist:add(nn.BatchNormalization(opt.game_colors * 28 * 28)) end 195 | model_mnist:add(nn.Linear(opt.game_colors * 28 * 28, 128)) -- fully connected layer (matrix multiplication between input and weights) 196 | model_mnist:add(nn.ReLU(true)) 197 | if opt.model_bn == 1 then model_mnist:add(nn.BatchNormalization(128)) end 198 | model_mnist:add(nn.Linear(128, opt.model_rnn_size)) 199 | model_mnist:add(nn.View(-1, opt.model_rnn_size)) 200 | else 201 | model_mnist:add(nn.SpatialConvolutionMM(opt.game_colors, 32, 5, 5)) -- 1 input image channel, 6 output channels, 5x5 convolution kernel 202 | model_mnist:add(nn.ReLU(true)) 203 | model_mnist:add(nn.SpatialMaxPooling(3, 3, 3, 3)) -- A max-pooling operation that looks at 2x2 windows and finds the max. 204 | model_mnist:add(nn.SpatialConvolutionMM(32, 64, 5, 5)) 205 | model_mnist:add(nn.ReLU(true)) 206 | model_mnist:add(nn.SpatialMaxPooling(2, 2, 2, 2)) 207 | 208 | local conv_out = model_mnist:forward(torch.zeros(1, opt.game_colors, 28, 28)):nElement() 209 | 210 | model_mnist:add(nn.View(-1, conv_out)) 211 | model_mnist:add(nn.Linear(conv_out, opt.model_rnn_size)) 212 | model_mnist:add(nn.View(-1, opt.model_rnn_size)) 213 | end 214 | end 215 | 216 | 217 | -- Process inputs 218 | local model_input = nn.Sequential() 219 | model_input:add(nn.CAddTable(2)) 220 | 221 | -- RNN 222 | local model_rnn 223 | if opt.model_rnn == 'lstm' then 224 | model_rnn = LSTM(opt.model_rnn_size, 225 | opt.model_rnn_size, 226 | opt.model_rnn_layers, 227 | opt.model_dropout, 228 | opt.model_bn == 1) 229 | elseif opt.model_rnn == 'gru' then 230 | model_rnn = GRU(opt.model_rnn_size, 231 | opt.model_rnn_size, 232 | opt.model_rnn_layers, 233 | opt.model_dropout) 234 | end 235 | 236 | -- use default initialization for convnet, but uniform -0.08 to .08 for RNN: 237 | -- double parens necessary 238 | for _, param in ipairs((model_rnn:parameters())) do 239 | param:uniform(-0.08, 0.08) 240 | end 241 | 242 | -- Output 243 | local model_out = nn.Sequential() 244 | if opt.model_dropout > 0 then model_out:add(nn.Dropout(opt.model_dropout)) end 245 | model_out:add(nn.Linear(opt.model_rnn_size, opt.model_rnn_size)) 246 | model_out:add(nn.ReLU(true)) 247 | model_out:add(nn.Linear(opt.model_rnn_size, opt.game_action_space_total)) 248 | 249 | -- Construct Graph 250 | 251 | local in_state = nn.Identity()() 252 | local in_id = nn.Identity()() 253 | local in_rnn_state = nn.Identity()() 254 | 255 | local in_comm, in_action 256 | 257 | local in_all = { 258 | model_mnist(in_state), 259 | nn.LookupTable(opt.game_nagents, opt.model_rnn_size)(in_id) 260 | } 261 | 262 | -- Communication enabled 263 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 264 | in_comm = nn.Identity()() 265 | -- Process comm 266 | local model_comm = nn.Sequential() 267 | model_comm:add(nn.View(-1, comm_size)) 268 | if opt.model_dial == 1 then 269 | if opt.model_comm_narrow == 1 then 270 | model_comm:add(nn.Sigmoid()) 271 | else 272 | model_comm:add(nn.SoftMax()) 273 | end 274 | end 275 | if opt.model_bn == 1 and opt.model_dial == 1 then 276 | model_comm:add(nn.BatchNormalization(comm_size)) 277 | end 278 | model_comm:add(nn.Linear(comm_size, opt.model_rnn_size)) 279 | if opt.model_comm_narrow == 1 then 280 | model_comm:add(nn.ReLU(true)) 281 | end 282 | 283 | -- Process inputs node 284 | table.insert(in_all, model_comm(in_comm)) 285 | end 286 | 287 | -- Last action enabled 288 | if opt.model_action_aware == 1 then 289 | in_action = nn.Identity()() 290 | 291 | -- Process action node (+1 for no-action at 0-step) 292 | if opt.model_dial == 0 then 293 | local in_action_aware = nn.CAddTable(2)({ 294 | nn.LookupTable(opt.game_action_space + 1, opt.model_rnn_size)(nn.SelectTable(1)(in_action)), 295 | nn.LookupTable(opt.game_comm_bits + 1, opt.model_rnn_size)(nn.SelectTable(2)(in_action)) 296 | }) 297 | table.insert(in_all, in_action_aware) 298 | else 299 | table.insert(in_all, nn.LookupTable(action_aware_size + 1, opt.model_rnn_size)(in_action)) 300 | end 301 | end 302 | 303 | -- Process inputs 304 | local proc_input = model_input(in_all) 305 | 306 | -- 2*n+1 rnn inputs 307 | local rnn_input = {} 308 | table.insert(rnn_input, proc_input) 309 | 310 | -- Restore state 311 | for i = 1, opt.model_rnn_states do 312 | table.insert(rnn_input, nn.SelectTable(i)(in_rnn_state)) 313 | end 314 | 315 | local rnn_output = model_rnn(rnn_input) 316 | 317 | -- Split state and out 318 | local rnn_state = rnn_output 319 | local rnn_out = nn.SelectTable(opt.model_rnn_states)(rnn_output) 320 | 321 | -- Process out 322 | local proc_out = model_out(rnn_out) 323 | 324 | -- Create model 325 | local model_inputs = { in_state, in_id, in_rnn_state } 326 | local model_outputs = { rnn_state, proc_out } 327 | 328 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 329 | table.insert(model_inputs, in_comm) 330 | end 331 | 332 | if opt.model_action_aware == 1 then 333 | table.insert(model_inputs, in_action) 334 | end 335 | 336 | nngraph.annotateNodes() 337 | 338 | local model = nn.gModule(model_inputs, model_outputs) 339 | 340 | return model:type(opt.dtype) 341 | end 342 | 343 | -- Create model 344 | local agent = create_agent() 345 | local agent_target = agent:clone() 346 | 347 | -- Knowledge sharing 348 | if opt.model_know_share == 1 then 349 | exp.agent = util.cloneManyTimes(agent, opt.game_nagents * (opt.nsteps + 1)) 350 | exp.agent_target = util.cloneManyTimes(agent_target, opt.game_nagents * (opt.nsteps + 1)) 351 | else 352 | exp.agent = {} 353 | exp.agent_target = {} 354 | 355 | local agent_copies = util.copyManyTimes(agent, opt.game_nagents) 356 | local agent_target_copies = util.copyManyTimes(agent_target, opt.game_nagents) 357 | 358 | for i = 1, opt.game_nagents do 359 | local unrolled = util.cloneManyTimes(agent_copies[i], opt.nsteps + 1) 360 | local unrolled_target = util.cloneManyTimes(agent_target_copies[i], opt.nsteps + 1) 361 | for s = 1, opt.nsteps + 1 do 362 | exp.agent[exp.id(s, i)] = unrolled[s] 363 | exp.agent_target[exp.id(s, i)] = unrolled_target[s] 364 | end 365 | end 366 | end 367 | 368 | return exp 369 | end -------------------------------------------------------------------------------- /code/model/Switch.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | require 'optim' 4 | require 'csvigo' 5 | 6 | local kwargs = require 'include.kwargs' 7 | local log = require 'include.log' 8 | local util = require 'include.util' 9 | local LSTM = require 'module.LSTM' 10 | local GRU = require 'module.GRU' 11 | require 'module.rmsprop' 12 | require 'module.GaussianNoise' 13 | require 'module.Binarize' 14 | 15 | return function(opt) 16 | 17 | local exp = {} 18 | 19 | function exp.optim(iter) 20 | -- iter can be used for learning rate decay 21 | -- local optimfunc = optim.adam 22 | local optimfunc = optim.rmsprop 23 | local optimconfig = { learningRate = opt.learningrate } 24 | return optimfunc, optimconfig 25 | end 26 | 27 | function exp.save(opt, stats, model) 28 | if stats.e % opt.step == 0 then 29 | if opt.filename == '' then 30 | exp.save_path = exp.save_path or paths.concat('results', opt.game .. '_' .. opt.game_nagents .. 31 | (opt.model_dial == 1 and '_dial' or '') .. '_' .. string.upper(string.format("%x", opt.seed))) 32 | else 33 | exp.save_path = exp.save_path or paths.concat('results', opt.game .. '_' .. opt.game_nagents .. 34 | (opt.model_dial == 1 and '_dial' or '') .. '_' .. opt.filename .. '_' .. string.upper(string.format("%x", opt.seed))) 35 | end 36 | 37 | 38 | -- Save opt 39 | if stats.e == opt.step then 40 | os.execute('mkdir -p ' .. exp.save_path) 41 | local opt_csv = {} 42 | for k, v in util.spairs(opt) do 43 | table.insert(opt_csv, { k, v }) 44 | end 45 | 46 | csvigo.save({ 47 | path = paths.concat(exp.save_path, 'opt.csv'), 48 | data = opt_csv, 49 | verbose = false 50 | }) 51 | end 52 | 53 | -- keep stats 54 | stats.history = stats.history or { { 'e', 'td_err', 'td_comm', 'train_r', 'test_r', 'test_opt', 'test_god', 'steps', 'comm_per', 'te' } } 55 | table.insert(stats.history, { 56 | stats.e, 57 | stats.td_err:mean(), 58 | stats.td_comm:mean(), 59 | stats.train_r:mean(), 60 | stats.test_r:mean(), 61 | stats.test_opt:mean(), 62 | stats.test_god:mean(), 63 | stats.steps:mean(), 64 | stats.comm_per:mean(), 65 | stats.te:mean() 66 | }) 67 | 68 | -- Save stats csv 69 | csvigo.save({ 70 | path = paths.concat(exp.save_path, 'stats.csv'), 71 | data = stats.history, 72 | verbose = false 73 | }) 74 | 75 | -- Save action histogram 76 | if opt.hist_action == 1 then 77 | -- Append to memory 78 | stats.history_hist_action = stats.history_hist_action or {} 79 | table.insert(stats.history_hist_action, 80 | stats.hist_action_avg:totable()[1]) 81 | 82 | -- save csv 83 | csvigo.save({ 84 | path = paths.concat(exp.save_path, 'hist_action.csv'), 85 | data = stats.history_hist_action, 86 | verbose = false 87 | }) 88 | end 89 | 90 | -- Save action histogram 91 | if opt.hist_comm == 1 then 92 | -- Append to memory 93 | stats.history_hist_comm = stats.history_hist_comm or {} 94 | table.insert(stats.history_hist_comm, 95 | stats.hist_comm_avg:totable()[1]) 96 | 97 | -- save csv 98 | csvigo.save({ 99 | path = paths.concat(exp.save_path, 'hist_comm.csv'), 100 | data = stats.history_hist_comm, 101 | verbose = false 102 | }) 103 | end 104 | -- save model 105 | if stats.e % (opt.step * 10) == 0 then 106 | log.debug('Saving model') 107 | 108 | -- clear state 109 | -- exp.clearState(model.agent) 110 | 111 | -- save model 112 | local filename = paths.concat(exp.save_path, 'exp.t7') 113 | torch.save(filename, { opt, stats, model.agent }) 114 | end 115 | end 116 | end 117 | 118 | function exp.load() 119 | end 120 | 121 | function exp.clearState(model) 122 | for i = 1, #model do 123 | model[i]:clearState() 124 | end 125 | end 126 | 127 | function exp.training(model) 128 | for i = 1, #model do 129 | model[i]:training() 130 | end 131 | end 132 | 133 | function exp.evaluate(model) 134 | for i = 1, #model do 135 | model[i]:evaluate() 136 | end 137 | end 138 | 139 | function exp.getParameters() 140 | -- Get model params 141 | local a = nn.Container() 142 | for i = 1, #exp.agent do 143 | a:add(exp.agent[i]) 144 | end 145 | local params, gradParams = a:getParameters() 146 | 147 | log.infof('Creating model(s), params=%d', params:nElement()) 148 | 149 | -- Get target model params 150 | local a = nn.Container() 151 | for i = 1, #exp.agent_target do 152 | a:add(exp.agent_target[i]) 153 | end 154 | local params_target, gradParams_target = a:getParameters() 155 | 156 | log.infof('Creating target model(s), params=%d', params_target:nElement()) 157 | 158 | return params, gradParams, params_target, gradParams_target 159 | end 160 | 161 | function exp.id(step_i, agent_i) 162 | return (step_i - 1) * opt.game_nagents + agent_i 163 | end 164 | 165 | function exp.stats(opt, game, stats, e) 166 | 167 | if e % opt.step_test == 0 then 168 | local test_idx = (e / opt.step_test - 1) % (opt.step / opt.step_test) + 1 169 | 170 | -- Initialise 171 | stats.test_opt = stats.test_opt or torch.zeros(opt.step / opt.step_test, opt.game_nagents) 172 | stats.test_god = stats.test_god or torch.zeros(opt.step / opt.step_test, opt.game_nagents) 173 | 174 | -- Naive strategy 175 | local r_naive = 0 176 | for b = 1, opt.bs do 177 | local has_been = game.has_been[{ { b }, { 1, opt.nsteps }, {} }]:sum(2):squeeze(2):gt(0):float():sum() 178 | if has_been == opt.game_nagents then 179 | r_naive = r_naive + game.reward_all_live 180 | else 181 | r_naive = r_naive + game.reward_all_die 182 | end 183 | end 184 | stats.test_opt[test_idx] = r_naive / opt.bs 185 | 186 | -- God strategy 187 | local r_god = 0 188 | for b = 1, opt.bs do 189 | local has_been = game.has_been[{ { b }, { 1, opt.nsteps }, {} }]:sum(2):squeeze(2):gt(0):float():sum() 190 | if has_been == opt.game_nagents then 191 | r_god = r_god + game.reward_all_live 192 | end 193 | end 194 | stats.test_god[test_idx] = r_god / opt.bs 195 | end 196 | 197 | -- Keep stats 198 | if e == opt.step then 199 | stats.test_opt_avg = stats.test_opt:mean() 200 | stats.test_god_avg = stats.test_god:mean() 201 | elseif e % opt.step == 0 then 202 | local coef = 0.9 203 | stats.test_opt_avg = stats.test_opt_avg * coef + stats.test_opt:mean() * (1 - coef) 204 | stats.test_god_avg = stats.test_god_avg * coef + stats.test_god:mean() * (1 - coef) 205 | end 206 | 207 | -- Print statistics 208 | if e % opt.step == 0 then 209 | log.infof('te_opt=%.2f, te_opt_avg=%.2f, te_god=%.2f, te_god_avg=%.2f', 210 | stats.test_opt:mean(), 211 | stats.test_opt_avg, 212 | stats.test_god:mean(), 213 | stats.test_god_avg) 214 | end 215 | end 216 | 217 | local function create_agent() 218 | 219 | -- Sizes 220 | local comm_size = 0 221 | 222 | if (opt.game_comm_bits > 0) and (opt.game_nagents > 1) then 223 | if opt.game_comm_limited then 224 | comm_size = opt.game_comm_bits 225 | else 226 | error('game_comm_limited is required') 227 | end 228 | end 229 | 230 | local action_aware_size = 0 231 | if opt.model_action_aware == 1 then 232 | action_aware_size = opt.game_action_space_total 233 | end 234 | 235 | 236 | -- Process inputs 237 | local model_input = nn.Sequential() 238 | model_input:add(nn.CAddTable(2)) 239 | -- if opt.model_bn == 1 then model_input:add(nn.BatchNormalization(opt.model_rnn_size)) end 240 | 241 | local model_state = nn.Sequential() 242 | model_state:add(nn.LookupTable(2, opt.model_rnn_size)) 243 | 244 | -- RNN 245 | local model_rnn 246 | if opt.model_rnn == 'lstm' then 247 | model_rnn = LSTM(opt.model_rnn_size, 248 | opt.model_rnn_size, 249 | opt.model_rnn_layers, 250 | opt.model_dropout, 251 | opt.model_bn == 1) 252 | elseif opt.model_rnn == 'gru' then 253 | model_rnn = GRU(opt.model_rnn_size, 254 | opt.model_rnn_size, 255 | opt.model_rnn_layers, 256 | opt.model_dropout) 257 | end 258 | 259 | -- use default initialization for convnet, but uniform -0.08 to .08 for RNN: 260 | -- double parens necessary 261 | for _, param in ipairs((model_rnn:parameters())) do 262 | param:uniform(-0.08, 0.08) 263 | end 264 | 265 | -- Output 266 | local model_out = nn.Sequential() 267 | if opt.model_dropout > 0 then model_out:add(nn.Dropout(opt.model_dropout)) end 268 | model_out:add(nn.Linear(opt.model_rnn_size, opt.model_rnn_size)) 269 | model_out:add(nn.ReLU(true)) 270 | model_out:add(nn.Linear(opt.model_rnn_size, opt.game_action_space_total)) 271 | 272 | -- Construct Graph 273 | local in_state = nn.Identity()() 274 | local in_id = nn.Identity()() 275 | local in_rnn_state = nn.Identity()() 276 | 277 | local in_comm, in_action 278 | 279 | local in_all = { 280 | model_state(in_state), 281 | nn.LookupTable(opt.game_nagents, opt.model_rnn_size)(in_id) 282 | } 283 | 284 | -- Communication enabled 285 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 286 | in_comm = nn.Identity()() 287 | -- Process comm 288 | local model_comm = nn.Sequential() 289 | model_comm:add(nn.View(-1, comm_size)) 290 | if opt.model_dial == 1 then 291 | if opt.model_comm_narrow == 1 then 292 | model_comm:add(nn.Sigmoid()) 293 | else 294 | model_comm:add(nn.SoftMax()) 295 | end 296 | end 297 | if opt.model_bn == 1 and opt.model_dial == 1 then 298 | model_comm:add(nn.BatchNormalization(comm_size)) 299 | end 300 | model_comm:add(nn.Linear(comm_size, opt.model_rnn_size)) 301 | if opt.model_comm_narrow == 1 then 302 | model_comm:add(nn.ReLU(true)) 303 | end 304 | 305 | -- Process inputs node 306 | table.insert(in_all, model_comm(in_comm)) 307 | end 308 | 309 | -- Last action enabled 310 | if opt.model_action_aware == 1 then 311 | in_action = nn.Identity()() 312 | 313 | -- Process action node (+1 for no-action at 0-step) 314 | if opt.model_dial == 0 then 315 | local in_action_aware = nn.CAddTable(2)({ 316 | nn.LookupTable(opt.game_action_space + 1, opt.model_rnn_size)(nn.SelectTable(1)(in_action)), 317 | nn.LookupTable(opt.game_comm_bits + 1, opt.model_rnn_size)(nn.SelectTable(2)(in_action)) 318 | }) 319 | table.insert(in_all, in_action_aware) 320 | else 321 | table.insert(in_all, nn.LookupTable(action_aware_size + 1, opt.model_rnn_size)(in_action)) 322 | end 323 | end 324 | 325 | -- Process inputs 326 | local proc_input = model_input(in_all) 327 | 328 | -- 2*n+1 rnn inputs 329 | local rnn_input = {} 330 | table.insert(rnn_input, proc_input) 331 | 332 | -- Restore state 333 | for i = 1, opt.model_rnn_states do 334 | table.insert(rnn_input, nn.SelectTable(i)(in_rnn_state)) 335 | end 336 | 337 | local rnn_output = model_rnn(rnn_input) 338 | 339 | -- Split state and out 340 | local rnn_state = rnn_output 341 | local rnn_out = nn.SelectTable(opt.model_rnn_states)(rnn_output) 342 | 343 | -- Process out 344 | local proc_out = model_out(rnn_out) 345 | 346 | -- Create model 347 | local model_inputs = { in_state, in_id, in_rnn_state } 348 | local model_outputs = { rnn_state, proc_out } 349 | 350 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 351 | table.insert(model_inputs, in_comm) 352 | end 353 | 354 | if opt.model_action_aware == 1 then 355 | table.insert(model_inputs, in_action) 356 | end 357 | 358 | nngraph.annotateNodes() 359 | 360 | local model = nn.gModule(model_inputs, model_outputs) 361 | 362 | return model:type(opt.dtype) 363 | end 364 | 365 | -- Create model 366 | local agent = create_agent() 367 | local agent_target = agent:clone() 368 | 369 | -- Knowledge sharing 370 | if opt.model_know_share == 1 then 371 | exp.agent = util.cloneManyTimes(agent, opt.game_nagents * (opt.nsteps + 1)) 372 | exp.agent_target = util.cloneManyTimes(agent_target, opt.game_nagents * (opt.nsteps + 1)) 373 | else 374 | exp.agent = {} 375 | exp.agent_target = {} 376 | 377 | local agent_copies = util.copyManyTimes(agent, opt.game_nagents) 378 | local agent_target_copies = util.copyManyTimes(agent_target, opt.game_nagents) 379 | 380 | for i = 1, opt.game_nagents do 381 | local unrolled = util.cloneManyTimes(agent_copies[i], opt.nsteps + 1) 382 | local unrolled_target = util.cloneManyTimes(agent_target_copies[i], opt.nsteps + 1) 383 | for s = 1, opt.nsteps + 1 do 384 | exp.agent[exp.id(s, i)] = unrolled[s] 385 | exp.agent_target[exp.id(s, i)] = unrolled_target[s] 386 | end 387 | end 388 | end 389 | 390 | return exp 391 | end -------------------------------------------------------------------------------- /code/module/Binarize.lua: -------------------------------------------------------------------------------- 1 | local Binarize, parent = torch.class('nn.Binarize', 'nn.Module') 2 | 3 | function Binarize:__init(stcFlag) 4 | parent.__init(self) 5 | self.stcFlag = stcFlag 6 | self.randmat = torch.Tensor(); 7 | self.outputR = torch.Tensor(); 8 | end 9 | 10 | function Binarize:updateOutput(input) 11 | self.randmat:resizeAs(input); 12 | self.outputR:resizeAs(input); 13 | self.output:resizeAs(input); 14 | self.outputR:copy(input):add(1):div(2) 15 | if self.train and self.stcFlag then 16 | local mask = self.outputR - self.randmat:rand(self.randmat:size()) 17 | self.output = mask:sign() 18 | else 19 | self.output:copy(self.outputR):add(-0.5):sign() 20 | end 21 | return self.output 22 | end 23 | 24 | function Binarize:updateGradInput(input, gradOutput) 25 | self.gradInput:resizeAs(gradOutput) 26 | self.gradInput:copy(gradOutput) --:mul(0.5) 27 | return self.gradInput 28 | end 29 | -------------------------------------------------------------------------------- /code/module/GRU.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Recurrent Batch Normalization 4 | Tim Cooijmans, Nicolas Ballas, César Laurent, Çağlar Gülçehre, Aaron Courville 5 | http://arxiv.org/abs/1603.09025 6 | 7 | Implemented by Yannis M. Assael (www.yannisassael.com), 2016. 8 | 9 | Based on 10 | https://github.com/wojciechz/learning_to_execute, 11 | https://github.com/karpathy/char-rnn/blob/master/model/LSTM.lua, 12 | and Brendan Shillingford. 13 | 14 | Usage: 15 | local rnn = LSTM(input_size, rnn_size, n, dropout, bn) 16 | 17 | ]] -- 18 | 19 | require 'nn' 20 | require 'nngraph' 21 | 22 | local function GRU(input_size, rnn_size, n, dropout) 23 | dropout = dropout or 0 24 | -- there are n+1 inputs (hiddens on each layer and x) 25 | local inputs = {} 26 | table.insert(inputs, nn.Identity()()) -- x 27 | for L = 1, n do 28 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 29 | end 30 | 31 | function new_input_sum(insize, xv, hv) 32 | local i2h = nn.Linear(insize, rnn_size)(xv) 33 | local h2h = nn.Linear(rnn_size, rnn_size)(hv) 34 | return nn.CAddTable()({ i2h, h2h }) 35 | end 36 | 37 | local x, input_size_L 38 | local outputs = {} 39 | for L = 1, n do 40 | 41 | local prev_h = inputs[L + 1] 42 | -- the input to this layer 43 | if L == 1 then 44 | x = inputs[1] 45 | input_size_L = input_size 46 | else 47 | x = outputs[(L - 1)] 48 | if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any 49 | input_size_L = rnn_size 50 | end 51 | -- GRU tick 52 | -- forward the update and reset gates 53 | local update_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) 54 | local reset_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) 55 | -- compute candidate hidden state 56 | local gated_hidden = nn.CMulTable()({ reset_gate, prev_h }) 57 | local p2 = nn.Linear(rnn_size, rnn_size)(gated_hidden) 58 | local p1 = nn.Linear(input_size_L, rnn_size)(x) 59 | local hidden_candidate = nn.Tanh()(nn.CAddTable()({ p1, p2 })) 60 | -- compute new interpolated hidden state, based on the update gate 61 | local zh = nn.CMulTable()({ update_gate, hidden_candidate }) 62 | local zhm1 = nn.CMulTable()({ nn.AddConstant(1, false)(nn.MulConstant(-1, false)(update_gate)), prev_h }) 63 | local next_h = nn.CAddTable()({ zh, zhm1 }) 64 | 65 | table.insert(outputs, next_h) 66 | end 67 | 68 | nngraph.annotateNodes() 69 | 70 | return nn.gModule(inputs, outputs) 71 | end 72 | 73 | return GRU 74 | -------------------------------------------------------------------------------- /code/module/GaussianNoise.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Implemented by Yannis M. Assael (www.yannisassael.com), 2016. 3 | --]] 4 | 5 | local GaussianNoise, parent = torch.class('nn.GaussianNoise', 'nn.Module') 6 | 7 | function GaussianNoise:__init(std, ip) 8 | parent.__init(self) 9 | assert(type(std) == 'number', 'input is not scalar!') 10 | self.std = std 11 | 12 | -- default for inplace is false 13 | self.inplace = ip or false 14 | if (ip and type(ip) ~= 'boolean') then 15 | error('in-place flag must be boolean') 16 | end 17 | 18 | self.noise = torch.Tensor() 19 | self.train = true 20 | end 21 | 22 | function GaussianNoise:training() 23 | self.train = true 24 | end 25 | 26 | function GaussianNoise:evaluate() 27 | self.train = false 28 | end 29 | 30 | function GaussianNoise:updateOutput(input) 31 | if self.train and self.std > 0 then 32 | -- Generate noise 33 | self.noise:resizeAs(input):normal(0, self.std) 34 | 35 | if self.inplace then 36 | input:add(self.noise) 37 | self.output:set(input) 38 | else 39 | self.output:resizeAs(input) 40 | self.output:copy(input) 41 | self.output:add(self.noise) 42 | end 43 | else 44 | if self.inplace then 45 | self.output:set(input) 46 | else 47 | self.output:resizeAs(input) 48 | self.output:copy(input) 49 | end 50 | end 51 | return self.output 52 | end 53 | 54 | function GaussianNoise:updateGradInput(input, gradOutput) 55 | if self.inplace and self.train and self.std > 0 then 56 | self.gradInput:set(gradOutput) 57 | -- restore previous input value 58 | input:add(-1, self.noise) 59 | else 60 | self.gradInput:resizeAs(gradOutput) 61 | self.gradInput:copy(gradOutput) 62 | end 63 | return self.gradInput 64 | end 65 | -------------------------------------------------------------------------------- /code/module/LSTM.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Recurrent Batch Normalization 4 | Tim Cooijmans, Nicolas Ballas, César Laurent, Çağlar Gülçehre, Aaron Courville 5 | http://arxiv.org/abs/1603.09025 6 | 7 | Implemented by Yannis M. Assael (www.yannisassael.com), 2016. 8 | 9 | Based on 10 | https://github.com/wojciechz/learning_to_execute, 11 | https://github.com/karpathy/char-rnn/blob/master/model/LSTM.lua, 12 | and Brendan Shillingford. 13 | 14 | Usage: 15 | local rnn = LSTM(input_size, rnn_size, n, dropout, bn) 16 | 17 | ]] -- 18 | 19 | require 'nn' 20 | require 'nngraph' 21 | 22 | local function LSTM(input_size, rnn_size, n, dropout, bn) 23 | dropout = dropout or 0 24 | 25 | -- there will be 2*n+1 inputs 26 | local inputs = {} 27 | table.insert(inputs, nn.Identity()()) -- x 28 | for L = 1, n do 29 | table.insert(inputs, nn.Identity()()) -- prev_c[L] 30 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 31 | end 32 | 33 | local x, input_size_L 34 | local outputs = {} 35 | for L = 1, n do 36 | -- c,h from previos timesteps 37 | local prev_h = inputs[L * 2 + 1] 38 | local prev_c = inputs[L * 2] 39 | -- the input to this layer 40 | if L == 1 then 41 | x = inputs[1] 42 | input_size_L = input_size 43 | else 44 | x = outputs[(L - 1) * 2] 45 | if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any 46 | input_size_L = rnn_size 47 | end 48 | -- recurrent batch normalization 49 | -- http://arxiv.org/abs/1603.09025 50 | local bn_wx, bn_wh, bn_c 51 | if bn then 52 | bn_wx = nn.BatchNormalization(4 * rnn_size, 1e-5, 0.1, true) 53 | bn_wh = nn.BatchNormalization(4 * rnn_size, 1e-5, 0.1, true) 54 | bn_c = nn.BatchNormalization(rnn_size, 1e-5, 0.1, true) 55 | else 56 | bn_wx = nn.Identity() 57 | bn_wh = nn.Identity() 58 | bn_c = nn.Identity() 59 | end 60 | -- evaluate the input sums at once for efficiency 61 | local i2h = bn_wx(nn.Linear(input_size_L, 4 * rnn_size)(x):annotate { name = 'i2h_' .. L }):annotate { name = 'bn_wx_' .. L } 62 | local h2h = bn_wh(nn.Linear(rnn_size, 4 * rnn_size, false)(prev_h):annotate { name = 'h2h_' .. L }):annotate { name = 'bn_wh_' .. L } 63 | local all_input_sums = nn.CAddTable()({ i2h, h2h }) 64 | 65 | local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) 66 | local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) 67 | -- decode the gates 68 | local in_gate = nn.Sigmoid()(n1) 69 | local forget_gate = nn.Sigmoid()(n2) 70 | local out_gate = nn.Sigmoid()(n3) 71 | -- decode the write inputs 72 | local in_transform = nn.Tanh()(n4) 73 | -- perform the LSTM update 74 | local next_c = nn.CAddTable()({ 75 | nn.CMulTable()({ forget_gate, prev_c }), 76 | nn.CMulTable()({ in_gate, in_transform }) 77 | }) 78 | -- gated cells form the output 79 | local next_h = nn.CMulTable()({ out_gate, nn.Tanh()(bn_c(next_c):annotate { name = 'bn_c_' .. L }) }) 80 | 81 | table.insert(outputs, next_c) 82 | table.insert(outputs, next_h) 83 | end 84 | 85 | nngraph.annotateNodes() 86 | 87 | return nn.gModule(inputs, outputs) 88 | end 89 | 90 | return LSTM 91 | -------------------------------------------------------------------------------- /code/module/LinearO.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Torch Linear Unit with Orthogonal Weight Initialization 3 | 4 | Exact solutions to the nonlinear dynamics of learning in deep linear neural networks 5 | http://arxiv.org/abs/1312.6120 6 | 7 | Implemented by Yannis M. Assael (www.yannisassael.com), 2016. 8 | 9 | ]] -- 10 | 11 | local LinearO, parent = torch.class('nn.LinearO', 'nn.Linear') 12 | 13 | function LinearO:__init(inputSize, outputSize) 14 | parent.__init(self, inputSize, outputSize) 15 | self:reset() 16 | end 17 | 18 | function LinearO:reset() 19 | local initScale = 1.1 -- math.sqrt(2) 20 | 21 | local M1 = torch.randn(self.weight:size(1), self.weight:size(1)) 22 | local M2 = torch.randn(self.weight:size(2), self.weight:size(2)) 23 | 24 | local n_min = math.min(self.weight:size(1), self.weight:size(2)) 25 | 26 | -- QR decomposition of random matrices ~ N(0, 1) 27 | local Q1, R1 = torch.qr(M1) 28 | local Q2, R2 = torch.qr(M2) 29 | 30 | self.weight:copy(Q1:narrow(2, 1, n_min) * Q2:narrow(1, 1, n_min)):mul(initScale) 31 | 32 | self.bias:zero() 33 | end 34 | 35 | -------------------------------------------------------------------------------- /code/module/Print.lua: -------------------------------------------------------------------------------- 1 | local Print, parent = torch.class('nn.Print', 'nn.Module') 2 | 3 | function Print:__init(stcFlag) 4 | parent.__init(self) 5 | end 6 | 7 | function Print:updateOutput(input) 8 | print(input) 9 | self.output = input 10 | return self.output 11 | end 12 | 13 | function Print:updateGradInput(input, gradOutput) 14 | self.gradInput = gradOutput 15 | return self.gradInput 16 | end 17 | -------------------------------------------------------------------------------- /code/module/rmsprop.lua: -------------------------------------------------------------------------------- 1 | --[[ An implementation of RMSprop 2 | 3 | ARGS: 4 | 5 | - 'opfunc' : a function that takes a single input (X), the point 6 | of a evaluation, and returns f(X) and df/dX 7 | - 'x' : the initial point 8 | - 'config` : a table with configuration parameters for the optimizer 9 | - 'config.learningRate' : learning rate 10 | - 'config.alpha' : smoothing constant 11 | - 'config.epsilon' : value with which to initialise m 12 | - 'config.weightDecay' : weight decay 13 | - 'state' : a table describing the state of the optimizer; 14 | after each call the state is modified 15 | - 'state.m' : leaky sum of squares of parameter gradients, 16 | - 'state.tmp' : and the square root (with epsilon smoothing) 17 | 18 | RETURN: 19 | - `x` : the new x vector 20 | - `f(x)` : the function, evaluated before the update 21 | 22 | ]] 23 | 24 | function optim.rmsprop(opfunc, x, config, state) 25 | -- (0) get/update state 26 | local config = config or {} 27 | local state = state or config 28 | local lr = config.learningRate or 1e-2 29 | local alpha = config.alpha or 0.99 30 | local epsilon = config.epsilon or 1e-8 31 | local wd = config.weightDecay or 0 32 | 33 | -- (1) evaluate f(x) and df/dx 34 | local fx, dfdx = opfunc(x) 35 | 36 | -- (2) weight decay 37 | if wd ~= 0 then 38 | dfdx:add(wd, x) 39 | end 40 | 41 | -- (3) initialize mean square values and square gradient storage 42 | if not state.m then 43 | -- This line kills the performance 44 | -- state.m = torch.Tensor():typeAs(x):resizeAs(dfdx):fill(1) 45 | state.m = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 46 | state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx) 47 | end 48 | 49 | -- (4) calculate new (leaky) mean squared values 50 | state.m:mul(alpha) 51 | state.m:addcmul(1.0-alpha, dfdx, dfdx) 52 | 53 | -- (5) perform update 54 | state.tmp:sqrt(state.m):add(epsilon) 55 | x:addcdiv(-lr, dfdx, state.tmp) 56 | 57 | -- return x*, f(x) before optimization 58 | return x, {fx} 59 | end 60 | -------------------------------------------------------------------------------- /code/module/rmspropm.lua: -------------------------------------------------------------------------------- 1 | -- RMSProp with momentum as found in "Generating Sequences With Recurrent Neural Networks" 2 | function optim.rmspropm(opfunc, x, config, state) 3 | -- Get state 4 | local config = config or {} 5 | local state = state or config 6 | local lr = config.learningRate or 1e-2 7 | local momentum = config.momentum or 0.95 8 | local epsilon = config.epsilon or 0.01 9 | 10 | -- Evaluate f(x) and df/dx 11 | local fx, dfdx = opfunc(x) 12 | 13 | -- Initialise storage 14 | if not state.g then 15 | state.g = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 16 | state.gSq = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 17 | state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx) 18 | end 19 | 20 | -- g = αg + (1 - α)df/dx 21 | state.g:mul(momentum):add(1 - momentum, dfdx) -- Calculate momentum 22 | -- tmp = df/dx . df/dx 23 | state.tmp:cmul(dfdx, dfdx) 24 | -- gSq = αgSq + (1 - α)df/dx 25 | state.gSq:mul(momentum):add(1 - momentum, state.tmp) -- Calculate "squared" momentum 26 | -- tmp = g . g 27 | state.tmp:cmul(state.g, state.g) 28 | -- tmp = (-tmp + gSq + ε)^0.5 29 | state.tmp:neg():add(state.gSq):add(epsilon):sqrt() 30 | 31 | -- Update x = x - lr x df/dx / tmp 32 | x:addcdiv(-lr, dfdx, state.tmp) 33 | 34 | -- Return x*, f(x) before optimisation 35 | return x, { fx } 36 | end 37 | -------------------------------------------------------------------------------- /code/results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iassael/learning-to-communicate/1cdfc235e07be9de2af38fb23d78f61b2fa7c99b/code/results/.gitkeep -------------------------------------------------------------------------------- /code/run_colordigit-dial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -bs 32 \ 4 | -game ColorDigit \ 5 | -game_level extra_hard_local \ 6 | -game_nagents 2 \ 7 | -game_action_space 2 \ 8 | -game_comm_limited 0 \ 9 | -game_comm_bits 1 \ 10 | -game_comm_sigma 2 \ 11 | -nsteps 2 \ 12 | -gamma 1 \ 13 | -model_dial 1 \ 14 | -model_know_share 1 \ 15 | -model_action_aware 1 \ 16 | -model_rnn_size 128 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 20000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 -------------------------------------------------------------------------------- /code/run_colordigit-rial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -bs 32 \ 4 | -game ColorDigit \ 5 | -game_level extra_hard_local \ 6 | -game_nagents 2 \ 7 | -game_action_space 2 \ 8 | -game_comm_limited 0 \ 9 | -game_comm_bits 1 \ 10 | -game_comm_sigma 2 \ 11 | -nsteps 2 \ 12 | -gamma 1 \ 13 | -model_dial 0 \ 14 | -model_know_share 1 \ 15 | -model_action_aware 1 \ 16 | -model_rnn_size 128 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 20000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 -------------------------------------------------------------------------------- /code/run_colordigit_many_steps-dial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game ColorDigit \ 4 | -game_level many_bits \ 5 | -game_nagents 2 \ 6 | -game_colors 1 \ 7 | -game_action_space 10 \ 8 | -game_comm_limited 0 \ 9 | -game_comm_bits 1 \ 10 | -game_comm_sigma 0.5 \ 11 | -nsteps 5 \ 12 | -eps 0.05 \ 13 | -gamma 1 \ 14 | -model_dial 1 \ 15 | -model_bn 1 \ 16 | -model_know_share 1 \ 17 | -model_action_aware 1 \ 18 | -model_rnn_size 128 \ 19 | -game_vision_net mlp \ 20 | -bs 32 \ 21 | -learningrate 0.0005 \ 22 | -nepisodes 50000 \ 23 | -step 100 \ 24 | -step_test 10 \ 25 | -step_target 100 \ 26 | -cuda 0 -------------------------------------------------------------------------------- /code/run_colordigit_many_steps-rial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game ColorDigit \ 4 | -game_level many_bits \ 5 | -game_nagents 2 \ 6 | -game_colors 1 \ 7 | -game_action_space 10 \ 8 | -game_comm_limited 0 \ 9 | -game_comm_bits 1 \ 10 | -game_comm_sigma 0.5 \ 11 | -nsteps 5 \ 12 | -eps 0.05 \ 13 | -gamma 1 \ 14 | -model_dial 0 \ 15 | -model_bn 1 \ 16 | -model_know_share 1 \ 17 | -model_action_aware 1 \ 18 | -model_rnn_size 128 \ 19 | -game_vision_net mlp \ 20 | -bs 32 \ 21 | -learningrate 0.0005 \ 22 | -nepisodes 50000 \ 23 | -step 100 \ 24 | -step_test 10 \ 25 | -step_target 100 \ 26 | -cuda 0 -------------------------------------------------------------------------------- /code/run_switch_3-dial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game Switch \ 4 | -game_nagents 3 \ 5 | -game_action_space 2 \ 6 | -game_comm_limited 1 \ 7 | -game_comm_bits 1 \ 8 | -game_comm_sigma 2 \ 9 | -nsteps 6 \ 10 | -gamma 1 \ 11 | -model_dial 1 \ 12 | -model_bn 1 \ 13 | -model_know_share 1 \ 14 | -model_action_aware 1 \ 15 | -model_rnn_size 128 \ 16 | -bs 32 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 5000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 23 | -------------------------------------------------------------------------------- /code/run_switch_3-rial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game Switch \ 4 | -game_nagents 3 \ 5 | -game_action_space 2 \ 6 | -game_comm_limited 1 \ 7 | -game_comm_bits 1 \ 8 | -game_comm_sigma 2 \ 9 | -nsteps 6 \ 10 | -gamma 1 \ 11 | -model_dial 0 \ 12 | -model_bn 1 \ 13 | -model_know_share 1 \ 14 | -model_action_aware 1 \ 15 | -model_rnn_size 128 \ 16 | -bs 32 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 5000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 23 | -------------------------------------------------------------------------------- /code/run_switch_4-dial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game Switch \ 4 | -game_nagents 4 \ 5 | -game_action_space 2 \ 6 | -game_comm_limited 1 \ 7 | -game_comm_bits 1 \ 8 | -game_comm_sigma 2 \ 9 | -nsteps 6 \ 10 | -gamma 1 \ 11 | -model_dial 1 \ 12 | -model_bn 1 \ 13 | -model_know_share 1 \ 14 | -model_action_aware 1 \ 15 | -model_rnn_size 128 \ 16 | -bs 32 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 50000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 23 | -------------------------------------------------------------------------------- /code/run_switch_4-rial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game Switch \ 4 | -game_nagents 4 \ 5 | -game_action_space 2 \ 6 | -game_comm_limited 1 \ 7 | -game_comm_bits 1 \ 8 | -game_comm_sigma 2 \ 9 | -nsteps 6 \ 10 | -gamma 1 \ 11 | -model_dial 0 \ 12 | -model_bn 1 \ 13 | -model_know_share 1 \ 14 | -model_action_aware 1 \ 15 | -model_rnn_size 128 \ 16 | -bs 32 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 50000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 23 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | # Learning to Communicate with Deep Multi-Agent Reinforcement Learning 3 | 4 | Jakob N. Foerster, Yannis M. Assael, Nando de Freitas, Shimon Whiteson 5 | 6 | ## PyTorch 7 | 8 | \- [PyTorch Implementation by @minqi](https://github.com/minqi/learning-to-communicate-pytorch) 9 | 10 | \- [Simplified PyTorch implementation in a colab by @JainMoksh](https://colab.research.google.com/gist/MJ10/2c0d1972f3dd1edcc3cd17c636aac8d2/dial.ipynb) 11 | 12 | ## Abstract 13 | 14 |

15 | Learning to Communicate 16 |

17 | 18 | We consider the problem of multiple agents sensing and acting in environments with the goal of maximising their shared utility. In these environments, agents must learn communication protocols in order to share information that is needed to solve the tasks. By embracing deep neural networks, we are able to demonstrate end-to-end learning of protocols in complex environments inspired by communication riddles and multi-agent computer vision problems with partial observability. We propose two approaches for learning in these domains: Reinforced Inter-Agent Learning (RIAL) and Differentiable Inter-Agent Learning (DIAL). The former uses deep Q-learning, while the latter exploits the fact that, during learning, agents can backpropagate error derivatives through (noisy) communication channels. Hence, this approach uses centralised learning but decentralised execution. Our experiments introduce new environments for studying the learning of communication protocols and present a set of engineering innovations that are essential for success in these domains. 19 | 20 | ## Links 21 | 22 | \- [PDF](https://papers.nips.cc/paper/6042-learning-to-communicate-with-deep-multi-agent-reinforcement-learning) 23 | 24 | \- [Montreal Deep Learning Summer School 2016 talk](http://videolectures.net/deeplearning2016_foerster_learning_communicate/) 25 | 26 | ## Execution 27 | ``` 28 | $ # Requirements: nvidia-docker 29 | $ # Build docker instance (takes a while) 30 | $ ./build.sh 31 | $ # Run docker instance 32 | $ ./run.sh 33 | $ # Run experiment e.g. 34 | $ ./run_switch_3-dial.sh 35 | ``` 36 | 37 | ## Bibtex 38 | @inproceedings{foerster2016learning, 39 | title={Learning to communicate with deep multi-agent reinforcement learning}, 40 | author={Foerster, Jakob and Assael, Yannis M and de Freitas, Nando and Whiteson, Shimon}, 41 | booktitle={Advances in Neural Information Processing Systems}, 42 | pages={2137--2145}, 43 | year={2016} 44 | } 45 | 46 | 47 | ## License 48 | 49 | Code licensed under the Apache License v2.0 50 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Launches a docker container using our image, and runs torch 4 | gpu=$1 5 | shift 6 | 7 | NV_GPU=$gpu nvidia-docker run --rm -ti \ 8 | -v `pwd`/code:/project \ 9 | $USER/comm \ 10 | $@ 11 | --------------------------------------------------------------------------------