├── LICENSE ├── README.md ├── incredibot-sct.py ├── load-train-mlpp.py ├── sc2env.py ├── test_model.py └── trainppo.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Harrison 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SC2RL 2 | Reinforcement Learning + Starcraft 2 3 | 4 | Trained model file: https://www.dropbox.com/s/k8bomuzapmxychm/models.zip?dl=0 5 | 6 | Project associated with the following video: https://youtu.be/q59wap1ELQ4 7 | 8 | -------------------------------------------------------------------------------- /incredibot-sct.py: -------------------------------------------------------------------------------- 1 | from sc2.bot_ai import BotAI # parent class we inherit from 2 | from sc2.data import Difficulty, Race # difficulty for bots, race for the 1 of 3 races 3 | from sc2.main import run_game # function that facilitates actually running the agents in games 4 | from sc2.player import Bot, Computer #wrapper for whether or not the agent is one of your bots, or a "computer" player 5 | from sc2 import maps # maps method for loading maps to play in. 6 | from sc2.ids.unit_typeid import UnitTypeId 7 | import random 8 | import cv2 9 | import math 10 | import numpy as np 11 | import sys 12 | import pickle 13 | import time 14 | 15 | 16 | SAVE_REPLAY = True 17 | 18 | total_steps = 10000 19 | steps_for_pun = np.linspace(0, 1, total_steps) 20 | step_punishment = ((np.exp(steps_for_pun**3)/10) - 0.1)*10 21 | 22 | 23 | 24 | class IncrediBot(BotAI): # inhereits from BotAI (part of BurnySC2) 25 | async def on_step(self, iteration: int): # on_step is a method that is called every step of the game. 26 | no_action = True 27 | while no_action: 28 | try: 29 | with open('state_rwd_action.pkl', 'rb') as f: 30 | state_rwd_action = pickle.load(f) 31 | 32 | if state_rwd_action['action'] is None: 33 | #print("No action yet") 34 | no_action = True 35 | else: 36 | #print("Action found") 37 | no_action = False 38 | except: 39 | pass 40 | 41 | 42 | await self.distribute_workers() # put idle workers back to work 43 | 44 | action = state_rwd_action['action'] 45 | ''' 46 | 0: expand (ie: move to next spot, or build to 16 (minerals)+3 assemblers+3) 47 | 1: build stargate (or up to one) (evenly) 48 | 2: build voidray (evenly) 49 | 3: send scout (evenly/random/closest to enemy?) 50 | 4: attack (known buildings, units, then enemy base, just go in logical order.) 51 | 5: voidray flee (back to base) 52 | ''' 53 | 54 | # 0: expand (ie: move to next spot, or build to 16 (minerals)+3 assemblers+3) 55 | if action == 0: 56 | try: 57 | found_something = False 58 | if self.supply_left < 4: 59 | # build pylons. 60 | if self.already_pending(UnitTypeId.PYLON) == 0: 61 | if self.can_afford(UnitTypeId.PYLON): 62 | await self.build(UnitTypeId.PYLON, near=random.choice(self.townhalls)) 63 | found_something = True 64 | 65 | if not found_something: 66 | 67 | for nexus in self.townhalls: 68 | # get worker count for this nexus: 69 | worker_count = len(self.workers.closer_than(10, nexus)) 70 | if worker_count < 22: # 16+3+3 71 | if nexus.is_idle and self.can_afford(UnitTypeId.PROBE): 72 | nexus.train(UnitTypeId.PROBE) 73 | found_something = True 74 | 75 | # have we built enough assimilators? 76 | # find vespene geysers 77 | for geyser in self.vespene_geyser.closer_than(10, nexus): 78 | # build assimilator if there isn't one already: 79 | if not self.can_afford(UnitTypeId.ASSIMILATOR): 80 | break 81 | if not self.structures(UnitTypeId.ASSIMILATOR).closer_than(2.0, geyser).exists: 82 | await self.build(UnitTypeId.ASSIMILATOR, geyser) 83 | found_something = True 84 | 85 | if not found_something: 86 | if self.already_pending(UnitTypeId.NEXUS) == 0 and self.can_afford(UnitTypeId.NEXUS): 87 | await self.expand_now() 88 | 89 | except Exception as e: 90 | print(e) 91 | 92 | 93 | #1: build stargate (or up to one) (evenly) 94 | elif action == 1: 95 | try: 96 | # iterate thru all nexus and see if these buildings are close 97 | for nexus in self.townhalls: 98 | # is there is not a gateway close: 99 | if not self.structures(UnitTypeId.GATEWAY).closer_than(10, nexus).exists: 100 | # if we can afford it: 101 | if self.can_afford(UnitTypeId.GATEWAY) and self.already_pending(UnitTypeId.GATEWAY) == 0: 102 | # build gateway 103 | await self.build(UnitTypeId.GATEWAY, near=nexus) 104 | 105 | # if the is not a cybernetics core close: 106 | if not self.structures(UnitTypeId.CYBERNETICSCORE).closer_than(10, nexus).exists: 107 | # if we can afford it: 108 | if self.can_afford(UnitTypeId.CYBERNETICSCORE) and self.already_pending(UnitTypeId.CYBERNETICSCORE) == 0: 109 | # build cybernetics core 110 | await self.build(UnitTypeId.CYBERNETICSCORE, near=nexus) 111 | 112 | # if there is not a stargate close: 113 | if not self.structures(UnitTypeId.STARGATE).closer_than(10, nexus).exists: 114 | # if we can afford it: 115 | if self.can_afford(UnitTypeId.STARGATE) and self.already_pending(UnitTypeId.STARGATE) == 0: 116 | # build stargate 117 | await self.build(UnitTypeId.STARGATE, near=nexus) 118 | 119 | except Exception as e: 120 | print(e) 121 | 122 | 123 | #2: build voidray (random stargate) 124 | elif action == 2: 125 | try: 126 | if self.can_afford(UnitTypeId.VOIDRAY): 127 | for sg in self.structures(UnitTypeId.STARGATE).ready.idle: 128 | if self.can_afford(UnitTypeId.VOIDRAY): 129 | sg.train(UnitTypeId.VOIDRAY) 130 | 131 | except Exception as e: 132 | print(e) 133 | 134 | #3: send scout 135 | elif action == 3: 136 | # are there any idle probes: 137 | try: 138 | self.last_sent 139 | except: 140 | self.last_sent = 0 141 | 142 | # if self.last_sent doesnt exist yet: 143 | if (iteration - self.last_sent) > 200: 144 | try: 145 | if self.units(UnitTypeId.PROBE).idle.exists: 146 | # pick one of these randomly: 147 | probe = random.choice(self.units(UnitTypeId.PROBE).idle) 148 | else: 149 | probe = random.choice(self.units(UnitTypeId.PROBE)) 150 | # send probe towards enemy base: 151 | probe.attack(self.enemy_start_locations[0]) 152 | self.last_sent = iteration 153 | 154 | except Exception as e: 155 | pass 156 | 157 | 158 | #4: attack (known buildings, units, then enemy base, just go in logical order.) 159 | elif action == 4: 160 | try: 161 | # take all void rays and attack! 162 | for voidray in self.units(UnitTypeId.VOIDRAY).idle: 163 | # if we can attack: 164 | if self.enemy_units.closer_than(10, voidray): 165 | # attack! 166 | voidray.attack(random.choice(self.enemy_units.closer_than(10, voidray))) 167 | # if we can attack: 168 | elif self.enemy_structures.closer_than(10, voidray): 169 | # attack! 170 | voidray.attack(random.choice(self.enemy_structures.closer_than(10, voidray))) 171 | # any enemy units: 172 | elif self.enemy_units: 173 | # attack! 174 | voidray.attack(random.choice(self.enemy_units)) 175 | # any enemy structures: 176 | elif self.enemy_structures: 177 | # attack! 178 | voidray.attack(random.choice(self.enemy_structures)) 179 | # if we can attack: 180 | elif self.enemy_start_locations: 181 | # attack! 182 | voidray.attack(self.enemy_start_locations[0]) 183 | 184 | except Exception as e: 185 | print(e) 186 | 187 | 188 | #5: voidray flee (back to base) 189 | elif action == 5: 190 | if self.units(UnitTypeId.VOIDRAY).amount > 0: 191 | for vr in self.units(UnitTypeId.VOIDRAY): 192 | vr.attack(self.start_location) 193 | 194 | 195 | map = np.zeros((self.game_info.map_size[0], self.game_info.map_size[1], 3), dtype=np.uint8) 196 | 197 | # draw the minerals: 198 | for mineral in self.mineral_field: 199 | pos = mineral.position 200 | c = [175, 255, 255] 201 | fraction = mineral.mineral_contents / 1800 202 | if mineral.is_visible: 203 | #print(mineral.mineral_contents) 204 | map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction*i) for i in c] 205 | else: 206 | map[math.ceil(pos.y)][math.ceil(pos.x)] = [20,75,50] 207 | 208 | 209 | # draw the enemy start location: 210 | for enemy_start_location in self.enemy_start_locations: 211 | pos = enemy_start_location 212 | c = [0, 0, 255] 213 | map[math.ceil(pos.y)][math.ceil(pos.x)] = c 214 | 215 | # draw the enemy units: 216 | for enemy_unit in self.enemy_units: 217 | pos = enemy_unit.position 218 | c = [100, 0, 255] 219 | # get unit health fraction: 220 | fraction = enemy_unit.health / enemy_unit.health_max if enemy_unit.health_max > 0 else 0.0001 221 | map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction*i) for i in c] 222 | 223 | 224 | # draw the enemy structures: 225 | for enemy_structure in self.enemy_structures: 226 | pos = enemy_structure.position 227 | c = [0, 100, 255] 228 | # get structure health fraction: 229 | fraction = enemy_structure.health / enemy_structure.health_max if enemy_structure.health_max > 0 else 0.0001 230 | map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction*i) for i in c] 231 | 232 | # draw our structures: 233 | for our_structure in self.structures: 234 | # if it's a nexus: 235 | if our_structure.type_id == UnitTypeId.NEXUS: 236 | pos = our_structure.position 237 | c = [255, 255, 175] 238 | # get structure health fraction: 239 | fraction = our_structure.health / our_structure.health_max if our_structure.health_max > 0 else 0.0001 240 | map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction*i) for i in c] 241 | 242 | else: 243 | pos = our_structure.position 244 | c = [0, 255, 175] 245 | # get structure health fraction: 246 | fraction = our_structure.health / our_structure.health_max if our_structure.health_max > 0 else 0.0001 247 | map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction*i) for i in c] 248 | 249 | 250 | # draw the vespene geysers: 251 | for vespene in self.vespene_geyser: 252 | # draw these after buildings, since assimilators go over them. 253 | # tried to denote some way that assimilator was on top, couldnt 254 | # come up with anything. Tried by positions, but the positions arent identical. ie: 255 | # vesp position: (50.5, 63.5) 256 | # bldg positions: [(64.369873046875, 58.982421875), (52.85693359375, 51.593505859375),...] 257 | pos = vespene.position 258 | c = [255, 175, 255] 259 | fraction = vespene.vespene_contents / 2250 260 | 261 | if vespene.is_visible: 262 | map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction*i) for i in c] 263 | else: 264 | map[math.ceil(pos.y)][math.ceil(pos.x)] = [50,20,75] 265 | 266 | # draw our units: 267 | for our_unit in self.units: 268 | # if it is a voidray: 269 | if our_unit.type_id == UnitTypeId.VOIDRAY: 270 | pos = our_unit.position 271 | c = [255, 75 , 75] 272 | # get health: 273 | fraction = our_unit.health / our_unit.health_max if our_unit.health_max > 0 else 0.0001 274 | map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction*i) for i in c] 275 | 276 | 277 | else: 278 | pos = our_unit.position 279 | c = [175, 255, 0] 280 | # get health: 281 | fraction = our_unit.health / our_unit.health_max if our_unit.health_max > 0 else 0.0001 282 | map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction*i) for i in c] 283 | 284 | # show map with opencv, resized to be larger: 285 | # horizontal flip: 286 | 287 | cv2.imshow('map',cv2.flip(cv2.resize(map, None, fx=4, fy=4, interpolation=cv2.INTER_NEAREST), 0)) 288 | cv2.waitKey(1) 289 | 290 | if SAVE_REPLAY: 291 | # save map image into "replays dir" 292 | cv2.imwrite(f"replays/{int(time.time())}-{iteration}.png", map) 293 | 294 | 295 | 296 | reward = 0 297 | 298 | try: 299 | attack_count = 0 300 | # iterate through our void rays: 301 | for voidray in self.units(UnitTypeId.VOIDRAY): 302 | # if voidray is attacking and is in range of enemy unit: 303 | if voidray.is_attacking and voidray.target_in_range: 304 | if self.enemy_units.closer_than(8, voidray) or self.enemy_structures.closer_than(8, voidray): 305 | # reward += 0.005 # original was 0.005, decent results, but let's 3x it. 306 | reward += 0.015 307 | attack_count += 1 308 | 309 | except Exception as e: 310 | print("reward",e) 311 | reward = 0 312 | 313 | 314 | if iteration % 100 == 0: 315 | print(f"Iter: {iteration}. RWD: {reward}. VR: {self.units(UnitTypeId.VOIDRAY).amount}") 316 | 317 | # write the file: 318 | data = {"state": map, "reward": reward, "action": None, "done": False} # empty action waiting for the next one! 319 | 320 | with open('state_rwd_action.pkl', 'wb') as f: 321 | pickle.dump(data, f) 322 | 323 | 324 | 325 | 326 | result = run_game( # run_game is a function that runs the game. 327 | maps.get("2000AtmospheresAIE"), # the map we are playing on 328 | [Bot(Race.Protoss, IncrediBot()), # runs our coded bot, protoss race, and we pass our bot object 329 | Computer(Race.Zerg, Difficulty.Hard)], # runs a pre-made computer agent, zerg race, with a hard difficulty. 330 | realtime=False, # When set to True, the agent is limited in how long each step can take to process. 331 | ) 332 | 333 | 334 | if str(result) == "Result.Victory": 335 | rwd = 500 336 | else: 337 | rwd = -500 338 | 339 | with open("results.txt","a") as f: 340 | f.write(f"{result}\n") 341 | 342 | 343 | map = np.zeros((224, 224, 3), dtype=np.uint8) 344 | observation = map 345 | data = {"state": map, "reward": rwd, "action": None, "done": True} # empty action waiting for the next one! 346 | with open('state_rwd_action.pkl', 'wb') as f: 347 | pickle.dump(data, f) 348 | 349 | cv2.destroyAllWindows() 350 | cv2.waitKey(1) 351 | time.sleep(3) 352 | sys.exit() -------------------------------------------------------------------------------- /load-train-mlpp.py: -------------------------------------------------------------------------------- 1 | # $ source ~/Desktop/sc2env/bin/activate 2 | 3 | # so this works, so far. 4 | 5 | from stable_baselines3 import PPO 6 | import os 7 | from sc2env import Sc2Env 8 | import time 9 | from wandb.integration.sb3 import WandbCallback 10 | import wandb 11 | 12 | 13 | LOAD_MODEL = "models/1647915989/1647915989.zip" 14 | # Environment: 15 | env = Sc2Env() 16 | 17 | # load the model: 18 | model = PPO.load(LOAD_MODEL, env=env) 19 | 20 | model_name = f"{int(time.time())}" 21 | 22 | models_dir = f"models/{model_name}/" 23 | logdir = f"logs/{model_name}/" 24 | 25 | 26 | conf_dict = {"Model": "load-v16s", 27 | "Machine": "Puget/Desktop/v18/2", 28 | "policy":"MlpPolicy", 29 | "model_save_name": model_name, 30 | "load_model": LOAD_MODEL 31 | } 32 | 33 | run = wandb.init( 34 | project=f'SC2RLv6', 35 | entity="sentdex", 36 | config=conf_dict, 37 | sync_tensorboard=True, # auto-upload sb3's tensorboard metrics 38 | save_code=True, # save source code 39 | ) 40 | 41 | 42 | # further train: 43 | TIMESTEPS = 10000 44 | iters = 0 45 | while True: 46 | print("On iteration: ", iters) 47 | iters += 1 48 | model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"PPO") 49 | model.save(f"{models_dir}/{TIMESTEPS*iters}") -------------------------------------------------------------------------------- /sc2env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy as np 4 | import subprocess 5 | import pickle 6 | import time 7 | import os 8 | 9 | class Sc2Env(gym.Env): 10 | """Custom Environment that follows gym interface""" 11 | def __init__(self): 12 | super(Sc2Env, self).__init__() 13 | # Define action and observation space 14 | # They must be gym.spaces objects 15 | # Example when using discrete actions: 16 | self.action_space = spaces.Discrete(6) 17 | self.observation_space = spaces.Box(low=0, high=255, 18 | shape=(224, 224, 3), dtype=np.uint8) 19 | 20 | def step(self, action): 21 | wait_for_action = True 22 | # waits for action. 23 | while wait_for_action: 24 | #print("waiting for action") 25 | try: 26 | with open('state_rwd_action.pkl', 'rb') as f: 27 | state_rwd_action = pickle.load(f) 28 | 29 | if state_rwd_action['action'] is not None: 30 | #print("No action yet") 31 | wait_for_action = True 32 | else: 33 | #print("Needs action") 34 | wait_for_action = False 35 | state_rwd_action['action'] = action 36 | with open('state_rwd_action.pkl', 'wb') as f: 37 | # now we've added the action. 38 | pickle.dump(state_rwd_action, f) 39 | except Exception as e: 40 | #print(str(e)) 41 | pass 42 | 43 | # waits for the new state to return (map and reward) (no new action yet. ) 44 | wait_for_state = True 45 | while wait_for_state: 46 | try: 47 | if os.path.getsize('state_rwd_action.pkl') > 0: 48 | with open('state_rwd_action.pkl', 'rb') as f: 49 | state_rwd_action = pickle.load(f) 50 | if state_rwd_action['action'] is None: 51 | #print("No state yet") 52 | wait_for_state = True 53 | else: 54 | #print("Got state state") 55 | state = state_rwd_action['state'] 56 | reward = state_rwd_action['reward'] 57 | done = state_rwd_action['done'] 58 | wait_for_state = False 59 | 60 | except Exception as e: 61 | wait_for_state = True 62 | map = np.zeros((224, 224, 3), dtype=np.uint8) 63 | observation = map 64 | # if still failing, input an ACTION, 3 (scout) 65 | data = {"state": map, "reward": 0, "action": 3, "done": False} # empty action waiting for the next one! 66 | with open('state_rwd_action.pkl', 'wb') as f: 67 | pickle.dump(data, f) 68 | 69 | state = map 70 | reward = 0 71 | done = False 72 | action = 3 73 | 74 | info ={} 75 | observation = state 76 | return observation, reward, done, info 77 | 78 | 79 | def reset(self): 80 | print("RESETTING ENVIRONMENT!!!!!!!!!!!!!") 81 | map = np.zeros((224, 224, 3), dtype=np.uint8) 82 | observation = map 83 | data = {"state": map, "reward": 0, "action": None, "done": False} # empty action waiting for the next one! 84 | with open('state_rwd_action.pkl', 'wb') as f: 85 | pickle.dump(data, f) 86 | 87 | # run incredibot-sct.py non-blocking: 88 | subprocess.Popen(['python3', 'incredibot-sct.py']) 89 | return observation # reward, done, info can't be included 90 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3 import PPO 2 | from sc2env import Sc2Env 3 | 4 | 5 | LOAD_MODEL = "models/1647915989/2880000.zip" 6 | # Environment: 7 | env = Sc2Env() 8 | 9 | # load the model: 10 | model = PPO.load(LOAD_MODEL) 11 | 12 | 13 | # Play the game: 14 | obs = env.reset() 15 | done = False 16 | while not done: 17 | action, _states = model.predict(obs) 18 | obs, rewards, dones, info = env.step(action) 19 | 20 | -------------------------------------------------------------------------------- /trainppo.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3 import PPO 2 | import os 3 | from sc2env import Sc2Env 4 | import time 5 | from wandb.integration.sb3 import WandbCallback 6 | import wandb 7 | 8 | 9 | model_name = f"{int(time.time())}" 10 | 11 | models_dir = f"models/{model_name}/" 12 | logdir = f"logs/{model_name}/" 13 | 14 | 15 | conf_dict = {"Model": "v19", 16 | "Machine": "Main", 17 | "policy":"MlpPolicy", 18 | "model_save_name": model_name} 19 | 20 | 21 | run = wandb.init( 22 | project=f'SC2RLv6', 23 | entity="sentdex", 24 | config=conf_dict, 25 | sync_tensorboard=True, # auto-upload sb3's tensorboard metrics 26 | save_code=True, # optional 27 | ) 28 | 29 | 30 | if not os.path.exists(models_dir): 31 | os.makedirs(models_dir) 32 | 33 | if not os.path.exists(logdir): 34 | os.makedirs(logdir) 35 | 36 | env = Sc2Env() 37 | 38 | model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=logdir) 39 | 40 | TIMESTEPS = 10000 41 | iters = 0 42 | while True: 43 | print("On iteration: ", iters) 44 | iters += 1 45 | model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"PPO") 46 | model.save(f"{models_dir}/{TIMESTEPS*iters}") 47 | --------------------------------------------------------------------------------