├── .gitignore ├── README.md ├── data.py ├── data ├── __init__.py ├── config.ini ├── convert_replays.py ├── create_dataset.py ├── download_replays.py ├── interfacers │ ├── __init__.py │ ├── base_interfacer.py │ ├── calculatedgg_api │ │ ├── __init__.py │ │ ├── api_interfacer.py │ │ ├── errors.py │ │ ├── query_params.py │ │ └── test_api_interfacer.py │ └── local_interfacer.py └── utils │ ├── __init__.py │ ├── columns.py │ ├── number_check.py │ ├── playlists.py │ └── utils.py ├── data_main.py ├── models ├── __init__.py ├── base_model.py └── dense_model.py ├── rank.ipynb ├── requirements.txt ├── trainers ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── metric_tracer.py │ ├── metrics.py │ ├── prediction_plotter.py │ ├── prediction_plotter_plot.py │ └── tensorboard.py └── sequences │ ├── __init__.py │ └── calculated_sequence.py ├── value_function ├── __init__.py ├── value_function_conv_model.py ├── value_function_model.py └── value_function_trainer.py └── x_things ├── __init__.py ├── model_retrainer.py ├── x_goals_conv_model.py ├── x_goals_trainer.py └── x_things_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | replays/ 3 | *.json 4 | ReplayDataProcessing/ 5 | __pycache__/ 6 | data/local/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReplayModels 2 | 3 | Creates models for Rocket League replay analysis. 4 | 5 | ## Implementation 6 | ### General classes and modules 7 | Data retrieval is done through the `DataManager` subclasses (e.g. `CalculatedLocalDM`). 8 | These subclasses expose a `get_data()` method which retrieves a `GameData` object (with `.df` and `.proto` attributes). 9 | 10 | General utility functions such as filtering columns of the dataframe are available in `data/utils/utils.py`. 11 | 12 | `data/utils/number_check.py` checks the number of available replays in calculated.gg's api for a certain query, 13 | for a given playlist and min MMR. 14 | 15 | ### `value_function` 16 | Use `batched_value_function.py` which uses the refactorised class `BatchTrainer`. 17 | Running it should cache replay dataframes and protos, and plot loss with quicktracer. 18 | 19 | ## data_main.py 20 | Run this script to either download replay files, convert them to CSV, or combine CSVs into a dataset. 21 | Doing this relies on the config.ini file in "data/" 22 | 23 | Steps to getting a dataframe of replay data: 24 | 25 | Set up your config.ini file (what mode and mmr range you want to deal with, path options) 26 | 27 | In the command line with the necessary packages installed: 28 | 29 | python data_main.py (to see what args you want to use) 30 | 31 | python data_main.py download [args] 32 | python data_main.py convert [args] 33 | python data_main.py dataset [args] 34 | 35 | You now have a .h5 file that can be opened by pandas into a dataframe. 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import gzip 3 | import io 4 | import os 5 | import pandas as pd 6 | 7 | import carball 8 | import requests 9 | from carball.analysis.analysis_manager import PandasManager, AnalysisManager 10 | from carball.analysis.utils.proto_manager import ProtobufManager 11 | from carball.generated.api.game_pb2 import Game 12 | 13 | BASE_URL = 'https://calculated.gg/api/v1/' 14 | 15 | 16 | class DataManager: 17 | 18 | def get_replay_list(self, num: int = 50): 19 | raise NotImplementedError() 20 | 21 | def get_pandas(self, id_: str): 22 | raise NotImplementedError() 23 | 24 | def get_proto(self, id_: str): 25 | raise NotImplementedError() 26 | 27 | 28 | class Calculated(DataManager): 29 | PANDAS_MAP = {} 30 | PROTO_MAP = {} 31 | BROKEN = [] 32 | 33 | def get_replay_list(self, num=50, page=1): 34 | r = requests.get(BASE_URL + 'replays?key=1&minmmr=1300&maxmmr=1400&playlist=13&num={}&page={}'.format(num, page)) 35 | return [replay['hash'] for replay in r.json()['data']] 36 | 37 | def get_pandas(self, id_): 38 | if id_ in self.BROKEN: 39 | return None 40 | if id_ in self.PANDAS_MAP: 41 | return self.PANDAS_MAP[id_] 42 | url = BASE_URL + 'parsed/{}.replay.gzip?key=1'.format(id_) 43 | r = requests.get(url) 44 | gzip_file = gzip.GzipFile(fileobj=io.BytesIO(r.content), mode='rb') 45 | try: 46 | pandas_ = PandasManager.safe_read_pandas_to_memory(gzip_file) 47 | except: 48 | self.PANDAS_MAP[id_] = None 49 | self.BROKEN.append(id_) 50 | return None 51 | self.PANDAS_MAP[id_] = pandas_ 52 | return pandas_ 53 | 54 | def get_proto(self, id_): 55 | if id_ in self.PROTO_MAP: 56 | return self.PROTO_MAP[id_] 57 | url = BASE_URL + 'parsed/{}.replay.pts?key=1'.format(id_) 58 | r = requests.get(url) 59 | # file_obj = io.BytesIO() 60 | # for chunk in r.iter_content(chunk_size=1024): 61 | # if chunk: # filter out keep-alive new chunks 62 | # file_obj.write(chunk) 63 | proto = ProtobufManager.read_proto_out_from_file(io.BytesIO(r.content)) 64 | self.PROTO_MAP[id_] = proto 65 | return proto 66 | 67 | 68 | class Carball(DataManager): 69 | REPLAYS_DIR = 'replays' 70 | REPLAYS_MAP = {} 71 | 72 | def get_replay_list(self, num=50): 73 | replays = glob.glob(os.path.join(self.REPLAYS_DIR, '*.replay')) 74 | return [os.path.basename(replay).split('.')[0] for replay in replays] 75 | 76 | def get_pandas(self, id_) -> pd.DataFrame: 77 | return self._process(id_).data_frame 78 | 79 | def get_proto(self, id_) -> Game: 80 | return self._process(id_).protobuf_game 81 | 82 | def _process(self, id_) -> AnalysisManager: 83 | if id_ in self.REPLAYS_MAP: 84 | return self.REPLAYS_MAP[id_] 85 | path = os.path.join(self.REPLAYS_DIR, id_ + '.replay') 86 | manager = carball.analyze_replay_file(path, "replay.json") 87 | self.REPLAYS_MAP[id_] = manager 88 | return manager 89 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SaltieRL/ReplayModels/62fb8765e8efc44c4865ee9fe121de35db7528cb/data/__init__.py -------------------------------------------------------------------------------- /data/config.ini: -------------------------------------------------------------------------------- 1 | [MODES] 2 | #These are the game modes you can work with 3 | #Select one of them in [VARS] (probably don't want to edit them here) 4 | UNRANKED_DUELS = 1,1 5 | UNRANKED_DOUBLES = 2,2 6 | UNRANKED_STANDARD = 3,3 7 | RANKED_DUELS = 10,1 8 | RANKED_DOUBLES = 11,2 9 | RANKED_SOLO_STANDARD = 12,3 10 | RANKED_STANDARD = 13,3 11 | [VARS] 12 | #Edit these 13 | mode = ${MODES:RANKED_DOUBLES} 14 | #Name the mode. Probably just copy it from "mode". 15 | mode_str = RANKED_DOUBLES 16 | #What range of MMR do you want to pull from? 17 | #This is used by the entire pipeline (download, convert, dataset) and isn't "smart" 18 | #i.e. if you download 1000-1100 and then 1100-1200, and then try to convert "1000-1200", it won't work. Sorry! 19 | mmr_range = 1000-1100 20 | [PATHVARS] 21 | #Edit these 22 | #Name the folder everything will go in 23 | main_path = ReplayDataProcessing 24 | #It works with different drives as well 25 | #main_path = D:\ReplayDataProcessing 26 | #The code mostly assumes PATHVARS are unique and hold one type of file. (Except dpath can have .csv, .h5, anything) 27 | #Sub-paths for different types of files. 28 | rpath = Replays 29 | cpath = CSVs 30 | tcpath = TestCSVs 31 | dpath = Datasets 32 | jpath = Jsons 33 | [PATHS] 34 | #I kept this loose so that it's easier to change your directory structure 35 | #Probably don't edit this. 36 | vars= ${VARS:mode} 37 | #PATHVARS for files we are skipping for a few reasons 38 | error_path = ${PATHVARS:main_path}/${VARS:mode_str}/${PATHVARS:rpath}/${VARS:mmr_range}-ERROR/ 39 | skip_path = ${PATHVARS:main_path}/${VARS:mode_str}/${PATHVARS:rpath}/${VARS:mmr_range}-SKIPPED/ 40 | #PATHVARS specific to the number of players 41 | replay_path = ${PATHVARS:main_path}/${VARS:mode_str}/${PATHVARS:rpath}/${VARS:mmr_range}/ 42 | replay_log = ${PATHVARS:main_path}/${VARS:mode_str}/${PATHVARS:rpath}/${VARS:mmr_range}_log.csv 43 | json_path = ${PATHVARS:main_path}/${VARS:mode_str}/${PATHVARS:jpath}/${VARS:mmr_range}/ 44 | csv_path = ${PATHVARS:main_path}/${VARS:mode_str}/${PATHVARS:cpath}/${VARS:mmr_range}/ 45 | dataset_path = ${PATHVARS:main_path}/${VARS:mode_str}/${PATHVARS:dpath}/${VARS:mmr_range}/ 46 | testcsv_path = ${PATHVARS:main_path}/${VARS:mode_str}/${PATHVARS:tcpath}/${VARS:mmr_range}/ 47 | [CSV] 48 | #Only change these if you change replay conversion such that there are more/less columns 49 | #TODO: Programmaticaly edit these during replay conversion? 50 | #Used by dataset creation 51 | columns_per_player = 18 52 | game_columns = 17 -------------------------------------------------------------------------------- /data/convert_replays.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import os 3 | import random 4 | import sys 5 | import time 6 | from configparser import ConfigParser, ExtendedInterpolation 7 | from datetime import datetime, timedelta 8 | from multiprocessing import Process, Value, current_process 9 | from shutil import copyfile 10 | from typing import List 11 | 12 | import carball 13 | import numpy as np 14 | import pandas as pd 15 | from carball.rattletrap.run_rattletrap import RattleTrapException 16 | 17 | # Getting config 18 | config = ConfigParser(interpolation=ExtendedInterpolation()) 19 | config.read('data/config.ini') 20 | mode = config['VARS']['MODE'].split(',') 21 | mmrs = config['VARS']['mmr_range'].split('-') 22 | NUM_PLAYERS = int(mode[1]) 23 | # Paths 24 | paths = config['PATHS'] 25 | # Paths for files we are skipping for various reasons 26 | error_path = paths['error_path'] 27 | skip_path = paths['skip_path'] 28 | # PATHVARS specific to the number of players 29 | replay_path = paths['replay_path'] 30 | json_path = paths['json_path'] 31 | csv_path = paths['csv_path'] 32 | dataset_path = paths['dataset_path'] 33 | testcsv_path = paths['testcsv_path'] 34 | 35 | 36 | # HELPERS 37 | # STOPPING RATTLETRAPS PRINTS (https://stackoverflow.com/questions/8391411/suppress-calls-to-print-python) 38 | class HiddenPrints: 39 | """ 40 | Context manager that prevents printing to stdout or stderr. 41 | """ 42 | 43 | def __enter__(self): 44 | self._original_stdout = sys.stdout 45 | self._original_stderr = sys.stderr 46 | sys.stdout = open(os.devnull, 'w') 47 | sys.stderr = open(os.devnull, 'w') 48 | 49 | def __exit__(self, exc_type, exc_val, exc_tb): 50 | sys.stdout.close() 51 | sys.stderr.close() 52 | sys.stdout = self._original_stdout 53 | sys.stderr = self._original_stderr 54 | 55 | 56 | """ with HiddenPrints(): 57 | print("This will not be printed") 58 | ... 59 | print("This will be printed as before")""" 60 | 61 | 62 | # Helper: Ordering the columns of the df 63 | def get_ordered_columns(players_per_team: int) -> List[str]: 64 | """ 65 | Return an ordered list of column names to be passed to a game dataframe. 66 | :param players_per_team: Determines how many player columns to return. 67 | :type players_per_team: int 68 | :return: A list of column names. 69 | :rtype: List[str] 70 | """ 71 | x = players_per_team 72 | non_player = ['ball_pos_x', 'ball_pos_y', 'ball_pos_z', 'ball_rot_x', 'ball_rot_y', 'ball_rot_z', 73 | 'ball_vel_x', 'ball_vel_y', 'ball_vel_z', 'ball_ang_vel_x', 'ball_ang_vel_y', 'ball_ang_vel_z', 74 | 'game_seconds_remaining', 'game_goal_number'] 75 | z_zero = ['z_0_pos_x', 'z_0_pos_y', 'z_0_pos_z', 'z_0_rot_x', 'z_0_rot_y', 'z_0_rot_z', 76 | 'z_0_vel_x', 'z_0_vel_y', 'z_0_vel_z', 'z_0_ang_vel_x', 'z_0_ang_vel_y', 'z_0_ang_vel_z', 77 | 'z_0_boost', 'z_0_boost_active', 'z_0_jump_active', 'z_0_double_jump_active', 'z_0_dodge_active', 78 | 'z_0_is_demo'] 79 | z_one = [] 80 | z_two = [] 81 | o_zero = [] 82 | o_one = [] 83 | o_two = [] 84 | for col in z_zero: 85 | if x > 1: 86 | z_one.append(col.replace('0', '1', 1)) 87 | if x > 2: 88 | z_two.append(col.replace('0', '2', 1)) 89 | o_zero.append(col.replace('z', 'o', 1)) 90 | if x > 1: 91 | for col in o_zero: 92 | o_one.append(col.replace('0', '1', 1)) 93 | if x > 2: 94 | o_two.append(col.replace('0', '2', 1)) 95 | 96 | columns_ordered = z_zero + z_one + z_two + o_zero + o_one + o_two + non_player 97 | return columns_ordered 98 | 99 | 100 | def shrink_df(input_df): 101 | """ 102 | Shrink a dataframes size in memory while minimizing data loss. 103 | :param input_df: The input game dataframe 104 | :type input_df: pd.DataFrame 105 | :return: A copy of the dataframe, with column types changed. 106 | :rtype: pd.DataFrame 107 | """ 108 | df = input_df.copy(deep=True) 109 | sub_dict = {"pos": [1, np.int16], "rot": [1000, np.int16], "vel": [1, np.int16], "boost": [1, np.uint8], 110 | "active": [1, np.int8], "next": [1, np.int8], "is": [1, np.int8], 111 | "sec": [1, np.int16], "score": [1, np.int8]} 112 | for sub, ops in sub_dict.items(): 113 | cols = [col for col in df.columns if sub in col] 114 | for col in cols: 115 | df[col] = (df[col] * ops[0]).astype(ops[1]) 116 | return df 117 | 118 | 119 | def restructure_and_get_goals(proto_game, gdf: pd.DataFrame): 120 | """ 121 | Rename and reorder gdf columns and create structures to continue converting the gdf. 122 | :param proto_game: protobuf data from analysis object 123 | :type proto_game: proto 124 | :param gdf: A dataframe created by an analysis object 125 | :type gdf: pd.Dataframe 126 | :return: Lists holding goal-related data, and the modified dataframe. 127 | :rtype: Union[List[int], List[int], List[str], List[int], pd.Dataframe] 128 | """ 129 | # Init structures from json 130 | goal_seconds, goal_frames, goal_scorers, goal_teams = [], [], [], [] 131 | # Player Data 132 | # ~1 in 3 games have the blue team as team "1" instead of "0". We have to fix the team order in these cases. 133 | team_names = [[], []] 134 | game = proto_game.game_metadata 135 | if proto_game.teams[1].is_orange: 136 | team_index = [0, 1] 137 | else: # Reverse the indexing 138 | team_index = [1, 0] 139 | 140 | for player in proto_game.players: 141 | if player.is_orange: 142 | team_names[team_index[1]].append(player.name) 143 | else: 144 | team_names[team_index[0]].append(player.name) 145 | 146 | for goal in game.goals: 147 | goal_frames.append(goal.frame_number) 148 | goal_scorers.append(goal.player_id) # Currently unused 149 | if goal.player_id in proto_game.teams[0].player_ids: 150 | goal_teams.append(team_index[0]) 151 | elif goal.player_id in proto_game.teams[1].player_ids: 152 | goal_teams.append(team_index[1]) 153 | # Dictionary for rename 154 | rename_dict = {} 155 | for tup in gdf.columns: 156 | # Need to change all the player values and make them in order 157 | if tup[0] in team_names[0]: 158 | i = team_names[0].index(tup[0]) 159 | sub = 'z_' + str(i) + '_' + tup[1] 160 | 161 | elif tup[0] in team_names[1]: 162 | i = team_names[1].index(tup[0]) 163 | sub = 'o_' + str(i) + '_' + tup[1] 164 | 165 | else: 166 | sub = tup[0] + '_' + tup[1] 167 | 168 | rename_dict[tup] = sub 169 | 170 | gdf = gdf.rename(rename_dict, axis='columns') 171 | # Add demo columns 172 | for team in ['z_', 'o_']: 173 | for i in range(NUM_PLAYERS): 174 | gdf[team + str(i) + '_is_demo'] = np.zeros(len(gdf)) 175 | 176 | return goal_seconds, goal_frames, goal_scorers, goal_teams, gdf 177 | 178 | 179 | def add_game_columns(gdf: pd.DataFrame, goal_frames: List[int], goal_seconds: List[int], goal_teams: List[int]) -> int: 180 | """ 181 | Mutate the game dataframe, adding columns for goals predictions, and score, and time until goal. 182 | :param gdf: The game dataframe from the function "replays_to_csv" 183 | :type gdf: pd.Dataframe 184 | :param goal_frames: 185 | :type goal_frames: List[int] 186 | :param goal_seconds: 187 | :type goal_seconds: List[int] 188 | :param goal_teams: 189 | :type goal_teams: List[int] 190 | :return: The length of the added columns. Used to truncate the gdf. 191 | :rtype: int 192 | """ 193 | goal_one_column = np.empty([0]) 194 | score_0 = np.empty([0]) 195 | score_1 = np.empty([0]) 196 | until_goal = np.array([300]) 197 | score = [0, 0] 198 | index = 0 199 | for i in range(len(goal_frames)): 200 | if i == 0: 201 | min_index = 0 202 | else: 203 | min_index = goal_frames[i - 1] 204 | # Get the length of the slice of the game between goals 205 | length = goal_frames[i] - min_index 206 | index += length 207 | # Make arrays to be added to columns 208 | l_1 = np.full(length, goal_teams[i]) 209 | s_0 = np.full(length, score[0]) 210 | s_1 = np.full(length, score[1]) 211 | 212 | until_seconds = np.full(length, goal_seconds[i]) 213 | # Get a slice from game_seconds_remaining and get it as an array 214 | arr_secs_remaining = gdf['game_seconds_remaining'][min_index:goal_frames[i]].to_numpy(copy=True) 215 | # Concat until we have our full columns 216 | goal_one_column = np.concatenate((goal_one_column, l_1)) 217 | score_0 = np.concatenate((score_0, s_0)) 218 | score_1 = np.concatenate((score_1, s_1)) 219 | # TODO: Fix how overtime causes negative values (multiple by -ot?) 220 | until_goal = np.concatenate((until_goal, arr_secs_remaining - until_seconds)) 221 | # Update score 222 | score[0] += 1 - goal_teams[i] 223 | score[1] += goal_teams[i] 224 | # Convert to series so it can be added to the df 225 | goal_one_column = pd.Series(goal_one_column) 226 | score_0 = pd.Series(score_0) 227 | score_1 = pd.Series(score_1) 228 | until_goal = pd.Series(until_goal) 229 | # Add columns 230 | gdf['secs_to_goal'] = until_goal 231 | gdf['next_goal_one'] = goal_one_column 232 | gdf['score_zero'] = score_0 233 | gdf['score_one'] = score_1 234 | 235 | return len(goal_one_column) 236 | 237 | 238 | # TODO: actually use logging instead of prints 239 | def reporting(shared, interval_mins): 240 | """ 241 | Multiprocessing function called by pre_process_parallel. Reports the status of the other processes and the work. 242 | :param shared: A list of multiprocessing values to track errors. 243 | :type shared: List[Value] 244 | :param interval_mins: The interval on which to report. 245 | :type interval_mins: int 246 | :return: None 247 | :rtype: None 248 | """ 249 | try: 250 | interval = timedelta(minutes=interval_mins) 251 | in_len = len(os.listdir(replay_path)) 252 | out_len = len(os.listdir(csv_path)) 253 | test_len = len(os.listdir(testcsv_path)) 254 | err_len = len(os.listdir(error_path)) 255 | skip_len = len(os.listdir(skip_path)) 256 | remaining = in_len - out_len - test_len - err_len - skip_len 257 | print( 258 | "Starting Reporting\n There are {} total replays.\n There are {} replays left to process. ({}%)\n".format( 259 | in_len, remaining, round(remaining / in_len, 2) * 100)) 260 | sys.stdout.flush() 261 | # Un-fancy way to ensure the first print is accurate, wait for other processes to start. Fancy method is not much better and more lines. 262 | time.sleep(10) 263 | average = [timedelta(0), 0] 264 | errors = shared[0].value + shared[1].value + shared[2].value + shared[3].value + shared[4].value + shared[ 265 | 5].value 266 | start = datetime.now() 267 | while remaining > 0: 268 | try: 269 | with shared[6].get_lock(): 270 | if shared[6].value == 0: 271 | print("No remaining processes") 272 | return 273 | while (datetime.now() - start) < interval: 274 | time.sleep(interval_mins) 275 | elapsed = datetime.now() - start 276 | start = datetime.now() 277 | 278 | new_errors = shared[0].value + shared[1].value + shared[2].value + \ 279 | shared[3].value + shared[4].value + shared[5].value 280 | 281 | processed = (len(os.listdir(csv_path)) - out_len) + (len(os.listdir(testcsv_path)) - test_len) + ( 282 | new_errors - errors) 283 | 284 | out_len = len(os.listdir(csv_path)) 285 | test_len = len(os.listdir(testcsv_path)) 286 | err_len = len(os.listdir(error_path)) 287 | skip_len = len(os.listdir(skip_path)) 288 | 289 | remaining = in_len - out_len - test_len - err_len - skip_len 290 | 291 | errors = new_errors 292 | if processed == 0: 293 | print("Empty Reporting Interval") 294 | continue 295 | average[0] += elapsed 296 | average[1] += processed 297 | 298 | print("Total errors:{} ({})%".format(errors, round(errors / in_len, 2))) 299 | print( 300 | "Processed {} files.\nThe average time per file is {}".format(processed, (average[0] / average[1]))) 301 | print( 302 | "There are {} replays left to process. ({}%)".format(remaining, round(remaining / in_len, 2) * 100)) 303 | print("Total errors:{} ({})%".format(errors, round(errors / in_len, 2))) 304 | print("Estimated completion in {}\n\n".format(average[0] / average[1] * remaining)) 305 | sys.stdout.flush() 306 | 307 | except KeyboardInterrupt: 308 | print("Exiting Reporting") 309 | break 310 | 311 | except KeyboardInterrupt: 312 | print("Exiting Reporting") 313 | 314 | print( 315 | "err_analysis_index: {}\nerr_analysis_key: {}\nerr_analysis_rattletrap: {}\nerr_analysis_unbound: {}\nerr_analysis_other: {}\nerr_gdf_index: {}\n".format( 316 | shared[0].value, shared[1].value, shared[2].value, shared[3].value, shared[4].value, shared[5].value)) 317 | return 318 | 319 | 320 | def replays_to_csv(in_files: List[str], output_path: str, shared: List[Value]): 321 | """ 322 | Multiprocessing function called by pre_process_parallel. Converts replay files to csv files. 323 | :param in_files: A list of files to process. 324 | :type in_files: List[str] 325 | :param output_path: The path to save CSVs 326 | :type output_path: str 327 | :param shared: A list of multiprocessing values to track errors. 328 | :type shared: List[Value] 329 | :return: None 330 | :rtype: None 331 | """ 332 | try: 333 | ordered_cols = get_ordered_columns(NUM_PLAYERS) 334 | print("Starting {} with {} files".format(current_process().name, len(in_files))) 335 | file_average = [timedelta(), 0] 336 | for file in in_files: 337 | file_start = datetime.now() 338 | try: 339 | with HiddenPrints(): 340 | try: 341 | e = False 342 | # Analysis has a lot of possible errors 343 | analysis = carball.analyze_replay_file(replay_path + file) 344 | except IndexError: 345 | with shared[0].get_lock(): 346 | shared[0].value += 1 347 | e = True 348 | continue 349 | except KeyError: 350 | with shared[1].get_lock(): 351 | shared[1].value += 1 352 | e = True 353 | continue 354 | except RattleTrapException: # Currently can't figure out how to except "carball.rattletrap.run_rattletrap.RattleTrapException" 355 | with shared[2].get_lock(): 356 | shared[2].value += 1 357 | e = True 358 | continue 359 | except UnboundLocalError: 360 | with shared[3].get_lock(): 361 | shared[3].value += 1 362 | e = True 363 | continue 364 | except KeyboardInterrupt: 365 | print("Exiting") 366 | break 367 | except(): 368 | with shared[4].get_lock(): 369 | shared[4].value += 1 370 | e = True 371 | continue 372 | finally: 373 | if e: 374 | copyfile(replay_path + file, error_path + file) 375 | 376 | proto_game = analysis.get_protobuf_data() 377 | gdf = analysis.get_data_frame() 378 | # Check if game has overtime, and skip it if it has extra columns (very rare, but can happen) 379 | gdf.columns = gdf.columns.to_flat_index() 380 | game_has_overtime = False 381 | if len(gdf.columns) != 65: 382 | if len(gdf.columns) == 66: 383 | if ('game', 'is_overtime') in gdf.columns: 384 | game_has_overtime = True 385 | else: 386 | copyfile(replay_path + file, skip_path + file) 387 | continue 388 | 389 | goal_seconds, goal_frames, goal_scorers, goal_teams, gdf = restructure_and_get_goals(proto_game, gdf) 390 | 391 | # Skip if no goals or if game is too short 392 | # Take this out or change it depending on how much you care about dataset quality 393 | if 210 not in gdf['game_seconds_remaining']: 394 | copyfile(replay_path + file, skip_path + file) 395 | continue 396 | 397 | if len(goal_frames) == 0: 398 | copyfile(replay_path + file, skip_path + file) 399 | assert (len(proto_game.game_metadata.goals) == 0) 400 | continue 401 | # Order the columns of the df 402 | gdf = gdf[ordered_cols] 403 | # Get times of each goal 404 | try: 405 | for i in goal_frames: 406 | goal_seconds.append(gdf.iloc[i]['game_seconds_remaining']) 407 | except IndexError: 408 | with shared[5].get_lock(): 409 | shared[5].value += 1 410 | copyfile(replay_path + file, error_path + file) 411 | print('Index Error') 412 | continue 413 | # Skip overtimes. 414 | if game_has_overtime: 415 | if len(goal_seconds) == 1: 416 | copyfile(replay_path + file, skip_path + file) 417 | continue 418 | else: 419 | goal_seconds = goal_seconds[:-1] 420 | goal_frames = goal_frames[:-1] 421 | 422 | trunc_length = add_game_columns(gdf, goal_frames, goal_seconds, goal_teams) 423 | 424 | # Fixing up missing values and rows after the last goal 425 | gdf = gdf.truncate(after=trunc_length - 1) 426 | # Drop when all player values are NA 427 | sub = ['z_0_pos_x', 'o_0_pos_x', 'z_1_pos_x', 'o_1_pos_x', 'z_2_pos_x', 'o_2_pos_x'] 428 | gdf = gdf.dropna(how='all', subset=sub[:NUM_PLAYERS*2]) 429 | # forward fill demos (Single player NA), then fill empty values (Ball values) 430 | for team in ['z_', 'o_']: 431 | for i in range(NUM_PLAYERS): 432 | num = str(i) + '_' 433 | gen_list = ['pos_x', 'pos_y', 'pos_z', 'rot_x', 'rot_y', 'rot_z', 'vel_x', 'vel_y', 'vel_z', 434 | 'ang_vel_x', 'ang_vel_y', 435 | 'ang_vel_z', 'boost_active', 'jump_active', 'double_jump_active', 'dodge_active'] 436 | fill_list = [team + num + entry for entry in gen_list] 437 | # Change demo column using presence of NA values 438 | gdf[team + num + 'is_demo'] = gdf[fill_list].isna().replace({True: 1, False: 0}).mean(axis=1) 439 | # Turn NA values into value before demo 440 | for _ in fill_list: 441 | gdf.loc[:, fill_list] = gdf.loc[:, fill_list].ffill(axis=0) 442 | 443 | # Drop the time after a goal is scored but before reset (in these casesm game_goal_number is N/A) 444 | gdf = gdf.dropna(axis='index', subset=['game_goal_number']) 445 | gdf = gdf.drop(['game_goal_number'], axis=1) 446 | gdf = gdf[gdf['secs_to_goal'] > 0] 447 | # Fill rest of NA value with 0 448 | gdf = gdf.fillna(0) 449 | # Change active values to boolean 450 | gdf['z_0_jump_active'] = ((gdf['z_0_jump_active'] % 2) != 0).astype(int) 451 | gdf['o_0_jump_active'] = ((gdf['o_0_jump_active'] % 2) != 0).astype(int) 452 | gdf['z_0_double_jump_active'] = ((gdf['z_0_double_jump_active'] % 2) != 0).astype(int) 453 | gdf['o_0_double_jump_active'] = ((gdf['o_0_double_jump_active'] % 2) != 0).astype(int) 454 | gdf['z_0_dodge_active'] = ((gdf['z_0_dodge_active'] % 2) != 0).astype(int) 455 | gdf['o_0_dodge_active'] = ((gdf['o_0_dodge_active'] % 2) != 0).astype(int) 456 | # Convert all booleans to 0 or 1 457 | gdf = gdf.replace({True: 1, False: 0}) 458 | # Reduce size in memory 459 | # Write out to CSV after shrinking 460 | gdf = shrink_df(gdf) 461 | gdf.to_csv(output_path + file.split('.')[0] + '.csv') 462 | file_average[1] += 1 463 | file_average[0] += (datetime.now() - file_start) 464 | sys.stdout.flush() 465 | except(KeyboardInterrupt, SystemExit): 466 | break 467 | except(KeyboardInterrupt, SystemExit): 468 | pass 469 | # Update number of processes so reporting doesn't idle 470 | with shared[6].get_lock(): 471 | shared[6].value -= 1 472 | print("{} Exiting".format(current_process().name)) 473 | return 474 | 475 | 476 | def pre_process_parallel(num_processes, test_ratio=.1, overwrite=False, verbose_interval=10): 477 | """ 478 | Execute a number of processes to pre-process replay files to CSVs, as designated by the config. 479 | :param num_processes: The number of processes to run. 480 | :type num_processes: int 481 | :param test_ratio: The ratio of replays to use as a test set 482 | :type test_ratio: float 483 | :param overwrite: Whether to overwrite existing CSVs 484 | :type overwrite: bool 485 | :param verbose_interval: How often to print information (Minutes) 486 | :type verbose_interval: float 487 | :return: None 488 | :rtype: None 489 | """ 490 | # check if num_processes is reasonable 491 | if num_processes < 1: 492 | print("Processes must be at least 1") 493 | return 494 | if num_processes > mp.cpu_count(): 495 | print("Running more processes than cpu_count") 496 | # prepare paths 497 | for p in [csv_path, testcsv_path, error_path, skip_path]: 498 | if not os.path.exists(p): 499 | os.makedirs(p) 500 | print("Created directories in {}".format(p)) 501 | 502 | # Get file names 503 | in_files = os.listdir(replay_path) 504 | out_files = os.listdir(csv_path) 505 | out_test = os.listdir(testcsv_path) 506 | err_files = os.listdir(error_path) 507 | skip_files = os.listdir(skip_path) 508 | # Tracking errors and handling separate test output 509 | total = len(in_files) 510 | # TODO: Restructure these loops into 1? 511 | # Skip existing CSVs unless we are overwriting, and count extraneous CSVs 512 | extraneous = 0 513 | if not overwrite: 514 | for file in out_files: 515 | if file.split('.')[0] + '.replay' in in_files: 516 | in_files.remove(file.split('.')[0] + '.replay') 517 | else: 518 | extraneous += 1 519 | for file in out_test: 520 | if file.split('.')[0] + '.replay' in in_files: 521 | in_files.remove(file.split('.')[0] + '.replay') 522 | else: 523 | extraneous += 1 524 | # Skip error replays 525 | err_count = 0 526 | for file in err_files + skip_files: 527 | if file in in_files: 528 | in_files.remove(file) 529 | err_count += 1 530 | 531 | print("Skipping {} files recorded as causing errors or meeting criteria to be skipped.".format(err_count)) 532 | print("There are {} existing output CSV's that don't correspond with a replay file.".format(extraneous)) 533 | if len(in_files) == 0: 534 | print("No replays left to process") 535 | return 536 | # Remove duplicate CSV's in test and out (from out) 537 | duplicates = 0 538 | for file in out_test: 539 | if file in out_files: 540 | duplicates += 1 541 | os.remove(csv_path + file) 542 | 543 | # Tracking total counts of things 544 | err_analysis_index = Value('I', 0) 545 | err_analysis_key = Value('I', 0) 546 | err_analysis_rattletrap = Value('I', 0) 547 | err_analysis_unbound = Value('I', 0) 548 | err_analysis_other = Value('I', 0) 549 | err_gdf_index = Value('I', 0) 550 | running_processes = Value('I', 0) 551 | 552 | shared = [err_analysis_index, err_analysis_key, err_analysis_rattletrap, 553 | err_analysis_unbound, err_analysis_other, err_gdf_index, running_processes] 554 | processes: List[Process] = [] 555 | if verbose_interval > 0: 556 | reporter = Process(target=reporting, args=(shared, verbose_interval)) 557 | # Randomly get test set 558 | random.shuffle(in_files) 559 | if test_ratio is not 0: 560 | total_test = int(test_ratio * total) 561 | num_test = total_test - len(out_test) 562 | if num_test > len(in_files): 563 | print( 564 | "There are not enough input files to produce the desired test ratio. This can happen if overwrite = False and we have already partially processed these folders with a smaller ratio.") 565 | return 566 | if num_test > 0: 567 | test_files = in_files[:num_test] 568 | in_files = in_files[num_test:] 569 | processes.append(Process(target=replays_to_csv, args=(test_files, testcsv_path, shared))) 570 | # If only 1 process (why?), just do it all and return (Pretend you didn't see this code) 571 | num_processes -= 1 572 | if num_processes == 0: 573 | processes[0].start() 574 | processes[0].join() 575 | process = Process(target=replays_to_csv, args=(in_files, csv_path, shared)) 576 | process.start() 577 | process.join() 578 | return 579 | else: 580 | print("There are already more files in the test set than given by the ratio.") 581 | # TODO: Stop bottlenecking on the test set process 582 | 583 | # Split up the work and start processes 584 | workloads = np.array_split(in_files, num_processes) 585 | if verbose_interval > 0: 586 | reporter.start() 587 | for i in range(num_processes): 588 | processes.append(Process(target=replays_to_csv, args=(workloads[i], csv_path, shared))) 589 | for p in processes: 590 | p.start() 591 | with shared[6].get_lock(): 592 | shared[6].value += 1 593 | for p in processes: 594 | p.join() 595 | if verbose_interval > 0: 596 | reporter.join() 597 | return 598 | -------------------------------------------------------------------------------- /data/create_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | from configparser import ConfigParser, ExtendedInterpolation 5 | from datetime import datetime 6 | import psutil 7 | import pandas as pd 8 | import numpy as np 9 | 10 | # Getting config 11 | config = ConfigParser(interpolation=ExtendedInterpolation()) 12 | config.read('data/config.ini') 13 | mode = config['VARS']['MODE'].split(',') 14 | mmrs = config['VARS']['mmr_range'].split('-') 15 | NUM_PLAYERS = int(mode[1]) 16 | # Paths 17 | paths = config['PATHS'] 18 | csv_path = paths['csv_path'] 19 | dataset_path = paths['dataset_path'] 20 | testcsv_path = paths['testcsv_path'] 21 | # CSV vars 22 | cols_per_player = int(config['CSV']['columns_per_player']) 23 | game_columns = int(config['CSV']['game_columns']) 24 | 25 | 26 | def name_dataset(size, test): 27 | """ 28 | Return a filename that reflects some context of the dataset creation. 29 | :param test: Whether the dataset is a test set. 30 | :type test: bool 31 | :param size: The number of games in the dataset. 32 | :type size: int 33 | :return: A filename. 34 | :rtype: str 35 | """ 36 | if test: 37 | t = '-test_set' 38 | else: 39 | t = '' 40 | pcols = cols_per_player 41 | gcols = game_columns 42 | return f"{size}_games-{pcols}_pcols-{gcols}_gcols{t}" 43 | 44 | 45 | # concat CSVs all together into a h5 dataset 46 | # This overwrites the output file right now 47 | def dataset(output_file=None, test=False, max_games=None, ram_max=75): 48 | """ 49 | Concatenate CSVs designated by the config into a dataset. 50 | :param ram_max: When RAM usage hits this percentage, start a new chunk. Actual RAM usage at this point will only be half the max. 51 | :type ram_max: int 52 | :param output_file: An output filename. Generated automatically if None. 53 | :type output_file: Optional[str] 54 | :param test: If the dataset will be a test set. 55 | :type test: bool 56 | :param max_games: Maximum number of games to put into the dataset. 57 | :type max_games: int 58 | :return: None 59 | :rtype: None 60 | """ 61 | ''' 62 | On chunking: The downside is that the shuffling is per chunk (frame level shuffling, not game shuffling), 63 | so with really small chunks the dataset won't be shuffled well. 64 | You'd think this would be fine since we can shuffle during training, but if you have RAM problems (4-16GB) that may not be the case. 65 | It's probably better to have a good shuffle (bigger chunks) now and then be able to train on smaller chunks later. 66 | ''' 67 | # TODO: make ram_max smarter 68 | 69 | # Total size of team set 70 | if not os.path.exists(dataset_path): 71 | os.makedirs(dataset_path) 72 | print("Created directories in {}".format(dataset_path)) 73 | if test and not os.path.exists(testcsv_path): 74 | print("Test is true but there is no testcsv_path") 75 | return 76 | if not os.path.exists(csv_path): 77 | print("There is no csv_path") 78 | return 79 | 80 | # Use testcsv_path for test data and csv_path for training data 81 | if test: 82 | input_glob = testcsv_path + '*.csv' 83 | if output_file is not None: 84 | output_file += '-test_set' 85 | else: 86 | input_glob = csv_path + '*.csv' 87 | dfs = glob.glob(input_glob) 88 | if (max_games is not None) and (max_games < len(dfs)): 89 | dfs = dfs[:max_games] 90 | if output_file is None: 91 | size = len(dfs) 92 | output_file = name_dataset(size, test) 93 | 94 | random.shuffle(dfs) 95 | new = True 96 | append_list = [] 97 | chunk_time = datetime.now() 98 | print("Starting") 99 | for df in dfs: 100 | game = pd.read_csv(df) 101 | game = game.drop('Unnamed: 0', axis=1) # index_col arg is often problematic, this is an extra line but always works. 102 | append_list.append(game) 103 | # Repeat until done with chunk or out of df's. 104 | if psutil.virtual_memory().percent < ram_max and (df != dfs[-1]): 105 | continue 106 | else: 107 | print(psutil.virtual_memory()) 108 | 109 | # Get here once every chunk 110 | result = pd.concat(append_list) 111 | append_list = [] 112 | print("Sample") 113 | result = result.sample(frac=1) 114 | print("Write") 115 | if new: 116 | result.to_hdf(dataset_path + output_file + '.h5', 117 | 'data', mode='w', format='table', data_columns=['secs_to_goal', 'next_goal_one']) 118 | new = False 119 | else: 120 | result.to_hdf(dataset_path + output_file + '.h5', 121 | 'data', mode='a', format='table', append=True, data_columns=['secs_to_goal', 'next_goal_one']) 122 | print("--- chunk time: {} ---".format(datetime.now() - chunk_time)) 123 | print("processed {} out of {} files".format(dfs.index(df) + 1, len(dfs))) 124 | chunk_time = datetime.now() 125 | del result 126 | 127 | return 128 | -------------------------------------------------------------------------------- /data/download_replays.py: -------------------------------------------------------------------------------- 1 | from data.interfacers.calculatedgg_api.api_interfacer import CalculatedApiInterfacer 2 | from data.interfacers.calculatedgg_api.query_params import CalculatedApiQueryParams 3 | import requests 4 | import os 5 | from configparser import ConfigParser, ExtendedInterpolation 6 | import pandas as pd 7 | 8 | config = ConfigParser(interpolation=ExtendedInterpolation()) 9 | config.read('data/config.ini') 10 | mode_tuple = config['VARS']['MODE'].split(',') 11 | mmrs = config['VARS']['mmr_range'].split('-') 12 | paths = config['PATHS'] 13 | 14 | 15 | def download_replays_range(playlist_tuple=mode_tuple, min_mmr=mmrs[0], max_mmr=mmrs[1], 16 | replay_path=paths['replay_path'], 17 | replay_log=paths['replay_log'], max_downloaded=None): 18 | """ 19 | Download replays from calculated.gg to the replay folder. Uses the config. 20 | :param replay_log: Path to logging csv 21 | :type replay_log: str 22 | :param playlist_tuple: Which playlist to draw from. 23 | :type playlist_tuple: tuple(int,int) 24 | :param min_mmr: 25 | :type min_mmr: int 26 | :param max_mmr: 27 | :type max_mmr: int 28 | :param replay_path: Path to replays folder 29 | :type replay_path: str 30 | :param max_downloaded: Maximum replays to download. 31 | :type max_downloaded: Optional[int] 32 | :return: None 33 | :rtype: None 34 | """ 35 | if not os.path.exists(replay_path): 36 | os.makedirs(replay_path) 37 | print("Created directories in replay_path") 38 | mode = int(playlist_tuple[0]) 39 | num_players = int(playlist_tuple[1]) 40 | data_cols = ['hash', 'download', 'map', 'match_date', 'upload_date', 'team_blue_score', 'team_orange_score'] + [ 41 | f'p_{x}' for x in range(num_players * 2)] + [f'mmr_{x}' for x in range(num_players * 2)] 42 | if not os.path.exists(replay_log): 43 | log = pd.DataFrame(columns=data_cols) 44 | log.to_csv(replay_log) 45 | print("Created replay log") 46 | 47 | params = CalculatedApiQueryParams(1, 200, mode, min_mmr, max_mmr) 48 | interfacer = CalculatedApiInterfacer(params) 49 | if max_downloaded is not None: 50 | pages = (max_downloaded // 200) + 1 51 | last_page_len = max_downloaded % 200 52 | else: 53 | pages = 100 54 | last_page_len = 0 55 | existing = os.listdir(replay_path) 56 | log = pd.read_csv(replay_log, index_col=0) 57 | logged = log['hash'].values 58 | old = 0 59 | new = 0 60 | logs = [] 61 | 62 | for page in range(pages): 63 | if page == 0: 64 | page_req = interfacer._get_replays_request(params).json() 65 | print(f"Matched {page_req['total_count']} replays") 66 | page_data = page_req['data'] 67 | else: 68 | page_data = interfacer._get_replays_request(params).json()['data'] 69 | if page == pages - 1: 70 | page_data = page_data[:last_page_len] 71 | if len(page_data) == 0: 72 | break 73 | for game in page_data: 74 | # Skip existing 75 | name = game['hash'] 76 | if name + '.replay' not in existing: 77 | # Write replay file 78 | open(replay_path + name + '.replay', 'wb').write( 79 | requests.get(f"https://calculated.gg/api/replay/{name}/download").content) 80 | new += 1 81 | else: 82 | old += 1 83 | 84 | if name not in logged: 85 | # Logging 86 | log_row = {x: None for x in data_cols} 87 | for key in list(set(log_row.keys()) & set(game.keys())): 88 | log_row[key] = game[key] 89 | for i in range(len(game['players'])): 90 | log_row[f'p_{i}'] = game['players'][i] 91 | # MMRS has unknown length, missing MMRS will be 'None' 92 | for i in range(len(game['mmrs'])): 93 | log_row[f'mmr_{i}'] = game['mmrs'][i] 94 | 95 | logs.append(log_row) 96 | 97 | params = params._replace(page=params.page + 1) 98 | 99 | new_log = pd.DataFrame(logs, columns=data_cols) 100 | pd.concat([log, new_log]).to_csv(replay_log) 101 | print(f"Existing: {old}") 102 | print(f"Wrote new: {new}") 103 | print(f"Logged existing: {len(new_log) - new}") 104 | -------------------------------------------------------------------------------- /data/interfacers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SaltieRL/ReplayModels/62fb8765e8efc44c4865ee9fe121de35db7528cb/data/interfacers/__init__.py -------------------------------------------------------------------------------- /data/interfacers/base_interfacer.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | import pandas as pd 4 | from carball.generated.api.game_pb2 import Game 5 | 6 | 7 | class BaseInterfacer: 8 | """ 9 | Represents a set of replays for Sequences to interact with. 10 | """ 11 | 12 | def get_total_count(self) -> int: 13 | raise NotImplementedError 14 | 15 | def get_all_replay_ids(self) -> Set[str]: 16 | raise NotImplementedError 17 | 18 | def get_replay_proto(self, replay_id: str) -> Game: 19 | raise NotImplementedError 20 | 21 | def get_replay_df(self, replay_id: str) -> pd.DataFrame: 22 | raise NotImplementedError 23 | 24 | def copy(self) -> 'BaseInterfacer': 25 | raise NotImplementedError 26 | 27 | -------------------------------------------------------------------------------- /data/interfacers/calculatedgg_api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SaltieRL/ReplayModels/62fb8765e8efc44c4865ee9fe121de35db7528cb/data/interfacers/calculatedgg_api/__init__.py -------------------------------------------------------------------------------- /data/interfacers/calculatedgg_api/api_interfacer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import io 3 | import logging 4 | import pandas as pd 5 | from typing import List, Set 6 | 7 | import requests 8 | from carball.analysis.utils.pandas_manager import PandasManager 9 | from carball.analysis.utils.proto_manager import ProtobufManager 10 | from carball.generated.api.game_pb2 import Game 11 | from requests import Response 12 | 13 | from data.interfacers.base_interfacer import BaseInterfacer 14 | from data.interfacers.calculatedgg_api.errors import BrokenDataFrameError 15 | from data.interfacers.calculatedgg_api.query_params import CalculatedApiQueryParams 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class CalculatedApiInterfacer(BaseInterfacer): 21 | BASE_URL = 'https://calculated.gg/api/v1/' 22 | PROTO_URL = 'https://storage.googleapis.com/calculatedgg-proto/' 23 | REPLAY_DF_URL = 'https://storage.googleapis.com/calculatedgg-parsed/' 24 | 25 | def __init__(self, query_params: CalculatedApiQueryParams = CalculatedApiQueryParams()): 26 | self.initial_query_params = query_params 27 | self.query_params = query_params.copy() 28 | 29 | def get_total_count(self) -> int: 30 | r = self._get_replays_request(self.initial_query_params) 31 | r_json = r.json() 32 | total_count = r_json['total_count'] 33 | logger.debug(f'Found a total of {total_count} replays.') 34 | return total_count 35 | 36 | def get_all_replay_ids(self) -> Set[str]: 37 | replay_ids = set() 38 | _query_params = self.initial_query_params._replace(page=1) 39 | while True: 40 | replay_ids_on_page = self.get_replay_list(_query_params) 41 | if len(replay_ids_on_page) == 0: 42 | break 43 | _query_params = _query_params._replace(page=_query_params.page + 1) 44 | replay_ids.update(replay_ids_on_page) 45 | logger.info(f'Found a total of {len(replay_ids)} unique replay ids.') 46 | return replay_ids 47 | 48 | def get_replay_list(self, query_params: CalculatedApiQueryParams = None) -> List[str]: 49 | query_params = query_params if query_params is not None else self.query_params 50 | r = self._get_replays_request(query_params) 51 | return [replay['hash'] for replay in r.json()['data']] 52 | 53 | @classmethod 54 | def _get_replays_request(cls, query_params: CalculatedApiQueryParams) -> Response: 55 | r = requests.get(cls.BASE_URL + 'replays', params=query_params._asdict()) 56 | r.raise_for_status() 57 | logger.debug(f'Performed request for {r.url}') 58 | return r 59 | 60 | def get_replay_proto(self, replay_id: str) -> Game: 61 | url = self.PROTO_URL + f'{replay_id}.replay.pts' 62 | r = requests.get(url) 63 | r.raise_for_status() 64 | file = io.BytesIO(r.content) 65 | proto = ProtobufManager.read_proto_out_from_file(file) 66 | logger.debug(f"Loaded {replay_id} proto from site.") 67 | return proto 68 | 69 | def get_replay_df(self, replay_id: str) -> pd.DataFrame: 70 | url = self.REPLAY_DF_URL + f'{replay_id}.replay.gzip' 71 | r = requests.get(url) 72 | r.raise_for_status() 73 | gzip_file = gzip.GzipFile(fileobj=io.BytesIO(r.content), mode='rb') 74 | df = PandasManager.safe_read_pandas_to_memory(gzip_file) 75 | if df is None: 76 | raise BrokenDataFrameError 77 | logger.debug(f"Loaded {replay_id} df from site.") 78 | return df 79 | 80 | def copy(self): 81 | return self.__class__(self.query_params.copy()) 82 | -------------------------------------------------------------------------------- /data/interfacers/calculatedgg_api/errors.py: -------------------------------------------------------------------------------- 1 | class BrokenDataFrameError(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /data/interfacers/calculatedgg_api/query_params.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | 4 | class CalculatedApiQueryParams(NamedTuple): 5 | page: int = 1 6 | num: int = 200 7 | playlist: int = None 8 | minmmr: int = None 9 | maxmmr: int = None 10 | mmrany: bool = None 11 | minrank: int = None 12 | maxrank: int = None 13 | rankany: bool = None 14 | start_timestamp: float = None 15 | end_timestamp: float = None 16 | key: str = 'PLACEHOLDER' 17 | 18 | def copy(self): 19 | return self.__class__(**self._asdict()) 20 | 21 | 22 | -------------------------------------------------------------------------------- /data/interfacers/calculatedgg_api/test_api_interfacer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import pandas as pd 3 | import unittest 4 | 5 | from carball.generated.api.game_pb2 import Game 6 | from requests import HTTPError 7 | 8 | from data.interfacers.calculatedgg_api.api_interfacer import CalculatedApiInterfacer 9 | from data.interfacers.calculatedgg_api.query_params import CalculatedApiQueryParams 10 | 11 | 12 | class ApiInterfacerTest(unittest.TestCase): 13 | 14 | def setUp(self) -> None: 15 | self.interfacer = CalculatedApiInterfacer() 16 | 17 | def test_default_get_replay_list(self): 18 | replay_ids = self.interfacer.get_replay_list() 19 | self.assertGreater(len(replay_ids), 0) 20 | 21 | def test_get_total_count(self): 22 | total_count = self.interfacer.get_total_count() 23 | self.assertGreater(total_count, 0) 24 | 25 | def test_get_all_replay_ids(self): 26 | query_params = CalculatedApiQueryParams( 27 | playlist=13, 28 | minmmr=1800, 29 | start_timestamp=int(datetime.datetime(2019, 1, 1).timestamp()), 30 | end_timestamp=int(datetime.datetime(2019, 1, 5).timestamp()), 31 | ) 32 | interfacer = CalculatedApiInterfacer(query_params) 33 | total_count = interfacer.get_total_count() 34 | self.assertGreater(total_count, 0) 35 | all_replays = interfacer.get_all_replay_ids() 36 | self.assertEqual(total_count, len(all_replays)) 37 | 38 | def test_get_replay_data_from_id(self): 39 | replay_id = "96BDDDEE11E8B6D4396D1B9668244BC6" # actually exists 40 | proto: Game = self.interfacer.get_replay_proto(replay_id) 41 | df = self.interfacer.get_replay_df(replay_id) 42 | 43 | self.assertEqual(replay_id, proto.game_metadata.match_guid) 44 | self.assertIsInstance(df, pd.DataFrame) 45 | 46 | replay_id = "MADE_UP_THING" # Does not exist 47 | with self.assertRaises(HTTPError): 48 | proto: Game = self.interfacer.get_replay_proto(replay_id) 49 | with self.assertRaises(HTTPError): 50 | df = self.interfacer.get_replay_df(replay_id) 51 | 52 | 53 | if __name__ == '__main__': 54 | unittest.main() 55 | -------------------------------------------------------------------------------- /data/interfacers/local_interfacer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | from typing import Set, Dict 5 | 6 | import carball 7 | import pandas as pd 8 | from carball.analysis.utils.pandas_manager import PandasManager 9 | from carball.analysis.utils.proto_manager import ProtobufManager 10 | from carball.generated.api.game_pb2 import Game 11 | 12 | from data.interfacers.base_interfacer import BaseInterfacer 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class LocalInterfacer(BaseInterfacer): 18 | CACHE_PATH = r"C:\Users\harry\Documents\rocket_league\ReplayModels\cache" 19 | 20 | def __init__(self, folder_path: str): 21 | self.folder_path = folder_path 22 | self.replay_paths: Dict[str, str] = self._get_replay_paths() # replay id to replay path 23 | 24 | self.known_bad_ids = [] 25 | 26 | def _get_replay_paths(self) -> Dict[str, str]: 27 | replays = glob.glob(os.path.join(self.folder_path, '**/*.replay'), recursive=True) 28 | replay_paths = { 29 | os.path.basename(replay_path)[:-7]: replay_path # -7 to remove ".replay" 30 | for replay_path in replays 31 | } 32 | return replay_paths 33 | 34 | def get_total_count(self) -> int: 35 | return len(self.replay_paths) 36 | 37 | def get_all_replay_ids(self) -> Set[str]: 38 | replay_ids = set(self.replay_paths.keys()) 39 | logger.info(f'Found a total of {len(replay_ids)} unique replay ids.') 40 | return replay_ids 41 | 42 | def get_replay_proto(self, replay_id: str) -> Game: 43 | proto_filename = self._get_proto_filename(replay_id) 44 | if proto_filename in os.listdir(self.CACHE_PATH): 45 | with open(os.path.join(self.CACHE_PATH, proto_filename), 'rb') as f: 46 | proto = ProtobufManager.read_proto_out_from_file(f) 47 | return proto 48 | else: 49 | proto = self.parse_replay(replay_id, return_proto=True) 50 | return proto 51 | 52 | def get_replay_df(self, replay_id: str) -> pd.DataFrame: 53 | dataframe_filename = self._get_dataframe_filename(replay_id) 54 | 55 | if dataframe_filename in os.listdir(self.CACHE_PATH): 56 | with open(os.path.join(self.CACHE_PATH, dataframe_filename), 'rb') as f: 57 | dataframe = PandasManager.safe_read_pandas_to_memory(f) 58 | if dataframe is None: 59 | self.known_bad_ids.append(replay_id) 60 | raise Exception(f'Cannot read replay dataframe {replay_id}') 61 | return dataframe 62 | else: 63 | dataframe = self.parse_replay(replay_id, return_dataframe=True) 64 | return dataframe 65 | 66 | def parse_replay(self, replay_id: str, return_proto: bool = False, return_dataframe: bool = False): 67 | assert not (return_proto and return_dataframe), 'Cannot return both proto and dataframe' 68 | replay_path = self.replay_paths[replay_id] 69 | try: 70 | analysis_manager = carball.analyze_replay_file(replay_path) 71 | proto_filename = self._get_proto_filename(replay_id) 72 | proto_filepath = os.path.join(self.CACHE_PATH, proto_filename) 73 | with open(proto_filepath, 'wb') as f: 74 | analysis_manager.write_proto_out_to_file(f) 75 | 76 | dataframe_filename = self._get_dataframe_filename(replay_id) 77 | dataframe_filepath = os.path.join(self.CACHE_PATH, dataframe_filename) 78 | with open(dataframe_filepath, 'wb') as f: 79 | analysis_manager.write_pandas_out_to_file(f) 80 | 81 | if return_proto: 82 | return analysis_manager.protobuf_game 83 | if return_dataframe: 84 | return analysis_manager.data_frame 85 | except Exception as e: 86 | print(f'Failed to parse replay: {e}') 87 | self.known_bad_ids.append(replay_id) 88 | 89 | @staticmethod 90 | def _get_proto_filename(replay_id: str): 91 | return replay_id + '.proto' 92 | 93 | @staticmethod 94 | def _get_dataframe_filename(replay_id: str): 95 | return replay_id + '.df' 96 | 97 | def copy(self) -> 'LocalInterfacer': 98 | return self.__class__(folder_path=self.folder_path) 99 | 100 | 101 | if __name__ == '__main__': 102 | logging.basicConfig(level=logging.INFO) 103 | folder = r"C:\Users\harry\Documents\rocket_league\replays\DHPC DHE19 Replay Files" 104 | folder = r"C:\Users\harry\Documents\rocket_league\replays\RLCS Season 6" 105 | interfacer = LocalInterfacer(folder) 106 | replay_ids = interfacer.get_all_replay_ids() 107 | 108 | print(replay_ids) 109 | replay_id = sorted(list(replay_ids))[0] 110 | proto: Game = interfacer.get_replay_proto(replay_id) 111 | 112 | print(len(proto.players)) 113 | -------------------------------------------------------------------------------- /data/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SaltieRL/ReplayModels/62fb8765e8efc44c4865ee9fe121de35db7528cb/data/utils/__init__.py -------------------------------------------------------------------------------- /data/utils/columns.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class PlayerColumn(Enum): 5 | POS_X = 'pos_x' 6 | POS_Y = 'pos_y' 7 | POS_Z = 'pos_z' 8 | VEL_X = 'vel_x' 9 | VEL_Y = 'vel_y' 10 | VEL_Z = 'vel_z' 11 | ROT_X = 'rot_x' 12 | ROT_Y = 'rot_y' 13 | ROT_Z = 'rot_z' 14 | ANG_VEL_X = 'ang_vel_x' 15 | ANG_VEL_Y = 'ang_vel_y' 16 | ANG_VEL_Z = 'ang_vel_z' 17 | THROTTLE = 'throttle' 18 | STEER = 'steer' 19 | HANDBRAKE = 'handbrake' 20 | BALL_CAM = 'ball_cam' 21 | DODGE_ACTIVE = 'dodge_active' 22 | DOUBLE_JUMP_ACTIVE = 'double_jump_active' 23 | JUMP_ACTIVE = 'jump_active' 24 | BOOST = 'boost' 25 | BOOST_ACTIVE = 'boost_active' 26 | PING = 'ping' 27 | BOOST_COLLECT = 'boost_collect' 28 | 29 | 30 | class BallColumn(Enum): 31 | POS_X = 'pos_x' 32 | POS_Y = 'pos_y' 33 | POS_Z = 'pos_z' 34 | VEL_X = 'vel_x' 35 | VEL_Y = 'vel_y' 36 | VEL_Z = 'vel_z' 37 | ROT_X = 'rot_x' 38 | ROT_Y = 'rot_y' 39 | ROT_Z = 'rot_z' 40 | ANG_VEL_X = 'ang_vel_x' 41 | ANG_VEL_Y = 'ang_vel_y' 42 | ANG_VEL_Z = 'ang_vel_z' 43 | HIT_TEAM_NO = 'hit_team_no' 44 | 45 | 46 | class GameColumn(Enum): 47 | TIME = 'time' 48 | DELTA = 'delta' 49 | SECONDS_REMAINING = 'seconds_remaining' 50 | REPLICATED_SECONDS_REMAINING = 'replicated_seconds_remaining' 51 | IS_OVERTIME = 'is_overtime' 52 | BALL_HAS_BEEN_HIT = 'ball_has_been_hit' 53 | GOAL_NUMBER = 'goal_number' 54 | -------------------------------------------------------------------------------- /data/utils/number_check.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import requests 4 | 5 | from data.utils.playlists import Playlist 6 | 7 | BASE_URL = 'https://calculated.gg/api/v1/' 8 | 9 | 10 | def get_query(playlist: int = 11): 11 | return BASE_URL + f'replays?key=1&minmmr={MIN_MMR}&playlist={playlist}&year=2019' 12 | 13 | 14 | def get_replay_list(query: str) -> List[str]: 15 | """ 16 | Does query and parses response for replay ids. 17 | :param query: query (given from get_query()) 18 | :return: List of replay ids 19 | """ 20 | r = requests.get(query) 21 | return [replay['hash'] for replay in r.json()['data']] 22 | 23 | 24 | def check_playlists(): 25 | """ 26 | Checks query for a list of playlists. (see PLAYLISTS_TO_TEST variable below). 27 | :return: None 28 | """ 29 | # PLAYLISTS_TO_TEST = [ 30 | # 1, 2, 3, 4, 6, 8, 10, 11, 12, 13, 15, 16, 27, 28, 29, 30 31 | # ] 32 | 33 | playlist_count = {} 34 | 35 | for playlist in Playlist: 36 | count = check_playlist(playlist.value) 37 | playlist_count[playlist.name] = count 38 | 39 | print(playlist_count) 40 | 41 | 42 | def check_playlist(playlist: int) -> int: 43 | """ 44 | Checks available replays for the given playlist 45 | :param playlist: the playlist query param 46 | :return: The number of available replays for the given playlist. 47 | """ 48 | 49 | query = get_query(playlist) 50 | print(f"query: {query}") 51 | r = requests.get(query) 52 | 53 | count = r.json()['total_count'] 54 | print(f"count: {count}") 55 | return count 56 | 57 | 58 | if __name__ == '__main__': 59 | # MIN_MMR = 1500 60 | # check_playlists() 61 | MIN_MMR = 1700 62 | check_playlist(13) # Ranked Standard 63 | # MIN_MMR = 1300 64 | # check_playlist(10) # Ranked Duels 65 | -------------------------------------------------------------------------------- /data/utils/playlists.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Playlist(Enum): 5 | UNRANKED_SOLO_DUEL = 1 6 | UNRANKED_DOUBLES = 2 7 | UNRANKED_STANDARD = 3 8 | UNRANKED_CHAOS = 4 9 | RANKED_DUEL = 10 10 | RANKED_DOUBLES = 11 11 | RANKED_SOLO_STANDARD = 12 12 | RANKED_STANDARD = 13 13 | RANKED_HOOPS = 27 14 | RANKED_RUMBLE = 28 15 | RANKED_DROPSHOT = 29 16 | RANKED_SNOW_DAY = 30 17 | -------------------------------------------------------------------------------- /data/utils/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Union 2 | 3 | import pandas as pd 4 | import numpy as np 5 | 6 | from .columns import PlayerColumn, BallColumn, GameColumn 7 | 8 | DataColumn = Union[PlayerColumn, BallColumn, GameColumn] 9 | 10 | 11 | def flip_teams(df: pd.DataFrame): 12 | """ 13 | Returns a new df where players' teams are effectively swapped. 14 | :param df: pd.DataFrame 15 | :return: the modified-in-place input pd.DataFrame 16 | """ 17 | _df = df.copy() 18 | players = [ 19 | name for name in _df.columns.get_level_values(level=0).unique() 20 | if name != 'ball' and name != 'game' 21 | ] 22 | for player in players: 23 | _df.loc[:, (player, 'pos_x')] *= -1 24 | _df.loc[:, (player, 'pos_y')] *= -1 25 | _df.loc[:, (player, 'vel_x')] *= -1 26 | _df.loc[:, (player, 'vel_y')] *= -1 27 | _df.loc[:, (player, 'rot_z')] += np.pi 28 | 29 | _df.loc[:, ('ball', 'pos_x')] *= -1 30 | _df.loc[:, ('ball', 'pos_y')] *= -1 31 | _df.loc[:, ('ball', 'vel_x')] *= -1 32 | _df.loc[:, ('ball', 'vel_y')] *= -1 33 | return _df 34 | 35 | 36 | def filter_columns(df: pd.DataFrame, columns: Sequence[DataColumn]): 37 | """ 38 | Returns a new pd.DataFrame only containing the given columns. 39 | :param df: Replay's df to filter 40 | :param columns: Sequence of DataColumns to keep 41 | :return: new pd.DataFrame 42 | """ 43 | player_df = df.drop(columns=['ball', 'game'], level=0) 44 | ball_df = df.xs('ball', level=0, axis=1, drop_level=False) 45 | game_df = df.xs('game', level=0, axis=1, drop_level=False) 46 | 47 | player_columns = [column.value for column in columns if isinstance(column, PlayerColumn)] 48 | ball_columns = [column.value for column in columns if isinstance(column, BallColumn)] 49 | game_columns = [column.value for column in columns if isinstance(column, GameColumn)] 50 | 51 | filtered_player_df = player_df.loc[:, (slice(None), player_columns)] 52 | filtered_ball_df = ball_df.loc[:, (slice(None), ball_columns)] 53 | filtered_game_df = game_df.loc[:, (slice(None), game_columns)] 54 | 55 | return pd.concat([filtered_player_df, filtered_ball_df, filtered_game_df], axis=1) 56 | 57 | 58 | NORMALISATION_FACTORS = { 59 | 'pos_x': 4096, 60 | 'pos_y': 6000, 61 | 'pos_z': 2048, 62 | 'rot_x': np.pi, 63 | 'rot_y': np.pi, 64 | 'rot_z': np.pi, 65 | 'vel_x': 23000, 66 | 'vel_y': 23000, 67 | 'vel_z': 23000, 68 | 'ang_vel_x': 5500, 69 | 'ang_vel_y': 5500, 70 | 'ang_vel_z': 5500, 71 | 'throttle': 255, 72 | 'steer': 255, 73 | 'boost': 255, 74 | } 75 | 76 | 77 | def normalise_df(df: pd.DataFrame, inplace: bool = False): 78 | if not inplace: 79 | df = df.copy() 80 | for column, normalisation_factor in NORMALISATION_FACTORS.items(): 81 | df.loc[:, (slice(None), column)] /= normalisation_factor 82 | return df 83 | -------------------------------------------------------------------------------- /data_main.py: -------------------------------------------------------------------------------- 1 | from data.download_replays import download_replays_range 2 | from data.convert_replays import pre_process_parallel 3 | from data.create_dataset import dataset 4 | import sys 5 | 6 | 7 | def main(): 8 | nargs = len(sys.argv) 9 | if nargs > 1: 10 | # Downloading Replays 11 | if sys.argv[1] == 'download': 12 | if nargs == 2: 13 | download_replays_range() 14 | if nargs == 3: 15 | download_replays_range(max_downloaded=int(sys.argv[2])) 16 | else: 17 | print("download usage: download [max_downloaded]") 18 | # Converting Replays to CSV 19 | if sys.argv[1] == 'convert': 20 | if nargs == 5: 21 | pre_process_parallel(int(sys.argv[2]), test_ratio=float(sys.argv[3]), overwrite=False, 22 | verbose_interval=float(sys.argv[4])) 23 | else: 24 | print("convert usage: convert ") 25 | print("\nNotes: test_ratio and verbose_interval are floats which can be 0 if you don't want either functionality.") 26 | print("test_ratio = .1 means 10% of the replays will go into the test folder. A single process handles test, so if the ratio is >>1/num_processes it will bottleneck.") 27 | print("verbose_interval = 10 means print every 10 minutes. ") 28 | # Creating dataset from csv 29 | if sys.argv[1] == 'dataset': 30 | if nargs == 5: 31 | output = None 32 | elif nargs == 6: 33 | output = sys.argv[5] 34 | else: 35 | print("dataset usage: dataset <0 (No test set) | 1 (with test set) | 2 (only test set)> [output_filename]") 36 | print("max_games: Each csv is a single game. ") 37 | print("ram_percent: RAM usage cap to avoid memory errors. Max this out with trial and error, if you want good shuffling.") 38 | print(" NOTE: After hitting the cap, RAM usage of the program will spike beyond the cap. So '100' probably won't work.") 39 | print("output_filename overrides the default naming.") 40 | print("ex: dataset 1 1000 250000") 41 | return 42 | if int(sys.argv[3]) == 0: 43 | mg = None 44 | else: 45 | mg = int(sys.argv[3]) 46 | 47 | if int(sys.argv[4]) == 0: 48 | ram = None 49 | else: 50 | ram = int(sys.argv[4]) 51 | t_int = int(sys.argv[2]) 52 | if t_int >= 1: 53 | dataset(output_file=output, test=True, max_games=mg, ram_max=ram) 54 | if t_int <= 1: 55 | dataset(output_file=output, test=False, max_games=mg, ram_max=ram) 56 | 57 | else: 58 | print("Usage: download | convert | dataset") 59 | print("['optional_arg'] <'required_arg'> '|' -> 'or'") 60 | print("download: download [max_replays]") 61 | print("convert: convert ") 62 | print( 63 | "dataset: dataset <0 (No test set) | 1 (with test set) | 2 (only test set)> [output_filename]") 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SaltieRL/ReplayModels/62fb8765e8efc44c4865ee9fe121de35db7528cb/models/__init__.py -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from typing import Union 4 | 5 | from tensorflow.python.keras.models import load_model, Sequential, Model, save_model 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class BaseModel: 11 | def __init__(self, inputs: int, outputs: int, load_from_filepath: str = None, **kwargs): 12 | self.inputs = inputs 13 | self.outputs = outputs 14 | 15 | self.load_from_filepath = load_from_filepath # Needed for copying with __dict__. 16 | if load_from_filepath is not None: 17 | self.model = load_model(load_from_filepath) 18 | else: 19 | self.model = self.build_model() 20 | 21 | def save_model(self, name: str = None, use_timestamp: bool = True): 22 | if name is None: 23 | name = self.__class__.__name__ 24 | filename = f"model_{name}" 25 | if use_timestamp: 26 | filename += f"{time.strftime('%Y%m%d-%H%M%S')}" 27 | save_model(self.model, filename) 28 | 29 | def build_model(self) -> Union[Sequential, Model]: 30 | raise NotImplementedError 31 | 32 | def evaluate(self, *args, **kwargs): 33 | return self.model.evaluate(*args, **kwargs) 34 | 35 | def predict(self, *args, **kwargs): 36 | return self.model.predict(*args, **kwargs) 37 | 38 | def fit(self, *args, **kwargs): 39 | return self.model.fit(*args, **kwargs) 40 | 41 | def fit_generator(self, *args, **kwargs): 42 | return self.model.fit_generator(*args, **kwargs) 43 | -------------------------------------------------------------------------------- /models/dense_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from typing import Sequence 4 | 5 | import tensorflow as tf 6 | from tensorflow import keras 7 | from tensorflow.python.keras.models import load_model, save_model 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class DenseModel: 13 | def __init__(self, inputs: int, outputs: int, 14 | layer_nodes: Sequence[int] = (24, 24), 15 | inner_activation=tf.nn.relu, output_activation='linear', 16 | regularizer=keras.regularizers.l2(1e-4), 17 | learning_rate=0.003, 18 | loss_fn='mse', load_from_filepath: str = None): 19 | logger.info(f'Creating DenseModel with {inputs} inputs and {outputs} outputs.') 20 | self.layer_nodes = layer_nodes 21 | self.inner_activation = inner_activation 22 | self.output_activation = output_activation 23 | self.regularizer = regularizer 24 | self.learning_rate = learning_rate 25 | self.loss_fn = loss_fn 26 | 27 | self.inputs = inputs 28 | self.outputs = outputs 29 | 30 | self.load_from_filepath = load_from_filepath # Needed for copying with __dict__. 31 | if load_from_filepath is not None: 32 | self.model = load_model(load_from_filepath) 33 | else: 34 | self.model = self.build_model() 35 | 36 | def save_model(self, name: str, use_timestamp: bool = True): 37 | filename = f"model_{name}" 38 | if use_timestamp: 39 | filename += f"{time.strftime('%Y%m%d-%H%M%S')}" 40 | save_model(self.model, filename) 41 | 42 | def build_model(self) -> keras.Sequential: 43 | model = keras.Sequential() 44 | 45 | # Verbose version needed because https://github.com/tensorflow/tensorflow/issues/22837#issuecomment-428327601 46 | # model.add(keras.layers.Dense(self.inputs)) 47 | model.add(keras.layers.Dense(input_shape=(self.inputs,), units=self.inputs)) 48 | 49 | for _layer_nodes in self.layer_nodes: 50 | model.add( 51 | keras.layers.Dense(_layer_nodes, activation=self.inner_activation, kernel_regularizer=self.regularizer) 52 | ) 53 | 54 | model.add(keras.layers.Dense(self.outputs, activation=self.output_activation)) 55 | optimizer = keras.optimizers.Adam(lr=self.learning_rate) 56 | model.compile(loss=self.loss_fn, optimizer=optimizer, metrics=['mae']) 57 | return model 58 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | psutil 2 | numpy 3 | carball 4 | matplotlib 5 | pandas 6 | requests 7 | scikit-learn -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SaltieRL/ReplayModels/62fb8765e8efc44c4865ee9fe121de35db7528cb/trainers/__init__.py -------------------------------------------------------------------------------- /trainers/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SaltieRL/ReplayModels/62fb8765e8efc44c4865ee9fe121de35db7528cb/trainers/callbacks/__init__.py -------------------------------------------------------------------------------- /trainers/callbacks/metric_tracer.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.keras.callbacks import Callback 2 | from quicktracer import trace 3 | 4 | 5 | class MetricTracer(Callback): 6 | def on_batch_end(self, batch, logs=None): 7 | self._trace_logs(logs) 8 | 9 | def on_epoch_end(self, epoch, logs=None): 10 | self._trace_logs(logs) 11 | 12 | def _trace_logs(self, logs): 13 | for metric in self.params['metrics']: 14 | if metric in logs: 15 | trace(float(logs[metric]), key=metric) 16 | -------------------------------------------------------------------------------- /trainers/callbacks/metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, precision_recall_fscore_support 6 | from tensorflow.python.keras.callbacks import Callback 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class ClassificationMetrics(Callback): 12 | 13 | def __init__(self, validation_data_: Tuple[np.ndarray, np.ndarray]): 14 | super().__init__() 15 | self.validation_data_ = validation_data_ 16 | 17 | def on_train_begin(self, logs=None): 18 | self.val_f1s = [] 19 | self.val_recalls = [] 20 | self.val_precisions = [] 21 | 22 | def on_epoch_end(self, epoch, logs=None): 23 | metrics = ['f1_metric', 'precision_metric', 'recall_metric'] 24 | for metric in metrics: 25 | if metric not in self.params['metrics']: 26 | self.params['metrics'].append(metric) 27 | 28 | val_predict = (np.asarray(self.model.predict(self.validation_data_[0]))).round() 29 | val_targ = self.validation_data_[1] 30 | _val_f1 = f1_score(val_targ, val_predict) 31 | _val_recall = recall_score(val_targ, val_predict) 32 | _val_precision = precision_score(val_targ, val_predict) 33 | _val_precision, _val_recall, _val_f1, support = precision_recall_fscore_support(val_targ, val_predict, 34 | average='binary') 35 | logger.debug(f"Support (count of occurences in target): {support}") 36 | unique, counts = np.unique(val_predict, return_counts=True) 37 | logger.info(f"\nValidation predictions: unique: {unique}, counts: {counts}.") 38 | 39 | self.val_f1s.append(_val_f1) 40 | self.val_recalls.append(_val_recall) 41 | self.val_precisions.append(_val_precision) 42 | # trace(_val_f1) 43 | # trace(_val_precision) 44 | # trace(_val_recall) 45 | if logs is not None: 46 | logs['f1_metric'] = _val_f1 47 | logs['precision_metric'] = _val_precision 48 | logs['recall_metric'] = _val_recall 49 | else: 50 | logger.warning('ClassificationMetrics not added to logs as logs is None.') 51 | return 52 | -------------------------------------------------------------------------------- /trainers/callbacks/prediction_plotter.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.keras import backend as K 6 | from tensorflow.python.keras import Sequential 7 | from tensorflow.python.keras.callbacks import Callback 8 | 9 | from trainers.callbacks.prediction_plotter_plot import PredictionPlotterPlot 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class PredictionPlotter(Callback): 15 | """ 16 | See https://stackoverflow.com/a/47081613 for implementation idea. 17 | 18 | An "on_run_begin(model)" method is defined as the initialisation has to be called before fit() 19 | (and thus before any of the existing methods of the base class, including even set_model()) 20 | """ 21 | 22 | def __init__(self, model: Sequential, plot_every_x_steps: int = 20): 23 | super().__init__() 24 | self.plot_every_x_steps = plot_every_x_steps 25 | self.plot = PredictionPlotterPlot() 26 | 27 | self.targets = [] # collect y_true batches 28 | self.outputs = [] # collect y_pred batches 29 | 30 | # the shape of these 2 variables will change according to batch shape 31 | # to handle the "last batch", specify `validate_shape=False` 32 | self.var_y_true = tf.Variable(0., validate_shape=False) 33 | self.var_y_pred = tf.Variable(0., validate_shape=False) 34 | 35 | self._initialise_variables(model) 36 | 37 | def _initialise_variables(self, model): 38 | fetches = [tf.assign(self.var_y_true, model.targets[0], validate_shape=False), 39 | tf.assign(self.var_y_pred, model.outputs[0], validate_shape=False)] 40 | 41 | model._function_kwargs = {'fetches': fetches} 42 | # use `model._function_kwargs` if using `Model` instead of `Sequential` 43 | 44 | def on_batch_end(self, batch, logs=None): 45 | self.targets.append(K.eval(self.var_y_true)) 46 | self.outputs.append(K.eval(self.var_y_pred)) 47 | 48 | if batch % self.plot_every_x_steps == 0: 49 | self.update_plot() 50 | 51 | def update_plot(self): 52 | actual = np.concatenate(self.targets, axis=None) 53 | predicted = np.concatenate(self.outputs, axis=None) 54 | logger.info(f"Plotting prediction on {len(predicted)} samples in epoch") 55 | self.plot.update_plot(actual, predicted) 56 | self.targets = [] 57 | self.outputs = [] 58 | -------------------------------------------------------------------------------- /trainers/callbacks/prediction_plotter_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | class PredictionPlotterPlot: 6 | 7 | def __init__(self): 8 | self.fig = plt.figure() 9 | self.ax = plt.gca() 10 | self.shown = False 11 | 12 | def update_plot(self, actual: np.ndarray, predicted: np.ndarray): 13 | self.ax.clear() 14 | self.ax.scatter(actual, predicted, marker='.', alpha=0.4) 15 | self.ax.grid() 16 | self.ax.set_xlabel('Actual values') 17 | self.ax.set_ylabel('Predicted values') 18 | self.fig.tight_layout() 19 | self.fig.canvas.draw() 20 | if not self.shown: 21 | plt.show(block=False) 22 | self.shown = True 23 | plt.pause(0.001) 24 | -------------------------------------------------------------------------------- /trainers/callbacks/tensorboard.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pathlib 4 | 5 | from tensorflow.python.keras.callbacks import TensorBoard 6 | 7 | LOG_FOLDER = "tensorboard_logs" 8 | FOLDER_PREFIX = "run" 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | dir_path = os.path.dirname(os.path.realpath(__file__)) 13 | log_path = os.path.join(dir_path, LOG_FOLDER) 14 | 15 | 16 | def get_tensorboard() -> TensorBoard: 17 | try: 18 | i = find_max_run_number() + 1 19 | except (FileNotFoundError, ValueError): 20 | i = 0 21 | _log_path = os.path.join(log_path, f"{FOLDER_PREFIX}_{i}") 22 | pathlib.Path(_log_path).mkdir(parents=True, exist_ok=True) 23 | 24 | callback = TensorBoard(_log_path) 25 | logger.info(f"Created TensorBoard logs in {_log_path}.") 26 | return callback 27 | 28 | 29 | def find_max_run_number() -> int: 30 | files = os.listdir(log_path) 31 | run_numbers = [] 32 | for file in files: 33 | if file.startswith(FOLDER_PREFIX): 34 | run_number = int(file[len(FOLDER_PREFIX) + 1:]) 35 | run_numbers.append(run_number) 36 | return max(run_numbers) 37 | -------------------------------------------------------------------------------- /trainers/sequences/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SaltieRL/ReplayModels/62fb8765e8efc44c4865ee9fe121de35db7528cb/trainers/sequences/__init__.py -------------------------------------------------------------------------------- /trainers/sequences/calculated_sequence.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from typing import List, Callable, Tuple, Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from carball.generated.api.game_pb2 import Game 8 | from tensorflow.python.keras.utils import Sequence 9 | 10 | from data.interfacers.base_interfacer import BaseInterfacer 11 | from data.interfacers.calculatedgg_api.errors import BrokenDataFrameError 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | GameDataToArraysTransformer = Callable[[pd.DataFrame, Game], 16 | Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]]] 17 | 18 | 19 | class CalculatedSequence(Sequence): 20 | def __init__(self, interfacer: BaseInterfacer, 21 | game_data_transformer: GameDataToArraysTransformer, 22 | replay_ids: List[str] = None): 23 | self.interfacer = interfacer 24 | self.game_data_transformer = game_data_transformer 25 | self.replay_ids: List[str] = replay_ids 26 | self._setup_sequence() 27 | 28 | def _setup_sequence(self): 29 | if self.replay_ids is None: 30 | total_count = self.interfacer.get_total_count() 31 | logger.info(f'Setting up sequence: total count: {total_count}') 32 | self.replay_ids = list(self.interfacer.get_all_replay_ids()) 33 | logger.info(f'Created sequence with {len(self)} replay ids.') 34 | 35 | def __len__(self): 36 | return len(self.replay_ids) 37 | 38 | def __getitem__(self, index: int): 39 | replay_id = self.replay_ids[index] 40 | try: 41 | proto = self.interfacer.get_replay_proto(replay_id) 42 | df = self.interfacer.get_replay_df(replay_id) 43 | return self.game_data_transformer(df, proto) 44 | except BrokenDataFrameError: 45 | logger.warning(f'Replay {replay_id} has broken dataframe.') 46 | except Exception as e: 47 | logger.exception(e) 48 | 49 | # Try random replay. 50 | while True: 51 | replay_id = random.choice(self.replay_ids) 52 | try: 53 | proto = self.interfacer.get_replay_proto(replay_id) 54 | df = self.interfacer.get_replay_df(replay_id) 55 | return self.game_data_transformer(df, proto) 56 | except BrokenDataFrameError: 57 | logger.warning(f'Replay {replay_id} has broken dataframe.') 58 | except Exception as e: 59 | logger.exception(e) 60 | 61 | def create_eval_sequence(self, eval_count: int): 62 | """ 63 | Creates a sequence to be used as evaluation. 64 | Removes the replays put into the evaluation sequence from the existing sequence. 65 | :param eval_count: 66 | :return: 67 | """ 68 | eval_set = random.sample(self.replay_ids, eval_count) 69 | 70 | # Remove eval set from this set 71 | self.replay_ids = [replay_id for replay_id in self.replay_ids if replay_id not in eval_set] 72 | 73 | return self.__class__(interfacer=self.interfacer.copy(), 74 | game_data_transformer=self.game_data_transformer, 75 | replay_ids=eval_set) 76 | 77 | def as_arrays(self) -> Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]]: 78 | logger.info(f"Generating arrays from sequence.") 79 | inputs = [] 80 | outputs = [] 81 | weights = [] 82 | for i in range(len(self)): 83 | arrays = self[i] 84 | if len(arrays) == 2: 85 | input_, output = arrays 86 | inputs.append(input_) 87 | outputs.append(output) 88 | elif len(arrays) == 3: 89 | input_, output, sample_weights = arrays 90 | inputs.append(input_) 91 | outputs.append(output) 92 | weights.append(sample_weights) 93 | elif len(arrays) == 0: 94 | continue 95 | else: 96 | raise Exception(f"GameDataToArrayTransformer should return tuple of length 2 or 3, not {len(arrays)}.") 97 | 98 | logger.info(f"Generated arrays from {len(inputs)} arrays.") 99 | if len(weights): 100 | return np.concatenate(inputs), np.concatenate(outputs), np.concatenate(weights) 101 | else: 102 | return np.concatenate(inputs), np.concatenate(outputs) 103 | -------------------------------------------------------------------------------- /value_function/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SaltieRL/ReplayModels/62fb8765e8efc44c4865ee9fe121de35db7528cb/value_function/__init__.py -------------------------------------------------------------------------------- /value_function/value_function_conv_model.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.keras import Sequential 2 | from tensorflow.python.keras.layers import Dense, Dropout, Reshape, Conv1D, Flatten 3 | from tensorflow.python.keras.optimizers import Adam 4 | 5 | from models.base_model import BaseModel 6 | 7 | 8 | class ValueFunctionConvModel(BaseModel): 9 | def __init__(self, inputs: int): 10 | super().__init__(inputs, 1) 11 | 12 | def build_model(self) -> Sequential: 13 | model = Sequential([ 14 | Reshape((self.inputs, 1), input_shape=(self.inputs,)), 15 | Conv1D(filters=512, kernel_size=3, strides=3, padding="same"), 16 | Flatten(data_format='channels_last'), 17 | Dense(256, activation='relu'), 18 | Dense(128, activation='relu'), 19 | Dense(1, activation='sigmoid') 20 | ]) 21 | 22 | optimizer = Adam(lr=1e-4) 23 | 24 | model.compile(loss='mse', optimizer=optimizer, metrics=['mae']) 25 | return model 26 | -------------------------------------------------------------------------------- /value_function/value_function_model.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.keras import Sequential 2 | from tensorflow.python.keras.layers import Dense, Dropout 3 | from tensorflow.python.keras.optimizers import Adam 4 | 5 | from models.base_model import BaseModel 6 | 7 | 8 | class ValueFunctionModel(BaseModel): 9 | def __init__(self, inputs: int): 10 | super().__init__(inputs, 1) 11 | 12 | def build_model(self) -> Sequential: 13 | model = Sequential([ 14 | Dense(512, input_dim=self.inputs, activation='relu'), 15 | Dropout(0.5), 16 | Dense(512, activation='relu'), 17 | Dropout(0.5), 18 | Dense(128, activation='relu'), 19 | Dropout(0.5), 20 | Dense(1, activation='sigmoid') 21 | ]) 22 | 23 | optimizer = Adam(lr=1e-4) 24 | 25 | model.compile(loss='mse', optimizer=optimizer, metrics=['mae']) 26 | return model 27 | -------------------------------------------------------------------------------- /value_function/value_function_trainer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from carball.generated.api.game_pb2 import Game 8 | from tensorflow.python.keras.callbacks import ModelCheckpoint 9 | 10 | from data.interfacers.calculatedgg_api.api_interfacer import CalculatedApiInterfacer 11 | from data.interfacers.calculatedgg_api.query_params import CalculatedApiQueryParams 12 | from data.utils.columns import PlayerColumn, BallColumn, GameColumn 13 | from data.utils.utils import filter_columns 14 | from trainers.callbacks.metric_tracer import MetricTracer 15 | from trainers.callbacks.tensorboard import get_tensorboard 16 | from trainers.sequences.calculated_sequence import CalculatedSequence 17 | from value_function.value_function_conv_model import ValueFunctionConvModel 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | logging.basicConfig(level=logging.INFO) 22 | logging.getLogger("carball").setLevel(logging.CRITICAL) 23 | logging.getLogger("data.base_data_manager").setLevel(logging.WARNING) 24 | 25 | 26 | def get_input_and_output_from_game_datas(df: pd.DataFrame, proto: Game) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 27 | logger.debug('Getting input and output') 28 | players = df.columns.levels[0] 29 | teams = proto.teams 30 | team_map = { 31 | player.id: team.is_orange 32 | for team in teams 33 | for player in team.player_ids 34 | } 35 | name_team_map = {player.name: player.is_orange for player in proto.players} 36 | 37 | sorted_players = sorted( 38 | [player for player in players if player not in ['ball', 'game']], 39 | key=lambda x: name_team_map[x] 40 | ) + ['ball', 'game'] 41 | 42 | goal_teams_list = [team_map[goal.player_id.id] for goal in proto.game_metadata.goals] 43 | goal_frame_numbers = [goal.frame_number for goal in proto.game_metadata.goals] 44 | goal_times_list = df.loc[goal_frame_numbers, ('game', 'time')].tolist() 45 | 46 | # Get goal number 47 | goal_number = df.loc[:, ('game', 'goal_number')].dropna() 48 | goal_number = goal_number[goal_number >= 0] 49 | goal_team = goal_number.apply(lambda x: goal_teams_list[int(x)]) 50 | goal_time = goal_number.apply(lambda x: goal_times_list[int(x)]) 51 | 52 | # Only train on 1/3 of the frames - reduce extremely similar frames 53 | _df = df.sample(frac=0.3) 54 | 55 | _df[('game', 'goal_team')] = goal_team.rename('goal_team') 56 | _df[('game', 'goal_time')] = goal_time.rename('goal_time') 57 | _df[('game', 'time_to_goal')] = _df.loc[:, ('game', 'goal_time')] - _df.loc[:, ('game', 'time')] 58 | _df = _df.dropna(subset=[('game', 'goal_team')]) 59 | 60 | # Remove post-goal frames 61 | _df = _df[_df.loc[:, ('game', 'time_to_goal')] >= 0] 62 | 63 | # _df = _df[sorted_players] # Same thing as below line 64 | _df.reindex(columns=sorted_players, level=0) 65 | 66 | # Set up data 67 | INPUT_COLUMNS = [ 68 | PlayerColumn.POS_X, PlayerColumn.POS_Y, PlayerColumn.POS_Z, 69 | PlayerColumn.ROT_X, PlayerColumn.ROT_Y, PlayerColumn.ROT_Z, 70 | PlayerColumn.VEL_X, PlayerColumn.VEL_Y, PlayerColumn.VEL_Z, 71 | # PlayerColumn.ANG_VEL_X, PlayerColumn.ANG_VEL_Y, PlayerColumn.ANG_VEL_Z, 72 | BallColumn.POS_X, BallColumn.POS_Y, BallColumn.POS_Z, 73 | BallColumn.VEL_X, BallColumn.VEL_Y, BallColumn.VEL_Z, 74 | GameColumn.SECONDS_REMAINING 75 | ] 76 | input_ = filter_columns(_df, INPUT_COLUMNS).fillna(0).astype(float) 77 | 78 | # Move value towards 0 if blue scored, or towards 1 if orange scored 79 | value_coefficient = ((-1) ** (_df.game.goal_team + 1)).astype(np.int8) 80 | 81 | MAX_TIME = 10 82 | 83 | # 0 (goal long later) to 1 (goal now) 84 | raw_value = ((MAX_TIME - _df.loc[:, ('game', 'time_to_goal')]) / MAX_TIME).clip(0, 1) 85 | output = 0.5 + 0.5 * raw_value * value_coefficient 86 | output = output.values.reshape((-1, 1)) 87 | input_ = input_.values 88 | logger.debug(f'Got input and output: input shape: {input_.shape}, output shape:{output.shape}') 89 | return input_, output, get_sample_weight(output) 90 | 91 | 92 | def get_sample_weight(output: np.ndarray) -> np.ndarray: 93 | weights = np.ones_like(output) 94 | # weights[output == 0.5] = 1 / np.sum(output == 0.5) 95 | weights[output == 0.5] = 1 / 5 96 | 97 | return weights.flatten() 98 | 99 | 100 | INPUT_FEATURES = 61 101 | 102 | 103 | if __name__ == '__main__': 104 | model = ValueFunctionConvModel(INPUT_FEATURES) 105 | # model = ValueFunctionModel(INPUT_FEATURES) 106 | 107 | interfacer = CalculatedApiInterfacer( 108 | CalculatedApiQueryParams(playlist=13, minmmr=1500, maxmmr=1550, 109 | start_timestamp=int(datetime.datetime(2018, 11, 1).timestamp())) 110 | ) 111 | sequence = CalculatedSequence( 112 | interfacer=interfacer, 113 | game_data_transformer=get_input_and_output_from_game_datas, 114 | ) 115 | 116 | EVAL_COUNT = 5 117 | eval_sequence = sequence.create_eval_sequence(EVAL_COUNT) 118 | eval_inputs, eval_outputs, _UNUSED_eval_weights = eval_sequence.as_arrays() 119 | 120 | save_callback = ModelCheckpoint('value_function.{epoch:02d}-{val_loss:.5f}.hdf5', save_best_only=True) 121 | callbacks = [ 122 | MetricTracer(), 123 | save_callback, 124 | get_tensorboard(), 125 | # PredictionPlotter(model.model), 126 | ] 127 | model.fit_generator(sequence, 128 | steps_per_epoch=500, 129 | validation_data=eval_sequence, epochs=1000, callbacks=callbacks, 130 | workers=4, use_multiprocessing=True) 131 | -------------------------------------------------------------------------------- /x_things/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SaltieRL/ReplayModels/62fb8765e8efc44c4865ee9fe121de35db7528cb/x_things/__init__.py -------------------------------------------------------------------------------- /x_things/model_retrainer.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from tensorflow.python.keras.callbacks import ModelCheckpoint 4 | from tensorflow.python.keras.models import load_model 5 | from tensorflow.python.keras import backend as K 6 | from tensorflow.python.keras.optimizers import Adam 7 | 8 | from data.interfacers.local_interfacer import LocalInterfacer 9 | from trainers.callbacks.metric_tracer import MetricTracer 10 | from trainers.callbacks.metrics import ClassificationMetrics 11 | from trainers.callbacks.tensorboard import get_tensorboard 12 | from trainers.sequences.calculated_sequence import CalculatedSequence 13 | from x_things.x_goals_trainer import get_input_and_output_from_game_datas, WeightMethod 14 | 15 | 16 | 17 | if __name__ == '__main__': 18 | 19 | filepath = "x_goals.113-0.88718.hdf5" 20 | 21 | 22 | model = load_model(filepath) 23 | 24 | # Set only last layer to be trainable 25 | for layer in model.layers[:-2]: 26 | layer.trainable = False 27 | 28 | for layer in model.layers: 29 | print(layer, layer.trainable) 30 | print(model.summary()) 31 | 32 | optimizer = Adam(lr=1e-5) 33 | model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) 34 | # K.set_value(model.optimizer.lr, 1e-5) 35 | 36 | folder = r"C:\Users\harry\Documents\rocket_league\replays\DHPC DHE19 Replay Files" 37 | interfacer = LocalInterfacer(folder) 38 | retraining_sequence = CalculatedSequence( 39 | interfacer=interfacer, 40 | game_data_transformer=partial(get_input_and_output_from_game_datas, weight_method=WeightMethod.EMPHASISE_SHOTS), 41 | ) 42 | 43 | EVAL_COUNT = 50 44 | eval_sequence = retraining_sequence.create_eval_sequence(EVAL_COUNT) 45 | eval_inputs, eval_outputs, _ = eval_sequence.as_arrays() 46 | 47 | save_callback = ModelCheckpoint('x_goals_retrained.{epoch:02d}-{val_acc:.5f}.hdf5', monitor='val_acc', save_best_only=True) 48 | classificaion_metrics = ClassificationMetrics((eval_inputs, eval_outputs)) 49 | callbacks = [ 50 | classificaion_metrics, 51 | MetricTracer(), 52 | save_callback, 53 | get_tensorboard(), 54 | # PredictionPlotter(model.model), 55 | ] 56 | model.fit_generator(retraining_sequence, 57 | steps_per_epoch=500, 58 | validation_data=eval_sequence, epochs=1000, callbacks=callbacks, 59 | workers=4, use_multiprocessing=True) 60 | -------------------------------------------------------------------------------- /x_things/x_goals_conv_model.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.keras import Sequential, Input, Model 2 | from tensorflow.python.keras.layers import Dense, Conv1D, Reshape, Flatten, Dot, Concatenate 3 | from tensorflow.python.keras.optimizers import Adam 4 | 5 | from models.base_model import BaseModel 6 | 7 | 8 | class XGoalsConvModel(BaseModel): 9 | def __init__(self, inputs: int, outputs: int, load_from_filepath: str = None): 10 | super().__init__(inputs, outputs, load_from_filepath) 11 | 12 | def build_model(self) -> Sequential: 13 | model = Sequential([ 14 | Reshape((self.inputs, 1), input_shape=(self.inputs,)), 15 | Conv1D(filters=1024, kernel_size=3, strides=3, padding="same"), 16 | Flatten(data_format='channels_last'), 17 | # Dense(1024, activation='relu', input_shape=(self.inputs,)), 18 | Dense(512, activation='relu'), 19 | Dense(128, activation='relu'), 20 | Dense(self.outputs, activation='sigmoid') 21 | ]) 22 | 23 | optimizer = Adam(lr=3e-4) 24 | 25 | model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) 26 | return model 27 | -------------------------------------------------------------------------------- /x_things/x_goals_trainer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import random 4 | from enum import Enum, auto 5 | from functools import partial 6 | from typing import Tuple 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from carball.generated.api.game_pb2 import Game 11 | from tensorflow.python.keras.callbacks import ModelCheckpoint 12 | 13 | from data.interfacers.calculatedgg_api.api_interfacer import CalculatedApiInterfacer 14 | from data.interfacers.calculatedgg_api.query_params import CalculatedApiQueryParams 15 | from data.utils.columns import PlayerColumn, BallColumn, GameColumn 16 | from data.utils.utils import filter_columns, flip_teams, normalise_df 17 | from trainers.callbacks.metric_tracer import MetricTracer 18 | from trainers.callbacks.metrics import ClassificationMetrics 19 | from trainers.callbacks.tensorboard import get_tensorboard 20 | from trainers.sequences.calculated_sequence import CalculatedSequence 21 | from x_things.x_goals_conv_model import XGoalsConvModel 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | logging.basicConfig(level=logging.INFO) 26 | logging.getLogger("carball").setLevel(logging.CRITICAL) 27 | logging.getLogger("data.base_data_manager").setLevel(logging.WARNING) 28 | 29 | 30 | class WeightMethod(Enum): 31 | OUTPUT_CATEGORY = auto() 32 | EMPHASISE_SHOTS = auto() 33 | 34 | 35 | def get_input_and_output_from_game_datas(df: pd.DataFrame, proto: Game, 36 | weight_method: WeightMethod = WeightMethod.OUTPUT_CATEGORY) -> Tuple[ 37 | np.ndarray, np.ndarray, np.ndarray]: 38 | logger.debug('Getting input and output') 39 | 40 | df = normalise_df(df, inplace=True) 41 | name_team_map = {player.name: player.is_orange for player in proto.players} 42 | 43 | # Set up data 44 | INPUT_COLUMNS = [ 45 | PlayerColumn.POS_X, PlayerColumn.POS_Y, PlayerColumn.POS_Z, 46 | PlayerColumn.ROT_X, PlayerColumn.ROT_Y, PlayerColumn.ROT_Z, 47 | PlayerColumn.VEL_X, PlayerColumn.VEL_Y, PlayerColumn.VEL_Z, 48 | PlayerColumn.ANG_VEL_X, PlayerColumn.ANG_VEL_Y, PlayerColumn.ANG_VEL_Z, 49 | BallColumn.POS_X, BallColumn.POS_Y, BallColumn.POS_Z, 50 | BallColumn.VEL_X, BallColumn.VEL_Y, BallColumn.VEL_Z, 51 | GameColumn.SECONDS_REMAINING 52 | ] 53 | filtered_df = filter_columns(df, INPUT_COLUMNS).fillna(0).astype(float) 54 | filtered_df_orange = flip_teams(filtered_df) 55 | 56 | player_id_to_player = { 57 | player.id.id: player 58 | for player in proto.players 59 | } 60 | 61 | hits = proto.game_stats.hits 62 | inputs = [] 63 | outputs = [] 64 | if weight_method == WeightMethod.EMPHASISE_SHOTS: 65 | weights = [] 66 | 67 | for hit in hits: 68 | # if not hit.shot: 69 | # continue 70 | player_name = player_id_to_player[hit.player_id.id].name 71 | 72 | player = [player for player in proto.players if player.name == player_name][0] 73 | 74 | # Make player taking shot be blue 75 | _df = filtered_df_orange if player.is_orange else filtered_df 76 | 77 | # Get right frame 78 | try: 79 | frame = _df.loc[hit.frame_number - random.randint(1, 10), :] 80 | except KeyError: 81 | frame = _df.loc[hit.frame_number, :] 82 | 83 | # Move player taking shot 84 | def key_fn(player_name: str) -> int: 85 | # Move player to front, move team to front. 86 | if player_name == player.name: 87 | return 0 88 | elif name_team_map[player_name] == player.is_orange: 89 | # return 1 90 | return random.randint(1, 10) # randomises teammates order 91 | else: 92 | # return 2 93 | return random.randint(11, 50) # randomises opponents order 94 | 95 | sorted_players = sorted( 96 | [player.name for player in proto.players], 97 | key=key_fn 98 | ) + ['ball', 'game'] 99 | 100 | frame = frame.reindex(sorted_players, level=0) 101 | inputs.append(frame.values) 102 | hit_output = [bool(getattr(hit, category)) for category in HIT_CATEGORIES] 103 | outputs.append(hit_output) 104 | if weight_method == WeightMethod.EMPHASISE_SHOTS: 105 | weights.append(10 if hit.shot else 1) 106 | 107 | input_ = np.array(inputs, dtype=np.float32) 108 | output = np.array(outputs, dtype=np.float32) 109 | if weight_method == WeightMethod.EMPHASISE_SHOTS: 110 | weights = np.array(weights, dtype=np.float32) 111 | 112 | logger.debug(f'Got input and output: input shape: {input_.shape}, output shape:{output.shape}') 113 | assert not np.any(np.isnan(input_)), "input contains nan" 114 | assert not np.any(np.isnan(output)), "output contains nan" 115 | 116 | assert INPUT_FEATURES in input_.shape, f"input has shape {input_.shape}, expected: {INPUT_FEATURES}." 117 | 118 | if weight_method == WeightMethod.OUTPUT_CATEGORY: 119 | weights = get_sample_weight(output) 120 | 121 | return input_, output, weights 122 | 123 | 124 | def get_sample_weight(output: np.ndarray): 125 | weights = np.ones_like(output) 126 | for output_category in (0, 1): 127 | category_mask = output == output_category 128 | if category_mask.any(): 129 | weights[category_mask] = 1 / np.sum(category_mask) 130 | return weights.flatten() 131 | 132 | 133 | INPUT_FEATURES = 79 134 | # HIT_CATEGORIES = ['pass_', 'passed', 'dribble', 'dribble_continuation', 'shot', 'goal', 'assist', 'assisted', 135 | # 'save', 'aerial'] 136 | # HIT_CATEGORIES = ['pass_', 'shot', 'goal', 'aerial'] 137 | HIT_CATEGORIES = ['goal'] 138 | # HIT_CATEGORIES = ['shot'] 139 | 140 | 141 | if __name__ == '__main__': 142 | model = XGoalsConvModel(INPUT_FEATURES, len(HIT_CATEGORIES)) 143 | 144 | interfacer = CalculatedApiInterfacer( 145 | # CalculatedApiQueryParams(playlist=13, minmmr=1350, maxmmr=1550, 146 | CalculatedApiQueryParams(playlist=13, minmmr=1500, 147 | start_timestamp=int(datetime.datetime(2018, 11, 1).timestamp())) 148 | ) 149 | sequence = CalculatedSequence( 150 | interfacer=interfacer, 151 | game_data_transformer=partial(get_input_and_output_from_game_datas, weight_method=WeightMethod.OUTPUT_CATEGORY), 152 | ) 153 | 154 | EVAL_COUNT = 100 155 | eval_sequence = sequence.create_eval_sequence(EVAL_COUNT) 156 | eval_inputs, eval_outputs, _UNUSED_eval_weights = eval_sequence.as_arrays() 157 | 158 | save_callback = ModelCheckpoint('x_goals.{epoch:02d}-{val_acc:.5f}.hdf5', monitor='val_acc', save_best_only=True) 159 | classificaion_metrics = ClassificationMetrics((eval_inputs, eval_outputs)) 160 | callbacks = [ 161 | classificaion_metrics, 162 | MetricTracer(), 163 | save_callback, 164 | get_tensorboard(), 165 | # PredictionPlotter(model.model), 166 | ] 167 | model.fit_generator(sequence, 168 | steps_per_epoch=1000, 169 | validation_data=eval_sequence, epochs=1000, callbacks=callbacks, 170 | workers=4, use_multiprocessing=True) 171 | -------------------------------------------------------------------------------- /x_things/x_things_model.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.keras import Sequential 2 | from tensorflow.python.keras.layers import Dense, Dropout 3 | from tensorflow.python.keras.optimizers import Adam 4 | 5 | from models.base_model import BaseModel 6 | 7 | 8 | class XThingsModel(BaseModel): 9 | def __init__(self, inputs: int, outputs: int): 10 | super().__init__(inputs, outputs) 11 | 12 | def build_model(self) -> Sequential: 13 | model = Sequential([ 14 | Dense(256, input_dim=self.inputs, activation='relu'), 15 | Dropout(0.5), 16 | Dense(256, activation='relu'), 17 | Dropout(0.5), 18 | Dense(128, activation='relu'), 19 | Dense(self.outputs, activation='sigmoid') 20 | ]) 21 | 22 | optimizer = Adam(lr=1e-3) 23 | 24 | model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) 25 | return model 26 | --------------------------------------------------------------------------------