├── .basketball_profile ├── .gitignore ├── 20210408160343_cropped.gif ├── 20210408161424_cropped.gif ├── 267_3_gen_baller2vec++_7_cropped.gif ├── 267_3_gen_baller2vec++_8_cropped.gif ├── 267_3_gen_baller2vec++_9_cropped.gif ├── 267_3_gen_baller2vec_0_cropped.gif ├── 267_3_gen_baller2vec_1_cropped.gif ├── 267_3_gen_baller2vec_7_cropped.gif ├── 267_3_truth_cropped.gif ├── LICENSE ├── README.md ├── baller2vec++.svg ├── baller2vecplusplus.py ├── baller2vecplusplus_dataset.py ├── court.png ├── events.zip ├── generate_game_numpy_arrays.py ├── gif_cropper.sh ├── gif_merger.sh ├── image_cropper.sh ├── paper_functions.py ├── paper_plots.py ├── requirements.txt ├── settings.py ├── test_gameids.txt ├── toy_dataset.py ├── train_baller2vecplusplus.py ├── train_cropped.gif ├── train_gameids.txt └── valid_gameids.txt /.basketball_profile: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Code directory. 4 | export PROJECT_DIR=/home/michael/baller2vec 5 | export DATA_DIR=/mnt/raid/michael/baller2vec_data 6 | export EVENTS_DIR=${DATA_DIR}/events 7 | export TRACKING_DIR=${DATA_DIR}/NBA-Player-Movements/data/2016.NBA.Raw.SportVU.Game.Logs 8 | export GAMES_DIR=${DATA_DIR}/games 9 | export EXPERIMENTS_DIR=${DATA_DIR}/experiments -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /20210408160343_cropped.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/20210408160343_cropped.gif -------------------------------------------------------------------------------- /20210408161424_cropped.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/20210408161424_cropped.gif -------------------------------------------------------------------------------- /267_3_gen_baller2vec++_7_cropped.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec++_7_cropped.gif -------------------------------------------------------------------------------- /267_3_gen_baller2vec++_8_cropped.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec++_8_cropped.gif -------------------------------------------------------------------------------- /267_3_gen_baller2vec++_9_cropped.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec++_9_cropped.gif -------------------------------------------------------------------------------- /267_3_gen_baller2vec_0_cropped.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec_0_cropped.gif -------------------------------------------------------------------------------- /267_3_gen_baller2vec_1_cropped.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec_1_cropped.gif -------------------------------------------------------------------------------- /267_3_gen_baller2vec_7_cropped.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec_7_cropped.gif -------------------------------------------------------------------------------- /267_3_truth_cropped.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_truth_cropped.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 Michael A. Alcorn 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # baller2vec++ 2 | 3 | This is the repository for the paper: 4 | 5 | >[Michael A. Alcorn](https://sites.google.com/view/michaelaalcorn) and [Anh Nguyen](http://anhnguyen.me). [`baller2vec++`: A Look-Ahead Multi-Entity Transformer For Modeling Coordinated Agents](https://arxiv.org/abs/2104.11980). arXiv. 2021. 6 | 7 | | | 8 | |:--| 9 | | To learn statistically dependent agent trajectories, baller2vec++ uses a specially designed self-attention mask to simultaneously process three different sets of features vectors in a single Transformer. The three sets of feature vectors consist of location feature vectors like those found in baller2vec, look-ahead trajectory feature vectors, and starting location feature vectors. This design allows the model to integrate information about *concurrent* agent trajectories through *multiple* Transformer layers without seeing the future (in contrast to baller2vec). | 10 | 11 | | | | | 12 | |:--:|:--:|:--:| 13 | | Training sample | baller2vec | baller2vec++ | 14 | 15 | When trained on a dataset of perfectly coordinated agent trajectories, the trajectories generated by baller2vec are completely *uncoordinated* while the trajectories generated by baller2vec++ are perfectly coordinated. 16 | 17 | | | | | | 18 | |:--:|:--:|:--:|:--:| 19 | | Ground truth | baller2vec | baller2vec | baller2vec | 20 | | | | | | 21 | | Ground truth | baller2vec++ | baller2vec++ | baller2vec++ | 22 | 23 | While baller2vec occasionally generates realistic trajectories for the red defender, it also makes egregious errors. 24 | In contrast, the trajectories generated by baller2vec++ often seem plausible. 25 | The red player was placed *last* in the player order when generating his trajectory with baller2vec++. 26 | 27 | ## Citation 28 | 29 | If you use this code for your own research, please cite: 30 | 31 | ``` 32 | @article{alcorn2021baller2vec, 33 | title={\texttt{baller2vec++}: A Look-Ahead Multi-Entity Transformer For Modeling Coordinated Agents}, 34 | author={Alcorn, Michael A. and Nguyen, Anh}, 35 | journal={arXiv preprint arXiv:2104.11980}, 36 | year={2021} 37 | } 38 | ``` 39 | 40 | ## Training baller2vec++ 41 | 42 | ### Setting up `.basketball_profile` 43 | 44 | After you've cloned the repository to your desired location, create a file called `.basketball_profile` in your home directory: 45 | 46 | ```bash 47 | nano ~/.basketball_profile 48 | ``` 49 | 50 | and copy and paste in the contents of [`.basketball_profile`](.basketball_profile), replacing each of the variable values with paths relevant to your environment. 51 | Next, add the following line to the end of your `~/.bashrc`: 52 | 53 | ```bash 54 | source ~/.basketball_profile 55 | ``` 56 | 57 | and either log out and log back in again or run: 58 | 59 | ```bash 60 | source ~/.bashrc 61 | ``` 62 | 63 | You should now be able to copy and paste all of the commands in the various instructions sections. 64 | For example: 65 | 66 | ```bash 67 | echo ${PROJECT_DIR} 68 | ``` 69 | 70 | should print the path you set for `PROJECT_DIR` in `.basketball_profile`. 71 | 72 | ### Installing the necessary Python packages 73 | 74 | ```bash 75 | cd ${PROJECT_DIR} 76 | pip3 install --upgrade -r requirements.txt 77 | ``` 78 | 79 | ### Organizing the play-by-play and tracking data 80 | 81 | 1) Copy `events.zip` (which I acquired from [here](https://github.com/sealneaward/nba-movement-data/tree/master/data/events) \[mirror [here](https://github.com/airalcorn2/nba-movement-data/tree/master/data/events)\] using https://downgit.github.io) to the `DATA_DIR` directory and unzip it: 82 | 83 | ```bash 84 | mkdir -p ${DATA_DIR} 85 | cp ${PROJECT_DIR}/events.zip ${DATA_DIR} 86 | cd ${DATA_DIR} 87 | unzip -q events.zip 88 | rm events.zip 89 | ``` 90 | 91 | Descriptions for the various `EVENTMSGTYPE`s can be found [here](https://github.com/rd11490/NBA_Tutorials/tree/master/analyze_play_by_play) (mirror [here](https://github.com/airalcorn2/NBA_Tutorials/tree/master/analyze_play_by_play)). 92 | 93 | 2) Clone the tracking data from [here](https://github.com/linouk23/NBA-Player-Movements) (mirror [here](https://github.com/airalcorn2/NBA-Player-Movements)) to the `DATA_DIR` directory: 94 | 95 | ```bash 96 | cd ${DATA_DIR} 97 | git clone git@github.com:linouk23/NBA-Player-Movements.git 98 | ``` 99 | 100 | A description of the tracking data can be found [here](https://danvatterott.com/blog/2016/06/16/creating-videos-of-nba-action-with-sportsvu-data/). 101 | 102 | ### Generating the training data 103 | 104 | ```bash 105 | cd ${PROJECT_DIR} 106 | nohup python3 generate_game_numpy_arrays.py > data.log & 107 | ``` 108 | 109 | You can monitor its progress with: 110 | 111 | ```bash 112 | top 113 | ``` 114 | 115 | or: 116 | 117 | ```bash 118 | ls -U ${GAMES_DIR} | wc -l 119 | ``` 120 | 121 | There should be 1,262 NumPy arrays (corresponding to 631 X/y pairs) when finished. 122 | 123 | ### Running the training script 124 | 125 | Run (or copy and paste) the following script, editing the variables as appropriate. 126 | 127 | ```bash 128 | #!/usr/bin/env bash 129 | 130 | JOB=$(date +%Y%m%d%H%M%S) 131 | 132 | echo "train:" >> ${JOB}.yaml 133 | task=basketball # "basketball" or "toy". 134 | echo " task: ${task}" >> ${JOB}.yaml 135 | if [[ "$task" = "basketball" ]] 136 | then 137 | 138 | echo " train_samples_per_epoch: 20000" >> ${JOB}.yaml 139 | echo " valid_samples: 1000" >> ${JOB}.yaml 140 | echo " workers: 10" >> ${JOB}.yaml 141 | echo " learning_rate: 1.0e-5" >> ${JOB}.yaml 142 | echo " patience: 20" >> ${JOB}.yaml 143 | 144 | echo "dataset:" >> ${JOB}.yaml 145 | echo " hz: 5" >> ${JOB}.yaml 146 | echo " secs: 4.2" >> ${JOB}.yaml 147 | echo " player_traj_n: 11" >> ${JOB}.yaml 148 | echo " max_player_move: 4.5" >> ${JOB}.yaml 149 | 150 | echo "model:" >> ${JOB}.yaml 151 | echo " embedding_dim: 20" >> ${JOB}.yaml 152 | echo " sigmoid: none" >> ${JOB}.yaml 153 | echo " mlp_layers: [128, 256, 512]" >> ${JOB}.yaml 154 | echo " nhead: 8" >> ${JOB}.yaml 155 | echo " dim_feedforward: 2048" >> ${JOB}.yaml 156 | echo " num_layers: 6" >> ${JOB}.yaml 157 | echo " dropout: 0.0" >> ${JOB}.yaml 158 | echo " b2v: False" >> ${JOB}.yaml 159 | 160 | else 161 | 162 | echo " workers: 10" >> ${JOB}.yaml 163 | echo " learning_rate: 1.0e-4" >> ${JOB}.yaml 164 | 165 | echo "model:" >> ${JOB}.yaml 166 | echo " embedding_dim: 20" >> ${JOB}.yaml 167 | echo " sigmoid: none" >> ${JOB}.yaml 168 | echo " mlp_layers: [64, 128]" >> ${JOB}.yaml 169 | echo " nhead: 4" >> ${JOB}.yaml 170 | echo " dim_feedforward: 512" >> ${JOB}.yaml 171 | echo " num_layers: 2" >> ${JOB}.yaml 172 | echo " dropout: 0.0" >> ${JOB}.yaml 173 | echo " b2v: True" >> ${JOB}.yaml 174 | 175 | fi 176 | 177 | # Save experiment settings. 178 | mkdir -p ${EXPERIMENTS_DIR}/${JOB} 179 | mv ${JOB}.yaml ${EXPERIMENTS_DIR}/${JOB}/ 180 | 181 | gpu=0 182 | cd ${PROJECT_DIR} 183 | nohup python3 train_baller2vecplusplus.py ${JOB} ${gpu} > ${EXPERIMENTS_DIR}/${JOB}/train.log & 184 | ``` 185 | -------------------------------------------------------------------------------- /baller2vecplusplus.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from torch import nn 5 | 6 | 7 | class Baller2VecPlusPlus(nn.Module): 8 | def __init__( 9 | self, 10 | n_player_ids, 11 | embedding_dim, 12 | sigmoid, 13 | seq_len, 14 | mlp_layers, 15 | n_players, 16 | n_player_labels, 17 | nhead, 18 | dim_feedforward, 19 | num_layers, 20 | dropout, 21 | b2v, 22 | ): 23 | super().__init__() 24 | self.sigmoid = sigmoid 25 | self.seq_len = seq_len 26 | self.n_players = n_players 27 | self.b2v = b2v 28 | 29 | initrange = 0.1 30 | self.player_embedding = nn.Embedding(n_player_ids, embedding_dim) 31 | self.player_embedding.weight.data.uniform_(-initrange, initrange) 32 | 33 | start_mlp = nn.Sequential() 34 | pos_mlp = nn.Sequential() 35 | traj_mlp = nn.Sequential() 36 | pos_in_feats = embedding_dim + 3 37 | traj_in_feats = embedding_dim + 5 38 | for (layer_idx, out_feats) in enumerate(mlp_layers): 39 | start_mlp.add_module( 40 | f"layer{layer_idx}", nn.Linear(pos_in_feats, out_feats) 41 | ) 42 | pos_mlp.add_module(f"layer{layer_idx}", nn.Linear(pos_in_feats, out_feats)) 43 | traj_mlp.add_module( 44 | f"layer{layer_idx}", nn.Linear(traj_in_feats, out_feats) 45 | ) 46 | if layer_idx < len(mlp_layers) - 1: 47 | start_mlp.add_module(f"relu{layer_idx}", nn.ReLU()) 48 | pos_mlp.add_module(f"relu{layer_idx}", nn.ReLU()) 49 | traj_mlp.add_module(f"relu{layer_idx}", nn.ReLU()) 50 | 51 | pos_in_feats = out_feats 52 | traj_in_feats = out_feats 53 | 54 | self.start_mlp = start_mlp 55 | self.pos_mlp = pos_mlp 56 | self.traj_mlp = traj_mlp 57 | 58 | d_model = mlp_layers[-1] 59 | self.d_model = d_model 60 | if b2v: 61 | self.register_buffer("b2v_mask", self.generate_b2v_mask()) 62 | else: 63 | self.register_buffer("b2vpp_mask", self.generate_b2vpp_mask()) 64 | 65 | encoder_layer = nn.TransformerEncoderLayer( 66 | d_model, nhead, dim_feedforward, dropout 67 | ) 68 | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers) 69 | 70 | self.traj_classifier = nn.Linear(d_model, n_player_labels) 71 | self.traj_classifier.weight.data.uniform_(-initrange, initrange) 72 | self.traj_classifier.bias.data.zero_() 73 | 74 | def generate_b2v_mask(self): 75 | seq_len = self.seq_len 76 | n_players = self.n_players 77 | sz = seq_len * n_players 78 | mask = torch.zeros(sz, sz) 79 | for step in range(seq_len): 80 | start = step * n_players 81 | stop = start + n_players 82 | mask[start:stop, :stop] = 1 83 | 84 | mask = mask.masked_fill(mask == 0, float("-inf")) 85 | mask = mask.masked_fill(mask == 1, float(0.0)) 86 | return mask 87 | 88 | def generate_b2vpp_mask(self): 89 | n_players = self.n_players 90 | tri_sz = 2 * (self.seq_len * n_players) 91 | sz = tri_sz + n_players 92 | mask = torch.zeros(sz, sz) 93 | mask[:tri_sz, :tri_sz] = torch.tril(torch.ones(tri_sz, tri_sz)) 94 | mask[:, -n_players:] = 1 95 | mask = mask.masked_fill(mask == 0, float("-inf")) 96 | mask = mask.masked_fill(mask == 1, float(0.0)) 97 | return mask 98 | 99 | def forward(self, tensors): 100 | device = list(self.player_embedding.parameters())[0].device 101 | 102 | player_embeddings = self.player_embedding( 103 | tensors["player_idxs"].flatten().to(device) 104 | ) 105 | if self.sigmoid == "logistic": 106 | player_embeddings = torch.sigmoid(player_embeddings) 107 | elif self.sigmoid == "tanh": 108 | player_embeddings = torch.tanh(player_embeddings) 109 | 110 | player_xs = tensors["player_xs"].flatten().unsqueeze(1).to(device) 111 | player_ys = tensors["player_ys"].flatten().unsqueeze(1).to(device) 112 | player_hoop_sides = ( 113 | tensors["player_hoop_sides"].flatten().unsqueeze(1).to(device) 114 | ) 115 | player_pos = torch.cat( 116 | [ 117 | player_embeddings, 118 | player_xs, 119 | player_ys, 120 | player_hoop_sides, 121 | ], 122 | dim=1, 123 | ) 124 | pos_feats = self.pos_mlp(player_pos) * math.sqrt(self.d_model) 125 | if self.b2v: 126 | outputs = self.transformer(pos_feats.unsqueeze(1), self.b2v_mask) 127 | preds = self.traj_classifier(outputs.squeeze(1)) 128 | 129 | else: 130 | start_feats = self.start_mlp(player_pos[: self.n_players]) * math.sqrt( 131 | self.d_model 132 | ) 133 | 134 | player_x_diffs = tensors["player_x_diffs"].flatten().unsqueeze(1).to(device) 135 | player_y_diffs = tensors["player_y_diffs"].flatten().unsqueeze(1).to(device) 136 | player_trajs = torch.cat( 137 | [ 138 | player_embeddings, 139 | player_xs + player_x_diffs, 140 | player_ys + player_y_diffs, 141 | player_hoop_sides, 142 | player_x_diffs, 143 | player_y_diffs, 144 | ], 145 | dim=1, 146 | ) 147 | trajs_feats = self.traj_mlp(player_trajs) * math.sqrt(self.d_model) 148 | 149 | combined = torch.zeros( 150 | 2 * len(pos_feats) + self.n_players, self.d_model 151 | ).to(device) 152 | combined[: -self.n_players : 2] = pos_feats 153 | combined[1 : -self.n_players : 2] = trajs_feats 154 | combined[-self.n_players :] = start_feats 155 | 156 | outputs = self.transformer(combined.unsqueeze(1), self.b2vpp_mask) 157 | preds = self.traj_classifier(outputs.squeeze(1)[: -self.n_players : 2]) 158 | 159 | return preds 160 | -------------------------------------------------------------------------------- /baller2vecplusplus_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from settings import COURT_LENGTH, COURT_WIDTH, GAMES_DIR 5 | from torch.utils.data import Dataset 6 | 7 | # The raw data is recorded at 25 Hz. 8 | RAW_DATA_HZ = 25 9 | 10 | 11 | class Baller2VecPlusPlusDataset(Dataset): 12 | def __init__( 13 | self, 14 | hz, 15 | secs, 16 | N, 17 | player_traj_n, 18 | max_player_move, 19 | gameids, 20 | starts, 21 | mode, 22 | n_player_ids, 23 | ): 24 | self.skip = int(RAW_DATA_HZ / hz) 25 | self.chunk_size = int(RAW_DATA_HZ * secs) 26 | self.seq_len = self.chunk_size // self.skip 27 | 28 | self.N = N 29 | self.n_players = 10 30 | self.gameids = gameids 31 | self.starts = starts 32 | self.mode = mode 33 | self.n_player_ids = n_player_ids 34 | 35 | self.player_traj_n = player_traj_n 36 | self.max_player_move = max_player_move 37 | self.player_traj_bins = np.linspace( 38 | -max_player_move, max_player_move, player_traj_n - 1 39 | ) 40 | 41 | def __len__(self): 42 | return self.N 43 | 44 | def get_sample(self, X, start): 45 | # Downsample. 46 | seq_data = X[start : start + self.chunk_size : self.skip] 47 | 48 | keep_players = np.random.choice(np.arange(10), self.n_players, False) 49 | if self.mode in {"valid", "test"}: 50 | keep_players.sort() 51 | 52 | # End sequence early if there is a position glitch. Often happens when there was 53 | # a break in the game, but glitches also happen for other reasons. See 54 | # glitch_example.py for an example. 55 | player_xs = seq_data[:, 20:30][:, keep_players] 56 | player_ys = seq_data[:, 30:40][:, keep_players] 57 | player_x_diffs = np.diff(player_xs, axis=0) 58 | player_y_diffs = np.diff(player_ys, axis=0) 59 | try: 60 | glitch_x_break = np.where( 61 | np.abs(player_x_diffs) > 1.2 * self.max_player_move 62 | )[0].min() 63 | except ValueError: 64 | glitch_x_break = len(seq_data) 65 | 66 | try: 67 | glitch_y_break = np.where( 68 | np.abs(player_y_diffs) > 1.2 * self.max_player_move 69 | )[0].min() 70 | except ValueError: 71 | glitch_y_break = len(seq_data) 72 | 73 | seq_break = min(glitch_x_break, glitch_y_break) 74 | seq_data = seq_data[:seq_break] 75 | 76 | player_idxs = seq_data[:, 10:20][:, keep_players].astype(int) 77 | player_xs = seq_data[:, 20:30][:, keep_players] 78 | player_ys = seq_data[:, 30:40][:, keep_players] 79 | player_hoop_sides = seq_data[:, 40:50][:, keep_players].astype(int) 80 | 81 | # Randomly rotate the court because the hoop direction is arbitrary. 82 | if (self.mode == "train") and (np.random.random() < 0.5): 83 | player_xs = COURT_LENGTH - player_xs 84 | player_ys = COURT_WIDTH - player_ys 85 | player_hoop_sides = (player_hoop_sides + 1) % 2 86 | 87 | # Get player trajectories. 88 | player_x_diffs = np.diff(player_xs, axis=0) 89 | player_y_diffs = np.diff(player_ys, axis=0) 90 | 91 | player_traj_rows = np.digitize(player_y_diffs, self.player_traj_bins) 92 | player_traj_cols = np.digitize(player_x_diffs, self.player_traj_bins) 93 | player_trajs = player_traj_rows * self.player_traj_n + player_traj_cols 94 | 95 | return { 96 | "player_idxs": torch.LongTensor(player_idxs[: seq_break - 1]), 97 | "player_xs": torch.Tensor(player_xs[: seq_break - 1]), 98 | "player_ys": torch.Tensor(player_ys[: seq_break - 1]), 99 | "player_x_diffs": torch.Tensor(player_x_diffs), 100 | "player_y_diffs": torch.Tensor(player_y_diffs), 101 | "player_hoop_sides": torch.Tensor(player_hoop_sides[: seq_break - 1]), 102 | "player_trajs": torch.LongTensor(player_trajs), 103 | } 104 | 105 | def __getitem__(self, idx): 106 | if self.mode == "train": 107 | gameid = np.random.choice(self.gameids) 108 | 109 | elif self.mode in {"valid", "test"}: 110 | gameid = self.gameids[idx] 111 | 112 | X = np.load(f"{GAMES_DIR}/{gameid}_X.npy") 113 | 114 | if self.mode == "train": 115 | start = np.random.randint(len(X) - self.chunk_size) 116 | 117 | elif self.mode in {"valid", "test"}: 118 | start = self.starts[idx] 119 | 120 | return self.get_sample(X, start) 121 | -------------------------------------------------------------------------------- /court.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/court.png -------------------------------------------------------------------------------- /events.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/events.zip -------------------------------------------------------------------------------- /generate_game_numpy_arrays.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import numpy as np 3 | import pandas as pd 4 | import pickle 5 | import py7zr 6 | import shutil 7 | 8 | from settings import * 9 | 10 | HALF_COURT_LENGTH = COURT_LENGTH // 2 11 | THRESHOLD = 1.0 12 | 13 | 14 | def game_name2gameid_worker(game_7zs, queue): 15 | game_names = [] 16 | gameids = [] 17 | for game_7z in game_7zs: 18 | game_name = game_7z.split(".7z")[0] 19 | game_names.append(game_name) 20 | try: 21 | archive = py7zr.SevenZipFile(f"{TRACKING_DIR}/{game_7z}", mode="r") 22 | archive.extractall(path=f"{TRACKING_DIR}/{game_name}") 23 | archive.close() 24 | except AttributeError: 25 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}") 26 | gameids.append("N/A") 27 | continue 28 | 29 | try: 30 | gameids.append(os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0]) 31 | except IndexError: 32 | gameids.append("N/A") 33 | 34 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}") 35 | 36 | queue.put((game_names, gameids)) 37 | 38 | 39 | def get_game_name2gameid_map(): 40 | q = multiprocessing.Queue() 41 | 42 | dir_fs = os.listdir(TRACKING_DIR) 43 | all_game_7zs = [dir_f for dir_f in dir_fs if dir_f.endswith(".7z")] 44 | processes = multiprocessing.cpu_count() 45 | game_7zs_per_process = int(np.ceil(len(all_game_7zs) / processes)) 46 | jobs = [] 47 | for i in range(processes): 48 | start = i * game_7zs_per_process 49 | end = start + game_7zs_per_process 50 | game_7zs = all_game_7zs[start:end] 51 | p = multiprocessing.Process(target=game_name2gameid_worker, args=(game_7zs, q)) 52 | jobs.append(p) 53 | p.start() 54 | 55 | all_game_names = [] 56 | all_gameids = [] 57 | for _ in jobs: 58 | (game_names, gameids) = q.get() 59 | all_game_names.extend(game_names) 60 | all_gameids.extend(gameids) 61 | 62 | for p in jobs: 63 | p.join() 64 | 65 | df = pd.DataFrame.from_dict({"game_name": all_game_names, "gameid": all_gameids}) 66 | home_dir = os.path.expanduser("~") 67 | df.to_csv(f"{home_dir}/test.csv", index=False) 68 | 69 | 70 | def playerid2player_idx_map_worker(game_7zs, queue): 71 | playerid2props = {} 72 | for game_7z in game_7zs: 73 | game_name = game_7z.split(".7z")[0] 74 | try: 75 | archive = py7zr.SevenZipFile(f"{TRACKING_DIR}/{game_7z}", mode="r") 76 | archive.extractall(path=f"{TRACKING_DIR}/{game_name}") 77 | archive.close() 78 | except AttributeError: 79 | print(f"{game_name}\nBusted.", flush=True) 80 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}") 81 | continue 82 | 83 | try: 84 | gameid = os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0] 85 | except IndexError: 86 | print(f"No tracking data for {game_name}.", flush=True) 87 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}") 88 | continue 89 | 90 | df_tracking = pd.read_json(f"{TRACKING_DIR}/{game_name}/{gameid}.json") 91 | event = df_tracking["events"].iloc[0] 92 | players = event["home"]["players"] + event["visitor"]["players"] 93 | for player in players: 94 | playerid = player["playerid"] 95 | playerid2props[playerid] = { 96 | "name": " ".join([player["firstname"], player["lastname"]]), 97 | } 98 | 99 | queue.put(playerid2props) 100 | 101 | 102 | def get_playerid2player_idx_map(): 103 | q = multiprocessing.Queue() 104 | 105 | dir_fs = os.listdir(TRACKING_DIR) 106 | all_game_7zs = [dir_f for dir_f in dir_fs if dir_f.endswith(".7z")] 107 | processes = multiprocessing.cpu_count() 108 | game_7zs_per_process = int(np.ceil(len(all_game_7zs) / processes)) 109 | jobs = [] 110 | for i in range(processes): 111 | start = i * game_7zs_per_process 112 | end = start + game_7zs_per_process 113 | game_7zs = all_game_7zs[start:end] 114 | p = multiprocessing.Process( 115 | target=playerid2player_idx_map_worker, args=(game_7zs, q) 116 | ) 117 | jobs.append(p) 118 | p.start() 119 | 120 | playerid2props = {} 121 | for _ in jobs: 122 | playerid2props.update(q.get()) 123 | 124 | for p in jobs: 125 | p.join() 126 | 127 | playerid2player_idx = {} 128 | player_idx2props = {} 129 | for (player_idx, playerid) in enumerate(playerid2props): 130 | playerid2player_idx[playerid] = player_idx 131 | player_idx2props[player_idx] = playerid2props[playerid] 132 | player_idx2props[player_idx]["playerid"] = playerid 133 | 134 | return (playerid2player_idx, player_idx2props) 135 | 136 | 137 | def get_game_time(game_clock_secs, period): 138 | period_secs = 720 if period <= 4 else 300 139 | period_time = period_secs - game_clock_secs 140 | if period <= 4: 141 | return (period - 1) * 720 + period_time 142 | else: 143 | return 4 * 720 + (period - 5) * 300 + period_time 144 | 145 | 146 | def get_shot_times_worker(game_7zs, queue): 147 | shot_times = {} 148 | for game_7z in game_7zs: 149 | game_name = game_7z.split(".7z")[0] 150 | try: 151 | gameid = os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0] 152 | except FileNotFoundError: 153 | continue 154 | 155 | df_events = pd.read_csv(f"{EVENTS_DIR}/{gameid}.csv") 156 | game_shot_times = {} 157 | for (row_idx, row) in df_events.iterrows(): 158 | period = row["PERIOD"] 159 | game_clock = row["PCTIMESTRING"].split(":") 160 | game_clock_secs = 60 * int(game_clock[0]) + int(game_clock[1]) 161 | game_time = get_game_time(game_clock_secs, period) 162 | 163 | if (row["EVENTMSGTYPE"] == 1) or (row["EVENTMSGTYPE"] == 2): 164 | game_shot_times[game_time] = row["PLAYER1_TEAM_ABBREVIATION"] 165 | 166 | shot_times[gameid] = game_shot_times 167 | 168 | queue.put(shot_times) 169 | 170 | 171 | def get_shot_times(): 172 | q = multiprocessing.Queue() 173 | 174 | dir_fs = os.listdir(TRACKING_DIR) 175 | all_game_7zs = [dir_f for dir_f in dir_fs if dir_f.endswith(".7z")] 176 | processes = multiprocessing.cpu_count() 177 | game_7zs_per_process = int(np.ceil(len(all_game_7zs) / processes)) 178 | jobs = [] 179 | for i in range(processes): 180 | start = i * game_7zs_per_process 181 | end = start + game_7zs_per_process 182 | game_7zs = all_game_7zs[start:end] 183 | p = multiprocessing.Process(target=get_shot_times_worker, args=(game_7zs, q)) 184 | jobs.append(p) 185 | p.start() 186 | 187 | shot_times = {} 188 | for _ in jobs: 189 | shot_times.update(q.get()) 190 | 191 | for p in jobs: 192 | p.join() 193 | 194 | return shot_times 195 | 196 | 197 | def fill_in_periods(game_hoop_sides): 198 | for p1 in range(4): 199 | if p1 not in game_hoop_sides: 200 | continue 201 | 202 | adder = 1 if p1 <= 2 else 3 203 | p2 = (p1 % 2) + adder 204 | if p2 not in game_hoop_sides: 205 | game_hoop_sides[p2] = game_hoop_sides[p1].copy() 206 | 207 | period_hoop_sides = list(game_hoop_sides[p1].items()) 208 | swapped = { 209 | period_hoop_sides[0][0]: period_hoop_sides[1][1], 210 | period_hoop_sides[1][0]: period_hoop_sides[0][1], 211 | } 212 | 213 | p2s = [3, 4] if p1 <= 2 else [1, 2] 214 | for p2 in p2s: 215 | if p2 not in game_hoop_sides: 216 | game_hoop_sides[p2] = swapped.copy() 217 | 218 | return game_hoop_sides 219 | 220 | 221 | def check_periods(game_hoop_sides, team, game): 222 | for period in range(4): 223 | if period in {2, 4}: 224 | if period - 1 in game_hoop_sides: 225 | assert ( 226 | game_hoop_sides[period - 1][team] == game_hoop_sides[period][team] 227 | ), f"{team} has different sides in periods {period} and {period - 1} of {game}." 228 | 229 | if period in {3, 4}: 230 | for first_half in [1, 2]: 231 | if first_half in game_hoop_sides: 232 | assert ( 233 | game_hoop_sides[first_half][team] 234 | != game_hoop_sides[period][team] 235 | ), f"{team} has same side in periods {first_half} and {period} of {game}." 236 | 237 | 238 | def get_game_hoop_sides(teams, hoop_side_counts, game): 239 | [team_a, team_b] = list(teams) 240 | game_hoop_sides = {period: {} for period in hoop_side_counts} 241 | do_check = False 242 | periods = list(hoop_side_counts) 243 | periods.sort() 244 | for period in periods: 245 | if len(hoop_side_counts[period]) == 0: 246 | print(f"No shooting data for {period} of {game}.") 247 | continue 248 | 249 | for team in teams: 250 | if team not in hoop_side_counts[period]: 251 | print(f"Missing {team} for {period} of {game}.") 252 | hoop_side_counts[period][team] = {0: 0, COURT_LENGTH: 0} 253 | 254 | l_count = hoop_side_counts[period][team][0] 255 | r_count = hoop_side_counts[period][team][COURT_LENGTH] 256 | if l_count > r_count: 257 | game_hoop_sides[period][team] = 0 258 | elif l_count < r_count: 259 | game_hoop_sides[period][team] = COURT_LENGTH 260 | else: 261 | do_check = True 262 | 263 | if do_check: 264 | team_in = team_a if team_a in game_hoop_sides[period] else team_b 265 | team_out = team_a if team_a not in game_hoop_sides[period] else team_b 266 | hoop_side = game_hoop_sides[period][team_in] 267 | if hoop_side == 0: 268 | game_hoop_sides[period][team_out] = 0 269 | else: 270 | game_hoop_sides[period][team_out] = COURT_LENGTH 271 | 272 | do_check = False 273 | 274 | assert ( 275 | game_hoop_sides[period][team_a] != game_hoop_sides[period][team_b] 276 | ), f"{team_a} and {team_b} have same side in {period} of {game}." 277 | 278 | game_hoop_sides = fill_in_periods(game_hoop_sides) 279 | 280 | for team in teams: 281 | check_periods(game_hoop_sides, team, game) 282 | 283 | return game_hoop_sides 284 | 285 | 286 | def get_hoop_sides_worker(game_7zs, queue): 287 | hoop_sides = {} 288 | for game_7z in game_7zs: 289 | game_name = game_7z.split(".7z")[0] 290 | 291 | try: 292 | gameid = os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0] 293 | except FileNotFoundError: 294 | continue 295 | 296 | df_tracking = pd.read_json(f"{TRACKING_DIR}/{game_name}/{gameid}.json") 297 | hoop_side_counts = {} 298 | used_game_times = set() 299 | teams = set() 300 | for tracking_event in df_tracking["events"]: 301 | for moment in tracking_event["moments"]: 302 | period = moment[0] 303 | game_clock = moment[2] 304 | game_time = int(get_game_time(game_clock, period)) 305 | if (game_time in shot_times[gameid]) and ( 306 | game_time not in used_game_times 307 | ): 308 | ball_x = moment[5][0][2] 309 | 310 | if ball_x < HALF_COURT_LENGTH: 311 | hoop_side = 0 312 | else: 313 | hoop_side = COURT_LENGTH 314 | 315 | if period not in hoop_side_counts: 316 | hoop_side_counts[period] = {} 317 | 318 | shooting_team = shot_times[gameid][game_time] 319 | if shooting_team not in hoop_side_counts[period]: 320 | hoop_side_counts[period][shooting_team] = { 321 | 0: 0, 322 | COURT_LENGTH: 0, 323 | } 324 | 325 | hoop_side_counts[period][shooting_team][hoop_side] += 1 326 | used_game_times.add(game_time) 327 | teams.add(shooting_team) 328 | 329 | if len(teams) == 0: 330 | print(f"The moments in the {game_name} JSON are empty.", flush=True) 331 | continue 332 | 333 | hoop_sides[gameid] = get_game_hoop_sides(teams, hoop_side_counts, game_name) 334 | 335 | queue.put(hoop_sides) 336 | 337 | 338 | def get_team_hoop_sides(): 339 | q = multiprocessing.Queue() 340 | 341 | dir_fs = os.listdir(TRACKING_DIR) 342 | all_game_7zs = [dir_f for dir_f in dir_fs if dir_f.endswith(".7z")] 343 | processes = multiprocessing.cpu_count() 344 | game_7zs_per_process = int(np.ceil(len(all_game_7zs) / processes)) 345 | jobs = [] 346 | for i in range(processes): 347 | start = i * game_7zs_per_process 348 | end = start + game_7zs_per_process 349 | game_7zs = all_game_7zs[start:end] 350 | p = multiprocessing.Process(target=get_hoop_sides_worker, args=(game_7zs, q)) 351 | jobs.append(p) 352 | p.start() 353 | 354 | hoop_sides = {} 355 | for _ in jobs: 356 | hoop_sides.update(q.get()) 357 | 358 | for p in jobs: 359 | p.join() 360 | 361 | return hoop_sides 362 | 363 | 364 | def get_event_stream(gameid): 365 | df_events = pd.read_csv(f"{EVENTS_DIR}/{gameid}.csv") 366 | df_events = df_events.fillna("") 367 | df_events["DESCRIPTION"] = ( 368 | df_events["HOMEDESCRIPTION"] 369 | + " " 370 | + df_events["NEUTRALDESCRIPTION"] 371 | + " " 372 | + df_events["VISITORDESCRIPTION"] 373 | ) 374 | 375 | # Posession times. 376 | # EVENTMSGTYPE descriptions can be found at: https://github.com/rd11490/NBA_Tutorials/tree/master/analyze_play_by_play. 377 | event_col = "EVENTMSGTYPE" 378 | description_col = "DESCRIPTION" 379 | player1_team_col = "PLAYER1_TEAM_ABBREVIATION" 380 | 381 | teams = list(df_events[player1_team_col].unique()) 382 | teams.sort() 383 | teams = teams[1:] if len(teams) > 2 else teams 384 | 385 | event = None 386 | score = "0 - 0" 387 | # pos_team is the team that had possession prior to the event. 388 | (pos_team, pos_team_idx) = (None, None) 389 | jump_ball_team_idx = None 390 | # I think most of these are technical fouls. 391 | skip_fouls = {10, 11, 16, 19} 392 | 393 | events = set() 394 | pos_stream = [] 395 | event_stream = [] 396 | for (row_idx, row) in df_events.iterrows(): 397 | period = row["PERIOD"] 398 | game_clock = row["PCTIMESTRING"].split(":") 399 | game_clock_secs = 60 * int(game_clock[0]) + int(game_clock[1]) 400 | game_time = get_game_time(game_clock_secs, period) 401 | 402 | description = row[description_col].lower().strip() 403 | 404 | eventmsgtype = row[event_col] 405 | 406 | # Don't know. 407 | if eventmsgtype == 18: 408 | continue 409 | 410 | # Blank line. 411 | elif eventmsgtype == 14: 412 | continue 413 | 414 | # End of a period. 415 | elif eventmsgtype == 13: 416 | if period == 4: 417 | jump_ball_team_idx = None 418 | 419 | continue 420 | 421 | # Start of a period. 422 | elif eventmsgtype == 12: 423 | if 2 <= period <= 4: 424 | if period == 4: 425 | pos_team_idx = jump_ball_team_idx 426 | else: 427 | pos_team_idx = (jump_ball_team_idx + 1) % 2 428 | 429 | elif 6 <= period: 430 | pos_team_idx = (jump_ball_team_idx + (period - 5)) % 2 431 | 432 | continue 433 | 434 | # Ejection. 435 | elif eventmsgtype == 11: 436 | continue 437 | 438 | # Jump ball. 439 | elif eventmsgtype == 10: 440 | if int(row["PLAYER3_ID"]) in TEAM_ID2PROPS: 441 | pos_team = TEAM_ID2PROPS[int(row["PLAYER3_ID"])]["abbreviation"] 442 | else: 443 | pos_team = row["PLAYER3_TEAM_ABBREVIATION"] 444 | 445 | if pos_team == teams[0]: 446 | pos_team_idx = 0 447 | else: 448 | pos_team_idx = 1 449 | 450 | if (period in {1, 5}) and (jump_ball_team_idx is None): 451 | jump_ball_team_idx = pos_team_idx 452 | 453 | continue 454 | 455 | # Timeout. 456 | elif eventmsgtype == 9: 457 | # TV timeout? 458 | if description == "": 459 | continue 460 | 461 | event = "timeout" 462 | 463 | # Substitution. 464 | elif eventmsgtype == 8: 465 | continue 466 | 467 | # Violation. 468 | elif eventmsgtype == 7: 469 | # With 35 seconds left in the fourth period, there was a kicked ball 470 | # violation attributed to Wayne Ellington of the Brooklyn Nets, but the 471 | # following event is a shot by the Nets, which means possession never changed. 472 | if (gameid == "0021500414") and (row_idx == 427): 473 | continue 474 | 475 | # Goaltending is considered a made shot, so the following event is always the 476 | # made shot event. 477 | if "goaltending" in description: 478 | score = row["SCORE"] if row["SCORE"] else score 479 | continue 480 | # Jump ball violations have weird possession rules. 481 | elif "jump ball" in description: 482 | if row["PLAYER1_TEAM_ABBREVIATION"] == teams[0]: 483 | pos_team_idx = 1 484 | else: 485 | pos_team_idx = 0 486 | 487 | pos_team = teams[pos_team_idx] 488 | if (period == 1) and (game_time == 0): 489 | jump_ball_team_idx = pos_team_idx 490 | 491 | continue 492 | 493 | else: 494 | if row[player1_team_col] == teams[pos_team_idx]: 495 | event = "violation_offense" 496 | pos_team_idx = (pos_team_idx + 1) % 2 497 | else: 498 | event = "violation_defense" 499 | 500 | # Foul. 501 | elif eventmsgtype == 6: 502 | # Skip weird fouls. 503 | if row["EVENTMSGACTIONTYPE"] in skip_fouls: 504 | score = row["SCORE"] if row["SCORE"] else score 505 | continue 506 | else: 507 | if row[player1_team_col] == teams[pos_team_idx]: 508 | event = "offensive_foul" 509 | pos_team_idx = (pos_team_idx + 1) % 2 510 | else: 511 | event = "defensive_foul" 512 | 513 | # Turnover. 514 | elif eventmsgtype == 5: 515 | if "steal" in description: 516 | event = "steal" 517 | elif "goaltending" in description: 518 | event = "goaltending_offense" 519 | elif ( 520 | ("violation" in description) 521 | or ("dribble" in description) 522 | or ("traveling" in description) 523 | ): 524 | event = "violation_offense" 525 | else: 526 | event = "turnover" 527 | 528 | # Team turnover. 529 | if row[player1_team_col] == "": 530 | team_id = int(row["PLAYER1_ID"]) 531 | team_abb = TEAM_ID2PROPS[team_id]["abbreviation"] 532 | else: 533 | team_abb = row[player1_team_col] 534 | 535 | pos_team_idx = 1 if team_abb == teams[0] else 0 536 | 537 | # Rebound. 538 | elif eventmsgtype == 4: 539 | # With 17 seconds left in the first period, Spencer Hawes missed a tip in, 540 | # which was rebounded by DeAndre Jordan. The tip in is recorded as a rebound 541 | # and a missed shot for Hawes. All three events have the same timestamp, 542 | # which seems to have caused the order of the events to be slightly shuffled 543 | # with the Jordan rebound occurring before the tip in. 544 | if (gameid == "0021500550") and (row_idx == 97): 545 | continue 546 | 547 | # Team rebound. 548 | if row[player1_team_col] == "": 549 | team_id = int(row["PLAYER1_ID"]) 550 | team_abb = TEAM_ID2PROPS[team_id]["abbreviation"] 551 | if team_abb == teams[pos_team_idx]: 552 | event = "rebound_offense" 553 | else: 554 | event = "rebound_defense" 555 | pos_team_idx = (pos_team_idx + 1) % 2 556 | 557 | elif row[player1_team_col] == teams[pos_team_idx]: 558 | event = "rebound_offense" 559 | else: 560 | event = "rebound_defense" 561 | pos_team_idx = (pos_team_idx + 1) % 2 562 | 563 | # Free throw. 564 | elif eventmsgtype == 3: 565 | # See rules for technical fouls: https://official.nba.com/rule-no-12-fouls-and-penalties/. 566 | # Possession only changes for too many players, which is extremely rare. 567 | if "technical" not in description: 568 | pos_team_idx = 0 if row[player1_team_col] == teams[0] else 1 569 | if ( 570 | ("Clear Path" not in row[description_col]) 571 | and ("Flagrant" not in row[description_col]) 572 | and ("MISS" not in row[description_col]) 573 | and ( 574 | ("1 of 1" in description) 575 | or ("2 of 2" in description) 576 | or ("3 of 3" in description) 577 | ) 578 | ): 579 | # Hack to handle foul shots for away from play fouls. 580 | if ((gameid == "0021500274") and (row_idx == 519)) or ( 581 | (gameid == "0021500572") and (row_idx == 428) 582 | ): 583 | pass 584 | # This event is a made foul shot by Thaddeus Young of the Brooklyn 585 | # Nets following an and-one foul, so possession should have changed 586 | # to the Milwaukee Bucks. However, the next event is a made shot by 587 | # Brook Lopez (also of the Brooklyn Nets) with no event indicating a 588 | # change of possession occurring before it. 589 | elif (gameid == "0021500047") and (row_idx == 64): 590 | pass 591 | else: 592 | pos_team_idx = (pos_team_idx + 1) % 2 593 | 594 | pos_team = teams[pos_team_idx] 595 | 596 | score = row["SCORE"] if row["SCORE"] else score 597 | 598 | continue 599 | 600 | # Missed shot. 601 | elif eventmsgtype == 2: 602 | if "dunk" in description: 603 | shot_type = "dunk" 604 | elif "layup" in description: 605 | shot_type = "layup" 606 | else: 607 | shot_type = "shot" 608 | 609 | if "BLOCK" in row[description_col]: 610 | miss_type = "block" 611 | else: 612 | miss_type = "miss" 613 | 614 | event = f"{shot_type}_{miss_type}" 615 | 616 | if row[player1_team_col] != teams[pos_team_idx]: 617 | print(pos_stream[-5:]) 618 | raise ValueError(f"Incorrect possession team in row {str(row_idx)}.") 619 | 620 | # Made shot. 621 | elif eventmsgtype == 1: 622 | if "dunk" in description: 623 | shot_type = "dunk" 624 | elif "layup" in description: 625 | shot_type = "layup" 626 | else: 627 | shot_type = "shot" 628 | 629 | event = f"{shot_type}_made" 630 | 631 | if row[player1_team_col] != teams[pos_team_idx]: 632 | print(pos_stream[-5:]) 633 | raise ValueError(f"Incorrect possession team in row {str(row_idx)}.") 634 | 635 | pos_team_idx = (pos_team_idx + 1) % 2 636 | 637 | events.add(event) 638 | pos_stream.append(pos_team_idx) 639 | 640 | if row[player1_team_col] == "": 641 | team_id = int(row["PLAYER1_ID"]) 642 | event_team = TEAM_ID2PROPS[team_id]["abbreviation"] 643 | else: 644 | event_team = row[player1_team_col] 645 | 646 | event_stream.append( 647 | { 648 | "game_time": game_time - 1, 649 | "pos_team": pos_team, 650 | "event": event, 651 | "description": description, 652 | "event_team": event_team, 653 | "score": score, 654 | } 655 | ) 656 | 657 | # With 17 seconds left in the first period, Spencer Hawes missed a tip in, 658 | # which was rebounded by DeAndre Jordan. The tip in is recorded as a rebound 659 | # and a missed shot for Hawes. All three events have the same timestamp, 660 | # which seems to have caused the order of the events to be slightly shuffled 661 | # with the Jordan rebound occurring before the tip in. 662 | if (gameid == "0021500550") and (row_idx == 98): 663 | event_stream.append( 664 | { 665 | "game_time": game_time - 1, 666 | "pos_team": pos_team, 667 | "event": "rebound_defense", 668 | "description": "jordan rebound (off:2 def:5)", 669 | "event_team": "LAC", 670 | "score": score, 671 | } 672 | ) 673 | pos_team_idx = (pos_team_idx + 1) % 2 674 | 675 | # This event is a missed shot by Kawhi Leonard of the San Antonio Spurs. The next 676 | # event is a missed shot by Bradley Beal of the Washington Wizards with no 677 | # event indicating a change of possession occurring before it. 678 | if (gameid == "0021500061") and (row_idx == 240): 679 | pos_team_idx = (pos_team_idx + 1) % 2 680 | 681 | pos_team = teams[pos_team_idx] 682 | score = row["SCORE"] if row["SCORE"] else score 683 | 684 | return event_stream 685 | 686 | 687 | def get_event_streams_worker(game_names, queue): 688 | gameid2event_stream = {} 689 | for (game_idx, game_name) in enumerate(game_names): 690 | try: 691 | gameid = os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0] 692 | gameid2event_stream[gameid] = get_event_stream(gameid) 693 | except IndexError: 694 | continue 695 | 696 | queue.put(gameid2event_stream) 697 | 698 | 699 | def get_event_streams(): 700 | q = multiprocessing.Queue() 701 | 702 | dir_fs = os.listdir(TRACKING_DIR) 703 | all_game_names = [dir_f for dir_f in dir_fs if not dir_f.endswith(".7z")] 704 | 705 | processes = multiprocessing.cpu_count() 706 | game_names_per_process = int(np.ceil(len(all_game_names) / processes)) 707 | jobs = [] 708 | for i in range(processes): 709 | start = i * game_names_per_process 710 | end = start + game_names_per_process 711 | game_names = all_game_names[start:end] 712 | p = multiprocessing.Process( 713 | target=get_event_streams_worker, args=(game_names, q) 714 | ) 715 | jobs.append(p) 716 | p.start() 717 | 718 | gameid2event_stream = {} 719 | for _ in jobs: 720 | gameid2event_stream.update(q.get()) 721 | 722 | for p in jobs: 723 | p.join() 724 | 725 | event2event_idx = {} 726 | for event_stream in gameid2event_stream.values(): 727 | for event in event_stream: 728 | event2event_idx.setdefault(event["event"], len(event2event_idx)) 729 | 730 | return (event2event_idx, gameid2event_stream) 731 | 732 | 733 | def add_score_changes(X): 734 | score_diff_idx = 6 735 | period_idx = 3 736 | wall_clock_idx = -1 737 | 738 | score_change_idxs = np.where(np.diff(X[:, score_diff_idx]) != 0)[0] + 1 739 | score_changes = ( 740 | X[score_change_idxs, score_diff_idx] - X[score_change_idxs - 1, score_diff_idx] 741 | ) 742 | 743 | # Score changes at the half are the result of the teams changing sides. 744 | half_idx = -1 745 | if (X[:, period_idx].min() <= 2) and (X[:, period_idx].max() >= 3): 746 | period_change_idxs = np.where(np.diff(X[:, 3]) != 0)[0] + 1 747 | for period_change_idx in period_change_idxs: 748 | before_period = X[period_change_idx - 1, period_idx] 749 | after_period = X[period_change_idx, period_idx] 750 | if (before_period <= 2) and (after_period > 2): 751 | half_idx = period_change_idx 752 | break 753 | 754 | else: 755 | half_idx = -1 756 | 757 | cur_score_change_idx = 0 758 | times_to_next_score_change = [] 759 | next_score_changes = [] 760 | end_time = X[-1, wall_clock_idx] 761 | for (idx, row) in enumerate(X): 762 | try: 763 | if idx == score_change_idxs[cur_score_change_idx]: 764 | cur_score_change_idx += 1 765 | 766 | score_change_idx = score_change_idxs[cur_score_change_idx] 767 | next_score_time = X[score_change_idx, wall_clock_idx] 768 | if score_change_idx == half_idx: 769 | next_score_change = 0 770 | else: 771 | next_score_change = score_changes[cur_score_change_idx] 772 | 773 | except IndexError: 774 | next_score_time = end_time 775 | next_score_change = 0 776 | 777 | cur_time = X[idx, wall_clock_idx] 778 | time_to_next_score_change = (next_score_time - cur_time) / 1000 779 | times_to_next_score_change.append(time_to_next_score_change) 780 | next_score_changes.append(next_score_change) 781 | 782 | times_to_next_score_change = np.array(times_to_next_score_change)[None].T 783 | next_score_changes = np.array(next_score_changes)[None].T 784 | 785 | X = np.hstack( 786 | [ 787 | X[:, :wall_clock_idx], 788 | times_to_next_score_change, 789 | next_score_changes, 790 | X[:, wall_clock_idx:], 791 | ] 792 | ) 793 | return X 794 | 795 | 796 | def save_game_numpy_arrays(game_name): 797 | try: 798 | gameid = os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0] 799 | except IndexError: 800 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}") 801 | return 802 | 803 | if gameid not in gameid2event_stream: 804 | print(f"Missing gameid: {gameid}", flush=True) 805 | return 806 | 807 | df_tracking = pd.read_json(f"{TRACKING_DIR}/{game_name}/{gameid}.json") 808 | home_team = None 809 | cur_time = -1 810 | event_idx = 0 811 | game_over = False 812 | X = [] 813 | y = [] 814 | event_stream = gameid2event_stream[gameid] 815 | for tracking_event in df_tracking["events"]: 816 | event_id = tracking_event["eventId"] 817 | if home_team is None: 818 | home_team_id = tracking_event["home"]["teamid"] 819 | home_team = TEAM_ID2PROPS[home_team_id]["abbreviation"] 820 | 821 | moments = tracking_event["moments"] 822 | for moment in moments: 823 | period = moment[0] 824 | # Milliseconds. 825 | wall_clock = moment[1] 826 | game_clock = moment[2] 827 | shot_clock = moment[3] 828 | shot_clock = shot_clock if shot_clock else game_clock 829 | 830 | period_time = 720 - game_clock if period <= 4 else 300 - game_clock 831 | game_time = get_game_time(game_clock, period) 832 | 833 | # Moments can overlap temporally, so previously processed time points are 834 | # skipped along with clock stoppages. 835 | if game_time <= cur_time: 836 | continue 837 | 838 | while game_time > event_stream[event_idx]["game_time"]: 839 | event_idx += 1 840 | if event_idx >= len(event_stream): 841 | game_over = True 842 | break 843 | 844 | if game_over: 845 | break 846 | 847 | event = event_stream[event_idx] 848 | score = event["score"] 849 | (away_score, home_score) = (int(s) for s in score.split(" - ")) 850 | home_hoop_side = hoop_sides[gameid][period][home_team] 851 | if home_hoop_side == 0: 852 | (left_score, right_score) = (home_score, away_score) 853 | else: 854 | (right_score, left_score) = (home_score, away_score) 855 | 856 | (ball_x, ball_y, ball_z) = moment[5][0][2:5] 857 | data = [ 858 | game_time, 859 | period_time, 860 | shot_clock, 861 | period, 862 | left_score, # off_score, 863 | right_score, # def_score, 864 | left_score - right_score, 865 | ball_x, 866 | ball_y, 867 | ball_z, 868 | ] 869 | 870 | if len(moment[5][1:]) != 10: 871 | continue 872 | 873 | player_idxs = [] 874 | player_xs = [] 875 | player_ys = [] 876 | player_hoop_sides = [] 877 | 878 | try: 879 | for player in moment[5][1:]: 880 | player_idxs.append(playerid2player_idx[player[1]]) 881 | player_xs.append(player[2]) 882 | player_ys.append(player[3]) 883 | hoop_side = hoop_sides[gameid][period][ 884 | TEAM_ID2PROPS[player[0]]["abbreviation"] 885 | ] 886 | player_hoop_sides.append(int(hoop_side == COURT_LENGTH)) 887 | 888 | except KeyError: 889 | if player[1] == 0: 890 | print( 891 | f"Bad player in event {event_id} for {game_name}.", flush=True 892 | ) 893 | continue 894 | 895 | else: 896 | raise KeyError 897 | 898 | order = np.argsort(player_idxs) 899 | for idx in order: 900 | data.append(player_idxs[idx]) 901 | 902 | for idx in order: 903 | data.append(player_xs[idx]) 904 | 905 | for idx in order: 906 | data.append(player_ys[idx]) 907 | 908 | for idx in order: 909 | data.append(player_hoop_sides[idx]) 910 | 911 | data.append(event_idx) 912 | data.append(wall_clock) 913 | 914 | if len(data) != 52: 915 | raise ValueError 916 | 917 | X.append(np.array(data)) 918 | y.append(event2event_idx.setdefault(event["event"], len(event2event_idx))) 919 | cur_time = game_time 920 | 921 | if game_over: 922 | break 923 | 924 | X = np.stack(X) 925 | y = np.array(y) 926 | 927 | X = add_score_changes(X) 928 | 929 | np.save(f"{GAMES_DIR}/{gameid}_X.npy", X) 930 | np.save(f"{GAMES_DIR}/{gameid}_y.npy", y) 931 | 932 | 933 | def save_numpy_arrays_worker(game_names): 934 | for game_name in game_names: 935 | try: 936 | save_game_numpy_arrays(game_name) 937 | except ValueError: 938 | pass 939 | 940 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}") 941 | 942 | 943 | def save_numpy_arrays(): 944 | dir_fs = os.listdir(TRACKING_DIR) 945 | all_game_names = [dir_f for dir_f in dir_fs if not dir_f.endswith(".7z")] 946 | 947 | processes = multiprocessing.cpu_count() 948 | game_names_per_process = int(np.ceil(len(all_game_names) / processes)) 949 | jobs = [] 950 | for i in range(processes): 951 | start = i * game_names_per_process 952 | end = start + game_names_per_process 953 | game_names = all_game_names[start:end] 954 | p = multiprocessing.Process(target=save_numpy_arrays_worker, args=(game_names,)) 955 | jobs.append(p) 956 | p.start() 957 | 958 | for p in jobs: 959 | p.join() 960 | 961 | 962 | def player_idx2playing_time_map_worker(gameids, queue): 963 | player_idx2playing_time = {} 964 | for gameid in gameids: 965 | X = np.load(f"{GAMES_DIR}/{gameid}_X.npy") 966 | wall_clock_diffs = np.diff(X[:, -1]) / 1000 967 | all_player_idxs = X[:, 10:20].astype(int) 968 | prev_players = set(all_player_idxs[0]) 969 | for (row_idx, player_idxs) in enumerate(all_player_idxs[1:]): 970 | current_players = set(player_idxs) 971 | if len(prev_players & current_players) == 10: 972 | wall_clock_diff = wall_clock_diffs[row_idx] 973 | if wall_clock_diff < THRESHOLD: 974 | for player_idx in current_players: 975 | player_idx2playing_time[player_idx] = ( 976 | player_idx2playing_time.get(player_idx, 0) + wall_clock_diff 977 | ) 978 | 979 | prev_players = current_players 980 | 981 | queue.put(player_idx2playing_time) 982 | 983 | 984 | def get_player_idx2playing_time_map(): 985 | q = multiprocessing.Queue() 986 | 987 | all_gameids = list(set([np_f.split("_")[0] for np_f in os.listdir(GAMES_DIR)])) 988 | processes = multiprocessing.cpu_count() 989 | gameids_per_process = int(np.ceil(len(all_gameids) / processes)) 990 | jobs = [] 991 | for i in range(processes): 992 | start = i * gameids_per_process 993 | end = start + gameids_per_process 994 | gameids = all_gameids[start:end] 995 | p = multiprocessing.Process( 996 | target=player_idx2playing_time_map_worker, args=(gameids, q) 997 | ) 998 | jobs.append(p) 999 | p.start() 1000 | 1001 | player_idx2playing_time = {} 1002 | for _ in jobs: 1003 | game_player_idx2_playing_time = q.get() 1004 | for (player_idx, playing_time) in game_player_idx2_playing_time.items(): 1005 | player_idx2playing_time[player_idx] = ( 1006 | player_idx2playing_time.get(player_idx, 0) + playing_time 1007 | ) 1008 | 1009 | for p in jobs: 1010 | p.join() 1011 | 1012 | playing_times = list(player_idx2playing_time.values()) 1013 | print(np.quantile(playing_times, [0.1, 0.2, 0.25, 0.5, 0.75, 0.8, 0.9])) 1014 | # [ 3364.6814 10270.2988 13314.768 39917.09399999 1015 | # 59131.73249999 63400.76839999 72048.72879999] 1016 | return player_idx2playing_time 1017 | 1018 | 1019 | if __name__ == "__main__": 1020 | os.makedirs(GAMES_DIR, exist_ok=True) 1021 | 1022 | (playerid2player_idx, player_idx2props) = get_playerid2player_idx_map() 1023 | 1024 | try: 1025 | baller2vec_config = pickle.load( 1026 | open(f"{DATA_DIR}/baller2vec_config.pydict", "rb") 1027 | ) 1028 | player_idx2props = baller2vec_config["player_idx2props"] 1029 | event2event_idx = baller2vec_config["event2event_idx"] 1030 | playerid2player_idx = {} 1031 | for (player_idx, props) in player_idx2props.items(): 1032 | playerid2player_idx[props["playerid"]] = player_idx 1033 | 1034 | except FileNotFoundError: 1035 | baller2vec_config = False 1036 | 1037 | shot_times = get_shot_times() 1038 | hoop_sides = get_team_hoop_sides() 1039 | (event2event_idx, gameid2event_stream) = get_event_streams() 1040 | if baller2vec_config: 1041 | event2event_idx = baller2vec_config["event2event_idx"] 1042 | 1043 | save_numpy_arrays() 1044 | 1045 | player_idx2playing_time = get_player_idx2playing_time_map() 1046 | for (player_idx, playing_time) in player_idx2playing_time.items(): 1047 | player_idx2props[player_idx]["playing_time"] = playing_time 1048 | 1049 | if not baller2vec_config: 1050 | baller2vec_config = { 1051 | "player_idx2props": player_idx2props, 1052 | "event2event_idx": event2event_idx, 1053 | } 1054 | pickle.dump( 1055 | baller2vec_config, open(f"{DATA_DIR}/baller2vec_config.pydict", "wb") 1056 | ) 1057 | -------------------------------------------------------------------------------- /gif_cropper.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # See: https://stackoverflow.com/q/965053/1316276. 4 | # for f in *.gif; do gifsicle --crop 0,0-250,266 --output ${f%.*}_cropped.gif ${f}; done 5 | for f in *.gif; do gifsicle --crop 250,0-500,266 --output ${f%.*}_cropped.gif ${f}; done -------------------------------------------------------------------------------- /gif_merger.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # See: https://stackoverflow.com/a/30932152/1316276. 4 | # Separate frames of train.gif. 5 | convert train.gif -coalesce a-%04d.gif 6 | # Separate frames of baller2vec.gif. 7 | convert baller2vec.gif -coalesce b-%04d.gif 8 | # Separate frames of baller2vec++.gif. 9 | convert baller2vec++.gif -coalesce c-%04d.gif 10 | # Append frames side-by-side. 11 | for f in a-*.gif; do convert $f -size 10x xc:black ${f/a/b} -size 10x xc:black ${f/a/c} +append ${f}; done 12 | # Rejoin frames. 13 | convert -loop 0 -delay 40 a-*.gif result.gif 14 | rm a-*.gif b-*.gif c-*.gif -------------------------------------------------------------------------------- /image_cropper.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # See: https://deparkes.co.uk/2015/04/30/batch-crop-images-with-imagemagick/. 4 | mogrify -crop 250x266+250+0 *.png 5 | -------------------------------------------------------------------------------- /paper_functions.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use("Agg") 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | import shutil 9 | import torch 10 | import yaml 11 | 12 | from PIL import Image, ImageDraw 13 | from settings import * 14 | from torch import nn 15 | from toy_dataset import ToyDataset 16 | from train_baller2vecplusplus import init_basketball_datasets, init_model 17 | 18 | colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] 19 | 20 | shuffle_keys = [ 21 | "player_idxs", 22 | "player_xs", 23 | "player_ys", 24 | "player_hoop_sides", 25 | "player_x_diffs", 26 | "player_y_diffs", 27 | "player_trajs", 28 | ] 29 | 30 | 31 | def add_grid(img, steps, highlight_center=True): 32 | # See: https://randomgeekery.org/post/2017/11/drawing-grids-with-python-and-pillow/. 33 | fill = (128, 128, 128) 34 | 35 | draw = ImageDraw.Draw(img) 36 | y_start = 0 37 | y_end = img.height 38 | step_size = int(img.width / steps) 39 | 40 | for x in range(0, img.width, step_size): 41 | line = ((x, y_start), (x, y_end)) 42 | draw.line(line, fill=fill) 43 | 44 | x_start = 0 45 | x_end = img.width 46 | 47 | for y in range(0, img.height, step_size): 48 | line = ((x_start, y), (x_end, y)) 49 | draw.line(line, fill=fill) 50 | 51 | if highlight_center: 52 | mid_color = (0, 0, 255) 53 | half_steps = steps // 2 54 | (mid_start, mid_end) = ( 55 | half_steps * step_size, 56 | half_steps * step_size + step_size, 57 | ) 58 | draw.line(((mid_start, mid_start), (mid_end, mid_start)), fill=mid_color) 59 | draw.line(((mid_start, mid_end), (mid_end, mid_end)), fill=mid_color) 60 | draw.line(((mid_start, mid_start), (mid_start, mid_end)), fill=mid_color) 61 | draw.line(((mid_end, mid_start), (mid_end, mid_end)), fill=mid_color) 62 | 63 | del draw 64 | 65 | 66 | def gen_imgs_from_sample(tensors, seq_len): 67 | # Create grid background. 68 | scale = 4 69 | img_size = 2 * seq_len + 2 70 | img_array = np.full((img_size, img_size, 3), 255, dtype=np.uint8) 71 | img = Image.fromarray(img_array).resize( 72 | (scale * img_size, scale * img_size), resample=0 73 | ) 74 | add_grid(img, img_size, False) 75 | 76 | (width, height) = (5, 5) 77 | (fig, ax) = plt.subplots(figsize=(width, height)) 78 | pt_scale = 1.4 79 | half_size = img_size / 2 80 | 81 | # Always give the agents the same colors. 82 | (player_idxs, p_idxs) = tensors["player_idxs"][0].sort() 83 | 84 | traj_imgs = [] 85 | for time_step in range(seq_len): 86 | ax.imshow(img, zorder=0, extent=[-half_size, half_size, -half_size, half_size]) 87 | ax.axis("off") 88 | ax.grid(False) 89 | for (idx, p_idx) in enumerate(p_idxs): 90 | ax.scatter( 91 | [tensors["player_xs"][time_step, p_idx] + 0.5], 92 | [tensors["player_ys"][time_step, p_idx] - 0.5], 93 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2, 94 | c=colors[idx], 95 | edgecolors="black", 96 | ) 97 | plt.xlim(auto=False) 98 | plt.ylim(auto=False) 99 | ax.plot( 100 | tensors["player_xs"][: time_step + 1, p_idx] + 0.5, 101 | tensors["player_ys"][: time_step + 1, p_idx] - 0.5, 102 | c=colors[idx], 103 | ) 104 | 105 | plt.subplots_adjust(0, 0, 1, 1) 106 | fig.canvas.draw() 107 | plt.cla() 108 | traj_imgs.append( 109 | Image.frombytes( 110 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() 111 | ) 112 | ) 113 | 114 | ax.imshow(img, zorder=0, extent=[-half_size, half_size, -half_size, half_size]) 115 | ax.axis("off") 116 | ax.grid(False) 117 | for (idx, p_idx) in enumerate(p_idxs): 118 | ax.scatter( 119 | [tensors["player_xs"][0, p_idx] + 0.5], 120 | [tensors["player_ys"][0, p_idx] - 0.5], 121 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2, 122 | c=colors[idx], 123 | edgecolors="black", 124 | ) 125 | plt.xlim(auto=False) 126 | plt.ylim(auto=False) 127 | ax.plot( 128 | tensors["player_xs"][:, p_idx] + 0.5, 129 | tensors["player_ys"][:, p_idx] - 0.5, 130 | c=colors[idx], 131 | ) 132 | 133 | plt.subplots_adjust(0, 0, 1, 1) 134 | fig.canvas.draw() 135 | plt.cla() 136 | traj_img = Image.frombytes( 137 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() 138 | ) 139 | 140 | return (traj_imgs, traj_img) 141 | 142 | 143 | def animate_toy_examples(): 144 | # Initialize dataset. 145 | ds = ToyDataset() 146 | n_players = ds.n_players 147 | player_traj_n = ds.player_traj_n 148 | seq_len = ds.seq_len 149 | 150 | home_dir = os.path.expanduser("~") 151 | os.makedirs(f"{home_dir}/results", exist_ok=True) 152 | 153 | tensors = ds[0] 154 | (traj_imgs, traj_img) = gen_imgs_from_sample(tensors, seq_len) 155 | traj_imgs[0].save( 156 | f"{home_dir}/results/train.gif", 157 | save_all=True, 158 | append_images=traj_imgs[1:], 159 | duration=400, 160 | loop=0, 161 | ) 162 | traj_img.save( 163 | f"{home_dir}/results/train.png", 164 | ) 165 | 166 | for JOB in ["20210408161424", "20210408160343"]: 167 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}" 168 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml")) 169 | 170 | # Initialize model. 171 | device = torch.device("cuda:0") 172 | model = init_model(opts, ds).to(device) 173 | model.load_state_dict(torch.load(f"{JOB_DIR}/best_params.pth")) 174 | model.eval() 175 | 176 | with torch.no_grad(): 177 | tensors = ds[0] 178 | for step in range(seq_len): 179 | preds_start = n_players * step 180 | for player_idx in range(n_players): 181 | pred_idx = preds_start + player_idx 182 | preds = model(tensors)[pred_idx] 183 | probs = torch.softmax(preds, dim=0) 184 | samp_traj = torch.multinomial(probs, 1) 185 | 186 | samp_row = samp_traj // player_traj_n 187 | samp_col = samp_traj % player_traj_n 188 | 189 | samp_x = samp_col.item() - 1 190 | samp_y = samp_row.item() - 1 191 | 192 | tensors["player_x_diffs"][step, player_idx] = samp_x 193 | tensors["player_y_diffs"][step, player_idx] = samp_y 194 | if step < seq_len - 1: 195 | tensors["player_xs"][step + 1, player_idx] = ( 196 | tensors["player_xs"][step, player_idx] + samp_x 197 | ) 198 | tensors["player_ys"][step + 1, player_idx] = ( 199 | tensors["player_ys"][step, player_idx] + samp_y 200 | ) 201 | 202 | (traj_imgs, traj_img) = gen_imgs_from_sample(tensors, seq_len) 203 | traj_imgs[0].save( 204 | f"{home_dir}/results/{JOB}.gif", 205 | save_all=True, 206 | append_images=traj_imgs[1:], 207 | duration=400, 208 | loop=0, 209 | ) 210 | traj_img.save(f"{home_dir}/results/{JOB}.png") 211 | 212 | shutil.make_archive(f"{home_dir}/results", "zip", f"{home_dir}/results") 213 | shutil.rmtree(f"{home_dir}/results") 214 | 215 | 216 | def gen_imgs_from_basketball_sample(tensors, seq_len): 217 | court = plt.imread("court.png") 218 | width = 5 219 | height = width * court.shape[0] / float(court.shape[1]) 220 | pt_scale = 1.4 221 | (fig, ax) = plt.subplots(figsize=(width, height)) 222 | 223 | # Always give the agents the same colors. 224 | (player_idxs, p_idxs) = tensors["player_idxs"][0].sort() 225 | markers = [] 226 | for p_idx in range(10): 227 | if tensors["player_hoop_sides"][0, p_idx]: 228 | markers.append("s") 229 | else: 230 | markers.append("^") 231 | 232 | traj_imgs = [] 233 | for time_step in range(seq_len): 234 | ax.imshow(court, zorder=0, extent=[X_MIN, X_MAX - DIFF, Y_MAX, Y_MIN]) 235 | ax.axis("off") 236 | ax.grid(False) 237 | for (idx, p_idx) in enumerate(p_idxs): 238 | ax.scatter( 239 | [tensors["player_xs"][time_step, p_idx]], 240 | [tensors["player_ys"][time_step, p_idx]], 241 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2, 242 | c=colors[idx], 243 | marker=markers[p_idx], 244 | edgecolors="black", 245 | ) 246 | plt.xlim(auto=False) 247 | plt.ylim(auto=False) 248 | ax.plot( 249 | tensors["player_xs"][: time_step + 1, p_idx], 250 | tensors["player_ys"][: time_step + 1, p_idx], 251 | c=colors[idx], 252 | ) 253 | 254 | plt.subplots_adjust(0, 0, 1, 1) 255 | fig.canvas.draw() 256 | traj_img = Image.frombytes( 257 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() 258 | ) 259 | plt.cla() 260 | traj_imgs.append(traj_img) 261 | 262 | ax.imshow(court, zorder=0, extent=[X_MIN, X_MAX - DIFF, Y_MAX, Y_MIN]) 263 | ax.axis("off") 264 | ax.grid(False) 265 | for (idx, p_idx) in enumerate(p_idxs): 266 | ax.scatter( 267 | [tensors["player_xs"][0, p_idx]], 268 | [tensors["player_ys"][0, p_idx]], 269 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2, 270 | c=colors[idx], 271 | marker=markers[p_idx], 272 | edgecolors="black", 273 | ) 274 | plt.xlim(auto=False) 275 | plt.ylim(auto=False) 276 | ax.plot( 277 | tensors["player_xs"][:, p_idx], 278 | tensors["player_ys"][:, p_idx], 279 | c=colors[idx], 280 | ) 281 | 282 | plt.subplots_adjust(0, 0, 1, 1) 283 | fig.canvas.draw() 284 | traj_img = Image.frombytes( 285 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() 286 | ) 287 | plt.close() 288 | 289 | return (traj_imgs, traj_img) 290 | 291 | 292 | def gen_tgt_imgs_from_basketball_sample(tensors, seq_len, tgt_idx): 293 | court = plt.imread("court.png") 294 | width = 5 295 | height = width * court.shape[0] / float(court.shape[1]) 296 | pt_scale = 1.4 297 | (fig, ax) = plt.subplots(figsize=(width, height)) 298 | 299 | # Always give the agents the same colors. 300 | colors = [] 301 | markers = [] 302 | for p_idx in range(10): 303 | if p_idx == tgt_idx: 304 | colors.append("r") 305 | elif tensors["player_hoop_sides"][0, p_idx]: 306 | colors.append("white") 307 | else: 308 | colors.append("gray") 309 | 310 | if tensors["player_hoop_sides"][0, p_idx]: 311 | markers.append("s") 312 | else: 313 | markers.append("^") 314 | 315 | traj_imgs = [] 316 | for time_step in range(seq_len): 317 | ax.imshow(court, zorder=0, extent=[X_MIN, X_MAX - DIFF, Y_MAX, Y_MIN]) 318 | ax.axis("off") 319 | ax.grid(False) 320 | for p_idx in range(10): 321 | ax.scatter( 322 | [tensors["player_xs"][time_step, p_idx]], 323 | [tensors["player_ys"][time_step, p_idx]], 324 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2, 325 | c=colors[p_idx], 326 | marker=markers[p_idx], 327 | edgecolors="black", 328 | ) 329 | plt.xlim(auto=False) 330 | plt.ylim(auto=False) 331 | ax.plot( 332 | tensors["player_xs"][: time_step + 1, p_idx], 333 | tensors["player_ys"][: time_step + 1, p_idx], 334 | c=colors[p_idx], 335 | ) 336 | 337 | plt.subplots_adjust(0, 0, 1, 1) 338 | fig.canvas.draw() 339 | traj_img = Image.frombytes( 340 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() 341 | ) 342 | plt.cla() 343 | traj_imgs.append(traj_img) 344 | 345 | ax.imshow(court, zorder=0, extent=[X_MIN, X_MAX - DIFF, Y_MAX, Y_MIN]) 346 | ax.axis("off") 347 | ax.grid(False) 348 | for p_idx in range(10): 349 | ax.scatter( 350 | [tensors["player_xs"][0, p_idx]], 351 | [tensors["player_ys"][0, p_idx]], 352 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2, 353 | c=colors[p_idx], 354 | marker=markers[p_idx], 355 | edgecolors="black", 356 | ) 357 | plt.xlim(auto=False) 358 | plt.ylim(auto=False) 359 | ax.plot( 360 | tensors["player_xs"][:, p_idx], 361 | tensors["player_ys"][:, p_idx], 362 | c=colors[p_idx], 363 | ) 364 | 365 | plt.subplots_adjust(0, 0, 1, 1) 366 | fig.canvas.draw() 367 | traj_img = Image.frombytes( 368 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() 369 | ) 370 | plt.close() 371 | 372 | return (traj_imgs, traj_img) 373 | 374 | 375 | def plot_generated_trajectories(): 376 | models = {"baller2vec": "20210402111942", "baller2vec++": "20210402111956"} 377 | samples = 3 378 | device = torch.device("cuda:0") 379 | home_dir = os.path.expanduser("~") 380 | os.makedirs(f"{home_dir}/results", exist_ok=True) 381 | for (which_model, JOB) in models.items(): 382 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}" 383 | 384 | # Load model. 385 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml")) 386 | (train_dataset, _, _, _, test_dataset, _) = init_basketball_datasets(opts) 387 | n_players = train_dataset.n_players 388 | player_traj_n = test_dataset.player_traj_n 389 | model = init_model(opts, train_dataset).to(device) 390 | model_dict = model.state_dict() 391 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth") 392 | pretrained_dict = { 393 | k: v for (k, v) in pretrained_dict.items() if k in model_dict 394 | } 395 | model_dict.update(pretrained_dict) 396 | model.load_state_dict(pretrained_dict) 397 | model.eval() 398 | seq_len = model.seq_len 399 | grid_gap = np.diff(test_dataset.player_traj_bins)[-1] / 2 400 | 401 | cand_test_idxs = [] 402 | for test_idx in range(len(test_dataset.gameids)): 403 | tensors = test_dataset[test_idx] 404 | if len(tensors["player_idxs"]) == model.seq_len: 405 | cand_test_idxs.append(test_idx) 406 | 407 | player_traj_bins = np.array(list(test_dataset.player_traj_bins) + [5.5]) 408 | 409 | torch.manual_seed(2010) 410 | np.random.seed(2010) 411 | 412 | test_idx = 97 413 | tensors = test_dataset[test_idx] 414 | 415 | (traj_imgs, traj_img) = gen_imgs_from_basketball_sample(tensors, seq_len) 416 | traj_imgs[0].save( 417 | f"{home_dir}/results/{test_idx}_truth.gif", 418 | save_all=True, 419 | append_images=traj_imgs[1:], 420 | duration=400, 421 | loop=0, 422 | ) 423 | traj_img.save( 424 | f"{home_dir}/results/{test_idx}_truth.png", 425 | ) 426 | for sample in range(samples): 427 | tensors = test_dataset[test_idx] 428 | with torch.no_grad(): 429 | for step in range(seq_len): 430 | preds_start = 10 * step 431 | for player_idx in range(n_players): 432 | pred_idx = preds_start + player_idx 433 | preds = model(tensors)[pred_idx] 434 | probs = torch.softmax(preds, dim=0) 435 | samp_traj = torch.multinomial(probs, 1) 436 | 437 | samp_row = samp_traj // player_traj_n 438 | samp_col = samp_traj % player_traj_n 439 | 440 | samp_x = ( 441 | player_traj_bins[samp_col] 442 | - grid_gap 443 | + np.random.uniform(-grid_gap, grid_gap) 444 | ) 445 | samp_y = ( 446 | player_traj_bins[samp_row] 447 | - grid_gap 448 | + np.random.uniform(-grid_gap, grid_gap) 449 | ) 450 | 451 | tensors["player_x_diffs"][step, player_idx] = samp_x 452 | tensors["player_y_diffs"][step, player_idx] = samp_y 453 | if step < seq_len - 1: 454 | tensors["player_xs"][step + 1, player_idx] = ( 455 | tensors["player_xs"][step, player_idx] + samp_x 456 | ) 457 | tensors["player_ys"][step + 1, player_idx] = ( 458 | tensors["player_ys"][step, player_idx] + samp_y 459 | ) 460 | 461 | (traj_imgs, traj_img) = gen_imgs_from_basketball_sample(tensors, seq_len) 462 | traj_imgs[0].save( 463 | f"{home_dir}/results/{test_idx}_gen_{which_model}_{sample}.gif", 464 | save_all=True, 465 | append_images=traj_imgs[1:], 466 | duration=400, 467 | loop=0, 468 | ) 469 | traj_img.save( 470 | f"{home_dir}/results/{test_idx}_gen_{which_model}_{sample}.png", 471 | ) 472 | 473 | shutil.make_archive(f"{home_dir}/results", "zip", f"{home_dir}/results") 474 | shutil.rmtree(f"{home_dir}/results") 475 | 476 | 477 | def compare_single_player_generation(): 478 | models = {"baller2vec": "20210402111942", "baller2vec++": "20210402111956"} 479 | samples = 10 480 | device = torch.device("cuda:0") 481 | home_dir = os.path.expanduser("~") 482 | os.makedirs(f"{home_dir}/results", exist_ok=True) 483 | for (which_model, JOB) in models.items(): 484 | print(which_model) 485 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}" 486 | 487 | # Load model. 488 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml")) 489 | (train_dataset, _, _, _, test_dataset, _) = init_basketball_datasets(opts) 490 | n_players = train_dataset.n_players 491 | player_traj_n = test_dataset.player_traj_n 492 | model = init_model(opts, train_dataset).to(device) 493 | model_dict = model.state_dict() 494 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth") 495 | pretrained_dict = { 496 | k: v for (k, v) in pretrained_dict.items() if k in model_dict 497 | } 498 | model_dict.update(pretrained_dict) 499 | model.load_state_dict(pretrained_dict) 500 | model.eval() 501 | seq_len = model.seq_len 502 | grid_gap = np.diff(test_dataset.player_traj_bins)[-1] / 2 503 | 504 | cand_test_idxs = [] 505 | for test_idx in range(len(test_dataset.gameids)): 506 | tensors = test_dataset[test_idx] 507 | if len(tensors["player_idxs"]) == model.seq_len: 508 | cand_test_idxs.append(test_idx) 509 | 510 | player_traj_bins = np.array(list(test_dataset.player_traj_bins) + [5.5]) 511 | 512 | torch.manual_seed(2010) 513 | np.random.seed(2010) 514 | 515 | test_idx = 267 516 | for p_idx in range(n_players): 517 | # for p_idx in [0, 1]: 518 | print(p_idx, flush=True) 519 | swap_idxs = [p_idx, 9] 520 | tensors = test_dataset[test_idx] 521 | 522 | (traj_imgs, traj_img) = gen_tgt_imgs_from_basketball_sample( 523 | tensors, seq_len, p_idx 524 | ) 525 | traj_imgs[0].save( 526 | f"{home_dir}/results/{test_idx}_{p_idx}_truth.gif", 527 | save_all=True, 528 | append_images=traj_imgs[1:], 529 | duration=400, 530 | loop=0, 531 | ) 532 | traj_img.save( 533 | f"{home_dir}/results/{test_idx}_{p_idx}_truth.png", 534 | ) 535 | for sample in range(samples): 536 | tensors = test_dataset[test_idx] 537 | for key in shuffle_keys: 538 | tensors[key][:, swap_idxs] = tensors[key][:, swap_idxs[::-1]] 539 | 540 | with torch.no_grad(): 541 | for step in range(seq_len): 542 | preds_start = 10 * step 543 | pred_idx = preds_start + 9 544 | preds = model(tensors)[pred_idx] 545 | probs = torch.softmax(preds, dim=0) 546 | samp_traj = torch.multinomial(probs, 1) 547 | 548 | samp_row = samp_traj // player_traj_n 549 | samp_col = samp_traj % player_traj_n 550 | 551 | samp_x = ( 552 | player_traj_bins[samp_col] 553 | - grid_gap 554 | + np.random.uniform(-grid_gap, grid_gap) 555 | ) 556 | samp_y = ( 557 | player_traj_bins[samp_row] 558 | - grid_gap 559 | + np.random.uniform(-grid_gap, grid_gap) 560 | ) 561 | 562 | tensors["player_x_diffs"][step, 9] = samp_x 563 | tensors["player_y_diffs"][step, 9] = samp_y 564 | if step < seq_len - 1: 565 | tensors["player_xs"][step + 1, 9] = ( 566 | tensors["player_xs"][step, 9] + samp_x 567 | ) 568 | tensors["player_ys"][step + 1, 9] = ( 569 | tensors["player_ys"][step, 9] + samp_y 570 | ) 571 | 572 | (traj_imgs, traj_img) = gen_tgt_imgs_from_basketball_sample( 573 | tensors, seq_len, 9 574 | ) 575 | traj_imgs[0].save( 576 | f"{home_dir}/results/{test_idx}_{p_idx}_gen_{which_model}_{sample}.gif", 577 | save_all=True, 578 | append_images=traj_imgs[1:], 579 | duration=400, 580 | loop=0, 581 | ) 582 | traj_img.save( 583 | f"{home_dir}/results/{test_idx}_{p_idx}_gen_{which_model}_{sample}.png", 584 | ) 585 | 586 | shutil.make_archive(f"{home_dir}/results", "zip", f"{home_dir}/results") 587 | shutil.rmtree(f"{home_dir}/results") 588 | 589 | 590 | def compare_time_steps(): 591 | models = {"baller2vec": "20210402111942", "baller2vec++": "20210402111956"} 592 | results = {} 593 | device = torch.device("cuda:0") 594 | for (which_model, JOB) in models.items(): 595 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}" 596 | 597 | # Load model. 598 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml")) 599 | (train_dataset, _, _, _, test_dataset, test_loader) = init_basketball_datasets( 600 | opts 601 | ) 602 | n_players = train_dataset.n_players 603 | model = init_model(opts, train_dataset).to(device) 604 | # See: https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3. 605 | model_dict = model.state_dict() 606 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth") 607 | pretrained_dict = { 608 | k: v for (k, v) in pretrained_dict.items() if k in model_dict 609 | } 610 | model_dict.update(pretrained_dict) 611 | model.load_state_dict(pretrained_dict) 612 | model.eval() 613 | seq_len = model.seq_len 614 | 615 | criterion = nn.CrossEntropyLoss() 616 | 617 | test_loss_steps = np.zeros(seq_len) 618 | test_loss_all = 0.0 619 | n_test = 0 620 | with torch.no_grad(): 621 | for test_tensors in test_loader: 622 | # Skip bad sequences. 623 | if len(test_tensors["player_idxs"]) < seq_len: 624 | continue 625 | 626 | player_trajs = test_tensors["player_trajs"].flatten() 627 | labels = player_trajs.to(device) 628 | preds = model(test_tensors) 629 | 630 | test_loss_all += criterion(preds, labels).item() 631 | for time_step in range(seq_len): 632 | start = time_step * n_players 633 | stop = start + n_players 634 | test_loss_steps[time_step] += criterion( 635 | preds[start:stop], labels[start:stop] 636 | ).item() 637 | 638 | n_test += 1 639 | 640 | print(test_loss_steps / n_test) 641 | print(test_loss_all / n_test) 642 | results[which_model] = test_loss_steps / n_test 643 | 644 | print((results["baller2vec"] - results["baller2vec++"]) / results["baller2vec"]) 645 | 646 | 647 | def test_permutation_invariance(): 648 | JOB = "20210402111956" 649 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}" 650 | 651 | device = torch.device("cuda:0") 652 | 653 | # Load model. 654 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml")) 655 | (train_dataset, _, _, _, test_dataset, test_loader) = init_basketball_datasets(opts) 656 | n_players = train_dataset.n_players 657 | model = init_model(opts, train_dataset).to(device) 658 | # See: https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3. 659 | model_dict = model.state_dict() 660 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth") 661 | pretrained_dict = {k: v for (k, v) in pretrained_dict.items() if k in model_dict} 662 | model_dict.update(pretrained_dict) 663 | model.load_state_dict(pretrained_dict) 664 | model.eval() 665 | seq_len = model.seq_len 666 | 667 | criterion = nn.CrossEntropyLoss() 668 | 669 | shuffles = 10 670 | np.random.seed(2010) 671 | data = [] 672 | avg_diffs = [] 673 | with torch.no_grad(): 674 | for test_tensors in test_loader: 675 | # Skip bad sequences. 676 | if len(test_tensors["player_idxs"]) < seq_len: 677 | continue 678 | 679 | player_trajs = test_tensors["player_trajs"].flatten() 680 | labels = player_trajs.to(device) 681 | preds = model(test_tensors) 682 | 683 | test_loss = criterion(preds, labels).item() 684 | avg_diff = 0 685 | 686 | for shuffle in range(shuffles): 687 | shuffled_players = np.random.choice(np.arange(10), n_players, False) 688 | for key in shuffle_keys: 689 | test_tensors[key] = test_tensors[key][:, shuffled_players] 690 | 691 | player_trajs = test_tensors["player_trajs"].flatten() 692 | labels = player_trajs.to(device) 693 | preds = model(test_tensors) 694 | shuffle_loss = criterion(preds, labels).item() 695 | data.append({"anchor": test_loss, "shuffled": shuffle_loss}) 696 | avg_diff += np.abs(test_loss - shuffle_loss) 697 | 698 | avg_diff /= shuffles 699 | avg_diffs.append(avg_diff / test_loss) 700 | 701 | print(np.mean(avg_diffs)) 702 | df = pd.DataFrame(data) 703 | home_dir = os.path.expanduser("~") 704 | df.to_csv(f"{home_dir}/results.csv") 705 | print(np.abs(df["anchor"] - df["shuffled"]).mean()) 706 | 707 | 708 | def compare_player_beginning_end(): 709 | JOB = "20210402111956" 710 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}" 711 | 712 | device = torch.device("cuda:0") 713 | 714 | # Load model. 715 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml")) 716 | (train_dataset, _, _, _, test_dataset, test_loader) = init_basketball_datasets(opts) 717 | n_players = train_dataset.n_players 718 | model = init_model(opts, train_dataset).to(device) 719 | # See: https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3. 720 | model_dict = model.state_dict() 721 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth") 722 | pretrained_dict = {k: v for (k, v) in pretrained_dict.items() if k in model_dict} 723 | model_dict.update(pretrained_dict) 724 | model.load_state_dict(pretrained_dict) 725 | model.eval() 726 | seq_len = model.seq_len 727 | 728 | criterion = nn.CrossEntropyLoss() 729 | 730 | shuffles = 10 731 | shuffle_keys = [ 732 | "player_idxs", 733 | "player_xs", 734 | "player_ys", 735 | "player_hoop_sides", 736 | "player_x_diffs", 737 | "player_y_diffs", 738 | "player_trajs", 739 | ] 740 | np.random.seed(2010) 741 | avg_diffs = [] 742 | with torch.no_grad(): 743 | for (test_idx, test_tensors) in enumerate(test_loader): 744 | print(test_idx) 745 | # Skip bad sequences. 746 | if len(test_tensors["player_idxs"]) < seq_len: 747 | continue 748 | 749 | for roll in range(n_players): 750 | for key in shuffle_keys: 751 | test_tensors[key] = test_tensors[key].roll(roll, 1) 752 | 753 | player_trajs = test_tensors["player_trajs"].flatten() 754 | labels = player_trajs.to(device) 755 | preds = model(test_tensors) 756 | test_loss = criterion(preds[:1], labels[:1]).item() 757 | # print(test_loss) 758 | 759 | avg_diff = 0 760 | for shuffle in range(shuffles): 761 | shuffled_players = list( 762 | np.random.choice(np.arange(1, 10), n_players - 1, False) 763 | ) + [0] 764 | shuffled_tensors = {} 765 | for key in shuffle_keys: 766 | shuffled_tensors[key] = test_tensors[key][:, shuffled_players] 767 | 768 | player_trajs = shuffled_tensors["player_trajs"].flatten() 769 | labels = player_trajs.to(device) 770 | preds = model(shuffled_tensors) 771 | shuffle_loss = criterion(preds[9:10], labels[9:10]).item() 772 | # print(shuffle_loss) 773 | avg_diff += test_loss - shuffle_loss 774 | 775 | # print() 776 | avg_diff /= shuffles 777 | avg_diffs.append(avg_diff / test_loss) 778 | # print(avg_diffs[-1]) 779 | 780 | print(np.mean(avg_diffs)) 781 | 782 | 783 | def get_diff_for_each_seq(): 784 | device = torch.device("cuda:0") 785 | models = {"baller2vec": "20210402111942", "baller2vec++": "20210402111956"} 786 | for (which_model, JOB) in models.items(): 787 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}" 788 | 789 | # Load model. 790 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml")) 791 | (train_dataset, _, _, _, _, test_loader) = init_basketball_datasets(opts) 792 | model = init_model(opts, train_dataset).to(device) 793 | model_dict = model.state_dict() 794 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth") 795 | pretrained_dict = { 796 | k: v for (k, v) in pretrained_dict.items() if k in model_dict 797 | } 798 | model_dict.update(pretrained_dict) 799 | model.load_state_dict(pretrained_dict) 800 | model.eval() 801 | seq_len = model.seq_len 802 | models[which_model] = model 803 | 804 | home_dir = os.path.expanduser("~") 805 | os.makedirs(f"{home_dir}/results", exist_ok=True) 806 | 807 | criterion = nn.CrossEntropyLoss() 808 | 809 | with torch.no_grad(): 810 | for (test_idx, test_tensors) in enumerate(test_loader): 811 | print(test_idx, flush=True) 812 | 813 | # Skip bad sequences. 814 | if len(test_tensors["player_idxs"]) < seq_len: 815 | continue 816 | 817 | player_trajs = test_tensors["player_trajs"].flatten() 818 | labels = player_trajs.to(device) 819 | losses = {} 820 | for (which_model, model) in models.items(): 821 | preds = model(test_tensors) 822 | losses[which_model] = criterion(preds, labels).item() 823 | 824 | b2v_loss = losses["baller2vec"] 825 | b2vpp_loss = losses["baller2vec++"] 826 | loss_diff = str(int(100 * (b2v_loss - b2vpp_loss) / b2v_loss)) 827 | 828 | (traj_imgs, traj_img) = gen_imgs_from_basketball_sample( 829 | test_tensors, seq_len 830 | ) 831 | traj_img.save( 832 | f"{home_dir}/results/{loss_diff.zfill(3)}_{test_idx}.png", 833 | ) 834 | 835 | shutil.make_archive(f"{home_dir}/results", "zip", f"{home_dir}/results") 836 | shutil.rmtree(f"{home_dir}/results") 837 | -------------------------------------------------------------------------------- /paper_plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from scipy.stats import pearsonr 7 | 8 | font = {"size": 22} 9 | 10 | matplotlib.rc("font", **font) 11 | 12 | 13 | def plot_nll_over_time(): 14 | # From: https://stackoverflow.com/a/59421062/1316276. 15 | # Numbers of pairs of bars you want. 16 | N = 20 17 | 18 | # Data on x-axis. 19 | 20 | # Specify the values of blue bars (height). 21 | blue_bar = [ 22 | 1.86897819, 23 | 0.58918497, 24 | 0.49199169, 25 | 0.47602542, 26 | 0.45367828, 27 | 0.43433007, 28 | 0.4377114, 29 | 0.43828331, 30 | 0.41888807, 31 | 0.42203086, 32 | 0.42575513, 33 | 0.41059803, 34 | 0.42874022, 35 | 0.41963211, 36 | 0.42825773, 37 | 0.42731653, 38 | 0.44892668, 39 | 0.44514571, 40 | 0.43499798, 41 | 0.46455422, 42 | ] 43 | # Specify the values of orange bars (height). 44 | orange_bar = [ 45 | 1.56743299, 46 | 0.54410327, 47 | 0.46949427, 48 | 0.45175594, 49 | 0.42762741, 50 | 0.41282431, 51 | 0.41227819, 52 | 0.41077594, 53 | 0.38996656, 54 | 0.39531277, 55 | 0.40007369, 56 | 0.37965269, 57 | 0.39489294, 58 | 0.3802205, 59 | 0.39562016, 60 | 0.39000066, 61 | 0.40920398, 62 | 0.39750678, 63 | 0.39637249, 64 | 0.41968062, 65 | ] 66 | 67 | # Position of bars on x-axis. 68 | ind = np.arange(N) 69 | 70 | # Figure size. 71 | plt.figure(figsize=(10, 5)) 72 | 73 | # Width of a bar. 74 | width = 0.3 75 | 76 | # Plotting. 77 | plt.bar(ind, blue_bar, width, label="baller2vec") 78 | plt.bar(ind + width, orange_bar, width, label="baller2vec++") 79 | 80 | plt.xlabel("Time step") 81 | plt.ylabel("Average NLL") 82 | 83 | # xticks(). 84 | # First argument: a list of positions at which ticks should be placed. 85 | # Second argument: a list of labels to place at the given locations. 86 | plt.xticks(ind + width / 2, list(range(1, N + 1))) 87 | 88 | # Finding the best position for legends and putting it there. 89 | plt.legend(loc="best", prop={"family": "monospace"}) 90 | plt.show() 91 | 92 | 93 | def plot_permutations(): 94 | df = pd.read_csv("test.csv") 95 | print(pearsonr(df["anchor"], df["shuffled"])) 96 | plt.scatter(df["anchor"], df["shuffled"]) 97 | plt.xlabel("Unshuffled average NLL") 98 | plt.ylabel("Shuffled average NLL") 99 | plt.show() 100 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | py7zr 3 | torch 4 | scipy 5 | numpy 6 | pandas 7 | Pillow 8 | PyYAML 9 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | NORMALIZATION_COEF = 7 4 | PLAYER_CIRCLE_SIZE = 12 / NORMALIZATION_COEF 5 | INTERVAL = 10 6 | DIFF = 6 7 | X_MIN = 0 8 | X_MAX = 100 9 | Y_MIN = 0 10 | Y_MAX = 50 11 | COL_WIDTH = 0.3 12 | SCALE = 1.65 13 | FONTSIZE = 6 14 | X_CENTER = X_MAX / 2 - DIFF / 1.5 + 0.10 15 | Y_CENTER = Y_MAX - DIFF / 1.5 - 0.35 16 | BALL_COLOR = "#ff8c00" 17 | (COURT_WIDTH, COURT_LENGTH) = (50, 94) 18 | TEAM_ID2PROPS = { 19 | 1610612737: {"color": "#E13A3E", "abbreviation": "ATL"}, 20 | 1610612738: {"color": "#008348", "abbreviation": "BOS"}, 21 | 1610612751: {"color": "#061922", "abbreviation": "BKN"}, 22 | 1610612766: {"color": "#1D1160", "abbreviation": "CHA"}, 23 | 1610612741: {"color": "#CE1141", "abbreviation": "CHI"}, 24 | 1610612739: {"color": "#860038", "abbreviation": "CLE"}, 25 | 1610612742: {"color": "#007DC5", "abbreviation": "DAL"}, 26 | 1610612743: {"color": "#4D90CD", "abbreviation": "DEN"}, 27 | 1610612765: {"color": "#006BB6", "abbreviation": "DET"}, 28 | 1610612744: {"color": "#FDB927", "abbreviation": "GSW"}, 29 | 1610612745: {"color": "#CE1141", "abbreviation": "HOU"}, 30 | 1610612754: {"color": "#00275D", "abbreviation": "IND"}, 31 | 1610612746: {"color": "#ED174C", "abbreviation": "LAC"}, 32 | 1610612747: {"color": "#552582", "abbreviation": "LAL"}, 33 | 1610612763: {"color": "#0F586C", "abbreviation": "MEM"}, 34 | 1610612748: {"color": "#98002E", "abbreviation": "MIA"}, 35 | 1610612749: {"color": "#00471B", "abbreviation": "MIL"}, 36 | 1610612750: {"color": "#005083", "abbreviation": "MIN"}, 37 | 1610612740: {"color": "#002B5C", "abbreviation": "NOP"}, 38 | 1610612752: {"color": "#006BB6", "abbreviation": "NYK"}, 39 | 1610612760: {"color": "#007DC3", "abbreviation": "OKC"}, 40 | 1610612753: {"color": "#007DC5", "abbreviation": "ORL"}, 41 | 1610612755: {"color": "#006BB6", "abbreviation": "PHI"}, 42 | 1610612756: {"color": "#1D1160", "abbreviation": "PHX"}, 43 | 1610612757: {"color": "#E03A3E", "abbreviation": "POR"}, 44 | 1610612758: {"color": "#724C9F", "abbreviation": "SAC"}, 45 | 1610612759: {"color": "#BAC3C9", "abbreviation": "SAS"}, 46 | 1610612761: {"color": "#CE1141", "abbreviation": "TOR"}, 47 | 1610612762: {"color": "#00471B", "abbreviation": "UTA"}, 48 | 1610612764: {"color": "#002B5C", "abbreviation": "WAS"}, 49 | } 50 | EVENTS_DIR = os.environ["EVENTS_DIR"] 51 | TRACKING_DIR = os.environ["TRACKING_DIR"] 52 | GAMES_DIR = os.environ["GAMES_DIR"] 53 | DATA_DIR = os.environ["DATA_DIR"] 54 | EXPERIMENTS_DIR = os.environ["EXPERIMENTS_DIR"] 55 | -------------------------------------------------------------------------------- /test_gameids.txt: -------------------------------------------------------------------------------- 1 | 0021500637 2 | 0021500479 3 | 0021500167 4 | 0021500149 5 | 0021500456 6 | 0021500425 7 | 0021500460 8 | 0021500160 9 | 0021500399 10 | 0021500145 11 | 0021500626 12 | 0021500267 13 | 0021500340 14 | 0021500268 15 | 0021500278 16 | 0021500147 17 | 0021500163 18 | 0021500593 19 | 0021500084 20 | 0021500403 21 | 0021500061 22 | 0021500434 23 | 0021500567 24 | 0021500359 25 | 0021500079 26 | 0021500481 27 | 0021500427 28 | 0021500243 29 | 0021500245 30 | 0021500622 31 | 0021500330 32 | 0021500522 33 | -------------------------------------------------------------------------------- /toy_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class ToyDataset(Dataset): 8 | def __init__(self): 9 | self.seq_len = 20 10 | self.n_players = 2 11 | self.n_player_ids = 2 12 | player_0_idxs = np.array(self.seq_len * [0]) 13 | player_1_idxs = np.array(self.seq_len * [1]) 14 | self.player_idxs = np.vstack([player_0_idxs, player_1_idxs]).T 15 | self.player_hoop_sides = np.zeros((self.seq_len, 2)) 16 | self.player_traj_n = 3 17 | 18 | def __len__(self): 19 | return 500 20 | 21 | def get_sample(self): 22 | # Sample coordinated trajectories. 23 | player_x_diffs = np.random.randint(-1, 2, self.seq_len) 24 | player_y_diffs = np.random.randint(-1, 2, self.seq_len) 25 | 26 | # Calculate total distance traveled in x and y directions. 27 | player_x_dists = np.cumsum(player_x_diffs)[None].T 28 | player_y_dists = np.cumsum(player_y_diffs)[None].T 29 | 30 | # Calculate x and y positions for agents. 31 | player_xs = np.zeros((self.seq_len, self.n_players)) 32 | # Agents start one unit to the left and right of the origin. 33 | player_xs[0] = [-1, 1] 34 | player_xs[1:] = player_xs[0] + player_x_dists[: self.seq_len - 1] 35 | player_ys = np.zeros((self.seq_len, self.n_players)) 36 | player_ys[1:] = player_y_dists[: self.seq_len - 1] 37 | 38 | # Agents start in random order from left to right. 39 | keep_players = np.random.choice(np.arange(2), self.n_players, False) 40 | player_idxs = self.player_idxs[:, keep_players].astype(int) 41 | 42 | # Convert trajectories to index labels. 43 | player_traj_rows = 1 + player_y_diffs 44 | player_traj_cols = 1 + player_x_diffs 45 | player_trajs = player_traj_rows * self.player_traj_n + player_traj_cols 46 | player_trajs = np.vstack([player_trajs, player_trajs]).T 47 | 48 | player_x_diffs = np.vstack([player_x_diffs, player_x_diffs]).T 49 | player_y_diffs = np.vstack([player_y_diffs, player_y_diffs]).T 50 | 51 | return { 52 | "player_idxs": torch.LongTensor(player_idxs), 53 | "player_xs": torch.Tensor(player_xs), 54 | "player_ys": torch.Tensor(player_ys), 55 | "player_x_diffs": torch.Tensor(player_x_diffs), 56 | "player_y_diffs": torch.Tensor(player_y_diffs), 57 | "player_hoop_sides": torch.Tensor(self.player_hoop_sides), 58 | "player_trajs": torch.LongTensor(player_trajs), 59 | } 60 | 61 | def __getitem__(self, idx): 62 | return self.get_sample() 63 | -------------------------------------------------------------------------------- /train_baller2vecplusplus.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import sys 4 | import time 5 | import torch 6 | import yaml 7 | 8 | from baller2vecplusplus import Baller2VecPlusPlus 9 | from baller2vecplusplus_dataset import Baller2VecPlusPlusDataset 10 | from settings import * 11 | from torch import nn, optim 12 | from torch.utils.data import DataLoader 13 | from toy_dataset import ToyDataset 14 | 15 | SEED = 2010 16 | torch.manual_seed(SEED) 17 | torch.set_printoptions(linewidth=160) 18 | np.random.seed(SEED) 19 | 20 | 21 | def worker_init_fn(worker_id): 22 | # See: https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers 23 | # and: https://pytorch.org/docs/stable/data.html#multi-process-data-loading 24 | # and: https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading. 25 | # NumPy seed takes a 32-bit unsigned integer. 26 | np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2 ** 32 - 1)) 27 | 28 | 29 | def get_train_valid_test_gameids(): 30 | with open("train_gameids.txt") as f: 31 | train_gameids = f.read().split() 32 | 33 | with open("valid_gameids.txt") as f: 34 | valid_gameids = f.read().split() 35 | 36 | with open("test_gameids.txt") as f: 37 | test_gameids = f.read().split() 38 | 39 | return (train_gameids, valid_gameids, test_gameids) 40 | 41 | 42 | def init_basketball_datasets(opts): 43 | baller2vec_config = pickle.load(open(f"{DATA_DIR}/baller2vec_config.pydict", "rb")) 44 | n_player_ids = len(baller2vec_config["player_idx2props"]) 45 | 46 | (train_gameids, valid_gameids, test_gameids) = get_train_valid_test_gameids() 47 | 48 | dataset_config = opts["dataset"] 49 | dataset_config["gameids"] = train_gameids 50 | dataset_config["N"] = opts["train"]["train_samples_per_epoch"] 51 | dataset_config["starts"] = [] 52 | dataset_config["mode"] = "train" 53 | dataset_config["n_player_ids"] = n_player_ids 54 | train_dataset = Baller2VecPlusPlusDataset(**dataset_config) 55 | train_loader = DataLoader( 56 | dataset=train_dataset, 57 | batch_size=None, 58 | num_workers=opts["train"]["workers"], 59 | worker_init_fn=worker_init_fn, 60 | ) 61 | 62 | N = opts["train"]["valid_samples"] 63 | samps_per_gameid = int(np.ceil(N / len(valid_gameids))) 64 | starts = [] 65 | for gameid in valid_gameids: 66 | y = np.load(f"{GAMES_DIR}/{gameid}_y.npy") 67 | max_start = len(y) - train_dataset.chunk_size 68 | gaps = max_start // samps_per_gameid 69 | starts.append(gaps * np.arange(samps_per_gameid)) 70 | 71 | dataset_config["gameids"] = np.repeat(valid_gameids, samps_per_gameid) 72 | dataset_config["N"] = len(dataset_config["gameids"]) 73 | dataset_config["starts"] = np.concatenate(starts) 74 | dataset_config["mode"] = "valid" 75 | valid_dataset = Baller2VecPlusPlusDataset(**dataset_config) 76 | valid_loader = DataLoader( 77 | dataset=valid_dataset, 78 | batch_size=None, 79 | num_workers=opts["train"]["workers"], 80 | ) 81 | 82 | samps_per_gameid = int(np.ceil(N / len(test_gameids))) 83 | starts = [] 84 | for gameid in test_gameids: 85 | y = np.load(f"{GAMES_DIR}/{gameid}_y.npy") 86 | max_start = len(y) - train_dataset.chunk_size 87 | gaps = max_start // samps_per_gameid 88 | starts.append(gaps * np.arange(samps_per_gameid)) 89 | 90 | dataset_config["gameids"] = np.repeat(test_gameids, samps_per_gameid) 91 | dataset_config["N"] = len(dataset_config["gameids"]) 92 | dataset_config["starts"] = np.concatenate(starts) 93 | dataset_config["mode"] = "test" 94 | test_dataset = Baller2VecPlusPlusDataset(**dataset_config) 95 | test_loader = DataLoader( 96 | dataset=test_dataset, 97 | batch_size=None, 98 | num_workers=opts["train"]["workers"], 99 | ) 100 | 101 | return ( 102 | train_dataset, 103 | train_loader, 104 | valid_dataset, 105 | valid_loader, 106 | test_dataset, 107 | test_loader, 108 | ) 109 | 110 | 111 | def init_toy_dataset(): 112 | train_dataset = ToyDataset() 113 | train_loader = DataLoader( 114 | dataset=train_dataset, 115 | batch_size=None, 116 | num_workers=opts["train"]["workers"], 117 | worker_init_fn=worker_init_fn, 118 | ) 119 | return (train_dataset, train_loader) 120 | 121 | 122 | def init_model(opts, train_dataset): 123 | model_config = opts["model"] 124 | # Add one for the generic player. 125 | model_config["n_player_ids"] = train_dataset.n_player_ids + 1 126 | model_config["seq_len"] = train_dataset.seq_len 127 | if opts["train"]["task"] == "basketball": 128 | model_config["seq_len"] -= 1 129 | 130 | model_config["n_players"] = train_dataset.n_players 131 | model_config["n_player_labels"] = train_dataset.player_traj_n ** 2 132 | model = Baller2VecPlusPlus(**model_config) 133 | 134 | return model 135 | 136 | 137 | def get_preds_labels(tensors): 138 | preds = model(tensors) 139 | labels = tensors["player_trajs"].flatten().to(device) 140 | return (preds, labels) 141 | 142 | 143 | def train_model(): 144 | # Initialize optimizer. 145 | train_params = [params for params in model.parameters()] 146 | optimizer = optim.Adam(train_params, lr=opts["train"]["learning_rate"]) 147 | criterion = nn.CrossEntropyLoss() 148 | 149 | # Continue training on a prematurely terminated model. 150 | try: 151 | model.load_state_dict(torch.load(f"{JOB_DIR}/best_params.pth")) 152 | 153 | try: 154 | state_dict = torch.load(f"{JOB_DIR}/optimizer.pth") 155 | if opts["train"]["learning_rate"] == state_dict["param_groups"][0]["lr"]: 156 | optimizer.load_state_dict(state_dict) 157 | 158 | except ValueError: 159 | print("Old optimizer doesn't match.") 160 | 161 | except FileNotFoundError: 162 | pass 163 | 164 | best_train_loss = float("inf") 165 | best_valid_loss = float("inf") 166 | test_loss_best_valid = float("inf") 167 | total_train_loss = None 168 | no_improvement = 0 169 | for epoch in range(1000000): 170 | print(f"\nepoch: {epoch}", flush=True) 171 | 172 | if task == "basketball": 173 | model.eval() 174 | total_valid_loss = 0.0 175 | with torch.no_grad(): 176 | n_valid = 0 177 | for (valid_idx, valid_tensors) in enumerate(valid_loader): 178 | # Skip bad sequences. 179 | if len(valid_tensors["player_idxs"]) < seq_len: 180 | continue 181 | 182 | (preds, labels) = get_preds_labels(valid_tensors) 183 | loss = criterion(preds, labels) 184 | total_valid_loss += loss.item() 185 | n_valid += 1 186 | 187 | probs = torch.softmax(preds, dim=1) 188 | (probs, preds) = probs.max(1) 189 | print(probs.view(seq_len, n_players), flush=True) 190 | print(preds.view(seq_len, n_players), flush=True) 191 | print(labels.view(seq_len, n_players), flush=True) 192 | 193 | total_valid_loss /= n_valid 194 | 195 | if total_valid_loss < best_valid_loss: 196 | best_valid_loss = total_valid_loss 197 | no_improvement = 0 198 | torch.save(optimizer.state_dict(), f"{JOB_DIR}/optimizer.pth") 199 | torch.save(model.state_dict(), f"{JOB_DIR}/best_params.pth") 200 | 201 | test_loss_best_valid = 0.0 202 | with torch.no_grad(): 203 | n_test = 0 204 | for (test_idx, test_tensors) in enumerate(test_loader): 205 | # Skip bad sequences. 206 | if len(test_tensors["player_idxs"]) < seq_len: 207 | continue 208 | 209 | (preds, labels) = get_preds_labels(test_tensors) 210 | loss = criterion(preds, labels) 211 | test_loss_best_valid += loss.item() 212 | n_test += 1 213 | 214 | test_loss_best_valid /= n_test 215 | 216 | elif no_improvement < opts["train"]["patience"]: 217 | no_improvement += 1 218 | if no_improvement == opts["train"]["patience"]: 219 | print("Reducing learning rate.") 220 | for g in optimizer.param_groups: 221 | g["lr"] *= 0.1 222 | 223 | print(f"total_train_loss: {total_train_loss}") 224 | print(f"best_train_loss: {best_train_loss}") 225 | if task == "basketball": 226 | print(f"total_valid_loss: {total_valid_loss}") 227 | print(f"best_valid_loss: {best_valid_loss}") 228 | print(f"test_loss_best_valid: {test_loss_best_valid}") 229 | 230 | model.train() 231 | total_train_loss = 0.0 232 | n_train = 0 233 | start_time = time.time() 234 | for (train_idx, train_tensors) in enumerate(train_loader): 235 | if train_idx % 1000 == 0: 236 | print(train_idx, flush=True) 237 | 238 | # Skip bad sequences. 239 | if len(train_tensors["player_idxs"]) < seq_len: 240 | continue 241 | 242 | optimizer.zero_grad() 243 | (preds, labels) = get_preds_labels(train_tensors) 244 | loss = criterion(preds, labels) 245 | total_train_loss += loss.item() 246 | loss.backward() 247 | optimizer.step() 248 | n_train += 1 249 | 250 | epoch_time = time.time() - start_time 251 | 252 | total_train_loss /= n_train 253 | if total_train_loss < best_train_loss: 254 | best_train_loss = total_train_loss 255 | if task == "toy": 256 | torch.save(optimizer.state_dict(), f"{JOB_DIR}/optimizer.pth") 257 | torch.save(model.state_dict(), f"{JOB_DIR}/best_params.pth") 258 | 259 | if task == "toy": 260 | train_tensors = train_dataset[0] 261 | while len(train_tensors["player_idxs"]) < seq_len: 262 | train_tensors = train_dataset[0] 263 | 264 | (preds, labels) = get_preds_labels(train_tensors) 265 | probs = torch.softmax(preds, dim=1) 266 | (probs, preds) = probs.max(1) 267 | print(probs.view(seq_len, n_players), flush=True) 268 | print(preds.view(seq_len, n_players), flush=True) 269 | print(labels.view(seq_len, n_players), flush=True) 270 | 271 | print(f"epoch_time: {epoch_time:.2f}", flush=True) 272 | 273 | 274 | if __name__ == "__main__": 275 | JOB = sys.argv[1] 276 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}" 277 | 278 | try: 279 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[2] 280 | except IndexError: 281 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 282 | 283 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml")) 284 | task = opts["train"]["task"] 285 | 286 | # Initialize datasets. 287 | if task == "basketball": 288 | ( 289 | train_dataset, 290 | train_loader, 291 | valid_dataset, 292 | valid_loader, 293 | test_dataset, 294 | test_loader, 295 | ) = init_basketball_datasets(opts) 296 | elif task == "toy": 297 | (train_dataset, train_loader) = init_toy_dataset() 298 | 299 | # Initialize model. 300 | device = torch.device("cuda:0") 301 | 302 | model = init_model(opts, train_dataset).to(device) 303 | print(model) 304 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 305 | print(f"Parameters: {n_params}") 306 | 307 | seq_len = model.seq_len 308 | n_players = model.n_players 309 | 310 | train_model() 311 | -------------------------------------------------------------------------------- /train_cropped.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/train_cropped.gif -------------------------------------------------------------------------------- /train_gameids.txt: -------------------------------------------------------------------------------- 1 | 0021500257 2 | 0021500516 3 | 0021500632 4 | 0021500193 5 | 0021500170 6 | 0021500080 7 | 0021500465 8 | 0021500142 9 | 0021500053 10 | 0021500374 11 | 0021500619 12 | 0021500651 13 | 0021500576 14 | 0021500277 15 | 0021500215 16 | 0021500410 17 | 0021500121 18 | 0021500517 19 | 0021500413 20 | 0021500143 21 | 0021500166 22 | 0021500463 23 | 0021500469 24 | 0021500621 25 | 0021500056 26 | 0021500271 27 | 0021500623 28 | 0021500341 29 | 0021500263 30 | 0021500366 31 | 0021500377 32 | 0021500009 33 | 0021500414 34 | 0021500151 35 | 0021500159 36 | 0021500227 37 | 0021500448 38 | 0021500052 39 | 0021500451 40 | 0021500212 41 | 0021500549 42 | 0021500633 43 | 0021500085 44 | 0021500356 45 | 0021500237 46 | 0021500018 47 | 0021500625 48 | 0021500261 49 | 0021500449 50 | 0021500346 51 | 0021500624 52 | 0021500295 53 | 0021500508 54 | 0021500187 55 | 0021500655 56 | 0021500127 57 | 0021500339 58 | 0021500452 59 | 0021500096 60 | 0021500210 61 | 0021500335 62 | 0021500493 63 | 0021500049 64 | 0021500013 65 | 0021500015 66 | 0021500480 67 | 0021500474 68 | 0021500354 69 | 0021500140 70 | 0021500652 71 | 0021500062 72 | 0021500562 73 | 0021500293 74 | 0021500432 75 | 0021500126 76 | 0021500103 77 | 0021500433 78 | 0021500326 79 | 0021500218 80 | 0021500094 81 | 0021500397 82 | 0021500099 83 | 0021500040 84 | 0021500491 85 | 0021500109 86 | 0021500171 87 | 0021500577 88 | 0021500662 89 | 0021500279 90 | 0021500309 91 | 0021500068 92 | 0021500106 93 | 0021500524 94 | 0021500379 95 | 0021500554 96 | 0021500338 97 | 0021500305 98 | 0021500105 99 | 0021500555 100 | 0021500639 101 | 0021500031 102 | 0021500462 103 | 0021500136 104 | 0021500026 105 | 0021500417 106 | 0021500083 107 | 0021500331 108 | 0021500074 109 | 0021500511 110 | 0021500405 111 | 0021500563 112 | 0021500504 113 | 0021500205 114 | 0021500382 115 | 0021500059 116 | 0021500046 117 | 0021500334 118 | 0021500254 119 | 0021500086 120 | 0021500017 121 | 0021500296 122 | 0021500337 123 | 0021500542 124 | 0021500638 125 | 0021500547 126 | 0021500520 127 | 0021500032 128 | 0021500518 129 | 0021500422 130 | 0021500630 131 | 0021500400 132 | 0021500412 133 | 0021500484 134 | 0021500048 135 | 0021500436 136 | 0021500371 137 | 0021500450 138 | 0021500224 139 | 0021500022 140 | 0021500082 141 | 0021500146 142 | 0021500239 143 | 0021500369 144 | 0021500391 145 | 0021500027 146 | 0021500381 147 | 0021500201 148 | 0021500636 149 | 0021500290 150 | 0021500012 151 | 0021500423 152 | 0021500431 153 | 0021500365 154 | 0021500387 155 | 0021500199 156 | 0021500150 157 | 0021500073 158 | 0021500663 159 | 0021500598 160 | 0021500580 161 | 0021500325 162 | 0021500342 163 | 0021500585 164 | 0021500281 165 | 0021500467 166 | 0021500066 167 | 0021500559 168 | 0021500492 169 | 0021500063 170 | 0021500306 171 | 0021500550 172 | 0021500113 173 | 0021500631 174 | 0021500353 175 | 0021500394 176 | 0021500318 177 | 0021500148 178 | 0021500019 179 | 0021500320 180 | 0021500247 181 | 0021500489 182 | 0021500057 183 | 0021500540 184 | 0021500246 185 | 0021500192 186 | 0021500510 187 | 0021500204 188 | 0021500368 189 | 0021500444 190 | 0021500586 191 | 0021500002 192 | 0021500572 193 | 0021500286 194 | 0021500538 195 | 0021500488 196 | 0021500496 197 | 0021500620 198 | 0021500288 199 | 0021500545 200 | 0021500095 201 | 0021500249 202 | 0021500584 203 | 0021500646 204 | 0021500116 205 | 0021500098 206 | 0021500348 207 | 0021500568 208 | 0021500307 209 | 0021500453 210 | 0021500408 211 | 0021500455 212 | 0021500260 213 | 0021500490 214 | 0021500599 215 | 0021500220 216 | 0021500601 217 | 0021500357 218 | 0021500274 219 | 0021500280 220 | 0021500176 221 | 0021500378 222 | 0021500273 223 | 0021500028 224 | 0021500447 225 | 0021500248 226 | 0021500471 227 | 0021500124 228 | 0021500200 229 | 0021500256 230 | 0021500115 231 | 0021500174 232 | 0021500457 233 | 0021500075 234 | 0021500658 235 | 0021500327 236 | 0021500534 237 | 0021500117 238 | 0021500389 239 | 0021500072 240 | 0021500380 241 | 0021500104 242 | 0021500011 243 | 0021500512 244 | 0021500195 245 | 0021500138 246 | 0021500556 247 | 0021500507 248 | 0021500553 249 | 0021500409 250 | 0021500437 251 | 0021500565 252 | 0021500442 253 | 0021500566 254 | 0021500446 255 | 0021500175 256 | 0021500198 257 | 0021500351 258 | 0021500616 259 | 0021500428 260 | 0021500537 261 | 0021500486 262 | 0021500370 263 | 0021500529 264 | 0021500575 265 | 0021500169 266 | 0021500531 267 | 0021500384 268 | 0021500398 269 | 0021500064 270 | 0021500158 271 | 0021500129 272 | 0021500153 273 | 0021500445 274 | 0021500564 275 | 0021500119 276 | 0021500345 277 | 0021500242 278 | 0021500071 279 | 0021500217 280 | 0021500500 281 | 0021500476 282 | 0021500358 283 | 0021500135 284 | 0021500482 285 | 0021500165 286 | 0021500472 287 | 0021500311 288 | 0021500373 289 | 0021500005 290 | 0021500232 291 | 0021500264 292 | 0021500214 293 | 0021500231 294 | 0021500528 295 | 0021500081 296 | 0021500191 297 | 0021500188 298 | 0021500592 299 | 0021500050 300 | 0021500582 301 | 0021500364 302 | 0021500054 303 | 0021500235 304 | 0021500660 305 | 0021500001 306 | 0021500429 307 | 0021500352 308 | 0021500078 309 | 0021500122 310 | 0021500509 311 | 0021500313 312 | 0021500023 313 | 0021500016 314 | 0021500360 315 | 0021500494 316 | 0021500362 317 | 0021500321 318 | 0021500299 319 | 0021500551 320 | 0021500229 321 | 0021500441 322 | 0021500025 323 | 0021500310 324 | 0021500648 325 | 0021500207 326 | 0021500315 327 | 0021500222 328 | 0021500285 329 | 0021500595 330 | 0021500617 331 | 0021500502 332 | 0021500514 333 | 0021500038 334 | 0021500211 335 | 0021500527 336 | 0021500597 337 | 0021500291 338 | 0021500383 339 | 0021500255 340 | 0021500539 341 | 0021500141 342 | 0021500583 343 | 0021500250 344 | 0021500234 345 | 0021500202 346 | 0021500418 347 | 0021500579 348 | 0021500089 349 | 0021500178 350 | 0021500144 351 | 0021500594 352 | 0021500189 353 | 0021500478 354 | 0021500519 355 | 0021500385 356 | 0021500506 357 | 0021500653 358 | 0021500505 359 | 0021500101 360 | 0021500533 361 | 0021500180 362 | 0021500152 363 | 0021500308 364 | 0021500312 365 | 0021500355 366 | 0021500090 367 | 0021500393 368 | 0021500349 369 | 0021500498 370 | 0021500420 371 | 0021500177 372 | 0021500629 373 | 0021500501 374 | 0021500344 375 | 0021500544 376 | 0021500581 377 | 0021500047 378 | 0021500128 379 | 0021500569 380 | 0021500114 381 | 0021500503 382 | 0021500314 383 | 0021500155 384 | 0021500650 385 | 0021500004 386 | 0021500156 387 | 0021500470 388 | 0021500402 389 | 0021500317 390 | 0021500173 391 | 0021500618 392 | 0021500236 393 | 0021500020 394 | 0021500571 395 | 0021500375 396 | 0021500185 397 | 0021500110 398 | 0021500272 399 | 0021500645 400 | 0021500615 401 | 0021500303 402 | 0021500322 403 | 0021500197 404 | 0021500036 405 | 0021500526 406 | 0021500323 407 | 0021500363 408 | 0021500329 409 | 0021500419 410 | 0021500251 411 | 0021500561 412 | 0021500558 413 | 0021500426 414 | 0021500468 415 | 0021500181 416 | 0021500596 417 | 0021500573 418 | 0021500390 419 | 0021500244 420 | 0021500219 421 | 0021500600 422 | 0021500111 423 | 0021500411 424 | 0021500560 425 | 0021500336 426 | 0021500475 427 | 0021500233 428 | 0021500515 429 | 0021500591 430 | 0021500058 431 | 0021500513 432 | 0021500208 433 | 0021500157 434 | 0021500270 435 | 0021500319 436 | 0021500186 437 | 0021500421 438 | 0021500034 439 | 0021500287 440 | 0021500301 441 | 0021500372 442 | 0021500483 443 | 0021500042 444 | 0021500024 445 | 0021500435 446 | 0021500259 447 | 0021500401 448 | 0021500164 449 | 0021500067 450 | 0021500535 451 | 0021500206 452 | 0021500557 453 | 0021500541 454 | 0021500213 455 | 0021500304 456 | 0021500350 457 | 0021500137 458 | 0021500657 459 | 0021500347 460 | 0021500172 461 | 0021500035 462 | 0021500184 463 | 0021500464 464 | 0021500196 465 | 0021500131 466 | 0021500439 467 | 0021500635 468 | 0021500262 469 | 0021500241 470 | 0021500485 471 | 0021500440 472 | 0021500182 473 | 0021500458 474 | 0021500060 475 | 0021500376 476 | 0021500037 477 | 0021500168 478 | 0021500275 479 | 0021500087 480 | 0021500183 481 | 0021500044 482 | 0021500092 483 | 0021500546 484 | 0021500396 485 | 0021500395 486 | 0021500487 487 | 0021500627 488 | 0021500077 489 | 0021500548 490 | 0021500300 491 | 0021500284 492 | 0021500216 493 | 0021500316 494 | 0021500404 495 | 0021500343 496 | 0021500125 497 | 0021500093 498 | 0021500258 499 | 0021500552 500 | 0021500328 501 | 0021500161 502 | 0021500223 503 | 0021500194 504 | 0021500097 505 | 0021500386 506 | 0021500041 507 | 0021500430 508 | 0021500102 509 | 0021500070 510 | 0021500179 511 | 0021500406 512 | 0021500130 513 | 0021500473 514 | 0021500107 515 | 0021500525 516 | 0021500407 517 | 0021500091 518 | 0021500139 519 | 0021500289 520 | 0021500499 521 | 0021500221 522 | 0021500530 523 | 0021500230 524 | 0021500051 525 | 0021500228 526 | 0021500021 527 | 0021500332 528 | 0021500162 529 | 0021500133 530 | 0021500570 531 | 0021500225 532 | 0021500367 533 | 0021500108 534 | 0021500424 535 | 0021500297 536 | 0021500523 537 | 0021500298 538 | 0021500112 539 | 0021500649 540 | 0021500203 541 | 0021500252 542 | 0021500292 543 | 0021500118 544 | 0021500209 545 | 0021500361 546 | 0021500154 547 | 0021500647 548 | 0021500532 549 | 0021500190 550 | 0021500642 551 | 0021500521 552 | 0021500388 553 | 0021500543 554 | 0021500045 555 | 0021500120 556 | 0021500415 557 | 0021500269 558 | 0021500088 559 | 0021500226 560 | 0021500003 561 | 0021500076 562 | 0021500123 563 | 0021500477 564 | 0021500438 565 | 0021500033 566 | 0021500392 567 | 0021500443 568 | 0021500578 569 | 0021500634 570 | -------------------------------------------------------------------------------- /valid_gameids.txt: -------------------------------------------------------------------------------- 1 | 0021500266 2 | 0021500466 3 | 0021500010 4 | 0021500132 5 | 0021500283 6 | 0021500043 7 | 0021500069 8 | 0021500029 9 | 0021500007 10 | 0021500055 11 | 0021500461 12 | 0021500265 13 | 0021500039 14 | 0021500134 15 | 0021500302 16 | 0021500454 17 | 0021500294 18 | 0021500628 19 | 0021500324 20 | 0021500333 21 | 0021500641 22 | 0021500253 23 | 0021500416 24 | 0021500495 25 | 0021500661 26 | 0021500497 27 | 0021500459 28 | 0021500536 29 | 0021500065 30 | 0021500574 31 | --------------------------------------------------------------------------------