├── models
└── .placeholder
├── src
├── data
│ ├── __init__.py
│ ├── create_database.py
│ ├── riotapi.py
│ ├── match_pool.py
│ ├── champion_info.py
│ ├── database_ops.py
│ └── query_wiki.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── softmax.py
│ ├── inference_model.py
│ └── qNetwork.py
├── run_tests.py
├── tests
│ └── tests.py
├── requirements.txt
├── update_champions_data.py
├── features
│ ├── rewards.py
│ ├── experience_replay.py
│ ├── draft.py
│ ├── match_processing.py
│ └── draftstate.py
├── update_match_data.py
├── main.py
├── model_predictions.py
└── trainer.py
├── _config.yml
├── common
└── images
│ ├── nms_error.png
│ ├── val_outliers.png
│ ├── val_matches_1.png
│ ├── val_matches_2.png
│ ├── discount_factor.png
│ ├── reward_sched_term.png
│ ├── validation_matches.png
│ ├── league_draft_structure.png
│ └── reward_sched_non-term.png
├── data
├── competitiveMatchData.db
├── match_sources.json
├── test_train_split.txt
└── patch_info.json
├── LICENSE
└── README.md
/models/.placeholder:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-leap-day
--------------------------------------------------------------------------------
/src/run_tests.py:
--------------------------------------------------------------------------------
1 | from tests.tests import run
2 |
3 | run()
4 |
--------------------------------------------------------------------------------
/common/images/nms_error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/nms_error.png
--------------------------------------------------------------------------------
/src/tests/tests.py:
--------------------------------------------------------------------------------
1 | from features.draftstate import DraftState
2 |
3 | def run():
4 | print("HELLO")
5 |
--------------------------------------------------------------------------------
/common/images/val_outliers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/val_outliers.png
--------------------------------------------------------------------------------
/data/competitiveMatchData.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lightd22/swainBot/HEAD/data/competitiveMatchData.db
--------------------------------------------------------------------------------
/common/images/val_matches_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/val_matches_1.png
--------------------------------------------------------------------------------
/common/images/val_matches_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/val_matches_2.png
--------------------------------------------------------------------------------
/common/images/discount_factor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/discount_factor.png
--------------------------------------------------------------------------------
/common/images/reward_sched_term.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/reward_sched_term.png
--------------------------------------------------------------------------------
/common/images/validation_matches.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/validation_matches.png
--------------------------------------------------------------------------------
/common/images/league_draft_structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/league_draft_structure.png
--------------------------------------------------------------------------------
/common/images/reward_sched_non-term.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lightd22/swainBot/HEAD/common/images/reward_sched_non-term.png
--------------------------------------------------------------------------------
/src/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.13.3
2 | matplotlib==1.5.3
3 | tensorflow==1.15.2
4 | requests==2.11.1
5 | pandas==0.18.1
6 | luigi==2.8.0
7 |
--------------------------------------------------------------------------------
/data/match_sources.json:
--------------------------------------------------------------------------------
1 | {
2 | "match_sources":
3 | {
4 | "tournaments":[],
5 | "patches":["8.16","8.17","8.18","8.19"]
6 | }
7 | }
8 |
--------------------------------------------------------------------------------
/src/models/base_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | class BaseModel():
3 | def __init__(self, name, path):
4 | self._name = name
5 | self._path_to_model = path
6 | self._graph = tf.Graph()
7 | self.sess = tf.Session(graph=self._graph)
8 |
9 | def __del__(self):
10 | try:
11 | self.sess.close()
12 | del self.sess
13 | finally:
14 | print("Model closed..")
15 |
16 | def build_model(self):
17 | raise NotImplementedError
18 | def init_saver(self):
19 | raise NotImplementedError
20 | def save(self):
21 | raise NotImplementedError
22 | def load(self):
23 | raise NotImplementedError
24 |
--------------------------------------------------------------------------------
/data/test_train_split.txt:
--------------------------------------------------------------------------------
1 | {"validation_ids": [1111, 238, 2300, 1104, 243, 481, 1422, 220, 1420, 2317, 2318, 219, 2122, 240, 1418, 221, 2334, 2333, 211, 1129], "training_ids": [2129, 1113, 490, 2315, 1097, 465, 1425, 225, 217, 214, 227, 2108, 244, 1419, 476, 1101, 467, 1122, 471, 1416, 469, 1102, 246, 489, 463, 2331, 2330, 1417, 2123, 1125, 2113, 475, 458, 2304, 1437, 229, 2313, 1128, 2115, 1433, 224, 485, 2307, 2325, 216, 488, 241, 2320, 218, 209, 233, 2332, 1430, 1108, 242, 1134, 1095, 487, 1436, 461, 479, 486, 2116, 466, 477, 1424, 2323, 2302, 1423, 462, 2306, 2327, 1127, 1098, 2321, 2111, 478, 234, 2125, 236, 472, 239, 228, 212, 1130, 1432, 468, 1116, 484, 1123, 1117, 2326, 245, 2109, 470, 2127, 2119, 1434, 473, 1124, 2118, 1121, 1106, 1426, 474, 235, 459, 2128, 480, 2124, 2305, 226, 2121, 215, 2117, 464, 232, 1112, 1120, 1110, 1100, 2324, 1431, 2114, 2314, 2303, 2311, 1114, 1119, 2120, 237, 2110, 2319, 2126, 1135, 1131, 230, 1132, 2329, 2309, 1099, 1096, 482, 1435, 460, 1126, 483, 231, 1109, 1133, 1103, 1118, 1429, 1115, 1107, 213, 223, 1428, 1427, 2312, 2328, 2107, 222, 2322, 2308, 210, 2112, 2316, 1105, 491, 2301, 2310, 1421]}
--------------------------------------------------------------------------------
/src/update_champions_data.py:
--------------------------------------------------------------------------------
1 | import luigi
2 | import requests
3 | import json
4 | import time
5 | import os
6 |
7 | class ChampionsDownload(luigi.ExternalTask):
8 | champions_path = luigi.Parameter(default="champions.json")
9 | def output(self):
10 | # output is temporary file to force task to check if patch data matches
11 | return luigi.LocalTarget("tmp/pipeline/champions{}.json".format(time.time()))
12 |
13 | def run(self):
14 | url = "https://ddragon.leagueoflegends.com/api/versions.json"
15 | response = requests.get(url=url).json()
16 | current_patch = response[0]
17 | # Check for local file patch version
18 | try:
19 | with open(self.champions_path, 'r') as infile:
20 | data = json.load(infile)
21 | local_patch = data["version"]
22 | except:
23 | local_patch = None
24 |
25 | # update local file if patches do not match (uses temporary file)
26 | if local_patch != current_patch:
27 | print("Local patch does not match current patch.. Updating")
28 | url = "http://ddragon.leagueoflegends.com/cdn/{current_patch}/data/en_US/champion.json".format(current_patch=current_patch)
29 | response = requests.get(url=url).json()
30 | tmp_file = self.output().path
31 | with open(tmp_file, 'w') as outfile:
32 | json.dump(response, outfile)
33 | os.rename(tmp_file, self.champions_path)
34 | else:
35 | print("Local patch matches current patch.. Skipping")
36 | return
37 |
38 | if __name__ == "__main__":
39 | luigi.run(['ChampionsDownload', "--local-scheduler"])
40 |
--------------------------------------------------------------------------------
/src/data/create_database.py:
--------------------------------------------------------------------------------
1 | import sqlite3
2 | from sqlite3 import Error
3 | import json
4 | from .query_wiki import query_wiki
5 | import re
6 | from . import database_ops as dbo
7 |
8 | def table_col_info(cursor, tableName, printOut=False):
9 | """
10 | Returns a list of tuples with column informations:
11 | (id, name, type, notnull, default_value, primary_key)
12 | """
13 | cursor.execute('PRAGMA TABLE_INFO({})'.format(tableName))
14 | info = cursor.fetchall()
15 |
16 | if printOut:
17 | print("Column Info:\nID, Name, Type, NotNull, DefaultVal, PrimaryKey")
18 | for col in info:
19 | print(col)
20 | return info
21 |
22 | def create_tables(cursor, tableNames, columnInfo, clobber = False):
23 | """
24 | create_tables attempts to create a table for each table in the list tableNames with
25 | columns as defined by columnInfo. For each if table = tableNames[k] then the columns for
26 | table are defined by columns = columnInfo[k]. Note that each element in columnInfo must
27 | be a list of strings of the form column[j] = "jth_column_name jth_column_data_type"
28 |
29 | Args:
30 | cursor (sqlite cursor): cursor used to execute commmands
31 | tableNames (list(string)): string labels for tableNames
32 | columnInfo (list(list(string))): list of string labels for each column of each table
33 | clobber (bool): flag to determine if old tables should be overwritten
34 |
35 | Returns:
36 | status (int): 0 if table creation failed, 1 if table creation was successful
37 | """
38 |
39 | for (table, colInfo) in zip(tableNames, columnInfo):
40 | columnInfoString = ", ".join(colInfo)
41 | try:
42 | if clobber:
43 | cursor.execute("DROP TABLE IF EXISTS {tableName}".format(tableName=table))
44 | cursor.execute("CREATE TABLE {tableName} ({columnInfo})".format(tableName=table,columnInfo=columnInfoString))
45 | except Error as e:
46 | print(e)
47 | print("Table {} already exists! Here's it's column info:".format(table))
48 | table_col_info(cursor, table, True)
49 | print("***")
50 | return 0
51 |
52 | return 1
53 |
--------------------------------------------------------------------------------
/src/features/rewards.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .draftstate import DraftState as ds
3 |
4 | def get_reward(state, match, submitted_action, actual_action):
5 | """
6 | Args:
7 | state (DraftState): Present state of the draft to be checked for reward
8 | match (dict): record of match which identifies winning team
9 | submitted_action (tuple(int)): id of action submitted by model
10 | actual_action (tuple(int)): id of action submitted in observation
11 | Returns:
12 | reward (int): Integer value representing the reward earned for the draft state.
13 |
14 | get_reward takes a draft state and returns the immediate reward for reaching that state. The reward is determined by a simple reward table
15 | 1) state is invalid -> reward = -10
16 | 2) state is complete, valid, and the selection was submitted by the winning team -> reward = +5
17 | 3) state is complete, valid but the submission was made by the losing team -> reward = +2.5
18 | 3) state is valid, but incomplete -> reward = 0
19 | """
20 | status = state.evaluate()
21 | if(status in ds.invalid_states):
22 | return -10.
23 |
24 | reward = 0.
25 | winner = get_winning_team(match)
26 | if(status == ds.DRAFT_COMPLETE and winner is not None):
27 | if(state.team == winner):
28 | reward += 5.
29 | else:
30 | reward += 2.5
31 |
32 | if(submitted_action == actual_action):
33 | reward += 0.5
34 | else:
35 | reward += -0.5
36 |
37 | return reward
38 |
39 | def get_winning_team(match):
40 | """
41 | Args:
42 | match (dict): match dictionary with pick and ban data for a single game.
43 | Returns:
44 | val (int): Integer representing which team won the match.
45 | val = DraftState.RED_TEAM if the red team won
46 | val = DraftState.BLUE_TEAM if blue team won
47 | val = None if match does not have data for winning team
48 |
49 | get_winning_team returns the winning team of the input match encoded as an integer according to DraftState.
50 | """
51 | if match["winner"]==0:
52 | return ds.BLUE_TEAM
53 | elif match["winner"]==1:
54 | return ds.RED_TEAM
55 | return None
56 |
--------------------------------------------------------------------------------
/src/features/experience_replay.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | class ExperienceBuffer():
4 | """
5 | ExperienceBuffer is a class for storing and adding experiences to be sampled from when batch learning a Qnetwork. An experience is defined as a tuple of the form
6 | (s, a, r, s') where
7 | s = input state
8 | a = action taken from state s
9 | r = reward obtained for taking action a
10 | s' = ending state after taking action a
11 | Args:
12 | max_buffer_size (int): maximum number of experiences to store in the buffer, default value is 300.
13 | """
14 | def __init__(self, max_buffer_size = 300):
15 | self.buffer = []
16 | self.buffer_size = max_buffer_size
17 | self.oldest_experience = 0
18 |
19 | def store(self, experiences):
20 | """
21 | ExperienceBuffer.store stores the input list of experience tuples into the buffer. The expereince is stored in one of two ways:
22 | 1) If the buffer has space remaining, the experience is appended to the end
23 | 2) If the buffer is full, the input experience replaces the oldest experience in the buffer
24 |
25 | Args:
26 | experiences ( list(tuple) ): each experience is a tuple of the form (s, a, r, s')
27 | Returns:
28 | None
29 | """
30 | for experience in experiences:
31 | if len(self.buffer) < self.buffer_size:
32 | self.buffer.append(experience)
33 | else:
34 | self.buffer[self.oldest_experience] = experience
35 | self.oldest_experience += 1
36 | self.oldest_experience = self.oldest_experience % self.buffer_size
37 | return None
38 |
39 | def sample(self, sample_size):
40 | """
41 | ExperienceBuffer.sample samples the current buffer using random.sample to return a collection of sample_size experiences from the replay buffer.
42 | random.sample samples without replacement, so sample_size must be no larger than the length of the current buffer.
43 |
44 | Args:
45 | sample_size (int): number of samples to take from buffer
46 | Returns:
47 | sample (list(tuples)): list of experience replay samples. len(sample) = sample_size.
48 | """
49 | return random.sample(self.buffer,sample_size)
50 |
51 | def get_buffer_size(self):
52 | """
53 | Returns length of the buffer.
54 | """
55 | return len(self.buffer)
56 |
--------------------------------------------------------------------------------
/src/data/riotapi.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from . import myRiotApiKey
3 | import time
4 | api_versions = {
5 | "staticdata": "v3",
6 | "datadragon": "7.15.1"
7 | }
8 | valid_methods = ["GET", "PUT", "POST"]
9 | region = "na1"
10 | def set_api_key(key):
11 | """
12 | Set calling user's API Key
13 |
14 | Args:
15 | key (string): the Riot API key desired for use.
16 | """
17 | myRiotApiKey.api_key = key
18 | return key
19 |
20 | def set_region(reg):
21 | """
22 | Set region to run API queries through
23 |
24 | Args:
25 | reg (string): region through which we are sending API requests
26 | """
27 | reg = reg.lower()
28 | regions = ["br1", "eun1", "euw1", "jp1", "kr", "la1", "la2", "na1", "oce1", "tr1", "ru"]
29 |
30 | assert (reg in regions), "Invalid region!"
31 | region = reg
32 | return region
33 |
34 | def make_request(request, method, params={}):
35 | """
36 | Makes a rate-limited HTTP request to Riot API and returns the response data
37 | """
38 | url = "https://{region}.api.riotgames.com/lol/{request}".format(region=region,request=request)
39 | try:
40 | response = execute_request(url, method, params)
41 | if(not response.ok):
42 | response.raise_for_status()
43 | return response.json()
44 | except requests.exceptions.HTTPError as e:
45 | # Wait and try again on 429 (rate limit exceeded)
46 | if response.status_code == 429:
47 | if "X-Rate-Limit-Type" not in e.headers or e.headers["X-Rate-Limit-Type"] == "service":
48 | # Wait 1 second before retrying
49 | time.sleep(1)
50 | else:
51 | retry_after = 1
52 | if response.headers["Retry-After"]:
53 | retry_after += int(e.headers["Retry-After"])
54 |
55 | time.sleep(retry_after)
56 | return make_request(request, method, params)
57 | else:
58 | raise
59 |
60 | def execute_request(url, method, params={}):
61 | """
62 | Executes HTTP request using requests library and returns response object.
63 | Args:
64 | url (str): full url string to request
65 | method (str): HTTP method to use (one of "GET", "PUT", or "POST")
66 | params (dict): dictionary of parameters to send along url
67 |
68 | Returns:
69 | response object returned by requests
70 | """
71 |
72 | response = None
73 | assert(method in valid_methods), "[execute_request] Invalid HTTP method!"
74 | if(method == "GET"):
75 | response = requests.get(url=url, params=params)
76 | return response
77 |
--------------------------------------------------------------------------------
/data/patch_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "patch_info":
3 | {
4 | "2018":
5 | {
6 | "NA_LCS":
7 | {
8 | "Spring_Season":["8.1","8.1","8.2","8.2","8.3","8.3","8.4","8.4","8.5"],
9 | "Spring_Playoffs":["8.5"],
10 | "Summer_Season":["8.11","8.12","8.12","8.13","8.13","8.14","8.14","8.15","8.15"],
11 | "Summer_Playoffs":["8.16"],
12 | "Regional_Finals":["8.16"]
13 | },
14 | "NA_ACA":
15 | {
16 | "Spring_Season":["8.1","8.1","8.2","8.2","8.3","8.3","8.4","8.4","8.5"],
17 | "Spring_Playoffs":["8.5"],
18 | "Summer_Season":["8.11","8.12","8.13","8.13","8.13","8.14","8.14","8.15"],
19 | "Summer_Playoffs":["8.16"]
20 | },
21 | "EU_LCS":
22 | {
23 | "Spring_Season":["8.1","8.1","8.2","8.2","8.3","8.3","8.4","8.4","8.5"],
24 | "Spring_Playoffs":["8.5"],
25 | "Summer_Season":["8.11","8.12","8.12","8.13","8.13","8.14","8.14","8.15","8.15"],
26 | "Summer_Playoffs":["8.16"],
27 | "Regional_Finals":["8.16"]
28 | },
29 | "LCK":
30 | {
31 | "Spring_Season":["8.1","8.1","8.1","8.2","8.3","8.3","8.4","8.4","8.5"],
32 | "Spring_Playoffs":["8.5"],
33 | "Summer_Season":["8.11","8.11","8.12","8.13","8.13","8.13","8.14","8.14"],
34 | "Summer_Playoffs":["8.15"],
35 | "Regional_Finals":["8.15"]
36 | },
37 | "KR_CHAL":
38 | {
39 | "Spring_Season":["8.1","8.1","8.1","8.2","8.2","8.3","8.4","8.4","8.5"],
40 | "Spring_Playoffs":["8.6"],
41 | "Summer_Season":["8.11","8.11","8.12","8.13","8.13","8.13","8.14","8.14","8.15","8.15","8.15"],
42 | "Summer_Playoffs":["8.15"]
43 | },
44 | "LPL":
45 | {
46 | "Spring_Season":["8.1","8.1","8.1","8.1","8.2","8.2","8.4","8.4","8.5"],
47 | "Spring_Playoffs":["8.5"],
48 | "Summer_Season":["8.11","8.11","8.11","8.13","8.13","8.13","8.14","8.14","8.15","8.15","8.15"],
49 | "Summer_Playoffs":["8.16"],
50 | "Regional_Finals":["8.16"]
51 | },
52 | "LDL":
53 | {
54 | "Spring_Playoffs":["8.8"],
55 | "Grand_Finals":["8.16"]
56 | },
57 | "LMS":
58 | {
59 | "Spring_Season":["8.1","8.1","8.1","8.2","8.3","8.3","8.4","8.4","8.5","8.5"],
60 | "Spring_Playoffs":["8.6"],
61 | "Summer_Season":["8.11","8.12","8.13","8.13","8.13","8.14","8.14","8.15"],
62 | "Summer_Playoffs":["8.17"],
63 | "Regional_Finals":["8.17"]
64 | },
65 | "MSI":
66 | {
67 | "Play-In":["8.8"],
68 | "Main_Event":["8.8"]
69 | },
70 | "WORLDS":
71 | {
72 | "Play-In":["8.19"],
73 | "Main_Event":["8.19"]
74 | }
75 | }
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/src/features/draft.py:
--------------------------------------------------------------------------------
1 | class Draft(object):
2 | BLUE_TEAM = 0
3 | RED_TEAM = 1
4 | BAN = 201
5 | PICK = 202
6 | PHASES = [BAN, PICK]
7 |
8 | # Draft specifcations
9 | # Default specification for drafting structure
10 | default_draft = [
11 | (BLUE_TEAM, BAN),
12 | (RED_TEAM, BAN),
13 | (BLUE_TEAM, BAN),
14 | (RED_TEAM, BAN),
15 | (BLUE_TEAM, BAN),
16 | (RED_TEAM, BAN),
17 |
18 | (BLUE_TEAM, PICK),
19 | (RED_TEAM, PICK),
20 | (RED_TEAM, PICK),
21 | (BLUE_TEAM, PICK),
22 | (BLUE_TEAM, PICK),
23 | (RED_TEAM, PICK),
24 |
25 | (RED_TEAM, BAN),
26 | (BLUE_TEAM, BAN),
27 | (RED_TEAM, BAN),
28 | (BLUE_TEAM, BAN),
29 |
30 | (RED_TEAM, PICK),
31 | (BLUE_TEAM, PICK),
32 | (BLUE_TEAM, PICK),
33 | (RED_TEAM, PICK),
34 | ]
35 |
36 | no_bans = [
37 | (BLUE_TEAM, PICK),
38 | (RED_TEAM, PICK),
39 | (RED_TEAM, PICK),
40 | (BLUE_TEAM, PICK),
41 | (BLUE_TEAM, PICK),
42 | (RED_TEAM, PICK),
43 |
44 | (RED_TEAM, PICK),
45 | (BLUE_TEAM, PICK),
46 | (BLUE_TEAM, PICK),
47 | (RED_TEAM, PICK),
48 | ]
49 | # Dictionary mapping draft labels to draft structures
50 | draft_structures = {'default': default_draft,
51 | 'no_bans': no_bans,
52 | }
53 |
54 | def __init__(self, draft_type = 'default'):
55 | self._draft_structure = None
56 | try:
57 | self._draft_structure = Draft.draft_structures[draft_type]
58 | except KeyError:
59 | print("In draft.py: Draft structure not defined")
60 | raise
61 |
62 | self.PHASE_LENGTHS = {}
63 | for phase in Draft.PHASES:
64 | self.PHASE_LENGTHS[phase] = []
65 |
66 | phase_length = 0
67 | current_phase = None
68 | for (team, phase) in self._draft_structure:
69 | if not current_phase:
70 | current_phase = phase
71 | if phase == current_phase:
72 | phase_length += 1
73 | else:
74 | self.PHASE_LENGTHS[current_phase].append(phase_length)
75 | current_phase = phase
76 | phase_length = 1
77 | self.PHASE_LENGTHS[current_phase].append(phase_length) # don't forget last phase
78 | self.NUM_BANS = sum(self.PHASE_LENGTHS[Draft.BAN]) # Total number of bans in draft
79 | self.NUM_PICKS = sum(self.PHASE_LENGTHS[Draft.PICK]) # Total number of picks in draft
80 |
81 | # submission_dist[k] gives tuple of counts for pick types just before kth submission is made (last element will hold final submission distribution for draft)
82 | self.submission_dist = [(0,0,0)]
83 | for (team, phase) in self._draft_structure:
84 | (cur_ban, cur_blue, cur_red) = self.submission_dist[-1]
85 | if phase == Draft.BAN:
86 | next_dist = (cur_ban+1, cur_blue, cur_red)
87 | elif team == Draft.BLUE_TEAM:
88 | next_dist = (cur_ban, cur_blue+1, cur_red)
89 | else:
90 | next_dist = (cur_ban, cur_blue, cur_red+1)
91 | self.submission_dist += [next_dist]
92 |
93 | def get_active_team(self, submission_count):
94 | """
95 | Gets the active team in the draft based on the number of submissions currently present
96 | Args:
97 | submission_count (int): number of submissions currently submitted to draft
98 | Returns:
99 | Draft.BLUE_TEAM if blue is active, else Draft.RED_TEAM
100 | """
101 | if submission_count > len(self._draft_structure):
102 | raise
103 | elif submission_count == len(self._draft_structure):
104 | return None
105 | (team, sub_type) = self._draft_structure[submission_count]
106 | return team
107 |
108 | def get_active_phase(self, submission_count):
109 | """
110 | Returns phase identifier for current phase of the draft based on the number of submissions made.
111 | Args:
112 | None
113 | Returns:
114 | Draft.BAN if state is in banning phase, otherwise Draft.PICK
115 | """
116 | if submission_count > len(self._draft_structure):
117 | raise
118 | elif submission_count == len(self._draft_structure):
119 | return None
120 | (team, sub_type) = self._draft_structure[submission_count]
121 | return sub_type
122 |
123 | if __name__ == "__main__":
124 | draft = Draft("default")
125 | [print(thing) for thing in draft.submission_dist]
126 |
--------------------------------------------------------------------------------
/src/models/softmax.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | from . import base_model
5 |
6 | class SoftmaxNetwork(base_model.BaseModel):
7 | """
8 | Args:
9 | input_shape (tuple): tuple of inputs to network.
10 | output_shape (int): number of output nodes for network.
11 | filter_sizes (tuple of 2 ints): number of filters in each of the two hidden layers. Defaults to (16,32).
12 | learning_rate (float): network's willingness to change current weights given new example
13 | regularization (float): strength of weights regularization term in loss function
14 |
15 | A simple softmax network class which is responsible for holding and updating the weights and biases used in predicing actions for given state. This network will consist of
16 | the following layers:
17 | 1) Input- a DraftState state s (an array of bool) representing the current state reshaped into an [n_batch, *input_shape] tensor.
18 | 2-4) Two layers of relu-activated hidden fc layers
19 | 4) Output- softmax-obtained probability of action submission for output_shape actions available.
20 |
21 | """
22 | @property
23 | def name(self):
24 | return self._name
25 |
26 | def __init__(self, name, path, input_shape, output_shape, filter_sizes = (512,512), learning_rate=1.e-3, regularization_coeff = 0.01):
27 | super().__init__(name=name, path=path)
28 | self._input_shape = input_shape
29 | self._output_shape = output_shape
30 | self._learning_rate = learning_rate
31 | self._regularization_coeff = regularization_coeff
32 | self._n_hidden_layers = len(filter_sizes)
33 | self._n_layers = self._n_hidden_layers + 2
34 | self._filter_sizes = filter_sizes
35 |
36 | self.ops_dict = self.build_model(name=self._name)
37 | with self._graph.as_default():
38 | self.ops_dict["init"] = tf.global_variables_initializer()
39 |
40 | self.init_saver()
41 |
42 | def init_saver(self):
43 | with self._graph.as_default():
44 | self.saver = tf.train.Saver()
45 |
46 | def save(self, path):
47 | self.saver.save(self.sess, save_path=path)
48 |
49 | def load(self, path):
50 | self.saver.restore(self.sess, save_path=path)
51 |
52 | def build_model(self, name):
53 | ops_dict = {}
54 | with self._graph.as_default():
55 | with tf.variable_scope(name):
56 | ops_dict["learning_rate"] = tf.Variable(self._learning_rate, trainable=False, name="learning_rate")
57 |
58 | # Incoming state matrices are of size input_size = (nChampions, nPos+2)
59 | # 'None' here means the input tensor will flex with the number of training
60 | # examples (aka batch size).
61 | ops_dict["input"] = tf.placeholder(tf.float32, (None,)+self._input_shape, name="inputs")
62 | ops_dict["dropout_keep_prob"] = tf.placeholder_with_default(1.0,shape=())
63 |
64 | # Fully connected (FC) layers:
65 | fc0 = tf.layers.dense(
66 | ops_dict["input"],
67 | self._filter_sizes[0],
68 | activation=tf.nn.relu,
69 | bias_initializer=tf.constant_initializer(0.1),
70 | name="fc_0")
71 | dropout0 = tf.nn.dropout(fc0, ops_dict["dropout_keep_prob"])
72 |
73 | fc1 = tf.layers.dense(
74 | dropout0,
75 | self._filter_sizes[1],
76 | activation=tf.nn.relu,
77 | bias_initializer=tf.constant_initializer(0.1),
78 | name="fc_1")
79 | dropout1 = tf.nn.dropout(fc1, ops_dict["dropout_keep_prob"])
80 |
81 | # Logits layer
82 | ops_dict["logits"] = tf.layers.dense(
83 | dropout1,
84 | self._output_shape,
85 | activation=None,
86 | bias_initializer=tf.constant_initializer(0.1),
87 | kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=self._regularization_coeff),
88 | name="logits")
89 |
90 | # Placeholder for valid actions filter
91 | ops_dict["valid_actions"] = tf.placeholder(tf.bool, shape=ops_dict["logits"].shape, name="valid_actions")
92 |
93 | # Filtered logits
94 | ops_dict["valid_logits"] = tf.where(ops_dict["valid_actions"], ops_dict["logits"], tf.scalar_mul(-np.inf, tf.ones_like(ops_dict["logits"])), name="valid_logits")
95 |
96 | # Predicted optimal action amongst valid actions
97 | ops_dict["probabilities"] = tf.nn.softmax(ops_dict["valid_logits"], name="action_probabilites")
98 | ops_dict["prediction"] = tf.argmax(input=ops_dict["valid_logits"], axis=1, name="predictions")
99 |
100 | ops_dict["actions"] = tf.placeholder(tf.int32, shape=[None], name="submitted_actions")
101 |
102 | ops_dict["loss"] = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=ops_dict["actions"], logits=ops_dict["valid_logits"]), name="loss")
103 |
104 | ops_dict["trainer"] = tf.train.AdamOptimizer(learning_rate = ops_dict["learning_rate"])
105 | ops_dict["update"] = ops_dict["trainer"].minimize(ops_dict["loss"], name="update")
106 |
107 | return ops_dict
108 |
--------------------------------------------------------------------------------
/src/models/inference_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from . import base_model
3 |
4 | class QNetInferenceModel(base_model.BaseModel):
5 | def __init__(self, name, path):
6 | super().__init__(name=name, path=path)
7 | self.init_saver()
8 | self.ops_dict = self.build_model()
9 |
10 | def init_saver(self):
11 | with self._graph.as_default():
12 | self.saver = tf.train.import_meta_graph("{path}.ckpt.meta".format(path=self._path_to_model))
13 | self.saver.restore(self.sess,"{path}.ckpt".format(path=self._path_to_model))
14 | def build_model(self):
15 | ops_dict = {}
16 | with self._graph.as_default():
17 | ops_dict["predict_q"] = tf.get_default_graph().get_tensor_by_name("online/valid_q_vals:0")
18 | ops_dict["prediction"] = tf.get_default_graph().get_tensor_by_name("online/prediction:0")
19 | ops_dict["input"] = tf.get_default_graph().get_tensor_by_name("online/inputs:0")
20 | ops_dict["valid_actions"] = tf.get_default_graph().get_tensor_by_name("online/valid_actions:0")
21 | return ops_dict
22 |
23 | def predict(self, states):
24 | """
25 | Feeds state into model and returns current predicted Q-values.
26 | Args:
27 | states (list of DraftStates): states to predict from
28 | Returns:
29 | predicted_Q (numpy array): model estimates of Q-values for actions from input states.
30 | predicted_Q[k,:] holds Q-values for state states[k]
31 | """
32 | inputs = [state.format_state() for state in states]
33 | valid_actions = [state.get_valid_actions() for state in states]
34 |
35 | feed_dict = {self.ops_dict["input"]:inputs,
36 | self.ops_dict["valid_actions"]:valid_actions}
37 | predicted_Q = self.sess.run(self.ops_dict["predict_q"], feed_dict=feed_dict)
38 | return predicted_Q
39 |
40 | def predict_action(self, states):
41 | """
42 | Feeds state into model and return recommended action to take from input state based on estimated Q-values.
43 | Args:
44 | state (list of DraftStates): states to predict from
45 | Returns:
46 | predicted_action (numpy array): array of integer representations of actions recommended by model.
47 | """
48 | inputs = [state.format_state() for state in states]
49 | valid_actions = [state.get_valid_actions() for state in states]
50 |
51 | feed_dict = {self.ops_dict["input"]:inputs,
52 | self.ops_dict["valid_actions"]:valid_actions}
53 | predicted_actions = self.sess.run(self.ops_dict["prediction"], feed_dict=feed_dict)
54 | return predicted_actions
55 |
56 | class SoftmaxInferenceModel(base_model.BaseModel):
57 | def __init__(self, name, path):
58 | super().__init__(name=name, path=path)
59 | self.init_saver()
60 | self.ops_dict = self.build_model()
61 |
62 | def init_saver(self):
63 | with self._graph.as_default():
64 | self.saver = tf.train.import_meta_graph("{path}.ckpt.meta".format(path=self._path_to_model))
65 | self.saver.restore(self.sess,"{path}.ckpt".format(path=self._path_to_model))
66 |
67 | def build_model(self):
68 | ops_dict = {}
69 | with self._graph.as_default():
70 | ops_dict["probabilities"] = tf.get_default_graph().get_tensor_by_name("softmax/action_probabilites:0")
71 | ops_dict["prediction"] = tf.get_default_graph().get_tensor_by_name("softmax/predictions:0")
72 | ops_dict["input"] = tf.get_default_graph().get_tensor_by_name("softmax/inputs:0")
73 | ops_dict["valid_actions"] = tf.get_default_graph().get_tensor_by_name("softmax/valid_actions:0")
74 | return ops_dict
75 |
76 | def predict(self, states):
77 | """
78 | Feeds state into model and returns current predicted probabilities.
79 | Args:
80 | states (list of DraftStates): states to predict from
81 | Returns:
82 | probabilities (numpy array): model estimates of probabilities for actions from input states.
83 | probabilities[k,:] holds Q-values for state states[k]
84 | """
85 | inputs = [state.format_state() for state in states]
86 | valid_actions = [state.get_valid_actions() for state in states]
87 |
88 | feed_dict = {self.ops_dict["input"]:inputs,
89 | self.ops_dict["valid_actions"]:valid_actions}
90 | probabilities = self.sess.run(self.ops_dict["probabilities"], feed_dict=feed_dict)
91 | return probabilities
92 |
93 | def predict_action(self, states):
94 | """
95 | Feeds state into model and return recommended action to take from input state based on estimated Q-values.
96 | Args:
97 | state (list of DraftStates): states to predict from
98 | Returns:
99 | predicted_action (numpy array): array of integer representations of actions recommended by model.
100 | """
101 | inputs = [state.format_state() for state in states]
102 | valid_actions = [state.get_valid_actions() for state in states]
103 |
104 | feed_dict = {self.ops_dict["input"]:inputs,
105 | self.ops_dict["valid_actions"]:valid_actions}
106 | predicted_actions = self.sess.run(self.ops_dict["prediction"], feed_dict=feed_dict)
107 | return predicted_actions
108 |
--------------------------------------------------------------------------------
/src/update_match_data.py:
--------------------------------------------------------------------------------
1 | import luigi
2 | import requests
3 | import json
4 | import time
5 | import sqlite3
6 | from data.create_database import create_tables
7 | import data.database_ops as dbo
8 | from data.query_wiki import query_wiki
9 |
10 | class CreateMatchDB(luigi.Task):
11 | path_to_db = luigi.Parameter(default="../data/competitiveMatchData.db")
12 |
13 | def output(self):
14 | return luigi.LocalTarget(self.path_to_db)
15 |
16 | def run(self):
17 | tableNames = ["game", "pick", "ban", "team"]
18 |
19 | columnInfo = []
20 | # Game table columns
21 | columnInfo.append(["id INTEGER PRIMARY KEY",
22 | "tournament TEXT","tourn_game_id INTEGER", "week INTEGER", "patch TEXT",
23 | "blue_teamid INTEGER NOT NULL", "red_teamid INTEGER NOT NULL",
24 | "winning_team INTEGER"])
25 | # Pick table columns
26 | columnInfo.append(["id INTEGER PRIMARY KEY",
27 | "game_id INTEGER", "champion_id INTEGER","position_id INTEGER",
28 | "selection_order INTEGER", "side_id INTEGER"])
29 | # Ban table columns
30 | columnInfo.append(["id INTEGER PRIMARY KEY",
31 | "game_id INTEGER", "champion_id INTEGER", "selection_order INTEGER", "side_id INTEGER"])
32 | # Team table columns
33 | columnInfo.append(["id INTEGER PRIMARY KEY",
34 | "region TEXT", "display_name TEXT"])
35 |
36 | conn = sqlite3.connect(self.path_to_db)
37 | cur = conn.cursor()
38 | print("Creating tables..")
39 | _ = create_tables(cur, tableNames, columnInfo, clobber = True)
40 | conn.close()
41 |
42 | return 1
43 |
44 | def validate_match_data(match_data):
45 | """
46 | validate_match_data performs basic match data validation by examining the following:
47 | 1. Number of picks/bans present in data
48 | 2. Presence of dupicate picks/bans
49 | 3. Duplicate roles on a single side
50 |
51 | Args:
52 | match_data (dict): dictionary of formatted match data
53 |
54 | Returns:
55 | bool: True if match_data passes validation checks, False otherwise.
56 | """
57 | NUM_BANS = 10
58 | NUM_PICKS = 10
59 |
60 | is_valid = True
61 | bans = match_data["bans"]["blue"] + match_data["bans"]["red"]
62 | picks = match_data["picks"]["blue"] + match_data["picks"]["red"]
63 | if(len(bans) != NUM_BANS or len(picks)!= NUM_PICKS):
64 | print("Incorrect number of picks and/or bans found! {} picks, {} bans".format(len(picks), len(bans)))
65 | is_valid = False
66 |
67 | # Need to consider edge case where teams fail to submit multiple bans (rare, but possible)
68 | champs = [ban for ban in bans if ban != "none"] + [p for (p,_) in picks]
69 | if len(set(champs)) != len(champs):
70 | print("Duplicate submission(s) encountered.")
71 | counts = {}
72 | for champ in champs:
73 | if champ not in counts:
74 | counts[champ] = 1
75 | else:
76 | counts[champ] += 1
77 | print(sorted([(value, key) for (key, value) in counts.items() if value>1]))
78 | is_valid = False
79 |
80 | for side in ["blue", "red"]:
81 | if len(set([pos for (_,pos) in match_data["picks"][side]])) != len(match_data["picks"][side]):
82 | print("Duplicate position on side {} found.".format(side))
83 | is_valid = False
84 |
85 | return is_valid
86 |
87 | if __name__ == "__main__":
88 | path_to_db = "../data/competitiveMatchData.db"
89 | luigi.run(
90 | cmdline_args=["--path-to-db={}".format(path_to_db)],
91 | main_task_cls=CreateMatchDB,
92 | local_scheduler=True)
93 |
94 | conn = sqlite3.connect(path_to_db)
95 | cur = conn.cursor()
96 |
97 | # deleted_match_ids = [770]
98 | # dbo.delete_game_from_table(cur, game_ids = deleted_match_ids, table_name="pick")
99 | # dbo.delete_game_from_table(cur, game_ids = deleted_match_ids, table_name="ban")
100 |
101 | year = "2018"
102 | schedule = []
103 | regions = ["EU_LCS","NA_LCS","LPL","LMS","LCK"]; tournaments = ["Spring_Season", "Spring_Playoffs", "Summer_Season", "Summer_Playoffs", "Regional_Finals"]
104 | schedule.append((regions,tournaments))
105 | regions = ["NA_ACA","KR_CHAL"]; tournaments = ["Spring_Season", "Spring_Playoffs", "Summer_Season", "Summer_Playoffs"]
106 | schedule.append((regions,tournaments))
107 | regions = ["LDL"]; tournaments = ["Spring_Playoffs", "Grand_Finals"]
108 | schedule.append((regions,tournaments))
109 |
110 | NUM_BANS = 10
111 | NUM_PICKS = 10
112 | for regions, tournaments in schedule:
113 | for region in regions:
114 | for tournament in tournaments:
115 | skip_commit = False
116 | print("Querying: {}".format(year+"/"+region+"/"+tournament))
117 | gameData = query_wiki(year, region, tournament)
118 | print("Found {} games.".format(len(gameData)))
119 | for i,game in enumerate(gameData):
120 | is_valid = validate_match_data(game)
121 | if not is_valid:
122 | skip_commit = True
123 | print("Errors in match: h_id {} tourn_g_id {}: {} vs {}".format(game["header_id"], game["tourn_game_id"], game["blue_team"], game["red_team"]))
124 |
125 | if(not skip_commit):
126 | print("Attempting to insert {} games..".format(len(gameData)))
127 | status = dbo.insert_team(cur,gameData)
128 | status = dbo.insert_game(cur,gameData)
129 | status = dbo.insert_ban(cur,gameData)
130 | status = dbo.insert_pick(cur,gameData)
131 | print("Committing changes to db..")
132 | conn.commit()
133 | else:
134 | print("Errors found in match data.. skipping commit")
135 | raise
136 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import json
4 | import sqlite3
5 | import matplotlib.pyplot as plt
6 | import time
7 |
8 | from features.draftstate import DraftState
9 | import data.champion_info as cinfo
10 | import features.match_processing as mp
11 | from data.match_pool import test_train_split
12 | import data.database_ops as dbo
13 |
14 | from models import qNetwork, softmax
15 | from trainer import DDQNTrainer, SoftmaxTrainer
16 | from models.inference_model import QNetInferenceModel, SoftmaxInferenceModel
17 |
18 | import tensorflow as tf
19 |
20 | print("")
21 | print("********************************")
22 | print("** Beginning Swain Bot Run! **")
23 | print("********************************")
24 |
25 | valid_champ_ids = cinfo.get_champion_ids()
26 | print("Number of valid championIds: {}".format(len(valid_champ_ids)))
27 |
28 | LIST_PATH = None#"../data/test_train_split.txt"
29 | LIST_SAVE_PATH = "../data/test_train_split.txt"
30 | PATH_TO_DB = "../data/competitiveMatchData.db"
31 | MODEL_DIR = "../models/"
32 | N_TRAIN = 173
33 | N_VAL = 20
34 | PATCHES = None
35 | PRUNE_PATCHES = None
36 | result = test_train_split(N_TRAIN, N_VAL, PATH_TO_DB, LIST_PATH, LIST_SAVE_PATH)
37 |
38 | validation_ids = result["validation_ids"]
39 | training_ids = result["training_ids"]
40 | print("Found {} training matches and {} validation matches in pool.".format(len(training_ids), len(validation_ids)))
41 |
42 | validation_matches = dbo.get_matches_by_id(validation_ids, PATH_TO_DB)
43 |
44 | print("***")
45 | print("Displaying Validation matches:")
46 | count = 0
47 | for match in validation_matches:
48 | count += 1
49 | print("Match: {:2} id: {:4} {:6} vs {:6} winner: {:2}".format(count, match["id"], match["blue_team"], match["red_team"], match["winner"]))
50 | for team in ["blue", "red"]:
51 | bans = match[team]["bans"]
52 | picks = match[team]["picks"]
53 | pretty_bans = []
54 | pretty_picks = []
55 | for ban in bans:
56 | pretty_bans.append(cinfo.champion_name_from_id(ban[0]))
57 | for pick in picks:
58 | pretty_picks.append((cinfo.champion_name_from_id(pick[0]), pick[1]))
59 | print("{} bans:{}".format(team, pretty_bans))
60 | print("{} picks:{}".format(team, pretty_picks))
61 | print("")
62 | print("***")
63 |
64 | # Network parameters
65 | state = DraftState(DraftState.BLUE_TEAM, valid_champ_ids)
66 | input_size = state.format_state().shape
67 | output_size = state.num_actions
68 | filter_size = (1024,1024)
69 | regularization_coeff = 7.5e-5#1.5e-4
70 | path_to_model = None#"model_predictions/spring_2018/week_3/model_E{}.ckpt".format(30)#None
71 | load_path = None#"tmp/ddqn_model_E45.ckpt"
72 |
73 | # Training parameters
74 | batch_size = 16#32
75 | buffer_size = 4096#2048
76 | n_epoch = 45
77 | discount_factor = 0.9
78 | learning_rate = 1.0e-4#2.0e-5#
79 | time.sleep(2.)
80 | for i in range(1):
81 | training_matches = dbo.get_matches_by_id(training_ids, PATH_TO_DB)
82 | print("Learning on {} matches for {} epochs. lr {:.4e} reg {:4e}".format(len(training_matches),n_epoch, learning_rate, regularization_coeff),flush=True)
83 | break
84 |
85 | tf.reset_default_graph()
86 | name = "softmax"
87 | out_path = "{}{}_model_E{}.ckpt".format(MODEL_DIR, name, n_epoch)
88 | softnet = softmax.SoftmaxNetwork(name, out_path, input_size, output_size, filter_size, learning_rate, regularization_coeff)
89 | trainer = SoftmaxTrainer(softnet, n_epoch, training_matches, validation_matches, batch_size, load_path=None)
90 | summaries = trainer.train()
91 |
92 | tf.reset_default_graph()
93 | name = "ddqn"
94 | out_path = "{}{}_model_E{}.ckpt".format(MODEL_DIR, name, n_epoch)
95 | ddqn = qNetwork.Qnetwork(name, out_path, input_size, output_size, filter_size, learning_rate, regularization_coeff, discount_factor)
96 | trainer = DDQNTrainer(ddqn, n_epoch, training_matches, validation_matches, batch_size, buffer_size, load_path)
97 | summaries = trainer.train()
98 |
99 | print("Learning complete!")
100 | print("..final training accuracy: {:.4f}".format(summaries["train_acc"][-1]))
101 | x = [i+1 for i in range(len(summaries["loss"]))]
102 | fig = plt.figure()
103 | plt.plot(x,summaries["loss"])
104 | plt.ylabel('loss')
105 | plt.xlabel('epoch')
106 | #plt.ylim([0,2])
107 | fig_name = "tmp/loss_figures/annuled_rate/loss_E{}_run_{}.pdf".format(n_epoch,i+1)
108 | print("Loss figure saved in:{}".format(fig_name),flush=True)
109 | fig.savefig(fig_name)
110 |
111 | fig = plt.figure()
112 | plt.plot(x, summaries["train_acc"], x, summaries["val_acc"])
113 | fig_name = "tmp/acc_figs/acc_E{}_run_{}.pdf".format(n_epoch,i+1)
114 | print("Accuracy figure saved in:{}".format(fig_name),flush=True)
115 | fig.savefig(fig_name)
116 |
117 |
118 | # Look at predicted Q values for states in a randomly drawn match
119 | match = random.sample(training_matches,1)[0]
120 | team = DraftState.RED_TEAM if match["winner"]==1 else DraftState.BLUE_TEAM
121 | experiences = mp.process_match(match,team)
122 | count = 0
123 | # x labels for q val plots
124 | xticks = []
125 | xtick_locs = []
126 | for a in range(state.num_actions):
127 | cid,pos = state.format_action(a)
128 | if cid not in xticks:
129 | xticks.append(cid)
130 | xtick_locs.append(a)
131 | xtick_labels = [cinfo.champion_name_from_id(cid)[:6] for cid in xticks]
132 |
133 | tf.reset_default_graph()
134 | #path_to_model = "../models/ddqn_model_E{}".format(45)#"tmp/ddqn_model_E45"#"tmp/model_E{}".format(n_epoch)
135 | #model = QNetInferenceModel(name="infer", path=path_to_model)
136 | path_to_model = "../models/softmax_model_E{}".format(45)#"tmp/ddqn_model_E45"#"tmp/model_E{}".format(n_epoch)
137 | model = SoftmaxInferenceModel(name="infer", path=path_to_model)
138 |
139 | for exp in experiences:
140 | state,act,rew,next_state = exp
141 | cid,pos = act
142 | if cid == None:
143 | continue
144 | count += 1
145 | form_act = state.get_action(cid,pos)
146 | pred_act = model.predict_action([state])
147 | pred_act = pred_act[0]
148 | pred_Q = model.predict([state])
149 | pred_Q = pred_Q[0,:]
150 |
151 | p_cid,p_pos = state.format_action(pred_act)
152 | actual = (cinfo.champion_name_from_id(cid),pos,pred_Q[form_act])
153 | pred = (cinfo.champion_name_from_id(p_cid),p_pos,pred_Q[pred_act])
154 | print("pred:{}, actual:{}".format(pred,actual))
155 |
156 | # Plot Q-val figure
157 | fig = plt.figure(figsize=(25,5))
158 | plt.ylabel('$Q(s,a)$')
159 | plt.xlabel('$a$')
160 | plt.xticks(xtick_locs, xtick_labels, rotation=70)
161 | plt.tick_params(axis='x',which='both',labelsize=6)
162 | x = np.arange(len(pred_Q))
163 | plt.bar(x,pred_Q, align='center',alpha=0.8,color='b')
164 | plt.bar(pred_act, pred_Q[pred_act],align='center',color='r')
165 | plt.bar(form_act, pred_Q[form_act],align='center',color='g')
166 |
167 | fig_name = "tmp/qval_figs/{}.pdf".format(count)
168 | fig.savefig(fig_name)
169 |
170 | print("")
171 | print("********************************")
172 | print("** Ending Swain Bot Run! **")
173 | print("********************************")
174 |
--------------------------------------------------------------------------------
/src/data/match_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import json
3 | import sqlite3
4 | from .database_ops import get_matches_by_id, get_game_ids, get_match_data, get_game_ids_by_tournament, get_tournament_data
5 |
6 | def test_train_split(n_training, n_validation, path_to_db, list_path=None, save_path=None, match_sources=None, prune_patches=None):
7 | """
8 | test_train_split returns a match_ids split into two nominal groups: a training set and a test set.
9 | Args:
10 | n_training (int): number of match_ids in training split
11 | n_test (int): number of match_ids in test_train_split
12 | path_to_db (str): path to database containing match data
13 | list_path (str, optional): path to existing match ids to either grow or pruned
14 | save_path (str, optional): path to save split match ids
15 | match_sources (dict, optional): dictionary containing "patches" and "tournaments" keys containing lists of tournament and patch identifiers to use for match sources
16 | prune_patches (list, optional): list of patches to prune from existing split match ids
17 |
18 | Returns:
19 | Dictionary {"training_ids":list(int),"validation_ids":list(int)}
20 | """
21 | save_match_pool = False
22 | validation_ids = []
23 | training_ids = []
24 | if list_path:
25 | print("Building list off of match data in {}.".format(list_path))
26 | with open(list_path,'r') as infile:
27 | data = json.load(infile)
28 | validation_ids = data["validation_ids"]
29 | training_ids = data["training_ids"]
30 | if prune_patches:
31 | pre_prune_id_count = len(validation_ids)+len(training_ids)
32 | validation_ids = prune_match_list(validation_ids, path=path_to_db, patches=prune_patches)
33 | training_ids = prune_match_list(training_ids, path=path_to_db, patches=prune_patches)
34 | post_prune_id_count = len(validation_ids)+len(training_ids)
35 | save_match_pool = True
36 | print("Pruned {} matches from the match list".format(pre_prune_id_count-post_prune_id_count))
37 |
38 | val_diff = max([n_validation - len(validation_ids),0])
39 | train_diff = max([n_training - len(training_ids),0])
40 |
41 | current = []
42 | current.extend(validation_ids)
43 | current.extend(training_ids)
44 |
45 | count = val_diff + train_diff
46 | if(count > 0):
47 | new_matches = grow_pool(count, current, path_to_db, match_sources)
48 | if(val_diff>0):
49 | print("Insufficient number of validation matches. Attempting to add difference..")
50 | validation_ids.extend(new_matches[:val_diff])
51 | if(train_diff>0):
52 | print("Insufficient number of training matches. Attempting to add difference..")
53 | training_ids.extend(new_matches[val_diff:count])
54 | print("Successfully added {} matches to validation and {} matches to training.".format(val_diff, train_diff))
55 | save_match_pool = True
56 |
57 | if(save_match_pool and save_path):
58 | print("Saving pool to {}..".format(save_path))
59 | with open(save_path,'w') as outfile:
60 | json.dump({"training_ids":training_ids,"validation_ids":validation_ids},outfile)
61 |
62 | return {"training_ids":training_ids,"validation_ids":validation_ids}
63 |
64 | def grow_pool(count, current_pool, path_to_db, match_sources=None):
65 | total = match_pool(0, path_to_db, randomize=False, match_sources=match_sources)["match_ids"]
66 | new = list(set(total)-set(current_pool))
67 | assert(len(new) >= count), "Not enough new matches to match required count! avail: {} needed: {}".format(len(new), count)
68 | random.shuffle(new)
69 |
70 | return new[:count]
71 |
72 | def prune_match_list(match_ids, path_to_db, patches=None):
73 | """
74 | Prunes match list by removing matches played on specified patches.
75 | """
76 | matches = get_matches_by_id(match_ids, path_to_db)
77 | pruned_match_list = []
78 | for match in matches:
79 | patch = match["patch"]
80 | if patch not in patches:
81 | pruned_match_list.append(match["id"])
82 | return pruned_match_list
83 |
84 | def match_pool(num_matches, path_to_db, randomize=True, match_sources=None):
85 | """
86 | Args:
87 | num_matches (int): Number of matches to include in the queue (0 indicates to use the maximum number of matches available)
88 | path_do_db (str): Path to match database to query against
89 | randomize (bool): Flag for randomizing order of output matches.
90 | match_sources (dict(string)): Dict containing "tournaments" and "patches" keys to use when building pool, if None, defaults to using patches/tournaments in data/match_sources.json
91 | Returns:
92 | match_data (dictionary): dictionary containing two keys:
93 | "match_ids": list of match_ids for pooled matches
94 | "matches": list of pooled match data to process
95 |
96 | Builds a set of matchids and match data used during learning phase. If randomize flag is set
97 | to false this returns the first num_matches in order according to match_sources.
98 | """
99 | if(match_sources is None):
100 | with open("../data/match_sources.json") as infile:
101 | data = json.load(infile)
102 | patches = data["match_sources"]["patches"]
103 | tournaments = data["match_sources"]["tournaments"]
104 | else:
105 | patches = match_sources["patches"]
106 | tournaments = match_sources["tournaments"]
107 |
108 | # If patches or tournaments is empty, grab matches from all patches from specified tournaments or all tournaments from specified matches
109 | if not patches:
110 | patches = [None]
111 | if not tournaments:
112 | tournaments = [None]
113 |
114 | match_pool = []
115 | conn = sqlite3.connect(path_to_db)
116 | cur = conn.cursor()
117 | # Build list of eligible match ids
118 | for patch in patches:
119 | for tournament in tournaments:
120 | game_ids = get_game_ids(cur, tournament, patch)
121 | match_pool.extend(game_ids)
122 |
123 | print("Number of available matches for training={}".format(len(match_pool)))
124 | if(num_matches == 0):
125 | num_matches = len(match_pool)
126 | assert num_matches <= len(match_pool), "Not enough matches found to sample!"
127 | if(randomize):
128 | selected_match_ids = random.sample(match_pool, num_matches)
129 | else:
130 | selected_match_ids = match_pool[:num_matches]
131 |
132 | selected_matches = []
133 | for match_id in selected_match_ids:
134 | match = get_match_data(cur, match_id)
135 | selected_matches.append(match)
136 | conn.close()
137 | return {"match_ids":selected_match_ids, "matches":selected_matches}
138 |
139 | if __name__ == "__main__":
140 | match_sources = {"patches":[], "tournaments": ["2018/NA/Summer_Season"]}
141 | path_to_db = "../data/competitiveMatchData.db"
142 | out_path = "../data/test_train_split.txt"
143 | res = test_train_split(20, 20, path_to_db, list_path=out_path, save_path=out_path, match_sources=match_sources, prune_patches=None)
144 | print(res)
145 |
--------------------------------------------------------------------------------
/src/data/champion_info.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | #from cassiopeia import riotapi
3 | from .riotapi import make_request, api_versions
4 | import requests
5 | import re
6 | import json
7 | from . import myRiotApiKey
8 |
9 | # Box is a vacant class with no initial members. This will be used to hold the champion_id list and champion_id <-> name dictionaries.
10 |
11 | class Box:
12 | pass
13 | __m = Box()
14 | __m.champion_name_from_id = None
15 | __m.champion_id_from_name = None
16 | __m.valid_champion_ids = None
17 | __m.championAliases = {
18 | "blitz": "blitzcrank",
19 | "gp": "gangplank",
20 | "jarvan": "jarvaniv",
21 | "cait": "caitlyn",
22 | "lb": "leblanc",
23 | "cass": "cassiopeia",
24 | "casiopeia": "cassiopeia",
25 | "ori": "orianna",
26 | "lee": "leesin",
27 | "vlad": "vladimir",
28 | "j4": "jarvaniv",
29 | "as": "aurelionsol", # who the fuck thinks this is unique?
30 | "kass": "kassadin",
31 | "tk": "tahmkench",
32 | "malz": "malzahar",
33 | "sej": "sejuani",
34 | "nid": "nidalee",
35 | "aurelion": "aurelionsol",
36 | "mundo": "drmundo",
37 | "tahm": "tahmkench",
38 | "kayne": "kayn",
39 | "zil": "zilean",
40 | "naut": "nautilus",
41 | "morg": "morgana",
42 | "ez": "ezreal",
43 | "nunu": "nunuwillump",
44 | "yi": "masteryi",
45 | "cho":"chogath",
46 | "mao":"maokai",
47 | "morde":"mordekaiser",
48 | "xin":"xinzhao",
49 | "eve":"evelynn",
50 | "sol":"aurelionsol",
51 | "fid":"fiddlesticks",
52 | "fiddle":"fiddlesticks",
53 | "heimer":"heimerdinger",
54 | "noc":"nocturne",
55 | "tf":"twistedfate",
56 | "kog":"kogmaw"
57 | }
58 |
59 | # This is a flag to make champion_info methods look for data stored locally
60 | # rather than query the API. Useful if API is out.
61 | look_local = True
62 | LOCAL_CHAMPION_PATH = "../data/champions.json"
63 |
64 | class Champion():
65 | def __init__(self,dictionary):
66 | if(look_local):
67 | # Local file is a cached query to data dragon
68 | # Data dragon reverses the meaning of keys and ids from the API.
69 | self.key = dictionary["id"]
70 | self.id = int(dictionary["key"])
71 | else:
72 | self.key = dictionary["key"]
73 | self.id = int(dictionary["id"])
74 | self.name = dictionary["name"]
75 | self.title = dictionary["title"]
76 |
77 | class AliasException(Exception):
78 | def __init__(self, message, errors):
79 | super().__init__(message)
80 | self.errors = errors
81 | self.message = message
82 |
83 | def convert_champion_alias(alias):
84 | """
85 | Args:
86 | alias (string): lowercase and pruned string alias for a champion
87 | Returns:
88 | name (string): lowercase and pruned string name for champion
89 |
90 | convert_champion_alias converts a given champion alias (ie "blitz")
91 | and returns the version of that champions proper name which is suitable for passing to
92 | champion_id_from_name(). If no such alias can be found, this raises an AliasException.
93 |
94 | Example: name = convert_champion_alias("blitz") will yield name = "blitzcrank"
95 | """
96 | null_champion = ["none","lossofban"]
97 | if (alias in null_champion):
98 | return "none"
99 | try:
100 | if (alias in __m.championAliases):
101 | return __m.championAliases[alias]
102 | else:
103 | raise AliasException("Champion alias not found!", alias)
104 | except AliasException as e:
105 | print("*****")
106 | print(e.message)
107 | print("Offending alias: {}".format(e.errors))
108 | print("*****")
109 | raise
110 |
111 | def champion_name_from_id(champion_id):
112 | """
113 | Args:
114 | champion_id (int): Integer Id corresponding to the desired champion name.
115 | Returns:
116 | name (string): String name of requested champion. If no such champion can be found, returns NULL
117 |
118 | champion_name_from_id takes a requested champion_id number and returns the string name of that champion using a champion_name_from_id dictionary.
119 | If the dictonary has not yet been populated, this creates the dictionary using cassiopeia's interface to Riot's API.
120 | """
121 | if __m.champion_name_from_id is None:
122 | populate_champion_dictionary()
123 |
124 | if (champion_id in __m.champion_name_from_id):
125 | return __m.champion_name_from_id[champion_id]
126 | return None
127 |
128 | def champion_id_from_name(champion_name):
129 | """
130 | Args:
131 | champion_name (string): lowercase and pruned string label corresponding to the desired champion id.
132 | Returns:
133 | id (int): id of requested champion. If no such champion can be found, returns NULL
134 |
135 | champion_id_from_name takes a requested champion name and returns the id label of that champion using a champion_id_from_name dictionary.
136 | If the dictonary has not yet been populated, this creates the dictionary using cassiopeia's interface to Riot's API.
137 | Note that champion_name should be all lowercase and have any non-alphanumeric characters (including whitespace) removed.
138 | """
139 | if __m.champion_id_from_name is None:
140 | populate_champion_dictionary()
141 |
142 | if (champion_name in __m.champion_id_from_name):
143 | return __m.champion_id_from_name[champion_name]
144 | return None
145 |
146 | def valid_champion_id(champion_id):
147 | """
148 | Checks to see if champion_id corresponds to a valid champion id code.
149 | Returns: True if champion_id is valid. False otherwise.
150 | Args:
151 | champion_id (int): Id of champion to be verified.
152 | """
153 | if __m.champion_name_from_id is None:
154 | populate_champion_dictionary()
155 |
156 | return champion_id in __m.valid_champion_ids
157 |
158 | def get_champion_ids():
159 | """
160 | Returns a sorted list of valid champion IDs.
161 | Args:
162 | None
163 | Returns:
164 | validIds (list(ints)): sorted list of valid champion IDs.
165 | """
166 | if __m.valid_champion_ids is None:
167 | populate_champion_dictionary()
168 |
169 | return __m.valid_champion_ids[:]
170 |
171 | def populate_champion_dictionary():
172 | """
173 | Args:
174 | None
175 | Returns:
176 | True if succesful, False otherwise
177 | Populates the module dictionary whose keys are champion Ids and values are strings of the corresponding champion's name.
178 | """
179 | #riotapi.set_region("NA")
180 | #riotapi.set_api_key(myRiotApiKey.api_key)
181 | #champions = riotapi.get_champions()
182 | DISABLED_CHAMPIONS = []
183 | if(look_local):
184 | with open(LOCAL_CHAMPION_PATH, 'r') as local_data:
185 | response = json.load(local_data)
186 | else:
187 | request = "{static}/{version}/champions".format(static="static-data",version=api_versions["staticdata"])
188 | params = {"locale":"en_US", "dataById":"true", "api_key":myRiotApiKey.api_key }
189 | response = make_request(request,"GET",params)
190 | data = response["data"]
191 | champions = []
192 | for value in data.values():
193 | if(value["name"] in DISABLED_CHAMPIONS):
194 | continue
195 | champion = Champion(value)
196 | champions.append(champion)
197 |
198 | __m.champion_name_from_id = {champion.id: champion.name for champion in champions}
199 | __m.champion_id_from_name = {re.sub("[^A-Za-z0-9]+", "", champion.name.lower()): champion.id for champion in champions}
200 | __m.valid_champion_ids = sorted(__m.champion_name_from_id.keys())
201 | if not __m.champion_name_from_id:
202 | return False
203 | return True
204 |
205 | def create_Champion_fixture():
206 | valid_ids = get_champion_ids()
207 | champions = []
208 | model = 'predict.Champion'
209 | for cid in valid_ids:
210 | champion = {}
211 | champion["model"] = model
212 | champion["pk"] = cid
213 | fields = {}
214 | fields["id"] = cid
215 | fields["display_name"] = champion_name_from_id(cid)
216 | champion["fields"] = fields
217 | champions.append(champion)
218 | with open('champions_fixture.json','w') as outfile:
219 | json.dump(champions,outfile)
220 |
221 | if __name__ == "__main__":
222 | create_Champion_fixture()
223 |
--------------------------------------------------------------------------------
/src/features/match_processing.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from .draftstate import DraftState
3 | from .rewards import get_reward
4 | from copy import deepcopy
5 |
6 | import random
7 | import json
8 |
9 | def process_match(match, team, augment_data=True):
10 | """
11 | process_match takes an input match and breaks each incremental pick and ban down the draft into experiences (aka "memories").
12 |
13 | Args:
14 | match (dict): match dictionary with pick and ban data for a single game.
15 | team (DraftState.BLUE_TEAM or DraftState.RED_TEAM): The team perspective that is used to process match
16 | The selected team has the positions for each pick explicitly included with the experience while the
17 | "opposing" team has the assigned positions for its champion picks masked.
18 | augment_data (optional) (bool): flag controlling the randomized ordering of submissions that do not affect the draft as a whole
19 | Returns:
20 | experiences ( list(tuple) ): list of experience tuples. Each experience is of the form (s, a, r, s') where:
21 | - s and s' are DraftState states before and after a single action
22 | - a is the (stateIndex, position) tuple of selected champion to be banned or picked. position = 0 for submissions
23 | by the opposing team
24 | - r is the integer reward obtained from submitting the action a
25 |
26 | process_match() can take the vantage from both sides of the draft to parse for memories. This means we can ultimately sample from
27 | both winning drafts (positive reinforcement) and losing drafts (negative reinforcement) when training.
28 | """
29 | experiences = []
30 |
31 | # This section controls data agumentation of the match. Certain submissions in the draft are
32 | # submitted consecutively by the same team during the same phase (ie team1 pick0 -> team1 pick1).
33 | # Although these submissions were produced in a particular order, from a draft perspective
34 | # there is no difference between submissions of the form
35 | # team1 pick0 -> team1 pick1 vs team1 pick1 -> team0 pickA
36 | # provided that the two picks are from the same phase (both bans or both picks).
37 | # Therefore it is possible to augment the order in which these submissions are processed.
38 |
39 | # Note that we can also augment the banning phase if desired. Although these submissions technically
40 | # fall outside of the conditions listed above, in practice bans made in the same phase are
41 | # interchangable in order.
42 |
43 | # Build queue of actions from match reference (augmenting if desired)
44 | augments_list = [
45 | ("blue","bans",slice(0,3)), # Blue bans 0,1,2 are augmentable
46 | ("blue","bans",slice(3,5)), # Blue bans 3,4 are augmentable
47 | ("red","bans",slice(0,3)),
48 | ("red","bans",slice(3,5)),
49 | ("blue","picks",slice(1,3)), # Blue picks 1,2 are augmentable
50 | ("blue","picks",slice(3,5)), # Blue picks 3,4 are augmentable
51 | ("red","picks",slice(0,2)) # Red picks 0,1 are augmentable
52 | ]
53 | if(augment_data):
54 | augmented_match = deepcopy(match) # Deepcopy match to avoid side effects
55 | for aug in augments_list:
56 | (k1,k2,aug_range) = aug
57 | count = len(augmented_match[k1][k2][aug_range])
58 | augmented_match[k1][k2][aug_range] = random.sample(augmented_match[k1][k2][aug_range],count)
59 |
60 | action_queue = build_action_queue(augmented_match)
61 | else:
62 | action_queue = build_action_queue(match)
63 |
64 | # Set up draft state
65 | draft = DraftState(team)
66 |
67 | finish_memory = False
68 | while action_queue:
69 | # Get next pick from deque
70 | submission = action_queue.popleft()
71 | (submitting_team, pick, position) = submission
72 |
73 | # There are two conditions under which we want to finalize a memory:
74 | # 1. Non-designated team has finished submitting picks for this phase (ie next submission belongs to the designated team)
75 | # 2. Draft is complete (no further picks in the draft)
76 | if submitting_team == team:
77 | if finish_memory:
78 | # This is case 1 to store memory
79 | r = get_reward(draft, match, a, a)
80 | s_next = deepcopy(draft)
81 | memory = (s, a, r, s_next)
82 | experiences.append(memory)
83 | finish_memory = False
84 | # Memory starts when upcoming pick belongs to designated team
85 | s = deepcopy(draft)
86 | # Store action = (champIndex, pos)
87 | a = (pick, position)
88 | finish_memory = True
89 | else:
90 | # Mask positions for pick submissions belonging to the non-designated team
91 | if position != -1:
92 | position = 0
93 |
94 | draft.update(pick, position)
95 |
96 | # Once the queue is empty, store last memory. This is case 2 above.
97 | # There is always an outstanding memory at the completion of the draft.
98 | # RED_TEAM always gets last pick. Therefore:
99 | # if team = BLUE_TEAM -> There is an outstanding memory from last RED_TEAM submission
100 | # if team = RED_TEAM -> Memory is open from just before our last submission
101 | if(draft.evaluate() == DraftState.DRAFT_COMPLETE):
102 | assert finish_memory == True
103 | r = get_reward(draft, match, a, a)
104 | s_next = deepcopy(draft)
105 | memory = (s, a, r, s_next)
106 | experiences.append(memory)
107 | else:
108 | print("Week {} match_id {} {} vs {}".format(match["week"], match["id"], match["blue_team"],match["red_team"]))
109 | draft.display()
110 | print("Error code {}".format(draft.evaluate()))
111 | print("Number of experiences {}".format(len(experiences)))
112 | for experience in experiences:
113 | _,a,_,_ = experience
114 | print(a)
115 | print("")#raise
116 |
117 | return experiences
118 |
119 | def build_action_queue(match):
120 | """
121 | Builds queue of champion picks or bans (depending on mode) in selection order. If mode = 'ban' this produces a queue of tuples
122 | Args:
123 | match (dict): dictonary structure of match data to be parsed
124 | Returns:
125 | action_queue (deque(tuple)): deque of pick tuples of the form (side_id, champion_id, position_id).
126 | action_queue is produced in selection order.
127 | """
128 | winner = match["winner"]
129 | action_queue = deque()
130 | phases = {0:{"phase_type":"bans", "pick_order":["blue", "red", "blue", "red", "blue", "red"]}, # phase 1 bans
131 | 1:{"phase_type":"picks", "pick_order":["blue", "red", "red", "blue", "blue", "red"]}, # phase 1 picks
132 | 2:{"phase_type":"bans", "pick_order":["red", "blue", "red", "blue"]}, # phase 2 bans
133 | 3:{"phase_type":"picks","pick_order":["red", "blue", "blue", "red"]}} # phase 2 picks
134 | ban_index = 0
135 | pick_index = 0
136 | completed_actions = 0
137 | for phase in range(4):
138 | phase_type = phases[phase]["phase_type"]
139 | pick_order = phases[phase]["pick_order"]
140 |
141 | num_actions = len(pick_order)
142 | for pick_num in range(num_actions):
143 | side = pick_order[pick_num]
144 | if side == "blue":
145 | side_id = DraftState.BLUE_TEAM
146 | else:
147 | side_id = DraftState.RED_TEAM
148 | if phase_type == "bans":
149 | position_id = -1
150 | index = ban_index
151 | ban_index += pick_num%2 # Order matters here. index needs to be updated *after* use
152 | else:
153 | position_id = match[side][phase_type][pick_index][1]
154 | index = pick_index
155 | pick_index += pick_num%2 # Order matters here. index needs to be updated *after* use
156 | action = (side_id, match[side][phase_type][index][0], position_id)
157 | action_queue.append(action)
158 | completed_actions += 1
159 |
160 | if(completed_actions != 20):
161 | print("Found a match with missing actions!")
162 | print("num_actions = {}".format(num_actions))
163 | print(json.dumps(match, indent=2, sort_keys=True))
164 | return action_queue
165 |
166 | if __name__ == "__main__":
167 | data = build_match_pool(1, patches=["8.3"])
168 | matches = data["matches"]
169 | for match in matches:
170 | print(match["patch"])
171 | for team in [DraftState.BLUE_TEAM, DraftState.RED_TEAM]:
172 | for augment_data in [False, True]:
173 | experiences = process_match(match, team, augment_data)
174 | count = 0
175 | for exp in experiences:
176 | _,a,_,_ = exp
177 | print("{} - {}".format(count,a))
178 | count+=1
179 | print("")
180 |
181 | data = build_match_pool(0, randomize=False, patches=["8.4","8.5"])
182 | # matches = data["matches"]
183 | # for match in matches:
184 | # print("Week {}, Patch {}: {} vs {}. Winner:{}".format(match["week"], match["patch"], match["blue_team"], match["red_team"], match["winner"]))
185 |
--------------------------------------------------------------------------------
/src/model_predictions.py:
--------------------------------------------------------------------------------
1 | import experience_replay as er
2 | import match_processing as mp
3 | import champion_info as cinfo
4 | import draft_db_ops as dbo
5 | from draftstate import DraftState
6 | from models.inference_model import QNetInferenceModel, SoftmaxInferenceModel
7 |
8 | import json
9 | import pandas as pd
10 | import numpy as np
11 | import tensorflow as tf
12 | import sqlite3
13 | import math
14 |
15 | #path_to_model = "model_predictions/spring_2018/week_3/run_2/model_E10"
16 | #path_to_model = "tmp/models/model_E10"
17 | #path_to_model = "tmp/model_E{}".format(45)
18 |
19 | path_to_model = "tmp/ddqn_model_E{}".format(45)
20 | model = QNetInferenceModel(name="ddqn", path=path_to_model)
21 |
22 | #path_to_model = "tmp/softmax_model_E{}".format(45)
23 | #model = SoftmaxInferenceModel(name="softmax", path=path_to_model)
24 | print("***")
25 | print("Loading Model From: {}".format(path_to_model))
26 | print("***")
27 |
28 | out_dir = "model_predictions/dump"
29 | print("***")
30 | print("Outputting To: {}".format(out_dir))
31 | print("***")
32 |
33 | specific_team = None#"tsm"
34 | print("***")
35 | if(specific_team):
36 | print("Looking at drafts by team:{}".format(specific_team))
37 | else:
38 | print("Looking at drafts submitted by winning team")
39 | print("***")
40 |
41 | #with open('worlds_matchids_by_stage.txt','r') as infile:
42 | # data = json.load(infile)
43 | #match_ids = data["groups"]
44 | #match_ids.extend(data["knockouts"])
45 | #match_ids.extend(data["finals"])
46 | #match_ids.extend(data["play_ins_rd1"])
47 | #match_ids.extend(data["play_ins_rd2"])
48 | with open('match_pool.txt','r') as infile:
49 | data = json.load(infile)
50 | match_ids = data['validation_ids']
51 | #match_ids = data['training_ids']
52 | #match_ids.extend(data['training_ids'])
53 | dbName = "competitiveGameData.db"
54 | conn = sqlite3.connect("tmp/"+dbName)
55 | cur = conn.cursor()
56 | #match_ids = dbo.get_game_ids_by_tournament(cur,"2017/INTL/WRLDS")
57 | matches = [dbo.get_match_data(cur,match_id) for match_id in match_ids]
58 | conn.close()
59 | if(specific_team):
60 | matches = [match for match in matches if (match["blue_team"]==specific_team or match["red_team"]==specific_team)]
61 |
62 | count = 0
63 | print("************************")
64 | print("Match Schedule:")
65 | print("************************")
66 | with open("{}/_match_schedule.txt".format(out_dir),'w') as outfile:
67 | outfile.write("************************\n")
68 | for match in matches:
69 | output_string = "Match {:2}: id: {:5} tourn: {:20} game_no: {:3} {:6} vs {:6} winner: {:2}".format(count, match["id"], match["tournament"], match["tourn_game_id"], match["blue_team"], match["red_team"], match["winner"])
70 | print(output_string)
71 | outfile.write(output_string+'\n')
72 | count += 1
73 | outfile.write("************************\n")
74 |
75 | with open("{}/match_data.json".format(out_dir),'w') as outfile:
76 | json.dump(matches,outfile)
77 |
78 | count = 0
79 | k = 5 # Rank to look for in topk range
80 |
81 | full_diag = {"top1":0, "topk":0, "target":0, "l2":[],"k":k}
82 | no_rd1_ban_diag = {"top1":0, "topk":0, "target":0, "l2":[],"k":k}
83 | no_ban_diag = {"top1":0, "topk":0, "target":0, "l2":[],"k":k}
84 | second_phase_only = {"top1":0, "topk":0, "target":0, "l2":[],"k":k}
85 | bans_only = {"top1":0, "topk":0, "target":0, "l2":[],"k":k}
86 | model_diagnostics = {"full":full_diag, "no_rd1_ban":no_rd1_ban_diag, "no_bans":no_ban_diag, "phase_2_only":second_phase_only, "bans":bans_only}
87 |
88 | position_distributions = {"phase_1":[0,0,0,0,0], "phase_2":[0,0,0,0,0]}
89 | actual_pos_distributions = {"phase_1":[0,0,0,0,0], "phase_2":[0,0,0,0,0]}
90 | augmentable_picks = {DraftState.BLUE_TEAM:[0,1,4,6,8], DraftState.RED_TEAM:[0,1,3,6]}
91 | targets = [10,10,10,9,8,7,6,6,6,5]
92 | for match in matches:
93 | # if(specific_team):
94 | # team = DraftState.RED_TEAM if match["red_team"]==specific_team else DraftState.BLUE_TEAM
95 | # else:
96 | # team = DraftState.RED_TEAM if match["winner"]==1 else DraftState.BLUE_TEAM
97 | # teams = [DraftState.BLUE_TEAM, DraftState.RED_TEAM]
98 | teams = [DraftState.RED_TEAM if match["winner"]==1 else DraftState.BLUE_TEAM]
99 | for team in teams:
100 |
101 | experiences = mp.process_match(match, team, augment_data=False)
102 |
103 | print("")
104 | print("Match: {:2} {:6} vs {:6} winner: {:2}".format(count, match["blue_team"], match["red_team"], match["winner"]))
105 | for pick_count, exp in enumerate(experiences):
106 | print(" === ")
107 | print(" Match {}, Pick {}".format(count, pick_count))
108 | print(" === ")
109 | state,act,rew,next_state = exp
110 | cid,pos = act
111 | if cid == None:
112 | continue
113 |
114 | predicted_q_values = model.predict([state])
115 | predicted_q_values = predicted_q_values[0,:]
116 | submitted_action_id = state.get_action(*act)
117 |
118 | data = [(a,*state.format_action(a),predicted_q_values[a]) for a in range(len(predicted_q_values))]
119 | data = [(a,cinfo.champion_name_from_id(cid),pos,Q) for (a,cid,pos,Q) in data]
120 | df = pd.DataFrame(data, columns=['act_id','cname','pos','Q(s,a)'])
121 |
122 | df.sort_values('Q(s,a)',ascending=False,inplace=True)
123 | df.reset_index(drop=True,inplace=True)
124 |
125 | df['rank'] = df.index
126 | df['error'] = abs(df['Q(s,a)'][0] - df['Q(s,a)'])/abs(df['Q(s,a)'][0])
127 |
128 | submitted_row = df[df['act_id']==submitted_action_id]
129 | print(" Submitted action:")
130 | print(submitted_row)
131 |
132 | rank = submitted_row['rank'].iloc[0]
133 | err = submitted_row['error'].iloc[0]
134 |
135 | # For picks submitted back-to-back look ahead to next action to see if it was possibly recommended
136 | if (rank >= k and pick_count in augmentable_picks[team]):#if False:
137 | _,next_action,_,_ = experiences[pick_count+1]
138 | cid,_ = next_action
139 | if(cid):
140 | next_action_id = state.get_action(*next_action)
141 | next_row = df[df['act_id']==next_action_id]
142 | next_rank = next_row['rank'].iloc[0]
143 | if(next_rank < k):
144 | result = state.update(*next_action)
145 | new_exp = (state, act, rew, None)
146 | experiences[pick_count+1] = new_exp
147 | rank = next_rank
148 | print(" AUGMENTED ACTION:")
149 | print(next_row)
150 |
151 | t = targets[pick_count]
152 | # Norms measuring all submissions
153 | if(rank == 0):
154 | model_diagnostics["full"]["top1"] += 1
155 | if(rank < t):
156 | model_diagnostics["full"]["target"] += 1
157 | if(rank < k):
158 | model_diagnostics["full"]["topk"] += 1
159 | model_diagnostics["full"]["l2"].append(err)
160 |
161 | # Norms excluding round 1 bans
162 | if(pick_count > 2):
163 | if(rank == 0):
164 | model_diagnostics["no_rd1_ban"]["top1"] += 1
165 | if(rank < t):
166 | model_diagnostics["no_rd1_ban"]["target"] += 1
167 | if(rank < k):
168 | model_diagnostics["no_rd1_ban"]["topk"] += 1
169 | model_diagnostics["no_rd1_ban"]["l2"].append(err)
170 |
171 | # Norms excluding round 1 completely
172 | if(pick_count > 5):
173 | if(rank == 0):
174 | model_diagnostics["phase_2_only"]["top1"] += 1
175 | if(rank < t):
176 | model_diagnostics["phase_2_only"]["target"] += 1
177 | if(rank < k):
178 | model_diagnostics["phase_2_only"]["topk"] += 1
179 | model_diagnostics["phase_2_only"]["l2"].append(err)
180 |
181 | # Norms excluding all bans
182 | if(pos != -1):
183 | if(rank == 0):
184 | model_diagnostics["no_bans"]["top1"] += 1
185 | if(rank < t):
186 | model_diagnostics["no_bans"]["target"] += 1
187 | if(rank < k):
188 | model_diagnostics["no_bans"]["topk"] += 1
189 | model_diagnostics["no_bans"]["l2"].append(err)
190 |
191 | # Norms for bans only
192 | if(pos == -1):
193 | if(rank == 0):
194 | model_diagnostics["bans"]["top1"] += 1
195 | if(rank < t):
196 | model_diagnostics["bans"]["target"] += 1
197 | if(rank < k):
198 | model_diagnostics["bans"]["topk"] += 1
199 | model_diagnostics["bans"]["l2"].append(err)
200 |
201 | if(rank >= t):
202 | print(" Top predictions:")
203 | print(df.head()) # Print top 5 choices for network
204 | #df.to_pickle("{}/match{}_pick{}.pkl".format(out_dir,count,pick_count))
205 |
206 | # Position distribution for picks
207 | if(pos > 0):
208 | top_pos = df.head()["pos"].values.tolist()
209 | if(pick_count <=5):
210 | actual_pos_distributions["phase_1"][pos-1] += 1
211 | for pos in top_pos:
212 | position_distributions["phase_1"][pos-1] += 1
213 | else:
214 | actual_pos_distributions["phase_2"][pos-1] += 1
215 | for pos in top_pos:
216 | position_distributions["phase_2"][pos-1] += 1
217 |
218 | pick_count += 1
219 | count += 1
220 |
221 | print("******************")
222 | print("Pick position distributions:")
223 | for phase in ["phase_1", "phase_2"]:
224 | print("{}: Recommendations".format(phase))
225 | count = sum(position_distributions[phase])
226 | for pos in range(len(position_distributions[phase])):
227 | pos_ratio = position_distributions[phase][pos] / count
228 | print(" Position {}: Count {:3}, Ratio {:.3}".format(pos+1, position_distributions[phase][pos], pos_ratio))
229 |
230 | print("{}: Actual".format(phase))
231 | count = sum(actual_pos_distributions[phase])
232 | for pos in range(len(actual_pos_distributions[phase])):
233 | pos_ratio = actual_pos_distributions[phase][pos] / count
234 | print(" Position {}: Count {:3}, Ratio {:.3}".format(pos+1, actual_pos_distributions[phase][pos], pos_ratio))
235 |
236 | print("******************")
237 | print("Norm Information:")
238 | for key in sorted(model_diagnostics.keys()):
239 | print(" {}".format(key))
240 | err_list = model_diagnostics[key]["l2"]
241 | err = math.sqrt((sum([e**2 for e in err_list])/len(err_list)))
242 | num_predictions = len(err_list)
243 | top1 = model_diagnostics[key]["top1"]
244 | topk = model_diagnostics[key]["topk"]
245 | target = model_diagnostics[key]["target"]
246 | k = model_diagnostics[key]["k"]
247 | print(" Num_predictions = {}".format(num_predictions))
248 | print(" top 1: count {} -> acc: {:.4}".format(top1, top1/num_predictions))
249 | print(" top {}: count {} -> acc: {:.4}".format(k, topk, topk/num_predictions))
250 | print(" target: count {} -> acc: {:.4}".format(target, target/num_predictions))
251 | print(" l2 error: {:.4}".format(err))
252 | print("---")
253 | print("******************")
254 |
--------------------------------------------------------------------------------
/src/models/qNetwork.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | from . import base_model
5 |
6 | class Qnetwork(base_model.BaseModel):
7 | """
8 | Args:
9 | name (string): label for model namespace
10 | path (string): path to save/load model
11 | input_shape (tuple): tuple of inputs to network.
12 | output_shape (int): number of output nodes for network.
13 | filter_sizes (tuple of ints): number of filters in each of the two hidden layers. Defaults to (512,512).
14 | learning_rate (float): network's willingness to change current weights given new example
15 | regularization (float): strength of weights regularization term in loss function
16 | discount_factor (float): factor by which future reward after next action is taken are discounted
17 | tau (float): Hyperparameter used in updating target network (if used)
18 | Some notable values:
19 | tau = 1.e-3 -> used in original paper
20 | tau = 0.5 -> average DDQN
21 | tau = 1.0 -> copy online -> target
22 |
23 | A Q-network class which is responsible for holding and updating the weights and biases used in predicing Q-values for a given state. This Q-network will consist of
24 | the following layers:
25 | 1) Input- a DraftState state s (an array of bool) representing the current state reshaped into an [n_batch, *input_shape] tensor.
26 | 2) Two layers of relu-activated hidden fc layers with dropout
27 | 3) Output- linearly activated estimations for Q-values Q(s,a) for each of the output_shape actions a available.
28 |
29 | """
30 | @property
31 | def name(self):
32 | return self._name
33 |
34 | @property
35 | def discount_factor(self):
36 | return self._discount_factor
37 |
38 | def __init__(self, name, path, input_shape, output_shape, filter_sizes=(512,512), learning_rate=1.e-5, regularization_coeff=1.e-4, discount_factor=0.9, tau=1.0):
39 | super().__init__(name=name, path=path)
40 | self._input_shape = input_shape
41 | self._output_shape = output_shape
42 | self._filter_sizes = filter_sizes
43 | self._learning_rate = learning_rate
44 | self._regularization_coeff = regularization_coeff
45 | self._discount_factor = discount_factor
46 | self._n_hidden_layers = len(filter_sizes)
47 | self._n_layers = self._n_hidden_layers + 2
48 | self._tau = tau
49 |
50 | self.online_name = "online"
51 | self.target_name = "target"
52 | # Build base Q-network model
53 | self.online_ops = self.build_model(name = self.online_name)
54 | # If using a target network for DDQN network, add related ops to model
55 | if(self.target_name):
56 | self.target_ops = self.build_model(name = self.target_name)
57 | self.target_ops["target_init"] = self.create_target_initialization_ops(self.target_name, self.online_name)
58 | self.target_ops["target_update"] = self.create_target_update_ops(self.target_name, self.online_name, tau=self._tau)
59 | with self._graph.as_default():
60 | self.online_ops["init"] = tf.global_variables_initializer()
61 | self.init_saver()
62 |
63 | def init_saver(self):
64 | with self._graph.as_default():
65 | self.saver = tf.train.Saver()
66 |
67 | def save(self, path):
68 | self.saver.save(self.sess, save_path=path)
69 |
70 | def load(self, path):
71 | self.saver.restore(self.sess, save_path=path)
72 |
73 | def build_model(self, name):
74 | ops_dict = {}
75 | with self._graph.as_default():
76 | with tf.variable_scope(name):
77 | ops_dict["learning_rate"] = tf.Variable(self._learning_rate, trainable=False, name="learning_rate")
78 |
79 | # Incoming state matrices are of size input_size = (nChampions, nPos+2)
80 | # 'None' here means the input tensor will flex with the number of training
81 | # examples (aka batch size).
82 | ops_dict["input"] = tf.placeholder(tf.float32, (None,)+self._input_shape, name="inputs")
83 | ops_dict["dropout_keep_prob"] = tf.placeholder_with_default(1.0,shape=())
84 |
85 | # Fully connected (FC) layers:
86 | fc0 = tf.layers.dense(
87 | ops_dict["input"],
88 | self._filter_sizes[0],
89 | activation=tf.nn.relu,
90 | bias_initializer=tf.constant_initializer(0.1),
91 | name="fc_0")
92 | dropout0 = tf.nn.dropout(fc0, ops_dict["dropout_keep_prob"])
93 |
94 | fc1 = tf.layers.dense(
95 | dropout0,
96 | self._filter_sizes[1],
97 | activation=tf.nn.relu,
98 | bias_initializer=tf.constant_initializer(0.1),
99 | name="fc_1")
100 | dropout1 = tf.nn.dropout(fc1, ops_dict["dropout_keep_prob"])
101 |
102 | # FC output layer
103 | ops_dict["outQ"] = tf.layers.dense(
104 | dropout1,
105 | self._output_shape,
106 | activation=None,
107 | bias_initializer=tf.constant_initializer(0.1),
108 | kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=self._regularization_coeff),
109 | name="q_vals")
110 |
111 | # Placeholder for valid actions filter
112 | ops_dict["valid_actions"] = tf.placeholder(tf.bool, shape=ops_dict["outQ"].shape, name="valid_actions")
113 |
114 | # Filtered Q-values
115 | ops_dict["valid_outQ"] = tf.where(ops_dict["valid_actions"], ops_dict["outQ"], tf.scalar_mul(-np.inf,tf.ones_like(ops_dict["outQ"])), name="valid_q_vals")
116 |
117 | # Max Q value amongst valid actions
118 | ops_dict["max_Q"] = tf.reduce_max(ops_dict["valid_outQ"], axis=1, name="max_Q")
119 |
120 | # Predicted optimal action amongst valid actions
121 | ops_dict["prediction"] = tf.argmax(ops_dict["valid_outQ"], axis=1, name="prediction")
122 |
123 | # Loss function and optimization:
124 | # The inputs self.target and self.actions are indexed by training example. If
125 | # s[i] = starting state for ith training example (recall that input state s is described by a vector so this will be a matrix)
126 | # a*[i] = action taken from state s[i] during this training sample
127 | # Q*(s[i],a*[i]) = the actual value observed from taking action a*[i] from state s[i]
128 | # outQ[i,-] = estimated values for all actions from state s[i]
129 | # Then we can write the inputs as
130 | # self.target[i] = Q*(s[i],a*[i])
131 | # self.actions[i] = a*[i]
132 |
133 | ops_dict["target"] = tf.placeholder(tf.float32, shape=[None], name="target_Q")
134 | ops_dict["actions"] = tf.placeholder(tf.int32, shape=[None], name="submitted_action")
135 |
136 | # Since the Qnet outputs a vector Q(s,-) of predicted values for every possible action that can be taken from state s,
137 | # we need to connect each target value with the appropriate predicted Q(s,a*) = Qout[i,a*[i]].
138 | # Main idea is to get indexes into the outQ tensor based on input actions and gather the resulting Q values
139 | # For some reason this isn't easy for tensorflow to do. So we must manually form the list of
140 | # [i, actions[i]] index pairs for outQ..
141 | # n_batch = outQ.shape[0] = actions.shape[0]
142 | # n_actions = outQ.shape[1]
143 | ind = tf.stack([tf.range(tf.shape(ops_dict["actions"])[0]),ops_dict["actions"]],axis=1)
144 | # and then "gather" them.
145 | estimatedQ = tf.gather_nd(ops_dict["outQ"], ind)
146 | # Special notes: this is more efficient than indexing into the flattened version of outQ (which I have seen before)
147 | # because the gather operation is applied to outQ directly. Apparently this propagates the gradient more efficiently
148 | # under specific sparsity conditions (which tf.Variables like outQ satisfy)
149 |
150 | # Simple sum-of-squares loss (error) function. Note that biases do not
151 | # need to be regularized since they are (generally) not subject to overfitting.
152 | ops_dict["loss"] = tf.reduce_mean(0.5*tf.square(ops_dict["target"]-estimatedQ), name="loss")
153 |
154 | ops_dict["trainer"] = tf.train.AdamOptimizer(learning_rate = ops_dict["learning_rate"])
155 | ops_dict["update"] = ops_dict["trainer"].minimize(ops_dict["loss"], name="update")
156 |
157 | return ops_dict
158 |
159 | def create_target_update_ops(self, target_scope, online_scope, tau=1e-3, name="target_update"):
160 | """
161 | Adds operations to graph which are used to update the target network after after a training batch is sent
162 | through the online network.
163 |
164 | This function should be executed only once before training begins. The resulting operations should
165 | be run within a tf.Session() once per training batch.
166 |
167 | In double-Q network learning, the online (primary) network is updated using traditional backpropegation techniques
168 | with target values produced by the target-Q network.
169 | To improve stability, the target-Q is updated using a linear combination of its current weights
170 | with the current weights of the online network:
171 | Q_target = tau*Q_online + (1-tau)*Q_target
172 | Typical tau values are small (tau ~ 1e-3). For more, see https://arxiv.org/abs/1509.06461 and https://arxiv.org/pdf/1509.02971.pdf.
173 | Args:
174 | target_scope (str): name of scope that target network occupies
175 | online_scope (str): name of scope that online network occupies
176 | tau (float32): Hyperparameter for combining target-Q and online-Q networks
177 | name (str): name of operation which updates the target network when run within a session
178 | Returns: Tensorflow operation which updates the target nework when run.
179 | """
180 | with self._graph.as_default():
181 | target_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=target_scope)
182 | online_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=online_scope)
183 | ops = [target_params[i].assign(tf.add(tf.multiply(tau,online_params[i]),tf.multiply(1.-tau,target_params[i]))) for i in range(len(target_params))]
184 | return tf.group(*ops,name=name)
185 |
186 | def create_target_initialization_ops(self, target_scope, online_scope):
187 | """
188 | This adds operations to the graph in order to initialize the target Q network to the same values as the
189 | online network.
190 |
191 | This function should be executed only once just after the online network has been initialized.
192 |
193 | Args:
194 | target_scope (str): name of scope that target network occupies
195 | online_scope (str): name of scope that online network occupies
196 | Returns:
197 | Tensorflow operation (named "target_init") which initialize the target nework when run.
198 | """
199 | return self.create_target_update_ops(target_scope, online_scope, tau=1.0, name="target_init")
200 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright {yyyy} {name of copyright owner}
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/src/data/database_ops.py:
--------------------------------------------------------------------------------
1 | import sqlite3
2 | import re
3 | from .champion_info import champion_id_from_name,champion_name_from_id, convert_champion_alias, AliasException
4 |
5 | regionsDict = {"NA_LCS":"NA", "EU_LCS":"EU", "LCK":"LCK", "LPL":"LPL",
6 | "LMS":"LMS", "International":"INTL", "NA_ACA": "NA_ACA", "KR_CHAL":"KR_CHAL", "LDL":"LDL"}
7 | internationalEventsDict = {"Mid-Season_Invitational":"MSI",
8 | "Rift_Rivals":"RR","World_Championship":"WRLDS"}
9 |
10 | def get_matches_by_id(match_ids, path):
11 | """
12 | Returns match data for each match_id in the list match_ids
13 | """
14 | conn = sqlite3.connect(path)
15 | cur = conn.cursor()
16 | match_data = []
17 | for match_id in match_ids:
18 | match = get_match_data(cur, match_id)
19 | match_data.append(match)
20 | conn.close()
21 | return match_data
22 |
23 | def get_game_ids_by_tournament(cursor, tournament, patch=None):
24 | """
25 | getMatchIdsByTournament queries the connected db for game ids which match the
26 | input tournament string.
27 |
28 | Args:
29 | cursor (sqlite cursor): cursor used to execute commmands
30 | tournament (string): id string for tournament (ie "2017/EU/Summer_Split")
31 | patch (string, optional): id string for patch to additionally filter
32 | Returns:
33 | gameIds (list(int)): list of gameIds
34 | """
35 | if patch:
36 | query = "SELECT id FROM game WHERE tournament=? AND patch=? ORDER BY id"
37 | params = (tournament, patch)
38 | else:
39 | query = "SELECT id FROM game WHERE tournament=? ORDER BY id"
40 | params = (tournament,)
41 | cursor.execute(query, params)
42 | response = cursor.fetchall()
43 | vals = []
44 | for r in response:
45 | vals.append(r[0])
46 | return vals
47 |
48 | def get_game_ids(cursor, tournament=None, patch=None):
49 | """
50 | get_game_ids queries the connected db for game ids which match the
51 | input tournament and patch strings.
52 |
53 | Args:
54 | cursor (sqlite cursor): cursor used to execute commmands
55 | tournament (string, optional): id string for tournament (ie "2017/EU/Summer_Split")
56 | patch (string, optional): id string for patch to filter for
57 | Returns:
58 | gameIds (list(int)): list of gameIds
59 | """
60 | if not patch and not tournament:
61 | return []
62 |
63 | params = ()
64 | where_clause = []
65 | if tournament:
66 | where_clause.append("tournament=?")
67 | params += (tournament,)
68 | if patch:
69 | where_clause.append("patch=?")
70 | params += (patch,)
71 |
72 | query = "SELECT id FROM game WHERE {where_clause} ORDER BY id".format(where_clause=" AND ".join(where_clause))
73 | cursor.execute(query, params)
74 | response = cursor.fetchall()
75 | vals = []
76 | for r in response:
77 | vals.append(r[0])
78 | return vals
79 |
80 | def get_match_data(cursor, gameId):
81 | """
82 | get_match_data queries the connected db for draft data and organizes it into a more convenient
83 | format.
84 |
85 | Args:
86 | cursor (sqlite cursor): cursor used to execute commmands
87 | gameId (int): primary key of game to process
88 | Returns:
89 | match (dict): formatted pick/ban phase data for game
90 | """
91 | match = {"id": gameId ,"winner": None, "blue":{}, "red":{}, "blue_team":None, "red_team":None, "header_id":None, "patch":None}
92 | # Get winning team
93 | query = "SELECT tournament, tourn_game_id, week, patch, winning_team FROM game WHERE id=?"
94 | params = (gameId,)
95 | cursor.execute(query, params)
96 | match["tournament"], match["tourn_game_id"], match["header_id"], match["patch"], match["winner"] = cursor.fetchone()#[0]
97 |
98 | # Get ban data
99 | query = "SELECT champion_id, selection_order FROM ban WHERE game_id=? and side_id=? ORDER BY selection_order"
100 | params = (gameId,0)
101 | cursor.execute(query, params)
102 | match["blue"]["bans"] = list(cursor.fetchall())
103 |
104 | query = "SELECT champion_id, selection_order FROM ban WHERE game_id=? and side_id=? ORDER BY selection_order"
105 | params = (gameId,1)
106 | cursor.execute(query, params)
107 | match["red"]["bans"] = list(cursor.fetchall())
108 |
109 | # Get pick data
110 | query = "SELECT champion_id, position_id, selection_order FROM pick WHERE game_id=? AND side_id=? ORDER BY selection_order"
111 | params = (gameId,0)
112 | cursor.execute(query, params)
113 | match["blue"]["picks"] = list(cursor.fetchall())
114 |
115 | query = "SELECT champion_id, position_id, selection_order FROM pick WHERE game_id=? AND side_id=? ORDER BY selection_order"
116 | params = (gameId,1)
117 | cursor.execute(query, params)
118 | match["red"]["picks"] = list(cursor.fetchall())
119 |
120 | query = "SELECT display_name FROM team JOIN game ON team.id = blue_teamid WHERE game.id = ?"
121 | params = (gameId,)
122 | cursor.execute(query, params)
123 | match["blue_team"] = cursor.fetchone()[0]
124 |
125 | query = "SELECT display_name FROM team JOIN game ON team.id = red_teamid WHERE game.id = ?"
126 | params = (gameId,)
127 | cursor.execute(query, params)
128 | match["red_team"] = cursor.fetchone()[0]
129 |
130 | return match
131 |
132 | def get_tournament_data(gameData):
133 | """
134 | get_tournament_data cleans up and combines the region/year/tournament fields in gameData for entry into
135 | the game table. When combined with the game_id field it uniquely identifies the match played.
136 | The format of tournamentData output is 'year/region_abbrv/tournament' (forward slash delimiters)
137 |
138 | Args:
139 | gameData (dict): dictonary output from query_wiki()
140 | Returns:
141 | tournamentData (string): formatted and cleaned region/year/split data
142 | """
143 | tournamentData = "/".join([gameData["year"], regionsDict[gameData["region"]], gameData["tournament"]])
144 | return tournamentData
145 |
146 | def get_game_id(cursor,gameData):
147 | """
148 | get_game_id looks in the game table for an entry with matching tournament and tourn_game_id as the input
149 | gameData and returns the id field. If no such entry is found, it adds this game to the game table and returns the
150 | id field.
151 |
152 | Args:
153 | cursor (sqlite cursor): cursor used to execute commmands
154 | gameData (dict): dictionary output from query_wiki()
155 | Returns:
156 | gameId (int): Primary key in game table corresponding to this gameData
157 | """
158 | tournament = get_tournament_data(gameData)
159 | vals = (tournament,gameData["tourn_game_id"])
160 | gameId = None
161 | while gameId is None:
162 | cursor.execute("SELECT id FROM game WHERE tournament=? AND tourn_game_id=?", vals)
163 | gameId = cursor.fetchone()
164 | if gameId is None:
165 | print("Warning: Game not found. Attempting to add game.")
166 | err = insert_game(cursor,[game])
167 | else:
168 | gameId = gameId[0]
169 | return gameId
170 |
171 | def delete_game_from_table(cursor, game_ids, table_name):
172 | """
173 | Deletes rows corresponding to game_id from table table_name.
174 | Args:
175 | cursor (sqlite cursor): cursor used to execute commmands
176 | game_ids (list(int)): game_ids to be removed from table
177 | table_name (string): name of table to remove rows from
178 | Returns:
179 | status (int): status = 1 if delete was successful, otherwise status = 0
180 | """
181 | status = 0
182 | assert isinstance(game_ids,list), "game_ids is not a list"
183 | for game_id in game_ids:
184 | query = "SELECT count(*) FROM {table_name} WHERE game_id=?".format(table_name=table_name)
185 | vals = (game_id,)
186 | cursor.execute(query, vals)
187 | print("Found {count} rows for game_id={game_id} to delete from table {table}".format(count=cursor.fetchone()[0], game_id=game_id, table=table_name))
188 |
189 | query = "DELETE FROM {table_name} WHERE game_id=?".format(table_name=table_name)
190 | cursor.execute(query, vals)
191 | status = 1
192 | return status
193 |
194 | def insert_game(cursor, gameData):
195 | """
196 | insert_game attempts to format collected gameData from query_wiki() and insert
197 | into the game table in the competitiveGameData.db.
198 |
199 | Args:
200 | cursor (sqlite cursor): cursor used to execute commmands
201 | gameData (list(dict)): list of dictionary output from query_wiki()
202 | Returns:
203 | status (int): status = 1 if insert was successful, otherwise status = 0
204 | """
205 | status = 0
206 | assert isinstance(gameData,list), "gameData is not a list"
207 | for game in gameData:
208 | tournGameId = game["tourn_game_id"] # Which game this is within current tournament
209 | tournamentData = get_tournament_data(game)
210 |
211 | # Check to see if game data is already in table
212 | vals = (tournamentData,tournGameId)
213 | cursor.execute("SELECT id FROM game WHERE tournament=? AND tourn_game_id=?", vals)
214 | result = cursor.fetchone()
215 | if result is not None:
216 | print("game {} already exists in table.. skipping".format(result[0]))
217 | else:
218 | # Get blue and red team_ids
219 | blueTeamId = None
220 | redTeamId = None
221 | while (blueTeamId is None or redTeamId is None):
222 | cursor.execute("SELECT id FROM team WHERE display_name=?",(game["blue_team"],))
223 | blueTeamId = cursor.fetchone()
224 | cursor.execute("SELECT id FROM team WHERE display_name=?",(game["red_team"],))
225 | redTeamId = cursor.fetchone()
226 | if (blueTeamId is None) or (redTeamId is None):
227 | print("*WARNING: When inserting game-- team not found. Attempting to add teams")
228 | err = insert_team(cursor, [game])
229 | else:
230 | blueTeamId = blueTeamId[0]
231 | redTeamId = redTeamId[0]
232 |
233 | winner = game["winning_team"]
234 | header_id = game["header_id"]
235 | patch = game["patch"]
236 | vals = (tournamentData, tournGameId, header_id, patch, blueTeamId, redTeamId, winner)
237 | cursor.execute("INSERT INTO game(tournament, tourn_game_id, week, patch, blue_teamid, red_teamid, winning_team) VALUES(?,?,?,?,?,?,?)", vals)
238 | status = 1
239 | return status
240 |
241 | def insert_team(cursor, gameData):
242 | """
243 | insert_team attempts to format collected gameData from query_wiki() and insert
244 | into the team table in the competitiveGameData.db.
245 |
246 | Args:
247 | cursor (sqlite cursor): cursor used to execute commmands
248 | wikiGameData (list(dict)): dictionary output from query_wiki()
249 | Returns:
250 | status (int): status = 1 if insert was successful, otherwise status = 0
251 | """
252 | status = 0
253 | assert isinstance(gameData,list), "gameData is not a list"
254 | for game in gameData:
255 | # We don't track all regions (i.e wildcard regions), but they can still appear at
256 | # international tournaments. When this happens we will track the team, but list their
257 | # region as NULL.
258 | if game["region"] is "Inernational":
259 | region = None
260 | else:
261 | region = regionsDict[game["region"]]
262 | teams = [game["blue_team"], game["red_team"]]
263 | for team in teams:
264 | vals = (region,team)
265 | # This only looks for matching display names.. what happens if theres a
266 | # NA TSM and and EU TSM?
267 | cursor.execute("SELECT * FROM team WHERE display_name=?", (team,))
268 | result = cursor.fetchone()
269 | if result is None:
270 | cursor.execute("INSERT INTO team(region, display_name) VALUES(?,?)", vals)
271 | status = 1
272 | return status
273 |
274 | def insert_ban(cursor, gameData):
275 | """
276 | insert_ban attempts to format collected gameData from query_wiki() and insert into the
277 | ban table in the competitiveGameData.db.
278 |
279 | Args:
280 | cursor (sqlite cursor): cursor used to execute commmands
281 | gameData (list(dict)): dictionary output from query_wiki()
282 | Returns:
283 | status (int): status = 1 if insert was successful, otherwise status = 0
284 | """
285 | status = 0
286 | assert isinstance(gameData,list), "gameData is not a list"
287 | teams = ["blue", "red"]
288 | for game in gameData:
289 | tournament = get_tournament_data(game)
290 | vals = (tournament,game["tourn_game_id"])
291 | gameId = get_game_id(cursor,game)
292 | # Check for existing entries in table. Skip if they already exist.
293 | cursor.execute("SELECT game_id FROM ban WHERE game_id=?",(gameId,))
294 | result = cursor.fetchone()
295 | if result is not None:
296 | print("Bans for game {} already exists in table.. skipping".format(result[0]))
297 | else:
298 | for k in range(len(teams)):
299 | bans = game["bans"][teams[k]]
300 | selectionOrder = 0
301 | side = k
302 | for ban in bans:
303 | if ban in ["lossofban","none"]:
304 | # Special case if no ban was submitted in game
305 | banId = None
306 | else:
307 | # print("ban={}".format(ban))
308 | banId = champion_id_from_name(ban)
309 | # If no such champion name is found, try looking for an alias
310 | if banId is None:
311 | banId = champion_id_from_name(convert_champion_alias(ban))
312 | selectionOrder += 1
313 | vals = (gameId,banId,selectionOrder,side)
314 | cursor.execute("INSERT INTO ban(game_id, champion_id, selection_order, side_id) VALUES(?,?,?,?)", vals)
315 | status = 1
316 | return status
317 |
318 | def insert_pick(cursor, gameData):
319 | """
320 | insert_pick formats collected gameData from query_wiki() and inserts it into the pick table of the
321 | competitiveGameData.db.
322 |
323 | Args:
324 | cursor (sqlite cursor): cursor used to execute commmands
325 | gameData (list(dict)): list of formatted game data from query_wiki()
326 | Returns:
327 | status (int): status = 1 if insert was successful, otherwise status = 0
328 | """
329 | status = 0
330 | assert isinstance(gameData,list), "gameData is not a list"
331 | teams = ["blue", "red"]
332 | for game in gameData:
333 | tournament = get_tournament_data(game)
334 | vals = (tournament,game["tourn_game_id"])
335 | gameId = get_game_id(cursor,game)
336 | # Check for existing entries in table. Skip if they already exist.
337 | cursor.execute("SELECT game_id FROM pick WHERE game_id=?",(gameId,))
338 | result = cursor.fetchone()
339 | if result is not None:
340 | print("Picks for game {} already exists in table.. skipping".format(result[0]))
341 | else:
342 | for k in range(len(teams)):
343 | picks = game["picks"][teams[k]]
344 | selectionOrder = 0
345 | side = k
346 | for (pick,position) in picks:
347 | if pick in ["lossofpick","none"]:
348 | # Special case if no pick was submitted to game (not really sure what that would mean
349 | # but being consistent with insert_pick())
350 | pickId = None
351 | else:
352 | pickId = champion_id_from_name(pick)
353 | # If no such champion name is found, try looking for an alias
354 | if pickId is None:
355 | pickId = champion_id_from_name(convert_champion_alias(pick))
356 | selectionOrder += 1
357 | vals = (gameId,pickId,position,selectionOrder,side)
358 | cursor.execute("INSERT INTO pick(game_id, champion_id, position_id, selection_order, side_id) VALUES(?,?,?,?,?)", vals)
359 | status = 1
360 | return status
361 |
--------------------------------------------------------------------------------
/src/data/query_wiki.py:
--------------------------------------------------------------------------------
1 | import json # JSON tools
2 | import requests # URL api tools
3 | import re # regex tools
4 | from .champion_info import convert_champion_alias, champion_id_from_name
5 |
6 | def query_wiki(year, region, tournament):
7 | """
8 | query_wiki takes identifying sections and subsections for a page title on leaguepedia and formats and executes a set of requests to the
9 | API looking for the pick/ban data corresponding to the specified sections and subsections. This response is then
10 | pruned and formatted into a list of dictionaries. Specified sections and subsections should combine into a unique identifying string
11 | for a specific tournament and query_wiki() will return all games for that tournament.
12 |
13 | For example, if we are interested in the regular season of the 2017 European Summer Split we would call:
14 | query_wiki("2017", "EU_LCS", "Summer_Season")
15 |
16 | If we were interested in 2016 World Championship we would pass:
17 | query_wiki("2016", "International", "WORLDS/Main_Event")
18 |
19 | Each dictionary corresponds to the pick/ban phase of an LCS game with the following keys:
20 | "region":
21 | "season":
22 | "tournament":
23 | "bans": {"blue":, "red":}
24 | "blue_team":
25 | "blue_team_score"
26 | "red_team":
27 | "red_team_score:"
28 | "tourn_game_id":
29 | "picks": {"blue":, "red":}
30 |
31 | Args:
32 | year (string): year of game data of interest
33 | region (string): region of play for games
34 | tournament (string): which tournament games were played in
35 | Returns:
36 | List of dictionaries containing formatted response data from lol.gamepedia api
37 | """
38 | # Common root for all requests
39 | url_root = "https://lol.gamepedia.com/api.php"
40 |
41 | # Semi-standardized page suffixes for pick/ban pages
42 | page_suffixes = ["", "/Bracket_Stage", "/3-4", "/5-6", "/5-8", "/4-6", "/4-7", "/7-9", "/7-10", "/8-10", "/7-8", "/9-11"]
43 |
44 | formatted_regions = {"NA_LCS":"League_Championship_Series/North_America",
45 | "NA_ACA":"NA_Academy_League",
46 | "EU_LCS":"League_Championship_Series/Europe",
47 | "LCK":"LCK",
48 | "KR_CHAL":"Challengers_Korea",
49 | "LPL":"LPL",
50 | "LMS":"LMS",
51 | "LDL":"LDL"}
52 |
53 | formatted_international_tournaments = {
54 | "WORLDS/Play-In": "Season_World_Championship/Play-In",
55 | "WORLDS/Main_Event": "Season_World_Championship/Main_Event",
56 | "MSI/Play-In": "Mid-Season_Invitational/Play-In",
57 | "MSI/Main_Event": "Mid-Season_Invitational/Main_Event",
58 | "WORLDS_QUALS/NA": "Season_North_America_Regional_Finals",
59 | "WORLDS_QUALS/EU": "Season_Europe_Regional_Finals",
60 | "WORLDS_QUALS/LCK": "Season_Korea_Regional_Finals",
61 | "WORLDS_QUALS/LPL": "Season_China_Regional_Finals",
62 | "WORLDS_QUALS/LMS": "Season_Taiwan_Regional_Finals",
63 | }
64 |
65 | with open('../data/patch_info.json','r') as infile:
66 | patch_data = json.load(infile)
67 | patches = patch_data["patch_info"][year][region][tournament]
68 | print(patches)
69 |
70 | # Build list of titles of pages to query
71 | if region == "International":
72 | title_root = ["_".join([year,formatted_international_tournaments[tournament]])]
73 | else:
74 | formatted_region = formatted_regions[region]
75 | formatted_year = "_".join([year,"Season"])
76 | title_root = [formatted_region, formatted_year, tournament]
77 | title_root.append("Picks_and_Bans")
78 | title_root = "/".join(title_root)
79 |
80 | title_list = []
81 | for suffix in page_suffixes:
82 | title_list.append(title_root+suffix)
83 | formatted_title_list = "|".join(title_list) # Parameter string to pass to API
84 | params = {"action": "query", "titles": formatted_title_list,
85 | "prop":"revisions", "rvprop":"content", "format": "json"}
86 |
87 | response = requests.get(url=url_root, params=params)
88 | print(response.url)
89 | data = json.loads(response.text)
90 | page_data = data['query']['pages']
91 | # Get list of page keys (actually a list of pageIds.. could be used to identify pages?)
92 | page_keys = list(sorted(page_data.keys()))
93 | page_keys = [k for k in page_keys if int(k)>=0] # Filter out "invalid page" and "missing page" responses
94 | formatted_data = []
95 | tournGameId = 0
96 |
97 | for page in page_keys:
98 | # Get the raw text of the most recent revision of the current page
99 | # Note that we remove all space characters from the raw text, including those
100 | # in team or champion names.
101 | raw_text = page_data[page]["revisions"][0]["*"].replace(" ","").replace("\\n"," ")
102 | print(page_data[page]["title"])
103 |
104 | # week_labels = parse_raw_text("(name=Week[0-9]+)", raw_text)
105 | # week_numbers = [int(i.replace("week","")) for i in week_labels]
106 | # week_data = re.split("(name=Week[0-9]+)", raw_text)[2::2]
107 |
108 | # Get section headers
109 | headers = parse_raw_text("(name=[\w0-9]+)", raw_text)
110 | # Look for indexed headers first. If found, use this index for patch data. If no such index is found, use the first element of patch data
111 | search = [re.search("[0-9]+", header) for header in headers]
112 | header_indices = [int(s.group()) if s else 0 for s in search]
113 |
114 | # If there's only one patch for this tournament, make sure all header indices point to it.
115 | if len(patches) == 1:
116 | header_indices = [0 for header in header_indices]
117 |
118 | # This is an edge case for when some header indices can be read but there are also non-indexable headers
119 | # Good example is when there is a "Tiebreaker" section at the end of a indexed regular split.
120 | # This bit sets the patch data index for such sections to the same one as the most recent non-zero index
121 | # Ex: [1,2,3,4,0] -> [1,2,3,4,4]
122 | if not all(value==0 for value in header_indices):
123 | last_val = 0
124 | for i,val in enumerate(header_indices):
125 | if val!=0:
126 | last_val = val
127 | else:
128 | header_indices[i] = last_val
129 |
130 | section_data = re.split("(name=[\w0-9]+)", raw_text)[2::2]
131 |
132 | num_games_on_page = 0
133 | for i in range(len(section_data)):
134 | data = section_data[i]
135 |
136 | # winning_teams holds which team won for each parsed game
137 | # winner = 1 -> first team won (i.e blue team)
138 | # winner = 2 -> second team won (i.e red team)
139 | winning_teams = parse_raw_text("(winner=[0-9])", data)
140 | winning_teams = [int(i)-1 for i in winning_teams] # Convert string response to int
141 | num_games_in_week = len(winning_teams)
142 |
143 | if(num_games_in_week == 0):
144 | continue
145 | else:
146 | num_games_on_page += num_games_in_week
147 |
148 | # string representation of blue and red teams, ordered by game
149 | blue_teams = parse_raw_text("(team1=[\w\s]+)", data)
150 | red_teams = parse_raw_text("(team2=[\w\s]+)", data)
151 |
152 | blue_scores = parse_raw_text("(team1score=[0-9])", data)
153 | red_scores = parse_raw_text("(team2score=[0-9])", data)
154 |
155 | # bans holds the string identifiers of submitted bans for each team in the parsed game
156 | # ex: bans[k] = list of bans for kth game on the page
157 | all_blue_bans = parse_raw_text("(blueban[0-9]=\w[\w\s',.]+)", data)
158 | all_red_bans = parse_raw_text("(red_ban[0-9]=\w[\w\s',.]+)", data)
159 | assert len(all_blue_bans)==len(all_red_bans), "blue bans: {}, red bans: {}".format(len(all_blue_bans),len(all_red_bans))
160 | bans_per_team = len(all_blue_bans)//num_games_in_week
161 |
162 | # blue_picks[i] = list of picks for kth game on the page
163 | all_blue_picks = parse_raw_text("(bluepick[0-9]=\w[\w\s',.]+)", data)
164 | all_blue_roles = parse_raw_text("(bluerole[0-9]=\w[\w\s',.]?)", data)
165 | all_red_picks = parse_raw_text("(red_pick[0-9]=\w[\w\s',.]+)", data)
166 | all_red_roles = parse_raw_text("(red_role[0-9]=\w[\w\s',.]?)", data)
167 | #print(data)
168 | assert len(all_blue_picks)==len(all_red_picks), "blue picks: {}, red picks: {}".format(len(all_blue_picks),len(all_red_picks))
169 | assert len(all_blue_roles)==len(all_red_roles), "blue roles: {}, red roles: {}".format(len(all_blue_roles),len(all_red_roles))
170 | picks_per_team = len(all_blue_picks)//num_games_in_week
171 |
172 | # Clean fields involving chanmpion names, looking for aliases if necessary
173 | all_blue_bans = clean_champion_names(all_blue_bans)
174 | all_red_bans = clean_champion_names(all_red_bans)
175 | all_blue_picks = clean_champion_names(all_blue_picks)
176 | all_red_picks = clean_champion_names(all_red_picks)
177 |
178 | # Format data by match
179 | blue_bans = []
180 | red_bans = []
181 | for k in range(num_games_in_week):
182 | blue_bans.append(all_blue_bans[bans_per_team*k:bans_per_team*(k+1)])
183 | red_bans.append(all_red_bans[bans_per_team*k:bans_per_team*(k+1)])
184 |
185 | # submissions holds the identifiers of submitted (pick, position) pairs for each team in the parsed game
186 | # string representation for the positions are converted to ints to match DraftState expectations
187 | blue_picks = []
188 | red_picks = []
189 | for k in range(num_games_in_week):
190 | picks = all_blue_picks[picks_per_team*k:picks_per_team*(k+1)]
191 | positions = position_string_to_id(all_blue_roles[picks_per_team*k:picks_per_team*(k+1)])
192 | blue_picks.append(list(zip(picks,positions)))
193 |
194 | picks = all_red_picks[picks_per_team*k:picks_per_team*(k+1)]
195 | positions = position_string_to_id(all_red_roles[picks_per_team*k:picks_per_team*(k+1)])
196 | red_picks.append(list(zip(picks,positions)))
197 |
198 | total_blue_bans = sum([len(bans) for bans in blue_bans])
199 | total_red_bans = sum([len(bans) for bans in red_bans])
200 | total_blue_picks = sum([len(picks) for picks in blue_picks])
201 | total_red_picks = sum([len(picks) for picks in red_picks])
202 | print("Total number of games found: {}".format(num_games_in_week))
203 | print("There should be {} bans. We found {} blue bans and {} red bans".format(num_games_in_week*5,total_blue_bans,total_red_bans))
204 | print("There should be {} picks. We found {} blue picks and {} red picks".format(num_games_in_week*5,total_blue_picks,total_red_picks))
205 | assert total_red_bans==total_blue_bans, "Bans don't match!"
206 | assert total_red_picks==total_blue_picks, "Picks don't match!"
207 | if(num_games_in_week > 0): # At least one game found on current page
208 | for k in range(num_games_in_week):
209 | print("Header_id {}, Game {}: {} vs {}".format(header_indices[i],k+1,blue_teams[k],red_teams[k]))
210 |
211 | tournGameId += 1
212 | bans = {"blue": blue_bans[k], "red":red_bans[k]}
213 | picks = {"blue": blue_picks[k], "red":red_picks[k]}
214 | blue = {"bans": blue_bans[k], "picks":blue_picks[k]}
215 | red = {"bans": red_bans[k], "picks":red_picks[k]}
216 | gameData = {"region": region, "year":year, "tournament": tournament,
217 | "blue_team": blue_teams[k], "red_team": red_teams[k],
218 | "winning_team": winning_teams[k],
219 | "blue_score":blue_scores[k], "red_score":red_scores[k],
220 | "bans": bans, "picks": picks, "blue":blue, "red":red,
221 | "tourn_game_id": tournGameId, "header_id":header_indices[i],
222 | "patch":patches[header_indices[i]-1]}
223 | formatted_data.append(gameData)
224 | return formatted_data
225 | def position_string_to_id(positions):
226 | """
227 | position_string_to_id converts input position strings to their integer representations defined by:
228 | Position 1 = Primary farm (ADC)
229 | Position 2 = Secondary farm (Mid)
230 | Position 3 = Tertiary farm (Top)
231 | Position 4 = Farming support (Jungle)
232 | Position 5 = Primary support (Support)
233 | Note that because of variable standardization of the string representations for each position
234 | (i.e "jg"="jng"="jungle"), this function only looks at the first character of each string when
235 | assigning integer positions since this seems to be more or less standard.
236 |
237 | Args:
238 | positions (list(string))
239 | Returns:
240 | list(int)
241 | """
242 |
243 | d = {"a":1, "m":2, "t":3, "j":4, "s":5} # This is lazy and I know it
244 | out = []
245 | for position in positions:
246 | char = position[0] # Look at first character for position information
247 | out.append(d[char])
248 | return out
249 |
250 | def parse_raw_text(regex, rawText):
251 | """
252 | parse_raw_text is a helper function which outputs a list of matching expressions defined
253 | by the regex input. Note that this function assumes that each regex yields matches of the form
254 | "A=B" which is passed to split_id_strings() for fomatting.
255 |
256 | Args:
257 | regex: desired regex to match with
258 | rawText: raw input string to find matches in
259 | Returns:
260 | List of formatted strings containing the matched data.
261 | """
262 | # Parse raw text responses for data. Note that a regular expression of the form
263 | # "(match)" will produce result = [stuff_before_match, match, stuff_after_match]
264 | # this means that the list of desired matches will be result[1::2]
265 | out = re.split(regex, rawText)
266 | out = split_id_strings(out[1::2]) # Format matching strings
267 | return out
268 |
269 | def split_id_strings(rawStrings):
270 | """
271 | split_id_strings takes a list of strings each of the form "A=B" and splits them
272 | along the "=" delimiting character. Returns the list formed by each of the "B"
273 | components of the raw input strings. For standardization purposes, the "B" string
274 | has the following done to it:
275 | 1. replace uppercase characters with lowercase
276 | 2. remove special characters (i.e non-alphanumeric)
277 |
278 | Args:
279 | rawStrings (list of strings): list of strings, each of the form "A=B"
280 | Returns:
281 | out (list of strings): list of strings formed by the "B" portion of each of the raw input strings
282 | """
283 | out = []
284 | for string in rawStrings:
285 | rightHandString = string.split("=")[1].lower() # Grab "B" part of string, make lowercase
286 | out.append(re.sub("[^A-Za-z0-9,]+", "", rightHandString)) # Remove special chars
287 | return out
288 |
289 | def convert_lcs_positions(index):
290 | """
291 | Given the index of a pick in LCS order, returns the position id corresponding
292 | to that index.
293 |
294 | LCS picks are submitted in the following order
295 | Index | Role | Position
296 | 0 Top 3
297 | 1 Jng 4
298 | 2 Mid 2
299 | 3 Adc 1
300 | 4 Sup 5
301 | """
302 | lcsOrderToPos = {i:j for i,j in enumerate([3,4,2,1,5])}
303 | return lcsOrderToPos[index]
304 |
305 | def create_position_dict(picks_in_lcs_order):
306 | """
307 | Given a list of champions selected in lcs order (ie top,jungle,mid,adc,support)
308 | returns a dictionary which matches pick -> position.
309 | Args:
310 | picks_in_lcs_order (list(string)): list of string identifiers of picks. Assumed to be in LCS order
311 | Returns:
312 | dict (dictionary): dictionary with champion names for keys and position that the key was played in for value.
313 | """
314 | d = {}
315 | cleaned_names = clean_champion_names(picks_in_lcs_order)
316 | for k in range(len(picks_in_lcs_order)):
317 | pos = convert_lcs_positions(k)
318 | d.update({cleaned_names[k]:pos})
319 | return d
320 |
321 |
322 | def clean_champion_names(names):
323 | """
324 | Takes a list of champion names and standarizes them by looking for possible aliases
325 | if necessary.
326 | Args:
327 | names (list(string)): list of champion names to be standardized
328 | Returns:
329 | cleanedNames (list(string)): list of standardized champion names
330 | """
331 | cleanedNames = []
332 | for name in names:
333 | if champion_id_from_name(name) is None:
334 | name = convert_champion_alias(name)
335 | cleanedNames.append(name)
336 | return cleanedNames
337 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import random
3 | from copy import deepcopy
4 |
5 | import tensorflow as tf
6 | import pandas as pd
7 | import numpy as np
8 |
9 | import data.match_pool as pool
10 |
11 | from features.draftstate import DraftState
12 | import features.experience_replay as er
13 | import features.match_processing as mp
14 | from features.rewards import get_reward
15 |
16 | class BaseTrainer():
17 | pass
18 |
19 | class DDQNTrainer(BaseTrainer):
20 | """
21 | Trainer class for Double DQN networks.
22 | Args:
23 | q_network (qNetwork): Q-network containing "online" and "target" networks
24 | n_epochs (int): number of times to iterate through given data
25 | training_matches (list(match)): list of matches to be trained on
26 | validation_matches (list(match)): list of matches to validate model against
27 | batch_size (int): size of each training set sampled from the replay buffer which will be used to update Qnet at a time
28 | buffer_size (int): size of replay buffer used
29 | load_path (string): path to reload existing model
30 | """
31 | def __init__(self, q_network, n_epoch, training_data, validation_data, batch_size, buffer_size, load_path=None):
32 | num_episodes = len(training_data)
33 | print("***")
34 | print("Beginning training..")
35 | print(" train_epochs: {}".format(n_epoch))
36 | print(" num_episodes: {}".format(num_episodes))
37 | print(" batch_size: {}".format(batch_size))
38 | print(" buffer_size: {}".format(buffer_size))
39 | print("***")
40 |
41 | self.ddq_net = q_network
42 | self.n_epoch = n_epoch
43 | self.training_data = training_data
44 | self.validation_data = validation_data
45 | self.batch_size = batch_size
46 | self.buffer_size = buffer_size
47 | self.load_path = load_path
48 |
49 | self.replay = er.ExperienceBuffer(self.buffer_size)
50 | self.step_count = 0
51 | self.epoch_count = 0
52 |
53 | self.dampen_states = False
54 | self.teams = [DraftState.BLUE_TEAM, DraftState.RED_TEAM]
55 |
56 | self.N_TEMP_TRAIN_MATCHES = 25
57 | self.TEMP_TRAIN_PATCHES = ["8.13","8.14","8.15"]
58 |
59 | def train(self):
60 | """
61 | Core training loop over epochs
62 | """
63 | self.target_update_frequency = 10000 # How often to update target network
64 |
65 | stash_model = True # Flag for stashing a copy of the model
66 | model_stash_interval = 10 # Stashes a copy of the model this often
67 |
68 | # Number of steps to take before training. Allows buffer to partially fill.
69 | # Must be at least batch_size to avoid error when sampling from experience replay
70 | self.pre_training_steps = 10*self.batch_size
71 | assert(self.pre_training_steps <= self.buffer_size), "Replay not large enough for pre-training!"
72 | assert(self.pre_training_steps >= self.batch_size), "Buffer not allowed to fill enough before sampling!"
73 | # Number of steps to force learner to observe submitted actions, rather than submit its own actions
74 | self.observations = 2000
75 |
76 | self.epsilon = 0.5 # Initial probability of letting the learner submit its own action
77 | self.eps_decay_rate = 1./(25*20*len(self.training_data)) # Rate at which epsilon decays per submission
78 |
79 | lr_decay_freq = 10 # Decay learning rate after a set number of epochs
80 | min_learning_rate = 1.e-8 # Minimum learning rate allowed to decay to
81 |
82 | summaries = {}
83 | summaries["loss"] = []
84 | summaries["train_acc"] = []
85 | summaries["val_acc"] = []
86 | # Load existing model
87 | self.ddq_net.sess.run(self.ddq_net.online_ops["init"])
88 | if(self.load_path):
89 | self.ddq_net.load(self.load_path)
90 | print("\nCheckpoint loaded from {}".format(self.load_path))
91 |
92 | # Initialize target network
93 | self.ddq_net.sess.run(self.ddq_net.target_ops["target_init"])
94 |
95 | for self.epoch_count in range(self.n_epoch):
96 | t0 = time.time()
97 | learning_rate = self.ddq_net.online_ops["learning_rate"].eval(self.ddq_net.sess)
98 | if((self.epoch_count>0) and (self.epoch_count % lr_decay_freq == 0) and (learning_rate>= min_learning_rate)):
99 | # Decay learning rate accoring to schedule
100 | learning_rate = 0.5*learning_rate
101 | self.ddq_net.sess.run(self.ddq_net.online_ops["learning_rate"].assign(learning_rate))
102 |
103 | # Run single epoch of training
104 | loss, train_acc, val_acc = self.train_epoch()
105 | dt = time.time()-t0
106 |
107 | print(" Finished epoch {:2}/{}: lr: {:.4e}, dt {:.2f}, loss {:.6f}, train {:.6f}, val {:.6f}".format(self.epoch_count+1, self.n_epoch, learning_rate, dt, loss, train_acc, val_acc), flush=True)
108 | summaries["loss"].append(loss)
109 | summaries["train_acc"].append(train_acc)
110 | summaries["val_acc"].append(val_acc)
111 |
112 | if(stash_model):
113 | if(self.epoch_count>0 and (self.epoch_count+1)%model_stash_interval==0):
114 | # Stash a copy of the current model
115 | out_path = "tmp/models/{}_model_E{}.ckpt".format(self.ddq_net._name, self.epoch_count+1)
116 | self.ddq_net.save(path=out_path)
117 | print("Stashed a copy of the current model in {}".format(out_path))
118 |
119 | self.ddq_net.save(path=self.ddq_net._path_to_model)
120 | return summaries
121 |
122 | def train_epoch(self):
123 | """
124 | Training loop for a single epoch
125 | """
126 | # We can't validate a winner for submissions generated by the learner,
127 | # so we will use a winner-less match when getting rewards for such states
128 | blank_match = {"winner":None}
129 |
130 | learner_submitted_actions = 0
131 | null_actions = 0
132 |
133 | # Shuffle match presentation order
134 | if(self.N_TEMP_TRAIN_MATCHES):
135 | path_to_db = "../data/competitiveMatchData.db"
136 | sources = {"patches":self.TEMP_TRAIN_PATCHES, "tournaments":[]}
137 | print("Adding {} matches to training pool from {}.".format(self.N_TEMP_TRAIN_MATCHES, path_to_db))
138 | temp_matches = pool.match_pool(self.N_TEMP_TRAIN_MATCHES, path_to_db, randomize=True, match_sources=sources)["matches"]
139 | else:
140 | temp_matches = []
141 | data = self.training_data + temp_matches
142 |
143 | shuffled_matches = random.sample(data, len(data))
144 | for match in shuffled_matches:
145 | for team in self.teams:
146 | # Process match into individual experiences
147 | experiences = mp.process_match(match, team)
148 | for pick_id, experience in enumerate(experiences):
149 | # Some experiences include NULL submissions (usually missing bans)
150 | # The learner isn't allowed to submit NULL picks so skip adding these
151 | # to the buffer.
152 | state,actual,_,_ = experience
153 | (cid,pos) = actual
154 | if cid is None:
155 | null_actions += 1
156 | continue
157 | # Store original experience
158 | self.replay.store([experience])
159 | self.step_count += 1
160 |
161 | # Give model feedback on current estimations
162 | if(self.step_count > self.observations):
163 | # Let the network predict the next action
164 | feed_dict = {self.ddq_net.online_ops["input"]:[state.format_state()],
165 | self.ddq_net.online_ops["valid_actions"]:[state.get_valid_actions()]}
166 | q_vals = self.ddq_net.sess.run(self.ddq_net.online_ops["valid_outQ"], feed_dict=feed_dict)
167 | sorted_actions = q_vals[0,:].argsort()[::-1]
168 | top_actions = sorted_actions[0:4]
169 |
170 | if(random.random() < self.epsilon):
171 | pred_act = random.sample(list(top_actions), 1)
172 | else:
173 | # Use model's top prediction
174 | pred_act = [sorted_actions[0]]
175 |
176 | for action in pred_act:
177 | (cid,pos) = state.format_action(action)
178 | if((cid,pos)!=actual):
179 | pred_state = deepcopy(state)
180 | pred_state.update(cid,pos)
181 | r = get_reward(pred_state, blank_match, (cid,pos), actual)
182 | new_experience = (state, (cid,pos), r, pred_state)
183 |
184 | self.replay.store([new_experience])
185 | learner_submitted_actions += 1
186 |
187 | if(self.epsilon > 0.1):
188 | # Reduce epsilon over time
189 | self.epsilon -= self.eps_decay_rate
190 |
191 | # Use minibatch sample to update online network
192 | if(self.step_count > self.pre_training_steps):
193 | self.train_step()
194 |
195 | if(self.step_count % self.target_update_frequency == 0):
196 | # After the online network has been updated, update target network
197 | _ = self.ddq_net.sess.run(self.ddq_net.target_ops["target_update"])
198 |
199 | # Get training loss, training_acc, and val_acc to return
200 | loss, train_acc = self.validate_model(self.training_data)
201 | _, val_acc = self.validate_model(self.validation_data)
202 | return (loss, train_acc, val_acc)
203 |
204 | def train_step(self):
205 | """
206 | Training logic for a single mini-batch update sampled from replay
207 | """
208 | # Sample training batch from replay
209 | training_batch = self.replay.sample(self.batch_size)
210 |
211 | # Calculate target Q values for each example:
212 | # For non-terminal states, targetQ is estimated according to
213 | # targetQ = r + gamma*Q'(s',max_a Q(s',a))
214 | # where Q' denotes the target network.
215 | # For terminating states the target is computed as
216 | # targetQ = r
217 | updates = []
218 | for exp in training_batch:
219 | start,_,reward,end = exp
220 | if(self.dampen_states):
221 | # To dampen states (usually done after major patches or when the meta shifts)
222 | # we replace winning rewards with 0.
223 | reward = 0.
224 | state_code = end.evaluate()
225 | if(state_code==DraftState.DRAFT_COMPLETE or state_code in DraftState.invalid_states):
226 | # Action moves to terminal state
227 | updates.append(reward)
228 | else:
229 | # Follwing double DQN paper (https://arxiv.org/abs/1509.06461).
230 | # Action is chosen by online network, but the target network is used to evaluate this policy.
231 | # Each row in predicted_Q gives estimated Q(s',a) values for all possible actions for the input state s'.
232 | feed_dict = {self.ddq_net.online_ops["input"]:[end.format_state()],
233 | self.ddq_net.online_ops["valid_actions"]:[end.get_valid_actions()]}
234 | predicted_action = self.ddq_net.sess.run(self.ddq_net.online_ops["prediction"], feed_dict=feed_dict)[0]
235 |
236 | feed_dict = {self.ddq_net.target_ops["input"]:[end.format_state()]}
237 | predicted_Q = self.ddq_net.sess.run(self.ddq_net.target_ops["outQ"], feed_dict=feed_dict)
238 |
239 | updates.append(reward + self.ddq_net.discount_factor*predicted_Q[0,predicted_action])
240 |
241 | # Update online net using target Q
242 | # Experience replay stores action = (champion_id, position) pairs
243 | # these need to be converted into the corresponding index of the input vector to the Qnet
244 | actions = np.array([start.get_action(*exp[1]) for exp in training_batch])
245 | targetQ = np.array(updates)
246 | feed_dict = {self.ddq_net.online_ops["input"]:np.stack([exp[0].format_state() for exp in training_batch],axis=0),
247 | self.ddq_net.online_ops["actions"]:actions,
248 | self.ddq_net.online_ops["target"]:targetQ,
249 | self.ddq_net.online_ops["dropout_keep_prob"]:0.5}
250 | _ = self.ddq_net.sess.run(self.ddq_net.online_ops["update"],feed_dict=feed_dict)
251 |
252 | def validate_model(self, data):
253 | """
254 | Validates given model by computing loss and absolute accuracy for data using current Qnet.
255 | Args:
256 | data (list(dict)): list of matches to validate against
257 | Returns:
258 | stats (tuple(float)): list of statistical measures of performance. stats = (loss,acc)
259 | """
260 | buf = []
261 | for match in data:
262 | # Loss is only computed for winning side of drafts
263 | team = DraftState.RED_TEAM if match["winner"]==1 else DraftState.BLUE_TEAM
264 | # Process match into individual experiences
265 | experiences = mp.process_match(match, team)
266 | for exp in experiences:
267 | _,act,_,_ = exp
268 | (cid,pos) = act
269 | if cid is None:
270 | # Skip null actions such as missing/skipped bans
271 | continue
272 | buf.append(exp)
273 |
274 | n_exp = len(buf)
275 | targets = []
276 | for exp in buf:
277 | start,_,reward,end = exp
278 | state_code = end.evaluate()
279 | if(state_code==DraftState.DRAFT_COMPLETE or state_code in DraftState.invalid_states):
280 | # Action moves to terminal state
281 | targets.append(reward)
282 | else:
283 | feed_dict = {self.ddq_net.online_ops["input"]:[end.format_state()],
284 | self.ddq_net.online_ops["valid_actions"]:[end.get_valid_actions()]}
285 | predicted_action = self.ddq_net.sess.run(self.ddq_net.online_ops["prediction"], feed_dict=feed_dict)[0]
286 |
287 | feed_dict = {self.ddq_net.target_ops["input"]:[end.format_state()]}
288 | predicted_Q = self.ddq_net.sess.run(self.ddq_net.target_ops["outQ"], feed_dict=feed_dict)
289 |
290 | targets.append(reward + self.ddq_net.discount_factor*predicted_Q[0,predicted_action])
291 |
292 | actions = np.array([start.get_action(*exp[1]) for exp in buf])
293 | targets = np.array(targets)
294 |
295 | feed_dict = {self.ddq_net.online_ops["input"]:np.stack([exp[0].format_state() for exp in buf],axis=0),
296 | self.ddq_net.online_ops["actions"]:actions,
297 | self.ddq_net.online_ops["target"]:targets,
298 | self.ddq_net.online_ops["valid_actions"]:np.stack([exp[0].get_valid_actions() for exp in buf],axis=0)}
299 |
300 | loss, pred_q = self.ddq_net.sess.run([self.ddq_net.online_ops["loss"], self.ddq_net.online_ops["valid_outQ"]],feed_dict=feed_dict)
301 |
302 | accurate_predictions = 0
303 | rank_tolerance = 5
304 | for n in range(n_exp):
305 | state,act,_,_ = buf[n]
306 | submitted_action_id = state.get_action(*act)
307 |
308 | data = [(a,pred_q[n,a]) for a in range(pred_q.shape[1])]
309 | df = pd.DataFrame(data, columns=['act_id','Q'])
310 | df.sort_values('Q',ascending=False,inplace=True)
311 | df.reset_index(drop=True,inplace=True)
312 | df['rank'] = df.index
313 | submitted_row = df[df['act_id']==submitted_action_id]
314 | rank = submitted_row['rank'].iloc[0]
315 | if rank < rank_tolerance:
316 | accurate_predictions += 1
317 |
318 | accuracy = accurate_predictions/n_exp
319 | return (loss, accuracy)
320 |
321 | class SoftmaxTrainer(BaseTrainer):
322 | def __init__(self, network, n_epoch, training_data, validation_data, batch_size, load_path=None):
323 | num_episodes = len(training_data)
324 | print("***")
325 | print("Beginning training..")
326 | print(" train_epochs: {}".format(n_epoch))
327 | print(" num_episodes: {}".format(num_episodes))
328 | print(" batch_size: {}".format(batch_size))
329 | print("***")
330 |
331 | self.model = network
332 | self.n_epoch = n_epoch
333 | self.training_data = training_data
334 | self.validation_data = validation_data
335 | self.batch_size = batch_size
336 | self.load_path = load_path
337 |
338 | self.step_count = 0
339 | self.epoch_count = 0
340 |
341 | self.teams = [DraftState.BLUE_TEAM, DraftState.RED_TEAM]
342 |
343 | self._buffer = er.ExperienceBuffer(max_buffer_size=20*len(training_data))
344 | self._val_buffer = er.ExperienceBuffer(max_buffer_size=20*len(validation_data))
345 |
346 | self.fill_buffer(training_data, self._buffer)
347 | self.fill_buffer(validation_data, self._val_buffer)
348 |
349 | def fill_buffer(self, data, buf):
350 | for match in data:
351 | for team in self.teams:
352 | experiences = mp.process_match(match, team)
353 | # remove null actions (usually missing bans)
354 | for exp in experiences:
355 | _,act,_,_ = exp
356 | cid,pos = act
357 | if(cid):
358 | buf.store([exp])
359 |
360 | def sample_buffer(self, buf, n_samples):
361 | experiences = buf.sample(n_samples)
362 | states = []
363 | actions = []
364 | valid_actions = []
365 | for (state, action, _, _) in experiences:
366 | states.append(state.format_state())
367 | valid_actions.append(state.get_valid_actions())
368 | actions.append(state.get_action(*action))
369 |
370 | return (states, actions, valid_actions)
371 |
372 | def train(self):
373 | summaries = {}
374 | summaries["loss"] = []
375 | summaries["train_acc"] = []
376 | summaries["val_acc"] = []
377 |
378 | lr_decay_freq = 10
379 | min_learning_rate = 1.e-8 # Minimum learning rate allowed to decay to
380 |
381 | stash_model = True # Flag for stashing a copy of the model
382 | model_stash_interval = 10 # Stashes a copy of the model this often
383 |
384 | # Load existing model
385 | self.model.sess.run(self.model.ops_dict["init"])
386 | if(self.load_path):
387 | self.model.load(self.load_path)
388 | print("\nCheckpoint loaded from {}".format(self.load_path))
389 |
390 | for self.epoch_count in range(self.n_epoch):
391 | learning_rate = self.model.ops_dict["learning_rate"].eval(self.model.sess)
392 | if((self.epoch_count>0) and (self.epoch_count % lr_decay_freq == 0) and (learning_rate>= min_learning_rate)):
393 | # Decay learning rate accoring to schedule
394 | learning_rate = 0.5*learning_rate
395 | self.model.sess.run(self.model.ops_dict["learning_rate"].assign(learning_rate))
396 |
397 | t0 = time.time()
398 | loss, train_acc, val_acc = self.train_epoch()
399 | dt = time.time()-t0
400 | print(" Finished epoch {:2}/{}: lr: {:.4e}, dt {:.2f}, loss {:.6f}, train {:.6f}, val {:.6f}".format(self.epoch_count+1, self.n_epoch, learning_rate, dt, loss, train_acc, val_acc), flush=True)
401 | summaries["loss"].append(loss)
402 | summaries["train_acc"].append(train_acc)
403 | summaries["val_acc"].append(val_acc)
404 |
405 | if(stash_model):
406 | if(self.epoch_count>0 and (self.epoch_count+1)%model_stash_interval==0):
407 | # Stash a copy of the current model
408 | out_path = "tmp/models/{}_model_E{}.ckpt".format(self.model._name, self.epoch_count+1)
409 | self.model.save(path=out_path)
410 | print("Stashed a copy of the current model in {}".format(out_path))
411 |
412 | self.model.save(path=self.model._path_to_model)
413 | return summaries
414 |
415 | def train_epoch(self):
416 | n_iter = self._buffer.buffer_size // self.batch_size
417 |
418 | for it in range(n_iter):
419 | self.train_step()
420 |
421 | loss, train_acc = self.validate_model(self._buffer)
422 | _, val_acc = self.validate_model(self._val_buffer)
423 |
424 | return (loss, train_acc, val_acc)
425 |
426 | def train_step(self):
427 | states, actions, valid_actions = self.sample_buffer(self._buffer, self.batch_size)
428 |
429 | feed_dict = {self.model.ops_dict["input"]:np.stack(states, axis=0),
430 | self.model.ops_dict["valid_actions"]:np.stack(valid_actions, axis=0),
431 | self.model.ops_dict["actions"]:actions,
432 | self.model.ops_dict["dropout_keep_prob"]:0.5}
433 | _ = self.model.sess.run(self.model.ops_dict["update"], feed_dict=feed_dict)
434 |
435 | def validate_model(self, buf):
436 | states, actions, valid_actions = self.sample_buffer(buf, buf.get_buffer_size())
437 |
438 | feed_dict = {self.model.ops_dict["input"]:np.stack(states, axis=0),
439 | self.model.ops_dict["valid_actions"]:np.stack(valid_actions, axis=0),
440 | self.model.ops_dict["actions"]:actions}
441 | loss, train_probs = self.model.sess.run([self.model.ops_dict["loss"], self.model.ops_dict["probabilities"]], feed_dict=feed_dict)
442 |
443 | THRESHOLD = 5
444 | accurate_predictions = 0
445 | for k in range(len(states)):
446 | probabilities = train_probs[k,:]
447 | data = [(a, probabilities[a]) for a in range(len(probabilities))]
448 | df = pd.DataFrame(data, columns=['act_id','prob'])
449 |
450 | df.sort_values('prob',ascending=False,inplace=True)
451 | df.reset_index(drop=True,inplace=True)
452 | df['rank'] = df.index
453 |
454 | submitted_action_id = actions[k]
455 | submitted_row = df[df['act_id']==submitted_action_id]
456 |
457 | rank = submitted_row['rank'].iloc[0]
458 | if(rank < THRESHOLD):
459 | accurate_predictions += 1
460 |
461 | accuracy = accurate_predictions/len(states)
462 | return (loss, accuracy)
463 |
--------------------------------------------------------------------------------
/src/features/draftstate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from data.champion_info import champion_name_from_id, valid_champion_id, get_champion_ids
3 | from .draft import Draft
4 |
5 | class InvalidDraftState(Exception):
6 | pass
7 |
8 | class DraftState:
9 | """
10 | Args:
11 | team (int) : indicator for which team we are drafting for (RED_TEAM or BLUE_TEAM)
12 | champ_ids (list(int)) : list of valid championids which are available for drafting.
13 | num_positions (int) : number of available positions to draft for. Default is 5 for a standard 5x5 draft.
14 |
15 | DraftState is the class responsible for holding and maintaining the current state of the draft. For a given champion with championid c,
16 | that champion's state with respect to the draft is at most one of:
17 | - c is banned from selection.
18 | - c is selected as part of the opponent's team.
19 | - c is selected as one of our team's position.
20 |
21 | The state of the draft will be stored as a (numChampions) x (numRoles+2) numPy array. If state(c,k) = 1 then:
22 | - k = 0 -> champion c is banned from selection.
23 | - k = 1 -> champion c is selected as part of the enemy team.
24 | - 2 <= k = num_positions+1 -> champion c is selected as position k-1 in our draft.
25 |
26 | Default draft positions are interpreted as:
27 | Position 1 -> ADC/Marksman (Primary farm)
28 | Position 2 -> Middle (Secondary farm)
29 | Position 3 -> Top (Tertiary farm)
30 | Position 4 -> Jungle (Farming support)
31 | Position 5 -> Support (Primary support)
32 | """
33 | # State codes
34 | BAN_AND_SUBMISSION = 101
35 | DUPLICATE_SUBMISSION = 102
36 | DUPLICATE_ROLE = 103
37 | INVALID_SUBMISSION = 104
38 | TOO_MANY_BANS = 105
39 | TOO_MANY_PICKS = 106
40 | invalid_states = [BAN_AND_SUBMISSION, DUPLICATE_ROLE, DUPLICATE_SUBMISSION, INVALID_SUBMISSION,
41 | TOO_MANY_BANS, TOO_MANY_PICKS]
42 |
43 | DRAFT_COMPLETE = 1
44 | BLUE_TEAM = Draft.BLUE_TEAM
45 | RED_TEAM = Draft.RED_TEAM
46 | BAN_PHASE = Draft.BAN
47 | PICK_PHASE = Draft.PICK
48 |
49 | def __init__(self, team, champ_ids = get_champion_ids(), num_positions = 5, draft = Draft('default')):
50 | #TODO (Devin): This should make sure that numChampions >= num_positions
51 | self.num_champions = len(champ_ids)
52 | self.num_positions = num_positions
53 | self.num_actions = (self.num_positions+1)*self.num_champions
54 | self.state_index_to_champ_id = {i:k for i,k in zip(range(self.num_champions),champ_ids)}
55 | self.champ_id_to_state_index = {k:i for i,k in zip(range(self.num_champions),champ_ids)}
56 | self.state = np.zeros((self.num_champions, self.num_positions+2), dtype=bool)
57 | self.picks = []
58 | self.bans = []
59 | self.selected_pos = []
60 |
61 | self.team = team
62 | self.draft_structure = draft
63 | # Get phase information from draft
64 | self.BAN_PHASE_LENGTHS = self.draft_structure.PHASE_LENGTHS[DraftState.BAN_PHASE]
65 | self.PICK_PHASE_LENGTHS = self.draft_structure.PHASE_LENGTHS[DraftState.PICK_PHASE]
66 |
67 | # The dicts pos_to_pos_index and pos_index_to_pos contain the mapping
68 | # from position labels to indices to the state matrix and vice versa.
69 | self.positions = [i-1 for i in range(num_positions+2)]
70 | self.pos_indices = [1,0]
71 | self.pos_indices.extend(range(2,num_positions+2))
72 | self.pos_to_pos_index = dict(zip(self.positions,self.pos_indices))
73 | self.pos_index_to_pos = dict(zip(self.pos_indices,self.positions))
74 |
75 | def reset(self):
76 | """
77 | Resets draft state back to default values.
78 | Args:
79 | None
80 | Returns:
81 | None
82 | """
83 | self.state[:] = False
84 | self.picks = []
85 | self.bans = []
86 | self.selected_pos = []
87 |
88 | def get_valid_actions(self, form="mask"):
89 | """
90 | Returns a valid actions for the current state.
91 | Input:
92 | self.state
93 | form (string): default returns actions as a mask. "list" returns actions as a list of ids
94 | Returns:
95 | action_ids (list[bool/int]): valid actions that can be taken from state. If form = "list" ids are returned as a list of action_ids,
96 | otherwise actions are returned as a boolean mask
97 |
98 | If the draft is complete or in an invalid state, get_valid_actions will return an empty list of actions.
99 | """
100 | # Check if draft is complete or invalid
101 | if(self.evaluate()):
102 | if(form == "list"):
103 | return np.array([])
104 | else:
105 | return np.zeros_like(self.state[:,1:].reshape(-1))
106 |
107 | sub_count = len(self.bans)+len(self.picks)
108 | phase = self.draft_structure.get_active_phase(sub_count)
109 | champ_available = np.logical_not(np.amax(self.state[:,:],axis=1))
110 | pos_available = [pos for pos in range(1, self.num_positions+1) if pos not in self.selected_pos]
111 | valid_actions = np.zeros_like(self.state[:,1:])
112 | if(phase == Draft.BAN):
113 | # only bans are (potentially) valid during ban phase
114 | valid_actions[:,0] = champ_available
115 | else:
116 | # only picks are (potentially) valid during pick phase
117 | for pos in pos_available:
118 | valid_actions[:,pos] = champ_available
119 |
120 | if(form == "list"):
121 | return np.nonzero(valid_actions.reshape(-1))
122 | else:
123 | return valid_actions.reshape(-1)
124 |
125 | def is_submission_legal(self, champion_id, position):
126 | """
127 | Checks if submission (champion_id, position) is a valid and legal submission for the current state.
128 | Returns:
129 | True if submission is legal, False otherwise.
130 | """
131 | if(not self.can_ban(champion_id) or not self.can_pick(champion_id)):
132 | return False
133 | sub_count = len(self.bans)+len(self.picks)
134 | phase = self.draft_structure.get_active_phase(sub_count)
135 | if phase == DraftState.BAN_PHASE and position != -1:
136 | return False
137 | if phase == DraftState.PICK_PHASE:
138 | pos_index = self.get_position_index(position)
139 | is_pos_filled = np.amax(self.state[:,pos_index])
140 | if(is_pos_filled):
141 | return False
142 | return True
143 |
144 | def get_champ_id(self,index):
145 | """
146 | get_champ_id returns the valid champion ID corresponding to the given state index. Since champion IDs are not contiguously defined or even necessarily ordered,
147 | this mapping will not be trivial. If index is invalid, returns -1.
148 | Args:
149 | index (int): location index in the state array of the desired champion.
150 | Returns:
151 | champ_id (int): champion ID corresponding to index (as defined by champ_ids)
152 | """
153 | if index not in self.state_index_to_champ_id.keys():
154 | return -1
155 | return self.state_index_to_champ_id[index]
156 |
157 | def get_state_index(self,champ_id):
158 | """
159 | get_state_index returns the state index corresponding to the given champion ID. Since champion IDs are not contiguously defined or even necessarily ordered,
160 | this mapping is non-trivial. If champ_id is invalid, returns -1.
161 | Args:
162 | champ_id (int): id of champion to look up
163 | Returns
164 | index (int): state index of corresponding champion id
165 | """
166 | if champ_id not in self.champ_id_to_state_index.keys():
167 | return -1
168 | return self.champ_id_to_state_index[champ_id]
169 |
170 | def get_position_index(self,position):
171 | """
172 | get_position_index returns the index of the state matrix corresponding to the given position label.
173 | If the position is invalid, returns False.
174 | Args:
175 | position (int): position label to look up
176 | Returns:
177 | index (int): index into the state matrix corresponding to this position
178 | """
179 | if position not in self.positions:
180 | return False
181 | return self.pos_to_pos_index[position]
182 |
183 | def get_position(self, pos_index):
184 | """
185 | get_position returns the position label corresponding to the given position index into the state matrix.
186 | If the position index is invalid, returns False.
187 | Args:
188 | pos_index (int): position index to look up
189 | Returns:
190 | position (int): position label corresponding to this position index
191 | """
192 | if pos_index not in self.pos_indices:
193 | return False
194 | return self.pos_index_to_pos[pos_index]
195 |
196 | def format_state(self):
197 | """
198 | Format the state so the Q-network can process it.
199 | Args:
200 | None
201 | Returns:
202 | A copy of self.state
203 | """
204 | if(self.evaluate() in DraftState.invalid_states):
205 | raise InvalidDraftState("Attempting to format an invalid draft state for network input with code {}".format(self.evaluate()))
206 |
207 | return self.state.reshape(-1)
208 |
209 | def format_secondary_inputs(self):
210 | """
211 | Produces secondary input information (information about filled positions and draft phase)
212 | to send to Q-network.
213 | Args:
214 | None
215 | Returns:
216 | Numpy vector of secondary network inputs
217 | """
218 | if(self.evaluate() in DraftState.invalid_states):
219 | raise InvalidDraftState("Attempting to format an invalid draft state for network input with code {}".format(self.evaluate()))
220 |
221 | # First segment of information checks whether each position has been or not filled in the state
222 | # This is done by looking at columns in the subarray corresponding to positions 1 thru 5
223 | start = self.get_position_index(1)
224 | end = self.get_position_index(5)
225 | secondary_inputs = np.amax(self.state[:,start:end+1],axis=0)
226 |
227 | # Second segment checks if the phase corresponding to this state is a pick phase
228 | # This is done by counting the number of bans currently submitted. Note that this assumes
229 | # that state is currently a valid state. If this is not necessarily the case a check can be made using
230 | # evaluate().
231 | submission_count = len(self.bans)+len(self.picks)
232 | phase = self.draft_structure.get_active_phase(submission_count)
233 | is_pick_phase = phase == DraftState.PICK_PHASE
234 | secondary_inputs = np.append(secondary_inputs, is_pick_phase)
235 | return secondary_inputs
236 |
237 | def format_action(self,action):
238 | """
239 | Format input action into the corresponding tuple (champ_id, position) which describes the input action.
240 | Args:
241 | action (int): Action to be interpreted. Assumed to be generated as output of ANN. action is the index
242 | of the flattened 'actionable state' matrix
243 | Returns:
244 | (championId, position) (tuple of ints): Tuple of integer values which may be passed as arguments to either
245 | self.add_pick() or self.add_ban() depending on the value of position. If position = -1 -> action is a ban otherwise action
246 | is a pick.
247 |
248 | Note: format_action() explicitly indexes into 'actionable state' matrix which excludes the portion of the state
249 | matrix corresponding to opponent team submission. In practice this means that (cid, pos) = format_action(a) will
250 | never output pos = 0.
251 | """
252 | # 'actionable state' is the sub-state of the state matrix with 'enemy picks' column removed.
253 | actionable_state = self.state[:,1:]
254 | if(action not in range(actionable_state.size)):
255 | raise "Invalid action to format_action()!"
256 | (state_index, position_index) = np.unravel_index(action,actionable_state.shape)
257 | # Action corresponds to a submission that we are allowed to make, ie. a pick or a ban.
258 | # We can't make submissions to the enemy team, so the indicies corresponding to these actions are removed.
259 | # position_index needs to be shifted by 1 in order to correctly index into full state array
260 | position_index += 1
261 | position = self.get_position(position_index)
262 | champ_id = self.get_champ_id(state_index)
263 | return (champ_id,position)
264 |
265 | def get_action(self, champion_id, position):
266 | """
267 | Given a (champion_id, position) submission pair. Return the corresponding action index in the flattened actionable state array.
268 | Args:
269 | champion_id (int): id of a champion to be picked/banned.
270 | position (int): Position of champion to be selected. The value of position determines if championId is interpreted as a pick or ban:
271 | position = -1 -> champion ban submitted.
272 | 0 < position <= num_positions -> champion selection submitted by our team for pos = position
273 | Returns:
274 | action (int): Action to be interpreted as index into the flattened actionable state vector. If no such action can be found, returns -1
275 |
276 | Note: get_action() explicitly indexes into 'actionable state' matrix which excludes the portion of the state
277 | matrix corresponding to opponent team submission. In practice this means that a = format_action(cid,pos) will
278 | produce an invalid action for pos = 0.
279 | """
280 | state_index = self.get_state_index(champion_id)
281 | pos_index = self.get_position_index(position)
282 | if ((state_index==-1) or (pos_index not in range(1,self.state.shape[1]))):
283 | print("Invalid state index or position out of range!")
284 | print("cid = {}".format(champion_id))
285 | print("pos = {}".format(position))
286 | return -1
287 | # Convert position index for full state matrix into index for actionable state
288 | pos_index -= 1
289 | action = np.ravel_multi_index((state_index,pos_index),self.state[:,1:].shape)
290 | return action
291 |
292 | def update(self, champion_id, position):
293 | """
294 | Attempt to update the current state of the draft and pick/ban lists with a given championId.
295 | Returns: True is selection was successful, False otherwise
296 | Args:
297 | champion_id (int): Id of champion to add to pick list.
298 | position (int): Position of champion to be selected. The value of position determines if championId is interpreted as a pick or ban:
299 | position = -1 -> champion ban submitted.
300 | position = 0 -> champion selection submitted by the opposing team.
301 | 0 < position <= num_positions -> champion selection submitted by our team for pos = position
302 | """
303 | # Special case for NULL ban submitted.
304 | if (champion_id is None and position == -1):
305 | # Only append NULL bans to ban list (nothing done to state matrix)
306 | self.bans.append(champion_id)
307 | return True
308 |
309 | # Submitted picks of the form (champ_id, pos) correspond with the selection champion = champion_id in position = pos.
310 | # Bans are given pos = -1 and enemy picks pos = 0. However, this is not how they are stored in the state array.
311 | # Finally this doesn't match indexing used for state array and action vector indexing (which follow state indexing).
312 | if((position < -1) or (position > self.num_positions) or (not valid_champion_id(champion_id))):
313 | return False
314 |
315 | index = self.champ_id_to_state_index[champion_id]
316 | pos_index = self.get_position_index(position)
317 | if(position == -1):
318 | self.bans.append(champion_id)
319 | else:
320 | self.picks.append(champion_id)
321 | self.selected_pos.append(position)
322 |
323 | self.state[index,pos_index] = True
324 | return True
325 |
326 | def display(self):
327 | #TODO (Devin): Clean up display to make it prettier.
328 | print("=== Begin Draft State ===")
329 | print("There are {num_picks} picks and {num_bans} bans completed in this draft. \n".format(num_picks=len(self.picks),num_bans=len(self.bans)))
330 |
331 | print("Banned Champions: {0}".format(list(map(champion_name_from_id, self.bans))))
332 | print("Picked Champions: {0}".format(list(map(champion_name_from_id, self.picks))))
333 | pos_index = self.get_position_index(0)
334 | enemy_draft_ids = list(map(self.get_champ_id, list(np.where(self.state[:,pos_index])[0])))
335 | print("Enemy Draft: {0}".format(list(map(champion_name_from_id,enemy_draft_ids))))
336 |
337 | print("Ally Draft:")
338 | for pos_index in range(2,len(self.state[0,:])): # Iterate through each position columns in state
339 | champ_index = np.where(self.state[:,pos_index])[0] # Find non-zero index
340 | if not champ_index.size: # No pick is found for this position, create a filler string
341 | draft_name = "--"
342 | else:
343 | draft_name = champion_name_from_id(self.get_champ_id(champ_index[0]))
344 | print("Position {p}: {c}".format(p=pos_index-1,c=draft_name))
345 | print("=== End Draft State ===")
346 |
347 | def can_pick(self, champion_id):
348 | """
349 | Check to see if a champion is available to be selected.
350 | Returns: True if champion is a valid selection, False otherwise.
351 | Args:
352 | champion_id (int): Id of champion to check for valid selection.
353 | """
354 | return ((champion_id not in self.picks) and valid_champion_id(champion_id))
355 |
356 | def can_ban(self, champion_id):
357 | """
358 | Check to see if a champion is available to be banned.
359 | Returns: True if champion is a valid ban, False otherwise.
360 | Args:
361 | champion_id (int): Id of champion to check for valid ban.
362 | """
363 | return ((champion_id not in self.bans) and valid_champion_id(champion_id))
364 |
365 | def add_pick(self, champion_id, position):
366 | """
367 | Attempt to add a champion to the selected champion list and update the state.
368 | Returns: True is selection was successful, False otherwise
369 | Args:
370 | champion_id (int): Id of champion to add to pick list.
371 | position (int): Position of champion to be selected. If position = 0 this is interpreted as a selection submitted by the opposing team.
372 | """
373 | if((position < 0) or (position > self.num_positions) or (not valid_champion_id(champion_id))):
374 | return False
375 | self.picks.append(champion_id)
376 | self.selected_pos.append(position)
377 | index = self.get_state_index(champion_id)
378 | pos_index = self.get_position_index(position)
379 | self.state[index,pos_index] = True
380 | return True
381 |
382 | def add_ban(self, champion_id):
383 | """
384 | Attempt to add a champion to the banned champion list and update the state.
385 | Returns: True is ban was successful, False otherwise
386 | Args:
387 | champion_id (int): Id of champion to add to bans.
388 | """
389 | if(not valid_champion_id(champion_id)):
390 | return False
391 | self.bans.append(champion_id)
392 | index = self.get_state_index(champion_id)
393 | self.state[index,self.get_position_index(-1)] = True
394 | return True
395 |
396 | def evaluate(self):
397 | """
398 | evaluate checks the current state and determines if the draft as it is currently recorded is valid.
399 | Returns: value (int) - code indicating validitiy of state
400 | Valid codes:
401 | value = 0 -> state is valid but incomplete.
402 | value = DRAFT_COMPLETE -> state is valid and complete.
403 | Invalid codes:
404 | value = BAN_AND_SUBMISSION -> state has a banned champion selected for draft. This will also appear if a ban is submitted which matches an previously submitted champion.
405 | value = DUPLICATE_SUBMISSION -> state has a champion drafted which is already part of the opposing team or has already been selected by our team.
406 | value = DUPLICATE_ROLE -> state has multiple champions selected for a single role
407 | value = INVALID_SUBMISSION -> state has a submission that was included out of the draft phase order (ex pick during ban phase / ban during pick phase)
408 | """
409 | # Check for duplicate submissions appearing in picks or bans
410 | duplicate_picks = set([cid for cid in self.picks if self.picks.count(cid)>1])
411 | # Need to remove possible NULL bans as duplicates (since these may be legitimate)
412 | duplicate_bans = set([cid for cid in self.bans if self.bans.count(cid)>1]).difference(set([None]))
413 | if(len(duplicate_picks)>0 or len(duplicate_bans)>0):
414 | return DraftState.DUPLICATE_SUBMISSION
415 |
416 | # Check for submissions appearing in both picks and bans
417 | if(len(set(self.picks).intersection(set(self.bans)))>0):
418 | # Invalid state includes an already banned champion
419 | return DraftState.BAN_AND_SUBMISSION
420 |
421 | # Check for different champions that have been submitted for the same role
422 | for pos in range(2,self.num_positions+2):
423 | loc = np.argwhere(self.state[:,pos])
424 | if(len(loc)>1):
425 | # Invalid state includes multiple champions intended for the same role.
426 | return DraftState.DUPLICATE_ROLE
427 |
428 | # Check for out of phase submissions
429 | num_bans = len(self.bans)
430 | num_picks = len(self.picks)
431 | sub_count = num_bans+num_picks
432 |
433 | if(num_bans > self.draft_structure.NUM_BANS):
434 | return DraftState.TOO_MANY_BANS
435 | if(num_picks > self.draft_structure.NUM_PICKS):
436 | return DraftState.TOO_MANY_PICKS
437 |
438 | # validation is tuple of form (target_ban_count, target_blue_pick_count, target_red_pick_count)
439 | validation = self.draft_structure.submission_dist[sub_count]
440 | num_opponent_sub = np.count_nonzero(self.state[:,self.get_position_index(0)])
441 | num_ally_sub = num_picks - num_opponent_sub
442 | if self.team == DraftState.BLUE_TEAM:
443 | dist = (num_bans, num_ally_sub, num_opponent_sub)
444 | else:
445 | dist = (num_bans, num_opponent_sub, num_ally_sub)
446 | if(dist != validation):
447 | return DraftState.INVALID_SUBMISSION
448 |
449 | # State is valid, check if draft is complete
450 | if(num_ally_sub == self.num_positions and num_opponent_sub == self.num_positions):
451 | # Draft is valid and complete. Note that it isn't necessary
452 | # to have the full number of valid bans to register a complete draft. This is
453 | # because teams can be forced to forefit bans due to disciplinary factor (rare)
454 | # or they can elect to not submit a ban (very rare).
455 | return DraftState.DRAFT_COMPLETE
456 |
457 | # Draft is valid, but not complete
458 | return 0
459 |
460 | if __name__=="__main__":
461 | state = DraftState(DraftState.BLUE_TEAM)
462 | print(state.evaluate())
463 | print(state.num_actions)
464 | state.display()
465 | actions = state.get_valid_actions()
466 | print(actions)
467 |
468 | state.update(1,-1)
469 | state.update(2,-1)
470 | state.update(3,-1)
471 | state.update(4,-1)
472 | state.update(5,-1)
473 | state.update(6,-1)
474 |
475 | state.update(7,1)
476 | state.update(8,0)
477 | state.update(9,0)
478 |
479 | new_actions = state.get_valid_actions()
480 | print(new_actions)
481 | print("")
482 | for aid in range(len(new_actions)):
483 | if(new_actions[aid]):
484 | print(state.format_action(aid))
485 |
486 | state.update(10,2)
487 | state.update(11,3)
488 | state.update(12,0)
489 |
490 | state.update(13,-1)
491 | state.update(14,-1)
492 | state.update(15,-1)
493 | state.update(16,-1)
494 |
495 | state.update(17,0)
496 | state.update(18,5)
497 | state.update(19,4)
498 |
499 | state.display()
500 | print(state.evaluate())
501 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Swain Bot
2 | Created by Devin Light
3 |
4 | ~~A [web version of Swain Bot](https://swainbot.herokuapp.com) (hosted on Heroku) is available to play with. (Note: this model has a limited number of concurrent users and is much slower than the local version because it is hosted using the free tier offered by Heroku)~~ Since this project has been on the back-burner for a while, the Heroku-hosted version of Swain Bot has been taken down and will be updated TBD.
5 |
6 | ## Introduction
7 | ### What is League of Legends?
8 | League of Legends (abbreviated as LoL, or League) is a multiplayer online battle arena (MOBA) game developed by Riot Games which features two teams of five players each competing in head-to-head matches with the ultimate goal of destroying the opposing team's nexus structure. The game boasts millions of monthly players and a large competitive scene involving dozens of teams participating in both national and international tournaments. The game takes place across two broadly defined phases. In the first phase (or Drafting phase), each side takes turns assembling their team by selecting a unique character (called a champion) from a pool of 138 (as of this writing) without replacement. Then, in the second phase (or Game phase), each player in the match takes control of one of the champions chosen by their team and attempts to claim victory. Although not strictly required by the game, over the years players have grown to play their champion in one of five roles named after the location on the map in which they typically start the game, and often corresponding to the amount of in-game resources that player will have devoted to them. The five roles are:
9 |
10 | - Position 1 (primary farm)-> ADC/Marksman1
11 | - Position 2 (secondary farm)-> Middle
12 | - Position 3 (tertiary farm)-> Top
13 | - Position 4 (farming support)-> Jungle
14 | - Position 5 (primary support)-> Support1
15 |
16 | 1 Traditionally the ADC and Support begin the game together same lane and are collectively called 'Bottom'.
17 |
18 | Each champion has distinct set of characteristics and abilities that allows them excel in certain situations while struggling in others. In order to maximize the odds of victory, it is important that the team assembled during the drafting phase simultaneously plays into one cohesive set of strengths and disrupts or plays well against the strengths of the opposing draft. There are two types of submissions made during the drafting phase. In the banning portions of the draft champions are removed from the pool of allowed submissions, whereas champions are added to the roster of the submitting team during the pick phases. The drafting alternates between banning and picking until both teams have a full roster of five champions, at which point the game is played. The structure of the drafting phase is displayed in Figure 1. Note the asymmetry between teams (for example Blue bans first in ban phase one, while Red bans first in ban phase two) and between the phases themselves (ban phases always alternate sides, while pick phases "snake" between teams).
19 |
20 | 
21 |
22 | ### What is Swain Bot?
23 | Swain Bot (named after the champion Swain who is associated with being a ruthless master tactician/general) is a machine learning application built in Python using Google's Tensorflow framework. Swain Bot is designed to analyze the drafting phase of competitive League of Legends matches. Given a state of the draft which includes full information of our team's submissions (champions and positions) and partial information of the opponent's submissions (champions only), Swain Bot attempts to suggest picks and bans that are well-suited for the current state of our draft.
24 |
25 | ### What do we hope to do with Swain Bot?
26 | Knowing the best pick for a given draft situation can dramatically improve a team's chance for success. Our objective with Swain Bot is to help provide insight into League's crucial draft phase by attempting to answer questions like:
27 | - Can we estimate how valuable each submission is for a given state of the draft?
28 | - Is there a common structure or theme to how professional League teams draft?
29 | - Can we identify the differences between a winning and a losing draft?
30 |
31 | ## Assumptions and Limitations
32 | Every model tasked with approaching a difficult problem is predicated on some number assumptions which in turn define the boundaries that the model can safely be applied. Swain Bot is no exception, so here we outline and discuss some of the explicit assumptions being made going into the construction of the underlying model Swain Bot uses to make its predictions. Some of the assumptions are more impactful than others and some could be removed in the future to improve Swain Bot's performance, but are in place for now for various reasons.
33 |
34 | 1. Swain Bot is limited to data from recorded professionally played games from the "Big 5" regions (NALCS, EULCS, LCK, LPL, and LMS). Limiting potential data sources to competitive leagues is very restrictive when compared to the pool of amateur matches played on servers across the world. However, this assumption is in place as a result of changes in Riot's (otherwise exceptional) API which effectively randomizes the order in which the champion submissions for a draft are presented, rendering it impossible to recover the sequence of draft states that make up the complete information regarding the draft. Should the API be changed in the future Swain Bot will be capable of learning from amateur matches as well.
35 |
36 | 2. Swain Bot does not receive information about either the patch the game was played on or the teams involved in the match. Not including the patch allows us to stretch the data as much as we can given the restricted pool. Although the effectiveness of a champion might change as they are tuned between patches, it is unlikely that they are changed so much that the situations that the champion would normally be picked in are dramatically different. Nevertheless substantial champion changes have occured in the past, usually in the form of a total redesign. Additionally, although team data for competitive matches is available during the draft, Swain Bot's primary objective is to identify the most effective submissions for a given draft state rather than predict what a specific team might select in that situation. It would be possible to combine Swain Bot's output with information about a team's drafting tendencies (using ensemble techniques like stacking) to produce a final prediction which both suits the draft and is likely to be chosen by the team. However we will leave this work for later.
37 |
38 | 3. Swain Bot's objective is to associate the combination of a state and a potential submission with a value and to suggest taking the action which has the highest value. This valuation should be based primarily on what is likely to win the draft (or move us towards a winning state), and partly on what is likely to be done. Although these two goals may be correlated (a champion that is highly-valued might also be the one selected most frequently) they are not necessarily the same since, for example, teams may be biased towards or against specific strategies or champions.
39 |
40 | 4. Swain Bot's objective to estimate the value of submissions for a given draft state is commonly approached using techniques from Reinforcement Learning (RL). RL methods have been successfully used in a variety of situations such as teaching robots how to move, playing ATARI games, and even [playing DOTA2](https://blog.openai.com/dota-2/). A common element to most RL applications is the ability to automatically explore and evaluate states as they are encountered in a simulated environment. However, Swain Bot is not capable of automatically playing out the drafts it recommends in order to evaluate them (yet..) and so is dependent on the data it observes originating from games that were previously played. This scenario is reminiscent of a Supervised Learning (SL) problem called behavioral cloning, where the task is to learn and replicate the policy outlined by an expert. However, behavioral cloning does not include the estimation of values associated with actions and attempts to directly mimic the expert policy. Swain Bot instead implements an RL algorithm to estimate action values (Q-Learning), but trained using expertly-generated data. In practice this means that the predictions made by Swain Bot can only have an expectation of accuracy when following trajectories that are similar to the paths prescribed by the training data.
41 |
42 | ## Methods
43 | This section is not designed to be too technical, but rather give some insight into how Swain Bot is implemented and some important modifications that helped with the learning process. For some excellent and thorough discussions on RL, check out the following:
44 | - [David Silver's course on RL](http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching.html) [(with video lectures)](https://www.youtube.com/watch?v=2pWv7GOvuf0)
45 | - [Reinforcement Learning](http://incompleteideas.net/book/the-book-2nd.html) By Sutton and Barto
46 | - [The DeepMind ATARI paper](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
47 | - [Dueling DQNs](https://arxiv.org/pdf/1511.06581.pdf)
48 | - And finally a [few](http://outlace.com/rlpart3.html), [useful](https://www.intelnervana.com/demystifying-deep-reinforcement-learning/), [tutorials](https://medium.com/emergent-future/simple-reinforcement-learning-with-tensorflow-part-0-q-learning-with-tables-and-neural-networks-d195264329d0)
49 |
50 | ### Representing Draft States and Actions
51 | Each of the _N_ eligible champions (138 as of this writing) in a draft is represented by a unique `champion_id` integer and every position in the game (five positions per team plus banned champions) is given by a `position_id`. An _action_ (or _submission_) to the draft is defined as tuple of the form `(champion_id, position_id) = (i,j)` representing the selection of champion `i` to position `j` in the draft. We can represent the _draft state_ as a boolean matrix _S_ where _S_(i,j) = 1 if the ith champion is present in the draft at the jth position. The number of columns in _S_ is determined by how much information about the positions is available to the drafter:
52 | - In a _completely informed_ draft all position information is known so _S_ is an `N x 11` matrix (10 positions + bans).
53 | - In a _partially informed_ draft position information is only known for the drafter's team whereas only the `champion_id`s are known for the opponent's team. As a result _S_ is given by an `N x 7` matrix (5 positions + bans + enemy champion ids).
54 |
55 | Note that the _S_ is a sparse matrix since for any given state of a draft, there are no more than 10 picks and 10 bans that have been submitted so there are no more than 20 non-zero entries in _S_ at any given time. Swain Bot operates using partially informed draft states as inputs which may be obtained by projecting the five columns in the completely informed state corresponding to the positions in the opponent's draft onto a single column. Finally, we define the _actionable state_ to be the submatrix of _S_ corresponding to the actions the drafter may submit-- that is the columns corresponding to bans as well as the drafter's five submittable positions.
56 |
57 | ### Drafting as a Markov Decision Process (MDP)
58 | An individual experience (sometimes called a memory) observed during the draft can be recorded as a tuple of the form `e_k = (s_k,a_k,r_k,s')` where `s_k` is the initial state of the draft before an action was taken, `a_k` is the action submitted, `r_k` is the reward observed as a result of taking `a_k`, and `s'` is the successor state transitioned to away from `s_k`. For either completely or partially informed states, the draft can be fully recovered using the sequence of experiences `(e_0, e_1,..., e_n)` transitioned through during the draft. This sequence defines a Markov chain because given the current state _s_, the value of the successor state _s'_ is independent of the states that were transitioned through before _s_. In other words, the possible states we are able to transition to away from _s_ depend only on _s_ itself, and not on the states that were seen on the way to _s_.
59 |
60 | To complete the description of drafting as an MDP we need to define a reward schedule and discount factor. The discount factor is a scalar value between 0 and 1 that governs the present value of future expected rewards. Two common reasons to use a discount factor are to express uncertainty about the potential value of the future and to capture the extra value of taking an immediate reward over a delayed one (e.g. if the reward is financial, an immediate reward is worth more than a delayed reward because that immediate reward can then be used to earn additional interest). Typical discount factor values are in the range `0.9` to `0.99`. Swain Bot uses `discount_factor = 0.9`.
61 |
62 | The reward schedule is a vital component of the MDP and determines what policy the model will ultimately converge towards. As previously discussed, Swain Bot's long-term objective is to select actions which move the draft towards a winning state while in the short term placing some value on actions which are likely to be taken. The ideal reward schedule should combine these two goals so that Swain Bot predicts both good and probable actions. We will approach this by associating larger magnitude rewards with _terminal_ states and smaller magnitude rewards with non-terminal states. A terminal state _s_ occurs in one of the three scenarios where _s_ represents:
63 | 1. a valid, complete draft which resulted in a win
64 | 2. a valid, complete draft which resulted in a loss
65 | 3. an invalid draft (which cannot be played)
66 |
67 | All other states are valid, but non-terminal. An invalid state _s_ is one in which one or more of the following conditions are satisfied where _s_ represents:
68 | 1. an incorrect number of picks or bans for the phase of the draft described by that state (e.g. picks submitted during Ban Phase 1, too many picks submitted in a single phase, two consecutive picks associated with red side during Pick Phase 2, etc.)
69 | 2. at least one champion selected in more than one position (e.g. picked and banned, picked by both teams, or picked by a team in more than one role)
70 | 3. at least one non-ban actionable position with more than one champion submitted to that position. For partially complete drafts the opposing team position must also have no more than five submissions represented.
71 |
72 | It's reasonable to have the network (or a secondary network) infer which actions lead to invalid drafts, and use that information to avoid predicting them. This essentially amounts to learning the rules of drafting. However we can get away with a smaller network by filtering out the illegal actions before making predictions. This helps significantly reduce the amount of training time required before observing reasonable results.
73 |
74 | The empirically determined reward schedule is defined in two parts depending on if _s_ is a terminal state. If _s_ is terminal, the reward is given by
75 |
76 |
77 |
78 | If _s_ is non-terminal, the reward has the form
79 |
80 |
81 |
82 | where _a_* is the action taken during the original memory.
83 |
84 | ### Deep Q-Learning (DQN)
85 | With the framework describing drafting as an MDP, we can apply a Q-Learning algorithm to estimate `Q(s,a)`, the maximum expected future return by taking action `a` from state `s`. With 138 total champions and 20 chosen at a time to appear in the final draft state, there are roughly `6.07x10^{23}` possible ending states, making a tabular Q-learning method out of the question. Instead, we opt to use a straightforward fully connected neural network to estimate the value function for an arbitrary state. The selected model's architecture consists of 4 total layers:
86 | - 1 One-hot encoded input layer (representing draft state)
87 | - 2 FC + ReLU layers with `1024` nodes each
88 | - 1 FC linearly activated output layer with `n_champ x n_pos+1` nodes (representing actions from the actionable state)
89 |
90 | Regularization is present in the model via dropout in between the FC layers. In order to help reduce the amount of time the model spends learning the underlying draft structure, the first FC layer is given several supplementary boolean inputs corresponding to which positions have been filled in the draft as well as if the draft is currently in a banning phase. These inputs aren't strictly necessary since the value of these variables is directly inferrable from the state matrix, but doing so substantially reduces the amount of training time required before the model begins producing legal submissions for input states. There are also three important modifications to the fundamental DQN algorithm that are necessary in order to ensure that the Q-learning is stable:
91 | 1. Data augmentation
92 | 2. Experience Replay (ER)
93 | 3. Double DQN (DDQN)
94 |
95 | Data augmentation refers to modifying existing training data in order to effectively increase the size of the training set and help reduce model overfitting. The techniques used to modify the data include adding random noise to training data or cropping, mirroring, and rotating images. For the purposes of data augmentation in Swain Bot, we identify two types of augmentable submissions: sequential picks and bans. Referring to the original image of the drafting structure, notice that several of the submissions made during the two pick phases are made sequentially: red side's first two picks, blue side's second and third picks, and blue side's fourth and fifth picks. Swain Bot makes predictions one action at a time, so the order in which these submissions are presented is preserved and the input state for the second submission will include the first submission as a result. However, since these submissions are made back-to-back, in practice sequential picks use the same initial state and the order in which they are presented is arbitrary. A simple augmentation which helps Swain Bot emulate this common initial state is to randomly exchange the order in which sequential picks in the draft are executed. For example, if a sample draft includes two memories corresponding to red side sequentially picking (Kalista, ADC) and then (Leona, Support), half of the time the order in which these picks are learned from would be reversed.
96 |
97 | The second augmented submissions are bans that share a common phase. Technically the drafting structure alternates between teams during bans, meaning that the initial state for each ban is distinct and as a result the order in which they are submitted is important. On the other hand, the purpose behind banning is to either protect your draft (by removing strong submissions against your team) or attack the enemy draft (by removing strong submissions to their team). In either circumstance, the bans we submit do not depend on the opponent's submissions _except in the case where they ban a champion which we would otherwise ban_. Unlike picks, however, our opponents "sniping" a ban actually _benefits_ our draft by effectively allowing our team to make an additional ban. This effect can be approximated by randomly shuffling the order in which the bans of a given phase are presented, just like we did with sequential picks. Note that bans cannot be shuffled across phases because bans made during the second phase are made with information about 3/5ths of the picks for each team available.
98 |
99 | Experience replay provides a mechanism for separating the generation of memories from learning from those memories. In experience replay, each experience associated with a draft is stored into a pool of experiences spanning many drafts. The Q-learning update is applied to a minibatch of randomly sampled experiences contained within the pool. This is important because consecutive experiences generated from the same draft are strongly correlated and learning from them all simultaneously using SGD is not only inefficient, but may also lead to a local (suboptimal) minimum. By randomizing the samples drawn from the replay buffer, the correlation between experiences is at least broken. Additionally, each memory is potentially used in multiple updates, improving overall data efficiency.
100 |
101 | The default DQN algorithm selects actions "greedily" by taking the maximum over the estimated action-values. A side effect of this maximization is that the DQN tends to learn overestimated values. Unfortunately this over optimism is often non-uniformly distributed across actions and can degrade the performance of the learned policy. Furthermore, this overestimation also tends to grow as the number of actions increases. There are 822 (137 champions each selectable in 6 positions) possible actions during each stage of drafting. As a result, it is desirable to control this overestimation as much as possible. The DDQN algorithm proposed by van Hesselt et. al. attempts to limit this overestimation by pseudo-decoupling action selection from evaluation by utilizing two networks: an "online" network and a "target" network. The online network represents the most up-to-date parameter estimates, while the target network is a periodic snapshot of the online network. In simplest terms the original update for DQN
102 |
103 | `update = reward + discount_factor*max_a'{Q(s',a')}`
104 |
105 | is replaced with
106 |
107 | `update = reward + discount_factor*Q_target(s', max_a'{Q_online(s',a')})`.
108 |
109 | Note that this doesn't truly decouple action selection and evaluation because the target network is a copy of a previous online network. The goal of DDQN is to be the simplest modification of DQN in the direction of a truly decopuled algorithm (like double Q-learning) in order to get most of the benefit with the smallest computational overhead required.
110 |
111 | ## Analysis
112 | ### Evaluating the Model
113 | In addition to the "eyeball test" of Swain Bot's predictions (i.e. no illegal submissions, correct number of roles, overall "meta-feel" of drafts, etc.), we're also interested in a quantitative measure of performance.
114 |
115 | One approach is to treat predictions as we would with a classifier and measure the fraction of predictions which agree with what was actually submitted in a winning draft. However, it's important to recall that our objective is to predict valuable submissions which may not necessarily overlap with what a specific team is likely to submit. It is often the case that multiple submissions are suited for the draft and as a result each have roughly equal value (this is particularly true for bans and early submissions). Selecting an action from amongst these picks is mostly a function of the biases of the drafting team. Since team identities aren't included as part of the input, it is unrealistic to expect the model to match the exact submission made by every team. A simple way to try and compensate for this is to group the top `k` submissions and regard these as a set of "good" picks according to the model. Then we measure accuracy as the fraction of submissions made that are contained in the predicted "good" submission pool for each state.
116 |
117 | Another approach is to examine the difference in estimated Q-values between the top prediction (`max_a{Q(s,a)}`) and the actual submission (`Q(s,a*)`). The difference between these two values estimates how far off the actual action that was submitted is from taking over as the top prediction. If both the top recommendation and `a*` are really a good submissions for this state this difference should be relatively small. If we use this to compute a normalized mean squared error over a each state we are predicting for we should get an estimate of the model performance:
118 |
119 |
120 |
121 | Note that if the model were to assign all actions the same value then this measure of error would be trivially zero. So just like the classification measure of accuracy, this measure of error is not perfect. Nevertheless the combination gives some insight into how the model performs.
122 |
123 | ### Training Data
124 | Competitive match data was pulled from [lol.gamepedia.com](https://lol.gamepedia.com).
125 |
126 | The model was trained in two stages. For the first stage, data was obtained from matches played during the 2017 Summer Season and 2017 Worlds Championship Qualifiers for the five major international regions:
127 | - North America (NA LCS)
128 | - Europe (EU LCS)
129 | - China (LPL)
130 | - Korea (LCK)
131 | - Taiwan, Hong Kong, & Macau (LMS)
132 |
133 | Between the two training stages the model underwent a dampening iteration (in which the value of all predictions were reduced) in order to simulate a change in metas associated with the gap in time between the Summer Season and Worlds Championships. The second stage of training used data from the 119 matches played during the 2017 World Championship with 11 randomly selected matches from the knockout stages held out for validation. The model learned on each match in the training set for 100 epochs (i.e. each match was seen 100 times in expectation). The model was trained using a smaller learning rate `alpha_0 = 1.0e-05` which was halved every 10 epochs until it reached a minimum value of `alpha_f = 1.0e-08`.
134 |
135 | ### Validation Matches
136 | The figures below illustrates the drafts from each match in the validation set. The left side of the draft represents the submissions made by the blue side, while the right side depicts submissions made by the red side.
137 |
138 |
139 |
140 |
141 | The distribution of the eleven matches according to the stage of the worlds they were played in was:
142 | - Play-in stage: 2 matches
143 | - Group stage: 5 matches
144 | - Knockout stage: 3 matches
145 | - Finals: 1 match
146 |
147 | The table below lists the classification (top 1), "good" set classification (top 5), and normalized root mean square error (l2 error) for three categories: all submissions predicted (full), all submissions excluding the first phase of bans (no round 1 bans), and picks only (no bans).
148 |
149 | ```
150 | Norm Information:
151 | Full
152 | Num_predictions = 110
153 | top 1: count 19 -> acc: 0.1727
154 | top 5: count 62 -> acc: 0.5636
155 | l2 error: 0.02775
156 | ---
157 | No Round 1 Bans
158 | Num_predictions = 77
159 | top 1: count 17 -> acc: 0.2208
160 | top 5: count 47 -> acc: 0.6104
161 | l2 error: 0.03004
162 | ---
163 | No Bans
164 | Num_predictions = 55
165 | top 1: count 14 -> acc: 0.2545
166 | top 5: count 39 -> acc: 0.7091
167 | l2 error: 0.02436
168 | ```
169 |
170 | Starting with the top 1 classification accuracy it's apparent that predicting the exact submission at every stage in the draft is difficult for the model, which achieves an abysmal 15-25% accuracy across the three categories. This is likely due to a combination of lacking information about which teams are drafting (which would allow the model to distinguish between submissions that are either "comfort picks" or picks that a team is biased against) and the weighting difference between winning submissions and likely submissions. For the top 5 classification accuracy the model improves significantly, particularly when the first three bans are ignored. The stable l2 error also indicates that the model at least associates elevated values for the submissions from winning drafts even if they are not the exact submission it predicts. Finally, the jump in performance between the full set of predictions and the predictions excluding the first phase of banning generally holds true. In contrast, the difference in performance after further removing the second phase of bans is smaller. This suggests that earlier submissions in the draft are significantly more difficult to predict than later ones. This might be due to a combination of factors. First, since the submissions are the furthest away from the most rewarding states there is a large uncertainty in associating the reward observed at the end of a draft with the first few selections submitted. Second, even after it's complete the first phase of bans contributes relatively little information to the draft when compared with the information gained after several picks have been made. Finally, the first phase ban submissions are perhaps the most significantly influenced by team biases since they tend to revolve around removing the picks you and your opponent are likely to make.
171 |
172 | It's also interesting to look at the types of misclassifications the model made on the picks submitted in this data set. Naturally, many of the misclassifications occured with the model predicting one set of meta picks but the submitting team selecting something within the meta but outside that list (usually within the top 10 or so predictions). In particular the model tended to under-value using a first pick on Janna as a support, instead often predicting ADCs or Junglers in that slot. Then there were mistakes occured when teams opted for pocket picks to gain an advantage by either selecting a seldom seen champion or by "flexing" a champion to a secondary role. These submissions were the most surprsing to the model in the sense that they often lay way outside the models top five predictions. Below we highlight a few of these picks:
173 |
174 |
175 |
176 | The C9 v LYN match was one of a very small number of games where Kalista was not removed in the first phase of bans. Although she certainly defined the meta (with a staggering 100% presence in the drafting phase at Worlds 2017), there weren't enough matches played with her involved for the model to recommend her over the more commonly seen ADCs like Xayah, Tristana, and Kog'maw. This illustrates that the model does not connect "must ban" with "must pick" at the moment , although this is often the case with the strongest champions in the meta. C9's follow-up pick on Thresh was surprising partly because Thresh was in his own right an uncommon pick (played only 5 times total), but also because the model lacked context for his synergy with Kalista. In SSG v SKT, Zac was a surprising pick because he had only appeared 5 times in the tournament, and this match was his only win. The SKT v MSF game was the only game involving flexing Trundle to support at the tournament and RNG was the only team to pick Soraka (although they were successful with it). Ultimately it is a challenging problem for the model to accurately predict niche picks without overfitting.
177 |
178 | For completeness here is the table for all 2017 Worlds matches (including the play-in stages):
179 | ```
180 | Norm Information:
181 | Full
182 | Num_predictions = 1190
183 | top 1: count 549 -> acc: 0.4613
184 | top 5: count 905 -> acc: 0.7605
185 | l2 error: 0.01443
186 | ---
187 | No Round 1 Bans
188 | Num_predictions = 833
189 | top 1: count 513 -> acc: 0.6158
190 | top 5: count 744 -> acc: 0.8932
191 | l2 error: 0.01136
192 | ---
193 | No Bans
194 | Num_predictions = 595
195 | top 1: count 371 -> acc: 0.6235
196 | top 5: count 525 -> acc: 0.8824
197 | l2 error: 0.01049
198 | ```
199 | Obviously since the model was trained on the majority of these matches it performs much better.
200 |
201 | Worlds 2017 was dominated by the "Ardent Censer" meta which favored hard-scaling position 1 carries combined with position 5 supports who could abuse the item Ardent Censer (an item which amplified the damage output of the position 1 pick). This made using early picks to secure a favorable bot lane matchup extremely popular. The remaining pick from the first phase tended to be spent selecting a safe jungler like Jarvan IV, Sejuani, or Gragas. As a result, if we look at the distribution of positions made during the first phase of each of the 119 drafts conducted at 2017 Worlds we can see a strong bias against the solo lanes (positions 2 and 3).
202 |
203 | ```
204 | Phase 1: Actual
205 | Position 1: Count 111, Ratio 0.311
206 | Position 2: Count 28, Ratio 0.0784
207 | Position 3: Count 24, Ratio 0.0672
208 | Position 4: Count 92, Ratio 0.258
209 | Position 5: Count 102, Ratio 0.286
210 | ```
211 |
212 | We can compare this with the positions of the top 5 recommendations made by Swain Bot during the first pick phase:
213 |
214 | ```
215 | Phase 1 Recommendations:
216 | Position 1: Count 611, Ratio 0.342
217 | Position 2: Count 179, Ratio 0.1
218 | Position 3: Count 248, Ratio 0.139
219 | Position 4: Count 305, Ratio 0.171
220 | Position 5: Count 442, Ratio 0.248
221 | ```
222 | Swain Bot agrees with the meta in using early picks to secure a bot lane. However, by comparison it is more likely to suggest a solo lane pick in the first phase instead of a jungler. This effect was also seen in the actual drafts towards the end of the tournament where solo lane picks like Galio and Malzahar became increasingly valuable and took over what would have likely been jungler picks in the earlier stages.
223 |
224 | ## Looking Ahead
225 | Even with the promising results so far, Swain Bot is far from complete. Here are just a few things that could be worth looking into for the future:
226 |
227 | - Compared with the massive number of possibilities for drafting, we're still limited to the relatively tiny pool of drafts coming from competitive play. As a result Swain Bot's ability to draft is really limited to the types of picks likely to be seen in the professional leagues (for fun try asking Swain Bot what to pick after adding champions in positions that are rarely/never seen in competitive play like Teemo or Heimerdinger). Ideally we would have access to detailed draft data for the millions of games played across all skill levels through [Riot's API](https://developer.riotgames.com), but unfortunately the API does not preserve submissions in draft order (yet). If we could have one wish this would be it. There is some hope for the future with the recently announced competitive 5v5 [Clash mode](https://nexus.leagueoflegends.com/en-us/2017/12/dev-clash/).
228 |
229 | - Build a sequence of models each using a small number of patches or a single patch for data. This could help improve Swain Bot's meta adaptation between patches by ensembling this sequence to make predictions from the most recent "meta history" while data from a new patch is being processed.
230 |
231 | - Including some form of team data for competitive drafts, either as a model input or as an additional layer on top of the current model structure. Swain Bot's current iteration implicitly assumes that every team is able to play every champion at the highest level. Even for the best players in the world, this is certainly not true. By adding information about the biases of the teams drafting, we could improve Swain Bot's ability to make suggestions which suit both the meta and the team.
232 |
233 | - Exploring alternative models (in particular Actor-Critic Policy Gradient). The DDQN model Swain Bot implements is by no means a poor one, but in the constantly evolving world of machine learning it is no longer cutting-edge. In particular our implementation of DDQN is deterministic, meaning that when presented with the same input state the model will always output the same action. If recommendations made by Swain Bot were to be used in practice to assemble a draft, this could be potentially abused by our opponents. Policy gradient methods parameterize the policy directly and are able to represent both stochastic and deterministic policies.
234 |
235 | - Reducing the required training time using GPUs and/or services like AWS.
236 |
237 | ## Conclusion
238 | Thank you for taking the time to read through this write up. Working on Swain Bot was originallly an excuse to dive into machine learning but eventually became a huge motivator to keep learning and grinding.. especially early on when nothing was stable. I also want to give a huge thank you to Crystal for supporting me in my now year-long diversion into machine learning and both being someone to lean on through the hard times and someone to celebrate with during the good times. I couldn't have done this without you. Thank you also to the greatest group of friends I could ask for. League is a fun game on its own, but I wouldn't still be so excited to play after so many years if it wasn't for you (and your reluctant tolerance for Nunu Top). Finally, thank you to the folks at Riot Games. The passion, excitement, and devotion you show for your product is infectious.
239 |
240 |
241 | ## Disclaimer
242 | Swain Bot isn’t endorsed by Riot Games and doesn’t reflect the views or opinions of Riot Games or anyone officially involved in producing or managing League of Legends. League of Legends and Riot Games are trademarks or registered trademarks of Riot Games, Inc. League of Legends © Riot Games, Inc.
243 |
--------------------------------------------------------------------------------