├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── agents ├── __init__.py ├── baseline_agent.py ├── ddt_agent.py ├── djinn_helpers.py ├── heuristic_agent.py ├── lstm_agent.py ├── prolonet_agent.py ├── py_djinn_agent.py ├── vectorized_prolonet.py └── vectorized_prolonet_helpers.py ├── gym_requirements.txt ├── models └── README.md ├── opt_helpers ├── __init__.py ├── ppo_update.py ├── prolo_from_language.py └── replay_buffer.py ├── python38.txt ├── runfiles ├── gym_runner.py ├── minigame_runner.py ├── sc_build_hellions.py ├── sc_build_hellions_helpers.py ├── sc_helpers.py ├── sc_replays │ ├── .gitignore │ └── .gitkeep ├── sc_runner.py └── visuals │ ├── figures │ └── .gitignore │ ├── visualize_build_hellions.py │ ├── visualize_prolonet.py │ └── visualize_sc_runner.py ├── sc2_requirements.txt └── txts └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | txts/*.txt 134 | models/*.tar 135 | 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | Copyright 2020 Andrew Silva 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | http://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. 191 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProLoNets 2 | Public implementation of "Encoding Human Domain Knowledge to Warm Start Reinforcement Learning" from AAAI'21 3 | 4 | ### Requirements 5 | 6 | Refer to the `python38.txt` file for the OpenAI gym environments and the `sc2_requirements.txt` file for the StarCraft II environments. As the name suggests, `python38.txt` builds on Python 3.8.10. In order to work with the SC2 environments, you must have Python >= 3.6, and then installing the requirements in the `sc2_requirements.txt` file should do it. 7 | 8 | ### Running Experiments 9 | 10 | All of the code to run various domains lives in the `runfiles/` directory. 11 | All file involve a few command line arguments, which I'll review now: 12 | 13 | * `-a` or `--agent_type`: Which agent should play through the domain. Details below. Default: `prolo` 14 | * `-e` or `--episodes`: How many episodes to run for. Default: 1000 15 | * `-s` or `--sl_init`: Should the agent be trained via imitation learning first? Only applies if `agent_type` is `fc`.Default: False 16 | * `-rand`: Should the ProLoNet agent be randomly-initialized? Include flag to set to `True`. 17 | * `-deep`: Should the ProLoNet include dynamic growth? Include flag to set to `True`. 18 | * `-adv`: Should the ProLoNet be an "M-Mistake" agent? Include the flag to set to `True`. The probability itself is hard-coded in the ProLoNet file at line 35. 19 | * `--reproduce`: Use pre-specified random seeds for lunar lander and cart pole? Include to indicate `True`, omit for `False`. 20 | 21 | For the `-a` or `--agent_type` flag, valid options are: 22 | * `prolo` for a normal ProLoNet agent 23 | * `random` for random actions (not available in full game of StarCraftII) 24 | * `heuristic` for the heuristic only (not available in the full game of StarCraftII) 25 | * `fc` for a fully-connected agent 26 | * `lstm` for an LSTM agent 27 | * `djinn` for a DJINN agent 28 | 29 | #### gym_runner.py 30 | 31 | This file runs both of the OpenAI gym domains from the paper, namely cart pole and lunar lander. It has one additional command line argument: 32 | * `-env` or `--env_type`: Which environment to run. Valid options are `cart` and `lunar`. Default: `cart` 33 | 34 | This script will run with most any version of Python3 and the required packages. To ensure consistent results with the `--reproduce` flag, you _must_ use Python 3.8.10 and the included `python38.txt` requirements and be on Ubuntu 20.04. Other operating systems have not been tested and may require additional tinkering or random seeding to reproduce results faithfully. 35 | 36 | 37 | Running a ProLoNet agent on lunar lander for 1500 episodes looks like: 38 | ``` 39 | python gym_runner.py -a prolo -deep -e 1500 -env lunar 40 | ``` 41 | For the _LOKI_ agent: 42 | ``` 43 | python gym_runner.py -a fc -e 1500 -env lunar -s 44 | ``` 45 | 46 | #### minigame_runner.py 47 | 48 | This file runs the FindAndDefeatZerglings minigame from the SC2LE. Running this is exactly the same as the `gym_runner.py` runfile, with the exception that no `--env_type` flag exists for this domain. You must also have all of the StarCraft II setup complete, which means having a valid copy of StarCraft II, having Python >= 3.6, and installing the requirements from the `sc2_requirements.txt` file. For information on setting up StarCraft II, refer to [Blizzard's Documentation](https://github.com/Blizzard/s2client-proto) and for the minigame itself, you'll need the map from [DeepMind's repo](https://github.com/deepmind/pysc2). 49 | 50 | Running a ProLoNet agent: 51 | ``` 52 | python minigame_runner.py -a prolo -deep -e 1000 53 | ``` 54 | And a fully-connected agent: 55 | ``` 56 | python minigame_runner.py -a fc -e 1000 57 | ``` 58 | And an LSTM agent: 59 | ``` 60 | python minigame_runner.py -a lstm -e 1000 61 | ``` 62 | And a DJINN agent: 63 | ``` 64 | python minigame_runner.py -a djinn -e 1000 65 | ``` 66 | 67 | #### sc_runner.py 68 | 69 | This file runs the full SC2 game against in-game AI. In game AI difficulty is set on lines 836-838. Simply changing "Difficult.VeryEasy" to "Difficulty.Easy", "Difficulty.Medium", or "Difficulty.Hard" does the trick. Again, you'll need SC2 and all of the requirements for the appropriate Python environment, as discussed above. 70 | Running a ProLoNet agent: 71 | ``` 72 | python sc_runner.py -a prolo -e 500 73 | ``` 74 | And a random ProLoNet agent: 75 | ``` 76 | python sc_runner.py -a prolo -rand -e 500 77 | ``` 78 | 79 | #### Citation 80 | If you use this project, please cite our work! Bibtex below: 81 | ``` 82 | @article{prolonets, 83 | title={Encoding Human Domain Knowledge to Warm Start Reinforcement Learning}, 84 | volume={35}, 85 | url={https://ojs.aaai.org/index.php/AAAI/article/view/16638}, 86 | abstractNote={Deep reinforcement learning has been successful in a variety of tasks, such as game playing and robotic manipulation. However, attempting to learn tabula rasa disregards the logical structure of many domains as well as the wealth of readily available knowledge from domain experts that could help "warm start" the learning process. We present a novel reinforcement learning technique that allows for intelligent initialization of a neural network weights and architecture. Our approach permits the encoding domain knowledge directly into a neural decision tree, and improves upon that knowledge with policy gradient updates. We empirically validate our approach on two OpenAI Gym tasks and two modified StarCraft 2 tasks, showing that our novel architecture outperforms multilayer-perceptron and recurrent architectures. Our knowledge-based framework finds superior policies compared to imitation learning-based and prior knowledge-based approaches. Importantly, we demonstrate that our approach can be used by untrained humans to initially provide >80\% increase in expected reward relative to baselines prior to training (p < 0.001), which results in a >60\% increase in expected reward after policy optimization (p = 0.011).}, 87 | number={6}, 88 | journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 89 | author={Silva, Andrew and Gombolay, Matthew}, 90 | year={2021}, 91 | month={5}, 92 | pages={5042-5050} } 93 | ``` -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Created by Andrew Silva on 12/27/18 -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CORE-Robotics-Lab/ProLoNets/b1f138ed3520e623a84f2a7d5b1969fa45845700/agents/__init__.py -------------------------------------------------------------------------------- /agents/baseline_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Categorical 4 | import sys 5 | sys.path.insert(0, '../') 6 | from opt_helpers import replay_buffer, ppo_update 7 | import copy 8 | from agents.heuristic_agent import CartPoleHeuristic, LunarHeuristic, \ 9 | StarCraftMacroHeuristic, StarCraftMicroHeuristic 10 | from agents.prolonet_agent import DeepProLoNet 11 | 12 | 13 | class BaselineFCNet(nn.Module): 14 | def __init__(self, input_dim, is_value=False, output_dim=2, hidden_layers=1): 15 | super(BaselineFCNet, self).__init__() 16 | self.lin1 = nn.Linear(input_dim, input_dim) 17 | self.lin2 = None 18 | self.lin3 = nn.Linear(input_dim, output_dim) 19 | self.sig = nn.ReLU() 20 | self.input_dim = input_dim 21 | modules = [] 22 | for h in range(hidden_layers): 23 | modules.append(nn.Linear(input_dim, input_dim)) 24 | if len(modules) > 0: 25 | self.lin2 = nn.Sequential(*modules) 26 | self.softmax = nn.Softmax(dim=1) 27 | self.is_value = is_value 28 | 29 | def forward(self, input_data): 30 | if self.lin2 is not None: 31 | act_out = self.lin3(self.sig(self.lin2(self.sig(self.lin1(input_data))))) 32 | else: 33 | act_out = self.lin3(self.sig(self.lin1(input_data))) 34 | if self.is_value: 35 | return act_out 36 | else: 37 | return self.softmax(act_out) 38 | 39 | 40 | class FCNet: 41 | def __init__(self, 42 | bot_name='FCNet', 43 | input_dim=4, 44 | output_dim=2, 45 | sl_init=False, 46 | num_hidden=1 47 | ): 48 | self.bot_name = bot_name + str(num_hidden) + '_hid' 49 | self.sl_init = sl_init 50 | self.input_dim = input_dim 51 | self.output_dim = output_dim 52 | self.num_hidden = num_hidden 53 | self.replay_buffer = replay_buffer.ReplayBufferSingleAgent() 54 | self.action_network = BaselineFCNet(input_dim=input_dim, 55 | output_dim=output_dim, 56 | is_value=False, 57 | hidden_layers=num_hidden) 58 | self.value_network = BaselineFCNet(input_dim=input_dim, 59 | output_dim=output_dim, 60 | is_value=True, 61 | hidden_layers=num_hidden) 62 | if self.sl_init: 63 | if input_dim == 4: 64 | self.teacher = CartPoleHeuristic() 65 | self.teacher = DeepProLoNet(distribution='one_hot', 66 | input_dim=input_dim, 67 | output_dim=output_dim, 68 | use_gpu=False, 69 | vectorized=False, 70 | randomized=False, 71 | adversarial=False, 72 | deepen=False, 73 | deterministic=True, 74 | ) 75 | self.action_loss_threshold = 50 76 | elif input_dim == 6: # Fire Sim 77 | self.teacher = DeepProLoNet(distribution='one_hot', 78 | input_dim=input_dim, 79 | output_dim=output_dim, 80 | use_gpu=False, 81 | vectorized=False, 82 | randomized=False, 83 | adversarial=False, 84 | deepen=False, 85 | deterministic=True, 86 | ) 87 | self.action_loss_threshold = 50 88 | elif input_dim == 12: # Build Marines 89 | self.teacher = DeepProLoNet(distribution='one_hot', 90 | input_dim=input_dim, 91 | output_dim=output_dim, 92 | use_gpu=False, 93 | vectorized=False, 94 | randomized=False, 95 | adversarial=False, 96 | deepen=False, 97 | deterministic=True, 98 | ) 99 | self.action_loss_threshold = 50 100 | elif input_dim == 8: 101 | self.teacher = LunarHeuristic() 102 | self.action_loss_threshold = 100 103 | elif input_dim == 28: 104 | self.teacher = DeepProLoNet(distribution='one_hot', 105 | input_dim=input_dim, 106 | output_dim=output_dim, 107 | use_gpu=False, 108 | vectorized=False, 109 | randomized=False, 110 | adversarial=False, 111 | deepen=False, 112 | deterministic=True, 113 | ) 114 | self.teacher.load() 115 | self.action_loss_threshold = 50 116 | elif input_dim == 37: 117 | self.teacher = StarCraftMicroHeuristic() 118 | self.action_loss_threshold = 50 119 | elif input_dim > 100: 120 | self.teacher = StarCraftMacroHeuristic() 121 | self.action_loss_threshold = 1000 122 | self.bot_name += '_SLtoRL_' 123 | self.ppo = ppo_update.PPO([self.action_network, self.value_network], two_nets=True) 124 | self.actor_opt = torch.optim.RMSprop(self.action_network.parameters(), lr=5e-3) 125 | self.value_opt = torch.optim.RMSprop(self.value_network.parameters(), lr=5e-3) 126 | # self.ppo.actor_opt = self.actor_opt 127 | # self.ppo.critic_opt = self.value_opt 128 | 129 | self.last_state = [0, 0, 0, 0] 130 | self.last_action = 0 131 | self.last_action_probs = torch.Tensor([0]) 132 | self.last_value_pred = torch.Tensor([[0, 0]]) 133 | self.last_deep_action_probs = torch.Tensor([0]) 134 | self.last_deep_value_pred = torch.Tensor([[0, 0]]) 135 | self.full_probs = None 136 | self.reward_history = [] 137 | self.num_steps = 0 138 | 139 | def get_action(self, observation): 140 | with torch.no_grad(): 141 | obs = torch.Tensor(observation) 142 | obs = obs.view(1, -1) 143 | self.last_state = obs 144 | probs = self.action_network(obs) 145 | value_pred = self.value_network(obs) 146 | probs = probs.view(-1) 147 | self.full_probs = probs 148 | if self.action_network.input_dim > 30: 149 | probs, inds = torch.topk(probs, 3) 150 | m = Categorical(probs) 151 | action = m.sample() 152 | # # Epsilon learning 153 | # action = torch.argmax(probs) 154 | # if random.random() < self.epsilon: 155 | # action = torch.LongTensor([random.choice(np.arange(0, self.output_dim, dtype=np.int))]) 156 | # # End Epsilon learning 157 | log_probs = m.log_prob(action) 158 | self.last_action_probs = log_probs 159 | self.last_value_pred = value_pred.view(-1).cpu() 160 | 161 | if self.action_network.input_dim > 30: 162 | self.last_action = inds[action] 163 | else: 164 | self.last_action = action 165 | if self.action_network.input_dim > 30: 166 | action = inds[action].item() 167 | else: 168 | action = action.item() 169 | return action 170 | 171 | def save_reward(self, reward): 172 | self.replay_buffer.insert(obs=[self.last_state], 173 | action_log_probs=self.last_action_probs, 174 | value_preds=self.last_value_pred[self.last_action.item()], 175 | last_action=self.last_action, 176 | full_probs_vector=self.full_probs, 177 | rewards=reward) 178 | return True 179 | 180 | def end_episode(self, timesteps, num_processes=1): 181 | self.reward_history.append(timesteps) 182 | if self.sl_init and self.num_steps < self.action_loss_threshold: 183 | action_loss = self.ppo.sl_updates(self.replay_buffer, self, self.teacher) 184 | else: 185 | value_loss, action_loss = self.ppo.batch_updates(self.replay_buffer, self) 186 | bot_name = '../txts/' + self.bot_name + str(num_processes) + '_processes' 187 | self.num_steps += 1 188 | with open(bot_name + '_rewards.txt', 'a') as myfile: 189 | myfile.write(str(timesteps) + '\n') 190 | 191 | def lower_lr(self): 192 | for param_group in self.ppo.actor_opt.param_groups: 193 | param_group['lr'] = param_group['lr'] * 0.5 194 | for param_group in self.ppo.critic_opt.param_groups: 195 | param_group['lr'] = param_group['lr'] * 0.5 196 | 197 | def reset(self): 198 | self.replay_buffer.clear() 199 | 200 | def deepen_networks(self): 201 | pass 202 | 203 | def save(self, fn='last'): 204 | checkpoint = dict() 205 | checkpoint['actor'] = self.action_network.state_dict() 206 | checkpoint['value'] = self.value_network.state_dict() 207 | torch.save(checkpoint, fn+self.bot_name+'.pth.tar') 208 | 209 | def load(self, fn='last'): 210 | # fn = fn + self.bot_name + '.pth.tar' 211 | model_checkpoint = torch.load(fn, map_location='cpu') 212 | actor_data = model_checkpoint['actor'] 213 | value_data = model_checkpoint['value'] 214 | self.action_network.load_state_dict(actor_data) 215 | self.value_network.load_state_dict(value_data) 216 | 217 | def __getstate__(self): 218 | return { 219 | # 'replay_buffer': self.replay_buffer, 220 | 'action_network': self.action_network, 221 | 'value_network': self.value_network, 222 | 'ppo': self.ppo, 223 | 'actor_opt': self.actor_opt, 224 | 'value_opt': self.value_opt, 225 | 'num_hidden': self.num_hidden 226 | } 227 | 228 | def __setstate__(self, state): 229 | self.action_network = copy.deepcopy(state['action_network']) 230 | self.value_network = copy.deepcopy(state['value_network']) 231 | self.ppo = copy.deepcopy(state['ppo']) 232 | self.actor_opt = copy.deepcopy(state['actor_opt']) 233 | self.value_opt = copy.deepcopy(state['value_opt']) 234 | self.num_hidden = copy.deepcopy(state['num_hidden']) 235 | 236 | def duplicate(self): 237 | new_agent = FCNet( 238 | bot_name=self.bot_name, 239 | input_dim=self.input_dim, 240 | output_dim=self.output_dim, 241 | sl_init=self.sl_init, 242 | num_hidden=self.num_hidden 243 | ) 244 | new_agent.__setstate__(self.__getstate__()) 245 | return new_agent 246 | -------------------------------------------------------------------------------- /agents/ddt_agent.py: -------------------------------------------------------------------------------- 1 | # Created by Andrew Silva on 8/28/19 2 | import torch 3 | from torch.distributions import Categorical 4 | import sys 5 | sys.path.insert(0, '../') 6 | from agents.vectorized_prolonet import ProLoNet 7 | from opt_helpers import replay_buffer, ppo_update 8 | import os 9 | import numpy as np 10 | 11 | 12 | def save_ddt(fn, model): 13 | checkpoint = dict() 14 | mdl_data = dict() 15 | mdl_data['weights'] = model.layers 16 | mdl_data['comparators'] = model.comparators 17 | mdl_data['leaf_init_information'] = model.leaf_init_information 18 | mdl_data['action_probs'] = model.action_probs 19 | mdl_data['alpha'] = model.alpha 20 | mdl_data['input_dim'] = model.input_dim 21 | mdl_data['is_value'] = model.is_value 22 | checkpoint['model_data'] = mdl_data 23 | torch.save(checkpoint, fn) 24 | 25 | 26 | def load_ddt(fn): 27 | model_checkpoint = torch.load(fn, map_location='cpu') 28 | model_data = model_checkpoint['model_data'] 29 | init_weights = [weight.detach().clone().data.cpu().numpy() for weight in model_data['weights']] 30 | init_comparators = [comp.detach().clone().data.cpu().numpy() for comp in model_data['comparators']] 31 | 32 | new_model = ProLoNet(input_dim=model_data['input_dim'], 33 | weights=init_weights, 34 | comparators=init_comparators, 35 | leaves=model_data['leaf_init_information'], 36 | alpha=model_data['alpha'].item(), 37 | is_value=model_data['is_value']) 38 | new_model.action_probs = model_data['action_probs'] 39 | return new_model 40 | 41 | 42 | def init_rule_list(num_rules, dim_in, dim_out): 43 | weights = np.random.rand(num_rules, dim_in) 44 | leaves = [] 45 | comparators = np.random.rand(num_rules, 1) 46 | for leaf_index in range(num_rules): 47 | leaves.append([[leaf_index], np.arange(0, leaf_index).tolist(), np.random.rand(dim_out)]) 48 | leaves.append([[], np.arange(0, num_rules).tolist(), np.random.rand(dim_out)]) 49 | return weights, comparators, leaves 50 | 51 | 52 | class DDTAgent: 53 | def __init__(self, 54 | bot_name='DDT', 55 | input_dim=4, 56 | output_dim=2, 57 | rule_list=False, 58 | num_rules=4): 59 | self.replay_buffer = replay_buffer.ReplayBufferSingleAgent() 60 | self.bot_name = bot_name 61 | self.rule_list = rule_list 62 | self.output_dim = output_dim 63 | self.input_dim = input_dim 64 | self.num_rules = num_rules 65 | if rule_list: 66 | self.bot_name += str(num_rules)+'_rules' 67 | init_weights, init_comparators, init_leaves = init_rule_list(num_rules, input_dim, output_dim) 68 | else: 69 | init_weights = None 70 | init_comparators = None 71 | init_leaves = num_rules 72 | self.bot_name += str(num_rules) + '_leaves' 73 | self.action_network = ProLoNet(input_dim=input_dim, 74 | output_dim=output_dim, 75 | weights=init_weights, 76 | comparators=init_comparators, 77 | leaves=init_leaves, 78 | alpha=1, 79 | is_value=False, 80 | use_gpu=False) 81 | self.value_network = ProLoNet(input_dim=input_dim, 82 | output_dim=output_dim, 83 | weights=init_weights, 84 | comparators=init_comparators, 85 | leaves=init_leaves, 86 | alpha=1, 87 | is_value=True, 88 | use_gpu=False) 89 | 90 | self.ppo = ppo_update.PPO([self.action_network, self.value_network], two_nets=True, use_gpu=False) 91 | 92 | self.last_state = [0, 0, 0, 0] 93 | self.last_action = 0 94 | self.last_action_probs = torch.Tensor([0]) 95 | self.last_value_pred = torch.Tensor([[0, 0]]) 96 | self.last_deep_action_probs = None 97 | self.last_deep_value_pred = [None]*output_dim 98 | self.full_probs = None 99 | self.deeper_full_probs = None 100 | self.reward_history = [] 101 | self.num_steps = 0 102 | 103 | def get_action(self, observation): 104 | with torch.no_grad(): 105 | obs = torch.Tensor(observation) 106 | obs = obs.view(1, -1) 107 | self.last_state = obs 108 | # if self.use_gpu: 109 | # obs = obs.cuda() 110 | probs = self.action_network(obs) 111 | value_pred = self.value_network(obs) 112 | probs = probs.view(-1).cpu() 113 | self.full_probs = probs 114 | if self.action_network.input_dim > 10: 115 | probs, inds = torch.topk(probs, 3) 116 | m = Categorical(probs) 117 | action = m.sample() 118 | log_probs = m.log_prob(action) 119 | self.last_action_probs = log_probs.cpu() 120 | self.last_value_pred = value_pred.view(-1).cpu() 121 | 122 | if self.action_network.input_dim > 10: 123 | self.last_action = inds[action].cpu() 124 | else: 125 | self.last_action = action.cpu() 126 | if self.action_network.input_dim > 10: 127 | action = inds[action].item() 128 | else: 129 | action = action.item() 130 | return action 131 | 132 | def save_reward(self, reward): 133 | self.replay_buffer.insert(obs=[self.last_state], 134 | action_log_probs=self.last_action_probs, 135 | value_preds=self.last_value_pred[self.last_action.item()], 136 | deeper_action_log_probs=self.last_deep_action_probs, 137 | deeper_value_pred=self.last_deep_value_pred[self.last_action.item()], 138 | last_action=self.last_action, 139 | full_probs_vector=self.full_probs, 140 | deeper_full_probs_vector=self.deeper_full_probs, 141 | rewards=reward) 142 | return True 143 | 144 | def end_episode(self, timesteps, num_processes): 145 | value_loss, action_loss = self.ppo.batch_updates(self.replay_buffer, self,) 146 | self.num_steps += 1 147 | # Copy over new decision node params from shallower network to deeper network 148 | bot_name = '../txts/' + self.bot_name 149 | with open(bot_name + '_rewards.txt', 'a') as myfile: 150 | myfile.write(str(timesteps) + '\n') 151 | 152 | def lower_lr(self): 153 | for param_group in self.ppo.actor_opt.param_groups: 154 | param_group['lr'] = param_group['lr'] * 0.5 155 | for param_group in self.ppo.critic_opt.param_groups: 156 | param_group['lr'] = param_group['lr'] * 0.5 157 | 158 | def reset(self): 159 | self.replay_buffer.clear() 160 | 161 | def save(self, fn='last'): 162 | act_fn = fn + self.bot_name + '_actor_' + '.pth.tar' 163 | val_fn = fn + self.bot_name + '_critic_' + '.pth.tar' 164 | 165 | save_ddt(act_fn, self.action_network) 166 | save_ddt(val_fn, self.value_network) 167 | 168 | def load(self, fn='last'): 169 | act_fn = fn + self.bot_name + '_actor_' + '.pth.tar' 170 | val_fn = fn + self.bot_name + '_critic_' + '.pth.tar' 171 | 172 | if os.path.exists(act_fn): 173 | self.action_network = load_ddt(act_fn) 174 | self.value_network = load_ddt(val_fn) 175 | 176 | def deepen_networks(self): 177 | pass 178 | 179 | def __getstate__(self): 180 | return { 181 | 'action_network': self.action_network, 182 | 'value_network': self.value_network, 183 | 'ppo': self.ppo, 184 | # 'actor_opt': self.actor_opt, 185 | # 'value_opt': self.value_opt, 186 | 'bot_name': self.bot_name, 187 | 'rule_list': self.rule_list, 188 | 'output_dim': self.output_dim, 189 | 'input_dim': self.input_dim, 190 | 'num_rules': self.num_rules 191 | } 192 | 193 | def __setstate__(self, state): 194 | for key in state: 195 | setattr(self, key, state[key]) 196 | 197 | def duplicate(self): 198 | new_agent = DDTAgent(bot_name=self.bot_name, 199 | input_dim=self.input_dim, 200 | output_dim=self.output_dim, 201 | rule_list=self.rule_list, 202 | num_rules=self.num_rules 203 | ) 204 | new_agent.__setstate__(self.__getstate__()) 205 | return new_agent 206 | -------------------------------------------------------------------------------- /agents/djinn_helpers.py: -------------------------------------------------------------------------------- 1 | # Created by Andrew Silva on 10/9/18 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | # from sklearn.metrics import mean_squared_error 6 | # from sklearn.model_selection import train_test_split 7 | from torch.autograd import Variable 8 | 9 | use_gpu = torch.cuda.is_available() 10 | 11 | DJINN_TREE_DATA = { 12 | 'cart': { 13 | 'feature': np.array([[0], [0], [2], [-2], [-2], [2], [-2], [1], [2], [-2], [-2], [3], 14 | [-2], [-2], [2], [1], [3], [-2], [-2], [2], [-2], [-2], [-2]]), 15 | 'children_left': np.array([1, 2, 3, -1, -1, 6, -1, 8, 9, -1, -1, 12, -1, -1, 15, 16, 17, -1, -1, 20, -1, -1, -1]), 16 | 'children_right': np.array([14, 5, 4, -1, -1, 7, -1, 11, 10, -1, -1, 13, -1, -1, 22, 19, 18, -1, -1, 21, -1, -1, -1]) 17 | 18 | # 'children_right': np.array( 19 | # [1, 2, 3, -1, -1, 6, -1, 8, 9, -1, -1, 12, -1, -1, 15, 16, 17, -1, -1, 20, -1, -1, -1]), 20 | # 'children_left': np.array( 21 | # [14, 5, 4, -1, -1, 7, -1, 11, 10, -1, -1, 13, -1, -1, 22, 19, 18, -1, -1, 21, -1, -1, -1]) 22 | }, 23 | 'lunar': { 24 | 'feature': np.array([[1], [3], [5], [6, 7], [-2], [4], [-2], [-2], [-2], [6, 7], [-2], [0], [-2], [0], [-2], 25 | [-2], [5], [5], [6, 7], [-2], [-2], [0], [-2], [0], [-2], [-2], [6, 7], [-2], [-2]]), 26 | 'children_left': np.array([1, 2, 3, 4, -1, 6, -1, -1, -1, 10, -1, 12, -1, 14, -1, -1, 27 | 17, 18, 19, -1, -1, 22, -1, 24, -1, -1, 27, -1, -1]), 28 | 'children_right': np.array([16, 9, 8, 5, -1, 7, -1, -1, -1, 11, -1, 13, -1, 15, -1, 29 | -1, 26, 21, 20, -1, -1, 23, -1, 25, -1, -1, 28, -1, -1]) 30 | # 'children_right': np.array([1, 2, 3, 4, -1, 6, -1, -1, -1, 10, -1, 12, -1, 14, -1, 31 | # -1, 17, 18, 19, -1, -1, 22, -1, 24, -1, -1, 27, -1, -1]), 32 | # 'children_left': np.array([16, 9, 8, 5, -1, 7, -1, -1, -1, 11, -1, 13, -1, 15, -1, 33 | # -1, 26, 21, 20, -1, -1, 23, -1, 25, -1, -1, 28, -1, -1]) 34 | 35 | }, 36 | 'sc_micro': { 37 | 'feature': np.array([[14], [-2], [1], [0], [-2], [0], [-2], [-2], [1], [-2], [0], [-2], [-2]]), 38 | # 'children_left': np.array([1, -1, 3, 4, -1, 6, -1, -1, 9, -1, 11, -1, -1]), 39 | # 'children_right': np.array([2, -1, 8, 5, -1, 7, -1, -1, 10, -1, 12, -1, -1]) 40 | 'children_right': np.array([1, -1, 3, 4, -1, 6, -1, -1, 9, -1, 11, -1, -1]), 41 | 'children_left': np.array([2, -1, 8, 5, -1, 7, -1, -1, 10, -1, 12, -1, -1]) 42 | 43 | }, 44 | 'sc_macro': { 45 | 'feature': np.array([[10, 12], [45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 46 | 62, 63, 64, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 47 | 96, 97, 98, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 48 | 130, 131, 132, 133, 134, 135, 136, 137, 138], 49 | [-2], [65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 99, 100, 50 | 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 51 | 117, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 52 | 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 53 | 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181], 54 | [-2], [-2], [4], [-2], [3, 2], [-2], [9, 157], [-2], [30, 178], [38, 186], [22, 170], [-2], 55 | [-2], [38, 186], [-2], [-2], [31, 179], [10, 158], [34, 182], [-2], [-2], [-2], [-2]]), 56 | 'children_left': np.array([1, 2, -1, 4, -1, -1, 7, -1, 9, -1, 11, -1, 13, 14, 15, 57 | -1, -1, 18, -1, -1, 21, 22, 23, -1, -1, -1, -1]), 58 | 'children_right': np.array([6, 3, -1, 5, -1, -1, 8, -1, 10, -1, 12, -1, 20, 17, 59 | 16, -1, -1, 19, -1, -1, 26, 25, 24, -1, -1, -1, -1]) 60 | }, 61 | 'fire': { 62 | 'feature': np.array([[4], [0], [-2], [0], [-2], [1], [-2], [1], [-2], [-2], [2], 63 | [-2], [2], [-2], [3], [-2], [3], [-2], [-2]]), 64 | # 'children_left': np.array([1, 2, -1, 4, -1, 6, -1, 8, -1, -1, 11, -1, 13, -1, 15, -1, 17, -1, -1]), 65 | # 'children_right': np.array([10, 3, -1, 5, -1, 7, -1, 9, -1, -1, 12, -1, 14, -1, 16, -1, 18, -1, -1]) 66 | 'children_right': np.array([1, 2, -1, 4, -1, 6, -1, 8, -1, -1, 11, -1, 13, -1, 15, -1, 17, -1, -1]), 67 | 'children_left': np.array([10, 3, -1, 5, -1, 7, -1, 9, -1, -1, 12, -1, 14, -1, 16, -1, 18, -1, -1]) 68 | } 69 | } 70 | 71 | 72 | def save_checkpoint(djinn_model, filename='djinn.pth.tar'): 73 | """ 74 | Helper function to save checkpoints of PyTorch models 75 | :param state: Everything to save with the checkpoint (params/weights, optimizer state, epoch, loss function, etc.) 76 | :param filename: Filename to save the checkpoint under 77 | :return: None 78 | """ 79 | if type(djinn_model) is not list: 80 | djinn_model = [djinn_model] 81 | master_list = [] 82 | for model in djinn_model: 83 | state_dict = model.state_dict() 84 | master_list.append(state_dict) 85 | torch.save(master_list, filename) 86 | 87 | 88 | def load_checkpoint(filename='tree0.pth.tar', drop_prob=0.0): 89 | if not use_gpu: 90 | model_checkpoint_master = torch.load(filename, map_location='cpu') 91 | else: 92 | model_checkpoint_master = torch.load(filename) 93 | master_model_list = [] 94 | for model_checkpoint in model_checkpoint_master: 95 | num_layers = len(model_checkpoint.keys())//2 96 | fake_weights = [] 97 | for i in range(num_layers): 98 | fake_weights.append(np.random.random_sample([4, 4])) 99 | new_net = PyDJINN(5, fake_weights, fake_weights) 100 | for index in range(0, num_layers-1): 101 | # weight_key = 'weight'+str(index) 102 | # bias_key = 'bias'+str(index) 103 | weight_key_value_pair = model_checkpoint.popitem(last=False) 104 | bias_key_value_pair = model_checkpoint.popitem(last=False) 105 | new_layer = nn.Linear(4, 4) 106 | new_layer.weight.data = weight_key_value_pair[1] 107 | new_layer.bias.data = bias_key_value_pair[1] 108 | new_seq = nn.Sequential( 109 | new_layer, 110 | nn.ReLU(), 111 | nn.Dropout(drop_prob) 112 | ) 113 | new_net.layers[index] = new_seq 114 | new_net.final_layer.weight.data = model_checkpoint['final_layer.weight'] 115 | new_net.final_layer.bias.data = model_checkpoint['final_layer.bias'] 116 | master_model_list.append(new_net) 117 | return master_model_list 118 | 119 | def xavier_init(dim_in, dim_out): 120 | dist = np.random.normal(0.0, scale=np.sqrt(3.0/(dim_in+dim_out))) 121 | return dist 122 | 123 | 124 | def tree_to_nn_weights(dim_in, dim_out, tree_dict): 125 | """ 126 | :param dim_in: input data (batch first) 127 | :param dim_out: output data (batch first) 128 | :param tree_dict: dictionary of features, left children, right children. 129 | :return: 130 | """ 131 | 132 | tree_to_net = { 133 | 'input_dim': dim_in, 134 | 'output_dim': dim_out, 135 | 'net_shape': {}, 136 | 'weights': {}, 137 | 'biases': {} 138 | } 139 | 140 | features = tree_dict['feature'] 141 | children_left = tree_dict['children_left'] 142 | children_right = tree_dict['children_right'] 143 | num_nodes = len(features) 144 | 145 | node_depth = np.zeros(num_nodes, dtype=np.int64) 146 | is_leaves = np.zeros(num_nodes, dtype=np.int64) 147 | stack = [(0, -1)] # node id and parent depth of the root (0th id, no parents means -1...) 148 | while len(stack) > 0: 149 | # Recurse through all nodes, adding all left nodes and then right nodes to the stack 150 | # (we go all the way left before going down any right,then we go right from the bottom-up) 151 | node_id, parent_depth = stack.pop() 152 | node_depth[node_id] = parent_depth + 1 153 | if children_left[node_id] != children_right[node_id]: 154 | stack.append((children_left[node_id], parent_depth+1)) 155 | stack.append((children_right[node_id], parent_depth+1)) 156 | else: 157 | # If left == right, they're both -1 and it's a leaf 158 | is_leaves[node_id] = 1 159 | 160 | # This bit appears to just recreate the information above, but more centralized (attach all nodes to their depth 161 | # and make sure they have children info attached) 162 | node_dict = {} 163 | for i in range(len(features)): 164 | node_dict[i] = {} 165 | node_dict[i]['depth'] = node_depth[i] 166 | if np.any([f > 0 for f in features[i]]): 167 | # if features[i] >= 0: 168 | node_dict[i]['features'] = features[i] 169 | else: 170 | node_dict[i]['features'] = [-2] 171 | node_dict[i]['child_left'] = features[children_left[i]] 172 | node_dict[i]['child_right'] = features[children_right[i]] 173 | 174 | num_layers = len(np.unique(node_depth)) 175 | nodes_per_level = np.zeros(num_layers) 176 | leaves_per_level = np.zeros(num_layers) 177 | 178 | # For each layer, get number of nodes at that layer. If feature is below zero, it's a leaf. Otherwise, it's a node. 179 | for i in range(num_layers): 180 | ind = np.where(node_depth == i)[0] 181 | all_feats = [] 182 | for f in features[ind]: 183 | all_feats.extend(f) 184 | all_feats = np.array(all_feats) 185 | nodes_per_level[i] = len(np.where(all_feats >= 0)[0]) 186 | leaves_per_level[i] = len(np.where(all_feats < 0)[0]) 187 | 188 | max_depth_feature = np.zeros(dim_in) 189 | # Find the deepest feature... for some reason? 190 | for i in range(len(max_depth_feature)): 191 | ind = np.where(features == i)[0] 192 | if len(ind) > 0: 193 | max_depth_feature[i] = np.max(node_depth[ind]) 194 | 195 | djinn_arch = np.zeros(num_layers, dtype=np.int64) 196 | 197 | djinn_arch[0] = dim_in 198 | for i in range(1, num_layers): 199 | djinn_arch[i] = djinn_arch[i-1] + nodes_per_level[i] 200 | djinn_arch[-1] = dim_out 201 | 202 | djinn_weights = {} 203 | for i in range(num_layers-1): 204 | djinn_weights[i] = np.zeros((djinn_arch[i+1], djinn_arch[i])) 205 | 206 | new_indices = [] 207 | for i in range(num_layers-1): 208 | input_dim = djinn_arch[i] 209 | output_dim = djinn_arch[i+1] 210 | new_indices.append(np.arange(input_dim, output_dim)) 211 | for f in range(dim_in): 212 | if i < max_depth_feature[f]-1: 213 | djinn_weights[i][f, f] = 1.0 214 | input_index = 0 215 | output_index = 0 216 | for index, node in node_dict.items(): 217 | if node['depth'] != i or node['features'][0] < 0: 218 | continue 219 | feature = node['features'] 220 | left = node['child_left'] 221 | right = node['child_right'] 222 | if index == 0 and (left[0] < 0 or right[0] < 0): 223 | for j in range(i, num_layers-2): 224 | djinn_weights[j][feature, feature] = 1.0 225 | djinn_weights[num_layers-2][:, feature] = 1.0 226 | if left[0] >= 0: 227 | if i == 0: 228 | djinn_weights[i][new_indices[i][input_index], 229 | feature] = xavier_init(input_dim, output_dim) 230 | else: 231 | djinn_weights[i][new_indices[i][input_index], 232 | new_indices[i-1][output_index]] = xavier_init(input_dim, output_dim) 233 | 234 | djinn_weights[i][new_indices[i][input_index], left] = xavier_init(input_dim, output_dim) 235 | input_index += 1 236 | # TODO: comment below? 237 | if output_index >= len(new_indices[i-1]): 238 | output_index = 0 239 | 240 | if left[0] < 0 and index != 0: 241 | leaf_ind = new_indices[i-1][output_index] 242 | for j in range(i, num_layers-2): 243 | djinn_weights[j][leaf_ind, leaf_ind] = 1.0 244 | djinn_weights[num_layers-2][:, leaf_ind] = 1.0 245 | 246 | if right[0] >= 0: 247 | if i == 0: 248 | djinn_weights[i][new_indices[i][input_index], 249 | feature] = xavier_init(input_dim, output_dim) 250 | else: 251 | djinn_weights[i][new_indices[i][input_index], 252 | new_indices[i-1][output_index]] = xavier_init(input_dim, output_dim) 253 | 254 | djinn_weights[i][new_indices[i][input_index], right] = xavier_init(input_dim, output_dim) 255 | input_index += 1 256 | if output_index >= len(new_indices[i-1]): 257 | output_index = 0 258 | 259 | if right[0] < 0 and index != 0: 260 | leaf_ind = new_indices[i-1][output_index] 261 | for j in range(i, num_layers-2): 262 | djinn_weights[j][leaf_ind:leaf_ind] = 1.0 263 | djinn_weights[num_layers-2][:, leaf_ind] = 1.0 264 | output_index += 1 265 | 266 | m = len(new_indices[-2]) 267 | ind = np.where(abs(djinn_weights[num_layers-3][:, -m:]) > 0)[0] 268 | for indices in range(len(djinn_weights[num_layers-2][:, ind])): 269 | djinn_weights[num_layers-2][indices, ind] = xavier_init(input_dim, output_dim) 270 | 271 | # Convert into a single array because we're not ensembling 272 | n_hidden = {} 273 | for i in range(1, len(djinn_arch) - 1): 274 | n_hidden[i] = djinn_arch[i] 275 | 276 | # existing weights from above, biases could be improved... ignoring for now 277 | w = [] 278 | b = [] 279 | for i in range(0, len(djinn_arch) - 1): 280 | w.append(djinn_weights[i].astype(np.float32)) 281 | # b.append(xavier_init(n_hidden[i], n_hidden[i+1])) # This might be wrong 282 | tree_to_net['net_shape'] = djinn_arch 283 | tree_to_net['weights'] = w 284 | tree_to_net['biases'] = [] # biases? 285 | 286 | return tree_to_net 287 | 288 | 289 | class PyDJINN(nn.Module): 290 | def __init__(self, input_dim, weights, biases, drop_prob=0.5, is_value=False): 291 | super(PyDJINN, self).__init__() 292 | self.layers = nn.ModuleList() 293 | self.input_dim = input_dim 294 | self.layers.append(nn.Linear(input_dim, weights[0].shape[0])) 295 | weight_inits = torch.Tensor(weights[0]) 296 | weight_inits.requires_grad = True 297 | self.layers[0].weight.data = weight_inits 298 | # self.layers[0].bias.data.fill_(biases[0]) 299 | last_dim = weights[0].shape[0] 300 | for index in range(1, len(weights)-1): 301 | new_linear_layer = nn.Linear(last_dim, weights[index].shape[0]) 302 | weight_inits = torch.Tensor(weights[index]) 303 | weight_inits.requires_grad = True 304 | new_linear_layer.weight.data = weight_inits 305 | # new_linear_layer.bias.data.fill_(biases[index]) 306 | new_layer = nn.Sequential( 307 | new_linear_layer, 308 | nn.ReLU(), 309 | nn.Dropout(drop_prob) 310 | ) 311 | self.layers.append(new_layer) 312 | last_dim = weights[index].shape[0] 313 | self.final_layer = nn.Linear(last_dim, weights[-1].shape[0]) 314 | weight_inits = torch.Tensor(weights[-1]) 315 | weight_inits.requires_grad = True 316 | self.final_layer.weight.data = weight_inits 317 | # self.final_layer.bias.data.fill_(biases['out']) 318 | self.softmax = nn.Softmax(dim=1) 319 | self.is_value = is_value 320 | 321 | def forward(self, input_data): 322 | for layer in self.layers: 323 | input_data = layer(input_data) 324 | # print(input_data) 325 | if self.is_value: 326 | return self.final_layer(input_data) 327 | else: 328 | return self.softmax(self.final_layer(input_data)) 329 | -------------------------------------------------------------------------------- /agents/heuristic_agent.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | sys.path.insert(0, '../') 4 | 5 | class CartPoleHeuristic: 6 | def __init__(self, 7 | bot_name='CartPoleHeuristic', 8 | params=None): 9 | self.bot_name = '../txts/'+bot_name 10 | from opt_helpers import replay_buffer 11 | self.replay_buffer = replay_buffer.ReplayBufferSingleAgent() 12 | 13 | def get_action(self, observation): 14 | ## THIS IS THE REAL PROLO INIT 15 | action = random.choice([0, 1]) 16 | if observation[0] > -1: 17 | if observation[0] < 1: 18 | if observation[2] < 0: 19 | action = 0 20 | else: 21 | action = 1 22 | else: 23 | if observation[2] < 0: 24 | action = 0 25 | else: 26 | if observation[1] > 0: 27 | if observation[3] < 0: 28 | action = 0 29 | else: 30 | action = 1 31 | else: 32 | if observation[2] < 0: 33 | action = 0 34 | else: 35 | action = 1 36 | else: 37 | if observation[2] < 0: 38 | if observation[1] < 0: 39 | if observation[3] > 0: 40 | action = 1 41 | else: 42 | action = 0 43 | else: 44 | if observation[2] < 0: 45 | action = 0 46 | else: 47 | action = 1 48 | return action 49 | 50 | def end_episode(self, timesteps=0, num_procs=None): 51 | with open(self.bot_name + '_rewards.txt', 'a') as myfile: 52 | myfile.write(str(timesteps) + '\n') 53 | self.reset() 54 | 55 | def save(self, fn=None): 56 | pass 57 | 58 | def save_reward(self, reward): 59 | self.replay_buffer.insert(value_preds=10, 60 | rewards=reward) 61 | return True 62 | 63 | def reset(self): 64 | self.replay_buffer.clear() 65 | 66 | def duplicate(self): 67 | return self 68 | 69 | def lower_lr(self): 70 | pass 71 | 72 | 73 | class LunarHeuristic: 74 | def __init__(self, 75 | bot_name='Lunar_Heuristic'): 76 | self.bot_name = '../txts/'+bot_name 77 | from opt_helpers import replay_buffer 78 | self.replay_buffer = replay_buffer.ReplayBufferSingleAgent() 79 | 80 | def get_action(self, observation): 81 | if observation[1] < 1.1: # 0 82 | if observation[3] < 0.2: # 1 83 | if observation[5] < 0: # 3 84 | if (observation[6] + observation[7]) > 1.2: # 7 85 | action = 0 86 | else: 87 | if observation[4] > -0.1: # 11 88 | action = 2 89 | else: 90 | action = 1 91 | else: 92 | action = 3 93 | else: 94 | if (observation[6] + observation[7]) > 1.2: # 4 95 | action = 0 96 | else: 97 | if observation[0] > 0.2: # 8 98 | action = 1 99 | else: 100 | if observation[0] < -0.2: # 12 101 | action = 3 102 | else: 103 | action = 0 104 | else: 105 | if observation[5] > 0.1: # 2 106 | if observation[5] < -0.1: # 5 107 | if (observation[6] + observation[7]) > 1.2: # 9 108 | action = 0 109 | else: 110 | action = 1 111 | else: 112 | if observation[0] > 0.2: # 10 113 | action = 1 114 | else: 115 | if observation[0] < -0.2: # 13 116 | action = 3 117 | else: 118 | action = 0 119 | else: 120 | if (observation[6] + observation[7]) > 1.2: # 6 121 | action = 0 122 | else: 123 | action = 3 124 | return action 125 | 126 | def lower_lr(self): 127 | pass 128 | 129 | def end_episode(self, timesteps=0, num_procs=None): 130 | with open(self.bot_name + '_rewards.txt', 'a') as myfile: 131 | myfile.write(str(timesteps) + '\n') 132 | self.reset() 133 | 134 | def save(self, fn=None): 135 | pass 136 | 137 | def save_reward(self, reward): 138 | self.replay_buffer.insert(value_preds=10, 139 | rewards=reward) 140 | return True 141 | 142 | def reset(self): 143 | self.replay_buffer.clear() 144 | 145 | def duplicate(self): 146 | return self 147 | 148 | 149 | class StarCraftMacroHeuristic: 150 | def __init__(self, 151 | bot_name='SC_Macro_Heuristic'): 152 | from opt_helpers import replay_buffer 153 | self.replay_buffer = replay_buffer.ReplayBufferSingleAgent() 154 | self.bot_name = '../txts/'+bot_name 155 | 156 | def get_action(self, observation): 157 | if observation[10] + observation[12] > 12: # Attackers > 12 158 | if sum(observation[45:65])+sum(observation[82:99])+sum(observation[118:139]) > 4: # enemy units 159 | action = 41 160 | else: 161 | if sum(observation[65:82])+sum(observation[99:118])+sum(observation[139:157]) > 0: # enemy buildings 162 | action = 39 163 | else: 164 | action = 42 165 | else: 166 | if observation[4] > 0.5: # idle workers 167 | action = 40 168 | else: 169 | if observation[3]-observation[2] < 4: # low supply 170 | action = 1 171 | else: 172 | if observation[9]+observation[157] < 15: # few probes 173 | action = 16 174 | else: 175 | if observation[30]+observation[178] > 0: # no assimilators 176 | if observation[38]+observation[186] > 1.5: # stargates 1.5 177 | if observation[22]+observation[170] > 7: # voidrays 178 | action = 39 179 | else: 180 | action = 29 181 | else: 182 | if observation[38] + observation[186] > 0.5: # stargates 0.5 183 | action = 0 184 | else: 185 | action = 10 186 | else: 187 | if observation[31]+observation[179] > 0.5: # gateway > 0.5 188 | if observation[10]+observation[158] > 6: 189 | if observation[34]+observation[182] > 0.5: 190 | action = 2 191 | else: 192 | action = 6 193 | else: 194 | action = 17 195 | else: 196 | action = 3 197 | return action 198 | 199 | def end_episode(self, timesteps=0, num_procs=None): 200 | with open(self.bot_name + '_rewards.txt', 'a') as myfile: 201 | myfile.write(str(timesteps) + '\n') 202 | self.reset() 203 | 204 | def save(self, fn=None): 205 | pass 206 | 207 | def save_reward(self, reward): 208 | self.replay_buffer.insert(value_preds=10, 209 | rewards=reward) 210 | return True 211 | 212 | def lower_lr(self): 213 | pass 214 | 215 | def reset(self): 216 | self.replay_buffer.clear() 217 | 218 | def duplicate(self): 219 | return self 220 | 221 | 222 | class StarCraftMicroHeuristic: 223 | def __init__(self, 224 | bot_name='SC_Micro_Heuristic'): 225 | from opt_helpers import replay_buffer 226 | self.replay_buffer = replay_buffer.ReplayBufferSingleAgent() 227 | self.bot_name = '../txts/'+bot_name 228 | 229 | def get_action(self, observation): 230 | if observation[14] > 0: 231 | action = 4 232 | else: 233 | if observation[1] > 30: 234 | if -observation[0] > -20: 235 | action = 2 236 | else: 237 | if observation[0] > 40: 238 | action = 0 239 | else: 240 | action = 3 241 | else: 242 | if observation[1] > 18: 243 | action = 2 244 | else: 245 | if -observation[0] > -40: 246 | action = 1 247 | else: 248 | action = 0 249 | return action 250 | 251 | def end_episode(self, timesteps=0, num_procs=None): 252 | with open(self.bot_name + '_rewards.txt', 'a') as myfile: 253 | myfile.write(str(timesteps) + '\n') 254 | self.reset() 255 | 256 | def save(self, fn=None): 257 | pass 258 | 259 | def lower_lr(self): 260 | pass 261 | 262 | def save_reward(self, reward): 263 | self.replay_buffer.insert(value_preds=10, 264 | rewards=reward) 265 | return True 266 | 267 | def reset(self): 268 | self.replay_buffer.clear() 269 | 270 | def duplicate(self): 271 | return self 272 | 273 | 274 | class RandomHeuristic: 275 | def __init__(self, 276 | bot_name='RandomBot', 277 | action_dim=2): 278 | self.action_dim = action_dim 279 | from opt_helpers import replay_buffer 280 | self.replay_buffer = replay_buffer.ReplayBufferSingleAgent() 281 | self.bot_name = '../txts/'+bot_name 282 | 283 | def get_action(self, observation): 284 | return random.randint(0, self.action_dim-1) 285 | 286 | def end_episode(self, timesteps=0, num_procs=None): 287 | with open(self.bot_name + '_rewards.txt', 'a') as myfile: 288 | myfile.write(str(timesteps) + '\n') 289 | self.reset() 290 | 291 | def save(self, fn=None): 292 | pass 293 | 294 | def save_reward(self, reward): 295 | self.replay_buffer.insert(value_preds=10, 296 | rewards=reward) 297 | return True 298 | 299 | def lower_lr(self): 300 | pass 301 | 302 | def reset(self): 303 | self.replay_buffer.clear() 304 | 305 | def duplicate(self): 306 | return self 307 | -------------------------------------------------------------------------------- /agents/lstm_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Categorical 4 | import sys 5 | sys.path.insert(0, '../') 6 | from opt_helpers import replay_buffer, ppo_update 7 | import copy 8 | 9 | use_gpu = False 10 | 11 | 12 | class BaselineLSTMNet(nn.Module): 13 | def __init__(self, input_dim, output_dim, is_value=False): 14 | super(BaselineLSTMNet, self).__init__() 15 | self.num_layers = 1 16 | self.input_dim = input_dim 17 | self.batch_size = 1 18 | self.hidden_dim = input_dim 19 | self.lin1 = nn.Linear(input_dim, input_dim) 20 | self.rnn = nn.LSTM(input_dim, self.hidden_dim, self.num_layers, batch_first=True) 21 | self.lin2 = nn.Linear(self.hidden_dim, input_dim) 22 | self.lin3 = nn.Linear(input_dim, output_dim) 23 | self.sig = nn.ReLU() 24 | if input_dim > 10: 25 | self.lin1 = nn.Sequential( 26 | nn.Linear(input_dim, 256), 27 | nn.ReLU(), 28 | nn.Linear(256, 256), 29 | ) 30 | self.softmax = nn.Softmax(dim=2) 31 | self.is_value = is_value 32 | 33 | def init_hidden(self, batch_size=1): 34 | """ 35 | Initialize hidden state so that it can be clean for each new series of inputs 36 | :return: Variable of zeros of shape (num_layers, minibatch_size, hidden_dim) 37 | """ 38 | first_dim = self.num_layers 39 | second_dim = batch_size 40 | self.batch_size = batch_size 41 | third_dim = self.hidden_dim 42 | if use_gpu: 43 | return (torch.zeros(first_dim, second_dim, third_dim).cuda(), 44 | torch.zeros(first_dim, second_dim, third_dim).cuda()) 45 | else: 46 | return (torch.zeros(first_dim, second_dim, third_dim), 47 | torch.zeros(first_dim, second_dim, third_dim)) 48 | 49 | def forward(self, input_data, hidden_state): 50 | act_out = self.sig(self.lin1(input_data)) 51 | act_out, hidden_out = self.rnn(act_out, hidden_state) 52 | act_out = self.lin3(self.sig(self.lin2(self.sig(act_out)))) 53 | if self.is_value: 54 | return act_out, hidden_out 55 | else: 56 | return self.softmax(act_out), hidden_out 57 | 58 | 59 | class LSTMNet: 60 | def __init__(self, 61 | bot_name='LSTMNet', 62 | input_dim=4, 63 | output_dim=2, 64 | epsilon=0.9, 65 | epsilon_decay=0.95, 66 | epsilon_min=0.05): 67 | self.epsilon = epsilon 68 | self.epsilon_decay = epsilon_decay 69 | self.epsilon_min = epsilon_min 70 | self.bot_name = bot_name 71 | self.input_dim = input_dim 72 | self.output_dim = output_dim 73 | self.replay_buffer = replay_buffer.ReplayBufferSingleAgent() 74 | self.action_network = BaselineLSTMNet(input_dim=input_dim, 75 | output_dim=output_dim, 76 | is_value=False) 77 | self.value_network = BaselineLSTMNet(input_dim=input_dim, 78 | output_dim=output_dim, 79 | is_value=True) 80 | self.ppo = ppo_update.PPO([self.action_network, self.value_network], two_nets=True) 81 | self.actor_opt = torch.optim.RMSprop(self.action_network.parameters(), lr=1e-5) 82 | self.value_opt = torch.optim.RMSprop(self.value_network.parameters(), lr=1e-5) 83 | 84 | self.last_state = [0, 0, 0, 0] 85 | self.last_action = 0 86 | self.last_action_probs = torch.Tensor([0]) 87 | self.last_value_pred = torch.Tensor([[0, 0]]) 88 | self.action_hidden_state = None 89 | self.full_probs = None 90 | self.value_hidden_state = None 91 | self.new_action_hidden = None 92 | self.new_value_hidden = None 93 | self.reward_history = [] 94 | self.num_steps = 0 95 | 96 | def get_action(self, observation): 97 | if self.action_hidden_state is None: 98 | self.action_hidden_state = self.action_network.init_hidden() 99 | self.value_hidden_state = self.value_network.init_hidden() 100 | self.new_action_hidden = self.action_network.init_hidden() 101 | self.new_value_hidden = self.value_network.init_hidden() 102 | with torch.no_grad(): 103 | obs = torch.Tensor(observation) 104 | obs = obs.view(1, 1, -1) 105 | self.last_state = obs 106 | probs, action_hidden = self.action_network(obs, self.action_hidden_state) 107 | value_pred, value_hidden = self.value_network(obs, self.value_hidden_state) 108 | probs = probs.view(-1) 109 | self.full_probs = probs 110 | if self.action_network.input_dim > 30: 111 | probs, inds = torch.topk(probs, 3) 112 | m = Categorical(probs) 113 | action = m.sample() 114 | # # Epsilon learning 115 | # action = torch.argmax(probs) 116 | # if random.random() < self.epsilon: 117 | # action = torch.LongTensor([random.choice(np.arange(0, self.output_dim, dtype=np.int))]) 118 | # # End Epsilon learning 119 | log_probs = m.log_prob(action) 120 | self.new_action_hidden = action_hidden 121 | self.new_value_hidden = value_hidden 122 | self.last_action_probs = log_probs 123 | self.last_value_pred = value_pred.view(-1).cpu() 124 | 125 | if self.action_network.input_dim > 30: 126 | self.last_action = inds[action] 127 | else: 128 | self.last_action = action 129 | if self.action_network.input_dim > 30: 130 | action = inds[action].item() 131 | else: 132 | action = action.item() 133 | return action 134 | 135 | def save_reward(self, reward): 136 | self.replay_buffer.insert(obs=[self.last_state], 137 | recurrent_hidden_states=(self.action_hidden_state, self.value_hidden_state), 138 | action_log_probs=self.last_action_probs, 139 | value_preds=self.last_value_pred[self.last_action.item()], 140 | last_action=self.last_action, 141 | full_probs_vector=self.full_probs, 142 | rewards=reward) 143 | self.action_hidden_state = self.new_action_hidden 144 | self.value_hidden_state = self.new_value_hidden 145 | return True 146 | 147 | def end_episode(self, timesteps, num_processes=1): 148 | self.reward_history.append(timesteps) 149 | value_loss, action_loss = self.ppo.batch_updates(self.replay_buffer, self) 150 | bot_name = '../txts/' + self.bot_name + str(num_processes) + '_processes' 151 | # with open(bot_name + "_losses.txt", "a") as myfile: 152 | # myfile.write(str(value_loss + action_loss) + '\n') 153 | with open(bot_name + '_rewards.txt', 'a') as myfile: 154 | myfile.write(str(timesteps) + '\n') 155 | self.epsilon = max(self.epsilon*self.epsilon_decay, self.epsilon_min) 156 | 157 | def lower_lr(self): 158 | for param_group in self.ppo.actor_opt.param_groups: 159 | param_group['lr'] = param_group['lr'] * 0.5 160 | for param_group in self.ppo.critic_opt.param_groups: 161 | param_group['lr'] = param_group['lr'] * 0.5 162 | 163 | def reset(self): 164 | self.replay_buffer.clear() 165 | 166 | def deepen_networks(self): 167 | pass 168 | 169 | def save(self, fn='last'): 170 | checkpoint = dict() 171 | checkpoint['actor'] = self.action_network.state_dict() 172 | checkpoint['value'] = self.value_network.state_dict() 173 | torch.save(checkpoint, fn+self.bot_name+'.pth.tar') 174 | 175 | def load(self, fn='last'): 176 | fn = fn+self.bot_name+'.pth.tar' 177 | model_checkpoint = torch.load(fn, map_location='cpu') 178 | actor_data = model_checkpoint['actor'] 179 | value_data = model_checkpoint['value'] 180 | self.action_network.load_state_dict(actor_data) 181 | self.value_network.load_state_dict(value_data) 182 | 183 | def __getstate__(self): 184 | return { 185 | 'action_network': self.action_network, 186 | 'value_network': self.value_network, 187 | 'ppo': self.ppo, 188 | 'actor_opt': self.actor_opt, 189 | 'value_opt': self.value_opt, 190 | 'epsilon': self.epsilon, 191 | 'epsilon_decay': self.epsilon_decay, 192 | 'epsilon_min': self.epsilon_min, 193 | } 194 | 195 | def __setstate__(self, state): 196 | self.action_network = copy.deepcopy(state['action_network']) 197 | self.value_network = copy.deepcopy(state['value_network']) 198 | self.ppo = copy.deepcopy(state['ppo']) 199 | self.actor_opt = copy.deepcopy(state['actor_opt']) 200 | self.value_opt = copy.deepcopy(state['value_opt']) 201 | self.epsilon_min = copy.deepcopy(state['epsilon_min']) 202 | self.epsilon_decay = copy.deepcopy(state['epsilon_decay']) 203 | self.epsilon = copy.deepcopy(state['epsilon']) 204 | 205 | def duplicate(self): 206 | new_agent = LSTMNet(bot_name=self.bot_name, 207 | input_dim=self.input_dim, 208 | output_dim=self.output_dim, 209 | epsilon=self.epsilon, 210 | epsilon_decay=self.epsilon_decay, 211 | epsilon_min=self.epsilon_min) 212 | new_agent.__setstate__(self.__getstate__()) 213 | return new_agent 214 | -------------------------------------------------------------------------------- /agents/prolonet_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Categorical 3 | import sys 4 | sys.path.insert(0, '../') 5 | from opt_helpers import replay_buffer, ppo_update 6 | from agents.vectorized_prolonet_helpers import init_cart_nets, swap_in_node, add_level, \ 7 | init_lander_nets, init_micro_net, init_adversarial_net, init_sc_nets, \ 8 | init_sc_build_hellions_net, save_prolonet, load_prolonet, init_fire_nets 9 | import copy 10 | import os 11 | 12 | class DeepProLoNet: 13 | def __init__(self, 14 | distribution='one_hot', 15 | bot_name='ProLoNet', 16 | input_dim=4, 17 | output_dim=2, 18 | use_gpu=False, 19 | vectorized=False, 20 | randomized=False, 21 | adversarial=False, 22 | deepen=True, 23 | epsilon=0.9, 24 | epsilon_decay=0.95, 25 | epsilon_min=0.05, 26 | deterministic=False): 27 | self.replay_buffer = replay_buffer.ReplayBufferSingleAgent() 28 | self.bot_name = bot_name 29 | self.use_gpu = use_gpu 30 | self.vectorized = vectorized 31 | self.randomized = randomized 32 | self.adversarial = adversarial 33 | self.deepen = deepen 34 | self.output_dim = output_dim 35 | self.input_dim = input_dim 36 | self.adv_prob = .05 37 | self.epsilon = epsilon 38 | self.epsilon_decay = epsilon_decay 39 | self.epsilon_min = epsilon_min 40 | self.deterministic = deterministic 41 | self.lr = 2e-2 42 | if vectorized: 43 | self.bot_name += '_vect' 44 | if randomized: 45 | self.bot_name += '_rand' 46 | if use_gpu: 47 | self.bot_name += '_gpu' 48 | if deepen: 49 | self.bot_name += '_deepening' 50 | if input_dim == 4 and output_dim == 2: # CartPole 51 | self.action_network, self.value_network = init_cart_nets(distribution, use_gpu, vectorized, randomized) 52 | if adversarial: 53 | self.action_network, self.value_network = init_adversarial_net(adv_type='cart', 54 | distribution_in=distribution, 55 | adv_prob=self.adv_prob) 56 | self.bot_name += '_adversarial' + str(self.adv_prob) 57 | elif input_dim == 8 and output_dim == 4: # Lunar Lander 58 | self.lr = 2e-2 59 | self.action_network, self.value_network = init_lander_nets(distribution, use_gpu, vectorized, randomized) 60 | if adversarial: 61 | self.action_network, self.value_network = init_adversarial_net(adv_type='lunar', 62 | distribution_in=distribution, 63 | adv_prob=self.adv_prob) 64 | self.bot_name += '_adversarial' + str(self.adv_prob) 65 | elif input_dim == 194 and output_dim == 44: # SC Macro 66 | self.action_network, self.value_network = init_sc_nets(distribution, use_gpu, vectorized, randomized) 67 | elif input_dim == 37 and output_dim == 10: # SC Micro 68 | self.action_network, self.value_network = init_micro_net(distribution, use_gpu, vectorized, randomized) 69 | if adversarial: 70 | self.action_network, self.value_network = init_adversarial_net(adv_type='micro', 71 | distribution_in=distribution, 72 | adv_prob=self.adv_prob) 73 | self.bot_name += '_adversarial' + str(self.adv_prob) 74 | elif input_dim == 32 and output_dim == 12: # SC Build Hellions 75 | self.action_network, self.value_network = init_sc_build_hellions_net(distribution, use_gpu, vectorized, randomized) 76 | elif input_dim == 6 and output_dim == 5: # Fire Sim 77 | self.action_network, self.value_network = init_fire_nets(distribution, 78 | use_gpu, 79 | vectorized, 80 | randomized, 81 | bot_name.split('_')[0]) 82 | 83 | self.ppo = ppo_update.PPO([self.action_network, self.value_network], two_nets=True, use_gpu=use_gpu) 84 | self.actor_opt = torch.optim.RMSprop(self.action_network.parameters(), lr=1e-5) 85 | self.value_opt = torch.optim.RMSprop(self.value_network.parameters(), lr=1e-5) 86 | if self.deepen: 87 | self.deeper_action_network = add_level(self.action_network, use_gpu=use_gpu) 88 | self.deeper_value_network = add_level(self.value_network, use_gpu=use_gpu) 89 | 90 | self.deeper_actor_opt = torch.optim.RMSprop(self.deeper_action_network.parameters(), lr=self.lr) 91 | self.deeper_value_opt = torch.optim.RMSprop(self.deeper_value_network.parameters(), lr=self.lr) 92 | else: 93 | self.deeper_value_network = None 94 | self.deeper_action_network = None 95 | self.deeper_actor_opt = None 96 | self.deeper_value_opt = None 97 | self.num_times_deepened = 0 98 | self.last_state = [0, 0, 0, 0] 99 | self.last_action = 0 100 | self.last_action_probs = torch.Tensor([0]) 101 | self.last_value_pred = torch.Tensor([[0, 0]]) 102 | self.last_deep_action_probs = None 103 | self.last_deep_value_pred = [None]*output_dim 104 | self.full_probs = None 105 | self.deeper_full_probs = None 106 | self.reward_history = [] 107 | self.num_steps = 0 108 | 109 | def get_action(self, observation): 110 | with torch.no_grad(): 111 | obs = torch.Tensor(observation) 112 | obs = obs.view(1, -1) 113 | self.last_state = obs 114 | if self.use_gpu: 115 | obs = obs.cuda() 116 | probs = self.action_network(obs) 117 | value_pred = self.value_network(obs) 118 | probs = probs.view(-1).cpu() 119 | self.full_probs = probs 120 | if self.action_network.input_dim >= 33: 121 | probs, inds = torch.topk(probs, 3) 122 | m = Categorical(probs) 123 | action = m.sample() 124 | if self.deterministic: 125 | action = torch.argmax(probs) 126 | 127 | log_probs = m.log_prob(action) 128 | self.last_action_probs = log_probs.cpu() 129 | self.last_value_pred = value_pred.view(-1).cpu() 130 | 131 | if self.deepen: 132 | deeper_probs = self.deeper_action_network(obs) 133 | deeper_value_pred = self.deeper_value_network(obs) 134 | deeper_probs = deeper_probs.view(-1).cpu() 135 | self.deeper_full_probs = deeper_probs 136 | if self.action_network.input_dim >= 33: 137 | deeper_probs, _ = torch.topk(probs, 3) 138 | deep_m = Categorical(deeper_probs) 139 | deep_log_probs = deep_m.log_prob(action) 140 | self.last_deep_action_probs = deep_log_probs.cpu() 141 | self.last_deep_value_pred = deeper_value_pred.view(-1).cpu() 142 | if self.action_network.input_dim >= 33: 143 | self.last_action = inds[action].cpu() 144 | else: 145 | self.last_action = action.cpu() 146 | if self.action_network.input_dim >= 33: 147 | action = inds[action].item() 148 | else: 149 | action = action.item() 150 | return action 151 | 152 | def save_reward(self, reward): 153 | self.replay_buffer.insert(obs=[self.last_state], 154 | action_log_probs=self.last_action_probs, 155 | value_preds=self.last_value_pred[self.last_action.item()], 156 | deeper_action_log_probs=self.last_deep_action_probs, 157 | deeper_value_pred=self.last_deep_value_pred[self.last_action.item()], 158 | last_action=self.last_action, 159 | full_probs_vector=self.full_probs, 160 | deeper_full_probs_vector=self.deeper_full_probs, 161 | rewards=reward) 162 | return True 163 | 164 | def end_episode(self, timesteps, num_processes): 165 | value_loss, action_loss = self.ppo.batch_updates(self.replay_buffer, self, go_deeper=self.deepen) 166 | self.num_steps += 1 167 | # Copy over new decision node params from shallower network to deeper network 168 | bot_name = '../txts/' + self.bot_name + str(num_processes) + '_processes' 169 | with open(bot_name + '_rewards.txt', 'a') as myfile: 170 | myfile.write(str(timesteps) + '\n') 171 | self.epsilon = max(self.epsilon*self.epsilon_decay, self.epsilon_min) 172 | 173 | def lower_lr(self): 174 | self.lr = self.lr * 0.5 175 | for param_group in self.ppo.actor_opt.param_groups: 176 | param_group['lr'] = param_group['lr'] * 0.5 177 | for param_group in self.ppo.critic_opt.param_groups: 178 | param_group['lr'] = param_group['lr'] * 0.5 179 | 180 | def reset(self): 181 | self.replay_buffer.clear() 182 | 183 | def save(self, fn='last'): 184 | act_fn = fn + self.bot_name + '_actor_' + '.pth.tar' 185 | val_fn = fn + self.bot_name + '_critic_' + '.pth.tar' 186 | 187 | deep_act_fn = fn + self.bot_name + '_deep_actor_' + '.pth.tar' 188 | deep_val_fn = fn + self.bot_name + '_deep_critic_' + '.pth.tar' 189 | save_prolonet(act_fn, self.action_network) 190 | save_prolonet(val_fn, self.value_network) 191 | if self.deepen: 192 | save_prolonet(deep_act_fn, self.deeper_action_network) 193 | save_prolonet(deep_val_fn, self.deeper_value_network) 194 | 195 | def load(self, fn='last', fn_botname=None): 196 | if fn_botname == None: 197 | act_fn = fn + self.bot_name + '_actor_' + '.pth.tar' 198 | val_fn = fn + self.bot_name + '_critic_' + '.pth.tar' 199 | else: 200 | act_fn = fn + '_actor_' + '.pth.tar' 201 | val_fn = fn + '_critic_' + '.pth.tar' 202 | 203 | deep_act_fn = fn + self.bot_name + '_deep_actor_' + '.pth.tar' 204 | deep_val_fn = fn + self.bot_name + '_deep_critic_' + '.pth.tar' 205 | if os.path.exists(act_fn): 206 | self.action_network = load_prolonet(act_fn) 207 | self.value_network = load_prolonet(val_fn) 208 | if self.deepen: 209 | self.deeper_action_network = load_prolonet(deep_act_fn) 210 | self.deeper_value_network = load_prolonet(deep_val_fn) 211 | else: 212 | return False 213 | return True 214 | 215 | def deepen_networks(self): 216 | if not self.deepen or self.num_times_deepened > 8: 217 | return 218 | self.entropy_leaf_checks() 219 | # Copy over shallow params to deeper network 220 | for weight_index in range(len(self.action_network.layers)): 221 | new_act_weight = torch.Tensor(self.action_network.layers[weight_index].cpu().data.numpy()) 222 | new_act_comp = torch.Tensor(self.action_network.comparators[weight_index].cpu().data.numpy()) 223 | 224 | if self.use_gpu: 225 | new_act_weight = new_act_weight.cuda() 226 | new_act_comp = new_act_comp.cuda() 227 | 228 | self.deeper_action_network.layers[weight_index].data = new_act_weight 229 | self.deeper_action_network.comparators[weight_index].data = new_act_comp 230 | for weight_index in range(len(self.value_network.layers)): 231 | new_val_weight = torch.Tensor(self.value_network.layers[weight_index].cpu().data.numpy()) 232 | new_val_comp = torch.Tensor(self.value_network.comparators[weight_index].cpu().data.numpy()) 233 | if self.use_gpu: 234 | new_val_weight = new_val_weight.cuda() 235 | new_val_comp = new_val_comp.cuda() 236 | self.deeper_value_network.layers[weight_index].data = new_val_weight 237 | self.deeper_value_network.comparators[weight_index].data = new_val_comp 238 | 239 | def entropy_leaf_checks(self): 240 | leaf_max = torch.nn.Softmax(dim=0) 241 | new_action_network = copy.deepcopy(self.action_network) 242 | changes_made = [] 243 | for leaf_index in range(len(self.action_network.action_probs)): 244 | existing_leaf = leaf_max(self.action_network.action_probs[leaf_index]) 245 | new_leaf_1 = leaf_max(self.deeper_action_network.action_probs[2*leaf_index+1]) 246 | new_leaf_2 = leaf_max(self.deeper_action_network.action_probs[2*leaf_index]) 247 | existing_entropy = Categorical(existing_leaf).entropy().item() 248 | new_entropy = Categorical(new_leaf_1).entropy().item() + \ 249 | Categorical(new_leaf_2).entropy().item() 250 | 251 | if new_entropy+0.1 <= existing_entropy: 252 | with open('../txts/' + self.bot_name + '_entropy_splits.txt', 'a') as myfile: 253 | myfile.write('Split at ' + str(self.num_steps) + ' steps' + ': \n') 254 | myfile.write('Leaf: ' + str(leaf_index) + '\n') 255 | myfile.write('Prior Probs: ' + str(self.action_network.action_probs[leaf_index]) + '\n') 256 | myfile.write('New Probs 1: ' + str(self.deeper_action_network.action_probs[leaf_index*2]) + '\n') 257 | myfile.write('New Probs 2: ' + str(self.deeper_action_network.action_probs[leaf_index*2+1]) + '\n') 258 | 259 | new_action_network = swap_in_node(new_action_network, self.deeper_action_network, leaf_index, use_gpu=self.use_gpu) 260 | changes_made.append(leaf_index) 261 | if len(changes_made) > 0: 262 | self.action_network = new_action_network 263 | 264 | if self.action_network.input_dim > 100: 265 | new_actor_opt = torch.optim.RMSprop(self.action_network.parameters(), lr=1e-5) 266 | elif self.action_network.input_dim >= 8: 267 | new_actor_opt = torch.optim.RMSprop(self.action_network.parameters(), lr=self.lr) 268 | else: 269 | new_actor_opt = torch.optim.RMSprop(self.action_network.parameters(), lr=self.lr) 270 | 271 | self.ppo.actor = self.action_network 272 | self.ppo.actor_opt = new_actor_opt 273 | 274 | for change in changes_made[::-1]: 275 | self.num_times_deepened += 1 276 | self.deeper_action_network = swap_in_node(self.deeper_action_network, None, change*2+1, use_gpu=self.use_gpu) 277 | self.deeper_action_network = swap_in_node(self.deeper_action_network, None, change*2, use_gpu=self.use_gpu) 278 | 279 | self.deeper_actor_opt = torch.optim.RMSprop(self.deeper_action_network.parameters(), lr=self.lr) 280 | 281 | def __getstate__(self): 282 | return { 283 | 'action_network': self.action_network, 284 | 'value_network': self.value_network, 285 | 'ppo': self.ppo, 286 | 'deeper_action_network': self.deeper_action_network, 287 | 'deeper_value_network': self.deeper_value_network, 288 | 'actor_opt': self.actor_opt, 289 | 'value_opt': self.value_opt, 290 | 'deeper_actor_opt': self.deeper_actor_opt, 291 | 'deeper_value_opt': self.deeper_value_opt, 292 | 'bot_name': self.bot_name, 293 | 'use_gpu': self.use_gpu, 294 | 'vectorized': self.vectorized, 295 | 'randomized': self.randomized, 296 | 'adversarial':self.adversarial, 297 | 'deepen': self.deepen, 298 | 'output_dim': self.output_dim, 299 | 'input_dim': self.input_dim, 300 | 'epsilon': self.epsilon, 301 | 'epsilon_decay': self.epsilon_decay, 302 | 'epsilon_min': self.epsilon_min, 303 | 'num_times_deepened': self.num_times_deepened, 304 | 'deterministic': self.deterministic, 305 | } 306 | 307 | def __setstate__(self, state): 308 | for key in state: 309 | setattr(self, key, state[key]) 310 | 311 | def duplicate(self): 312 | new_agent = DeepProLoNet(distribution='one_hot', 313 | bot_name=self.bot_name, 314 | input_dim=self.input_dim, 315 | output_dim=self.output_dim, 316 | use_gpu=self.use_gpu, 317 | vectorized=self.vectorized, 318 | randomized=self.randomized, 319 | adversarial=self.adversarial, 320 | deepen=self.deepen, 321 | epsilon=self.epsilon, 322 | epsilon_decay=self.epsilon_decay, 323 | epsilon_min=self.epsilon_min, 324 | deterministic=self.deterministic 325 | ) 326 | new_agent.__setstate__(self.__getstate__()) 327 | return new_agent 328 | -------------------------------------------------------------------------------- /agents/py_djinn_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | sys.path.insert(0, '../') 5 | from torch.distributions import Categorical 6 | from opt_helpers import replay_buffer, ppo_update 7 | import copy 8 | from agents.djinn_helpers import PyDJINN, DJINN_TREE_DATA, tree_to_nn_weights 9 | 10 | 11 | class DJINNAgent: 12 | def __init__(self, 13 | bot_name='DJINNAgent', 14 | input_dim=4, 15 | output_dim=2, 16 | drop_prob=0.0 17 | ): 18 | self.bot_name = bot_name 19 | self.input_dim = input_dim 20 | self.output_dim = output_dim 21 | self.replay_buffer = replay_buffer.ReplayBufferSingleAgent() 22 | 23 | if input_dim == 4: 24 | self.env_name = 'cart' 25 | elif input_dim == 6: # Fire Sim 26 | self.env_name = 'fire' 27 | elif input_dim == 8: 28 | self.env_name = 'lunar' 29 | elif input_dim == 37: 30 | self.env_name = 'sc_micro' 31 | elif input_dim > 100: 32 | self.env_name = 'sc_macro' 33 | tree_dict = DJINN_TREE_DATA[self.env_name] 34 | action_param_dict = tree_to_nn_weights(input_dim, output_dim, tree_dict) 35 | # value_param_dict = tree_to_nn_weights(input_dim, 1, tree_dict) 36 | self.action_network = PyDJINN(input_dim, 37 | weights=action_param_dict['weights'], 38 | biases=[], drop_prob=drop_prob, 39 | is_value=False 40 | ) 41 | self.value_network = PyDJINN(input_dim, 42 | weights=action_param_dict['weights'], 43 | biases=[], 44 | drop_prob=drop_prob, 45 | is_value=True 46 | ) 47 | self.ppo = ppo_update.PPO([self.action_network, self.value_network], two_nets=True) 48 | self.actor_opt = torch.optim.RMSprop(self.action_network.parameters(), lr=1e-2) 49 | self.value_opt = torch.optim.RMSprop(self.value_network.parameters(), lr=1e-2) 50 | 51 | self.last_state = [0, 0, 0, 0] 52 | self.last_action = 0 53 | self.last_action_probs = torch.Tensor([0]) 54 | self.last_value_pred = torch.Tensor([[0, 0]]) 55 | self.last_deep_action_probs = torch.Tensor([0]) 56 | self.last_deep_value_pred = torch.Tensor([[0, 0]]) 57 | self.full_probs = None 58 | self.reward_history = [] 59 | self.num_steps = 0 60 | 61 | def get_action(self, observation): 62 | with torch.no_grad(): 63 | obs = torch.Tensor(observation) 64 | obs = obs.view(1, -1) 65 | self.last_state = obs 66 | probs = self.action_network(obs) 67 | value_pred = self.value_network(obs) 68 | probs = probs.view(-1) 69 | self.full_probs = probs 70 | if self.action_network.input_dim > 30: 71 | probs, inds = torch.topk(probs, 3) 72 | m = Categorical(probs) 73 | action = m.sample() 74 | # # Epsilon learning 75 | # action = torch.argmax(probs) 76 | # if random.random() < self.epsilon: 77 | # action = torch.LongTensor([random.choice(np.arange(0, self.output_dim, dtype=np.int))]) 78 | # # End Epsilon learning 79 | log_probs = m.log_prob(action) 80 | self.last_action_probs = log_probs 81 | self.last_value_pred = value_pred.view(-1).cpu() 82 | 83 | if self.action_network.input_dim > 30: 84 | self.last_action = inds[action] 85 | else: 86 | self.last_action = action 87 | if self.action_network.input_dim > 30: 88 | action = inds[action].item() 89 | else: 90 | action = action.item() 91 | return action 92 | 93 | def save_reward(self, reward): 94 | self.replay_buffer.insert(obs=[self.last_state], 95 | action_log_probs=self.last_action_probs, 96 | value_preds=self.last_value_pred[self.last_action.item()], 97 | # value_preds=self.last_value_pred.item(), 98 | last_action=self.last_action, 99 | full_probs_vector=self.full_probs, 100 | rewards=reward) 101 | return True 102 | 103 | def end_episode(self, timesteps, num_processes=1): 104 | self.reward_history.append(timesteps) 105 | value_loss, action_loss = self.ppo.batch_updates(self.replay_buffer, self) 106 | bot_name = '../txts/' + self.bot_name + str(num_processes) + '_processes' 107 | self.num_steps += 1 108 | with open(bot_name + '_rewards.txt', 'a') as myfile: 109 | myfile.write(str(timesteps) + '\n') 110 | 111 | def lower_lr(self): 112 | for param_group in self.ppo.actor_opt.param_groups: 113 | param_group['lr'] = param_group['lr'] * 0.5 114 | for param_group in self.ppo.critic_opt.param_groups: 115 | param_group['lr'] = param_group['lr'] * 0.5 116 | 117 | def reset(self): 118 | self.replay_buffer.clear() 119 | 120 | def deepen_networks(self): 121 | pass 122 | 123 | def save(self, fn='last'): 124 | checkpoint = dict() 125 | checkpoint['actor'] = self.action_network.state_dict() 126 | checkpoint['value'] = self.value_network.state_dict() 127 | torch.save(checkpoint, fn+self.bot_name+'.pth.tar') 128 | 129 | def load(self, fn='last'): 130 | # fn = fn + self.bot_name + '.pth.tar' 131 | model_checkpoint = torch.load(fn, map_location='cpu') 132 | actor_data = model_checkpoint['actor'] 133 | value_data = model_checkpoint['value'] 134 | self.action_network.load_state_dict(actor_data) 135 | self.value_network.load_state_dict(value_data) 136 | 137 | def __getstate__(self): 138 | return { 139 | 'action_network': self.action_network, 140 | 'value_network': self.value_network, 141 | 'ppo': self.ppo, 142 | 'actor_opt': self.actor_opt, 143 | 'value_opt': self.value_opt, 144 | } 145 | 146 | def __setstate__(self, state): 147 | self.action_network = copy.deepcopy(state['action_network']) 148 | self.value_network = copy.deepcopy(state['value_network']) 149 | self.ppo = copy.deepcopy(state['ppo']) 150 | self.actor_opt = copy.deepcopy(state['actor_opt']) 151 | self.value_opt = copy.deepcopy(state['value_opt']) 152 | 153 | def duplicate(self): 154 | new_agent = DJINNAgent( 155 | bot_name=self.bot_name, 156 | input_dim=self.input_dim, 157 | output_dim=self.output_dim, 158 | ) 159 | new_agent.__setstate__(self.__getstate__()) 160 | return new_agent 161 | -------------------------------------------------------------------------------- /agents/vectorized_prolonet.py: -------------------------------------------------------------------------------- 1 | # Created by Andrew Silva on 2/21/19 2 | import torch.nn as nn 3 | import torch 4 | import numpy as np 5 | import typing as t 6 | 7 | 8 | class ProLoNet(nn.Module): 9 | def __init__(self, 10 | input_dim: int, 11 | weights: t.Union[t.List[np.array], np.array, None], 12 | comparators: t.Union[t.List[np.array], np.array, None], 13 | leaves: t.Union[None, int, t.List], # t.List[t.Tuple[t.List[int], t.List[int], np.array]]] 14 | selectors: t.Optional[t.List[np.array]] = None, 15 | output_dim: t.Optional[int] = None, 16 | alpha: float = 1.0, 17 | is_value: bool = False, 18 | device: str = 'cpu', 19 | vectorized: bool = False): 20 | super(ProLoNet, self).__init__() 21 | """ 22 | Initialize the ProLoNet, taking in premade weights for inputs to comparators and sigmoids 23 | Alternatively, pass in None to everything except for input_dim and output_dim, and you will get a randomly 24 | initialized tree. If you pass an int to leaves, it must be 2**N so that we can build a balanced tree 25 | :param input_dim: int. always required for input dimensionality 26 | :param weights: None or a list of lists, where each sub-list is a weight vector for each node 27 | :param comparators: None or a list of lists, where each sub-list is a comparator vector for each node 28 | :param leaves: None, int, or truple of [[left turn indices], [right turn indices], [final_probs]]. If int, must be 2**N 29 | :param output_dim: None or int, must be an int if weights and comparators are None 30 | :param alpha: int. Strictness of the tree, default 1sh 31 | :param is_value: if False, outputs are passed through a Softmax final layer. Default: False 32 | :param device: Which device should this network run on? Default: 'cpu' 33 | :param vectorized: Use a vectorized comparator? Default: False 34 | """ 35 | self.device = device 36 | self.vectorized = vectorized 37 | self.leaf_init_information = leaves 38 | 39 | self.input_dim = input_dim 40 | self.output_dim = output_dim 41 | self.layers = None 42 | self.comparators = None 43 | self.selector = None 44 | 45 | self.init_comparators(comparators) 46 | self.init_weights(weights) 47 | self.init_alpha(alpha) 48 | if self.vectorized: 49 | self.init_selector(selectors, weights) 50 | self.init_paths() 51 | self.init_leaves() 52 | self.added_levels = nn.Sequential() 53 | 54 | self.sig = nn.Sigmoid() 55 | self.softmax = nn.Softmax(dim=-1) 56 | self.is_value = is_value 57 | 58 | def init_comparators(self, comparators): 59 | if comparators is None: 60 | comparators = [] 61 | if type(self.leaf_init_information) is int: 62 | depth = int(np.floor(np.log2(self.leaf_init_information))) 63 | else: 64 | depth = 4 65 | for level in range(depth): 66 | for node in range(2**level): 67 | if self.vectorized: 68 | comparators.append(np.random.rand(self.input_dim)) 69 | else: 70 | comparators.append(np.random.normal(0, 1.0, 1)) 71 | new_comps = torch.tensor(comparators, dtype=torch.float).to(self.device) 72 | new_comps.requires_grad = True 73 | self.comparators = nn.Parameter(new_comps, requires_grad=True) 74 | 75 | def init_weights(self, weights): 76 | if weights is None: 77 | weights = [] 78 | if type(self.leaf_init_information) is int: 79 | depth = int(np.floor(np.log2(self.leaf_init_information))) 80 | else: 81 | depth = 4 82 | for level in range(depth): 83 | for node in range(2**level): 84 | weights.append(np.random.rand(self.input_dim)) 85 | 86 | new_weights = torch.tensor(weights, dtype=torch.float).to(self.device) 87 | new_weights.requires_grad = True 88 | self.layers = nn.Parameter(new_weights, requires_grad=True) 89 | 90 | def init_alpha(self, alpha): 91 | self.alpha = torch.tensor([alpha], dtype=torch.float).to(self.device) 92 | self.alpha.requires_grad = True 93 | self.alpha = nn.Parameter(self.alpha, requires_grad=True) 94 | 95 | def init_selector(self, selector, weights): 96 | if selector is None: 97 | if weights is None: 98 | selector = np.random.rand(*self.layers.size()) 99 | else: 100 | selector = [] 101 | for layer in self.layers: 102 | new_sel = np.zeros(layer.size()) + 1e-4 103 | max_ind = torch.argmax(torch.abs(layer)).item() 104 | new_sel[max_ind] = 1 105 | selector.append(new_sel) 106 | selector = torch.tensor(selector, dtype=torch.float).to(self.device) 107 | selector.requires_grad = True 108 | self.selector = nn.Parameter(selector, requires_grad=True) 109 | 110 | def init_paths(self): 111 | if type(self.leaf_init_information) is list: 112 | left_branches = torch.zeros((len(self.layers), len(self.leaf_init_information)), dtype=torch.float) 113 | right_branches = torch.zeros((len(self.layers), len(self.leaf_init_information)), dtype=torch.float) 114 | for n in range(0, len(self.leaf_init_information)): 115 | for i in self.leaf_init_information[n][0]: 116 | left_branches[i][n] = 1.0 117 | for j in self.leaf_init_information[n][1]: 118 | right_branches[j][n] = 1.0 119 | else: 120 | if type(self.leaf_init_information) is int: 121 | depth = int(np.floor(np.log2(self.leaf_init_information))) 122 | else: 123 | depth = 4 124 | left_branches = torch.zeros((2 ** depth - 1, 2 ** depth), dtype=torch.float) 125 | for n in range(0, depth): 126 | row = 2 ** n - 1 127 | for i in range(0, 2 ** depth): 128 | col = 2 ** (depth - n) * i 129 | end_col = col + 2 ** (depth - 1 - n) 130 | if row + i >= len(left_branches) or end_col >= len(left_branches[row]): 131 | break 132 | left_branches[row + i, col:end_col] = 1.0 133 | right_branches = torch.zeros((2 ** depth - 1, 2 ** depth), dtype=torch.float) 134 | left_turns = np.where(left_branches == 1) 135 | for row in np.unique(left_turns[0]): 136 | cols = left_turns[1][left_turns[0] == row] 137 | start_pos = cols[-1] + 1 138 | end_pos = start_pos + len(cols) 139 | right_branches[row, start_pos:end_pos] = 1.0 140 | left_branches.requires_grad = False 141 | right_branches.requires_grad = False 142 | self.left_path_sigs = left_branches.to(self.device) 143 | self.right_path_sigs = right_branches.to(self.device) 144 | 145 | def init_leaves(self): 146 | if type(self.leaf_init_information) is list: 147 | new_leaves = [leaf[-1] for leaf in self.leaf_init_information] 148 | else: 149 | new_leaves = [] 150 | if type(self.leaf_init_information) is int: 151 | depth = int(np.floor(np.log2(self.leaf_init_information))) 152 | else: 153 | depth = 4 154 | 155 | last_level = np.arange(2**(depth-1)-1, 2**depth-1) 156 | going_left = True 157 | leaf_index = 0 158 | self.leaf_init_information = [] 159 | for level in range(2**depth): 160 | curr_node = last_level[leaf_index] 161 | turn_left = going_left 162 | left_path = [] 163 | right_path = [] 164 | while curr_node >= 0: 165 | if turn_left: 166 | left_path.append(int(curr_node)) 167 | else: 168 | right_path.append(int(curr_node)) 169 | prev_node = np.ceil(curr_node / 2) - 1 170 | if curr_node // 2 > prev_node: 171 | turn_left = False 172 | else: 173 | turn_left = True 174 | curr_node = prev_node 175 | if going_left: 176 | going_left = False 177 | else: 178 | going_left = True 179 | leaf_index += 1 180 | new_probs = np.random.uniform(0, 1, self.output_dim) # *(1.0/self.output_dim) 181 | self.leaf_init_information.append([sorted(left_path), sorted(right_path), new_probs]) 182 | new_leaves.append(new_probs) 183 | 184 | labels = torch.tensor(new_leaves, dtype=torch.float).to(self.device) 185 | labels.requires_grad = True 186 | self.action_probs = nn.Parameter(labels, requires_grad=True) 187 | 188 | def forward(self, input_data, embedding_list=None): 189 | 190 | input_data = input_data.t().expand(self.layers.size(0), *input_data.t().size()) 191 | 192 | input_data = input_data.permute(2, 0, 1) 193 | comp = self.layers.mul(input_data) 194 | if not self.vectorized: 195 | comp = comp.sum(dim=2).unsqueeze(-1) 196 | comp = comp.sub(self.comparators.expand(input_data.size(0), *self.comparators.size())) 197 | comp = comp.mul(self.alpha) 198 | sig_vals = self.sig(comp) 199 | if self.vectorized: 200 | s_temp_main = self.selector.expand(input_data.size(0), *self.selector.size()) 201 | sig_vals = sig_vals.mul(s_temp_main) 202 | sig_vals = sig_vals.sum(dim=2) 203 | 204 | sig_vals = sig_vals.view(input_data.size(0), -1) 205 | one_minus_sig = torch.ones(sig_vals.size()).to(self.device) 206 | one_minus_sig = torch.sub(one_minus_sig, sig_vals) 207 | 208 | if input_data.size(0) > 1: 209 | left_path_probs = self.left_path_sigs.t() 210 | right_path_probs = self.right_path_sigs.t() 211 | left_path_probs = left_path_probs.expand(input_data.size(0), *left_path_probs.size()) * sig_vals.unsqueeze( 212 | 1) 213 | right_path_probs = right_path_probs.expand(input_data.size(0), 214 | *right_path_probs.size()) * one_minus_sig.unsqueeze(1) 215 | left_path_probs = left_path_probs.permute(0, 2, 1) 216 | right_path_probs = right_path_probs.permute(0, 2, 1) 217 | 218 | # We don't want 0s to ruin leaf probabilities, so replace them with 1s so they don't affect the product 219 | left_filler = torch.zeros(self.left_path_sigs.size()).to(self.device) 220 | left_filler[self.left_path_sigs == 0] = 1 221 | right_filler = torch.zeros(self.right_path_sigs.size()).to(self.device) 222 | right_filler[self.right_path_sigs == 0] = 1 223 | 224 | left_path_probs = left_path_probs.add(left_filler) 225 | right_path_probs = right_path_probs.add(right_filler) 226 | 227 | probs = torch.cat((left_path_probs, right_path_probs), dim=1) 228 | probs = probs.prod(dim=1) 229 | actions = probs.mm(self.action_probs) 230 | else: 231 | left_path_probs = self.left_path_sigs * sig_vals.t() 232 | right_path_probs = self.right_path_sigs * one_minus_sig.t() 233 | # We don't want 0s to ruin leaf probabilities, so replace them with 1s so they don't affect the product 234 | left_filler = torch.zeros(self.left_path_sigs.size(), dtype=torch.float).to(self.device) 235 | left_filler[self.left_path_sigs == 0] = 1 236 | right_filler = torch.zeros(self.right_path_sigs.size(), dtype=torch.float).to(self.device) 237 | right_filler[self.right_path_sigs == 0] = 1 238 | 239 | left_path_probs = torch.add(left_path_probs, left_filler) 240 | right_path_probs = torch.add(right_path_probs, right_filler) 241 | 242 | probs = torch.cat((left_path_probs, right_path_probs), dim=0) 243 | probs = probs.prod(dim=0) 244 | 245 | actions = (self.action_probs * probs.view(1, -1).t()).sum(dim=0) 246 | if not self.is_value: 247 | return self.softmax(actions) 248 | else: 249 | return actions 250 | -------------------------------------------------------------------------------- /gym_requirements.txt: -------------------------------------------------------------------------------- 1 | cffi==1.13.0 2 | cloudpickle==1.2.2 3 | cycler==0.10.0 4 | Cython==0.29.13 5 | fasteners==0.15 6 | future==0.18.0 7 | glfw==1.8.3 8 | gym==0.15.3 9 | imageio==2.6.1 10 | joblib==0.14.0 11 | kiwisolver==1.1.0 12 | matplotlib==3.1.1 13 | monotonic==1.5 14 | numpy==1.17.2 15 | pandas==0.25.1 16 | Pillow>=7.1.0 17 | pycparser==2.19 18 | pyglet==1.3.2 19 | pyparsing==2.4.2 20 | python-dateutil==2.8.0 21 | pytz==2019.3 22 | scikit-learn==0.21.3 23 | # scipy==1.3.1 24 | six==1.12.0 25 | sklearn==0.0 26 | torch 27 | # ==1.3.0 28 | torchvision 29 | # ==0.4.1 30 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | Models are saved here by default. -------------------------------------------------------------------------------- /opt_helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CORE-Robotics-Lab/ProLoNets/b1f138ed3520e623a84f2a7d5b1969fa45845700/opt_helpers/__init__.py -------------------------------------------------------------------------------- /opt_helpers/ppo_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.distributions import Categorical 6 | import numpy as np 7 | 8 | 9 | class PPO: 10 | def __init__(self, actor_critic_arr, two_nets=True, use_gpu=False): 11 | 12 | lr = 1e-3 13 | eps = 1e-5 14 | self.clip_param = 0.2 15 | self.ppo_epoch = 32 16 | self.num_mini_batch = 4 17 | self.value_loss_coef = 0.5 18 | self.entropy_coef = 0.01 19 | self.max_grad_norm = 0.5 20 | self.use_gpu = use_gpu 21 | if two_nets: 22 | self.actor = actor_critic_arr[0] 23 | self.critic = actor_critic_arr[1] 24 | if self.actor.input_dim > 100: 25 | self.actor_opt = optim.RMSprop(self.actor.parameters(), lr=1e-5) 26 | self.critic_opt = optim.RMSprop(self.critic.parameters(), lr=1e-5) 27 | elif self.actor.input_dim < 8: 28 | self.actor_opt = optim.RMSprop(self.actor.parameters(), lr=2e-2) 29 | self.critic_opt = optim.RMSprop(self.critic.parameters(), lr=2e-2) 30 | else: 31 | self.actor_opt = optim.RMSprop(self.actor.parameters(), lr=2e-2) 32 | self.critic_opt = optim.RMSprop(self.critic.parameters(), lr=2e-2) 33 | else: 34 | self.actor = actor_critic_arr 35 | self.actor_opt = optim.Adam(self.actor.parameters(), lr=lr, eps=eps) 36 | self.two_nets = two_nets 37 | self.epoch_counter = 0 38 | 39 | def sl_updates(self, rollouts, agent_in, heuristic_teacher): 40 | if self.actor.input_dim < 10: 41 | batch_size = max(rollouts.step // 32, 2) 42 | num_iters = rollouts.step // batch_size 43 | else: 44 | num_iters = 4 45 | batch_size = 32 46 | aggregate_actor_loss = 0 47 | for iteration in range(num_iters): 48 | total_action_loss = torch.Tensor([0]) 49 | total_value_loss = torch.Tensor([0]) 50 | for b in range(batch_size): 51 | sample = rollouts.sample() 52 | if not sample: 53 | break 54 | state = sample['state'] 55 | reward = sample['reward'] 56 | if np.isnan(reward): 57 | continue 58 | new_action_probs = self.actor(*state).view(1, -1) 59 | new_value = self.critic(*state) 60 | label = torch.LongTensor([heuristic_teacher.get_action(state[0].detach().clone().data.cpu().numpy()[0])]) 61 | action_loss = torch.nn.functional.cross_entropy(new_action_probs, label) 62 | new_value = new_value.view(-1, 1) 63 | new_value = new_value[label] 64 | reward = torch.Tensor([reward]).view(-1, 1) 65 | value_loss = F.mse_loss(reward, new_value) 66 | 67 | total_value_loss = total_value_loss + value_loss 68 | total_action_loss = total_action_loss + action_loss 69 | if total_value_loss != 0: 70 | self.critic_opt.zero_grad() 71 | total_value_loss.backward() 72 | self.critic_opt.step() 73 | if total_action_loss != 0: 74 | self.actor_opt.zero_grad() 75 | total_action_loss.backward() 76 | self.actor_opt.step() 77 | aggregate_actor_loss += total_action_loss.item() 78 | # aggregate_actor_loss /= float(num_iters*batch_size) 79 | agent_in.reset() 80 | return aggregate_actor_loss 81 | 82 | def batch_updates(self, rollouts, agent_in, go_deeper=False): 83 | if self.actor.input_dim == 8: 84 | batch_size = max(rollouts.step // 32, 1) 85 | num_iters = rollouts.step // batch_size 86 | elif self.actor.input_dim == 4: 87 | batch_size = max(rollouts.step // 32, 2) 88 | num_iters = rollouts.step // batch_size 89 | else: 90 | num_iters = 4 91 | batch_size = 32 92 | total_action_loss = torch.Tensor([0]) 93 | total_value_loss = torch.Tensor([0]) 94 | for iteration in range(num_iters): 95 | total_action_loss = torch.Tensor([0]) 96 | total_value_loss = torch.Tensor([0]) 97 | if self.use_gpu: 98 | total_action_loss = total_action_loss.cuda() 99 | total_value_loss = total_value_loss.cuda() 100 | if go_deeper: 101 | deep_total_action_loss = torch.Tensor([0]) 102 | deep_total_value_loss = torch.Tensor([0]) 103 | if self.use_gpu: 104 | deep_total_value_loss = deep_total_value_loss.cuda() 105 | deep_total_action_loss = deep_total_action_loss.cuda() 106 | samples = [rollouts.sample() for _ in range(batch_size)] 107 | samples = [sample for sample in samples if sample != False] 108 | if len(samples) <= 1: 109 | continue 110 | state = torch.cat([sample['state'][0] for sample in samples], dim=0) 111 | action_probs = torch.Tensor([sample['action_prob'] for sample in samples]) 112 | adv_targ = torch.Tensor([sample['advantage'] for sample in samples]) 113 | reward = torch.Tensor([sample['reward'] for sample in samples]) 114 | old_action_probs = torch.cat([sample['full_prob_vector'].unsqueeze(0) for sample in samples], dim=0) 115 | if True in np.array(np.isnan(adv_targ).tolist()) or \ 116 | True in np.array(np.isnan(reward).tolist()) or \ 117 | True in np.array(np.isnan(old_action_probs).tolist()): 118 | continue 119 | action_taken = torch.Tensor([sample['action_taken'] for sample in samples]) 120 | if self.use_gpu: 121 | action_taken = action_taken.cuda() 122 | state = state.cuda() 123 | action_probs = action_probs.cuda() 124 | old_action_probs = old_action_probs.cuda() 125 | adv_targ = adv_targ.cuda() 126 | reward = reward.cuda() 127 | if samples[0]['hidden_state'] is not None: 128 | actor_hidden_state_batch0 = torch.cat([sample['hidden_state'][0][0] for sample in samples], dim=1) 129 | actor_hidden_state_batch1 = torch.cat([sample['hidden_state'][0][1] for sample in samples], dim=1) 130 | actor_hidden_state = (actor_hidden_state_batch0, actor_hidden_state_batch1) 131 | critic_hidden_state_batch0 = torch.cat([sample['hidden_state'][1][0] for sample in samples], dim=1) 132 | critic_hidden_state_batch1 = torch.cat([sample['hidden_state'][1][1] for sample in samples], dim=1) 133 | critic_hidden_state = (critic_hidden_state_batch0, critic_hidden_state_batch1) 134 | 135 | new_action_probs, _ = self.actor(state, actor_hidden_state) 136 | new_value, _ = self.critic(state, critic_hidden_state) 137 | new_value = new_value.squeeze(1) 138 | new_action_probs = new_action_probs.squeeze(1) 139 | else: 140 | new_action_probs = self.actor(state) 141 | new_value = self.critic(state) 142 | 143 | if go_deeper: 144 | deep_action_probs = torch.Tensor([sample['deeper_action_prob'] for sample in samples]) 145 | deep_adv = torch.Tensor([sample['deeper_advantage'] for sample in samples]) 146 | deeper_old_probs = torch.cat([sample['deeper_full_prob_vector'].unsqueeze(0) for sample in samples], dim=0) 147 | if self.use_gpu: 148 | deep_action_probs = deep_action_probs.cuda() 149 | deeper_old_probs = deeper_old_probs.cuda() 150 | deep_adv = deep_adv.cuda() 151 | 152 | new_deep_probs = agent_in.deeper_action_network(state) 153 | new_deep_vals = agent_in.deeper_value_network(state) 154 | deep_dist = Categorical(new_deep_probs) 155 | deeper_probs = deep_dist.log_prob(action_taken) 156 | deeper_action_indices = [int(action_ind.item()) for action_ind in action_taken] 157 | deeper_val = new_deep_vals[np.arange(0, len(new_deep_vals)), deeper_action_indices] 158 | deeper_entropy = deep_dist.entropy().mean() * self.entropy_coef 159 | # deep_ratio = torch.nn.functional.kl_div(new_deep_probs, old_action_probs, reduction='batchmean').pow(-1) 160 | # # 161 | # deep_clipped = torch.clamp(deep_ratio, 1.0 - self.clip_param, 162 | # 1.0 + self.clip_param).mul(adv_targ).mul(deeper_probs) 163 | # # 164 | # deep_ratio = deep_ratio.mul(adv_targ).mul(deeper_probs) 165 | # deep_action_loss = -torch.min(deep_ratio, deep_clipped).mean() 166 | # 167 | deep_ratio = torch.exp(deeper_probs - deep_action_probs) 168 | deep_surr1 = deep_ratio * deep_adv 169 | deep_surr2 = torch.clamp(deep_ratio, 1.0-self.clip_param, 1+self.clip_param) * deep_adv 170 | deep_action_loss = -torch.min(deep_surr1, deep_surr2).mean() 171 | deep_total_action_loss = deep_total_action_loss + deep_action_loss - deeper_entropy 172 | deeper_value_loss = F.mse_loss(reward, deeper_val) 173 | 174 | deep_total_value_loss = deep_total_value_loss + deeper_value_loss 175 | # Copy over shallow params to deeper network 176 | for weight_index in range(len(self.actor.layers)): 177 | new_act_weight = torch.Tensor(self.actor.layers[weight_index].cpu().data.numpy()) 178 | new_act_comp = torch.Tensor(self.actor.comparators[weight_index].cpu().data.numpy()) 179 | 180 | if self.use_gpu: 181 | new_act_weight = new_act_weight.cuda() 182 | new_act_comp = new_act_comp.cuda() 183 | 184 | agent_in.deeper_action_network.layers[weight_index].data = new_act_weight 185 | agent_in.deeper_action_network.comparators[weight_index].data = new_act_comp 186 | for weight_index in range(len(self.critic.layers)): 187 | new_val_weight = torch.Tensor(self.critic.layers[weight_index].cpu().data.numpy()) 188 | new_val_comp = torch.Tensor(self.critic.comparators[weight_index].cpu().data.numpy()) 189 | if self.use_gpu: 190 | new_val_weight = new_val_weight.cuda() 191 | new_val_comp = new_val_comp.cuda() 192 | agent_in.deeper_value_network.layers[weight_index].data = new_val_weight 193 | agent_in.deeper_value_network.comparators[weight_index].data = new_val_comp 194 | 195 | update_m = Categorical(new_action_probs) 196 | update_log_probs = update_m.log_prob(action_taken) 197 | action_indices = [int(action_ind.item()) for action_ind in action_taken] 198 | new_value = new_value[np.arange(0, len(new_value)), action_indices] 199 | entropy = update_m.entropy().mean().mul(self.entropy_coef) 200 | # Fake PPO Updates: 201 | # ratio = torch.div(update_log_probs, action_probs) 202 | # 203 | # ratio = torch.nn.functional.kl_div(new_action_probs, old_action_probs, reduction='batchmean').pow(-1) 204 | # # # 205 | # clipped = torch.clamp(ratio, 1.0 - self.clip_param, 206 | # 1.0 + self.clip_param).mul(adv_targ).mul(update_log_probs) 207 | # # # 208 | # ratio = ratio.mul(adv_targ).mul(update_log_probs) 209 | # action_loss = -torch.min(ratio, clipped).mean() 210 | # 211 | # Real PPO Updates 212 | ratio = torch.exp(update_log_probs - action_probs) 213 | surr1 = ratio * adv_targ 214 | surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ 215 | action_loss = -torch.min(surr1, surr2).mean() 216 | # Policy Gradient: 217 | # action_loss = (torch.sum(torch.mul(update_log_probs, adv_targ).mul(-1), -1)) 218 | value_loss = F.mse_loss(reward, new_value) 219 | 220 | total_value_loss = total_value_loss.add(value_loss) 221 | total_action_loss = total_action_loss.add(action_loss).sub(entropy) 222 | nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) 223 | self.critic_opt.zero_grad() 224 | total_value_loss.backward() 225 | self.critic_opt.step() 226 | nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) 227 | self.actor_opt.zero_grad() 228 | total_action_loss.backward() 229 | self.actor_opt.step() 230 | if go_deeper: 231 | nn.utils.clip_grad_norm_(agent_in.deeper_value_network.parameters(), self.max_grad_norm) 232 | agent_in.deeper_value_opt.zero_grad() 233 | deep_total_value_loss.backward() 234 | agent_in.deeper_value_opt.step() 235 | nn.utils.clip_grad_norm_(agent_in.deeper_action_network.parameters(), self.max_grad_norm) 236 | agent_in.deeper_actor_opt.zero_grad() 237 | deep_total_action_loss.backward() 238 | agent_in.deeper_actor_opt.step() 239 | agent_in.deepen_networks() 240 | agent_in.reset() 241 | self.epoch_counter += 1 242 | return total_action_loss.item(), total_value_loss.item() 243 | -------------------------------------------------------------------------------- /opt_helpers/prolo_from_language.py: -------------------------------------------------------------------------------- 1 | # Created by Andrew Silva on 7/30/19 2 | import numpy as np 3 | import sys 4 | sys.path.insert(0, '../') 5 | from agents.vectorized_prolonet import ProLoNet 6 | from agents.vectorized_prolonet_helpers import save_prolonet, load_prolonet 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import matplotlib.lines as mlines 10 | import matplotlib.patches as mpatches 11 | from matplotlib.collections import PatchCollection 12 | from tkinter import * 13 | import copy 14 | 15 | global user_command 16 | 17 | 18 | DOMAIN = 'FireSim' # options are: 'FireSim' or 'Build_Marines' or 'robot' 19 | FIRE_OPTIONS = [ 20 | 'If Fire 1 is to my south ', 21 | 'If Fire 1 is to my north ', 22 | 'If Fire 1 is to my east ', 23 | 'If Fire 1 is to my west ', 24 | 'If I am the closest drone to Fire 1', 25 | 'If Fire 2 is to my south ', 26 | 'If Fire 2 is to my north ', 27 | 'If Fire 2 is to my east ', 28 | 'If Fire 2 is to my west ', 29 | 'If I am the closest drone to Fire 2', 30 | 'Move north', 31 | 'Move south', 32 | 'Move east', 33 | 'Move west' 34 | ] 35 | 36 | def construct_tree(node_data_in): 37 | """ 38 | takaes in a data dump from gen_init_params's loop 39 | :param node_data_in: 40 | :return: 41 | """ 42 | weights = [] 43 | comparators = [] 44 | leaves = [] 45 | current_right_path = [] 46 | current_left_path = [] 47 | current_node_index = 0 48 | # Variables necessary for drawing 49 | drawing_info = {} 50 | last_node_leaf = False 51 | node_x = 0 52 | node_y = 0 53 | max_path_len = 1 54 | for node_index in range(len(node_data_in)): 55 | data = node_data_in[node_index][0] 56 | is_leaf = node_data_in[node_index][1] 57 | command_to_write = node_data_in[node_index][2] 58 | if len(current_left_path) + len(current_right_path) > max_path_len: 59 | max_path_len = len(current_left_path)+len(current_right_path) 60 | if is_leaf: 61 | leaves.append( 62 | [current_left_path.copy(), current_right_path.copy(), data]) # paths all set, just throw the leaf on there 63 | if not current_left_path: # If we're at the far right end of the tree, we're done. 64 | drawing_info['lrfinal'] = [command_to_write, node_x, node_y] # save for drawing 65 | draw_tree(drawing_info) 66 | 67 | return np.array(weights), np.array(comparators), leaves 68 | last_node_index = current_node_index - 1 69 | if last_node_index in current_left_path: # If from a left split, just back up a step 70 | current_right_path.append(last_node_index) 71 | current_left_path.remove(last_node_index) 72 | new_node_x = node_x + 15 * 2 # * (1 + max_path_len - (len(current_right_path)+len(current_left_path))) 73 | node_y = node_y 74 | drawing_info['ll' + str(last_node_index)] = [command_to_write, node_x, node_y, node_x+15, node_y+15, new_node_x, node_y] # save for drawing 75 | node_x = new_node_x 76 | else: # If from a right split, back alllll the way up to the last left split. 77 | last_node_index = max(current_left_path) 78 | prev_path_len = len(current_left_path) + len(current_right_path) 79 | current_right_path = [i for i in current_right_path if i < last_node_index] 80 | current_left_path.remove(last_node_index) 81 | current_right_path.append(last_node_index) 82 | # TODO: Fix these x / y previous coords... 83 | new_node_x = drawing_info[last_node_index][1] + (15 * (1 + max_path_len - (len(current_right_path)+len(current_left_path)))) 84 | drawing_node_y = -15*(len(current_left_path)+len(current_right_path)-1) # 15*(len(current_right_path)) - 15*(len(current_left_path)+1) 85 | drawing_node_x = drawing_info[last_node_index][1] # 15*len(current_right_path) - 15*(len(current_left_path)+1) 86 | new_node_y = drawing_info[last_node_index][2] - 15 87 | drawing_info['lr' + str(last_node_index)] = [command_to_write, node_x, node_y, drawing_node_x, drawing_node_y, new_node_x, new_node_y] # save for drawing 88 | node_x = new_node_x 89 | node_y = new_node_y 90 | last_node_leaf = True 91 | else: 92 | weights.append(data[0]) 93 | comparators.append(data[1]) 94 | current_left_path.append(current_node_index) 95 | if last_node_leaf: 96 | new_node_x = node_x - 15 97 | new_node_y = node_y - 15 98 | else: 99 | new_node_x = node_x - 15 100 | new_node_y = node_y - 15 101 | drawing_info[current_node_index] = [command_to_write, node_x, node_y, node_x, node_y, new_node_x, new_node_y] 102 | node_x = new_node_x 103 | node_y = new_node_y 104 | 105 | current_node_index += 1 106 | last_node_leaf = False 107 | draw_tree(drawing_info) 108 | return False 109 | 110 | 111 | def gen_init_params(): 112 | 113 | data_dump = [] 114 | while True: 115 | # user_command = get_node_data_from_user() 116 | gui = make_gui(DOMAIN) 117 | 118 | # start the GUI 119 | gui.mainloop() 120 | if user_command == 'Undo': 121 | data_dump.pop(-1) 122 | else: 123 | data, is_leaf = data_from_command(user_command, DOMAIN) 124 | data_dump.append([data, is_leaf, copy.deepcopy(user_command)]) 125 | finished_outcome = construct_tree(data_dump) 126 | if finished_outcome: 127 | return finished_outcome[0], finished_outcome[1], finished_outcome[2] 128 | 129 | 130 | def make_gui(domain_in): 131 | if domain_in == 'FireSim': 132 | gui = Tk() 133 | button0 = Button(gui, text='Undo', fg='black', bg='white', 134 | command=lambda: press(-1, gui, domain_in), height=2, width=30) 135 | button0.grid(row=1, column=0) 136 | 137 | button1 = Button(gui, text=FIRE_OPTIONS[0], fg='black', bg='white', 138 | command=lambda: press(0, gui, domain_in), height=2, width=30) 139 | button1.grid(row=2, column=0) 140 | 141 | button2 = Button(gui, text=FIRE_OPTIONS[1], fg='black', bg='white', 142 | command=lambda: press(1, gui, domain_in), height=2, width=30) 143 | button2.grid(row=3, column=0) 144 | 145 | button3 = Button(gui, text=FIRE_OPTIONS[2], fg='black', bg='white', 146 | command=lambda: press(2, gui, domain_in), height=2, width=30) 147 | button3.grid(row=4, column=0) 148 | 149 | button4 = Button(gui, text=FIRE_OPTIONS[3], fg='black', bg='white', 150 | command=lambda: press(3, gui, domain_in), height=2, width=30) 151 | button4.grid(row=5, column=0) 152 | 153 | button5 = Button(gui, text=FIRE_OPTIONS[4], fg='black', bg='white', 154 | command=lambda: press(4, gui, domain_in), height=2, width=30) 155 | button5.grid(row=6, column=0) 156 | 157 | button6 = Button(gui, text=FIRE_OPTIONS[5], fg='black', bg='white', 158 | command=lambda: press(5, gui, domain_in), height=2, width=30) 159 | button6.grid(row=7, column=0) 160 | 161 | button7= Button(gui, text=FIRE_OPTIONS[6], fg='black', bg='white', 162 | command=lambda: press(6, gui, domain_in), height=2, width=30) 163 | button7.grid(row=8, column=0) 164 | 165 | button8 = Button(gui, text=FIRE_OPTIONS[7], fg='black', bg='white', 166 | command=lambda: press(7, gui, domain_in), height=2, width=30) 167 | button8.grid(row=9, column=0) 168 | 169 | button9 = Button(gui, text=FIRE_OPTIONS[8], fg='black', bg='white', 170 | command=lambda: press(8, gui, domain_in), height=2, width=30) 171 | button9.grid(row=10, column=0) 172 | 173 | button10 = Button(gui, text=FIRE_OPTIONS[9], fg='black', bg='white', 174 | command=lambda: press(9, gui, domain_in), height=2, width=30) 175 | button10.grid(row=11, column=0) 176 | 177 | button11 = Button(gui, text=FIRE_OPTIONS[10], fg='black', bg='white', 178 | command=lambda: press(10, gui, domain_in), height=2, width=30) 179 | button11.grid(row=1, column=1) 180 | 181 | button12 = Button(gui, text=FIRE_OPTIONS[11], fg='black', bg='white', 182 | command=lambda: press(11, gui, domain_in), height=2, width=30) 183 | button12.grid(row=2, column=1) 184 | 185 | button13 = Button(gui, text=FIRE_OPTIONS[12], fg='black', bg='white', 186 | command=lambda: press(12, gui, domain_in), height=2, width=30) 187 | button13.grid(row=3, column=1) 188 | 189 | button14 = Button(gui, text=FIRE_OPTIONS[13], fg='black', bg='white', 190 | command=lambda: press(13, gui, domain_in), height=2, width=30) 191 | button14.grid(row=4, column=1) 192 | return gui 193 | 194 | 195 | def close_event(): 196 | plt.close() # timer calls this function after 10 seconds and closes the window 197 | 198 | 199 | def press(int_in, gui_in, domain_in): 200 | if domain_in == 'FireSim': 201 | options = FIRE_OPTIONS 202 | global user_command 203 | user_command = options[int_in] 204 | if int_in == -1: 205 | user_command = 'Undo' 206 | gui_in.destroy() 207 | close_event() 208 | 209 | 210 | def data_from_command(command_in, domain="fire_sim"): 211 | """ 212 | Interpret some user command into weights/comparators or a leaf 213 | :param command_in: text from speech to text? 214 | :param domain: which domain should I be looking through? fire_sim, sc2, or sawyer 215 | :return: data (tuple of some sort), is_leaf=True/False 216 | """ 217 | print(f"Command: {command_in}") 218 | if domain == 'FireSim': 219 | weights = np.zeros(6) 220 | comparator = [5.] 221 | leaf = np.zeros(5) 222 | if command_in == FIRE_OPTIONS[0]: # if fire 1 south 223 | weights[0] = 1 224 | return (weights, comparator), False 225 | elif command_in == FIRE_OPTIONS[1]: # if fire 1 north 226 | weights[0] = -1 227 | return (weights, comparator), False 228 | elif command_in == FIRE_OPTIONS[2]: # if fire 1 east 229 | weights[1] = -1 230 | return (weights, comparator), False 231 | elif command_in == FIRE_OPTIONS[3]: # if fire 1 west 232 | weights[1] = 1 233 | return (weights, comparator), False 234 | elif command_in == FIRE_OPTIONS[4]: # if closer to fire 1 235 | weights[4] = 5 236 | comparator = [1.] 237 | return (weights, comparator), False 238 | elif command_in == FIRE_OPTIONS[5]: # if fire 2 south 239 | weights[2] = 1 240 | return (weights, comparator), False 241 | elif command_in == FIRE_OPTIONS[6]: # if fire 2 north 242 | weights[2] = -1 243 | return (weights, comparator), False 244 | elif command_in == FIRE_OPTIONS[7]: # if fire 2 east 245 | weights[3] = -1 246 | return (weights, comparator), False 247 | elif command_in == FIRE_OPTIONS[8]: # if fire 2 west 248 | weights[3] = 1 249 | return (weights, comparator), False 250 | elif command_in == FIRE_OPTIONS[9]: # if closer to fire 2 251 | weights[5] = 5 252 | comparator = [1.] 253 | return (weights, comparator), False 254 | elif command_in == FIRE_OPTIONS[10]: # move north 255 | leaf[3] = 1 256 | return leaf, True 257 | elif command_in == FIRE_OPTIONS[11]: # move south 258 | leaf[1] = 1 259 | return leaf, True 260 | elif command_in == FIRE_OPTIONS[12]: # move east 261 | leaf[2] = 1 262 | return leaf, True 263 | elif command_in == FIRE_OPTIONS[13]: # move west 264 | leaf[0] = 1 265 | return leaf, True 266 | print("Something went wrong...") 267 | 268 | 269 | def init_actor_and_critic(weights, comparators, leaves, use_gpu=False, alpha=1.0): 270 | """ 271 | Take in some np arrays of initialization params from gen_init_params 272 | :param weights: 273 | :param comparators: 274 | :param leaves: 275 | :param use_gpu: 276 | :param alpha: 277 | :return: 278 | """ 279 | dim_in = weights.shape[-1] 280 | dim_out = leaves[0][-1].shape[-1] 281 | if len(comparators[0]) == 1: 282 | vectorized = False 283 | else: 284 | vectorized = True 285 | action_network = ProLoNet(input_dim=dim_in, 286 | output_dim=dim_out, 287 | weights=weights, 288 | comparators=comparators, 289 | leaves=leaves, 290 | alpha=alpha, 291 | is_value=False, 292 | use_gpu=use_gpu, 293 | vectorized=vectorized) 294 | value_network = ProLoNet(input_dim=dim_in, 295 | output_dim=dim_out, 296 | weights=weights, 297 | comparators=comparators, 298 | leaves=leaves, 299 | alpha=alpha, 300 | is_value=True, 301 | use_gpu=use_gpu, 302 | vectorized=vectorized) 303 | return action_network, value_network 304 | 305 | 306 | def prolo_from_language_main(): 307 | weights, comparators, leaves = gen_init_params() 308 | actor, critic = init_actor_and_critic(weights, comparators, leaves, alpha=99999) 309 | save_prolonet('../study_models/usermadeprolo' + DOMAIN + "_actor_.pth.tar", actor) 310 | save_prolonet('../study_models/usermadeprolo' + DOMAIN + "_critic_.pth.tar", critic) 311 | 312 | 313 | def draw_tree(draw_info): 314 | patches = [] 315 | 316 | fig, ax = plt.subplots(figsize=(18, 12)) 317 | 318 | timer = fig.canvas.new_timer(interval=10000) # creating a timer object and setting an interval of 10 sec 319 | timer.add_callback(close_event) 320 | lines = [] 321 | text_size = 14 322 | box_width = 10 323 | for node in draw_info.values(): 324 | x_coord = node[1] 325 | y_coord = node[2] 326 | text = node[0] 327 | if len(node) > 3: 328 | l1_x_coord = node[3] 329 | l1_y_coord = node[4] 330 | l2_x_coord = node[5] 331 | l2_y_coord = node[6] 332 | this_node = mpatches.FancyBboxPatch(xy=[x_coord, y_coord], 333 | width=box_width, 334 | height=1, 335 | boxstyle=mpatches.BoxStyle("Round", pad=4)) 336 | patches.append(this_node) 337 | # last_node = last_node_coords[node//2] 338 | line = mlines.Line2D([l1_x_coord+(box_width//2), l2_x_coord+(box_width//2)], 339 | [l1_y_coord, l2_y_coord+1]) 340 | lines.append(line) 341 | plt.text(x_coord+5, y_coord, text, ha="center", family='sans-serif', size=text_size) 342 | 343 | colors = np.linspace(0, 1, len(patches)) 344 | collection = PatchCollection(patches, cmap=plt.cm.hsv, alpha=0.3) 345 | collection.set_array(np.array(colors)) 346 | ax.add_collection(collection) 347 | for line in lines: 348 | ax.add_line(line) 349 | 350 | plt.axis('equal') 351 | plt.axis('off') 352 | plt.tight_layout() 353 | plt.ion() 354 | # timer.start() 355 | plt.savefig(f'{DOMAIN}_expert_policy.png') 356 | 357 | plt.show() 358 | # timer.stop() 359 | plt.pause(0.01) 360 | 361 | prolo_from_language_main() 362 | -------------------------------------------------------------------------------- /opt_helpers/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class ReplayBufferSingleAgent(object): 7 | def __init__(self): 8 | self.states_list = [] 9 | self.action_probs_list = [] 10 | self.value_list = [] 11 | self.hidden_state_list = [] 12 | self.rewards_list = [] 13 | self.deeper_value_list = [] 14 | self.deeper_action_list = [] 15 | self.deeper_advantage_list = [] 16 | self.action_taken_list = [] 17 | self.advantage_list = [] 18 | self.full_probs_list = [] 19 | self.deeper_full_probs_list = [] 20 | self.step = -1 21 | 22 | def __getstate__(self): 23 | all_data = { 24 | 'states': self.states_list, 25 | 'actions': self.action_probs_list, 26 | 'values': self.value_list, 27 | 'hidden_states': self.hidden_state_list, 28 | 'rewards': self.rewards_list, 29 | 'steps': self.step, 30 | 'deeper_values': self.deeper_value_list, 31 | 'deeper_actions': self.deeper_action_list, 32 | 'actions_taken': self.action_taken_list, 33 | 'advantage_list': self.advantage_list, 34 | 'deeper_advantage_list': self.deeper_advantage_list, 35 | 'full_probs_list': self.full_probs_list, 36 | 'deeper_full_probs_list': self.deeper_full_probs_list 37 | } 38 | return all_data 39 | 40 | def __setstate__(self, state): 41 | self.states_list = state['states'] 42 | self.action_probs_list = state['actions'] 43 | self.value_list = state['values'] 44 | self.hidden_state_list = state['hidden_states'] 45 | self.rewards_list = state['rewards'] 46 | self.step = state['steps'] 47 | self.deeper_value_list = state['deeper_values'] 48 | self.deeper_action_list = state['deeper_actions'] 49 | self.action_taken_list = state['actions_taken'] 50 | self.advantage_list = state['advantage_list'] 51 | self.deeper_advantage_list = state['deeper_advantage_list'] 52 | self.full_probs_list = state['full_probs_list'] 53 | self.deeper_full_probs_list = state['deeper_full_probs_list'] 54 | 55 | def extend(self, state): 56 | self.states_list.extend(state['states']) 57 | self.action_probs_list.extend(state['actions']) 58 | self.value_list.extend(state['values']) 59 | self.hidden_state_list.extend(state['hidden_states']) 60 | self.rewards_list.extend(state['rewards']) 61 | self.step += state['steps'] 62 | self.deeper_value_list.extend(state['deeper_values']) 63 | self.deeper_action_list.extend(state['deeper_actions']) 64 | self.action_taken_list.extend(state['actions_taken']) 65 | self.advantage_list.extend(state['advantage_list']) 66 | self.deeper_advantage_list.extend(state['deeper_advantage_list']) 67 | self.full_probs_list.extend(state['full_probs_list']) 68 | self.deeper_full_probs_list.extend(state['deeper_full_probs_list']) 69 | 70 | def trim(self, amount): 71 | self.states_list = self.states_list[-amount:] 72 | self.action_probs_list = self.action_probs_list[-amount:] 73 | self.value_list = self.value_list[-amount:] 74 | self.hidden_state_list = self.hidden_state_list[-amount:] 75 | self.rewards_list = self.rewards_list[-amount:] 76 | self.step = min(self.step, amount) 77 | self.deeper_value_list = self.deeper_value_list[-amount:] 78 | self.deeper_action_list = self.deeper_action_list[-amount:] 79 | self.action_taken_list = self.action_taken_list[-amount:] 80 | self.advantage_list = self.advantage_list[-amount:] 81 | self.deeper_advantage_list = self.deeper_advantage_list[-amount:] 82 | self.full_probs_list = self.full_probs_list[-amount:] 83 | self.deeper_full_probs_list = self.deeper_full_probs_list[-amount:] 84 | self.step = amount 85 | 86 | def insert(self, 87 | obs=None, 88 | recurrent_hidden_states=None, 89 | action_log_probs=None, 90 | value_preds=None, 91 | deeper_action_log_probs=None, 92 | deeper_value_pred=None, 93 | last_action=None, 94 | full_probs_vector=None, 95 | deeper_full_probs_vector=None, 96 | rewards=None): 97 | self.states_list.append(obs) 98 | self.hidden_state_list.append(recurrent_hidden_states) 99 | self.action_probs_list.append(action_log_probs) 100 | self.value_list.append(value_preds) 101 | self.rewards_list.append(rewards) 102 | self.deeper_action_list.append(deeper_action_log_probs) 103 | self.deeper_value_list.append(deeper_value_pred) 104 | self.action_taken_list.append(last_action) 105 | self.full_probs_list.append(full_probs_vector) 106 | self.deeper_full_probs_list.append(deeper_full_probs_vector) 107 | # self.done_list.append(done) 108 | self.step += 1 109 | 110 | def clear(self): 111 | del self.states_list[:] 112 | del self.hidden_state_list[:] 113 | del self.value_list[:] 114 | del self.action_probs_list[:] 115 | del self.rewards_list[:] 116 | del self.deeper_value_list[:] 117 | del self.deeper_action_list[:] 118 | del self.action_taken_list[:] 119 | del self.deeper_advantage_list[:] 120 | del self.advantage_list[:] 121 | del self.full_probs_list[:] 122 | del self.deeper_full_probs_list[:] 123 | self.step = 0 124 | 125 | def sample(self): 126 | # randomly sample a time step 127 | if len(self.states_list) <= 0: 128 | return False 129 | t = random.randint(0, len(self.states_list)-1) 130 | sample_back = { 131 | 'state': self.states_list[t], 132 | 'hidden_state': self.hidden_state_list[t], 133 | 'action_prob': self.action_probs_list[t], 134 | 'value_pred': self.value_list[t], 135 | 'deeper_action_prob': self.deeper_action_list[t], 136 | 'deeper_value_pred': self.deeper_value_list[t], 137 | 'reward': self.rewards_list[t], 138 | 'advantage': self.advantage_list[t], 139 | 'action_taken': self.action_taken_list[t], 140 | 'deeper_advantage': self.deeper_advantage_list[t], 141 | 'full_prob_vector': self.full_probs_list[t], 142 | 'deeper_full_prob_vector': self.deeper_full_probs_list[t] 143 | } 144 | return sample_back 145 | 146 | 147 | def discount_reward(reward, value, deeper_value): 148 | R = 0 149 | rewards = [] 150 | all_rewards = reward 151 | reward_sum = sum(all_rewards) 152 | all_values = value 153 | deeper_all_values = deeper_value 154 | # Discount future rewards back to the present using gamma 155 | advantages = [] 156 | deeper_advantages = [] 157 | 158 | for r, v, d_v in zip(all_rewards[::-1], all_values[::-1], deeper_all_values[::-1]): 159 | R = r + 0.99 * R 160 | rewards.insert(0, R) 161 | advantages.insert(0, R - v) 162 | if d_v is not None: 163 | deeper_advantages.insert(0, R - d_v) 164 | advantages = torch.Tensor(advantages) 165 | rewards = torch.Tensor(rewards) 166 | 167 | if len(deeper_advantages) > 0: 168 | deeper_advantages = torch.Tensor(deeper_advantages) 169 | deeper_advantages = (deeper_advantages - deeper_advantages.mean()) / ( 170 | deeper_advantages.std() + torch.Tensor([np.finfo(np.float32).eps])) 171 | deeper_advantage_list = deeper_advantages.detach().clone().cpu().numpy().tolist() 172 | else: 173 | deeper_advantage_list = [None] * len(all_rewards) 174 | # Scale rewards 175 | rewards = (rewards - rewards.mean()) / (rewards.std() + torch.Tensor([np.finfo(np.float32).eps])) 176 | advantages = (advantages - advantages.mean()) / (advantages.std() + torch.Tensor([np.finfo(np.float32).eps])) 177 | rewards_list = rewards.detach().clone().cpu().numpy().tolist() 178 | advantage_list = advantages.detach().clone().cpu().numpy().tolist() 179 | return rewards_list, advantage_list, deeper_advantage_list 180 | -------------------------------------------------------------------------------- /python38.txt: -------------------------------------------------------------------------------- 1 | argon2-cffi==21.3.0 2 | argon2-cffi-bindings==21.2.0 3 | asttokens==2.0.5 4 | attrs==21.4.0 5 | backcall==0.2.0 6 | bleach==4.1.0 7 | Box2D==2.3.10 8 | cffi==1.13.0 9 | cloudpickle==1.2.2 10 | cycler==0.10.0 11 | Cython==0.29.13 12 | debugpy==1.5.1 13 | decorator==5.1.1 14 | defusedxml==0.7.1 15 | entrypoints==0.4 16 | executing==0.8.3 17 | fasteners==0.15 18 | future==0.18.0 19 | glfw==1.8.3 20 | gym==0.15.3 21 | imageio==2.6.1 22 | importlib-resources==5.4.0 23 | ipykernel==6.9.1 24 | ipython==8.1.1 25 | ipython-genutils==0.2.0 26 | ipywidgets==7.6.5 27 | jedi==0.18.1 28 | Jinja2==3.0.3 29 | joblib==0.14.0 30 | jsonschema==4.4.0 31 | jupyter==1.0.0 32 | jupyter-client==7.1.2 33 | jupyter-console==6.4.0 34 | jupyter-core==4.9.2 35 | jupyterlab-pygments==0.1.2 36 | jupyterlab-widgets==1.0.2 37 | kiwisolver==1.1.0 38 | MarkupSafe==2.1.0 39 | matplotlib==3.1.1 40 | matplotlib-inline==0.1.3 41 | mistune==0.8.4 42 | monotonic==1.5 43 | nbclient==0.5.11 44 | nbconvert==6.4.2 45 | nbformat==5.1.3 46 | nest-asyncio==1.5.4 47 | notebook==6.4.8 48 | numpy==1.21.2 49 | packaging==21.3 50 | pandas==1.3.3 51 | pandocfilters==1.5.0 52 | parso==0.8.3 53 | pexpect==4.8.0 54 | pickleshare==0.7.5 55 | Pillow==8.3.1 56 | prometheus-client==0.13.1 57 | prompt-toolkit==3.0.28 58 | ptyprocess==0.7.0 59 | pure-eval==0.2.2 60 | pycparser==2.19 61 | pyglet==1.3.2 62 | Pygments==2.11.2 63 | pyparsing==2.4.2 64 | pyrsistent==0.18.1 65 | python-dateutil==2.8.0 66 | pytz==2019.3 67 | pyzmq==22.3.0 68 | qtconsole==5.2.2 69 | QtPy==2.0.1 70 | scikit-learn==0.21.3 71 | scipy==1.7.0 72 | Send2Trash==1.8.0 73 | six==1.12.0 74 | sklearn==0.0 75 | stack-data==0.2.0 76 | terminado==0.13.2 77 | testpath==0.6.0 78 | torch==1.9.0 79 | torchvision==0.10.0 80 | tornado==6.1 81 | traitlets==5.1.1 82 | typing-extensions==3.10.0.0 83 | wcwidth==0.2.5 84 | webencodings==0.5.1 85 | widgetsnbextension==3.5.2 86 | zipp==3.7.0 87 | -------------------------------------------------------------------------------- /runfiles/gym_runner.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import sys 4 | import torch 5 | import os 6 | sys.path.insert(0, '../') 7 | from agents.prolonet_agent import DeepProLoNet 8 | # from agents.non_deep_prolonet_agent import ProLoLoki 9 | # from agents.random_prolonet_agent import RandomProLoNet 10 | from agents.heuristic_agent import LunarHeuristic, CartPoleHeuristic, RandomHeuristic 11 | from agents.lstm_agent import LSTMNet 12 | from agents.baseline_agent import FCNet 13 | from agents.py_djinn_agent import DJINNAgent 14 | import random 15 | from opt_helpers.replay_buffer import discount_reward 16 | import torch.multiprocessing as mp 17 | import argparse 18 | import copy 19 | 20 | 21 | def run_episode(q, agent_in, ENV_NAME, env_seed_in=42): 22 | agent = agent_in.duplicate() 23 | if ENV_NAME == 'lunar': 24 | env = gym.make('LunarLander-v2') 25 | elif ENV_NAME == 'cart': 26 | env = gym.make('CartPole-v1') 27 | else: 28 | raise Exception('No valid environment selected') 29 | if seed_in is not None: 30 | env.seed(env_seed_in) 31 | np.random.seed(env_seed_in) 32 | torch.random.manual_seed(env_seed_in) 33 | random.seed(env_seed_in) 34 | 35 | state = env.reset() # Reset environment and record the starting state 36 | done = False 37 | 38 | while not done: 39 | action = agent.get_action(state) 40 | # Step through environment using chosen action 41 | state, reward, done, _ = env.step(action) 42 | # Save reward 43 | agent.save_reward(reward) 44 | if done: 45 | break 46 | reward_sum = np.sum(agent.replay_buffer.rewards_list) 47 | rewards_list, advantage_list, deeper_advantage_list = discount_reward(agent.replay_buffer.rewards_list, 48 | agent.replay_buffer.value_list, 49 | agent.replay_buffer.deeper_value_list) 50 | agent.replay_buffer.rewards_list = rewards_list 51 | agent.replay_buffer.advantage_list = advantage_list 52 | agent.replay_buffer.deeper_advantage_list = deeper_advantage_list 53 | 54 | to_return = [reward_sum, copy.deepcopy(agent.replay_buffer.__getstate__())] 55 | if q is not None: 56 | try: 57 | q.put(to_return) 58 | except RuntimeError as e: 59 | print(e) 60 | return to_return 61 | return to_return 62 | 63 | 64 | def main(episodes, agent, num_processes, ENV_NAME, random_seed): 65 | running_reward_array = [] 66 | for episode in range(episodes): 67 | master_reward = 0 68 | if random_seed is not None: 69 | random_seed_in = random_seed + episode 70 | else: 71 | random_seed_in = random_seed 72 | returned_object = run_episode(None, agent_in=agent, ENV_NAME=ENV_NAME, env_seed_in=random_seed_in) 73 | master_reward += returned_object[0] 74 | running_reward_array.append(returned_object[0]) 75 | agent.replay_buffer.extend(returned_object[1]) 76 | 77 | reward = master_reward / float(num_processes) 78 | agent.end_episode(reward, num_processes) 79 | 80 | running_reward = sum(running_reward_array[-100:]) / float(min(100.0, len(running_reward_array))) 81 | if episode % 50 == 0: 82 | print(f'Episode {episode} Last Reward: {reward} Average Reward: {running_reward}') 83 | if episode % 250 == 0: 84 | agent.lower_lr() 85 | 86 | return running_reward_array 87 | 88 | 89 | if __name__ == "__main__": 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument("-a", "--agent_type", help="architecture of agent to run", type=str, default='djinn') 92 | parser.add_argument("-e", "--episodes", help="how many episodes", type=int, default=1000) 93 | parser.add_argument("-p", "--processes", help="how many processes?", type=int, default=1) 94 | parser.add_argument("-env", "--env_type", help="environment to run on", type=str, default='cart') 95 | parser.add_argument("-gpu", help="run on GPU?", action='store_true') 96 | parser.add_argument("-vec", help="Vectorized ProLoNet?", action='store_true') 97 | parser.add_argument("-adv", help="Adversarial ProLoNet?", action='store_true') 98 | parser.add_argument("-rand", help="Random ProLoNet?", action='store_true') 99 | parser.add_argument("-deep", help="Deepening?", action='store_true') 100 | parser.add_argument("-s", "--sl_init", help="sl to rl for fc net?", action='store_true') 101 | parser.add_argument("--reproduce", help="use saved random seeds and make deterministic?", action="store_true") 102 | 103 | args = parser.parse_args() 104 | AGENT_TYPE = args.agent_type # 'shallow_prolo', 'prolo', 'random', 'fc', 'lstm' 105 | ADVERSARIAL = args.adv # Adversarial prolo, applies for AGENT_TYPE=='shallow_prolo' 106 | SL_INIT = args.sl_init # SL->RL fc, applies only for AGENT_TYPE=='fc' 107 | NUM_EPS = args.episodes # num episodes Default 1000 108 | NUM_PROCS = args.processes # num concurrent processes Default 1 109 | ENV_TYPE = args.env_type # 'cart' or 'lunar' Default 'cart' 110 | USE_GPU = args.gpu # Applies for 'prolo' only. use gpu? Default false 111 | VECTORIZED = args.vec # Applies for 'prolo' vectorized or no? Default false 112 | RANDOM = args.rand # Applies for 'prolo' random init or no? Default false 113 | DEEPEN = args.deep # Applies for 'prolo' deepen or no? Default false 114 | EPSILON = 1.0 # Chance to take random action 115 | EPSILON_DECAY = 0.99 # Multiplicative decay on epsilon 116 | EPSILON_MIN = 0.1 # Floor for epsilon 117 | 118 | REPRODUCE = args.reproduce 119 | # torch.set_num_threads(NUM_PROCS) 120 | torch.use_deterministic_algorithms(True) 121 | for NUM_PROCS in [NUM_PROCS]: 122 | if ENV_TYPE == 'lunar': 123 | init_env = gym.make('LunarLander-v2') 124 | dim_in = init_env.observation_space.shape[0] 125 | dim_out = init_env.action_space.n 126 | elif ENV_TYPE == 'cart': 127 | init_env = gym.make('CartPole-v1') 128 | dim_in = init_env.observation_space.shape[0] 129 | dim_out = init_env.action_space.n 130 | else: 131 | raise Exception('No valid environment selected') 132 | 133 | print(f"Agent {AGENT_TYPE} on {ENV_TYPE} with {NUM_PROCS} runners") 134 | # mp.set_start_method('spawn') 135 | # mp.set_sharing_strategy('file_system') 136 | if REPRODUCE: 137 | if ENV_TYPE == 'cart': 138 | seeds = [42, 43, 45, 46, 47] 139 | if AGENT_TYPE == 'fc' and SL_INIT: 140 | seeds = [42, 43, 48, 49, 50] 141 | if AGENT_TYPE == 'prolo' and RANDOM: 142 | seeds = [100, 62, 65, 104, 112] 143 | elif ENV_TYPE == 'lunar': 144 | if AGENT_TYPE == 'djinn': 145 | seeds = [62, 51, 69, 60, 57] 146 | elif AGENT_TYPE == 'lstm': 147 | seeds = [42, 54, 49, 45, 53] 148 | elif AGENT_TYPE == 'fc' and SL_INIT: 149 | seeds = [42, 43, 44, 45, 46] 150 | elif AGENT_TYPE == 'fc': 151 | seeds = [42, 60, 54, 82, 75] 152 | elif AGENT_TYPE == 'prolo' and RANDOM: 153 | seeds = [45, 48, 49, 53, 67] 154 | else: 155 | seeds = [42, 46, 48, 50, 51] 156 | for i in range(5): 157 | if REPRODUCE: 158 | seed_in = seeds[i] 159 | 160 | init_env.seed(seed_in) 161 | np.random.seed(seed_in) 162 | torch.random.manual_seed(seed_in) 163 | random.seed(seed_in) 164 | os.environ['PYTHONHASHSEED'] = str(seed_in) 165 | torch.backends.cudnn.deterministic = True 166 | torch.backends.cudnn.benchmark = False 167 | torch.backends.cudnn.enabled = False 168 | torch.cuda.manual_seed_all(seed_in) 169 | torch.cuda.manual_seed(seed_in) 170 | else: 171 | seed_in = None 172 | bot_name = AGENT_TYPE + ENV_TYPE 173 | if USE_GPU: 174 | bot_name += 'GPU' 175 | if AGENT_TYPE == 'prolo': 176 | policy_agent = DeepProLoNet(distribution='one_hot', 177 | bot_name=bot_name, 178 | input_dim=dim_in, 179 | output_dim=dim_out, 180 | use_gpu=USE_GPU, 181 | vectorized=VECTORIZED, 182 | randomized=RANDOM, 183 | adversarial=ADVERSARIAL, 184 | deepen=DEEPEN, 185 | epsilon=EPSILON, 186 | epsilon_decay=EPSILON_DECAY, 187 | epsilon_min=EPSILON_MIN 188 | ) 189 | elif AGENT_TYPE == 'fc': 190 | policy_agent = FCNet(input_dim=dim_in, 191 | bot_name=bot_name, 192 | output_dim=dim_out, 193 | num_hidden=1, 194 | sl_init=SL_INIT) 195 | 196 | elif AGENT_TYPE == 'lstm': 197 | policy_agent = LSTMNet(input_dim=dim_in, 198 | bot_name=bot_name, 199 | output_dim=dim_out) 200 | elif AGENT_TYPE == 'random': 201 | policy_agent = RandomHeuristic(bot_name=bot_name, 202 | action_dim=dim_out) 203 | elif AGENT_TYPE == 'heuristic': 204 | if ENV_TYPE == 'lunar': 205 | policy_agent = LunarHeuristic(bot_name=bot_name) 206 | elif ENV_TYPE == 'cart': 207 | policy_agent = CartPoleHeuristic(bot_name=bot_name) 208 | elif AGENT_TYPE == 'djinn': 209 | policy_agent = DJINNAgent(bot_name=bot_name, 210 | input_dim=dim_in, 211 | output_dim=dim_out) 212 | else: 213 | raise Exception('No valid network selected') 214 | 215 | # if SL_INIT and i == 0: 216 | # NUM_EPS += policy_agent.action_loss_threshold 217 | num_procs = NUM_PROCS 218 | reward_array = main(NUM_EPS, policy_agent, num_procs, ENV_TYPE, seed_in) 219 | -------------------------------------------------------------------------------- /runfiles/minigame_runner.py: -------------------------------------------------------------------------------- 1 | import sc2 2 | from sc2 import Race, Difficulty 3 | import os 4 | import sys 5 | 6 | from sc2.constants import * 7 | from sc2.position import Pointlike, Point2 8 | from sc2.player import Bot, Computer 9 | from sc2.unit import Unit as sc2Unit 10 | sys.path.insert(0, os.path.abspath('../')) 11 | sys.path.insert(0, '../../') 12 | import torch 13 | from agents.prolonet_agent import DeepProLoNet 14 | from agents.heuristic_agent import RandomHeuristic, StarCraftMicroHeuristic 15 | from agents.py_djinn_agent import DJINNAgent 16 | from agents.lstm_agent import LSTMNet 17 | from agents.baseline_agent import FCNet 18 | from opt_helpers.replay_buffer import discount_reward 19 | from runfiles import sc_helpers 20 | import numpy as np 21 | import torch.multiprocessing as mp 22 | import argparse 23 | 24 | DEBUG = False 25 | SUPER_DEBUG = False 26 | if SUPER_DEBUG: 27 | DEBUG = True 28 | 29 | FAILED_REWARD = -0.0 30 | SUCCESS_BUILD_REWARD = 1. 31 | SUCCESS_TRAIN_REWARD = 1. 32 | SUCCESS_SCOUT_REWARD = 1. 33 | SUCCESS_ATTACK_REWARD = 1. 34 | SUCCESS_MINING_REWARD = 1. 35 | 36 | 37 | class SC2MicroBot(sc2.BotAI): 38 | def __init__(self, rl_agent, kill_reward=1): 39 | super(SC2MicroBot, self).__init__() 40 | self.agent = rl_agent 41 | self.kill_reward = kill_reward 42 | self.action_buffer = [] 43 | self.prev_state = None 44 | self.last_known_enemy_units = [] 45 | self.itercount = 0 46 | self.last_reward = 0 47 | self.my_tags = None 48 | self.agent_list = [] 49 | self.dead_agent_list = [] 50 | self.dead_index_mover = 0 51 | self.dead_enemies = 0 52 | 53 | async def on_step(self, iteration): 54 | 55 | if iteration == 0: 56 | self.my_tags = [unit.tag for unit in self.units] 57 | for unit in self.units: 58 | self.agent_list.append(self.agent.duplicate()) 59 | else: 60 | self.last_reward = 0 61 | for unit in self.state.dead_units: 62 | if unit in self.my_tags: 63 | self.last_reward -= 1 64 | self.dead_agent_list.append(self.agent_list[self.my_tags.index(unit)]) 65 | del self.agent_list[self.my_tags.index(unit)] 66 | del self.my_tags[self.my_tags.index(unit)] 67 | self.dead_agent_list[-1].save_reward(self.last_reward) 68 | else: 69 | self.last_reward += self.kill_reward 70 | self.dead_enemies += 1 71 | # if len(self.state.dead_units) > 0: 72 | for agent in self.agent_list: 73 | agent.save_reward(self.last_reward) 74 | for unit in self.units: 75 | if unit.tag not in self.my_tags: 76 | self.my_tags.append(unit.tag) 77 | self.agent_list.append(self.agent.duplicate()) 78 | # if iteration % 20 != 0: 79 | # return 80 | all_unit_data = [] 81 | for unit in self.units: 82 | all_unit_data.append(sc_helpers.get_unit_data(unit)) 83 | while len(all_unit_data) < 3: 84 | all_unit_data.append([-1, -1, -1, -1]) 85 | for unit, agent in zip(self.units, self.agent_list): 86 | nearest_allies = sc_helpers.get_nearest_enemies(unit, self.units) 87 | unit_data = sc_helpers.get_unit_data(unit) 88 | nearest_enemies = sc_helpers.get_nearest_enemies(unit, self.known_enemy_units) 89 | unit_data = np.array(unit_data).reshape(-1) 90 | enemy_data = [] 91 | allied_data = [] 92 | for enemy in nearest_enemies: 93 | enemy_data.extend(sc_helpers.get_enemy_unit_data(enemy)) 94 | for ally in nearest_allies[1:3]: 95 | allied_data.extend(sc_helpers.get_unit_data(ally)) 96 | enemy_data = np.array(enemy_data).reshape(-1) 97 | allied_data = np.array(allied_data).reshape(-1) 98 | state_in = np.concatenate((unit_data, allied_data, enemy_data)) 99 | action = agent.get_action(state_in) 100 | await self.execute_unit_action(unit, action, nearest_enemies) 101 | try: 102 | await self.do_actions(self.action_buffer) 103 | except sc2.protocol.ProtocolError: 104 | print("Not in game?") 105 | self.action_buffer = [] 106 | return 107 | self.action_buffer = [] 108 | 109 | async def execute_unit_action(self, unit_in, action_in, nearest_enemies): 110 | if action_in < 4: 111 | await self.move_unit(unit_in, action_in) 112 | elif action_in < 9: 113 | await self.attack_nearest(unit_in, action_in, nearest_enemies) 114 | else: 115 | pass 116 | 117 | async def move_unit(self, unit_to_move, direction): 118 | current_pos = unit_to_move.position 119 | target_destination = current_pos 120 | if direction == 0: 121 | target_destination = [current_pos.x, current_pos.y + 5] 122 | elif direction == 1: 123 | target_destination = [current_pos.x + 5, current_pos.y] 124 | elif direction == 2: 125 | target_destination = [current_pos.x, current_pos.y - 5] 126 | elif direction == 3: 127 | target_destination = [current_pos.x - 5, current_pos.y] 128 | self.action_buffer.append(unit_to_move.move(Point2(Pointlike(target_destination)))) 129 | 130 | async def attack_nearest(self, unit_to_attack, action_in, nearest_enemies_list): 131 | if len(nearest_enemies_list) > action_in-4: 132 | target = nearest_enemies_list[action_in-4] 133 | if target is None: 134 | return -1 135 | self.action_buffer.append(unit_to_attack.attack(target)) 136 | else: 137 | return -1 138 | 139 | def finish_episode(self, game_result): 140 | print("Game over!") 141 | if game_result == sc2.Result.Defeat: 142 | for index in range(len(self.agent_list), 0, -1): 143 | self.dead_agent_list.append(self.agent_list[index-1]) 144 | self.dead_agent_list[-1].save_reward(-1) 145 | del self.agent_list[:] 146 | elif game_result == sc2.Result.Tie: 147 | reward = 0 148 | elif game_result == sc2.Result.Victory: 149 | reward = 0 # - min(self.itercount/500.0, 900) + self.units.amount 150 | else: 151 | # ??? 152 | return -13 153 | if len(self.agent_list) > 0: 154 | reward_sum = sum(self.agent_list[0].replay_buffer.rewards_list) 155 | else: 156 | reward_sum = sum(self.dead_agent_list[-1].replay_buffer.rewards_list) 157 | 158 | for agent_index in range(len(self.agent_list)): 159 | rewards_list, advantage_list, deeper_advantage_list = discount_reward( 160 | self.agent_list[agent_index].replay_buffer.rewards_list, 161 | self.agent_list[agent_index].replay_buffer.value_list, 162 | self.agent_list[agent_index].replay_buffer.deeper_value_list) 163 | self.agent_list[agent_index].replay_buffer.rewards_list = rewards_list 164 | self.agent_list[agent_index].replay_buffer.advantage_list = advantage_list 165 | self.agent_list[agent_index].replay_buffer.deeper_advantage_list = deeper_advantage_list 166 | for dead_agent_index in range(len(self.dead_agent_list)): 167 | rewards_list, advantage_list, deeper_advantage_list = discount_reward( 168 | self.dead_agent_list[dead_agent_index].replay_buffer.rewards_list, 169 | self.dead_agent_list[dead_agent_index].replay_buffer.value_list, 170 | self.dead_agent_list[dead_agent_index].replay_buffer.deeper_value_list) 171 | self.dead_agent_list[dead_agent_index].replay_buffer.rewards_list = rewards_list 172 | self.dead_agent_list[dead_agent_index].replay_buffer.advantage_list = advantage_list 173 | self.dead_agent_list[dead_agent_index].replay_buffer.deeper_advantage_list = deeper_advantage_list 174 | return self.dead_enemies*self.kill_reward - len(self.dead_agent_list) 175 | 176 | 177 | def run_episode(q, main_agent, game_mode): 178 | result = None 179 | agent_in = main_agent 180 | kill_reward = 1 181 | if game_mode == 'DefeatRoaches': 182 | kill_reward = 10 183 | elif game_mode == 'DefeatZerglingsAndBanelings': 184 | kill_reward = 5 185 | bot = SC2MicroBot(rl_agent=agent_in, kill_reward=kill_reward) 186 | 187 | try: 188 | result = sc2.run_game(sc2.maps.get(game_mode), 189 | [Bot(Race.Terran, bot)], 190 | realtime=False) 191 | except KeyboardInterrupt: 192 | result = [-1, -1] 193 | except Exception as e: 194 | print(str(e)) 195 | print("No worries", e, " carry on please") 196 | if type(result) == list and len(result) > 1: 197 | result = result[0] 198 | reward_sum = bot.finish_episode(result) 199 | for agent in bot.agent_list+bot.dead_agent_list: 200 | agent_in.replay_buffer.extend(agent.replay_buffer.__getstate__()) 201 | if q is not None: 202 | try: 203 | q.put([reward_sum, agent_in.replay_buffer.__getstate__()]) 204 | except RuntimeError as e: 205 | print(e) 206 | return [reward_sum, agent_in.replay_buffer.__getstate__()] 207 | return [reward_sum, agent_in.replay_buffer.__getstate__()] 208 | 209 | 210 | def main(episodes, agent, num_processes, game_mode): 211 | running_reward_array = [] 212 | # lowered = False 213 | for episode in range(1, episodes+1): 214 | successful_runs = 0 215 | master_reward, reward, running_reward = 0, 0, 0 216 | processes = [] 217 | try: 218 | returned_object = run_episode(None, main_agent=agent, game_mode=game_mode) 219 | master_reward += returned_object[0] 220 | running_reward_array.append(returned_object[0]) 221 | # agent.replay_buffer.extend(returned_object[1]) 222 | successful_runs += 1 223 | except MemoryError as e: 224 | print(e) 225 | continue 226 | reward = master_reward / float(successful_runs) 227 | agent.end_episode(reward, num_processes) 228 | running_reward = sum(running_reward_array[-100:]) / float(min(100.0, len(running_reward_array))) 229 | if episode % 50 == 0: 230 | print(f'Episode {episode} Last Reward: {reward} Average Reward: {running_reward}') 231 | print(f"Running {num_processes} concurrent simulations per episode") 232 | if episode % 300 == 0: 233 | agent.save('../models/' + str(episode) + 'th') 234 | agent.lower_lr() 235 | return running_reward_array 236 | 237 | 238 | if __name__ == '__main__': 239 | parser = argparse.ArgumentParser() 240 | parser.add_argument("-a", "--agent_type", help="architecture of agent to run", type=str, default='prolo') 241 | parser.add_argument("-env", "--env_type", help="FindAndDefeatZerglings, DefeatRoaches, DefeatZerglingsAndBanelings", 242 | type=str, default='FindAndDefeatZerglings') 243 | parser.add_argument("-e", "--episodes", help="how many episodes", type=int, default=1000) 244 | parser.add_argument("-p", "--processes", help="how many processes?", type=int, default=1) 245 | parser.add_argument("-gpu", help="run on GPU?", action='store_true') 246 | parser.add_argument("-vec", help="Vectorized ProLoNet?", action='store_true') 247 | parser.add_argument("-adv", help="Adversarial ProLoNet?", action='store_true') 248 | parser.add_argument("-rand", help="Random ProLoNet?", action='store_true') 249 | parser.add_argument("-deep", help="Deepening?", action='store_true') 250 | parser.add_argument("-s", "--sl_init", help="sl to rl for fc net?", action='store_true') 251 | 252 | args = parser.parse_args() 253 | AGENT_TYPE = args.agent_type # 'shallow_prolo', 'prolo', 'random', 'fc', 'lstm' 254 | ADVERSARIAL = args.adv # Adversarial prolo, applies for AGENT_TYPE=='shallow_prolo' 255 | SL_INIT = args.sl_init # SL->RL fc, applies only for AGENT_TYPE=='fc' 256 | NUM_EPS = args.episodes # num episodes Default 1000 257 | NUM_PROCS = args.processes # num concurrent processes Default 1 258 | USE_GPU = args.gpu # Applies for 'prolo' only. use gpu? Default false 259 | VECTORIZED = args.vec # Applies for 'prolo' vectorized or no? Default false 260 | RANDOM = args.rand # Applies for 'prolo' random init or no? Default false 261 | DEEPEN = args.deep # Applies for 'prolo' deepen or no? Default false 262 | ENV_TYPE = args.env_type 263 | torch.set_num_threads(NUM_PROCS) 264 | dim_in = 37 265 | dim_out = 10 266 | bot_name = AGENT_TYPE + ENV_TYPE 267 | # mp.set_start_method('spawn') 268 | mp.set_sharing_strategy('file_system') 269 | for _ in range(5): 270 | 271 | if AGENT_TYPE == 'prolo': 272 | policy_agent = DeepProLoNet(distribution='one_hot', 273 | bot_name=bot_name, 274 | input_dim=dim_in, 275 | output_dim=dim_out, 276 | use_gpu=USE_GPU, 277 | vectorized=VECTORIZED, 278 | randomized=RANDOM, 279 | adversarial=ADVERSARIAL, 280 | deepen=DEEPEN) 281 | 282 | elif AGENT_TYPE == 'fc': 283 | policy_agent = FCNet(input_dim=dim_in, 284 | bot_name=bot_name, 285 | output_dim=dim_out, 286 | sl_init=SL_INIT, 287 | num_hidden=1) 288 | elif AGENT_TYPE == 'lstm': 289 | policy_agent = LSTMNet(input_dim=dim_in, 290 | bot_name=bot_name, 291 | output_dim=dim_out) 292 | elif AGENT_TYPE == 'random': 293 | policy_agent = RandomHeuristic(bot_name=bot_name, 294 | action_dim=dim_out) 295 | elif AGENT_TYPE == 'heuristic': 296 | policy_agent = StarCraftMicroHeuristic(bot_name=bot_name) 297 | elif AGENT_TYPE == 'djinn': 298 | policy_agent = DJINNAgent(bot_name=bot_name, 299 | input_dim=dim_in, 300 | output_dim=dim_out) 301 | else: 302 | raise Exception('No valid network selected') 303 | main(episodes=NUM_EPS, agent=policy_agent, num_processes=NUM_PROCS, game_mode=ENV_TYPE) 304 | -------------------------------------------------------------------------------- /runfiles/sc_build_hellions_helpers.py: -------------------------------------------------------------------------------- 1 | from sc2.constants import * 2 | from sc2.ids.unit_typeid import * 3 | import numpy as np 4 | from sc2 import Race 5 | from enum import Enum 6 | 7 | class EXPANDED_TYPES(Enum): 8 | ARMY_COUNT = 0 9 | FOOD_ARMY = 1 10 | FOOD_CAP = 2 11 | FOOD_USED = 3 12 | IDLE_WORKER_COUNT = 4 13 | LARVA_COUNT = 5 14 | MINERALS = 6 15 | VESPENE = 7 16 | WARP_GATE_COUNT = 8 17 | COMMAND_CENTER = 9 18 | SUPPLY_DEPOT = 10 19 | REFINERY = 11 20 | BARRACKS = 12 21 | ENGINEERING_BAY = 13 22 | ARMORY = 14 23 | FACTORY = 15 24 | STARPORT = 16 25 | SCV = 17 26 | MARINE = 18 27 | PENDING_COMMAND_CENTER = 19 28 | PENDING_SUPPLY_DEPOT = 20 29 | PENDING_REFINERY = 21 30 | PENDING_BARRACKS = 22 31 | PENDING_ENGINEERING_BAY = 23 32 | PENDING_ARMORY = 24 33 | PENDING_FACTORY = 25 34 | PENDING_STARPORT = 26 35 | PENDING_SCV = 27 36 | PENDING_MARINE = 28 37 | PENDING_HELLION = 29 38 | HELLION = 30 39 | LAST_ACTION = 31 40 | 41 | MY_POSSIBLES = [COMMANDCENTER, 42 | SUPPLYDEPOT, 43 | REFINERY, 44 | BARRACKS, 45 | ENGINEERINGBAY, 46 | ARMORY, 47 | FACTORY, 48 | STARPORT, 49 | SCV, 50 | MARINE, 51 | HELLION] 52 | 53 | def my_units_to_str(unit_idx): 54 | return str(MY_POSSIBLES[unit_idx]) 55 | 56 | def my_units_to_type_count(unit_array_in): 57 | """ 58 | Take in current units owned by player and return a 36-dim list of how many of each type there are 59 | :param unit_array_in: self.units from a python-sc2 bot 60 | :return: 1x36 where each element is the count of units of that type 61 | """ 62 | type_counts = np.zeros(len(MY_POSSIBLES)) 63 | for unit in unit_array_in: 64 | # print(type_counts) 65 | # print(unit) 66 | # print(unit.type_id) 67 | # print(MY_POSSIBLES) 68 | # print(MY_POSSIBLES.index(unit.type_id)) 69 | type_counts[MY_POSSIBLES.index(unit.type_id)] += 1 70 | return type_counts 71 | 72 | 73 | def get_player_state(state_in): 74 | 75 | sorted_observed_player_state = [ 76 | state_in.observation.player_common.army_count, 77 | state_in.observation.player_common.food_army, 78 | state_in.observation.player_common.food_cap, 79 | state_in.observation.player_common.food_used, 80 | state_in.observation.player_common.idle_worker_count, 81 | state_in.observation.player_common.larva_count, 82 | state_in.observation.player_common.minerals, 83 | state_in.observation.player_common.vespene, 84 | state_in.observation.player_common.warp_gate_count 85 | ] 86 | return sorted_observed_player_state 87 | 88 | 89 | def get_human_readable_mapping(): 90 | idx_to_name = {} 91 | name_to_idx = {} 92 | 93 | idx = 0 94 | 95 | # purpose is to make mappings between human-readable names and indices in the statespace: 96 | # self.prev_state = np.concatenate((current_state, 97 | # my_unit_type_arr, 98 | # enemy_unit_type_arr, (not using) 99 | # pending, 100 | # last_act)) 101 | 102 | # current_state 103 | current_state_names = [ 104 | 'army_count', 105 | 'food_army', 106 | 'food_cap', 107 | 'food_used', 108 | 'idle_worker_count', 109 | 'larva_count', 110 | 'minerals', 111 | 'vespene', 112 | 'warp_gate_count' 113 | ] 114 | 115 | my_unit_type_names = [] 116 | for e in MY_POSSIBLES: 117 | e = str(e) 118 | e = e.replace('UnitTypeId.', '') 119 | my_unit_type_names.append(e) 120 | 121 | 122 | 123 | for e in current_state_names + my_unit_type_names: 124 | e = str(e) 125 | e = e.replace('UnitTypeId.', '') 126 | # print(e) 127 | idx_to_name[idx] = e 128 | name_to_idx[e] = idx 129 | idx += 1 130 | 131 | # pending 132 | for e in my_unit_type_names: 133 | e = "PENDING_" + e 134 | idx_to_name[idx] = e 135 | name_to_idx[e] = idx 136 | idx += 1 137 | 138 | # last act 139 | e = 'last_act' 140 | idx_to_name[idx] = e 141 | name_to_idx[e] = idx 142 | idx += 1 143 | 144 | return idx_to_name, name_to_idx 145 | 146 | def get_human_readable_action_mapping(): 147 | idx_to_action = {} 148 | idx = 0 149 | for e in MY_POSSIBLES: 150 | e = str(e) 151 | e = e.replace('UnitTypeId.', '') 152 | idx_to_action[idx] = e 153 | idx += 1 154 | 155 | idx_to_action[idx] = 'back_to_mining' 156 | idx += 1 157 | 158 | return idx_to_action 159 | 160 | def get_unit_data(unit_in): 161 | if unit_in is None: 162 | return [-1, -1, -1, -1] 163 | extracted_data = [ 164 | float(unit_in.position.x), 165 | float(unit_in.position.y), 166 | float(unit_in.health), 167 | float(unit_in.weapon_cooldown), 168 | ] 169 | return extracted_data 170 | 171 | 172 | def get_enemy_unit_data(unit_in): 173 | data = get_unit_data(unit_in) 174 | if unit_in is None: 175 | data.append(-1) 176 | else: 177 | data.append(float(unit_in.type_id == UnitTypeId.BANELING)) 178 | return data 179 | 180 | 181 | def get_nearest_enemies(unit_in, enemy_list): 182 | my_pos = [unit_in.position.x, unit_in.position.y] 183 | distances = [] 184 | enemies = [] 185 | for enemy in enemy_list: 186 | enemy_pos = [enemy.position.x, enemy.position.y] 187 | distances.append(dist(my_pos, enemy_pos)) 188 | enemies.append(enemy) 189 | sorted_results = [x for _, x in sorted(zip(distances,enemies), key=lambda pair: pair[0])] 190 | sorted_results.extend([None, None, None, None, None]) 191 | return sorted_results[:5] 192 | 193 | 194 | def dist(pos1, pos2): 195 | return np.sqrt((pos1[0]-pos2[0])**2 + (pos1[1]-pos2[1])**2) 196 | -------------------------------------------------------------------------------- /runfiles/sc_helpers.py: -------------------------------------------------------------------------------- 1 | from sc2.constants import * 2 | from sc2.ids.unit_typeid import * 3 | import numpy as np 4 | from sc2 import Race 5 | MY_POSSIBLES = [PROBE, ZEALOT, STALKER, SENTRY, ADEPT, HIGHTEMPLAR, DARKTEMPLAR, OBSERVER, WARPPRISM, 6 | IMMORTAL, COLOSSUS, DISRUPTOR, PHOENIX, VOIDRAY, ORACLE, TEMPEST, CARRIER, INTERCEPTOR, 7 | MOTHERSHIP, NEXUS, PYLON, ASSIMILATOR, GATEWAY, WARPGATE, FORGE, CYBERNETICSCORE, PHOTONCANNON, 8 | SHIELDBATTERY, ROBOTICSFACILITY, STARGATE, TWILIGHTCOUNCIL, ROBOTICSBAY, FLEETBEACON, 9 | TEMPLARARCHIVE, DARKSHRINE, ORACLESTASISTRAP] 10 | 11 | ENEMY_POSSIBLES = [PROBE, ZEALOT, STALKER, SENTRY, ADEPT, HIGHTEMPLAR, DARKTEMPLAR, OBSERVER, ARCHON, WARPPRISM, 12 | IMMORTAL, COLOSSUS, DISRUPTOR, PHOENIX, VOIDRAY, ORACLE, TEMPEST, CARRIER, INTERCEPTOR, 13 | MOTHERSHIP, NEXUS, PYLON, ASSIMILATOR, GATEWAY, WARPGATE, FORGE, CYBERNETICSCORE, PHOTONCANNON, 14 | SHIELDBATTERY, ROBOTICSFACILITY, STARGATE, TWILIGHTCOUNCIL, ROBOTICSBAY, FLEETBEACON, 15 | TEMPLARARCHIVE, DARKSHRINE, ORACLESTASISTRAP, SCV, MULE, MARINE, REAPER, MARAUDER, GHOST, HELLION, WIDOWMINE, 16 | CYCLONE, SIEGETANKSIEGED, THOR, VIKINGFIGHTER, MEDIVAC, LIBERATOR, RAVEN, BANSHEE, BATTLECRUISER, COMMANDCENTER, 17 | SUPPLYDEPOT, REFINERY, ENGINEERINGBAY, BUNKER, MISSILETURRET, SENSORTOWER, GHOSTACADEMY, FACTORY, BARRACKS, STARPORT, 18 | FUSIONCORE, TECHLAB, REACTOR, AUTOTURRET, ORBITALCOMMAND, PLANETARYFORTRESS, HELLIONTANK, LARVA, DRONE, OVERLORD, 19 | QUEEN, ZERGLING, BANELING, ROACH, RAVAGER, OVERSEER, CHANGELING, HYDRALISK, LURKER, MUTALISK, CORRUPTOR, 20 | SWARMHOSTMP, LOCUSTMP, INFESTOR, INFESTEDTERRAN, VIPER, ULTRALISK, BROODLORD, BROODLING, HATCHERY, EXTRACTOR, 21 | SPAWNINGPOOL, SPINECRAWLER, SPORECRAWLER, EVOLUTIONCHAMBER, ROACHWARREN, BANELINGNEST, HYDRALISKDEN, LURKERDENMP, 22 | SPIRE, NYDUSNETWORK, INFESTATIONPIT, ULTRALISKCAVERN, CREEPTUMOR, LAIR, HIVE, GREATERSPIRE] 23 | ENEMY_MAPPINGS = { 24 | WIDOWMINEBURROWED: WIDOWMINE, # widow mine / burrowed widow mine 25 | SIEGETANK: SIEGETANKSIEGED, # siege tank siege / tank 26 | VIKINGASSAULT: VIKINGFIGHTER, # viking assault / viking fighter 27 | COMMANDCENTERFLYING: COMMANDCENTER, # flying command center / command center 28 | ORBITALCOMMANDFLYING: ORBITALCOMMAND, # flying orbital command / command center 29 | SUPPLYDEPOTLOWERED: SUPPLYDEPOT, # supply depot / lowered 30 | BARRACKSFLYING: BARRACKS, # barracks / flying barracks 31 | FACTORYFLYING: FACTORY, # factory / factory flying 32 | STARPORTFLYING: STARPORT, # starport / starport flying 33 | BARRACKSTECHLAB: TECHLAB, # barracks tech lab / tech lab 34 | FACTORYTECHLAB: TECHLAB, # factory tech lab / tech lab 35 | STARPORTTECHLAB: TECHLAB, # starport tech lab / tech lab 36 | BARRACKSREACTOR: REACTOR, # barracks reactor / reactor 37 | FACTORYREACTOR: REACTOR, # factory reactor / reactor 38 | STARPORTREACTOR: REACTOR, # starport reactor / reactor 39 | DRONEBURROWED: DRONE, 40 | BANELINGBURROWED: BANELING, 41 | HYDRALISKBURROWED: HYDRALISK, 42 | ROACHBURROWED: ROACH, 43 | ZERGLINGBURROWED: ZERGLING, 44 | QUEENBURROWED: QUEEN, 45 | RAVAGERBURROWED: RAVAGER, 46 | CHANGELINGZEALOT: CHANGELING, 47 | CHANGELINGMARINESHIELD: CHANGELING, 48 | CHANGELINGMARINE: CHANGELING, 49 | CHANGELINGZERGLINGWINGS: CHANGELING, 50 | CHANGELINGZERGLING: CHANGELING, 51 | LURKERBURROWED: LURKER, 52 | SWARMHOSTBURROWEDMP: SWARMHOSTMP, 53 | LOCUSTMPFLYING: LOCUSTMP, 54 | INFESTORTERRANBURROWED: INFESTOR, 55 | INFESTORBURROWED: INFESTOR, 56 | INFESTORTERRAN: INFESTOR, 57 | ULTRALISKBURROWED: ULTRALISK, 58 | SPINECRAWLERUPROOTED: SPINECRAWLER, 59 | SPORECRAWLERUPROOTED: SPORECRAWLER, 60 | LURKERDEN: LURKERDENMP, 61 | } 62 | 63 | index_to_upgrade = { 64 | 34: "ground_attacks", 65 | 35: "air_attacks", 66 | 36: "ground_armor", 67 | 37: "air_armor", 68 | 38: "shields", 69 | 39: "speed", 70 | 40: "range", 71 | 41: "spells", 72 | 42: "misc" 73 | } 74 | 75 | index_to_unit = { 76 | 0: UnitTypeId.NEXUS, 77 | 1: UnitTypeId.PYLON, 78 | 2: UnitTypeId.ASSIMILATOR, 79 | 3: UnitTypeId.GATEWAY, 80 | 4: UnitTypeId.WARPGATE, 81 | 5: UnitTypeId.FORGE, 82 | 6: UnitTypeId.CYBERNETICSCORE, 83 | 7: UnitTypeId.PHOTONCANNON, 84 | 8: UnitTypeId.SHIELDBATTERY, 85 | 9: UnitTypeId.ROBOTICSFACILITY, 86 | 10: UnitTypeId.STARGATE, 87 | 11: UnitTypeId.TWILIGHTCOUNCIL, 88 | 12: UnitTypeId.ROBOTICSBAY, 89 | 13: UnitTypeId.FLEETBEACON, 90 | 14: UnitTypeId.TEMPLARARCHIVE, 91 | 15: UnitTypeId.DARKSHRINE 92 | } 93 | action_to_unit = { 94 | 16: UnitTypeId.PROBE, 95 | 17: UnitTypeId.ZEALOT, 96 | 18: UnitTypeId.STALKER, 97 | 19: UnitTypeId.SENTRY, 98 | 20: UnitTypeId.ADEPT, 99 | 21: UnitTypeId.HIGHTEMPLAR, 100 | 22: UnitTypeId.DARKTEMPLAR, 101 | 23: UnitTypeId.OBSERVER, 102 | 24: UnitTypeId.WARPPRISM, 103 | 25: UnitTypeId.IMMORTAL, 104 | 26: UnitTypeId.COLOSSUS, 105 | 27: UnitTypeId.DISRUPTOR, 106 | 28: UnitTypeId.PHOENIX, 107 | 29: UnitTypeId.VOIDRAY, 108 | 30: UnitTypeId.ORACLE, 109 | 31: UnitTypeId.TEMPEST, 110 | 32: UnitTypeId.CARRIER, 111 | 33: UnitTypeId.MOTHERSHIP, 112 | # 34: UnitTypeId.INTERCEPTOR, # TRAIN BY DEFAULT, DONT NEED TO TRAIN 113 | # 35: UnitTypeId.ARCHON, # CURRENT IMPOSSIBLE WITH THE API 114 | } 115 | 116 | def get_human_readable_mapping(): 117 | idx_to_name = {} 118 | name_to_idx = {} 119 | 120 | idx = 0 121 | 122 | # purpose is to make mappings between human-readable names and indices in the statespace: 123 | # self.prev_state = np.concatenate((current_state, 124 | # my_unit_type_arr, 125 | # enemy_unit_type_arr, 126 | # pending, 127 | # last_act)) 128 | 129 | # current_state 130 | current_state_names = [ 131 | 'army_count', 132 | 'food_army', 133 | 'food_cap', 134 | 'food_used', 135 | 'idle_worker_count', 136 | 'larva_count', 137 | 'minerals', 138 | 'vespene', 139 | 'warp_gate_count' 140 | ] 141 | 142 | my_unit_type_names = [] 143 | for e in MY_POSSIBLES: 144 | e = str(e) 145 | e = e.replace('UnitTypeId.', '') 146 | my_unit_type_names.append(e) 147 | 148 | enemy_unit_type_names = [] 149 | for e in ENEMY_POSSIBLES: 150 | e = str(e) 151 | e = e.replace('UnitTypeId.', '') 152 | enemy_unit_type_names.append(e) 153 | 154 | for e in current_state_names + my_unit_type_names: 155 | e = str(e) 156 | e = e.replace('UnitTypeId.', '') 157 | # print(e) 158 | idx_to_name[idx] = e 159 | name_to_idx[e] = idx 160 | idx += 1 161 | 162 | # for i in index_to_upgrade.keys(): 163 | # idx_to_name[i] = index_to_upgrade[i] 164 | # name_to_idx[index_to_upgrade[i]] = i 165 | # idx += 1 166 | 167 | for e in enemy_unit_type_names: 168 | e = str(e) 169 | e = e.replace('UnitTypeId.', '') 170 | # print(e) 171 | idx_to_name[idx] = e 172 | name_to_idx[e] = idx 173 | idx += 1 174 | 175 | # nop 176 | # idx_to_name[idx] = 'nop' 177 | # name_to_idx['nop'] = idx 178 | # idx += 1 179 | 180 | # pending 181 | for e in my_unit_type_names: 182 | e = "PENDING_" + e 183 | idx_to_name[idx] = e 184 | name_to_idx[e] = idx 185 | idx += 1 186 | 187 | # last act 188 | e = 'last_act' 189 | idx_to_name[idx] = e 190 | name_to_idx[e] = idx 191 | idx += 1 192 | 193 | # print(idx_to_name) 194 | # quit() 195 | 196 | return idx_to_name, name_to_idx 197 | 198 | def get_human_readable_action_mapping(): 199 | idx_to_action = {} 200 | idx = 0 201 | for e in index_to_unit: 202 | e = index_to_unit[e] 203 | e = str(e) 204 | e = e.replace('UnitTypeId.', '') 205 | idx_to_action[idx] = e 206 | idx += 1 207 | for e in action_to_unit: 208 | e = action_to_unit[e] 209 | e = str(e) 210 | e = e.replace('UnitTypeId.', '') 211 | idx_to_action[idx] = e 212 | idx += 1 213 | for i in index_to_upgrade.keys(): 214 | idx_to_action[i] = index_to_upgrade[i] 215 | idx += 1 216 | idx_to_action[idx] = 'nop' 217 | idx += 1 218 | return idx_to_action 219 | 220 | def my_units_to_type_count(unit_array_in): 221 | """ 222 | Take in current units owned by player and return a 36-dim list of how many of each type there are 223 | :param unit_array_in: self.units from a python-sc2 bot 224 | :return: 1x36 where each element is the count of units of that type 225 | """ 226 | type_counts = np.zeros(len(MY_POSSIBLES)) 227 | for unit in unit_array_in: 228 | type_counts[MY_POSSIBLES.index(unit.type_id)] += 1 229 | return type_counts 230 | 231 | 232 | def enemy_units_to_type_count(enemy_array_in): 233 | """ 234 | Take in enemy units and map them to type counts, as in my army. But here I consider all 3 races 235 | :param enemy_array_in: self.known_enemy_units (or my all_enemy_units list) from python-sc2 bot 236 | :return: 1x111 where each element is the count of units of that type 237 | """ 238 | type_counts = np.zeros(len(ENEMY_POSSIBLES)) 239 | for unit in enemy_array_in: 240 | if unit.type_id in ENEMY_POSSIBLES: 241 | type_counts[ENEMY_POSSIBLES.index(unit.type_id)] += 1 242 | elif unit.type_id in ENEMY_MAPPINGS.keys(): 243 | type_counts[ENEMY_POSSIBLES.index(ENEMY_MAPPINGS[unit.type_id])] += 1 244 | else: 245 | continue 246 | return type_counts 247 | 248 | 249 | def get_player_state(state_in): 250 | 251 | sorted_observed_player_state = [ 252 | state_in.observation.player_common.army_count, 253 | state_in.observation.player_common.food_army, 254 | state_in.observation.player_common.food_cap, 255 | state_in.observation.player_common.food_used, 256 | state_in.observation.player_common.idle_worker_count, 257 | state_in.observation.player_common.larva_count, 258 | state_in.observation.player_common.minerals, 259 | state_in.observation.player_common.vespene, 260 | state_in.observation.player_common.warp_gate_count 261 | ] 262 | return sorted_observed_player_state 263 | 264 | 265 | def get_unit_data(unit_in): 266 | if unit_in is None: 267 | return [-1, -1, -1, -1] 268 | extracted_data = [ 269 | float(unit_in.position.x), 270 | float(unit_in.position.y), 271 | float(unit_in.health), 272 | float(unit_in.weapon_cooldown), 273 | ] 274 | return extracted_data 275 | 276 | 277 | def get_enemy_unit_data(unit_in): 278 | data = get_unit_data(unit_in) 279 | if unit_in is None: 280 | data.append(-1) 281 | else: 282 | data.append(float(unit_in.type_id == UnitTypeId.BANELING)) 283 | return data 284 | 285 | 286 | def get_nearest_enemies(unit_in, enemy_list): 287 | my_pos = [unit_in.position.x, unit_in.position.y] 288 | distances = [] 289 | enemies = [] 290 | for enemy in enemy_list: 291 | enemy_pos = [enemy.position.x, enemy.position.y] 292 | distances.append(dist(my_pos, enemy_pos)) 293 | enemies.append(enemy) 294 | sorted_results = [x for _, x in sorted(zip(distances,enemies), key=lambda pair: pair[0])] 295 | sorted_results.extend([None, None, None, None, None]) 296 | return sorted_results[:5] 297 | 298 | 299 | def dist(pos1, pos2): 300 | return np.sqrt((pos1[0]-pos2[0])**2 + (pos1[1]-pos2[1])**2) 301 | -------------------------------------------------------------------------------- /runfiles/sc_replays/.gitignore: -------------------------------------------------------------------------------- 1 | *.SC2REPLAY 2 | -------------------------------------------------------------------------------- /runfiles/sc_replays/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CORE-Robotics-Lab/ProLoNets/b1f138ed3520e623a84f2a7d5b1969fa45845700/runfiles/sc_replays/.gitkeep -------------------------------------------------------------------------------- /runfiles/visuals/figures/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /runfiles/visuals/visualize_build_hellions.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.insert(0, os.path.abspath('../../')) 6 | 7 | from agents.prolonet_agent import DeepProLoNet 8 | from runfiles.sc_build_hellions_helpers import get_human_readable_mapping, get_human_readable_action_mapping 9 | import visualize_prolonet 10 | 11 | 12 | if __name__ == '__main__': 13 | idx_to_name, name_to_idx = get_human_readable_mapping() 14 | idx_to_action = get_human_readable_action_mapping() 15 | dim_in = len(idx_to_name) 16 | dim_out = len(idx_to_action) 17 | bot_name = 'prolo' + '_hellions' 18 | 19 | policy_agent = DeepProLoNet(distribution='one_hot', 20 | bot_name=bot_name, 21 | input_dim=dim_in, 22 | output_dim=dim_out, 23 | use_gpu=False, 24 | vectorized=False, 25 | randomized=False, 26 | adversarial=False, 27 | deepen=False, 28 | deterministic=True) 29 | 30 | visualize_prolonet.visualize_prolonet(prolonet=policy_agent.value_network, raw_indices=True, 31 | idx_to_names=idx_to_name, 32 | idx_to_actions=idx_to_action, max_lines=4, node_size=30000, 33 | node_color='#d9d9d9', 34 | show=False, 35 | save_fig_filename='figures/1thprolo_hellions_value_network.png', 36 | save_fig_dimensions=(14, 15)) 37 | 38 | visualize_prolonet.visualize_prolonet(prolonet=policy_agent.action_network, raw_indices=True, 39 | idx_to_names=idx_to_name, 40 | idx_to_actions=idx_to_action, max_lines=4, node_size=30000, 41 | node_color='#d9d9d9', 42 | show=False, 43 | save_fig_filename='figures/1thprolo_hellions_action_network.png', 44 | save_fig_dimensions=(14, 15)) 45 | 46 | # Can load model from save 47 | load_result = policy_agent.load(fn_botname='../../models/1000thprolo_hellions_gpu') 48 | if not load_result: 49 | print("model file not found") 50 | quit() 51 | 52 | # Can then go on to another visualization with this loaded model... 53 | 54 | -------------------------------------------------------------------------------- /runfiles/visuals/visualize_prolonet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import networkx as nx 4 | import numpy as np 5 | # import pydot 6 | from networkx.drawing.nx_pydot import graphviz_layout 7 | 8 | def visualize_prolonet(prolonet, raw_indices=False, idx_to_names=None, 9 | dont_care_about_weights_smaller_than=0.005, 10 | idx_to_actions=None, max_lines=5, node_size=8000, node_color='white', show=True, 11 | save_fig_filename=None, save_fig_dimensions=None): 12 | 13 | g = nx.DiGraph() 14 | 15 | node_relabel = {} 16 | 17 | attached_node_indices = set() 18 | for leaf in prolonet.leaf_init_information: 19 | for left_parent in leaf[0]: 20 | attached_node_indices.add(left_parent) 21 | for right_parent in leaf[1]: 22 | attached_node_indices.add(right_parent) 23 | 24 | for leaf in attached_node_indices: 25 | deepest_node = leaf 26 | 27 | weights = prolonet.layers[deepest_node] 28 | comparator = prolonet.comparators[deepest_node] 29 | 30 | readable_weights = [] 31 | 32 | weights_descending = torch.argsort(torch.abs(weights)).detach().numpy() 33 | for idx in weights_descending[::-1]: 34 | if abs(float(weights[idx])) > dont_care_about_weights_smaller_than: 35 | new_str = '({:.2f})'.format(float(weights[idx])) + str(idx_to_names[int(idx)]) 36 | if raw_indices: 37 | new_str += '[' + str(int(idx)) + ']' 38 | readable_weights.append(new_str) 39 | 40 | comparator_string = '' 41 | num_lines = 0 42 | for i in range(len(readable_weights)): 43 | comparator_string += readable_weights[i] 44 | if num_lines > max_lines: 45 | comparator_string += '...' 46 | break 47 | num_lines += 1 48 | if i < len(readable_weights) - 1: 49 | comparator_string += ' +\n' 50 | 51 | comparator_string += '\n > ' 52 | comparator_string += '{:.2f}'.format(float(comparator)) 53 | 54 | summary_string = 'Comp [' 55 | summary_string += str(deepest_node) + ']\n' 56 | summary_string += comparator_string 57 | 58 | node_relabel[deepest_node] = summary_string 59 | 60 | g.add_node(deepest_node, color=node_color) 61 | 62 | leaf_idx = 0 63 | actual_leaves = prolonet.action_probs 64 | 65 | 66 | for leaf_left, leaf_right, leaf_actions in prolonet.leaf_init_information: 67 | leaf_actions = actual_leaves[leaf_idx].data.cpu().numpy() 68 | 69 | readable_actions = [] 70 | leaf_actions_descending = np.argsort(np.abs(leaf_actions)) 71 | for i in leaf_actions_descending[::-1]: 72 | if abs(float(leaf_actions[i])) > dont_care_about_weights_smaller_than: 73 | new_str = '({:.2f})'.format(leaf_actions[i]) + str(idx_to_actions[i]) 74 | if raw_indices: 75 | new_str += '[' + str(int(i)) + ']' 76 | readable_actions.append(new_str) 77 | 78 | actions_string = 'Action [' + str(leaf_idx) + ']\n' 79 | num_lines = 0 80 | for action in readable_actions: 81 | if num_lines > max_lines: 82 | actions_string += '...' 83 | break 84 | num_lines += 1 85 | actions_string += str(action) + '\n' 86 | 87 | g.add_node(actions_string, color=node_color) 88 | 89 | for left_parent in leaf_left: 90 | g.add_edge(left_parent, actions_string, edge_color="green", width=4) 91 | 92 | for right_parent in leaf_right: 93 | g.add_edge(right_parent, actions_string, edge_color="red", width=2) 94 | 95 | leaf_idx += 1 96 | 97 | g = nx.relabel_nodes(g, node_relabel) 98 | 99 | edge_colors = [g.edges[edge]["edge_color"] for edge in g.edges()] 100 | 101 | pos = nx.drawing.nx_agraph.graphviz_layout(g, prog='dot', args='-Grankdir=LR') 102 | color_list = [g.nodes[node]["color"] for node in g.nodes()] 103 | edge_width_list = [g.edges[edge]["width"] for edge in g.edges()] 104 | nx.draw(g, pos, node_color=color_list, edge_color=edge_colors, width=edge_width_list, 105 | with_labels=True, node_size=node_size) 106 | figure = plt.gcf() 107 | if save_fig_filename != None: 108 | figure.set_size_inches(save_fig_dimensions) 109 | plt.savefig(save_fig_filename, dpi=200) 110 | if show: 111 | plt.show() 112 | 113 | plt.cla() -------------------------------------------------------------------------------- /runfiles/visuals/visualize_sc_runner.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.insert(0, os.path.abspath('../../')) 6 | 7 | from agents.prolonet_agent import DeepProLoNet 8 | from runfiles.sc_helpers import get_human_readable_mapping, get_human_readable_action_mapping 9 | import visualize_prolonet 10 | 11 | 12 | if __name__ == '__main__': 13 | idx_to_name, name_to_idx = get_human_readable_mapping() 14 | idx_to_action = get_human_readable_action_mapping() 15 | dim_in = len(idx_to_name) 16 | dim_out = len(idx_to_action) 17 | print(dim_in, dim_out) 18 | bot_name = 'prolo' + '_sc_runner' 19 | 20 | policy_agent = DeepProLoNet(distribution='one_hot', 21 | bot_name=bot_name, 22 | input_dim=dim_in, 23 | output_dim=dim_out, 24 | use_gpu=False, 25 | vectorized=False, 26 | randomized=False, 27 | adversarial=False, 28 | deepen=False, 29 | deterministic=True) 30 | 31 | visualize_prolonet.visualize_prolonet(prolonet=policy_agent.value_network, raw_indices=True, 32 | idx_to_names=idx_to_name, 33 | idx_to_actions=idx_to_action, max_lines=4, node_size=30000, 34 | node_color='#d9d9d9', 35 | show=False, 36 | save_fig_filename='figures/1thprolo_sc_runner_value_network.png', 37 | save_fig_dimensions=(14, 20)) 38 | 39 | visualize_prolonet.visualize_prolonet(prolonet=policy_agent.action_network, raw_indices=True, 40 | idx_to_names=idx_to_name, 41 | idx_to_actions=idx_to_action, max_lines=4, node_size=30000, 42 | node_color='#d9d9d9', 43 | show=False, 44 | save_fig_filename='figures/1thprolo_sc_runner_action_network.png', 45 | save_fig_dimensions=(14, 20)) 46 | -------------------------------------------------------------------------------- /sc2_requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.4.0 2 | aiohttp==3.7.4 3 | async-timeout==3.0.0 4 | attrs==18.1.0 5 | backcall==0.1.0 6 | bleach==3.3.0 7 | certifi==2018.8.24 8 | chardet==3.0.4 9 | cycler==0.10.0 10 | decorator==4.3.0 11 | entrypoints==0.2.3 12 | enum34==1.1.6 13 | future==0.16.0 14 | html5lib==1.0.1 15 | idna==2.7 16 | idna-ssl==1.1.0 17 | ipykernel==4.8.2 18 | ipython==6.5.0 19 | ipython-genutils==0.2.0 20 | ipywidgets==7.4.0 21 | jedi==0.12.1 22 | Jinja2==2.11.3 23 | jsonschema==2.6.0 24 | jupyter==1.0.0 25 | jupyter-client==5.2.3 26 | jupyter-console==5.2.0 27 | jupyter-core==4.4.0 28 | kiwisolver==1.0.1 29 | MarkupSafe==1.0 30 | matplotlib==2.2.3 31 | mistune==0.8.3 32 | mock==2.0.0 33 | mpyq==0.2.5 34 | multidict==4.3.1 35 | nbconvert==5.3.1 36 | nbformat==4.4.0 37 | notebook>=6.1.5 38 | numpy==1.15.1 39 | opencv-contrib-python==3.4.3.18 40 | pandas==0.23.4 41 | pandocfilters==1.4.2 42 | parso==0.3.1 43 | pbr==4.2.0 44 | pexpect==4.6.0 45 | pickleshare==0.7.4 46 | Pillow>=7.1.0 47 | portpicker==1.2.0 48 | prometheus-client==0.3.1 49 | prompt-toolkit==1.0.15 50 | protobuf==3.6.1 51 | ptyprocess==0.6.0 52 | pygame==1.9.4 53 | Pygments==2.7.4 54 | pyparsing==2.2.0 55 | PySC2==2.0.1 56 | python-dateutil==2.7.3 57 | pytz==2018.5 58 | pyzmq==17.1.2 59 | qtconsole==4.4.1 60 | requests==2.21.0 61 | s2clientprotocol==4.5.1.67344.0 62 | sc2==0.10.6 63 | sc2reader==1.1.0 64 | scipy==1.1.0 65 | Send2Trash==1.5.0 66 | simplegeneric==0.8.1 67 | six==1.11.0 68 | sk-video==1.1.10 69 | terminado==0.8.1 70 | testpath==0.3.1 71 | torch==1.0.0 72 | torchvision==0.2.1 73 | tornado==5.1 74 | traitlets==4.3.2 75 | urllib3==1.26.5 76 | wcwidth==0.1.7 77 | webencodings==0.5.1 78 | websocket-client==0.51.0 79 | whichcraft==0.4.1 80 | widgetsnbextension==3.4.0 81 | yarl==1.2.6 -------------------------------------------------------------------------------- /txts/README.md: -------------------------------------------------------------------------------- 1 | Reward files are saved here by default. --------------------------------------------------------------------------------