├── .gitignore ├── README.md ├── ablations ├── baseline.yaml ├── baseline_lowmem.yaml ├── module_cost_no_map_size_no_msdamage_no.yaml ├── module_cost_no_map_size_no_msdamage_yes.yaml ├── module_cost_no_map_size_yes_msdamage_no.yaml ├── module_cost_no_map_size_yes_msdamage_yes.yaml ├── module_cost_yes_map_size_no_msdamage_no.yaml ├── module_cost_yes_map_size_no_msdamage_yes.yaml ├── module_cost_yes_map_size_yes_msdamage_no.yaml ├── ominiscient_value_function.yaml ├── rotational_invariance.yaml ├── small_network.yaml └── sparse_reward.yaml ├── adr.py ├── codecraft.py ├── gather.py ├── gym_codecraft ├── __init__.py └── envs │ ├── __init__.py │ └── codecraft_vec_env.py ├── hyper_params.py ├── list_net.py ├── main.py ├── multihead_attention.py ├── plot_results.py ├── policy_t2.py ├── policy_t3.py ├── policy_t4.py ├── policy_t5.py ├── policy_t6.py ├── policy_t7.py ├── policy_t8.py ├── progress.ipynb ├── requirements.txt ├── reset-drivers.sh ├── runner.py ├── schedule.py ├── setup-remote.sh ├── setup-system.sh ├── showmatch.py ├── spatial.py └── test_spatial_scatter.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | params*.yaml 4 | wandb 5 | verify 6 | plots 7 | plotspng 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep CodeCraft 2 | 3 | Hacky research code that trains policies for the [CodeCraft](http://codecraftgame.org/) real-time strategy game with proximal policy optimization. 4 | 5 | Blog post: [Mastering Real-Time Strategy Games with Deep Reinforcement Learning: Mere Mortal Edition](https://clemenswinter.com/2021/03/24/mastering-real-time-strategy-games-with-deep-reinforcement-learning-mere-mortal-edition/) 6 | 7 | ## Requirements 8 | 9 | - Python >= 3.7, pip 10 | - [CodeCraft Server](https://github.com/cswinter/CodeCraftServer/) 11 | 12 | ## Setup 13 | 14 | Install dependencies with 15 | 16 | ``` 17 | pip install -r requirements.txt 18 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.6.0+${CUDA}.html 19 | ``` 20 | 21 | where `${CUDA}` should be replaced by either `cpu`, `cu92`, `cu101` or `cu102` depending on your PyTorch installation. 22 | 23 | If you want the training code to record metrics to [Weights & Biases](https://www.wandb.com/), run `wandb login`. 24 | 25 | ## Usage 26 | 27 | The first step is to setup and run [CodeCraft Server](https://github.com/cswinter/CodeCraftServer/). 28 | 29 | ### Training 30 | 31 | To train a policy with the default set of hyperparameters, run: 32 | 33 | ``` 34 | EVAL_MODELS_PATH=/path/to/golden-models python main.py --hpset=standard --out-dir=${OUT_DIR}` 35 | ``` 36 | 37 | Logs and model checkpoints will be written to the `${OUT_DIR}` directory. 38 | If you want policies to be evaluted against a set of fixed opponents during training, download the required checkpoints [available here](https://www.dropbox.com/sh/h0f4faf7cbubn3t/AACfYYYY9kwPwNjm_TeCahxAa/golden-models?dl=0) to the right subfolder in the folder specified by `EVAL_MODEL_PATH`. 39 | For evaluations with the standard config, you need `standard/curious-galaxy-40M.pt` and `standard/graceful-frog-100M.pt`. 40 | To disable evaluation of the policy during training, set `--eval_envs=0`. 41 | To see additional options, run `python main.py --help` and consult [hyperparams.py](https://github.com/cswinter/DeepCodeCraft/blob/master/hyper_params.py). 42 | 43 | ### Showmatch 44 | 45 | To run games with already trained policies, run: 46 | 47 | ``` 48 | python showmatch.py /path/to/policy1.pt /path/to/policy2.pt --task=STANDARD --num_envs=64 49 | ``` 50 | 51 | You can then watch the games at . 52 | 53 | ### Job Runner 54 | 55 | The job runner allows you to schedule and execute many runs in parallel. 56 | The command 57 | 58 | ``` 59 | python runner.py --jobfile-dir=${JOB_DIR} --out-dir=${OUT_DIR} --concurrency=${CONCURRENCY} 60 | ``` 61 | 62 | starts a job runner that watches the `${JOB_DIR}` directory for new jobs, writes results to folders created in `${OUT_DIR}` and will run up to `${CONCURRENCY}` experiments in parallel. 63 | 64 | You can then schedule jobs with 65 | 66 | ``` 67 | python schedule.py --repo-path=https://github.com/cswinter/DeepCodeCraft.git --queue-dir=${JOB_DIR} --params-file=params.yaml 68 | ``` 69 | 70 | where `params.yaml` is a file that specifies the set of hyperparameters to use, for example: 71 | 72 | ``` 73 | - hpset: standard 74 | adr_variety: [0.5, 0.3] 75 | lr: [0.001, 0.0003] 76 | - hpset: standard 77 | repeat: 4 78 | steps: 300e6 79 | ``` 80 | 81 | The `repeat` parameter tells the job runner to spawn multiple runs. 82 | When a hyperparameter is set to a list of different values, one experiment is spawned for each combination. 83 | So above `params.yaml` will spawn a total of 8 experiment runs, 4 of which will run for 300 million samples with the default set of hyperparameters, and one additional run for all 4 combinations of the `adr_variety` and `lr` hyperparameters. 84 | 85 | The `${JOB_DIR}` may be on a remote machine that you can access via ssh/rsync, e.g. `--queue-dir=192.168.0.101:/home/clemens/xprun/queue`. 86 | 87 | ### Citation 88 | 89 | ``` 90 | @misc{DeepCodeCraft2020, 91 | author = {Winter, Clemens}, 92 | title = {Deep CodeCraft}, 93 | year = {2020}, 94 | publisher = {GitHub}, 95 | journal = {GitHub repository}, 96 | howpublished = {\url{https://github.com/cswinter/DeepCodeCraft}} 97 | } 98 | ``` 99 | -------------------------------------------------------------------------------- /ablations/baseline.yaml: -------------------------------------------------------------------------------- 1 | # Baseline 2 | - hpset: standard 3 | repeat: 8 4 | -------------------------------------------------------------------------------- /ablations/baseline_lowmem.yaml: -------------------------------------------------------------------------------- 1 | # Baseline but reduce memory usage 2 | - hpset: standard 3 | repeat: 8 4 | batches_per_update: 64 5 | bs: 256 6 | -------------------------------------------------------------------------------- /ablations/module_cost_no_map_size_no_msdamage_no.yaml: -------------------------------------------------------------------------------- 1 | # No mothership damage modifier curriculum, map curriculum, or cost automatic domain randomization 2 | - hpset: standard 3 | repeat: 8 4 | mothership_damage_scale: 0.0 5 | mothership_damage_scale_schedule: '' 6 | adr_variety: 0.0 7 | adr_variety_schedule: '' 8 | task_hardness: 150 9 | linear_hardness: False 10 | adr_hstepsize: 0.0 11 | -------------------------------------------------------------------------------- /ablations/module_cost_no_map_size_no_msdamage_yes.yaml: -------------------------------------------------------------------------------- 1 | # No map curriculum or cost automatic domain randomization 2 | - hpset: standard 3 | repeat: 8 4 | adr_variety: 0.0 5 | adr_variety_schedule: '' 6 | task_hardness: 150 7 | linear_hardness: False 8 | adr_hstepsize: 0.0 9 | -------------------------------------------------------------------------------- /ablations/module_cost_no_map_size_yes_msdamage_no.yaml: -------------------------------------------------------------------------------- 1 | # No mothership damage modifier curriculum or cost automatic domain randomization 2 | - hpset: standard 3 | repeat: 8 4 | mothership_damage_scale: 0.0 5 | mothership_damage_scale_schedule: '' 6 | adr_variety: 0.0 7 | adr_variety_schedule: '' 8 | -------------------------------------------------------------------------------- /ablations/module_cost_no_map_size_yes_msdamage_yes.yaml: -------------------------------------------------------------------------------- 1 | # No cost automatic domain randomization 2 | - hpset: standard 3 | repeat: 8 4 | adr_variety: 0.0 5 | adr_variety_schedule: '' 6 | batches_per_update: 64 7 | bs: 256 8 | -------------------------------------------------------------------------------- /ablations/module_cost_yes_map_size_no_msdamage_no.yaml: -------------------------------------------------------------------------------- 1 | # No mothership damage modifier curriculum or map curriculum 2 | - hpset: standard 3 | repeat: 8 4 | mothership_damage_scale: 0.0 5 | mothership_damage_scale_schedule: '' 6 | task_hardness: 150 7 | linear_hardness: False 8 | adr_hstepsize: 0.0 9 | -------------------------------------------------------------------------------- /ablations/module_cost_yes_map_size_no_msdamage_yes.yaml: -------------------------------------------------------------------------------- 1 | - hpset: standard 2 | repeat: 8 3 | task_hardness: 150 4 | linear_hardness: False 5 | adr_hstepsize: 0.0 6 | -------------------------------------------------------------------------------- /ablations/module_cost_yes_map_size_yes_msdamage_no.yaml: -------------------------------------------------------------------------------- 1 | # No mothership damage modifier curriculum 2 | - hpset: standard 3 | repeat: 8 4 | mothership_damage_scale: 0.0 5 | mothership_damage_scale_schedule: '' 6 | -------------------------------------------------------------------------------- /ablations/ominiscient_value_function.yaml: -------------------------------------------------------------------------------- 1 | # Value function can only see objects visible to the player 2 | - hpset: standard 3 | repeat: 8 4 | use_privileged: False 5 | -------------------------------------------------------------------------------- /ablations/rotational_invariance.yaml: -------------------------------------------------------------------------------- 1 | # No rotational invariance 2 | - hpset: standard 3 | repeat: 8 4 | rotational_invariance: False 5 | batches_per_update: 64 6 | bs: 256 7 | -------------------------------------------------------------------------------- /ablations/small_network.yaml: -------------------------------------------------------------------------------- 1 | # Smaller network 2 | - hpset: standard 3 | repeat: 8 4 | d_agent: 128 5 | d_item: 64 6 | -------------------------------------------------------------------------------- /ablations/sparse_reward.yaml: -------------------------------------------------------------------------------- 1 | # Only reward is +2 when winning match 2 | - hpset: standard 3 | partial_score: 0.0 4 | repeat: 8 5 | -------------------------------------------------------------------------------- /adr.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import defaultdict 3 | from gym_codecraft.envs.codecraft_vec_env import Rules 4 | 5 | 6 | class ADR: 7 | def __init__(self, 8 | hstepsize, 9 | stepsize=0.02, 10 | warmup=10, 11 | initial_hardness=0.0, 12 | ruleset: Rules = None, 13 | linear_hardness: bool = False, 14 | max_hardness: float = 200, 15 | hardness_offset: float = 0, 16 | variety: float = 0.7, 17 | step: int = 0, 18 | average_cost_target: float = 0.8): 19 | if ruleset is None: 20 | ruleset = Rules( 21 | cost_modifier_size=4 * [average_cost_target], 22 | cost_modifier_missiles=average_cost_target, 23 | cost_modifier_shields=average_cost_target, 24 | cost_modifier_storage=average_cost_target, 25 | cost_modifier_constructor=average_cost_target, 26 | cost_modifier_engines=average_cost_target, 27 | ) 28 | self.ruleset = ruleset 29 | self.variety = variety 30 | self.target_fractions = normalize({b: 1.0 for b in [ 31 | '1m', '1s', '1m1p', '2m', '1s1c', '2m1e1p', '3m1p', '2m2p', '2s2c', '2s1c1e', '2s1m1c' 32 | ]}) 33 | 34 | self.target_modifier = average_cost_target 35 | self.stepsize = stepsize 36 | self.warmup = warmup 37 | self.step = step 38 | 39 | self.w_ema = 0.5 40 | self.counts = defaultdict(lambda: 0.0) 41 | 42 | self.hardness = initial_hardness 43 | self.max_hardness = max_hardness 44 | self.linear_hardness = linear_hardness 45 | self.hardness_offset = hardness_offset 46 | self.stepsize_hardness = hstepsize 47 | self.target_elimination_rate = 0.97 48 | 49 | def target_eplenmean(self): 50 | if self.hardness < 25: 51 | return 250 + 6 * self.hardness 52 | elif self.hardness < 50: 53 | return 400 + 4 * (self.hardness - 25) 54 | elif self.hardness < 100: 55 | return 500 + 2 * (self.hardness - 50) 56 | else: 57 | return 600 58 | 59 | def adjust(self, counts, elimination_rate, eplenmean, step) -> float: 60 | self.step += 1 61 | stepsize = self.stepsize * min(1.0, self.step / self.warmup) 62 | for build, bfraction in counts.items(): 63 | self.counts[build] = (1 - self.w_ema) * bfraction + self.w_ema * self.counts[build] 64 | 65 | gradient = defaultdict(lambda: 0.0) 66 | weight = defaultdict(lambda: 0.0) 67 | for build, bfraction in normalize(self.counts).items(): 68 | if bfraction == 0: 69 | loss = -100 70 | else: 71 | loss = -self.variety * math.log(self.target_fractions[build] / bfraction) 72 | 73 | for module, mfraction in module_norm(build).items(): 74 | gradient[module] += mfraction * loss 75 | weight[module] += mfraction 76 | size_key = f'size{size(build)}' 77 | gradient[size_key] += 0.3 * loss 78 | weight[size_key] += 1 79 | for key in gradient.keys(): 80 | gradient[key] /= weight[key] 81 | 82 | modifier_decay = 1 - self.variety 83 | gradient['m'] += modifier_decay * math.log(self.target_modifier / self.ruleset.cost_modifier_missiles) 84 | gradient['s'] += modifier_decay * math.log(self.target_modifier / self.ruleset.cost_modifier_storage) 85 | gradient['p'] += modifier_decay * math.log(self.target_modifier / self.ruleset.cost_modifier_shields) 86 | gradient['c'] += modifier_decay * math.log(self.target_modifier / self.ruleset.cost_modifier_constructor) 87 | gradient['e'] += modifier_decay * math.log(self.target_modifier / self.ruleset.cost_modifier_engines) 88 | gradient['size1'] += modifier_decay * math.log(self.target_modifier / self.ruleset.cost_modifier_size[0]) 89 | gradient['size2'] += modifier_decay * math.log(self.target_modifier / self.ruleset.cost_modifier_size[1]) 90 | gradient['size4'] += modifier_decay * math.log(self.target_modifier / self.ruleset.cost_modifier_size[3]) 91 | 92 | size_weighted_counts = normalize({build: count * size(build) for build, count in self.counts.items()}) 93 | average_modifier = 0.0 94 | for build, bfraction in size_weighted_counts.items(): 95 | modifier = 0.0 96 | for module, mfraction in module_norm(build).items(): 97 | if module == 'm': 98 | modifier += self.ruleset.cost_modifier_missiles * mfraction 99 | if module == 's': 100 | modifier += self.ruleset.cost_modifier_storage * mfraction 101 | if module == 'p': 102 | modifier += self.ruleset.cost_modifier_shields * mfraction 103 | if module == 'c': 104 | modifier += self.ruleset.cost_modifier_constructor * mfraction 105 | if module == 'e': 106 | modifier += self.ruleset.cost_modifier_engines * mfraction 107 | size_modifier = self.ruleset.cost_modifier_size[size(build) - 1] 108 | average_modifier += modifier * size_modifier * bfraction 109 | 110 | if average_modifier == 0: 111 | return 112 | 113 | average_cost_grad = 10 * math.log(self.target_modifier / average_modifier) 114 | for key, grad in gradient.items(): 115 | exponent = stepsize * min(10.0, max(-10.0, grad + average_cost_grad)) 116 | multiplier = math.exp(exponent) 117 | if key == 'm': 118 | self.ruleset.cost_modifier_missiles *= multiplier 119 | if key == 's': 120 | self.ruleset.cost_modifier_storage *= multiplier 121 | if key == 'p': 122 | self.ruleset.cost_modifier_shields *= multiplier 123 | if key == 'c': 124 | self.ruleset.cost_modifier_constructor *= multiplier 125 | if key == 'e': 126 | self.ruleset.cost_modifier_engines *= multiplier 127 | if key == 'size1': 128 | self.ruleset.cost_modifier_size[0] *= multiplier 129 | if key == 'size2': 130 | self.ruleset.cost_modifier_size[1] *= multiplier 131 | if key == 'size3': 132 | self.ruleset.cost_modifier_size[2] *= multiplier 133 | if key == 'size4': 134 | self.ruleset.cost_modifier_size[3] *= multiplier 135 | 136 | if step > self.hardness_offset: 137 | if self.linear_hardness: 138 | self.hardness = min((step - self.hardness_offset) * self.stepsize_hardness, self.max_hardness) 139 | else: 140 | if eplenmean is not None: 141 | self.hardness += self.stepsize_hardness * (self.target_eplenmean() - eplenmean) 142 | self.hardness = max(0.0, self.hardness) 143 | 144 | return average_modifier 145 | 146 | def metrics(self): 147 | return { 148 | 'adr_missile_cost': self.ruleset.cost_modifier_missiles, 149 | 'adr_storage_cost': self.ruleset.cost_modifier_storage, 150 | 'adr_constructor_cost': self.ruleset.cost_modifier_constructor, 151 | 'adr_engine_cost': self.ruleset.cost_modifier_engines, 152 | 'adr_shield_cost': self.ruleset.cost_modifier_shields, 153 | 'adr_size1_cost': self.ruleset.cost_modifier_size[0], 154 | 'adr_size2_cost': self.ruleset.cost_modifier_size[1], 155 | 'adr_size4_cost': self.ruleset.cost_modifier_size[3], 156 | } 157 | 158 | 159 | def size(build): 160 | modules = 0 161 | for module in [build[i:i+2] for i in range(0, len(build), 2)]: 162 | modules += int(module[:1]) 163 | return modules 164 | 165 | 166 | def module_norm(build): 167 | weights = defaultdict(lambda: 0.0) 168 | for module in [build[i:i+2] for i in range(0, len(build), 2)]: 169 | weights[module[1:]] = float(module[:1]) 170 | return normalize(weights) 171 | 172 | 173 | def normalize(weights): 174 | total = sum(weights.values()) 175 | if total == 0: 176 | total = 1 177 | return {key: weight / total for key, weight in weights.items()} 178 | 179 | -------------------------------------------------------------------------------- /codecraft.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import logging 3 | import time 4 | 5 | import orjson 6 | import numpy as np 7 | 8 | from dataclasses import dataclass, field 9 | from typing import List, Tuple 10 | 11 | 12 | RETRIES = 100 13 | 14 | 15 | @dataclass 16 | class ObsConfig: 17 | allies: int 18 | drones: int 19 | minerals: int 20 | tiles: int 21 | global_drones: int = 0 22 | relative_positions: bool = True 23 | feat_last_seen: bool = False 24 | feat_map_size: bool = False 25 | feat_is_visible: bool = False 26 | feat_abstime: bool = False 27 | v2: bool = False 28 | feat_rule_msdm: bool = False 29 | feat_rule_costs: bool = False 30 | feat_mineral_claims: bool = False 31 | harvest_action: bool = False 32 | lock_build_action: bool = False 33 | feat_dist_to_wall: bool = False 34 | 35 | def global_features(self): 36 | gf = 2 37 | if self.feat_map_size: 38 | gf += 2 39 | if self.feat_abstime: 40 | gf += 2 41 | if self.feat_rule_msdm: 42 | gf += 1 43 | if self.feat_rule_costs: 44 | gf += 9 45 | return gf 46 | 47 | def dstride(self): 48 | ds = 15 49 | if self.feat_last_seen: 50 | ds += 2 51 | if self.feat_is_visible: 52 | ds += 1 53 | if self.lock_build_action: 54 | ds += 1 55 | if self.feat_dist_to_wall: 56 | ds += 5 57 | return ds 58 | 59 | def mstride(self): 60 | return 4 if self.feat_mineral_claims else 3 61 | 62 | def tstride(self): 63 | return 4 64 | 65 | def nonobs_features(self): 66 | return 5 67 | 68 | def enemies(self): 69 | return self.drones - self.allies 70 | 71 | def total_drones(self): 72 | return 2 * self.drones - self.allies 73 | 74 | def stride(self): 75 | return self.global_features() \ 76 | + self.total_drones() * self.dstride() \ 77 | + self.minerals * self.mstride() \ 78 | + self.tiles * self.tstride() 79 | 80 | def endglobals(self): 81 | return self.global_features() 82 | 83 | def endallies(self): 84 | return self.global_features() + self.dstride() * self.allies 85 | 86 | def endenemies(self): 87 | return self.global_features() + self.dstride() * self.drones 88 | 89 | def endmins(self): 90 | return self.endenemies() + self.mstride() * self.minerals 91 | 92 | def endtiles(self): 93 | return self.endmins() + self.tstride() * self.tiles 94 | 95 | def endallenemies(self): 96 | return self.endtiles() + self.dstride() * self.enemies() 97 | 98 | def extra_actions(self): 99 | if self.lock_build_action: 100 | return 2 101 | else: 102 | return 0 103 | 104 | 105 | @dataclass 106 | class Rules: 107 | mothership_damage_multiplier: float = 1.0 108 | cost_modifier_size: List[float] = field(default_factory=lambda: [1.0, 1.0, 1.0, 1.0]) 109 | cost_modifier_missiles: float = 1.0 110 | cost_modifier_shields: float = 1.0 111 | cost_modifier_storage: float = 1.0 112 | cost_modifier_constructor: float = 1.0 113 | cost_modifier_engines: float = 1.0 114 | 115 | 116 | def create_game(game_length: int = None, 117 | action_delay: int = 0, 118 | self_play: bool = False, 119 | custom_map=None, 120 | scripted_opponent: str = 'none', 121 | rules=Rules(), 122 | allowHarvesting: bool = True, 123 | forceHarvesting: bool = True, 124 | randomizeIdle: bool = True) -> int: 125 | if custom_map is None: 126 | custom_map = '' 127 | try: 128 | if game_length: 129 | response = requests.post(f'http://localhost:9000/start-game' 130 | f'?maxTicks={game_length}' 131 | f'&actionDelay={action_delay}' 132 | f'&scriptedOpponent={scripted_opponent}' 133 | f'&mothershipDamageMultiplier={rules.mothership_damage_multiplier}' 134 | f'&costModifierSize1={rules.cost_modifier_size[0]}' 135 | f'&costModifierSize2={rules.cost_modifier_size[1]}' 136 | f'&costModifierSize3={rules.cost_modifier_size[2]}' 137 | f'&costModifierSize4={rules.cost_modifier_size[3]}' 138 | f'&costModifierConstructor={rules.cost_modifier_constructor}' 139 | f'&costModifierStorage={rules.cost_modifier_storage}' 140 | f'&costModifierShields={rules.cost_modifier_shields}' 141 | f'&costModifierMissiles={rules.cost_modifier_missiles}' 142 | f'&costModifierEngines={rules.cost_modifier_engines}' 143 | f'&allowHarvesting={scalabool(allowHarvesting)}' 144 | f'&forceHarvesting={scalabool(forceHarvesting)}' 145 | f'&randomizeIdle={scalabool(randomizeIdle)}' , 146 | json=custom_map).json() 147 | else: 148 | response = requests.post(f'http://localhost:9000/start-game?actionDelay={action_delay}').json() 149 | return int(response['id']) 150 | except requests.exceptions.ConnectionError: 151 | logging.info(f"Connection error on create_game, retrying") 152 | time.sleep(1) 153 | return create_game(game_length, action_delay, self_play) 154 | 155 | 156 | def act(game_id: int, action): 157 | retries = 100 158 | while retries > 0: 159 | try: 160 | requests.post(f'http://localhost:9000/act?gameID={game_id}&playerID=0', json=action).raise_for_status() 161 | return 162 | except requests.exceptions.ConnectionError: 163 | # For some reason, a small percentage of requests fails with 164 | # "connection error (errno 98, address already in use)" 165 | # Just retry 166 | retries -= 1 167 | logging.info(f"Connection error on act({game_id}), retrying") 168 | time.sleep(1) 169 | 170 | 171 | def act_batch(actions): 172 | payload = {} 173 | for game_id, player_id, player_actions in actions: 174 | player_actions_json = [] 175 | for move, turn, buildSpec, harvest, lockBuildAction, unlockBuildAction in player_actions: 176 | player_actions_json.append({ 177 | "buildDrone": buildSpec, 178 | "move": move, 179 | "harvest": harvest, 180 | "transfer": False, 181 | "turn": turn, 182 | "lockBuildAction": lockBuildAction, 183 | "unlockBuildAction": unlockBuildAction 184 | }) 185 | payload[f'{game_id}.{player_id}'] = player_actions_json 186 | 187 | retries = 100 188 | while retries > 0: 189 | try: 190 | requests.post( 191 | f'http://localhost:9000/batch-act', 192 | data=orjson.dumps(payload), 193 | headers={'Content-Type': 'application/json'}, 194 | ).raise_for_status() 195 | return 196 | except requests.exceptions.ConnectionError: 197 | # For some reason, a small percentage of requests fails with 198 | # "connection error (errno 98, address already in use)" 199 | # Just retry 200 | retries -= 1 201 | logging.info(f"Connection error on act_batch(), retrying") 202 | time.sleep(1) 203 | 204 | 205 | def observe(game_id: int, player_id: int = 0): 206 | try: 207 | return requests.get(f'http://localhost:9000/observation?gameID={game_id}&playerID={player_id}').json() 208 | except requests.exceptions.ConnectionError: 209 | logging.info(f"Connection error on observe({game_id}.{player_id}), retrying") 210 | time.sleep(1) 211 | return observe(game_id, player_id) 212 | 213 | 214 | def observe_batch(game_ids): 215 | retries = RETRIES 216 | while retries > 0: 217 | try: 218 | return requests.get(f'http://localhost:9000/batch-observation', json=[game_ids, []]).json() 219 | except requests.exceptions.ConnectionError: 220 | retries -= 1 221 | logging.info(f"Connection error on observe_batch(), retrying") 222 | time.sleep(10) 223 | 224 | 225 | def scalabool(b: bool) -> str: 226 | return 'true' if b else 'false' 227 | 228 | 229 | def observe_batch_raw(obs_config: ObsConfig, 230 | game_ids: List[Tuple[int, int]], 231 | allies: int, 232 | drones: int, 233 | minerals: int, 234 | global_drones: int, 235 | tiles: int, 236 | relative_positions: bool, 237 | v2: bool, 238 | extra_build_actions: List[List[int]], 239 | map_size: bool = False, 240 | last_seen: bool = False, 241 | is_visible: bool = False, 242 | abstime: bool = False, 243 | rule_msdm: bool = False, 244 | rule_costs: bool = False) -> object: 245 | retries = RETRIES 246 | ebcstr = '' 247 | url = f'http://localhost:9000/batch-observation?' \ 248 | f'json=false&' \ 249 | f'allies={allies}&' \ 250 | f'drones={drones}&' \ 251 | f'minerals={minerals}&' \ 252 | f'tiles={tiles}&' \ 253 | f'globalDrones={global_drones}&' \ 254 | f'relativePositions={"true" if relative_positions else "false"}&' \ 255 | f'lastSeen={"true" if last_seen else "false"}&' \ 256 | f'isVisible={"true" if is_visible else "false"}&' \ 257 | f'abstime={"true" if abstime else "false"}&' \ 258 | f'mapSize={"true" if map_size else "false"}&' \ 259 | f'v2={"true" if v2 else "false"}&' \ 260 | f'mineralClaims={scalabool(obs_config.feat_mineral_claims)}&' \ 261 | f'harvestAction={scalabool(obs_config.harvest_action)}&' \ 262 | f'ruleMsdm={scalabool(rule_msdm)}&' \ 263 | f'ruleCosts={scalabool(rule_costs)}&' \ 264 | f'lockBuildAction={scalabool(obs_config.lock_build_action)}&' \ 265 | f'distanceToWall={scalabool(obs_config.feat_dist_to_wall)}' + ebcstr 266 | while retries > 0: 267 | json = [game_ids, extra_build_actions] 268 | try: 269 | response = requests.get(url, 270 | json=json, 271 | stream=True) 272 | response.raise_for_status() 273 | response_bytes = response.content 274 | return np.frombuffer(response_bytes, dtype=np.float32) 275 | except requests.exceptions.ConnectionError as e: 276 | retries -= 1 277 | logging.info(f"Connection error on {url} with json={json}, retrying: {e}") 278 | time.sleep(10) 279 | 280 | 281 | def one_hot_to_action(action): 282 | # 0-5: turn/movement (4 is no turn, no movement) 283 | # 6: build [0,1,0,0,0] drone (if minerals > 5) 284 | # 7: harvest 285 | move = False 286 | harvest = False 287 | turn = 0 288 | build = [] 289 | if action == 0 or action == 1 or action == 2: 290 | move = True 291 | if action == 0 or action == 3: 292 | turn = -1 293 | if action == 2 or action == 5: 294 | turn = 1 295 | if action == 6: 296 | build = [[0, 1, 0, 0, 0]] 297 | if action == 7: 298 | harvest = True 299 | 300 | return { 301 | "buildDrone": build, 302 | "move": move, 303 | "harvest": harvest, 304 | "transfer": False, 305 | "turn": turn, 306 | } 307 | 308 | 309 | def observation_to_np(observation): 310 | o = [] 311 | x = float(observation['alliedDrones'][0]['xPos']) 312 | y = float(observation['alliedDrones'][0]['yPos']) 313 | o.append(x / 1000.0) 314 | o.append(y / 1000.0) 315 | o.append(np.sin(float(observation['alliedDrones'][0]['orientation']))) 316 | o.append(np.cos(float(observation['alliedDrones'][0]['orientation']))) 317 | o.append(float(observation['alliedDrones'][0]['storedResources']) / 50.0) 318 | o.append(1.0 if observation['alliedDrones'][0]['isConstructing'] else -1.0) 319 | o.append(1.0 if observation['alliedDrones'][0]['isHarvesting'] else -1.0) 320 | minerals = sorted(observation['minerals'], key=lambda m: dist2(m['xPos'], m['yPos'], x, y)) 321 | for m in range(0, 10): 322 | if m < len(minerals): 323 | mx = float(minerals[m]['xPos']) 324 | my = float(minerals[m]['yPos']) 325 | o.append((mx - x) / 1000.0) 326 | o.append((my - y) / 1000.0) 327 | o.append(dist(mx, my, x, y) / 1000.0) 328 | o.append(float(minerals[m]['size'] / 100.0)) 329 | else: 330 | o.extend([0.0, 0.0, 0.0, 0.0]) 331 | return np.array(o, dtype=np.float32) 332 | 333 | 334 | def dist(x1, y1, x2, y2): 335 | dx = x1 - x2 336 | dy = y1 - y2 337 | return np.sqrt(dx * dx + dy * dy) 338 | 339 | 340 | def dist2(x1, y1, x2, y2): 341 | dx = x1 - x2 342 | dy = y1 - y2 343 | return dx * dx + dy * dy 344 | -------------------------------------------------------------------------------- /gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def topk_by(values, vdim, keys, kdim, k): 5 | indices = keys.topk(k=k, dim=kdim, sorted=True).indices 6 | indices = indices.unsqueeze(-1).expand(indices.size() + values.size()[vdim+1:]) 7 | values_topk = values.gather(dim=vdim, index=indices) 8 | return values_topk 9 | 10 | 11 | def topk_and_index_by(values, vdim, keys, kdim, k): 12 | indices = keys.topk(k=k, dim=kdim, sorted=True).indices 13 | indices = indices.unsqueeze(-1).expand(indices.size() + values.size()[vdim+1:]) 14 | values_topk = values.gather(dim=vdim, index=indices) 15 | return values_topk, indices 16 | 17 | 18 | if __name__ == '__main__': 19 | a = torch.tensor([[[ 1.0, 0.0, 3.0 ], [0.0, 0.0, 2.0], [ 3.0, 3.0, 3.0], [4.0, 4.0, 4.0]], 20 | [[-1.0, 0.3, 0.23], [1.0, 0.0, -2.0], [-3.0, -0.3, -3.0], [0.44, -0.44, 4.04]]]) 21 | keys = a.sum(dim=2) 22 | topk_by(a, 1, keys, 1, 2) 23 | 24 | -------------------------------------------------------------------------------- /gym_codecraft/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cswinter/DeepCodeCraft/999c580a03a3c177f60eb6a6b8d1cc2b6dce373b/gym_codecraft/__init__.py -------------------------------------------------------------------------------- /gym_codecraft/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_codecraft.envs.codecraft_vec_env import CodeCraftVecEnv 2 | from gym_codecraft.envs.codecraft_vec_env import Objective 3 | -------------------------------------------------------------------------------- /hyper_params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | from abc import ABC, abstractmethod 4 | from typing import List, Tuple, Optional 5 | from gym_codecraft import envs 6 | 7 | 8 | class HyperParams: 9 | def __init__(self): 10 | # Optimizer 11 | self.optimizer = 'Adam' # Optimizer ("SGD" or "RMSProp" or "Adam") 12 | self.lr = 0.0003 # Learning rate 13 | self.final_lr = 0.0001 # Learning rate floor when using cosine schedule 14 | self.lr_schedule = 'none' # Learning rate schedule ("none" or "cosine") 15 | self.momentum = 0.9 # Momentum 16 | self.weight_decay = 0.0001 17 | self.bs = 2048 # Batch size during optimization 18 | self.batches_per_update = 1 # Accumulate gradients over this many batches before applying gradients 19 | self.batches_per_update_schedule = '' 20 | self.shuffle = True # Shuffle samples collected during rollout before optimization 21 | self.vf_coef = 1.0 # Weighting of value function loss in optimization objective 22 | self.entropy_bonus = 0.0 # Weighting of entropy bonus in loss function 23 | self.entropy_bonus_schedule = '' 24 | self.max_grad_norm = 20.0 # Maximum gradient norm for gradient clipping 25 | self.epochs = 2 # Number of optimizer passes over samples collected during rollout 26 | self.lr_ratios = 1.0 # Learning rate multiplier applied to earlier layers 27 | self.warmup = 0 # Learning rate is increased linearly from 0 during first n samples 28 | 29 | # Policy 30 | self.d_agent = 256 31 | self.d_item = 128 32 | self.dff_ratio = 2 33 | self.nhead = 8 34 | self.item_item_attn_layers = 0 35 | self.dropout = 0.0 # Try 0.1? 36 | self.nearby_map = True # Construct map of nearby objects populated with scatter connections 37 | self.nm_ring_width = 60 # Width of circles on nearby map 38 | self.nm_nrays = 8 # Number of rays on nearby map 39 | self.nm_nrings = 8 # Number of rings on nearby map 40 | self.map_conv = False # Whether to perform convolution on nearby map 41 | self.mc_kernel_size = 3 # Size of convolution kernel for nearby map 42 | self.map_embed_offset = False # Whether the nearby map has 2 channels corresponding to the offset of objects within the tile 43 | self.item_ff = True # Adds itemwise ff resblock after initial embedding before transformer 44 | self.agents = 1 # Max number of simultaneously controllable drones 45 | self.nally = 1 # Max number of allies observed by each drone 46 | self.nenemy = 0 # Max number of enemies observed by each drone 47 | self.nmineral = 10 # Max number of minerals observed by each drone 48 | self.ntile = 0 # Number of map tiles observed by each drone 49 | self.nconstant = 0 # Number learnable constant valued items observed by each drone 50 | self.ally_enemy_same = False # Use same weights for processing ally and enemy drones 51 | self.norm = 'layernorm' # Normalization layers ("none", "batchnorm", "layernorm") 52 | self.fp16 = False # Whether to use half-precision floating point 53 | self.zero_init_vf = True # Set all initial weights for value function head to zero 54 | self.small_init_pi = False # Set initial weights for policy head to small values and biases to zero 55 | self.rotational_invariance = True 56 | 57 | self.resume_from = '' # Filepath to saved policy 58 | 59 | # Data parallel 60 | self.rank = 0 61 | self.parallelism = 1 # Number of data parallel processes. Must be set explicitly when using schedule.py, otherwise runner.py will just spawn a single process. 62 | 63 | # Observations 64 | self.obs_allies = 10 # Max number of allied drones returned by the env 65 | self.obs_enemies = 10 # Max number of enemy drones returned by the env 66 | self.obs_minerals = 10 # Max number of minerals returned by the env 67 | self.obs_map_tiles = 10 # Max number of map tiles returned by the env 68 | self.obs_keep_abspos = False # Have features for both absolute and relative positions on each object 69 | self.use_privileged = True # Whether value function has access to hidden information 70 | self.feat_map_size = True # Global features for width/height of map 71 | self.feat_last_seen = False # Remember last position/time each enemy was seen + missile cooldown feat 72 | self.feat_is_visible = True # Feature for whether drone is currently visible 73 | self.feat_abstime = True # Global features for absolute remaining/elapsed number of timesteps 74 | self.feat_mineral_claims = False # Feature for whether another drone is currently harvesting a mineral 75 | self.harvest_action = False # Harvest action that will freeze drone until one resource has been harvested 76 | self.lock_build_action = False # Pair of actions to disable/enable all build actions 77 | self.feat_dist_to_wall = False # Five features giving distance to closest wall in movement direction, and in movement direction offset by +-pi/2 and +-pi/4 78 | 79 | # Eval 80 | self.eval_envs = 256 81 | self.eval_timesteps = 360 82 | self.eval_frequency = 1e5 83 | self.model_save_frequency = 10 84 | self.eval_symmetric = True 85 | 86 | self.extra_checkpoint_steps = [] 87 | 88 | # RL 89 | self.steps = 10e6 # Total number of timesteps 90 | self.num_envs = 64 # Number of environments 91 | self.num_self_play = 32 # Number of self-play environments (each provides two environments) 92 | self.num_vs_replicator = 0 # Number of environments played vs scripted replicator AI 93 | self.num_vs_aggro_replicator = 0 # Number of environments played vs scripted aggressive replicator AI 94 | self.num_vs_destroyer = 0 # Number of environments played vs scripted destroyer AI 95 | self.num_self_play_schedule = '' 96 | self.seq_rosteps = 256 # Number of sequential steps per rollout 97 | self.gamma = 0.99 # Discount factor 98 | self.gamma_schedule = '' 99 | self.lamb = 0.95 # Generalized advantage estimation parameter lambda 100 | self.norm_advs = True # Normalize advantage values 101 | self.rewscale = 1.0 # Scaling of reward values 102 | self.ppo = True # Use PPO-clip instead of vanilla policy gradients objective 103 | self.cliprange = 0.2 # PPO cliprange 104 | self.clip_vf = True # Use clipped value function objective 105 | self.split_reward = False # Split reward evenly amongst all active agents 106 | self.liveness_penalty = 0.0 # Negative reward applied at each timestep 107 | self.build_variety_bonus = 0.0 # Extra reward for building a drone type at least once during episode 108 | self.win_bonus = 0.0 # Reward received when winning game by eliminating opponent 109 | self.loss_penalty = 0.0 # Negative reward received when losing game by being eliminated 110 | self.partial_score = 1.0 # Instantaneous reward received from change in relative amount of resources under allied control 111 | self.attac = 0.0 # Fraction of shaped reward awarded for minimum health of enemy mothership during episode 112 | self.protec = 0.0 # Fraction of shaped reward awarded for maximum health of allied mothership during episode 113 | self.rewnorm = False # Rescale reward values by ema of mean and variance 114 | self.rewnorm_emaw = 0.97 115 | self.max_army_size_score = 9999999 116 | self.max_enemy_army_size_score = 9999999 117 | 118 | # Task/Curriculum 119 | self.objective = envs.Objective.ARENA_TINY_2V2 120 | self.action_delay = 0 121 | self.use_action_masks = True 122 | self.task_hardness = 0.0 123 | self.max_game_length = 0 # Max length of games, or default game length for map if 0. 124 | self.max_hardness = 150 # Maxiumum map area 125 | self.hardness_offset = 1e6 # Number of timesteps steps after which hardness starts to increase 126 | self.task_randomize = True 127 | self.symmetric_map = 0.0 # Percentage of maps which are symmetric 128 | self.symmetry_increase = 2e-8 # Linearly increase env symmetry parameter with this slope for every step 129 | self.mix_mp = 0.0 # Fraction of maps that use MICRO_PRACTICE instead of the main objective 130 | self.rule_rng_fraction = 0.0 # Fraction of maps that use randomize ruleset 131 | self.rule_rng_amount = 1.0 # Amount of rule randomization 132 | self.rule_cost_rng = 0.0 133 | self.adr = False # Automatically adjust environment rules 134 | self.adr_hstepsize = 2.0e-6 # Amount by which task difficulty/map size is increased for each processed frame 135 | self.linear_hardness = True # Linearly increase task difficulty/map size 136 | self.mothership_damage_scale = 4.0 137 | self.mothership_damage_scale_schedule = 'lin 50e6:1.0,150e6:0.0' 138 | self.adr_average_cost_target = 1.0 # Target value for average module cost 139 | self.adr_avg_cost_schedule = '' 140 | 141 | self.adr_variety = 0.8 142 | self.adr_variety_schedule = '60e6:0.5,120e6:0.4,140e6:0.3' 143 | 144 | # Testing 145 | self.verify_create_golden = False 146 | self.verify = False 147 | 148 | 149 | @staticmethod 150 | def micro_practice(): 151 | hps = HyperParams() 152 | hps.objective = envs.Objective.MICRO_PRACTICE 153 | 154 | hps.steps = 40e6 155 | 156 | hps.agents = 8 157 | hps.nenemy = 7 158 | hps.nally = 7 159 | hps.nmineral = 5 160 | 161 | hps.batches_per_update = 2 162 | hps.bs = 1024 163 | hps.seq_rosteps = 256 164 | hps.num_envs = 64 165 | hps.num_self_play = 32 166 | 167 | hps.eval_envs = 256 168 | hps.eval_frequency = 1e6 169 | hps.eval_timesteps = 500 170 | 171 | hps.gamma = 0.997 172 | hps.entropy_bonus = 0.001 173 | 174 | hps.symmetric_map = 0.0 175 | hps.eval_symmetric = False 176 | 177 | return hps 178 | 179 | 180 | @staticmethod 181 | def standard(): 182 | hps = HyperParams() 183 | hps.objective = envs.Objective.STANDARD 184 | 185 | hps.steps = 125e6 186 | 187 | hps.agents = 15 188 | hps.nenemy = 15 189 | hps.nally = 15 190 | hps.nmineral = 5 191 | hps.ntile = 5 192 | 193 | hps.obs_minerals = 5 194 | hps.obs_allies = 15 195 | hps.obs_map_tiles = 5 196 | hps.obs_enemies = 15 197 | hps.feat_last_seen = True 198 | hps.feat_mineral_claims = True 199 | hps.harvest_action = True 200 | hps.feat_dist_to_wall = True 201 | hps.nearby_map = False 202 | 203 | hps.lr = 0.0005 204 | hps.final_lr = 0.00005 205 | hps.lr_schedule = 'cosine' 206 | hps.win_bonus = 2.0 207 | hps.partial_score = 1.0 208 | hps.vf_coef = 1.0 209 | hps.rule_rng_fraction = 1.0 210 | hps.rule_rng_amount = 1.0 211 | hps.adr = True 212 | hps.gamma = 0.999 213 | hps.entropy_bonus = 0.2 214 | hps.entropy_bonus_schedule = 'lin 15e6:0.1,60e6:0.0' 215 | hps.mothership_damage_scale = 4.0 216 | hps.mothership_damage_scale_schedule = 'lin 50e6:0.0' 217 | hps.adr_hstepsize = 3.0e-6 218 | 219 | hps.batches_per_update = 32 220 | hps.bs = 512 221 | hps.seq_rosteps = 128 222 | hps.num_envs = 128 223 | hps.num_self_play = 64 224 | 225 | hps.model_save_frequency = 1 226 | hps.eval_envs = 128 227 | hps.eval_frequency = 5e6 228 | hps.eval_timesteps = 5000 229 | 230 | hps.extra_checkpoint_steps = [1e6, 2.5e6] 231 | 232 | return hps 233 | 234 | # Equivalent to `standard` config when run dataparallel across 2 processes. 235 | @staticmethod 236 | def standard_2dataparallel(): 237 | hps = HyperParams.standard() 238 | hps.batches_per_update //= 2 239 | hps.num_envs //= 2 240 | hps.num_self_play //= 2 241 | return hps 242 | 243 | @staticmethod 244 | def standard_dataparallel(): 245 | hps = HyperParams.standard() 246 | 247 | hps.steps = 300e6 248 | 249 | hps.batches_per_update = 16 250 | hps.num_envs = 64 251 | hps.num_self_play = 32 252 | 253 | hps.entropy_bonus_schedule = 'lin 30e6:0.1,120e6:0.0' 254 | hps.mothership_damage_scale_schedule = 'lin 100e6:0.0' 255 | hps.hardness_offset *= 2.0 256 | hps.adr_hstepsize *= 0.5 257 | hps.mothership_damage_scale_schedule = 'lin 100e6:1.0,300e6:0.0' 258 | hps.adr_variety_schedule = '120e6:0.5,240e6:0.4,280e6:0.3' 259 | 260 | return hps 261 | 262 | @staticmethod 263 | def arena(): 264 | hps = HyperParams() 265 | hps.objective = envs.Objective.ARENA 266 | 267 | hps.steps = 25e6 268 | 269 | hps.agents = 6 270 | hps.nenemy = 5 271 | hps.nally = 5 272 | hps.nmineral = 5 273 | 274 | hps.batches_per_update = 2 275 | hps.bs = 1024 276 | hps.seq_rosteps = 256 277 | hps.num_envs = 64 278 | hps.num_self_play = 32 279 | 280 | hps.eval_envs = 256 281 | hps.eval_frequency = 5e5 282 | hps.eval_timesteps = 1100 283 | 284 | hps.gamma = 0.997 285 | hps.entropy_bonus = 0.001 286 | 287 | hps.symmetric_map = 1.0 288 | hps.task_hardness = 4 289 | 290 | return hps 291 | 292 | 293 | @staticmethod 294 | def arena_medium(): 295 | hps = HyperParams() 296 | hps.objective = envs.Objective.ARENA_MEDIUM 297 | 298 | hps.steps = 50e6 299 | 300 | hps.agents = 4 301 | hps.nenemy = 5 302 | hps.nally = 5 303 | hps.nmineral = 5 304 | 305 | hps.batches_per_update = 1 306 | hps.batches_per_update_schedule = '15e6:2,30e6:4' 307 | hps.bs = 1024 308 | hps.seq_rosteps = 256 309 | hps.num_envs = 64 310 | hps.num_self_play = 32 311 | 312 | hps.model_save_frequency = 1 313 | hps.eval_envs = 512 314 | hps.eval_frequency = 5e6 315 | hps.eval_timesteps = 2000 316 | 317 | hps.gamma = 0.997 318 | hps.entropy_bonus = 0.002 319 | hps.entropy_bonus_schedule = '15e6:0.0005,30e6:0.0' 320 | 321 | hps.symmetric_map = 1.0 322 | hps.task_hardness = 0 323 | 324 | return hps 325 | 326 | @staticmethod 327 | def arena_medium_large_ms(): 328 | hps = HyperParams.arena_medium() 329 | hps.objective = envs.Objective.ARENA_MEDIUM_LARGE_MS 330 | hps.task_hardness = 1 331 | hps.win_bonus = 2 332 | hps.vf_coef = 0.5 333 | hps.rule_rng_fraction = 1.0 334 | hps.rule_rng_amount = 1.0 335 | hps.agents = 7 336 | hps.gamma = 0.999 337 | hps.eval_envs = 256 338 | hps.nenemy = 7 339 | hps.nally = 7 340 | hps.obs_allies = 15 341 | hps.obs_enemies = 15 342 | hps.batches_per_update_schedule = '20e6:2,35e6:4,45e6:8' 343 | hps.entropy_bonus = 0.01 344 | hps.entropy_bonus_schedule = '15e6:0.003,40e6:0.001' 345 | return hps 346 | 347 | @staticmethod 348 | def arena_tiny_2v2(): 349 | hps = HyperParams() 350 | hps.objective = envs.Objective.ARENA_TINY_2V2 351 | 352 | hps.steps = 25e6 353 | 354 | hps.entropy_bonus = 0.001 355 | 356 | hps.agents = 2 357 | hps.nally = 2 358 | hps.nenemy = 2 359 | hps.nmineral = 1 360 | hps.obs_allies = 2 361 | hps.obs_enemies = 2 362 | hps.obs_minerals = 1 # Could be 0, currently incompatible with ally_enemy_same=False 363 | 364 | hps.eval_envs = 256 365 | hps.eval_timesteps = 360 366 | hps.eval_frequency = 1e5 367 | hps.eval_symmetric = False 368 | 369 | return hps 370 | 371 | @staticmethod 372 | def arena_tiny(): 373 | hps = HyperParams() 374 | hps.objective = envs.Objective.ARENA_TINY 375 | 376 | hps.steps = 2e6 377 | 378 | hps.d_agent = 128 379 | hps.d_item = 64 380 | 381 | hps.agents = 1 382 | hps.nally = 1 383 | hps.nenemy = 1 384 | hps.nmineral = 1 385 | hps.obs_allies = 1 386 | hps.obs_enemies = 1 387 | hps.obs_minerals = 1 # Could be 0, currently incompatible with ally_enemy_same=False 388 | 389 | hps.eval_envs = 256 390 | hps.eval_frequency = 1e5 391 | hps.eval_timesteps = 360 392 | 393 | hps.num_envs = 64 394 | hps.num_self_play = 32 395 | hps.seq_rosteps = 256 396 | hps.eval_symmetric = False 397 | 398 | return hps 399 | 400 | 401 | @staticmethod 402 | def scout(): 403 | hps = HyperParams() 404 | hps.objective = envs.Objective.SCOUT 405 | 406 | hps.steps = 1e6 407 | 408 | hps.agents = 5 409 | hps.nenemy = 5 410 | hps.nally = 5 411 | hps.nmineral = 0 412 | hps.ntile = 5 413 | hps.obs_map_tiles = 10 414 | hps.use_privileged = False 415 | 416 | hps.batches_per_update = 1 417 | hps.bs = 256 418 | hps.seq_rosteps = 64 419 | hps.num_envs = 64 420 | hps.num_self_play = 0 421 | 422 | hps.eval_envs = 0 423 | 424 | hps.gamma = 0.99 425 | 426 | return hps 427 | 428 | 429 | @staticmethod 430 | def allied_wealth(): 431 | hps = HyperParams() 432 | hps.clip_vf = True 433 | hps.steps = 1.5e6 434 | hps.dff_ratio = 2 435 | hps.eval_envs = 0 436 | hps.gamma = 0.99 437 | hps.lamb = 0.95 438 | hps.lr = 0.0003 439 | hps.max_grad_norm = 20.0 440 | hps.momentum = 0.9 441 | hps.norm = 'layernorm' 442 | hps.norm_advs = True 443 | hps.num_envs = 64 444 | hps.num_self_play = 0 445 | hps.objective = envs.Objective.ALLIED_WEALTH 446 | hps.nally = 1 447 | hps.nmineral = 10 448 | hps.obs_allies = 1 449 | hps.obs_map_tiles = 0 450 | hps.obs_enemies = 0 451 | hps.obs_global_drones = 0 452 | hps.optimizer = 'Adam' 453 | hps.epochs = 2 454 | hps.small_init_pi = False 455 | hps.transformer_layers = 1 456 | hps.use_action_masks = True 457 | hps.use_privileged = False 458 | hps.vf_coef = 1.0 459 | hps.weight_decay = 0.0001 460 | hps.zero_init_vf = True 461 | 462 | return hps 463 | 464 | 465 | @staticmethod 466 | def distance_to_origin(): 467 | hps = HyperParams() 468 | hps.objective = envs.Objective.DISTANCE_TO_ORIGIN 469 | hps.num_self_play = 0 470 | hps.eval_envs = 0 471 | hps.agents = 1 472 | hps.obs_allies = 1 473 | hps.obs_enemies = 0 474 | hps.use_privileged = False 475 | 476 | return hps 477 | 478 | 479 | @staticmethod 480 | def distance_to_mineral(): 481 | hps = HyperParams() 482 | hps.objective = envs.Objective.DISTANCE_TO_CRYSTAL 483 | hps.num_self_play = 0 484 | hps.eval_envs = 0 485 | hps.agents = 1 486 | hps.obs_allies = 1 487 | hps.obs_enemies = 0 488 | hps.use_privileged = False 489 | 490 | return hps 491 | 492 | 493 | @property 494 | def rosteps(self): 495 | return self.num_envs * self.seq_rosteps 496 | 497 | def get_num_self_play_schedule(self): 498 | return parse_int_schedule(self.num_self_play_schedule) 499 | 500 | def get_entropy_bonus_schedule(self): 501 | return parse_float_schedule(self.entropy_bonus_schedule) 502 | 503 | def get_batches_per_update_schedule(self): 504 | return parse_int_schedule(self.batches_per_update_schedule) 505 | 506 | def get_variety_schedule(self) -> List[Tuple[float, float]]: 507 | return parse_float_schedule(self.adr_variety_schedule) 508 | 509 | def args_parser(self) -> argparse.ArgumentParser: 510 | parser = argparse.ArgumentParser() 511 | for name, value in vars(self).items(): 512 | if isinstance(value, bool): 513 | parser.add_argument(f"--no-{name}", action='store_const', const=False, dest=name) 514 | parser.add_argument(f"--{name}", action='store_const', const=True, dest=name) 515 | else: 516 | parser.add_argument(f"--{name}", type=type(value)) 517 | return parser 518 | 519 | 520 | class HPSchedule(ABC): 521 | @abstractmethod 522 | def value_at(self, step: int) -> float: 523 | pass 524 | 525 | 526 | class LinearHPSchedule(HPSchedule): 527 | def __init__(self, segments: List[Tuple[int, float]]): 528 | self.segments = segments 529 | 530 | def value_at(self, step: int) -> float: 531 | left, right = find_adjacent(self.segments, step) 532 | if right is None: 533 | return left[1] 534 | return left[1] + (step - left[0]) * (right[1] - left[1]) / (right[0] - left[0]) 535 | 536 | 537 | class CosineSchedule(HPSchedule): 538 | def __init__(self, initial_value: float, final_value: float, steps: int): 539 | self.initial_value = initial_value 540 | self.final_value = final_value 541 | self.steps = steps 542 | 543 | def value_at(self, step: int) -> float: 544 | return (self.initial_value - self.final_value) * 0.5 * (math.cos(math.pi * step / self.steps) + 1) \ 545 | + self.final_value 546 | 547 | 548 | class StepHPSchedule(HPSchedule): 549 | def __init__(self, segments: List[Tuple[int, float]]): 550 | self.segments = segments 551 | 552 | def value_at(self, step: int) -> float: 553 | left, _ = find_adjacent(self.segments, step) 554 | return left[1] 555 | 556 | 557 | class ConstantSchedule(HPSchedule): 558 | def __init__(self, value): 559 | self.value = value 560 | 561 | def value_at(self, step) -> float: 562 | return self.value 563 | 564 | 565 | def parse_schedule(schedule: str, initial_value: float, steps: int) -> HPSchedule: 566 | if schedule == '': 567 | return ConstantSchedule(initial_value) 568 | elif schedule.startswith('lin '): 569 | segments = [(0, initial_value)] 570 | for kv in schedule[len('lin '):].split(","): 571 | [k, v] = kv.split(":") 572 | segments.append((int(float(k)), float(v))) 573 | return LinearHPSchedule(segments) 574 | elif schedule.startswith('cos'): 575 | if schedule == 'cos': 576 | final_value = 0.0 577 | else: 578 | final_value = float(schedule[len('cos '):]) 579 | return CosineSchedule(initial_value, final_value, steps) 580 | else: 581 | segments = [(0, initial_value)] 582 | for kv in schedule.split(","): 583 | [k, v] = kv.split(":") 584 | segments.append((int(float(k)), float(v))) 585 | return StepHPSchedule(segments) 586 | 587 | 588 | def find_adjacent(segments: List[Tuple[int, float]], step: int) -> Tuple[Tuple[int, float], Optional[Tuple[int, float]]]: 589 | left = None 590 | right: Optional[Tuple[int, float]] = None 591 | for s, v in segments: 592 | if s <= step: 593 | left = (s, v) 594 | if step < s: 595 | right = (s, v) 596 | break 597 | assert left is not None, f"invalid inputs to find_adjacent: segments={segments}, step={step}" 598 | return left, right 599 | 600 | 601 | def parse_int_schedule(schedule): 602 | if schedule == '': 603 | return [] 604 | else: 605 | items = [] 606 | for kv in schedule.split(","): 607 | [k, v] = kv.split(":") 608 | items.append((float(k), int(v))) 609 | return list(reversed(items)) 610 | 611 | 612 | def parse_float_schedule(schedule) -> List[Tuple[float, float]]: 613 | if schedule == '': 614 | return [] 615 | else: 616 | items = [] 617 | for kv in schedule.split(","): 618 | [k, v] = kv.split(":") 619 | items.append((float(k), float(v))) 620 | return list(reversed(items)) 621 | -------------------------------------------------------------------------------- /list_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ListNet(nn.Module): 7 | def __init__(self, in_features, width, items, groups, pooling, norm, resblocks=1): 8 | super(ListNet, self).__init__() 9 | 10 | assert(pooling in ['max', 'avg', 'both']) 11 | 12 | self.in_features = in_features 13 | self.width = width // 2 if pooling == 'both' else width 14 | self.output_width = width 15 | self.items = items 16 | self.groups = groups 17 | self.pooling = pooling 18 | self.norm = norm 19 | 20 | self.layer0 = nn.Conv1d(in_channels=1, out_channels=self.width, kernel_size=in_features) 21 | 22 | if norm == 'none': 23 | self.layer0_norm = nn.Sequential() 24 | elif norm == 'batchnorm': 25 | self.layer0_norm = nn.BatchNorm1d(self.width) 26 | elif norm == 'layernorm': 27 | self.layer0_norm = nn.LayerNorm([self.width, 1]) 28 | else: 29 | raise Exception(f'Unexpected normalization layer {norm}') 30 | 31 | self.net = nn.Sequential( 32 | *[ResBlock(self.width, norm) for _ in range(resblocks)] 33 | ) 34 | 35 | def forward(self, x): 36 | batch_size = x.shape[0] 37 | x = x.reshape(batch_size * self.groups * self.items, 1, self.in_features) 38 | x = self.layer0_norm(F.relu(self.layer0(x))) 39 | x = self.net(x) 40 | x = x.view(batch_size, self.groups, self.items, self.width) 41 | x = x.permute(0, 1, 3, 2).reshape(batch_size * self.groups, self.width, self.items) 42 | 43 | if self.pooling == 'max': 44 | x = F.max_pool1d(x, kernel_size=self.items) 45 | elif self.pooling == 'avg': 46 | x = F.avg_pool1d(x, kernel_size=self.items) 47 | elif self.pooling == 'both': 48 | x_max = F.max_pool1d(x, kernel_size=self.items) 49 | x_avg = F.avg_pool1d(x, kernel_size=self.items) 50 | x = torch.cat([x_max, x_avg], dim=1) 51 | else: 52 | raise Exception(f'Invalid pooling variant {self.pooling}') 53 | 54 | return x.reshape(batch_size, self.groups, self.output_width, 1).permute(0, 2, 1, 3) 55 | 56 | 57 | class ResBlock(nn.Module): 58 | def __init__(self, channels, norm): 59 | super(ResBlock, self).__init__() 60 | if norm == 'none': 61 | self.convs = nn.Sequential( 62 | nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1), 63 | nn.ReLU(), 64 | nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1), 65 | nn.ReLU(), 66 | ) 67 | elif norm == 'batchnorm': 68 | self.convs = nn.Sequential( 69 | nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1), 70 | nn.ReLU(), 71 | nn.BatchNorm1d(channels), 72 | nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1), 73 | nn.ReLU(), 74 | nn.BatchNorm2d(channels), 75 | ) 76 | elif norm == 'layernorm': 77 | self.convs = nn.Sequential( 78 | nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1), 79 | nn.ReLU(), 80 | nn.LayerNorm([channels, 1]), 81 | nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=1), 82 | nn.ReLU(), 83 | nn.LayerNorm([channels, 1]), 84 | ) 85 | else: 86 | raise Exception(f'Unexpected normalization layer {norm}') 87 | 88 | def forward(self, x): 89 | return x + self.convs(x) 90 | 91 | -------------------------------------------------------------------------------- /multihead_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear, Parameter, Module 3 | from torch.nn.functional import softmax, dropout, linear 4 | from torch.nn.init import xavier_uniform_, constant_, xavier_normal_ 5 | 6 | 7 | # Copied from https://github.com/pytorch/pytorch/blob/458353b5b6c5ebbc43a1e52faef732d9f2c64671/torch/nn/functional.py#L3138 8 | # Modified to return full attention weights tensor, rather than attention weights averaged over all heads 9 | def multi_head_attention_forward(query, # type: Tensor 10 | key, # type: Tensor 11 | value, # type: Tensor 12 | embed_dim_to_check, # type: int 13 | num_heads, # type: int 14 | in_proj_weight, # type: Tensor 15 | in_proj_bias, # type: Tensor 16 | bias_k, # type: Optional[Tensor] 17 | bias_v, # type: Optional[Tensor] 18 | add_zero_attn, # type: bool 19 | dropout_p, # type: float 20 | out_proj_weight, # type: Tensor 21 | out_proj_bias, # type: Tensor 22 | training=True, # type: bool 23 | key_padding_mask=None, # type: Optional[Tensor] 24 | need_weights=True, # type: bool 25 | attn_mask=None, # type: Optional[Tensor] 26 | use_separate_proj_weight=False, # type: bool 27 | q_proj_weight=None, # type: Optional[Tensor] 28 | k_proj_weight=None, # type: Optional[Tensor] 29 | v_proj_weight=None, # type: Optional[Tensor] 30 | static_k=None, # type: Optional[Tensor] 31 | static_v=None # type: Optional[Tensor] 32 | ): 33 | # type: (...) -> Tuple[Tensor, Optional[Tensor]] 34 | r""" 35 | Args: 36 | query, key, value: map a query and a set of key-value pairs to an output. 37 | See "Attention Is All You Need" for more details. 38 | embed_dim_to_check: total dimension of the model. 39 | num_heads: parallel attention heads. 40 | in_proj_weight, in_proj_bias: input projection weight and bias. 41 | bias_k, bias_v: bias of the key and value sequences to be added at dim=0. 42 | add_zero_attn: add a new batch of zeros to the key and 43 | value sequences at dim=1. 44 | dropout_p: probability of an element to be zeroed. 45 | out_proj_weight, out_proj_bias: the output projection weight and bias. 46 | training: apply dropout if is ``True``. 47 | key_padding_mask: if provided, specified padding elements in the key will 48 | be ignored by the attention. This is an binary mask. When the value is True, 49 | the corresponding value on the attention layer will be filled with -inf. 50 | need_weights: output attn_output_weights. 51 | attn_mask: mask that prevents attention to certain positions. This is an additive mask 52 | (i.e. the values will be added to the attention layer). 53 | use_separate_proj_weight: the function accept the proj. weights for query, key, 54 | and value in differnt forms. If false, in_proj_weight will be used, which is 55 | a combination of q_proj_weight, k_proj_weight, v_proj_weight. 56 | q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. 57 | static_k, static_v: static key and value used for attention operators. 58 | Shape: 59 | Inputs: 60 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 61 | the embedding dimension. 62 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 63 | the embedding dimension. 64 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 65 | the embedding dimension. 66 | - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. 67 | - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 68 | - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 69 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 70 | - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 71 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 72 | Outputs: 73 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 74 | E is the embedding dimension. 75 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 76 | L is the target sequence length, S is the source sequence length. 77 | """ 78 | 79 | tgt_len, bsz, embed_dim = query.size() 80 | assert embed_dim == embed_dim_to_check 81 | assert key.size() == value.size() 82 | 83 | head_dim = embed_dim // num_heads 84 | assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" 85 | scaling = float(head_dim) ** -0.5 86 | 87 | if not use_separate_proj_weight: 88 | if torch.equal(query, key) and torch.equal(key, value): 89 | # self-attention 90 | q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) 91 | 92 | elif torch.equal(key, value): 93 | # encoder-decoder attention 94 | # This is inline in_proj function with in_proj_weight and in_proj_bias 95 | _b = in_proj_bias 96 | _start = 0 97 | _end = embed_dim 98 | _w = in_proj_weight[_start:_end, :] 99 | if _b is not None: 100 | _b = _b[_start:_end] 101 | q = linear(query, _w, _b) 102 | 103 | if key is None: 104 | assert value is None 105 | k = None 106 | v = None 107 | else: 108 | 109 | # This is inline in_proj function with in_proj_weight and in_proj_bias 110 | _b = in_proj_bias 111 | _start = embed_dim 112 | _end = None 113 | _w = in_proj_weight[_start:, :] 114 | if _b is not None: 115 | _b = _b[_start:] 116 | k, v = linear(key, _w, _b).chunk(2, dim=-1) 117 | 118 | else: 119 | # This is inline in_proj function with in_proj_weight and in_proj_bias 120 | _b = in_proj_bias 121 | _start = 0 122 | _end = embed_dim 123 | _w = in_proj_weight[_start:_end, :] 124 | if _b is not None: 125 | _b = _b[_start:_end] 126 | q = linear(query, _w, _b) 127 | 128 | # This is inline in_proj function with in_proj_weight and in_proj_bias 129 | _b = in_proj_bias 130 | _start = embed_dim 131 | _end = embed_dim * 2 132 | _w = in_proj_weight[_start:_end, :] 133 | if _b is not None: 134 | _b = _b[_start:_end] 135 | k = linear(key, _w, _b) 136 | 137 | # This is inline in_proj function with in_proj_weight and in_proj_bias 138 | _b = in_proj_bias 139 | _start = embed_dim * 2 140 | _end = None 141 | _w = in_proj_weight[_start:, :] 142 | if _b is not None: 143 | _b = _b[_start:] 144 | v = linear(value, _w, _b) 145 | else: 146 | q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) 147 | len1, len2 = q_proj_weight_non_opt.size() 148 | assert len1 == embed_dim and len2 == query.size(-1) 149 | 150 | k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) 151 | len1, len2 = k_proj_weight_non_opt.size() 152 | assert len1 == embed_dim and len2 == key.size(-1) 153 | 154 | v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) 155 | len1, len2 = v_proj_weight_non_opt.size() 156 | assert len1 == embed_dim and len2 == value.size(-1) 157 | 158 | if in_proj_bias is not None: 159 | q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) 160 | k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) 161 | v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) 162 | else: 163 | q = linear(query, q_proj_weight_non_opt, in_proj_bias) 164 | k = linear(key, k_proj_weight_non_opt, in_proj_bias) 165 | v = linear(value, v_proj_weight_non_opt, in_proj_bias) 166 | q = q * scaling 167 | 168 | if bias_k is not None and bias_v is not None: 169 | if static_k is None and static_v is None: 170 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 171 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 172 | if attn_mask is not None: 173 | attn_mask = torch.cat([attn_mask, 174 | torch.zeros((attn_mask.size(0), 1), 175 | dtype=attn_mask.dtype, 176 | device=attn_mask.device)], dim=1) 177 | if key_padding_mask is not None: 178 | key_padding_mask = torch.cat( 179 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 180 | dtype=key_padding_mask.dtype, 181 | device=key_padding_mask.device)], dim=1) 182 | else: 183 | assert static_k is None, "bias cannot be added to static key." 184 | assert static_v is None, "bias cannot be added to static value." 185 | else: 186 | assert bias_k is None 187 | assert bias_v is None 188 | 189 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 190 | if k is not None: 191 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 192 | if v is not None: 193 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 194 | 195 | if static_k is not None: 196 | assert static_k.size(0) == bsz * num_heads 197 | assert static_k.size(2) == head_dim 198 | k = static_k 199 | 200 | if static_v is not None: 201 | assert static_v.size(0) == bsz * num_heads 202 | assert static_v.size(2) == head_dim 203 | v = static_v 204 | 205 | src_len = k.size(1) 206 | 207 | if key_padding_mask is not None: 208 | assert key_padding_mask.size(0) == bsz 209 | assert key_padding_mask.size(1) == src_len 210 | 211 | if add_zero_attn: 212 | src_len += 1 213 | k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) 214 | v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) 215 | if attn_mask is not None: 216 | attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1), 217 | dtype=attn_mask.dtype, 218 | device=attn_mask.device)], dim=1) 219 | if key_padding_mask is not None: 220 | key_padding_mask = torch.cat( 221 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 222 | dtype=key_padding_mask.dtype, 223 | device=key_padding_mask.device)], dim=1) 224 | 225 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 226 | assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] 227 | 228 | if attn_mask is not None: 229 | attn_mask = attn_mask.unsqueeze(0) 230 | attn_output_weights += attn_mask 231 | 232 | if key_padding_mask is not None: 233 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 234 | attn_output_weights = attn_output_weights.masked_fill( 235 | key_padding_mask.unsqueeze(1).unsqueeze(2), 236 | float('-inf'), 237 | ) 238 | attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) 239 | 240 | attn_output_weights = softmax( 241 | attn_output_weights, dim=-1) 242 | attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) 243 | 244 | attn_output = torch.bmm(attn_output_weights, v) 245 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] 246 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 247 | attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 248 | 249 | if need_weights: 250 | # ORIGINAL CODE 251 | # average attention weights over heads 252 | # attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 253 | # return attn_output, attn_output_weights.sum(dim=1) / num_heads 254 | 255 | # Return full attention weights instead 256 | return attn_output, attn_output_weights 257 | else: 258 | return attn_output, None 259 | 260 | 261 | # Copied from https://github.com/pytorch/pytorch/blob/v1.4.0/torch/nn/modules/activation.py#L673 262 | class MultiheadAttention(Module): 263 | r"""Allows the model to jointly attend to information 264 | from different representation subspaces. 265 | See reference: Attention Is All You Need 266 | 267 | .. math:: 268 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 269 | \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) 270 | 271 | Args: 272 | embed_dim: total dimension of the model. 273 | num_heads: parallel attention heads. 274 | dropout: a Dropout layer on attn_output_weights. Default: 0.0. 275 | bias: add bias as module parameter. Default: True. 276 | add_bias_kv: add bias to the key and value sequences at dim=0. 277 | add_zero_attn: add a new batch of zeros to the key and 278 | value sequences at dim=1. 279 | kdim: total number of features in key. Default: None. 280 | vdim: total number of features in key. Default: None. 281 | 282 | Note: if kdim and vdim are None, they will be set to embed_dim such that 283 | query, key, and value have the same number of features. 284 | 285 | Examples:: 286 | 287 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 288 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 289 | """ 290 | __annotations__ = { 291 | 'bias_k': torch._jit_internal.Optional[torch.Tensor], 292 | 'bias_v': torch._jit_internal.Optional[torch.Tensor], 293 | } 294 | __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'] 295 | 296 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): 297 | super(MultiheadAttention, self).__init__() 298 | self.embed_dim = embed_dim 299 | self.kdim = kdim if kdim is not None else embed_dim 300 | self.vdim = vdim if vdim is not None else embed_dim 301 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 302 | 303 | self.num_heads = num_heads 304 | self.dropout = dropout 305 | self.head_dim = embed_dim // num_heads 306 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 307 | 308 | if self._qkv_same_embed_dim is False: 309 | self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) 310 | self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) 311 | self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) 312 | self.register_parameter('in_proj_weight', None) 313 | else: 314 | self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) 315 | self.register_parameter('q_proj_weight', None) 316 | self.register_parameter('k_proj_weight', None) 317 | self.register_parameter('v_proj_weight', None) 318 | 319 | if bias: 320 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) 321 | else: 322 | self.register_parameter('in_proj_bias', None) 323 | self.out_proj = Linear(embed_dim, embed_dim, bias=bias) 324 | 325 | if add_bias_kv: 326 | self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) 327 | self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) 328 | else: 329 | self.bias_k = self.bias_v = None 330 | 331 | self.add_zero_attn = add_zero_attn 332 | 333 | self._reset_parameters() 334 | 335 | def _reset_parameters(self): 336 | if self._qkv_same_embed_dim: 337 | xavier_uniform_(self.in_proj_weight) 338 | else: 339 | xavier_uniform_(self.q_proj_weight) 340 | xavier_uniform_(self.k_proj_weight) 341 | xavier_uniform_(self.v_proj_weight) 342 | 343 | if self.in_proj_bias is not None: 344 | constant_(self.in_proj_bias, 0.) 345 | constant_(self.out_proj.bias, 0.) 346 | if self.bias_k is not None: 347 | xavier_normal_(self.bias_k) 348 | if self.bias_v is not None: 349 | xavier_normal_(self.bias_v) 350 | 351 | def __setstate__(self, state): 352 | super(MultiheadAttention, self).__setstate__(state) 353 | 354 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 355 | if 'self._qkv_same_embed_dim' not in self.__dict__: 356 | self._qkv_same_embed_dim = True 357 | 358 | def forward(self, query, key, value, key_padding_mask=None, 359 | need_weights=True, attn_mask=None): 360 | # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] 361 | r""" 362 | Args: 363 | query, key, value: map a query and a set of key-value pairs to an output. 364 | See "Attention Is All You Need" for more details. 365 | key_padding_mask: if provided, specified padding elements in the key will 366 | be ignored by the attention. This is an binary mask. When the value is True, 367 | the corresponding value on the attention layer will be filled with -inf. 368 | need_weights: output attn_output_weights. 369 | attn_mask: mask that prevents attention to certain positions. This is an additive mask 370 | (i.e. the values will be added to the attention layer). 371 | 372 | Shape: 373 | - Inputs: 374 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 375 | the embedding dimension. 376 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 377 | the embedding dimension. 378 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 379 | the embedding dimension. 380 | - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. 381 | - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 382 | 383 | - Outputs: 384 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 385 | E is the embedding dimension. 386 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 387 | L is the target sequence length, S is the source sequence length. 388 | """ 389 | if not self._qkv_same_embed_dim: 390 | return multi_head_attention_forward( 391 | query, key, value, self.embed_dim, self.num_heads, 392 | self.in_proj_weight, self.in_proj_bias, 393 | self.bias_k, self.bias_v, self.add_zero_attn, 394 | self.dropout, self.out_proj.weight, self.out_proj.bias, 395 | training=self.training, 396 | key_padding_mask=key_padding_mask, need_weights=need_weights, 397 | attn_mask=attn_mask, use_separate_proj_weight=True, 398 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 399 | v_proj_weight=self.v_proj_weight) 400 | else: 401 | return multi_head_attention_forward( 402 | query, key, value, self.embed_dim, self.num_heads, 403 | self.in_proj_weight, self.in_proj_bias, 404 | self.bias_k, self.bias_v, self.add_zero_attn, 405 | self.dropout, self.out_proj.weight, self.out_proj.bias, 406 | training=self.training, 407 | key_padding_mask=key_padding_mask, need_weights=need_weights, 408 | attn_mask=attn_mask) 409 | 410 | -------------------------------------------------------------------------------- /plot_results.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import os 3 | import math 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from typing import Tuple, List, Dict, Union 8 | from functools import lru_cache 9 | from dataclasses import dataclass 10 | 11 | EVAL_METRICS = { 12 | "Replicator": "eval_mean_score_vs_replicator", 13 | "Destroyer": "eval_mean_score_vs_destroyer", 14 | "Curious Galaxy 40M": "eval_mean_score_vs_curious-galaxy-40", 15 | "Graceful Frog 100M": "eval_mean_score_vs_graceful-frog-100", 16 | } 17 | 18 | @dataclass 19 | class Experiment: 20 | descriptor: str 21 | label: str 22 | 23 | @lru_cache(maxsize=None) 24 | def fetch_run_data(descriptor: str, metrics: Union[List[str], str]) -> List[Tuple[np.array, np.array]]: 25 | if isinstance(metrics, str): 26 | metrics = [metrics] 27 | else: 28 | metrics = list(metrics) 29 | api = wandb.Api() 30 | runs = api.runs("cswinter/deep-codecraft-ablations", {"config.descriptor": descriptor}) 31 | 32 | curves = [] 33 | for run in runs: 34 | step = [] 35 | value = [] 36 | vals = run.history(keys=metrics, samples=100, pandas=False) 37 | for entry in vals: 38 | if metrics[0] in entry: 39 | step.append(entry['_step'] * 1e-6) 40 | meanvalue = np.array([entry[metric] for metric in metrics]).mean() 41 | value.append(meanvalue) 42 | curves.append((np.array(step), np.array(value))) 43 | return curves 44 | 45 | def final_score(descriptor: str) -> Tuple[float, float]: 46 | runs = fetch_run_data(descriptor, tuple(EVAL_METRICS.values())) 47 | runs = [run for run in runs if len(run[0]) == 26] 48 | if len(runs) < 8: 49 | print(f"Only {len(runs)} for {descriptor}") 50 | values = np.array([[run[1][i] for run in runs] for i in range(len(runs[0][0]))]) 51 | return values.mean(axis=1)[-1], (values.std(axis=1, ddof=1)/math.sqrt(len(runs)))[-1] 52 | 53 | def errplot3(ax, xps: List[Experiment], metrics: Union[List[str], str], title: str): 54 | colors = ['tab:blue', 'tab:orange'] 55 | markers = ['x', '+'] 56 | 57 | for i, xp in enumerate(xps): 58 | curves = fetch_run_data(xp.descriptor, metrics) 59 | curves = [curve for curve in curves if len(curve[0]) == 26] 60 | 61 | samples = curves[0][0] 62 | values = np.array([[curve[1][i] for curve in curves] for i in range(len(samples))]) 63 | ax.errorbar( 64 | samples, 65 | values.mean(axis=1), 66 | yerr=values.std(axis=1, ddof=1)/math.sqrt(len(curves)), 67 | color=colors[i], 68 | alpha=0.75, 69 | capsize=3, 70 | capthick=1, 71 | linestyle=":", 72 | label=xp.label, 73 | ) 74 | ax.fill_between( 75 | samples, 76 | values.min(axis=1), 77 | values.max(axis=1), 78 | alpha=.25 79 | ) 80 | 81 | 82 | ax.set(xlabel='million samples', ylabel='eval score', title=title, xlim=(0, 125.35), ylim=(-1, 1)) 83 | ax.set_yticks([-1.0, -0.5, 0, 0.5, 1]) 84 | ax.set_xticks([0, 25, 50, 75, 100, 125]) 85 | ax.legend(loc='upper left') 86 | ax.grid() 87 | 88 | 89 | def plot_drone_types(ax, runid: str): 90 | run = wandb.Api().run(f'cswinter/deep-codecraft-ablations/{runid}') 91 | frac_metrics = sorted([key for key in run.summary.keys() if key.startswith('frac')]) 92 | fracs = [ 93 | {row['_step']: row[fm] for row in run.scan_history(keys=['_step', fm], page_size=int(1e9))} 94 | for fm in frac_metrics 95 | ] 96 | steps = set() 97 | for frac in fracs: 98 | steps.update(frac.keys()) 99 | steps = sorted(list(steps)) 100 | fixed_fracs = [] 101 | for frac in fracs: 102 | fixed_frac = [] 103 | for step in steps: 104 | fixed_frac.append(frac.get(step, 0.0)) 105 | fixed_fracs.append(np.array(fixed_frac)) 106 | 107 | binned_fracs = [] 108 | bins = np.linspace(0, 125e6, 250) 109 | digitized = np.digitize(steps, bins) 110 | for frac in fixed_fracs: 111 | bin_means = [frac[digitized == i].mean() for i in range(1, len(bins)+1)] 112 | binned_fracs.append(bin_means) 113 | 114 | labels = [m[len('frac_'):] for m in frac_metrics] 115 | ax.stackplot([0.5 * i + 0.25 for i in range(250)], binned_fracs, labels=labels) 116 | ax.set(xlabel='million samples', ylabel='drone type fraction', xlim=(0, 125), ylim=(0, 1), title=' '.join(run.name.split('-')[:2])) 117 | ax.set_xticks([0, 25, 50, 75, 100, 125]) 118 | ax.legend(reversed(plt.legend().legendHandles), reversed(labels), loc='lower right') 119 | 120 | 121 | def plot2dt(runid1: str, runid2: str): 122 | fig, axs = plt.subplots(1, 2, figsize=(12, 6)) 123 | plot_drone_types(axs[0], runid1) 124 | plot_drone_types(axs[1], runid2) 125 | fig.savefig(f"plots/dronetypes.svg") 126 | fig.savefig(f"plotspng/dronetypes.png") 127 | plt.show() 128 | 129 | def plot3dt(runid1: str, runid2: str, runid3: str): 130 | fig, axs = plt.subplots(1, 3, figsize=(18, 6)) 131 | plot_drone_types(axs[0], runid1) 132 | plot_drone_types(axs[1], runid2) 133 | plot_drone_types(axs[2], runid3) 134 | fig.savefig(f"plots/dronetypes3.svg") 135 | fig.savefig(f"plotspng/dronetypes3.png") 136 | plt.show() 137 | 138 | 139 | def plot(xps: List[Experiment], metrics: List[str], title: str, name: str): 140 | fig, ax = plt.subplots(figsize=(12, 9)) 141 | errplot3(ax, xps, metrics, title) 142 | fig.savefig(f"plots/{name}.svg") 143 | fig.savefig(f"plotspng/{name}.png") 144 | plt.show() 145 | 146 | 147 | def plot4(descriptors: List[str], metrics: Dict[str, str], name: str): 148 | assert len(metrics) == 4 149 | 150 | fig, axs = plt.subplots(2, 2, figsize=(12, 9)) 151 | #fig.suptitle(name) 152 | for i, (metric_title, metric_name) in enumerate(metrics.items()): 153 | print(f"{i}/{len(metrics)} {len(axs)}") 154 | errplot3(axs[i // 2, i % 2], descriptors, metric_name, metric_title) 155 | fig.savefig(f"plots/{name}.svg") 156 | fig.savefig(f"plotspng/{name}.png") 157 | plt.show() 158 | 159 | 160 | if __name__ == '__main__': 161 | if not os.path.exists('plots'): 162 | os.makedirs('plots') 163 | if not os.path.exists('plotspng'): 164 | os.makedirs('plotspng') 165 | 166 | plot2dt('i17gv7pw', 'sidk0gu4') 167 | plot3dt('lbspx7ok', 'i17gv7pw', 'sidk0gu4') 168 | 169 | baseline = Experiment(descriptor="f2034f-hpsetstandard", label="baseline") 170 | adr_ablations = [ 171 | Experiment("f2034f-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "module cost, map curriculum"), 172 | Experiment("f2034f-adr_variety0.0-adr_variety_schedule-hpsetstandard", "mothership damage, map curriculum"), 173 | Experiment("f2034f-adr_variety0.0-adr_variety_schedule-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "map curriculum"), 174 | 175 | Experiment("f2034f-adr_hstepsize0.0-hpsetstandard-linear_hardnessFalse-task_hardness150", "mothership damage, module cost, map randomization"), 176 | Experiment("f2034f-adr_hstepsize0.0-hpsetstandard-linear_hardnessFalse-mothership_damage_scale0.0-mothership_damage_scale_schedule-task_hardness150", "module cost, map randomization"), 177 | Experiment("f2034f-adr_hstepsize0.0-adr_variety0.0-adr_variety_schedule-hpsetstandard-linear_hardnessFalse-task_hardness150", "mothership damage, map randomization"), 178 | Experiment("f2034f-adr_hstepsize0.0-adr_variety0.0-adr_variety_schedule-hpsetstandard-linear_hardnessFalse-mothership_damage_scale0.0-mothership_damage_scale_schedule-task_hardness150", "map randomization"), 179 | 180 | Experiment("049430-batches_per_update64-bs256-hpsetstandard", "mothership damage, module cost, fixed map"), 181 | Experiment("049430-batches_per_update64-bs256-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "module cost, fixed map"), 182 | Experiment("049430-adr_variety0.0-adr_variety_schedule-batches_per_update64-bs256-hpsetstandard", "mothership damage, fixed map"), 183 | Experiment("049430-adr_variety0.0-adr_variety_schedule-batches_per_update64-bs256-hpsetstandard-mothership_damage_scale0.0-mothership_damage_scale_schedule", "fixed map"), 184 | 185 | Experiment("d06bdd-hpsetstandard", "mothership damage, random module cost, map curriculum"), 186 | ] 187 | ablations = [ 188 | Experiment("f2034f-hpsetstandard-partial_score0.0", "sparse reward"), 189 | Experiment("f2034f-hpsetstandard-use_privilegedFalse", "non-omniscient value function"), 190 | Experiment("f2034f-d_agent128-d_item64-hpsetstandard", "smaller network"), 191 | Experiment("f2034f-batches_per_update64-bs256-hpsetstandard-rotational_invarianceFalse", "no rotational invariance"), 192 | Experiment("7a9d92-hpsetstandard", "no shared spatial embeddings"), 193 | *adr_ablations, 194 | ] 195 | 196 | 197 | for xp in [baseline] + adr_ablations: 198 | label = xp.label 199 | score_mean, score_sem = final_score(xp.descriptor) 200 | print(f"{label} {score_mean} {score_sem}") 201 | 202 | plot([baseline], tuple(EVAL_METRICS.values()), "Mean score against all opponents", "baseline") 203 | plot4([baseline], EVAL_METRICS, "breakdown") 204 | plot4([baseline, ablations[3]], EVAL_METRICS, "breakdown cost adr") 205 | 206 | 207 | for xp in ablations: 208 | print(f"plotting {xp.label}") 209 | plot([baseline, xp], tuple(EVAL_METRICS.values()), "Mean score against all opponents", xp.label) 210 | plot4([baseline, xp], EVAL_METRICS, f"breakdown {xp.label}") 211 | 212 | -------------------------------------------------------------------------------- /policy_t2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as distributions 5 | 6 | from gather import topk_by 7 | from multihead_attention import MultiheadAttention 8 | import spatial 9 | from gym_codecraft.envs.codecraft_vec_env import DEFAULT_OBS_CONFIG 10 | 11 | 12 | GLOBAL_FEATURES_V2 = 2 13 | DSTRIDE_V2 = 15 14 | MSTRIDE_V2 = 3 15 | NONOBS_FEATURES_V2 = 3 16 | 17 | 18 | class TransformerPolicy2(nn.Module): 19 | def __init__(self, 20 | d_agent, 21 | d_item, 22 | dff_ratio, 23 | nhead, 24 | dropout, 25 | small_init_pi, 26 | zero_init_vf, 27 | fp16, 28 | norm, 29 | agents, 30 | nally, 31 | nenemy, 32 | nmineral, 33 | obs_config=DEFAULT_OBS_CONFIG, 34 | use_privileged=False, 35 | nearby_map=False, 36 | ring_width=40, 37 | nrays=8, 38 | nrings=8, 39 | map_conv=False, 40 | map_conv_kernel_size=3, 41 | map_embed_offset=False, 42 | item_ff=True, 43 | keep_abspos=False, 44 | ally_enemy_same=False, 45 | naction=8, 46 | ): 47 | super(TransformerPolicy2, self).__init__() 48 | assert obs_config.drones > 0 or obs_config.minerals > 0,\ 49 | 'Must have at least one mineral or drones observation' 50 | assert obs_config.drones >= obs_config.allies 51 | assert not use_privileged or (nmineral > 0 and nally > 0 and (nenemy > 0 or ally_enemy_same)) 52 | 53 | self.version = 'transformer_v2' 54 | 55 | self.kwargs = dict( 56 | d_agent=d_agent, 57 | d_item=d_item, 58 | dff_ratio=dff_ratio, 59 | nhead=nhead, 60 | dropout=dropout, 61 | small_init_pi=small_init_pi, 62 | zero_init_vf=zero_init_vf, 63 | fp16=fp16, 64 | use_privileged=use_privileged, 65 | norm=norm, 66 | obs_config=obs_config, 67 | agents=agents, 68 | nally=nally, 69 | nenemy=nenemy, 70 | nmineral=nmineral, 71 | nearby_map=nearby_map, 72 | 73 | ring_width=ring_width, 74 | nrays=nrays, 75 | nrings=nrings, 76 | map_conv=map_conv, 77 | map_conv_kernel_size=map_conv_kernel_size, 78 | map_embed_offset=map_embed_offset, 79 | item_ff=item_ff, 80 | keep_abspos=keep_abspos, 81 | ally_enemy_same=ally_enemy_same, 82 | naction=naction, 83 | ) 84 | 85 | self.obs_config = obs_config 86 | self.agents = agents 87 | self.nally = nally 88 | self.nenemy = nenemy 89 | self.nmineral = nmineral 90 | self.nitem = nally + nenemy + nmineral 91 | if hasattr(obs_config, 'global_drones'): 92 | self.global_drones = obs_config.global_drones 93 | else: 94 | self.global_drones = 0 95 | 96 | self.d_agent = d_agent 97 | self.d_item = d_item 98 | self.dff_ratio = dff_ratio 99 | self.nhead = nhead 100 | self.dropout = dropout 101 | self.nearby_map = nearby_map 102 | self.ring_width = ring_width 103 | self.nrays = nrays 104 | self.nrings = nrings 105 | self.map_conv = map_conv 106 | self.map_conv_kernel_size = map_conv_kernel_size 107 | self.map_embed_offset = map_embed_offset 108 | self.item_ff = item_ff 109 | self.naction = naction 110 | 111 | self.fp16 = fp16 112 | self.use_privileged = use_privileged 113 | self.ally_enemy_same = ally_enemy_same 114 | 115 | if norm == 'none': 116 | norm_fn = lambda x: nn.Sequential() 117 | elif norm == 'batchnorm': 118 | norm_fn = lambda n: nn.BatchNorm2d(n) 119 | elif norm == 'layernorm': 120 | norm_fn = lambda n: nn.LayerNorm(n) 121 | else: 122 | raise Exception(f'Unexpected normalization layer {norm}') 123 | 124 | self.agent_embedding = ItemBlock( 125 | DSTRIDE_V2 + GLOBAL_FEATURES_V2, d_agent, d_agent * dff_ratio, norm_fn, True, 126 | keep_abspos=True, 127 | mask_feature=7, # Feature 7 is hitpoints 128 | relpos=False, 129 | ) 130 | if ally_enemy_same: 131 | self.drone_net = ItemBlock( 132 | DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff, 133 | keep_abspos=keep_abspos, 134 | mask_feature=7, # Feature 7 is hitpoints 135 | topk=nally+nenemy, 136 | ) 137 | else: 138 | self.ally_net = ItemBlock( 139 | DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff, 140 | keep_abspos=keep_abspos, 141 | mask_feature=7, # Feature 7 is hitpoints 142 | topk=nally, 143 | ) 144 | self.enemy_net = ItemBlock( 145 | DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff, 146 | keep_abspos=keep_abspos, 147 | mask_feature=7, # Feature 7 is hitpoints 148 | topk=nenemy, 149 | ) 150 | self.mineral_net = ItemBlock( 151 | MSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff, 152 | keep_abspos=keep_abspos, 153 | mask_feature=2, # Feature 2 is size 154 | topk=nmineral, 155 | ) 156 | 157 | if use_privileged: 158 | self.pmineral_net = ItemBlock( 159 | MSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff, 160 | keep_abspos=True, relpos=False, mask_feature=2, 161 | ) 162 | if ally_enemy_same: 163 | self.pdrone_net = ItemBlock( 164 | DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff, 165 | keep_abspos=True, relpos=False, mask_feature=7, 166 | ) 167 | else: 168 | self.pally_net = ItemBlock( 169 | DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff, 170 | keep_abspos=True, relpos=False, mask_feature=7, 171 | ) 172 | self.penemy_net = ItemBlock( 173 | DSTRIDE_V2, d_item, d_item * dff_ratio, norm_fn, self.item_ff, 174 | keep_abspos=True, relpos=False, mask_feature=7, 175 | ) 176 | 177 | self.multihead_attention = MultiheadAttention( 178 | embed_dim=d_agent, 179 | kdim=d_item, 180 | vdim=d_item, 181 | num_heads=nhead, 182 | dropout=dropout, 183 | ) 184 | self.linear1 = nn.Linear(d_agent, d_agent * dff_ratio) 185 | self.linear2 = nn.Linear(d_agent * dff_ratio, d_agent) 186 | self.norm1 = nn.LayerNorm(d_agent) 187 | self.norm2 = nn.LayerNorm(d_agent) 188 | 189 | self.map_channels = d_agent // (nrings * nrays) 190 | map_item_channels = self.map_channels - 2 if self.map_embed_offset else self.map_channels 191 | self.downscale = nn.Linear(d_item, map_item_channels) 192 | self.norm_map = norm_fn(map_item_channels) 193 | self.conv1 = spatial.ZeroPaddedCylindricalConv2d( 194 | self.map_channels, dff_ratio * self.map_channels, kernel_size=map_conv_kernel_size) 195 | self.conv2 = spatial.ZeroPaddedCylindricalConv2d( 196 | dff_ratio * self.map_channels, self.map_channels, kernel_size=map_conv_kernel_size) 197 | self.norm_conv = norm_fn(self.map_channels) 198 | 199 | final_width = d_agent 200 | if nearby_map: 201 | final_width += d_agent 202 | self.final_layer = nn.Sequential( 203 | nn.Linear(final_width, d_agent * dff_ratio), 204 | nn.ReLU(), 205 | ) 206 | 207 | self.policy_head = nn.Linear(d_agent * dff_ratio, naction) 208 | if small_init_pi: 209 | self.policy_head.weight.data *= 0.01 210 | self.policy_head.bias.data.fill_(0.0) 211 | 212 | if self.use_privileged: 213 | self.value_head = nn.Linear(d_agent * dff_ratio + 2 * d_item, 1) 214 | else: 215 | self.value_head = nn.Linear(d_agent * dff_ratio, 1) 216 | if zero_init_vf: 217 | self.value_head.weight.data.fill_(0.0) 218 | self.value_head.bias.data.fill_(0.0) 219 | 220 | self.epsilon = 1e-4 if fp16 else 1e-8 221 | 222 | def evaluate(self, observation, action_masks, privileged_obs): 223 | if self.fp16: 224 | action_masks = action_masks.half() 225 | action_masks = action_masks[:, :self.agents, :] 226 | probs, v = self.forward(observation, privileged_obs) 227 | probs = probs.view(-1, self.agents, self.naction) 228 | probs = probs * action_masks + self.epsilon # Add small value to prevent crash when no action is possible 229 | # We get device-side assert when using fp16 here (needs more investigation) 230 | action_dist = distributions.Categorical(probs.float() if self.fp16 else probs) 231 | actions = action_dist.sample() 232 | entropy = action_dist.entropy()[action_masks.sum(2) != 0] 233 | return actions, action_dist.log_prob(actions), entropy, v.detach().view(-1).cpu().numpy(), probs.detach().cpu().numpy() 234 | 235 | def backprop(self, 236 | hps, 237 | obs, 238 | actions, 239 | old_logprobs, 240 | returns, 241 | value_loss_scale, 242 | advantages, 243 | old_values, 244 | action_masks, 245 | old_probs, 246 | privileged_obs, 247 | split_reward): 248 | if self.fp16: 249 | advantages = advantages.half() 250 | returns = returns.half() 251 | action_masks = action_masks.half() 252 | old_logprobs = old_logprobs.half() 253 | 254 | action_masks = action_masks[:, :self.agents, :] 255 | x, (pitems, pmask) = self.latents(obs, privileged_obs) 256 | batch_size = x.size()[0] 257 | 258 | vin = x.max(dim=1).values.view(batch_size, self.d_agent * self.dff_ratio) 259 | if self.use_privileged: 260 | pitems_max = pitems.max(dim=1).values 261 | pitems_avg = pitems.sum(dim=1) / torch.clamp_min((~pmask).float().sum(dim=1), min=1).unsqueeze(-1) 262 | vin = torch.cat([vin, pitems_max, pitems_avg], dim=1) 263 | values = self.value_head(vin).view(-1) 264 | 265 | logits = self.policy_head(x) 266 | probs = F.softmax(logits, dim=2) 267 | probs = probs.view(-1, self.agents, self.naction) 268 | 269 | # add small value to prevent degenerate probability distribution when no action is possible 270 | # gradients still get blocked by the action mask 271 | # TODO: mask actions by setting logits to -inf? 272 | probs = probs * action_masks + self.epsilon 273 | 274 | active_agents = torch.clamp_min((action_masks.sum(dim=2) > 0).float().sum(dim=1), min=1) 275 | 276 | dist = distributions.Categorical(probs) 277 | entropy = dist.entropy() 278 | logprobs = dist.log_prob(actions) 279 | ratios = torch.exp(logprobs - old_logprobs) 280 | advantages = advantages.view(-1, 1) 281 | if split_reward: 282 | advantages = advantages / active_agents.view(-1, 1) 283 | vanilla_policy_loss = advantages * ratios 284 | clipped_policy_loss = advantages * torch.clamp(ratios, 1 - hps.cliprange, 1 + hps.cliprange) 285 | if hps.ppo: 286 | policy_loss = -torch.min(vanilla_policy_loss, clipped_policy_loss).mean() 287 | else: 288 | policy_loss = -vanilla_policy_loss.mean() 289 | 290 | # TODO: do over full distribution, not just selected actions? 291 | approxkl = 0.5 * (old_logprobs - logprobs).pow(2).mean() 292 | clipfrac = ((ratios - 1.0).abs() > hps.cliprange).sum().type(torch.float32) / ratios.numel() 293 | 294 | clipped_values = old_values + torch.clamp(values - old_values, -hps.cliprange, hps.cliprange) 295 | vanilla_value_loss = (values - returns) ** 2 296 | clipped_value_loss = (clipped_values - returns) ** 2 297 | if hps.clip_vf: 298 | value_loss = torch.max(vanilla_value_loss, clipped_value_loss).mean() 299 | else: 300 | value_loss = vanilla_value_loss.mean() 301 | 302 | entropy_loss = -hps.entropy_bonus * entropy.mean() 303 | 304 | loss = policy_loss + value_loss_scale * value_loss + entropy_loss 305 | loss /= hps.batches_per_update 306 | loss.backward() 307 | return policy_loss.data.tolist(), value_loss.data.tolist(), approxkl.data.tolist(), clipfrac.data.tolist() 308 | 309 | def forward(self, x, x_privileged): 310 | batch_size = x.size()[0] 311 | x, (pitems, pmask) = self.latents(x, x_privileged) 312 | 313 | vin = x.max(dim=1).values.view(batch_size, self.d_agent * self.dff_ratio) 314 | if self.use_privileged: 315 | pitems_max = pitems.max(dim=1).values 316 | pitems_avg = pitems.sum(dim=1) / torch.clamp_min((~pmask).float().sum(dim=1), min=1).unsqueeze(-1) 317 | vin = torch.cat([vin, pitems_max, pitems_avg], dim=1) 318 | values = self.value_head(vin).view(-1) 319 | 320 | logits = self.policy_head(x) 321 | probs = F.softmax(logits, dim=2) 322 | 323 | # return probs.view(batch_size, 8, self.allies).permute(0, 2, 1), values 324 | return probs, values 325 | 326 | def logits(self, x, x_privileged): 327 | x, x_privileged = self.latents(x, x_privileged) 328 | return self.policy_head(x) 329 | 330 | def latents(self, x, x_privileged): 331 | if self.fp16: 332 | # Normalization layers perform fp16 conversion for x after normalization 333 | x_privileged = x_privileged.half() 334 | 335 | batch_size = x.size()[0] 336 | 337 | endglobals = GLOBAL_FEATURES_V2 338 | endallies = GLOBAL_FEATURES_V2 + DSTRIDE_V2 * self.obs_config.allies 339 | endenemies = GLOBAL_FEATURES_V2 + DSTRIDE_V2 * self.obs_config.drones 340 | endmins = endenemies + MSTRIDE_V2 * self.obs_config.minerals 341 | endallenemies = endmins + DSTRIDE_V2 * (self.obs_config.drones - self.obs_config.allies) 342 | 343 | globals = x[:, :endglobals] 344 | 345 | # properties of the drone controlled by this network 346 | xagent = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, DSTRIDE_V2)[:, :self.agents, :] 347 | globals = globals.view(batch_size, 1, GLOBAL_FEATURES_V2) \ 348 | .expand(batch_size, self.agents, GLOBAL_FEATURES_V2) 349 | xagent = torch.cat([xagent, globals], dim=2) 350 | agents, _, mask_agent = self.agent_embedding(xagent) 351 | 352 | origin = xagent[:, :, 0:2].clone() 353 | direction = xagent[:, :, 2:4].clone() 354 | 355 | if self.ally_enemy_same: 356 | xdrone = x[:, endglobals:endenemies].view(batch_size, self.obs_config.drones, DSTRIDE_V2) 357 | items, relpos, mask = self.drone_net(xdrone, origin, direction) 358 | else: 359 | xally = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, DSTRIDE_V2) 360 | items, relpos, mask = self.ally_net(xally, origin, direction) 361 | # Ensure that at least one item is not masked out to prevent NaN in transformer softmax 362 | mask[:, :, 0] = 0 363 | 364 | if self.nenemy > 0 and not self.ally_enemy_same: 365 | eobs = self.obs_config.drones - self.obs_config.allies 366 | xe = x[:, endallies:endenemies].view(batch_size, eobs, DSTRIDE_V2) 367 | 368 | items_e, relpos_e, mask_e = self.enemy_net(xe, origin, direction) 369 | items = torch.cat([items, items_e], dim=2) 370 | mask = torch.cat([mask, mask_e], dim=2) 371 | relpos = torch.cat([relpos, relpos_e], dim=2) 372 | 373 | if self.nmineral > 0: 374 | xm = x[:, endenemies:endmins].view(batch_size, self.obs_config.minerals, MSTRIDE_V2) 375 | 376 | items_m, relpos_m, mask_m = self.mineral_net(xm, origin, direction) 377 | items = torch.cat([items, items_m], dim=2) 378 | mask = torch.cat([mask, mask_m], dim=2) 379 | relpos = torch.cat([relpos, relpos_m], dim=2) 380 | 381 | if self.use_privileged: 382 | # TODO: use hidden enemies 383 | xally = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, DSTRIDE_V2) 384 | eobs = self.obs_config.drones - self.obs_config.allies 385 | xenemy = x[:, endmins:endallenemies].view(batch_size, eobs, DSTRIDE_V2) 386 | if self.ally_enemy_same: 387 | xdrone = torch.cat([xally, xenemy], dim=1) 388 | pitems, _, pmask = self.pdrone_net(xdrone) 389 | else: 390 | pitems, _, pmask = self.pally_net(xally) 391 | pitems_e, _, pmask_e = self.penemy_net(xenemy) 392 | pitems = torch.cat([pitems, pitems_e], dim=1) 393 | pmask = torch.cat([pmask, pmask_e], dim=1) 394 | xm = x[:, endenemies:endmins].view(batch_size, self.obs_config.minerals, MSTRIDE_V2) 395 | pitems_m, _, pmask_m = self.pmineral_net(xm) 396 | pitems = torch.cat([pitems, pitems_m], dim=1) 397 | pmask = torch.cat([pmask, pmask_m], dim=1) 398 | else: 399 | pitems = None 400 | pmask = None 401 | 402 | # Transformer input dimensions are: Sequence length, Batch size, Embedding size 403 | source = items.view(batch_size * self.agents, self.nitem, self.d_item).permute(1, 0, 2) 404 | target = agents.view(1, batch_size * self.agents, self.d_agent) 405 | x, attn_weights = self.multihead_attention( 406 | query=target, 407 | key=source, 408 | value=source, 409 | key_padding_mask=mask.view(batch_size * self.agents, self.nitem), 410 | ) 411 | x = self.norm1(x + target) 412 | x2 = self.linear2(F.relu(self.linear1(x))) 413 | x = self.norm2(x + x2) 414 | x = x.view(batch_size, self.agents, self.d_agent) 415 | 416 | if self.nearby_map: 417 | items = self.norm_map(F.relu(self.downscale(items))) 418 | items = items * (1 - mask.float().unsqueeze(-1)) 419 | nearby_map = spatial.spatial_scatter( 420 | items=items, 421 | positions=relpos, 422 | nray=self.nrays, 423 | nring=self.nrings, 424 | inner_radius=self.ring_width, 425 | embed_offsets=self.map_embed_offset, 426 | ).view(batch_size * self.agents, self.map_channels, self.nrings, self.nrays) 427 | if self.map_conv: 428 | nearby_map2 = self.conv2(F.relu(self.conv1(nearby_map))) 429 | nearby_map2 = nearby_map2.permute(0, 3, 2, 1) 430 | nearby_map = nearby_map.permute(0, 3, 2, 1) 431 | nearby_map = self.norm_conv(nearby_map + nearby_map2) 432 | nearby_map = nearby_map.reshape(batch_size, self.agents, self.d_agent) 433 | x = torch.cat([x, nearby_map], dim=2) 434 | 435 | x = self.final_layer(x) 436 | x = x.view(batch_size, self.agents, self.d_agent * self.dff_ratio) 437 | x = x * (~mask_agent).float().unsqueeze(-1) 438 | 439 | return x, (pitems, pmask) 440 | 441 | def param_groups(self): 442 | # TODO? 443 | pass 444 | 445 | 446 | # Computes a running mean/variance of input features and performs normalization. 447 | # https://www.johndcook.com/blog/standard_deviation/ 448 | class InputNorm(nn.Module): 449 | def __init__(self, num_features, cliprange=5): 450 | super(InputNorm, self).__init__() 451 | 452 | self.cliprange = cliprange 453 | self.register_buffer('count', torch.tensor(0)) 454 | self.register_buffer('mean', torch.zeros(num_features)) 455 | self.register_buffer('squares_sum', torch.zeros(num_features)) 456 | self.fp16 = False 457 | 458 | def update(self, input, mask): 459 | if mask is not None: 460 | if len(input.size()) == 4: 461 | batch_size, nally, nitem, features = input.size() 462 | assert (batch_size, nally, nitem) == mask.size() 463 | elif len(input.size()) == 3: 464 | batch_size, nally, features = input.size() 465 | assert (batch_size, nally) == mask.size() 466 | else: 467 | raise Exception(f'Expecting 3 or 4 dimensions, actual: {len(input.size())}') 468 | input = input[mask, :] 469 | else: 470 | features = input.size()[-1] 471 | input = input.reshape(-1, features) 472 | 473 | count = input.numel() / features 474 | if count == 0: 475 | return 476 | mean = input.mean(dim=0) 477 | if self.count == 0: 478 | self.count += count 479 | self.mean = mean 480 | self.squares_sum = ((input - mean) * (input - mean)).sum(dim=0) 481 | else: 482 | self.count += count 483 | new_mean = self.mean + (mean - self.mean) * count / self.count 484 | # This is probably not quite right because it applies multiple updates simultaneously. 485 | self.squares_sum = self.squares_sum + ((input - self.mean) * (input - new_mean)).sum(dim=0) 486 | self.mean = new_mean 487 | 488 | def forward(self, input, mask=None): 489 | with torch.no_grad(): 490 | if self.training: 491 | self.update(input, mask=mask) 492 | if self.count > 1: 493 | input = (input - self.mean) / self.stddev() 494 | input = torch.clamp(input, -self.cliprange, self.cliprange) 495 | if (input == float('-inf')).sum() > 0 \ 496 | or (input == float('inf')).sum() > 0 \ 497 | or (input != input).sum() > 0: 498 | print(input) 499 | print(self.squares_sum) 500 | print(self.stddev()) 501 | print(input) 502 | raise Exception("OVER/UNDERFLOW DETECTED!") 503 | 504 | return input.half() if self.fp16 else input 505 | 506 | def enable_fp16(self): 507 | # Convert buffers back to fp32, fp16 has insufficient precision and runs into overflow on squares_sum 508 | self.float() 509 | self.fp16 = True 510 | 511 | def stddev(self): 512 | sd = torch.sqrt(self.squares_sum / (self.count - 1)) 513 | sd[sd == 0] = 1 514 | return sd 515 | 516 | 517 | class InputEmbedding(nn.Module): 518 | def __init__(self, d_in, d_model, norm_fn): 519 | super(InputEmbedding, self).__init__() 520 | 521 | self.normalize = InputNorm(d_in) 522 | self.linear = nn.Linear(d_in, d_model) 523 | self.norm = norm_fn(d_model) 524 | 525 | def forward(self, x, mask=None): 526 | x = self.normalize(x, mask) 527 | x = F.relu(self.linear(x)) 528 | x = self.norm(x) 529 | return x 530 | 531 | 532 | class FFResblock(nn.Module): 533 | def __init__(self, d_model, d_ff, norm_fn): 534 | super(FFResblock, self).__init__() 535 | 536 | self.linear_1 = nn.Linear(d_model, d_ff) 537 | self.linear_2 = nn.Linear(d_ff, d_model) 538 | self.norm = norm_fn(d_model) 539 | 540 | # self.linear_2.weight.data.fill_(0.0) 541 | # self.linear_2.bias.data.fill_(0.0) 542 | 543 | def forward(self, x, mask=None): 544 | x2 = F.relu(self.linear_1(x)) 545 | x = x + F.relu(self.linear_2(x2)) 546 | x = self.norm(x) 547 | return x 548 | 549 | 550 | class ItemBlock(nn.Module): 551 | def __init__(self, d_in, d_model, d_ff, norm_fn, resblock, keep_abspos, mask_feature, relpos=True, topk=None): 552 | super(ItemBlock, self).__init__() 553 | 554 | if relpos: 555 | if keep_abspos: 556 | d_in += 3 557 | else: 558 | d_in += 1 559 | self.embedding = InputEmbedding(d_in, d_model, norm_fn) 560 | self.mask_feature = mask_feature 561 | self.keep_abspos = keep_abspos 562 | self.topk = topk 563 | if resblock: 564 | self.resblock = FFResblock(d_model, d_ff, norm_fn) 565 | 566 | def forward(self, x, origin=None, direction=None): 567 | batch_size, items, features = x.size() 568 | 569 | if origin is not None: 570 | _, agents, _ = origin.size() 571 | 572 | pos = x[:, :, 0:2] 573 | relpos = spatial.relative_positions(origin, direction, pos) 574 | dist = relpos.norm(p=2, dim=3) 575 | direction = relpos / (dist.unsqueeze(-1) + 1e-8) 576 | 577 | x = x.view(batch_size, 1, items, features)\ 578 | .expand(batch_size, agents, items, features) 579 | if self.keep_abspos: 580 | x = torch.cat([x, direction, torch.sqrt(dist.unsqueeze(-1))], dim=3) 581 | else: 582 | x = torch.cat([direction, x[:, :, :, 2:], torch.sqrt(dist.unsqueeze(-1))], dim=3) 583 | 584 | if self.topk is not None: 585 | empty = (x[:, :, :, self.mask_feature] == 0).float() 586 | key = -dist - empty * 1e8 587 | x = topk_by(values=x, vdim=2, keys=key, kdim=2, k=self.topk) 588 | relpos = topk_by(values=relpos, vdim=2, keys=key, kdim=2, k=self.topk) 589 | 590 | mask = x[:, :, :, self.mask_feature] == 0 591 | else: 592 | relpos = None 593 | mask = x[:, :, self.mask_feature] == 0 594 | 595 | x = self.embedding(x, ~mask) 596 | if self.resblock is not None: 597 | x = self.resblock(x) 598 | x = x * (~mask).unsqueeze(-1).float() 599 | 600 | return x, relpos, mask 601 | 602 | -------------------------------------------------------------------------------- /policy_t3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as distributions 5 | 6 | from gather import topk_by 7 | from multihead_attention import MultiheadAttention 8 | import spatial 9 | 10 | 11 | class TransformerPolicy3(nn.Module): 12 | def __init__(self, hps, obs_config): 13 | super(TransformerPolicy3, self).__init__() 14 | assert obs_config.drones > 0 or obs_config.minerals > 0,\ 15 | 'Must have at least one mineral or drones observation' 16 | assert obs_config.drones >= obs_config.allies 17 | assert not hps.use_privileged or (hps.nmineral > 0 and hps.nally > 0 and (hps.nenemy > 0 or hps.ally_enemy_same)) 18 | 19 | self.version = 'transformer_v3' 20 | 21 | self.kwargs = dict( 22 | hps=hps, 23 | obs_config=obs_config 24 | ) 25 | 26 | self.hps = hps 27 | self.obs_config = obs_config 28 | self.agents = hps.agents 29 | self.nally = hps.nally 30 | self.nenemy = hps.nenemy 31 | self.nmineral = hps.nmineral 32 | self.nitem = hps.nally + hps.nenemy + hps.nmineral 33 | self.fp16 = hps.fp16 34 | self.d_agent = hps.d_agent 35 | self.d_item = hps.d_item 36 | self.naction = hps.objective.naction() 37 | 38 | if hasattr(obs_config, 'global_drones'): 39 | self.global_drones = obs_config.global_drones 40 | else: 41 | self.global_drones = 0 42 | 43 | if hps.norm == 'none': 44 | norm_fn = lambda x: nn.Sequential() 45 | elif hps.norm == 'batchnorm': 46 | norm_fn = lambda n: nn.BatchNorm2d(n) 47 | elif hps.norm == 'layernorm': 48 | norm_fn = lambda n: nn.LayerNorm(n) 49 | else: 50 | raise Exception(f'Unexpected normalization layer {hps.norm}') 51 | 52 | self.agent_embedding = ItemBlock( 53 | obs_config.dstride() + obs_config.global_features(), 54 | hps.d_agent, hps.d_agent * hps.dff_ratio, norm_fn, True, 55 | keep_abspos=True, 56 | mask_feature=7, # Feature 7 is hitpoints 57 | relpos=False, 58 | ) 59 | if hps.ally_enemy_same: 60 | self.drone_net = ItemBlock( 61 | obs_config.dstride(), 62 | hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 63 | keep_abspos=hps.obs_keep_abspos, 64 | mask_feature=7, # Feature 7 is hitpoints 65 | topk=hps.nally+hps.nenemy, 66 | ) 67 | else: 68 | self.ally_net = ItemBlock( 69 | obs_config.dstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 70 | keep_abspos=hps.obs_keep_abspos, 71 | mask_feature=7, # Feature 7 is hitpoints 72 | topk=hps.nally, 73 | ) 74 | self.enemy_net = ItemBlock( 75 | obs_config.dstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 76 | keep_abspos=hps.obs_keep_abspos, 77 | mask_feature=7, # Feature 7 is hitpoints 78 | topk=hps.nenemy, 79 | ) 80 | self.mineral_net = ItemBlock( 81 | obs_config.mstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 82 | keep_abspos=hps.obs_keep_abspos, 83 | mask_feature=2, # Feature 2 is size 84 | topk=hps.nmineral, 85 | ) 86 | 87 | if hps.use_privileged: 88 | self.pmineral_net = ItemBlock( 89 | obs_config.mstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 90 | keep_abspos=True, relpos=False, mask_feature=2, 91 | ) 92 | if hps.ally_enemy_same: 93 | self.pdrone_net = ItemBlock( 94 | obs_config.dstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 95 | keep_abspos=True, relpos=False, mask_feature=7, 96 | ) 97 | else: 98 | self.pally_net = ItemBlock( 99 | obs_config.dstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 100 | keep_abspos=True, relpos=False, mask_feature=7, 101 | ) 102 | self.penemy_net = ItemBlock( 103 | obs_config.dstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 104 | keep_abspos=True, relpos=False, mask_feature=7, 105 | ) 106 | 107 | if hps.item_item_attn_layers > 0: 108 | encoder_layer = nn.TransformerEncoderLayer(d_model=hps.d_item, nhead=8) 109 | self.item_item_attn = nn.TransformerEncoder(encoder_layer, num_layers=hps.item_item_attn_layers) 110 | else: 111 | self.item_item_attn = None 112 | 113 | self.multihead_attention = MultiheadAttention( 114 | embed_dim=hps.d_agent, 115 | kdim=hps.d_item, 116 | vdim=hps.d_item, 117 | num_heads=hps.nhead, 118 | dropout=hps.dropout, 119 | ) 120 | self.linear1 = nn.Linear(hps.d_agent, hps.d_agent * hps.dff_ratio) 121 | self.linear2 = nn.Linear(hps.d_agent * hps.dff_ratio, hps.d_agent) 122 | self.norm1 = nn.LayerNorm(hps.d_agent) 123 | self.norm2 = nn.LayerNorm(hps.d_agent) 124 | 125 | self.map_channels = hps.d_agent // (hps.nm_nrings * hps.nm_nrays) 126 | map_item_channels = self.map_channels - 2 if self.hps.map_embed_offset else self.map_channels 127 | self.downscale = nn.Linear(hps.d_item, map_item_channels) 128 | self.norm_map = norm_fn(map_item_channels) 129 | self.conv1 = spatial.ZeroPaddedCylindricalConv2d( 130 | self.map_channels, hps.dff_ratio * self.map_channels, kernel_size=3) 131 | self.conv2 = spatial.ZeroPaddedCylindricalConv2d( 132 | hps.dff_ratio * self.map_channels, self.map_channels, kernel_size=3) 133 | self.norm_conv = norm_fn(self.map_channels) 134 | 135 | final_width = hps.d_agent 136 | if hps.nearby_map: 137 | final_width += hps.d_agent 138 | self.final_layer = nn.Sequential( 139 | nn.Linear(final_width, hps.d_agent * hps.dff_ratio), 140 | nn.ReLU(), 141 | ) 142 | 143 | self.policy_head = nn.Linear(hps.d_agent * hps.dff_ratio, self.naction) 144 | if hps.small_init_pi: 145 | self.policy_head.weight.data *= 0.01 146 | self.policy_head.bias.data.fill_(0.0) 147 | 148 | if hps.use_privileged: 149 | self.value_head = nn.Linear(hps.d_agent * hps.dff_ratio + 2 * hps.d_item, 1) 150 | else: 151 | self.value_head = nn.Linear(hps.d_agent * hps.dff_ratio, 1) 152 | if hps.zero_init_vf: 153 | self.value_head.weight.data.fill_(0.0) 154 | self.value_head.bias.data.fill_(0.0) 155 | 156 | self.epsilon = 1e-4 if hps.fp16 else 1e-8 157 | 158 | def evaluate(self, observation, action_masks, privileged_obs): 159 | if self.fp16: 160 | action_masks = action_masks.half() 161 | action_masks = action_masks[:, :self.agents, :] 162 | probs, v = self.forward(observation, privileged_obs) 163 | probs = probs.view(-1, self.agents, self.naction) 164 | probs = probs * action_masks + self.epsilon # Add small value to prevent crash when no action is possible 165 | # We get device-side assert when using fp16 here (needs more investigation) 166 | action_dist = distributions.Categorical(probs.float() if self.fp16 else probs) 167 | actions = action_dist.sample() 168 | entropy = action_dist.entropy()[action_masks.sum(2) != 0] 169 | return actions, action_dist.log_prob(actions), entropy, v.detach().view(-1).cpu().numpy(), probs.detach().cpu().numpy() 170 | 171 | def backprop(self, 172 | hps, 173 | obs, 174 | actions, 175 | old_logprobs, 176 | returns, 177 | value_loss_scale, 178 | advantages, 179 | old_values, 180 | action_masks, 181 | old_probs, 182 | privileged_obs, 183 | split_reward): 184 | if self.fp16: 185 | advantages = advantages.half() 186 | returns = returns.half() 187 | action_masks = action_masks.half() 188 | old_logprobs = old_logprobs.half() 189 | 190 | action_masks = action_masks[:, :self.agents, :] 191 | x, (pitems, pmask) = self.latents(obs, privileged_obs) 192 | batch_size = x.size()[0] 193 | 194 | vin = x.max(dim=1).values.view(batch_size, self.d_agent * self.hps.dff_ratio) 195 | if self.hps.use_privileged: 196 | pitems_max = pitems.max(dim=1).values 197 | pitems_avg = pitems.sum(dim=1) / torch.clamp_min((~pmask).float().sum(dim=1), min=1).unsqueeze(-1) 198 | vin = torch.cat([vin, pitems_max, pitems_avg], dim=1) 199 | values = self.value_head(vin).view(-1) 200 | 201 | logits = self.policy_head(x) 202 | probs = F.softmax(logits, dim=2) 203 | probs = probs.view(-1, self.agents, self.naction) 204 | 205 | # add small value to prevent degenerate probability distribution when no action is possible 206 | # gradients still get blocked by the action mask 207 | # TODO: mask actions by setting logits to -inf? 208 | probs = probs * action_masks + self.epsilon 209 | 210 | active_agents = torch.clamp_min((action_masks.sum(dim=2) > 0).float().sum(dim=1), min=1) 211 | 212 | dist = distributions.Categorical(probs) 213 | entropy = dist.entropy() 214 | logprobs = dist.log_prob(actions) 215 | ratios = torch.exp(logprobs - old_logprobs) 216 | advantages = advantages.view(-1, 1) 217 | if split_reward: 218 | advantages = advantages / active_agents.view(-1, 1) 219 | vanilla_policy_loss = advantages * ratios 220 | clipped_policy_loss = advantages * torch.clamp(ratios, 1 - hps.cliprange, 1 + hps.cliprange) 221 | if hps.ppo: 222 | policy_loss = -torch.min(vanilla_policy_loss, clipped_policy_loss).mean() 223 | else: 224 | policy_loss = -vanilla_policy_loss.mean() 225 | 226 | # TODO: do over full distribution, not just selected actions? 227 | approxkl = 0.5 * (old_logprobs - logprobs).pow(2).mean() 228 | clipfrac = ((ratios - 1.0).abs() > hps.cliprange).sum().type(torch.float32) / ratios.numel() 229 | 230 | clipped_values = old_values + torch.clamp(values - old_values, -hps.cliprange, hps.cliprange) 231 | vanilla_value_loss = (values - returns) ** 2 232 | clipped_value_loss = (clipped_values - returns) ** 2 233 | if hps.clip_vf: 234 | value_loss = torch.max(vanilla_value_loss, clipped_value_loss).mean() 235 | else: 236 | value_loss = vanilla_value_loss.mean() 237 | 238 | entropy_loss = -hps.entropy_bonus * entropy.mean() 239 | 240 | loss = policy_loss + value_loss_scale * value_loss + entropy_loss 241 | loss /= hps.batches_per_update 242 | loss.backward() 243 | return policy_loss.data.tolist(), value_loss.data.tolist(), approxkl.data.tolist(), clipfrac.data.tolist() 244 | 245 | def forward(self, x, x_privileged): 246 | batch_size = x.size()[0] 247 | x, (pitems, pmask) = self.latents(x, x_privileged) 248 | 249 | vin = x.max(dim=1).values.view(batch_size, self.d_agent * self.hps.dff_ratio) 250 | if self.hps.use_privileged: 251 | pitems_max = pitems.max(dim=1).values 252 | pitems_avg = pitems.sum(dim=1) / torch.clamp_min((~pmask).float().sum(dim=1), min=1).unsqueeze(-1) 253 | vin = torch.cat([vin, pitems_max, pitems_avg], dim=1) 254 | values = self.value_head(vin).view(-1) 255 | 256 | logits = self.policy_head(x) 257 | probs = F.softmax(logits, dim=2) 258 | 259 | # return probs.view(batch_size, 8, self.allies).permute(0, 2, 1), values 260 | return probs, values 261 | 262 | def logits(self, x, x_privileged): 263 | x, x_privileged = self.latents(x, x_privileged) 264 | return self.policy_head(x) 265 | 266 | def latents(self, x, x_privileged): 267 | if self.fp16: 268 | # Normalization layers perform fp16 conversion for x after normalization 269 | x_privileged = x_privileged.half() 270 | 271 | batch_size = x.size()[0] 272 | 273 | endglobals = self.obs_config.endglobals() 274 | endallies = self.obs_config.endallies() 275 | endenemies = self.obs_config.endenemies() 276 | endmins = self.obs_config.endmins() 277 | endallenemies = self.obs_config.endallenemies() 278 | 279 | globals = x[:, :endglobals] 280 | 281 | # properties of the drone controlled by this network 282 | xagent = x[:, endglobals:endallies]\ 283 | .view(batch_size, self.obs_config.allies, self.obs_config.dstride())[:, :self.agents, :] 284 | globals = globals.view(batch_size, 1, self.obs_config.global_features()) \ 285 | .expand(batch_size, self.agents, self.obs_config.global_features()) 286 | xagent = torch.cat([xagent, globals], dim=2) 287 | agents, _, mask_agent = self.agent_embedding(xagent) 288 | 289 | origin = xagent[:, :, 0:2].clone() 290 | direction = xagent[:, :, 2:4].clone() 291 | 292 | if self.hps.ally_enemy_same: 293 | xdrone = x[:, endglobals:endenemies].view(batch_size, self.obs_config.drones, self.obs_config.dstride()) 294 | items, relpos, mask = self.drone_net(xdrone, origin, direction) 295 | else: 296 | xally = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, self.obs_config.dstride()) 297 | items, relpos, mask = self.ally_net(xally, origin, direction) 298 | # Ensure that at least one item is not masked out to prevent NaN in transformer softmax 299 | mask[:, :, 0] = 0 300 | 301 | if self.nenemy > 0 and not self.hps.ally_enemy_same: 302 | eobs = self.obs_config.drones - self.obs_config.allies 303 | xe = x[:, endallies:endenemies].view(batch_size, eobs, self.obs_config.dstride()) 304 | 305 | items_e, relpos_e, mask_e = self.enemy_net(xe, origin, direction) 306 | items = torch.cat([items, items_e], dim=2) 307 | mask = torch.cat([mask, mask_e], dim=2) 308 | relpos = torch.cat([relpos, relpos_e], dim=2) 309 | 310 | if self.nmineral > 0: 311 | xm = x[:, endenemies:endmins].view(batch_size, self.obs_config.minerals, self.obs_config.mstride()) 312 | 313 | items_m, relpos_m, mask_m = self.mineral_net(xm, origin, direction) 314 | items = torch.cat([items, items_m], dim=2) 315 | mask = torch.cat([mask, mask_m], dim=2) 316 | relpos = torch.cat([relpos, relpos_m], dim=2) 317 | 318 | if self.hps.use_privileged: 319 | xally = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, self.obs_config.dstride()) 320 | eobs = self.obs_config.drones - self.obs_config.allies 321 | xenemy = x[:, endmins:endallenemies].view(batch_size, eobs, self.obs_config.dstride()) 322 | if self.hps.ally_enemy_same: 323 | xdrone = torch.cat([xally, xenemy], dim=1) 324 | pitems, _, pmask = self.pdrone_net(xdrone) 325 | else: 326 | pitems, _, pmask = self.pally_net(xally) 327 | pitems_e, _, pmask_e = self.penemy_net(xenemy) 328 | pitems = torch.cat([pitems, pitems_e], dim=1) 329 | pmask = torch.cat([pmask, pmask_e], dim=1) 330 | xm = x[:, endenemies:endmins].view(batch_size, self.obs_config.minerals, self.obs_config.mstride()) 331 | pitems_m, _, pmask_m = self.pmineral_net(xm) 332 | pitems = torch.cat([pitems, pitems_m], dim=1) 333 | pmask = torch.cat([pmask, pmask_m], dim=1) 334 | if self.item_item_attn: 335 | pmask_nonzero = pmask.clone() 336 | pmask_nonzero[:, 0] = False 337 | pitems = self.item_item_attn( 338 | pitems.permute(1, 0, 2), 339 | src_key_padding_mask=pmask_nonzero, 340 | ).permute(1, 0, 2) 341 | if (pitems != pitems).sum() > 0: 342 | print(pmask) 343 | print(pitems) 344 | raise Exception("NaN!") 345 | else: 346 | pitems = None 347 | pmask = None 348 | 349 | # Transformer input dimensions are: Sequence length, Batch size, Embedding size 350 | source = items.view(batch_size * self.agents, self.nitem, self.d_item).permute(1, 0, 2) 351 | target = agents.view(1, batch_size * self.agents, self.d_agent) 352 | x, attn_weights = self.multihead_attention( 353 | query=target, 354 | key=source, 355 | value=source, 356 | key_padding_mask=mask.view(batch_size * self.agents, self.nitem), 357 | ) 358 | x = self.norm1(x + target) 359 | x2 = self.linear2(F.relu(self.linear1(x))) 360 | x = self.norm2(x + x2) 361 | x = x.view(batch_size, self.agents, self.d_agent) 362 | 363 | if self.hps.nearby_map: 364 | items = self.norm_map(F.relu(self.downscale(items))) 365 | items = items * (1 - mask.float().unsqueeze(-1)) 366 | nearby_map = spatial.spatial_scatter( 367 | items=items, 368 | positions=relpos, 369 | nray=self.hps.nm_nrays, 370 | nring=self.hps.nm_nrings, 371 | inner_radius=self.hps.nm_ring_width, 372 | embed_offsets=self.hps.map_embed_offset, 373 | ).view(batch_size * self.agents, self.map_channels, self.hps.nm_nrings, self.hps.nm_nrays) 374 | if self.hps.map_conv: 375 | nearby_map2 = self.conv2(F.relu(self.conv1(nearby_map))) 376 | nearby_map2 = nearby_map2.permute(0, 3, 2, 1) 377 | nearby_map = nearby_map.permute(0, 3, 2, 1) 378 | nearby_map = self.norm_conv(nearby_map + nearby_map2) 379 | nearby_map = nearby_map.reshape(batch_size, self.agents, self.d_agent) 380 | x = torch.cat([x, nearby_map], dim=2) 381 | 382 | x = self.final_layer(x) 383 | x = x.view(batch_size, self.agents, self.d_agent * self.hps.dff_ratio) 384 | x = x * (~mask_agent).float().unsqueeze(-1) 385 | 386 | return x, (pitems, pmask) 387 | 388 | def param_groups(self): 389 | # TODO? 390 | pass 391 | 392 | 393 | # Computes a running mean/variance of input features and performs normalization. 394 | # https://www.johndcook.com/blog/standard_deviation/ 395 | class InputNorm(nn.Module): 396 | def __init__(self, num_features, cliprange=5): 397 | super(InputNorm, self).__init__() 398 | 399 | self.cliprange = cliprange 400 | self.register_buffer('count', torch.tensor(0)) 401 | self.register_buffer('mean', torch.zeros(num_features)) 402 | self.register_buffer('squares_sum', torch.zeros(num_features)) 403 | self.fp16 = False 404 | 405 | def update(self, input, mask): 406 | if mask is not None: 407 | if len(input.size()) == 4: 408 | batch_size, nally, nitem, features = input.size() 409 | assert (batch_size, nally, nitem) == mask.size() 410 | elif len(input.size()) == 3: 411 | batch_size, nally, features = input.size() 412 | assert (batch_size, nally) == mask.size() 413 | else: 414 | raise Exception(f'Expecting 3 or 4 dimensions, actual: {len(input.size())}') 415 | input = input[mask, :] 416 | else: 417 | features = input.size()[-1] 418 | input = input.reshape(-1, features) 419 | 420 | count = input.numel() / features 421 | if count == 0: 422 | return 423 | mean = input.mean(dim=0) 424 | if self.count == 0: 425 | self.count += count 426 | self.mean = mean 427 | self.squares_sum = ((input - mean) * (input - mean)).sum(dim=0) 428 | else: 429 | self.count += count 430 | new_mean = self.mean + (mean - self.mean) * count / self.count 431 | # This is probably not quite right because it applies multiple updates simultaneously. 432 | self.squares_sum = self.squares_sum + ((input - self.mean) * (input - new_mean)).sum(dim=0) 433 | self.mean = new_mean 434 | 435 | def forward(self, input, mask=None): 436 | with torch.no_grad(): 437 | if self.training: 438 | self.update(input, mask=mask) 439 | if self.count > 1: 440 | input = (input - self.mean) / self.stddev() 441 | input = torch.clamp(input, -self.cliprange, self.cliprange) 442 | if (input == float('-inf')).sum() > 0 \ 443 | or (input == float('inf')).sum() > 0 \ 444 | or (input != input).sum() > 0: 445 | print(input) 446 | print(self.squares_sum) 447 | print(self.stddev()) 448 | print(input) 449 | raise Exception("OVER/UNDERFLOW DETECTED!") 450 | 451 | return input.half() if self.fp16 else input 452 | 453 | def enable_fp16(self): 454 | # Convert buffers back to fp32, fp16 has insufficient precision and runs into overflow on squares_sum 455 | self.float() 456 | self.fp16 = True 457 | 458 | def stddev(self): 459 | sd = torch.sqrt(self.squares_sum / (self.count - 1)) 460 | sd[sd == 0] = 1 461 | return sd 462 | 463 | 464 | class InputEmbedding(nn.Module): 465 | def __init__(self, d_in, d_model, norm_fn): 466 | super(InputEmbedding, self).__init__() 467 | 468 | self.normalize = InputNorm(d_in) 469 | self.linear = nn.Linear(d_in, d_model) 470 | self.norm = norm_fn(d_model) 471 | 472 | def forward(self, x, mask=None): 473 | x = self.normalize(x, mask) 474 | x = F.relu(self.linear(x)) 475 | x = self.norm(x) 476 | return x 477 | 478 | 479 | class FFResblock(nn.Module): 480 | def __init__(self, d_model, d_ff, norm_fn): 481 | super(FFResblock, self).__init__() 482 | 483 | self.linear_1 = nn.Linear(d_model, d_ff) 484 | self.linear_2 = nn.Linear(d_ff, d_model) 485 | self.norm = norm_fn(d_model) 486 | 487 | # self.linear_2.weight.data.fill_(0.0) 488 | # self.linear_2.bias.data.fill_(0.0) 489 | 490 | def forward(self, x, mask=None): 491 | x2 = F.relu(self.linear_1(x)) 492 | x = x + F.relu(self.linear_2(x2)) 493 | x = self.norm(x) 494 | return x 495 | 496 | 497 | class ItemBlock(nn.Module): 498 | def __init__(self, d_in, d_model, d_ff, norm_fn, resblock, keep_abspos, mask_feature, relpos=True, topk=None): 499 | super(ItemBlock, self).__init__() 500 | 501 | if relpos: 502 | if keep_abspos: 503 | d_in += 3 504 | else: 505 | d_in += 1 506 | self.embedding = InputEmbedding(d_in, d_model, norm_fn) 507 | self.mask_feature = mask_feature 508 | self.keep_abspos = keep_abspos 509 | self.topk = topk 510 | if resblock: 511 | self.resblock = FFResblock(d_model, d_ff, norm_fn) 512 | 513 | def forward(self, x, origin=None, direction=None): 514 | batch_size, items, features = x.size() 515 | 516 | if origin is not None: 517 | _, agents, _ = origin.size() 518 | 519 | pos = x[:, :, 0:2] 520 | relpos = spatial.relative_positions(origin, direction, pos) 521 | dist = relpos.norm(p=2, dim=3) 522 | direction = relpos / (dist.unsqueeze(-1) + 1e-8) 523 | 524 | x = x.view(batch_size, 1, items, features)\ 525 | .expand(batch_size, agents, items, features) 526 | if self.keep_abspos: 527 | x = torch.cat([x, direction, torch.sqrt(dist.unsqueeze(-1))], dim=3) 528 | else: 529 | x = torch.cat([direction, x[:, :, :, 2:], torch.sqrt(dist.unsqueeze(-1))], dim=3) 530 | 531 | if self.topk is not None: 532 | empty = (x[:, :, :, self.mask_feature] == 0).float() 533 | key = -dist - empty * 1e8 534 | x = topk_by(values=x, vdim=2, keys=key, kdim=2, k=self.topk) 535 | relpos = topk_by(values=relpos, vdim=2, keys=key, kdim=2, k=self.topk) 536 | 537 | mask = x[:, :, :, self.mask_feature] == 0 538 | else: 539 | relpos = None 540 | mask = x[:, :, self.mask_feature] == 0 541 | 542 | x = self.embedding(x, ~mask) 543 | if self.resblock is not None: 544 | x = self.resblock(x) 545 | x = x * (~mask).unsqueeze(-1).float() 546 | 547 | return x, relpos, mask 548 | 549 | -------------------------------------------------------------------------------- /policy_t4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as distributions 5 | 6 | from gather import topk_by 7 | from multihead_attention import MultiheadAttention 8 | import spatial 9 | 10 | 11 | class TransformerPolicy4(nn.Module): 12 | def __init__(self, hps, obs_config): 13 | super(TransformerPolicy4, self).__init__() 14 | assert obs_config.drones > 0 or obs_config.minerals > 0,\ 15 | 'Must have at least one mineral or drones observation' 16 | assert obs_config.drones >= obs_config.allies 17 | assert not hps.use_privileged or (hps.nmineral > 0 and hps.nally > 0 and (hps.nenemy > 0 or hps.ally_enemy_same)) 18 | 19 | self.version = 'transformer_v4' 20 | 21 | self.kwargs = dict( 22 | hps=hps, 23 | obs_config=obs_config 24 | ) 25 | 26 | self.hps = hps 27 | self.obs_config = obs_config 28 | self.agents = hps.agents 29 | self.nally = hps.nally 30 | self.nenemy = hps.nenemy 31 | self.nmineral = hps.nmineral 32 | self.nconstant = hps.nconstant 33 | self.nitem = hps.nally + hps.nenemy + hps.nmineral + hps.nconstant 34 | self.fp16 = hps.fp16 35 | self.d_agent = hps.d_agent 36 | self.d_item = hps.d_item 37 | self.naction = hps.objective.naction() 38 | 39 | if hasattr(obs_config, 'global_drones'): 40 | self.global_drones = obs_config.global_drones 41 | else: 42 | self.global_drones = 0 43 | 44 | if hps.norm == 'none': 45 | norm_fn = lambda x: nn.Sequential() 46 | elif hps.norm == 'batchnorm': 47 | norm_fn = lambda n: nn.BatchNorm2d(n) 48 | elif hps.norm == 'layernorm': 49 | norm_fn = lambda n: nn.LayerNorm(n) 50 | else: 51 | raise Exception(f'Unexpected normalization layer {hps.norm}') 52 | 53 | self.agent_embedding = ItemBlock( 54 | obs_config.dstride() + obs_config.global_features(), 55 | hps.d_agent, hps.d_agent * hps.dff_ratio, norm_fn, True, 56 | keep_abspos=True, 57 | mask_feature=7, # Feature 7 is hitpoints 58 | relpos=False, 59 | ) 60 | if hps.ally_enemy_same: 61 | self.drone_net = ItemBlock( 62 | obs_config.dstride(), 63 | hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 64 | keep_abspos=hps.obs_keep_abspos, 65 | mask_feature=7, # Feature 7 is hitpoints 66 | topk=hps.nally+hps.nenemy, 67 | ) 68 | else: 69 | self.ally_net = ItemBlock( 70 | obs_config.dstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 71 | keep_abspos=hps.obs_keep_abspos, 72 | mask_feature=7, # Feature 7 is hitpoints 73 | topk=hps.nally, 74 | ) 75 | self.enemy_net = ItemBlock( 76 | obs_config.dstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 77 | keep_abspos=hps.obs_keep_abspos, 78 | mask_feature=7, # Feature 7 is hitpoints 79 | topk=hps.nenemy, 80 | ) 81 | self.mineral_net = ItemBlock( 82 | obs_config.mstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 83 | keep_abspos=hps.obs_keep_abspos, 84 | mask_feature=2, # Feature 2 is size 85 | topk=hps.nmineral, 86 | ) 87 | if hps.nconstant > 0: 88 | self.constant_items = nn.Parameter(torch.normal(0, 1, (hps.nconstant, hps.d_item))) 89 | 90 | if hps.use_privileged: 91 | self.pmineral_net = ItemBlock( 92 | obs_config.mstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 93 | keep_abspos=True, relpos=False, mask_feature=2, 94 | ) 95 | if hps.ally_enemy_same: 96 | self.pdrone_net = ItemBlock( 97 | obs_config.dstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 98 | keep_abspos=True, relpos=False, mask_feature=7, 99 | ) 100 | else: 101 | self.pally_net = ItemBlock( 102 | obs_config.dstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 103 | keep_abspos=True, relpos=False, mask_feature=7, 104 | ) 105 | self.penemy_net = ItemBlock( 106 | obs_config.dstride(), hps.d_item, hps.d_item * hps.dff_ratio, norm_fn, hps.item_ff, 107 | keep_abspos=True, relpos=False, mask_feature=7, 108 | ) 109 | 110 | if hps.item_item_attn_layers > 0: 111 | encoder_layer = nn.TransformerEncoderLayer(d_model=hps.d_item, nhead=8) 112 | self.item_item_attn = nn.TransformerEncoder(encoder_layer, num_layers=hps.item_item_attn_layers) 113 | else: 114 | self.item_item_attn = None 115 | 116 | self.multihead_attention = MultiheadAttention( 117 | embed_dim=hps.d_agent, 118 | kdim=hps.d_item, 119 | vdim=hps.d_item, 120 | num_heads=hps.nhead, 121 | dropout=hps.dropout, 122 | ) 123 | self.linear1 = nn.Linear(hps.d_agent, hps.d_agent * hps.dff_ratio) 124 | self.linear2 = nn.Linear(hps.d_agent * hps.dff_ratio, hps.d_agent) 125 | self.norm1 = nn.LayerNorm(hps.d_agent) 126 | self.norm2 = nn.LayerNorm(hps.d_agent) 127 | 128 | self.map_channels = hps.d_agent // (hps.nm_nrings * hps.nm_nrays) 129 | map_item_channels = self.map_channels - 2 if self.hps.map_embed_offset else self.map_channels 130 | self.downscale = nn.Linear(hps.d_item, map_item_channels) 131 | self.norm_map = norm_fn(map_item_channels) 132 | self.conv1 = spatial.ZeroPaddedCylindricalConv2d( 133 | self.map_channels, hps.dff_ratio * self.map_channels, kernel_size=3) 134 | self.conv2 = spatial.ZeroPaddedCylindricalConv2d( 135 | hps.dff_ratio * self.map_channels, self.map_channels, kernel_size=3) 136 | self.norm_conv = norm_fn(self.map_channels) 137 | 138 | final_width = hps.d_agent 139 | if hps.nearby_map: 140 | final_width += hps.d_agent 141 | self.final_layer = nn.Sequential( 142 | nn.Linear(final_width, hps.d_agent * hps.dff_ratio), 143 | nn.ReLU(), 144 | ) 145 | 146 | self.policy_head = nn.Linear(hps.d_agent * hps.dff_ratio, self.naction) 147 | if hps.small_init_pi: 148 | self.policy_head.weight.data *= 0.01 149 | self.policy_head.bias.data.fill_(0.0) 150 | 151 | if hps.use_privileged: 152 | self.value_head = nn.Linear(hps.d_agent * hps.dff_ratio + 2 * hps.d_item, 1) 153 | else: 154 | self.value_head = nn.Linear(hps.d_agent * hps.dff_ratio, 1) 155 | if hps.zero_init_vf: 156 | self.value_head.weight.data.fill_(0.0) 157 | self.value_head.bias.data.fill_(0.0) 158 | 159 | self.epsilon = 1e-4 if hps.fp16 else 1e-8 160 | 161 | def evaluate(self, observation, action_masks, privileged_obs): 162 | if self.fp16: 163 | action_masks = action_masks.half() 164 | action_masks = action_masks[:, :self.agents, :] 165 | probs, v = self.forward(observation, privileged_obs) 166 | probs = probs.view(-1, self.agents, self.naction) 167 | if action_masks.size(2) != self.naction: 168 | nbatch, nagent, naction = action_masks.size() 169 | zeros = torch.zeros(nbatch, nagent, self.naction - naction).to(observation.device) 170 | action_masks = torch.cat([action_masks, zeros], dim=2) 171 | probs = probs * action_masks + self.epsilon # Add small value to prevent crash when no action is possible 172 | # We get device-side assert when using fp16 here (needs more investigation) 173 | action_dist = distributions.Categorical(probs.float() if self.fp16 else probs) 174 | actions = action_dist.sample() 175 | entropy = action_dist.entropy()[action_masks.sum(2) != 0] 176 | return actions, action_dist.log_prob(actions), entropy, v.detach().view(-1).cpu().numpy(), probs.detach().cpu().numpy() 177 | 178 | def backprop(self, 179 | hps, 180 | obs, 181 | actions, 182 | old_logprobs, 183 | returns, 184 | value_loss_scale, 185 | advantages, 186 | old_values, 187 | action_masks, 188 | old_probs, 189 | privileged_obs, 190 | split_reward): 191 | if self.fp16: 192 | advantages = advantages.half() 193 | returns = returns.half() 194 | action_masks = action_masks.half() 195 | old_logprobs = old_logprobs.half() 196 | 197 | action_masks = action_masks[:, :self.agents, :] 198 | x, (pitems, pmask) = self.latents(obs, privileged_obs) 199 | batch_size = x.size()[0] 200 | 201 | vin = x.max(dim=1).values.view(batch_size, self.d_agent * self.hps.dff_ratio) 202 | if self.hps.use_privileged: 203 | pitems_max = pitems.max(dim=1).values 204 | pitems_avg = pitems.sum(dim=1) / torch.clamp_min((~pmask).float().sum(dim=1), min=1).unsqueeze(-1) 205 | vin = torch.cat([vin, pitems_max, pitems_avg], dim=1) 206 | values = self.value_head(vin).view(-1) 207 | 208 | logits = self.policy_head(x) 209 | probs = F.softmax(logits, dim=2) 210 | probs = probs.view(-1, self.agents, self.naction) 211 | 212 | # add small value to prevent degenerate probability distribution when no action is possible 213 | # gradients still get blocked by the action mask 214 | # TODO: mask actions by setting logits to -inf? 215 | probs = probs * action_masks + self.epsilon 216 | 217 | active_agents = torch.clamp_min((action_masks.sum(dim=2) > 0).float().sum(dim=1), min=1) 218 | 219 | dist = distributions.Categorical(probs) 220 | entropy = dist.entropy() 221 | logprobs = dist.log_prob(actions) 222 | ratios = torch.exp(logprobs - old_logprobs) 223 | advantages = advantages.view(-1, 1) 224 | if split_reward: 225 | advantages = advantages / active_agents.view(-1, 1) 226 | vanilla_policy_loss = advantages * ratios 227 | clipped_policy_loss = advantages * torch.clamp(ratios, 1 - hps.cliprange, 1 + hps.cliprange) 228 | if hps.ppo: 229 | policy_loss = -torch.min(vanilla_policy_loss, clipped_policy_loss).mean() 230 | else: 231 | policy_loss = -vanilla_policy_loss.mean() 232 | 233 | # TODO: do over full distribution, not just selected actions? 234 | approxkl = 0.5 * (old_logprobs - logprobs).pow(2).mean() 235 | clipfrac = ((ratios - 1.0).abs() > hps.cliprange).sum().type(torch.float32) / ratios.numel() 236 | 237 | clipped_values = old_values + torch.clamp(values - old_values, -hps.cliprange, hps.cliprange) 238 | vanilla_value_loss = (values - returns) ** 2 239 | clipped_value_loss = (clipped_values - returns) ** 2 240 | if hps.clip_vf: 241 | value_loss = torch.max(vanilla_value_loss, clipped_value_loss).mean() 242 | else: 243 | value_loss = vanilla_value_loss.mean() 244 | 245 | entropy_loss = -hps.entropy_bonus * entropy.mean() 246 | 247 | loss = policy_loss + value_loss_scale * value_loss + entropy_loss 248 | loss /= hps.batches_per_update 249 | loss.backward() 250 | return policy_loss.data.tolist(), value_loss.data.tolist(), approxkl.data.tolist(), clipfrac.data.tolist() 251 | 252 | def forward(self, x, x_privileged): 253 | batch_size = x.size()[0] 254 | x, (pitems, pmask) = self.latents(x, x_privileged) 255 | 256 | vin = x.max(dim=1).values.view(batch_size, self.d_agent * self.hps.dff_ratio) 257 | if self.hps.use_privileged: 258 | pitems_max = pitems.max(dim=1).values 259 | pitems_avg = pitems.sum(dim=1) / torch.clamp_min((~pmask).float().sum(dim=1), min=1).unsqueeze(-1) 260 | vin = torch.cat([vin, pitems_max, pitems_avg], dim=1) 261 | values = self.value_head(vin).view(-1) 262 | 263 | logits = self.policy_head(x) 264 | probs = F.softmax(logits, dim=2) 265 | 266 | # return probs.view(batch_size, 8, self.allies).permute(0, 2, 1), values 267 | return probs, values 268 | 269 | def logits(self, x, x_privileged): 270 | x, x_privileged = self.latents(x, x_privileged) 271 | return self.policy_head(x) 272 | 273 | def latents(self, x, x_privileged): 274 | if self.fp16: 275 | # Normalization layers perform fp16 conversion for x after normalization 276 | x_privileged = x_privileged.half() 277 | 278 | batch_size = x.size()[0] 279 | 280 | endglobals = self.obs_config.endglobals() 281 | endallies = self.obs_config.endallies() 282 | endenemies = self.obs_config.endenemies() 283 | endmins = self.obs_config.endmins() 284 | endallenemies = self.obs_config.endallenemies() 285 | 286 | globals = x[:, :endglobals] 287 | 288 | # properties of the drone controlled by this network 289 | xagent = x[:, endglobals:endallies]\ 290 | .view(batch_size, self.obs_config.allies, self.obs_config.dstride())[:, :self.agents, :] 291 | globals = globals.view(batch_size, 1, self.obs_config.global_features()) \ 292 | .expand(batch_size, self.agents, self.obs_config.global_features()) 293 | xagent = torch.cat([xagent, globals], dim=2) 294 | agents, _, mask_agent = self.agent_embedding(xagent) 295 | 296 | origin = xagent[:, :, 0:2].clone() 297 | direction = xagent[:, :, 2:4].clone() 298 | 299 | if self.hps.ally_enemy_same: 300 | xdrone = x[:, endglobals:endenemies].view(batch_size, self.obs_config.drones, self.obs_config.dstride()) 301 | items, relpos, mask = self.drone_net(xdrone, origin, direction) 302 | else: 303 | xally = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, self.obs_config.dstride()) 304 | items, relpos, mask = self.ally_net(xally, origin, direction) 305 | # Ensure that at least one item is not masked out to prevent NaN in transformer softmax 306 | mask[:, :, 0] = 0 307 | 308 | if self.nenemy > 0 and not self.hps.ally_enemy_same: 309 | eobs = self.obs_config.drones - self.obs_config.allies 310 | xe = x[:, endallies:endenemies].view(batch_size, eobs, self.obs_config.dstride()) 311 | 312 | items_e, relpos_e, mask_e = self.enemy_net(xe, origin, direction) 313 | items = torch.cat([items, items_e], dim=2) 314 | mask = torch.cat([mask, mask_e], dim=2) 315 | relpos = torch.cat([relpos, relpos_e], dim=2) 316 | 317 | if self.nmineral > 0: 318 | xm = x[:, endenemies:endmins].view(batch_size, self.obs_config.minerals, self.obs_config.mstride()) 319 | 320 | items_m, relpos_m, mask_m = self.mineral_net(xm, origin, direction) 321 | items = torch.cat([items, items_m], dim=2) 322 | mask = torch.cat([mask, mask_m], dim=2) 323 | relpos = torch.cat([relpos, relpos_m], dim=2) 324 | 325 | if self.nconstant > 0: 326 | items_c = self.constant_items\ 327 | .view(1, 1, self.nconstant, self.hps.d_item)\ 328 | .repeat((batch_size, self.agents, 1, 1)) 329 | mask_c = torch.zeros(batch_size, self.agents, self.nconstant).bool().to(x.device) 330 | items = torch.cat([items, items_c], dim=2) 331 | mask = torch.cat([mask, mask_c], dim=2) 332 | 333 | if self.hps.use_privileged: 334 | xally = x[:, endglobals:endallies].view(batch_size, self.obs_config.allies, self.obs_config.dstride()) 335 | eobs = self.obs_config.drones - self.obs_config.allies 336 | xenemy = x[:, endmins:endallenemies].view(batch_size, eobs, self.obs_config.dstride()) 337 | if self.hps.ally_enemy_same: 338 | xdrone = torch.cat([xally, xenemy], dim=1) 339 | pitems, _, pmask = self.pdrone_net(xdrone) 340 | else: 341 | pitems, _, pmask = self.pally_net(xally) 342 | pitems_e, _, pmask_e = self.penemy_net(xenemy) 343 | pitems = torch.cat([pitems, pitems_e], dim=1) 344 | pmask = torch.cat([pmask, pmask_e], dim=1) 345 | xm = x[:, endenemies:endmins].view(batch_size, self.obs_config.minerals, self.obs_config.mstride()) 346 | pitems_m, _, pmask_m = self.pmineral_net(xm) 347 | pitems = torch.cat([pitems, pitems_m], dim=1) 348 | pmask = torch.cat([pmask, pmask_m], dim=1) 349 | if self.item_item_attn: 350 | pmask_nonzero = pmask.clone() 351 | pmask_nonzero[:, 0] = False 352 | pitems = self.item_item_attn( 353 | pitems.permute(1, 0, 2), 354 | src_key_padding_mask=pmask_nonzero, 355 | ).permute(1, 0, 2) 356 | if (pitems != pitems).sum() > 0: 357 | print(pmask) 358 | print(pitems) 359 | raise Exception("NaN!") 360 | else: 361 | pitems = None 362 | pmask = None 363 | 364 | # Transformer input dimensions are: Sequence length, Batch size, Embedding size 365 | source = items.view(batch_size * self.agents, self.nitem, self.d_item).permute(1, 0, 2) 366 | target = agents.view(1, batch_size * self.agents, self.d_agent) 367 | x, attn_weights = self.multihead_attention( 368 | query=target, 369 | key=source, 370 | value=source, 371 | key_padding_mask=mask.view(batch_size * self.agents, self.nitem), 372 | ) 373 | x = self.norm1(x + target) 374 | x2 = self.linear2(F.relu(self.linear1(x))) 375 | x = self.norm2(x + x2) 376 | x = x.view(batch_size, self.agents, self.d_agent) 377 | 378 | if self.hps.nearby_map: 379 | items = self.norm_map(F.relu(self.downscale(items))) 380 | items = items * (1 - mask.float().unsqueeze(-1)) 381 | nearby_map = spatial.spatial_scatter( 382 | items=items[:, :, :(self.nitem - self.nconstant), :], 383 | positions=relpos, 384 | nray=self.hps.nm_nrays, 385 | nring=self.hps.nm_nrings, 386 | inner_radius=self.hps.nm_ring_width, 387 | embed_offsets=self.hps.map_embed_offset, 388 | ).view(batch_size * self.agents, self.map_channels, self.hps.nm_nrings, self.hps.nm_nrays) 389 | if self.hps.map_conv: 390 | nearby_map2 = self.conv2(F.relu(self.conv1(nearby_map))) 391 | nearby_map2 = nearby_map2.permute(0, 3, 2, 1) 392 | nearby_map = nearby_map.permute(0, 3, 2, 1) 393 | nearby_map = self.norm_conv(nearby_map + nearby_map2) 394 | nearby_map = nearby_map.reshape(batch_size, self.agents, self.d_agent) 395 | x = torch.cat([x, nearby_map], dim=2) 396 | 397 | x = self.final_layer(x) 398 | x = x.view(batch_size, self.agents, self.d_agent * self.hps.dff_ratio) 399 | x = x * (~mask_agent).float().unsqueeze(-1) 400 | 401 | return x, (pitems, pmask) 402 | 403 | def param_groups(self): 404 | # TODO? 405 | pass 406 | 407 | 408 | # Computes a running mean/variance of input features and performs normalization. 409 | # https://www.johndcook.com/blog/standard_deviation/ 410 | class InputNorm(nn.Module): 411 | def __init__(self, num_features, cliprange=5): 412 | super(InputNorm, self).__init__() 413 | 414 | self.cliprange = cliprange 415 | self.register_buffer('count', torch.tensor(0)) 416 | self.register_buffer('mean', torch.zeros(num_features)) 417 | self.register_buffer('squares_sum', torch.zeros(num_features)) 418 | self.fp16 = False 419 | 420 | def update(self, input, mask): 421 | if mask is not None: 422 | if len(input.size()) == 4: 423 | batch_size, nally, nitem, features = input.size() 424 | assert (batch_size, nally, nitem) == mask.size() 425 | elif len(input.size()) == 3: 426 | batch_size, nally, features = input.size() 427 | assert (batch_size, nally) == mask.size() 428 | else: 429 | raise Exception(f'Expecting 3 or 4 dimensions, actual: {len(input.size())}') 430 | input = input[mask, :] 431 | else: 432 | features = input.size()[-1] 433 | input = input.reshape(-1, features) 434 | 435 | count = input.numel() / features 436 | if count == 0: 437 | return 438 | mean = input.mean(dim=0) 439 | if self.count == 0: 440 | self.count += count 441 | self.mean = mean 442 | self.squares_sum = ((input - mean) * (input - mean)).sum(dim=0) 443 | else: 444 | self.count += count 445 | new_mean = self.mean + (mean - self.mean) * count / self.count 446 | # This is probably not quite right because it applies multiple updates simultaneously. 447 | self.squares_sum = self.squares_sum + ((input - self.mean) * (input - new_mean)).sum(dim=0) 448 | self.mean = new_mean 449 | 450 | def forward(self, input, mask=None): 451 | with torch.no_grad(): 452 | if self.training: 453 | self.update(input, mask=mask) 454 | if self.count > 1: 455 | input = (input - self.mean) / self.stddev() 456 | input = torch.clamp(input, -self.cliprange, self.cliprange) 457 | if (input == float('-inf')).sum() > 0 \ 458 | or (input == float('inf')).sum() > 0 \ 459 | or (input != input).sum() > 0: 460 | print(input) 461 | print(self.squares_sum) 462 | print(self.stddev()) 463 | print(input) 464 | raise Exception("OVER/UNDERFLOW DETECTED!") 465 | 466 | return input.half() if self.fp16 else input 467 | 468 | def enable_fp16(self): 469 | # Convert buffers back to fp32, fp16 has insufficient precision and runs into overflow on squares_sum 470 | self.float() 471 | self.fp16 = True 472 | 473 | def stddev(self): 474 | sd = torch.sqrt(self.squares_sum / (self.count - 1)) 475 | sd[sd == 0] = 1 476 | return sd 477 | 478 | 479 | class InputEmbedding(nn.Module): 480 | def __init__(self, d_in, d_model, norm_fn): 481 | super(InputEmbedding, self).__init__() 482 | 483 | self.normalize = InputNorm(d_in) 484 | self.linear = nn.Linear(d_in, d_model) 485 | self.norm = norm_fn(d_model) 486 | 487 | def forward(self, x, mask=None): 488 | x = self.normalize(x, mask) 489 | x = F.relu(self.linear(x)) 490 | x = self.norm(x) 491 | return x 492 | 493 | 494 | class FFResblock(nn.Module): 495 | def __init__(self, d_model, d_ff, norm_fn): 496 | super(FFResblock, self).__init__() 497 | 498 | self.linear_1 = nn.Linear(d_model, d_ff) 499 | self.linear_2 = nn.Linear(d_ff, d_model) 500 | self.norm = norm_fn(d_model) 501 | 502 | # self.linear_2.weight.data.fill_(0.0) 503 | # self.linear_2.bias.data.fill_(0.0) 504 | 505 | def forward(self, x, mask=None): 506 | x2 = F.relu(self.linear_1(x)) 507 | x = x + F.relu(self.linear_2(x2)) 508 | x = self.norm(x) 509 | return x 510 | 511 | 512 | class ItemBlock(nn.Module): 513 | def __init__(self, d_in, d_model, d_ff, norm_fn, resblock, keep_abspos, mask_feature, relpos=True, topk=None): 514 | super(ItemBlock, self).__init__() 515 | 516 | if relpos: 517 | if keep_abspos: 518 | d_in += 3 519 | else: 520 | d_in += 1 521 | self.embedding = InputEmbedding(d_in, d_model, norm_fn) 522 | self.mask_feature = mask_feature 523 | self.keep_abspos = keep_abspos 524 | self.topk = topk 525 | if resblock: 526 | self.resblock = FFResblock(d_model, d_ff, norm_fn) 527 | 528 | def forward(self, x, origin=None, direction=None): 529 | batch_size, items, features = x.size() 530 | 531 | if origin is not None: 532 | _, agents, _ = origin.size() 533 | 534 | pos = x[:, :, 0:2] 535 | relpos = spatial.relative_positions(origin, direction, pos) 536 | dist = relpos.norm(p=2, dim=3) 537 | direction = relpos / (dist.unsqueeze(-1) + 1e-8) 538 | 539 | x = x.view(batch_size, 1, items, features)\ 540 | .expand(batch_size, agents, items, features) 541 | if self.keep_abspos: 542 | x = torch.cat([x, direction, torch.sqrt(dist.unsqueeze(-1))], dim=3) 543 | else: 544 | x = torch.cat([direction, x[:, :, :, 2:], torch.sqrt(dist.unsqueeze(-1))], dim=3) 545 | 546 | if self.topk is not None: 547 | empty = (x[:, :, :, self.mask_feature] == 0).float() 548 | key = -dist - empty * 1e8 549 | x = topk_by(values=x, vdim=2, keys=key, kdim=2, k=self.topk) 550 | relpos = topk_by(values=relpos, vdim=2, keys=key, kdim=2, k=self.topk) 551 | 552 | mask = x[:, :, :, self.mask_feature] == 0 553 | else: 554 | relpos = None 555 | mask = x[:, :, self.mask_feature] == 0 556 | 557 | x = self.embedding(x, ~mask) 558 | if self.resblock is not None: 559 | x = self.resblock(x) 560 | x = x * (~mask).unsqueeze(-1).float() 561 | 562 | return x, relpos, mask 563 | 564 | -------------------------------------------------------------------------------- /progress.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.7.5-final" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "python3", 18 | "display_name": "Python 3.7.5 64-bit ('dcc': conda)", 19 | "metadata": { 20 | "interpreter": { 21 | "hash": "d660374ac31277b9ea7ee26abd64205cafc644f5ebc9b3efcbdb7eb83107acd0" 22 | } 23 | } 24 | } 25 | }, 26 | "nbformat": 4, 27 | "nbformat_minor": 2, 28 | "cells": [ 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "from plot_results import fetch_run_data\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "import numpy as np\n", 38 | "import wandb\n", 39 | "from functools import lru_cache\n", 40 | "import matplotlib.dates as mdates\n", 41 | "from datetime import datetime \n" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "\n", 51 | "@lru_cache(maxsize=None)\n", 52 | "def fetch_run_data(descriptor: str, metrics):\n", 53 | " if isinstance(metrics, str):\n", 54 | " metrics = [metrics]\n", 55 | " else:\n", 56 | " metrics = list(metrics)\n", 57 | " api = wandb.Api()\n", 58 | " runs = api.runs(\"cswinter/deep-codecraft-vs\", {\"config.descriptor\": descriptor})\n", 59 | " \n", 60 | " curves = []\n", 61 | " for run in runs:\n", 62 | " step = []\n", 63 | " value = []\n", 64 | " vals = run.history(keys=metrics, samples=100, pandas=False)\n", 65 | " for entry in vals:\n", 66 | " if metrics[0] in entry:\n", 67 | " step.append(entry['_step'] * 1e-6)\n", 68 | " meanvalue = np.array([entry[metric] for metric in metrics]).mean()\n", 69 | " value.append(meanvalue)\n", 70 | " curves.append((np.array(step), np.array(value)))\n", 71 | " return curves, runs[0].summary[\"_timestamp\"]" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "runs = [\n", 81 | " \"154506-agents15-hpsetstandard-steps150e6\",\n", 82 | " \"24e131-agents15-hpsetstandard-steps150e6\",\n", 83 | " \"613056-agents15-hpsetstandard-steps150e6\",\n", 84 | " \"87c1ab-hpsetstandard\",\n", 85 | " \"8af81d-hpsetstandard-num_self_play30-num_vs_aggro_replicator1-num_vs_destroyer2-num_vs_replicator1\",\n", 86 | " \"d33903-batches_per_update32-batches_per_update_schedule-hpsetstandard-lr0.001-lr_schedulecosine-steps150e6\",\n", 87 | " \"49b7fa-entropy_bonus0.02-entropy_bonus_schedulelin 20e6:0.005,60e6:0.0-hpsetstandard\",\n", 88 | " \"49b7fa-feat_dist_to_wallTrue-hpsetstandard\",\n", 89 | " \"b9bab7-hpsetstandard-max_hardness150\",\n", 90 | " \"46e0b2-hpsetstandard-spatial_attnFalse\",\n", 91 | " \"2d9e29-hpsetstandard\",\n", 92 | " \"30ed5b-hpsetstandard-max_hardness175\",\n", 93 | " \"fc244e-hpsetstandard-spatial_attnTrue-spatial_attn_lr_multiplier10.0\",\n", 94 | " \"0a5940-hpsetstandard-item_item_attn_layers1-item_item_spatial_attnTrue-item_item_spatial_attn_vfFalse-max_grad_norm200\",\n", 95 | " \"0a5940-hpsetstandard-mothership_damage_scale4.0-mothership_damage_scale_schedulelin 50e6:1.0,150:0.0\",\n", 96 | " \"83a3af-hpsetstandard-mothership_damage_scale4.0-mothership_damage_scale_schedulelin 50e6:0.0\",\n", 97 | " \"667ac7-hpsetstandard\",\n", 98 | " \"80a87d-entropy_bonus0.15-entropy_bonus_schedulelin 15e6:0.07,60e6:0.0-hpsetstandard\",\n", 99 | " \"80a87d-entropy_bonus0.2-entropy_bonus_schedulelin 15e6:0.1,60e6:0.0-final_lr5e-05-hpsetstandard-lr0.0005-vf_coef1.0\",\n", 100 | " \"c0b3b4-hpsetstandard-partial_score0\",\n", 101 | " \"9fc3de-hpsetstandard\",\n", 102 | " \"9fc3de-adr_hstepsize0.001-hpsetstandard-linear_hardnessFalse\",\n", 103 | " \"ac84c0-gamma0.9997-hpsetstandard\",\n", 104 | " \"a1210b-gamma_schedulecos 1.0-hpsetstandard\",\n", 105 | " \"b9f907-adr_average_cost_target1-hpsetstandard\",\n", 106 | " \"5fb181-hpsetstandard\",\n", 107 | " \"5fb181-hpsetstandard-steps150e6\",\n", 108 | " \"3c69a5-adr_average_cost_target0.5-adr_avg_cost_schedulelin 80e6:1.0-hpsetstandard\",\n", 109 | " \"35b3a7-hpsetstandard-nearby_mapFalse-steps150e6\",\n", 110 | " \"152ec3-hpsetstandard-nearby_mapFalse-steps125e6\",\n", 111 | "]" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": { 118 | "tags": [] 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "fig, ax = plt.subplots(figsize=(19, 10))\n", 123 | "cmap = plt.get_cmap('viridis')\n", 124 | "\n", 125 | "t0 = 1593959023.8568478\n", 126 | "tn = 1607756232\n", 127 | "ts = []\n", 128 | "for ri, run in enumerate(runs):\n", 129 | " #print(f\"Fetching {run}\")\n", 130 | " curves, date = fetch_run_data(run, \"eval_mean_score\")\n", 131 | " samples = []\n", 132 | " values = []\n", 133 | " for curve in curves:\n", 134 | " ax.plot(curve[0], curve[1], color=cmap((date-t0)/(tn-t0)), marker='o')\n", 135 | " for i, value in enumerate(curve[1]):\n", 136 | " if len(values) <= i:\n", 137 | " samples.append(curve[0][i])\n", 138 | " values.append([value])\n", 139 | " else:\n", 140 | " values[i].append(value)\n", 141 | " #values = np.array([np.array(vals).mean() for vals in values])\n", 142 | " #ax.plot(samples, values, color=cmap((date-t0)/(tn-t0)), marker='o')\n", 143 | " #ts.append(mdates.date2num(datetime.fromtimestamp(date)))\n", 144 | "\n", 145 | "from matplotlib.cm import ScalarMappable\n", 146 | "from matplotlib.colors import Normalize\n", 147 | "loc = mdates.AutoDateLocator()\n", 148 | "def dateformatter(x, pos=None):\n", 149 | " return datetime.fromtimestamp(x*(tn-t0)+t0).strftime('%Y-%m-%d')\n", 150 | "fig.colorbar(ScalarMappable(cmap=cmap), ticks=loc, format=dateformatter)\n", 151 | "\n", 152 | "ax.set_yticks([-1.0, -0.5, 0, 0.5, 1])\n", 153 | "ax.set_xlim(0, 200)\n", 154 | "ax.grid()\n", 155 | "fig.show()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "fig, ax = plt.subplots(figsize=(20, 15))\n", 165 | "cmap = plt.get_cmap('viridis')\n", 166 | "\n", 167 | "t0 = 1593959023.8568478\n", 168 | "tn = 1607756232\n", 169 | "ts = []\n", 170 | "for ri, run in enumerate(runs):\n", 171 | " #print(f\"Fetching {run}\")\n", 172 | " curves, date = fetch_run_data(run, \"eval_mean_score\")\n", 173 | " samples = []\n", 174 | " values = []\n", 175 | " for curve in curves:\n", 176 | " for i, value in enumerate(curve[1]):\n", 177 | " if len(values) <= i:\n", 178 | " samples.append(curve[0][i])\n", 179 | " values.append([value])\n", 180 | " else:\n", 181 | " values[i].append(value)\n", 182 | " values = np.array([np.array(vals).mean() for vals in values])\n", 183 | " ax.plot(samples, values)#, color=cmap((date-t0)/(tn-t0)))\n", 184 | " #ts.append(mdates.date2num(datetime.fromtimestamp(date)))\n", 185 | "\n", 186 | "#from matplotlib.cm import ScalarMappable\n", 187 | "#from matplotlib.colors import Normalize\n", 188 | "#loc = mdates.AutoDateLocator()\n", 189 | "#fig.colorbar(ScalarMappable(norm=Normalize(t0, tn), cmap=cmap))#, ticks=loc, format=mdates.AutoDateFormatter(loc))\n", 190 | "\n", 191 | "ax.set(xlabel='million samples', ylim=(-1, 1))\n", 192 | "ax.set_yticks([-1.0, -0.5, 0, 0.5, 1])\n", 193 | "ax.set_xlim(0, 200e6)\n", 194 | "#ax.set_xticks([0, 25, 50, 75, 100, 125])\n", 195 | "ax.legend(loc='upper left')\n", 196 | "ax.grid()\n", 197 | "fig.show()" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "api = wandb.Api()\n", 207 | "runs = api.runs(\"cswinter/deep-codecraft-vs\", {\"config.descriptor\": runs[0]})" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "runs" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "fetch_run_data(runs[-1], 'eval_mean_score')[1]" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "#help(runs[0])\n", 242 | "{metric: values for metric, values in runs[0].summary.items() if metric.startswith('eval')}" 243 | ] 244 | } 245 | ] 246 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | logger==1.4 2 | matplotlib 3 | numpy==1.16.3 4 | orjson==3.0.2 5 | requests==2.22.0 6 | torch==1.6.0 7 | torchprof 8 | wandb==0.10.9 9 | 10 | # torch-scatter cannot be installed with requirements.txt, see https://github.com/rusty1s/pytorch_scatter for installation instructions 11 | # torch-scatter==2.0.5 12 | -------------------------------------------------------------------------------- /reset-drivers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sudo rmmod nvidia_uvm 4 | sudo rmmod nvidia_drm 5 | sudo rmmod nvidia_modeset 6 | sudo rmmod nvidia 7 | sudo modprobe nvidia 8 | sudo modprobe nvidia_modeset 9 | sudo modprobe nvidia_drm 10 | sudo modprobe nvidia_uvm 11 | -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import pathlib 5 | import queue 6 | import subprocess 7 | import threading 8 | import tempfile 9 | import time 10 | import yaml 11 | import click 12 | 13 | 14 | logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO) 15 | 16 | 17 | class JobQueue: 18 | def __init__(self, queue_dir, concurrency, devices, out_dir): 19 | self.queue_dir = queue_dir 20 | self.concurrency = concurrency 21 | self.devices = devices 22 | self.out_dir = out_dir 23 | self.known_jobs = {} 24 | self.queue = queue.Queue() 25 | self.active_jobs = 0 26 | self.active_jobs_per_device = {device: 0 for device in range(devices)} 27 | self.lock = threading.Lock() 28 | self.port_offset = 0 29 | for device in os.environ.get("GPU_DENYLIST", default="").split(","): 30 | if device != '': 31 | self.active_jobs_per_device.pop(int(device)) 32 | devices -= 1 33 | 34 | def run(self): 35 | logging.info(f"Watching {self.queue_dir} for new jobs...") 36 | 37 | while True: 38 | for job_file in os.listdir(self.queue_dir): 39 | if job_file not in self.known_jobs: 40 | if job_file.startswith("."): 41 | logging.info(f"Ignoring hidden file {job_file}") 42 | continue 43 | logging.info(f"Found new job file {job_file}") 44 | self.process_job_file(job_file) 45 | 46 | while self.queue.qsize() > 0: 47 | job = self.queue.get() 48 | required_devices = min(self.devices, job.parallelism) 49 | if job.parallelism > self.devices and job.parallelism % self.devices != 0: 50 | logging.error(f"Can't evenly distribute {job.parallelism} processes across {self.devices} GPUs, dropping job.") 51 | continue 52 | required_slots_per_device = job.parallelism // self.devices if job.parallelism > self.devices else 1 53 | while True: 54 | selected_devices = [] 55 | with self.lock: 56 | min_load = self.concurrency + 1 57 | for device, load in self.active_jobs_per_device.items(): 58 | if load + required_slots_per_device <= self.concurrency // self.devices: 59 | if load < min_load: 60 | selected_devices = [device] 61 | min_load = load 62 | elif load == min_load: 63 | selected_devices.append(device) 64 | if len(selected_devices) >= required_devices: 65 | rank = 0 66 | for device in selected_devices[:required_devices]: 67 | for _ in range(required_slots_per_device): 68 | job_copy = copy.deepcopy(job) 69 | job_copy.set_device(device, rank, 29000 + self.port_offset) 70 | self.active_jobs_per_device[job_copy.device] += 1 71 | threading.Thread(target=self.run_job, args=(job_copy,)).start() 72 | rank += 1 73 | self.active_jobs += 1 74 | self.port_offset = (self.port_offset + 1) % 1000 75 | logging.info(f"In queue: {self.queue.qsize()} Running: {self.active_jobs_per_device}") 76 | break 77 | time.sleep(0.1) 78 | 79 | time.sleep(0.1) 80 | 81 | time.sleep(0.1) 82 | 83 | def run_job(self, job): 84 | try: 85 | with tempfile.TemporaryDirectory() as dir: 86 | 87 | def git(args, workdir=dir): 88 | FNULL = open(os.devnull, 'w') 89 | cmd = ["git"] 90 | if workdir is not None: 91 | cmd.extend(["-C", dir]) 92 | cmd.extend(args) 93 | subprocess.check_call(cmd, stdout=FNULL, stderr=subprocess.STDOUT) 94 | 95 | git(["clone", job.repo_path, dir], workdir=None) 96 | git(["reset", "--hard", "HEAD"]) 97 | git(["clean", "-fd"]) 98 | try: 99 | git(["checkout", job.revision]) 100 | except subprocess.CalledProcessError: 101 | logging.error(f"Failed to checkout revision {job.revision}! Aborting.") 102 | return 103 | 104 | revision = subprocess.check_output( 105 | ["git", "-C", dir, "describe", "--tags", "--always", "--dirty"]).decode("UTF-8")[:-1] 106 | 107 | out_dir = os.path.join(self.out_dir, f'{time.strftime("%Y-%m-%d~%H:%M:%S")}-{revision}') 108 | for name, value in job.params.items(): 109 | out_dir += f"-{name}{value}" 110 | pathlib.Path(out_dir).mkdir(parents=True, exist_ok=True) 111 | 112 | job_desc = f"{job.repo_path} at {job.revision} with {job.params}" 113 | args = [] 114 | for name, value in job.params.items(): 115 | if isinstance(value, bool): 116 | if value: 117 | args.append(f'--{name}') 118 | else: 119 | args.append(f'--no-{name}') 120 | else: 121 | args.append(f"--{name}={value}") 122 | args.append(f"--descriptor={job.descriptor}") 123 | 124 | logpath = os.path.join(out_dir, f"out{job.rank}.txt") 125 | 126 | logging.info(f"Running {job_desc}") 127 | logging.info(f"Output in {logpath}") 128 | 129 | with open(logpath, "w+") as outfile: 130 | retcode = subprocess.call( 131 | ["python3", "main.py", "--out-dir", out_dir] + args, 132 | env=dict( 133 | os.environ, 134 | CUDA_VISIBLE_DEVICES=str(job.device), 135 | MASTER_ADDR='localhost', 136 | MASTER_PORT=str(job.discovery_port), 137 | ), 138 | stdout=outfile, stderr=outfile, cwd=dir 139 | ) 140 | if retcode != 0: 141 | logging.warning(f"Command {job_desc} returned non-zero exit status {retcode}. Logs: {logpath}") 142 | else: 143 | logging.info(f"Success: {job_desc}") 144 | finally: 145 | with self.lock: 146 | if job.rank == 0: 147 | self.active_jobs -= 1 148 | self.known_jobs[job.handle] -= 1 149 | if self.known_jobs[job.handle] == 0: 150 | del self.known_jobs[job.handle] 151 | self.active_jobs_per_device[job.device] -= 1 152 | 153 | def process_job_file(self, job_file): 154 | filepath = os.path.join(self.queue_dir, job_file) 155 | job = yaml.safe_load(open(filepath, "r")) 156 | param_sets = [] 157 | for param_set in job["params"]: 158 | param_sets.extend(self.all_combinations(param_set)) 159 | 160 | logging.info(f"Enqueuing {len(param_sets)} jobs") 161 | self.known_jobs[job_file] = len(param_sets) 162 | 163 | for param_set in param_sets: 164 | self.queue.put(Job(job["repo-path"], job["revision"], param_set, job_file, param_set.get("parallelism", 1))) 165 | os.remove(filepath) 166 | 167 | def all_combinations(self, params_dict): 168 | param_sets = [{}] 169 | if 'repeat' in params_dict: 170 | repetitions = params_dict['repeat'] 171 | del(params_dict['repeat']) 172 | else: 173 | repetitions = 1 174 | 175 | for name, values in params_dict.items(): 176 | if type(values) is list: 177 | new_sets = [] 178 | for value in values: 179 | for param_set in param_sets: 180 | ps = copy.deepcopy(param_set) 181 | ps[name] = value 182 | new_sets.append(ps) 183 | param_sets = new_sets 184 | else: 185 | for param_set in param_sets: 186 | param_set[name] = values 187 | 188 | result = [] 189 | for param_set in param_sets: 190 | for _ in range(repetitions): 191 | result.append(param_set.copy()) 192 | return result 193 | 194 | 195 | class Job: 196 | def __init__(self, repo_path, revision, params, handle, parallelism): 197 | self.repo_path = repo_path 198 | self.revision = revision 199 | self.params = params 200 | self.handle = handle 201 | self.device = None 202 | self.parallelism = parallelism 203 | self.descriptor = "-".join([revision[:6]] + [f'{k}{v}' for k, v in params.items()]) 204 | self.rank = 0 205 | self.discovery_port = None 206 | 207 | def set_device(self, device, rank, discovery_port): 208 | self.device = device 209 | self.rank = rank 210 | self.discovery_port = discovery_port 211 | self.params['device'] = device 212 | self.params['rank'] = rank 213 | 214 | 215 | @click.command() 216 | @click.option("--jobfile-dir", default="/home/clemens/xprun/queue", help="Directory to watch for new job files.") 217 | @click.option("--concurrency", default=8, help="Maximum number of jobs running at the same time.") 218 | @click.option("--out-dir", default="/home/clemens/Dropbox/artifacts/DeepCodeCraft", help="Root of output directories given to jobs.") 219 | def main(jobfile_dir, concurrency, out_dir): 220 | gpus = len(subprocess.check_output(["nvidia-smi", "-L"]).decode("UTF-8").split("\n")) - 1 221 | job_queue = JobQueue(jobfile_dir, concurrency, gpus, out_dir) 222 | job_queue.run() 223 | 224 | 225 | if __name__ == "__main__": 226 | main() 227 | 228 | -------------------------------------------------------------------------------- /schedule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import tempfile 4 | import time 5 | 6 | import click 7 | import yaml 8 | 9 | 10 | @click.command() 11 | @click.option("--repo-path", default="git@github.com:cswinter/DeepCodeCraft.git", help="Path to git code repository to execute.") 12 | @click.option("--revision", default="HEAD", help="Git revision to execute.") 13 | @click.option("--params-file", default=None, help="Path to parameter file.") 14 | @click.option("--hps", default=None, help="List of hyperparameters in format name1:value1,name2:value2") 15 | @click.option("--queue-dir", default="192.168.0.101:/home/clemens/xprun/queue") 16 | def main(repo_path, revision, params_file, hps, queue_dir): 17 | commit = subprocess.check_output(["git", "rev-parse", revision]).decode("UTF-8")[:-1] 18 | 19 | if params_file: 20 | with open(params_file, "r") as f: 21 | params = yaml.safe_load(f) 22 | elif hps: 23 | params = [{}] 24 | for param in hps.split(","): 25 | key, value = param.split(":") 26 | params[0][key] = value 27 | else: 28 | params = [{}] 29 | 30 | job = { 31 | "repo-path": repo_path, 32 | "revision": commit, 33 | "params": params, 34 | } 35 | 36 | fd, path = tempfile.mkstemp() 37 | with open(fd, 'w') as f: 38 | f.write(yaml.dump(job)) 39 | subprocess.check_call(["rsync", path, os.path.join(queue_dir, f"{int(time.time())}.yaml")]) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | 45 | -------------------------------------------------------------------------------- /setup-remote.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euxo pipefail 4 | 5 | 6 | PORT="$1" 7 | INSTANCE="$2" 8 | 9 | rsync -azP -e "ssh -p $PORT" . root@ssh$INSTANCE.vast.ai:src/DeepCodeCraft 10 | rsync -azP -e "ssh -p $PORT" ../CodeCraftServer/ root@ssh$INSTANCE.vast.ai:src/CodeCraftServer 11 | rsync -azP -e "ssh -p $PORT" /home/clemens/Dropbox/artifacts/DeepCodeCraft/golden-models/standard/noble-sky-145M.pt root@ssh$INSTANCE.vast.ai:/home/clemens/Dropbox/artifacts/DeepCodeCraft/golden-models/standard/ 12 | rsync -azP -e "ssh -p $PORT" /home/clemens/Dropbox/artifacts/DeepCodeCraft/golden-models/standard/radiant-sun-35M.pt root@ssh$INSTANCE.vast.ai:/home/clemens/Dropbox/artifacts/DeepCodeCraft/golden-models/standard/ 13 | -------------------------------------------------------------------------------- /setup-system.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euxo pipefail 4 | 5 | mkdir src 6 | mkdir -p /home/clemens/Dropbox/artifacts/DeepCodeCraft/golden-models/standard 7 | mkdir -p /home/clemens/xprun/queue 8 | 9 | apt-get update 10 | apt-get install --yes gnupg curl software-properties-common htop git rsync vim g++ 11 | 12 | #add-apt-repository ppa:graphics-drivers --yes 13 | #apt-get update 14 | #apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub 15 | #bash -c 'echo "deb http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list' 16 | #bash -c 'echo "deb http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda_learn.list' 17 | #apt-get update 18 | #apt-get install --yes cuda-10-1 libcudnn7 19 | 20 | echo "deb https://dl.bintray.com/sbt/debian /" | tee -a /etc/apt/sources.list.d/sbt.list 21 | curl -sL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x2EE0EA64E40A89B84B2DF73499E82A75642AC823" | apt-key add 22 | apt-get update 23 | apt-get install --yes openjdk-8-jdk sbt=0.13.16 24 | 25 | pip install torch-scatter==2.0.5+cu101 -f https://pytorch-geometric.com/whl/torch-1.6.0.html 26 | 27 | cd src 28 | git clone https://github.com/cswinter/CodeCraftGame.git 29 | git checkout deepcodecraft 30 | cd CodeCraftGame 31 | sbt publishLocal 32 | 33 | -------------------------------------------------------------------------------- /showmatch.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | import numpy as np 4 | 5 | from main import load_policy, eval 6 | from gym_codecraft import envs 7 | 8 | 9 | @click.command() 10 | @click.argument('model_paths', nargs=-1) 11 | @click.option('--task', default='ARENA_TINY_2V2') 12 | @click.option('--randomize/--no-randomize', default=False) 13 | @click.option('--hardness', default=0) 14 | @click.option('--num_envs', default=4) 15 | @click.option('--symmetric/--no-symmetric', default=True) 16 | @click.option('--random_rules', default=0.0) 17 | def showmatch(model_paths, task, randomize, hardness, num_envs, symmetric, random_rules): 18 | if torch.cuda.is_available(): 19 | device = torch.device("cuda:0") 20 | else: 21 | print("Running on CPU") 22 | device = "cpu" 23 | 24 | if len(model_paths) == 1: 25 | opponents = None 26 | elif len(model_paths) == 2: 27 | opponents = {'player2': {'model_file': model_paths[1]}} 28 | else: 29 | raise Exception("Invalid args") 30 | objective = envs.Objective(task) 31 | policy1, _, _, _, _ = load_policy(model_paths[0], device) 32 | eval( 33 | policy=policy1, 34 | num_envs=num_envs, 35 | device=device, 36 | objective=objective, 37 | eval_steps=int(1e20), 38 | opponents=opponents, 39 | printerval=500, 40 | randomize=randomize, 41 | hardness=hardness, 42 | symmetric=symmetric, 43 | random_rules=random_rules, 44 | ) 45 | 46 | 47 | if __name__ == "__main__": 48 | showmatch() 49 | -------------------------------------------------------------------------------- /spatial.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_scatter import scatter_add, scatter_mean 6 | 7 | 8 | # N: Batch size 9 | # L_s: number of controllable drones 10 | # L: max number of visible objects 11 | # C: number of channels/features on each object 12 | def relative_positions( 13 | origin, # (N, L_s, 2) 14 | direction, # (N, L_s, 2) 15 | positions, # (N, L, 2) 16 | ): # (N, L_s, L, 2) 17 | n, ls, _ = origin.size() 18 | _, l, _ = positions.size() 19 | 20 | origin = origin.view(n, ls, 1, 2) 21 | direction = direction.view(n, ls, 1, 2) 22 | positions = positions.view(n, 1, l, 2) 23 | 24 | positions = positions - origin 25 | 26 | angle = -torch.atan2(direction[:, :, :, 1], direction[:, :, :, 0]) 27 | rotation = torch.cat( 28 | [ 29 | torch.cat( 30 | [angle.cos().view(n, ls, 1, 1, 1), angle.sin().view(n, ls, 1, 1, 1)], 31 | dim=3, 32 | ), 33 | torch.cat( 34 | [-angle.sin().view(n, ls, 1, 1, 1), angle.cos().view(n, ls, 1, 1, 1)], 35 | dim=3, 36 | ), 37 | ], 38 | dim=4, 39 | ) 40 | 41 | positions_rotated = torch.matmul(rotation, positions.view(n, ls, l, 2, 1)).view(n, ls, l, 2) 42 | 43 | return positions_rotated 44 | 45 | 46 | def polar_indices( 47 | positions, # (N, L_s, L, 2) 48 | nray, 49 | nring, 50 | inner_radius 51 | ): # (N, L_s, L), (N, L_s, L), (N, L_s, L), (N, L_s, L) 52 | distances = torch.sqrt(positions[:, :, :, 0] ** 2 + positions[:, :, :, 1] ** 2) 53 | distance_indices = torch.clamp(distances / inner_radius, min=0, max=nring-1).floor().long() 54 | angles = torch.atan2(positions[:, :, :, 1], positions[:, :, :, 0]) + math.pi 55 | # There is one angle value that can result in index of exactly nray, clamp it to nray-1 56 | angular_indices = torch.clamp_max((angles / (2 * math.pi) * nray).floor().long(), nray-1) 57 | 58 | distance_offsets = torch.clamp_max(distances / inner_radius - distance_indices.float() - 0.5, max=2) 59 | angular_offsets = angles / (2 * math.pi) * nray - angular_indices.float() - 0.5 60 | 61 | assert angular_indices.min() >= 0, f'Negative angular index: {angular_indices.min()}' 62 | assert angular_indices.max() < nray, f'invalid angular index: {angular_indices.max()} >= {nray}' 63 | assert distance_indices.min() >= 0, f'Negative distance index: {distance_indices.min()}' 64 | assert distance_indices.max() < nring, f'invalid distance index: {distance_indices.max()} >= {nring}' 65 | 66 | return distance_indices, angular_indices, distance_offsets, angular_offsets 67 | 68 | 69 | # N: Batch size 70 | # L: max number of visible objects 71 | # C: number of channels/features on each object 72 | def unbatched_relative_positions( 73 | origin, # (N, 2) 74 | direction, # (N, 2) 75 | positions, # (N, L, 2) 76 | rotate: bool = True, 77 | ): # (N, L, 2) 78 | n, _ = origin.size() 79 | _, l, _ = positions.size() 80 | 81 | origin = origin.view(n, 1, 2) 82 | direction = direction.view(n, 1, 2) 83 | positions = positions.view(n, l, 2) 84 | 85 | positions = positions - origin 86 | 87 | if not rotate: 88 | return positions 89 | 90 | angle = -torch.atan2(direction[:, :, 1], direction[:, :, 0]) 91 | rotation = torch.cat( 92 | [ 93 | torch.cat( 94 | [angle.cos().view(n, 1, 1, 1), angle.sin().view(n, 1, 1, 1)], 95 | dim=2, 96 | ), 97 | torch.cat( 98 | [-angle.sin().view(n, 1, 1, 1), angle.cos().view(n, 1, 1, 1)], 99 | dim=2, 100 | ), 101 | ], 102 | dim=3, 103 | ) 104 | 105 | positions_rotated = torch.matmul(rotation, positions.view(n, l, 2, 1)).view(n, l, 2) 106 | 107 | return positions_rotated 108 | 109 | 110 | def varlength_polar_indices( 111 | positions, # (N, L_s, L, 2) 112 | indices, 113 | nray, 114 | nring, 115 | inner_radius 116 | ): # (N, L_s, L), (N, L_s, L), (N, L_s, L), (N, L_s, L) 117 | distances = torch.sqrt(positions[:, :, :, 0] ** 2 + positions[:, :, :, 1] ** 2) 118 | distance_indices = torch.clamp(distances / inner_radius, min=0, max=nring-1).floor().long() 119 | angles = torch.atan2(positions[:, :, :, 1], positions[:, :, :, 0]) + math.pi 120 | # There is one angle value that can result in index of exactly nray, clamp it to nray-1 121 | angular_indices = torch.clamp_max((angles / (2 * math.pi) * nray).floor().long(), nray-1) 122 | 123 | distance_offsets = torch.clamp_max(distances / inner_radius - distance_indices.float() - 0.5, max=2) 124 | angular_offsets = angles / (2 * math.pi) * nray - angular_indices.float() - 0.5 125 | 126 | assert angular_indices.min() >= 0, f'Negative angular index: {angular_indices.min()}' 127 | assert angular_indices.max() < nray, f'invalid angular index: {angular_indices.max()} >= {nray}' 128 | assert distance_indices.min() >= 0, f'Negative distance index: {distance_indices.min()}' 129 | assert distance_indices.max() < nring, f'invalid distance index: {distance_indices.max()} >= {nring}' 130 | 131 | return distance_indices, angular_indices, distance_offsets, angular_offsets 132 | 133 | 134 | def spatial_scatter( 135 | items, # (N, L_s, L, C) 136 | positions, # (N, L_s, L, 2) 137 | nray, 138 | nring, 139 | inner_radius, 140 | embed_offsets=False, 141 | ): # (N, L_s, C', nring, nray) where C' = C + 2 if embed_offsets else C 142 | n, ls, l, c = items.size() 143 | assert (n, ls, l, 2) == positions.size(), f'Expect size {(n, ls, l, 2)} for positions, actual: {positions.size()}' 144 | 145 | distance_index, angular_index, distance_offsets, angular_offsets = \ 146 | polar_indices(positions, nray, nring, inner_radius) 147 | index = distance_index * nray + angular_index 148 | index = index.unsqueeze(-1) 149 | scattered_items = scatter_add(items, index, dim=2, dim_size=nray * nring) \ 150 | .permute(0, 1, 3, 2) \ 151 | .reshape(n, ls, c, nring, nray) 152 | 153 | if embed_offsets: 154 | offsets = torch.cat([distance_offsets.unsqueeze(-1), angular_offsets.unsqueeze(-1)], dim=3) 155 | scattered_nonshared = scatter_mean(offsets, index, dim=2, dim_size=nray * nring) \ 156 | .permute(0, 1, 3, 2) \ 157 | .reshape(n, ls, 2, nring, nray) 158 | return torch.cat([scattered_nonshared, scattered_items], dim=2) 159 | else: 160 | return scattered_items 161 | 162 | 163 | def single_batch_dim_spatial_scatter( 164 | items, # (N, L, C) 165 | positions, # (N, L, 2) 166 | nray, 167 | nring, 168 | inner_radius, 169 | embed_offsets=False, 170 | ): # (N, C', nring, nray) where C' = C + 2 if embed_offsets else C 171 | n, l, c = items.size() 172 | assert (n, l, 2) == positions.size(), f'Expect size {(n, l, 2)} for positions, actual: {positions.size()}' 173 | 174 | distance_index, angular_index, distance_offsets, angular_offsets = \ 175 | single_batch_dim_polar_indices(positions, nray, nring, inner_radius) 176 | index = distance_index * nray + angular_index 177 | index = index.unsqueeze(-1) 178 | scattered_items = scatter_add(items, index, dim=1, dim_size=nray * nring) \ 179 | .permute(0, 2, 1) \ 180 | .reshape(n, c, nring, nray) 181 | 182 | if embed_offsets: 183 | offsets = torch.cat([distance_offsets.unsqueeze(-1), angular_offsets.unsqueeze(-1)], dim=2) 184 | scattered_nonshared = scatter_mean(offsets, index, dim=1, dim_size=nray * nring) \ 185 | .permute(0, 2, 1) \ 186 | .reshape(n, 2, nring, nray) 187 | return torch.cat([scattered_nonshared, scattered_items], dim=1) 188 | else: 189 | return scattered_items 190 | 191 | 192 | def single_batch_dim_polar_indices( 193 | positions, # (N, L, 2) 194 | nray, 195 | nring, 196 | inner_radius 197 | ): # (N, L), (N, L), (N, L), (N, L) 198 | distances = torch.sqrt(positions[:, :, 0] ** 2 + positions[:, :, 1] ** 2) 199 | distance_indices = torch.clamp(distances / inner_radius, min=0, max=nring-1).floor().long() 200 | angles = torch.atan2(positions[:, :, 1], positions[:, :, 0]) + math.pi 201 | # There is one angle value that can result in index of exactly nray, clamp it to nray-1 202 | angular_indices = torch.clamp_max((angles / (2 * math.pi) * nray).floor().long(), nray-1) 203 | 204 | distance_offsets = torch.clamp_max(distances / inner_radius - distance_indices.float() - 0.5, max=2) 205 | angular_offsets = angles / (2 * math.pi) * nray - angular_indices.float() - 0.5 206 | 207 | assert angular_indices.min() >= 0, f'Negative angular index: {angular_indices.min()}' 208 | assert angular_indices.max() < nray, f'invalid angular index: {angular_indices.max()} >= {nray}' 209 | assert distance_indices.min() >= 0, f'Negative distance index: {distance_indices.min()}' 210 | assert distance_indices.max() < nring, f'invalid distance index: {distance_indices.max()} >= {nring}' 211 | 212 | return distance_indices, angular_indices, distance_offsets, angular_offsets 213 | 214 | 215 | class ZeroPaddedCylindricalConv2d(nn.Module): 216 | def __init__(self, in_channels, out_channels, kernel_size): 217 | super(ZeroPaddedCylindricalConv2d, self).__init__() 218 | 219 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size) 220 | self.padding = kernel_size // 2 221 | 222 | # input should be of dims (N, C, H, W) 223 | # applies dimension-preserving conv2d by zero-padding H dimension and circularly padding W dimension 224 | def forward(self, input): 225 | input = F.pad(input, [0, 0, self.padding, self.padding], mode='circular') 226 | input = F.pad(input, [self.padding, self.padding, 0, 0], mode='constant') 227 | return self.conv(input) 228 | 229 | -------------------------------------------------------------------------------- /test_spatial_scatter.py: -------------------------------------------------------------------------------- 1 | from spatial import spatial_scatter, relative_positions 2 | import torch 3 | 4 | items = torch.tensor([ 5 | [[1.0, -1.0], [5.0, -5.0], [3., 3.]], 6 | [[5., 5.], [7., -7.], [2., 2.]], 7 | ]) 8 | positions = torch.tensor([ 9 | [[0., 0.1], [1., 0.], [10., 10.]], 10 | [[0., 0.1], [1., 0.], [99.5, 9.5]], 11 | ]) 12 | 13 | origin = torch.tensor([ 14 | [[0., 0.], [0., 1.]], 15 | [[0., 0.], [100., 10.]], 16 | ]) 17 | direction = torch.tensor([ 18 | [[0., 1.], [0., -1.]], 19 | [[0., 1.], [0., -1.]], 20 | ]) 21 | 22 | relpos = relative_positions(origin, direction, positions) 23 | map = spatial_scatter( 24 | items, 25 | relpos, 26 | nray=8, 27 | nring=5, 28 | inner_radius=1, 29 | ) 30 | print(map.size()) 31 | print(map) 32 | 33 | assert((map == torch.tensor([ 34 | [[[[ 0., 0., 0., 0., 1., 0., 0., 0.], 35 | [ 0., 0., 5., 0., 0., 0., 0., 0.], 36 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 37 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 38 | [ 0., 0., 0., 3., 0., 0., 0., 0.]], 39 | 40 | [[ 0., 0., 0., 0., -1., 0., 0., 0.], 41 | [ 0., 0., -5., 0., 0., 0., 0., 0.], 42 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 43 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 44 | [ 0., 0., 0., 3., 0., 0., 0., 0.]]], 45 | 46 | 47 | [[[ 0., 0., 0., 0., 1., 0., 0., 0.], 48 | [ 0., 0., 0., 0., 0., 5., 0., 0.], 49 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 50 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 51 | [ 0., 0., 0., 0., 0., 0., 3., 0.]], 52 | 53 | [[ 0., 0., 0., 0., -1., 0., 0., 0.], 54 | [ 0., 0., 0., 0., 0., -5., 0., 0.], 55 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 56 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 57 | [ 0., 0., 0., 0., 0., 0., 3., 0.]]]], 58 | 59 | 60 | 61 | [[[[ 0., 0., 0., 0., 5., 0., 0., 0.], 62 | [ 0., 0., 7., 0., 0., 0., 0., 0.], 63 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 64 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 65 | [ 0., 0., 2., 0., 0., 0., 0., 0.]], 66 | 67 | [[ 0., 0., 0., 0., 5., 0., 0., 0.], 68 | [ 0., 0., -7., 0., 0., 0., 0., 0.], 69 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 70 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 71 | [ 0., 0., 2., 0., 0., 0., 0., 0.]]], 72 | 73 | 74 | [[[ 0., 0., 0., 2., 0., 0., 0., 0.], 75 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 76 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 77 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 78 | [ 0., 0., 12., 0., 0., 0., 0., 0.]], 79 | 80 | [[ 0., 0., 0., 2., 0., 0., 0., 0.], 81 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 82 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 83 | [ 0., 0., 0., 0., 0., 0., 0., 0.], 84 | [ 0., 0., -2., 0., 0., 0., 0., 0.]]]]]) 85 | ).all()) 86 | --------------------------------------------------------------------------------