├── .dockerignore ├── Dockerfile ├── LICENSE ├── build.sh ├── code ├── comm.lua ├── game │ ├── ColorDigit.lua │ └── Switch.lua ├── include │ ├── kwargs.lua │ ├── log.lua │ └── util.lua ├── model │ ├── ColorDigit.lua │ └── Switch.lua ├── module │ ├── Binarize.lua │ ├── GRU.lua │ ├── GaussianNoise.lua │ ├── LSTM.lua │ ├── LinearO.lua │ ├── Print.lua │ ├── rmsprop.lua │ └── rmspropm.lua ├── results │ └── .gitkeep ├── run_colordigit-dial.sh ├── run_colordigit-rial.sh ├── run_colordigit_many_steps-dial.sh ├── run_colordigit_many_steps-rial.sh ├── run_switch_3-dial.sh ├── run_switch_3-rial.sh ├── run_switch_4-dial.sh └── run_switch_4-rial.sh ├── readme.md └── run.sh /.dockerignore: -------------------------------------------------------------------------------- 1 | code -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:8.0-cudnn5-devel-ubuntu16.04 2 | # FROM ubuntu:16.04 3 | MAINTAINER Yannis Assael, Jakob Foerster 4 | 5 | # CUDA includes 6 | ENV CUDA_PATH /usr/local/cuda 7 | ENV CUDA_INCLUDE_PATH /usr/local/cuda/include 8 | ENV CUDA_LIBRARY_PATH /usr/local/cuda/lib64 9 | 10 | # Ubuntu Packages 11 | RUN apt-get update -y && apt-get install software-properties-common -y && \ 12 | add-apt-repository -y multiverse && apt-get update -y && apt-get upgrade -y && \ 13 | apt-get install -y apt-utils nano vim man build-essential wget sudo && \ 14 | rm -rf /var/lib/apt/lists/* 15 | 16 | # Install curl and other dependencies 17 | RUN apt-get update -y && apt-get install -y curl libssl-dev openssl libopenblas-dev \ 18 | libhdf5-dev hdf5-helpers hdf5-tools libhdf5-serial-dev libprotobuf-dev protobuf-compiler && \ 19 | curl -sk https://raw.githubusercontent.com/torch/distro/master/install-deps | bash && \ 20 | rm -rf /var/lib/apt/lists/* 21 | 22 | # Clone torch (and package) repos: 23 | RUN mkdir -p /opt && git clone https://github.com/torch/distro.git /opt/torch --recursive 24 | 25 | # Run installation script 26 | RUN cd /opt/torch && ./install.sh -b 27 | 28 | # Export environment variables manually 29 | ENV TORCH_DIR /opt/torch/pkg/torch/build/cmake-exports/ 30 | ENV LUA_PATH '/root/.luarocks/share/lua/5.1/?.lua;/root/.luarocks/share/lua/5.1/?/init.lua;/opt/torch/install/share/lua/5.1/?.lua;/opt/torch/install/share/lua/5.1/?/init.lua;./?.lua;/opt/torch/install/share/luajit-2.1.0-beta1/?.lua;/usr/local/share/lua/5.1/?.lua;/usr/local/share/lua/5.1/?/init.lua' 31 | ENV LUA_CPATH '/root/.luarocks/lib/lua/5.1/?.so;/opt/torch/install/lib/lua/5.1/?.so;./?.so;/usr/local/lib/lua/5.1/?.so;/usr/local/lib/lua/5.1/loadall.so' 32 | ENV PATH /opt/torch/install/bin:$PATH 33 | ENV LD_LIBRARY_PATH /opt/torch/install/lib:$LD_LIBRARY_PATH 34 | ENV DYLD_LIBRARY_PATH /opt/torch/install/lib:$DYLD_LIBRARY_PATH 35 | ENV LUA_CPATH '/opt/torch/install/lib/?.so;'$LUA_CPATH 36 | 37 | # Install torch packages 38 | RUN luarocks install totem && \ 39 | luarocks install https://raw.githubusercontent.com/deepmind/torch-hdf5/master/hdf5-0-0.rockspec && \ 40 | luarocks install unsup && \ 41 | luarocks install csvigo && \ 42 | luarocks install loadcaffe && \ 43 | luarocks install classic && \ 44 | luarocks install pprint && \ 45 | luarocks install class && \ 46 | luarocks install image && \ 47 | luarocks install mnist && \ 48 | luarocks install https://raw.githubusercontent.com/deepmind/torch-distributions/master/distributions-0-0.rockspec 49 | 50 | # Cleanup 51 | RUN rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 52 | 53 | WORKDIR /project 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | nvidia-docker build -t $USER/comm . 2 | -------------------------------------------------------------------------------- /code/comm.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Learning to Communicate with Deep Multi-Agent Reinforcement Learning 4 | 5 | @article{foerster2016learning, 6 | title={Learning to Communicate with Deep Multi-Agent Reinforcement Learning}, 7 | author={Foerster, Jakob N and Assael, Yannis M and de Freitas, Nando and Whiteson, Shimon}, 8 | journal={arXiv preprint arXiv:1605.06676}, 9 | year={2016} 10 | } 11 | 12 | ]] -- 13 | 14 | 15 | -- Configuration 16 | cmd = torch.CmdLine() 17 | cmd:text() 18 | cmd:text('Learning to Communicate with Deep Multi-Agent Reinforcement Learning') 19 | cmd:text() 20 | cmd:text('Options') 21 | 22 | -- general options: 23 | cmd:option('-seed', -1, 'initial random seed') 24 | cmd:option('-threads', 1, 'number of threads') 25 | 26 | -- gpu 27 | cmd:option('-cuda', 0, 'cuda') 28 | 29 | -- rl 30 | cmd:option('-gamma', 1, 'discount factor') 31 | cmd:option('-eps', 0.05, 'epsilon-greedy policy') 32 | 33 | -- model 34 | cmd:option('-model_rnn', 'gru', 'rnn type') 35 | cmd:option('-model_dial', 0, 'use dial connection or rial') 36 | cmd:option('-model_comm_narrow', 1, 'combines comm bits') 37 | cmd:option('-model_know_share', 1, 'knowledge sharing') 38 | cmd:option('-model_action_aware', 1, 'last action used as input') 39 | cmd:option('-model_rnn_size', 128, 'rnn size') 40 | cmd:option('-model_rnn_layers', 2, 'rnn layers') 41 | cmd:option('-model_dropout', 0, 'dropout') 42 | cmd:option('-model_bn', 1, 'batch normalisation') 43 | cmd:option('-model_target', 1, 'use a target network') 44 | cmd:option('-model_avg_q', 1, 'avearge q functions') 45 | 46 | -- training 47 | cmd:option('-bs', 32, 'batch size') 48 | cmd:option('-learningrate', 5e-4, 'learningrate') 49 | cmd:option('-nepisodes', 1e+6, 'number of episodes') 50 | cmd:option('-nsteps', 10, 'number of steps') 51 | 52 | cmd:option('-step', 1000, 'print every episodes') 53 | cmd:option('-step_test', 10, 'print every episodes') 54 | cmd:option('-step_target', 100, 'target network updates') 55 | 56 | cmd:option('-filename', '', '') 57 | 58 | -- games 59 | -- ColorDigit 60 | cmd:option('-game', 'ColorDigit', 'game name') 61 | cmd:option('-game_dim', 28, '') 62 | cmd:option('-game_bias', 0, '') 63 | cmd:option('-game_colors', 2, '') 64 | cmd:option('-game_use_mnist', 1, '') 65 | cmd:option('-game_use_digits', 0, '') 66 | cmd:option('-game_nagents', 2, '') 67 | cmd:option('-game_action_space', 2, '') 68 | cmd:option('-game_comm_limited', 0, '') 69 | cmd:option('-game_comm_bits', 1, '') 70 | cmd:option('-game_comm_sigma', 0, '') 71 | cmd:option('-game_coop', 1, '') 72 | cmd:option('-game_bottleneck', 10, '') 73 | cmd:option('-game_level', 'extra_hard', '') 74 | cmd:option('-game_vision_net', 'mlp', 'mlp or cnn') 75 | cmd:option('-nsteps', 2, 'number of steps') 76 | -- Switch 77 | cmd:option('-game', 'Switch', 'game name') 78 | cmd:option('-game_nagents', 3, '') 79 | cmd:option('-game_action_space', 2, '') 80 | cmd:option('-game_comm_limited', 1, '') 81 | cmd:option('-game_comm_bits', 2, '') 82 | cmd:option('-game_comm_sigma', 0, '') 83 | cmd:option('-nsteps', 6, 'number of steps') 84 | 85 | cmd:text() 86 | 87 | local opt = cmd:parse(arg) 88 | 89 | -- Custom options 90 | if opt.seed == -1 then opt.seed = torch.random(1000000) end 91 | opt.model_comm_narrow = opt.model_dial 92 | 93 | if opt.model_rnn == 'lstm' then 94 | opt.model_rnn_states = 2 * opt.model_rnn_layers 95 | elseif opt.model_rnn == 'gru' then 96 | opt.model_rnn_states = opt.model_rnn_layers 97 | end 98 | 99 | -- Requirements 100 | require 'nn' 101 | require 'optim' 102 | local kwargs = require 'include.kwargs' 103 | local log = require 'include.log' 104 | local util = require 'include.util' 105 | 106 | -- Set float as default type 107 | torch.manualSeed(opt.seed) 108 | torch.setnumthreads(opt.threads) 109 | torch.setdefaulttensortype('torch.FloatTensor') 110 | 111 | -- Cuda initialisation 112 | if opt.cuda == 1 then 113 | require 'cutorch' 114 | require 'cunn' 115 | cutorch.setDevice(1) 116 | opt.dtype = 'torch.CudaTensor' 117 | print(cutorch.getDeviceProperties(1)) 118 | else 119 | opt.dtype = 'torch.FloatTensor' 120 | end 121 | 122 | if opt.model_comm_narrow == 0 and opt.game_comm_bits > 0 then 123 | opt.game_comm_bits = 2 ^ opt.game_comm_bits 124 | end 125 | 126 | -- Initialise game 127 | local game = (require('game.' .. opt.game))(opt) 128 | 129 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 130 | -- Without dial we add the communication actions to the action space 131 | opt.game_action_space_total = opt.game_action_space + opt.game_comm_bits 132 | else 133 | opt.game_action_space_total = opt.game_action_space 134 | end 135 | 136 | -- Initialise models 137 | local model = (require('model.' .. opt.game))(opt) 138 | 139 | -- Print options 140 | util.sprint(opt) 141 | 142 | -- Model target evaluate 143 | model.evaluate(model.agent_target) 144 | 145 | -- Get parameters 146 | local params, gradParams, params_target, _ = model.getParameters() 147 | 148 | -- Optimisation function 149 | local optim_func, optim_config = model.optim() 150 | local optim_state = {} 151 | 152 | -- Initialise agents 153 | local agent = {} 154 | for i = 1, opt.game_nagents do 155 | agent[i] = {} 156 | 157 | agent[i].id = torch.Tensor():type(opt.dtype):resize(opt.bs):fill(i) 158 | 159 | -- Populate init state 160 | agent[i].input = {} 161 | agent[i].input_target = {} 162 | agent[i].state = {} 163 | agent[i].state_target = {} 164 | agent[i].state[0] = {} 165 | agent[i].state_target[0] = {} 166 | for j = 1, opt.model_rnn_states do 167 | agent[i].state[0][j] = torch.zeros(opt.bs, opt.model_rnn_size):type(opt.dtype) 168 | agent[i].state_target[0][j] = torch.zeros(opt.bs, opt.model_rnn_size):type(opt.dtype) 169 | end 170 | 171 | agent[i].d_state = {} 172 | agent[i].d_state[0] = {} 173 | for j = 1, opt.model_rnn_states do 174 | agent[i].d_state[0][j] = torch.zeros(opt.bs, opt.model_rnn_size):type(opt.dtype) 175 | end 176 | 177 | -- Store q values 178 | agent[i].q_next_max = {} 179 | agent[i].q_comm_next_max = {} 180 | end 181 | 182 | local episode = {} 183 | 184 | -- Initialise aux vectors 185 | local d_err = torch.Tensor(opt.bs, opt.game_action_space_total):type(opt.dtype) 186 | local td_err = torch.Tensor(opt.bs):type(opt.dtype) 187 | local td_comm_err = torch.Tensor(opt.bs):type(opt.dtype) 188 | local stats = { 189 | r_episode = torch.zeros(opt.nsteps), 190 | td_err = torch.zeros(opt.step), 191 | td_comm = torch.zeros(opt.step), 192 | train_r = torch.zeros(opt.step, opt.game_nagents), 193 | steps = torch.zeros(opt.step / opt.step_test), 194 | test_r = torch.zeros(opt.step / opt.step_test, opt.game_nagents), 195 | comm_per = torch.zeros(opt.step / opt.step_test), 196 | te = torch.zeros(opt.step) 197 | } 198 | 199 | local replay = {} 200 | 201 | -- Run episode 202 | local function run_episode(opt, game, model, agent, test_mode) 203 | 204 | -- Test mode 205 | test_mode = test_mode or false 206 | 207 | -- Reset game 208 | game:reset() 209 | 210 | -- Initialise episode 211 | local step = 1 212 | local episode = { 213 | comm_per = torch.zeros(opt.bs), 214 | r = torch.zeros(opt.bs, opt.game_nagents), 215 | steps = torch.zeros(opt.bs), 216 | ended = torch.zeros(opt.bs), 217 | comm_count = 0, 218 | non_comm_count = 0 219 | } 220 | episode[step] = { 221 | s_t = game:getState(), 222 | terminal = torch.zeros(opt.bs) 223 | } 224 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 225 | episode[step].comm = torch.zeros(opt.bs, opt.game_nagents, opt.game_comm_bits):type(opt.dtype) 226 | if opt.model_dial == 1 and opt.model_target == 1 then 227 | episode[step].comm_target = episode[step].comm:clone() 228 | end 229 | episode[step].d_comm = torch.zeros(opt.bs, opt.game_nagents, opt.game_comm_bits):type(opt.dtype) 230 | end 231 | 232 | 233 | -- Run for N steps 234 | local steps = test_mode and opt.nsteps or opt.nsteps + 1 235 | while step <= steps and episode.ended:sum() < opt.bs do 236 | 237 | -- Initialise next step 238 | episode[step + 1] = {} 239 | 240 | -- Initialise comm channel 241 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 242 | episode[step + 1].comm = torch.zeros(opt.bs, opt.game_nagents, opt.game_comm_bits):type(opt.dtype) 243 | episode[step + 1].d_comm = torch.zeros(opt.bs, opt.game_nagents, opt.game_comm_bits):type(opt.dtype) 244 | if opt.model_dial == 1 and opt.model_target == 1 then 245 | episode[step + 1].comm_target = torch.zeros(opt.bs, opt.game_nagents, opt.game_comm_bits):type(opt.dtype) 246 | end 247 | end 248 | 249 | -- Forward pass 250 | episode[step].a_t = torch.zeros(opt.bs, opt.game_nagents):type(opt.dtype) 251 | if opt.model_dial == 0 then 252 | episode[step].a_comm_t = torch.zeros(opt.bs, opt.game_nagents):type(opt.dtype) 253 | end 254 | 255 | -- Iterate agents 256 | for i = 1, opt.game_nagents do 257 | agent[i].input[step] = { 258 | episode[step].s_t[i]:type(opt.dtype), 259 | agent[i].id, 260 | agent[i].state[step - 1] 261 | } 262 | 263 | -- Communication enabled 264 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 265 | local comm_limited = game:getCommLimited(step, i) 266 | local comm = episode[step].comm:clone():type(opt.dtype) 267 | if comm_limited then 268 | -- Create limited communication channel nbits 269 | local comm_lim = torch.zeros(opt.bs, 1, opt.game_comm_bits):type(opt.dtype) 270 | for b = 1, opt.bs do 271 | if comm_limited[b] == 0 then 272 | comm_lim[{ { b } }]:zero() 273 | else 274 | comm_lim[{ { b } }] = comm[{ { b }, unpack(comm_limited[b]) }] 275 | end 276 | end 277 | table.insert(agent[i].input[step], comm_lim) 278 | else 279 | -- zero out own communication if not action aware 280 | comm[{ {}, { i } }]:zero() 281 | table.insert(agent[i].input[step], comm) 282 | end 283 | end 284 | 285 | -- Last action enabled 286 | if opt.model_action_aware == 1 then 287 | -- If comm always then use both action 288 | if opt.model_dial == 0 then 289 | local la = { torch.ones(opt.bs):type(opt.dtype), torch.ones(opt.bs):type(opt.dtype) } 290 | if step > 1 then 291 | for b = 1, opt.bs do 292 | -- Last action 293 | if episode[step - 1].a_t[b][i] > 0 then 294 | la[1][{ { b } }] = episode[step - 1].a_t[b][i] + 1 295 | end 296 | -- Last comm action 297 | if episode[step - 1].a_comm_t[b][i] > 0 then 298 | la[2][{ { b } }] = episode[step - 1].a_comm_t[b][i] - opt.game_action_space + 1 299 | end 300 | end 301 | end 302 | table.insert(agent[i].input[step], la) 303 | else 304 | -- Action aware for single a, comm action 305 | local la = torch.ones(opt.bs):type(opt.dtype) 306 | if step > 1 then 307 | for b = 1, opt.bs do 308 | if episode[step - 1].a_t[b][i] > 0 then 309 | la[{ { b } }] = episode[step - 1].a_t[b][i] + 1 310 | end 311 | end 312 | end 313 | table.insert(agent[i].input[step], la) 314 | end 315 | end 316 | 317 | -- Compute Q values 318 | local comm, state, q_t 319 | agent[i].state[step], q_t = unpack(model.agent[model.id(step, i)]:forward(agent[i].input[step])) 320 | 321 | 322 | -- If dial split out the comm values from q values 323 | if opt.model_dial == 1 then 324 | q_t, comm = DRU(q_t, test_mode) 325 | end 326 | 327 | -- Pick an action (epsilon-greedy) 328 | local action_range, action_range_comm 329 | local max_value, max_a, max_a_comm 330 | if opt.model_dial == 0 then 331 | action_range, action_range_comm = game:getActionRange(step, i) 332 | else 333 | action_range = game:getActionRange(step, i) 334 | end 335 | 336 | -- If Limited action range 337 | if action_range then 338 | agent[i].range = agent[i].range or torch.range(1, opt.game_action_space_total) 339 | max_value = torch.Tensor(opt.bs, 1) 340 | max_a = torch.zeros(opt.bs, 1) 341 | if opt.model_dial == 0 then 342 | max_a_comm = torch.zeros(opt.bs, 1) 343 | end 344 | for b = 1, opt.bs do 345 | -- If comm always fetch range for comm and actions 346 | if opt.model_dial == 0 then 347 | -- If action was taken 348 | if action_range[b][2][1] > 0 then 349 | local v, a = torch.max(q_t[action_range[b]], 2) 350 | max_value[b] = v:squeeze() 351 | max_a[b] = agent[i].range[{ action_range[b][2] }][a:squeeze()] 352 | end 353 | -- If comm action was taken 354 | if action_range_comm[b][2][1] > 0 then 355 | local v, a = torch.max(q_t[action_range_comm[b]], 2) 356 | max_a_comm[b] = agent[i].range[{ action_range_comm[b][2] }][a:squeeze()] 357 | end 358 | else 359 | local v, a = torch.max(q_t[action_range[b]], 2) 360 | max_a[b] = agent[i].range[{ action_range[b][2] }][a:squeeze()] 361 | end 362 | end 363 | else 364 | -- If comm always pick max_a and max_comm 365 | if opt.model_dial == 0 and opt.game_comm_bits > 0 then 366 | _, max_a = torch.max(q_t[{ {}, { 1, opt.game_action_space } }], 2) 367 | _, max_a_comm = torch.max(q_t[{ {}, { opt.game_action_space + 1, opt.game_action_space_total } }], 2) 368 | max_a_comm = max_a_comm + opt.game_action_space 369 | else 370 | _, max_a = torch.max(q_t, 2) 371 | end 372 | end 373 | 374 | -- Store actions 375 | episode[step].a_t[{ {}, { i } }] = max_a:type(opt.dtype) 376 | if opt.model_dial == 0 and opt.game_comm_bits > 0 then 377 | episode[step].a_comm_t[{ {}, { i } }] = max_a_comm:type(opt.dtype) 378 | end 379 | 380 | for b = 1, opt.bs do 381 | 382 | -- Epsilon-greedy action picking 383 | if not test_mode then 384 | if opt.model_dial == 0 then 385 | -- Random action 386 | if torch.uniform() < opt.eps then 387 | if action_range then 388 | if action_range[b][2][1] > 0 then 389 | local a_range = agent[i].range[{ action_range[b][2] }] 390 | local a_idx = torch.random(a_range:nElement()) 391 | episode[step].a_t[b][i] = agent[i].range[{ action_range[b][2] }][a_idx] 392 | end 393 | else 394 | episode[step].a_t[b][i] = torch.random(opt.game_action_space) 395 | end 396 | end 397 | 398 | -- Random communication 399 | if opt.game_comm_bits > 0 and torch.uniform() < opt.eps then 400 | if action_range then 401 | if action_range_comm[b][2][1] > 0 then 402 | local a_range = agent[i].range[{ action_range_comm[b][2] }] 403 | local a_idx = torch.random(a_range:nElement()) 404 | episode[step].a_comm_t[b][i] = agent[i].range[{ action_range_comm[b][2] }][a_idx] 405 | end 406 | else 407 | episode[step].a_comm_t[b][i] = torch.random(opt.game_action_space + 1, opt.game_action_space_total) 408 | end 409 | end 410 | 411 | else 412 | if torch.uniform() < opt.eps then 413 | if action_range then 414 | local a_range = agent[i].range[{ action_range[b][2] }] 415 | local a_idx = torch.random(a_range:nElement()) 416 | episode[step].a_t[b][i] = agent[i].range[{ action_range[b][2] }][a_idx] 417 | else 418 | episode[step].a_t[b][i] = torch.random(q_t[b]:size(1)) 419 | end 420 | end 421 | end 422 | end 423 | 424 | -- If communication action populate channel 425 | if step <= opt.nsteps then 426 | -- For dial we 'forward' the direct activation otherwise we shift the a_t into the 1-game_comm_bits range 427 | if opt.model_dial == 1 then 428 | episode[step + 1].comm[b][i] = comm[b] 429 | else 430 | local a_t = episode[step].a_comm_t[b][i] - opt.game_action_space 431 | if a_t > 0 then 432 | episode[step + 1].comm[b][{ { i }, { a_t } }] = 1 433 | end 434 | end 435 | 436 | if episode.ended[b] == 0 then 437 | episode.comm_per[{ { b } }]:add(1 / opt.game_nagents) 438 | end 439 | episode.comm_count = episode.comm_count + 1 440 | 441 | else 442 | episode.non_comm_count = episode.non_comm_count + 1 443 | end 444 | end 445 | end 446 | 447 | -- Compute reward for current state-action pair 448 | episode[step].r_t, episode[step].terminal = game:step(episode[step].a_t) 449 | 450 | 451 | -- Accumulate steps (not for +1 step) 452 | if step <= opt.nsteps then 453 | for b = 1, opt.bs do 454 | if episode.ended[b] == 0 then 455 | 456 | -- Keep steps and rewards 457 | episode.steps[{ { b } }]:add(1) 458 | episode.r[{ { b } }]:add(episode[step].r_t[b]) 459 | 460 | -- Check if terminal 461 | if episode[step].terminal[b] == 1 then 462 | episode.ended[{ { b } }] = 1 463 | end 464 | end 465 | end 466 | end 467 | 468 | 469 | -- Target Network, for look-ahead 470 | if opt.model_target == 1 and not test_mode then 471 | for i = 1, opt.game_nagents do 472 | local comm = agent[i].input[step][4] 473 | 474 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 and opt.model_dial == 1 then 475 | local comm_limited = game:getCommLimited(step, i) 476 | comm = episode[step].comm_target:clone():type(opt.dtype) 477 | 478 | -- Create limited communication channel nbits 479 | if comm_limited then 480 | local comm_lim = torch.zeros(opt.bs, 1, opt.game_comm_bits):type(opt.dtype) 481 | for b = 1, opt.bs do 482 | if comm_limited[b] == 0 then 483 | comm_lim[{ { b } }] = 0 484 | else 485 | comm_lim[{ { b } }] = comm[{ { b }, unpack(comm_limited[b]) }] 486 | end 487 | end 488 | comm = comm_lim 489 | else 490 | comm[{ {}, { i } }] = 0 491 | end 492 | end 493 | 494 | -- Target input 495 | agent[i].input_target[step] = { 496 | agent[i].input[step][1], 497 | agent[i].input[step][2], 498 | agent[i].state_target[step - 1], 499 | comm, 500 | agent[i].input[step][5], 501 | } 502 | 503 | -- Forward target 504 | local state, q_t_target = unpack(model.agent_target[model.id(step, i)]:forward(agent[i].input_target[step])) 505 | agent[i].state_target[step] = state 506 | if opt.model_dial == 1 then 507 | q_t_target, comm = DRU(q_t_target, test_mode) 508 | end 509 | 510 | -- Limit actions 511 | if opt.model_dial == 0 and opt.game_comm_bits > 0 then 512 | local action_range, action_range_comm = game:getActionRange(step, i) 513 | if action_range then 514 | agent[i].q_next_max[step] = torch.zeros(opt.bs):type(opt.dtype) 515 | agent[i].q_comm_next_max[step] = torch.zeros(opt.bs):type(opt.dtype) 516 | for b = 1, opt.bs do 517 | if action_range[b][2][1] > 0 then 518 | agent[i].q_next_max[step][b], _ = torch.max(q_t_target[action_range[b]], 2) 519 | else 520 | error('Not implemented') 521 | end 522 | 523 | -- If comm not available pick from None 524 | if action_range_comm[b][2][1] > 0 then 525 | agent[i].q_comm_next_max[step][b], _ = torch.max(q_t_target[action_range_comm[b]], 2) 526 | else 527 | agent[i].q_comm_next_max[step][b], _ = torch.max(q_t_target[action_range[b]], 2) 528 | end 529 | end 530 | else 531 | agent[i].q_next_max[step], _ = torch.max(q_t_target[{ {}, { 1, opt.game_action_space } }], 2) 532 | agent[i].q_comm_next_max[step], _ = torch.max(q_t_target[{ {}, { opt.game_action_space + 1, opt.game_action_space_total } }], 2) 533 | end 534 | else 535 | local action_range = game:getActionRange(step, i) 536 | if action_range then 537 | agent[i].q_next_max[step] = torch.zeros(opt.bs):type(opt.dtype) 538 | for b = 1, opt.bs do 539 | if action_range[b][2][1] > 0 then 540 | agent[i].q_next_max[step][b], _ = torch.max(q_t_target[action_range[b]], 2) 541 | end 542 | end 543 | else 544 | agent[i].q_next_max[step], _ = torch.max(q_t_target, 2) 545 | end 546 | end 547 | 548 | if opt.model_dial == 1 then 549 | for b = 1, opt.bs do 550 | episode[step + 1].comm_target[b][i] = comm[b] 551 | end 552 | end 553 | end 554 | end 555 | 556 | -- Forward next step 557 | step = step + 1 558 | if episode.ended:sum() < opt.bs then 559 | episode[step].s_t = game:getState() 560 | end 561 | end 562 | 563 | -- Update stats 564 | episode.nsteps = episode.steps:max() 565 | episode.comm_per:cdiv(episode.steps) 566 | 567 | return episode, agent 568 | end 569 | 570 | 571 | -- split out the communication bits and add noise. 572 | function DRU(q_t, test_mode) 573 | if opt.model_dial == 0 then error('Warning!! Should only be used in DIAL') end 574 | local bound = opt.game_action_space 575 | 576 | local q_t_n = q_t[{ {}, { 1, bound } }]:clone() 577 | local comm = q_t[{ {}, { bound + 1, opt.game_action_space_total } }]:clone() 578 | if test_mode then 579 | if opt.model_comm_narrow == 0 then 580 | local ind 581 | _, ind = torch.max(comm, 2) 582 | comm:zero() 583 | for b = 1, opt.bs do 584 | comm[b][ind[b][1]] = 20 585 | end 586 | else 587 | comm = comm:gt(0.5):type(opt.dtype):add(-0.5):mul(2 * 20) 588 | end 589 | end 590 | if opt.game_comm_sigma > 0 and not test_mode then 591 | local noise_vect = torch.randn(comm:size()):type(opt.dtype):mul(opt.game_comm_sigma) 592 | comm = comm + noise_vect 593 | end 594 | return q_t_n, comm 595 | end 596 | 597 | -- Start time 598 | local beginning_time = torch.tic() 599 | 600 | -- Iterate episodes 601 | for e = 1, opt.nepisodes do 602 | 603 | stats.e = e 604 | 605 | -- Initialise clock 606 | local time = sys.clock() 607 | 608 | -- Model training 609 | model.training(model.agent) 610 | 611 | -- Run episode 612 | episode, agent = run_episode(opt, game, model, agent) 613 | 614 | -- Rewards stats 615 | stats.train_r[(e - 1) % opt.step + 1] = episode.r:mean(1) 616 | 617 | -- Reset parameters 618 | if e == 1 then 619 | gradParams:zero() 620 | end 621 | 622 | -- Backwawrd pass 623 | local step_back = 1 624 | for step = episode.nsteps, 1, -1 do 625 | stats.td_err[(e - 1) % opt.step + 1] = 0 626 | stats.td_comm[(e - 1) % opt.step + 1] = 0 627 | 628 | -- Iterate agents 629 | for i = 1, opt.game_nagents do 630 | 631 | -- Compute Q values 632 | local state, q_t = unpack(model.agent[model.id(step, i)].output) 633 | 634 | -- Compute td error 635 | td_err:zero() 636 | td_comm_err:zero() 637 | d_err:zero() 638 | 639 | for b = 1, opt.bs do 640 | if step >= episode.steps[b] then 641 | -- if first backward init RNN 642 | for j = 1, opt.model_rnn_states do 643 | agent[i].d_state[step_back - 1][j][b]:zero() 644 | end 645 | end 646 | 647 | if step <= episode.steps[b] then 648 | 649 | -- if terminal state or end state => no future rewards 650 | if episode[step].a_t[b][i] > 0 then 651 | if episode[step].terminal[b] == 1 then 652 | td_err[b] = episode[step].r_t[b][i] - q_t[b][episode[step].a_t[b][i]] 653 | else 654 | local q_next_max 655 | if opt.model_avg_q == 1 and opt.model_dial == 0 and episode[step].a_comm_t[b][i] > 0 then 656 | q_next_max = (agent[i].q_next_max[step + 1]:squeeze() + agent[i].q_comm_next_max[step + 1]:squeeze()) / 2 657 | else 658 | q_next_max = agent[i].q_next_max[step + 1]:squeeze() 659 | end 660 | td_err[b] = episode[step].r_t[b][i] + opt.gamma * q_next_max[b] - q_t[b][episode[step].a_t[b][i]] 661 | end 662 | d_err[{ { b }, { episode[step].a_t[b][i] } }] = -td_err[b] 663 | 664 | else 665 | error('Error!') 666 | end 667 | 668 | -- Delta Q for communication 669 | if opt.model_dial == 0 then 670 | if episode[step].a_comm_t[b][i] > 0 then 671 | if episode[step].terminal[b] == 1 then 672 | td_comm_err[b] = episode[step].r_t[b][i] - q_t[b][episode[step].a_comm_t[b][i]] 673 | else 674 | local q_next_max 675 | if opt.model_avg_q == 1 and episode[step].a_t[b][i] > 0 then 676 | q_next_max = (agent[i].q_next_max[step + 1]:squeeze() + agent[i].q_comm_next_max[step + 1]:squeeze()) / 2 677 | else 678 | q_next_max = agent[i].q_comm_next_max[step + 1]:squeeze() 679 | end 680 | td_comm_err[b] = episode[step].r_t[b][i] + opt.gamma * q_next_max[b] - q_t[b][episode[step].a_comm_t[b][i]] 681 | end 682 | d_err[{ { b }, { episode[step].a_comm_t[b][i] } }] = -td_comm_err[b] 683 | end 684 | end 685 | 686 | -- If we use dial and the agent took the umbrella comm action and the messsage happened before last round, the we get incoming derivaties 687 | if opt.model_dial == 1 and step < episode.steps[b] then 688 | -- Derivatives with respect to agent_i's message are stored in d_comm[b][i] 689 | local bound = opt.game_action_space 690 | d_err[{ { b }, { bound + 1, opt.game_action_space_total } }]:add(episode[step + 1].d_comm[b][i]) 691 | end 692 | end 693 | end 694 | 695 | -- Track td-err 696 | stats.td_err[(e - 1) % opt.step + 1] = stats.td_err[(e - 1) % opt.step + 1] + 0.5 * td_err:clone():pow(2):mean() 697 | if opt.model_dial == 0 then 698 | stats.td_comm[(e - 1) % opt.step + 1] = stats.td_comm[(e - 1) % opt.step + 1] + 0.5 * td_comm_err:clone():pow(2):mean() 699 | end 700 | 701 | -- Track the amplitude of the dial-derivatives 702 | if opt.model_dial == 1 then 703 | local bound = opt.game_action_space 704 | stats.td_comm[(e - 1) % opt.step + 1] = stats.td_comm[(e - 1) % opt.step + 1] + 0.5 * d_err[{ {}, { bound + 1, opt.game_action_space_total } }]:clone():pow(2):mean() 705 | end 706 | 707 | -- Backward pass 708 | local grad = model.agent[model.id(step, i)]:backward(agent[i].input[step], { 709 | agent[i].d_state[step_back - 1], 710 | d_err 711 | }) 712 | 713 | --'state' is the 3rd input, so we can extract d_state 714 | agent[i].d_state[step_back] = grad[3] 715 | 716 | --For dial we need to write add the derivatives w/ respect to the incoming messages to the d_comm tracker 717 | if opt.model_dial == 1 then 718 | local comm_limited = game:getCommLimited(step, i) 719 | local comm_grad = grad[4] 720 | 721 | if comm_limited then 722 | for b = 1, opt.bs do 723 | -- Agent could only receive the message if they were active 724 | if comm_limited[b] ~= 0 then 725 | episode[step].d_comm[{ { b }, unpack(comm_limited[b]) }]:add(comm_grad[b]) 726 | end 727 | end 728 | else 729 | -- zero out own communication unless it's part of the switch riddle 730 | comm_grad[{ {}, { i } }]:zero() 731 | episode[step].d_comm:add(comm_grad) 732 | end 733 | end 734 | end 735 | 736 | -- Count backward steps 737 | step_back = step_back + 1 738 | end 739 | 740 | -- Update gradients 741 | local feval = function(x) 742 | 743 | -- Normalise Gradients 744 | gradParams:div(opt.game_nagents * opt.bs) 745 | 746 | -- Clip Gradients 747 | gradParams:clamp(-10, 10) 748 | 749 | return nil, gradParams 750 | end 751 | 752 | optim_func(feval, params, optim_config, optim_state) 753 | 754 | -- Gradient statistics 755 | if e % opt.step == 0 then 756 | stats.grad_norm = gradParams:norm() / gradParams:nElement() * 1000 757 | end 758 | 759 | -- Reset parameters 760 | gradParams:zero() 761 | 762 | -- Update target network 763 | if e % opt.step_target == 0 then 764 | params_target:copy(params) 765 | end 766 | 767 | -- Test 768 | if e % opt.step_test == 0 then 769 | local test_idx = (e / opt.step_test - 1) % (opt.step / opt.step_test) + 1 770 | 771 | local episode, _ = run_episode(opt, game, model, agent, true) 772 | stats.test_r[test_idx] = episode.r:mean(1) 773 | stats.steps[test_idx] = episode.steps:mean() 774 | stats.comm_per[test_idx] = episode.comm_count / (episode.comm_count + episode.non_comm_count) 775 | end 776 | 777 | -- Compute statistics 778 | stats.te[(e - 1) % opt.step + 1] = sys.clock() - time 779 | 780 | if e == opt.step then 781 | stats.td_err_avg = stats.td_err:mean() 782 | stats.td_comm_avg = stats.td_comm:mean() 783 | stats.train_r_avg = stats.train_r:mean(1) 784 | stats.test_r_avg = stats.test_r:mean(1) 785 | stats.steps_avg = stats.steps:mean() 786 | stats.comm_per_avg = stats.comm_per:mean() 787 | stats.te_avg = stats.te:mean() 788 | elseif e % opt.step == 0 then 789 | local coef = 0.9 790 | stats.td_err_avg = stats.td_err_avg * coef + stats.td_err:mean() * (1 - coef) 791 | stats.td_comm_avg = stats.td_comm_avg * coef + stats.td_comm:mean() * (1 - coef) 792 | stats.train_r_avg = stats.train_r_avg * coef + stats.train_r:mean(1) * (1 - coef) 793 | stats.test_r_avg = stats.test_r_avg * coef + stats.test_r:mean(1) * (1 - coef) 794 | stats.steps_avg = stats.steps_avg * coef + stats.steps:mean() * (1 - coef) 795 | stats.comm_per_avg = stats.comm_per_avg * coef + stats.comm_per:mean() * (1 - coef) 796 | stats.te_avg = stats.te_avg * coef + stats.te:mean() * (1 - coef) 797 | end 798 | 799 | -- Print statistics 800 | if e % opt.step == 0 then 801 | log.infof('e=%d, td_err=%.3f, td_err_avg=%.3f, td_comm=%.3f, td_comm_avg=%.3f, tr_r=%.2f, tr_r_avg=%.2f, te_r=%.2f, te_r_avg=%.2f, st=%.1f, comm=%.1f%%, grad=%.3f, t/s=%.2f s, t=%d m', 802 | stats.e, 803 | stats.td_err:mean(), 804 | stats.td_err_avg, 805 | stats.td_comm:mean(), 806 | stats.td_comm_avg, 807 | stats.train_r:mean(), 808 | stats.train_r_avg:mean(), 809 | stats.test_r:mean(), 810 | stats.test_r_avg:mean(), 811 | stats.steps_avg, 812 | stats.comm_per_avg * 100, 813 | stats.grad_norm, 814 | stats.te_avg * opt.step, 815 | torch.toc(beginning_time) / 60) 816 | 817 | collectgarbage() 818 | end 819 | 820 | -- run model specific statistics 821 | model.stats(opt, game, stats, e) 822 | 823 | -- run model specific statistics 824 | model.save(opt, stats, model) 825 | end 826 | -------------------------------------------------------------------------------- /code/game/ColorDigit.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | local class = require 'class' 3 | 4 | local log = require 'include.log' 5 | local kwargs = require 'include.kwargs' 6 | local util = require 'include.util' 7 | 8 | 9 | local ColorDigit = class('ColorDigit') 10 | 11 | function ColorDigit:__init(opt) 12 | self.opt = opt 13 | 14 | self.step_counter = 1 15 | -- Preprocess data 16 | local dataset = (require 'mnist').traindataset() 17 | local data = {} 18 | local lookup = {} 19 | for i = 1, dataset.size do 20 | -- Shift 0 class 21 | local y = dataset[i].y + 1 22 | -- Create array 23 | if not data[y] then 24 | data[y] = {} 25 | end 26 | -- Move data 27 | data[y][#data[y] + 1] = { 28 | x = dataset[i].x, 29 | y = y 30 | } 31 | lookup[i] = y 32 | end 33 | 34 | self.mnist = data 35 | 36 | self.mnist.lookup = lookup 37 | 38 | -- Rewards 39 | self.reward = torch.zeros(self.opt.bs) 40 | self.terminal = torch.zeros(self.opt.bs) 41 | 42 | -- Spawn new game 43 | self:reset() 44 | end 45 | 46 | function ColorDigit:loadDigit() 47 | -- Pick random digit and color 48 | local color_id = torch.zeros(self.opt.bs) 49 | local number = torch.zeros(self.opt.bs) 50 | local x = torch.zeros(self.opt.bs, self.opt.game_colors, self.opt.game_dim, self.opt.game_dim):type(self.opt.dtype) 51 | for b = 1, self.opt.bs do 52 | -- Pick number 53 | local num 54 | if self.opt.game_use_mnist == 1 then 55 | local index = torch.random(#self.mnist.lookup) 56 | num = self.mnist.lookup[index] 57 | elseif torch.uniform() < self.opt.game_bias then 58 | num = 1 59 | else 60 | num = torch.random(10) 61 | end 62 | 63 | number[b] = num 64 | 65 | -- Pick color 66 | color_id[b] = torch.random(self.opt.game_colors) 67 | 68 | -- Pick dataset id 69 | local id = torch.random(#self.mnist[num]) 70 | x[b][color_id[b]] = self.mnist[num][id].x 71 | end 72 | return { x, color_id, number } 73 | end 74 | 75 | 76 | function ColorDigit:reset() 77 | 78 | -- Load images 79 | self.state = { self:loadDigit(), self:loadDigit() } 80 | 81 | -- Reset rewards 82 | self.reward:zero() 83 | self.terminal:zero() 84 | 85 | -- Reset counter 86 | self.step_counter = 1 87 | 88 | return self 89 | end 90 | 91 | function ColorDigit:getActionRange() 92 | return nil 93 | end 94 | 95 | function ColorDigit:getCommLimited() 96 | return nil 97 | end 98 | 99 | function ColorDigit:getReward(a) 100 | 101 | local color_1 = self.state[1][2] 102 | local color_2 = self.state[2][2] 103 | local digit_1 = self.state[1][3] 104 | local digit_2 = self.state[2][3] 105 | 106 | local reward = torch.zeros(self.opt.bs, self.opt.game_nagents) 107 | 108 | for b = 1, self.opt.bs do 109 | if self.opt.game_level == "extra_hard_local" then 110 | if a[b][2] <= self.opt.game_action_space and self.step_counter > 1 then 111 | reward[b] = 2 * (-1) ^ (digit_1[b] + a[b][2] + color_2[b]) + (-1) ^ (digit_2[b] + a[b][2] + color_1[b]) 112 | end 113 | if a[b][1] <= self.opt.game_action_space and self.step_counter > 1 then 114 | reward[b] = reward[b] + 2 * (-1) ^ (digit_2[b] + a[b][1] + color_1[b]) + (-1) ^ (digit_1[b] + a[b][1] + color_2[b]) 115 | end 116 | elseif self.opt.game_level == "many_bits" then 117 | if a[b][1] <= self.opt.game_action_space and self.step_counter == self.opt.nsteps then 118 | if digit_2[b] == a[b][1] then 119 | reward[b] = reward[b] + 0.5 120 | end 121 | end 122 | 123 | if a[b][2] <= self.opt.game_action_space and self.step_counter == self.opt.nsteps then 124 | if digit_1[b] == a[b][2] then 125 | reward[b] = reward[b] + 0.5 126 | end 127 | end 128 | else 129 | error("[ColorDigit] wrong level") 130 | end 131 | end 132 | 133 | local reward_coop = torch.zeros(self.opt.bs, self.opt.game_nagents) 134 | reward_coop[{ {}, { 2 } }] = (reward[{ {}, { 2 } }] + reward[{ {}, { 1 } }] * self.opt.game_coop) / (1 + self.opt.game_coop) 135 | reward_coop[{ {}, { 1 } }] = (reward[{ {}, { 1 } }] + reward[{ {}, { 2 } }] * self.opt.game_coop) / (1 + self.opt.game_coop) 136 | 137 | return reward_coop 138 | end 139 | 140 | function ColorDigit:step(a) 141 | local reward, terminal 142 | 143 | reward = self:getReward(a) 144 | 145 | if self.step_counter == self.opt.nsteps then 146 | self.terminal:fill(1) 147 | end 148 | 149 | self.step_counter = self.step_counter + 1 150 | 151 | return reward, self.terminal:clone() 152 | end 153 | 154 | 155 | function ColorDigit:getState() 156 | if self.opt.game_use_digits == 1 then 157 | return { self.state[1][3], self.state[2][3] } 158 | else 159 | return { self.state[1][1], self.state[2][1] } 160 | end 161 | end 162 | 163 | return ColorDigit 164 | 165 | -------------------------------------------------------------------------------- /code/game/Switch.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | local class = require 'class' 3 | 4 | local log = require 'include.log' 5 | local kwargs = require 'include.kwargs' 6 | local util = require 'include.util' 7 | 8 | 9 | local Switch = class('Switch') 10 | 11 | -- Actions 12 | -- 1 = on 13 | -- 2 = off 14 | -- 3 = tell 15 | -- 4* = none 16 | 17 | function Switch:__init(opt) 18 | local opt_game = kwargs(_, { 19 | { 'game_action_space', type = 'int-pos', default = 2 }, 20 | { 'game_reward_shift', type = 'int', default = 0 }, 21 | { 'game_comm_bits', type = 'int', default = 0 }, 22 | { 'game_comm_sigma', type = 'number', default = 2 }, 23 | }) 24 | 25 | -- Steps max override 26 | opt.nsteps = 4 * opt.game_nagents - 6 27 | 28 | for k, v in pairs(opt_game) do 29 | if not opt[k] then 30 | opt[k] = v 31 | end 32 | end 33 | self.opt = opt 34 | 35 | -- Rewards 36 | self.reward_all_live = 1 + self.opt.game_reward_shift 37 | self.reward_all_die = -1 + self.opt.game_reward_shift 38 | 39 | -- Spawn new game 40 | self:reset() 41 | end 42 | 43 | function Switch:reset() 44 | 45 | -- Reset rewards 46 | self.reward = torch.zeros(self.opt.bs, self.opt.game_nagents) 47 | 48 | -- Has been in 49 | self.has_been = torch.zeros(self.opt.bs, self.opt.nsteps, self.opt.game_nagents) 50 | 51 | -- Reached end 52 | self.terminal = torch.zeros(self.opt.bs) 53 | 54 | -- Step counter 55 | self.step_counter = 1 56 | 57 | -- Who is in 58 | self.active_agent = torch.zeros(self.opt.bs, self.opt.nsteps) 59 | for b = 1, self.opt.bs do 60 | for step = 1, self.opt.nsteps do 61 | local id = torch.random(self.opt.game_nagents) 62 | self.active_agent[{ { b }, { step } }] = id 63 | self.has_been[{ { b }, { step }, { id } }] = 1 64 | end 65 | end 66 | 67 | return self 68 | end 69 | 70 | function Switch:getActionRange(step, agent) 71 | local range = {} 72 | if self.opt.model_dial == 1 then 73 | local bound = self.opt.game_action_space 74 | 75 | for i = 1, self.opt.bs do 76 | if self.active_agent[i][step] == agent then 77 | range[i] = { { i }, { 1, bound } } 78 | else 79 | range[i] = { { i }, { 1 } } 80 | end 81 | end 82 | return range 83 | else 84 | local comm_range = {} 85 | for i = 1, self.opt.bs do 86 | if self.active_agent[i][step] == agent then 87 | range[i] = { { i }, { 1, self.opt.game_action_space } } 88 | comm_range[i] = { { i }, { self.opt.game_action_space + 1, self.opt.game_action_space_total } } 89 | else 90 | range[i] = { { i }, { 1 } } 91 | comm_range[i] = { { i }, { 0, 0 } } 92 | end 93 | end 94 | return range, comm_range 95 | end 96 | end 97 | 98 | 99 | function Switch:getCommLimited(step, i) 100 | if self.opt.game_comm_limited then 101 | 102 | local range = {} 103 | 104 | -- Get range per batch 105 | for b = 1, self.opt.bs do 106 | -- if agent is active read from field of previous agent 107 | if step > 1 and i == self.active_agent[b][step] then 108 | range[b] = { self.active_agent[b][step - 1], {} } 109 | else 110 | range[b] = 0 111 | end 112 | end 113 | return range 114 | else 115 | return nil 116 | end 117 | end 118 | 119 | function Switch:getReward(a_t) 120 | 121 | for b = 1, self.opt.bs do 122 | local active_agent = self.active_agent[b][self.step_counter] 123 | if (a_t[b][active_agent] == 2 and self.terminal[b] == 0) then 124 | local has_been = self.has_been[{ { b }, { 1, self.step_counter }, {} }]:sum(2):squeeze(2):gt(0):float():sum() 125 | if has_been == self.opt.game_nagents then 126 | self.reward[b] = self.reward_all_live 127 | else 128 | self.reward[b] = self.reward_all_die 129 | end 130 | self.terminal[b] = 1 131 | elseif self.step_counter == self.opt.nsteps and self.terminal[b] == 0 then 132 | self.terminal[b] = 1 133 | end 134 | end 135 | 136 | return self.reward:clone(), self.terminal:clone() 137 | end 138 | 139 | function Switch:step(a_t) 140 | 141 | -- Get rewards 142 | local reward, terminal = self:getReward(a_t) 143 | 144 | -- Make step 145 | self.step_counter = self.step_counter + 1 146 | 147 | return reward, terminal 148 | end 149 | 150 | 151 | function Switch:getState() 152 | local state = {} 153 | 154 | for agent = 1, self.opt.game_nagents do 155 | state[agent] = torch.Tensor(self.opt.bs) 156 | 157 | for b = 1, self.opt.bs do 158 | if self.active_agent[b][self.step_counter] == agent then 159 | state[agent][{ { b } }] = 1 160 | else 161 | state[agent][{ { b } }] = 2 162 | end 163 | end 164 | end 165 | 166 | return state 167 | end 168 | 169 | return Switch 170 | 171 | -------------------------------------------------------------------------------- /code/include/kwargs.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Based on the code of Brendan Shillingford [bitbucket.org/bshillingford/nnob](https://bitbucket.org/bshillingford/nnob). 4 | 5 | Argument type checker for keyword arguments, i.e. arguments 6 | specified as a key-value table to constructors/functions. 7 | 8 | ## Valid typespecs: 9 | - `"number"` 10 | - `"string"` 11 | - `"boolean"` 12 | - `"function"` 13 | - `"table"` 14 | - `"tensor"` 15 | - torch class (not a string) like nn.Module 16 | - `"class:x"` 17 | - e.g. x=`nn.Module`; can be comma-separated to OR them 18 | - specific `torch.*Tensor` 19 | - specific `torch.*Storage` 20 | - `"int"`: number that is integer 21 | - `"int-pos"`: integer, and `> 0` 22 | - `"int-nonneg"`: integer, and `>= 0` 23 | --]] 24 | require 'torch' 25 | local math = require 'math' 26 | 27 | local function assert_type(val, typespec, argname) 28 | local typename = torch.typename(val) or type(val) 29 | -- handles number, boolean, string, table; but passes through if needed 30 | if typespec == typename then return true end 31 | 32 | -- try to parse, nil if no match 33 | local classnames = type(typespec) == string and string.match(typespec, 'class: *(.*)') 34 | 35 | -- isTypeOf for table-typed specs (see below for class:x version) 36 | if type(typespec) == 'table' and typespec.__typename then 37 | if torch.isTypeOf(val, typespec) then 38 | return true 39 | else 40 | error(string.format('argument %s should be instance of %s, but is type %s', 41 | argname, typespec.__typename, typename)) 42 | end 43 | elseif typespec == 'tensor' then 44 | if torch.isTensor(val) then return true end 45 | elseif classnames then 46 | for _, classname in pairs(string.split(classnames, ' *, *')) do 47 | if torch.isTypeOf(val, classname) then return true end 48 | end 49 | elseif typespec == 'int' or typespec == 'integer' then 50 | if math.floor(val) == val then return true end 51 | elseif typespec == 'int-pos' then 52 | if val > 0 and math.floor(val) == val then return true end 53 | elseif typespec == 'int-nonneg' then 54 | if val >= 0 and math.floor(val) == val then return true end 55 | else 56 | error('invalid type spec (' .. tostring(typespec) .. ') for arg ' .. argname) 57 | end 58 | error(string.format('argument %s must be of type %s, given type %s', 59 | argname, typespec, typename)) 60 | end 61 | 62 | return function(args, settings) 63 | local result = {} 64 | local unprocessed = {} 65 | 66 | if not args then 67 | args = {} 68 | end 69 | 70 | if type(args) ~= 'table' then 71 | error('args must be non-nil and must be a table') 72 | end 73 | 74 | for k, _ in pairs(args) do 75 | unprocessed[k] = true 76 | end 77 | 78 | -- Use ipairs, so we skip named settings 79 | for _, setting in ipairs(settings) do 80 | -- allow name to either be the only non-named element 81 | -- e.g. {'name', type='...'}, or named 82 | local name = setting.name or setting[1] 83 | 84 | -- get value or default 85 | local val 86 | if args[name] ~= nil then 87 | val = args[name] 88 | elseif setting.default ~= nil then 89 | val = setting.default 90 | elseif not setting.optional then 91 | error('required argument: ' .. name) 92 | end 93 | -- check types 94 | if val ~= nil and not setting.optional and setting.type ~= nil then 95 | assert_type(val, setting.type, name) 96 | end 97 | 98 | result[name] = val 99 | unprocessed[name] = nil 100 | end 101 | 102 | if settings.ignore_extras then 103 | for _, name in pairs(unprocessed) do 104 | result[name] = args[name] 105 | end 106 | elseif #unprocessed > 0 then 107 | error('extra unprocessed arguments: ' 108 | .. table.concat(unprocessed, ', ')) 109 | end 110 | return result 111 | end 112 | -------------------------------------------------------------------------------- /code/include/log.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Based on the code of Brendan Shillingford [bitbucket.org/bshillingford/nnob](https://bitbucket.org/bshillingford/nnob). 4 | 5 | Based on [github.com/rxi/log.lua](https://github.com/rxi/log.lua/commit/93cfbe0c91bced6d3d061a58e7129979441eb200). 6 | 7 | Usage: 8 | ``` 9 | local log = require 'nnob.log' 10 | 11 | -- ... 12 | log.info('hello world!') 13 | log.infof('some number: %3.3f', 123) 14 | ``` 15 | 16 | Pasted from README.md of `github.com/rxi/log.lua`: 17 | # log.lua 18 | A tiny logging module for Lua. 19 | 20 |  21 | 22 | 23 | ## Usage 24 | log.lua provides 6 functions, each function takes all its arguments, 25 | concatenates them into a string then outputs the string to the console and -- 26 | if one is set -- the log file: 27 | 28 | * **log.trace(...)** 29 | * **log.debug(...)** 30 | * **log.info(...)** 31 | * **log.warn(...)** 32 | * **log.error(...)** 33 | * **log.fatal(...)** 34 | 35 | 36 | ### Additional options 37 | log.lua provides variables for setting additional options: 38 | 39 | #### log.usecolor 40 | Whether colors should be used when outputting to the console, this is `true` by 41 | default. If you're using a console which does not support ANSI color escape 42 | codes then this should be disabled. 43 | 44 | #### log.outfile 45 | The name of the file where the log should be written, log files do not contain 46 | ANSI colors and always use the full date rather than just the time. By default 47 | `log.outfile` is `nil` (no log file is used). If a file which does not exist is 48 | set as the `log.outfile` then it is created on the first message logged. If the 49 | file already exists it is appended to. 50 | 51 | #### log.level 52 | The minimum level to log, any logging function called with a lower level than 53 | the `log.level` is ignored and no text is outputted or written. By default this 54 | value is set to `"trace"`, the lowest log level, such that no log messages are 55 | ignored. 56 | 57 | The level of each log mode, starting with the lowest log level is as follows: 58 | `"trace"` `"debug"` `"info"` `"warn"` `"error"` `"fatal"` 59 | 60 | 61 | ## License 62 | This library is free software; you can redistribute it and/or modify it under 63 | the terms of the MIT license. See [LICENSE](LICENSE) for details. 64 | --]] 65 | 66 | 67 | 68 | -- 69 | -- log.lua 70 | -- 71 | -- Copyright (c) 2016 rxi 72 | -- 73 | -- This library is free software; you can redistribute it and/or modify it 74 | -- under the terms of the MIT license. See LICENSE for details. 75 | -- 76 | 77 | local log = { _version = "0.1.0" } 78 | 79 | log.usecolor = true 80 | log.outfile = nil 81 | log.level = "trace" 82 | 83 | 84 | local modes = { 85 | { name = "trace", color = "\27[34m", }, 86 | { name = "debug", color = "\27[36m", }, 87 | { name = "info", color = "\27[32m", }, 88 | { name = "warn", color = "\27[33m", }, 89 | { name = "error", color = "\27[31m", }, 90 | { name = "fatal", color = "\27[35m", }, 91 | } 92 | 93 | 94 | local levels = {} 95 | for i, v in ipairs(modes) do 96 | levels[v.name] = i 97 | end 98 | 99 | 100 | local round = function(x, increment) 101 | increment = increment or 1 102 | x = x / increment 103 | return (x > 0 and math.floor(x + .5) or math.ceil(x - .5)) * increment 104 | end 105 | 106 | 107 | local _tostring = tostring 108 | 109 | local tostring = function(...) 110 | local t = {} 111 | for i = 1, select('#', ...) do 112 | local x = select(i, ...) 113 | if type(x) == "number" then 114 | x = round(x, .01) 115 | end 116 | t[#t + 1] = _tostring(x) 117 | end 118 | return table.concat(t, " ") 119 | end 120 | 121 | 122 | for i, x in ipairs(modes) do 123 | local nameupper = x.name:upper() 124 | log[x.name] = function(...) 125 | 126 | -- Return early if we're below the log level 127 | if i < levels[log.level] then 128 | return 129 | end 130 | 131 | local msg = tostring(...) 132 | local info = debug.getinfo(2, "Sl") 133 | local lineinfo = info.short_src .. ":" .. info.currentline 134 | 135 | -- Output to console 136 | print(string.format("%s[%-6s%s]%s %s: %s", 137 | log.usecolor and x.color or "", 138 | nameupper, 139 | os.date("%H:%M:%S"), 140 | log.usecolor and "\27[0m" or "", 141 | lineinfo, 142 | msg)) 143 | 144 | -- Output to log file 145 | if log.outfile then 146 | local fp = io.open(log.outfile, "a") 147 | local str = string.format("[%-6s%s] %s: %s\n", 148 | nameupper, os.date(), lineinfo, msg) 149 | fp:write(str) 150 | fp:close() 151 | end 152 | end 153 | 154 | -- bshillingford: add formatted versions, 155 | -- e.g. log.infof(...) as alias for log.info(string.format(...) 156 | log[x.name .. 'f'] = function(...) 157 | -- Return early if we're below the log level 158 | if i < levels[log.level] then 159 | return 160 | end 161 | 162 | local fmt = ... -- i.e. first arg; note: select(2, ...) gets everything after 163 | 164 | local info = debug.getinfo(2, "Sl") 165 | local lineinfo = info.short_src .. ":" .. info.currentline 166 | 167 | -- Output to console 168 | print(string.format("%s[%-6s%s]%s %s: " .. fmt, 169 | log.usecolor and x.color or "", 170 | nameupper, 171 | os.date("%H:%M:%S"), 172 | log.usecolor and "\27[0m" or "", 173 | lineinfo, 174 | select(2, ...))) 175 | 176 | -- Output to log file 177 | if log.outfile then 178 | local fp = io.open(log.outfile, "a") 179 | local str = string.format("[%-6s%s] %s: " .. fmt .. "\n", 180 | nameupper, os.date(), lineinfo, select(2, ...)) 181 | fp:write(str) 182 | fp:close() 183 | end 184 | end 185 | end 186 | 187 | 188 | return log 189 | -------------------------------------------------------------------------------- /code/include/util.lua: -------------------------------------------------------------------------------- 1 | local util = {} 2 | local log = require 'include.log' 3 | 4 | function util.euclidean_dist(p, q) 5 | return math.sqrt(util.euclidean_dist2(p, q)) 6 | end 7 | 8 | function util.euclidean_dist2(p, q) 9 | assert(#p == #q, 'vectors must have the same length') 10 | local sum = 0 11 | for i in ipairs(p) do 12 | sum = sum + (p[i] - q[i]) ^ 2 13 | end 14 | return sum 15 | end 16 | 17 | function util.resetParams(cur_module) 18 | if cur_module.modules then 19 | for i, module in ipairs(cur_module.modules) do 20 | util.resetParams(module) 21 | end 22 | else 23 | cur_module:reset() 24 | end 25 | 26 | return cur_module 27 | end 28 | 29 | function util.dc(orig) 30 | local orig_type = torch.type(orig) 31 | local copy 32 | if orig_type == 'table' then 33 | copy = {} 34 | for orig_key, orig_value in next, orig, nil do 35 | copy[util.dc(orig_key)] = util.dc(orig_value) 36 | end 37 | setmetatable(copy, util.dc(getmetatable(orig))) 38 | elseif orig_type == 'torch.FloatTensor' or orig_type == 'torch.DoubleTensor' or orig_type == 'torch.CudaTensor' then 39 | -- Torch tensor 40 | copy = orig:clone() 41 | else 42 | -- number, string, boolean, etc 43 | copy = orig 44 | end 45 | 46 | return copy 47 | end 48 | 49 | function util.copyManyTimes(net, n) 50 | local nets = {} 51 | 52 | for i = 1, n do 53 | nets[#nets + 1] = util.resetParams(net:clone()) 54 | end 55 | 56 | return nets 57 | end 58 | 59 | function util.cloneManyTimes(net, T) 60 | local clones = {} 61 | local params, gradParams = net:parameters() 62 | if params == nil then 63 | params = {} 64 | end 65 | local paramsNoGrad 66 | if net.parametersNoGrad then 67 | paramsNoGrad = net:parametersNoGrad() 68 | end 69 | local mem = torch.MemoryFile("w"):binary() 70 | mem:writeObject(net) 71 | for t = 1, T do 72 | -- We need to use a new reader for each clone. 73 | -- We don't want to use the pointers to already read objects. 74 | local reader = torch.MemoryFile(mem:storage(), "r"):binary() 75 | local clone = reader:readObject() 76 | reader:close() 77 | local cloneParams, cloneGradParams = clone:parameters() 78 | local cloneParamsNoGrad 79 | for i = 1, #params do 80 | cloneParams[i]:set(params[i]) 81 | cloneGradParams[i]:set(gradParams[i]) 82 | end 83 | if paramsNoGrad then 84 | cloneParamsNoGrad = clone:parametersNoGrad() 85 | for i = 1, #paramsNoGrad do 86 | cloneParamsNoGrad[i]:set(paramsNoGrad[i]) 87 | end 88 | end 89 | clones[t] = clone 90 | collectgarbage() 91 | end 92 | mem:close() 93 | return clones 94 | end 95 | 96 | function util.spairs(t, order) 97 | -- collect the keys 98 | local keys = {} 99 | for k in pairs(t) do keys[#keys + 1] = k end 100 | 101 | -- if order function given, sort by it by passing the table and keys a, b, 102 | -- otherwise just sort the keys 103 | if order then 104 | table.sort(keys, function(a, b) return order(t, a, b) end) 105 | else 106 | table.sort(keys) 107 | end 108 | 109 | -- return the iterator function 110 | local i = 0 111 | return function() 112 | i = i + 1 113 | if keys[i] then 114 | return keys[i], t[keys[i]] 115 | end 116 | end 117 | end 118 | 119 | function util.sprint(t) 120 | for k, v in util.spairs(t) do 121 | log.debugf('opt[\'%s\'] = %s', k, v) 122 | end 123 | end 124 | 125 | function util.f2(f) 126 | return string.format("%.2f", f) 127 | end 128 | 129 | function util.f3(f) 130 | return string.format("%.3f", f) 131 | end 132 | 133 | function util.f4(f) 134 | return string.format("%.4f", f) 135 | end 136 | 137 | function util.f5(f) 138 | return string.format("%.5f", f) 139 | end 140 | 141 | function util.f6(f) 142 | return string.format("%.6f", f) 143 | end 144 | 145 | function util.d(f) 146 | return string.format("%d", torch.round(f)) 147 | end 148 | 149 | 150 | 151 | return util -------------------------------------------------------------------------------- /code/model/ColorDigit.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | require 'optim' 4 | require 'csvigo' 5 | 6 | local kwargs = require 'include.kwargs' 7 | local log = require 'include.log' 8 | local util = require 'include.util' 9 | local LSTM = require 'module.LSTM' 10 | local GRU = require 'module.GRU' 11 | require 'module.LinearO' 12 | require 'module.Binarize' 13 | require 'module.Print' 14 | require 'module.GaussianNoise' 15 | require 'module.rmsprop' 16 | 17 | return function(opt) 18 | 19 | local exp = {} 20 | 21 | function exp.optim(iter) 22 | -- iter can be used for learning rate decay 23 | -- local optimfunc = optim.adam 24 | local optimfunc = optim.rmsprop 25 | local optimconfig = { learningRate = opt.learningrate } 26 | return optimfunc, optimconfig 27 | end 28 | 29 | function exp.save(opt, stats, model) 30 | if stats.e % opt.step == 0 then 31 | if opt.filename == '' then 32 | exp.save_path = exp.save_path or paths.concat('results', opt.game .. '_' .. opt.game_nagents .. 33 | (opt.model_dial == 1 and '_dial' or '') .. '_' .. string.upper(string.format("%x", opt.seed))) 34 | else 35 | exp.save_path = exp.save_path or paths.concat('results', opt.game .. '_' .. opt.game_nagents .. 36 | (opt.model_dial == 1 and '_dial' or '') .. '_' .. opt.filename .. '_' .. string.upper(string.format("%x", opt.seed))) 37 | end 38 | 39 | 40 | -- Save opt 41 | if stats.e == opt.step then 42 | os.execute('mkdir -p ' .. exp.save_path) 43 | local opt_csv = {} 44 | for k, v in util.spairs(opt) do 45 | table.insert(opt_csv, { k, v }) 46 | end 47 | 48 | csvigo.save({ 49 | path = paths.concat(exp.save_path, 'opt.csv'), 50 | data = opt_csv, 51 | verbose = false 52 | }) 53 | end 54 | 55 | -- keep stats 56 | stats.history = stats.history or { { 'e', 'td_err', 'td_comm', 'train_r', 'test_r', 'steps', 'comm_per', 'te' } } 57 | table.insert(stats.history, { 58 | stats.e, 59 | stats.td_err:mean(), 60 | stats.td_comm:mean(), 61 | stats.train_r:mean(), 62 | stats.test_r:mean(), 63 | stats.steps:mean(), 64 | stats.comm_per:mean(), 65 | stats.te:mean() 66 | }) 67 | 68 | -- Save stats csv 69 | csvigo.save({ 70 | path = paths.concat(exp.save_path, 'stats.csv'), 71 | data = stats.history, 72 | verbose = false 73 | }) 74 | 75 | -- Save action histogram 76 | if opt.hist_action == 1 then 77 | -- Append to memory 78 | stats.history_hist_action = stats.history_hist_action or {} 79 | table.insert(stats.history_hist_action, 80 | stats.hist_action_avg:totable()[1]) 81 | 82 | -- save csv 83 | csvigo.save({ 84 | path = paths.concat(exp.save_path, 'hist_action.csv'), 85 | data = stats.history_hist_action, 86 | verbose = false 87 | }) 88 | end 89 | 90 | -- Save action histogram 91 | if opt.hist_comm == 1 then 92 | -- Append to memory 93 | stats.history_hist_comm = stats.history_hist_comm or {} 94 | table.insert(stats.history_hist_comm, 95 | stats.hist_comm_avg:totable()[1]) 96 | 97 | -- save csv 98 | csvigo.save({ 99 | path = paths.concat(exp.save_path, 'hist_comm.csv'), 100 | data = stats.history_hist_comm, 101 | verbose = false 102 | }) 103 | end 104 | 105 | 106 | -- save model 107 | if stats.e % (opt.step * 10) == 0 then 108 | log.debug('Saving model') 109 | 110 | -- clear state 111 | -- exp.clearState(model.agent) 112 | 113 | -- save model 114 | local filename = paths.concat(exp.save_path, 'exp.t7') 115 | torch.save(filename, { opt, stats, model.agent }) 116 | end 117 | end 118 | end 119 | 120 | function exp.load() 121 | end 122 | 123 | function exp.clearState(model) 124 | for i = 1, #model do 125 | model[i]:clearState() 126 | end 127 | end 128 | 129 | function exp.training(model) 130 | for i = 1, #model do 131 | model[i]:training() 132 | end 133 | end 134 | 135 | function exp.evaluate(model) 136 | for i = 1, #model do 137 | model[i]:evaluate() 138 | end 139 | end 140 | 141 | function exp.getParameters() 142 | -- Get model params 143 | local a = nn.Container() 144 | for i = 1, #exp.agent do 145 | a:add(exp.agent[i]) 146 | end 147 | local params, gradParams = a:getParameters() 148 | 149 | log.infof('Creating model(s), params=%d', params:nElement()) 150 | 151 | -- Get target model params 152 | local a = nn.Container() 153 | for i = 1, #exp.agent_target do 154 | a:add(exp.agent_target[i]) 155 | end 156 | local params_target, gradParams_target = a:getParameters() 157 | 158 | log.infof('Creating target model(s), params=%d', params_target:nElement()) 159 | 160 | return params, gradParams, params_target, gradParams_target 161 | end 162 | 163 | function exp.id(step_i, agent_i) 164 | return (step_i - 1) * opt.game_nagents + agent_i 165 | end 166 | 167 | function exp.stats(opt, game, stats, e) 168 | if e % opt.step == 0 then 169 | end 170 | end 171 | 172 | local function create_agent() 173 | 174 | -- Sizes 175 | local comm_size = 0 176 | 177 | if (opt.game_comm_bits > 0) and (opt.game_nagents > 1) then 178 | comm_size = opt.game_comm_bits * opt.game_nagents 179 | end 180 | 181 | local action_aware_size = 0 182 | if opt.model_action_aware == 1 then 183 | action_aware_size = opt.game_action_space_total 184 | end 185 | 186 | -- Networks 187 | local model_mnist = nn.Sequential() 188 | 189 | if opt.game_use_digits == 1 then 190 | model_mnist:add(nn.LookupTable(10, opt.model_rnn_size)) 191 | else 192 | if opt.game_vision_net == 'mlp' then 193 | model_mnist:add(nn.View(-1, opt.game_colors * 28 * 28)) 194 | if opt.model_bn == 1 then model_mnist:add(nn.BatchNormalization(opt.game_colors * 28 * 28)) end 195 | model_mnist:add(nn.Linear(opt.game_colors * 28 * 28, 128)) -- fully connected layer (matrix multiplication between input and weights) 196 | model_mnist:add(nn.ReLU(true)) 197 | if opt.model_bn == 1 then model_mnist:add(nn.BatchNormalization(128)) end 198 | model_mnist:add(nn.Linear(128, opt.model_rnn_size)) 199 | model_mnist:add(nn.View(-1, opt.model_rnn_size)) 200 | else 201 | model_mnist:add(nn.SpatialConvolutionMM(opt.game_colors, 32, 5, 5)) -- 1 input image channel, 6 output channels, 5x5 convolution kernel 202 | model_mnist:add(nn.ReLU(true)) 203 | model_mnist:add(nn.SpatialMaxPooling(3, 3, 3, 3)) -- A max-pooling operation that looks at 2x2 windows and finds the max. 204 | model_mnist:add(nn.SpatialConvolutionMM(32, 64, 5, 5)) 205 | model_mnist:add(nn.ReLU(true)) 206 | model_mnist:add(nn.SpatialMaxPooling(2, 2, 2, 2)) 207 | 208 | local conv_out = model_mnist:forward(torch.zeros(1, opt.game_colors, 28, 28)):nElement() 209 | 210 | model_mnist:add(nn.View(-1, conv_out)) 211 | model_mnist:add(nn.Linear(conv_out, opt.model_rnn_size)) 212 | model_mnist:add(nn.View(-1, opt.model_rnn_size)) 213 | end 214 | end 215 | 216 | 217 | -- Process inputs 218 | local model_input = nn.Sequential() 219 | model_input:add(nn.CAddTable(2)) 220 | 221 | -- RNN 222 | local model_rnn 223 | if opt.model_rnn == 'lstm' then 224 | model_rnn = LSTM(opt.model_rnn_size, 225 | opt.model_rnn_size, 226 | opt.model_rnn_layers, 227 | opt.model_dropout, 228 | opt.model_bn == 1) 229 | elseif opt.model_rnn == 'gru' then 230 | model_rnn = GRU(opt.model_rnn_size, 231 | opt.model_rnn_size, 232 | opt.model_rnn_layers, 233 | opt.model_dropout) 234 | end 235 | 236 | -- use default initialization for convnet, but uniform -0.08 to .08 for RNN: 237 | -- double parens necessary 238 | for _, param in ipairs((model_rnn:parameters())) do 239 | param:uniform(-0.08, 0.08) 240 | end 241 | 242 | -- Output 243 | local model_out = nn.Sequential() 244 | if opt.model_dropout > 0 then model_out:add(nn.Dropout(opt.model_dropout)) end 245 | model_out:add(nn.Linear(opt.model_rnn_size, opt.model_rnn_size)) 246 | model_out:add(nn.ReLU(true)) 247 | model_out:add(nn.Linear(opt.model_rnn_size, opt.game_action_space_total)) 248 | 249 | -- Construct Graph 250 | 251 | local in_state = nn.Identity()() 252 | local in_id = nn.Identity()() 253 | local in_rnn_state = nn.Identity()() 254 | 255 | local in_comm, in_action 256 | 257 | local in_all = { 258 | model_mnist(in_state), 259 | nn.LookupTable(opt.game_nagents, opt.model_rnn_size)(in_id) 260 | } 261 | 262 | -- Communication enabled 263 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 264 | in_comm = nn.Identity()() 265 | -- Process comm 266 | local model_comm = nn.Sequential() 267 | model_comm:add(nn.View(-1, comm_size)) 268 | if opt.model_dial == 1 then 269 | if opt.model_comm_narrow == 1 then 270 | model_comm:add(nn.Sigmoid()) 271 | else 272 | model_comm:add(nn.SoftMax()) 273 | end 274 | end 275 | if opt.model_bn == 1 and opt.model_dial == 1 then 276 | model_comm:add(nn.BatchNormalization(comm_size)) 277 | end 278 | model_comm:add(nn.Linear(comm_size, opt.model_rnn_size)) 279 | if opt.model_comm_narrow == 1 then 280 | model_comm:add(nn.ReLU(true)) 281 | end 282 | 283 | -- Process inputs node 284 | table.insert(in_all, model_comm(in_comm)) 285 | end 286 | 287 | -- Last action enabled 288 | if opt.model_action_aware == 1 then 289 | in_action = nn.Identity()() 290 | 291 | -- Process action node (+1 for no-action at 0-step) 292 | if opt.model_dial == 0 then 293 | local in_action_aware = nn.CAddTable(2)({ 294 | nn.LookupTable(opt.game_action_space + 1, opt.model_rnn_size)(nn.SelectTable(1)(in_action)), 295 | nn.LookupTable(opt.game_comm_bits + 1, opt.model_rnn_size)(nn.SelectTable(2)(in_action)) 296 | }) 297 | table.insert(in_all, in_action_aware) 298 | else 299 | table.insert(in_all, nn.LookupTable(action_aware_size + 1, opt.model_rnn_size)(in_action)) 300 | end 301 | end 302 | 303 | -- Process inputs 304 | local proc_input = model_input(in_all) 305 | 306 | -- 2*n+1 rnn inputs 307 | local rnn_input = {} 308 | table.insert(rnn_input, proc_input) 309 | 310 | -- Restore state 311 | for i = 1, opt.model_rnn_states do 312 | table.insert(rnn_input, nn.SelectTable(i)(in_rnn_state)) 313 | end 314 | 315 | local rnn_output = model_rnn(rnn_input) 316 | 317 | -- Split state and out 318 | local rnn_state = rnn_output 319 | local rnn_out = nn.SelectTable(opt.model_rnn_states)(rnn_output) 320 | 321 | -- Process out 322 | local proc_out = model_out(rnn_out) 323 | 324 | -- Create model 325 | local model_inputs = { in_state, in_id, in_rnn_state } 326 | local model_outputs = { rnn_state, proc_out } 327 | 328 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 329 | table.insert(model_inputs, in_comm) 330 | end 331 | 332 | if opt.model_action_aware == 1 then 333 | table.insert(model_inputs, in_action) 334 | end 335 | 336 | nngraph.annotateNodes() 337 | 338 | local model = nn.gModule(model_inputs, model_outputs) 339 | 340 | return model:type(opt.dtype) 341 | end 342 | 343 | -- Create model 344 | local agent = create_agent() 345 | local agent_target = agent:clone() 346 | 347 | -- Knowledge sharing 348 | if opt.model_know_share == 1 then 349 | exp.agent = util.cloneManyTimes(agent, opt.game_nagents * (opt.nsteps + 1)) 350 | exp.agent_target = util.cloneManyTimes(agent_target, opt.game_nagents * (opt.nsteps + 1)) 351 | else 352 | exp.agent = {} 353 | exp.agent_target = {} 354 | 355 | local agent_copies = util.copyManyTimes(agent, opt.game_nagents) 356 | local agent_target_copies = util.copyManyTimes(agent_target, opt.game_nagents) 357 | 358 | for i = 1, opt.game_nagents do 359 | local unrolled = util.cloneManyTimes(agent_copies[i], opt.nsteps + 1) 360 | local unrolled_target = util.cloneManyTimes(agent_target_copies[i], opt.nsteps + 1) 361 | for s = 1, opt.nsteps + 1 do 362 | exp.agent[exp.id(s, i)] = unrolled[s] 363 | exp.agent_target[exp.id(s, i)] = unrolled_target[s] 364 | end 365 | end 366 | end 367 | 368 | return exp 369 | end -------------------------------------------------------------------------------- /code/model/Switch.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | require 'optim' 4 | require 'csvigo' 5 | 6 | local kwargs = require 'include.kwargs' 7 | local log = require 'include.log' 8 | local util = require 'include.util' 9 | local LSTM = require 'module.LSTM' 10 | local GRU = require 'module.GRU' 11 | require 'module.rmsprop' 12 | require 'module.GaussianNoise' 13 | require 'module.Binarize' 14 | 15 | return function(opt) 16 | 17 | local exp = {} 18 | 19 | function exp.optim(iter) 20 | -- iter can be used for learning rate decay 21 | -- local optimfunc = optim.adam 22 | local optimfunc = optim.rmsprop 23 | local optimconfig = { learningRate = opt.learningrate } 24 | return optimfunc, optimconfig 25 | end 26 | 27 | function exp.save(opt, stats, model) 28 | if stats.e % opt.step == 0 then 29 | if opt.filename == '' then 30 | exp.save_path = exp.save_path or paths.concat('results', opt.game .. '_' .. opt.game_nagents .. 31 | (opt.model_dial == 1 and '_dial' or '') .. '_' .. string.upper(string.format("%x", opt.seed))) 32 | else 33 | exp.save_path = exp.save_path or paths.concat('results', opt.game .. '_' .. opt.game_nagents .. 34 | (opt.model_dial == 1 and '_dial' or '') .. '_' .. opt.filename .. '_' .. string.upper(string.format("%x", opt.seed))) 35 | end 36 | 37 | 38 | -- Save opt 39 | if stats.e == opt.step then 40 | os.execute('mkdir -p ' .. exp.save_path) 41 | local opt_csv = {} 42 | for k, v in util.spairs(opt) do 43 | table.insert(opt_csv, { k, v }) 44 | end 45 | 46 | csvigo.save({ 47 | path = paths.concat(exp.save_path, 'opt.csv'), 48 | data = opt_csv, 49 | verbose = false 50 | }) 51 | end 52 | 53 | -- keep stats 54 | stats.history = stats.history or { { 'e', 'td_err', 'td_comm', 'train_r', 'test_r', 'test_opt', 'test_god', 'steps', 'comm_per', 'te' } } 55 | table.insert(stats.history, { 56 | stats.e, 57 | stats.td_err:mean(), 58 | stats.td_comm:mean(), 59 | stats.train_r:mean(), 60 | stats.test_r:mean(), 61 | stats.test_opt:mean(), 62 | stats.test_god:mean(), 63 | stats.steps:mean(), 64 | stats.comm_per:mean(), 65 | stats.te:mean() 66 | }) 67 | 68 | -- Save stats csv 69 | csvigo.save({ 70 | path = paths.concat(exp.save_path, 'stats.csv'), 71 | data = stats.history, 72 | verbose = false 73 | }) 74 | 75 | -- Save action histogram 76 | if opt.hist_action == 1 then 77 | -- Append to memory 78 | stats.history_hist_action = stats.history_hist_action or {} 79 | table.insert(stats.history_hist_action, 80 | stats.hist_action_avg:totable()[1]) 81 | 82 | -- save csv 83 | csvigo.save({ 84 | path = paths.concat(exp.save_path, 'hist_action.csv'), 85 | data = stats.history_hist_action, 86 | verbose = false 87 | }) 88 | end 89 | 90 | -- Save action histogram 91 | if opt.hist_comm == 1 then 92 | -- Append to memory 93 | stats.history_hist_comm = stats.history_hist_comm or {} 94 | table.insert(stats.history_hist_comm, 95 | stats.hist_comm_avg:totable()[1]) 96 | 97 | -- save csv 98 | csvigo.save({ 99 | path = paths.concat(exp.save_path, 'hist_comm.csv'), 100 | data = stats.history_hist_comm, 101 | verbose = false 102 | }) 103 | end 104 | -- save model 105 | if stats.e % (opt.step * 10) == 0 then 106 | log.debug('Saving model') 107 | 108 | -- clear state 109 | -- exp.clearState(model.agent) 110 | 111 | -- save model 112 | local filename = paths.concat(exp.save_path, 'exp.t7') 113 | torch.save(filename, { opt, stats, model.agent }) 114 | end 115 | end 116 | end 117 | 118 | function exp.load() 119 | end 120 | 121 | function exp.clearState(model) 122 | for i = 1, #model do 123 | model[i]:clearState() 124 | end 125 | end 126 | 127 | function exp.training(model) 128 | for i = 1, #model do 129 | model[i]:training() 130 | end 131 | end 132 | 133 | function exp.evaluate(model) 134 | for i = 1, #model do 135 | model[i]:evaluate() 136 | end 137 | end 138 | 139 | function exp.getParameters() 140 | -- Get model params 141 | local a = nn.Container() 142 | for i = 1, #exp.agent do 143 | a:add(exp.agent[i]) 144 | end 145 | local params, gradParams = a:getParameters() 146 | 147 | log.infof('Creating model(s), params=%d', params:nElement()) 148 | 149 | -- Get target model params 150 | local a = nn.Container() 151 | for i = 1, #exp.agent_target do 152 | a:add(exp.agent_target[i]) 153 | end 154 | local params_target, gradParams_target = a:getParameters() 155 | 156 | log.infof('Creating target model(s), params=%d', params_target:nElement()) 157 | 158 | return params, gradParams, params_target, gradParams_target 159 | end 160 | 161 | function exp.id(step_i, agent_i) 162 | return (step_i - 1) * opt.game_nagents + agent_i 163 | end 164 | 165 | function exp.stats(opt, game, stats, e) 166 | 167 | if e % opt.step_test == 0 then 168 | local test_idx = (e / opt.step_test - 1) % (opt.step / opt.step_test) + 1 169 | 170 | -- Initialise 171 | stats.test_opt = stats.test_opt or torch.zeros(opt.step / opt.step_test, opt.game_nagents) 172 | stats.test_god = stats.test_god or torch.zeros(opt.step / opt.step_test, opt.game_nagents) 173 | 174 | -- Naive strategy 175 | local r_naive = 0 176 | for b = 1, opt.bs do 177 | local has_been = game.has_been[{ { b }, { 1, opt.nsteps }, {} }]:sum(2):squeeze(2):gt(0):float():sum() 178 | if has_been == opt.game_nagents then 179 | r_naive = r_naive + game.reward_all_live 180 | else 181 | r_naive = r_naive + game.reward_all_die 182 | end 183 | end 184 | stats.test_opt[test_idx] = r_naive / opt.bs 185 | 186 | -- God strategy 187 | local r_god = 0 188 | for b = 1, opt.bs do 189 | local has_been = game.has_been[{ { b }, { 1, opt.nsteps }, {} }]:sum(2):squeeze(2):gt(0):float():sum() 190 | if has_been == opt.game_nagents then 191 | r_god = r_god + game.reward_all_live 192 | end 193 | end 194 | stats.test_god[test_idx] = r_god / opt.bs 195 | end 196 | 197 | -- Keep stats 198 | if e == opt.step then 199 | stats.test_opt_avg = stats.test_opt:mean() 200 | stats.test_god_avg = stats.test_god:mean() 201 | elseif e % opt.step == 0 then 202 | local coef = 0.9 203 | stats.test_opt_avg = stats.test_opt_avg * coef + stats.test_opt:mean() * (1 - coef) 204 | stats.test_god_avg = stats.test_god_avg * coef + stats.test_god:mean() * (1 - coef) 205 | end 206 | 207 | -- Print statistics 208 | if e % opt.step == 0 then 209 | log.infof('te_opt=%.2f, te_opt_avg=%.2f, te_god=%.2f, te_god_avg=%.2f', 210 | stats.test_opt:mean(), 211 | stats.test_opt_avg, 212 | stats.test_god:mean(), 213 | stats.test_god_avg) 214 | end 215 | end 216 | 217 | local function create_agent() 218 | 219 | -- Sizes 220 | local comm_size = 0 221 | 222 | if (opt.game_comm_bits > 0) and (opt.game_nagents > 1) then 223 | if opt.game_comm_limited then 224 | comm_size = opt.game_comm_bits 225 | else 226 | error('game_comm_limited is required') 227 | end 228 | end 229 | 230 | local action_aware_size = 0 231 | if opt.model_action_aware == 1 then 232 | action_aware_size = opt.game_action_space_total 233 | end 234 | 235 | 236 | -- Process inputs 237 | local model_input = nn.Sequential() 238 | model_input:add(nn.CAddTable(2)) 239 | -- if opt.model_bn == 1 then model_input:add(nn.BatchNormalization(opt.model_rnn_size)) end 240 | 241 | local model_state = nn.Sequential() 242 | model_state:add(nn.LookupTable(2, opt.model_rnn_size)) 243 | 244 | -- RNN 245 | local model_rnn 246 | if opt.model_rnn == 'lstm' then 247 | model_rnn = LSTM(opt.model_rnn_size, 248 | opt.model_rnn_size, 249 | opt.model_rnn_layers, 250 | opt.model_dropout, 251 | opt.model_bn == 1) 252 | elseif opt.model_rnn == 'gru' then 253 | model_rnn = GRU(opt.model_rnn_size, 254 | opt.model_rnn_size, 255 | opt.model_rnn_layers, 256 | opt.model_dropout) 257 | end 258 | 259 | -- use default initialization for convnet, but uniform -0.08 to .08 for RNN: 260 | -- double parens necessary 261 | for _, param in ipairs((model_rnn:parameters())) do 262 | param:uniform(-0.08, 0.08) 263 | end 264 | 265 | -- Output 266 | local model_out = nn.Sequential() 267 | if opt.model_dropout > 0 then model_out:add(nn.Dropout(opt.model_dropout)) end 268 | model_out:add(nn.Linear(opt.model_rnn_size, opt.model_rnn_size)) 269 | model_out:add(nn.ReLU(true)) 270 | model_out:add(nn.Linear(opt.model_rnn_size, opt.game_action_space_total)) 271 | 272 | -- Construct Graph 273 | local in_state = nn.Identity()() 274 | local in_id = nn.Identity()() 275 | local in_rnn_state = nn.Identity()() 276 | 277 | local in_comm, in_action 278 | 279 | local in_all = { 280 | model_state(in_state), 281 | nn.LookupTable(opt.game_nagents, opt.model_rnn_size)(in_id) 282 | } 283 | 284 | -- Communication enabled 285 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 286 | in_comm = nn.Identity()() 287 | -- Process comm 288 | local model_comm = nn.Sequential() 289 | model_comm:add(nn.View(-1, comm_size)) 290 | if opt.model_dial == 1 then 291 | if opt.model_comm_narrow == 1 then 292 | model_comm:add(nn.Sigmoid()) 293 | else 294 | model_comm:add(nn.SoftMax()) 295 | end 296 | end 297 | if opt.model_bn == 1 and opt.model_dial == 1 then 298 | model_comm:add(nn.BatchNormalization(comm_size)) 299 | end 300 | model_comm:add(nn.Linear(comm_size, opt.model_rnn_size)) 301 | if opt.model_comm_narrow == 1 then 302 | model_comm:add(nn.ReLU(true)) 303 | end 304 | 305 | -- Process inputs node 306 | table.insert(in_all, model_comm(in_comm)) 307 | end 308 | 309 | -- Last action enabled 310 | if opt.model_action_aware == 1 then 311 | in_action = nn.Identity()() 312 | 313 | -- Process action node (+1 for no-action at 0-step) 314 | if opt.model_dial == 0 then 315 | local in_action_aware = nn.CAddTable(2)({ 316 | nn.LookupTable(opt.game_action_space + 1, opt.model_rnn_size)(nn.SelectTable(1)(in_action)), 317 | nn.LookupTable(opt.game_comm_bits + 1, opt.model_rnn_size)(nn.SelectTable(2)(in_action)) 318 | }) 319 | table.insert(in_all, in_action_aware) 320 | else 321 | table.insert(in_all, nn.LookupTable(action_aware_size + 1, opt.model_rnn_size)(in_action)) 322 | end 323 | end 324 | 325 | -- Process inputs 326 | local proc_input = model_input(in_all) 327 | 328 | -- 2*n+1 rnn inputs 329 | local rnn_input = {} 330 | table.insert(rnn_input, proc_input) 331 | 332 | -- Restore state 333 | for i = 1, opt.model_rnn_states do 334 | table.insert(rnn_input, nn.SelectTable(i)(in_rnn_state)) 335 | end 336 | 337 | local rnn_output = model_rnn(rnn_input) 338 | 339 | -- Split state and out 340 | local rnn_state = rnn_output 341 | local rnn_out = nn.SelectTable(opt.model_rnn_states)(rnn_output) 342 | 343 | -- Process out 344 | local proc_out = model_out(rnn_out) 345 | 346 | -- Create model 347 | local model_inputs = { in_state, in_id, in_rnn_state } 348 | local model_outputs = { rnn_state, proc_out } 349 | 350 | if opt.game_comm_bits > 0 and opt.game_nagents > 1 then 351 | table.insert(model_inputs, in_comm) 352 | end 353 | 354 | if opt.model_action_aware == 1 then 355 | table.insert(model_inputs, in_action) 356 | end 357 | 358 | nngraph.annotateNodes() 359 | 360 | local model = nn.gModule(model_inputs, model_outputs) 361 | 362 | return model:type(opt.dtype) 363 | end 364 | 365 | -- Create model 366 | local agent = create_agent() 367 | local agent_target = agent:clone() 368 | 369 | -- Knowledge sharing 370 | if opt.model_know_share == 1 then 371 | exp.agent = util.cloneManyTimes(agent, opt.game_nagents * (opt.nsteps + 1)) 372 | exp.agent_target = util.cloneManyTimes(agent_target, opt.game_nagents * (opt.nsteps + 1)) 373 | else 374 | exp.agent = {} 375 | exp.agent_target = {} 376 | 377 | local agent_copies = util.copyManyTimes(agent, opt.game_nagents) 378 | local agent_target_copies = util.copyManyTimes(agent_target, opt.game_nagents) 379 | 380 | for i = 1, opt.game_nagents do 381 | local unrolled = util.cloneManyTimes(agent_copies[i], opt.nsteps + 1) 382 | local unrolled_target = util.cloneManyTimes(agent_target_copies[i], opt.nsteps + 1) 383 | for s = 1, opt.nsteps + 1 do 384 | exp.agent[exp.id(s, i)] = unrolled[s] 385 | exp.agent_target[exp.id(s, i)] = unrolled_target[s] 386 | end 387 | end 388 | end 389 | 390 | return exp 391 | end -------------------------------------------------------------------------------- /code/module/Binarize.lua: -------------------------------------------------------------------------------- 1 | local Binarize, parent = torch.class('nn.Binarize', 'nn.Module') 2 | 3 | function Binarize:__init(stcFlag) 4 | parent.__init(self) 5 | self.stcFlag = stcFlag 6 | self.randmat = torch.Tensor(); 7 | self.outputR = torch.Tensor(); 8 | end 9 | 10 | function Binarize:updateOutput(input) 11 | self.randmat:resizeAs(input); 12 | self.outputR:resizeAs(input); 13 | self.output:resizeAs(input); 14 | self.outputR:copy(input):add(1):div(2) 15 | if self.train and self.stcFlag then 16 | local mask = self.outputR - self.randmat:rand(self.randmat:size()) 17 | self.output = mask:sign() 18 | else 19 | self.output:copy(self.outputR):add(-0.5):sign() 20 | end 21 | return self.output 22 | end 23 | 24 | function Binarize:updateGradInput(input, gradOutput) 25 | self.gradInput:resizeAs(gradOutput) 26 | self.gradInput:copy(gradOutput) --:mul(0.5) 27 | return self.gradInput 28 | end 29 | -------------------------------------------------------------------------------- /code/module/GRU.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Recurrent Batch Normalization 4 | Tim Cooijmans, Nicolas Ballas, César Laurent, Çağlar Gülçehre, Aaron Courville 5 | http://arxiv.org/abs/1603.09025 6 | 7 | Implemented by Yannis M. Assael (www.yannisassael.com), 2016. 8 | 9 | Based on 10 | https://github.com/wojciechz/learning_to_execute, 11 | https://github.com/karpathy/char-rnn/blob/master/model/LSTM.lua, 12 | and Brendan Shillingford. 13 | 14 | Usage: 15 | local rnn = LSTM(input_size, rnn_size, n, dropout, bn) 16 | 17 | ]] -- 18 | 19 | require 'nn' 20 | require 'nngraph' 21 | 22 | local function GRU(input_size, rnn_size, n, dropout) 23 | dropout = dropout or 0 24 | -- there are n+1 inputs (hiddens on each layer and x) 25 | local inputs = {} 26 | table.insert(inputs, nn.Identity()()) -- x 27 | for L = 1, n do 28 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 29 | end 30 | 31 | function new_input_sum(insize, xv, hv) 32 | local i2h = nn.Linear(insize, rnn_size)(xv) 33 | local h2h = nn.Linear(rnn_size, rnn_size)(hv) 34 | return nn.CAddTable()({ i2h, h2h }) 35 | end 36 | 37 | local x, input_size_L 38 | local outputs = {} 39 | for L = 1, n do 40 | 41 | local prev_h = inputs[L + 1] 42 | -- the input to this layer 43 | if L == 1 then 44 | x = inputs[1] 45 | input_size_L = input_size 46 | else 47 | x = outputs[(L - 1)] 48 | if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any 49 | input_size_L = rnn_size 50 | end 51 | -- GRU tick 52 | -- forward the update and reset gates 53 | local update_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) 54 | local reset_gate = nn.Sigmoid()(new_input_sum(input_size_L, x, prev_h)) 55 | -- compute candidate hidden state 56 | local gated_hidden = nn.CMulTable()({ reset_gate, prev_h }) 57 | local p2 = nn.Linear(rnn_size, rnn_size)(gated_hidden) 58 | local p1 = nn.Linear(input_size_L, rnn_size)(x) 59 | local hidden_candidate = nn.Tanh()(nn.CAddTable()({ p1, p2 })) 60 | -- compute new interpolated hidden state, based on the update gate 61 | local zh = nn.CMulTable()({ update_gate, hidden_candidate }) 62 | local zhm1 = nn.CMulTable()({ nn.AddConstant(1, false)(nn.MulConstant(-1, false)(update_gate)), prev_h }) 63 | local next_h = nn.CAddTable()({ zh, zhm1 }) 64 | 65 | table.insert(outputs, next_h) 66 | end 67 | 68 | nngraph.annotateNodes() 69 | 70 | return nn.gModule(inputs, outputs) 71 | end 72 | 73 | return GRU 74 | -------------------------------------------------------------------------------- /code/module/GaussianNoise.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Implemented by Yannis M. Assael (www.yannisassael.com), 2016. 3 | --]] 4 | 5 | local GaussianNoise, parent = torch.class('nn.GaussianNoise', 'nn.Module') 6 | 7 | function GaussianNoise:__init(std, ip) 8 | parent.__init(self) 9 | assert(type(std) == 'number', 'input is not scalar!') 10 | self.std = std 11 | 12 | -- default for inplace is false 13 | self.inplace = ip or false 14 | if (ip and type(ip) ~= 'boolean') then 15 | error('in-place flag must be boolean') 16 | end 17 | 18 | self.noise = torch.Tensor() 19 | self.train = true 20 | end 21 | 22 | function GaussianNoise:training() 23 | self.train = true 24 | end 25 | 26 | function GaussianNoise:evaluate() 27 | self.train = false 28 | end 29 | 30 | function GaussianNoise:updateOutput(input) 31 | if self.train and self.std > 0 then 32 | -- Generate noise 33 | self.noise:resizeAs(input):normal(0, self.std) 34 | 35 | if self.inplace then 36 | input:add(self.noise) 37 | self.output:set(input) 38 | else 39 | self.output:resizeAs(input) 40 | self.output:copy(input) 41 | self.output:add(self.noise) 42 | end 43 | else 44 | if self.inplace then 45 | self.output:set(input) 46 | else 47 | self.output:resizeAs(input) 48 | self.output:copy(input) 49 | end 50 | end 51 | return self.output 52 | end 53 | 54 | function GaussianNoise:updateGradInput(input, gradOutput) 55 | if self.inplace and self.train and self.std > 0 then 56 | self.gradInput:set(gradOutput) 57 | -- restore previous input value 58 | input:add(-1, self.noise) 59 | else 60 | self.gradInput:resizeAs(gradOutput) 61 | self.gradInput:copy(gradOutput) 62 | end 63 | return self.gradInput 64 | end 65 | -------------------------------------------------------------------------------- /code/module/LSTM.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | 3 | Recurrent Batch Normalization 4 | Tim Cooijmans, Nicolas Ballas, César Laurent, Çağlar Gülçehre, Aaron Courville 5 | http://arxiv.org/abs/1603.09025 6 | 7 | Implemented by Yannis M. Assael (www.yannisassael.com), 2016. 8 | 9 | Based on 10 | https://github.com/wojciechz/learning_to_execute, 11 | https://github.com/karpathy/char-rnn/blob/master/model/LSTM.lua, 12 | and Brendan Shillingford. 13 | 14 | Usage: 15 | local rnn = LSTM(input_size, rnn_size, n, dropout, bn) 16 | 17 | ]] -- 18 | 19 | require 'nn' 20 | require 'nngraph' 21 | 22 | local function LSTM(input_size, rnn_size, n, dropout, bn) 23 | dropout = dropout or 0 24 | 25 | -- there will be 2*n+1 inputs 26 | local inputs = {} 27 | table.insert(inputs, nn.Identity()()) -- x 28 | for L = 1, n do 29 | table.insert(inputs, nn.Identity()()) -- prev_c[L] 30 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 31 | end 32 | 33 | local x, input_size_L 34 | local outputs = {} 35 | for L = 1, n do 36 | -- c,h from previos timesteps 37 | local prev_h = inputs[L * 2 + 1] 38 | local prev_c = inputs[L * 2] 39 | -- the input to this layer 40 | if L == 1 then 41 | x = inputs[1] 42 | input_size_L = input_size 43 | else 44 | x = outputs[(L - 1) * 2] 45 | if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any 46 | input_size_L = rnn_size 47 | end 48 | -- recurrent batch normalization 49 | -- http://arxiv.org/abs/1603.09025 50 | local bn_wx, bn_wh, bn_c 51 | if bn then 52 | bn_wx = nn.BatchNormalization(4 * rnn_size, 1e-5, 0.1, true) 53 | bn_wh = nn.BatchNormalization(4 * rnn_size, 1e-5, 0.1, true) 54 | bn_c = nn.BatchNormalization(rnn_size, 1e-5, 0.1, true) 55 | else 56 | bn_wx = nn.Identity() 57 | bn_wh = nn.Identity() 58 | bn_c = nn.Identity() 59 | end 60 | -- evaluate the input sums at once for efficiency 61 | local i2h = bn_wx(nn.Linear(input_size_L, 4 * rnn_size)(x):annotate { name = 'i2h_' .. L }):annotate { name = 'bn_wx_' .. L } 62 | local h2h = bn_wh(nn.Linear(rnn_size, 4 * rnn_size, false)(prev_h):annotate { name = 'h2h_' .. L }):annotate { name = 'bn_wh_' .. L } 63 | local all_input_sums = nn.CAddTable()({ i2h, h2h }) 64 | 65 | local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) 66 | local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) 67 | -- decode the gates 68 | local in_gate = nn.Sigmoid()(n1) 69 | local forget_gate = nn.Sigmoid()(n2) 70 | local out_gate = nn.Sigmoid()(n3) 71 | -- decode the write inputs 72 | local in_transform = nn.Tanh()(n4) 73 | -- perform the LSTM update 74 | local next_c = nn.CAddTable()({ 75 | nn.CMulTable()({ forget_gate, prev_c }), 76 | nn.CMulTable()({ in_gate, in_transform }) 77 | }) 78 | -- gated cells form the output 79 | local next_h = nn.CMulTable()({ out_gate, nn.Tanh()(bn_c(next_c):annotate { name = 'bn_c_' .. L }) }) 80 | 81 | table.insert(outputs, next_c) 82 | table.insert(outputs, next_h) 83 | end 84 | 85 | nngraph.annotateNodes() 86 | 87 | return nn.gModule(inputs, outputs) 88 | end 89 | 90 | return LSTM 91 | -------------------------------------------------------------------------------- /code/module/LinearO.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Torch Linear Unit with Orthogonal Weight Initialization 3 | 4 | Exact solutions to the nonlinear dynamics of learning in deep linear neural networks 5 | http://arxiv.org/abs/1312.6120 6 | 7 | Implemented by Yannis M. Assael (www.yannisassael.com), 2016. 8 | 9 | ]] -- 10 | 11 | local LinearO, parent = torch.class('nn.LinearO', 'nn.Linear') 12 | 13 | function LinearO:__init(inputSize, outputSize) 14 | parent.__init(self, inputSize, outputSize) 15 | self:reset() 16 | end 17 | 18 | function LinearO:reset() 19 | local initScale = 1.1 -- math.sqrt(2) 20 | 21 | local M1 = torch.randn(self.weight:size(1), self.weight:size(1)) 22 | local M2 = torch.randn(self.weight:size(2), self.weight:size(2)) 23 | 24 | local n_min = math.min(self.weight:size(1), self.weight:size(2)) 25 | 26 | -- QR decomposition of random matrices ~ N(0, 1) 27 | local Q1, R1 = torch.qr(M1) 28 | local Q2, R2 = torch.qr(M2) 29 | 30 | self.weight:copy(Q1:narrow(2, 1, n_min) * Q2:narrow(1, 1, n_min)):mul(initScale) 31 | 32 | self.bias:zero() 33 | end 34 | 35 | -------------------------------------------------------------------------------- /code/module/Print.lua: -------------------------------------------------------------------------------- 1 | local Print, parent = torch.class('nn.Print', 'nn.Module') 2 | 3 | function Print:__init(stcFlag) 4 | parent.__init(self) 5 | end 6 | 7 | function Print:updateOutput(input) 8 | print(input) 9 | self.output = input 10 | return self.output 11 | end 12 | 13 | function Print:updateGradInput(input, gradOutput) 14 | self.gradInput = gradOutput 15 | return self.gradInput 16 | end 17 | -------------------------------------------------------------------------------- /code/module/rmsprop.lua: -------------------------------------------------------------------------------- 1 | --[[ An implementation of RMSprop 2 | 3 | ARGS: 4 | 5 | - 'opfunc' : a function that takes a single input (X), the point 6 | of a evaluation, and returns f(X) and df/dX 7 | - 'x' : the initial point 8 | - 'config` : a table with configuration parameters for the optimizer 9 | - 'config.learningRate' : learning rate 10 | - 'config.alpha' : smoothing constant 11 | - 'config.epsilon' : value with which to initialise m 12 | - 'config.weightDecay' : weight decay 13 | - 'state' : a table describing the state of the optimizer; 14 | after each call the state is modified 15 | - 'state.m' : leaky sum of squares of parameter gradients, 16 | - 'state.tmp' : and the square root (with epsilon smoothing) 17 | 18 | RETURN: 19 | - `x` : the new x vector 20 | - `f(x)` : the function, evaluated before the update 21 | 22 | ]] 23 | 24 | function optim.rmsprop(opfunc, x, config, state) 25 | -- (0) get/update state 26 | local config = config or {} 27 | local state = state or config 28 | local lr = config.learningRate or 1e-2 29 | local alpha = config.alpha or 0.99 30 | local epsilon = config.epsilon or 1e-8 31 | local wd = config.weightDecay or 0 32 | 33 | -- (1) evaluate f(x) and df/dx 34 | local fx, dfdx = opfunc(x) 35 | 36 | -- (2) weight decay 37 | if wd ~= 0 then 38 | dfdx:add(wd, x) 39 | end 40 | 41 | -- (3) initialize mean square values and square gradient storage 42 | if not state.m then 43 | -- This line kills the performance 44 | -- state.m = torch.Tensor():typeAs(x):resizeAs(dfdx):fill(1) 45 | state.m = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 46 | state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx) 47 | end 48 | 49 | -- (4) calculate new (leaky) mean squared values 50 | state.m:mul(alpha) 51 | state.m:addcmul(1.0-alpha, dfdx, dfdx) 52 | 53 | -- (5) perform update 54 | state.tmp:sqrt(state.m):add(epsilon) 55 | x:addcdiv(-lr, dfdx, state.tmp) 56 | 57 | -- return x*, f(x) before optimization 58 | return x, {fx} 59 | end 60 | -------------------------------------------------------------------------------- /code/module/rmspropm.lua: -------------------------------------------------------------------------------- 1 | -- RMSProp with momentum as found in "Generating Sequences With Recurrent Neural Networks" 2 | function optim.rmspropm(opfunc, x, config, state) 3 | -- Get state 4 | local config = config or {} 5 | local state = state or config 6 | local lr = config.learningRate or 1e-2 7 | local momentum = config.momentum or 0.95 8 | local epsilon = config.epsilon or 0.01 9 | 10 | -- Evaluate f(x) and df/dx 11 | local fx, dfdx = opfunc(x) 12 | 13 | -- Initialise storage 14 | if not state.g then 15 | state.g = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 16 | state.gSq = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 17 | state.tmp = torch.Tensor():typeAs(x):resizeAs(dfdx) 18 | end 19 | 20 | -- g = αg + (1 - α)df/dx 21 | state.g:mul(momentum):add(1 - momentum, dfdx) -- Calculate momentum 22 | -- tmp = df/dx . df/dx 23 | state.tmp:cmul(dfdx, dfdx) 24 | -- gSq = αgSq + (1 - α)df/dx 25 | state.gSq:mul(momentum):add(1 - momentum, state.tmp) -- Calculate "squared" momentum 26 | -- tmp = g . g 27 | state.tmp:cmul(state.g, state.g) 28 | -- tmp = (-tmp + gSq + ε)^0.5 29 | state.tmp:neg():add(state.gSq):add(epsilon):sqrt() 30 | 31 | -- Update x = x - lr x df/dx / tmp 32 | x:addcdiv(-lr, dfdx, state.tmp) 33 | 34 | -- Return x*, f(x) before optimisation 35 | return x, { fx } 36 | end 37 | -------------------------------------------------------------------------------- /code/results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iassael/learning-to-communicate/1cdfc235e07be9de2af38fb23d78f61b2fa7c99b/code/results/.gitkeep -------------------------------------------------------------------------------- /code/run_colordigit-dial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -bs 32 \ 4 | -game ColorDigit \ 5 | -game_level extra_hard_local \ 6 | -game_nagents 2 \ 7 | -game_action_space 2 \ 8 | -game_comm_limited 0 \ 9 | -game_comm_bits 1 \ 10 | -game_comm_sigma 2 \ 11 | -nsteps 2 \ 12 | -gamma 1 \ 13 | -model_dial 1 \ 14 | -model_know_share 1 \ 15 | -model_action_aware 1 \ 16 | -model_rnn_size 128 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 20000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 -------------------------------------------------------------------------------- /code/run_colordigit-rial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -bs 32 \ 4 | -game ColorDigit \ 5 | -game_level extra_hard_local \ 6 | -game_nagents 2 \ 7 | -game_action_space 2 \ 8 | -game_comm_limited 0 \ 9 | -game_comm_bits 1 \ 10 | -game_comm_sigma 2 \ 11 | -nsteps 2 \ 12 | -gamma 1 \ 13 | -model_dial 0 \ 14 | -model_know_share 1 \ 15 | -model_action_aware 1 \ 16 | -model_rnn_size 128 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 20000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 -------------------------------------------------------------------------------- /code/run_colordigit_many_steps-dial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game ColorDigit \ 4 | -game_level many_bits \ 5 | -game_nagents 2 \ 6 | -game_colors 1 \ 7 | -game_action_space 10 \ 8 | -game_comm_limited 0 \ 9 | -game_comm_bits 1 \ 10 | -game_comm_sigma 0.5 \ 11 | -nsteps 5 \ 12 | -eps 0.05 \ 13 | -gamma 1 \ 14 | -model_dial 1 \ 15 | -model_bn 1 \ 16 | -model_know_share 1 \ 17 | -model_action_aware 1 \ 18 | -model_rnn_size 128 \ 19 | -game_vision_net mlp \ 20 | -bs 32 \ 21 | -learningrate 0.0005 \ 22 | -nepisodes 50000 \ 23 | -step 100 \ 24 | -step_test 10 \ 25 | -step_target 100 \ 26 | -cuda 0 -------------------------------------------------------------------------------- /code/run_colordigit_many_steps-rial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game ColorDigit \ 4 | -game_level many_bits \ 5 | -game_nagents 2 \ 6 | -game_colors 1 \ 7 | -game_action_space 10 \ 8 | -game_comm_limited 0 \ 9 | -game_comm_bits 1 \ 10 | -game_comm_sigma 0.5 \ 11 | -nsteps 5 \ 12 | -eps 0.05 \ 13 | -gamma 1 \ 14 | -model_dial 0 \ 15 | -model_bn 1 \ 16 | -model_know_share 1 \ 17 | -model_action_aware 1 \ 18 | -model_rnn_size 128 \ 19 | -game_vision_net mlp \ 20 | -bs 32 \ 21 | -learningrate 0.0005 \ 22 | -nepisodes 50000 \ 23 | -step 100 \ 24 | -step_test 10 \ 25 | -step_target 100 \ 26 | -cuda 0 -------------------------------------------------------------------------------- /code/run_switch_3-dial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game Switch \ 4 | -game_nagents 3 \ 5 | -game_action_space 2 \ 6 | -game_comm_limited 1 \ 7 | -game_comm_bits 1 \ 8 | -game_comm_sigma 2 \ 9 | -nsteps 6 \ 10 | -gamma 1 \ 11 | -model_dial 1 \ 12 | -model_bn 1 \ 13 | -model_know_share 1 \ 14 | -model_action_aware 1 \ 15 | -model_rnn_size 128 \ 16 | -bs 32 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 5000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 23 | -------------------------------------------------------------------------------- /code/run_switch_3-rial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game Switch \ 4 | -game_nagents 3 \ 5 | -game_action_space 2 \ 6 | -game_comm_limited 1 \ 7 | -game_comm_bits 1 \ 8 | -game_comm_sigma 2 \ 9 | -nsteps 6 \ 10 | -gamma 1 \ 11 | -model_dial 0 \ 12 | -model_bn 1 \ 13 | -model_know_share 1 \ 14 | -model_action_aware 1 \ 15 | -model_rnn_size 128 \ 16 | -bs 32 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 5000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 23 | -------------------------------------------------------------------------------- /code/run_switch_4-dial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game Switch \ 4 | -game_nagents 4 \ 5 | -game_action_space 2 \ 6 | -game_comm_limited 1 \ 7 | -game_comm_bits 1 \ 8 | -game_comm_sigma 2 \ 9 | -nsteps 6 \ 10 | -gamma 1 \ 11 | -model_dial 1 \ 12 | -model_bn 1 \ 13 | -model_know_share 1 \ 14 | -model_action_aware 1 \ 15 | -model_rnn_size 128 \ 16 | -bs 32 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 50000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 23 | -------------------------------------------------------------------------------- /code/run_switch_4-rial.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | th comm.lua \ 3 | -game Switch \ 4 | -game_nagents 4 \ 5 | -game_action_space 2 \ 6 | -game_comm_limited 1 \ 7 | -game_comm_bits 1 \ 8 | -game_comm_sigma 2 \ 9 | -nsteps 6 \ 10 | -gamma 1 \ 11 | -model_dial 0 \ 12 | -model_bn 1 \ 13 | -model_know_share 1 \ 14 | -model_action_aware 1 \ 15 | -model_rnn_size 128 \ 16 | -bs 32 \ 17 | -learningrate 0.0005 \ 18 | -nepisodes 50000 \ 19 | -step 100 \ 20 | -step_test 10 \ 21 | -step_target 100 \ 22 | -cuda 0 23 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | # Learning to Communicate with Deep Multi-Agent Reinforcement Learning 3 | 4 | Jakob N. Foerster, Yannis M. Assael, Nando de Freitas, Shimon Whiteson 5 | 6 | ## PyTorch 7 | 8 | \- [PyTorch Implementation by @minqi](https://github.com/minqi/learning-to-communicate-pytorch) 9 | 10 | \- [Simplified PyTorch implementation in a colab by @JainMoksh](https://colab.research.google.com/gist/MJ10/2c0d1972f3dd1edcc3cd17c636aac8d2/dial.ipynb) 11 | 12 | ## Abstract 13 | 14 |
15 |
16 |