├── .github └── workflows │ └── workflow.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── agent_zoo ├── __init__.py ├── hybrid.py ├── neurips23_start_kit │ ├── __init__.py │ ├── baseline_policy.py │ └── reward_wrapper.py ├── takeru │ ├── README.md │ ├── __init__.py │ ├── policy.py │ └── reward_wrapper.py └── yaofeng │ ├── __init__.py │ ├── policy.py │ └── reward_wrapper.py ├── analysis ├── proc_eval_result.py ├── proc_task_cond_result.py └── run_task_conditioning.py ├── config.yaml ├── curriculum_generation ├── __init__.py ├── curriculum_tutorial.py ├── curriculum_with_embedding.pkl ├── custom_curriculum_with_embedding.pkl ├── manual_curriculum.py ├── task_encoder.py └── task_sampler.py ├── evaluate.py ├── neurips23_evaluation ├── export_embeddings.py ├── heldout_evaluation_task.py ├── heldout_task_with_embedding.pkl ├── sample_eval_task_with_embedding.pkl └── sample_evaluation_task.py ├── policies ├── README.md ├── baseline_10M.pt ├── elo.db ├── eval_pvp_1.json ├── eval_pvp_12196392.json ├── eval_pvp_19525770.json ├── eval_pvp_28034063.json ├── eval_pvp_31128942.json ├── eval_pvp_42720373.json ├── eval_pvp_47914166.json ├── eval_pvp_97868113.json ├── eval_pvp_99672462.json ├── eval_pvp_99910280.json ├── score_by_seed.tsv ├── score_by_task_seed.tsv ├── score_category_summary.tsv ├── score_summary.tsv ├── score_task_summary.tsv ├── takeru_100M.pt ├── takeru_200M.pt ├── takeru_25M.pt ├── takeru_50M.pt ├── yaofeng_100M.pt ├── yaofeng_200M.pt ├── yaofeng_25M.pt └── yaofeng_50M.pt ├── pyproject.toml ├── reinforcement_learning ├── __init__.py ├── clean_pufferl.py ├── environment.py └── stat_wrapper.py ├── scripts ├── evaluate_policies.sh ├── pre-git-check.sh ├── slurm_jobs.sh ├── slurm_run.sh ├── slurm_run_cpu.sh ├── train_baseline.sh ├── upload_checkpoints.sh └── upload_latest_checkpoint.sh ├── syllabus_wrapper.py ├── tests └── test_task_encoder.py ├── train.py └── train_helper.py /.github/workflows/workflow.yml: -------------------------------------------------------------------------------- 1 | name: tox 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-latest 7 | strategy: 8 | fail-fast: false 9 | matrix: 10 | py: ["3.8", "3.9", "3.10"] 11 | steps: 12 | - name: Setup python for test ${{ matrix.py }} 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: ${{ matrix.py }} 16 | - uses: actions/checkout@v3 17 | - name: Upgrade pip 18 | run: python -m pip install -U pip setuptools wheel cython 19 | - name: Install 20 | run: python -m pip install -e '.[dev]' 21 | - name: Check formatting 22 | run: ruff format . 23 | - name: Check lint 24 | run: ruff check . -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | experiments 2 | maps 3 | runs 4 | replay*/ 5 | checkpoints 6 | wandb 7 | maps 8 | pool 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | venv/ 116 | # ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # IDEs 139 | .idea/ 140 | .vscode/ 141 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.3.2 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | # Run the formatter. 9 | - id: ruff-format -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Neural MMO 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![figure](https://neuralmmo.github.io/_static/banner.jpg) 2 | 3 | # ![icon](https://neuralmmo.github.io/_build/html/_images/icon.png) Welcome to the Platform! 4 | 5 | [![PyPI version](https://badge.fury.io/py/nmmo.svg)](https://badge.fury.io/py/nmmo) 6 | [![](https://dcbadge.vercel.app/api/server/BkMmFUC?style=plastic)](https://discord.gg/BkMmFUC) 7 | [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow%20%40jsuarez5341)](https://twitter.com/jsuarez5341) 8 | 9 | [Documentation](https://neuralmmo.github.io "Neural MMO Documentation") is hosted by github.io. 10 | 11 | ## Installation 12 | 13 | After cloning this repo, run: 14 | 15 | ``` 16 | pip install -e .[dev] 17 | ``` 18 | 19 | ## Training 20 | 21 | To test if the installation was successful (with the `--debug` mode), run the following command: 22 | 23 | ``` 24 | python train.py --debug --no-track 25 | ``` 26 | 27 | To log the training process, edit the wandb section in `config.yaml` and remove `--no-track` from the command line. The `config.yaml` file contains various configuration settings for the project. 28 | 29 | ### Agent zoo and your custom policy 30 | 31 | This baseline comes with four different models under the `agent_zoo` directory: `neurips23_start_kit`, `yaofeng`, `takeru`, and `hybrid`. You can use any of these models by specifying the `-a` argument. 32 | 33 | ``` 34 | python train.py -a hybrid 35 | ``` 36 | 37 | You can also create your own policy by creating a new module under the `agent_zoo` directory, which should contain `Policy`, `Recurrent`, and `RewardWrapper` classes. 38 | 39 | ### Curriculum Learning using Syllabus 40 | 41 | The training script supports automatic curriculum learning using the [Syllabus](https://github.com/RyanNavillus/Syllabus) library. To use it, add `--syllabus` to the command line. 42 | 43 | ``` 44 | python train.py --syllabus 45 | ``` 46 | 47 | ## Replay generation 48 | 49 | The `policies` directory contains a set of trained policies. For your models, create a directory and copy the checkpoint files to it. To generate a replay, run the following command: 50 | 51 | ``` 52 | python train.py -m replay -p policies 53 | ``` 54 | 55 | The replay file ends with `.replay.lzma`. You can view the replay using the [web viewer](https://kywch.github.io/nmmo-client/). 56 | 57 | ## Evaluation 58 | 59 | The evaluation script supports the pvp and pve modes. The pve mode spawns all agents using only one policy. The pvp mode spawns groups of agents, each controlled by a different policy. 60 | 61 | To evaluate models in the `policies` directory, run the following command: 62 | 63 | ``` 64 | python evaluate.py policies pvp -r 10 65 | ``` 66 | 67 | This generates 10 results json files in the same directory (by using `-r 10`), each of which contains the results from 200 episodes. Then the task completion metrics can be viewed using: 68 | 69 | ``` 70 | python analysis/proc_eval_result.py policies 71 | ``` 72 | -------------------------------------------------------------------------------- /agent_zoo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/agent_zoo/__init__.py -------------------------------------------------------------------------------- /agent_zoo/hybrid.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | from .takeru import Policy 3 | from .takeru import Recurrent 4 | from .yaofeng import RewardWrapper 5 | -------------------------------------------------------------------------------- /agent_zoo/neurips23_start_kit/__init__.py: -------------------------------------------------------------------------------- 1 | from .baseline_policy import Baseline as Policy 2 | from .baseline_policy import Recurrent 3 | from .reward_wrapper import RewardWrapper 4 | -------------------------------------------------------------------------------- /agent_zoo/neurips23_start_kit/baseline_policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import pufferlib 5 | import pufferlib.models 6 | import pufferlib.emulation 7 | from pufferlib.emulation import unpack_batched_obs 8 | 9 | from nmmo.entity.entity import EntityState 10 | 11 | EntityId = EntityState.State.attr_name_to_col["id"] 12 | 13 | # NOTE: a workaround for the torch.complier problem. TODO: try torch 2.2 14 | # unpack_batched_obs = torch.compiler.disable(unpack_batched_obs) 15 | 16 | 17 | class Recurrent(pufferlib.models.RecurrentWrapper): 18 | def __init__(self, env, policy, input_size=256, hidden_size=256, num_layers=1): 19 | super().__init__(env, policy, input_size, hidden_size, num_layers) 20 | 21 | 22 | class Baseline(pufferlib.models.Policy): 23 | """Improved baseline policy by JimyhZhu""" 24 | 25 | def __init__(self, env, input_size=256, hidden_size=256, task_size=2048): 26 | super().__init__(env) 27 | 28 | self.unflatten_context = env.unflatten_context 29 | 30 | self.tile_encoder = TileEncoder(input_size) 31 | self.player_encoder = PlayerEncoder(input_size, hidden_size) 32 | self.item_encoder = ItemEncoder(input_size, hidden_size) 33 | self.inventory_encoder = InventoryEncoder(input_size, hidden_size) 34 | self.market_encoder = MarketEncoder(input_size, hidden_size) 35 | self.task_encoder = TaskEncoder(input_size, hidden_size, task_size) 36 | self.proj_fc = torch.nn.Linear(6 * input_size, input_size) 37 | self.action_decoder = ActionDecoder(input_size, hidden_size) 38 | self.value_head = torch.nn.Linear(hidden_size, 1) 39 | 40 | def encode_observations(self, flat_observations): 41 | env_outputs = unpack_batched_obs(flat_observations, self.unflatten_context) 42 | tile = self.tile_encoder(env_outputs["Tile"]) 43 | player_embeddings, my_agent = self.player_encoder( 44 | env_outputs["Entity"], env_outputs["AgentId"][:, 0] 45 | ) 46 | 47 | item_embeddings = self.item_encoder(env_outputs["Inventory"]) 48 | market_embeddings = self.item_encoder(env_outputs["Market"]) # no_pooling 49 | market = self.market_encoder(market_embeddings) # fc +mean pooling already applied 50 | task = self.task_encoder(env_outputs["Task"]) 51 | pooled_item_embeddings = item_embeddings.mean(dim=1) 52 | pooled_player_embeddings = player_embeddings.mean(dim=1) 53 | obs = torch.cat( 54 | [tile, my_agent, pooled_player_embeddings, pooled_item_embeddings, market, task], dim=-1 55 | ) 56 | obs = self.proj_fc(obs) 57 | 58 | # Pad the embeddings to make them the same size to the action_decoder 59 | # This is a workaround for the fact that the action_decoder expects the same number of actions including no-op 60 | embeddings = [player_embeddings, item_embeddings, market_embeddings] 61 | padded_embeddings = [] 62 | for embedding in embeddings: 63 | padding_size = 1 # The size of padding to be added 64 | padding = torch.zeros( 65 | embedding.size(0), padding_size, embedding.size(2), device=embedding.device 66 | ) 67 | padded_embedding = torch.cat([embedding, padding], dim=1) 68 | padded_embeddings.append(padded_embedding) 69 | # Replace the original embeddings with the padded versions 70 | player_embeddings, item_embeddings, market_embeddings = padded_embeddings 71 | 72 | return obs, ( 73 | player_embeddings, 74 | item_embeddings, 75 | market_embeddings, 76 | env_outputs["ActionTargets"], 77 | ) 78 | 79 | def decode_actions(self, hidden, lookup): 80 | actions = self.action_decoder(hidden, lookup) 81 | value = self.value_head(hidden) 82 | return actions, value 83 | 84 | 85 | class TileEncoder(torch.nn.Module): 86 | def __init__(self, input_size): 87 | super().__init__() 88 | self.tile_offset = torch.tensor([i * 256 for i in range(3)]) 89 | self.embedding = torch.nn.Embedding(3 * 256, 32) 90 | 91 | self.tile_conv_1 = torch.nn.Conv2d(96, 32, 3) 92 | self.tile_conv_2 = torch.nn.Conv2d(32, 8, 3) 93 | self.tile_fc = torch.nn.Linear(8 * 11 * 11, input_size) 94 | 95 | def forward(self, tile): 96 | tile[:, :, :2] -= tile[:, 112:113, :2].clone() 97 | tile[:, :, :2] += 7 98 | tile = self.embedding(tile.long().clip(0, 255) + self.tile_offset.to(tile.device)) 99 | 100 | agents, tiles, features, embed = tile.shape 101 | tile = ( 102 | tile.view(agents, tiles, features * embed) 103 | .transpose(1, 2) 104 | .view(agents, features * embed, 15, 15) 105 | ) 106 | 107 | tile = F.relu(self.tile_conv_1(tile)) 108 | tile = F.relu(self.tile_conv_2(tile)) 109 | tile = tile.contiguous().view(agents, -1) 110 | tile = F.relu(self.tile_fc(tile)) 111 | 112 | return tile 113 | 114 | 115 | class PlayerEncoder(torch.nn.Module): 116 | def __init__(self, input_size, hidden_size): 117 | super().__init__() 118 | self.entity_dim = 31 119 | self.num_classes_npc_type = 5 # only using npc_type for one hot (0,4) 120 | self.agent_fc = torch.nn.Linear(self.entity_dim + 5 - 1, hidden_size) 121 | self.my_agent_fc = torch.nn.Linear(self.entity_dim + 5 - 1, input_size) 122 | 123 | def forward(self, agents, my_id): 124 | npc_type = agents[:, :, 1] 125 | one_hot_npc_type = F.one_hot( 126 | npc_type.long(), num_classes=self.num_classes_npc_type 127 | ).float() # Subtract 1 if npc_type starts from 1 128 | one_hot_agents = torch.cat( 129 | [agents[:, :, :1], one_hot_npc_type, agents[:, :, 2:]], dim=-1 130 | ).float() 131 | 132 | agent_ids = one_hot_agents[:, :, EntityId] 133 | mask = (agent_ids == my_id.unsqueeze(1)) & (agent_ids != 0) 134 | mask = mask.int() 135 | row_indices = torch.where( 136 | mask.any(dim=1), mask.argmax(dim=1), torch.zeros_like(mask.sum(dim=1)) 137 | ) 138 | 139 | # batch, agent, attrs, embed = agent_embeddings.shape 140 | my_agent_embeddings = one_hot_agents[torch.arange(one_hot_agents.shape[0]), row_indices] 141 | agent_embeddings = self.agent_fc(one_hot_agents.cuda()) 142 | my_agent_embeddings = self.my_agent_fc(my_agent_embeddings) 143 | my_agent_embeddings = F.relu(my_agent_embeddings) 144 | 145 | return agent_embeddings, my_agent_embeddings 146 | 147 | 148 | class ItemEncoder(torch.nn.Module): 149 | def __init__(self, input_size, hidden_size): 150 | super().__init__() 151 | self.fc = torch.nn.Linear(18 + 2 + 14 - 2, hidden_size) 152 | self.discrete_idxs = [1, 14] 153 | self.discrete_offset = torch.Tensor([2, 0]) 154 | self.continuous_idxs = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15] 155 | self.continuous_scale = torch.Tensor([1 / 100] * 12) 156 | 157 | def forward(self, items): 158 | if self.discrete_offset.device != items.device: 159 | self.discrete_offset = self.discrete_offset.to(items.device) 160 | self.continuous_scale = self.continuous_scale.to(items.device) 161 | 162 | one_hot_discrete_equipped = F.one_hot(items[:, :, 14].long(), num_classes=2).float() 163 | one_hot_discrete_type_id = F.one_hot(items[:, :, 1].long(), num_classes=18).float() 164 | one_hot_discrete = torch.concat( 165 | [one_hot_discrete_type_id, one_hot_discrete_equipped], dim=-1 166 | ) # 167 | continuous = items[:, :, self.continuous_idxs] * self.continuous_scale 168 | item_embeddings = torch.cat([one_hot_discrete, continuous], dim=-1).float() 169 | item_embeddings = self.fc(item_embeddings) 170 | return item_embeddings 171 | 172 | 173 | class InventoryEncoder(torch.nn.Module): 174 | def __init__(self, input_size, hidden_size): 175 | super().__init__() 176 | self.fc = torch.nn.Linear(12 * hidden_size, input_size) 177 | 178 | def forward(self, inventory): 179 | agents, items, hidden = inventory.shape 180 | inventory = inventory.view(agents, items * hidden) 181 | return self.fc(inventory) 182 | 183 | 184 | class MarketEncoder(torch.nn.Module): 185 | def __init__(self, input_size, hidden_size): 186 | super().__init__() 187 | self.fc = torch.nn.Linear(hidden_size, input_size) 188 | 189 | def forward(self, market): 190 | return self.fc(market).mean(-2) 191 | 192 | 193 | class TaskEncoder(torch.nn.Module): 194 | def __init__(self, input_size, hidden_size, task_size): 195 | super().__init__() 196 | self.fc = torch.nn.Linear(task_size, input_size) 197 | 198 | def forward(self, task): 199 | return self.fc(task.clone().float()) 200 | 201 | 202 | class ActionDecoder(torch.nn.Module): 203 | def __init__(self, input_size, hidden_size): 204 | super().__init__() 205 | self.layers = torch.nn.ModuleDict( 206 | { 207 | "attack_style": torch.nn.Linear(hidden_size, 3), 208 | "attack_target": torch.nn.Linear(hidden_size, hidden_size), 209 | "market_buy": torch.nn.Linear(hidden_size, hidden_size), 210 | "inventory_destroy": torch.nn.Linear(hidden_size, hidden_size), 211 | "inventory_give_item": torch.nn.Linear(hidden_size, hidden_size), 212 | "inventory_give_player": torch.nn.Linear(hidden_size, hidden_size), 213 | "gold_quantity": torch.nn.Linear(hidden_size, 99), 214 | "gold_target": torch.nn.Linear(hidden_size, hidden_size), 215 | "move": torch.nn.Linear(hidden_size, 5), 216 | "inventory_sell": torch.nn.Linear(hidden_size, hidden_size), 217 | "inventory_price": torch.nn.Linear(hidden_size, 99), 218 | "inventory_use": torch.nn.Linear(hidden_size, hidden_size), 219 | } 220 | ) 221 | 222 | def apply_layer(self, layer, embeddings, mask, hidden): 223 | hidden = layer(hidden) 224 | if hidden.dim() == 2 and embeddings is not None: 225 | hidden = torch.matmul(embeddings, hidden.unsqueeze(-1)).squeeze(-1) 226 | 227 | if mask is not None: 228 | hidden = hidden.masked_fill(mask == 0, -1e9) 229 | 230 | return hidden 231 | 232 | def forward(self, hidden, lookup): 233 | ( 234 | player_embeddings, 235 | inventory_embeddings, 236 | market_embeddings, 237 | action_targets, 238 | ) = lookup 239 | 240 | embeddings = { 241 | "attack_target": player_embeddings, 242 | "market_buy": market_embeddings, 243 | "inventory_destroy": inventory_embeddings, 244 | "inventory_give_item": inventory_embeddings, 245 | "inventory_give_player": player_embeddings, 246 | "gold_target": player_embeddings, 247 | "inventory_sell": inventory_embeddings, 248 | "inventory_use": inventory_embeddings, 249 | } 250 | 251 | action_targets = { 252 | "attack_style": action_targets["Attack"]["Style"], 253 | "attack_target": action_targets["Attack"]["Target"], 254 | "market_buy": action_targets["Buy"]["MarketItem"], 255 | "inventory_destroy": action_targets["Destroy"]["InventoryItem"], 256 | "inventory_give_item": action_targets["Give"]["InventoryItem"], 257 | "inventory_give_player": action_targets["Give"]["Target"], 258 | "gold_quantity": action_targets["GiveGold"]["Price"], 259 | "gold_target": action_targets["GiveGold"]["Target"], 260 | "move": action_targets["Move"]["Direction"], 261 | "inventory_sell": action_targets["Sell"]["InventoryItem"], 262 | "inventory_price": action_targets["Sell"]["Price"], 263 | "inventory_use": action_targets["Use"]["InventoryItem"], 264 | } 265 | 266 | actions = [] 267 | for key, layer in self.layers.items(): 268 | mask = None 269 | mask = action_targets[key] 270 | embs = embeddings.get(key) 271 | 272 | # NOTE: SHOULD not hit this 273 | # if embs is not None and embs.shape[1] != mask.shape[1]: 274 | # b, _, f = embs.shape 275 | # zeros = torch.zeros([b, 1, f], dtype=embs.dtype, device=embs.device) 276 | # embs = torch.cat([embs, zeros], dim=1) 277 | 278 | action = self.apply_layer(layer, embs, mask, hidden) 279 | actions.append(action) 280 | 281 | return actions 282 | -------------------------------------------------------------------------------- /agent_zoo/neurips23_start_kit/reward_wrapper.py: -------------------------------------------------------------------------------- 1 | from reinforcement_learning.stat_wrapper import BaseStatWrapper 2 | 3 | 4 | class RewardWrapper(BaseStatWrapper): 5 | def __init__( 6 | # BaseStatWrapper args 7 | self, 8 | env, 9 | eval_mode=False, 10 | early_stop_agent_num=0, 11 | stat_prefix=None, 12 | use_custom_reward=True, 13 | # Custom reward wrapper args 14 | heal_bonus_weight=0, 15 | explore_bonus_weight=0, 16 | clip_unique_event=3, 17 | ): 18 | super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix, use_custom_reward) 19 | self.stat_prefix = stat_prefix 20 | 21 | self.heal_bonus_weight = heal_bonus_weight 22 | self.explore_bonus_weight = explore_bonus_weight 23 | self.clip_unique_event = clip_unique_event 24 | 25 | def reset(self, **kwargs): 26 | """Called at the start of each episode""" 27 | self._reset_reward_vars() 28 | return super().reset(**kwargs) 29 | 30 | def _reset_reward_vars(self): 31 | self._history = { 32 | agent_id: { 33 | "prev_price": 0, 34 | "prev_moves": [], 35 | } 36 | for agent_id in self.env.possible_agents 37 | } 38 | 39 | """ 40 | @functools.cached_property 41 | def observation_space(self): 42 | '''If you modify the shape of features, you need to specify the new obs space''' 43 | return super().observation_space 44 | """ 45 | 46 | def observation(self, agent_id, agent_obs): 47 | """Called before observations are returned from the environment 48 | 49 | Use this to define custom featurizers. Changing the space itself requires you to 50 | define the observation space again (i.e. Gym.spaces.Dict(gym.spaces....)) 51 | """ 52 | # Mask the price of the previous action, to encourage agents to explore new prices 53 | agent_obs["ActionTargets"]["Sell"]["Price"][self._history[agent_id]["prev_price"]] = 0 54 | return agent_obs 55 | 56 | def action(self, agent_id, agent_atn): 57 | """Called before actions are passed from the model to the environment""" 58 | # Keep track of the previous price and moves for each agent 59 | self._history[agent_id]["prev_price"] = agent_atn["Sell"]["Price"] 60 | self._history[agent_id]["prev_moves"].append(agent_atn["Move"]["Direction"]) 61 | return agent_atn 62 | 63 | def reward_terminated_truncated_info(self, agent_id, reward, terminated, truncated, info): 64 | realm = self.env.realm 65 | 66 | # Add "Healing" score based on health increase and decrease, due to food and water 67 | healing_bonus = 0 68 | if self.heal_bonus_weight > 0 and agent_id in realm.players: 69 | if realm.players[agent_id].resources.health_restore > 0: 70 | healing_bonus = self.heal_bonus_weight 71 | 72 | # Unique event-based rewards, similar to exploration bonus 73 | # The number of unique events are available in self._unique_events[agent_id] 74 | uniq = self._unique_events[agent_id] 75 | explore_bonus = 0 76 | if self.explore_bonus_weight > 0 and uniq["curr_count"] > uniq["prev_count"]: 77 | explore_bonus = min(self.clip_unique_event, uniq["curr_count"] - uniq["prev_count"]) 78 | explore_bonus *= self.explore_bonus_weight 79 | 80 | reward += healing_bonus + explore_bonus 81 | 82 | return reward, terminated, truncated, info 83 | -------------------------------------------------------------------------------- /agent_zoo/takeru/README.md: -------------------------------------------------------------------------------- 1 | Please see the participants' repo for the full training code: https://github.com/Netease-Games-OPD-AI/neural-mmo-2023 2 | -------------------------------------------------------------------------------- /agent_zoo/takeru/__init__.py: -------------------------------------------------------------------------------- 1 | from .policy import ReducedModelV2 as Policy 2 | from .policy import Recurrent 3 | from .reward_wrapper import RewardWrapper 4 | -------------------------------------------------------------------------------- /agent_zoo/takeru/policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import pufferlib 5 | import pufferlib.emulation 6 | import pufferlib.models 7 | 8 | from nmmo.entity.entity import EntityState 9 | 10 | EVAL_MODE = False 11 | # print(f"** EVAL_MODE {EVAL_MODE}") 12 | 13 | EntityId = EntityState.State.attr_name_to_col["id"] 14 | 15 | 16 | class Recurrent(pufferlib.models.RecurrentWrapper): 17 | def __init__(self, env, policy, input_size=256, hidden_size=256, num_layers=0): 18 | super().__init__(env, policy, input_size, hidden_size, num_layers) 19 | 20 | 21 | class ReducedModelV2(pufferlib.models.Policy): 22 | """Reduce observation space""" 23 | 24 | def __init__(self, env, input_size=256, hidden_size=256, task_size=2048): 25 | super().__init__(env) 26 | 27 | self.unflatten_context = env.unflatten_context 28 | 29 | self.tile_encoder = ReducedTileEncoder(input_size) 30 | self.player_encoder = ReducedPlayerEncoder(input_size, hidden_size) 31 | self.item_encoder = ReducedItemEncoder(input_size, hidden_size) 32 | self.inventory_encoder = InventoryEncoder(input_size, hidden_size) 33 | self.market_encoder = MarketEncoder(input_size, hidden_size) 34 | self.task_encoder = TaskEncoder(input_size, hidden_size, task_size) 35 | self.proj_fc = torch.nn.Linear(5 * input_size, hidden_size) 36 | self.action_decoder = ReducedActionDecoder(input_size, hidden_size) 37 | self.value_head = torch.nn.Linear(hidden_size, 1) 38 | 39 | def encode_observations(self, flat_observations): 40 | env_outputs = pufferlib.emulation.unpack_batched_obs( 41 | flat_observations, self.unflatten_context 42 | ) 43 | tile = self.tile_encoder(env_outputs["Tile"]) 44 | player_embeddings, my_agent = self.player_encoder( 45 | env_outputs["Entity"], env_outputs["AgentId"][:, 0] 46 | ) 47 | 48 | item_embeddings = self.item_encoder(env_outputs["Inventory"]) 49 | inventory = self.inventory_encoder(item_embeddings) 50 | 51 | market_embeddings = self.item_encoder(env_outputs["Market"]) 52 | market = self.market_encoder(market_embeddings) 53 | 54 | task = self.task_encoder(env_outputs["Task"]) 55 | 56 | obs = torch.cat([tile, my_agent, inventory, market, task], dim=-1) 57 | obs = self.proj_fc(obs) 58 | 59 | return obs, ( 60 | player_embeddings, 61 | item_embeddings, 62 | market_embeddings, 63 | env_outputs["ActionTargets"], 64 | ) 65 | 66 | def no_explore_post_processing(self, logits): 67 | # logits shape (BS, n sub-action dim) 68 | max_index = torch.argmax(logits, dim=-1) 69 | ret = torch.full_like(logits, fill_value=-1e9) 70 | ret[torch.arange(logits.shape[0]), max_index] = 0 71 | 72 | return ret 73 | 74 | def decode_actions(self, hidden, lookup): 75 | actions = self.action_decoder(hidden, lookup) 76 | value = self.value_head(hidden) 77 | 78 | if EVAL_MODE: 79 | actions = [self.no_explore_post_processing(logits) for logits in actions] 80 | # TODO: skip value 81 | 82 | return actions, value 83 | 84 | 85 | class ReducedTileEncoder(torch.nn.Module): 86 | def __init__(self, input_size): 87 | super().__init__() 88 | self.embedding = torch.nn.Embedding(256, 32) 89 | 90 | self.tile_conv_1 = torch.nn.Conv2d(32, 16, 3) 91 | self.tile_conv_2 = torch.nn.Conv2d(16, 8, 3) 92 | self.tile_fc = torch.nn.Linear(8 * 11 * 11, input_size) 93 | 94 | def forward(self, tile): 95 | # tile: row, col, material_id 96 | tile = tile[:, :, 2:] 97 | 98 | tile = self.embedding(tile.long().clip(0, 255)) 99 | 100 | agents, tiles, features, embed = tile.shape 101 | tile = ( 102 | tile.view(agents, tiles, features * embed) 103 | .transpose(1, 2) 104 | .view(agents, features * embed, 15, 15) 105 | ) 106 | 107 | tile = F.relu(self.tile_conv_1(tile)) 108 | tile = F.relu(self.tile_conv_2(tile)) 109 | tile = tile.contiguous().view(agents, -1) 110 | tile = F.relu(self.tile_fc(tile)) 111 | 112 | return tile 113 | 114 | 115 | class ReducedPlayerEncoder(torch.nn.Module): 116 | """ """ 117 | 118 | def __init__(self, input_size, hidden_size): 119 | super().__init__() 120 | 121 | discrete_attr = [ 122 | "id", # pos player entity id & neg npc entity id 123 | "npc_type", 124 | "attacker_id", # just pos player entity id 125 | "message", 126 | ] 127 | self.discrete_idxs = [EntityState.State.attr_name_to_col[key] for key in discrete_attr] 128 | self.discrete_offset = torch.Tensor([i * 256 for i in range(len(discrete_attr))]) 129 | 130 | _max_exp = 100 131 | _max_level = 10 132 | 133 | continuous_attr_and_scale = [ 134 | ("row", 256), 135 | ("col", 256), 136 | ("damage", 100), 137 | ("time_alive", 1024), 138 | ("freeze", 3), 139 | ("item_level", 50), 140 | ("latest_combat_tick", 1024), 141 | ("gold", 100), 142 | ("health", 100), 143 | ("food", 100), 144 | ("water", 100), 145 | ("melee_level", _max_level), 146 | ("melee_exp", _max_exp), 147 | ("range_level", _max_level), 148 | ("range_exp", _max_exp), 149 | ("mage_level", _max_level), 150 | ("mage_exp", _max_exp), 151 | ("fishing_level", _max_level), 152 | ("fishing_exp", _max_exp), 153 | ("herbalism_level", _max_level), 154 | ("herbalism_exp", _max_exp), 155 | ("prospecting_level", _max_level), 156 | ("prospecting_exp", _max_exp), 157 | ("carving_level", _max_level), 158 | ("carving_exp", _max_exp), 159 | ("alchemy_level", _max_exp), 160 | ("alchemy_exp", _max_level), 161 | ] 162 | self.continuous_idxs = [ 163 | EntityState.State.attr_name_to_col[key] for key, _ in continuous_attr_and_scale 164 | ] 165 | self.continuous_scale = torch.Tensor([scale for _, scale in continuous_attr_and_scale]) 166 | 167 | self.embedding = torch.nn.Embedding(len(discrete_attr) * 256, 32) 168 | 169 | emb_dim = len(discrete_attr) * 32 + len(continuous_attr_and_scale) 170 | self.agent_fc = torch.nn.Linear(emb_dim, hidden_size) 171 | self.my_agent_fc = torch.nn.Linear(emb_dim, input_size) 172 | 173 | def forward(self, agents, my_id): 174 | # self._debug(agents) 175 | 176 | # Pull out rows corresponding to the agent 177 | agent_ids = agents[:, :, EntityId] 178 | mask = (agent_ids == my_id.unsqueeze(1)) & (agent_ids != 0) 179 | mask = mask.int() 180 | row_indices = torch.where( 181 | mask.any(dim=1), mask.argmax(dim=1), torch.zeros_like(mask.sum(dim=1)) 182 | ) 183 | 184 | if self.discrete_offset.device != agents.device: 185 | self.discrete_offset = self.discrete_offset.to(agents.device) 186 | self.continuous_scale = self.continuous_scale.to(agents.device) 187 | 188 | # Embed each feature separately 189 | # agents shape (BS, agents, n of states) 190 | discrete = agents[:, :, self.discrete_idxs] + self.discrete_offset 191 | discrete = self.embedding(discrete.long().clip(0, 255)) 192 | batch, item, attrs, embed = discrete.shape 193 | discrete = discrete.view(batch, item, attrs * embed) 194 | 195 | continuous = agents[:, :, self.continuous_idxs] / self.continuous_scale 196 | 197 | # shape (BS, agents, x) 198 | agent_embeddings = torch.cat([discrete, continuous], dim=-1).float() 199 | 200 | my_agent_embeddings = agent_embeddings[torch.arange(agents.shape[0]), row_indices] 201 | 202 | # Project to input of recurrent size 203 | agent_embeddings = self.agent_fc(agent_embeddings) 204 | my_agent_embeddings = self.my_agent_fc(my_agent_embeddings) 205 | my_agent_embeddings = F.relu(my_agent_embeddings) 206 | 207 | return agent_embeddings, my_agent_embeddings 208 | 209 | def _debug(self, agents): 210 | agents_max, _ = torch.max(agents, dim=-2) 211 | agents_max, _ = torch.max(agents_max, dim=-2) 212 | print(f"agents_max {agents_max.tolist()}") 213 | 214 | 215 | class ReducedItemEncoder(torch.nn.Module): 216 | def __init__(self, input_size, hidden_size): 217 | super().__init__() 218 | self.item_offset = torch.tensor([i * 256 for i in range(16)]) 219 | self.embedding = torch.nn.Embedding(256, 32) 220 | 221 | self.fc = torch.nn.Linear(2 * 32 + 12, hidden_size) 222 | 223 | self.discrete_idxs = [1, 14] 224 | self.discrete_offset = torch.Tensor([2, 0]) 225 | self.continuous_idxs = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15] 226 | self.continuous_scale = torch.Tensor( 227 | [ 228 | 10, 229 | 10, 230 | 10, 231 | 100, 232 | 100, 233 | 100, 234 | 40, 235 | 40, 236 | 40, 237 | 100, 238 | 100, 239 | 100, 240 | ] 241 | ) 242 | 243 | def forward(self, items): 244 | if self.discrete_offset.device != items.device: 245 | self.discrete_offset = self.discrete_offset.to(items.device) 246 | self.continuous_scale = self.continuous_scale.to(items.device) 247 | 248 | # Embed each feature separately 249 | discrete = items[:, :, self.discrete_idxs] + self.discrete_offset 250 | discrete = self.embedding(discrete.long().clip(0, 255)) 251 | batch, item, attrs, embed = discrete.shape 252 | discrete = discrete.view(batch, item, attrs * embed) 253 | 254 | continuous = items[:, :, self.continuous_idxs] / self.continuous_scale 255 | 256 | item_embeddings = torch.cat([discrete, continuous], dim=-1).float() 257 | item_embeddings = self.fc(item_embeddings) 258 | return item_embeddings 259 | 260 | 261 | class InventoryEncoder(torch.nn.Module): 262 | def __init__(self, input_size, hidden_size): 263 | super().__init__() 264 | self.fc = torch.nn.Linear(12 * hidden_size, input_size) 265 | 266 | def forward(self, inventory): 267 | agents, items, hidden = inventory.shape 268 | inventory = inventory.view(agents, items * hidden) 269 | return self.fc(inventory) 270 | 271 | 272 | class MarketEncoder(torch.nn.Module): 273 | def __init__(self, input_size, hidden_size): 274 | super().__init__() 275 | self.fc = torch.nn.Linear(hidden_size, input_size) 276 | 277 | def forward(self, market): 278 | return self.fc(market).mean(-2) 279 | 280 | 281 | class TaskEncoder(torch.nn.Module): 282 | def __init__(self, input_size, hidden_size, task_size): 283 | super().__init__() 284 | self.fc = torch.nn.Linear(task_size, input_size) 285 | 286 | def forward(self, task): 287 | return self.fc(task.clone().float()) 288 | 289 | 290 | class ReducedActionDecoder(torch.nn.Module): 291 | def __init__(self, input_size, hidden_size): 292 | super().__init__() 293 | # order corresponding to action space 294 | self.sub_action_keys = [ 295 | "attack_style", 296 | "attack_target", 297 | "market_buy", 298 | "inventory_destroy", 299 | "inventory_give_item", 300 | "inventory_give_player", 301 | "gold_quantity", 302 | "gold_target", 303 | "move", 304 | "inventory_sell", 305 | "inventory_price", 306 | "inventory_use", 307 | ] 308 | self.layers = torch.nn.ModuleDict( 309 | { 310 | "attack_style": torch.nn.Linear(hidden_size, 3), 311 | "attack_target": torch.nn.Linear(hidden_size, hidden_size), 312 | "market_buy": torch.nn.Linear(hidden_size, hidden_size), 313 | "inventory_destroy": torch.nn.Linear(hidden_size, hidden_size), 314 | "inventory_give_item": torch.nn.Linear( 315 | hidden_size, hidden_size 316 | ), # TODO: useful for Inventory Management? 317 | "inventory_give_player": torch.nn.Linear(hidden_size, hidden_size), 318 | "gold_quantity": torch.nn.Linear(hidden_size, 99), 319 | "gold_target": torch.nn.Linear(hidden_size, hidden_size), 320 | "move": torch.nn.Linear(hidden_size, 5), 321 | "inventory_sell": torch.nn.Linear(hidden_size, hidden_size), 322 | "inventory_price": torch.nn.Linear(hidden_size, 99), 323 | "inventory_use": torch.nn.Linear(hidden_size, hidden_size), 324 | } 325 | ) 326 | 327 | def apply_layer(self, layer, embeddings, mask, hidden): 328 | hidden = layer(hidden) 329 | if hidden.dim() == 2 and embeddings is not None: 330 | hidden = torch.matmul(embeddings, hidden.unsqueeze(-1)).squeeze(-1) 331 | 332 | if mask is not None: 333 | hidden = hidden.masked_fill(mask == 0, -1e9) 334 | 335 | return hidden 336 | 337 | # NOTE: Disabling give/give_gold was moved to the reward wrapper 338 | # def act_noob_action(self, key, mask): 339 | # if key in ("inventory_give_item", "inventory_give_player", "gold_target"): 340 | # noob_action_index = -1 341 | # elif key in ("gold_quantity",): 342 | # noob_action_index = 0 343 | # else: 344 | # raise NotImplementedError(key) 345 | 346 | # logits = torch.full_like(mask, fill_value=-1e9) 347 | # logits[:, noob_action_index] = 0 348 | 349 | # return logits 350 | 351 | def forward(self, hidden, lookup): 352 | ( 353 | player_embeddings, 354 | inventory_embeddings, 355 | market_embeddings, 356 | action_targets, 357 | ) = lookup 358 | 359 | embeddings = { 360 | "attack_target": player_embeddings, 361 | "market_buy": market_embeddings, 362 | "inventory_destroy": inventory_embeddings, 363 | "inventory_give_item": inventory_embeddings, 364 | "inventory_give_player": player_embeddings, 365 | "gold_target": player_embeddings, 366 | "inventory_sell": inventory_embeddings, 367 | "inventory_use": inventory_embeddings, 368 | } 369 | 370 | action_targets = { 371 | "attack_style": action_targets["Attack"]["Style"], 372 | "attack_target": action_targets["Attack"]["Target"], 373 | "market_buy": action_targets["Buy"]["MarketItem"], 374 | "inventory_destroy": action_targets["Destroy"]["InventoryItem"], 375 | "inventory_give_item": action_targets["Give"]["InventoryItem"], 376 | "inventory_give_player": action_targets["Give"]["Target"], 377 | "gold_quantity": action_targets["GiveGold"]["Price"], 378 | "gold_target": action_targets["GiveGold"]["Target"], 379 | "move": action_targets["Move"]["Direction"], 380 | "inventory_sell": action_targets["Sell"]["InventoryItem"], 381 | "inventory_price": action_targets["Sell"]["Price"], 382 | "inventory_use": action_targets["Use"]["InventoryItem"], 383 | } 384 | 385 | actions = [] 386 | for key in self.sub_action_keys: 387 | mask = action_targets[key] 388 | 389 | if key in self.layers: 390 | layer = self.layers[key] 391 | embs = embeddings.get(key) 392 | if embs is not None and embs.shape[1] != mask.shape[1]: 393 | b, _, f = embs.shape 394 | zeros = torch.zeros([b, 1, f], dtype=embs.dtype, device=embs.device) 395 | embs = torch.cat([embs, zeros], dim=1) 396 | 397 | action = self.apply_layer(layer, embs, mask, hidden) 398 | 399 | # NOTE: see act_noob_action() 400 | # else: 401 | # action = self.act_noob_action(key, mask) 402 | 403 | actions.append(action) 404 | 405 | return actions 406 | -------------------------------------------------------------------------------- /agent_zoo/takeru/reward_wrapper.py: -------------------------------------------------------------------------------- 1 | from reinforcement_learning.stat_wrapper import BaseStatWrapper 2 | 3 | 4 | class RewardWrapper(BaseStatWrapper): 5 | def __init__( 6 | # BaseStatWrapper args 7 | self, 8 | env, 9 | eval_mode=False, 10 | early_stop_agent_num=0, 11 | stat_prefix=None, 12 | use_custom_reward=True, 13 | # Custom reward wrapper args 14 | explore_bonus_weight=0, 15 | clip_unique_event=3, 16 | disable_give=True, 17 | ): 18 | super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix, use_custom_reward) 19 | self.stat_prefix = stat_prefix 20 | 21 | self.explore_bonus_weight = explore_bonus_weight 22 | self.clip_unique_event = clip_unique_event 23 | self.disable_give = disable_give 24 | 25 | def observation(self, agent_id, agent_obs): 26 | """Called before observations are returned from the environment 27 | 28 | Use this to define custom featurizers. Changing the space itself requires you to 29 | define the observation space again (i.e. Gym.spaces.Dict(gym.spaces....)) 30 | """ 31 | if self.disable_give is True: 32 | agent_obs["ActionTargets"]["Give"]["InventoryItem"][:-1] = 0 33 | agent_obs["ActionTargets"]["Give"]["Target"][:-1] = 0 34 | agent_obs["ActionTargets"]["GiveGold"]["Target"][:-1] = 0 35 | agent_obs["ActionTargets"]["GiveGold"]["Price"][1:] = 0 36 | 37 | return agent_obs 38 | 39 | def reward_terminated_truncated_info(self, agent_id, reward, terminated, truncated, info): 40 | if not (terminated or truncated): 41 | # Unique event-based rewards, similar to exploration bonus 42 | # The number of unique events are available in self._unique_events[agent_id] 43 | uniq = self._unique_events[agent_id] 44 | explore_bonus = 0 45 | if self.explore_bonus_weight > 0 and uniq["curr_count"] > uniq["prev_count"]: 46 | explore_bonus = min(self.clip_unique_event, uniq["curr_count"] - uniq["prev_count"]) 47 | explore_bonus *= self.explore_bonus_weight 48 | 49 | reward += explore_bonus 50 | 51 | return reward, terminated, truncated, info 52 | -------------------------------------------------------------------------------- /agent_zoo/yaofeng/__init__.py: -------------------------------------------------------------------------------- 1 | from .policy import Policy 2 | from .policy import Recurrent 3 | from .reward_wrapper import RewardWrapper 4 | -------------------------------------------------------------------------------- /agent_zoo/yaofeng/policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import pufferlib 5 | import pufferlib.models 6 | import pufferlib.emulation 7 | 8 | from nmmo.entity.entity import EntityState 9 | 10 | EntityId = EntityState.State.attr_name_to_col["id"] 11 | 12 | 13 | class Recurrent(pufferlib.models.RecurrentWrapper): 14 | def __init__(self, env, policy, input_size=256, hidden_size=256, num_layers=2): 15 | super().__init__(env, policy, input_size, hidden_size, num_layers) 16 | 17 | 18 | def orthogonal_init(layer, gain=1.0): 19 | torch.nn.init.orthogonal_(layer.weight, gain=gain) 20 | torch.nn.init.constant_(layer.bias, 0) 21 | 22 | 23 | class Policy(pufferlib.models.Policy): 24 | def __init__(self, env, input_size=256, hidden_size=256, task_size=2048): 25 | super().__init__(env) 26 | 27 | self.unflatten_context = env.unflatten_context 28 | 29 | self.tile_encoder = TileEncoder(input_size) 30 | self.player_encoder = PlayerEncoder(input_size, hidden_size) 31 | self.item_encoder = ItemEncoder(input_size, hidden_size) 32 | self.inventory_encoder = InventoryEncoder(input_size, hidden_size) 33 | self.market_encoder = MarketEncoder(input_size, hidden_size) 34 | self.task_encoder = TaskEncoder(input_size, hidden_size, task_size) 35 | self.proj_fc = torch.nn.Linear(5 * input_size, hidden_size) 36 | self.action_decoder = ActionDecoder(input_size, hidden_size) 37 | self.value_head = torch.nn.Linear(hidden_size, 1) 38 | orthogonal_init(self.proj_fc) 39 | orthogonal_init(self.value_head) 40 | 41 | def encode_observations(self, flat_observations): 42 | env_outputs = pufferlib.emulation.unpack_batched_obs( 43 | flat_observations, self.unflatten_context 44 | ) 45 | tile = self.tile_encoder(env_outputs["Tile"]) 46 | player_embeddings, my_agent = self.player_encoder( 47 | env_outputs["Entity"], env_outputs["AgentId"][:, 0] 48 | ) 49 | 50 | item_embeddings = self.item_encoder(env_outputs["Inventory"]) 51 | inventory = self.inventory_encoder(item_embeddings) 52 | 53 | market_embeddings = self.item_encoder(env_outputs["Market"]) 54 | market = self.market_encoder(market_embeddings) 55 | 56 | task = self.task_encoder(env_outputs["Task"]) 57 | 58 | obs = torch.cat([tile, my_agent, inventory, market, task], dim=-1) 59 | obs = F.relu(self.proj_fc(obs)) 60 | 61 | return obs, ( 62 | player_embeddings, 63 | item_embeddings, 64 | market_embeddings, 65 | env_outputs["ActionTargets"], 66 | ) 67 | 68 | def decode_actions(self, hidden, lookup): 69 | actions = self.action_decoder(hidden, lookup) 70 | value = self.value_head(hidden) 71 | return actions, value 72 | 73 | 74 | class ResnetBlock(torch.nn.Module): 75 | def __init__(self, in_planes, img_size=(15, 15)): 76 | super().__init__() 77 | self.model = torch.nn.Sequential( 78 | torch.nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1), 79 | torch.nn.LayerNorm((in_planes, *img_size)), 80 | torch.nn.ReLU(), 81 | torch.nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1), 82 | torch.nn.LayerNorm((in_planes, *img_size)), 83 | ) 84 | 85 | def forward(self, x): 86 | out = self.model(x) 87 | out += x 88 | return out 89 | 90 | 91 | class TileEncoder(torch.nn.Module): 92 | def __init__(self, input_size): 93 | super().__init__() 94 | self.type_embedding = torch.nn.Embedding(16, 62) 95 | 96 | self.tile_resnet = ResnetBlock(64) 97 | self.tile_conv_1 = torch.nn.Conv2d(64, 32, 3) 98 | self.tile_conv_2 = torch.nn.Conv2d(32, 8, 3) 99 | self.tile_fc = torch.nn.Linear(8 * 11 * 11, input_size) 100 | self.tile_norm = torch.nn.LayerNorm(input_size) 101 | orthogonal_init(self.tile_fc) 102 | 103 | def forward(self, tile): 104 | tile_position = tile[:, :, :2] / 128 - 0.5 105 | tile_type = tile[:, :, 2].long().clip(0, 15) 106 | tile = torch.cat((tile_position, self.type_embedding(tile_type)), dim=-1) 107 | agents, _, features = tile.shape 108 | tile = tile.transpose(1, 2).view(agents, features, 15, 15).float() 109 | 110 | tile = F.relu(self.tile_resnet(tile)) 111 | tile = F.relu(self.tile_conv_1(tile)) 112 | tile = F.relu(self.tile_conv_2(tile)) 113 | tile = tile.contiguous().view(agents, -1) 114 | tile = F.relu(self.tile_norm(self.tile_fc(tile))) 115 | return tile 116 | 117 | 118 | class MLPBlock(torch.nn.Module): 119 | def __init__(self, input_size, hidden_size, output_size, num_layers=2): 120 | super().__init__() 121 | self.model = [ 122 | torch.nn.Linear(input_size, hidden_size), 123 | torch.nn.ReLU(), 124 | ] 125 | for _ in range(num_layers - 2): 126 | self.model += [torch.nn.Linear(hidden_size, hidden_size), torch.nn.ReLU()] 127 | self.model.append(torch.nn.Linear(hidden_size, output_size)) 128 | for layer in self.model: 129 | if isinstance(layer, torch.nn.Linear): 130 | orthogonal_init(layer) 131 | self.model = torch.nn.Sequential(*self.model) 132 | 133 | def forward(self, x): 134 | out = self.model(x) 135 | return out 136 | 137 | 138 | class PlayerEncoder(torch.nn.Module): 139 | def __init__(self, input_size, hidden_size): 140 | super().__init__() 141 | self.entity_dim = 31 142 | self.player_offset = torch.tensor([i * 256 for i in range(self.entity_dim)]) 143 | self.embedding = torch.nn.Embedding(self.entity_dim * 256, 32) 144 | 145 | self.EntityId = EntityState.State.attr_name_to_col["id"] 146 | self.EntityAttackerId = EntityState.State.attr_name_to_col["attacker_id"] 147 | self.EntityMessage = EntityState.State.attr_name_to_col["message"] 148 | self.id_embedding = torch.nn.Embedding(512, 64) 149 | self.embedding_idx = [self.EntityId, self.EntityAttackerId] 150 | self.no_embedding_idx = [i for i in range(self.entity_dim)] 151 | self.no_embedding_idx.remove(self.EntityId) 152 | self.no_embedding_idx.remove(self.EntityAttackerId) 153 | self.no_embedding_idx.remove(self.EntityMessage) 154 | 155 | self.agent_mlp = MLPBlock(64 + self.entity_dim - 3, hidden_size, hidden_size) 156 | self.agent_fc = torch.nn.Linear(hidden_size, hidden_size) 157 | self.my_agent_fc = torch.nn.Linear(hidden_size, input_size) 158 | self.agent_norm = torch.nn.LayerNorm(hidden_size) 159 | self.my_agent_norm = torch.nn.LayerNorm(hidden_size) 160 | orthogonal_init(self.agent_fc) 161 | orthogonal_init(self.my_agent_fc) 162 | 163 | def forward(self, agents, my_id): 164 | # Pull out rows corresponding to the agent 165 | agent_ids = agents[:, :, EntityId] 166 | mask = (agent_ids == my_id.unsqueeze(1)) & (agent_ids != 0) 167 | mask = mask.int() 168 | row_indices = torch.where( 169 | mask.any(dim=1), mask.argmax(dim=1), torch.zeros_like(mask.sum(dim=1)) 170 | ) 171 | 172 | batch, agent, _ = agents.shape 173 | agent_embeddings = self.embedding( 174 | (agents[:, :, self.embedding_idx].long() + 256).clip(0, 511) 175 | ).reshape(batch, agent, -1) 176 | agent_embeddings = torch.cat( 177 | (agent_embeddings, agents[:, :, self.no_embedding_idx]), dim=-1 178 | ).float() 179 | agent_embeddings = F.relu(self.agent_mlp(agent_embeddings)) 180 | 181 | my_agent_embeddings = agent_embeddings[torch.arange(agents.shape[0]), row_indices] 182 | agent_embeddings = F.relu(self.agent_norm(self.agent_fc(agent_embeddings))) 183 | my_agent_embeddings = F.relu(self.my_agent_norm(self.my_agent_fc(my_agent_embeddings))) 184 | return agent_embeddings, my_agent_embeddings 185 | 186 | 187 | class ItemEncoder(torch.nn.Module): 188 | def __init__(self, input_size, hidden_size): 189 | super().__init__() 190 | self.embedding = torch.nn.Embedding(256, 32) 191 | self.item_mlp = MLPBlock(2 * 32 + 12, hidden_size, hidden_size) 192 | self.item_norm = torch.nn.LayerNorm(hidden_size) 193 | 194 | self.discrete_idxs = [1, 14] 195 | self.discrete_offset = torch.Tensor([2, 0]) 196 | self.continuous_idxs = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15] 197 | self.continuous_scale = torch.Tensor( 198 | [ 199 | 1 / 10, 200 | 1 / 10, 201 | 1 / 10, 202 | 1 / 100, 203 | 1 / 100, 204 | 1 / 100, 205 | 1 / 40, 206 | 1 / 40, 207 | 1 / 40, 208 | 1 / 100, 209 | 1 / 100, 210 | 1 / 100, 211 | ] 212 | ) 213 | 214 | def forward(self, items): 215 | if self.discrete_offset.device != items.device: 216 | self.discrete_offset = self.discrete_offset.to(items.device) 217 | self.continuous_scale = self.continuous_scale.to(items.device) 218 | 219 | # Embed each feature separately 220 | discrete = items[:, :, self.discrete_idxs] + self.discrete_offset 221 | discrete = self.embedding(discrete.long().clip(0, 255)) 222 | batch, item, attrs, embed = discrete.shape 223 | discrete = discrete.view(batch, item, attrs * embed) 224 | 225 | continuous = items[:, :, self.continuous_idxs] / self.continuous_scale 226 | 227 | item_embeddings = torch.cat([discrete, continuous], dim=-1).float() 228 | item_embeddings = F.relu(self.item_norm(self.item_mlp(item_embeddings))) 229 | return item_embeddings 230 | 231 | 232 | class InventoryEncoder(torch.nn.Module): 233 | def __init__(self, input_size, hidden_size): 234 | super().__init__() 235 | self.fc = torch.nn.Linear(12 * hidden_size, input_size) 236 | self.norm = torch.nn.LayerNorm(input_size) 237 | orthogonal_init(self.fc) 238 | 239 | def forward(self, inventory): 240 | agents, items, hidden = inventory.shape 241 | inventory = inventory.view(agents, items * hidden) 242 | return F.relu(self.norm(self.fc(inventory))) 243 | 244 | 245 | class MarketEncoder(torch.nn.Module): 246 | def __init__(self, input_size, hidden_size): 247 | super().__init__() 248 | self.fc = torch.nn.Linear(hidden_size, input_size) 249 | self.norm = torch.nn.LayerNorm(input_size) 250 | orthogonal_init(self.fc) 251 | 252 | def forward(self, market): 253 | return F.relu(self.norm(self.fc(market).mean(-2))) 254 | 255 | 256 | class TaskEncoder(torch.nn.Module): 257 | def __init__(self, input_size, hidden_size, task_size): 258 | super().__init__() 259 | self.fc = torch.nn.Linear(task_size, input_size) 260 | self.norm = torch.nn.LayerNorm(input_size) 261 | orthogonal_init(self.fc) 262 | 263 | def forward(self, task): 264 | return F.relu(self.norm(self.fc(task.clone().float()))) 265 | 266 | 267 | class ActionDecoder(torch.nn.Module): 268 | def __init__(self, input_size, hidden_size): 269 | super().__init__() 270 | self.layers = { 271 | "attack_style": torch.nn.Linear(hidden_size, 3), 272 | "attack_target": torch.nn.Linear(hidden_size, hidden_size), 273 | "market_buy": torch.nn.Linear(hidden_size, hidden_size), 274 | "inventory_destroy": torch.nn.Linear(hidden_size, hidden_size), 275 | "inventory_give_item": torch.nn.Linear(hidden_size, hidden_size), 276 | "inventory_give_player": torch.nn.Linear(hidden_size, hidden_size), 277 | "gold_quantity": torch.nn.Linear(hidden_size, 99), 278 | "gold_target": torch.nn.Linear(hidden_size, hidden_size), 279 | "move": torch.nn.Linear(hidden_size, 5), 280 | "inventory_sell": torch.nn.Linear(hidden_size, hidden_size), 281 | "inventory_price": torch.nn.Linear(hidden_size, 99), 282 | "inventory_use": torch.nn.Linear(hidden_size, hidden_size), 283 | } 284 | for _, v in self.layers.items(): 285 | orthogonal_init(v, gain=0.1) 286 | self.layers = torch.nn.ModuleDict(self.layers) 287 | 288 | def apply_layer(self, layer, embeddings, mask, hidden): 289 | hidden = layer(hidden) 290 | if hidden.dim() == 2 and embeddings is not None: 291 | hidden = torch.matmul(embeddings, hidden.unsqueeze(-1)).squeeze(-1) 292 | 293 | if mask is not None: 294 | hidden = hidden.masked_fill(mask == 0, -1e9) 295 | 296 | return hidden 297 | 298 | def forward(self, hidden, lookup): 299 | ( 300 | player_embeddings, 301 | inventory_embeddings, 302 | market_embeddings, 303 | action_targets, 304 | ) = lookup 305 | 306 | embeddings = { 307 | "attack_target": player_embeddings, 308 | "market_buy": market_embeddings, 309 | "inventory_destroy": inventory_embeddings, 310 | "inventory_give_item": inventory_embeddings, 311 | "inventory_give_player": player_embeddings, 312 | "gold_target": player_embeddings, 313 | "inventory_sell": inventory_embeddings, 314 | "inventory_use": inventory_embeddings, 315 | } 316 | 317 | action_targets = { 318 | "attack_style": action_targets["Attack"]["Style"], 319 | "attack_target": action_targets["Attack"]["Target"], 320 | "market_buy": action_targets["Buy"]["MarketItem"], 321 | "inventory_destroy": action_targets["Destroy"]["InventoryItem"], 322 | "inventory_give_item": action_targets["Give"]["InventoryItem"], 323 | "inventory_give_player": action_targets["Give"]["Target"], 324 | "gold_quantity": action_targets["GiveGold"]["Price"], 325 | "gold_target": action_targets["GiveGold"]["Target"], 326 | "move": action_targets["Move"]["Direction"], 327 | "inventory_sell": action_targets["Sell"]["InventoryItem"], 328 | "inventory_price": action_targets["Sell"]["Price"], 329 | "inventory_use": action_targets["Use"]["InventoryItem"], 330 | } 331 | 332 | # Pass the LSTM output through a ReLU 333 | # NOTE: The original implementation had relu after both LSTM layers 334 | hidden = F.relu(hidden) 335 | 336 | actions = [] 337 | # assert action_targets["inventory_give_item"].sum() <= hidden.shape[0] 338 | # assert action_targets["inventory_give_player"].sum() <= hidden.shape[0] 339 | # assert action_targets["gold_target"].sum() <= hidden.shape[0] 340 | # assert action_targets["gold_quantity"].sum() <= hidden.shape[0] 341 | for key, layer in self.layers.items(): 342 | mask = None 343 | mask = action_targets[key] 344 | embs = embeddings.get(key) 345 | if embs is not None and embs.shape[1] != mask.shape[1]: 346 | b, _, f = embs.shape 347 | zeros = torch.zeros([b, 1, f], dtype=embs.dtype, device=embs.device) 348 | embs = torch.cat([embs, zeros], dim=1) 349 | action = self.apply_layer(layer, embs, mask, hidden) 350 | actions.append(action) 351 | 352 | return actions 353 | -------------------------------------------------------------------------------- /agent_zoo/yaofeng/reward_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from nmmo.entity.entity import EntityState 4 | 5 | from reinforcement_learning.stat_wrapper import BaseStatWrapper 6 | 7 | EntityAttr = EntityState.State.attr_name_to_col 8 | SKILL_LIST = ["melee", "range", "mage", "fishing", "herbalism", "prospecting", "carving", "alchemy"] 9 | 10 | 11 | class RewardWrapper(BaseStatWrapper): 12 | def __init__( 13 | # BaseStatWrapper args 14 | self, 15 | env, 16 | eval_mode=False, 17 | early_stop_agent_num=0, 18 | stat_prefix=None, 19 | use_custom_reward=True, 20 | # Custom reward wrapper args 21 | hp_bonus_weight=0, 22 | exp_bonus_weight=0, 23 | defense_bonus_weight=0, 24 | attack_bonus_weight=0, 25 | gold_bonus_weight=0, 26 | custom_bonus_scale=1, 27 | randomize_spawn_immunity=False, 28 | disable_give=True, 29 | donot_attack_dangerous_npc=True, 30 | ): 31 | super().__init__(env, eval_mode, early_stop_agent_num, stat_prefix, use_custom_reward) 32 | self.stat_prefix = stat_prefix 33 | self.default_spawn_immunity = env.config.COMBAT_SPAWN_IMMUNITY 34 | 35 | self.hp_bonus_weight = hp_bonus_weight 36 | self.exp_bonus_weight = exp_bonus_weight 37 | self.defense_bonus_weight = defense_bonus_weight 38 | self.attack_bonus_weight = attack_bonus_weight 39 | self.gold_bonus_weight = gold_bonus_weight 40 | self.custom_bonus_scale = custom_bonus_scale 41 | 42 | # TODO: implement random spawn immunity, while not breaking determinism 43 | self.randomize_spawn_immunity = randomize_spawn_immunity 44 | 45 | self.disable_give = disable_give 46 | self.donot_attack_dangerous_npc = donot_attack_dangerous_npc 47 | 48 | def reset(self, **kwargs): 49 | """Called at the start of each episode""" 50 | self._reset_reward_vars() 51 | return super().reset(**kwargs) 52 | 53 | def _reset_reward_vars(self): 54 | self._data = { 55 | agent_id: { 56 | "hp": 100, 57 | "exp": 0, 58 | "damage_received": 0, 59 | "damage_inflicted": 0, 60 | "gold": 0, 61 | } 62 | for agent_id in self.env.possible_agents 63 | } 64 | 65 | def observation(self, agent_id, agent_obs): 66 | """Called before observations are returned from the environment 67 | 68 | Use this to define custom featurizers. Changing the space itself requires you to 69 | define the observation space again (i.e. Gym.spaces.Dict(gym.spaces....)) 70 | """ 71 | if self.disable_give is True: 72 | agent_obs["ActionTargets"]["Give"]["InventoryItem"][:-1] = 0 73 | agent_obs["ActionTargets"]["Give"]["Target"][:-1] = 0 74 | agent_obs["ActionTargets"]["GiveGold"]["Target"][:-1] = 0 75 | agent_obs["ActionTargets"]["GiveGold"]["Price"][1:] = 0 76 | 77 | if self.donot_attack_dangerous_npc is True: 78 | # npc type: 1: passive, 2: neutral, 3: hostile 79 | dangerours_npc_idxs = np.where(agent_obs["Entity"][:, EntityAttr["npc_type"]] > 1) 80 | agent_obs["ActionTargets"]["Attack"]["Target"][dangerours_npc_idxs] = 0 81 | 82 | return agent_obs 83 | 84 | def reward_terminated_truncated_info(self, agent_id, reward, terminated, truncated, info): 85 | if not (terminated or truncated): 86 | assert agent_id in self.env.realm.players, f"agent_id {agent_id} not in realm.players" 87 | agent_info = self.env.realm.players[agent_id] 88 | 89 | # HP bonus 90 | current_hp = agent_info.health.val 91 | hp_bonus = (current_hp - self._data[agent_id]["hp"]) * self.hp_bonus_weight 92 | self._data[agent_id]["hp"] = current_hp 93 | 94 | # Experience bonus 95 | current_exps = np.array( 96 | [getattr(agent_info, f"{skill}_exp").val for skill in SKILL_LIST] 97 | ) 98 | current_exp = np.max(current_exps) 99 | exp_bonus = (current_exp - self._data[agent_id]["exp"]) * self.exp_bonus_weight 100 | assert exp_bonus >= 0, "exp bonus error" 101 | self._data[agent_id]["exp"] = current_exp 102 | 103 | # Defense bonus 104 | current_damage_received = agent_info.history.damage_received 105 | equipment = agent_info.inventory.equipment 106 | defense = ( 107 | equipment.melee_defense + equipment.range_defense + equipment.mage_defense 108 | ) / (15 * 3) 109 | defense_bonus = self.defense_bonus_weight * defense 110 | self._data[agent_id]["damage_received"] = current_damage_received 111 | 112 | # Attack bonus 113 | current_damage_inflicted = agent_info.history.damage_inflicted 114 | attack_bonus = ( 115 | current_damage_inflicted - self._data[agent_id]["damage_inflicted"] 116 | ) * self.attack_bonus_weight 117 | assert attack_bonus >= 0, "attack bonus error" 118 | self._data[agent_id]["damage_inflicted"] = current_damage_inflicted 119 | 120 | # Gold bonus 121 | current_gold = agent_info.gold.val 122 | gold_bonus = (current_gold - self._data[agent_id]["gold"]) * self.gold_bonus_weight 123 | self._data[agent_id]["gold"] = current_gold 124 | 125 | reward += ( 126 | hp_bonus + exp_bonus + defense_bonus + attack_bonus + gold_bonus 127 | ) * self.custom_bonus_scale 128 | 129 | return reward, terminated, truncated, info 130 | -------------------------------------------------------------------------------- /analysis/proc_eval_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import argparse 5 | 6 | import numpy as np 7 | import polars as pl 8 | 9 | # Make the table output simpler 10 | pl.Config.set_tbl_hide_dataframe_shape(True) 11 | pl.Config.set_tbl_formatting("NOTHING") 12 | pl.Config.set_tbl_hide_column_data_types(True) 13 | 14 | # string matching for task names 15 | WEIGHT_DICT = { 16 | "TickGE": ("survival", 100 / 6), # 1 survival task 17 | "PLAYER_KILL": ("combat", 100 / (6 * 3)), # 3 combat tasks 18 | "DefeatEntity": ("combat", 100 / (6 * 3)), 19 | "GO_FARTHEST": ("exploration", 100 / (6 * 2)), # 2 exploration tasks 20 | "OccupyTile": ("exploration", 100 / (6 * 2)), 21 | "AttainSkill": ("skill", 100 / (6 * 8)), # 8 skill tasks 22 | "HarvestItem": ("item", 100 / (6 * 44)), # 44 item tasks 23 | "ConsumeItem": ("item", 100 / (6 * 44)), 24 | "EquipItem": ("item", 100 / (6 * 44)), 25 | "FullyArmed": ("item", 100 / (6 * 44)), 26 | "EARN_GOLD": ("market", 100 / (6 * 5)), # 5 market tasks 27 | "BUY_ITEM": ("market", 100 / (6 * 5)), 28 | "EarnGold": ("market", 100 / (6 * 5)), 29 | "HoardGold": ("market", 100 / (6 * 5)), 30 | "MakeProfit": ("market", 100 / (6 * 5)), 31 | } 32 | 33 | 34 | def get_task_weight(task_name): 35 | for key, val in WEIGHT_DICT.items(): 36 | if key in task_name: 37 | return val 38 | logging.warning(f"Task name {task_name} not found in weight dict") 39 | return "etc", 0 40 | 41 | 42 | def get_summary_dict(progress, key): 43 | # progress = vals if key == "length" else [v[0] for v in vals] 44 | summ = {"count": len(progress), "mean": np.mean(progress), "median": np.median(progress)} 45 | 46 | if key == "length": 47 | progress = np.array(progress) / 1023 # full episode length 48 | 49 | summ["completed"] = np.mean([1 if v >= 1 else 0 for v in progress]) 50 | summ["over30pcnt"] = np.mean([1 if v >= 0.3 else 0 for v in progress]) 51 | return summ 52 | 53 | 54 | def summarize_single_eval(data, weighted_score=False): 55 | summary = {} 56 | 57 | # task-level info 58 | for key, vals in data.items(): 59 | if key.startswith("curriculum") or key == "length": 60 | summary[key] = get_summary_dict(vals, key) 61 | 62 | if weighted_score and key.startswith("curriculum"): 63 | category, weight = get_task_weight(key) 64 | summary[key]["category"] = category 65 | summary[key]["weight"] = weight 66 | summary[key]["weighted_score"] = summary[key]["mean"] * weight 67 | 68 | # meta info 69 | summary["avg_progress"] = np.mean( 70 | [v["mean"] for k, v in summary.items() if k.startswith("curriculum")] 71 | ) 72 | if weighted_score: 73 | summary["weighted_score"] = np.sum( 74 | [v["weighted_score"] for k, v in summary.items() if k.startswith("curriculum")] 75 | ) 76 | return summary 77 | 78 | 79 | def process_eval_files(policy_store_dir, eval_prefix): 80 | summ_policy = [] 81 | summ_task = [] 82 | 83 | for file in os.listdir(policy_store_dir): 84 | # NOTE: assumes the file naming convention is 'eval__.json' 85 | if not file.startswith(eval_prefix) or not file.endswith(".json"): 86 | continue 87 | 88 | mode = file.split("_")[1] 89 | random_seed = file.split("_")[2].replace(".json", "") 90 | 91 | with open(os.path.join(policy_store_dir, file), "r") as f: 92 | data = json.load(f) 93 | 94 | for pol_name, pol_data in data.items(): 95 | if len(pol_data) == 0: 96 | continue 97 | 98 | summary = summarize_single_eval(pol_data, weighted_score=True) 99 | summ_policy.append( 100 | { 101 | "policy_name": pol_name, 102 | "mode": mode, 103 | "seed": random_seed, 104 | "count": summary["length"]["count"], 105 | "length": summary["length"]["mean"], 106 | "task_progress": summary["avg_progress"], 107 | "weighted_score": summary["weighted_score"], 108 | } 109 | ) 110 | 111 | # also gather the results across random seeds for each task, then average 112 | for task_name, task_data in summary.items(): 113 | if not task_name.startswith("curriculum"): 114 | continue 115 | summ_task.append( 116 | { 117 | "category": task_data["category"], 118 | "task_name": task_name, 119 | "weight": task_data["weight"], 120 | "policy_name": pol_name, 121 | "mode": mode, 122 | "seed": random_seed, 123 | "count": task_data["count"], 124 | "task_progress": task_data["mean"], 125 | } 126 | ) 127 | 128 | summ_df = pl.DataFrame(summ_policy).sort(["policy_name", "mode", "seed"]) 129 | summ_grp = summ_df.group_by(["policy_name", "mode"]).agg( 130 | pl.col("task_progress").mean(), 131 | pl.col("weighted_score").mean(), 132 | ) 133 | summ_grp = summ_grp.sort("weighted_score", descending=True) 134 | summ_grp.write_csv( 135 | os.path.join(policy_store_dir, "score_summary.tsv"), separator="\t", float_precision=6 136 | ) 137 | print("\nPolicy score summary, sorted by weighted_score:") 138 | print(summ_grp) 139 | 140 | task_df = pl.DataFrame(summ_task).sort(["mode", "category", "task_name", "policy_name", "seed"]) 141 | task_grp = task_df.group_by(["mode", "category", "task_name", "policy_name"]).agg( 142 | pl.col("task_progress").mean() 143 | ) 144 | task_grp = task_grp.sort(["mode", "category", "task_name", "policy_name"]) 145 | task_grp.write_csv( 146 | os.path.join(policy_store_dir, "score_task_summary.tsv"), separator="\t", float_precision=6 147 | ) 148 | cate_grp = task_df.group_by(["mode", "category", "policy_name"]).agg( 149 | pl.col("task_progress").mean() 150 | ) 151 | cate_grp = cate_grp.sort(["mode", "category", "policy_name"]) 152 | cate_grp.write_csv( 153 | os.path.join(policy_store_dir, "score_category_summary.tsv"), 154 | separator="\t", 155 | float_precision=6, 156 | ) 157 | 158 | if len(summ_df["seed"].unique()) > 1: 159 | summ_df.write_csv( 160 | os.path.join(policy_store_dir, "score_by_seed.tsv"), separator="\t", float_precision=6 161 | ) 162 | task_df.write_csv( 163 | os.path.join(policy_store_dir, "score_by_task_seed.tsv"), 164 | separator="\t", 165 | float_precision=6, 166 | ) 167 | 168 | return summ_df, summ_grp, task_df, task_grp, cate_grp 169 | 170 | 171 | if __name__ == "__main__": 172 | parser = argparse.ArgumentParser(description="Process the evaluation result files") 173 | parser.add_argument("policy_store_dir", type=str, help="Path to the policy directory") 174 | parser.add_argument( 175 | "-p", "--prefix", type=str, default="eval_", help="Prefix of the evaluation result files" 176 | ) 177 | args = parser.parse_args() 178 | 179 | process_eval_files(args.policy_store_dir, args.prefix) 180 | -------------------------------------------------------------------------------- /analysis/proc_task_cond_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from collections import defaultdict 4 | 5 | import dill 6 | import numpy as np 7 | import polars as pl 8 | from tqdm import tqdm 9 | 10 | from nmmo.lib.event_code import EventCode 11 | from nmmo.systems.item import ALL_ITEM 12 | from nmmo.systems.skill import COMBAT_SKILL, HARVEST_SKILL 13 | 14 | CODE_TO_EVENT = {v: k for k, v in EventCode.__dict__.items() if not k.startswith("_")} 15 | 16 | ITEM_ID_TO_NAME = {item.ITEM_TYPE_ID: item.__name__ for item in ALL_ITEM} 17 | 18 | SKILL_ID_TO_NAME = {skill.SKILL_ID: skill.__name__ for skill in COMBAT_SKILL + HARVEST_SKILL} 19 | 20 | 21 | # event tuple key to string 22 | def event_key_to_str(event_key): 23 | if event_key[0] == EventCode.LEVEL_UP: 24 | return f"LEVEL_{SKILL_ID_TO_NAME[event_key[1]]}" 25 | 26 | elif event_key[0] == EventCode.SCORE_HIT: 27 | return f"ATTACK_NUM_{SKILL_ID_TO_NAME[event_key[1]]}" 28 | 29 | elif event_key[0] in [ 30 | EventCode.HARVEST_ITEM, 31 | EventCode.CONSUME_ITEM, 32 | EventCode.EQUIP_ITEM, 33 | EventCode.LIST_ITEM, 34 | EventCode.BUY_ITEM, 35 | ]: 36 | return f"{CODE_TO_EVENT[event_key[0]]}_{ITEM_ID_TO_NAME[event_key[1]]}" 37 | 38 | elif event_key[0] == EventCode.GO_FARTHEST: 39 | return "3_PROGRESS_TO_CENTER" 40 | 41 | elif event_key[0] == EventCode.AGENT_CULLED: 42 | return "2_AGENT_LIFESPAN" 43 | 44 | else: 45 | return CODE_TO_EVENT[event_key[0]] 46 | 47 | 48 | def extract_task_name(task_str): 49 | name = task_str.split("Task_eval_fn:(")[1].split(")_assignee:")[0] 50 | # then take out (agent_id,) 51 | return name.split("_(")[0] + "_" + name.split(")_")[1] 52 | 53 | 54 | def gather_agent_events_by_task(data_dir): 55 | data_by_task = defaultdict(list) 56 | file_list = [f for f in os.listdir(data_dir) if f.endswith(".metadata.pkl")] 57 | for file_name in tqdm(file_list): 58 | data = dill.load(open(f"{data_dir}/{file_name}", "rb")) 59 | final_tick = data["tick"] 60 | 61 | for agent_id, vals in data["event_stats"].items(): 62 | task_name = extract_task_name(data["task"][agent_id]) 63 | 64 | # Agent survived until the end 65 | if EventCode.AGENT_CULLED not in vals: 66 | vals[(EventCode.AGENT_CULLED,)] = final_tick 67 | data_by_task[task_name].append(vals) 68 | 69 | return data_by_task 70 | 71 | 72 | def get_event_stats(task_name, task_data): 73 | num_agents = len(task_data) 74 | assert num_agents > 0, "There should be at least one agent" 75 | 76 | cnt_attack = 0 77 | cnt_buy = 0 78 | cnt_consume = 0 79 | cnt_equip = 0 80 | cnt_harvest = 0 81 | cnt_list = 0 82 | 83 | results = {"0_NAME": task_name, "1_COUNT": num_agents} 84 | event_data = defaultdict(list) 85 | for data in task_data: 86 | for event, val in data.items(): 87 | event_data[event].append(val) 88 | 89 | for event, vals in event_data.items(): 90 | if event[0] == EventCode.LEVEL_UP: 91 | # Base skill level is 1 92 | vals += [1] * (num_agents - len(vals)) 93 | results[event_key_to_str(event)] = np.mean(vals) # AVG skill level 94 | elif event[0] == EventCode.AGENT_CULLED: 95 | life_span = np.mean(vals) 96 | results["2_AGENT_LIFESPAN_AVG"] = life_span 97 | results["2_AGENT_LIFESPAN_SD"] = np.std(vals) 98 | elif event[0] == EventCode.GO_FARTHEST: 99 | results["3_PROGRESS_TO_CENTER_AVG"] = np.mean(vals) 100 | results["3_PROGRESS_TO_CENTER_SD"] = np.std(vals) 101 | else: 102 | results[event_key_to_str(event)] = sum(vals) / num_agents 103 | 104 | if event[0] == EventCode.SCORE_HIT: 105 | cnt_attack += sum(vals) 106 | if event[0] == EventCode.BUY_ITEM: 107 | cnt_buy += sum(vals) 108 | if event[0] == EventCode.CONSUME_ITEM: 109 | cnt_consume += sum(vals) 110 | if event[0] == EventCode.EQUIP_ITEM: 111 | cnt_equip += sum(vals) 112 | if event[0] == EventCode.HARVEST_ITEM: 113 | cnt_harvest += sum(vals) 114 | if event[0] == EventCode.LIST_ITEM: 115 | cnt_list += sum(vals) 116 | 117 | results["4_NORM_ATTACK"] = cnt_attack / life_span 118 | results["4_NORM_BUY"] = cnt_buy / life_span 119 | results["4_NORM_CONSUME"] = cnt_consume / life_span 120 | results["4_NORM_EQUIP"] = cnt_equip / life_span 121 | results["4_NORM_HARVEST"] = cnt_harvest / life_span 122 | results["4_NORM_LIST"] = cnt_list / life_span 123 | 124 | return results 125 | 126 | 127 | if __name__ == "__main__": 128 | parser = argparse.ArgumentParser(description="Process replay data") 129 | parser.add_argument("policy_store_dir", type=str, help="Path to the policy directory") 130 | args = parser.parse_args() 131 | 132 | # Gather the event data by tasks, across multiple replays 133 | data_by_task = gather_agent_events_by_task(args.policy_store_dir) 134 | 135 | task_results = [ 136 | get_event_stats(task_name, task_data) for task_name, task_data in data_by_task.items() 137 | ] 138 | 139 | task_df = pl.DataFrame(task_results).fill_null(0).sort("0_NAME") 140 | task_df = task_df.select(sorted(task_df.columns)) 141 | task_df.write_csv("task_conditioning.tsv", separator="\t", float_precision=5) 142 | 143 | print("Result file saved as task_conditioning.tsv") 144 | print("Done.") 145 | -------------------------------------------------------------------------------- /analysis/run_task_conditioning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import random 4 | 5 | 6 | from train import load_from_config, combine_config_args, update_args 7 | from train_helper import generate_replay 8 | 9 | from evaluate import make_env_creator, make_agent_creator 10 | 11 | CURRICULUM_FILE = "neurips23_evaluation/heldout_task_with_embedding.pkl" 12 | 13 | 14 | if __name__ == "__main__": 15 | logging.basicConfig(level=logging.INFO) 16 | parser = argparse.ArgumentParser(description="Parse environment argument", add_help=False) 17 | parser.add_argument("eval_model_path", type=str, default=None, help="Path to model to evaluate") 18 | parser.add_argument( 19 | "-c", "--curriculum", type=str, default=CURRICULUM_FILE, help="Path to curriculum file" 20 | ) 21 | parser.add_argument( 22 | "-t", 23 | "--task-to-assign", 24 | type=int, 25 | default=None, 26 | help="The index of the task to assign in the curriculum file", 27 | ) 28 | parser.add_argument( 29 | "-r", "--repeat", type=int, default=1, help="Number of times to repeat the evaluation" 30 | ) 31 | clean_parser = argparse.ArgumentParser(parents=[parser]) 32 | args = parser.parse_known_args()[0].__dict__ 33 | 34 | # required args when using train.py's helper functions 35 | args["no_track"] = True 36 | args["no_recurrence"] = False 37 | args["vectorization"] = "serial" 38 | args["debug"] = False 39 | 40 | # Generate argparse menu from config 41 | config = load_from_config("neurips23_start_kit") # dummy learner 42 | args = combine_config_args(parser, args, config) 43 | args = update_args(args, mode="replay") 44 | 45 | agent_creator = make_agent_creator() 46 | env_creator = make_env_creator(args["curriculum"], "pvp") 47 | 48 | # Generate replay 49 | for i in range(args.repeat): 50 | generate_replay( 51 | args, 52 | env_creator, 53 | agent_creator, 54 | stop_when_all_complete_task=False, 55 | seed=random.randint(10000000, 99999999), 56 | ) 57 | print(f"Generated replay for task {i+1}/{args.repeat}...") 58 | 59 | print("Done!") 60 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | project: nmmo-baselines 3 | entity: kywch 4 | group: ~ 5 | 6 | debug: 7 | train: 8 | num_envs: 1 9 | envs_per_batch: 1 # batching envs work? 10 | envs_per_worker: 1 11 | batch_size: 1024 12 | total_timesteps: 10000 13 | pool_kernel: [0, 1] 14 | checkpoint_interval: 3 15 | verbose: True 16 | 17 | train: 18 | seed: 1 19 | torch_deterministic: True 20 | device: cuda 21 | total_timesteps: 10_000_000 22 | learning_rate: 1.5e-4 23 | anneal_lr: True 24 | gamma: 0.99 25 | gae_lambda: 0.95 26 | update_epochs: 3 27 | norm_adv: True 28 | clip_coef: 0.1 29 | clip_vloss: True 30 | ent_coef: 0.01 31 | vf_coef: 0.5 32 | max_grad_norm: 0.5 33 | target_kl: ~ 34 | 35 | num_envs: 15 36 | envs_per_worker: 1 37 | envs_per_batch: 6 38 | env_pool: True 39 | verbose: True 40 | data_dir: runs 41 | checkpoint_interval: 200 42 | pool_kernel: [0] 43 | batch_size: 32768 44 | batch_rows: 128 45 | bptt_horizon: 8 46 | vf_clip_coef: 0.1 47 | compile: False 48 | compile_mode: reduce-overhead 49 | 50 | sweep: 51 | method: random 52 | name: sweep 53 | metric: 54 | goal: maximize 55 | name: episodic_return 56 | # Nested parameters name required by WandB API 57 | parameters: 58 | train: 59 | parameters: 60 | learning_rate: { 61 | 'distribution': 'log_uniform_values', 62 | 'min': 1e-4, 63 | 'max': 1e-1, 64 | } 65 | batch_size: { 66 | 'values': [128, 256, 512, 1024, 2048], 67 | } 68 | batch_rows: { 69 | 'values': [16, 32, 64, 128, 256], 70 | } 71 | bptt_horizon: { 72 | 'values': [4, 8, 16, 32], 73 | } 74 | 75 | env: 76 | num_agents: 128 77 | num_npcs: 256 78 | max_episode_length: 1024 79 | maps_path: 'maps/train/' 80 | map_size: 128 81 | num_maps: 256 82 | map_force_generation: False 83 | death_fog_tick: ~ 84 | task_size: 2048 85 | spawn_immunity: 20 86 | resilient_population: 0.2 87 | 88 | policy: 89 | input_size: 256 90 | hidden_size: 256 91 | task_size: 2048 # must match env task_size 92 | 93 | recurrent: 94 | input_size: 256 95 | hidden_size: 256 96 | num_layers: 1 97 | 98 | reward_wrapper: 99 | eval_mode: False 100 | early_stop_agent_num: 8 101 | use_custom_reward: True 102 | 103 | neurips23_start_kit: 104 | reward_wrapper: 105 | heal_bonus_weight: 0.03 106 | explore_bonus_weight: 0.01 107 | 108 | yaofeng: 109 | env: 110 | maps_path: 'maps/train_yaofeng/' 111 | num_maps: 1024 112 | resilient_population: 0 113 | train: 114 | update_epochs: 2 115 | learning_rate: 1.0e-4 116 | recurrent: 117 | num_layers: 2 118 | reward_wrapper: 119 | hp_bonus_weight: 0.03 120 | exp_bonus_weight: 0.002 121 | defense_bonus_weight: 0.04 122 | attack_bonus_weight: 0.0 123 | gold_bonus_weight: 0.001 124 | custom_bonus_scale: 0.1 125 | disable_give: True 126 | donot_attack_dangerous_npc: True 127 | 128 | takeru: 129 | env: 130 | maps_path: 'maps/train_takeru/' 131 | num_maps: 1280 132 | resilient_population: 0 133 | train: 134 | update_epochs: 1 135 | recurrent: 136 | num_layers: 0 137 | reward_wrapper: 138 | early_stop_agent_num: 0 139 | explore_bonus_weight: 0.01 140 | disable_give: True 141 | 142 | hybrid: 143 | env: 144 | maps_path: 'maps/train_yaofeng/' 145 | num_maps: 1024 146 | resilient_population: 0 147 | train: 148 | update_epochs: 1 149 | recurrent: 150 | num_layers: 1 151 | reward_wrapper: 152 | hp_bonus_weight: 0.03 153 | exp_bonus_weight: 0.002 154 | defense_bonus_weight: 0.04 155 | attack_bonus_weight: 0.0 156 | gold_bonus_weight: 0.001 157 | custom_bonus_scale: 0.1 158 | disable_give: True 159 | donot_attack_dangerous_npc: True 160 | -------------------------------------------------------------------------------- /curriculum_generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/curriculum_generation/__init__.py -------------------------------------------------------------------------------- /curriculum_generation/curriculum_tutorial.py: -------------------------------------------------------------------------------- 1 | """This file explains how you can manually create your curriculum. 2 | 1. Use pre-built evaluation functions and TaskSpec to define training tasks. 3 | 2. Define your own evaluation functions. 4 | 3. Check if your training task tasks are valid and pickable. Must satisfy both. 5 | 4. Generate the task embedding file using the task encoder. 6 | 5. Train agents using the task embedding file. 7 | 6. Extract the training task stats. 8 | """ 9 | 10 | # allow custom functions to use pre-built eval functions without prefix 11 | from nmmo.task.base_predicates import CountEvent, InventorySpaceGE, TickGE, norm 12 | from nmmo.task.task_spec import TaskSpec, check_task_spec 13 | 14 | 15 | ############################################################################## 16 | # Use pre-built eval functions and TaskSpec class to define each training task 17 | # See manual_curriculum.py for detailed examples based on pre-built eval fns 18 | 19 | curriculum = [] 20 | 21 | # Make training tasks for each of the following events 22 | # Agents have completed the task if they have done the event N times 23 | essential_events = [ 24 | "GO_FARTHEST", 25 | "EAT_FOOD", 26 | "DRINK_WATER", 27 | "SCORE_HIT", 28 | "HARVEST_ITEM", 29 | "LEVEL_UP", 30 | ] 31 | 32 | for event_code in essential_events: 33 | curriculum.append( 34 | TaskSpec( 35 | eval_fn=CountEvent, # is a pre-built eval function 36 | eval_fn_kwargs={"event": event_code, "N": 10}, # kwargs for CountEvent 37 | ) 38 | ) 39 | 40 | 41 | ############################################################################## 42 | # Create training tasks using custom evaluation functions 43 | 44 | 45 | def PracticeEating(gs, subject): 46 | """The progress, the max of which is 1, should 47 | * increase small for each eating 48 | * increase big for the 1st and 3rd eating 49 | * reach 1 with 10 eatings 50 | """ 51 | num_eat = len(subject.event.EAT_FOOD) 52 | progress = num_eat * 0.06 53 | if num_eat >= 1: 54 | progress += 0.1 55 | if num_eat >= 3: 56 | progress += 0.3 57 | return norm(progress) # norm is a helper function to normalize the value to [0, 1] 58 | 59 | 60 | curriculum.append(TaskSpec(eval_fn=PracticeEating, eval_fn_kwargs={})) 61 | 62 | 63 | # You can also use pre-built eval functions to define your own eval functions 64 | def PracticeInventoryManagement(gs, subject, space, num_tick): 65 | return norm(InventorySpaceGE(gs, subject, space) * TickGE(gs, subject, num_tick)) 66 | 67 | 68 | for space in [2, 4, 8]: 69 | curriculum.append( 70 | TaskSpec( 71 | eval_fn=PracticeInventoryManagement, 72 | eval_fn_kwargs={"space": space, "num_tick": 500}, 73 | ) 74 | ) 75 | 76 | 77 | if __name__ == "__main__": 78 | # Import the custom curriculum 79 | print("------------------------------------------------------------") 80 | import curriculum_tutorial # which is this file 81 | 82 | CURRICULUM = curriculum_tutorial.curriculum 83 | print("The number of training tasks in the curriculum:", len(CURRICULUM)) 84 | 85 | # Check if these task specs are valid in the nmmo environment 86 | # Invalid tasks will crash your agent training 87 | print("------------------------------------------------------------") 88 | print("Checking whether the task specs are valid ...") 89 | results = check_task_spec(CURRICULUM) 90 | num_error = 0 91 | for result in results: 92 | if result["runnable"] is False: 93 | print("ERROR: ", result["spec_name"]) 94 | num_error += 1 95 | assert num_error == 0, "Invalid task specs will crash training. Please fix them." 96 | print("All training tasks are valid.") 97 | 98 | # The task_spec must be picklable to be used for agent training 99 | print("------------------------------------------------------------") 100 | print("Checking if the training tasks are picklable ...") 101 | CURRICULUM_FILE_PATH = "custom_curriculum_with_embedding.pkl" 102 | with open(CURRICULUM_FILE_PATH, "wb") as f: 103 | import dill 104 | 105 | dill.dump(CURRICULUM, f) 106 | print("All training tasks are picklable.") 107 | 108 | # To use the curriculum for agent training, the curriculum, task_spec, should be 109 | # saved to a file with the embeddings using the task encoder. The task encoder uses 110 | # a coding LLM to encode the task_spec into a vector. 111 | print("------------------------------------------------------------") 112 | print("Generating the task spec with embedding file ...") 113 | from task_encoder import TaskEncoder 114 | 115 | LLM_CHECKPOINT = "deepseek-ai/deepseek-coder-1.3b-instruct" 116 | 117 | # Get the task embeddings for the training tasks and save to file 118 | # You need to provide the curriculum file as a module to the task encoder 119 | with TaskEncoder(LLM_CHECKPOINT, curriculum_tutorial) as task_encoder: 120 | task_encoder.get_task_embedding(CURRICULUM, save_to_file=CURRICULUM_FILE_PATH) 121 | print("Done.") 122 | 123 | # TODO: MAKE THE BELOW CODE WORK 124 | 125 | # # Initialize the trainer with the custom curriculum 126 | # # These lines are the same as the RL track. If these don't run, please see train.py 127 | # from reinforcement_learning import config 128 | # from train import setup_env 129 | # args = config.create_config(config.Config) 130 | # args.tasks_path = CURRICULUM_FILE_PATH # This is the curriculum file saved by the task encoder 131 | 132 | # # Remove below lines if you want to use the default training config 133 | # local_mode = True 134 | # if local_mode: 135 | # args.num_envs = 1 136 | # args.num_buffers = 1 137 | # args.use_serial_vecenv = True 138 | # args.rollout_batch_size = 2**12 139 | 140 | # print("------------------------------------------------------------") 141 | # print("Setting up the agent training env ...") 142 | # trainer = setup_env(args) 143 | 144 | # # Train agents using the curriculum file 145 | # # NOTE: this is basically the same as the reinforcement_learning_track function in the train.py 146 | # while not trainer.done_training(): 147 | # print("------------------------------------------------------------") 148 | # print("Evaluating the agents ...") 149 | # _, _, infos = trainer.evaluate() 150 | # # The training task stats are available in infos, which then can be use for training task selection 151 | # if len(infos) > 0: 152 | # print("------------------------------------------------------------") 153 | # print("Training task stats:") 154 | # curri_keys = [key for key in infos.keys() if key.startswith("curriculum/")] 155 | # for key in curri_keys: 156 | # completed = [] 157 | # max_progress = [] 158 | # reward_signal_count = [] 159 | # for sub_list in infos[key]: 160 | # for prog, rcnt in sub_list: 161 | # completed.append(int(prog>=1)) # progress >= 1 is considered task complete 162 | # max_progress.append(prog) 163 | # reward_signal_count.append(rcnt) 164 | # print(f"{key} -- task tried: {len(completed)}, completed: {sum(completed)}, " + 165 | # f"avg max progress: {sum(max_progress)/len(max_progress):.3f}, " + 166 | # f"avg reward signal count: {sum(reward_signal_count)/len(reward_signal_count):.3f}") 167 | 168 | # print("------------------------------------------------------------") 169 | # print("The tutorial is done.") 170 | # break 171 | 172 | # print("------------------------------------------------------------") 173 | # print("Training the agents ...") 174 | # trainer.train( 175 | # update_epochs=args.ppo_update_epochs, 176 | # bptt_horizon=args.bptt_horizon, 177 | # batch_rows=args.ppo_training_batch_size // args.bptt_horizon, 178 | # ) 179 | -------------------------------------------------------------------------------- /curriculum_generation/curriculum_with_embedding.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/curriculum_generation/curriculum_with_embedding.pkl -------------------------------------------------------------------------------- /curriculum_generation/custom_curriculum_with_embedding.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/curriculum_generation/custom_curriculum_with_embedding.pkl -------------------------------------------------------------------------------- /curriculum_generation/manual_curriculum.py: -------------------------------------------------------------------------------- 1 | """Manual test for creating learning curriculum manually""" 2 | 3 | from typing import List 4 | 5 | import nmmo.lib.material as m 6 | import nmmo.systems.item as Item 7 | import nmmo.systems.skill as Skill 8 | from nmmo.task.base_predicates import ( 9 | AttainSkill, 10 | BuyItem, 11 | CanSeeAgent, 12 | CanSeeGroup, 13 | CanSeeTile, 14 | ConsumeItem, 15 | CountEvent, 16 | EarnGold, 17 | EquipItem, 18 | GainExperience, 19 | HarvestItem, 20 | HoardGold, 21 | InventorySpaceGE, 22 | ListItem, 23 | MakeProfit, 24 | OccupyTile, 25 | OwnItem, 26 | ScoreHit, 27 | SpendGold, 28 | TickGE, 29 | ) 30 | from nmmo.task.task_spec import TaskSpec, check_task_spec 31 | 32 | EVENT_NUMBER_GOAL = [1, 2, 3, 5, 7, 9, 12, 15, 20, 30, 50] 33 | INFREQUENT_GOAL = list(range(1, 10)) 34 | STAY_ALIVE_GOAL = EXP_GOAL = [50, 100, 150, 200, 300, 500, 700] 35 | LEVEL_GOAL = list(range(2, 10)) # TODO: get config 36 | AGENT_NUM_GOAL = ITEM_NUM_GOAL = [1, 2, 3, 4, 5] # competition team size: 8 37 | SKILLS = Skill.COMBAT_SKILL + Skill.HARVEST_SKILL 38 | COMBAT_STYLE = Skill.COMBAT_SKILL 39 | ALL_ITEM = Item.ALL_ITEM 40 | EQUIP_ITEM = Item.ARMOR + Item.WEAPON + Item.TOOL + Item.AMMUNITION 41 | HARVEST_ITEM = Item.WEAPON + Item.AMMUNITION + Item.CONSUMABLE 42 | TOOL_FOR_SKILL = { 43 | Skill.Melee: Item.Spear, 44 | Skill.Range: Item.Bow, 45 | Skill.Mage: Item.Wand, 46 | Skill.Fishing: Item.Rod, 47 | Skill.Herbalism: Item.Gloves, 48 | Skill.Carving: Item.Axe, 49 | Skill.Prospecting: Item.Pickaxe, 50 | Skill.Alchemy: Item.Chisel, 51 | } 52 | 53 | curriculum: List[TaskSpec] = [] 54 | 55 | # The default task: stay alive for the whole episode 56 | curriculum.append(TaskSpec(eval_fn=TickGE, eval_fn_kwargs={"num_tick": 1024})) 57 | 58 | # explore, eat, drink, attack any agent, harvest any item, level up any skill 59 | # which can happen frequently 60 | most_essentials = [ 61 | "EAT_FOOD", 62 | "DRINK_WATER", 63 | ] 64 | for event_code in most_essentials: 65 | for cnt in range(1, 10): 66 | curriculum.append( 67 | TaskSpec( 68 | eval_fn=CountEvent, 69 | eval_fn_kwargs={"event": event_code, "N": cnt}, 70 | sampling_weight=100, 71 | ) 72 | ) 73 | 74 | essential_skills = [ 75 | "SCORE_HIT", 76 | "PLAYER_KILL", 77 | "HARVEST_ITEM", 78 | "EQUIP_ITEM", 79 | "CONSUME_ITEM", 80 | "LEVEL_UP", 81 | "EARN_GOLD", 82 | "LIST_ITEM", 83 | "BUY_ITEM", 84 | ] 85 | for event_code in essential_skills: 86 | for cnt in EVENT_NUMBER_GOAL: 87 | curriculum.append( 88 | TaskSpec( 89 | eval_fn=CountEvent, 90 | eval_fn_kwargs={"event": event_code, "N": cnt}, 91 | sampling_weight=20, 92 | ) 93 | ) 94 | 95 | # item/market skills, which happen less frequently or should not do too much 96 | item_skills = [ 97 | "GIVE_ITEM", 98 | "DESTROY_ITEM", 99 | "GIVE_GOLD", 100 | ] 101 | for event_code in item_skills: 102 | curriculum += [ 103 | TaskSpec(eval_fn=CountEvent, eval_fn_kwargs={"event": event_code, "N": cnt}) 104 | for cnt in INFREQUENT_GOAL 105 | ] # less than 10 106 | 107 | # find resource tiles 108 | for resource in m.Harvestable: 109 | curriculum.append( 110 | TaskSpec( 111 | eval_fn=CanSeeTile, 112 | eval_fn_kwargs={"tile_type": resource}, 113 | sampling_weight=10, 114 | ) 115 | ) 116 | 117 | 118 | # practice particular skill with a tool 119 | def PracticeSkillWithTool(gs, subject, skill, exp): 120 | return 0.3 * EquipItem( 121 | gs, subject, item=TOOL_FOR_SKILL[skill], level=1, num_agent=1 122 | ) + 0.7 * GainExperience(gs, subject, skill, exp, num_agent=1) 123 | 124 | 125 | for skill in SKILLS: 126 | # level up a skill 127 | for level in LEVEL_GOAL[1:]: 128 | # since this is an agent task, num_agent must be 1 129 | curriculum.append( 130 | TaskSpec( 131 | eval_fn=AttainSkill, 132 | eval_fn_kwargs={"skill": skill, "level": level, "num_agent": 1}, 133 | sampling_weight=10 * (6 - level) if level < 6 else 5, 134 | ) 135 | ) 136 | 137 | # gain experience on particular skill 138 | for exp in EXP_GOAL: 139 | curriculum.append( 140 | TaskSpec( 141 | eval_fn=PracticeSkillWithTool, 142 | eval_fn_kwargs={"skill": skill, "exp": exp}, 143 | sampling_weight=50, 144 | ) 145 | ) 146 | 147 | # stay alive ... like ... for 300 ticks 148 | # i.e., getting incremental reward for each tick alive as an individual or a team 149 | for num_tick in STAY_ALIVE_GOAL: 150 | curriculum.append(TaskSpec(eval_fn=TickGE, eval_fn_kwargs={"num_tick": num_tick})) 151 | 152 | # occupy the center tile, assuming the Medium map size 153 | # TODO: it'd be better to have some intermediate targets toward the center 154 | curriculum.append(TaskSpec(eval_fn=OccupyTile, eval_fn_kwargs={"row": 80, "col": 80})) 155 | 156 | # find the other team leader 157 | for target in ["left_team_leader", "right_team_leader"]: 158 | curriculum.append(TaskSpec(eval_fn=CanSeeAgent, eval_fn_kwargs={"target": target})) 159 | 160 | # find the other team (any agent) 161 | for target in ["left_team", "right_team"]: 162 | curriculum.append(TaskSpec(eval_fn=CanSeeGroup, eval_fn_kwargs={"target": target})) 163 | 164 | # practice specific combat style 165 | for style in COMBAT_STYLE: 166 | for cnt in EVENT_NUMBER_GOAL: 167 | curriculum.append( 168 | TaskSpec( 169 | eval_fn=ScoreHit, 170 | eval_fn_kwargs={"combat_style": style, "N": cnt}, 171 | sampling_weight=5, 172 | ) 173 | ) 174 | 175 | # hoarding gold -- evaluated on the current gold 176 | for amount in EVENT_NUMBER_GOAL: 177 | curriculum.append( 178 | TaskSpec(eval_fn=HoardGold, eval_fn_kwargs={"amount": amount}, sampling_weight=10) 179 | ) 180 | 181 | # earning gold -- evaluated on the total gold earned by selling items 182 | for amount in EVENT_NUMBER_GOAL: 183 | curriculum.append( 184 | TaskSpec(eval_fn=EarnGold, eval_fn_kwargs={"amount": amount}, sampling_weight=10) 185 | ) 186 | 187 | # spending gold, by buying items 188 | for amount in EVENT_NUMBER_GOAL: 189 | curriculum.append( 190 | TaskSpec(eval_fn=SpendGold, eval_fn_kwargs={"amount": amount}, sampling_weight=5) 191 | ) 192 | 193 | # making profits by trading -- only buying and selling are counted 194 | for amount in EVENT_NUMBER_GOAL: 195 | curriculum.append( 196 | TaskSpec(eval_fn=MakeProfit, eval_fn_kwargs={"amount": amount}, sampling_weight=3) 197 | ) 198 | 199 | 200 | # managing inventory space 201 | def PracticeInventoryManagement(gs, subject, space, num_tick): 202 | return InventorySpaceGE(gs, subject, space) * TickGE(gs, subject, num_tick) 203 | 204 | 205 | for space in [2, 4, 8]: 206 | curriculum += [ 207 | TaskSpec( 208 | eval_fn=PracticeInventoryManagement, 209 | eval_fn_kwargs={"space": space, "num_tick": num_tick}, 210 | ) 211 | for num_tick in STAY_ALIVE_GOAL 212 | ] 213 | 214 | # own item, evaluated on the current inventory 215 | for item in ALL_ITEM: 216 | for level in LEVEL_GOAL: 217 | # agent task 218 | for quantity in ITEM_NUM_GOAL: 219 | if level + quantity <= 6 or quantity == 1: # heuristic prune 220 | curriculum.append( 221 | TaskSpec( 222 | eval_fn=OwnItem, 223 | eval_fn_kwargs={ 224 | "item": item, 225 | "level": level, 226 | "quantity": quantity, 227 | }, 228 | sampling_weight=4 - level if level < 4 else 1, 229 | ) 230 | ) 231 | 232 | # equip item, evaluated on the current inventory and equipment status 233 | for item in EQUIP_ITEM: 234 | for level in LEVEL_GOAL: 235 | # agent task 236 | curriculum.append( 237 | TaskSpec( 238 | eval_fn=EquipItem, 239 | eval_fn_kwargs={"item": item, "level": level, "num_agent": 1}, 240 | sampling_weight=4 - level if level < 4 else 1, 241 | ) 242 | ) 243 | 244 | # consume items (ration, potion), evaluated based on the event log 245 | for item in Item.CONSUMABLE: 246 | for level in LEVEL_GOAL: 247 | # agent task 248 | for quantity in ITEM_NUM_GOAL: 249 | if level + quantity <= 6 or quantity == 1: # heuristic prune 250 | curriculum.append( 251 | TaskSpec( 252 | eval_fn=ConsumeItem, 253 | eval_fn_kwargs={ 254 | "item": item, 255 | "level": level, 256 | "quantity": quantity, 257 | }, 258 | sampling_weight=4 - level if level < 4 else 1, 259 | ) 260 | ) 261 | 262 | # harvest items, evaluated based on the event log 263 | for item in HARVEST_ITEM: 264 | for level in LEVEL_GOAL: 265 | # agent task 266 | for quantity in ITEM_NUM_GOAL: 267 | if level + quantity <= 6 or quantity == 1: # heuristic prune 268 | curriculum.append( 269 | TaskSpec( 270 | eval_fn=HarvestItem, 271 | eval_fn_kwargs={ 272 | "item": item, 273 | "level": level, 274 | "quantity": quantity, 275 | }, 276 | sampling_weight=4 - level if level < 4 else 1, 277 | ) 278 | ) 279 | 280 | # list items, evaluated based on the event log 281 | for item in ALL_ITEM: 282 | for level in LEVEL_GOAL: 283 | # agent task 284 | for quantity in ITEM_NUM_GOAL: 285 | if level + quantity <= 6 or quantity == 1: # heuristic prune 286 | curriculum.append( 287 | TaskSpec( 288 | eval_fn=ListItem, 289 | eval_fn_kwargs={ 290 | "item": item, 291 | "level": level, 292 | "quantity": quantity, 293 | }, 294 | sampling_weight=4 - level if level < 4 else 1, 295 | ) 296 | ) 297 | 298 | # buy items, evaluated based on the event log 299 | for item in ALL_ITEM: 300 | for level in LEVEL_GOAL: 301 | # agent task 302 | for quantity in ITEM_NUM_GOAL: 303 | if level + quantity <= 6 or quantity == 1: # heuristic prune 304 | curriculum.append( 305 | TaskSpec( 306 | eval_fn=BuyItem, 307 | eval_fn_kwargs={ 308 | "item": item, 309 | "level": level, 310 | "quantity": quantity, 311 | }, 312 | sampling_weight=4 - level if level < 4 else 1, 313 | ) 314 | ) 315 | 316 | if __name__ == "__main__": 317 | import multiprocessing as mp 318 | from contextlib import contextmanager 319 | 320 | import dill 321 | import numpy as np 322 | import psutil 323 | 324 | @contextmanager 325 | def create_pool(num_proc): 326 | pool = mp.Pool(processes=num_proc) 327 | yield pool 328 | pool.close() 329 | pool.join() 330 | 331 | # 1609 task specs: divide the specs into chunks 332 | num_workers = round(psutil.cpu_count(logical=False) * 0.7) 333 | spec_chunks = np.array_split(curriculum, num_workers) 334 | with create_pool(num_workers) as pool: 335 | pool.map(check_task_spec, spec_chunks) 336 | 337 | # test if the task spec is pickalable 338 | with open("pickle_test.pkl", "wb") as f: 339 | dill.dump(curriculum, f) 340 | -------------------------------------------------------------------------------- /curriculum_generation/task_encoder.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import inspect 3 | import os 4 | import gc 5 | from types import ModuleType 6 | from typing import List 7 | 8 | import dill 9 | import torch 10 | import numpy as np 11 | from nmmo.task import task_spec as ts 12 | from tqdm import tqdm 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | 15 | 16 | def extract_module_fn(module: ModuleType): 17 | fn_dict = {} 18 | for name, fn in module.__dict__.items(): 19 | if inspect.isfunction(fn) and not inspect.isbuiltin(fn) and not name.startswith("_"): 20 | fn_dict[name] = fn 21 | return fn_dict 22 | 23 | 24 | class TaskEncoder: 25 | """A class for encoding tasks into embeddings using a pretrained model.""" 26 | 27 | def __init__( 28 | self, 29 | checkpoint: str, 30 | context: ModuleType, 31 | batch_size=2, 32 | tmp_file_path="tmp_task_encoder.pkl", 33 | ): 34 | """ 35 | Initialize the TaskEncoder. 36 | 37 | Args: 38 | checkpoint: Path to the pretrained model. 39 | context: Python module context in which tasks are defined. 40 | batch_size: Size of each batch during embedding computation. 41 | tmp_file_path: Temporary file path for saving intermediate data. 42 | """ 43 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 44 | self.tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) 45 | self.tokenizer.pad_token = self.tokenizer.eos_token 46 | if self.device == "cuda": 47 | self.model = AutoModelForCausalLM.from_pretrained( 48 | checkpoint, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16 49 | ) 50 | else: 51 | self.model = AutoModelForCausalLM.from_pretrained( 52 | checkpoint, trust_remote_code=True 53 | ).to(self.device) 54 | self.model.eval() 55 | self.batch_size = batch_size 56 | self.temp_file_path = tmp_file_path 57 | self._fn_dict = extract_module_fn(context) 58 | 59 | blank_embedding = self._get_embedding(["# just to get the embedding size"]) 60 | self.embed_dim = len(blank_embedding[0]) 61 | 62 | def update_context(self, context: ModuleType): 63 | """Update the module context, extracting function dictionary.""" 64 | self._fn_dict = extract_module_fn(context) 65 | 66 | def _get_embedding(self, prompts: List[str]) -> list: 67 | """ 68 | Compute the embeddings of tasks. 69 | 70 | Args: 71 | prompts: List of tasks defined as prompts. 72 | 73 | Returns: 74 | A list of embeddings corresponding to input tasks. 75 | """ 76 | all_embeddings = [] 77 | with torch.no_grad(): 78 | for i in tqdm(range(0, len(prompts), self.batch_size)): 79 | batch = prompts[i : i + self.batch_size] 80 | tokens = self.tokenizer( 81 | batch, return_tensors="pt", padding=True, truncation=True 82 | ).to(self.device) 83 | outputs = self.model(**tokens, output_hidden_states=True) 84 | embeddings = ( 85 | outputs.hidden_states[-1].mean(dim=1).detach().cpu().to(torch.float32).numpy() 86 | ) 87 | all_embeddings.extend(embeddings.astype(np.float16)) 88 | return all_embeddings 89 | 90 | def _get_task_deps_src(self, eval_fn) -> tuple: 91 | """ 92 | Extract source code and dependent functions of the evaluation function. 93 | 94 | Args: 95 | eval_fn: Function for task evaluation. 96 | 97 | Returns: 98 | A tuple with source code and dependencies of eval_fn. 99 | """ 100 | eval_src = inspect.getsource(eval_fn) 101 | deps_fns = [ 102 | node.func.id 103 | for node in ast.walk(ast.parse(eval_src)) 104 | if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) 105 | ] 106 | deps_src = "\n".join( 107 | [ 108 | inspect.getsource(self._fn_dict[fn_name]) 109 | for fn_name in deps_fns 110 | if fn_name in self._fn_dict 111 | ] 112 | ) 113 | return eval_src, deps_src 114 | 115 | def _construct_prompt(self, reward_to, eval_fn, eval_fn_kwargs) -> str: 116 | """ 117 | Construct a task-specific prompt. 118 | 119 | Args: 120 | reward_to: Reward given to the agent upon successful completion of the task. 121 | eval_fn: Function for task evaluation. 122 | eval_fn_kwargs: Keyword arguments for eval_fn. 123 | 124 | Returns: 125 | A string representing the task prompt. 126 | """ 127 | eval_src, deps_src = self._get_task_deps_src(eval_fn) 128 | task_specific_prompt = f"""Your goal is to explain what an agent must accomplish in the neural MMO. 129 | Neural MMO is a research platform simulating populations of agents in virtual worlds. 130 | The reward from this function goes to {reward_to}. 131 | The function name is {eval_fn.__name__}. These are the arguments that the function takes {eval_fn_kwargs}. 132 | The function source code is \n####\n{eval_src}#### . 133 | This function calls these other functions \n####\n{deps_src}#### . 134 | The agent's goal is""" 135 | return task_specific_prompt 136 | 137 | def get_task_embedding(self, task_spec_list: List[ts.TaskSpec], save_to_file: str = None): 138 | """ 139 | Compute embeddings for given task specifications and save them to file. 140 | 141 | Args: 142 | task_spec_list: List of task specifications. 143 | save_to_file: Name of the file where the results should be saved. 144 | 145 | Returns: 146 | Updated task specifications with embeddings. 147 | """ 148 | assert self.model is not None, "Model has been unloaded. Re-initialize the TaskEncoder." 149 | prompts = [ 150 | self._construct_prompt( 151 | single_spec.reward_to, single_spec.eval_fn, single_spec.eval_fn_kwargs 152 | ) 153 | for single_spec in task_spec_list 154 | ] 155 | embeddings = self._get_embedding(prompts) 156 | 157 | for single_spec, embedding in zip(task_spec_list, embeddings): 158 | single_spec.embedding = embedding 159 | 160 | if save_to_file: # use save_to_file as the file name 161 | with open(self.temp_file_path, "wb") as f: 162 | dill.dump(task_spec_list, f) 163 | os.replace(self.temp_file_path, save_to_file) 164 | 165 | return task_spec_list 166 | 167 | def close(self): 168 | # free up gpu memory 169 | self.model = None 170 | self.tokenizer = None 171 | gc.collect() 172 | torch.cuda.empty_cache() 173 | 174 | def __enter__(self): 175 | return self 176 | 177 | def __exit__(self, exc_type, exc_value, traceback): 178 | self.close() 179 | 180 | 181 | if __name__ == "__main__": 182 | import curriculum_generation.manual_curriculum as curriculum 183 | 184 | LLM_CHECKPOINT = "deepseek-ai/deepseek-coder-1.3b-instruct" 185 | CURRICULUM_FILE_PATH = "curriculum_generation/curriculum_with_embedding.pkl" 186 | 187 | with TaskEncoder(LLM_CHECKPOINT, curriculum, batch_size=6) as task_encoder: 188 | task_encoder.get_task_embedding(curriculum.curriculum, save_to_file=CURRICULUM_FILE_PATH) 189 | -------------------------------------------------------------------------------- /curriculum_generation/task_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | from nmmo.task import task_spec as ts 6 | 7 | 8 | class LearnableTaskSampler: 9 | def __init__(self, task_spec: List[ts.TaskSpec], average_window=50): 10 | self.task_spec = task_spec 11 | self.name_to_spec = {single_spec.name: single_spec for single_spec in self.task_spec} 12 | self.task_stats = {} 13 | self.average_window = average_window 14 | 15 | def reset(self): 16 | self.task_stats = {} 17 | 18 | def add_tasks(self, task_spec: List[ts.TaskSpec]): 19 | # so that the stats for the new tasks can be tracked 20 | for new_spec in task_spec: 21 | if new_spec.name not in self.name_to_spec: 22 | self.task_spec.append(new_spec) 23 | self.name_to_spec[new_spec.name] = new_spec 24 | 25 | def update(self, infos, prefix="curriculum/"): 26 | for key, val in infos.items(): 27 | # Process the new infos 28 | if key.startswith(prefix): 29 | spec_name = key.replace(prefix, "") 30 | completed, _, rcnt_over_2 = [], [], [] 31 | for sublist in val: 32 | for prog, rcnt in sublist: 33 | completed.append(float(prog >= 1)) 34 | rcnt_over_2.append(float(rcnt >= 2)) # rewarded >= 2 times 35 | 36 | # Add to the task_stats 37 | if spec_name not in self.task_stats: 38 | self.task_stats[spec_name] = defaultdict(list) 39 | self.task_stats[spec_name]["completed"] += completed 40 | self.task_stats[spec_name]["rcnt_over_2"] += rcnt_over_2 41 | 42 | # Keep only the recent values -- self.average_window (50) 43 | for key, vals in self.task_stats[spec_name].items(): 44 | self.task_stats[spec_name][key] = vals[-self.average_window :] 45 | 46 | def get_learnable_tasks( 47 | self, 48 | num_tasks, 49 | max_completed=0.8, # filter out easy tasks 50 | min_completed=0.05, # filter out harder tasks 51 | min_rcnt_rate=0.1, # reward signal generating 52 | ) -> List[ts.TaskSpec]: 53 | learnable = [] 54 | for spec_name, stat in self.task_stats.items(): 55 | completion_rate = np.mean(stat["completed"]) 56 | rcnt_over2_rate = np.mean(stat["rcnt_over_2"]) 57 | if completion_rate < max_completed and ( 58 | completion_rate >= min_completed or rcnt_over2_rate >= min_rcnt_rate 59 | ): 60 | learnable.append(self.name_to_spec[spec_name]) 61 | 62 | if len(learnable) > num_tasks: 63 | return list(np.random.choice(learnable, num_tasks)) 64 | return learnable 65 | 66 | def sample_tasks( 67 | self, 68 | num_tasks, 69 | random_ratio=0.5, 70 | reset_sampling_weight=True, 71 | ) -> List[ts.TaskSpec]: 72 | task_spec = [] 73 | if 0 <= random_ratio < 1: 74 | num_learnable = round(num_tasks * (1 - random_ratio)) 75 | task_spec = self.get_learnable_tasks(num_learnable) 76 | 77 | # fill in with the randomly-sampled tasks 78 | # TODO: sample more "less-sampled" tasks (i.e., no or little stats) 79 | num_sample = num_tasks - len(task_spec) 80 | task_spec += list(np.random.choice(self.task_spec, num_sample)) 81 | 82 | if reset_sampling_weight: 83 | for single_spec in task_spec: 84 | single_spec.sampling_weight = 1 85 | 86 | return task_spec 87 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import logging 5 | import argparse 6 | from collections import defaultdict 7 | 8 | import nmmo 9 | import nmmo.core.config as nc 10 | 11 | import pufferlib 12 | import pufferlib.policy_pool as pp 13 | 14 | from reinforcement_learning import clean_pufferl 15 | import agent_zoo.neurips23_start_kit as default_learner 16 | 17 | from train import get_init_args 18 | 19 | NUM_AGENTS = 128 20 | EVAL_TASK_FILE = "neurips23_evaluation/heldout_task_with_embedding.pkl" 21 | NUM_PVE_EVAL_EPISODE = 32 22 | NUM_PVP_EVAL_EPISODE = 200 # TODO: cannot do more due to memory leak 23 | 24 | 25 | def get_eval_config(debug=False): 26 | return { 27 | "device": "cuda", 28 | "num_envs": 6 if not debug else 1, 29 | "batch_size": 2**15 if not debug else 2**12, 30 | } 31 | 32 | 33 | class EvalConfig( 34 | nc.Medium, 35 | nc.Terrain, 36 | nc.Resource, 37 | nc.Combat, 38 | nc.NPC, 39 | nc.Progression, 40 | nc.Item, 41 | nc.Equipment, 42 | nc.Profession, 43 | nc.Exchange, 44 | ): 45 | """NMMO config for NeurIPS 2023 competition evaluation. 46 | Hardcoded to keep the eval config independent from the training config. 47 | """ 48 | 49 | def __init__(self, task_file, mode): 50 | super().__init__() 51 | self.set("GAME_PACKS", [(nmmo.core.game_api.AgentTraining, 1)]) 52 | self.set("CURRICULUM_FILE_PATH", task_file) 53 | self.set("TASK_EMBED_DIM", 2048) # must match the task file 54 | 55 | # Eval constants 56 | self.set("PROVIDE_ACTION_TARGETS", True) 57 | self.set("PROVIDE_NOOP_ACTION_TARGET", True) 58 | self.set("PLAYER_N", NUM_AGENTS) 59 | self.set("HORIZON", 1024) 60 | self.set("PLAYER_DEATH_FOG", None) 61 | self.set("NPC_N", 256) 62 | self.set("RESOURCE_RESILIENT_POPULATION", 0) 63 | self.set("COMBAT_SPAWN_IMMUNITY", 20) 64 | 65 | # Map related 66 | self.set("TERRAIN_FLIP_SEED", True) 67 | self.set("MAP_CENTER", 128) 68 | self.set("MAP_FORCE_GENERATION", False) 69 | self.set("MAP_GENERATE_PREVIEWS", True) 70 | if mode not in ["pve", "pvp"]: 71 | raise ValueError(f"Invalid eval_mode: {mode}") 72 | if mode == "pve": 73 | self.set("MAP_N", 4) 74 | self.set("PATH_MAPS", "maps/pve_eval/") 75 | else: 76 | self.set("MAP_N", 256) 77 | self.set("PATH_MAPS", "maps/pvp_eval/") 78 | 79 | 80 | def make_env_creator(task_file, mode): 81 | def env_creator(*args, **kwargs): # dummy args 82 | env = nmmo.Env(EvalConfig(task_file, mode)) 83 | # Reward wrapper is for the learner, which is not used in evaluation 84 | env = default_learner.RewardWrapper( 85 | env, 86 | **{ 87 | "eval_mode": True, 88 | "early_stop_agent_num": 0, 89 | }, 90 | ) 91 | env = pufferlib.emulation.PettingZooPufferEnv(env) 92 | return env 93 | 94 | return env_creator 95 | 96 | 97 | def make_agent_creator(): 98 | # NOTE: Assuming all policies are recurrent, which may not be true 99 | policy_args = get_init_args(default_learner.Policy.__init__) 100 | recurrent_args = get_init_args(default_learner.Recurrent.__init__) 101 | 102 | def agent_creator(env, args=None): 103 | policy = default_learner.Policy(env, **policy_args) 104 | policy = default_learner.Recurrent(env, policy, **recurrent_args) 105 | policy = pufferlib.frameworks.cleanrl.RecurrentPolicy(policy) 106 | return policy.to(get_eval_config()["device"]) 107 | 108 | return agent_creator 109 | 110 | 111 | class EvalRunner: 112 | def __init__(self, policy_store_dir, debug=False): 113 | self.policy_store_dir = policy_store_dir 114 | self._debug = debug 115 | 116 | def set_debug(self, debug): 117 | self._debug = debug 118 | 119 | def setup_evaluator(self, mode, task_file, seed): 120 | policies = pp.get_policy_names(self.policy_store_dir) 121 | assert len(policies) > 0, "No policies found in eval_model_path" 122 | if mode == "pve": 123 | assert len(policies) == 1, "PvE mode requires only one policy" 124 | logging.info(f"Policies to evaluate: {policies}") 125 | 126 | # pool_kernel determines policy-agent mapping 127 | pool_kernel = pp.create_kernel(NUM_AGENTS, len(policies), shuffle_with_seed=seed) 128 | 129 | config = self.get_pufferl_config(self._debug) 130 | config.seed = seed 131 | config.data_dir = self.policy_store_dir 132 | config.pool_kernel = pool_kernel 133 | 134 | vectorization = ( 135 | pufferlib.vectorization.Serial 136 | if self._debug 137 | else pufferlib.vectorization.Multiprocessing 138 | ) 139 | 140 | return clean_pufferl.create( 141 | config=config, 142 | agent_creator=make_agent_creator(), 143 | env_creator=make_env_creator(task_file, mode), 144 | vectorization=vectorization, 145 | eval_mode=True, 146 | eval_model_path=self.policy_store_dir, 147 | policy_selector=pp.AllPolicySelector(seed), 148 | ) 149 | 150 | @staticmethod 151 | def get_pufferl_config(debug=False): 152 | config = get_eval_config(debug) 153 | # add required configs 154 | config["torch_deterministic"] = True 155 | config["total_timesteps"] = 100_000_000 # arbitrarily large, but will end much earlier 156 | config["envs_per_batch"] = config["num_envs"] 157 | config["envs_per_worker"] = 1 158 | config["env_pool"] = False 159 | config["learning_rate"] = 1e-4 160 | config["compile"] = False 161 | config["verbose"] = True # not debug 162 | return pufferlib.namespace(**config) 163 | 164 | def perform_eval(self, mode, task_file, seed, num_eval_episode, save_file_prefix): 165 | pufferl_data = self.setup_evaluator(mode, task_file, seed) 166 | # this is a hack 167 | pufferl_data.policy_pool.mask[:] = 1 # policy_pool.mask is valid for training only 168 | 169 | eval_results = {} 170 | cnt_episode = 0 171 | while cnt_episode < num_eval_episode: 172 | _, infos = clean_pufferl.evaluate(pufferl_data) 173 | 174 | for pol, vals in infos.items(): 175 | cnt_episode += sum(infos[pol]["episode_done"]) 176 | if pol not in eval_results: 177 | eval_results[pol] = defaultdict(list) 178 | for k, v in vals.items(): 179 | if k == "length": 180 | eval_results[pol][k] += v # length is a plain list 181 | if k.startswith("curriculum"): 182 | eval_results[pol][k] += [vv[0] for vv in v] 183 | 184 | pufferl_data.sort_keys = [] # TODO: check if this solves memory leak 185 | 186 | print(f"\nSeed: {seed}, evaluated {cnt_episode} episodes.\n") 187 | 188 | file_name = f"{save_file_prefix}_{seed}.json" 189 | self._save_results(eval_results, file_name) 190 | clean_pufferl.close(pufferl_data) 191 | return eval_results, file_name 192 | 193 | def _save_results(self, results, file_name): 194 | with open(os.path.join(self.policy_store_dir, file_name), "w") as f: 195 | json.dump(results, f) 196 | 197 | def run( 198 | self, mode, task_file=EVAL_TASK_FILE, seed=None, num_episode=None, save_file_prefix=None 199 | ): 200 | assert mode in ["pve", "pvp"], f"Invalid mode: {mode}" 201 | if mode == "pve": 202 | num_episode = num_episode or NUM_PVE_EVAL_EPISODE 203 | save_file_prefix = save_file_prefix or "eval_pve" 204 | else: 205 | num_episode = num_episode or NUM_PVP_EVAL_EPISODE 206 | save_file_prefix = save_file_prefix or "eval_pvp" 207 | 208 | if self._debug: 209 | num_episode = 4 210 | 211 | if seed is None: 212 | seed = random.randint(10000000, 99999999) 213 | 214 | logging.info(f"Evaluating {self.policy_store_dir} in the {mode} mode with seed: {seed}") 215 | logging.info(f"Using the task file: {task_file}") 216 | 217 | _, file_name = self.perform_eval(mode, task_file, seed, num_episode, save_file_prefix) 218 | 219 | print(f"Saved the result file to: {file_name}.") 220 | 221 | 222 | if __name__ == "__main__": 223 | logging.basicConfig(level=logging.INFO) 224 | parser = argparse.ArgumentParser(description="Evaluate a policy store") 225 | parser.add_argument("policy_store_dir", type=str, help="Path to the policy directory") 226 | parser.add_argument("mode", type=str, choices=["pve", "pvp"], help="Evaluation mode") 227 | parser.add_argument( 228 | "-t", "--task-file", type=str, default=EVAL_TASK_FILE, help="Path to the task file" 229 | ) 230 | parser.add_argument("-s", "--seed", type=int, default=1, help="Random seed") 231 | parser.add_argument( 232 | "-n", "--num-episode", type=int, default=None, help="Number of episodes to evaluate" 233 | ) 234 | parser.add_argument( 235 | "-r", "--repeat", type=int, default=1, help="Number of times to repeat the evaluation" 236 | ) 237 | parser.add_argument( 238 | "--save-file-prefix", type=str, default=None, help="Prefix for the save file" 239 | ) 240 | parser.add_argument("--debug", action="store_true", help="Debug mode") 241 | args = parser.parse_args() 242 | 243 | runner = EvalRunner(args.policy_store_dir, args.debug) 244 | for i in range(args.repeat): 245 | if i > 0: 246 | args.seed = None # this will sample new seed 247 | runner.run(args.mode, args.task_file, args.seed, args.num_episode, args.save_file_prefix) 248 | -------------------------------------------------------------------------------- /neurips23_evaluation/export_embeddings.py: -------------------------------------------------------------------------------- 1 | import dill 2 | import polars as pl 3 | 4 | # load the curriculum and evaluation files 5 | with open("curriculum_generation/curriculum_with_embedding.pkl", "rb") as f: 6 | curriculum = dill.load(f) 7 | with open("neurips23_evaluation/heldout_task_with_embedding.pkl", "rb") as f: 8 | eval_tasks = dill.load(f) 9 | 10 | # metadata: task name (including full info), predicate, kwargs, sampling weights, training vs evaluation 11 | # group by 12 | # - train vs eval 13 | # - predicate 14 | 15 | # embedding projector needs a tsv file of vectors only and metadata files 16 | 17 | embeddings = [] 18 | metadata = [] 19 | 20 | 21 | def get_task_predicate(spec): 22 | name = spec.name.split("_")[1] 23 | if name == "CountEvent": 24 | return name + "=" + spec.eval_fn_kwargs["event"] 25 | return name 26 | 27 | 28 | for spec in curriculum: 29 | embeddings.append(spec.embedding) 30 | metadata.append( 31 | { 32 | "task_name": spec.name.replace("Task_", "").replace("_reward_to:agent", ""), 33 | "predicate": get_task_predicate(spec), 34 | "used_for": "train", 35 | "sampling_weight": spec.sampling_weight, 36 | } 37 | ) 38 | 39 | for spec in eval_tasks: 40 | embeddings.append(spec.embedding) 41 | metadata.append( 42 | { 43 | "task_name": spec.name.replace("Task_", "").replace("_reward_to:agent", ""), 44 | "predicate": get_task_predicate(spec), 45 | "used_for": "eval", 46 | "sampling_weight": spec.sampling_weight, 47 | } 48 | ) 49 | 50 | 51 | embed_df = pl.DataFrame(embeddings) 52 | embed_df.write_csv("task_embeddings.tsv", separator="\t", include_header=False, float_precision=6) 53 | 54 | meta_df = pl.DataFrame(metadata) 55 | meta_df.write_csv("task_metadata.tsv", separator="\t") 56 | -------------------------------------------------------------------------------- /neurips23_evaluation/heldout_evaluation_task.py: -------------------------------------------------------------------------------- 1 | """Held-out evaluation tasks for NeurIPS 2023 competition.""" 2 | 3 | from typing import List 4 | 5 | from nmmo.systems import skill as s 6 | from nmmo.systems import item as i 7 | from nmmo.task.base_predicates import ( 8 | AttainSkill, 9 | ConsumeItem, 10 | CountEvent, 11 | DefeatEntity, 12 | EarnGold, 13 | EquipItem, 14 | FullyArmed, 15 | HarvestItem, 16 | HoardGold, 17 | MakeProfit, 18 | OccupyTile, 19 | TickGE, 20 | ) 21 | from nmmo.task.task_spec import TaskSpec, check_task_spec 22 | 23 | 24 | CURRICULUM_FILE_PATH = "neurips23_evaluation/heldout_task_with_embedding.pkl" 25 | 26 | EVENT_GOAL = 20 27 | LEVEL_GOAL = [1, 3] 28 | GOLD_GOAL = 100 29 | 30 | curriculum: List[TaskSpec] = [] 31 | 32 | # Survive to the end 33 | curriculum.append(TaskSpec(eval_fn=TickGE, eval_fn_kwargs={"num_tick": 1024})) 34 | 35 | # Kill 20 players/npcs 36 | curriculum.append( 37 | TaskSpec( 38 | eval_fn=CountEvent, 39 | eval_fn_kwargs={"event": "PLAYER_KILL", "N": EVENT_GOAL}, 40 | ) 41 | ) 42 | 43 | # Kill npcs of level 1+, 3+ 44 | for level in LEVEL_GOAL: 45 | curriculum.append( 46 | TaskSpec( 47 | eval_fn=DefeatEntity, 48 | eval_fn_kwargs={"agent_type": "npc", "level": level, "num_agent": EVENT_GOAL}, 49 | ) 50 | ) 51 | 52 | # Explore and reach the center (80, 80) 53 | curriculum.append( 54 | TaskSpec( 55 | eval_fn=CountEvent, 56 | eval_fn_kwargs={"event": "GO_FARTHEST", "N": 64}, 57 | ) 58 | ) 59 | 60 | curriculum.append( 61 | TaskSpec( 62 | eval_fn=OccupyTile, 63 | eval_fn_kwargs={"row": 80, "col": 80}, 64 | ) 65 | ) 66 | 67 | # Reach skill level 10 68 | for skill in s.COMBAT_SKILL + s.HARVEST_SKILL: 69 | curriculum.append( 70 | TaskSpec( 71 | eval_fn=AttainSkill, 72 | eval_fn_kwargs={"skill": skill, "level": 10, "num_agent": 1}, 73 | ) 74 | ) 75 | 76 | # Harvest 20 ammos of level 1+ or 3+ 77 | for ammo in i.AMMUNITION: 78 | for level in LEVEL_GOAL: 79 | curriculum.append( 80 | TaskSpec( 81 | eval_fn=HarvestItem, 82 | eval_fn_kwargs={"item": ammo, "level": level, "quantity": EVENT_GOAL}, 83 | ) 84 | ) 85 | 86 | # Consume 10 ration/potions of level 1+ or 3+ 87 | for item in i.CONSUMABLE: 88 | for level in LEVEL_GOAL: 89 | curriculum.append( 90 | TaskSpec( 91 | eval_fn=ConsumeItem, 92 | eval_fn_kwargs={"item": item, "level": level, "quantity": EVENT_GOAL}, 93 | ) 94 | ) 95 | 96 | # Equip armour, weapons, tools, and ammos 97 | for item in i.ARMOR + i.WEAPON + i.TOOL + i.AMMUNITION: 98 | for level in LEVEL_GOAL: 99 | curriculum.append( 100 | TaskSpec( 101 | eval_fn=EquipItem, 102 | eval_fn_kwargs={"item": item, "level": level, "num_agent": 1}, 103 | ) 104 | ) 105 | 106 | # Fully armed, level 1+ or 3+ 107 | for skill in s.COMBAT_SKILL: 108 | for level in LEVEL_GOAL: 109 | curriculum.append( 110 | TaskSpec( 111 | eval_fn=FullyArmed, 112 | eval_fn_kwargs={"combat_style": skill, "level": level, "num_agent": 1}, 113 | ) 114 | ) 115 | 116 | # Buy and Sell 10 items (of any kind) 117 | curriculum.append( 118 | TaskSpec( 119 | eval_fn=CountEvent, 120 | eval_fn_kwargs={"event": "EARN_GOLD", "N": EVENT_GOAL}, # item sold 121 | ) 122 | ) 123 | 124 | curriculum.append( 125 | TaskSpec( 126 | eval_fn=CountEvent, 127 | eval_fn_kwargs={"event": "BUY_ITEM", "N": EVENT_GOAL}, # item bought 128 | ) 129 | ) 130 | 131 | # Earn 100 gold (revenue), just by trading 132 | curriculum.append(TaskSpec(eval_fn=EarnGold, eval_fn_kwargs={"amount": GOLD_GOAL})) 133 | 134 | # Own and protect 100 gold by any means (looting or trading) 135 | curriculum.append(TaskSpec(eval_fn=HoardGold, eval_fn_kwargs={"amount": GOLD_GOAL})) 136 | 137 | # Make profit of 100 gold by any means 138 | curriculum.append(TaskSpec(eval_fn=MakeProfit, eval_fn_kwargs={"amount": GOLD_GOAL})) 139 | 140 | 141 | if __name__ == "__main__": 142 | # Import the custom curriculum 143 | print("------------------------------------------------------------") 144 | from neurips23_evaluation import heldout_evaluation_task # which is this file 145 | 146 | CURRICULUM = heldout_evaluation_task.curriculum 147 | print("The number of training tasks in the curriculum:", len(CURRICULUM)) 148 | 149 | # Check if these task specs are valid in the nmmo environment 150 | # Invalid tasks will crash your agent training 151 | print("------------------------------------------------------------") 152 | print("Checking whether the task specs are valid ...") 153 | results = check_task_spec(CURRICULUM) 154 | num_error = 0 155 | for result in results: 156 | if result["runnable"] is False: 157 | print("ERROR: ", result["spec_name"]) 158 | num_error += 1 159 | assert num_error == 0, "Invalid task specs will crash training. Please fix them." 160 | print("All training tasks are valid.") 161 | 162 | # The task_spec must be picklable to be used for agent training 163 | print("------------------------------------------------------------") 164 | print("Checking if the training tasks are picklable ...") 165 | with open(CURRICULUM_FILE_PATH, "wb") as f: 166 | import dill 167 | 168 | dill.dump(CURRICULUM, f) 169 | print("All training tasks are picklable.") 170 | 171 | # To use the curriculum for agent training, the curriculum, task_spec, should be 172 | # saved to a file with the embeddings using the task encoder. The task encoder uses 173 | # a coding LLM to encode the task_spec into a vector. 174 | print("------------------------------------------------------------") 175 | print("Generating the task spec with embedding file ...") 176 | from curriculum_generation.task_encoder import TaskEncoder 177 | 178 | LLM_CHECKPOINT = "deepseek-ai/deepseek-coder-1.3b-instruct" 179 | 180 | # Get the task embeddings for the training tasks and save to file 181 | # You need to provide the curriculum file as a module to the task encoder 182 | with TaskEncoder(LLM_CHECKPOINT, heldout_evaluation_task) as task_encoder: 183 | task_encoder.get_task_embedding(CURRICULUM, save_to_file=CURRICULUM_FILE_PATH) 184 | print("Done.") 185 | -------------------------------------------------------------------------------- /neurips23_evaluation/heldout_task_with_embedding.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/neurips23_evaluation/heldout_task_with_embedding.pkl -------------------------------------------------------------------------------- /neurips23_evaluation/sample_eval_task_with_embedding.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/neurips23_evaluation/sample_eval_task_with_embedding.pkl -------------------------------------------------------------------------------- /neurips23_evaluation/sample_evaluation_task.py: -------------------------------------------------------------------------------- 1 | """Manual test for creating learning curriculum manually""" 2 | 3 | from typing import List 4 | 5 | from nmmo.systems import skill as s 6 | from nmmo.task.base_predicates import AttainSkill, CountEvent, EarnGold, TickGE 7 | from nmmo.task.task_spec import TaskSpec 8 | 9 | 10 | CURRICULUM_FILE_PATH = "neurips23_evaluation/sample_eval_task_with_embedding.pkl" 11 | 12 | curriculum: List[TaskSpec] = [] 13 | 14 | # Stay alive as long as possible 15 | curriculum.append(TaskSpec(eval_fn=TickGE, eval_fn_kwargs={"num_tick": 1024})) 16 | 17 | # Perform these 10 times 18 | essential_skills = [ 19 | "EAT_FOOD", 20 | "DRINK_WATER", 21 | "SCORE_HIT", 22 | "PLAYER_KILL", 23 | "HARVEST_ITEM", 24 | "EQUIP_ITEM", 25 | "CONSUME_ITEM", 26 | "LEVEL_UP", 27 | "EARN_GOLD", 28 | "LIST_ITEM", 29 | "BUY_ITEM", 30 | "GIVE_ITEM", 31 | "DESTROY_ITEM", 32 | "GIVE_GOLD", 33 | ] 34 | for event_code in essential_skills: 35 | curriculum.append( 36 | TaskSpec( 37 | eval_fn=CountEvent, 38 | eval_fn_kwargs={"event": event_code, "N": 10}, 39 | ) 40 | ) 41 | 42 | # Reach skill level 10 43 | for skill in s.COMBAT_SKILL + s.HARVEST_SKILL: 44 | curriculum.append( 45 | TaskSpec( 46 | eval_fn=AttainSkill, 47 | eval_fn_kwargs={"skill": skill, "level": 10, "num_agent": 1}, 48 | ) 49 | ) 50 | 51 | # Earn gold 50 52 | curriculum.append(TaskSpec(eval_fn=EarnGold, eval_fn_kwargs={"amount": 50})) 53 | 54 | if __name__ == "__main__": 55 | from neurips23_evaluation import sample_evaluation_task as curriculum 56 | from curriculum_generation.task_encoder import TaskEncoder 57 | 58 | LLM_CHECKPOINT = "deepseek-ai/deepseek-coder-1.3b-instruct" 59 | 60 | with TaskEncoder(LLM_CHECKPOINT, curriculum, batch_size=6) as task_encoder: 61 | task_encoder.get_task_embedding(curriculum.curriculum, save_to_file=CURRICULUM_FILE_PATH) 62 | -------------------------------------------------------------------------------- /policies/README.md: -------------------------------------------------------------------------------- 1 | # Training logs 2 | * Baseline (`neurips23_start_kit`) 10M steps: https://wandb.ai/kywch/nmmo-baselines/runs/test_01 3 | * Yaofeng 25M, 50M, 100M, 200M steps: https://wandb.ai/kywch/nmmo-baselines/runs/23t7ga2i 4 | * Takeru 25M, 50M, 100M, 200M steps: https://wandb.ai/kywch/nmmo-baselines/runs/3jxf93gp 5 | 6 | # Evaluation script 7 | * The eval_pvp json files were obtained by running `(.venv) $ python evaluate.py policies pvp -r 10` 8 | * The results were summarized by running `(.venv) $ python analysis/proc_eval_result.py policies` 9 | -------------------------------------------------------------------------------- /policies/baseline_10M.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/policies/baseline_10M.pt -------------------------------------------------------------------------------- /policies/elo.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/policies/elo.db -------------------------------------------------------------------------------- /policies/score_by_seed.tsv: -------------------------------------------------------------------------------- 1 | policy_name mode seed count length task_progress weighted_score 2 | baseline_10M pvp 1 2875 113.036870 0.057293 7.711149 3 | baseline_10M pvp 12196392 2856 113.172619 0.053504 6.621138 4 | baseline_10M pvp 19525770 2856 111.803221 0.051267 7.079278 5 | baseline_10M pvp 28034063 2891 113.072639 0.058612 6.802867 6 | baseline_10M pvp 31128942 2896 113.251036 0.049771 6.449095 7 | baseline_10M pvp 42720373 2856 112.990896 0.049595 6.890654 8 | baseline_10M pvp 47914166 2856 115.503151 0.061916 7.388279 9 | baseline_10M pvp 97868113 2856 112.744048 0.052768 7.327276 10 | baseline_10M pvp 99672462 2879 111.319208 0.047929 6.403828 11 | baseline_10M pvp 99910280 2856 111.749300 0.048945 7.105112 12 | learner pvp 1 417 33.431655 0.007722 1.593511 13 | learner pvp 12196392 408 34.080882 0.007041 1.788326 14 | learner pvp 19525770 408 32.973039 0.008720 1.594650 15 | learner pvp 28034063 420 32.695238 0.012624 1.329694 16 | learner pvp 31128942 419 32.730310 0.009790 1.354010 17 | learner pvp 42720373 408 34.433824 0.015673 1.462101 18 | learner pvp 47914166 408 32.987745 0.007112 1.603758 19 | learner pvp 97868113 408 32.691176 0.019604 1.616171 20 | learner pvp 99672462 419 34.639618 0.013562 1.788460 21 | learner pvp 99910280 408 33.882353 0.004185 1.365756 22 | takeru_100M pvp 1 2882 250.350104 0.121535 15.618457 23 | takeru_100M pvp 12196392 2856 271.406162 0.144111 16.949041 24 | takeru_100M pvp 19525770 2856 252.528361 0.125749 16.176245 25 | takeru_100M pvp 28034063 2876 245.837969 0.120900 14.367364 26 | takeru_100M pvp 31128942 2875 256.180870 0.121279 15.395644 27 | takeru_100M pvp 42720373 2856 262.454832 0.123755 15.880736 28 | takeru_100M pvp 47914166 2856 253.642157 0.114422 14.218429 29 | takeru_100M pvp 97868113 2856 239.824230 0.120914 14.178165 30 | takeru_100M pvp 99672462 2869 271.960962 0.131653 17.312564 31 | takeru_100M pvp 99910280 2856 249.379552 0.127160 14.150382 32 | takeru_200M pvp 1 2875 316.104000 0.151924 18.709205 33 | takeru_200M pvp 12196392 2856 339.818978 0.151841 20.426849 34 | takeru_200M pvp 19525770 2856 316.747199 0.139267 18.577009 35 | takeru_200M pvp 28034063 2881 324.179104 0.151727 19.935177 36 | takeru_200M pvp 31128942 2880 335.050347 0.158917 21.702642 37 | takeru_200M pvp 42720373 2856 324.113445 0.155242 21.043062 38 | takeru_200M pvp 47914166 2856 337.428221 0.152481 20.125692 39 | takeru_200M pvp 97868113 2856 334.341737 0.152232 19.192452 40 | takeru_200M pvp 99672462 2864 328.941341 0.155119 18.984392 41 | takeru_200M pvp 99910280 2856 323.931723 0.140493 17.695197 42 | takeru_25M pvp 1 2884 180.718100 0.089929 11.114902 43 | takeru_25M pvp 12196392 2856 177.607143 0.099868 10.870050 44 | takeru_25M pvp 19525770 2856 197.280462 0.100289 12.289034 45 | takeru_25M pvp 28034063 2877 181.195342 0.086204 11.439433 46 | takeru_25M pvp 31128942 2876 183.363004 0.094024 11.418434 47 | takeru_25M pvp 42720373 2856 175.254902 0.093796 11.112426 48 | takeru_25M pvp 47914166 2856 185.682073 0.101088 12.007849 49 | takeru_25M pvp 97868113 2856 180.685924 0.086907 11.268605 50 | takeru_25M pvp 99672462 2867 191.282874 0.097065 13.397202 51 | takeru_25M pvp 99910280 2856 180.972339 0.091033 11.780795 52 | takeru_50M pvp 1 2880 187.513194 0.094340 11.108544 53 | takeru_50M pvp 12196392 2856 183.173319 0.085969 9.765318 54 | takeru_50M pvp 19525770 2856 195.420868 0.088209 10.601355 55 | takeru_50M pvp 28034063 2882 189.238376 0.102966 11.641201 56 | takeru_50M pvp 31128942 2877 190.757039 0.088876 10.537141 57 | takeru_50M pvp 42720373 2856 204.296569 0.095444 11.299242 58 | takeru_50M pvp 47914166 2856 187.340336 0.092127 11.501803 59 | takeru_50M pvp 97868113 2856 200.723389 0.098814 12.097092 60 | takeru_50M pvp 99672462 2871 190.401254 0.089566 10.872617 61 | takeru_50M pvp 99910280 2856 204.198529 0.103643 13.027134 62 | yaofeng_100M pvp 1 2873 521.517229 0.213351 28.023256 63 | yaofeng_100M pvp 12196392 2856 522.893908 0.215434 29.544247 64 | yaofeng_100M pvp 19525770 2856 522.175420 0.213550 27.830269 65 | yaofeng_100M pvp 28034063 2876 521.185327 0.218254 27.658662 66 | yaofeng_100M pvp 31128942 2863 524.315403 0.216067 28.792953 67 | yaofeng_100M pvp 42720373 2856 523.355392 0.217156 29.740841 68 | yaofeng_100M pvp 47914166 2856 519.951331 0.206967 28.194194 69 | yaofeng_100M pvp 97868113 2856 527.612745 0.215622 31.332508 70 | yaofeng_100M pvp 99672462 2861 530.077945 0.221404 29.723764 71 | yaofeng_100M pvp 99910280 2856 505.139006 0.200635 26.478715 72 | yaofeng_200M pvp 1 2867 604.292640 0.257036 33.722996 73 | yaofeng_200M pvp 12196392 2856 631.487045 0.264420 33.374962 74 | yaofeng_200M pvp 19525770 2856 606.517157 0.252681 33.309315 75 | yaofeng_200M pvp 28034063 2867 610.524939 0.251194 32.512235 76 | yaofeng_200M pvp 31128942 2865 631.468063 0.263421 34.195978 77 | yaofeng_200M pvp 42720373 2856 621.339986 0.268310 34.821742 78 | yaofeng_200M pvp 47914166 2856 631.818277 0.268357 35.441499 79 | yaofeng_200M pvp 97868113 2856 607.034664 0.255898 35.460509 80 | yaofeng_200M pvp 99672462 2864 608.230098 0.265366 34.140322 81 | yaofeng_200M pvp 99910280 2856 605.400210 0.259645 34.211521 82 | yaofeng_25M pvp 1 2886 305.388080 0.129862 18.244865 83 | yaofeng_25M pvp 12196392 2856 313.843137 0.129536 18.494977 84 | yaofeng_25M pvp 19525770 2856 303.464286 0.137456 18.455156 85 | yaofeng_25M pvp 28034063 2884 312.892510 0.138521 19.129880 86 | yaofeng_25M pvp 31128942 2871 315.831766 0.144807 21.178642 87 | yaofeng_25M pvp 42720373 2856 299.599090 0.130901 18.187198 88 | yaofeng_25M pvp 47914166 2856 323.201681 0.151005 21.174788 89 | yaofeng_25M pvp 97868113 2856 316.847339 0.144221 18.992866 90 | yaofeng_25M pvp 99672462 2868 319.894003 0.145005 20.106129 91 | yaofeng_25M pvp 99910280 2856 310.024860 0.139246 18.866955 92 | yaofeng_50M pvp 1 2878 372.162265 0.163365 21.912547 93 | yaofeng_50M pvp 12196392 2856 389.990196 0.162783 21.662320 94 | yaofeng_50M pvp 19525770 2856 380.175770 0.163746 22.486598 95 | yaofeng_50M pvp 28034063 2876 398.583102 0.162805 23.229189 96 | yaofeng_50M pvp 31128942 2868 387.937238 0.168452 22.369273 97 | yaofeng_50M pvp 42720373 2856 403.100490 0.174496 24.020460 98 | yaofeng_50M pvp 47914166 2856 400.049020 0.163483 24.476792 99 | yaofeng_50M pvp 97868113 2856 379.010154 0.172073 21.580274 100 | yaofeng_50M pvp 99672462 2867 392.014301 0.171406 22.375796 101 | yaofeng_50M pvp 99910280 2856 395.121499 0.169452 22.586875 102 | -------------------------------------------------------------------------------- /policies/score_category_summary.tsv: -------------------------------------------------------------------------------- 1 | mode category policy_name task_progress 2 | pvp combat baseline_10M 0.060323 3 | pvp combat learner 0.022414 4 | pvp combat takeru_100M 0.141534 5 | pvp combat takeru_200M 0.179071 6 | pvp combat takeru_25M 0.103009 7 | pvp combat takeru_50M 0.108632 8 | pvp combat yaofeng_100M 0.249735 9 | pvp combat yaofeng_200M 0.298348 10 | pvp combat yaofeng_25M 0.168542 11 | pvp combat yaofeng_50M 0.204556 12 | pvp exploration baseline_10M 0.104342 13 | pvp exploration learner 0.018378 14 | pvp exploration takeru_100M 0.098290 15 | pvp exploration takeru_200M 0.121019 16 | pvp exploration takeru_25M 0.082401 17 | pvp exploration takeru_50M 0.099748 18 | pvp exploration yaofeng_100M 0.140990 19 | pvp exploration yaofeng_200M 0.163010 20 | pvp exploration yaofeng_25M 0.117004 21 | pvp exploration yaofeng_50M 0.127692 22 | pvp item baseline_10M 0.051819 23 | pvp item learner 0.011147 24 | pvp item takeru_100M 0.118230 25 | pvp item takeru_200M 0.139118 26 | pvp item takeru_25M 0.088768 27 | pvp item takeru_50M 0.092103 28 | pvp item yaofeng_100M 0.191561 29 | pvp item yaofeng_200M 0.234871 30 | pvp item yaofeng_25M 0.118466 31 | pvp item yaofeng_50M 0.150575 32 | pvp market baseline_10M 0.064515 33 | pvp market learner 0.007175 34 | pvp market takeru_100M 0.251776 35 | pvp market takeru_200M 0.316602 36 | pvp market takeru_25M 0.198542 37 | pvp market takeru_50M 0.153054 38 | pvp market yaofeng_100M 0.484392 39 | pvp market yaofeng_200M 0.600675 40 | pvp market yaofeng_25M 0.354815 41 | pvp market yaofeng_50M 0.357944 42 | pvp skill baseline_10M 0.031309 43 | pvp skill learner 0.000465 44 | pvp skill takeru_100M 0.069499 45 | pvp skill takeru_200M 0.086037 46 | pvp skill takeru_25M 0.046213 47 | pvp skill takeru_50M 0.051050 48 | pvp skill yaofeng_100M 0.133303 49 | pvp skill yaofeng_200M 0.158695 50 | pvp skill yaofeng_25M 0.090882 51 | pvp skill yaofeng_50M 0.104352 52 | pvp survival baseline_10M 0.106365 53 | pvp survival learner 0.033430 54 | pvp survival takeru_100M 0.246154 55 | pvp survival takeru_200M 0.336503 56 | pvp survival takeru_25M 0.181260 57 | pvp survival takeru_50M 0.170122 58 | pvp survival yaofeng_100M 0.523936 59 | pvp survival yaofeng_200M 0.591548 60 | pvp survival yaofeng_25M 0.307280 61 | pvp survival yaofeng_50M 0.415081 62 | -------------------------------------------------------------------------------- /policies/score_summary.tsv: -------------------------------------------------------------------------------- 1 | policy_name mode task_progress weighted_score 2 | yaofeng_200M pvp 0.260633 34.119108 3 | yaofeng_100M pvp 0.213844 28.731941 4 | yaofeng_50M pvp 0.167206 22.670012 5 | takeru_200M pvp 0.150924 19.639168 6 | yaofeng_25M pvp 0.139056 19.283145 7 | takeru_100M pvp 0.125148 15.424703 8 | takeru_25M pvp 0.094020 11.669873 9 | takeru_50M pvp 0.093995 11.245145 10 | baseline_10M pvp 0.053160 6.977868 11 | learner pvp 0.010603 1.549644 12 | -------------------------------------------------------------------------------- /policies/takeru_100M.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/policies/takeru_100M.pt -------------------------------------------------------------------------------- /policies/takeru_200M.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/policies/takeru_200M.pt -------------------------------------------------------------------------------- /policies/takeru_25M.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/policies/takeru_25M.pt -------------------------------------------------------------------------------- /policies/takeru_50M.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/policies/takeru_50M.pt -------------------------------------------------------------------------------- /policies/yaofeng_100M.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/policies/yaofeng_100M.pt -------------------------------------------------------------------------------- /policies/yaofeng_200M.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/policies/yaofeng_200M.pt -------------------------------------------------------------------------------- /policies/yaofeng_25M.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/policies/yaofeng_25M.pt -------------------------------------------------------------------------------- /policies/yaofeng_50M.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/policies/yaofeng_50M.pt -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["pip>=23.0", "setuptools>=61.0", "wheel"] 3 | 4 | [project] 5 | name = "nmmo2-baselines" 6 | version = "0.1.0" 7 | description = "Neural MMO 2023 competition baselines" 8 | keywords = [] 9 | classifiers = [ 10 | "Natural Language :: English", 11 | "Operating System :: POSIX :: Linux", 12 | "Operating System :: MacOS :: MacOS X", 13 | "Programming Language :: Python", 14 | "Programming Language :: Python :: 3.10", 15 | "Programming Language :: Python :: Implementation :: CPython", 16 | ] 17 | dependencies = [ 18 | "accelerate", 19 | "nmmo>=2.1,<2.2", 20 | "polars", 21 | "pufferlib[nmmo]==0.7.3", 22 | "psutil<6", 23 | "syllabus-rl@git+https://github.com/kywch/Syllabus@nmmo", # To replace with pip later 24 | "torch>2", 25 | "transformers", 26 | "wandb", 27 | ] 28 | 29 | [tool.setuptools.packages.find] 30 | where = ["."] 31 | exclude = ["tests"] 32 | 33 | [project.optional-dependencies] 34 | monitoring = [ 35 | "nvitop" 36 | ] 37 | dev = [ 38 | "pre-commit", 39 | "ruff" 40 | ] 41 | 42 | [tool.distutils.bdist_wheel] 43 | universal = true 44 | 45 | [tool.ruff] 46 | # Exclude a variety of commonly ignored directories. 47 | exclude = [ 48 | ".bzr", 49 | ".direnv", 50 | ".eggs", 51 | ".git", 52 | ".git-rewrite", 53 | ".hg", 54 | ".ipynb_checkpoints", 55 | ".mypy_cache", 56 | ".nox", 57 | ".pants.d", 58 | ".pyenv", 59 | ".pytest_cache", 60 | ".pytype", 61 | ".ruff_cache", 62 | ".svn", 63 | ".tox", 64 | ".venv", 65 | ".vscode", 66 | "__pypackages__", 67 | "_build", 68 | "buck-out", 69 | "build", 70 | "dist", 71 | "node_modules", 72 | "site-packages", 73 | "venv", 74 | ] 75 | 76 | # Same as Black. 77 | line-length = 100 78 | indent-width = 4 79 | 80 | # Assume Python 3.10 81 | target-version = "py310" 82 | 83 | [tool.ruff.lint] 84 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 85 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or 86 | # McCabe complexity (`C901`) by default. 87 | select = ["E4", "E7", "E9", "F"] 88 | ignore = [] 89 | 90 | # Allow fix for all enabled rules (when `--fix`) is provided. 91 | fixable = ["ALL"] 92 | unfixable = [] 93 | 94 | # Allow unused variables when underscore-prefixed. 95 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 96 | 97 | [tool.ruff.lint.per-file-ignores] 98 | "__init__.py" = ["F401"] # Ignore imported but unused 99 | 100 | [tool.ruff.format] 101 | # Like Black, use double quotes for strings. 102 | quote-style = "double" 103 | 104 | # Like Black, indent with spaces, rather than tabs. 105 | indent-style = "space" 106 | 107 | # Like Black, respect magic trailing commas. 108 | skip-magic-trailing-comma = false 109 | 110 | # Like Black, automatically detect the appropriate line ending. 111 | line-ending = "auto" 112 | 113 | # Enable auto-formatting of code examples in docstrings. Markdown, 114 | # reStructuredText code/literal blocks and doctests are all supported. 115 | # 116 | # This is currently disabled by default, but it is planned for this 117 | # to be opt-out in the future. 118 | docstring-code-format = false 119 | 120 | # Set the line length limit used when formatting code snippets in 121 | # docstrings. 122 | # 123 | # This only has an effect when the `docstring-code-format` setting is 124 | # enabled. 125 | docstring-code-line-length = "dynamic" 126 | -------------------------------------------------------------------------------- /reinforcement_learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuralMMO/baselines/9a4ec9ccac6f7e91d01708ceff8386ecfd52abc6/reinforcement_learning/__init__.py -------------------------------------------------------------------------------- /reinforcement_learning/clean_pufferl.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: E722, F841 2 | 3 | # Copied from: https://github.com/PufferAI/PufferLib/blob/0.7/clean_pufferl.py 4 | # from pdb import set_trace as T 5 | import os 6 | import random 7 | import time 8 | import uuid 9 | 10 | from collections import defaultdict 11 | from datetime import timedelta 12 | 13 | import numpy as np 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | 19 | import pufferlib 20 | import pufferlib.utils 21 | import pufferlib.emulation 22 | import pufferlib.vectorization 23 | import pufferlib.frameworks.cleanrl 24 | import pufferlib.policy_pool 25 | 26 | SKIP_LOG_KEYS = ["curriculum/Task_", "env_id"] 27 | 28 | 29 | @pufferlib.dataclass 30 | class Performance: 31 | total_uptime = 0 32 | total_updates = 0 33 | total_global_step = 0 34 | total_agent_steps = 0 35 | epoch_time = 0 36 | epoch_sps = 0 37 | env_time = 0 38 | env_sps = 0 39 | inference_time = 0 40 | inference_sps = 0 41 | train_time = 0 42 | train_sps = 0 43 | train_memory = 0 44 | train_pytorch_memory = 0 45 | misc_time = 0 46 | 47 | 48 | @pufferlib.dataclass 49 | class Losses: 50 | policy_loss = 0 51 | value_loss = 0 52 | entropy = 0 53 | old_approx_kl = 0 54 | approx_kl = 0 55 | clipfrac = 0 56 | explained_variance = 0 57 | 58 | 59 | @pufferlib.dataclass 60 | class Charts: 61 | global_step = 0 62 | SPS = 0 63 | learning_rate = 0 64 | agent_step = 0 65 | agent_SPS = 0 66 | 67 | 68 | def create( 69 | self: object = None, 70 | config: pufferlib.namespace = None, 71 | exp_name: str = None, 72 | track: bool = False, 73 | # Agent 74 | agent: nn.Module = None, 75 | agent_creator: callable = None, 76 | agent_kwargs: dict = None, 77 | # Environment 78 | env_creator: callable = None, 79 | env_creator_kwargs: dict = None, 80 | vectorization: ... = pufferlib.vectorization.Serial, 81 | # Evaluation or replay mode 82 | eval_mode: bool = False, 83 | eval_model_path: str = None, 84 | # Policy Pool options 85 | policy_selector: callable = None, 86 | ): 87 | if config is None: 88 | config = pufferlib.args.CleanPuffeRL() 89 | 90 | if exp_name is None: 91 | exp_name = str(uuid.uuid4())[:8] 92 | 93 | wandb = None 94 | if track: 95 | import wandb 96 | 97 | start_time = time.time() 98 | seed_everything(config.seed, config.torch_deterministic) 99 | total_updates = config.total_timesteps // config.batch_size 100 | 101 | device = config.device 102 | 103 | # Create environments, agent, and optimizer 104 | init_profiler = pufferlib.utils.Profiler(memory=True) 105 | with init_profiler: 106 | pool = vectorization( 107 | env_creator, 108 | env_kwargs=env_creator_kwargs, 109 | num_envs=config.num_envs, 110 | envs_per_worker=config.envs_per_worker, 111 | envs_per_batch=config.envs_per_batch, 112 | env_pool=config.env_pool, 113 | mask_agents=True, 114 | ) 115 | 116 | obs_shape = pool.single_observation_space.shape 117 | atn_shape = pool.single_action_space.shape 118 | num_agents = pool.agents_per_env 119 | total_agents = num_agents * config.num_envs 120 | 121 | # If data_dir is provided, load the resume state 122 | resume_state = {} 123 | path = os.path.join(config.data_dir, exp_name) 124 | if os.path.exists(path): 125 | trainer_path = os.path.join(path, "trainer_state.pt") 126 | resume_state = torch.load(trainer_path) 127 | model_path = os.path.join(path, resume_state["model_name"]) 128 | agent = torch.load(model_path, map_location=device) 129 | print( 130 | f'Resumed from update {resume_state["update"]} ' 131 | f'with policy {resume_state["model_name"]}' 132 | ) 133 | else: 134 | agent = pufferlib.emulation.make_object( 135 | agent, agent_creator, [pool.driver_env], agent_kwargs 136 | ) 137 | 138 | global_step = resume_state.get("global_step", 0) 139 | agent_step = resume_state.get("agent_step", 0) 140 | update = resume_state.get("update", 0) 141 | 142 | optimizer = optim.Adam(agent.parameters(), lr=config.learning_rate, eps=1e-5) 143 | 144 | uncompiled_agent = agent # Needed to save the model 145 | if config.compile: 146 | agent = torch.compile(agent, mode=config.compile_mode) 147 | 148 | if config.verbose: 149 | n_params = sum(p.numel() for p in agent.parameters() if p.requires_grad) 150 | print(f"Model Size: {n_params//1000} K parameters") 151 | print(f"Obs size: {obs_shape[0]} -- {pool.driver_env.obs_sz}") 152 | 153 | opt_state = resume_state.get("optimizer_state_dict", None) 154 | if opt_state is not None: 155 | optimizer.load_state_dict(resume_state["optimizer_state_dict"]) 156 | 157 | # Create policy pool 158 | pool_agents = num_agents * pool.envs_per_batch 159 | pool_path = eval_model_path if eval_mode else path 160 | if policy_selector is None: 161 | policy_selector = pufferlib.policy_pool.RandomPolicySelector(config.seed) 162 | policy_pool = pufferlib.policy_pool.PolicyPool( 163 | agent, 164 | pool_agents, 165 | atn_shape, 166 | device, 167 | pool_path, 168 | config.pool_kernel, 169 | policy_selector, 170 | ) 171 | 172 | # Allocate Storage 173 | storage_profiler = pufferlib.utils.Profiler(memory=True, pytorch_memory=True).start() 174 | next_lstm_state = [] 175 | pool.async_reset(config.seed) 176 | next_lstm_state = None 177 | if hasattr(agent, "lstm"): 178 | shape = (agent.lstm.num_layers, total_agents, agent.lstm.hidden_size) 179 | next_lstm_state = ( 180 | torch.zeros(shape).to(device), 181 | torch.zeros(shape).to(device), 182 | ) 183 | obs = torch.zeros(config.batch_size + 1, *obs_shape) 184 | actions = torch.zeros(config.batch_size + 1, *atn_shape, dtype=int) 185 | logprobs = torch.zeros(config.batch_size + 1) 186 | rewards = torch.zeros(config.batch_size + 1) 187 | dones = torch.zeros(config.batch_size + 1) 188 | truncateds = torch.zeros(config.batch_size + 1) 189 | values = torch.zeros(config.batch_size + 1) 190 | 191 | obs_ary = np.asarray(obs) 192 | actions_ary = np.asarray(actions) 193 | logprobs_ary = np.asarray(logprobs) 194 | rewards_ary = np.asarray(rewards) 195 | dones_ary = np.asarray(dones) 196 | truncateds_ary = np.asarray(truncateds) 197 | values_ary = np.asarray(values) 198 | 199 | storage_profiler.stop() 200 | 201 | # "charts/actions": wandb.Histogram(b_actions.cpu().numpy()), 202 | init_performance = pufferlib.namespace( 203 | init_time=time.time() - start_time, 204 | init_env_time=init_profiler.elapsed, 205 | init_env_memory=init_profiler.memory, 206 | tensor_memory=storage_profiler.memory, 207 | tensor_pytorch_memory=storage_profiler.pytorch_memory, 208 | ) 209 | 210 | return pufferlib.namespace( 211 | self, 212 | # Agent, Optimizer, and Environment 213 | config=config, 214 | pool=pool, 215 | agent=agent, 216 | uncompiled_agent=uncompiled_agent, 217 | optimizer=optimizer, 218 | policy_pool=policy_pool, 219 | # Logging 220 | exp_name=exp_name, 221 | wandb=wandb, 222 | learning_rate=config.learning_rate, 223 | losses=Losses(), 224 | init_performance=init_performance, 225 | performance=Performance(), 226 | # Storage 227 | sort_keys=[], 228 | next_lstm_state=next_lstm_state, 229 | obs=obs, 230 | actions=actions, 231 | logprobs=logprobs, 232 | rewards=rewards, 233 | dones=dones, 234 | values=values, 235 | obs_ary=obs_ary, 236 | actions_ary=actions_ary, 237 | logprobs_ary=logprobs_ary, 238 | rewards_ary=rewards_ary, 239 | dones_ary=dones_ary, 240 | truncateds_ary=truncateds_ary, 241 | values_ary=values_ary, 242 | # Misc 243 | total_updates=total_updates, 244 | update=update, 245 | global_step=global_step, 246 | agent_step=agent_step, 247 | device=device, 248 | start_time=start_time, 249 | eval_mode=eval_mode, 250 | ) 251 | 252 | 253 | @pufferlib.utils.profile 254 | def evaluate(data): 255 | config = data.config 256 | # TODO: Handle update on resume 257 | if data.wandb is not None and data.performance.total_uptime > 0: 258 | data.wandb.log( 259 | { 260 | "agent_SPS": data.agent_SPS, 261 | "agent_step": data.agent_step, 262 | "global_step": data.global_step, 263 | "learning_rate": data.optimizer.param_groups[0]["lr"], 264 | **{f"losses/{k}": v for k, v in data.losses.items()}, 265 | **{f"performance/{k}": v for k, v in data.performance.items()}, 266 | **{f"{k}": v for k, v in data.stats.items()}, # comes with stats/ prefix 267 | **{ 268 | f"skillrank/{policy}": elo 269 | for policy, elo in data.policy_pool.ranker.ratings.items() 270 | }, 271 | } 272 | ) 273 | 274 | # update_policies() changes the policy id (in kernel) - policy mapping 275 | # It's good for training but not wanted for replay or evaluation, so we skip it 276 | if not data.eval_mode: 277 | data.policy_pool.update_policies() 278 | 279 | # performance = defaultdict(list) 280 | env_profiler = pufferlib.utils.Profiler() 281 | inference_profiler = pufferlib.utils.Profiler() 282 | eval_profiler = pufferlib.utils.Profiler(memory=True, pytorch_memory=True).start() 283 | misc_profiler = pufferlib.utils.Profiler() 284 | 285 | ptr = step = padded_steps_collected = agent_steps_collected = 0 286 | infos = defaultdict(lambda: defaultdict(list)) 287 | while True: 288 | step += 1 289 | if ptr == config.batch_size + 1: 290 | break 291 | 292 | with env_profiler: 293 | o, r, d, t, i, env_id, mask = data.pool.recv() 294 | 295 | with misc_profiler: 296 | i = data.policy_pool.update_scores(i, "return") 297 | # TODO: Update this for policy pool 298 | for ii, ee in zip(i["learner"], env_id): 299 | ii["env_id"] = ee 300 | 301 | with inference_profiler, torch.no_grad(): 302 | o = torch.as_tensor(o) 303 | r = torch.as_tensor(r).float().to(data.device).view(-1) 304 | d = torch.as_tensor(d).float().to(data.device).view(-1) 305 | 306 | agent_steps_collected += sum(mask) 307 | padded_steps_collected += len(mask) 308 | 309 | # Multiple policies will not work with new envpool 310 | next_lstm_state = data.next_lstm_state 311 | if next_lstm_state is not None: 312 | next_lstm_state = ( 313 | next_lstm_state[0][:, env_id], 314 | next_lstm_state[1][:, env_id], 315 | ) 316 | 317 | actions, logprob, value, next_lstm_state = data.policy_pool.forwards( 318 | o.to(data.device), next_lstm_state 319 | ) 320 | 321 | if next_lstm_state is not None: 322 | h, c = next_lstm_state 323 | data.next_lstm_state[0][:, env_id] = h 324 | data.next_lstm_state[1][:, env_id] = c 325 | 326 | value = value.flatten() 327 | 328 | with misc_profiler: 329 | actions = actions.cpu().numpy() 330 | 331 | # Index alive mask with policy pool idxs... 332 | # TODO: Find a way to avoid having to do this 333 | learner_mask = torch.Tensor(mask * data.policy_pool.mask) 334 | 335 | # Ensure indices do not exceed batch size 336 | indices = torch.where(learner_mask)[0][: config.batch_size - ptr + 1].numpy() 337 | end = ptr + len(indices) 338 | 339 | # Batch indexing 340 | data.obs_ary[ptr:end] = o.cpu().numpy()[indices] 341 | data.values_ary[ptr:end] = value.cpu().numpy()[indices] 342 | data.actions_ary[ptr:end] = actions[indices] 343 | data.logprobs_ary[ptr:end] = logprob.cpu().numpy()[indices] 344 | data.rewards_ary[ptr:end] = r.cpu().numpy()[indices] 345 | data.dones_ary[ptr:end] = d.cpu().numpy()[indices] 346 | data.sort_keys.extend([(env_id[i], step) for i in indices]) 347 | 348 | # Update pointer 349 | ptr += len(indices) 350 | 351 | for policy_name, policy_i in i.items(): 352 | for agent_i in policy_i: 353 | for name, dat in unroll_nested_dict(agent_i): 354 | infos[policy_name][name].append(dat) 355 | 356 | with env_profiler: 357 | data.pool.send(actions) 358 | 359 | eval_profiler.stop() 360 | 361 | data.agent_step += agent_steps_collected 362 | data.global_step += padded_steps_collected 363 | data.reward = float(torch.mean(data.rewards)) 364 | data.SPS = int(padded_steps_collected / eval_profiler.elapsed) 365 | data.agent_SPS = int(agent_steps_collected / eval_profiler.elapsed) 366 | 367 | perf = data.performance 368 | perf.total_uptime = int(time.time() - data.start_time) 369 | perf.total_global_step = data.global_step 370 | perf.total_agent_steps = data.agent_step 371 | perf.env_time = env_profiler.elapsed 372 | perf.env_sps = int(agent_steps_collected / env_profiler.elapsed) 373 | perf.inference_time = inference_profiler.elapsed 374 | perf.inference_sps = int(padded_steps_collected / inference_profiler.elapsed) 375 | perf.eval_time = eval_profiler.elapsed 376 | perf.eval_sps = int(agent_steps_collected / eval_profiler.elapsed) 377 | perf.eval_memory = eval_profiler.end_mem 378 | perf.eval_pytorch_memory = eval_profiler.end_torch_mem 379 | perf.misc_time = misc_profiler.elapsed 380 | 381 | data.stats = {} 382 | 383 | # Get stats only from the learner 384 | for k, v in infos["learner"].items(): 385 | try: # TODO: Better checks on log data types 386 | # Skip the unnecessary info from the stats 387 | if not any(skip in k for skip in SKIP_LOG_KEYS): 388 | data.stats[k] = np.mean(v) 389 | except: 390 | continue 391 | 392 | if config.verbose: 393 | print_dashboard(data.stats, data.init_performance, data.performance) 394 | 395 | return data.stats, infos 396 | 397 | 398 | @pufferlib.utils.profile 399 | def train(data): 400 | if done_training(data): 401 | raise RuntimeError(f"Max training updates {data.total_updates} already reached") 402 | 403 | config = data.config 404 | # assert data.num_steps % bptt_horizon == 0, "num_steps must be divisible by bptt_horizon" 405 | train_profiler = pufferlib.utils.Profiler(memory=True, pytorch_memory=True) 406 | train_profiler.start() 407 | 408 | if config.anneal_lr: 409 | frac = 1.0 - (data.update - 1.0) / data.total_updates 410 | lrnow = frac * config.learning_rate 411 | data.optimizer.param_groups[0]["lr"] = lrnow 412 | 413 | num_minibatches = config.batch_size // config.bptt_horizon // config.batch_rows 414 | idxs = sorted(range(len(data.sort_keys)), key=data.sort_keys.__getitem__) 415 | data.sort_keys = [] 416 | b_idxs = ( 417 | torch.Tensor(idxs) 418 | .long()[:-1] 419 | .reshape(config.batch_rows, num_minibatches, config.bptt_horizon) 420 | .transpose(0, 1) 421 | ) 422 | 423 | # bootstrap value if not done 424 | with torch.no_grad(): 425 | advantages = torch.zeros(config.batch_size, device=data.device) 426 | lastgaelam = 0 427 | for t in reversed(range(config.batch_size)): 428 | i, i_nxt = idxs[t], idxs[t + 1] 429 | nextnonterminal = 1.0 - data.dones[i_nxt] 430 | nextvalues = data.values[i_nxt] 431 | delta = ( 432 | data.rewards[i_nxt] + config.gamma * nextvalues * nextnonterminal - data.values[i] 433 | ) 434 | advantages[t] = lastgaelam = ( 435 | delta + config.gamma * config.gae_lambda * nextnonterminal * lastgaelam 436 | ) 437 | 438 | # Flatten the batch 439 | data.b_obs = b_obs = torch.Tensor(data.obs_ary[b_idxs]) 440 | b_actions = torch.Tensor(data.actions_ary[b_idxs]).to(data.device, non_blocking=True) 441 | b_logprobs = torch.Tensor(data.logprobs_ary[b_idxs]).to(data.device, non_blocking=True) 442 | b_dones = torch.Tensor(data.dones_ary[b_idxs]).to(data.device, non_blocking=True) 443 | b_values = torch.Tensor(data.values_ary[b_idxs]).to(data.device, non_blocking=True) 444 | b_advantages = advantages.reshape( 445 | config.batch_rows, num_minibatches, config.bptt_horizon 446 | ).transpose(0, 1) 447 | b_returns = b_advantages + b_values 448 | 449 | # Optimizing the policy and value network 450 | train_time = time.time() 451 | pg_losses, entropy_losses, v_losses, clipfracs, old_kls, kls = [], [], [], [], [], [] 452 | mb_obs_buffer = torch.zeros_like(b_obs[0], pin_memory=(data.device == "cuda")) 453 | 454 | for epoch in range(config.update_epochs): 455 | lstm_state = None 456 | for mb in range(num_minibatches): 457 | mb_obs_buffer.copy_(b_obs[mb], non_blocking=True) 458 | mb_obs = mb_obs_buffer.to(data.device, non_blocking=True) 459 | # mb_obs = b_obs[mb].to(data.device, non_blocking=True) 460 | mb_actions = b_actions[mb].contiguous() 461 | mb_values = b_values[mb].reshape(-1) 462 | mb_advantages = b_advantages[mb].reshape(-1) 463 | mb_returns = b_returns[mb].reshape(-1) 464 | 465 | if hasattr(data.agent, "lstm"): 466 | _, newlogprob, entropy, newvalue, lstm_state = data.agent( 467 | mb_obs, state=lstm_state, action=mb_actions 468 | ) 469 | lstm_state = (lstm_state[0].detach(), lstm_state[1].detach()) 470 | else: 471 | _, newlogprob, entropy, newvalue = data.agent( 472 | mb_obs.reshape(-1, *data.pool.single_observation_space.shape), 473 | action=mb_actions, 474 | ) 475 | 476 | logratio = newlogprob - b_logprobs[mb].reshape(-1) 477 | ratio = logratio.exp() 478 | 479 | with torch.no_grad(): 480 | # calculate approx_kl http://joschu.net/blog/kl-approx.html 481 | old_approx_kl = (-logratio).mean() 482 | old_kls.append(old_approx_kl.item()) 483 | approx_kl = ((ratio - 1) - logratio).mean() 484 | kls.append(approx_kl.item()) 485 | clipfracs += [((ratio - 1.0).abs() > config.clip_coef).float().mean().item()] 486 | 487 | mb_advantages = mb_advantages.reshape(-1) 488 | if config.norm_adv: 489 | mb_advantages = (mb_advantages - mb_advantages.mean()) / ( 490 | mb_advantages.std() + 1e-8 491 | ) 492 | 493 | # Policy loss 494 | pg_loss1 = -mb_advantages * ratio 495 | pg_loss2 = -mb_advantages * torch.clamp( 496 | ratio, 1 - config.clip_coef, 1 + config.clip_coef 497 | ) 498 | pg_loss = torch.max(pg_loss1, pg_loss2).mean() 499 | pg_losses.append(pg_loss.item()) 500 | 501 | # Value loss 502 | newvalue = newvalue.view(-1) 503 | if config.clip_vloss: 504 | v_loss_unclipped = (newvalue - mb_returns) ** 2 505 | v_clipped = mb_values + torch.clamp( 506 | newvalue - mb_values, 507 | -config.vf_clip_coef, 508 | config.vf_clip_coef, 509 | ) 510 | v_loss_clipped = (v_clipped - mb_returns) ** 2 511 | v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) 512 | v_loss = 0.5 * v_loss_max.mean() 513 | else: 514 | v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() 515 | v_losses.append(v_loss.item()) 516 | 517 | entropy_loss = entropy.mean() 518 | entropy_losses.append(entropy_loss.item()) 519 | 520 | loss = pg_loss - config.ent_coef * entropy_loss + v_loss * config.vf_coef 521 | data.optimizer.zero_grad() 522 | loss.backward() 523 | nn.utils.clip_grad_norm_(data.agent.parameters(), config.max_grad_norm) 524 | data.optimizer.step() 525 | 526 | if config.target_kl is not None: 527 | if approx_kl > config.target_kl: 528 | break 529 | 530 | train_profiler.stop() 531 | y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() 532 | var_y = np.var(y_true) 533 | explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y 534 | 535 | losses = data.losses 536 | losses.policy_loss = np.mean(pg_losses) 537 | losses.value_loss = np.mean(v_losses) 538 | losses.entropy = np.mean(entropy_losses) 539 | losses.old_approx_kl = np.mean(old_kls) 540 | losses.approx_kl = np.mean(kls) 541 | losses.clipfrac = np.mean(clipfracs) 542 | losses.explained_variance = explained_var 543 | 544 | perf = data.performance 545 | perf.total_uptime = int(time.time() - data.start_time) 546 | perf.total_updates = data.update + 1 547 | perf.train_time = time.time() - train_time 548 | perf.train_sps = int(config.batch_size / perf.train_time) 549 | perf.train_memory = train_profiler.end_mem 550 | perf.train_pytorch_memory = train_profiler.end_torch_mem 551 | perf.epoch_time = perf.eval_time + perf.train_time 552 | perf.epoch_sps = int(config.batch_size / perf.epoch_time) 553 | 554 | if config.verbose: 555 | print_dashboard(data.stats, data.init_performance, data.performance) 556 | 557 | data.update += 1 558 | if data.update % config.checkpoint_interval == 0 or done_training(data): 559 | save_checkpoint(data) 560 | 561 | 562 | def close(data): 563 | data.pool.close() 564 | 565 | if data.wandb is not None: 566 | artifact_name = f"{data.exp_name}_model" 567 | artifact = data.wandb.Artifact(artifact_name, type="model") 568 | model_path = save_checkpoint(data) 569 | artifact.add_file(model_path) 570 | data.wandb.run.log_artifact(artifact) 571 | data.wandb.finish() 572 | 573 | 574 | def done_training(data): 575 | return data.update >= data.total_updates 576 | 577 | 578 | def save_checkpoint(data): 579 | path = os.path.join(data.config.data_dir, data.exp_name) 580 | if not os.path.exists(path): 581 | os.makedirs(path) 582 | 583 | model_name = f"model_{data.update:06d}.pt" 584 | model_path = os.path.join(path, model_name) 585 | 586 | # Already saved 587 | if os.path.exists(model_path): 588 | return model_path 589 | 590 | torch.save(data.uncompiled_agent, model_path) 591 | 592 | state = { 593 | "optimizer_state_dict": data.optimizer.state_dict(), 594 | "global_step": data.global_step, 595 | "agent_step": data.agent_step, 596 | "update": data.update, 597 | "model_name": model_name, 598 | } 599 | 600 | if data.wandb: 601 | state["exp_name"] = data.exp_name 602 | 603 | state_path = os.path.join(path, "trainer_state.pt") 604 | torch.save(state, state_path + ".tmp") 605 | os.rename(state_path + ".tmp", state_path) 606 | 607 | if data.config.verbose: 608 | print(f"Model saved to {model_path}") 609 | 610 | return model_path 611 | 612 | 613 | def seed_everything(seed, torch_deterministic): 614 | random.seed(seed) 615 | np.random.seed(seed) 616 | if seed is not None: 617 | torch.manual_seed(seed) 618 | 619 | # NOTE: Deterministic torch operations tend to have worse performance than nondeterministic ones 620 | # https://pytorch.org/docs/2.1/generated/torch.use_deterministic_algorithms.html 621 | if torch_deterministic: 622 | torch.backends.cudnn.deterministic = torch_deterministic 623 | # os env var CUBLAS_WORKSPACE_CONFIG was set to :4096:8 624 | # See https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility 625 | torch.use_deterministic_algorithms(torch_deterministic) 626 | # With torch >= 2.2, check also https://pytorch.org/docs/2.2/deterministic.html 627 | # torch.utils.deterministic.fill_uninitialized_memory = torch_deterministic 628 | 629 | 630 | def unroll_nested_dict(d): 631 | if not isinstance(d, dict): 632 | return d 633 | 634 | for k, v in d.items(): 635 | if isinstance(v, dict): 636 | for k2, v2 in unroll_nested_dict(v): 637 | yield f"{k}/{k2}", v2 638 | else: 639 | yield k, v 640 | 641 | 642 | def print_dashboard(stats, init_performance, performance): 643 | output = [] 644 | data = {**init_performance, **performance} 645 | # Only show these stats in the dashboard 646 | if "length" in stats: 647 | data["length"] = stats["length"] 648 | 649 | grouped_data = defaultdict(dict) 650 | 651 | for k, v in data.items(): 652 | if k == "total_uptime": 653 | v = timedelta(seconds=v) 654 | if "memory" in k: 655 | v = pufferlib.utils.format_bytes(v) 656 | elif "time" in k: 657 | try: 658 | v = f"{v:.2f} s" 659 | except: 660 | pass 661 | 662 | first_word, *rest_words = k.split("_") 663 | rest_words = " ".join(rest_words).title() 664 | 665 | grouped_data[first_word][rest_words] = v 666 | 667 | for main_key, sub_dict in grouped_data.items(): 668 | output.append(f"{main_key.title()}") 669 | for sub_key, sub_value in sub_dict.items(): 670 | output.append(f" {sub_key}: {sub_value}") 671 | 672 | print("\033c", end="") 673 | print("\n".join(output)) 674 | time.sleep(1 / 20) 675 | -------------------------------------------------------------------------------- /reinforcement_learning/environment.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import nmmo 4 | import nmmo.core.config as nc 5 | import nmmo.core.game_api as ng 6 | import pufferlib 7 | import pufferlib.emulation 8 | from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper 9 | from syllabus.core import PettingZooMultiProcessingSyncWrapper as SyllabusSyncWrapper 10 | 11 | from syllabus_wrapper import SyllabusTaskWrapper 12 | 13 | 14 | class Config( 15 | nc.Medium, 16 | nc.Terrain, 17 | nc.Resource, 18 | nc.Combat, 19 | nc.NPC, 20 | nc.Progression, 21 | nc.Item, 22 | nc.Equipment, 23 | nc.Profession, 24 | nc.Exchange, 25 | ): 26 | """Configuration for Neural MMO.""" 27 | 28 | def __init__(self, env_args: Namespace): 29 | super().__init__() 30 | 31 | self.set("PROVIDE_ACTION_TARGETS", True) 32 | self.set("PROVIDE_NOOP_ACTION_TARGET", True) 33 | self.set("MAP_FORCE_GENERATION", env_args.map_force_generation) 34 | self.set("PLAYER_N", env_args.num_agents) 35 | self.set("HORIZON", env_args.max_episode_length) 36 | self.set("MAP_N", env_args.num_maps) 37 | self.set( 38 | "PLAYER_DEATH_FOG", 39 | env_args.death_fog_tick if isinstance(env_args.death_fog_tick, int) else None, 40 | ) 41 | self.set("PATH_MAPS", f"{env_args.maps_path}/{env_args.map_size}/") 42 | self.set("MAP_CENTER", env_args.map_size) 43 | self.set("NPC_N", env_args.num_npcs) 44 | self.set("TASK_EMBED_DIM", env_args.task_size) 45 | self.set("RESOURCE_RESILIENT_POPULATION", env_args.resilient_population) 46 | self.set("COMBAT_SPAWN_IMMUNITY", env_args.spawn_immunity) 47 | 48 | self.set("GAME_PACKS", [(ng.AgentTraining, 1)]) 49 | self.set("CURRICULUM_FILE_PATH", env_args.curriculum_file_path) 50 | 51 | 52 | def make_env_creator( 53 | reward_wrapper_cls: BaseParallelWrapper, syllabus_wrapper=False, syllabus=None 54 | ): 55 | def env_creator(*args, **kwargs): 56 | """Create an environment.""" 57 | env = nmmo.Env(Config(kwargs["env"])) # args.env is provided as kwargs 58 | env = reward_wrapper_cls(env, **kwargs["reward_wrapper"]) 59 | 60 | # Add Syllabus task wrapper 61 | if syllabus_wrapper or syllabus is not None: 62 | env = SyllabusTaskWrapper(env) 63 | 64 | # Use syllabus curriculum if provided 65 | if syllabus is not None: 66 | env = SyllabusSyncWrapper( 67 | env, 68 | syllabus.get_components(), 69 | update_on_step=False, 70 | task_space=env.task_space, 71 | ) 72 | 73 | env = pufferlib.emulation.PettingZooPufferEnv(env) 74 | return env 75 | 76 | return env_creator 77 | -------------------------------------------------------------------------------- /reinforcement_learning/stat_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper 4 | 5 | from nmmo.lib.event_code import EventCode 6 | import nmmo.systems.item as Item 7 | 8 | 9 | class BaseStatWrapper(BaseParallelWrapper): 10 | def __init__( 11 | self, 12 | env, 13 | eval_mode=False, 14 | early_stop_agent_num=0, 15 | stat_prefix=None, 16 | use_custom_reward=True, 17 | ): 18 | super().__init__(env) 19 | self.env_done = False 20 | self.early_stop_agent_num = early_stop_agent_num 21 | self.eval_mode = eval_mode 22 | self._reset_episode_stats() 23 | self._stat_prefix = stat_prefix 24 | self.use_custom_reward = use_custom_reward 25 | 26 | def seed(self, seed): 27 | self.env.seed(seed) 28 | 29 | def observation(self, agent_id, agent_obs): 30 | """Called before observations are returned from the environment 31 | Use this to define custom featurizers. Changing the space itself requires you to 32 | define the observation space again (i.e. Gym.spaces.Dict(gym.spaces....))""" 33 | return agent_obs 34 | 35 | def action(self, agent_id, agent_atn): 36 | """Called before actions are passed from the model to the environment""" 37 | return agent_atn 38 | 39 | def reward_terminated_truncated_info(self, agent_id, reward, terminated, truncated, info): 40 | """Called on reward, terminated, truncated, and info before they are returned from the environment 41 | Use this to define custom reward shaping.""" 42 | return reward, terminated, truncated, info 43 | 44 | @property 45 | def agents(self): 46 | return [] if self.env_done else self.env.agents 47 | 48 | def reset(self, **kwargs): 49 | """Called at the start of each episode""" 50 | self._reset_episode_stats() 51 | obs, info = self.env.reset(**kwargs) 52 | 53 | for agent_id in self.env.agents: 54 | obs[agent_id] = self.observation(agent_id, obs[agent_id]) 55 | return obs, info 56 | 57 | def step(self, action): 58 | assert len(self.env.agents) > 0, "No agents in the environment" # xcxc: sanity check 59 | 60 | # Modify actions before they are passed to the environment 61 | for agent_id in self.env.agents: 62 | action[agent_id] = self.action(agent_id, action[agent_id]) 63 | 64 | obs, rewards, terms, truncs, infos = self.env.step(action) 65 | 66 | # Stop early if there are too few agents generating the training data 67 | # Also env.agents is empty when the tick reaches the config horizon 68 | if len(self.env.agents) <= self.early_stop_agent_num: 69 | self.env_done = True 70 | 71 | # Modify reward and observation after they are returned from the environment 72 | agent_list = list(obs.keys()) 73 | for agent_id in agent_list: 74 | trunc, info = self._process_stats_and_early_stop( 75 | agent_id, rewards[agent_id], terms[agent_id], truncs[agent_id], infos[agent_id] 76 | ) 77 | 78 | if self.use_custom_reward is True: 79 | rew, term, trunc, info = self.reward_terminated_truncated_info( 80 | agent_id, rewards[agent_id], terms[agent_id], trunc, info 81 | ) 82 | else: 83 | # NOTE: Also disable death penalty, which is not from the task 84 | rew = 0 if terms[agent_id] is True else rewards[agent_id] 85 | term = terms[agent_id] 86 | 87 | rewards[agent_id] = rew 88 | terms[agent_id] = term 89 | truncs[agent_id] = trunc 90 | infos[agent_id] = info 91 | obs[agent_id] = self.observation(agent_id, obs[agent_id]) 92 | 93 | if self.env_done: 94 | # To mark the end of the episode. Only one agent's done flag is enough. 95 | infos[agent_id]["episode_done"] = True 96 | 97 | return obs, rewards, terms, truncs, infos 98 | 99 | # def reward_done_truncated_info(self, agent_id, reward, don 100 | def _reset_episode_stats(self): 101 | self.env_done = False 102 | self.cum_rewards = {agent_id: 0 for agent_id in self.env.possible_agents} 103 | self._unique_events = { 104 | agent_id: { 105 | "experienced": set(), 106 | "prev_count": 0, 107 | "curr_count": 0, 108 | } 109 | for agent_id in self.env.possible_agents 110 | } 111 | 112 | def _process_stats_and_early_stop(self, agent_id, reward, terminated, truncated, info): 113 | """Update stats + info and save replays.""" 114 | # Remove the task from info. Curriculum info is processed in _update_stats() 115 | info.pop("task", None) 116 | 117 | # Handle early stopping 118 | if self.env_done and not terminated: 119 | truncated = True 120 | 121 | # Count and store unique event counts for easier use 122 | realm = self.env.realm 123 | tick_log = realm.event_log.get_data(agents=[agent_id], tick=-1) # get only the last tick 124 | uniq = self._unique_events[agent_id] 125 | uniq["prev_count"] = uniq["curr_count"] 126 | uniq["curr_count"] += count_unique_events(tick_log, uniq["experienced"]) 127 | 128 | if not (terminated or truncated): 129 | self.cum_rewards[agent_id] += reward 130 | return truncated, info 131 | 132 | # The agent is terminated or truncated, so recoding the episode stats 133 | if "stats" not in info: 134 | info["stats"] = {} 135 | 136 | agent = realm.players.dead_this_tick.get(agent_id, realm.players.get(agent_id)) 137 | assert agent is not None 138 | 139 | # NOTE: this may not be true when players can be resurrected. Check back later 140 | info["length"] = realm.tick 141 | 142 | info["return"] = self.cum_rewards[agent_id] 143 | 144 | # Cause of Deaths 145 | if terminated: 146 | info["stats"]["cod/attacked"] = 1.0 if agent.damage.val > 0 else 0.0 147 | info["stats"]["cod/starved"] = 1.0 if agent.food.val == 0 else 0.0 148 | info["stats"]["cod/dehydrated"] = 1.0 if agent.water.val == 0 else 0.0 149 | else: 150 | info["stats"]["cod/attacked"] = 0 151 | info["stats"]["cod/starved"] = 0 152 | info["stats"]["cod/dehydrated"] = 0 153 | 154 | # Task-related stats 155 | task = self.env.agent_task_map[agent_id][0] # consider only the first task 156 | info["stats"]["task/completed"] = 1.0 if task.completed else 0.0 157 | info["stats"]["task/pcnt_2_reward_signal"] = 1.0 if task.reward_signal_count >= 2 else 0.0 158 | info["stats"]["task/pcnt_0p2_max_progress"] = 1.0 if task._max_progress >= 0.2 else 0.0 159 | info["curriculum"] = {task.spec_name: (task._max_progress, task.reward_signal_count)} 160 | 161 | if self.eval_mode: 162 | # 'return' is used for ranking in the eval mode, so put the task progress here 163 | info["return"] = task._max_progress # this is 1 if done 164 | 165 | # Max combat/harvest level achieved 166 | info["stats"]["achieved/max_combat_level"] = agent.attack_level 167 | info["stats"]["achieved/max_harvest_skill_ammo"] = max( 168 | agent.prospecting_level.val, 169 | agent.carving_level.val, 170 | agent.alchemy_level.val, 171 | ) 172 | info["stats"]["achieved/max_harvest_skill_consum"] = max( 173 | agent.fishing_level.val, 174 | agent.herbalism_level.val, 175 | ) 176 | 177 | # Event-based stats 178 | achieved, performed, _ = process_event_log(realm, [agent_id]) 179 | for key, val in list(achieved.items()) + list(performed.items()): 180 | info["stats"][key] = float(val) 181 | 182 | if self._stat_prefix: 183 | info = {self._stat_prefix: info} 184 | 185 | return truncated, info 186 | 187 | 188 | ################################################################################ 189 | # Event processing utilities for Neural MMO. 190 | 191 | INFO_KEY_TO_EVENT_CODE = { 192 | "event/" + evt.lower(): val for evt, val in EventCode.__dict__.items() if isinstance(val, int) 193 | } 194 | 195 | # convert the numbers into binary (performed or not) for the key events 196 | KEY_EVENT = [ 197 | "eat_food", 198 | "drink_water", 199 | "score_hit", 200 | "player_kill", 201 | "consume_item", 202 | "harvest_item", 203 | "list_item", 204 | "buy_item", 205 | ] 206 | 207 | ITEM_TYPE = { 208 | "armor": [item.ITEM_TYPE_ID for item in Item.ARMOR], 209 | "weapon": [item.ITEM_TYPE_ID for item in Item.WEAPON], 210 | "tool": [item.ITEM_TYPE_ID for item in Item.TOOL], 211 | "ammo": [item.ITEM_TYPE_ID for item in Item.AMMUNITION], 212 | "consumable": [item.ITEM_TYPE_ID for item in Item.CONSUMABLE], 213 | } 214 | 215 | 216 | def process_event_log(realm, agent_list): 217 | """Process the event log and extract performed actions and achievements.""" 218 | log = realm.event_log.get_data(agents=agent_list) 219 | attr_to_col = realm.event_log.attr_to_col 220 | 221 | # count the number of events 222 | event_cnt = {} 223 | for key, code in INFO_KEY_TO_EVENT_CODE.items(): 224 | # count the freq of each event 225 | event_cnt[key] = int(sum(log[:, attr_to_col["event"]] == code)) 226 | 227 | # record true or false for each event 228 | performed = {} 229 | for evt in KEY_EVENT: 230 | key = "event/" + evt 231 | performed[key] = event_cnt[key] > 0 232 | 233 | # check if tools, weapons, ammos, ammos were equipped 234 | for item_type, item_ids in ITEM_TYPE.items(): 235 | if item_type == "consumable": 236 | continue 237 | key = "event/equip_" + item_type 238 | idx = (log[:, attr_to_col["event"]] == EventCode.EQUIP_ITEM) & np.in1d( 239 | log[:, attr_to_col["item_type"]], item_ids 240 | ) 241 | performed[key] = sum(idx) > 0 242 | 243 | # check if weapon was harvested 244 | key = "event/harvest_weapon" 245 | idx = (log[:, attr_to_col["event"]] == EventCode.HARVEST_ITEM) & np.in1d( 246 | log[:, attr_to_col["item_type"]], ITEM_TYPE["weapon"] 247 | ) 248 | performed[key] = sum(idx) > 0 249 | 250 | # record important achievements 251 | achieved = {} 252 | 253 | # get progress to center 254 | idx = log[:, attr_to_col["event"]] == EventCode.GO_FARTHEST 255 | achieved["achieved/max_progress_to_center"] = ( 256 | int(max(log[idx, attr_to_col["distance"]])) if sum(idx) > 0 else 0 257 | ) 258 | 259 | # get earned gold 260 | idx = log[:, attr_to_col["event"]] == EventCode.EARN_GOLD 261 | achieved["achieved/earned_gold"] = int(sum(log[idx, attr_to_col["gold"]])) 262 | 263 | # get max damage 264 | idx = log[:, attr_to_col["event"]] == EventCode.SCORE_HIT 265 | achieved["achieved/max_damage"] = ( 266 | int(max(log[idx, attr_to_col["damage"]])) if sum(idx) > 0 else 0 267 | ) 268 | 269 | # get max possessed item levels: from harvesting, looting, buying 270 | idx = np.in1d( 271 | log[:, attr_to_col["event"]], 272 | [EventCode.HARVEST_ITEM, EventCode.LOOT_ITEM, EventCode.BUY_ITEM], 273 | ) 274 | if sum(idx) > 0: 275 | for item_type, item_ids in ITEM_TYPE.items(): 276 | idx_item = np.in1d(log[idx, attr_to_col["item_type"]], item_ids) 277 | if sum(idx_item) > 0: 278 | achieved["achieved/max_" + item_type + "_level"] = int( 279 | max(log[idx][idx_item, attr_to_col["level"]]) 280 | ) 281 | 282 | # other notable achievements 283 | idx = log[:, attr_to_col["event"]] == EventCode.PLAYER_KILL 284 | achieved["achieved/agent_kill_count"] = int(sum(idx & (log[:, attr_to_col["target_ent"]] > 0))) 285 | achieved["achieved/npc_kill_count"] = int(sum(idx & (log[:, attr_to_col["target_ent"]] < 0))) 286 | achieved["achieved/unique_events"] = count_unique_events(log, set()) 287 | 288 | return achieved, performed, event_cnt 289 | 290 | 291 | # These events are important, so count them even though they are not unique 292 | EVERY_EVENT_TO_COUNT = set([EventCode.PLAYER_KILL, EventCode.EARN_GOLD]) 293 | 294 | 295 | def count_unique_events(tick_log, experienced, every_event_to_count=EVERY_EVENT_TO_COUNT): 296 | cnt_unique = 0 297 | if len(tick_log) == 0: 298 | return cnt_unique 299 | 300 | for row in tick_log[:, 3:6]: # only taking the event, type, level cols 301 | event = tuple(row) 302 | if event not in experienced: 303 | experienced.add(event) 304 | cnt_unique += 1 305 | 306 | elif row[0] in every_event_to_count: 307 | # These events are important, so count them even though they are not unique 308 | cnt_unique += 1 309 | 310 | return cnt_unique 311 | -------------------------------------------------------------------------------- /scripts/evaluate_policies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -u -O -m tools.evaluate \ 4 | --model.policy_pool=/fsx/home-daveey/experiments/pool.json \ 5 | --env.num_npcs=256 \ 6 | --eval.num_rounds=1000000 \ 7 | --eval.num_policies=8 \ 8 | "${@}" 9 | -------------------------------------------------------------------------------- /scripts/pre-git-check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 4 | echo "Checking pytest, pylint, xcxc without touching git" 5 | echo 6 | 7 | # Run unit tests 8 | echo 9 | echo "--------------------------------------------------------------------" 10 | echo "Running unit tests..." 11 | if ! pytest; then 12 | echo "Unit tests failed. Exiting." 13 | exit 1 14 | fi 15 | 16 | # Run linter 17 | echo "--------------------------------------------------------------------" 18 | echo "Running linter..." 19 | files=$(git ls-files -m -o --exclude-standard '*.py') 20 | for file in $files; do 21 | if test -e $file; then 22 | echo $file 23 | if ! pylint --score=no --fail-under=10 $file; then 24 | echo "Lint failed. Exiting." 25 | exit 1 26 | fi 27 | fi 28 | done 29 | 30 | # if ! pylint --recursive=y config feature_extractor model tests; then 31 | # echo "Lint failed. Exiting." 32 | # exit 1 33 | # fi 34 | 35 | # Check if there are any "xcxc" strings in the code 36 | echo "--------------------------------------------------------------------" 37 | echo "Looking for xcxc..." 38 | files=$(find . -name '*.py') 39 | for file in $files; do 40 | if grep -q 'xcxc' $file; then 41 | echo "Found xcxc in $file!" >&2 42 | read -p "Do you like to stop here? (y/n) " ans 43 | if [ "$ans" = "y" ]; then 44 | exit 1 45 | fi 46 | fi 47 | done 48 | 49 | echo 50 | echo "Pre-git checks look good!" 51 | echo -------------------------------------------------------------------------------- /scripts/slurm_jobs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get all running jobs 4 | running_jobs=$(sacct | grep RUN | grep -v .batch | awk '{print $1}') 5 | 6 | # Loop through each job ID 7 | for job_id in $running_jobs; do 8 | # Find the log file 9 | log_file=$(find sbatch/ -name "${job_id}.log") 10 | 11 | # Grep the experiment directory 12 | experiment_dir=$(grep -m 1 "Training run" "$log_file" | awk '{print $3}') 13 | 14 | # Print job_id and experiment_dir 15 | echo "$job_id, $experiment_dir" 16 | done 17 | -------------------------------------------------------------------------------- /scripts/slurm_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Example ussage: 4 | # 5 | # sbatch ./scripts/slurm_run.sh scripts/train_baseline.sh \ 6 | # --run-name=test --wandb-project=nmmo --wandb-entity=kywch 7 | 8 | #SBATCH --account=carperai 9 | #SBATCH --partition=g40x 10 | #SBATCH --nodes=1 11 | #SBATCH --gpus=1 12 | #SBATCH --cpus-per-gpu=8 13 | #__SBATCH --mem=80G 14 | #SBATCH --chdir=/weka/proj-nmmo/nmmo-baselines/ 15 | #SBATCH --output=sbatch/%j.log 16 | #SBATCH --error=sbatch/%j.log 17 | #SBATCH --requeue 18 | #SBATCH --export=PYTHONUNBUFFERED=1,WANDB_DIR=/weka/proj-nmmo/tmp/wandb,WANDB_CONFIG_DIR=/weka/proj-nmmo/tmp/wandb 19 | 20 | source /weka/proj-nmmo/venv/bin/activate && \ 21 | ulimit -c unlimited && \ 22 | ulimit -s unlimited && \ 23 | ulimit -a 24 | 25 | wandb login --cloud 26 | 27 | # Extract run_name from the arguments 28 | run_name="" 29 | args=() 30 | for i in "$@" 31 | do 32 | case $i in 33 | --train.run_name=*) 34 | run_name="${i#*=}" 35 | args+=("$i") 36 | shift 37 | ;; 38 | *) 39 | args+=("$i") 40 | shift 41 | ;; 42 | esac 43 | done 44 | 45 | # Create symlink to the log file 46 | if [ ! -z "$run_name" ]; then 47 | logfile="$SLURM_JOB_ID.log" 48 | symlink="sbatch/${run_name}.log" 49 | if [ -L "$symlink" ]; then 50 | rm "$symlink" 51 | fi 52 | ln -s "$logfile" "$symlink" 53 | fi 54 | 55 | max_retries=50 56 | retry_count=0 57 | 58 | while true; do 59 | stdbuf -oL -eL "${args[@]}" 60 | 61 | exit_status=$? 62 | echo "Job exited with status $exit_status." 63 | 64 | if [ $exit_status -eq 0 ]; then 65 | echo "Job completed successfully." 66 | break 67 | elif [ $exit_status -eq 101 ]; then 68 | echo "Job failed due to torch.cuda.OutOfMemoryError." 69 | elif [ $exit_status -eq 137 ]; then 70 | echo "Job failed due to OOM. Killing child processes..." 71 | 72 | # Killing child processes 73 | child_pids=$(pgrep -P $$) # This fetches all child processes of the current process 74 | if [ "$child_pids" != "" ]; then 75 | echo "The following child processes will be killed:" 76 | for pid in $child_pids; do 77 | echo "Child PID $pid: $(ps -p $pid -o cmd=)" 78 | done 79 | kill $child_pids # This kills the child processes 80 | fi 81 | 82 | # Killing processes that have the run name in their command line 83 | run_pids=$(pgrep -f "python.*$run_name") 84 | if [ "$run_pids" != "" ]; then 85 | echo "The following processes with '$run_name' will be killed:" 86 | for pid in $run_pids; do 87 | echo "Experiment PID $pid: $(ps -p $pid -o cmd=)" 88 | done 89 | kill $run_pids # This kills the processes 90 | fi 91 | elif [ $exit_status -eq 143 ]; then 92 | echo "Killing Zombie processes..." 93 | pids=$(pgrep -P $$) 94 | for pid in $pids; do 95 | if [ $(ps -o stat= -p $pid) == "Z" ]; then 96 | kill -9 $pid 97 | echo "Killed zombie process $pid" 98 | fi 99 | done 100 | fi 101 | retry_count=$((retry_count + 1)) 102 | if [ $retry_count -gt $max_retries ]; then 103 | echo "Job failed with exit status $exit_status. Maximum retries exceeded. Exiting..." 104 | break 105 | fi 106 | echo "Job failed with exit status $exit_status. Retrying in 10 seconds..." 107 | sleep 10 108 | done 109 | 110 | echo "Slurm Job completed." 111 | -------------------------------------------------------------------------------- /scripts/slurm_run_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Example ussage: 4 | # 5 | # sbatch ./scripts/slurm_run.sh scripts/train_baseline.sh \ 6 | # --train.experiment_name=realikun_16x8_0001 7 | 8 | #SBATCH --comment=carperai 9 | #SBATCH --partition=cpu128 10 | #SBATCH --nodes=1 11 | #SBATCH --mem=40G 12 | #SBATCH --chdir=/fsx/home-daveey/nmmo-baselines/ 13 | #SBATCH --output=sbatch/%j.log 14 | #SBATCH --error=sbatch/%j.log 15 | #SBATCH --requeue 16 | #SBATCH --export=PYTHONUNBUFFERED=1 17 | 18 | source /fsx/home-daveey/miniconda3/etc/profile.d/conda.sh && \ 19 | conda activate nmmo && \ 20 | ulimit -c unlimited && \ 21 | ulimit -s unlimited && \ 22 | ulimit -a 23 | 24 | # Extract experiment_name from the arguments 25 | experiment_name="" 26 | args=() 27 | for i in "$@" 28 | do 29 | case $i in 30 | --train.experiment_name=*) 31 | experiment_name="${i#*=}" 32 | args+=("$i") 33 | shift 34 | ;; 35 | *) 36 | args+=("$i") 37 | shift 38 | ;; 39 | esac 40 | done 41 | 42 | # Create symlink to the log file 43 | if [ ! -z "$experiment_name" ]; then 44 | logfile="$SLURM_JOB_ID.log" 45 | symlink="sbatch/${experiment_name}.log" 46 | if [ -L "$symlink" ]; then 47 | rm "$symlink" 48 | fi 49 | ln -s "$logfile" "$symlink" 50 | fi 51 | 52 | while true; do 53 | stdbuf -oL -eL "${args[@]}" 54 | 55 | exit_status=$? 56 | echo "Job exited with status $exit_status." 57 | 58 | if [ $exit_status -eq 0 ]; then 59 | echo "Job completed successfully." 60 | break 61 | elif [ $exit_status -eq 101 ]; then 62 | echo "Job failed due to torch.cuda.OutOfMemoryError." 63 | break 64 | else 65 | echo "Job failed with exit status $exit_status. Retrying..." 66 | fi 67 | done 68 | 69 | echo "Slurm Job completed." 70 | -------------------------------------------------------------------------------- /scripts/train_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -u -m train \ 4 | --train.data_dir=/weka/proj-nmmo/runs/ \ 5 | "${@}" 6 | -------------------------------------------------------------------------------- /scripts/upload_checkpoints.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # define the root directory 4 | root_dir="../experiments" 5 | 6 | # define the S3 bucket 7 | s3_bucket="s3://nmmo" 8 | 9 | # iterate over all subdirectories 10 | for dir in $(find $root_dir -type d); do 11 | 12 | # check if there are any .pt files in the directory 13 | if ls $dir/*.pt > /dev/null 2>&1; then 14 | 15 | # get the .pt files, sorted by name 16 | all_files=$(ls -v $dir/*.pt) 17 | files_to_keep=$(ls -v $dir/*.pt | tail -n 3) 18 | 19 | # iterate over all the files 20 | for file in $all_files; do 21 | 22 | # define the object path on S3 23 | s3_object_path="${s3_bucket}/$(basename $dir)/$(basename $file)" 24 | 25 | # check if the file is one of the files to keep 26 | if echo $files_to_keep | grep -q $file; then 27 | 28 | # check if the file exists on S3 29 | if ! aws s3 ls $s3_object_path > /dev/null 2>&1; then 30 | 31 | # upload the file to S3 32 | aws s3 cp $file $s3_object_path 33 | echo "Uploaded $file to $s3_object_path" 34 | 35 | else 36 | echo "File $s3_object_path already exists in S3, not uploading" 37 | 38 | fi 39 | 40 | else 41 | 42 | # remove the file 43 | rm $file 44 | echo "Deleted $file" 45 | 46 | fi 47 | done 48 | fi 49 | done 50 | -------------------------------------------------------------------------------- /scripts/upload_latest_checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if experiment name is provided 4 | if [ -z "$1" ] 5 | then 6 | echo "No experiment name provided. Usage: ./scriptname.sh " 7 | exit 1 8 | fi 9 | 10 | # Define the directory where the experiment results are 11 | EXP_NAME=$1 12 | EXP_DIR="../experiments/$EXP_NAME" 13 | 14 | # Define the target s3 bucket 15 | S3_BUCKET="s3://nmmo/model_weights" 16 | 17 | # Find the latest checkpoint file 18 | LATEST_FILE=$(find $EXP_DIR -name "*.pt" -type f -printf "%T@ %p\n" | sort -n | tail -1 | cut -f2- -d" ") 19 | 20 | # If a checkpoint file was found 21 | if [ -n "$LATEST_FILE" ]; then 22 | # Extract filename 23 | FILENAME=$(basename -- "$LATEST_FILE") 24 | 25 | # Construct the target S3 path 26 | S3_PATH="$S3_BUCKET/$EXP_NAME.$FILENAME" 27 | 28 | # Copy the latest file to s3 29 | aws s3 cp $LATEST_FILE $S3_PATH 30 | echo "Copied $LATEST_FILE to $S3_PATH" 31 | else 32 | echo "No checkpoint files found." 33 | fi 34 | -------------------------------------------------------------------------------- /tests/test_task_encoder.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | import curriculum_generation.manual_curriculum 5 | from curriculum_generation.task_encoder import TaskEncoder 6 | 7 | LLM_CHECKPOINT = "deepseek-ai/deepseek-coder-1.3b-instruct" 8 | CURRICULUM_FILE_PATH = "curriculum_with_embedding.pkl" 9 | 10 | # NOTE: different LLMs will give different embedding dimensions 11 | EMBEDDING_DIM = 4096 12 | 13 | 14 | class TestTaskEncoder(unittest.TestCase): 15 | @classmethod 16 | def setUpClass(cls): 17 | cls.task_encoder = TaskEncoder( 18 | LLM_CHECKPOINT, curriculum_generation.manual_curriculum, batch_size=4 19 | ) 20 | 21 | @classmethod 22 | def tearDownClass(cls): 23 | cls.task_encoder.close() 24 | 25 | def test_embed_dim(self): 26 | self.assertEqual(self.task_encoder.embed_dim, EMBEDDING_DIM) 27 | 28 | def test_task_encoder_api(self): 29 | task_spec_with_embedding = self.task_encoder.get_task_embedding( 30 | curriculum_generation.manual_curriculum.curriculum, save_to_file=CURRICULUM_FILE_PATH 31 | ) 32 | 33 | for single_spec in task_spec_with_embedding: 34 | self.assertFalse(sum(single_spec.embedding) == 0) 35 | 36 | def test_get_task_deps_src(self): 37 | custom_fn = curriculum_generation.manual_curriculum.PracticeInventoryManagement 38 | fn_src, deps_src = self.task_encoder._get_task_deps_src(custom_fn) 39 | 40 | self.assertEqual( 41 | fn_src, 42 | "def PracticeInventoryManagement(gs, subject, space, num_tick):\n " 43 | + "return InventorySpaceGE(gs, subject, space) * TickGE(gs, subject, num_tick)\n", 44 | ) 45 | self.assertTrue("def InventorySpaceGE(" in deps_src) 46 | self.assertTrue("def TickGE(" in deps_src) 47 | 48 | def test_contruct_prompt(self): 49 | single_spec = random.choice(curriculum_generation.manual_curriculum.curriculum) 50 | prompt = self.task_encoder._construct_prompt( 51 | single_spec.reward_to, single_spec.eval_fn, single_spec.eval_fn_kwargs 52 | ) 53 | print(prompt) 54 | 55 | 56 | if __name__ == "__main__": 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import inspect 4 | import logging 5 | import sys 6 | import time 7 | 8 | import pufferlib 9 | import yaml 10 | 11 | from reinforcement_learning import environment 12 | from train_helper import generate_replay, init_wandb, sweep, train 13 | import syllabus_wrapper 14 | 15 | DEBUG = False 16 | # See curriculum_generation/manual_curriculum.py for details 17 | BASELINE_CURRICULUM = "curriculum_generation/curriculum_with_embedding.pkl" 18 | 19 | 20 | def load_from_config(agent, debug=False): 21 | with open("config.yaml") as f: 22 | config = yaml.safe_load(f) 23 | default_keys = ( 24 | "env train policy recurrent sweep_metadata sweep_metric sweep wandb reward_wrapper".split() 25 | ) 26 | defaults = {key: config.get(key, {}) for key in default_keys} 27 | 28 | debug_config = config.get("debug", {}) if debug else {} 29 | agent_config = config[agent] 30 | 31 | combined_config = {} 32 | for key in default_keys: 33 | agent_subconfig = agent_config.get(key, {}) 34 | debug_subconfig = debug_config.get(key, {}) 35 | combined_config[key] = {**defaults[key], **agent_subconfig, **debug_subconfig} 36 | 37 | return pufferlib.namespace(**combined_config) 38 | 39 | 40 | def get_init_args(fn): 41 | if fn is None: 42 | return {} 43 | sig = inspect.signature(fn) 44 | args = {} 45 | for name, param in sig.parameters.items(): 46 | if name in ("self", "env", "policy"): 47 | continue 48 | if name in ("agent_id", "is_multiagent"): # Postprocessor args 49 | continue 50 | if param.kind == inspect.Parameter.VAR_POSITIONAL: 51 | continue 52 | elif param.kind == inspect.Parameter.VAR_KEYWORD: 53 | continue 54 | else: 55 | args[name] = param.default if param.default is not inspect.Parameter.empty else None 56 | return args 57 | 58 | 59 | def setup_agent(module_name): 60 | try: 61 | agent_module = importlib.import_module(f"agent_zoo.{module_name}") 62 | except ModuleNotFoundError: 63 | raise ValueError(f"Agent module {module_name} not found under the agent_zoo directory.") 64 | 65 | recurrent_policy = getattr(agent_module, "Recurrent", None) 66 | 67 | def agent_creator(env, args): 68 | policy = agent_module.Policy(env, **args.policy) 69 | if not args.no_recurrence and recurrent_policy is not None: 70 | policy = recurrent_policy(env, policy, **args.recurrent) 71 | policy = pufferlib.frameworks.cleanrl.RecurrentPolicy(policy) 72 | else: 73 | policy = pufferlib.frameworks.cleanrl.Policy(policy) 74 | return policy.to(args.train.device) 75 | 76 | init_args = { 77 | "policy": get_init_args(agent_module.Policy.__init__), 78 | "recurrent": get_init_args(agent_module.Recurrent.__init__), 79 | "reward_wrapper": get_init_args(agent_module.RewardWrapper.__init__), 80 | } 81 | 82 | return agent_module, agent_creator, init_args 83 | 84 | 85 | def combine_config_args(parser, args, config): 86 | clean_parser = argparse.ArgumentParser(parents=[parser]) 87 | for name, sub_config in config.items(): 88 | args[name] = {} 89 | for key, value in sub_config.items(): 90 | data_key = f"{name}.{key}" 91 | cli_key = f"--{data_key}".replace("_", "-") 92 | if isinstance(value, bool) and value is False: 93 | parser.add_argument(cli_key, default=value, action="store_true") 94 | clean_parser.add_argument(cli_key, default=value, action="store_true") 95 | elif isinstance(value, bool) and value is True: 96 | data_key = f"{name}.no_{key}" 97 | cli_key = f"--{data_key}".replace("_", "-") 98 | parser.add_argument(cli_key, default=value, action="store_false") 99 | clean_parser.add_argument(cli_key, default=value, action="store_false") 100 | else: 101 | parser.add_argument(cli_key, default=value, type=type(value)) 102 | clean_parser.add_argument(cli_key, default=value, metavar="", type=type(value)) 103 | 104 | args[name][key] = getattr(parser.parse_known_args()[0], data_key) 105 | args[name] = pufferlib.namespace(**args[name]) 106 | 107 | clean_parser.parse_args(sys.argv[1:]) 108 | return args 109 | 110 | 111 | def update_args(args, mode=None): 112 | args = pufferlib.namespace(**args) 113 | 114 | args.track = not args.no_track 115 | args.env.curriculum_file_path = args.curriculum 116 | 117 | vec = args.vectorization 118 | if vec == "serial" or args.debug: 119 | args.vectorization = pufferlib.vectorization.Serial 120 | elif vec == "multiprocessing": 121 | args.vectorization = pufferlib.vectorization.Multiprocessing 122 | elif vec == "ray": 123 | args.vectorization = pufferlib.vectorization.Ray 124 | else: 125 | raise ValueError("Invalid --vectorization (serial/multiprocessing/ray).") 126 | 127 | # TODO: load the trained baseline from wandb 128 | # elif args.baseline: 129 | # args.track = True 130 | # version = '.'.join(pufferlib.__version__.split('.')[:2]) 131 | # args.exp_name = f'puf-{version}-nmmo' 132 | # args.wandb_group = f'puf-{version}-baseline' 133 | # shutil.rmtree(f'experiments/{args.exp_name}', ignore_errors=True) 134 | # run = init_wandb(args, resume=False) 135 | # if args.mode == 'evaluate': 136 | # model_name = f'puf-{version}-nmmo_model:latest' 137 | # artifact = run.use_artifact(model_name) 138 | # data_dir = artifact.download() 139 | # model_file = max(os.listdir(data_dir)) 140 | # args.eval_model_path = os.path.join(data_dir, model_file) 141 | 142 | if mode in ["evaluate", "replay"]: 143 | assert args.eval_model_path is not None, "Eval mode requires a path to checkpoints" 144 | args.track = False 145 | # Disable env pool - see the comment about next_lstm_state in clean_pufferl.evaluate() 146 | args.train.env_pool = False 147 | args.env.resilient_population = 0 148 | args.reward_wrapper.eval_mode = True 149 | args.reward_wrapper.early_stop_agent_num = 0 150 | 151 | if mode == "replay": 152 | args.train.num_envs = args.train.envs_per_worker = args.train.envs_per_batch = 1 153 | args.vectorization = pufferlib.vectorization.Serial 154 | 155 | return args 156 | 157 | 158 | if __name__ == "__main__": 159 | logging.basicConfig(level=logging.INFO) 160 | parser = argparse.ArgumentParser(description="Parse environment argument", add_help=False) 161 | parser.add_argument( 162 | "-m", "--mode", type=str, default="train", choices="train sweep replay".split() 163 | ) 164 | parser.add_argument( 165 | "-a", "--agent", type=str, default="neurips23_start_kit", help="Agent module to use" 166 | ) 167 | parser.add_argument( 168 | "-n", "--exp-name", type=str, default=None, help="Need exp name to resume the experiment" 169 | ) 170 | parser.add_argument( 171 | "-p", "--eval-model-path", type=str, default=None, help="Path to model to evaluate" 172 | ) 173 | parser.add_argument( 174 | "-c", "--curriculum", type=str, default=BASELINE_CURRICULUM, help="Path to curriculum file" 175 | ) 176 | parser.add_argument( 177 | "-t", 178 | "--task-to-assign", 179 | type=int, 180 | default=None, 181 | help="The index of the task to assign in the curriculum file", 182 | ) 183 | # parser.add_argument( 184 | # "--test-curriculum", type=str, default=BASELINE_CURRICULUM, help="Path to curriculum file" 185 | # ) 186 | parser.add_argument("--syllabus", action="store_true", help="Use Syllabus for curriculum") 187 | # parser.add_argument('--baseline', action='store_true', help='Baseline run') 188 | parser.add_argument( 189 | "--vectorization", 190 | type=str, 191 | default="multiprocessing", 192 | choices="serial multiprocessing ray".split(), 193 | ) 194 | parser.add_argument("--no-recurrence", action="store_true", help="Do not use recurrence") 195 | if DEBUG: 196 | parser.add_argument("--no-track", default=True, help="Do NOT track on WandB") 197 | parser.add_argument("--debug", default=True, help="Debug mode") 198 | else: 199 | parser.add_argument("--no-track", action="store_true", help="Do NOT track on WandB") 200 | parser.add_argument("--debug", action="store_true", help="Debug mode") 201 | 202 | args = parser.parse_known_args()[0].__dict__ 203 | config = load_from_config(args["agent"], debug=args.get("debug", False)) 204 | agent_module, agent_creator, init_args = setup_agent(args["agent"]) 205 | 206 | # Update config with environment defaults 207 | config.policy = {**init_args["policy"], **config.policy} 208 | config.recurrent = {**init_args["recurrent"], **config.recurrent} 209 | config.reward_wrapper = {**init_args["reward_wrapper"], **config.reward_wrapper} 210 | 211 | # Generate argparse menu from config 212 | args = combine_config_args(parser, args, config) 213 | 214 | # Perform mode-specific updates 215 | args = update_args(args, mode=args["mode"]) 216 | 217 | # Make default or syllabus-based env_creator 218 | syllabus = None 219 | if args.syllabus is True: 220 | # NOTE: Setting use_custom_reward to False will ignore the agent's custom reward 221 | # and only use the env-provided reward from the curriculum tasks 222 | args.reward_wrapper.use_custom_reward = False 223 | syllabus, env_creator = syllabus_wrapper.make_syllabus_env_creator(args, agent_module) 224 | else: 225 | args.env.curriculum_file_path = args.curriculum 226 | env_creator = environment.make_env_creator(reward_wrapper_cls=agent_module.RewardWrapper) 227 | 228 | if args.train.env_pool is True: 229 | logging.warning( 230 | "Env_pool is enabled. This may increase training speed but break determinism." 231 | ) 232 | 233 | if args.track: 234 | args.exp_name = init_wandb(args).id 235 | else: 236 | args.exp_name = f"nmmo_{time.strftime('%Y%m%d_%H%M%S')}" 237 | 238 | if args.mode == "train": 239 | train(args, env_creator, agent_creator, syllabus) 240 | exit(0) 241 | elif args.mode == "sweep": 242 | sweep(args, env_creator, agent_creator) 243 | exit(0) 244 | elif args.mode == "replay": 245 | generate_replay(args, env_creator, agent_creator) 246 | exit(0) 247 | else: 248 | raise ValueError("Mode must be one of train, sweep, or evaluate") 249 | -------------------------------------------------------------------------------- /train_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | 5 | import dill 6 | import wandb 7 | import torch 8 | import numpy as np 9 | 10 | import pufferlib.policy_pool as pp 11 | from nmmo.render.replay_helper import FileReplayHelper 12 | from nmmo.task.task_spec import make_task_from_spec 13 | 14 | from reinforcement_learning import clean_pufferl 15 | 16 | # Related to torch.use_deterministic_algorithms(True) 17 | # See also https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility 18 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 19 | 20 | 21 | def init_wandb(args, resume=True): 22 | if args.no_track: 23 | return None 24 | assert args.wandb.project is not None, "Please set the wandb project in config.yaml" 25 | assert args.wandb.entity is not None, "Please set the wandb entity in config.yaml" 26 | wandb_kwargs = { 27 | "id": args.exp_name or wandb.util.generate_id(), 28 | "project": args.wandb.project, 29 | "entity": args.wandb.entity, 30 | "config": { 31 | "cleanrl": args.train, 32 | "env": args.env, 33 | "agent_zoo": args.agent, 34 | "policy": args.policy, 35 | "recurrent": args.recurrent, 36 | "reward_wrapper": args.reward_wrapper, 37 | "syllabus": args.syllabus, 38 | }, 39 | "name": args.exp_name, 40 | "monitor_gym": True, 41 | "save_code": True, 42 | "resume": resume, 43 | } 44 | if args.wandb.group is not None: 45 | wandb_kwargs["group"] = args.wandb.group 46 | return wandb.init(**wandb_kwargs) 47 | 48 | 49 | def train(args, env_creator, agent_creator, syllabus=None): 50 | data = clean_pufferl.create( 51 | config=args.train, 52 | agent_creator=agent_creator, 53 | agent_kwargs={"args": args}, 54 | env_creator=env_creator, 55 | env_creator_kwargs={"env": args.env, "reward_wrapper": args.reward_wrapper}, 56 | vectorization=args.vectorization, 57 | exp_name=args.exp_name, 58 | track=args.track, 59 | ) 60 | 61 | while not clean_pufferl.done_training(data): 62 | clean_pufferl.evaluate(data) 63 | clean_pufferl.train(data) 64 | if syllabus is not None: 65 | syllabus.log_metrics(data.wandb, step=data.global_step) 66 | 67 | print("Done training. Saving data...") 68 | clean_pufferl.close(data) 69 | print("Run complete.") 70 | 71 | 72 | def sweep(args, env_creator, agent_creator): 73 | sweep_id = wandb.sweep(sweep=args.sweep, project=args.wandb.project) 74 | 75 | def main(): 76 | try: 77 | args.exp_name = init_wandb(args).id 78 | if hasattr(wandb.config, "train"): 79 | # TODO: Add update method to namespace 80 | print(args.train.__dict__) 81 | print(wandb.config.train) 82 | args.train.__dict__.update(dict(wandb.config.train)) 83 | train(args, env_creator, agent_creator) 84 | except Exception as e: # noqa: F841 85 | import traceback 86 | 87 | traceback.print_exc() 88 | 89 | wandb.agent(sweep_id, main, count=20) 90 | 91 | 92 | def generate_replay(args, env_creator, agent_creator, stop_when_all_complete_task=True, seed=None): 93 | assert args.eval_model_path is not None, "eval_model_path must be set for replay generation" 94 | policies = pp.get_policy_names(args.eval_model_path) 95 | assert len(policies) > 0, "No policies found in eval_model_path" 96 | logging.info(f"Policies to generate replay: {policies}") 97 | 98 | save_dir = args.eval_model_path 99 | logging.info("Replays will be saved to %s", save_dir) 100 | 101 | if seed is not None: 102 | args.train.seed = seed 103 | logging.info("Seed: %d", args.train.seed) 104 | 105 | # Set the train config for replay 106 | args.train.num_envs = 1 107 | args.train.envs_per_batch = 1 108 | args.train.envs_per_worker = 1 109 | 110 | # Set the reward wrapper for replay 111 | args.reward_wrapper.eval_mode = True 112 | args.reward_wrapper.early_stop_agent_num = 0 113 | 114 | # Use the policy pool helper functions to create kernel (policy-agent mapping) 115 | args.train.pool_kernel = pp.create_kernel( 116 | args.env.num_agents, len(policies), shuffle_with_seed=args.train.seed 117 | ) 118 | 119 | data = clean_pufferl.create( 120 | config=args.train, 121 | agent_creator=agent_creator, 122 | agent_kwargs={"args": args}, 123 | env_creator=env_creator, 124 | env_creator_kwargs={"env": args.env, "reward_wrapper": args.reward_wrapper}, 125 | eval_mode=True, 126 | eval_model_path=args.eval_model_path, 127 | policy_selector=pp.AllPolicySelector(args.train.seed), 128 | ) 129 | 130 | # Set up the replay helper 131 | o, r, d, t, i, env_id, mask = data.pool.recv() # This resets the env 132 | replay_helper = FileReplayHelper() 133 | nmmo_env = data.pool.multi_envs[0].envs[0].env.env 134 | nmmo_env.realm.record_replay(replay_helper) 135 | 136 | # Sanity checks for replay generation 137 | assert len(policies) == len(data.policy_pool.current_policies), "Policy count mismatch" 138 | assert len(data.policy_pool.kernel) == nmmo_env.max_num_agents, "Agent count mismatch" 139 | 140 | # Add the policy names to agent names 141 | if len(policies) > 1: 142 | for policy_id, samp in data.policy_pool.sample_idxs.items(): 143 | policy_name = "learner" 144 | if policy_id in data.policy_pool.current_policies: 145 | policy_name = data.policy_pool.current_policies[policy_id]["name"] 146 | for idx in samp: 147 | agent_id = idx + 1 # agents are 0-indexed in policy_pool, but 1-indexed in nmmo 148 | nmmo_env.realm.players[agent_id].name = f"{policy_name}_{agent_id}" 149 | 150 | # Assign the specified task to the agents, if provided 151 | if args.task_to_assign is not None: 152 | with open(args.curriculum, "rb") as f: 153 | task_with_embedding = dill.load(f) # a list of TaskSpec 154 | assert 0 <= args.task_to_assign < len(task_with_embedding), "Task index out of range" 155 | select_task = task_with_embedding[args.task_to_assign] 156 | tasks = make_task_from_spec( 157 | nmmo_env.possible_agents, [select_task] * len(nmmo_env.possible_agents) 158 | ) 159 | 160 | # Reassign the task to the agents 161 | nmmo_env.tasks = tasks 162 | nmmo_env._map_task_to_agent() # update agent_task_map 163 | for agent_id in nmmo_env.possible_agents: 164 | # task_spec must have tasks for all agents, otherwise it will cause an error 165 | task_embedding = nmmo_env.agent_task_map[agent_id][0].embedding 166 | nmmo_env.obs[agent_id].gym_obs.reset(task_embedding) 167 | 168 | print(f"All agents are assigned: {nmmo_env.tasks[0].spec_name}\n") 169 | 170 | # Generate the replay 171 | replay_helper.reset() 172 | while True: 173 | with torch.no_grad(): 174 | o = torch.as_tensor(o) 175 | r = torch.as_tensor(r).float().to(data.device).view(-1) 176 | d = torch.as_tensor(d).float().to(data.device).view(-1) 177 | 178 | # env_pool must be false for the lstm to work 179 | next_lstm_state = data.next_lstm_state 180 | if next_lstm_state is not None: 181 | next_lstm_state = ( 182 | next_lstm_state[0][:, env_id], 183 | next_lstm_state[1][:, env_id], 184 | ) 185 | 186 | actions, logprob, value, next_lstm_state = data.policy_pool.forwards( 187 | o.to(data.device), next_lstm_state 188 | ) 189 | 190 | if next_lstm_state is not None: 191 | h, c = next_lstm_state 192 | data.next_lstm_state[0][:, env_id] = h 193 | data.next_lstm_state[1][:, env_id] = c 194 | 195 | value = value.flatten() 196 | 197 | data.pool.send(actions.cpu().numpy()) 198 | o, r, d, t, i, env_id, mask = data.pool.recv() 199 | 200 | num_alive = len(nmmo_env.agents) 201 | task_done = sum(1 for task in nmmo_env.tasks if task.completed) 202 | alive_done = sum( 203 | 1 204 | for task in nmmo_env.tasks 205 | if task.completed and task.assignee[0] in nmmo_env.realm.players 206 | ) 207 | print("Tick:", nmmo_env.realm.tick, ", alive agents:", num_alive, ", task done:", task_done) 208 | if num_alive == alive_done: 209 | print("All alive agents completed the task.") 210 | break 211 | if num_alive == 0 or nmmo_env.realm.tick == args.env.max_episode_length: 212 | print("All agents died or reached the max episode length.") 213 | break 214 | 215 | # Count how many agents completed the task 216 | print("--------------------------------------------------") 217 | print("Task:", nmmo_env.tasks[0].spec_name) 218 | num_completed = sum(1 for task in nmmo_env.tasks if task.completed) 219 | print("Number of agents completed the task:", num_completed) 220 | avg_progress = np.mean([task.progress_info["max_progress"] for task in nmmo_env.tasks]) 221 | print(f"Average maximum progress (max=1): {avg_progress:.3f}") 222 | avg_completed_tick = 0 223 | if num_completed > 0: 224 | avg_completed_tick = np.mean( 225 | [task.progress_info["completed_tick"] for task in nmmo_env.tasks if task.completed] 226 | ) 227 | print(f"Average completed tick: {avg_completed_tick:.1f}") 228 | 229 | # Save the replay file 230 | replay_file = f"replay_seed_{args.train.seed}_" 231 | if args.task_to_assign is not None: 232 | replay_file += f"task_{args.task_to_assign}_" 233 | replay_file = os.path.join(save_dir, replay_file + time.strftime("%Y%m%d_%H%M%S")) 234 | print(f"Saving replay to {replay_file}") 235 | replay_helper.save(replay_file, compress=True) 236 | clean_pufferl.close(data) 237 | 238 | return replay_file 239 | --------------------------------------------------------------------------------