├── README.md ├── GameStats.py ├── baseballStatsFormatGuide.txt ├── learningBaseball.ipynb └── NueralNetwork.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Machine-Learning-Baseball ⚾ 2 | 3 | ## Baseball 4 | The movie Money Ball, which is based on a true story, shows in game baseball statistics can be collected and analyzed in such a way that provides accurate answers to specific questions. This relies on the fact that, over the course of a season, teams experience patterns and react to factors in a repetitive manner, this ultimately affects their in-game performances. Essentially, the MLB is one large complexity system with feedback, stocks, and other system qualities as a result it can theoretically be understood. 5 | 6 | ## Hypothesis 7 | We theorized that there is indeed a relationship between the statistics to a game and its outcome. As a result, the group focused on implementing a model that predicted the score of a particular game using the statistics of that game. 8 | 9 | ## Model Overview 10 | Both the teams in a game are given their individual ID values and are made into vectors. Relevant data like the home and away team, home runs, RBI’s, and walk’s are all taken into account and passed through layers. There’s no need to reinvent the wheel here, there's a multitude of libraries that enable a coder to implement machine learning theories efficiently. In this case we will be using a library called TFlearn, documentation available from http://tflearn.org. The program will output the home and away teams as well as their respective score predictions. 11 | 12 | ## Implementation 13 | As mentioned earlier, the model was built using TFLearn, which is a API to Tensorflow. The model’s input data is the 2015/2016 baseball season statistics, score, and matchups. The model learns what statistics are useful for deciding a score, it also recognizes the different team by feeding the teams ID into a separate layer first. To train the model we used back propagation with gradient descent, this was handled by TFLearn during the training process. An in depth description of the model used is given in the notebook that is proved with the report. We chose to train the model on the 2015 season, using that data to learn what statistics are import in a game, that trained model could then make a predictions. We applied the model to the 2016 season, for each game we gave the model the statistics, scores and teams in that game to base a prediction from, even though the statistic are not determined until after a game. We originally tried using input as the team's average statistics for all their previous games, however these predictions were no better than coin flip. 14 | 15 | ## Results 16 | In the end the model predicted games quite well. Scores were within 1 or 2 points off the actual values for the score and have about a 90% prediction rate of who would win games. The high accuracy of the model helps prove the hypothesis that baseball game statistics are highly correlated to the final score of the game. Such that the amount of home runs a team achieves in a game, has an direct effect of high there score is. Also when team achieves more hits than their opponent, they have a higher probability of scoring more runs than the other team. For defence, if a team completes a substantial amount of double plays in a game, then it demonstrates that there defense is effective, and that the other team will have a harder time scoring runs. The neural net was able to detect these subtle relations, to make effective predictions on who wins the game, and what the score is. 17 | 18 | 19 | ## Resources 20 | ********* 21 | baseball data source 22 | http://www.retrosheet.org/gamelogs/index.html 23 | 24 | Tensflor Library 25 | https://www.tensorflow.org 26 | 27 | TFlearn 28 | http://tflearn.org 29 | -------------------------------------------------------------------------------- /GameStats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class GameStats(object): 4 | 5 | def __init__(self, homeTeamNameIndex, homeTeamScoreIndex, homeTeamStatsIndex, visitorTeamNameIndex, visitorTeamScoreIndex, visitorTeamStatsIndex): 6 | #parse the text file 7 | self.statsFile = open("baseball2016.txt", "r") 8 | self.topArray = [] 9 | self.sideArray = [] 10 | self.sc = np.zeros((30,30,30), np.int32) 11 | self.sc[:,:,:] = -1 12 | self.am = np.zeros((30,30), np.float32) 13 | self.gameList = [] 14 | 15 | for line in self.statsFile: 16 | homeTeam = "" 17 | awayTeam = "" 18 | homeScore = 0 19 | awayScore = 0 20 | 21 | token = line.split(',') #tokenize the string 22 | tokenIndex = [homeTeamNameIndex, homeTeamScoreIndex, visitorTeamNameIndex, visitorTeamScoreIndex] + [i for i in homeTeamStatsIndex] + [i for i in visitorTeamStatsIndex] 23 | attributes = dict() 24 | 25 | for i in xrange(len(token)): 26 | if(i in tokenIndex): 27 | attributes[i] = removeQuotes(token[i]) 28 | 29 | self.addScore(attributes[homeTeamNameIndex], attributes[visitorTeamNameIndex], attributes[homeTeamScoreIndex], attributes[visitorTeamScoreIndex]) 30 | 31 | self.addGame(attributes[homeTeamNameIndex], attributes[homeTeamScoreIndex], [attributes[i] for i in homeTeamStatsIndex], attributes[visitorTeamNameIndex], attributes[visitorTeamScoreIndex], [attributes[i] for i in homeTeamStatsIndex]) 32 | 33 | self.buildAvgMatrix() 34 | self.statsFile.close() 35 | 36 | def removeQuotes(string): 37 | if (string.startswith('"') and string.endswith('"')) or (string.startswith("'") and string.endswith("'")): 38 | print("here") 39 | return string[1:-1] 40 | return string 41 | 42 | def addGame(self, team1, score1, stats1, team2, score2, stats2): 43 | self.gameList.append([team1, score1, stats1, team2, score2, stats2]) 44 | 45 | #give it two teams, the scores, and it will add it to the matrix 46 | def addScore(self, team1, team2, score1, score2): 47 | ''' 48 | for a team in top array, the index in the array corrisponds to the matrix column there located in 49 | for a team in side array, the index in the array corrisponds to the matrix row there located in 50 | ''' 51 | #team 1 score entry 52 | try: 53 | row = self.sideArray.index(team2) 54 | 55 | except: 56 | self.sideArray.append(team2) 57 | row = self.sideArray.index(team2) 58 | 59 | try: 60 | col = self.topArray.index(team1) 61 | except: 62 | self.topArray.append(team1) 63 | col = self.topArray.index(team1) 64 | temp = self.sc[row, col] 65 | counter = 0 66 | for e in temp: 67 | if (e == -1): 68 | temp[counter] = score1 69 | break 70 | counter += 1 71 | self.sc[row, col] = temp 72 | 73 | #team 2 score entry 74 | try: 75 | row = self.sideArray.index(team1) 76 | except: 77 | self.sideArray.append(team1) 78 | row = self.sideArray.index(team1) 79 | 80 | try: 81 | col = self.topArray.index(team2) 82 | except: 83 | self.topArray.append(team2) 84 | col = self.topArray.index(team2) 85 | temp = self.sc[row, col] 86 | counter = 0 87 | for e in temp: 88 | if (e == -1): 89 | temp[counter] = score2 90 | break 91 | counter += 1 92 | self.sc[row, col] = temp 93 | 94 | #returns the score(s) for match up 95 | def getScore(self, team1, team2, gameSelect = None): 96 | print(team1, team2) 97 | try: 98 | score1 = self.sc[self.sideArray.index(team2), self.topArray.index(team1)] 99 | score2 = self.sc[self.sideArray.index(team1), self.topArray.index(team2)] 100 | if (gameSelect == None): 101 | print(team1, score1) 102 | print(team2, score2) 103 | else: 104 | print(team1, score1[gameSelect]) 105 | print(team2, score2[gameSelect]) 106 | except: 107 | print('Invalid input of teams') 108 | 109 | def getGameList(self): 110 | return self.gameList 111 | 112 | #constructs a matrix of the avg score in a matchup 113 | def buildAvgMatrix(self): 114 | for col in range(len(self.sc[:,0])): #depth 115 | for row in range(len(self.sc[0, :])): #width 116 | tempScore = self.sc[row, col] 117 | avgScore = 0.0 118 | count = 0.0 119 | for j in tempScore: 120 | if (j != -1): 121 | avgScore += j 122 | count += 1 123 | else: 124 | break 125 | try: 126 | avgScore = avgScore / count 127 | except: 128 | avgScore = -1 129 | self.am[row, col] = avgScore 130 | 131 | #get the value of the avg score for a match up 132 | def getAvgScore(self, team1, team2): 133 | try: 134 | score1 = self.am[self.sideArray.index(team2), self.topArray.index(team1)] 135 | score2 = self.am[self.sideArray.index(team1), self.topArray.index(team2)] 136 | print(team1, score1) 137 | print(team2, score2) 138 | except: 139 | print('Invalid input of teams') -------------------------------------------------------------------------------- /baseballStatsFormatGuide.txt: -------------------------------------------------------------------------------- 1 | Field(s) Meaning 2 | 1 Date in the form "yyyymmdd" 3 | 2 Number of game: 4 | "0" -- a single game 5 | "1" -- the first game of a double (or triple) header 6 | including seperate admission doubleheaders 7 | "2" -- the second game of a double (or triple) header 8 | including seperate admission doubleheaders 9 | "3" -- the third game of a triple-header 10 | "A" -- the first game of a double-header involving 3 teams 11 | "B" -- the second game of a double-header involving 3 teams 12 | 3 Day of week ("Sun","Mon","Tue","Wed","Thu","Fri","Sat") 13 | 4-5 Visiting team and league 14 | 6 Visiting team game number 15 | For this and the home team game number, ties are counted as 16 | games and suspended games are counted from the starting 17 | rather than the ending date. 18 | 7-8 Home team and league 19 | 9 Home team game number 20 | 10-11 Visiting and home team score (unquoted) 21 | 12 Length of game in outs (unquoted). A full 9-inning game would 22 | have a 54 in this field. If the home team won without batting 23 | in the bottom of the ninth, this field would contain a 51. 24 | 13 Day/night indicator ("D" or "N") 25 | 14 Completion information. If the game was completed at a 26 | later date (either due to a suspension or an upheld protest) 27 | this field will include: 28 | "yyyymmdd,park,vs,hs,len" Where 29 | yyyymmdd -- the date the game was completed 30 | park -- the park ID where the game was completed 31 | vs -- the visitor score at the time of interruption 32 | hs -- the home score at the time of interruption 33 | len -- the length of the game in outs at time of interruption 34 | All the rest of the information in the record refers to the 35 | entire game. 36 | 15 Forfeit information: 37 | "V" -- the game was forfeited to the visiting team 38 | "H" -- the game was forfeited to the home team 39 | "T" -- the game was ruled a no-decision 40 | 16 Protest information: 41 | "P" -- the game was protested by an unidentified team 42 | "V" -- a disallowed protest was made by the visiting team 43 | "H" -- a disallowed protest was made by the home team 44 | "X" -- an upheld protest was made by the visiting team 45 | "Y" -- an upheld protest was made by the home team 46 | Note: two of these last four codes can appear in the field 47 | (if both teams protested the game). 48 | 17 Park ID 49 | 18 Attendance (unquoted) 50 | 19 Time of game in minutes (unquoted) 51 | 20-21 Visiting and home line scores. For example: 52 | "010000(10)0x" 53 | Would indicate a game where the home team scored a run in 54 | the second inning, ten in the seventh and didn't bat in the 55 | bottom of the ninth. 56 | 22-38 Visiting team offensive statistics (unquoted) (in order): 57 | at-bats 58 | hits 59 | doubles 60 | triples 61 | homeruns 62 | RBI 63 | sacrifice hits. This may include sacrifice flies for years 64 | prior to 1954 when sacrifice flies were allowed. 65 | sacrifice flies (since 1954) 66 | hit-by-pitch 67 | walks 68 | intentional walks 69 | strikeouts 70 | stolen bases 71 | caught stealing 72 | grounded into double plays 73 | awarded first on catcher's interference 74 | left on base 75 | 39-43 Visiting team pitching statistics (unquoted)(in order): 76 | pitchers used ( 1 means it was a complete game ) 77 | individual earned runs 78 | team earned runs 79 | wild pitches 80 | balks 81 | 44-49 Visiting team defensive statistics (unquoted) (in order): 82 | putouts. Note: prior to 1931, this may not equal 3 times 83 | the number of innings pitched. Prior to that, no 84 | putout was awarded when a runner was declared out for 85 | being hit by a batted ball. 86 | assists 87 | errors 88 | passed balls 89 | double plays 90 | triple plays 91 | 50-66 Home team offensive statistics 92 | 67-71 Home team pitching statistics 93 | 72-77 Home team defensive statistics 94 | 78-79 Home plate umpire ID and name 95 | 80-81 1B umpire ID and name 96 | 82-83 2B umpire ID and name 97 | 84-85 3B umpire ID and name 98 | 86-87 LF umpire ID and name 99 | 88-89 RF umpire ID and name 100 | If any umpire positions were not filled for a particular game 101 | the fields will be "","(none)". 102 | 90-91 Visiting team manager ID and name 103 | 92-93 Home team manager ID and name 104 | 94-95 Winning pitcher ID and name 105 | 96-97 Losing pitcher ID and name 106 | 98-99 Saving pitcher ID and name--"","(none)" if none awarded 107 | 100-101 Game Winning RBI batter ID and name--"","(none)" if none 108 | awarded 109 | 102-103 Visiting starting pitcher ID and name 110 | 104-105 Home starting pitcher ID and name 111 | 106-132 Visiting starting players ID, name and defensive position, 112 | listed in the order (1-9) they appeared in the batting order. 113 | 133-159 Home starting players ID, name and defensive position 114 | listed in the order (1-9) they appeared in the batting order. 115 | 160 Additional information. This is a grab-bag of informational 116 | items that might not warrant a field on their own. The field 117 | is alpha-numeric. Some items are represented by tokens such as: 118 | "HTBF" -- home team batted first. 119 | Note: if "HTBF" is specified it would be possible to see 120 | something like "01002000x" in the visitor's line score. 121 | Changes in umpire positions during a game will also appear in 122 | this field. These will be in the form: 123 | umpchange,inning,umpPosition,umpid with the latter three 124 | repeated for each umpire. 125 | These changes occur with umpire injuries, late arrival of 126 | umpires or changes from completion of suspended games. Details 127 | of suspended games are in field 14. 128 | 161 Acquisition information: 129 | "Y" -- we have the complete game 130 | "N" -- we don't have any portion of the game 131 | "D" -- the game was derived from box score and game story 132 | "P" -- we have some portion of the game. We may be missing 133 | innings at the beginning, middle and end of the game. 134 | 135 | Missing fields will be NULL. 136 |  -------------------------------------------------------------------------------- /learningBaseball.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "collapsed": false, 8 | "deletable": true, 9 | "editable": true 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "import numpy as np\n", 14 | "\n", 15 | "class GameStats(object): \n", 16 | " \n", 17 | " def __init__(self, homeTeamNameIndex, homeTeamScoreIndex, homeTeamStatsIndex, visitorTeamNameIndex, visitorTeamScoreIndex, visitorTeamStatsIndex):\n", 18 | " #parse the text file\n", 19 | " self.statsFile = open(\"baseball2016.txt\", \"r\")\n", 20 | " self.topArray = []\n", 21 | " self.sideArray = [] \n", 22 | " self.sc = np.zeros((30,30,30), np.int32) \n", 23 | " self.sc[:,:,:] = -1 \n", 24 | " self.am = np.zeros((30,30), np.float32)\n", 25 | " self.gameList = []\n", 26 | " \n", 27 | " seenTeams = []\n", 28 | " for line in self.statsFile:\n", 29 | " token = line.split(',') #tokenize the string\n", 30 | " tokenIndex = [homeTeamNameIndex, homeTeamScoreIndex, visitorTeamNameIndex, visitorTeamScoreIndex] + [i for i in homeTeamStatsIndex] + [i for i in visitorTeamStatsIndex]\n", 31 | " attributes = dict()\n", 32 | " \n", 33 | " for i in range(len(token)):\n", 34 | " if(i in tokenIndex):\n", 35 | " attributes[i] = self.removeQuotes(token[i])\n", 36 | " \n", 37 | " self.addScore(attributes[homeTeamNameIndex], attributes[visitorTeamNameIndex], attributes[homeTeamScoreIndex], attributes[visitorTeamScoreIndex]) \n", 38 | " self.addGame(attributes[homeTeamNameIndex], attributes[homeTeamScoreIndex], [attributes[i] for i in homeTeamStatsIndex], attributes[visitorTeamNameIndex], attributes[visitorTeamScoreIndex], [attributes[i] for i in visitorTeamStatsIndex])\n", 39 | " \n", 40 | " if(attributes[homeTeamNameIndex] not in seenTeams):\n", 41 | " seenTeams.append(attributes[homeTeamNameIndex])\n", 42 | " \n", 43 | " self.buildAvgMatrix()\n", 44 | " self.statsFile.close()\n", 45 | " #self.gameList = [bin(2**i)[2:].zfill(len(seenTeams)) if x == seenTeams[i] else x for i in range(len(seenTeams)) for x in self.gameList]\n", 46 | " # take the teams and convert the indexes of the order that they appeared in into the hot one format by rasiing\n", 47 | " # 2 to the power of them and then converting them into binary.\n", 48 | " seenTeamsDict = {k: v for v, k in enumerate(seenTeams)}\n", 49 | " temp = []\n", 50 | " for x in self.gameList:\n", 51 | " tempi = []\n", 52 | " for z in x:\n", 53 | " if(z in seenTeams):\n", 54 | " tempi.append(bin(2**seenTeamsDict[z])[2:].zfill(len(seenTeamsDict)))\n", 55 | " else:\n", 56 | " tempi.append(z)\n", 57 | " temp.append(tempi)\n", 58 | " self.gameList = temp\n", 59 | " \n", 60 | " def removeQuotes(self, string):\n", 61 | " if (string.startswith('\"') and string.endswith('\"')) or (string.startswith(\"'\") and string.endswith(\"'\")):\n", 62 | " return string[1:-1]\n", 63 | " return string \n", 64 | " \n", 65 | " def addGame(self, team1, score1, stats1, team2, score2, stats2):\n", 66 | " self.gameList.append([team1, score1, stats1, team2, score2, stats2])\n", 67 | " \n", 68 | " #give it two teams, the scores, and it will add it to the matrix\n", 69 | " def addScore(self, team1, team2, score1, score2):\n", 70 | " '''\n", 71 | " for a team in top array, the index in the array corrisponds to the matrix column there located in\n", 72 | " for a team in side array, the index in the array corrisponds to the matrix row there located in\n", 73 | " '''\n", 74 | " #team 1 score entry\n", 75 | " try:\n", 76 | " row = self.sideArray.index(team2) \n", 77 | "\n", 78 | " except:\n", 79 | " self.sideArray.append(team2)\n", 80 | " row = self.sideArray.index(team2) \n", 81 | "\n", 82 | " try:\n", 83 | " col = self.topArray.index(team1)\n", 84 | " except:\n", 85 | " self.topArray.append(team1)\n", 86 | " col = self.topArray.index(team1)\n", 87 | " temp = self.sc[row, col]\n", 88 | " counter = 0\n", 89 | " for e in temp:\n", 90 | " if (e == -1):\n", 91 | " temp[counter] = score1\n", 92 | " break\n", 93 | " counter += 1\n", 94 | " self.sc[row, col] = temp\n", 95 | " \n", 96 | " #team 2 score entry\n", 97 | " try:\n", 98 | " row = self.sideArray.index(team1) \n", 99 | " except:\n", 100 | " self.sideArray.append(team1)\n", 101 | " row = self.sideArray.index(team1) \n", 102 | " \n", 103 | " try:\n", 104 | " col = self.topArray.index(team2)\n", 105 | " except:\n", 106 | " self.topArray.append(team2)\n", 107 | " col = self.topArray.index(team2)\n", 108 | " temp = self.sc[row, col]\n", 109 | " counter = 0\n", 110 | " for e in temp:\n", 111 | " if (e == -1):\n", 112 | " temp[counter] = score2\n", 113 | " break\n", 114 | " counter += 1\n", 115 | " self.sc[row, col] = temp\n", 116 | " \n", 117 | " #returns the score(s) for match up\n", 118 | " def getScore(self, team1, team2, gameSelect = None):\n", 119 | " print(team1, team2)\n", 120 | " try:\n", 121 | " score1 = self.sc[self.sideArray.index(team2), self.topArray.index(team1)]\n", 122 | " score2 = self.sc[self.sideArray.index(team1), self.topArray.index(team2)]\n", 123 | " if (gameSelect == None):\n", 124 | " print(team1, score1)\n", 125 | " print(team2, score2)\n", 126 | " else:\n", 127 | " print(team1, score1[gameSelect])\n", 128 | " print(team2, score2[gameSelect])\n", 129 | " except:\n", 130 | " print('Invalid input of teams')\n", 131 | " \n", 132 | " def getGameList(self):\n", 133 | " return self.gameList\n", 134 | " \n", 135 | " #constructs a matrix of the avg score in a matchup\n", 136 | " def buildAvgMatrix(self): \n", 137 | " for col in range(len(self.sc[:,0])): #depth\n", 138 | " for row in range(len(self.sc[0, :])): #width\n", 139 | " tempScore = self.sc[row, col]\n", 140 | " avgScore = 0.0\n", 141 | " count = 0.0\n", 142 | " for j in tempScore:\n", 143 | " if (j != -1):\n", 144 | " avgScore += j\n", 145 | " count += 1\n", 146 | " else:\n", 147 | " break\n", 148 | " try:\n", 149 | " avgScore = avgScore / count\n", 150 | " except:\n", 151 | " avgScore = -1\n", 152 | " self.am[row, col] = avgScore\n", 153 | " \n", 154 | " #get the value of the avg score for a match up\n", 155 | " def getAvgScore(self, team1, team2):\n", 156 | " try:\n", 157 | " score1 = self.am[self.sideArray.index(team2), self.topArray.index(team1)]\n", 158 | " score2 = self.am[self.sideArray.index(team1), self.topArray.index(team2)]\n", 159 | " print(team1, score1)\n", 160 | " print(team2, score2) \n", 161 | " except:\n", 162 | " print('Invalid input of teams')" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 24, 168 | "metadata": { 169 | "collapsed": false, 170 | "deletable": true, 171 | "editable": true, 172 | "scrolled": true 173 | }, 174 | "outputs": [ 175 | { 176 | "ename": "TypeError", 177 | "evalue": "dropout() missing 1 required positional argument: 'keep_prob'", 178 | "output_type": "error", 179 | "traceback": [ 180 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 181 | "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", 182 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 15\u001b[0m \u001b[0mnet\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtflearn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfully_connected\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m15\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[1;31m# The output layer\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 17\u001b[1;33m \u001b[0mnet\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtflearn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 18\u001b[0m \u001b[0mnet\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtflearn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfully_connected\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgameList\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 183 | "\u001b[1;31mTypeError\u001b[0m: dropout() missing 1 required positional argument: 'keep_prob'" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "import tensorflow as tf\n", 189 | "import tflearn\n", 190 | "\n", 191 | "tf.reset_default_graph()\n", 192 | "\n", 193 | "# Parse the game files and grab the stats we want\n", 194 | "# the indexes are put in the order: homeTeamName, homeTeamScore, homeTeamStats, awayName, awayScore, awayStats\n", 195 | "stats = GameStats(3, 10, [51, 55], 6, 9, [23, 27])\n", 196 | "gameList = stats.getGameList() # get the list of games\n", 197 | "\n", 198 | "# the input layer needs to have the same dimensions as our input (in this case the teams)\n", 199 | "net = tflearn.input_data(shape=[None, len(gameList[0][0])], name='input')\n", 200 | "# next we have the hidden layer it is the feature matrix the size is arbitrary.\n", 201 | "# I chose 15 becasue that is what the matrix factorization feature size is set to.\n", 202 | "net = tflearn.fully_connected(net, 15)\n", 203 | "# The output layer\n", 204 | "net = tflearn.fully_connected(net, len(gameList[0][2]))\n", 205 | "\n", 206 | "net = tflearn.regression(net, name='target', learning_rate=0.9)\n", 207 | "\n", 208 | "# Note(daniel): right now we don't care what teams played eachother (we may in the future) therefore all of the teams are put\n", 209 | "# into one array.\n", 210 | "\n", 211 | "# take only the team names and split them into individual parts (ie 0100 becomes [0], [1], [0], [0])\n", 212 | "NNInput = [list(i[x]) for i in gameList for x in range(len(i)) if x == 0 or x == 3]\n", 213 | "NNInput = np.array(NNInput)\n", 214 | "\n", 215 | "# take only the stats for each team and put them into an array\n", 216 | "NNOutput = [[7, 7] for i in gameList for x in range(len(i)) if x == 2 or x == 5]\n", 217 | "NNOutput = np.array(NNOutput)\n", 218 | "\n", 219 | "# Define model\n", 220 | "model = tflearn.DNN(net)\n", 221 | "# Start training (apply gradient descent algorithm)\n", 222 | "model.fit(NNInput, NNOutput, validation_set=0.1, n_epoch=20, show_metric=True)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 2, 228 | "metadata": { 229 | "collapsed": false, 230 | "deletable": true, 231 | "editable": true 232 | }, 233 | "outputs": [ 234 | { 235 | "ename": "NameError", 236 | "evalue": "name 'NNInput' is not defined", 237 | "output_type": "error", 238 | "traceback": [ 239 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 240 | "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", 241 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[0mtestIndex\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[0mclosest\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m1000\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtqdm_notebook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mNNInput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 11\u001b[0m \u001b[0mpredict\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mNNInput\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 242 | "\u001b[1;31mNameError\u001b[0m: name 'NNInput' is not defined" 243 | ] 244 | } 245 | ], 246 | "source": [ 247 | "# Prediction Test\n", 248 | "from tqdm import tqdm_notebook\n", 249 | "\n", 250 | "# trying to write all of the data seems to crash my browser\n", 251 | "# so I'm dumping the raw output to a file to read later.\n", 252 | "file = open('netOutput.txt', 'w')\n", 253 | "\n", 254 | "testIndex = 0\n", 255 | "closest = 1000\n", 256 | "for i in tqdm_notebook(range(len(NNInput))):\n", 257 | " predict = model.predict([NNInput[i]])\n", 258 | "\n", 259 | " if(((predict[0][0]-float(NNOutput[i][0]))+(predict[0][1]-float(NNOutput[i][1])))**2 < closest):\n", 260 | " closest = ((predict[0][0]-float(NNOutput[i][0]))+(predict[0][1]-float(NNOutput[i][1])))**2 \n", 261 | " testIndex = i \n", 262 | " \n", 263 | " file.write(\"hits, RBI\")\n", 264 | " file.write(\"\\n\")\n", 265 | " file.write(str(predict))\n", 266 | " file.write(\"\\n\")\n", 267 | " file.write(str(NNOutput[i]))\n", 268 | " file.write(\"\\n\")\n", 269 | " \n", 270 | "file.close()\n", 271 | "\n", 272 | "print(\"hits, RBI\")\n", 273 | "print(testIndex)\n", 274 | "print(model.predict([NNInput[testIndex]]))\n", 275 | "print(NNOutput[testIndex])" 276 | ] 277 | } 278 | ], 279 | "metadata": { 280 | "anaconda-cloud": {}, 281 | "kernelspec": { 282 | "display_name": "Python 3", 283 | "language": "python", 284 | "name": "python3" 285 | }, 286 | "language_info": { 287 | "codemirror_mode": { 288 | "name": "ipython", 289 | "version": 3 290 | }, 291 | "file_extension": ".py", 292 | "mimetype": "text/x-python", 293 | "name": "python", 294 | "nbconvert_exporter": "python", 295 | "pygments_lexer": "ipython3", 296 | "version": "3.5.3" 297 | } 298 | }, 299 | "nbformat": 4, 300 | "nbformat_minor": 2 301 | } 302 | -------------------------------------------------------------------------------- /NueralNetwork.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "deletable": true, 7 | "editable": true 8 | }, 9 | "source": [ 10 | "# Baseball learning network" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": { 17 | "collapsed": false, 18 | "deletable": true, 19 | "editable": true 20 | }, 21 | "outputs": [ 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "hdf5 is not supported on this machine (please install/reinstall h5py for optimal experience)\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "import numpy as np\n", 32 | "import tensorflow as tf\n", 33 | "import tflearn\n", 34 | "from tqdm import tqdm_notebook\n", 35 | "import copy\n", 36 | "from scipy import stats" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 1, 42 | "metadata": { 43 | "collapsed": true, 44 | "deletable": true, 45 | "editable": true 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "class GameStats(object): \n", 50 | " \n", 51 | " def __init__(self, homeTeamNameIndex, homeTeamScoreIndex, homeTeamStatsIndex, visitorTeamNameIndex, visitorTeamScoreIndex, visitorTeamStatsIndex):\n", 52 | " #parse the text file\n", 53 | " self.statsFile = open(\"baseball2016.txt\", \"r\")\n", 54 | " self.topArray = []\n", 55 | " self.sideArray = [] \n", 56 | " self.sc = np.zeros((30,30,30), np.int32) \n", 57 | " self.sc[:,:,:] = -1 \n", 58 | " self.am = np.zeros((30,30), np.float32)\n", 59 | " self.gameList = []\n", 60 | " \n", 61 | " seenTeams = []\n", 62 | " for line in self.statsFile:\n", 63 | " token = line.split(',') #tokenize the string\n", 64 | " tokenIndex = [homeTeamNameIndex, homeTeamScoreIndex, visitorTeamNameIndex, visitorTeamScoreIndex] + [i for i in homeTeamStatsIndex] + [i for i in visitorTeamStatsIndex]\n", 65 | " attributes = dict()\n", 66 | " \n", 67 | " for i in range(len(token)):\n", 68 | " if(i in tokenIndex):\n", 69 | " attributes[i] = self.removeQuotes(token[i])\n", 70 | " \n", 71 | " self.addScore(attributes[homeTeamNameIndex], attributes[visitorTeamNameIndex], attributes[homeTeamScoreIndex], attributes[visitorTeamScoreIndex]) \n", 72 | " self.addGame(attributes[homeTeamNameIndex], attributes[homeTeamScoreIndex], [attributes[i] for i in homeTeamStatsIndex], attributes[visitorTeamNameIndex], attributes[visitorTeamScoreIndex], [attributes[i] for i in visitorTeamStatsIndex])\n", 73 | " \n", 74 | " if(attributes[homeTeamNameIndex] not in seenTeams):\n", 75 | " seenTeams.append(attributes[homeTeamNameIndex])\n", 76 | " \n", 77 | " self.buildAvgMatrix()\n", 78 | " self.statsFile.close()\n", 79 | " #self.gameList = [bin(2**i)[2:].zfill(len(seenTeams)) if x == seenTeams[i] else x for i in range(len(seenTeams)) for x in self.gameList]\n", 80 | " # take the teams and convert the indexes of the order that they appeared in into the hot one format by rasiing\n", 81 | " # 2 to the power of them and then converting them into binary.\n", 82 | " seenTeamsDict = {k: v for v, k in enumerate(seenTeams)}\n", 83 | " temp = []\n", 84 | " for x in self.gameList:\n", 85 | " tempi = []\n", 86 | " for z in x:\n", 87 | " if(z in seenTeams):\n", 88 | " tempi.append(bin(2**seenTeamsDict[z])[2:].zfill(len(seenTeamsDict)))\n", 89 | " else:\n", 90 | " tempi.append(z)\n", 91 | " temp.append(tempi)\n", 92 | " self.gameList = temp\n", 93 | " \n", 94 | " def removeQuotes(self, string):\n", 95 | " if (string.startswith('\"') and string.endswith('\"')) or (string.startswith(\"'\") and string.endswith(\"'\")):\n", 96 | " return string[1:-1]\n", 97 | " return string \n", 98 | " \n", 99 | " def addGame(self, team1, score1, stats1, team2, score2, stats2):\n", 100 | " self.gameList.append([team1, score1, stats1, team2, score2, stats2])\n", 101 | " \n", 102 | " #give it two teams, the scores, and it will add it to the matrix\n", 103 | " def addScore(self, team1, team2, score1, score2):\n", 104 | " '''\n", 105 | " for a team in top array, the index in the array corrisponds to the matrix column there located in\n", 106 | " for a team in side array, the index in the array corrisponds to the matrix row there located in\n", 107 | " '''\n", 108 | " #team 1 score entry\n", 109 | " try:\n", 110 | " row = self.sideArray.index(team2) \n", 111 | "\n", 112 | " except:\n", 113 | " self.sideArray.append(team2)\n", 114 | " row = self.sideArray.index(team2) \n", 115 | "\n", 116 | " try:\n", 117 | " col = self.topArray.index(team1)\n", 118 | " except:\n", 119 | " self.topArray.append(team1)\n", 120 | " col = self.topArray.index(team1)\n", 121 | " temp = self.sc[row, col]\n", 122 | " counter = 0\n", 123 | " for e in temp:\n", 124 | " if (e == -1):\n", 125 | " temp[counter] = score1\n", 126 | " break\n", 127 | " counter += 1\n", 128 | " self.sc[row, col] = temp\n", 129 | " \n", 130 | " #team 2 score entry\n", 131 | " try:\n", 132 | " row = self.sideArray.index(team1) \n", 133 | " except:\n", 134 | " self.sideArray.append(team1)\n", 135 | " row = self.sideArray.index(team1) \n", 136 | " \n", 137 | " try:\n", 138 | " col = self.topArray.index(team2)\n", 139 | " except:\n", 140 | " self.topArray.append(team2)\n", 141 | " col = self.topArray.index(team2)\n", 142 | " temp = self.sc[row, col]\n", 143 | " counter = 0\n", 144 | " for e in temp:\n", 145 | " if (e == -1):\n", 146 | " temp[counter] = score2\n", 147 | " break\n", 148 | " counter += 1\n", 149 | " self.sc[row, col] = temp\n", 150 | " \n", 151 | " #returns the score(s) for match up\n", 152 | " def getScore(self, team1, team2, gameSelect = None):\n", 153 | " print(team1, team2)\n", 154 | " try:\n", 155 | " score1 = self.sc[self.sideArray.index(team2), self.topArray.index(team1)]\n", 156 | " score2 = self.sc[self.sideArray.index(team1), self.topArray.index(team2)]\n", 157 | " if (gameSelect == None):\n", 158 | " print(team1, score1)\n", 159 | " print(team2, score2)\n", 160 | " else:\n", 161 | " print(team1, score1[gameSelect])\n", 162 | " print(team2, score2[gameSelect])\n", 163 | " except:\n", 164 | " print('Invalid input of teams')\n", 165 | " \n", 166 | " def getGameList(self):\n", 167 | " return copy.deepcopy(self.gameList)\n", 168 | " \n", 169 | " #constructs a matrix of the avg score in a matchup\n", 170 | " def buildAvgMatrix(self): \n", 171 | " for col in range(len(self.sc[:,0])): #depth\n", 172 | " for row in range(len(self.sc[0, :])): #width\n", 173 | " tempScore = self.sc[row, col]\n", 174 | " avgScore = 0.0\n", 175 | " count = 0.0\n", 176 | " for j in tempScore:\n", 177 | " if (j != -1):\n", 178 | " avgScore += j\n", 179 | " count += 1\n", 180 | " else:\n", 181 | " break\n", 182 | " try:\n", 183 | " avgScore = avgScore / count\n", 184 | " except:\n", 185 | " avgScore = -1\n", 186 | " self.am[row, col] = avgScore\n", 187 | " \n", 188 | " #get the value of the avg score for a match up\n", 189 | " def getAvgScore(self, team1, team2):\n", 190 | " try:\n", 191 | " score1 = self.am[self.sideArray.index(team2), self.topArray.index(team1)]\n", 192 | " score2 = self.am[self.sideArray.index(team1), self.topArray.index(team2)]\n", 193 | " print(team1, score1)\n", 194 | " print(team2, score2) \n", 195 | " except:\n", 196 | " print('Invalid input of teams')" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": { 202 | "deletable": true, 203 | "editable": true 204 | }, 205 | "source": [ 206 | "Extract the data form the file." 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 3, 212 | "metadata": { 213 | "collapsed": true, 214 | "deletable": true, 215 | "editable": true 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "gameStats = GameStats(6, 10, [8, 50, 53, 54, 58], 3, 9, [5, 22, 25, 26, 30])\n", 220 | "gameList = gameStats.getGameList() # get the list of games\n", 221 | "#getting rid of the strings\n", 222 | "def removeQuotes(gameList):\n", 223 | " for row in gameList:\n", 224 | " for x in range(len(row)):\n", 225 | " #convert scores strings to float\n", 226 | " if (x == 1 or x == 4): row[x] = float(row[x])\n", 227 | " #convert arrays to floats\n", 228 | " if (x == 2 or x == 5): row[x] = list(map(float, row[x]))\n", 229 | " return gameList\n" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": { 235 | "deletable": true, 236 | "editable": true 237 | }, 238 | "source": [ 239 | "Print off all the data used into file name 'gameList.txt' so the user can see what stats were working with." 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 4, 245 | "metadata": { 246 | "collapsed": false, 247 | "deletable": true, 248 | "editable": true 249 | }, 250 | "outputs": [], 251 | "source": [ 252 | "gameListFile = open(\"gameList.txt\", \"w\")\n", 253 | "for row in gameList:\n", 254 | " gameListFile.write('\\n-------------------------------------------')\n", 255 | " gameListFile.write('\\nHome: ')\n", 256 | " gameListFile.write(row[0])\n", 257 | " gameListFile.write('\\nScore: ')\n", 258 | " gameListFile.write(str(row[1]))\n", 259 | " \n", 260 | " gameListFile.write('\\n game: ')\n", 261 | " gameListFile.write(str(row[2][0]))\n", 262 | " gameListFile.write('\\n hits: ')\n", 263 | " gameListFile.write(str(row[2][1]))\n", 264 | " gameListFile.write('\\n home runs: ')\n", 265 | " gameListFile.write(str(row[2][2]))\n", 266 | " gameListFile.write('\\n RBI: ')\n", 267 | " gameListFile.write(str(row[2][3]))\n", 268 | " gameListFile.write('\\n walks: ')\n", 269 | " gameListFile.write(str(row[2][4]))\n", 270 | " \n", 271 | " gameListFile.write('\\n-----------------')\n", 272 | " \n", 273 | " gameListFile.write('\\nAway: ')\n", 274 | " gameListFile.write((row[3]))\n", 275 | " gameListFile.write('\\nScore: ')\n", 276 | " gameListFile.write(str(row[4]))\n", 277 | " \n", 278 | " gameListFile.write('\\n game: ')\n", 279 | " gameListFile.write(str(row[5][0]))\n", 280 | " gameListFile.write('\\n hits: ')\n", 281 | " gameListFile.write(str(row[5][1]))\n", 282 | " gameListFile.write('\\n home runs: ')\n", 283 | " gameListFile.write(str(row[5][2]))\n", 284 | " gameListFile.write('\\n RBI: ')\n", 285 | " gameListFile.write(str(row[5][3]))\n", 286 | " gameListFile.write('\\n walks: ')\n", 287 | " gameListFile.write(str(row[5][4]))\n", 288 | " gameListFile.write('\\n-------------------------------------------')\n", 289 | "gameListFile.close()\n", 290 | "gameList = removeQuotes(gameList) #now thats its printed to the file, lets remove the quotes from the #" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "metadata": { 296 | "deletable": true, 297 | "editable": true 298 | }, 299 | "source": [ 300 | "Now lets make the average stats for each team." 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 5, 306 | "metadata": { 307 | "collapsed": false, 308 | "deletable": true, 309 | "editable": true 310 | }, 311 | "outputs": [], 312 | "source": [ 313 | "avgStats = dict()\n", 314 | "np.set_printoptions(precision=4) #IF you wanna remove this i recommend restarting the kernal\n", 315 | "\n", 316 | "#add up all the stats for each team\n", 317 | "for row in gameList:\n", 318 | " #home\n", 319 | " if row[0] in avgStats:\n", 320 | " avgStats[row[0]][0] = 0 #set the proir games to 0,because there set to right amount next line\n", 321 | " avgStats[row[0]] = np.sum([avgStats[row[0]], row[2]], axis=0) #total stats + indivudal game stats\n", 322 | " else:\n", 323 | " avgStats[row[0]] = row[2]\n", 324 | " \n", 325 | " #away\n", 326 | " if row[3] in avgStats:\n", 327 | " avgStats[row[3]][0] = 0\n", 328 | " avgStats[row[3]] = np.sum([avgStats[row[3]], row[5]], axis=0)\n", 329 | " else:\n", 330 | " avgStats[row[3]] = row[5]\n", 331 | "\n", 332 | "#divide the sum, to make the average..\n", 333 | "keyList = avgStats.keys() \n", 334 | "for key in keyList:\n", 335 | " for index in range(1, len(avgStats[key])): #divide all stats by games played (162, 162)\n", 336 | " avgStats[key][index] = avgStats[key][index] / avgStats[key][0] " 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 6, 342 | "metadata": { 343 | "collapsed": false, 344 | "deletable": true, 345 | "editable": true 346 | }, 347 | "outputs": [], 348 | "source": [ 349 | "# split the stats up into different parts (this takes a lot of time, but I think it will be worth it)\n", 350 | "\n", 351 | "#gameList = gameStats.getGameList() # I want to test what happens if we just use the stats of the game.\n", 352 | "gameList = removeQuotes(gameList) #because i initialized gamestats above, must remove quote again\n", 353 | "homeTeamName = np.empty((0, len(list(gameList[0][0])))) \n", 354 | "homeTeamScore = np.empty((0, 1))\n", 355 | "homeTeamStats = np.empty((0, len(gameList[0][2])))\n", 356 | "\n", 357 | "visitingTeamName = np.empty((0, len(list(gameList[0][3])))) \n", 358 | "visitingTeamScore = np.empty((0, 1))\n", 359 | "visitingTeamStats = np.empty((0, len(gameList[0][5])))\n", 360 | "\n", 361 | "for row in gameList:\n", 362 | " homeTeamName = np.vstack((homeTeamName, list(row[0])))\n", 363 | " homeTeamScore = np.vstack((homeTeamScore, row[1]))\n", 364 | " homeTeamStats = np.vstack((homeTeamStats, row[2]))\n", 365 | " visitingTeamName = np.vstack((visitingTeamName, list(row[3])))\n", 366 | " visitingTeamScore = np.vstack((visitingTeamScore, row[4]))\n", 367 | " visitingTeamStats = np.vstack((visitingTeamStats, row[5]))" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 7, 373 | "metadata": { 374 | "collapsed": true, 375 | "deletable": true, 376 | "editable": true 377 | }, 378 | "outputs": [], 379 | "source": [ 380 | "# take the stats and zscore them, I do a number of transformations on the array because of how the zScore input contstraints \n", 381 | "zScoredStatsHome = np.empty((0, len(homeTeamStats)))\n", 382 | "zScoredStatsVisitor = np.empty((0, len(visitingTeamStats)))\n", 383 | "for x in range(len(homeTeamStats[0])):\n", 384 | " zScoredStatsHome = np.vstack((zScoredStatsHome, stats.zscore([i[x] for i in homeTeamStats])))\n", 385 | " zScoredStatsVisitor = np.vstack((zScoredStatsVisitor, stats.zscore([i[x] for i in visitingTeamStats])))\n", 386 | "\n", 387 | "homeTeamStatsTemp = np.empty((len(zScoredStatsHome[0]), 0))\n", 388 | "visitingTeamStatsTemp = np.empty((len(zScoredStatsVisitor[0]), 0))\n", 389 | "for i in zScoredStatsHome:\n", 390 | " homeTeamStatsTemp = np.hstack((homeTeamStatsTemp, [[x] for x in i]))\n", 391 | "\n", 392 | "for i in zScoredStatsVisitor:\n", 393 | " visitingTeamStatsTemp = np.hstack((visitingTeamStatsTemp, [[x] for x in i]))" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": 19, 399 | "metadata": { 400 | "collapsed": false, 401 | "deletable": true, 402 | "editable": true 403 | }, 404 | "outputs": [ 405 | { 406 | "name": "stdout", 407 | "output_type": "stream", 408 | "text": [ 409 | "Training Step: 699 | total loss: \u001b[1m\u001b[32m6.14998\u001b[0m\u001b[0m | time: 0.273s\n", 410 | "| Adam | epoch: 020 | loss: 6.14998 - acc: 0.4187 -- iter: 2176/2185\n", 411 | "Training Step: 700 | total loss: \u001b[1m\u001b[32m6.14890\u001b[0m\u001b[0m | time: 1.284s\n", 412 | "| Adam | epoch: 020 | loss: 6.14890 - acc: 0.3987 | val_loss: 6.44243 - val_acc: 0.1111 -- iter: 2185/2185\n", 413 | "--\n", 414 | "\n" 415 | ] 416 | } 417 | ], 418 | "source": [ 419 | "predFile = open(\"prediction.txt\", \"w\")\n", 420 | "predFile.close() #clear the file\n", 421 | "learnArr = [1*(10**(x)) for x in [-2]]\n", 422 | "regulArr = [1*(2**(x)) for x in xrange(-10)]\n", 423 | "layerArr = [5, 8]\n", 424 | "widthArr = [30]\n", 425 | "paramErr = dict()\n", 426 | "mastErrList = []\n", 427 | "\n", 428 | "for learnRate in tqdm_notebook(learnArr):\n", 429 | " for numWidth in tqdm_notebook(widthArr):\n", 430 | " for numLayers in tqdm_notebook(layerArr): \n", 431 | " tf.reset_default_graph()\n", 432 | " \n", 433 | " # the input layer needs to have the same dimensions as our input (in this case the teams)\n", 434 | " homeTeamNameInput = tflearn.input_data(shape=[None, len(homeTeamName[0])], name='nameInput1')\n", 435 | " homeTeamStatsInput = tflearn.input_data(shape=[None, len(homeTeamStatsTemp[0])], name='statsInput1')\n", 436 | " visitingTeamNameInput = tflearn.input_data(shape=[None, len(visitingTeamName[0])], name='nameInput2')\n", 437 | " visitingTeamStatsInput = tflearn.input_data(shape=[None, len(visitingTeamStatsTemp[0])], name='statsInput2')\n", 438 | " \n", 439 | " nameProcess1 = tflearn.fully_connected(homeTeamNameInput, 4)\n", 440 | " nameProcess2 = tflearn.fully_connected(visitingTeamNameInput, 4)\n", 441 | " net = tflearn.layers.merge_ops.merge([nameProcess1, homeTeamStatsInput, nameProcess2, visitingTeamStatsInput], 'concat', axis=1)\n", 442 | " # next we have the hidden layer it is the feature matrix the size is arbitrary.\n", 443 | " for _ in range(numLayers):\n", 444 | " net = tflearn.fully_connected(net, numWidth)\n", 445 | " \n", 446 | " net = tflearn.dropout(net, keep_prob=0.5)\n", 447 | " # The output layer\n", 448 | " net = tflearn.fully_connected(net, 2)\n", 449 | " net = tflearn.regression(net, name='target', learning_rate=learnRate)\n", 450 | " \n", 451 | " # take only the stats for each team and put them into an array\n", 452 | " NNOutput = [[i[1], i[4]] for i in gameList]\n", 453 | " NNOutput = np.array(NNOutput)\n", 454 | " \n", 455 | " # Define model\n", 456 | " model = tflearn.DNN(net)\n", 457 | " # Start training (apply gradient descent algorithm)\n", 458 | " model.fit({'nameInput1':homeTeamName, 'statsInput1':homeTeamStatsTemp, 'nameInput2':visitingTeamName, 'statsInput2':visitingTeamStatsTemp}, NNOutput, validation_set=0.1, n_epoch=20, show_metric=True)\n", 459 | " \n", 460 | " \n", 461 | " #MAKING THE PREDICTIONS\n", 462 | " predFile = open(\"prediction.txt\", \"a\")\n", 463 | " errorArr = []\n", 464 | " for i in tqdm_notebook(range(len(homeTeamName))):\n", 465 | " predict = model.predict([[homeTeamName[i]], [homeTeamStatsTemp[i]], [visitingTeamName[i]], [visitingTeamStatsTemp[i]]])\n", 466 | " pred = predict[0] #only need 1 dimension in the array\n", 467 | " pred[0] = round(pred[0], 2) #rounding the predictions\n", 468 | " pred[1] = round(pred[1], 2)\n", 469 | " #using relitive error but cant divide by 0, so replace actual score of 0 with 0.1\n", 470 | " if (NNOutput[i][0] == 0): NNOutput[i][0] = 0.1\n", 471 | " if (NNOutput[i][1] == 0): NNOutput[i][1] = 0.1\n", 472 | " \n", 473 | " error = (abs(pred[0] - NNOutput[i][0]) / NNOutput[i][0]) + (abs(pred[1] - NNOutput[i][1]) / NNOutput[i][1]) #adding the 'relitive error'\n", 474 | " error = round(error, 4) #round to 4 digits\n", 475 | " errorArr.append(error) \n", 476 | " if (i < 15): #only print 15 cause printing the prediction for 1000 games isnt needed\n", 477 | " predFile.write('\\n-----------')\n", 478 | " predFile.write('\\n Home | Away')\n", 479 | " predFile.write('\\nPred: ' + str(pred[0]) + ' | '+ str(pred[1]))\n", 480 | " predFile.write('\\nActu: ' + str(NNOutput[i][0]) + ' | '+ str(NNOutput[i][1]))\n", 481 | " predFile.write('\\nError: ' + str(error))\n", 482 | " predFile.write('\\n-----------')\n", 483 | " \n", 484 | " \n", 485 | " totErr = round(sum(errorArr) / len(homeTeamName), 3) #the average relative errror\n", 486 | " tempStr = ('Error: ' + str(totErr) + ', learnRate: ' + str(learnRate) + ', Width: ' + str(numWidth) + ', layers: ' + str(numLayers))\n", 487 | " mastErrList.append(tempStr) #this is an list of all the relitive error averages and what combination caused them\n", 488 | " \n", 489 | " predFile.write('\\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^'+ str(len(mastErrList)-1)+ '^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n')\n", 490 | " predFile.write(tempStr + '\\n')\n", 491 | " predFile.write('\\n_________________________________________________________________________________')\n", 492 | " predFile.write('\\n_________________________________________________________________________________')\n", 493 | " predFile.write('\\n_________________________________________________________________________________')\n", 494 | " predFile.close() #close the file so that it saves the changes" 495 | ] 496 | }, 497 | { 498 | "cell_type": "markdown", 499 | "metadata": { 500 | "deletable": true, 501 | "editable": true 502 | }, 503 | "source": [ 504 | "Here is a list of all the combinations tested and the 'average relitive error' they yeild." 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 22, 510 | "metadata": { 511 | "collapsed": false, 512 | "deletable": true, 513 | "editable": true 514 | }, 515 | "outputs": [ 516 | { 517 | "name": "stdout", 518 | "output_type": "stream", 519 | "text": [ 520 | "Error: 1.993, learnRate: 0.01, Width: 30, layers: 5\n", 521 | "Error: 9750.962, learnRate: 0.01, Width: 30, layers: 8\n" 522 | ] 523 | } 524 | ], 525 | "source": [ 526 | "for i in mastErrList:\n", 527 | " print(i)" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 54, 533 | "metadata": { 534 | "collapsed": false, 535 | "deletable": true, 536 | "editable": true 537 | }, 538 | "outputs": [], 539 | "source": [] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": 55, 544 | "metadata": { 545 | "collapsed": true, 546 | "deletable": true, 547 | "editable": true 548 | }, 549 | "outputs": [], 550 | "source": [] 551 | }, 552 | { 553 | "cell_type": "code", 554 | "execution_count": null, 555 | "metadata": { 556 | "collapsed": false, 557 | "deletable": true, 558 | "editable": true 559 | }, 560 | "outputs": [], 561 | "source": [] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "execution_count": null, 566 | "metadata": { 567 | "collapsed": false, 568 | "deletable": true, 569 | "editable": true 570 | }, 571 | "outputs": [], 572 | "source": [] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": null, 577 | "metadata": { 578 | "collapsed": false, 579 | "deletable": true, 580 | "editable": true 581 | }, 582 | "outputs": [], 583 | "source": [] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": 16, 588 | "metadata": { 589 | "collapsed": false, 590 | "deletable": true, 591 | "editable": true 592 | }, 593 | "outputs": [], 594 | "source": [] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": null, 599 | "metadata": { 600 | "collapsed": true, 601 | "deletable": true, 602 | "editable": true 603 | }, 604 | "outputs": [], 605 | "source": [] 606 | } 607 | ], 608 | "metadata": { 609 | "kernelspec": { 610 | "display_name": "Python 3", 611 | "language": "python", 612 | "name": "python3" 613 | }, 614 | "language_info": { 615 | "codemirror_mode": { 616 | "name": "ipython", 617 | "version": 3 618 | }, 619 | "file_extension": ".py", 620 | "mimetype": "text/x-python", 621 | "name": "python", 622 | "nbconvert_exporter": "python", 623 | "pygments_lexer": "ipython3", 624 | "version": "3.5.3" 625 | }, 626 | "widgets": { 627 | "state": { 628 | "069276cff7bf4814af22e14eec1457ab": { 629 | "views": [ 630 | { 631 | "cell_index": 11 632 | } 633 | ] 634 | }, 635 | "0d0f097d22674400882ea667b246504b": { 636 | "views": [ 637 | { 638 | "cell_index": 10 639 | } 640 | ] 641 | }, 642 | "109d06593f59494f85e7109df9bebf94": { 643 | "views": [ 644 | { 645 | "cell_index": 11 646 | } 647 | ] 648 | }, 649 | "30571c55e1f34e94aade16921e1dde5d": { 650 | "views": [ 651 | { 652 | "cell_index": 11 653 | } 654 | ] 655 | }, 656 | "3beea730509648ccb18658350db19c5d": { 657 | "views": [ 658 | { 659 | "cell_index": 10 660 | } 661 | ] 662 | }, 663 | "47117cda286145d299180a796ca362a0": { 664 | "views": [ 665 | { 666 | "cell_index": 11 667 | } 668 | ] 669 | }, 670 | "541a4ea7fe984a31bc69ae35c82312b9": { 671 | "views": [ 672 | { 673 | "cell_index": 11 674 | } 675 | ] 676 | }, 677 | "5cde23d8f63349d5bf8864584852153f": { 678 | "views": [ 679 | { 680 | "cell_index": 11 681 | } 682 | ] 683 | }, 684 | "5d506b11717c4223868d57a973887c0c": { 685 | "views": [ 686 | { 687 | "cell_index": 11 688 | } 689 | ] 690 | }, 691 | "5ee5a07940444b48a41edcd45ff50c0f": { 692 | "views": [ 693 | { 694 | "cell_index": 11 695 | } 696 | ] 697 | }, 698 | "692f0a4e6c204ed2b3bba5db186aefa4": { 699 | "views": [ 700 | { 701 | "cell_index": 11 702 | } 703 | ] 704 | }, 705 | "737ba34e94b2425b9bcf2112cd165406": { 706 | "views": [ 707 | { 708 | "cell_index": 11 709 | } 710 | ] 711 | }, 712 | "7f93e45062e944dbb3fc10b28a211cfe": { 713 | "views": [ 714 | { 715 | "cell_index": 11 716 | } 717 | ] 718 | }, 719 | "95805d5039ca4a2cbfaa45af3e3bf587": { 720 | "views": [ 721 | { 722 | "cell_index": 11 723 | } 724 | ] 725 | }, 726 | "a3e50f9ea968462a922ad2c0b32a74ee": { 727 | "views": [ 728 | { 729 | "cell_index": 11 730 | } 731 | ] 732 | }, 733 | "b80e6c386820476fb36a5711d7b89bdc": { 734 | "views": [ 735 | { 736 | "cell_index": 11 737 | } 738 | ] 739 | }, 740 | "d1c0375837c74e0dac63ba12261ed35f": { 741 | "views": [ 742 | { 743 | "cell_index": 11 744 | } 745 | ] 746 | }, 747 | "ecc0fedd1c6a45a28a912a312546f936": { 748 | "views": [ 749 | { 750 | "cell_index": 11 751 | } 752 | ] 753 | }, 754 | "ef35bef469f94e84baeaccc91a714448": { 755 | "views": [ 756 | { 757 | "cell_index": 11 758 | } 759 | ] 760 | }, 761 | "f4b3b43b824a45c5a4134b641be6cb7d": { 762 | "views": [ 763 | { 764 | "cell_index": 11 765 | } 766 | ] 767 | } 768 | }, 769 | "version": "1.2.0" 770 | } 771 | }, 772 | "nbformat": 4, 773 | "nbformat_minor": 2 774 | } 775 | --------------------------------------------------------------------------------