├── models └── .placeholder ├── src ├── data │ ├── __init__.py │ ├── create_database.py │ ├── riotapi.py │ ├── match_pool.py │ ├── champion_info.py │ ├── database_ops.py │ └── query_wiki.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── softmax.py │ ├── inference_model.py │ └── qNetwork.py ├── run_tests.py ├── tests │ └── tests.py ├── requirements.txt ├── update_champions_data.py ├── features │ ├── rewards.py │ ├── experience_replay.py │ ├── draft.py │ ├── match_processing.py │ └── draftstate.py ├── update_match_data.py ├── main.py ├── model_predictions.py └── trainer.py ├── _config.yml ├── common └── images │ ├── nms_error.png │ ├── val_outliers.png │ ├── val_matches_1.png │ ├── val_matches_2.png │ ├── discount_factor.png │ ├── reward_sched_term.png │ ├── validation_matches.png │ ├── league_draft_structure.png │ └── reward_sched_non-term.png ├── data ├── competitiveMatchData.db ├── match_sources.json ├── test_train_split.txt └── patch_info.json ├── LICENSE └── README.md /models/.placeholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-leap-day -------------------------------------------------------------------------------- /src/run_tests.py: -------------------------------------------------------------------------------- 1 | from tests.tests import run 2 | 3 | run() 4 | -------------------------------------------------------------------------------- /common/images/nms_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/nms_error.png -------------------------------------------------------------------------------- /src/tests/tests.py: -------------------------------------------------------------------------------- 1 | from features.draftstate import DraftState 2 | 3 | def run(): 4 | print("HELLO") 5 | -------------------------------------------------------------------------------- /common/images/val_outliers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/val_outliers.png -------------------------------------------------------------------------------- /data/competitiveMatchData.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightd22/swainBot/HEAD/data/competitiveMatchData.db -------------------------------------------------------------------------------- /common/images/val_matches_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/val_matches_1.png -------------------------------------------------------------------------------- /common/images/val_matches_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/val_matches_2.png -------------------------------------------------------------------------------- /common/images/discount_factor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/discount_factor.png -------------------------------------------------------------------------------- /common/images/reward_sched_term.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/reward_sched_term.png -------------------------------------------------------------------------------- /common/images/validation_matches.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/validation_matches.png -------------------------------------------------------------------------------- /common/images/league_draft_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/league_draft_structure.png -------------------------------------------------------------------------------- /common/images/reward_sched_non-term.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/reward_sched_non-term.png -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.13.3 2 | matplotlib==1.5.3 3 | tensorflow==1.15.2 4 | requests==2.11.1 5 | pandas==0.18.1 6 | luigi==2.8.0 7 | -------------------------------------------------------------------------------- /data/match_sources.json: -------------------------------------------------------------------------------- 1 | { 2 | "match_sources": 3 | { 4 | "tournaments":[], 5 | "patches":["8.16","8.17","8.18","8.19"] 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /src/models/base_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | class BaseModel(): 3 | def __init__(self, name, path): 4 | self._name = name 5 | self._path_to_model = path 6 | self._graph = tf.Graph() 7 | self.sess = tf.Session(graph=self._graph) 8 | 9 | def __del__(self): 10 | try: 11 | self.sess.close() 12 | del self.sess 13 | finally: 14 | print("Model closed..") 15 | 16 | def build_model(self): 17 | raise NotImplementedError 18 | def init_saver(self): 19 | raise NotImplementedError 20 | def save(self): 21 | raise NotImplementedError 22 | def load(self): 23 | raise NotImplementedError 24 | -------------------------------------------------------------------------------- /data/test_train_split.txt: -------------------------------------------------------------------------------- 1 | {"validation_ids": [1111, 238, 2300, 1104, 243, 481, 1422, 220, 1420, 2317, 2318, 219, 2122, 240, 1418, 221, 2334, 2333, 211, 1129], "training_ids": [2129, 1113, 490, 2315, 1097, 465, 1425, 225, 217, 214, 227, 2108, 244, 1419, 476, 1101, 467, 1122, 471, 1416, 469, 1102, 246, 489, 463, 2331, 2330, 1417, 2123, 1125, 2113, 475, 458, 2304, 1437, 229, 2313, 1128, 2115, 1433, 224, 485, 2307, 2325, 216, 488, 241, 2320, 218, 209, 233, 2332, 1430, 1108, 242, 1134, 1095, 487, 1436, 461, 479, 486, 2116, 466, 477, 1424, 2323, 2302, 1423, 462, 2306, 2327, 1127, 1098, 2321, 2111, 478, 234, 2125, 236, 472, 239, 228, 212, 1130, 1432, 468, 1116, 484, 1123, 1117, 2326, 245, 2109, 470, 2127, 2119, 1434, 473, 1124, 2118, 1121, 1106, 1426, 474, 235, 459, 2128, 480, 2124, 2305, 226, 2121, 215, 2117, 464, 232, 1112, 1120, 1110, 1100, 2324, 1431, 2114, 2314, 2303, 2311, 1114, 1119, 2120, 237, 2110, 2319, 2126, 1135, 1131, 230, 1132, 2329, 2309, 1099, 1096, 482, 1435, 460, 1126, 483, 231, 1109, 1133, 1103, 1118, 1429, 1115, 1107, 213, 223, 1428, 1427, 2312, 2328, 2107, 222, 2322, 2308, 210, 2112, 2316, 1105, 491, 2301, 2310, 1421]} -------------------------------------------------------------------------------- /src/update_champions_data.py: -------------------------------------------------------------------------------- 1 | import luigi 2 | import requests 3 | import json 4 | import time 5 | import os 6 | 7 | class ChampionsDownload(luigi.ExternalTask): 8 | champions_path = luigi.Parameter(default="champions.json") 9 | def output(self): 10 | # output is temporary file to force task to check if patch data matches 11 | return luigi.LocalTarget("tmp/pipeline/champions{}.json".format(time.time())) 12 | 13 | def run(self): 14 | url = "https://ddragon.leagueoflegends.com/api/versions.json" 15 | response = requests.get(url=url).json() 16 | current_patch = response[0] 17 | # Check for local file patch version 18 | try: 19 | with open(self.champions_path, 'r') as infile: 20 | data = json.load(infile) 21 | local_patch = data["version"] 22 | except: 23 | local_patch = None 24 | 25 | # update local file if patches do not match (uses temporary file) 26 | if local_patch != current_patch: 27 | print("Local patch does not match current patch.. Updating") 28 | url = "http://ddragon.leagueoflegends.com/cdn/{current_patch}/data/en_US/champion.json".format(current_patch=current_patch) 29 | response = requests.get(url=url).json() 30 | tmp_file = self.output().path 31 | with open(tmp_file, 'w') as outfile: 32 | json.dump(response, outfile) 33 | os.rename(tmp_file, self.champions_path) 34 | else: 35 | print("Local patch matches current patch.. Skipping") 36 | return 37 | 38 | if __name__ == "__main__": 39 | luigi.run(['ChampionsDownload', "--local-scheduler"]) 40 | -------------------------------------------------------------------------------- /src/data/create_database.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | from sqlite3 import Error 3 | import json 4 | from .query_wiki import query_wiki 5 | import re 6 | from . import database_ops as dbo 7 | 8 | def table_col_info(cursor, tableName, printOut=False): 9 | """ 10 | Returns a list of tuples with column informations: 11 | (id, name, type, notnull, default_value, primary_key) 12 | """ 13 | cursor.execute('PRAGMA TABLE_INFO({})'.format(tableName)) 14 | info = cursor.fetchall() 15 | 16 | if printOut: 17 | print("Column Info:\nID, Name, Type, NotNull, DefaultVal, PrimaryKey") 18 | for col in info: 19 | print(col) 20 | return info 21 | 22 | def create_tables(cursor, tableNames, columnInfo, clobber = False): 23 | """ 24 | create_tables attempts to create a table for each table in the list tableNames with 25 | columns as defined by columnInfo. For each if table = tableNames[k] then the columns for 26 | table are defined by columns = columnInfo[k]. Note that each element in columnInfo must 27 | be a list of strings of the form column[j] = "jth_column_name jth_column_data_type" 28 | 29 | Args: 30 | cursor (sqlite cursor): cursor used to execute commmands 31 | tableNames (list(string)): string labels for tableNames 32 | columnInfo (list(list(string))): list of string labels for each column of each table 33 | clobber (bool): flag to determine if old tables should be overwritten 34 | 35 | Returns: 36 | status (int): 0 if table creation failed, 1 if table creation was successful 37 | """ 38 | 39 | for (table, colInfo) in zip(tableNames, columnInfo): 40 | columnInfoString = ", ".join(colInfo) 41 | try: 42 | if clobber: 43 | cursor.execute("DROP TABLE IF EXISTS {tableName}".format(tableName=table)) 44 | cursor.execute("CREATE TABLE {tableName} ({columnInfo})".format(tableName=table,columnInfo=columnInfoString)) 45 | except Error as e: 46 | print(e) 47 | print("Table {} already exists! Here's it's column info:".format(table)) 48 | table_col_info(cursor, table, True) 49 | print("***") 50 | return 0 51 | 52 | return 1 53 | -------------------------------------------------------------------------------- /src/features/rewards.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .draftstate import DraftState as ds 3 | 4 | def get_reward(state, match, submitted_action, actual_action): 5 | """ 6 | Args: 7 | state (DraftState): Present state of the draft to be checked for reward 8 | match (dict): record of match which identifies winning team 9 | submitted_action (tuple(int)): id of action submitted by model 10 | actual_action (tuple(int)): id of action submitted in observation 11 | Returns: 12 | reward (int): Integer value representing the reward earned for the draft state. 13 | 14 | get_reward takes a draft state and returns the immediate reward for reaching that state. The reward is determined by a simple reward table 15 | 1) state is invalid -> reward = -10 16 | 2) state is complete, valid, and the selection was submitted by the winning team -> reward = +5 17 | 3) state is complete, valid but the submission was made by the losing team -> reward = +2.5 18 | 3) state is valid, but incomplete -> reward = 0 19 | """ 20 | status = state.evaluate() 21 | if(status in ds.invalid_states): 22 | return -10. 23 | 24 | reward = 0. 25 | winner = get_winning_team(match) 26 | if(status == ds.DRAFT_COMPLETE and winner is not None): 27 | if(state.team == winner): 28 | reward += 5. 29 | else: 30 | reward += 2.5 31 | 32 | if(submitted_action == actual_action): 33 | reward += 0.5 34 | else: 35 | reward += -0.5 36 | 37 | return reward 38 | 39 | def get_winning_team(match): 40 | """ 41 | Args: 42 | match (dict): match dictionary with pick and ban data for a single game. 43 | Returns: 44 | val (int): Integer representing which team won the match. 45 | val = DraftState.RED_TEAM if the red team won 46 | val = DraftState.BLUE_TEAM if blue team won 47 | val = None if match does not have data for winning team 48 | 49 | get_winning_team returns the winning team of the input match encoded as an integer according to DraftState. 50 | """ 51 | if match["winner"]==0: 52 | return ds.BLUE_TEAM 53 | elif match["winner"]==1: 54 | return ds.RED_TEAM 55 | return None 56 | -------------------------------------------------------------------------------- /src/features/experience_replay.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | class ExperienceBuffer(): 4 | """ 5 | ExperienceBuffer is a class for storing and adding experiences to be sampled from when batch learning a Qnetwork. An experience is defined as a tuple of the form 6 | (s, a, r, s') where 7 | s = input state 8 | a = action taken from state s 9 | r = reward obtained for taking action a 10 | s' = ending state after taking action a 11 | Args: 12 | max_buffer_size (int): maximum number of experiences to store in the buffer, default value is 300. 13 | """ 14 | def __init__(self, max_buffer_size = 300): 15 | self.buffer = [] 16 | self.buffer_size = max_buffer_size 17 | self.oldest_experience = 0 18 | 19 | def store(self, experiences): 20 | """ 21 | ExperienceBuffer.store stores the input list of experience tuples into the buffer. The expereince is stored in one of two ways: 22 | 1) If the buffer has space remaining, the experience is appended to the end 23 | 2) If the buffer is full, the input experience replaces the oldest experience in the buffer 24 | 25 | Args: 26 | experiences ( list(tuple) ): each experience is a tuple of the form (s, a, r, s') 27 | Returns: 28 | None 29 | """ 30 | for experience in experiences: 31 | if len(self.buffer) < self.buffer_size: 32 | self.buffer.append(experience) 33 | else: 34 | self.buffer[self.oldest_experience] = experience 35 | self.oldest_experience += 1 36 | self.oldest_experience = self.oldest_experience % self.buffer_size 37 | return None 38 | 39 | def sample(self, sample_size): 40 | """ 41 | ExperienceBuffer.sample samples the current buffer using random.sample to return a collection of sample_size experiences from the replay buffer. 42 | random.sample samples without replacement, so sample_size must be no larger than the length of the current buffer. 43 | 44 | Args: 45 | sample_size (int): number of samples to take from buffer 46 | Returns: 47 | sample (list(tuples)): list of experience replay samples. len(sample) = sample_size. 48 | """ 49 | return random.sample(self.buffer,sample_size) 50 | 51 | def get_buffer_size(self): 52 | """ 53 | Returns length of the buffer. 54 | """ 55 | return len(self.buffer) 56 | -------------------------------------------------------------------------------- /src/data/riotapi.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from . import myRiotApiKey 3 | import time 4 | api_versions = { 5 | "staticdata": "v3", 6 | "datadragon": "7.15.1" 7 | } 8 | valid_methods = ["GET", "PUT", "POST"] 9 | region = "na1" 10 | def set_api_key(key): 11 | """ 12 | Set calling user's API Key 13 | 14 | Args: 15 | key (string): the Riot API key desired for use. 16 | """ 17 | myRiotApiKey.api_key = key 18 | return key 19 | 20 | def set_region(reg): 21 | """ 22 | Set region to run API queries through 23 | 24 | Args: 25 | reg (string): region through which we are sending API requests 26 | """ 27 | reg = reg.lower() 28 | regions = ["br1", "eun1", "euw1", "jp1", "kr", "la1", "la2", "na1", "oce1", "tr1", "ru"] 29 | 30 | assert (reg in regions), "Invalid region!" 31 | region = reg 32 | return region 33 | 34 | def make_request(request, method, params={}): 35 | """ 36 | Makes a rate-limited HTTP request to Riot API and returns the response data 37 | """ 38 | url = "https://{region}.api.riotgames.com/lol/{request}".format(region=region,request=request) 39 | try: 40 | response = execute_request(url, method, params) 41 | if(not response.ok): 42 | response.raise_for_status() 43 | return response.json() 44 | except requests.exceptions.HTTPError as e: 45 | # Wait and try again on 429 (rate limit exceeded) 46 | if response.status_code == 429: 47 | if "X-Rate-Limit-Type" not in e.headers or e.headers["X-Rate-Limit-Type"] == "service": 48 | # Wait 1 second before retrying 49 | time.sleep(1) 50 | else: 51 | retry_after = 1 52 | if response.headers["Retry-After"]: 53 | retry_after += int(e.headers["Retry-After"]) 54 | 55 | time.sleep(retry_after) 56 | return make_request(request, method, params) 57 | else: 58 | raise 59 | 60 | def execute_request(url, method, params={}): 61 | """ 62 | Executes HTTP request using requests library and returns response object. 63 | Args: 64 | url (str): full url string to request 65 | method (str): HTTP method to use (one of "GET", "PUT", or "POST") 66 | params (dict): dictionary of parameters to send along url 67 | 68 | Returns: 69 | response object returned by requests 70 | """ 71 | 72 | response = None 73 | assert(method in valid_methods), "[execute_request] Invalid HTTP method!" 74 | if(method == "GET"): 75 | response = requests.get(url=url, params=params) 76 | return response 77 | -------------------------------------------------------------------------------- /data/patch_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "patch_info": 3 | { 4 | "2018": 5 | { 6 | "NA_LCS": 7 | { 8 | "Spring_Season":["8.1","8.1","8.2","8.2","8.3","8.3","8.4","8.4","8.5"], 9 | "Spring_Playoffs":["8.5"], 10 | "Summer_Season":["8.11","8.12","8.12","8.13","8.13","8.14","8.14","8.15","8.15"], 11 | "Summer_Playoffs":["8.16"], 12 | "Regional_Finals":["8.16"] 13 | }, 14 | "NA_ACA": 15 | { 16 | "Spring_Season":["8.1","8.1","8.2","8.2","8.3","8.3","8.4","8.4","8.5"], 17 | "Spring_Playoffs":["8.5"], 18 | "Summer_Season":["8.11","8.12","8.13","8.13","8.13","8.14","8.14","8.15"], 19 | "Summer_Playoffs":["8.16"] 20 | }, 21 | "EU_LCS": 22 | { 23 | "Spring_Season":["8.1","8.1","8.2","8.2","8.3","8.3","8.4","8.4","8.5"], 24 | "Spring_Playoffs":["8.5"], 25 | "Summer_Season":["8.11","8.12","8.12","8.13","8.13","8.14","8.14","8.15","8.15"], 26 | "Summer_Playoffs":["8.16"], 27 | "Regional_Finals":["8.16"] 28 | }, 29 | "LCK": 30 | { 31 | "Spring_Season":["8.1","8.1","8.1","8.2","8.3","8.3","8.4","8.4","8.5"], 32 | "Spring_Playoffs":["8.5"], 33 | "Summer_Season":["8.11","8.11","8.12","8.13","8.13","8.13","8.14","8.14"], 34 | "Summer_Playoffs":["8.15"], 35 | "Regional_Finals":["8.15"] 36 | }, 37 | "KR_CHAL": 38 | { 39 | "Spring_Season":["8.1","8.1","8.1","8.2","8.2","8.3","8.4","8.4","8.5"], 40 | "Spring_Playoffs":["8.6"], 41 | "Summer_Season":["8.11","8.11","8.12","8.13","8.13","8.13","8.14","8.14","8.15","8.15","8.15"], 42 | "Summer_Playoffs":["8.15"] 43 | }, 44 | "LPL": 45 | { 46 | "Spring_Season":["8.1","8.1","8.1","8.1","8.2","8.2","8.4","8.4","8.5"], 47 | "Spring_Playoffs":["8.5"], 48 | "Summer_Season":["8.11","8.11","8.11","8.13","8.13","8.13","8.14","8.14","8.15","8.15","8.15"], 49 | "Summer_Playoffs":["8.16"], 50 | "Regional_Finals":["8.16"] 51 | }, 52 | "LDL": 53 | { 54 | "Spring_Playoffs":["8.8"], 55 | "Grand_Finals":["8.16"] 56 | }, 57 | "LMS": 58 | { 59 | "Spring_Season":["8.1","8.1","8.1","8.2","8.3","8.3","8.4","8.4","8.5","8.5"], 60 | "Spring_Playoffs":["8.6"], 61 | "Summer_Season":["8.11","8.12","8.13","8.13","8.13","8.14","8.14","8.15"], 62 | "Summer_Playoffs":["8.17"], 63 | "Regional_Finals":["8.17"] 64 | }, 65 | "MSI": 66 | { 67 | "Play-In":["8.8"], 68 | "Main_Event":["8.8"] 69 | }, 70 | "WORLDS": 71 | { 72 | "Play-In":["8.19"], 73 | "Main_Event":["8.19"] 74 | } 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/features/draft.py: -------------------------------------------------------------------------------- 1 | class Draft(object): 2 | BLUE_TEAM = 0 3 | RED_TEAM = 1 4 | BAN = 201 5 | PICK = 202 6 | PHASES = [BAN, PICK] 7 | 8 | # Draft specifcations 9 | # Default specification for drafting structure 10 | default_draft = [ 11 | (BLUE_TEAM, BAN), 12 | (RED_TEAM, BAN), 13 | (BLUE_TEAM, BAN), 14 | (RED_TEAM, BAN), 15 | (BLUE_TEAM, BAN), 16 | (RED_TEAM, BAN), 17 | 18 | (BLUE_TEAM, PICK), 19 | (RED_TEAM, PICK), 20 | (RED_TEAM, PICK), 21 | (BLUE_TEAM, PICK), 22 | (BLUE_TEAM, PICK), 23 | (RED_TEAM, PICK), 24 | 25 | (RED_TEAM, BAN), 26 | (BLUE_TEAM, BAN), 27 | (RED_TEAM, BAN), 28 | (BLUE_TEAM, BAN), 29 | 30 | (RED_TEAM, PICK), 31 | (BLUE_TEAM, PICK), 32 | (BLUE_TEAM, PICK), 33 | (RED_TEAM, PICK), 34 | ] 35 | 36 | no_bans = [ 37 | (BLUE_TEAM, PICK), 38 | (RED_TEAM, PICK), 39 | (RED_TEAM, PICK), 40 | (BLUE_TEAM, PICK), 41 | (BLUE_TEAM, PICK), 42 | (RED_TEAM, PICK), 43 | 44 | (RED_TEAM, PICK), 45 | (BLUE_TEAM, PICK), 46 | (BLUE_TEAM, PICK), 47 | (RED_TEAM, PICK), 48 | ] 49 | # Dictionary mapping draft labels to draft structures 50 | draft_structures = {'default': default_draft, 51 | 'no_bans': no_bans, 52 | } 53 | 54 | def __init__(self, draft_type = 'default'): 55 | self._draft_structure = None 56 | try: 57 | self._draft_structure = Draft.draft_structures[draft_type] 58 | except KeyError: 59 | print("In draft.py: Draft structure not defined") 60 | raise 61 | 62 | self.PHASE_LENGTHS = {} 63 | for phase in Draft.PHASES: 64 | self.PHASE_LENGTHS[phase] = [] 65 | 66 | phase_length = 0 67 | current_phase = None 68 | for (team, phase) in self._draft_structure: 69 | if not current_phase: 70 | current_phase = phase 71 | if phase == current_phase: 72 | phase_length += 1 73 | else: 74 | self.PHASE_LENGTHS[current_phase].append(phase_length) 75 | current_phase = phase 76 | phase_length = 1 77 | self.PHASE_LENGTHS[current_phase].append(phase_length) # don't forget last phase 78 | self.NUM_BANS = sum(self.PHASE_LENGTHS[Draft.BAN]) # Total number of bans in draft 79 | self.NUM_PICKS = sum(self.PHASE_LENGTHS[Draft.PICK]) # Total number of picks in draft 80 | 81 | # submission_dist[k] gives tuple of counts for pick types just before kth submission is made (last element will hold final submission distribution for draft) 82 | self.submission_dist = [(0,0,0)] 83 | for (team, phase) in self._draft_structure: 84 | (cur_ban, cur_blue, cur_red) = self.submission_dist[-1] 85 | if phase == Draft.BAN: 86 | next_dist = (cur_ban+1, cur_blue, cur_red) 87 | elif team == Draft.BLUE_TEAM: 88 | next_dist = (cur_ban, cur_blue+1, cur_red) 89 | else: 90 | next_dist = (cur_ban, cur_blue, cur_red+1) 91 | self.submission_dist += [next_dist] 92 | 93 | def get_active_team(self, submission_count): 94 | """ 95 | Gets the active team in the draft based on the number of submissions currently present 96 | Args: 97 | submission_count (int): number of submissions currently submitted to draft 98 | Returns: 99 | Draft.BLUE_TEAM if blue is active, else Draft.RED_TEAM 100 | """ 101 | if submission_count > len(self._draft_structure): 102 | raise 103 | elif submission_count == len(self._draft_structure): 104 | return None 105 | (team, sub_type) = self._draft_structure[submission_count] 106 | return team 107 | 108 | def get_active_phase(self, submission_count): 109 | """ 110 | Returns phase identifier for current phase of the draft based on the number of submissions made. 111 | Args: 112 | None 113 | Returns: 114 | Draft.BAN if state is in banning phase, otherwise Draft.PICK 115 | """ 116 | if submission_count > len(self._draft_structure): 117 | raise 118 | elif submission_count == len(self._draft_structure): 119 | return None 120 | (team, sub_type) = self._draft_structure[submission_count] 121 | return sub_type 122 | 123 | if __name__ == "__main__": 124 | draft = Draft("default") 125 | [print(thing) for thing in draft.submission_dist] 126 | -------------------------------------------------------------------------------- /src/models/softmax.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from . import base_model 5 | 6 | class SoftmaxNetwork(base_model.BaseModel): 7 | """ 8 | Args: 9 | input_shape (tuple): tuple of inputs to network. 10 | output_shape (int): number of output nodes for network. 11 | filter_sizes (tuple of 2 ints): number of filters in each of the two hidden layers. Defaults to (16,32). 12 | learning_rate (float): network's willingness to change current weights given new example 13 | regularization (float): strength of weights regularization term in loss function 14 | 15 | A simple softmax network class which is responsible for holding and updating the weights and biases used in predicing actions for given state. This network will consist of 16 | the following layers: 17 | 1) Input- a DraftState state s (an array of bool) representing the current state reshaped into an [n_batch, *input_shape] tensor. 18 | 2-4) Two layers of relu-activated hidden fc layers 19 | 4) Output- softmax-obtained probability of action submission for output_shape actions available. 20 | 21 | """ 22 | @property 23 | def name(self): 24 | return self._name 25 | 26 | def __init__(self, name, path, input_shape, output_shape, filter_sizes = (512,512), learning_rate=1.e-3, regularization_coeff = 0.01): 27 | super().__init__(name=name, path=path) 28 | self._input_shape = input_shape 29 | self._output_shape = output_shape 30 | self._learning_rate = learning_rate 31 | self._regularization_coeff = regularization_coeff 32 | self._n_hidden_layers = len(filter_sizes) 33 | self._n_layers = self._n_hidden_layers + 2 34 | self._filter_sizes = filter_sizes 35 | 36 | self.ops_dict = self.build_model(name=self._name) 37 | with self._graph.as_default(): 38 | self.ops_dict["init"] = tf.global_variables_initializer() 39 | 40 | self.init_saver() 41 | 42 | def init_saver(self): 43 | with self._graph.as_default(): 44 | self.saver = tf.train.Saver() 45 | 46 | def save(self, path): 47 | self.saver.save(self.sess, save_path=path) 48 | 49 | def load(self, path): 50 | self.saver.restore(self.sess, save_path=path) 51 | 52 | def build_model(self, name): 53 | ops_dict = {} 54 | with self._graph.as_default(): 55 | with tf.variable_scope(name): 56 | ops_dict["learning_rate"] = tf.Variable(self._learning_rate, trainable=False, name="learning_rate") 57 | 58 | # Incoming state matrices are of size input_size = (nChampions, nPos+2) 59 | # 'None' here means the input tensor will flex with the number of training 60 | # examples (aka batch size). 61 | ops_dict["input"] = tf.placeholder(tf.float32, (None,)+self._input_shape, name="inputs") 62 | ops_dict["dropout_keep_prob"] = tf.placeholder_with_default(1.0,shape=()) 63 | 64 | # Fully connected (FC) layers: 65 | fc0 = tf.layers.dense( 66 | ops_dict["input"], 67 | self._filter_sizes[0], 68 | activation=tf.nn.relu, 69 | bias_initializer=tf.constant_initializer(0.1), 70 | name="fc_0") 71 | dropout0 = tf.nn.dropout(fc0, ops_dict["dropout_keep_prob"]) 72 | 73 | fc1 = tf.layers.dense( 74 | dropout0, 75 | self._filter_sizes[1], 76 | activation=tf.nn.relu, 77 | bias_initializer=tf.constant_initializer(0.1), 78 | name="fc_1") 79 | dropout1 = tf.nn.dropout(fc1, ops_dict["dropout_keep_prob"]) 80 | 81 | # Logits layer 82 | ops_dict["logits"] = tf.layers.dense( 83 | dropout1, 84 | self._output_shape, 85 | activation=None, 86 | bias_initializer=tf.constant_initializer(0.1), 87 | kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=self._regularization_coeff), 88 | name="logits") 89 | 90 | # Placeholder for valid actions filter 91 | ops_dict["valid_actions"] = tf.placeholder(tf.bool, shape=ops_dict["logits"].shape, name="valid_actions") 92 | 93 | # Filtered logits 94 | ops_dict["valid_logits"] = tf.where(ops_dict["valid_actions"], ops_dict["logits"], tf.scalar_mul(-np.inf, tf.ones_like(ops_dict["logits"])), name="valid_logits") 95 | 96 | # Predicted optimal action amongst valid actions 97 | ops_dict["probabilities"] = tf.nn.softmax(ops_dict["valid_logits"], name="action_probabilites") 98 | ops_dict["prediction"] = tf.argmax(input=ops_dict["valid_logits"], axis=1, name="predictions") 99 | 100 | ops_dict["actions"] = tf.placeholder(tf.int32, shape=[None], name="submitted_actions") 101 | 102 | ops_dict["loss"] = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=ops_dict["actions"], logits=ops_dict["valid_logits"]), name="loss") 103 | 104 | ops_dict["trainer"] = tf.train.AdamOptimizer(learning_rate = ops_dict["learning_rate"]) 105 | ops_dict["update"] = ops_dict["trainer"].minimize(ops_dict["loss"], name="update") 106 | 107 | return ops_dict 108 | -------------------------------------------------------------------------------- /src/models/inference_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from . import base_model 3 | 4 | class QNetInferenceModel(base_model.BaseModel): 5 | def __init__(self, name, path): 6 | super().__init__(name=name, path=path) 7 | self.init_saver() 8 | self.ops_dict = self.build_model() 9 | 10 | def init_saver(self): 11 | with self._graph.as_default(): 12 | self.saver = tf.train.import_meta_graph("{path}.ckpt.meta".format(path=self._path_to_model)) 13 | self.saver.restore(self.sess,"{path}.ckpt".format(path=self._path_to_model)) 14 | def build_model(self): 15 | ops_dict = {} 16 | with self._graph.as_default(): 17 | ops_dict["predict_q"] = tf.get_default_graph().get_tensor_by_name("online/valid_q_vals:0") 18 | ops_dict["prediction"] = tf.get_default_graph().get_tensor_by_name("online/prediction:0") 19 | ops_dict["input"] = tf.get_default_graph().get_tensor_by_name("online/inputs:0") 20 | ops_dict["valid_actions"] = tf.get_default_graph().get_tensor_by_name("online/valid_actions:0") 21 | return ops_dict 22 | 23 | def predict(self, states): 24 | """ 25 | Feeds state into model and returns current predicted Q-values. 26 | Args: 27 | states (list of DraftStates): states to predict from 28 | Returns: 29 | predicted_Q (numpy array): model estimates of Q-values for actions from input states. 30 | predicted_Q[k,:] holds Q-values for state states[k] 31 | """ 32 | inputs = [state.format_state() for state in states] 33 | valid_actions = [state.get_valid_actions() for state in states] 34 | 35 | feed_dict = {self.ops_dict["input"]:inputs, 36 | self.ops_dict["valid_actions"]:valid_actions} 37 | predicted_Q = self.sess.run(self.ops_dict["predict_q"], feed_dict=feed_dict) 38 | return predicted_Q 39 | 40 | def predict_action(self, states): 41 | """ 42 | Feeds state into model and return recommended action to take from input state based on estimated Q-values. 43 | Args: 44 | state (list of DraftStates): states to predict from 45 | Returns: 46 | predicted_action (numpy array): array of integer representations of actions recommended by model. 47 | """ 48 | inputs = [state.format_state() for state in states] 49 | valid_actions = [state.get_valid_actions() for state in states] 50 | 51 | feed_dict = {self.ops_dict["input"]:inputs, 52 | self.ops_dict["valid_actions"]:valid_actions} 53 | predicted_actions = self.sess.run(self.ops_dict["prediction"], feed_dict=feed_dict) 54 | return predicted_actions 55 | 56 | class SoftmaxInferenceModel(base_model.BaseModel): 57 | def __init__(self, name, path): 58 | super().__init__(name=name, path=path) 59 | self.init_saver() 60 | self.ops_dict = self.build_model() 61 | 62 | def init_saver(self): 63 | with self._graph.as_default(): 64 | self.saver = tf.train.import_meta_graph("{path}.ckpt.meta".format(path=self._path_to_model)) 65 | self.saver.restore(self.sess,"{path}.ckpt".format(path=self._path_to_model)) 66 | 67 | def build_model(self): 68 | ops_dict = {} 69 | with self._graph.as_default(): 70 | ops_dict["probabilities"] = tf.get_default_graph().get_tensor_by_name("softmax/action_probabilites:0") 71 | ops_dict["prediction"] = tf.get_default_graph().get_tensor_by_name("softmax/predictions:0") 72 | ops_dict["input"] = tf.get_default_graph().get_tensor_by_name("softmax/inputs:0") 73 | ops_dict["valid_actions"] = tf.get_default_graph().get_tensor_by_name("softmax/valid_actions:0") 74 | return ops_dict 75 | 76 | def predict(self, states): 77 | """ 78 | Feeds state into model and returns current predicted probabilities. 79 | Args: 80 | states (list of DraftStates): states to predict from 81 | Returns: 82 | probabilities (numpy array): model estimates of probabilities for actions from input states. 83 | probabilities[k,:] holds Q-values for state states[k] 84 | """ 85 | inputs = [state.format_state() for state in states] 86 | valid_actions = [state.get_valid_actions() for state in states] 87 | 88 | feed_dict = {self.ops_dict["input"]:inputs, 89 | self.ops_dict["valid_actions"]:valid_actions} 90 | probabilities = self.sess.run(self.ops_dict["probabilities"], feed_dict=feed_dict) 91 | return probabilities 92 | 93 | def predict_action(self, states): 94 | """ 95 | Feeds state into model and return recommended action to take from input state based on estimated Q-values. 96 | Args: 97 | state (list of DraftStates): states to predict from 98 | Returns: 99 | predicted_action (numpy array): array of integer representations of actions recommended by model. 100 | """ 101 | inputs = [state.format_state() for state in states] 102 | valid_actions = [state.get_valid_actions() for state in states] 103 | 104 | feed_dict = {self.ops_dict["input"]:inputs, 105 | self.ops_dict["valid_actions"]:valid_actions} 106 | predicted_actions = self.sess.run(self.ops_dict["prediction"], feed_dict=feed_dict) 107 | return predicted_actions 108 | -------------------------------------------------------------------------------- /src/update_match_data.py: -------------------------------------------------------------------------------- 1 | import luigi 2 | import requests 3 | import json 4 | import time 5 | import sqlite3 6 | from data.create_database import create_tables 7 | import data.database_ops as dbo 8 | from data.query_wiki import query_wiki 9 | 10 | class CreateMatchDB(luigi.Task): 11 | path_to_db = luigi.Parameter(default="../data/competitiveMatchData.db") 12 | 13 | def output(self): 14 | return luigi.LocalTarget(self.path_to_db) 15 | 16 | def run(self): 17 | tableNames = ["game", "pick", "ban", "team"] 18 | 19 | columnInfo = [] 20 | # Game table columns 21 | columnInfo.append(["id INTEGER PRIMARY KEY", 22 | "tournament TEXT","tourn_game_id INTEGER", "week INTEGER", "patch TEXT", 23 | "blue_teamid INTEGER NOT NULL", "red_teamid INTEGER NOT NULL", 24 | "winning_team INTEGER"]) 25 | # Pick table columns 26 | columnInfo.append(["id INTEGER PRIMARY KEY", 27 | "game_id INTEGER", "champion_id INTEGER","position_id INTEGER", 28 | "selection_order INTEGER", "side_id INTEGER"]) 29 | # Ban table columns 30 | columnInfo.append(["id INTEGER PRIMARY KEY", 31 | "game_id INTEGER", "champion_id INTEGER", "selection_order INTEGER", "side_id INTEGER"]) 32 | # Team table columns 33 | columnInfo.append(["id INTEGER PRIMARY KEY", 34 | "region TEXT", "display_name TEXT"]) 35 | 36 | conn = sqlite3.connect(self.path_to_db) 37 | cur = conn.cursor() 38 | print("Creating tables..") 39 | _ = create_tables(cur, tableNames, columnInfo, clobber = True) 40 | conn.close() 41 | 42 | return 1 43 | 44 | def validate_match_data(match_data): 45 | """ 46 | validate_match_data performs basic match data validation by examining the following: 47 | 1. Number of picks/bans present in data 48 | 2. Presence of dupicate picks/bans 49 | 3. Duplicate roles on a single side 50 | 51 | Args: 52 | match_data (dict): dictionary of formatted match data 53 | 54 | Returns: 55 | bool: True if match_data passes validation checks, False otherwise. 56 | """ 57 | NUM_BANS = 10 58 | NUM_PICKS = 10 59 | 60 | is_valid = True 61 | bans = match_data["bans"]["blue"] + match_data["bans"]["red"] 62 | picks = match_data["picks"]["blue"] + match_data["picks"]["red"] 63 | if(len(bans) != NUM_BANS or len(picks)!= NUM_PICKS): 64 | print("Incorrect number of picks and/or bans found! {} picks, {} bans".format(len(picks), len(bans))) 65 | is_valid = False 66 | 67 | # Need to consider edge case where teams fail to submit multiple bans (rare, but possible) 68 | champs = [ban for ban in bans if ban != "none"] + [p for (p,_) in picks] 69 | if len(set(champs)) != len(champs): 70 | print("Duplicate submission(s) encountered.") 71 | counts = {} 72 | for champ in champs: 73 | if champ not in counts: 74 | counts[champ] = 1 75 | else: 76 | counts[champ] += 1 77 | print(sorted([(value, key) for (key, value) in counts.items() if value>1])) 78 | is_valid = False 79 | 80 | for side in ["blue", "red"]: 81 | if len(set([pos for (_,pos) in match_data["picks"][side]])) != len(match_data["picks"][side]): 82 | print("Duplicate position on side {} found.".format(side)) 83 | is_valid = False 84 | 85 | return is_valid 86 | 87 | if __name__ == "__main__": 88 | path_to_db = "../data/competitiveMatchData.db" 89 | luigi.run( 90 | cmdline_args=["--path-to-db={}".format(path_to_db)], 91 | main_task_cls=CreateMatchDB, 92 | local_scheduler=True) 93 | 94 | conn = sqlite3.connect(path_to_db) 95 | cur = conn.cursor() 96 | 97 | # deleted_match_ids = [770] 98 | # dbo.delete_game_from_table(cur, game_ids = deleted_match_ids, table_name="pick") 99 | # dbo.delete_game_from_table(cur, game_ids = deleted_match_ids, table_name="ban") 100 | 101 | year = "2018" 102 | schedule = [] 103 | regions = ["EU_LCS","NA_LCS","LPL","LMS","LCK"]; tournaments = ["Spring_Season", "Spring_Playoffs", "Summer_Season", "Summer_Playoffs", "Regional_Finals"] 104 | schedule.append((regions,tournaments)) 105 | regions = ["NA_ACA","KR_CHAL"]; tournaments = ["Spring_Season", "Spring_Playoffs", "Summer_Season", "Summer_Playoffs"] 106 | schedule.append((regions,tournaments)) 107 | regions = ["LDL"]; tournaments = ["Spring_Playoffs", "Grand_Finals"] 108 | schedule.append((regions,tournaments)) 109 | 110 | NUM_BANS = 10 111 | NUM_PICKS = 10 112 | for regions, tournaments in schedule: 113 | for region in regions: 114 | for tournament in tournaments: 115 | skip_commit = False 116 | print("Querying: {}".format(year+"/"+region+"/"+tournament)) 117 | gameData = query_wiki(year, region, tournament) 118 | print("Found {} games.".format(len(gameData))) 119 | for i,game in enumerate(gameData): 120 | is_valid = validate_match_data(game) 121 | if not is_valid: 122 | skip_commit = True 123 | print("Errors in match: h_id {} tourn_g_id {}: {} vs {}".format(game["header_id"], game["tourn_game_id"], game["blue_team"], game["red_team"])) 124 | 125 | if(not skip_commit): 126 | print("Attempting to insert {} games..".format(len(gameData))) 127 | status = dbo.insert_team(cur,gameData) 128 | status = dbo.insert_game(cur,gameData) 129 | status = dbo.insert_ban(cur,gameData) 130 | status = dbo.insert_pick(cur,gameData) 131 | print("Committing changes to db..") 132 | conn.commit() 133 | else: 134 | print("Errors found in match data.. skipping commit") 135 | raise 136 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import json 4 | import sqlite3 5 | import matplotlib.pyplot as plt 6 | import time 7 | 8 | from features.draftstate import DraftState 9 | import data.champion_info as cinfo 10 | import features.match_processing as mp 11 | from data.match_pool import test_train_split 12 | import data.database_ops as dbo 13 | 14 | from models import qNetwork, softmax 15 | from trainer import DDQNTrainer, SoftmaxTrainer 16 | from models.inference_model import QNetInferenceModel, SoftmaxInferenceModel 17 | 18 | import tensorflow as tf 19 | 20 | print("") 21 | print("********************************") 22 | print("** Beginning Swain Bot Run! **") 23 | print("********************************") 24 | 25 | valid_champ_ids = cinfo.get_champion_ids() 26 | print("Number of valid championIds: {}".format(len(valid_champ_ids))) 27 | 28 | LIST_PATH = None#"../data/test_train_split.txt" 29 | LIST_SAVE_PATH = "../data/test_train_split.txt" 30 | PATH_TO_DB = "../data/competitiveMatchData.db" 31 | MODEL_DIR = "../models/" 32 | N_TRAIN = 173 33 | N_VAL = 20 34 | PATCHES = None 35 | PRUNE_PATCHES = None 36 | result = test_train_split(N_TRAIN, N_VAL, PATH_TO_DB, LIST_PATH, LIST_SAVE_PATH) 37 | 38 | validation_ids = result["validation_ids"] 39 | training_ids = result["training_ids"] 40 | print("Found {} training matches and {} validation matches in pool.".format(len(training_ids), len(validation_ids))) 41 | 42 | validation_matches = dbo.get_matches_by_id(validation_ids, PATH_TO_DB) 43 | 44 | print("***") 45 | print("Displaying Validation matches:") 46 | count = 0 47 | for match in validation_matches: 48 | count += 1 49 | print("Match: {:2} id: {:4} {:6} vs {:6} winner: {:2}".format(count, match["id"], match["blue_team"], match["red_team"], match["winner"])) 50 | for team in ["blue", "red"]: 51 | bans = match[team]["bans"] 52 | picks = match[team]["picks"] 53 | pretty_bans = [] 54 | pretty_picks = [] 55 | for ban in bans: 56 | pretty_bans.append(cinfo.champion_name_from_id(ban[0])) 57 | for pick in picks: 58 | pretty_picks.append((cinfo.champion_name_from_id(pick[0]), pick[1])) 59 | print("{} bans:{}".format(team, pretty_bans)) 60 | print("{} picks:{}".format(team, pretty_picks)) 61 | print("") 62 | print("***") 63 | 64 | # Network parameters 65 | state = DraftState(DraftState.BLUE_TEAM, valid_champ_ids) 66 | input_size = state.format_state().shape 67 | output_size = state.num_actions 68 | filter_size = (1024,1024) 69 | regularization_coeff = 7.5e-5#1.5e-4 70 | path_to_model = None#"model_predictions/spring_2018/week_3/model_E{}.ckpt".format(30)#None 71 | load_path = None#"tmp/ddqn_model_E45.ckpt" 72 | 73 | # Training parameters 74 | batch_size = 16#32 75 | buffer_size = 4096#2048 76 | n_epoch = 45 77 | discount_factor = 0.9 78 | learning_rate = 1.0e-4#2.0e-5# 79 | time.sleep(2.) 80 | for i in range(1): 81 | training_matches = dbo.get_matches_by_id(training_ids, PATH_TO_DB) 82 | print("Learning on {} matches for {} epochs. lr {:.4e} reg {:4e}".format(len(training_matches),n_epoch, learning_rate, regularization_coeff),flush=True) 83 | break 84 | 85 | tf.reset_default_graph() 86 | name = "softmax" 87 | out_path = "{}{}_model_E{}.ckpt".format(MODEL_DIR, name, n_epoch) 88 | softnet = softmax.SoftmaxNetwork(name, out_path, input_size, output_size, filter_size, learning_rate, regularization_coeff) 89 | trainer = SoftmaxTrainer(softnet, n_epoch, training_matches, validation_matches, batch_size, load_path=None) 90 | summaries = trainer.train() 91 | 92 | tf.reset_default_graph() 93 | name = "ddqn" 94 | out_path = "{}{}_model_E{}.ckpt".format(MODEL_DIR, name, n_epoch) 95 | ddqn = qNetwork.Qnetwork(name, out_path, input_size, output_size, filter_size, learning_rate, regularization_coeff, discount_factor) 96 | trainer = DDQNTrainer(ddqn, n_epoch, training_matches, validation_matches, batch_size, buffer_size, load_path) 97 | summaries = trainer.train() 98 | 99 | print("Learning complete!") 100 | print("..final training accuracy: {:.4f}".format(summaries["train_acc"][-1])) 101 | x = [i+1 for i in range(len(summaries["loss"]))] 102 | fig = plt.figure() 103 | plt.plot(x,summaries["loss"]) 104 | plt.ylabel('loss') 105 | plt.xlabel('epoch') 106 | #plt.ylim([0,2]) 107 | fig_name = "tmp/loss_figures/annuled_rate/loss_E{}_run_{}.pdf".format(n_epoch,i+1) 108 | print("Loss figure saved in:{}".format(fig_name),flush=True) 109 | fig.savefig(fig_name) 110 | 111 | fig = plt.figure() 112 | plt.plot(x, summaries["train_acc"], x, summaries["val_acc"]) 113 | fig_name = "tmp/acc_figs/acc_E{}_run_{}.pdf".format(n_epoch,i+1) 114 | print("Accuracy figure saved in:{}".format(fig_name),flush=True) 115 | fig.savefig(fig_name) 116 | 117 | 118 | # Look at predicted Q values for states in a randomly drawn match 119 | match = random.sample(training_matches,1)[0] 120 | team = DraftState.RED_TEAM if match["winner"]==1 else DraftState.BLUE_TEAM 121 | experiences = mp.process_match(match,team) 122 | count = 0 123 | # x labels for q val plots 124 | xticks = [] 125 | xtick_locs = [] 126 | for a in range(state.num_actions): 127 | cid,pos = state.format_action(a) 128 | if cid not in xticks: 129 | xticks.append(cid) 130 | xtick_locs.append(a) 131 | xtick_labels = [cinfo.champion_name_from_id(cid)[:6] for cid in xticks] 132 | 133 | tf.reset_default_graph() 134 | #path_to_model = "../models/ddqn_model_E{}".format(45)#"tmp/ddqn_model_E45"#"tmp/model_E{}".format(n_epoch) 135 | #model = QNetInferenceModel(name="infer", path=path_to_model) 136 | path_to_model = "../models/softmax_model_E{}".format(45)#"tmp/ddqn_model_E45"#"tmp/model_E{}".format(n_epoch) 137 | model = SoftmaxInferenceModel(name="infer", path=path_to_model) 138 | 139 | for exp in experiences: 140 | state,act,rew,next_state = exp 141 | cid,pos = act 142 | if cid == None: 143 | continue 144 | count += 1 145 | form_act = state.get_action(cid,pos) 146 | pred_act = model.predict_action([state]) 147 | pred_act = pred_act[0] 148 | pred_Q = model.predict([state]) 149 | pred_Q = pred_Q[0,:] 150 | 151 | p_cid,p_pos = state.format_action(pred_act) 152 | actual = (cinfo.champion_name_from_id(cid),pos,pred_Q[form_act]) 153 | pred = (cinfo.champion_name_from_id(p_cid),p_pos,pred_Q[pred_act]) 154 | print("pred:{}, actual:{}".format(pred,actual)) 155 | 156 | # Plot Q-val figure 157 | fig = plt.figure(figsize=(25,5)) 158 | plt.ylabel('$Q(s,a)$') 159 | plt.xlabel('$a$') 160 | plt.xticks(xtick_locs, xtick_labels, rotation=70) 161 | plt.tick_params(axis='x',which='both',labelsize=6) 162 | x = np.arange(len(pred_Q)) 163 | plt.bar(x,pred_Q, align='center',alpha=0.8,color='b') 164 | plt.bar(pred_act, pred_Q[pred_act],align='center',color='r') 165 | plt.bar(form_act, pred_Q[form_act],align='center',color='g') 166 | 167 | fig_name = "tmp/qval_figs/{}.pdf".format(count) 168 | fig.savefig(fig_name) 169 | 170 | print("") 171 | print("********************************") 172 | print("** Ending Swain Bot Run! **") 173 | print("********************************") 174 | -------------------------------------------------------------------------------- /src/data/match_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import json 3 | import sqlite3 4 | from .database_ops import get_matches_by_id, get_game_ids, get_match_data, get_game_ids_by_tournament, get_tournament_data 5 | 6 | def test_train_split(n_training, n_validation, path_to_db, list_path=None, save_path=None, match_sources=None, prune_patches=None): 7 | """ 8 | test_train_split returns a match_ids split into two nominal groups: a training set and a test set. 9 | Args: 10 | n_training (int): number of match_ids in training split 11 | n_test (int): number of match_ids in test_train_split 12 | path_to_db (str): path to database containing match data 13 | list_path (str, optional): path to existing match ids to either grow or pruned 14 | save_path (str, optional): path to save split match ids 15 | match_sources (dict, optional): dictionary containing "patches" and "tournaments" keys containing lists of tournament and patch identifiers to use for match sources 16 | prune_patches (list, optional): list of patches to prune from existing split match ids 17 | 18 | Returns: 19 | Dictionary {"training_ids":list(int),"validation_ids":list(int)} 20 | """ 21 | save_match_pool = False 22 | validation_ids = [] 23 | training_ids = [] 24 | if list_path: 25 | print("Building list off of match data in {}.".format(list_path)) 26 | with open(list_path,'r') as infile: 27 | data = json.load(infile) 28 | validation_ids = data["validation_ids"] 29 | training_ids = data["training_ids"] 30 | if prune_patches: 31 | pre_prune_id_count = len(validation_ids)+len(training_ids) 32 | validation_ids = prune_match_list(validation_ids, path=path_to_db, patches=prune_patches) 33 | training_ids = prune_match_list(training_ids, path=path_to_db, patches=prune_patches) 34 | post_prune_id_count = len(validation_ids)+len(training_ids) 35 | save_match_pool = True 36 | print("Pruned {} matches from the match list".format(pre_prune_id_count-post_prune_id_count)) 37 | 38 | val_diff = max([n_validation - len(validation_ids),0]) 39 | train_diff = max([n_training - len(training_ids),0]) 40 | 41 | current = [] 42 | current.extend(validation_ids) 43 | current.extend(training_ids) 44 | 45 | count = val_diff + train_diff 46 | if(count > 0): 47 | new_matches = grow_pool(count, current, path_to_db, match_sources) 48 | if(val_diff>0): 49 | print("Insufficient number of validation matches. Attempting to add difference..") 50 | validation_ids.extend(new_matches[:val_diff]) 51 | if(train_diff>0): 52 | print("Insufficient number of training matches. Attempting to add difference..") 53 | training_ids.extend(new_matches[val_diff:count]) 54 | print("Successfully added {} matches to validation and {} matches to training.".format(val_diff, train_diff)) 55 | save_match_pool = True 56 | 57 | if(save_match_pool and save_path): 58 | print("Saving pool to {}..".format(save_path)) 59 | with open(save_path,'w') as outfile: 60 | json.dump({"training_ids":training_ids,"validation_ids":validation_ids},outfile) 61 | 62 | return {"training_ids":training_ids,"validation_ids":validation_ids} 63 | 64 | def grow_pool(count, current_pool, path_to_db, match_sources=None): 65 | total = match_pool(0, path_to_db, randomize=False, match_sources=match_sources)["match_ids"] 66 | new = list(set(total)-set(current_pool)) 67 | assert(len(new) >= count), "Not enough new matches to match required count! avail: {} needed: {}".format(len(new), count) 68 | random.shuffle(new) 69 | 70 | return new[:count] 71 | 72 | def prune_match_list(match_ids, path_to_db, patches=None): 73 | """ 74 | Prunes match list by removing matches played on specified patches. 75 | """ 76 | matches = get_matches_by_id(match_ids, path_to_db) 77 | pruned_match_list = [] 78 | for match in matches: 79 | patch = match["patch"] 80 | if patch not in patches: 81 | pruned_match_list.append(match["id"]) 82 | return pruned_match_list 83 | 84 | def match_pool(num_matches, path_to_db, randomize=True, match_sources=None): 85 | """ 86 | Args: 87 | num_matches (int): Number of matches to include in the queue (0 indicates to use the maximum number of matches available) 88 | path_do_db (str): Path to match database to query against 89 | randomize (bool): Flag for randomizing order of output matches. 90 | match_sources (dict(string)): Dict containing "tournaments" and "patches" keys to use when building pool, if None, defaults to using patches/tournaments in data/match_sources.json 91 | Returns: 92 | match_data (dictionary): dictionary containing two keys: 93 | "match_ids": list of match_ids for pooled matches 94 | "matches": list of pooled match data to process 95 | 96 | Builds a set of matchids and match data used during learning phase. If randomize flag is set 97 | to false this returns the first num_matches in order according to match_sources. 98 | """ 99 | if(match_sources is None): 100 | with open("../data/match_sources.json") as infile: 101 | data = json.load(infile) 102 | patches = data["match_sources"]["patches"] 103 | tournaments = data["match_sources"]["tournaments"] 104 | else: 105 | patches = match_sources["patches"] 106 | tournaments = match_sources["tournaments"] 107 | 108 | # If patches or tournaments is empty, grab matches from all patches from specified tournaments or all tournaments from specified matches 109 | if not patches: 110 | patches = [None] 111 | if not tournaments: 112 | tournaments = [None] 113 | 114 | match_pool = [] 115 | conn = sqlite3.connect(path_to_db) 116 | cur = conn.cursor() 117 | # Build list of eligible match ids 118 | for patch in patches: 119 | for tournament in tournaments: 120 | game_ids = get_game_ids(cur, tournament, patch) 121 | match_pool.extend(game_ids) 122 | 123 | print("Number of available matches for training={}".format(len(match_pool))) 124 | if(num_matches == 0): 125 | num_matches = len(match_pool) 126 | assert num_matches <= len(match_pool), "Not enough matches found to sample!" 127 | if(randomize): 128 | selected_match_ids = random.sample(match_pool, num_matches) 129 | else: 130 | selected_match_ids = match_pool[:num_matches] 131 | 132 | selected_matches = [] 133 | for match_id in selected_match_ids: 134 | match = get_match_data(cur, match_id) 135 | selected_matches.append(match) 136 | conn.close() 137 | return {"match_ids":selected_match_ids, "matches":selected_matches} 138 | 139 | if __name__ == "__main__": 140 | match_sources = {"patches":[], "tournaments": ["2018/NA/Summer_Season"]} 141 | path_to_db = "../data/competitiveMatchData.db" 142 | out_path = "../data/test_train_split.txt" 143 | res = test_train_split(20, 20, path_to_db, list_path=out_path, save_path=out_path, match_sources=match_sources, prune_patches=None) 144 | print(res) 145 | -------------------------------------------------------------------------------- /src/data/champion_info.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #from cassiopeia import riotapi 3 | from .riotapi import make_request, api_versions 4 | import requests 5 | import re 6 | import json 7 | from . import myRiotApiKey 8 | 9 | # Box is a vacant class with no initial members. This will be used to hold the champion_id list and champion_id <-> name dictionaries. 10 | 11 | class Box: 12 | pass 13 | __m = Box() 14 | __m.champion_name_from_id = None 15 | __m.champion_id_from_name = None 16 | __m.valid_champion_ids = None 17 | __m.championAliases = { 18 | "blitz": "blitzcrank", 19 | "gp": "gangplank", 20 | "jarvan": "jarvaniv", 21 | "cait": "caitlyn", 22 | "lb": "leblanc", 23 | "cass": "cassiopeia", 24 | "casiopeia": "cassiopeia", 25 | "ori": "orianna", 26 | "lee": "leesin", 27 | "vlad": "vladimir", 28 | "j4": "jarvaniv", 29 | "as": "aurelionsol", # who the fuck thinks this is unique? 30 | "kass": "kassadin", 31 | "tk": "tahmkench", 32 | "malz": "malzahar", 33 | "sej": "sejuani", 34 | "nid": "nidalee", 35 | "aurelion": "aurelionsol", 36 | "mundo": "drmundo", 37 | "tahm": "tahmkench", 38 | "kayne": "kayn", 39 | "zil": "zilean", 40 | "naut": "nautilus", 41 | "morg": "morgana", 42 | "ez": "ezreal", 43 | "nunu": "nunuwillump", 44 | "yi": "masteryi", 45 | "cho":"chogath", 46 | "mao":"maokai", 47 | "morde":"mordekaiser", 48 | "xin":"xinzhao", 49 | "eve":"evelynn", 50 | "sol":"aurelionsol", 51 | "fid":"fiddlesticks", 52 | "fiddle":"fiddlesticks", 53 | "heimer":"heimerdinger", 54 | "noc":"nocturne", 55 | "tf":"twistedfate", 56 | "kog":"kogmaw" 57 | } 58 | 59 | # This is a flag to make champion_info methods look for data stored locally 60 | # rather than query the API. Useful if API is out. 61 | look_local = True 62 | LOCAL_CHAMPION_PATH = "../data/champions.json" 63 | 64 | class Champion(): 65 | def __init__(self,dictionary): 66 | if(look_local): 67 | # Local file is a cached query to data dragon 68 | # Data dragon reverses the meaning of keys and ids from the API. 69 | self.key = dictionary["id"] 70 | self.id = int(dictionary["key"]) 71 | else: 72 | self.key = dictionary["key"] 73 | self.id = int(dictionary["id"]) 74 | self.name = dictionary["name"] 75 | self.title = dictionary["title"] 76 | 77 | class AliasException(Exception): 78 | def __init__(self, message, errors): 79 | super().__init__(message) 80 | self.errors = errors 81 | self.message = message 82 | 83 | def convert_champion_alias(alias): 84 | """ 85 | Args: 86 | alias (string): lowercase and pruned string alias for a champion 87 | Returns: 88 | name (string): lowercase and pruned string name for champion 89 | 90 | convert_champion_alias converts a given champion alias (ie "blitz") 91 | and returns the version of that champions proper name which is suitable for passing to 92 | champion_id_from_name(). If no such alias can be found, this raises an AliasException. 93 | 94 | Example: name = convert_champion_alias("blitz") will yield name = "blitzcrank" 95 | """ 96 | null_champion = ["none","lossofban"] 97 | if (alias in null_champion): 98 | return "none" 99 | try: 100 | if (alias in __m.championAliases): 101 | return __m.championAliases[alias] 102 | else: 103 | raise AliasException("Champion alias not found!", alias) 104 | except AliasException as e: 105 | print("*****") 106 | print(e.message) 107 | print("Offending alias: {}".format(e.errors)) 108 | print("*****") 109 | raise 110 | 111 | def champion_name_from_id(champion_id): 112 | """ 113 | Args: 114 | champion_id (int): Integer Id corresponding to the desired champion name. 115 | Returns: 116 | name (string): String name of requested champion. If no such champion can be found, returns NULL 117 | 118 | champion_name_from_id takes a requested champion_id number and returns the string name of that champion using a champion_name_from_id dictionary. 119 | If the dictonary has not yet been populated, this creates the dictionary using cassiopeia's interface to Riot's API. 120 | """ 121 | if __m.champion_name_from_id is None: 122 | populate_champion_dictionary() 123 | 124 | if (champion_id in __m.champion_name_from_id): 125 | return __m.champion_name_from_id[champion_id] 126 | return None 127 | 128 | def champion_id_from_name(champion_name): 129 | """ 130 | Args: 131 | champion_name (string): lowercase and pruned string label corresponding to the desired champion id. 132 | Returns: 133 | id (int): id of requested champion. If no such champion can be found, returns NULL 134 | 135 | champion_id_from_name takes a requested champion name and returns the id label of that champion using a champion_id_from_name dictionary. 136 | If the dictonary has not yet been populated, this creates the dictionary using cassiopeia's interface to Riot's API. 137 | Note that champion_name should be all lowercase and have any non-alphanumeric characters (including whitespace) removed. 138 | """ 139 | if __m.champion_id_from_name is None: 140 | populate_champion_dictionary() 141 | 142 | if (champion_name in __m.champion_id_from_name): 143 | return __m.champion_id_from_name[champion_name] 144 | return None 145 | 146 | def valid_champion_id(champion_id): 147 | """ 148 | Checks to see if champion_id corresponds to a valid champion id code. 149 | Returns: True if champion_id is valid. False otherwise. 150 | Args: 151 | champion_id (int): Id of champion to be verified. 152 | """ 153 | if __m.champion_name_from_id is None: 154 | populate_champion_dictionary() 155 | 156 | return champion_id in __m.valid_champion_ids 157 | 158 | def get_champion_ids(): 159 | """ 160 | Returns a sorted list of valid champion IDs. 161 | Args: 162 | None 163 | Returns: 164 | validIds (list(ints)): sorted list of valid champion IDs. 165 | """ 166 | if __m.valid_champion_ids is None: 167 | populate_champion_dictionary() 168 | 169 | return __m.valid_champion_ids[:] 170 | 171 | def populate_champion_dictionary(): 172 | """ 173 | Args: 174 | None 175 | Returns: 176 | True if succesful, False otherwise 177 | Populates the module dictionary whose keys are champion Ids and values are strings of the corresponding champion's name. 178 | """ 179 | #riotapi.set_region("NA") 180 | #riotapi.set_api_key(myRiotApiKey.api_key) 181 | #champions = riotapi.get_champions() 182 | DISABLED_CHAMPIONS = [] 183 | if(look_local): 184 | with open(LOCAL_CHAMPION_PATH, 'r') as local_data: 185 | response = json.load(local_data) 186 | else: 187 | request = "{static}/{version}/champions".format(static="static-data",version=api_versions["staticdata"]) 188 | params = {"locale":"en_US", "dataById":"true", "api_key":myRiotApiKey.api_key } 189 | response = make_request(request,"GET",params) 190 | data = response["data"] 191 | champions = [] 192 | for value in data.values(): 193 | if(value["name"] in DISABLED_CHAMPIONS): 194 | continue 195 | champion = Champion(value) 196 | champions.append(champion) 197 | 198 | __m.champion_name_from_id = {champion.id: champion.name for champion in champions} 199 | __m.champion_id_from_name = {re.sub("[^A-Za-z0-9]+", "", champion.name.lower()): champion.id for champion in champions} 200 | __m.valid_champion_ids = sorted(__m.champion_name_from_id.keys()) 201 | if not __m.champion_name_from_id: 202 | return False 203 | return True 204 | 205 | def create_Champion_fixture(): 206 | valid_ids = get_champion_ids() 207 | champions = [] 208 | model = 'predict.Champion' 209 | for cid in valid_ids: 210 | champion = {} 211 | champion["model"] = model 212 | champion["pk"] = cid 213 | fields = {} 214 | fields["id"] = cid 215 | fields["display_name"] = champion_name_from_id(cid) 216 | champion["fields"] = fields 217 | champions.append(champion) 218 | with open('champions_fixture.json','w') as outfile: 219 | json.dump(champions,outfile) 220 | 221 | if __name__ == "__main__": 222 | create_Champion_fixture() 223 | -------------------------------------------------------------------------------- /src/features/match_processing.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from .draftstate import DraftState 3 | from .rewards import get_reward 4 | from copy import deepcopy 5 | 6 | import random 7 | import json 8 | 9 | def process_match(match, team, augment_data=True): 10 | """ 11 | process_match takes an input match and breaks each incremental pick and ban down the draft into experiences (aka "memories"). 12 | 13 | Args: 14 | match (dict): match dictionary with pick and ban data for a single game. 15 | team (DraftState.BLUE_TEAM or DraftState.RED_TEAM): The team perspective that is used to process match 16 | The selected team has the positions for each pick explicitly included with the experience while the 17 | "opposing" team has the assigned positions for its champion picks masked. 18 | augment_data (optional) (bool): flag controlling the randomized ordering of submissions that do not affect the draft as a whole 19 | Returns: 20 | experiences ( list(tuple) ): list of experience tuples. Each experience is of the form (s, a, r, s') where: 21 | - s and s' are DraftState states before and after a single action 22 | - a is the (stateIndex, position) tuple of selected champion to be banned or picked. position = 0 for submissions 23 | by the opposing team 24 | - r is the integer reward obtained from submitting the action a 25 | 26 | process_match() can take the vantage from both sides of the draft to parse for memories. This means we can ultimately sample from 27 | both winning drafts (positive reinforcement) and losing drafts (negative reinforcement) when training. 28 | """ 29 | experiences = [] 30 | 31 | # This section controls data agumentation of the match. Certain submissions in the draft are 32 | # submitted consecutively by the same team during the same phase (ie team1 pick0 -> team1 pick1). 33 | # Although these submissions were produced in a particular order, from a draft perspective 34 | # there is no difference between submissions of the form 35 | # team1 pick0 -> team1 pick1 vs team1 pick1 -> team0 pickA 36 | # provided that the two picks are from the same phase (both bans or both picks). 37 | # Therefore it is possible to augment the order in which these submissions are processed. 38 | 39 | # Note that we can also augment the banning phase if desired. Although these submissions technically 40 | # fall outside of the conditions listed above, in practice bans made in the same phase are 41 | # interchangable in order. 42 | 43 | # Build queue of actions from match reference (augmenting if desired) 44 | augments_list = [ 45 | ("blue","bans",slice(0,3)), # Blue bans 0,1,2 are augmentable 46 | ("blue","bans",slice(3,5)), # Blue bans 3,4 are augmentable 47 | ("red","bans",slice(0,3)), 48 | ("red","bans",slice(3,5)), 49 | ("blue","picks",slice(1,3)), # Blue picks 1,2 are augmentable 50 | ("blue","picks",slice(3,5)), # Blue picks 3,4 are augmentable 51 | ("red","picks",slice(0,2)) # Red picks 0,1 are augmentable 52 | ] 53 | if(augment_data): 54 | augmented_match = deepcopy(match) # Deepcopy match to avoid side effects 55 | for aug in augments_list: 56 | (k1,k2,aug_range) = aug 57 | count = len(augmented_match[k1][k2][aug_range]) 58 | augmented_match[k1][k2][aug_range] = random.sample(augmented_match[k1][k2][aug_range],count) 59 | 60 | action_queue = build_action_queue(augmented_match) 61 | else: 62 | action_queue = build_action_queue(match) 63 | 64 | # Set up draft state 65 | draft = DraftState(team) 66 | 67 | finish_memory = False 68 | while action_queue: 69 | # Get next pick from deque 70 | submission = action_queue.popleft() 71 | (submitting_team, pick, position) = submission 72 | 73 | # There are two conditions under which we want to finalize a memory: 74 | # 1. Non-designated team has finished submitting picks for this phase (ie next submission belongs to the designated team) 75 | # 2. Draft is complete (no further picks in the draft) 76 | if submitting_team == team: 77 | if finish_memory: 78 | # This is case 1 to store memory 79 | r = get_reward(draft, match, a, a) 80 | s_next = deepcopy(draft) 81 | memory = (s, a, r, s_next) 82 | experiences.append(memory) 83 | finish_memory = False 84 | # Memory starts when upcoming pick belongs to designated team 85 | s = deepcopy(draft) 86 | # Store action = (champIndex, pos) 87 | a = (pick, position) 88 | finish_memory = True 89 | else: 90 | # Mask positions for pick submissions belonging to the non-designated team 91 | if position != -1: 92 | position = 0 93 | 94 | draft.update(pick, position) 95 | 96 | # Once the queue is empty, store last memory. This is case 2 above. 97 | # There is always an outstanding memory at the completion of the draft. 98 | # RED_TEAM always gets last pick. Therefore: 99 | # if team = BLUE_TEAM -> There is an outstanding memory from last RED_TEAM submission 100 | # if team = RED_TEAM -> Memory is open from just before our last submission 101 | if(draft.evaluate() == DraftState.DRAFT_COMPLETE): 102 | assert finish_memory == True 103 | r = get_reward(draft, match, a, a) 104 | s_next = deepcopy(draft) 105 | memory = (s, a, r, s_next) 106 | experiences.append(memory) 107 | else: 108 | print("Week {} match_id {} {} vs {}".format(match["week"], match["id"], match["blue_team"],match["red_team"])) 109 | draft.display() 110 | print("Error code {}".format(draft.evaluate())) 111 | print("Number of experiences {}".format(len(experiences))) 112 | for experience in experiences: 113 | _,a,_,_ = experience 114 | print(a) 115 | print("")#raise 116 | 117 | return experiences 118 | 119 | def build_action_queue(match): 120 | """ 121 | Builds queue of champion picks or bans (depending on mode) in selection order. If mode = 'ban' this produces a queue of tuples 122 | Args: 123 | match (dict): dictonary structure of match data to be parsed 124 | Returns: 125 | action_queue (deque(tuple)): deque of pick tuples of the form (side_id, champion_id, position_id). 126 | action_queue is produced in selection order. 127 | """ 128 | winner = match["winner"] 129 | action_queue = deque() 130 | phases = {0:{"phase_type":"bans", "pick_order":["blue", "red", "blue", "red", "blue", "red"]}, # phase 1 bans 131 | 1:{"phase_type":"picks", "pick_order":["blue", "red", "red", "blue", "blue", "red"]}, # phase 1 picks 132 | 2:{"phase_type":"bans", "pick_order":["red", "blue", "red", "blue"]}, # phase 2 bans 133 | 3:{"phase_type":"picks","pick_order":["red", "blue", "blue", "red"]}} # phase 2 picks 134 | ban_index = 0 135 | pick_index = 0 136 | completed_actions = 0 137 | for phase in range(4): 138 | phase_type = phases[phase]["phase_type"] 139 | pick_order = phases[phase]["pick_order"] 140 | 141 | num_actions = len(pick_order) 142 | for pick_num in range(num_actions): 143 | side = pick_order[pick_num] 144 | if side == "blue": 145 | side_id = DraftState.BLUE_TEAM 146 | else: 147 | side_id = DraftState.RED_TEAM 148 | if phase_type == "bans": 149 | position_id = -1 150 | index = ban_index 151 | ban_index += pick_num%2 # Order matters here. index needs to be updated *after* use 152 | else: 153 | position_id = match[side][phase_type][pick_index][1] 154 | index = pick_index 155 | pick_index += pick_num%2 # Order matters here. index needs to be updated *after* use 156 | action = (side_id, match[side][phase_type][index][0], position_id) 157 | action_queue.append(action) 158 | completed_actions += 1 159 | 160 | if(completed_actions != 20): 161 | print("Found a match with missing actions!") 162 | print("num_actions = {}".format(num_actions)) 163 | print(json.dumps(match, indent=2, sort_keys=True)) 164 | return action_queue 165 | 166 | if __name__ == "__main__": 167 | data = build_match_pool(1, patches=["8.3"]) 168 | matches = data["matches"] 169 | for match in matches: 170 | print(match["patch"]) 171 | for team in [DraftState.BLUE_TEAM, DraftState.RED_TEAM]: 172 | for augment_data in [False, True]: 173 | experiences = process_match(match, team, augment_data) 174 | count = 0 175 | for exp in experiences: 176 | _,a,_,_ = exp 177 | print("{} - {}".format(count,a)) 178 | count+=1 179 | print("") 180 | 181 | data = build_match_pool(0, randomize=False, patches=["8.4","8.5"]) 182 | # matches = data["matches"] 183 | # for match in matches: 184 | # print("Week {}, Patch {}: {} vs {}. Winner:{}".format(match["week"], match["patch"], match["blue_team"], match["red_team"], match["winner"])) 185 | -------------------------------------------------------------------------------- /src/model_predictions.py: -------------------------------------------------------------------------------- 1 | import experience_replay as er 2 | import match_processing as mp 3 | import champion_info as cinfo 4 | import draft_db_ops as dbo 5 | from draftstate import DraftState 6 | from models.inference_model import QNetInferenceModel, SoftmaxInferenceModel 7 | 8 | import json 9 | import pandas as pd 10 | import numpy as np 11 | import tensorflow as tf 12 | import sqlite3 13 | import math 14 | 15 | #path_to_model = "model_predictions/spring_2018/week_3/run_2/model_E10" 16 | #path_to_model = "tmp/models/model_E10" 17 | #path_to_model = "tmp/model_E{}".format(45) 18 | 19 | path_to_model = "tmp/ddqn_model_E{}".format(45) 20 | model = QNetInferenceModel(name="ddqn", path=path_to_model) 21 | 22 | #path_to_model = "tmp/softmax_model_E{}".format(45) 23 | #model = SoftmaxInferenceModel(name="softmax", path=path_to_model) 24 | print("***") 25 | print("Loading Model From: {}".format(path_to_model)) 26 | print("***") 27 | 28 | out_dir = "model_predictions/dump" 29 | print("***") 30 | print("Outputting To: {}".format(out_dir)) 31 | print("***") 32 | 33 | specific_team = None#"tsm" 34 | print("***") 35 | if(specific_team): 36 | print("Looking at drafts by team:{}".format(specific_team)) 37 | else: 38 | print("Looking at drafts submitted by winning team") 39 | print("***") 40 | 41 | #with open('worlds_matchids_by_stage.txt','r') as infile: 42 | # data = json.load(infile) 43 | #match_ids = data["groups"] 44 | #match_ids.extend(data["knockouts"]) 45 | #match_ids.extend(data["finals"]) 46 | #match_ids.extend(data["play_ins_rd1"]) 47 | #match_ids.extend(data["play_ins_rd2"]) 48 | with open('match_pool.txt','r') as infile: 49 | data = json.load(infile) 50 | match_ids = data['validation_ids'] 51 | #match_ids = data['training_ids'] 52 | #match_ids.extend(data['training_ids']) 53 | dbName = "competitiveGameData.db" 54 | conn = sqlite3.connect("tmp/"+dbName) 55 | cur = conn.cursor() 56 | #match_ids = dbo.get_game_ids_by_tournament(cur,"2017/INTL/WRLDS") 57 | matches = [dbo.get_match_data(cur,match_id) for match_id in match_ids] 58 | conn.close() 59 | if(specific_team): 60 | matches = [match for match in matches if (match["blue_team"]==specific_team or match["red_team"]==specific_team)] 61 | 62 | count = 0 63 | print("************************") 64 | print("Match Schedule:") 65 | print("************************") 66 | with open("{}/_match_schedule.txt".format(out_dir),'w') as outfile: 67 | outfile.write("************************\n") 68 | for match in matches: 69 | output_string = "Match {:2}: id: {:5} tourn: {:20} game_no: {:3} {:6} vs {:6} winner: {:2}".format(count, match["id"], match["tournament"], match["tourn_game_id"], match["blue_team"], match["red_team"], match["winner"]) 70 | print(output_string) 71 | outfile.write(output_string+'\n') 72 | count += 1 73 | outfile.write("************************\n") 74 | 75 | with open("{}/match_data.json".format(out_dir),'w') as outfile: 76 | json.dump(matches,outfile) 77 | 78 | count = 0 79 | k = 5 # Rank to look for in topk range 80 | 81 | full_diag = {"top1":0, "topk":0, "target":0, "l2":[],"k":k} 82 | no_rd1_ban_diag = {"top1":0, "topk":0, "target":0, "l2":[],"k":k} 83 | no_ban_diag = {"top1":0, "topk":0, "target":0, "l2":[],"k":k} 84 | second_phase_only = {"top1":0, "topk":0, "target":0, "l2":[],"k":k} 85 | bans_only = {"top1":0, "topk":0, "target":0, "l2":[],"k":k} 86 | model_diagnostics = {"full":full_diag, "no_rd1_ban":no_rd1_ban_diag, "no_bans":no_ban_diag, "phase_2_only":second_phase_only, "bans":bans_only} 87 | 88 | position_distributions = {"phase_1":[0,0,0,0,0], "phase_2":[0,0,0,0,0]} 89 | actual_pos_distributions = {"phase_1":[0,0,0,0,0], "phase_2":[0,0,0,0,0]} 90 | augmentable_picks = {DraftState.BLUE_TEAM:[0,1,4,6,8], DraftState.RED_TEAM:[0,1,3,6]} 91 | targets = [10,10,10,9,8,7,6,6,6,5] 92 | for match in matches: 93 | # if(specific_team): 94 | # team = DraftState.RED_TEAM if match["red_team"]==specific_team else DraftState.BLUE_TEAM 95 | # else: 96 | # team = DraftState.RED_TEAM if match["winner"]==1 else DraftState.BLUE_TEAM 97 | # teams = [DraftState.BLUE_TEAM, DraftState.RED_TEAM] 98 | teams = [DraftState.RED_TEAM if match["winner"]==1 else DraftState.BLUE_TEAM] 99 | for team in teams: 100 | 101 | experiences = mp.process_match(match, team, augment_data=False) 102 | 103 | print("") 104 | print("Match: {:2} {:6} vs {:6} winner: {:2}".format(count, match["blue_team"], match["red_team"], match["winner"])) 105 | for pick_count, exp in enumerate(experiences): 106 | print(" === ") 107 | print(" Match {}, Pick {}".format(count, pick_count)) 108 | print(" === ") 109 | state,act,rew,next_state = exp 110 | cid,pos = act 111 | if cid == None: 112 | continue 113 | 114 | predicted_q_values = model.predict([state]) 115 | predicted_q_values = predicted_q_values[0,:] 116 | submitted_action_id = state.get_action(*act) 117 | 118 | data = [(a,*state.format_action(a),predicted_q_values[a]) for a in range(len(predicted_q_values))] 119 | data = [(a,cinfo.champion_name_from_id(cid),pos,Q) for (a,cid,pos,Q) in data] 120 | df = pd.DataFrame(data, columns=['act_id','cname','pos','Q(s,a)']) 121 | 122 | df.sort_values('Q(s,a)',ascending=False,inplace=True) 123 | df.reset_index(drop=True,inplace=True) 124 | 125 | df['rank'] = df.index 126 | df['error'] = abs(df['Q(s,a)'][0] - df['Q(s,a)'])/abs(df['Q(s,a)'][0]) 127 | 128 | submitted_row = df[df['act_id']==submitted_action_id] 129 | print(" Submitted action:") 130 | print(submitted_row) 131 | 132 | rank = submitted_row['rank'].iloc[0] 133 | err = submitted_row['error'].iloc[0] 134 | 135 | # For picks submitted back-to-back look ahead to next action to see if it was possibly recommended 136 | if (rank >= k and pick_count in augmentable_picks[team]):#if False: 137 | _,next_action,_,_ = experiences[pick_count+1] 138 | cid,_ = next_action 139 | if(cid): 140 | next_action_id = state.get_action(*next_action) 141 | next_row = df[df['act_id']==next_action_id] 142 | next_rank = next_row['rank'].iloc[0] 143 | if(next_rank < k): 144 | result = state.update(*next_action) 145 | new_exp = (state, act, rew, None) 146 | experiences[pick_count+1] = new_exp 147 | rank = next_rank 148 | print(" AUGMENTED ACTION:") 149 | print(next_row) 150 | 151 | t = targets[pick_count] 152 | # Norms measuring all submissions 153 | if(rank == 0): 154 | model_diagnostics["full"]["top1"] += 1 155 | if(rank < t): 156 | model_diagnostics["full"]["target"] += 1 157 | if(rank < k): 158 | model_diagnostics["full"]["topk"] += 1 159 | model_diagnostics["full"]["l2"].append(err) 160 | 161 | # Norms excluding round 1 bans 162 | if(pick_count > 2): 163 | if(rank == 0): 164 | model_diagnostics["no_rd1_ban"]["top1"] += 1 165 | if(rank < t): 166 | model_diagnostics["no_rd1_ban"]["target"] += 1 167 | if(rank < k): 168 | model_diagnostics["no_rd1_ban"]["topk"] += 1 169 | model_diagnostics["no_rd1_ban"]["l2"].append(err) 170 | 171 | # Norms excluding round 1 completely 172 | if(pick_count > 5): 173 | if(rank == 0): 174 | model_diagnostics["phase_2_only"]["top1"] += 1 175 | if(rank < t): 176 | model_diagnostics["phase_2_only"]["target"] += 1 177 | if(rank < k): 178 | model_diagnostics["phase_2_only"]["topk"] += 1 179 | model_diagnostics["phase_2_only"]["l2"].append(err) 180 | 181 | # Norms excluding all bans 182 | if(pos != -1): 183 | if(rank == 0): 184 | model_diagnostics["no_bans"]["top1"] += 1 185 | if(rank < t): 186 | model_diagnostics["no_bans"]["target"] += 1 187 | if(rank < k): 188 | model_diagnostics["no_bans"]["topk"] += 1 189 | model_diagnostics["no_bans"]["l2"].append(err) 190 | 191 | # Norms for bans only 192 | if(pos == -1): 193 | if(rank == 0): 194 | model_diagnostics["bans"]["top1"] += 1 195 | if(rank < t): 196 | model_diagnostics["bans"]["target"] += 1 197 | if(rank < k): 198 | model_diagnostics["bans"]["topk"] += 1 199 | model_diagnostics["bans"]["l2"].append(err) 200 | 201 | if(rank >= t): 202 | print(" Top predictions:") 203 | print(df.head()) # Print top 5 choices for network 204 | #df.to_pickle("{}/match{}_pick{}.pkl".format(out_dir,count,pick_count)) 205 | 206 | # Position distribution for picks 207 | if(pos > 0): 208 | top_pos = df.head()["pos"].values.tolist() 209 | if(pick_count <=5): 210 | actual_pos_distributions["phase_1"][pos-1] += 1 211 | for pos in top_pos: 212 | position_distributions["phase_1"][pos-1] += 1 213 | else: 214 | actual_pos_distributions["phase_2"][pos-1] += 1 215 | for pos in top_pos: 216 | position_distributions["phase_2"][pos-1] += 1 217 | 218 | pick_count += 1 219 | count += 1 220 | 221 | print("******************") 222 | print("Pick position distributions:") 223 | for phase in ["phase_1", "phase_2"]: 224 | print("{}: Recommendations".format(phase)) 225 | count = sum(position_distributions[phase]) 226 | for pos in range(len(position_distributions[phase])): 227 | pos_ratio = position_distributions[phase][pos] / count 228 | print(" Position {}: Count {:3}, Ratio {:.3}".format(pos+1, position_distributions[phase][pos], pos_ratio)) 229 | 230 | print("{}: Actual".format(phase)) 231 | count = sum(actual_pos_distributions[phase]) 232 | for pos in range(len(actual_pos_distributions[phase])): 233 | pos_ratio = actual_pos_distributions[phase][pos] / count 234 | print(" Position {}: Count {:3}, Ratio {:.3}".format(pos+1, actual_pos_distributions[phase][pos], pos_ratio)) 235 | 236 | print("******************") 237 | print("Norm Information:") 238 | for key in sorted(model_diagnostics.keys()): 239 | print(" {}".format(key)) 240 | err_list = model_diagnostics[key]["l2"] 241 | err = math.sqrt((sum([e**2 for e in err_list])/len(err_list))) 242 | num_predictions = len(err_list) 243 | top1 = model_diagnostics[key]["top1"] 244 | topk = model_diagnostics[key]["topk"] 245 | target = model_diagnostics[key]["target"] 246 | k = model_diagnostics[key]["k"] 247 | print(" Num_predictions = {}".format(num_predictions)) 248 | print(" top 1: count {} -> acc: {:.4}".format(top1, top1/num_predictions)) 249 | print(" top {}: count {} -> acc: {:.4}".format(k, topk, topk/num_predictions)) 250 | print(" target: count {} -> acc: {:.4}".format(target, target/num_predictions)) 251 | print(" l2 error: {:.4}".format(err)) 252 | print("---") 253 | print("******************") 254 | -------------------------------------------------------------------------------- /src/models/qNetwork.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from . import base_model 5 | 6 | class Qnetwork(base_model.BaseModel): 7 | """ 8 | Args: 9 | name (string): label for model namespace 10 | path (string): path to save/load model 11 | input_shape (tuple): tuple of inputs to network. 12 | output_shape (int): number of output nodes for network. 13 | filter_sizes (tuple of ints): number of filters in each of the two hidden layers. Defaults to (512,512). 14 | learning_rate (float): network's willingness to change current weights given new example 15 | regularization (float): strength of weights regularization term in loss function 16 | discount_factor (float): factor by which future reward after next action is taken are discounted 17 | tau (float): Hyperparameter used in updating target network (if used) 18 | Some notable values: 19 | tau = 1.e-3 -> used in original paper 20 | tau = 0.5 -> average DDQN 21 | tau = 1.0 -> copy online -> target 22 | 23 | A Q-network class which is responsible for holding and updating the weights and biases used in predicing Q-values for a given state. This Q-network will consist of 24 | the following layers: 25 | 1) Input- a DraftState state s (an array of bool) representing the current state reshaped into an [n_batch, *input_shape] tensor. 26 | 2) Two layers of relu-activated hidden fc layers with dropout 27 | 3) Output- linearly activated estimations for Q-values Q(s,a) for each of the output_shape actions a available. 28 | 29 | """ 30 | @property 31 | def name(self): 32 | return self._name 33 | 34 | @property 35 | def discount_factor(self): 36 | return self._discount_factor 37 | 38 | def __init__(self, name, path, input_shape, output_shape, filter_sizes=(512,512), learning_rate=1.e-5, regularization_coeff=1.e-4, discount_factor=0.9, tau=1.0): 39 | super().__init__(name=name, path=path) 40 | self._input_shape = input_shape 41 | self._output_shape = output_shape 42 | self._filter_sizes = filter_sizes 43 | self._learning_rate = learning_rate 44 | self._regularization_coeff = regularization_coeff 45 | self._discount_factor = discount_factor 46 | self._n_hidden_layers = len(filter_sizes) 47 | self._n_layers = self._n_hidden_layers + 2 48 | self._tau = tau 49 | 50 | self.online_name = "online" 51 | self.target_name = "target" 52 | # Build base Q-network model 53 | self.online_ops = self.build_model(name = self.online_name) 54 | # If using a target network for DDQN network, add related ops to model 55 | if(self.target_name): 56 | self.target_ops = self.build_model(name = self.target_name) 57 | self.target_ops["target_init"] = self.create_target_initialization_ops(self.target_name, self.online_name) 58 | self.target_ops["target_update"] = self.create_target_update_ops(self.target_name, self.online_name, tau=self._tau) 59 | with self._graph.as_default(): 60 | self.online_ops["init"] = tf.global_variables_initializer() 61 | self.init_saver() 62 | 63 | def init_saver(self): 64 | with self._graph.as_default(): 65 | self.saver = tf.train.Saver() 66 | 67 | def save(self, path): 68 | self.saver.save(self.sess, save_path=path) 69 | 70 | def load(self, path): 71 | self.saver.restore(self.sess, save_path=path) 72 | 73 | def build_model(self, name): 74 | ops_dict = {} 75 | with self._graph.as_default(): 76 | with tf.variable_scope(name): 77 | ops_dict["learning_rate"] = tf.Variable(self._learning_rate, trainable=False, name="learning_rate") 78 | 79 | # Incoming state matrices are of size input_size = (nChampions, nPos+2) 80 | # 'None' here means the input tensor will flex with the number of training 81 | # examples (aka batch size). 82 | ops_dict["input"] = tf.placeholder(tf.float32, (None,)+self._input_shape, name="inputs") 83 | ops_dict["dropout_keep_prob"] = tf.placeholder_with_default(1.0,shape=()) 84 | 85 | # Fully connected (FC) layers: 86 | fc0 = tf.layers.dense( 87 | ops_dict["input"], 88 | self._filter_sizes[0], 89 | activation=tf.nn.relu, 90 | bias_initializer=tf.constant_initializer(0.1), 91 | name="fc_0") 92 | dropout0 = tf.nn.dropout(fc0, ops_dict["dropout_keep_prob"]) 93 | 94 | fc1 = tf.layers.dense( 95 | dropout0, 96 | self._filter_sizes[1], 97 | activation=tf.nn.relu, 98 | bias_initializer=tf.constant_initializer(0.1), 99 | name="fc_1") 100 | dropout1 = tf.nn.dropout(fc1, ops_dict["dropout_keep_prob"]) 101 | 102 | # FC output layer 103 | ops_dict["outQ"] = tf.layers.dense( 104 | dropout1, 105 | self._output_shape, 106 | activation=None, 107 | bias_initializer=tf.constant_initializer(0.1), 108 | kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=self._regularization_coeff), 109 | name="q_vals") 110 | 111 | # Placeholder for valid actions filter 112 | ops_dict["valid_actions"] = tf.placeholder(tf.bool, shape=ops_dict["outQ"].shape, name="valid_actions") 113 | 114 | # Filtered Q-values 115 | ops_dict["valid_outQ"] = tf.where(ops_dict["valid_actions"], ops_dict["outQ"], tf.scalar_mul(-np.inf,tf.ones_like(ops_dict["outQ"])), name="valid_q_vals") 116 | 117 | # Max Q value amongst valid actions 118 | ops_dict["max_Q"] = tf.reduce_max(ops_dict["valid_outQ"], axis=1, name="max_Q") 119 | 120 | # Predicted optimal action amongst valid actions 121 | ops_dict["prediction"] = tf.argmax(ops_dict["valid_outQ"], axis=1, name="prediction") 122 | 123 | # Loss function and optimization: 124 | # The inputs self.target and self.actions are indexed by training example. If 125 | # s[i] = starting state for ith training example (recall that input state s is described by a vector so this will be a matrix) 126 | # a*[i] = action taken from state s[i] during this training sample 127 | # Q*(s[i],a*[i]) = the actual value observed from taking action a*[i] from state s[i] 128 | # outQ[i,-] = estimated values for all actions from state s[i] 129 | # Then we can write the inputs as 130 | # self.target[i] = Q*(s[i],a*[i]) 131 | # self.actions[i] = a*[i] 132 | 133 | ops_dict["target"] = tf.placeholder(tf.float32, shape=[None], name="target_Q") 134 | ops_dict["actions"] = tf.placeholder(tf.int32, shape=[None], name="submitted_action") 135 | 136 | # Since the Qnet outputs a vector Q(s,-) of predicted values for every possible action that can be taken from state s, 137 | # we need to connect each target value with the appropriate predicted Q(s,a*) = Qout[i,a*[i]]. 138 | # Main idea is to get indexes into the outQ tensor based on input actions and gather the resulting Q values 139 | # For some reason this isn't easy for tensorflow to do. So we must manually form the list of 140 | # [i, actions[i]] index pairs for outQ.. 141 | # n_batch = outQ.shape[0] = actions.shape[0] 142 | # n_actions = outQ.shape[1] 143 | ind = tf.stack([tf.range(tf.shape(ops_dict["actions"])[0]),ops_dict["actions"]],axis=1) 144 | # and then "gather" them. 145 | estimatedQ = tf.gather_nd(ops_dict["outQ"], ind) 146 | # Special notes: this is more efficient than indexing into the flattened version of outQ (which I have seen before) 147 | # because the gather operation is applied to outQ directly. Apparently this propagates the gradient more efficiently 148 | # under specific sparsity conditions (which tf.Variables like outQ satisfy) 149 | 150 | # Simple sum-of-squares loss (error) function. Note that biases do not 151 | # need to be regularized since they are (generally) not subject to overfitting. 152 | ops_dict["loss"] = tf.reduce_mean(0.5*tf.square(ops_dict["target"]-estimatedQ), name="loss") 153 | 154 | ops_dict["trainer"] = tf.train.AdamOptimizer(learning_rate = ops_dict["learning_rate"]) 155 | ops_dict["update"] = ops_dict["trainer"].minimize(ops_dict["loss"], name="update") 156 | 157 | return ops_dict 158 | 159 | def create_target_update_ops(self, target_scope, online_scope, tau=1e-3, name="target_update"): 160 | """ 161 | Adds operations to graph which are used to update the target network after after a training batch is sent 162 | through the online network. 163 | 164 | This function should be executed only once before training begins. The resulting operations should 165 | be run within a tf.Session() once per training batch. 166 | 167 | In double-Q network learning, the online (primary) network is updated using traditional backpropegation techniques 168 | with target values produced by the target-Q network. 169 | To improve stability, the target-Q is updated using a linear combination of its current weights 170 | with the current weights of the online network: 171 | Q_target = tau*Q_online + (1-tau)*Q_target 172 | Typical tau values are small (tau ~ 1e-3). For more, see https://arxiv.org/abs/1509.06461 and https://arxiv.org/pdf/1509.02971.pdf. 173 | Args: 174 | target_scope (str): name of scope that target network occupies 175 | online_scope (str): name of scope that online network occupies 176 | tau (float32): Hyperparameter for combining target-Q and online-Q networks 177 | name (str): name of operation which updates the target network when run within a session 178 | Returns: Tensorflow operation which updates the target nework when run. 179 | """ 180 | with self._graph.as_default(): 181 | target_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=target_scope) 182 | online_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=online_scope) 183 | ops = [target_params[i].assign(tf.add(tf.multiply(tau,online_params[i]),tf.multiply(1.-tau,target_params[i]))) for i in range(len(target_params))] 184 | return tf.group(*ops,name=name) 185 | 186 | def create_target_initialization_ops(self, target_scope, online_scope): 187 | """ 188 | This adds operations to the graph in order to initialize the target Q network to the same values as the 189 | online network. 190 | 191 | This function should be executed only once just after the online network has been initialized. 192 | 193 | Args: 194 | target_scope (str): name of scope that target network occupies 195 | online_scope (str): name of scope that online network occupies 196 | Returns: 197 | Tensorflow operation (named "target_init") which initialize the target nework when run. 198 | """ 199 | return self.create_target_update_ops(target_scope, online_scope, tau=1.0, name="target_init") 200 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/data/database_ops.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import re 3 | from .champion_info import champion_id_from_name,champion_name_from_id, convert_champion_alias, AliasException 4 | 5 | regionsDict = {"NA_LCS":"NA", "EU_LCS":"EU", "LCK":"LCK", "LPL":"LPL", 6 | "LMS":"LMS", "International":"INTL", "NA_ACA": "NA_ACA", "KR_CHAL":"KR_CHAL", "LDL":"LDL"} 7 | internationalEventsDict = {"Mid-Season_Invitational":"MSI", 8 | "Rift_Rivals":"RR","World_Championship":"WRLDS"} 9 | 10 | def get_matches_by_id(match_ids, path): 11 | """ 12 | Returns match data for each match_id in the list match_ids 13 | """ 14 | conn = sqlite3.connect(path) 15 | cur = conn.cursor() 16 | match_data = [] 17 | for match_id in match_ids: 18 | match = get_match_data(cur, match_id) 19 | match_data.append(match) 20 | conn.close() 21 | return match_data 22 | 23 | def get_game_ids_by_tournament(cursor, tournament, patch=None): 24 | """ 25 | getMatchIdsByTournament queries the connected db for game ids which match the 26 | input tournament string. 27 | 28 | Args: 29 | cursor (sqlite cursor): cursor used to execute commmands 30 | tournament (string): id string for tournament (ie "2017/EU/Summer_Split") 31 | patch (string, optional): id string for patch to additionally filter 32 | Returns: 33 | gameIds (list(int)): list of gameIds 34 | """ 35 | if patch: 36 | query = "SELECT id FROM game WHERE tournament=? AND patch=? ORDER BY id" 37 | params = (tournament, patch) 38 | else: 39 | query = "SELECT id FROM game WHERE tournament=? ORDER BY id" 40 | params = (tournament,) 41 | cursor.execute(query, params) 42 | response = cursor.fetchall() 43 | vals = [] 44 | for r in response: 45 | vals.append(r[0]) 46 | return vals 47 | 48 | def get_game_ids(cursor, tournament=None, patch=None): 49 | """ 50 | get_game_ids queries the connected db for game ids which match the 51 | input tournament and patch strings. 52 | 53 | Args: 54 | cursor (sqlite cursor): cursor used to execute commmands 55 | tournament (string, optional): id string for tournament (ie "2017/EU/Summer_Split") 56 | patch (string, optional): id string for patch to filter for 57 | Returns: 58 | gameIds (list(int)): list of gameIds 59 | """ 60 | if not patch and not tournament: 61 | return [] 62 | 63 | params = () 64 | where_clause = [] 65 | if tournament: 66 | where_clause.append("tournament=?") 67 | params += (tournament,) 68 | if patch: 69 | where_clause.append("patch=?") 70 | params += (patch,) 71 | 72 | query = "SELECT id FROM game WHERE {where_clause} ORDER BY id".format(where_clause=" AND ".join(where_clause)) 73 | cursor.execute(query, params) 74 | response = cursor.fetchall() 75 | vals = [] 76 | for r in response: 77 | vals.append(r[0]) 78 | return vals 79 | 80 | def get_match_data(cursor, gameId): 81 | """ 82 | get_match_data queries the connected db for draft data and organizes it into a more convenient 83 | format. 84 | 85 | Args: 86 | cursor (sqlite cursor): cursor used to execute commmands 87 | gameId (int): primary key of game to process 88 | Returns: 89 | match (dict): formatted pick/ban phase data for game 90 | """ 91 | match = {"id": gameId ,"winner": None, "blue":{}, "red":{}, "blue_team":None, "red_team":None, "header_id":None, "patch":None} 92 | # Get winning team 93 | query = "SELECT tournament, tourn_game_id, week, patch, winning_team FROM game WHERE id=?" 94 | params = (gameId,) 95 | cursor.execute(query, params) 96 | match["tournament"], match["tourn_game_id"], match["header_id"], match["patch"], match["winner"] = cursor.fetchone()#[0] 97 | 98 | # Get ban data 99 | query = "SELECT champion_id, selection_order FROM ban WHERE game_id=? and side_id=? ORDER BY selection_order" 100 | params = (gameId,0) 101 | cursor.execute(query, params) 102 | match["blue"]["bans"] = list(cursor.fetchall()) 103 | 104 | query = "SELECT champion_id, selection_order FROM ban WHERE game_id=? and side_id=? ORDER BY selection_order" 105 | params = (gameId,1) 106 | cursor.execute(query, params) 107 | match["red"]["bans"] = list(cursor.fetchall()) 108 | 109 | # Get pick data 110 | query = "SELECT champion_id, position_id, selection_order FROM pick WHERE game_id=? AND side_id=? ORDER BY selection_order" 111 | params = (gameId,0) 112 | cursor.execute(query, params) 113 | match["blue"]["picks"] = list(cursor.fetchall()) 114 | 115 | query = "SELECT champion_id, position_id, selection_order FROM pick WHERE game_id=? AND side_id=? ORDER BY selection_order" 116 | params = (gameId,1) 117 | cursor.execute(query, params) 118 | match["red"]["picks"] = list(cursor.fetchall()) 119 | 120 | query = "SELECT display_name FROM team JOIN game ON team.id = blue_teamid WHERE game.id = ?" 121 | params = (gameId,) 122 | cursor.execute(query, params) 123 | match["blue_team"] = cursor.fetchone()[0] 124 | 125 | query = "SELECT display_name FROM team JOIN game ON team.id = red_teamid WHERE game.id = ?" 126 | params = (gameId,) 127 | cursor.execute(query, params) 128 | match["red_team"] = cursor.fetchone()[0] 129 | 130 | return match 131 | 132 | def get_tournament_data(gameData): 133 | """ 134 | get_tournament_data cleans up and combines the region/year/tournament fields in gameData for entry into 135 | the game table. When combined with the game_id field it uniquely identifies the match played. 136 | The format of tournamentData output is 'year/region_abbrv/tournament' (forward slash delimiters) 137 | 138 | Args: 139 | gameData (dict): dictonary output from query_wiki() 140 | Returns: 141 | tournamentData (string): formatted and cleaned region/year/split data 142 | """ 143 | tournamentData = "/".join([gameData["year"], regionsDict[gameData["region"]], gameData["tournament"]]) 144 | return tournamentData 145 | 146 | def get_game_id(cursor,gameData): 147 | """ 148 | get_game_id looks in the game table for an entry with matching tournament and tourn_game_id as the input 149 | gameData and returns the id field. If no such entry is found, it adds this game to the game table and returns the 150 | id field. 151 | 152 | Args: 153 | cursor (sqlite cursor): cursor used to execute commmands 154 | gameData (dict): dictionary output from query_wiki() 155 | Returns: 156 | gameId (int): Primary key in game table corresponding to this gameData 157 | """ 158 | tournament = get_tournament_data(gameData) 159 | vals = (tournament,gameData["tourn_game_id"]) 160 | gameId = None 161 | while gameId is None: 162 | cursor.execute("SELECT id FROM game WHERE tournament=? AND tourn_game_id=?", vals) 163 | gameId = cursor.fetchone() 164 | if gameId is None: 165 | print("Warning: Game not found. Attempting to add game.") 166 | err = insert_game(cursor,[game]) 167 | else: 168 | gameId = gameId[0] 169 | return gameId 170 | 171 | def delete_game_from_table(cursor, game_ids, table_name): 172 | """ 173 | Deletes rows corresponding to game_id from table table_name. 174 | Args: 175 | cursor (sqlite cursor): cursor used to execute commmands 176 | game_ids (list(int)): game_ids to be removed from table 177 | table_name (string): name of table to remove rows from 178 | Returns: 179 | status (int): status = 1 if delete was successful, otherwise status = 0 180 | """ 181 | status = 0 182 | assert isinstance(game_ids,list), "game_ids is not a list" 183 | for game_id in game_ids: 184 | query = "SELECT count(*) FROM {table_name} WHERE game_id=?".format(table_name=table_name) 185 | vals = (game_id,) 186 | cursor.execute(query, vals) 187 | print("Found {count} rows for game_id={game_id} to delete from table {table}".format(count=cursor.fetchone()[0], game_id=game_id, table=table_name)) 188 | 189 | query = "DELETE FROM {table_name} WHERE game_id=?".format(table_name=table_name) 190 | cursor.execute(query, vals) 191 | status = 1 192 | return status 193 | 194 | def insert_game(cursor, gameData): 195 | """ 196 | insert_game attempts to format collected gameData from query_wiki() and insert 197 | into the game table in the competitiveGameData.db. 198 | 199 | Args: 200 | cursor (sqlite cursor): cursor used to execute commmands 201 | gameData (list(dict)): list of dictionary output from query_wiki() 202 | Returns: 203 | status (int): status = 1 if insert was successful, otherwise status = 0 204 | """ 205 | status = 0 206 | assert isinstance(gameData,list), "gameData is not a list" 207 | for game in gameData: 208 | tournGameId = game["tourn_game_id"] # Which game this is within current tournament 209 | tournamentData = get_tournament_data(game) 210 | 211 | # Check to see if game data is already in table 212 | vals = (tournamentData,tournGameId) 213 | cursor.execute("SELECT id FROM game WHERE tournament=? AND tourn_game_id=?", vals) 214 | result = cursor.fetchone() 215 | if result is not None: 216 | print("game {} already exists in table.. skipping".format(result[0])) 217 | else: 218 | # Get blue and red team_ids 219 | blueTeamId = None 220 | redTeamId = None 221 | while (blueTeamId is None or redTeamId is None): 222 | cursor.execute("SELECT id FROM team WHERE display_name=?",(game["blue_team"],)) 223 | blueTeamId = cursor.fetchone() 224 | cursor.execute("SELECT id FROM team WHERE display_name=?",(game["red_team"],)) 225 | redTeamId = cursor.fetchone() 226 | if (blueTeamId is None) or (redTeamId is None): 227 | print("*WARNING: When inserting game-- team not found. Attempting to add teams") 228 | err = insert_team(cursor, [game]) 229 | else: 230 | blueTeamId = blueTeamId[0] 231 | redTeamId = redTeamId[0] 232 | 233 | winner = game["winning_team"] 234 | header_id = game["header_id"] 235 | patch = game["patch"] 236 | vals = (tournamentData, tournGameId, header_id, patch, blueTeamId, redTeamId, winner) 237 | cursor.execute("INSERT INTO game(tournament, tourn_game_id, week, patch, blue_teamid, red_teamid, winning_team) VALUES(?,?,?,?,?,?,?)", vals) 238 | status = 1 239 | return status 240 | 241 | def insert_team(cursor, gameData): 242 | """ 243 | insert_team attempts to format collected gameData from query_wiki() and insert 244 | into the team table in the competitiveGameData.db. 245 | 246 | Args: 247 | cursor (sqlite cursor): cursor used to execute commmands 248 | wikiGameData (list(dict)): dictionary output from query_wiki() 249 | Returns: 250 | status (int): status = 1 if insert was successful, otherwise status = 0 251 | """ 252 | status = 0 253 | assert isinstance(gameData,list), "gameData is not a list" 254 | for game in gameData: 255 | # We don't track all regions (i.e wildcard regions), but they can still appear at 256 | # international tournaments. When this happens we will track the team, but list their 257 | # region as NULL. 258 | if game["region"] is "Inernational": 259 | region = None 260 | else: 261 | region = regionsDict[game["region"]] 262 | teams = [game["blue_team"], game["red_team"]] 263 | for team in teams: 264 | vals = (region,team) 265 | # This only looks for matching display names.. what happens if theres a 266 | # NA TSM and and EU TSM? 267 | cursor.execute("SELECT * FROM team WHERE display_name=?", (team,)) 268 | result = cursor.fetchone() 269 | if result is None: 270 | cursor.execute("INSERT INTO team(region, display_name) VALUES(?,?)", vals) 271 | status = 1 272 | return status 273 | 274 | def insert_ban(cursor, gameData): 275 | """ 276 | insert_ban attempts to format collected gameData from query_wiki() and insert into the 277 | ban table in the competitiveGameData.db. 278 | 279 | Args: 280 | cursor (sqlite cursor): cursor used to execute commmands 281 | gameData (list(dict)): dictionary output from query_wiki() 282 | Returns: 283 | status (int): status = 1 if insert was successful, otherwise status = 0 284 | """ 285 | status = 0 286 | assert isinstance(gameData,list), "gameData is not a list" 287 | teams = ["blue", "red"] 288 | for game in gameData: 289 | tournament = get_tournament_data(game) 290 | vals = (tournament,game["tourn_game_id"]) 291 | gameId = get_game_id(cursor,game) 292 | # Check for existing entries in table. Skip if they already exist. 293 | cursor.execute("SELECT game_id FROM ban WHERE game_id=?",(gameId,)) 294 | result = cursor.fetchone() 295 | if result is not None: 296 | print("Bans for game {} already exists in table.. skipping".format(result[0])) 297 | else: 298 | for k in range(len(teams)): 299 | bans = game["bans"][teams[k]] 300 | selectionOrder = 0 301 | side = k 302 | for ban in bans: 303 | if ban in ["lossofban","none"]: 304 | # Special case if no ban was submitted in game 305 | banId = None 306 | else: 307 | # print("ban={}".format(ban)) 308 | banId = champion_id_from_name(ban) 309 | # If no such champion name is found, try looking for an alias 310 | if banId is None: 311 | banId = champion_id_from_name(convert_champion_alias(ban)) 312 | selectionOrder += 1 313 | vals = (gameId,banId,selectionOrder,side) 314 | cursor.execute("INSERT INTO ban(game_id, champion_id, selection_order, side_id) VALUES(?,?,?,?)", vals) 315 | status = 1 316 | return status 317 | 318 | def insert_pick(cursor, gameData): 319 | """ 320 | insert_pick formats collected gameData from query_wiki() and inserts it into the pick table of the 321 | competitiveGameData.db. 322 | 323 | Args: 324 | cursor (sqlite cursor): cursor used to execute commmands 325 | gameData (list(dict)): list of formatted game data from query_wiki() 326 | Returns: 327 | status (int): status = 1 if insert was successful, otherwise status = 0 328 | """ 329 | status = 0 330 | assert isinstance(gameData,list), "gameData is not a list" 331 | teams = ["blue", "red"] 332 | for game in gameData: 333 | tournament = get_tournament_data(game) 334 | vals = (tournament,game["tourn_game_id"]) 335 | gameId = get_game_id(cursor,game) 336 | # Check for existing entries in table. Skip if they already exist. 337 | cursor.execute("SELECT game_id FROM pick WHERE game_id=?",(gameId,)) 338 | result = cursor.fetchone() 339 | if result is not None: 340 | print("Picks for game {} already exists in table.. skipping".format(result[0])) 341 | else: 342 | for k in range(len(teams)): 343 | picks = game["picks"][teams[k]] 344 | selectionOrder = 0 345 | side = k 346 | for (pick,position) in picks: 347 | if pick in ["lossofpick","none"]: 348 | # Special case if no pick was submitted to game (not really sure what that would mean 349 | # but being consistent with insert_pick()) 350 | pickId = None 351 | else: 352 | pickId = champion_id_from_name(pick) 353 | # If no such champion name is found, try looking for an alias 354 | if pickId is None: 355 | pickId = champion_id_from_name(convert_champion_alias(pick)) 356 | selectionOrder += 1 357 | vals = (gameId,pickId,position,selectionOrder,side) 358 | cursor.execute("INSERT INTO pick(game_id, champion_id, position_id, selection_order, side_id) VALUES(?,?,?,?,?)", vals) 359 | status = 1 360 | return status 361 | -------------------------------------------------------------------------------- /src/data/query_wiki.py: -------------------------------------------------------------------------------- 1 | import json # JSON tools 2 | import requests # URL api tools 3 | import re # regex tools 4 | from .champion_info import convert_champion_alias, champion_id_from_name 5 | 6 | def query_wiki(year, region, tournament): 7 | """ 8 | query_wiki takes identifying sections and subsections for a page title on leaguepedia and formats and executes a set of requests to the 9 | API looking for the pick/ban data corresponding to the specified sections and subsections. This response is then 10 | pruned and formatted into a list of dictionaries. Specified sections and subsections should combine into a unique identifying string 11 | for a specific tournament and query_wiki() will return all games for that tournament. 12 | 13 | For example, if we are interested in the regular season of the 2017 European Summer Split we would call: 14 | query_wiki("2017", "EU_LCS", "Summer_Season") 15 | 16 | If we were interested in 2016 World Championship we would pass: 17 | query_wiki("2016", "International", "WORLDS/Main_Event") 18 | 19 | Each dictionary corresponds to the pick/ban phase of an LCS game with the following keys: 20 | "region": 21 | "season": 22 | "tournament": 23 | "bans": {"blue":, "red":} 24 | "blue_team": 25 | "blue_team_score" 26 | "red_team": 27 | "red_team_score:" 28 | "tourn_game_id": 29 | "picks": {"blue":, "red":} 30 | 31 | Args: 32 | year (string): year of game data of interest 33 | region (string): region of play for games 34 | tournament (string): which tournament games were played in 35 | Returns: 36 | List of dictionaries containing formatted response data from lol.gamepedia api 37 | """ 38 | # Common root for all requests 39 | url_root = "https://lol.gamepedia.com/api.php" 40 | 41 | # Semi-standardized page suffixes for pick/ban pages 42 | page_suffixes = ["", "/Bracket_Stage", "/3-4", "/5-6", "/5-8", "/4-6", "/4-7", "/7-9", "/7-10", "/8-10", "/7-8", "/9-11"] 43 | 44 | formatted_regions = {"NA_LCS":"League_Championship_Series/North_America", 45 | "NA_ACA":"NA_Academy_League", 46 | "EU_LCS":"League_Championship_Series/Europe", 47 | "LCK":"LCK", 48 | "KR_CHAL":"Challengers_Korea", 49 | "LPL":"LPL", 50 | "LMS":"LMS", 51 | "LDL":"LDL"} 52 | 53 | formatted_international_tournaments = { 54 | "WORLDS/Play-In": "Season_World_Championship/Play-In", 55 | "WORLDS/Main_Event": "Season_World_Championship/Main_Event", 56 | "MSI/Play-In": "Mid-Season_Invitational/Play-In", 57 | "MSI/Main_Event": "Mid-Season_Invitational/Main_Event", 58 | "WORLDS_QUALS/NA": "Season_North_America_Regional_Finals", 59 | "WORLDS_QUALS/EU": "Season_Europe_Regional_Finals", 60 | "WORLDS_QUALS/LCK": "Season_Korea_Regional_Finals", 61 | "WORLDS_QUALS/LPL": "Season_China_Regional_Finals", 62 | "WORLDS_QUALS/LMS": "Season_Taiwan_Regional_Finals", 63 | } 64 | 65 | with open('../data/patch_info.json','r') as infile: 66 | patch_data = json.load(infile) 67 | patches = patch_data["patch_info"][year][region][tournament] 68 | print(patches) 69 | 70 | # Build list of titles of pages to query 71 | if region == "International": 72 | title_root = ["_".join([year,formatted_international_tournaments[tournament]])] 73 | else: 74 | formatted_region = formatted_regions[region] 75 | formatted_year = "_".join([year,"Season"]) 76 | title_root = [formatted_region, formatted_year, tournament] 77 | title_root.append("Picks_and_Bans") 78 | title_root = "/".join(title_root) 79 | 80 | title_list = [] 81 | for suffix in page_suffixes: 82 | title_list.append(title_root+suffix) 83 | formatted_title_list = "|".join(title_list) # Parameter string to pass to API 84 | params = {"action": "query", "titles": formatted_title_list, 85 | "prop":"revisions", "rvprop":"content", "format": "json"} 86 | 87 | response = requests.get(url=url_root, params=params) 88 | print(response.url) 89 | data = json.loads(response.text) 90 | page_data = data['query']['pages'] 91 | # Get list of page keys (actually a list of pageIds.. could be used to identify pages?) 92 | page_keys = list(sorted(page_data.keys())) 93 | page_keys = [k for k in page_keys if int(k)>=0] # Filter out "invalid page" and "missing page" responses 94 | formatted_data = [] 95 | tournGameId = 0 96 | 97 | for page in page_keys: 98 | # Get the raw text of the most recent revision of the current page 99 | # Note that we remove all space characters from the raw text, including those 100 | # in team or champion names. 101 | raw_text = page_data[page]["revisions"][0]["*"].replace(" ","").replace("\\n"," ") 102 | print(page_data[page]["title"]) 103 | 104 | # week_labels = parse_raw_text("(name=Week[0-9]+)", raw_text) 105 | # week_numbers = [int(i.replace("week","")) for i in week_labels] 106 | # week_data = re.split("(name=Week[0-9]+)", raw_text)[2::2] 107 | 108 | # Get section headers 109 | headers = parse_raw_text("(name=[\w0-9]+)", raw_text) 110 | # Look for indexed headers first. If found, use this index for patch data. If no such index is found, use the first element of patch data 111 | search = [re.search("[0-9]+", header) for header in headers] 112 | header_indices = [int(s.group()) if s else 0 for s in search] 113 | 114 | # If there's only one patch for this tournament, make sure all header indices point to it. 115 | if len(patches) == 1: 116 | header_indices = [0 for header in header_indices] 117 | 118 | # This is an edge case for when some header indices can be read but there are also non-indexable headers 119 | # Good example is when there is a "Tiebreaker" section at the end of a indexed regular split. 120 | # This bit sets the patch data index for such sections to the same one as the most recent non-zero index 121 | # Ex: [1,2,3,4,0] -> [1,2,3,4,4] 122 | if not all(value==0 for value in header_indices): 123 | last_val = 0 124 | for i,val in enumerate(header_indices): 125 | if val!=0: 126 | last_val = val 127 | else: 128 | header_indices[i] = last_val 129 | 130 | section_data = re.split("(name=[\w0-9]+)", raw_text)[2::2] 131 | 132 | num_games_on_page = 0 133 | for i in range(len(section_data)): 134 | data = section_data[i] 135 | 136 | # winning_teams holds which team won for each parsed game 137 | # winner = 1 -> first team won (i.e blue team) 138 | # winner = 2 -> second team won (i.e red team) 139 | winning_teams = parse_raw_text("(winner=[0-9])", data) 140 | winning_teams = [int(i)-1 for i in winning_teams] # Convert string response to int 141 | num_games_in_week = len(winning_teams) 142 | 143 | if(num_games_in_week == 0): 144 | continue 145 | else: 146 | num_games_on_page += num_games_in_week 147 | 148 | # string representation of blue and red teams, ordered by game 149 | blue_teams = parse_raw_text("(team1=[\w\s]+)", data) 150 | red_teams = parse_raw_text("(team2=[\w\s]+)", data) 151 | 152 | blue_scores = parse_raw_text("(team1score=[0-9])", data) 153 | red_scores = parse_raw_text("(team2score=[0-9])", data) 154 | 155 | # bans holds the string identifiers of submitted bans for each team in the parsed game 156 | # ex: bans[k] = list of bans for kth game on the page 157 | all_blue_bans = parse_raw_text("(blueban[0-9]=\w[\w\s',.]+)", data) 158 | all_red_bans = parse_raw_text("(red_ban[0-9]=\w[\w\s',.]+)", data) 159 | assert len(all_blue_bans)==len(all_red_bans), "blue bans: {}, red bans: {}".format(len(all_blue_bans),len(all_red_bans)) 160 | bans_per_team = len(all_blue_bans)//num_games_in_week 161 | 162 | # blue_picks[i] = list of picks for kth game on the page 163 | all_blue_picks = parse_raw_text("(bluepick[0-9]=\w[\w\s',.]+)", data) 164 | all_blue_roles = parse_raw_text("(bluerole[0-9]=\w[\w\s',.]?)", data) 165 | all_red_picks = parse_raw_text("(red_pick[0-9]=\w[\w\s',.]+)", data) 166 | all_red_roles = parse_raw_text("(red_role[0-9]=\w[\w\s',.]?)", data) 167 | #print(data) 168 | assert len(all_blue_picks)==len(all_red_picks), "blue picks: {}, red picks: {}".format(len(all_blue_picks),len(all_red_picks)) 169 | assert len(all_blue_roles)==len(all_red_roles), "blue roles: {}, red roles: {}".format(len(all_blue_roles),len(all_red_roles)) 170 | picks_per_team = len(all_blue_picks)//num_games_in_week 171 | 172 | # Clean fields involving chanmpion names, looking for aliases if necessary 173 | all_blue_bans = clean_champion_names(all_blue_bans) 174 | all_red_bans = clean_champion_names(all_red_bans) 175 | all_blue_picks = clean_champion_names(all_blue_picks) 176 | all_red_picks = clean_champion_names(all_red_picks) 177 | 178 | # Format data by match 179 | blue_bans = [] 180 | red_bans = [] 181 | for k in range(num_games_in_week): 182 | blue_bans.append(all_blue_bans[bans_per_team*k:bans_per_team*(k+1)]) 183 | red_bans.append(all_red_bans[bans_per_team*k:bans_per_team*(k+1)]) 184 | 185 | # submissions holds the identifiers of submitted (pick, position) pairs for each team in the parsed game 186 | # string representation for the positions are converted to ints to match DraftState expectations 187 | blue_picks = [] 188 | red_picks = [] 189 | for k in range(num_games_in_week): 190 | picks = all_blue_picks[picks_per_team*k:picks_per_team*(k+1)] 191 | positions = position_string_to_id(all_blue_roles[picks_per_team*k:picks_per_team*(k+1)]) 192 | blue_picks.append(list(zip(picks,positions))) 193 | 194 | picks = all_red_picks[picks_per_team*k:picks_per_team*(k+1)] 195 | positions = position_string_to_id(all_red_roles[picks_per_team*k:picks_per_team*(k+1)]) 196 | red_picks.append(list(zip(picks,positions))) 197 | 198 | total_blue_bans = sum([len(bans) for bans in blue_bans]) 199 | total_red_bans = sum([len(bans) for bans in red_bans]) 200 | total_blue_picks = sum([len(picks) for picks in blue_picks]) 201 | total_red_picks = sum([len(picks) for picks in red_picks]) 202 | print("Total number of games found: {}".format(num_games_in_week)) 203 | print("There should be {} bans. We found {} blue bans and {} red bans".format(num_games_in_week*5,total_blue_bans,total_red_bans)) 204 | print("There should be {} picks. We found {} blue picks and {} red picks".format(num_games_in_week*5,total_blue_picks,total_red_picks)) 205 | assert total_red_bans==total_blue_bans, "Bans don't match!" 206 | assert total_red_picks==total_blue_picks, "Picks don't match!" 207 | if(num_games_in_week > 0): # At least one game found on current page 208 | for k in range(num_games_in_week): 209 | print("Header_id {}, Game {}: {} vs {}".format(header_indices[i],k+1,blue_teams[k],red_teams[k])) 210 | 211 | tournGameId += 1 212 | bans = {"blue": blue_bans[k], "red":red_bans[k]} 213 | picks = {"blue": blue_picks[k], "red":red_picks[k]} 214 | blue = {"bans": blue_bans[k], "picks":blue_picks[k]} 215 | red = {"bans": red_bans[k], "picks":red_picks[k]} 216 | gameData = {"region": region, "year":year, "tournament": tournament, 217 | "blue_team": blue_teams[k], "red_team": red_teams[k], 218 | "winning_team": winning_teams[k], 219 | "blue_score":blue_scores[k], "red_score":red_scores[k], 220 | "bans": bans, "picks": picks, "blue":blue, "red":red, 221 | "tourn_game_id": tournGameId, "header_id":header_indices[i], 222 | "patch":patches[header_indices[i]-1]} 223 | formatted_data.append(gameData) 224 | return formatted_data 225 | def position_string_to_id(positions): 226 | """ 227 | position_string_to_id converts input position strings to their integer representations defined by: 228 | Position 1 = Primary farm (ADC) 229 | Position 2 = Secondary farm (Mid) 230 | Position 3 = Tertiary farm (Top) 231 | Position 4 = Farming support (Jungle) 232 | Position 5 = Primary support (Support) 233 | Note that because of variable standardization of the string representations for each position 234 | (i.e "jg"="jng"="jungle"), this function only looks at the first character of each string when 235 | assigning integer positions since this seems to be more or less standard. 236 | 237 | Args: 238 | positions (list(string)) 239 | Returns: 240 | list(int) 241 | """ 242 | 243 | d = {"a":1, "m":2, "t":3, "j":4, "s":5} # This is lazy and I know it 244 | out = [] 245 | for position in positions: 246 | char = position[0] # Look at first character for position information 247 | out.append(d[char]) 248 | return out 249 | 250 | def parse_raw_text(regex, rawText): 251 | """ 252 | parse_raw_text is a helper function which outputs a list of matching expressions defined 253 | by the regex input. Note that this function assumes that each regex yields matches of the form 254 | "A=B" which is passed to split_id_strings() for fomatting. 255 | 256 | Args: 257 | regex: desired regex to match with 258 | rawText: raw input string to find matches in 259 | Returns: 260 | List of formatted strings containing the matched data. 261 | """ 262 | # Parse raw text responses for data. Note that a regular expression of the form 263 | # "(match)" will produce result = [stuff_before_match, match, stuff_after_match] 264 | # this means that the list of desired matches will be result[1::2] 265 | out = re.split(regex, rawText) 266 | out = split_id_strings(out[1::2]) # Format matching strings 267 | return out 268 | 269 | def split_id_strings(rawStrings): 270 | """ 271 | split_id_strings takes a list of strings each of the form "A=B" and splits them 272 | along the "=" delimiting character. Returns the list formed by each of the "B" 273 | components of the raw input strings. For standardization purposes, the "B" string 274 | has the following done to it: 275 | 1. replace uppercase characters with lowercase 276 | 2. remove special characters (i.e non-alphanumeric) 277 | 278 | Args: 279 | rawStrings (list of strings): list of strings, each of the form "A=B" 280 | Returns: 281 | out (list of strings): list of strings formed by the "B" portion of each of the raw input strings 282 | """ 283 | out = [] 284 | for string in rawStrings: 285 | rightHandString = string.split("=")[1].lower() # Grab "B" part of string, make lowercase 286 | out.append(re.sub("[^A-Za-z0-9,]+", "", rightHandString)) # Remove special chars 287 | return out 288 | 289 | def convert_lcs_positions(index): 290 | """ 291 | Given the index of a pick in LCS order, returns the position id corresponding 292 | to that index. 293 | 294 | LCS picks are submitted in the following order 295 | Index | Role | Position 296 | 0 Top 3 297 | 1 Jng 4 298 | 2 Mid 2 299 | 3 Adc 1 300 | 4 Sup 5 301 | """ 302 | lcsOrderToPos = {i:j for i,j in enumerate([3,4,2,1,5])} 303 | return lcsOrderToPos[index] 304 | 305 | def create_position_dict(picks_in_lcs_order): 306 | """ 307 | Given a list of champions selected in lcs order (ie top,jungle,mid,adc,support) 308 | returns a dictionary which matches pick -> position. 309 | Args: 310 | picks_in_lcs_order (list(string)): list of string identifiers of picks. Assumed to be in LCS order 311 | Returns: 312 | dict (dictionary): dictionary with champion names for keys and position that the key was played in for value. 313 | """ 314 | d = {} 315 | cleaned_names = clean_champion_names(picks_in_lcs_order) 316 | for k in range(len(picks_in_lcs_order)): 317 | pos = convert_lcs_positions(k) 318 | d.update({cleaned_names[k]:pos}) 319 | return d 320 | 321 | 322 | def clean_champion_names(names): 323 | """ 324 | Takes a list of champion names and standarizes them by looking for possible aliases 325 | if necessary. 326 | Args: 327 | names (list(string)): list of champion names to be standardized 328 | Returns: 329 | cleanedNames (list(string)): list of standardized champion names 330 | """ 331 | cleanedNames = [] 332 | for name in names: 333 | if champion_id_from_name(name) is None: 334 | name = convert_champion_alias(name) 335 | cleanedNames.append(name) 336 | return cleanedNames 337 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random 3 | from copy import deepcopy 4 | 5 | import tensorflow as tf 6 | import pandas as pd 7 | import numpy as np 8 | 9 | import data.match_pool as pool 10 | 11 | from features.draftstate import DraftState 12 | import features.experience_replay as er 13 | import features.match_processing as mp 14 | from features.rewards import get_reward 15 | 16 | class BaseTrainer(): 17 | pass 18 | 19 | class DDQNTrainer(BaseTrainer): 20 | """ 21 | Trainer class for Double DQN networks. 22 | Args: 23 | q_network (qNetwork): Q-network containing "online" and "target" networks 24 | n_epochs (int): number of times to iterate through given data 25 | training_matches (list(match)): list of matches to be trained on 26 | validation_matches (list(match)): list of matches to validate model against 27 | batch_size (int): size of each training set sampled from the replay buffer which will be used to update Qnet at a time 28 | buffer_size (int): size of replay buffer used 29 | load_path (string): path to reload existing model 30 | """ 31 | def __init__(self, q_network, n_epoch, training_data, validation_data, batch_size, buffer_size, load_path=None): 32 | num_episodes = len(training_data) 33 | print("***") 34 | print("Beginning training..") 35 | print(" train_epochs: {}".format(n_epoch)) 36 | print(" num_episodes: {}".format(num_episodes)) 37 | print(" batch_size: {}".format(batch_size)) 38 | print(" buffer_size: {}".format(buffer_size)) 39 | print("***") 40 | 41 | self.ddq_net = q_network 42 | self.n_epoch = n_epoch 43 | self.training_data = training_data 44 | self.validation_data = validation_data 45 | self.batch_size = batch_size 46 | self.buffer_size = buffer_size 47 | self.load_path = load_path 48 | 49 | self.replay = er.ExperienceBuffer(self.buffer_size) 50 | self.step_count = 0 51 | self.epoch_count = 0 52 | 53 | self.dampen_states = False 54 | self.teams = [DraftState.BLUE_TEAM, DraftState.RED_TEAM] 55 | 56 | self.N_TEMP_TRAIN_MATCHES = 25 57 | self.TEMP_TRAIN_PATCHES = ["8.13","8.14","8.15"] 58 | 59 | def train(self): 60 | """ 61 | Core training loop over epochs 62 | """ 63 | self.target_update_frequency = 10000 # How often to update target network 64 | 65 | stash_model = True # Flag for stashing a copy of the model 66 | model_stash_interval = 10 # Stashes a copy of the model this often 67 | 68 | # Number of steps to take before training. Allows buffer to partially fill. 69 | # Must be at least batch_size to avoid error when sampling from experience replay 70 | self.pre_training_steps = 10*self.batch_size 71 | assert(self.pre_training_steps <= self.buffer_size), "Replay not large enough for pre-training!" 72 | assert(self.pre_training_steps >= self.batch_size), "Buffer not allowed to fill enough before sampling!" 73 | # Number of steps to force learner to observe submitted actions, rather than submit its own actions 74 | self.observations = 2000 75 | 76 | self.epsilon = 0.5 # Initial probability of letting the learner submit its own action 77 | self.eps_decay_rate = 1./(25*20*len(self.training_data)) # Rate at which epsilon decays per submission 78 | 79 | lr_decay_freq = 10 # Decay learning rate after a set number of epochs 80 | min_learning_rate = 1.e-8 # Minimum learning rate allowed to decay to 81 | 82 | summaries = {} 83 | summaries["loss"] = [] 84 | summaries["train_acc"] = [] 85 | summaries["val_acc"] = [] 86 | # Load existing model 87 | self.ddq_net.sess.run(self.ddq_net.online_ops["init"]) 88 | if(self.load_path): 89 | self.ddq_net.load(self.load_path) 90 | print("\nCheckpoint loaded from {}".format(self.load_path)) 91 | 92 | # Initialize target network 93 | self.ddq_net.sess.run(self.ddq_net.target_ops["target_init"]) 94 | 95 | for self.epoch_count in range(self.n_epoch): 96 | t0 = time.time() 97 | learning_rate = self.ddq_net.online_ops["learning_rate"].eval(self.ddq_net.sess) 98 | if((self.epoch_count>0) and (self.epoch_count % lr_decay_freq == 0) and (learning_rate>= min_learning_rate)): 99 | # Decay learning rate accoring to schedule 100 | learning_rate = 0.5*learning_rate 101 | self.ddq_net.sess.run(self.ddq_net.online_ops["learning_rate"].assign(learning_rate)) 102 | 103 | # Run single epoch of training 104 | loss, train_acc, val_acc = self.train_epoch() 105 | dt = time.time()-t0 106 | 107 | print(" Finished epoch {:2}/{}: lr: {:.4e}, dt {:.2f}, loss {:.6f}, train {:.6f}, val {:.6f}".format(self.epoch_count+1, self.n_epoch, learning_rate, dt, loss, train_acc, val_acc), flush=True) 108 | summaries["loss"].append(loss) 109 | summaries["train_acc"].append(train_acc) 110 | summaries["val_acc"].append(val_acc) 111 | 112 | if(stash_model): 113 | if(self.epoch_count>0 and (self.epoch_count+1)%model_stash_interval==0): 114 | # Stash a copy of the current model 115 | out_path = "tmp/models/{}_model_E{}.ckpt".format(self.ddq_net._name, self.epoch_count+1) 116 | self.ddq_net.save(path=out_path) 117 | print("Stashed a copy of the current model in {}".format(out_path)) 118 | 119 | self.ddq_net.save(path=self.ddq_net._path_to_model) 120 | return summaries 121 | 122 | def train_epoch(self): 123 | """ 124 | Training loop for a single epoch 125 | """ 126 | # We can't validate a winner for submissions generated by the learner, 127 | # so we will use a winner-less match when getting rewards for such states 128 | blank_match = {"winner":None} 129 | 130 | learner_submitted_actions = 0 131 | null_actions = 0 132 | 133 | # Shuffle match presentation order 134 | if(self.N_TEMP_TRAIN_MATCHES): 135 | path_to_db = "../data/competitiveMatchData.db" 136 | sources = {"patches":self.TEMP_TRAIN_PATCHES, "tournaments":[]} 137 | print("Adding {} matches to training pool from {}.".format(self.N_TEMP_TRAIN_MATCHES, path_to_db)) 138 | temp_matches = pool.match_pool(self.N_TEMP_TRAIN_MATCHES, path_to_db, randomize=True, match_sources=sources)["matches"] 139 | else: 140 | temp_matches = [] 141 | data = self.training_data + temp_matches 142 | 143 | shuffled_matches = random.sample(data, len(data)) 144 | for match in shuffled_matches: 145 | for team in self.teams: 146 | # Process match into individual experiences 147 | experiences = mp.process_match(match, team) 148 | for pick_id, experience in enumerate(experiences): 149 | # Some experiences include NULL submissions (usually missing bans) 150 | # The learner isn't allowed to submit NULL picks so skip adding these 151 | # to the buffer. 152 | state,actual,_,_ = experience 153 | (cid,pos) = actual 154 | if cid is None: 155 | null_actions += 1 156 | continue 157 | # Store original experience 158 | self.replay.store([experience]) 159 | self.step_count += 1 160 | 161 | # Give model feedback on current estimations 162 | if(self.step_count > self.observations): 163 | # Let the network predict the next action 164 | feed_dict = {self.ddq_net.online_ops["input"]:[state.format_state()], 165 | self.ddq_net.online_ops["valid_actions"]:[state.get_valid_actions()]} 166 | q_vals = self.ddq_net.sess.run(self.ddq_net.online_ops["valid_outQ"], feed_dict=feed_dict) 167 | sorted_actions = q_vals[0,:].argsort()[::-1] 168 | top_actions = sorted_actions[0:4] 169 | 170 | if(random.random() < self.epsilon): 171 | pred_act = random.sample(list(top_actions), 1) 172 | else: 173 | # Use model's top prediction 174 | pred_act = [sorted_actions[0]] 175 | 176 | for action in pred_act: 177 | (cid,pos) = state.format_action(action) 178 | if((cid,pos)!=actual): 179 | pred_state = deepcopy(state) 180 | pred_state.update(cid,pos) 181 | r = get_reward(pred_state, blank_match, (cid,pos), actual) 182 | new_experience = (state, (cid,pos), r, pred_state) 183 | 184 | self.replay.store([new_experience]) 185 | learner_submitted_actions += 1 186 | 187 | if(self.epsilon > 0.1): 188 | # Reduce epsilon over time 189 | self.epsilon -= self.eps_decay_rate 190 | 191 | # Use minibatch sample to update online network 192 | if(self.step_count > self.pre_training_steps): 193 | self.train_step() 194 | 195 | if(self.step_count % self.target_update_frequency == 0): 196 | # After the online network has been updated, update target network 197 | _ = self.ddq_net.sess.run(self.ddq_net.target_ops["target_update"]) 198 | 199 | # Get training loss, training_acc, and val_acc to return 200 | loss, train_acc = self.validate_model(self.training_data) 201 | _, val_acc = self.validate_model(self.validation_data) 202 | return (loss, train_acc, val_acc) 203 | 204 | def train_step(self): 205 | """ 206 | Training logic for a single mini-batch update sampled from replay 207 | """ 208 | # Sample training batch from replay 209 | training_batch = self.replay.sample(self.batch_size) 210 | 211 | # Calculate target Q values for each example: 212 | # For non-terminal states, targetQ is estimated according to 213 | # targetQ = r + gamma*Q'(s',max_a Q(s',a)) 214 | # where Q' denotes the target network. 215 | # For terminating states the target is computed as 216 | # targetQ = r 217 | updates = [] 218 | for exp in training_batch: 219 | start,_,reward,end = exp 220 | if(self.dampen_states): 221 | # To dampen states (usually done after major patches or when the meta shifts) 222 | # we replace winning rewards with 0. 223 | reward = 0. 224 | state_code = end.evaluate() 225 | if(state_code==DraftState.DRAFT_COMPLETE or state_code in DraftState.invalid_states): 226 | # Action moves to terminal state 227 | updates.append(reward) 228 | else: 229 | # Follwing double DQN paper (https://arxiv.org/abs/1509.06461). 230 | # Action is chosen by online network, but the target network is used to evaluate this policy. 231 | # Each row in predicted_Q gives estimated Q(s',a) values for all possible actions for the input state s'. 232 | feed_dict = {self.ddq_net.online_ops["input"]:[end.format_state()], 233 | self.ddq_net.online_ops["valid_actions"]:[end.get_valid_actions()]} 234 | predicted_action = self.ddq_net.sess.run(self.ddq_net.online_ops["prediction"], feed_dict=feed_dict)[0] 235 | 236 | feed_dict = {self.ddq_net.target_ops["input"]:[end.format_state()]} 237 | predicted_Q = self.ddq_net.sess.run(self.ddq_net.target_ops["outQ"], feed_dict=feed_dict) 238 | 239 | updates.append(reward + self.ddq_net.discount_factor*predicted_Q[0,predicted_action]) 240 | 241 | # Update online net using target Q 242 | # Experience replay stores action = (champion_id, position) pairs 243 | # these need to be converted into the corresponding index of the input vector to the Qnet 244 | actions = np.array([start.get_action(*exp[1]) for exp in training_batch]) 245 | targetQ = np.array(updates) 246 | feed_dict = {self.ddq_net.online_ops["input"]:np.stack([exp[0].format_state() for exp in training_batch],axis=0), 247 | self.ddq_net.online_ops["actions"]:actions, 248 | self.ddq_net.online_ops["target"]:targetQ, 249 | self.ddq_net.online_ops["dropout_keep_prob"]:0.5} 250 | _ = self.ddq_net.sess.run(self.ddq_net.online_ops["update"],feed_dict=feed_dict) 251 | 252 | def validate_model(self, data): 253 | """ 254 | Validates given model by computing loss and absolute accuracy for data using current Qnet. 255 | Args: 256 | data (list(dict)): list of matches to validate against 257 | Returns: 258 | stats (tuple(float)): list of statistical measures of performance. stats = (loss,acc) 259 | """ 260 | buf = [] 261 | for match in data: 262 | # Loss is only computed for winning side of drafts 263 | team = DraftState.RED_TEAM if match["winner"]==1 else DraftState.BLUE_TEAM 264 | # Process match into individual experiences 265 | experiences = mp.process_match(match, team) 266 | for exp in experiences: 267 | _,act,_,_ = exp 268 | (cid,pos) = act 269 | if cid is None: 270 | # Skip null actions such as missing/skipped bans 271 | continue 272 | buf.append(exp) 273 | 274 | n_exp = len(buf) 275 | targets = [] 276 | for exp in buf: 277 | start,_,reward,end = exp 278 | state_code = end.evaluate() 279 | if(state_code==DraftState.DRAFT_COMPLETE or state_code in DraftState.invalid_states): 280 | # Action moves to terminal state 281 | targets.append(reward) 282 | else: 283 | feed_dict = {self.ddq_net.online_ops["input"]:[end.format_state()], 284 | self.ddq_net.online_ops["valid_actions"]:[end.get_valid_actions()]} 285 | predicted_action = self.ddq_net.sess.run(self.ddq_net.online_ops["prediction"], feed_dict=feed_dict)[0] 286 | 287 | feed_dict = {self.ddq_net.target_ops["input"]:[end.format_state()]} 288 | predicted_Q = self.ddq_net.sess.run(self.ddq_net.target_ops["outQ"], feed_dict=feed_dict) 289 | 290 | targets.append(reward + self.ddq_net.discount_factor*predicted_Q[0,predicted_action]) 291 | 292 | actions = np.array([start.get_action(*exp[1]) for exp in buf]) 293 | targets = np.array(targets) 294 | 295 | feed_dict = {self.ddq_net.online_ops["input"]:np.stack([exp[0].format_state() for exp in buf],axis=0), 296 | self.ddq_net.online_ops["actions"]:actions, 297 | self.ddq_net.online_ops["target"]:targets, 298 | self.ddq_net.online_ops["valid_actions"]:np.stack([exp[0].get_valid_actions() for exp in buf],axis=0)} 299 | 300 | loss, pred_q = self.ddq_net.sess.run([self.ddq_net.online_ops["loss"], self.ddq_net.online_ops["valid_outQ"]],feed_dict=feed_dict) 301 | 302 | accurate_predictions = 0 303 | rank_tolerance = 5 304 | for n in range(n_exp): 305 | state,act,_,_ = buf[n] 306 | submitted_action_id = state.get_action(*act) 307 | 308 | data = [(a,pred_q[n,a]) for a in range(pred_q.shape[1])] 309 | df = pd.DataFrame(data, columns=['act_id','Q']) 310 | df.sort_values('Q',ascending=False,inplace=True) 311 | df.reset_index(drop=True,inplace=True) 312 | df['rank'] = df.index 313 | submitted_row = df[df['act_id']==submitted_action_id] 314 | rank = submitted_row['rank'].iloc[0] 315 | if rank < rank_tolerance: 316 | accurate_predictions += 1 317 | 318 | accuracy = accurate_predictions/n_exp 319 | return (loss, accuracy) 320 | 321 | class SoftmaxTrainer(BaseTrainer): 322 | def __init__(self, network, n_epoch, training_data, validation_data, batch_size, load_path=None): 323 | num_episodes = len(training_data) 324 | print("***") 325 | print("Beginning training..") 326 | print(" train_epochs: {}".format(n_epoch)) 327 | print(" num_episodes: {}".format(num_episodes)) 328 | print(" batch_size: {}".format(batch_size)) 329 | print("***") 330 | 331 | self.model = network 332 | self.n_epoch = n_epoch 333 | self.training_data = training_data 334 | self.validation_data = validation_data 335 | self.batch_size = batch_size 336 | self.load_path = load_path 337 | 338 | self.step_count = 0 339 | self.epoch_count = 0 340 | 341 | self.teams = [DraftState.BLUE_TEAM, DraftState.RED_TEAM] 342 | 343 | self._buffer = er.ExperienceBuffer(max_buffer_size=20*len(training_data)) 344 | self._val_buffer = er.ExperienceBuffer(max_buffer_size=20*len(validation_data)) 345 | 346 | self.fill_buffer(training_data, self._buffer) 347 | self.fill_buffer(validation_data, self._val_buffer) 348 | 349 | def fill_buffer(self, data, buf): 350 | for match in data: 351 | for team in self.teams: 352 | experiences = mp.process_match(match, team) 353 | # remove null actions (usually missing bans) 354 | for exp in experiences: 355 | _,act,_,_ = exp 356 | cid,pos = act 357 | if(cid): 358 | buf.store([exp]) 359 | 360 | def sample_buffer(self, buf, n_samples): 361 | experiences = buf.sample(n_samples) 362 | states = [] 363 | actions = [] 364 | valid_actions = [] 365 | for (state, action, _, _) in experiences: 366 | states.append(state.format_state()) 367 | valid_actions.append(state.get_valid_actions()) 368 | actions.append(state.get_action(*action)) 369 | 370 | return (states, actions, valid_actions) 371 | 372 | def train(self): 373 | summaries = {} 374 | summaries["loss"] = [] 375 | summaries["train_acc"] = [] 376 | summaries["val_acc"] = [] 377 | 378 | lr_decay_freq = 10 379 | min_learning_rate = 1.e-8 # Minimum learning rate allowed to decay to 380 | 381 | stash_model = True # Flag for stashing a copy of the model 382 | model_stash_interval = 10 # Stashes a copy of the model this often 383 | 384 | # Load existing model 385 | self.model.sess.run(self.model.ops_dict["init"]) 386 | if(self.load_path): 387 | self.model.load(self.load_path) 388 | print("\nCheckpoint loaded from {}".format(self.load_path)) 389 | 390 | for self.epoch_count in range(self.n_epoch): 391 | learning_rate = self.model.ops_dict["learning_rate"].eval(self.model.sess) 392 | if((self.epoch_count>0) and (self.epoch_count % lr_decay_freq == 0) and (learning_rate>= min_learning_rate)): 393 | # Decay learning rate accoring to schedule 394 | learning_rate = 0.5*learning_rate 395 | self.model.sess.run(self.model.ops_dict["learning_rate"].assign(learning_rate)) 396 | 397 | t0 = time.time() 398 | loss, train_acc, val_acc = self.train_epoch() 399 | dt = time.time()-t0 400 | print(" Finished epoch {:2}/{}: lr: {:.4e}, dt {:.2f}, loss {:.6f}, train {:.6f}, val {:.6f}".format(self.epoch_count+1, self.n_epoch, learning_rate, dt, loss, train_acc, val_acc), flush=True) 401 | summaries["loss"].append(loss) 402 | summaries["train_acc"].append(train_acc) 403 | summaries["val_acc"].append(val_acc) 404 | 405 | if(stash_model): 406 | if(self.epoch_count>0 and (self.epoch_count+1)%model_stash_interval==0): 407 | # Stash a copy of the current model 408 | out_path = "tmp/models/{}_model_E{}.ckpt".format(self.model._name, self.epoch_count+1) 409 | self.model.save(path=out_path) 410 | print("Stashed a copy of the current model in {}".format(out_path)) 411 | 412 | self.model.save(path=self.model._path_to_model) 413 | return summaries 414 | 415 | def train_epoch(self): 416 | n_iter = self._buffer.buffer_size // self.batch_size 417 | 418 | for it in range(n_iter): 419 | self.train_step() 420 | 421 | loss, train_acc = self.validate_model(self._buffer) 422 | _, val_acc = self.validate_model(self._val_buffer) 423 | 424 | return (loss, train_acc, val_acc) 425 | 426 | def train_step(self): 427 | states, actions, valid_actions = self.sample_buffer(self._buffer, self.batch_size) 428 | 429 | feed_dict = {self.model.ops_dict["input"]:np.stack(states, axis=0), 430 | self.model.ops_dict["valid_actions"]:np.stack(valid_actions, axis=0), 431 | self.model.ops_dict["actions"]:actions, 432 | self.model.ops_dict["dropout_keep_prob"]:0.5} 433 | _ = self.model.sess.run(self.model.ops_dict["update"], feed_dict=feed_dict) 434 | 435 | def validate_model(self, buf): 436 | states, actions, valid_actions = self.sample_buffer(buf, buf.get_buffer_size()) 437 | 438 | feed_dict = {self.model.ops_dict["input"]:np.stack(states, axis=0), 439 | self.model.ops_dict["valid_actions"]:np.stack(valid_actions, axis=0), 440 | self.model.ops_dict["actions"]:actions} 441 | loss, train_probs = self.model.sess.run([self.model.ops_dict["loss"], self.model.ops_dict["probabilities"]], feed_dict=feed_dict) 442 | 443 | THRESHOLD = 5 444 | accurate_predictions = 0 445 | for k in range(len(states)): 446 | probabilities = train_probs[k,:] 447 | data = [(a, probabilities[a]) for a in range(len(probabilities))] 448 | df = pd.DataFrame(data, columns=['act_id','prob']) 449 | 450 | df.sort_values('prob',ascending=False,inplace=True) 451 | df.reset_index(drop=True,inplace=True) 452 | df['rank'] = df.index 453 | 454 | submitted_action_id = actions[k] 455 | submitted_row = df[df['act_id']==submitted_action_id] 456 | 457 | rank = submitted_row['rank'].iloc[0] 458 | if(rank < THRESHOLD): 459 | accurate_predictions += 1 460 | 461 | accuracy = accurate_predictions/len(states) 462 | return (loss, accuracy) 463 | -------------------------------------------------------------------------------- /src/features/draftstate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from data.champion_info import champion_name_from_id, valid_champion_id, get_champion_ids 3 | from .draft import Draft 4 | 5 | class InvalidDraftState(Exception): 6 | pass 7 | 8 | class DraftState: 9 | """ 10 | Args: 11 | team (int) : indicator for which team we are drafting for (RED_TEAM or BLUE_TEAM) 12 | champ_ids (list(int)) : list of valid championids which are available for drafting. 13 | num_positions (int) : number of available positions to draft for. Default is 5 for a standard 5x5 draft. 14 | 15 | DraftState is the class responsible for holding and maintaining the current state of the draft. For a given champion with championid c, 16 | that champion's state with respect to the draft is at most one of: 17 | - c is banned from selection. 18 | - c is selected as part of the opponent's team. 19 | - c is selected as one of our team's position. 20 | 21 | The state of the draft will be stored as a (numChampions) x (numRoles+2) numPy array. If state(c,k) = 1 then: 22 | - k = 0 -> champion c is banned from selection. 23 | - k = 1 -> champion c is selected as part of the enemy team. 24 | - 2 <= k = num_positions+1 -> champion c is selected as position k-1 in our draft. 25 | 26 | Default draft positions are interpreted as: 27 | Position 1 -> ADC/Marksman (Primary farm) 28 | Position 2 -> Middle (Secondary farm) 29 | Position 3 -> Top (Tertiary farm) 30 | Position 4 -> Jungle (Farming support) 31 | Position 5 -> Support (Primary support) 32 | """ 33 | # State codes 34 | BAN_AND_SUBMISSION = 101 35 | DUPLICATE_SUBMISSION = 102 36 | DUPLICATE_ROLE = 103 37 | INVALID_SUBMISSION = 104 38 | TOO_MANY_BANS = 105 39 | TOO_MANY_PICKS = 106 40 | invalid_states = [BAN_AND_SUBMISSION, DUPLICATE_ROLE, DUPLICATE_SUBMISSION, INVALID_SUBMISSION, 41 | TOO_MANY_BANS, TOO_MANY_PICKS] 42 | 43 | DRAFT_COMPLETE = 1 44 | BLUE_TEAM = Draft.BLUE_TEAM 45 | RED_TEAM = Draft.RED_TEAM 46 | BAN_PHASE = Draft.BAN 47 | PICK_PHASE = Draft.PICK 48 | 49 | def __init__(self, team, champ_ids = get_champion_ids(), num_positions = 5, draft = Draft('default')): 50 | #TODO (Devin): This should make sure that numChampions >= num_positions 51 | self.num_champions = len(champ_ids) 52 | self.num_positions = num_positions 53 | self.num_actions = (self.num_positions+1)*self.num_champions 54 | self.state_index_to_champ_id = {i:k for i,k in zip(range(self.num_champions),champ_ids)} 55 | self.champ_id_to_state_index = {k:i for i,k in zip(range(self.num_champions),champ_ids)} 56 | self.state = np.zeros((self.num_champions, self.num_positions+2), dtype=bool) 57 | self.picks = [] 58 | self.bans = [] 59 | self.selected_pos = [] 60 | 61 | self.team = team 62 | self.draft_structure = draft 63 | # Get phase information from draft 64 | self.BAN_PHASE_LENGTHS = self.draft_structure.PHASE_LENGTHS[DraftState.BAN_PHASE] 65 | self.PICK_PHASE_LENGTHS = self.draft_structure.PHASE_LENGTHS[DraftState.PICK_PHASE] 66 | 67 | # The dicts pos_to_pos_index and pos_index_to_pos contain the mapping 68 | # from position labels to indices to the state matrix and vice versa. 69 | self.positions = [i-1 for i in range(num_positions+2)] 70 | self.pos_indices = [1,0] 71 | self.pos_indices.extend(range(2,num_positions+2)) 72 | self.pos_to_pos_index = dict(zip(self.positions,self.pos_indices)) 73 | self.pos_index_to_pos = dict(zip(self.pos_indices,self.positions)) 74 | 75 | def reset(self): 76 | """ 77 | Resets draft state back to default values. 78 | Args: 79 | None 80 | Returns: 81 | None 82 | """ 83 | self.state[:] = False 84 | self.picks = [] 85 | self.bans = [] 86 | self.selected_pos = [] 87 | 88 | def get_valid_actions(self, form="mask"): 89 | """ 90 | Returns a valid actions for the current state. 91 | Input: 92 | self.state 93 | form (string): default returns actions as a mask. "list" returns actions as a list of ids 94 | Returns: 95 | action_ids (list[bool/int]): valid actions that can be taken from state. If form = "list" ids are returned as a list of action_ids, 96 | otherwise actions are returned as a boolean mask 97 | 98 | If the draft is complete or in an invalid state, get_valid_actions will return an empty list of actions. 99 | """ 100 | # Check if draft is complete or invalid 101 | if(self.evaluate()): 102 | if(form == "list"): 103 | return np.array([]) 104 | else: 105 | return np.zeros_like(self.state[:,1:].reshape(-1)) 106 | 107 | sub_count = len(self.bans)+len(self.picks) 108 | phase = self.draft_structure.get_active_phase(sub_count) 109 | champ_available = np.logical_not(np.amax(self.state[:,:],axis=1)) 110 | pos_available = [pos for pos in range(1, self.num_positions+1) if pos not in self.selected_pos] 111 | valid_actions = np.zeros_like(self.state[:,1:]) 112 | if(phase == Draft.BAN): 113 | # only bans are (potentially) valid during ban phase 114 | valid_actions[:,0] = champ_available 115 | else: 116 | # only picks are (potentially) valid during pick phase 117 | for pos in pos_available: 118 | valid_actions[:,pos] = champ_available 119 | 120 | if(form == "list"): 121 | return np.nonzero(valid_actions.reshape(-1)) 122 | else: 123 | return valid_actions.reshape(-1) 124 | 125 | def is_submission_legal(self, champion_id, position): 126 | """ 127 | Checks if submission (champion_id, position) is a valid and legal submission for the current state. 128 | Returns: 129 | True if submission is legal, False otherwise. 130 | """ 131 | if(not self.can_ban(champion_id) or not self.can_pick(champion_id)): 132 | return False 133 | sub_count = len(self.bans)+len(self.picks) 134 | phase = self.draft_structure.get_active_phase(sub_count) 135 | if phase == DraftState.BAN_PHASE and position != -1: 136 | return False 137 | if phase == DraftState.PICK_PHASE: 138 | pos_index = self.get_position_index(position) 139 | is_pos_filled = np.amax(self.state[:,pos_index]) 140 | if(is_pos_filled): 141 | return False 142 | return True 143 | 144 | def get_champ_id(self,index): 145 | """ 146 | get_champ_id returns the valid champion ID corresponding to the given state index. Since champion IDs are not contiguously defined or even necessarily ordered, 147 | this mapping will not be trivial. If index is invalid, returns -1. 148 | Args: 149 | index (int): location index in the state array of the desired champion. 150 | Returns: 151 | champ_id (int): champion ID corresponding to index (as defined by champ_ids) 152 | """ 153 | if index not in self.state_index_to_champ_id.keys(): 154 | return -1 155 | return self.state_index_to_champ_id[index] 156 | 157 | def get_state_index(self,champ_id): 158 | """ 159 | get_state_index returns the state index corresponding to the given champion ID. Since champion IDs are not contiguously defined or even necessarily ordered, 160 | this mapping is non-trivial. If champ_id is invalid, returns -1. 161 | Args: 162 | champ_id (int): id of champion to look up 163 | Returns 164 | index (int): state index of corresponding champion id 165 | """ 166 | if champ_id not in self.champ_id_to_state_index.keys(): 167 | return -1 168 | return self.champ_id_to_state_index[champ_id] 169 | 170 | def get_position_index(self,position): 171 | """ 172 | get_position_index returns the index of the state matrix corresponding to the given position label. 173 | If the position is invalid, returns False. 174 | Args: 175 | position (int): position label to look up 176 | Returns: 177 | index (int): index into the state matrix corresponding to this position 178 | """ 179 | if position not in self.positions: 180 | return False 181 | return self.pos_to_pos_index[position] 182 | 183 | def get_position(self, pos_index): 184 | """ 185 | get_position returns the position label corresponding to the given position index into the state matrix. 186 | If the position index is invalid, returns False. 187 | Args: 188 | pos_index (int): position index to look up 189 | Returns: 190 | position (int): position label corresponding to this position index 191 | """ 192 | if pos_index not in self.pos_indices: 193 | return False 194 | return self.pos_index_to_pos[pos_index] 195 | 196 | def format_state(self): 197 | """ 198 | Format the state so the Q-network can process it. 199 | Args: 200 | None 201 | Returns: 202 | A copy of self.state 203 | """ 204 | if(self.evaluate() in DraftState.invalid_states): 205 | raise InvalidDraftState("Attempting to format an invalid draft state for network input with code {}".format(self.evaluate())) 206 | 207 | return self.state.reshape(-1) 208 | 209 | def format_secondary_inputs(self): 210 | """ 211 | Produces secondary input information (information about filled positions and draft phase) 212 | to send to Q-network. 213 | Args: 214 | None 215 | Returns: 216 | Numpy vector of secondary network inputs 217 | """ 218 | if(self.evaluate() in DraftState.invalid_states): 219 | raise InvalidDraftState("Attempting to format an invalid draft state for network input with code {}".format(self.evaluate())) 220 | 221 | # First segment of information checks whether each position has been or not filled in the state 222 | # This is done by looking at columns in the subarray corresponding to positions 1 thru 5 223 | start = self.get_position_index(1) 224 | end = self.get_position_index(5) 225 | secondary_inputs = np.amax(self.state[:,start:end+1],axis=0) 226 | 227 | # Second segment checks if the phase corresponding to this state is a pick phase 228 | # This is done by counting the number of bans currently submitted. Note that this assumes 229 | # that state is currently a valid state. If this is not necessarily the case a check can be made using 230 | # evaluate(). 231 | submission_count = len(self.bans)+len(self.picks) 232 | phase = self.draft_structure.get_active_phase(submission_count) 233 | is_pick_phase = phase == DraftState.PICK_PHASE 234 | secondary_inputs = np.append(secondary_inputs, is_pick_phase) 235 | return secondary_inputs 236 | 237 | def format_action(self,action): 238 | """ 239 | Format input action into the corresponding tuple (champ_id, position) which describes the input action. 240 | Args: 241 | action (int): Action to be interpreted. Assumed to be generated as output of ANN. action is the index 242 | of the flattened 'actionable state' matrix 243 | Returns: 244 | (championId, position) (tuple of ints): Tuple of integer values which may be passed as arguments to either 245 | self.add_pick() or self.add_ban() depending on the value of position. If position = -1 -> action is a ban otherwise action 246 | is a pick. 247 | 248 | Note: format_action() explicitly indexes into 'actionable state' matrix which excludes the portion of the state 249 | matrix corresponding to opponent team submission. In practice this means that (cid, pos) = format_action(a) will 250 | never output pos = 0. 251 | """ 252 | # 'actionable state' is the sub-state of the state matrix with 'enemy picks' column removed. 253 | actionable_state = self.state[:,1:] 254 | if(action not in range(actionable_state.size)): 255 | raise "Invalid action to format_action()!" 256 | (state_index, position_index) = np.unravel_index(action,actionable_state.shape) 257 | # Action corresponds to a submission that we are allowed to make, ie. a pick or a ban. 258 | # We can't make submissions to the enemy team, so the indicies corresponding to these actions are removed. 259 | # position_index needs to be shifted by 1 in order to correctly index into full state array 260 | position_index += 1 261 | position = self.get_position(position_index) 262 | champ_id = self.get_champ_id(state_index) 263 | return (champ_id,position) 264 | 265 | def get_action(self, champion_id, position): 266 | """ 267 | Given a (champion_id, position) submission pair. Return the corresponding action index in the flattened actionable state array. 268 | Args: 269 | champion_id (int): id of a champion to be picked/banned. 270 | position (int): Position of champion to be selected. The value of position determines if championId is interpreted as a pick or ban: 271 | position = -1 -> champion ban submitted. 272 | 0 < position <= num_positions -> champion selection submitted by our team for pos = position 273 | Returns: 274 | action (int): Action to be interpreted as index into the flattened actionable state vector. If no such action can be found, returns -1 275 | 276 | Note: get_action() explicitly indexes into 'actionable state' matrix which excludes the portion of the state 277 | matrix corresponding to opponent team submission. In practice this means that a = format_action(cid,pos) will 278 | produce an invalid action for pos = 0. 279 | """ 280 | state_index = self.get_state_index(champion_id) 281 | pos_index = self.get_position_index(position) 282 | if ((state_index==-1) or (pos_index not in range(1,self.state.shape[1]))): 283 | print("Invalid state index or position out of range!") 284 | print("cid = {}".format(champion_id)) 285 | print("pos = {}".format(position)) 286 | return -1 287 | # Convert position index for full state matrix into index for actionable state 288 | pos_index -= 1 289 | action = np.ravel_multi_index((state_index,pos_index),self.state[:,1:].shape) 290 | return action 291 | 292 | def update(self, champion_id, position): 293 | """ 294 | Attempt to update the current state of the draft and pick/ban lists with a given championId. 295 | Returns: True is selection was successful, False otherwise 296 | Args: 297 | champion_id (int): Id of champion to add to pick list. 298 | position (int): Position of champion to be selected. The value of position determines if championId is interpreted as a pick or ban: 299 | position = -1 -> champion ban submitted. 300 | position = 0 -> champion selection submitted by the opposing team. 301 | 0 < position <= num_positions -> champion selection submitted by our team for pos = position 302 | """ 303 | # Special case for NULL ban submitted. 304 | if (champion_id is None and position == -1): 305 | # Only append NULL bans to ban list (nothing done to state matrix) 306 | self.bans.append(champion_id) 307 | return True 308 | 309 | # Submitted picks of the form (champ_id, pos) correspond with the selection champion = champion_id in position = pos. 310 | # Bans are given pos = -1 and enemy picks pos = 0. However, this is not how they are stored in the state array. 311 | # Finally this doesn't match indexing used for state array and action vector indexing (which follow state indexing). 312 | if((position < -1) or (position > self.num_positions) or (not valid_champion_id(champion_id))): 313 | return False 314 | 315 | index = self.champ_id_to_state_index[champion_id] 316 | pos_index = self.get_position_index(position) 317 | if(position == -1): 318 | self.bans.append(champion_id) 319 | else: 320 | self.picks.append(champion_id) 321 | self.selected_pos.append(position) 322 | 323 | self.state[index,pos_index] = True 324 | return True 325 | 326 | def display(self): 327 | #TODO (Devin): Clean up display to make it prettier. 328 | print("=== Begin Draft State ===") 329 | print("There are {num_picks} picks and {num_bans} bans completed in this draft. \n".format(num_picks=len(self.picks),num_bans=len(self.bans))) 330 | 331 | print("Banned Champions: {0}".format(list(map(champion_name_from_id, self.bans)))) 332 | print("Picked Champions: {0}".format(list(map(champion_name_from_id, self.picks)))) 333 | pos_index = self.get_position_index(0) 334 | enemy_draft_ids = list(map(self.get_champ_id, list(np.where(self.state[:,pos_index])[0]))) 335 | print("Enemy Draft: {0}".format(list(map(champion_name_from_id,enemy_draft_ids)))) 336 | 337 | print("Ally Draft:") 338 | for pos_index in range(2,len(self.state[0,:])): # Iterate through each position columns in state 339 | champ_index = np.where(self.state[:,pos_index])[0] # Find non-zero index 340 | if not champ_index.size: # No pick is found for this position, create a filler string 341 | draft_name = "--" 342 | else: 343 | draft_name = champion_name_from_id(self.get_champ_id(champ_index[0])) 344 | print("Position {p}: {c}".format(p=pos_index-1,c=draft_name)) 345 | print("=== End Draft State ===") 346 | 347 | def can_pick(self, champion_id): 348 | """ 349 | Check to see if a champion is available to be selected. 350 | Returns: True if champion is a valid selection, False otherwise. 351 | Args: 352 | champion_id (int): Id of champion to check for valid selection. 353 | """ 354 | return ((champion_id not in self.picks) and valid_champion_id(champion_id)) 355 | 356 | def can_ban(self, champion_id): 357 | """ 358 | Check to see if a champion is available to be banned. 359 | Returns: True if champion is a valid ban, False otherwise. 360 | Args: 361 | champion_id (int): Id of champion to check for valid ban. 362 | """ 363 | return ((champion_id not in self.bans) and valid_champion_id(champion_id)) 364 | 365 | def add_pick(self, champion_id, position): 366 | """ 367 | Attempt to add a champion to the selected champion list and update the state. 368 | Returns: True is selection was successful, False otherwise 369 | Args: 370 | champion_id (int): Id of champion to add to pick list. 371 | position (int): Position of champion to be selected. If position = 0 this is interpreted as a selection submitted by the opposing team. 372 | """ 373 | if((position < 0) or (position > self.num_positions) or (not valid_champion_id(champion_id))): 374 | return False 375 | self.picks.append(champion_id) 376 | self.selected_pos.append(position) 377 | index = self.get_state_index(champion_id) 378 | pos_index = self.get_position_index(position) 379 | self.state[index,pos_index] = True 380 | return True 381 | 382 | def add_ban(self, champion_id): 383 | """ 384 | Attempt to add a champion to the banned champion list and update the state. 385 | Returns: True is ban was successful, False otherwise 386 | Args: 387 | champion_id (int): Id of champion to add to bans. 388 | """ 389 | if(not valid_champion_id(champion_id)): 390 | return False 391 | self.bans.append(champion_id) 392 | index = self.get_state_index(champion_id) 393 | self.state[index,self.get_position_index(-1)] = True 394 | return True 395 | 396 | def evaluate(self): 397 | """ 398 | evaluate checks the current state and determines if the draft as it is currently recorded is valid. 399 | Returns: value (int) - code indicating validitiy of state 400 | Valid codes: 401 | value = 0 -> state is valid but incomplete. 402 | value = DRAFT_COMPLETE -> state is valid and complete. 403 | Invalid codes: 404 | value = BAN_AND_SUBMISSION -> state has a banned champion selected for draft. This will also appear if a ban is submitted which matches an previously submitted champion. 405 | value = DUPLICATE_SUBMISSION -> state has a champion drafted which is already part of the opposing team or has already been selected by our team. 406 | value = DUPLICATE_ROLE -> state has multiple champions selected for a single role 407 | value = INVALID_SUBMISSION -> state has a submission that was included out of the draft phase order (ex pick during ban phase / ban during pick phase) 408 | """ 409 | # Check for duplicate submissions appearing in picks or bans 410 | duplicate_picks = set([cid for cid in self.picks if self.picks.count(cid)>1]) 411 | # Need to remove possible NULL bans as duplicates (since these may be legitimate) 412 | duplicate_bans = set([cid for cid in self.bans if self.bans.count(cid)>1]).difference(set([None])) 413 | if(len(duplicate_picks)>0 or len(duplicate_bans)>0): 414 | return DraftState.DUPLICATE_SUBMISSION 415 | 416 | # Check for submissions appearing in both picks and bans 417 | if(len(set(self.picks).intersection(set(self.bans)))>0): 418 | # Invalid state includes an already banned champion 419 | return DraftState.BAN_AND_SUBMISSION 420 | 421 | # Check for different champions that have been submitted for the same role 422 | for pos in range(2,self.num_positions+2): 423 | loc = np.argwhere(self.state[:,pos]) 424 | if(len(loc)>1): 425 | # Invalid state includes multiple champions intended for the same role. 426 | return DraftState.DUPLICATE_ROLE 427 | 428 | # Check for out of phase submissions 429 | num_bans = len(self.bans) 430 | num_picks = len(self.picks) 431 | sub_count = num_bans+num_picks 432 | 433 | if(num_bans > self.draft_structure.NUM_BANS): 434 | return DraftState.TOO_MANY_BANS 435 | if(num_picks > self.draft_structure.NUM_PICKS): 436 | return DraftState.TOO_MANY_PICKS 437 | 438 | # validation is tuple of form (target_ban_count, target_blue_pick_count, target_red_pick_count) 439 | validation = self.draft_structure.submission_dist[sub_count] 440 | num_opponent_sub = np.count_nonzero(self.state[:,self.get_position_index(0)]) 441 | num_ally_sub = num_picks - num_opponent_sub 442 | if self.team == DraftState.BLUE_TEAM: 443 | dist = (num_bans, num_ally_sub, num_opponent_sub) 444 | else: 445 | dist = (num_bans, num_opponent_sub, num_ally_sub) 446 | if(dist != validation): 447 | return DraftState.INVALID_SUBMISSION 448 | 449 | # State is valid, check if draft is complete 450 | if(num_ally_sub == self.num_positions and num_opponent_sub == self.num_positions): 451 | # Draft is valid and complete. Note that it isn't necessary 452 | # to have the full number of valid bans to register a complete draft. This is 453 | # because teams can be forced to forefit bans due to disciplinary factor (rare) 454 | # or they can elect to not submit a ban (very rare). 455 | return DraftState.DRAFT_COMPLETE 456 | 457 | # Draft is valid, but not complete 458 | return 0 459 | 460 | if __name__=="__main__": 461 | state = DraftState(DraftState.BLUE_TEAM) 462 | print(state.evaluate()) 463 | print(state.num_actions) 464 | state.display() 465 | actions = state.get_valid_actions() 466 | print(actions) 467 | 468 | state.update(1,-1) 469 | state.update(2,-1) 470 | state.update(3,-1) 471 | state.update(4,-1) 472 | state.update(5,-1) 473 | state.update(6,-1) 474 | 475 | state.update(7,1) 476 | state.update(8,0) 477 | state.update(9,0) 478 | 479 | new_actions = state.get_valid_actions() 480 | print(new_actions) 481 | print("") 482 | for aid in range(len(new_actions)): 483 | if(new_actions[aid]): 484 | print(state.format_action(aid)) 485 | 486 | state.update(10,2) 487 | state.update(11,3) 488 | state.update(12,0) 489 | 490 | state.update(13,-1) 491 | state.update(14,-1) 492 | state.update(15,-1) 493 | state.update(16,-1) 494 | 495 | state.update(17,0) 496 | state.update(18,5) 497 | state.update(19,4) 498 | 499 | state.display() 500 | print(state.evaluate()) 501 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Swain Bot 2 | Created by Devin Light 3 | 4 | ~~A [web version of Swain Bot](https://swainbot.herokuapp.com) (hosted on Heroku) is available to play with. (Note: this model has a limited number of concurrent users and is much slower than the local version because it is hosted using the free tier offered by Heroku)~~ Since this project has been on the back-burner for a while, the Heroku-hosted version of Swain Bot has been taken down and will be updated TBD. 5 | 6 | ## Introduction 7 | ### What is League of Legends? 8 | League of Legends (abbreviated as LoL, or League) is a multiplayer online battle arena (MOBA) game developed by Riot Games which features two teams of five players each competing in head-to-head matches with the ultimate goal of destroying the opposing team's nexus structure. The game boasts millions of monthly players and a large competitive scene involving dozens of teams participating in both national and international tournaments. The game takes place across two broadly defined phases. In the first phase (or Drafting phase), each side takes turns assembling their team by selecting a unique character (called a champion) from a pool of 138 (as of this writing) without replacement. Then, in the second phase (or Game phase), each player in the match takes control of one of the champions chosen by their team and attempts to claim victory. Although not strictly required by the game, over the years players have grown to play their champion in one of five roles named after the location on the map in which they typically start the game, and often corresponding to the amount of in-game resources that player will have devoted to them. The five roles are: 9 | 10 | - Position 1 (primary farm)-> ADC/Marksman1 11 | - Position 2 (secondary farm)-> Middle 12 | - Position 3 (tertiary farm)-> Top 13 | - Position 4 (farming support)-> Jungle 14 | - Position 5 (primary support)-> Support1 15 | 16 | 1 Traditionally the ADC and Support begin the game together same lane and are collectively called 'Bottom'. 17 | 18 | Each champion has distinct set of characteristics and abilities that allows them excel in certain situations while struggling in others. In order to maximize the odds of victory, it is important that the team assembled during the drafting phase simultaneously plays into one cohesive set of strengths and disrupts or plays well against the strengths of the opposing draft. There are two types of submissions made during the drafting phase. In the banning portions of the draft champions are removed from the pool of allowed submissions, whereas champions are added to the roster of the submitting team during the pick phases. The drafting alternates between banning and picking until both teams have a full roster of five champions, at which point the game is played. The structure of the drafting phase is displayed in Figure 1. Note the asymmetry between teams (for example Blue bans first in ban phase one, while Red bans first in ban phase two) and between the phases themselves (ban phases always alternate sides, while pick phases "snake" between teams). 19 | 20 | ![Figure 1](common/images/league_draft_structure.png "Figure 1") 21 | 22 | ### What is Swain Bot? 23 | Swain Bot (named after the champion Swain who is associated with being a ruthless master tactician/general) is a machine learning application built in Python using Google's Tensorflow framework. Swain Bot is designed to analyze the drafting phase of competitive League of Legends matches. Given a state of the draft which includes full information of our team's submissions (champions and positions) and partial information of the opponent's submissions (champions only), Swain Bot attempts to suggest picks and bans that are well-suited for the current state of our draft. 24 | 25 | ### What do we hope to do with Swain Bot? 26 | Knowing the best pick for a given draft situation can dramatically improve a team's chance for success. Our objective with Swain Bot is to help provide insight into League's crucial draft phase by attempting to answer questions like: 27 | - Can we estimate how valuable each submission is for a given state of the draft? 28 | - Is there a common structure or theme to how professional League teams draft? 29 | - Can we identify the differences between a winning and a losing draft? 30 | 31 | ## Assumptions and Limitations 32 | Every model tasked with approaching a difficult problem is predicated on some number assumptions which in turn define the boundaries that the model can safely be applied. Swain Bot is no exception, so here we outline and discuss some of the explicit assumptions being made going into the construction of the underlying model Swain Bot uses to make its predictions. Some of the assumptions are more impactful than others and some could be removed in the future to improve Swain Bot's performance, but are in place for now for various reasons. 33 | 34 | 1. Swain Bot is limited to data from recorded professionally played games from the "Big 5" regions (NALCS, EULCS, LCK, LPL, and LMS). Limiting potential data sources to competitive leagues is very restrictive when compared to the pool of amateur matches played on servers across the world. However, this assumption is in place as a result of changes in Riot's (otherwise exceptional) API which effectively randomizes the order in which the champion submissions for a draft are presented, rendering it impossible to recover the sequence of draft states that make up the complete information regarding the draft. Should the API be changed in the future Swain Bot will be capable of learning from amateur matches as well. 35 | 36 | 2. Swain Bot does not receive information about either the patch the game was played on or the teams involved in the match. Not including the patch allows us to stretch the data as much as we can given the restricted pool. Although the effectiveness of a champion might change as they are tuned between patches, it is unlikely that they are changed so much that the situations that the champion would normally be picked in are dramatically different. Nevertheless substantial champion changes have occured in the past, usually in the form of a total redesign. Additionally, although team data for competitive matches is available during the draft, Swain Bot's primary objective is to identify the most effective submissions for a given draft state rather than predict what a specific team might select in that situation. It would be possible to combine Swain Bot's output with information about a team's drafting tendencies (using ensemble techniques like stacking) to produce a final prediction which both suits the draft and is likely to be chosen by the team. However we will leave this work for later. 37 | 38 | 3. Swain Bot's objective is to associate the combination of a state and a potential submission with a value and to suggest taking the action which has the highest value. This valuation should be based primarily on what is likely to win the draft (or move us towards a winning state), and partly on what is likely to be done. Although these two goals may be correlated (a champion that is highly-valued might also be the one selected most frequently) they are not necessarily the same since, for example, teams may be biased towards or against specific strategies or champions. 39 | 40 | 4. Swain Bot's objective to estimate the value of submissions for a given draft state is commonly approached using techniques from Reinforcement Learning (RL). RL methods have been successfully used in a variety of situations such as teaching robots how to move, playing ATARI games, and even [playing DOTA2](https://blog.openai.com/dota-2/). A common element to most RL applications is the ability to automatically explore and evaluate states as they are encountered in a simulated environment. However, Swain Bot is not capable of automatically playing out the drafts it recommends in order to evaluate them (yet..) and so is dependent on the data it observes originating from games that were previously played. This scenario is reminiscent of a Supervised Learning (SL) problem called behavioral cloning, where the task is to learn and replicate the policy outlined by an expert. However, behavioral cloning does not include the estimation of values associated with actions and attempts to directly mimic the expert policy. Swain Bot instead implements an RL algorithm to estimate action values (Q-Learning), but trained using expertly-generated data. In practice this means that the predictions made by Swain Bot can only have an expectation of accuracy when following trajectories that are similar to the paths prescribed by the training data. 41 | 42 | ## Methods 43 | This section is not designed to be too technical, but rather give some insight into how Swain Bot is implemented and some important modifications that helped with the learning process. For some excellent and thorough discussions on RL, check out the following: 44 | - [David Silver's course on RL](http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching.html) [(with video lectures)](https://www.youtube.com/watch?v=2pWv7GOvuf0) 45 | - [Reinforcement Learning](http://incompleteideas.net/book/the-book-2nd.html) By Sutton and Barto 46 | - [The DeepMind ATARI paper](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) 47 | - [Dueling DQNs](https://arxiv.org/pdf/1511.06581.pdf) 48 | - And finally a [few](http://outlace.com/rlpart3.html), [useful](https://www.intelnervana.com/demystifying-deep-reinforcement-learning/), [tutorials](https://medium.com/emergent-future/simple-reinforcement-learning-with-tensorflow-part-0-q-learning-with-tables-and-neural-networks-d195264329d0) 49 | 50 | ### Representing Draft States and Actions 51 | Each of the _N_ eligible champions (138 as of this writing) in a draft is represented by a unique `champion_id` integer and every position in the game (five positions per team plus banned champions) is given by a `position_id`. An _action_ (or _submission_) to the draft is defined as tuple of the form `(champion_id, position_id) = (i,j)` representing the selection of champion `i` to position `j` in the draft. We can represent the _draft state_ as a boolean matrix _S_ where _S_(i,j) = 1 if the ith champion is present in the draft at the jth position. The number of columns in _S_ is determined by how much information about the positions is available to the drafter: 52 | - In a _completely informed_ draft all position information is known so _S_ is an `N x 11` matrix (10 positions + bans). 53 | - In a _partially informed_ draft position information is only known for the drafter's team whereas only the `champion_id`s are known for the opponent's team. As a result _S_ is given by an `N x 7` matrix (5 positions + bans + enemy champion ids). 54 | 55 | Note that the _S_ is a sparse matrix since for any given state of a draft, there are no more than 10 picks and 10 bans that have been submitted so there are no more than 20 non-zero entries in _S_ at any given time. Swain Bot operates using partially informed draft states as inputs which may be obtained by projecting the five columns in the completely informed state corresponding to the positions in the opponent's draft onto a single column. Finally, we define the _actionable state_ to be the submatrix of _S_ corresponding to the actions the drafter may submit-- that is the columns corresponding to bans as well as the drafter's five submittable positions. 56 | 57 | ### Drafting as a Markov Decision Process (MDP) 58 | An individual experience (sometimes called a memory) observed during the draft can be recorded as a tuple of the form `e_k = (s_k,a_k,r_k,s')` where `s_k` is the initial state of the draft before an action was taken, `a_k` is the action submitted, `r_k` is the reward observed as a result of taking `a_k`, and `s'` is the successor state transitioned to away from `s_k`. For either completely or partially informed states, the draft can be fully recovered using the sequence of experiences `(e_0, e_1,..., e_n)` transitioned through during the draft. This sequence defines a Markov chain because given the current state _s_, the value of the successor state _s'_ is independent of the states that were transitioned through before _s_. In other words, the possible states we are able to transition to away from _s_ depend only on _s_ itself, and not on the states that were seen on the way to _s_. 59 | 60 | To complete the description of drafting as an MDP we need to define a reward schedule and discount factor. The discount factor is a scalar value between 0 and 1 that governs the present value of future expected rewards. Two common reasons to use a discount factor are to express uncertainty about the potential value of the future and to capture the extra value of taking an immediate reward over a delayed one (e.g. if the reward is financial, an immediate reward is worth more than a delayed reward because that immediate reward can then be used to earn additional interest). Typical discount factor values are in the range `0.9` to `0.99`. Swain Bot uses `discount_factor = 0.9`. 61 | 62 | The reward schedule is a vital component of the MDP and determines what policy the model will ultimately converge towards. As previously discussed, Swain Bot's long-term objective is to select actions which move the draft towards a winning state while in the short term placing some value on actions which are likely to be taken. The ideal reward schedule should combine these two goals so that Swain Bot predicts both good and probable actions. We will approach this by associating larger magnitude rewards with _terminal_ states and smaller magnitude rewards with non-terminal states. A terminal state _s_ occurs in one of the three scenarios where _s_ represents: 63 | 1. a valid, complete draft which resulted in a win 64 | 2. a valid, complete draft which resulted in a loss 65 | 3. an invalid draft (which cannot be played) 66 | 67 | All other states are valid, but non-terminal. An invalid state _s_ is one in which one or more of the following conditions are satisfied where _s_ represents: 68 | 1. an incorrect number of picks or bans for the phase of the draft described by that state (e.g. picks submitted during Ban Phase 1, too many picks submitted in a single phase, two consecutive picks associated with red side during Pick Phase 2, etc.) 69 | 2. at least one champion selected in more than one position (e.g. picked and banned, picked by both teams, or picked by a team in more than one role) 70 | 3. at least one non-ban actionable position with more than one champion submitted to that position. For partially complete drafts the opposing team position must also have no more than five submissions represented. 71 | 72 | It's reasonable to have the network (or a secondary network) infer which actions lead to invalid drafts, and use that information to avoid predicting them. This essentially amounts to learning the rules of drafting. However we can get away with a smaller network by filtering out the illegal actions before making predictions. This helps significantly reduce the amount of training time required before observing reasonable results. 73 | 74 | The empirically determined reward schedule is defined in two parts depending on if _s_ is a terminal state. If _s_ is terminal, the reward is given by 75 | 76 | 77 | 78 | If _s_ is non-terminal, the reward has the form 79 | 80 | 81 | 82 | where _a_* is the action taken during the original memory. 83 | 84 | ### Deep Q-Learning (DQN) 85 | With the framework describing drafting as an MDP, we can apply a Q-Learning algorithm to estimate `Q(s,a)`, the maximum expected future return by taking action `a` from state `s`. With 138 total champions and 20 chosen at a time to appear in the final draft state, there are roughly `6.07x10^{23}` possible ending states, making a tabular Q-learning method out of the question. Instead, we opt to use a straightforward fully connected neural network to estimate the value function for an arbitrary state. The selected model's architecture consists of 4 total layers: 86 | - 1 One-hot encoded input layer (representing draft state) 87 | - 2 FC + ReLU layers with `1024` nodes each 88 | - 1 FC linearly activated output layer with `n_champ x n_pos+1` nodes (representing actions from the actionable state) 89 | 90 | Regularization is present in the model via dropout in between the FC layers. In order to help reduce the amount of time the model spends learning the underlying draft structure, the first FC layer is given several supplementary boolean inputs corresponding to which positions have been filled in the draft as well as if the draft is currently in a banning phase. These inputs aren't strictly necessary since the value of these variables is directly inferrable from the state matrix, but doing so substantially reduces the amount of training time required before the model begins producing legal submissions for input states. There are also three important modifications to the fundamental DQN algorithm that are necessary in order to ensure that the Q-learning is stable: 91 | 1. Data augmentation 92 | 2. Experience Replay (ER) 93 | 3. Double DQN (DDQN) 94 | 95 | Data augmentation refers to modifying existing training data in order to effectively increase the size of the training set and help reduce model overfitting. The techniques used to modify the data include adding random noise to training data or cropping, mirroring, and rotating images. For the purposes of data augmentation in Swain Bot, we identify two types of augmentable submissions: sequential picks and bans. Referring to the original image of the drafting structure, notice that several of the submissions made during the two pick phases are made sequentially: red side's first two picks, blue side's second and third picks, and blue side's fourth and fifth picks. Swain Bot makes predictions one action at a time, so the order in which these submissions are presented is preserved and the input state for the second submission will include the first submission as a result. However, since these submissions are made back-to-back, in practice sequential picks use the same initial state and the order in which they are presented is arbitrary. A simple augmentation which helps Swain Bot emulate this common initial state is to randomly exchange the order in which sequential picks in the draft are executed. For example, if a sample draft includes two memories corresponding to red side sequentially picking (Kalista, ADC) and then (Leona, Support), half of the time the order in which these picks are learned from would be reversed. 96 | 97 | The second augmented submissions are bans that share a common phase. Technically the drafting structure alternates between teams during bans, meaning that the initial state for each ban is distinct and as a result the order in which they are submitted is important. On the other hand, the purpose behind banning is to either protect your draft (by removing strong submissions against your team) or attack the enemy draft (by removing strong submissions to their team). In either circumstance, the bans we submit do not depend on the opponent's submissions _except in the case where they ban a champion which we would otherwise ban_. Unlike picks, however, our opponents "sniping" a ban actually _benefits_ our draft by effectively allowing our team to make an additional ban. This effect can be approximated by randomly shuffling the order in which the bans of a given phase are presented, just like we did with sequential picks. Note that bans cannot be shuffled across phases because bans made during the second phase are made with information about 3/5ths of the picks for each team available. 98 | 99 | Experience replay provides a mechanism for separating the generation of memories from learning from those memories. In experience replay, each experience associated with a draft is stored into a pool of experiences spanning many drafts. The Q-learning update is applied to a minibatch of randomly sampled experiences contained within the pool. This is important because consecutive experiences generated from the same draft are strongly correlated and learning from them all simultaneously using SGD is not only inefficient, but may also lead to a local (suboptimal) minimum. By randomizing the samples drawn from the replay buffer, the correlation between experiences is at least broken. Additionally, each memory is potentially used in multiple updates, improving overall data efficiency. 100 | 101 | The default DQN algorithm selects actions "greedily" by taking the maximum over the estimated action-values. A side effect of this maximization is that the DQN tends to learn overestimated values. Unfortunately this over optimism is often non-uniformly distributed across actions and can degrade the performance of the learned policy. Furthermore, this overestimation also tends to grow as the number of actions increases. There are 822 (137 champions each selectable in 6 positions) possible actions during each stage of drafting. As a result, it is desirable to control this overestimation as much as possible. The DDQN algorithm proposed by van Hesselt et. al. attempts to limit this overestimation by pseudo-decoupling action selection from evaluation by utilizing two networks: an "online" network and a "target" network. The online network represents the most up-to-date parameter estimates, while the target network is a periodic snapshot of the online network. In simplest terms the original update for DQN 102 | 103 | `update = reward + discount_factor*max_a'{Q(s',a')}` 104 | 105 | is replaced with 106 | 107 | `update = reward + discount_factor*Q_target(s', max_a'{Q_online(s',a')})`. 108 | 109 | Note that this doesn't truly decouple action selection and evaluation because the target network is a copy of a previous online network. The goal of DDQN is to be the simplest modification of DQN in the direction of a truly decopuled algorithm (like double Q-learning) in order to get most of the benefit with the smallest computational overhead required. 110 | 111 | ## Analysis 112 | ### Evaluating the Model 113 | In addition to the "eyeball test" of Swain Bot's predictions (i.e. no illegal submissions, correct number of roles, overall "meta-feel" of drafts, etc.), we're also interested in a quantitative measure of performance. 114 | 115 | One approach is to treat predictions as we would with a classifier and measure the fraction of predictions which agree with what was actually submitted in a winning draft. However, it's important to recall that our objective is to predict valuable submissions which may not necessarily overlap with what a specific team is likely to submit. It is often the case that multiple submissions are suited for the draft and as a result each have roughly equal value (this is particularly true for bans and early submissions). Selecting an action from amongst these picks is mostly a function of the biases of the drafting team. Since team identities aren't included as part of the input, it is unrealistic to expect the model to match the exact submission made by every team. A simple way to try and compensate for this is to group the top `k` submissions and regard these as a set of "good" picks according to the model. Then we measure accuracy as the fraction of submissions made that are contained in the predicted "good" submission pool for each state. 116 | 117 | Another approach is to examine the difference in estimated Q-values between the top prediction (`max_a{Q(s,a)}`) and the actual submission (`Q(s,a*)`). The difference between these two values estimates how far off the actual action that was submitted is from taking over as the top prediction. If both the top recommendation and `a*` are really a good submissions for this state this difference should be relatively small. If we use this to compute a normalized mean squared error over a each state we are predicting for we should get an estimate of the model performance: 118 | 119 | 120 | 121 | Note that if the model were to assign all actions the same value then this measure of error would be trivially zero. So just like the classification measure of accuracy, this measure of error is not perfect. Nevertheless the combination gives some insight into how the model performs. 122 | 123 | ### Training Data 124 | Competitive match data was pulled from [lol.gamepedia.com](https://lol.gamepedia.com). 125 | 126 | The model was trained in two stages. For the first stage, data was obtained from matches played during the 2017 Summer Season and 2017 Worlds Championship Qualifiers for the five major international regions: 127 | - North America (NA LCS) 128 | - Europe (EU LCS) 129 | - China (LPL) 130 | - Korea (LCK) 131 | - Taiwan, Hong Kong, & Macau (LMS) 132 | 133 | Between the two training stages the model underwent a dampening iteration (in which the value of all predictions were reduced) in order to simulate a change in metas associated with the gap in time between the Summer Season and Worlds Championships. The second stage of training used data from the 119 matches played during the 2017 World Championship with 11 randomly selected matches from the knockout stages held out for validation. The model learned on each match in the training set for 100 epochs (i.e. each match was seen 100 times in expectation). The model was trained using a smaller learning rate `alpha_0 = 1.0e-05` which was halved every 10 epochs until it reached a minimum value of `alpha_f = 1.0e-08`. 134 | 135 | ### Validation Matches 136 | The figures below illustrates the drafts from each match in the validation set. The left side of the draft represents the submissions made by the blue side, while the right side depicts submissions made by the red side. 137 | 138 | 139 | 140 | 141 | The distribution of the eleven matches according to the stage of the worlds they were played in was: 142 | - Play-in stage: 2 matches 143 | - Group stage: 5 matches 144 | - Knockout stage: 3 matches 145 | - Finals: 1 match 146 | 147 | The table below lists the classification (top 1), "good" set classification (top 5), and normalized root mean square error (l2 error) for three categories: all submissions predicted (full), all submissions excluding the first phase of bans (no round 1 bans), and picks only (no bans). 148 | 149 | ``` 150 | Norm Information: 151 | Full 152 | Num_predictions = 110 153 | top 1: count 19 -> acc: 0.1727 154 | top 5: count 62 -> acc: 0.5636 155 | l2 error: 0.02775 156 | --- 157 | No Round 1 Bans 158 | Num_predictions = 77 159 | top 1: count 17 -> acc: 0.2208 160 | top 5: count 47 -> acc: 0.6104 161 | l2 error: 0.03004 162 | --- 163 | No Bans 164 | Num_predictions = 55 165 | top 1: count 14 -> acc: 0.2545 166 | top 5: count 39 -> acc: 0.7091 167 | l2 error: 0.02436 168 | ``` 169 | 170 | Starting with the top 1 classification accuracy it's apparent that predicting the exact submission at every stage in the draft is difficult for the model, which achieves an abysmal 15-25% accuracy across the three categories. This is likely due to a combination of lacking information about which teams are drafting (which would allow the model to distinguish between submissions that are either "comfort picks" or picks that a team is biased against) and the weighting difference between winning submissions and likely submissions. For the top 5 classification accuracy the model improves significantly, particularly when the first three bans are ignored. The stable l2 error also indicates that the model at least associates elevated values for the submissions from winning drafts even if they are not the exact submission it predicts. Finally, the jump in performance between the full set of predictions and the predictions excluding the first phase of banning generally holds true. In contrast, the difference in performance after further removing the second phase of bans is smaller. This suggests that earlier submissions in the draft are significantly more difficult to predict than later ones. This might be due to a combination of factors. First, since the submissions are the furthest away from the most rewarding states there is a large uncertainty in associating the reward observed at the end of a draft with the first few selections submitted. Second, even after it's complete the first phase of bans contributes relatively little information to the draft when compared with the information gained after several picks have been made. Finally, the first phase ban submissions are perhaps the most significantly influenced by team biases since they tend to revolve around removing the picks you and your opponent are likely to make. 171 | 172 | It's also interesting to look at the types of misclassifications the model made on the picks submitted in this data set. Naturally, many of the misclassifications occured with the model predicting one set of meta picks but the submitting team selecting something within the meta but outside that list (usually within the top 10 or so predictions). In particular the model tended to under-value using a first pick on Janna as a support, instead often predicting ADCs or Junglers in that slot. Then there were mistakes occured when teams opted for pocket picks to gain an advantage by either selecting a seldom seen champion or by "flexing" a champion to a secondary role. These submissions were the most surprsing to the model in the sense that they often lay way outside the models top five predictions. Below we highlight a few of these picks: 173 | 174 | 175 | 176 | The C9 v LYN match was one of a very small number of games where Kalista was not removed in the first phase of bans. Although she certainly defined the meta (with a staggering 100% presence in the drafting phase at Worlds 2017), there weren't enough matches played with her involved for the model to recommend her over the more commonly seen ADCs like Xayah, Tristana, and Kog'maw. This illustrates that the model does not connect "must ban" with "must pick" at the moment , although this is often the case with the strongest champions in the meta. C9's follow-up pick on Thresh was surprising partly because Thresh was in his own right an uncommon pick (played only 5 times total), but also because the model lacked context for his synergy with Kalista. In SSG v SKT, Zac was a surprising pick because he had only appeared 5 times in the tournament, and this match was his only win. The SKT v MSF game was the only game involving flexing Trundle to support at the tournament and RNG was the only team to pick Soraka (although they were successful with it). Ultimately it is a challenging problem for the model to accurately predict niche picks without overfitting. 177 | 178 | For completeness here is the table for all 2017 Worlds matches (including the play-in stages): 179 | ``` 180 | Norm Information: 181 | Full 182 | Num_predictions = 1190 183 | top 1: count 549 -> acc: 0.4613 184 | top 5: count 905 -> acc: 0.7605 185 | l2 error: 0.01443 186 | --- 187 | No Round 1 Bans 188 | Num_predictions = 833 189 | top 1: count 513 -> acc: 0.6158 190 | top 5: count 744 -> acc: 0.8932 191 | l2 error: 0.01136 192 | --- 193 | No Bans 194 | Num_predictions = 595 195 | top 1: count 371 -> acc: 0.6235 196 | top 5: count 525 -> acc: 0.8824 197 | l2 error: 0.01049 198 | ``` 199 | Obviously since the model was trained on the majority of these matches it performs much better. 200 | 201 | Worlds 2017 was dominated by the "Ardent Censer" meta which favored hard-scaling position 1 carries combined with position 5 supports who could abuse the item Ardent Censer (an item which amplified the damage output of the position 1 pick). This made using early picks to secure a favorable bot lane matchup extremely popular. The remaining pick from the first phase tended to be spent selecting a safe jungler like Jarvan IV, Sejuani, or Gragas. As a result, if we look at the distribution of positions made during the first phase of each of the 119 drafts conducted at 2017 Worlds we can see a strong bias against the solo lanes (positions 2 and 3). 202 | 203 | ``` 204 | Phase 1: Actual 205 | Position 1: Count 111, Ratio 0.311 206 | Position 2: Count 28, Ratio 0.0784 207 | Position 3: Count 24, Ratio 0.0672 208 | Position 4: Count 92, Ratio 0.258 209 | Position 5: Count 102, Ratio 0.286 210 | ``` 211 | 212 | We can compare this with the positions of the top 5 recommendations made by Swain Bot during the first pick phase: 213 | 214 | ``` 215 | Phase 1 Recommendations: 216 | Position 1: Count 611, Ratio 0.342 217 | Position 2: Count 179, Ratio 0.1 218 | Position 3: Count 248, Ratio 0.139 219 | Position 4: Count 305, Ratio 0.171 220 | Position 5: Count 442, Ratio 0.248 221 | ``` 222 | Swain Bot agrees with the meta in using early picks to secure a bot lane. However, by comparison it is more likely to suggest a solo lane pick in the first phase instead of a jungler. This effect was also seen in the actual drafts towards the end of the tournament where solo lane picks like Galio and Malzahar became increasingly valuable and took over what would have likely been jungler picks in the earlier stages. 223 | 224 | ## Looking Ahead 225 | Even with the promising results so far, Swain Bot is far from complete. Here are just a few things that could be worth looking into for the future: 226 | 227 | - Compared with the massive number of possibilities for drafting, we're still limited to the relatively tiny pool of drafts coming from competitive play. As a result Swain Bot's ability to draft is really limited to the types of picks likely to be seen in the professional leagues (for fun try asking Swain Bot what to pick after adding champions in positions that are rarely/never seen in competitive play like Teemo or Heimerdinger). Ideally we would have access to detailed draft data for the millions of games played across all skill levels through [Riot's API](https://developer.riotgames.com), but unfortunately the API does not preserve submissions in draft order (yet). If we could have one wish this would be it. There is some hope for the future with the recently announced competitive 5v5 [Clash mode](https://nexus.leagueoflegends.com/en-us/2017/12/dev-clash/). 228 | 229 | - Build a sequence of models each using a small number of patches or a single patch for data. This could help improve Swain Bot's meta adaptation between patches by ensembling this sequence to make predictions from the most recent "meta history" while data from a new patch is being processed. 230 | 231 | - Including some form of team data for competitive drafts, either as a model input or as an additional layer on top of the current model structure. Swain Bot's current iteration implicitly assumes that every team is able to play every champion at the highest level. Even for the best players in the world, this is certainly not true. By adding information about the biases of the teams drafting, we could improve Swain Bot's ability to make suggestions which suit both the meta and the team. 232 | 233 | - Exploring alternative models (in particular Actor-Critic Policy Gradient). The DDQN model Swain Bot implements is by no means a poor one, but in the constantly evolving world of machine learning it is no longer cutting-edge. In particular our implementation of DDQN is deterministic, meaning that when presented with the same input state the model will always output the same action. If recommendations made by Swain Bot were to be used in practice to assemble a draft, this could be potentially abused by our opponents. Policy gradient methods parameterize the policy directly and are able to represent both stochastic and deterministic policies. 234 | 235 | - Reducing the required training time using GPUs and/or services like AWS. 236 | 237 | ## Conclusion 238 | Thank you for taking the time to read through this write up. Working on Swain Bot was originallly an excuse to dive into machine learning but eventually became a huge motivator to keep learning and grinding.. especially early on when nothing was stable. I also want to give a huge thank you to Crystal for supporting me in my now year-long diversion into machine learning and both being someone to lean on through the hard times and someone to celebrate with during the good times. I couldn't have done this without you. Thank you also to the greatest group of friends I could ask for. League is a fun game on its own, but I wouldn't still be so excited to play after so many years if it wasn't for you (and your reluctant tolerance for Nunu Top). Finally, thank you to the folks at Riot Games. The passion, excitement, and devotion you show for your product is infectious. 239 | 240 | 241 | ## Disclaimer 242 | Swain Bot isn’t endorsed by Riot Games and doesn’t reflect the views or opinions of Riot Games or anyone officially involved in producing or managing League of Legends. League of Legends and Riot Games are trademarks or registered trademarks of Riot Games, Inc. League of Legends © Riot Games, Inc. 243 | --------------------------------------------------------------------------------