├── .basketball_profile
├── .gitignore
├── 20210408160343_cropped.gif
├── 20210408161424_cropped.gif
├── 267_3_gen_baller2vec++_7_cropped.gif
├── 267_3_gen_baller2vec++_8_cropped.gif
├── 267_3_gen_baller2vec++_9_cropped.gif
├── 267_3_gen_baller2vec_0_cropped.gif
├── 267_3_gen_baller2vec_1_cropped.gif
├── 267_3_gen_baller2vec_7_cropped.gif
├── 267_3_truth_cropped.gif
├── LICENSE
├── README.md
├── baller2vec++.svg
├── baller2vecplusplus.py
├── baller2vecplusplus_dataset.py
├── court.png
├── events.zip
├── generate_game_numpy_arrays.py
├── gif_cropper.sh
├── gif_merger.sh
├── image_cropper.sh
├── paper_functions.py
├── paper_plots.py
├── requirements.txt
├── settings.py
├── test_gameids.txt
├── toy_dataset.py
├── train_baller2vecplusplus.py
├── train_cropped.gif
├── train_gameids.txt
└── valid_gameids.txt
/.basketball_profile:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Code directory.
4 | export PROJECT_DIR=/home/michael/baller2vec
5 | export DATA_DIR=/mnt/raid/michael/baller2vec_data
6 | export EVENTS_DIR=${DATA_DIR}/events
7 | export TRACKING_DIR=${DATA_DIR}/NBA-Player-Movements/data/2016.NBA.Raw.SportVU.Game.Logs
8 | export GAMES_DIR=${DATA_DIR}/games
9 | export EXPERIMENTS_DIR=${DATA_DIR}/experiments
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | __pycache__/
3 |
--------------------------------------------------------------------------------
/20210408160343_cropped.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/20210408160343_cropped.gif
--------------------------------------------------------------------------------
/20210408161424_cropped.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/20210408161424_cropped.gif
--------------------------------------------------------------------------------
/267_3_gen_baller2vec++_7_cropped.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec++_7_cropped.gif
--------------------------------------------------------------------------------
/267_3_gen_baller2vec++_8_cropped.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec++_8_cropped.gif
--------------------------------------------------------------------------------
/267_3_gen_baller2vec++_9_cropped.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec++_9_cropped.gif
--------------------------------------------------------------------------------
/267_3_gen_baller2vec_0_cropped.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec_0_cropped.gif
--------------------------------------------------------------------------------
/267_3_gen_baller2vec_1_cropped.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec_1_cropped.gif
--------------------------------------------------------------------------------
/267_3_gen_baller2vec_7_cropped.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_gen_baller2vec_7_cropped.gif
--------------------------------------------------------------------------------
/267_3_truth_cropped.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/267_3_truth_cropped.gif
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2021 Michael A. Alcorn
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # baller2vec++
2 |
3 | This is the repository for the paper:
4 |
5 | >[Michael A. Alcorn](https://sites.google.com/view/michaelaalcorn) and [Anh Nguyen](http://anhnguyen.me). [`baller2vec++`: A Look-Ahead Multi-Entity Transformer For Modeling Coordinated Agents](https://arxiv.org/abs/2104.11980). arXiv. 2021.
6 |
7 | |
|
8 | |:--|
9 | | To learn statistically dependent agent trajectories, baller2vec++
uses a specially designed self-attention mask to simultaneously process three different sets of features vectors in a single Transformer. The three sets of feature vectors consist of location feature vectors like those found in baller2vec
, look-ahead trajectory feature vectors, and starting location feature vectors. This design allows the model to integrate information about *concurrent* agent trajectories through *multiple* Transformer layers without seeing the future (in contrast to baller2vec
). |
10 |
11 | |
|
|
|
12 | |:--:|:--:|:--:|
13 | | Training sample | baller2vec
| baller2vec++
|
14 |
15 | When trained on a dataset of perfectly coordinated agent trajectories, the trajectories generated by baller2vec
are completely *uncoordinated* while the trajectories generated by baller2vec++
are perfectly coordinated.
16 |
17 | |
|
|
|
|
18 | |:--:|:--:|:--:|:--:|
19 | | Ground truth | baller2vec
| baller2vec
| baller2vec
|
20 | |
|
|
|
|
21 | | Ground truth | baller2vec++
| baller2vec++
| baller2vec++
|
22 |
23 | While baller2vec
occasionally generates realistic trajectories for the red defender, it also makes egregious errors.
24 | In contrast, the trajectories generated by baller2vec++
often seem plausible.
25 | The red player was placed *last* in the player order when generating his trajectory with baller2vec++
.
26 |
27 | ## Citation
28 |
29 | If you use this code for your own research, please cite:
30 |
31 | ```
32 | @article{alcorn2021baller2vec,
33 | title={\texttt{baller2vec++}: A Look-Ahead Multi-Entity Transformer For Modeling Coordinated Agents},
34 | author={Alcorn, Michael A. and Nguyen, Anh},
35 | journal={arXiv preprint arXiv:2104.11980},
36 | year={2021}
37 | }
38 | ```
39 |
40 | ## Training baller2vec++
41 |
42 | ### Setting up `.basketball_profile`
43 |
44 | After you've cloned the repository to your desired location, create a file called `.basketball_profile` in your home directory:
45 |
46 | ```bash
47 | nano ~/.basketball_profile
48 | ```
49 |
50 | and copy and paste in the contents of [`.basketball_profile`](.basketball_profile), replacing each of the variable values with paths relevant to your environment.
51 | Next, add the following line to the end of your `~/.bashrc`:
52 |
53 | ```bash
54 | source ~/.basketball_profile
55 | ```
56 |
57 | and either log out and log back in again or run:
58 |
59 | ```bash
60 | source ~/.bashrc
61 | ```
62 |
63 | You should now be able to copy and paste all of the commands in the various instructions sections.
64 | For example:
65 |
66 | ```bash
67 | echo ${PROJECT_DIR}
68 | ```
69 |
70 | should print the path you set for `PROJECT_DIR` in `.basketball_profile`.
71 |
72 | ### Installing the necessary Python packages
73 |
74 | ```bash
75 | cd ${PROJECT_DIR}
76 | pip3 install --upgrade -r requirements.txt
77 | ```
78 |
79 | ### Organizing the play-by-play and tracking data
80 |
81 | 1) Copy `events.zip` (which I acquired from [here](https://github.com/sealneaward/nba-movement-data/tree/master/data/events) \[mirror [here](https://github.com/airalcorn2/nba-movement-data/tree/master/data/events)\] using https://downgit.github.io) to the `DATA_DIR` directory and unzip it:
82 |
83 | ```bash
84 | mkdir -p ${DATA_DIR}
85 | cp ${PROJECT_DIR}/events.zip ${DATA_DIR}
86 | cd ${DATA_DIR}
87 | unzip -q events.zip
88 | rm events.zip
89 | ```
90 |
91 | Descriptions for the various `EVENTMSGTYPE`s can be found [here](https://github.com/rd11490/NBA_Tutorials/tree/master/analyze_play_by_play) (mirror [here](https://github.com/airalcorn2/NBA_Tutorials/tree/master/analyze_play_by_play)).
92 |
93 | 2) Clone the tracking data from [here](https://github.com/linouk23/NBA-Player-Movements) (mirror [here](https://github.com/airalcorn2/NBA-Player-Movements)) to the `DATA_DIR` directory:
94 |
95 | ```bash
96 | cd ${DATA_DIR}
97 | git clone git@github.com:linouk23/NBA-Player-Movements.git
98 | ```
99 |
100 | A description of the tracking data can be found [here](https://danvatterott.com/blog/2016/06/16/creating-videos-of-nba-action-with-sportsvu-data/).
101 |
102 | ### Generating the training data
103 |
104 | ```bash
105 | cd ${PROJECT_DIR}
106 | nohup python3 generate_game_numpy_arrays.py > data.log &
107 | ```
108 |
109 | You can monitor its progress with:
110 |
111 | ```bash
112 | top
113 | ```
114 |
115 | or:
116 |
117 | ```bash
118 | ls -U ${GAMES_DIR} | wc -l
119 | ```
120 |
121 | There should be 1,262 NumPy arrays (corresponding to 631 X/y pairs) when finished.
122 |
123 | ### Running the training script
124 |
125 | Run (or copy and paste) the following script, editing the variables as appropriate.
126 |
127 | ```bash
128 | #!/usr/bin/env bash
129 |
130 | JOB=$(date +%Y%m%d%H%M%S)
131 |
132 | echo "train:" >> ${JOB}.yaml
133 | task=basketball # "basketball" or "toy".
134 | echo " task: ${task}" >> ${JOB}.yaml
135 | if [[ "$task" = "basketball" ]]
136 | then
137 |
138 | echo " train_samples_per_epoch: 20000" >> ${JOB}.yaml
139 | echo " valid_samples: 1000" >> ${JOB}.yaml
140 | echo " workers: 10" >> ${JOB}.yaml
141 | echo " learning_rate: 1.0e-5" >> ${JOB}.yaml
142 | echo " patience: 20" >> ${JOB}.yaml
143 |
144 | echo "dataset:" >> ${JOB}.yaml
145 | echo " hz: 5" >> ${JOB}.yaml
146 | echo " secs: 4.2" >> ${JOB}.yaml
147 | echo " player_traj_n: 11" >> ${JOB}.yaml
148 | echo " max_player_move: 4.5" >> ${JOB}.yaml
149 |
150 | echo "model:" >> ${JOB}.yaml
151 | echo " embedding_dim: 20" >> ${JOB}.yaml
152 | echo " sigmoid: none" >> ${JOB}.yaml
153 | echo " mlp_layers: [128, 256, 512]" >> ${JOB}.yaml
154 | echo " nhead: 8" >> ${JOB}.yaml
155 | echo " dim_feedforward: 2048" >> ${JOB}.yaml
156 | echo " num_layers: 6" >> ${JOB}.yaml
157 | echo " dropout: 0.0" >> ${JOB}.yaml
158 | echo " b2v: False" >> ${JOB}.yaml
159 |
160 | else
161 |
162 | echo " workers: 10" >> ${JOB}.yaml
163 | echo " learning_rate: 1.0e-4" >> ${JOB}.yaml
164 |
165 | echo "model:" >> ${JOB}.yaml
166 | echo " embedding_dim: 20" >> ${JOB}.yaml
167 | echo " sigmoid: none" >> ${JOB}.yaml
168 | echo " mlp_layers: [64, 128]" >> ${JOB}.yaml
169 | echo " nhead: 4" >> ${JOB}.yaml
170 | echo " dim_feedforward: 512" >> ${JOB}.yaml
171 | echo " num_layers: 2" >> ${JOB}.yaml
172 | echo " dropout: 0.0" >> ${JOB}.yaml
173 | echo " b2v: True" >> ${JOB}.yaml
174 |
175 | fi
176 |
177 | # Save experiment settings.
178 | mkdir -p ${EXPERIMENTS_DIR}/${JOB}
179 | mv ${JOB}.yaml ${EXPERIMENTS_DIR}/${JOB}/
180 |
181 | gpu=0
182 | cd ${PROJECT_DIR}
183 | nohup python3 train_baller2vecplusplus.py ${JOB} ${gpu} > ${EXPERIMENTS_DIR}/${JOB}/train.log &
184 | ```
185 |
--------------------------------------------------------------------------------
/baller2vecplusplus.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from torch import nn
5 |
6 |
7 | class Baller2VecPlusPlus(nn.Module):
8 | def __init__(
9 | self,
10 | n_player_ids,
11 | embedding_dim,
12 | sigmoid,
13 | seq_len,
14 | mlp_layers,
15 | n_players,
16 | n_player_labels,
17 | nhead,
18 | dim_feedforward,
19 | num_layers,
20 | dropout,
21 | b2v,
22 | ):
23 | super().__init__()
24 | self.sigmoid = sigmoid
25 | self.seq_len = seq_len
26 | self.n_players = n_players
27 | self.b2v = b2v
28 |
29 | initrange = 0.1
30 | self.player_embedding = nn.Embedding(n_player_ids, embedding_dim)
31 | self.player_embedding.weight.data.uniform_(-initrange, initrange)
32 |
33 | start_mlp = nn.Sequential()
34 | pos_mlp = nn.Sequential()
35 | traj_mlp = nn.Sequential()
36 | pos_in_feats = embedding_dim + 3
37 | traj_in_feats = embedding_dim + 5
38 | for (layer_idx, out_feats) in enumerate(mlp_layers):
39 | start_mlp.add_module(
40 | f"layer{layer_idx}", nn.Linear(pos_in_feats, out_feats)
41 | )
42 | pos_mlp.add_module(f"layer{layer_idx}", nn.Linear(pos_in_feats, out_feats))
43 | traj_mlp.add_module(
44 | f"layer{layer_idx}", nn.Linear(traj_in_feats, out_feats)
45 | )
46 | if layer_idx < len(mlp_layers) - 1:
47 | start_mlp.add_module(f"relu{layer_idx}", nn.ReLU())
48 | pos_mlp.add_module(f"relu{layer_idx}", nn.ReLU())
49 | traj_mlp.add_module(f"relu{layer_idx}", nn.ReLU())
50 |
51 | pos_in_feats = out_feats
52 | traj_in_feats = out_feats
53 |
54 | self.start_mlp = start_mlp
55 | self.pos_mlp = pos_mlp
56 | self.traj_mlp = traj_mlp
57 |
58 | d_model = mlp_layers[-1]
59 | self.d_model = d_model
60 | if b2v:
61 | self.register_buffer("b2v_mask", self.generate_b2v_mask())
62 | else:
63 | self.register_buffer("b2vpp_mask", self.generate_b2vpp_mask())
64 |
65 | encoder_layer = nn.TransformerEncoderLayer(
66 | d_model, nhead, dim_feedforward, dropout
67 | )
68 | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
69 |
70 | self.traj_classifier = nn.Linear(d_model, n_player_labels)
71 | self.traj_classifier.weight.data.uniform_(-initrange, initrange)
72 | self.traj_classifier.bias.data.zero_()
73 |
74 | def generate_b2v_mask(self):
75 | seq_len = self.seq_len
76 | n_players = self.n_players
77 | sz = seq_len * n_players
78 | mask = torch.zeros(sz, sz)
79 | for step in range(seq_len):
80 | start = step * n_players
81 | stop = start + n_players
82 | mask[start:stop, :stop] = 1
83 |
84 | mask = mask.masked_fill(mask == 0, float("-inf"))
85 | mask = mask.masked_fill(mask == 1, float(0.0))
86 | return mask
87 |
88 | def generate_b2vpp_mask(self):
89 | n_players = self.n_players
90 | tri_sz = 2 * (self.seq_len * n_players)
91 | sz = tri_sz + n_players
92 | mask = torch.zeros(sz, sz)
93 | mask[:tri_sz, :tri_sz] = torch.tril(torch.ones(tri_sz, tri_sz))
94 | mask[:, -n_players:] = 1
95 | mask = mask.masked_fill(mask == 0, float("-inf"))
96 | mask = mask.masked_fill(mask == 1, float(0.0))
97 | return mask
98 |
99 | def forward(self, tensors):
100 | device = list(self.player_embedding.parameters())[0].device
101 |
102 | player_embeddings = self.player_embedding(
103 | tensors["player_idxs"].flatten().to(device)
104 | )
105 | if self.sigmoid == "logistic":
106 | player_embeddings = torch.sigmoid(player_embeddings)
107 | elif self.sigmoid == "tanh":
108 | player_embeddings = torch.tanh(player_embeddings)
109 |
110 | player_xs = tensors["player_xs"].flatten().unsqueeze(1).to(device)
111 | player_ys = tensors["player_ys"].flatten().unsqueeze(1).to(device)
112 | player_hoop_sides = (
113 | tensors["player_hoop_sides"].flatten().unsqueeze(1).to(device)
114 | )
115 | player_pos = torch.cat(
116 | [
117 | player_embeddings,
118 | player_xs,
119 | player_ys,
120 | player_hoop_sides,
121 | ],
122 | dim=1,
123 | )
124 | pos_feats = self.pos_mlp(player_pos) * math.sqrt(self.d_model)
125 | if self.b2v:
126 | outputs = self.transformer(pos_feats.unsqueeze(1), self.b2v_mask)
127 | preds = self.traj_classifier(outputs.squeeze(1))
128 |
129 | else:
130 | start_feats = self.start_mlp(player_pos[: self.n_players]) * math.sqrt(
131 | self.d_model
132 | )
133 |
134 | player_x_diffs = tensors["player_x_diffs"].flatten().unsqueeze(1).to(device)
135 | player_y_diffs = tensors["player_y_diffs"].flatten().unsqueeze(1).to(device)
136 | player_trajs = torch.cat(
137 | [
138 | player_embeddings,
139 | player_xs + player_x_diffs,
140 | player_ys + player_y_diffs,
141 | player_hoop_sides,
142 | player_x_diffs,
143 | player_y_diffs,
144 | ],
145 | dim=1,
146 | )
147 | trajs_feats = self.traj_mlp(player_trajs) * math.sqrt(self.d_model)
148 |
149 | combined = torch.zeros(
150 | 2 * len(pos_feats) + self.n_players, self.d_model
151 | ).to(device)
152 | combined[: -self.n_players : 2] = pos_feats
153 | combined[1 : -self.n_players : 2] = trajs_feats
154 | combined[-self.n_players :] = start_feats
155 |
156 | outputs = self.transformer(combined.unsqueeze(1), self.b2vpp_mask)
157 | preds = self.traj_classifier(outputs.squeeze(1)[: -self.n_players : 2])
158 |
159 | return preds
160 |
--------------------------------------------------------------------------------
/baller2vecplusplus_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from settings import COURT_LENGTH, COURT_WIDTH, GAMES_DIR
5 | from torch.utils.data import Dataset
6 |
7 | # The raw data is recorded at 25 Hz.
8 | RAW_DATA_HZ = 25
9 |
10 |
11 | class Baller2VecPlusPlusDataset(Dataset):
12 | def __init__(
13 | self,
14 | hz,
15 | secs,
16 | N,
17 | player_traj_n,
18 | max_player_move,
19 | gameids,
20 | starts,
21 | mode,
22 | n_player_ids,
23 | ):
24 | self.skip = int(RAW_DATA_HZ / hz)
25 | self.chunk_size = int(RAW_DATA_HZ * secs)
26 | self.seq_len = self.chunk_size // self.skip
27 |
28 | self.N = N
29 | self.n_players = 10
30 | self.gameids = gameids
31 | self.starts = starts
32 | self.mode = mode
33 | self.n_player_ids = n_player_ids
34 |
35 | self.player_traj_n = player_traj_n
36 | self.max_player_move = max_player_move
37 | self.player_traj_bins = np.linspace(
38 | -max_player_move, max_player_move, player_traj_n - 1
39 | )
40 |
41 | def __len__(self):
42 | return self.N
43 |
44 | def get_sample(self, X, start):
45 | # Downsample.
46 | seq_data = X[start : start + self.chunk_size : self.skip]
47 |
48 | keep_players = np.random.choice(np.arange(10), self.n_players, False)
49 | if self.mode in {"valid", "test"}:
50 | keep_players.sort()
51 |
52 | # End sequence early if there is a position glitch. Often happens when there was
53 | # a break in the game, but glitches also happen for other reasons. See
54 | # glitch_example.py for an example.
55 | player_xs = seq_data[:, 20:30][:, keep_players]
56 | player_ys = seq_data[:, 30:40][:, keep_players]
57 | player_x_diffs = np.diff(player_xs, axis=0)
58 | player_y_diffs = np.diff(player_ys, axis=0)
59 | try:
60 | glitch_x_break = np.where(
61 | np.abs(player_x_diffs) > 1.2 * self.max_player_move
62 | )[0].min()
63 | except ValueError:
64 | glitch_x_break = len(seq_data)
65 |
66 | try:
67 | glitch_y_break = np.where(
68 | np.abs(player_y_diffs) > 1.2 * self.max_player_move
69 | )[0].min()
70 | except ValueError:
71 | glitch_y_break = len(seq_data)
72 |
73 | seq_break = min(glitch_x_break, glitch_y_break)
74 | seq_data = seq_data[:seq_break]
75 |
76 | player_idxs = seq_data[:, 10:20][:, keep_players].astype(int)
77 | player_xs = seq_data[:, 20:30][:, keep_players]
78 | player_ys = seq_data[:, 30:40][:, keep_players]
79 | player_hoop_sides = seq_data[:, 40:50][:, keep_players].astype(int)
80 |
81 | # Randomly rotate the court because the hoop direction is arbitrary.
82 | if (self.mode == "train") and (np.random.random() < 0.5):
83 | player_xs = COURT_LENGTH - player_xs
84 | player_ys = COURT_WIDTH - player_ys
85 | player_hoop_sides = (player_hoop_sides + 1) % 2
86 |
87 | # Get player trajectories.
88 | player_x_diffs = np.diff(player_xs, axis=0)
89 | player_y_diffs = np.diff(player_ys, axis=0)
90 |
91 | player_traj_rows = np.digitize(player_y_diffs, self.player_traj_bins)
92 | player_traj_cols = np.digitize(player_x_diffs, self.player_traj_bins)
93 | player_trajs = player_traj_rows * self.player_traj_n + player_traj_cols
94 |
95 | return {
96 | "player_idxs": torch.LongTensor(player_idxs[: seq_break - 1]),
97 | "player_xs": torch.Tensor(player_xs[: seq_break - 1]),
98 | "player_ys": torch.Tensor(player_ys[: seq_break - 1]),
99 | "player_x_diffs": torch.Tensor(player_x_diffs),
100 | "player_y_diffs": torch.Tensor(player_y_diffs),
101 | "player_hoop_sides": torch.Tensor(player_hoop_sides[: seq_break - 1]),
102 | "player_trajs": torch.LongTensor(player_trajs),
103 | }
104 |
105 | def __getitem__(self, idx):
106 | if self.mode == "train":
107 | gameid = np.random.choice(self.gameids)
108 |
109 | elif self.mode in {"valid", "test"}:
110 | gameid = self.gameids[idx]
111 |
112 | X = np.load(f"{GAMES_DIR}/{gameid}_X.npy")
113 |
114 | if self.mode == "train":
115 | start = np.random.randint(len(X) - self.chunk_size)
116 |
117 | elif self.mode in {"valid", "test"}:
118 | start = self.starts[idx]
119 |
120 | return self.get_sample(X, start)
121 |
--------------------------------------------------------------------------------
/court.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/court.png
--------------------------------------------------------------------------------
/events.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/events.zip
--------------------------------------------------------------------------------
/generate_game_numpy_arrays.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import numpy as np
3 | import pandas as pd
4 | import pickle
5 | import py7zr
6 | import shutil
7 |
8 | from settings import *
9 |
10 | HALF_COURT_LENGTH = COURT_LENGTH // 2
11 | THRESHOLD = 1.0
12 |
13 |
14 | def game_name2gameid_worker(game_7zs, queue):
15 | game_names = []
16 | gameids = []
17 | for game_7z in game_7zs:
18 | game_name = game_7z.split(".7z")[0]
19 | game_names.append(game_name)
20 | try:
21 | archive = py7zr.SevenZipFile(f"{TRACKING_DIR}/{game_7z}", mode="r")
22 | archive.extractall(path=f"{TRACKING_DIR}/{game_name}")
23 | archive.close()
24 | except AttributeError:
25 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}")
26 | gameids.append("N/A")
27 | continue
28 |
29 | try:
30 | gameids.append(os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0])
31 | except IndexError:
32 | gameids.append("N/A")
33 |
34 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}")
35 |
36 | queue.put((game_names, gameids))
37 |
38 |
39 | def get_game_name2gameid_map():
40 | q = multiprocessing.Queue()
41 |
42 | dir_fs = os.listdir(TRACKING_DIR)
43 | all_game_7zs = [dir_f for dir_f in dir_fs if dir_f.endswith(".7z")]
44 | processes = multiprocessing.cpu_count()
45 | game_7zs_per_process = int(np.ceil(len(all_game_7zs) / processes))
46 | jobs = []
47 | for i in range(processes):
48 | start = i * game_7zs_per_process
49 | end = start + game_7zs_per_process
50 | game_7zs = all_game_7zs[start:end]
51 | p = multiprocessing.Process(target=game_name2gameid_worker, args=(game_7zs, q))
52 | jobs.append(p)
53 | p.start()
54 |
55 | all_game_names = []
56 | all_gameids = []
57 | for _ in jobs:
58 | (game_names, gameids) = q.get()
59 | all_game_names.extend(game_names)
60 | all_gameids.extend(gameids)
61 |
62 | for p in jobs:
63 | p.join()
64 |
65 | df = pd.DataFrame.from_dict({"game_name": all_game_names, "gameid": all_gameids})
66 | home_dir = os.path.expanduser("~")
67 | df.to_csv(f"{home_dir}/test.csv", index=False)
68 |
69 |
70 | def playerid2player_idx_map_worker(game_7zs, queue):
71 | playerid2props = {}
72 | for game_7z in game_7zs:
73 | game_name = game_7z.split(".7z")[0]
74 | try:
75 | archive = py7zr.SevenZipFile(f"{TRACKING_DIR}/{game_7z}", mode="r")
76 | archive.extractall(path=f"{TRACKING_DIR}/{game_name}")
77 | archive.close()
78 | except AttributeError:
79 | print(f"{game_name}\nBusted.", flush=True)
80 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}")
81 | continue
82 |
83 | try:
84 | gameid = os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0]
85 | except IndexError:
86 | print(f"No tracking data for {game_name}.", flush=True)
87 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}")
88 | continue
89 |
90 | df_tracking = pd.read_json(f"{TRACKING_DIR}/{game_name}/{gameid}.json")
91 | event = df_tracking["events"].iloc[0]
92 | players = event["home"]["players"] + event["visitor"]["players"]
93 | for player in players:
94 | playerid = player["playerid"]
95 | playerid2props[playerid] = {
96 | "name": " ".join([player["firstname"], player["lastname"]]),
97 | }
98 |
99 | queue.put(playerid2props)
100 |
101 |
102 | def get_playerid2player_idx_map():
103 | q = multiprocessing.Queue()
104 |
105 | dir_fs = os.listdir(TRACKING_DIR)
106 | all_game_7zs = [dir_f for dir_f in dir_fs if dir_f.endswith(".7z")]
107 | processes = multiprocessing.cpu_count()
108 | game_7zs_per_process = int(np.ceil(len(all_game_7zs) / processes))
109 | jobs = []
110 | for i in range(processes):
111 | start = i * game_7zs_per_process
112 | end = start + game_7zs_per_process
113 | game_7zs = all_game_7zs[start:end]
114 | p = multiprocessing.Process(
115 | target=playerid2player_idx_map_worker, args=(game_7zs, q)
116 | )
117 | jobs.append(p)
118 | p.start()
119 |
120 | playerid2props = {}
121 | for _ in jobs:
122 | playerid2props.update(q.get())
123 |
124 | for p in jobs:
125 | p.join()
126 |
127 | playerid2player_idx = {}
128 | player_idx2props = {}
129 | for (player_idx, playerid) in enumerate(playerid2props):
130 | playerid2player_idx[playerid] = player_idx
131 | player_idx2props[player_idx] = playerid2props[playerid]
132 | player_idx2props[player_idx]["playerid"] = playerid
133 |
134 | return (playerid2player_idx, player_idx2props)
135 |
136 |
137 | def get_game_time(game_clock_secs, period):
138 | period_secs = 720 if period <= 4 else 300
139 | period_time = period_secs - game_clock_secs
140 | if period <= 4:
141 | return (period - 1) * 720 + period_time
142 | else:
143 | return 4 * 720 + (period - 5) * 300 + period_time
144 |
145 |
146 | def get_shot_times_worker(game_7zs, queue):
147 | shot_times = {}
148 | for game_7z in game_7zs:
149 | game_name = game_7z.split(".7z")[0]
150 | try:
151 | gameid = os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0]
152 | except FileNotFoundError:
153 | continue
154 |
155 | df_events = pd.read_csv(f"{EVENTS_DIR}/{gameid}.csv")
156 | game_shot_times = {}
157 | for (row_idx, row) in df_events.iterrows():
158 | period = row["PERIOD"]
159 | game_clock = row["PCTIMESTRING"].split(":")
160 | game_clock_secs = 60 * int(game_clock[0]) + int(game_clock[1])
161 | game_time = get_game_time(game_clock_secs, period)
162 |
163 | if (row["EVENTMSGTYPE"] == 1) or (row["EVENTMSGTYPE"] == 2):
164 | game_shot_times[game_time] = row["PLAYER1_TEAM_ABBREVIATION"]
165 |
166 | shot_times[gameid] = game_shot_times
167 |
168 | queue.put(shot_times)
169 |
170 |
171 | def get_shot_times():
172 | q = multiprocessing.Queue()
173 |
174 | dir_fs = os.listdir(TRACKING_DIR)
175 | all_game_7zs = [dir_f for dir_f in dir_fs if dir_f.endswith(".7z")]
176 | processes = multiprocessing.cpu_count()
177 | game_7zs_per_process = int(np.ceil(len(all_game_7zs) / processes))
178 | jobs = []
179 | for i in range(processes):
180 | start = i * game_7zs_per_process
181 | end = start + game_7zs_per_process
182 | game_7zs = all_game_7zs[start:end]
183 | p = multiprocessing.Process(target=get_shot_times_worker, args=(game_7zs, q))
184 | jobs.append(p)
185 | p.start()
186 |
187 | shot_times = {}
188 | for _ in jobs:
189 | shot_times.update(q.get())
190 |
191 | for p in jobs:
192 | p.join()
193 |
194 | return shot_times
195 |
196 |
197 | def fill_in_periods(game_hoop_sides):
198 | for p1 in range(4):
199 | if p1 not in game_hoop_sides:
200 | continue
201 |
202 | adder = 1 if p1 <= 2 else 3
203 | p2 = (p1 % 2) + adder
204 | if p2 not in game_hoop_sides:
205 | game_hoop_sides[p2] = game_hoop_sides[p1].copy()
206 |
207 | period_hoop_sides = list(game_hoop_sides[p1].items())
208 | swapped = {
209 | period_hoop_sides[0][0]: period_hoop_sides[1][1],
210 | period_hoop_sides[1][0]: period_hoop_sides[0][1],
211 | }
212 |
213 | p2s = [3, 4] if p1 <= 2 else [1, 2]
214 | for p2 in p2s:
215 | if p2 not in game_hoop_sides:
216 | game_hoop_sides[p2] = swapped.copy()
217 |
218 | return game_hoop_sides
219 |
220 |
221 | def check_periods(game_hoop_sides, team, game):
222 | for period in range(4):
223 | if period in {2, 4}:
224 | if period - 1 in game_hoop_sides:
225 | assert (
226 | game_hoop_sides[period - 1][team] == game_hoop_sides[period][team]
227 | ), f"{team} has different sides in periods {period} and {period - 1} of {game}."
228 |
229 | if period in {3, 4}:
230 | for first_half in [1, 2]:
231 | if first_half in game_hoop_sides:
232 | assert (
233 | game_hoop_sides[first_half][team]
234 | != game_hoop_sides[period][team]
235 | ), f"{team} has same side in periods {first_half} and {period} of {game}."
236 |
237 |
238 | def get_game_hoop_sides(teams, hoop_side_counts, game):
239 | [team_a, team_b] = list(teams)
240 | game_hoop_sides = {period: {} for period in hoop_side_counts}
241 | do_check = False
242 | periods = list(hoop_side_counts)
243 | periods.sort()
244 | for period in periods:
245 | if len(hoop_side_counts[period]) == 0:
246 | print(f"No shooting data for {period} of {game}.")
247 | continue
248 |
249 | for team in teams:
250 | if team not in hoop_side_counts[period]:
251 | print(f"Missing {team} for {period} of {game}.")
252 | hoop_side_counts[period][team] = {0: 0, COURT_LENGTH: 0}
253 |
254 | l_count = hoop_side_counts[period][team][0]
255 | r_count = hoop_side_counts[period][team][COURT_LENGTH]
256 | if l_count > r_count:
257 | game_hoop_sides[period][team] = 0
258 | elif l_count < r_count:
259 | game_hoop_sides[period][team] = COURT_LENGTH
260 | else:
261 | do_check = True
262 |
263 | if do_check:
264 | team_in = team_a if team_a in game_hoop_sides[period] else team_b
265 | team_out = team_a if team_a not in game_hoop_sides[period] else team_b
266 | hoop_side = game_hoop_sides[period][team_in]
267 | if hoop_side == 0:
268 | game_hoop_sides[period][team_out] = 0
269 | else:
270 | game_hoop_sides[period][team_out] = COURT_LENGTH
271 |
272 | do_check = False
273 |
274 | assert (
275 | game_hoop_sides[period][team_a] != game_hoop_sides[period][team_b]
276 | ), f"{team_a} and {team_b} have same side in {period} of {game}."
277 |
278 | game_hoop_sides = fill_in_periods(game_hoop_sides)
279 |
280 | for team in teams:
281 | check_periods(game_hoop_sides, team, game)
282 |
283 | return game_hoop_sides
284 |
285 |
286 | def get_hoop_sides_worker(game_7zs, queue):
287 | hoop_sides = {}
288 | for game_7z in game_7zs:
289 | game_name = game_7z.split(".7z")[0]
290 |
291 | try:
292 | gameid = os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0]
293 | except FileNotFoundError:
294 | continue
295 |
296 | df_tracking = pd.read_json(f"{TRACKING_DIR}/{game_name}/{gameid}.json")
297 | hoop_side_counts = {}
298 | used_game_times = set()
299 | teams = set()
300 | for tracking_event in df_tracking["events"]:
301 | for moment in tracking_event["moments"]:
302 | period = moment[0]
303 | game_clock = moment[2]
304 | game_time = int(get_game_time(game_clock, period))
305 | if (game_time in shot_times[gameid]) and (
306 | game_time not in used_game_times
307 | ):
308 | ball_x = moment[5][0][2]
309 |
310 | if ball_x < HALF_COURT_LENGTH:
311 | hoop_side = 0
312 | else:
313 | hoop_side = COURT_LENGTH
314 |
315 | if period not in hoop_side_counts:
316 | hoop_side_counts[period] = {}
317 |
318 | shooting_team = shot_times[gameid][game_time]
319 | if shooting_team not in hoop_side_counts[period]:
320 | hoop_side_counts[period][shooting_team] = {
321 | 0: 0,
322 | COURT_LENGTH: 0,
323 | }
324 |
325 | hoop_side_counts[period][shooting_team][hoop_side] += 1
326 | used_game_times.add(game_time)
327 | teams.add(shooting_team)
328 |
329 | if len(teams) == 0:
330 | print(f"The moments in the {game_name} JSON are empty.", flush=True)
331 | continue
332 |
333 | hoop_sides[gameid] = get_game_hoop_sides(teams, hoop_side_counts, game_name)
334 |
335 | queue.put(hoop_sides)
336 |
337 |
338 | def get_team_hoop_sides():
339 | q = multiprocessing.Queue()
340 |
341 | dir_fs = os.listdir(TRACKING_DIR)
342 | all_game_7zs = [dir_f for dir_f in dir_fs if dir_f.endswith(".7z")]
343 | processes = multiprocessing.cpu_count()
344 | game_7zs_per_process = int(np.ceil(len(all_game_7zs) / processes))
345 | jobs = []
346 | for i in range(processes):
347 | start = i * game_7zs_per_process
348 | end = start + game_7zs_per_process
349 | game_7zs = all_game_7zs[start:end]
350 | p = multiprocessing.Process(target=get_hoop_sides_worker, args=(game_7zs, q))
351 | jobs.append(p)
352 | p.start()
353 |
354 | hoop_sides = {}
355 | for _ in jobs:
356 | hoop_sides.update(q.get())
357 |
358 | for p in jobs:
359 | p.join()
360 |
361 | return hoop_sides
362 |
363 |
364 | def get_event_stream(gameid):
365 | df_events = pd.read_csv(f"{EVENTS_DIR}/{gameid}.csv")
366 | df_events = df_events.fillna("")
367 | df_events["DESCRIPTION"] = (
368 | df_events["HOMEDESCRIPTION"]
369 | + " "
370 | + df_events["NEUTRALDESCRIPTION"]
371 | + " "
372 | + df_events["VISITORDESCRIPTION"]
373 | )
374 |
375 | # Posession times.
376 | # EVENTMSGTYPE descriptions can be found at: https://github.com/rd11490/NBA_Tutorials/tree/master/analyze_play_by_play.
377 | event_col = "EVENTMSGTYPE"
378 | description_col = "DESCRIPTION"
379 | player1_team_col = "PLAYER1_TEAM_ABBREVIATION"
380 |
381 | teams = list(df_events[player1_team_col].unique())
382 | teams.sort()
383 | teams = teams[1:] if len(teams) > 2 else teams
384 |
385 | event = None
386 | score = "0 - 0"
387 | # pos_team is the team that had possession prior to the event.
388 | (pos_team, pos_team_idx) = (None, None)
389 | jump_ball_team_idx = None
390 | # I think most of these are technical fouls.
391 | skip_fouls = {10, 11, 16, 19}
392 |
393 | events = set()
394 | pos_stream = []
395 | event_stream = []
396 | for (row_idx, row) in df_events.iterrows():
397 | period = row["PERIOD"]
398 | game_clock = row["PCTIMESTRING"].split(":")
399 | game_clock_secs = 60 * int(game_clock[0]) + int(game_clock[1])
400 | game_time = get_game_time(game_clock_secs, period)
401 |
402 | description = row[description_col].lower().strip()
403 |
404 | eventmsgtype = row[event_col]
405 |
406 | # Don't know.
407 | if eventmsgtype == 18:
408 | continue
409 |
410 | # Blank line.
411 | elif eventmsgtype == 14:
412 | continue
413 |
414 | # End of a period.
415 | elif eventmsgtype == 13:
416 | if period == 4:
417 | jump_ball_team_idx = None
418 |
419 | continue
420 |
421 | # Start of a period.
422 | elif eventmsgtype == 12:
423 | if 2 <= period <= 4:
424 | if period == 4:
425 | pos_team_idx = jump_ball_team_idx
426 | else:
427 | pos_team_idx = (jump_ball_team_idx + 1) % 2
428 |
429 | elif 6 <= period:
430 | pos_team_idx = (jump_ball_team_idx + (period - 5)) % 2
431 |
432 | continue
433 |
434 | # Ejection.
435 | elif eventmsgtype == 11:
436 | continue
437 |
438 | # Jump ball.
439 | elif eventmsgtype == 10:
440 | if int(row["PLAYER3_ID"]) in TEAM_ID2PROPS:
441 | pos_team = TEAM_ID2PROPS[int(row["PLAYER3_ID"])]["abbreviation"]
442 | else:
443 | pos_team = row["PLAYER3_TEAM_ABBREVIATION"]
444 |
445 | if pos_team == teams[0]:
446 | pos_team_idx = 0
447 | else:
448 | pos_team_idx = 1
449 |
450 | if (period in {1, 5}) and (jump_ball_team_idx is None):
451 | jump_ball_team_idx = pos_team_idx
452 |
453 | continue
454 |
455 | # Timeout.
456 | elif eventmsgtype == 9:
457 | # TV timeout?
458 | if description == "":
459 | continue
460 |
461 | event = "timeout"
462 |
463 | # Substitution.
464 | elif eventmsgtype == 8:
465 | continue
466 |
467 | # Violation.
468 | elif eventmsgtype == 7:
469 | # With 35 seconds left in the fourth period, there was a kicked ball
470 | # violation attributed to Wayne Ellington of the Brooklyn Nets, but the
471 | # following event is a shot by the Nets, which means possession never changed.
472 | if (gameid == "0021500414") and (row_idx == 427):
473 | continue
474 |
475 | # Goaltending is considered a made shot, so the following event is always the
476 | # made shot event.
477 | if "goaltending" in description:
478 | score = row["SCORE"] if row["SCORE"] else score
479 | continue
480 | # Jump ball violations have weird possession rules.
481 | elif "jump ball" in description:
482 | if row["PLAYER1_TEAM_ABBREVIATION"] == teams[0]:
483 | pos_team_idx = 1
484 | else:
485 | pos_team_idx = 0
486 |
487 | pos_team = teams[pos_team_idx]
488 | if (period == 1) and (game_time == 0):
489 | jump_ball_team_idx = pos_team_idx
490 |
491 | continue
492 |
493 | else:
494 | if row[player1_team_col] == teams[pos_team_idx]:
495 | event = "violation_offense"
496 | pos_team_idx = (pos_team_idx + 1) % 2
497 | else:
498 | event = "violation_defense"
499 |
500 | # Foul.
501 | elif eventmsgtype == 6:
502 | # Skip weird fouls.
503 | if row["EVENTMSGACTIONTYPE"] in skip_fouls:
504 | score = row["SCORE"] if row["SCORE"] else score
505 | continue
506 | else:
507 | if row[player1_team_col] == teams[pos_team_idx]:
508 | event = "offensive_foul"
509 | pos_team_idx = (pos_team_idx + 1) % 2
510 | else:
511 | event = "defensive_foul"
512 |
513 | # Turnover.
514 | elif eventmsgtype == 5:
515 | if "steal" in description:
516 | event = "steal"
517 | elif "goaltending" in description:
518 | event = "goaltending_offense"
519 | elif (
520 | ("violation" in description)
521 | or ("dribble" in description)
522 | or ("traveling" in description)
523 | ):
524 | event = "violation_offense"
525 | else:
526 | event = "turnover"
527 |
528 | # Team turnover.
529 | if row[player1_team_col] == "":
530 | team_id = int(row["PLAYER1_ID"])
531 | team_abb = TEAM_ID2PROPS[team_id]["abbreviation"]
532 | else:
533 | team_abb = row[player1_team_col]
534 |
535 | pos_team_idx = 1 if team_abb == teams[0] else 0
536 |
537 | # Rebound.
538 | elif eventmsgtype == 4:
539 | # With 17 seconds left in the first period, Spencer Hawes missed a tip in,
540 | # which was rebounded by DeAndre Jordan. The tip in is recorded as a rebound
541 | # and a missed shot for Hawes. All three events have the same timestamp,
542 | # which seems to have caused the order of the events to be slightly shuffled
543 | # with the Jordan rebound occurring before the tip in.
544 | if (gameid == "0021500550") and (row_idx == 97):
545 | continue
546 |
547 | # Team rebound.
548 | if row[player1_team_col] == "":
549 | team_id = int(row["PLAYER1_ID"])
550 | team_abb = TEAM_ID2PROPS[team_id]["abbreviation"]
551 | if team_abb == teams[pos_team_idx]:
552 | event = "rebound_offense"
553 | else:
554 | event = "rebound_defense"
555 | pos_team_idx = (pos_team_idx + 1) % 2
556 |
557 | elif row[player1_team_col] == teams[pos_team_idx]:
558 | event = "rebound_offense"
559 | else:
560 | event = "rebound_defense"
561 | pos_team_idx = (pos_team_idx + 1) % 2
562 |
563 | # Free throw.
564 | elif eventmsgtype == 3:
565 | # See rules for technical fouls: https://official.nba.com/rule-no-12-fouls-and-penalties/.
566 | # Possession only changes for too many players, which is extremely rare.
567 | if "technical" not in description:
568 | pos_team_idx = 0 if row[player1_team_col] == teams[0] else 1
569 | if (
570 | ("Clear Path" not in row[description_col])
571 | and ("Flagrant" not in row[description_col])
572 | and ("MISS" not in row[description_col])
573 | and (
574 | ("1 of 1" in description)
575 | or ("2 of 2" in description)
576 | or ("3 of 3" in description)
577 | )
578 | ):
579 | # Hack to handle foul shots for away from play fouls.
580 | if ((gameid == "0021500274") and (row_idx == 519)) or (
581 | (gameid == "0021500572") and (row_idx == 428)
582 | ):
583 | pass
584 | # This event is a made foul shot by Thaddeus Young of the Brooklyn
585 | # Nets following an and-one foul, so possession should have changed
586 | # to the Milwaukee Bucks. However, the next event is a made shot by
587 | # Brook Lopez (also of the Brooklyn Nets) with no event indicating a
588 | # change of possession occurring before it.
589 | elif (gameid == "0021500047") and (row_idx == 64):
590 | pass
591 | else:
592 | pos_team_idx = (pos_team_idx + 1) % 2
593 |
594 | pos_team = teams[pos_team_idx]
595 |
596 | score = row["SCORE"] if row["SCORE"] else score
597 |
598 | continue
599 |
600 | # Missed shot.
601 | elif eventmsgtype == 2:
602 | if "dunk" in description:
603 | shot_type = "dunk"
604 | elif "layup" in description:
605 | shot_type = "layup"
606 | else:
607 | shot_type = "shot"
608 |
609 | if "BLOCK" in row[description_col]:
610 | miss_type = "block"
611 | else:
612 | miss_type = "miss"
613 |
614 | event = f"{shot_type}_{miss_type}"
615 |
616 | if row[player1_team_col] != teams[pos_team_idx]:
617 | print(pos_stream[-5:])
618 | raise ValueError(f"Incorrect possession team in row {str(row_idx)}.")
619 |
620 | # Made shot.
621 | elif eventmsgtype == 1:
622 | if "dunk" in description:
623 | shot_type = "dunk"
624 | elif "layup" in description:
625 | shot_type = "layup"
626 | else:
627 | shot_type = "shot"
628 |
629 | event = f"{shot_type}_made"
630 |
631 | if row[player1_team_col] != teams[pos_team_idx]:
632 | print(pos_stream[-5:])
633 | raise ValueError(f"Incorrect possession team in row {str(row_idx)}.")
634 |
635 | pos_team_idx = (pos_team_idx + 1) % 2
636 |
637 | events.add(event)
638 | pos_stream.append(pos_team_idx)
639 |
640 | if row[player1_team_col] == "":
641 | team_id = int(row["PLAYER1_ID"])
642 | event_team = TEAM_ID2PROPS[team_id]["abbreviation"]
643 | else:
644 | event_team = row[player1_team_col]
645 |
646 | event_stream.append(
647 | {
648 | "game_time": game_time - 1,
649 | "pos_team": pos_team,
650 | "event": event,
651 | "description": description,
652 | "event_team": event_team,
653 | "score": score,
654 | }
655 | )
656 |
657 | # With 17 seconds left in the first period, Spencer Hawes missed a tip in,
658 | # which was rebounded by DeAndre Jordan. The tip in is recorded as a rebound
659 | # and a missed shot for Hawes. All three events have the same timestamp,
660 | # which seems to have caused the order of the events to be slightly shuffled
661 | # with the Jordan rebound occurring before the tip in.
662 | if (gameid == "0021500550") and (row_idx == 98):
663 | event_stream.append(
664 | {
665 | "game_time": game_time - 1,
666 | "pos_team": pos_team,
667 | "event": "rebound_defense",
668 | "description": "jordan rebound (off:2 def:5)",
669 | "event_team": "LAC",
670 | "score": score,
671 | }
672 | )
673 | pos_team_idx = (pos_team_idx + 1) % 2
674 |
675 | # This event is a missed shot by Kawhi Leonard of the San Antonio Spurs. The next
676 | # event is a missed shot by Bradley Beal of the Washington Wizards with no
677 | # event indicating a change of possession occurring before it.
678 | if (gameid == "0021500061") and (row_idx == 240):
679 | pos_team_idx = (pos_team_idx + 1) % 2
680 |
681 | pos_team = teams[pos_team_idx]
682 | score = row["SCORE"] if row["SCORE"] else score
683 |
684 | return event_stream
685 |
686 |
687 | def get_event_streams_worker(game_names, queue):
688 | gameid2event_stream = {}
689 | for (game_idx, game_name) in enumerate(game_names):
690 | try:
691 | gameid = os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0]
692 | gameid2event_stream[gameid] = get_event_stream(gameid)
693 | except IndexError:
694 | continue
695 |
696 | queue.put(gameid2event_stream)
697 |
698 |
699 | def get_event_streams():
700 | q = multiprocessing.Queue()
701 |
702 | dir_fs = os.listdir(TRACKING_DIR)
703 | all_game_names = [dir_f for dir_f in dir_fs if not dir_f.endswith(".7z")]
704 |
705 | processes = multiprocessing.cpu_count()
706 | game_names_per_process = int(np.ceil(len(all_game_names) / processes))
707 | jobs = []
708 | for i in range(processes):
709 | start = i * game_names_per_process
710 | end = start + game_names_per_process
711 | game_names = all_game_names[start:end]
712 | p = multiprocessing.Process(
713 | target=get_event_streams_worker, args=(game_names, q)
714 | )
715 | jobs.append(p)
716 | p.start()
717 |
718 | gameid2event_stream = {}
719 | for _ in jobs:
720 | gameid2event_stream.update(q.get())
721 |
722 | for p in jobs:
723 | p.join()
724 |
725 | event2event_idx = {}
726 | for event_stream in gameid2event_stream.values():
727 | for event in event_stream:
728 | event2event_idx.setdefault(event["event"], len(event2event_idx))
729 |
730 | return (event2event_idx, gameid2event_stream)
731 |
732 |
733 | def add_score_changes(X):
734 | score_diff_idx = 6
735 | period_idx = 3
736 | wall_clock_idx = -1
737 |
738 | score_change_idxs = np.where(np.diff(X[:, score_diff_idx]) != 0)[0] + 1
739 | score_changes = (
740 | X[score_change_idxs, score_diff_idx] - X[score_change_idxs - 1, score_diff_idx]
741 | )
742 |
743 | # Score changes at the half are the result of the teams changing sides.
744 | half_idx = -1
745 | if (X[:, period_idx].min() <= 2) and (X[:, period_idx].max() >= 3):
746 | period_change_idxs = np.where(np.diff(X[:, 3]) != 0)[0] + 1
747 | for period_change_idx in period_change_idxs:
748 | before_period = X[period_change_idx - 1, period_idx]
749 | after_period = X[period_change_idx, period_idx]
750 | if (before_period <= 2) and (after_period > 2):
751 | half_idx = period_change_idx
752 | break
753 |
754 | else:
755 | half_idx = -1
756 |
757 | cur_score_change_idx = 0
758 | times_to_next_score_change = []
759 | next_score_changes = []
760 | end_time = X[-1, wall_clock_idx]
761 | for (idx, row) in enumerate(X):
762 | try:
763 | if idx == score_change_idxs[cur_score_change_idx]:
764 | cur_score_change_idx += 1
765 |
766 | score_change_idx = score_change_idxs[cur_score_change_idx]
767 | next_score_time = X[score_change_idx, wall_clock_idx]
768 | if score_change_idx == half_idx:
769 | next_score_change = 0
770 | else:
771 | next_score_change = score_changes[cur_score_change_idx]
772 |
773 | except IndexError:
774 | next_score_time = end_time
775 | next_score_change = 0
776 |
777 | cur_time = X[idx, wall_clock_idx]
778 | time_to_next_score_change = (next_score_time - cur_time) / 1000
779 | times_to_next_score_change.append(time_to_next_score_change)
780 | next_score_changes.append(next_score_change)
781 |
782 | times_to_next_score_change = np.array(times_to_next_score_change)[None].T
783 | next_score_changes = np.array(next_score_changes)[None].T
784 |
785 | X = np.hstack(
786 | [
787 | X[:, :wall_clock_idx],
788 | times_to_next_score_change,
789 | next_score_changes,
790 | X[:, wall_clock_idx:],
791 | ]
792 | )
793 | return X
794 |
795 |
796 | def save_game_numpy_arrays(game_name):
797 | try:
798 | gameid = os.listdir(f"{TRACKING_DIR}/{game_name}")[0].split(".")[0]
799 | except IndexError:
800 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}")
801 | return
802 |
803 | if gameid not in gameid2event_stream:
804 | print(f"Missing gameid: {gameid}", flush=True)
805 | return
806 |
807 | df_tracking = pd.read_json(f"{TRACKING_DIR}/{game_name}/{gameid}.json")
808 | home_team = None
809 | cur_time = -1
810 | event_idx = 0
811 | game_over = False
812 | X = []
813 | y = []
814 | event_stream = gameid2event_stream[gameid]
815 | for tracking_event in df_tracking["events"]:
816 | event_id = tracking_event["eventId"]
817 | if home_team is None:
818 | home_team_id = tracking_event["home"]["teamid"]
819 | home_team = TEAM_ID2PROPS[home_team_id]["abbreviation"]
820 |
821 | moments = tracking_event["moments"]
822 | for moment in moments:
823 | period = moment[0]
824 | # Milliseconds.
825 | wall_clock = moment[1]
826 | game_clock = moment[2]
827 | shot_clock = moment[3]
828 | shot_clock = shot_clock if shot_clock else game_clock
829 |
830 | period_time = 720 - game_clock if period <= 4 else 300 - game_clock
831 | game_time = get_game_time(game_clock, period)
832 |
833 | # Moments can overlap temporally, so previously processed time points are
834 | # skipped along with clock stoppages.
835 | if game_time <= cur_time:
836 | continue
837 |
838 | while game_time > event_stream[event_idx]["game_time"]:
839 | event_idx += 1
840 | if event_idx >= len(event_stream):
841 | game_over = True
842 | break
843 |
844 | if game_over:
845 | break
846 |
847 | event = event_stream[event_idx]
848 | score = event["score"]
849 | (away_score, home_score) = (int(s) for s in score.split(" - "))
850 | home_hoop_side = hoop_sides[gameid][period][home_team]
851 | if home_hoop_side == 0:
852 | (left_score, right_score) = (home_score, away_score)
853 | else:
854 | (right_score, left_score) = (home_score, away_score)
855 |
856 | (ball_x, ball_y, ball_z) = moment[5][0][2:5]
857 | data = [
858 | game_time,
859 | period_time,
860 | shot_clock,
861 | period,
862 | left_score, # off_score,
863 | right_score, # def_score,
864 | left_score - right_score,
865 | ball_x,
866 | ball_y,
867 | ball_z,
868 | ]
869 |
870 | if len(moment[5][1:]) != 10:
871 | continue
872 |
873 | player_idxs = []
874 | player_xs = []
875 | player_ys = []
876 | player_hoop_sides = []
877 |
878 | try:
879 | for player in moment[5][1:]:
880 | player_idxs.append(playerid2player_idx[player[1]])
881 | player_xs.append(player[2])
882 | player_ys.append(player[3])
883 | hoop_side = hoop_sides[gameid][period][
884 | TEAM_ID2PROPS[player[0]]["abbreviation"]
885 | ]
886 | player_hoop_sides.append(int(hoop_side == COURT_LENGTH))
887 |
888 | except KeyError:
889 | if player[1] == 0:
890 | print(
891 | f"Bad player in event {event_id} for {game_name}.", flush=True
892 | )
893 | continue
894 |
895 | else:
896 | raise KeyError
897 |
898 | order = np.argsort(player_idxs)
899 | for idx in order:
900 | data.append(player_idxs[idx])
901 |
902 | for idx in order:
903 | data.append(player_xs[idx])
904 |
905 | for idx in order:
906 | data.append(player_ys[idx])
907 |
908 | for idx in order:
909 | data.append(player_hoop_sides[idx])
910 |
911 | data.append(event_idx)
912 | data.append(wall_clock)
913 |
914 | if len(data) != 52:
915 | raise ValueError
916 |
917 | X.append(np.array(data))
918 | y.append(event2event_idx.setdefault(event["event"], len(event2event_idx)))
919 | cur_time = game_time
920 |
921 | if game_over:
922 | break
923 |
924 | X = np.stack(X)
925 | y = np.array(y)
926 |
927 | X = add_score_changes(X)
928 |
929 | np.save(f"{GAMES_DIR}/{gameid}_X.npy", X)
930 | np.save(f"{GAMES_DIR}/{gameid}_y.npy", y)
931 |
932 |
933 | def save_numpy_arrays_worker(game_names):
934 | for game_name in game_names:
935 | try:
936 | save_game_numpy_arrays(game_name)
937 | except ValueError:
938 | pass
939 |
940 | shutil.rmtree(f"{TRACKING_DIR}/{game_name}")
941 |
942 |
943 | def save_numpy_arrays():
944 | dir_fs = os.listdir(TRACKING_DIR)
945 | all_game_names = [dir_f for dir_f in dir_fs if not dir_f.endswith(".7z")]
946 |
947 | processes = multiprocessing.cpu_count()
948 | game_names_per_process = int(np.ceil(len(all_game_names) / processes))
949 | jobs = []
950 | for i in range(processes):
951 | start = i * game_names_per_process
952 | end = start + game_names_per_process
953 | game_names = all_game_names[start:end]
954 | p = multiprocessing.Process(target=save_numpy_arrays_worker, args=(game_names,))
955 | jobs.append(p)
956 | p.start()
957 |
958 | for p in jobs:
959 | p.join()
960 |
961 |
962 | def player_idx2playing_time_map_worker(gameids, queue):
963 | player_idx2playing_time = {}
964 | for gameid in gameids:
965 | X = np.load(f"{GAMES_DIR}/{gameid}_X.npy")
966 | wall_clock_diffs = np.diff(X[:, -1]) / 1000
967 | all_player_idxs = X[:, 10:20].astype(int)
968 | prev_players = set(all_player_idxs[0])
969 | for (row_idx, player_idxs) in enumerate(all_player_idxs[1:]):
970 | current_players = set(player_idxs)
971 | if len(prev_players & current_players) == 10:
972 | wall_clock_diff = wall_clock_diffs[row_idx]
973 | if wall_clock_diff < THRESHOLD:
974 | for player_idx in current_players:
975 | player_idx2playing_time[player_idx] = (
976 | player_idx2playing_time.get(player_idx, 0) + wall_clock_diff
977 | )
978 |
979 | prev_players = current_players
980 |
981 | queue.put(player_idx2playing_time)
982 |
983 |
984 | def get_player_idx2playing_time_map():
985 | q = multiprocessing.Queue()
986 |
987 | all_gameids = list(set([np_f.split("_")[0] for np_f in os.listdir(GAMES_DIR)]))
988 | processes = multiprocessing.cpu_count()
989 | gameids_per_process = int(np.ceil(len(all_gameids) / processes))
990 | jobs = []
991 | for i in range(processes):
992 | start = i * gameids_per_process
993 | end = start + gameids_per_process
994 | gameids = all_gameids[start:end]
995 | p = multiprocessing.Process(
996 | target=player_idx2playing_time_map_worker, args=(gameids, q)
997 | )
998 | jobs.append(p)
999 | p.start()
1000 |
1001 | player_idx2playing_time = {}
1002 | for _ in jobs:
1003 | game_player_idx2_playing_time = q.get()
1004 | for (player_idx, playing_time) in game_player_idx2_playing_time.items():
1005 | player_idx2playing_time[player_idx] = (
1006 | player_idx2playing_time.get(player_idx, 0) + playing_time
1007 | )
1008 |
1009 | for p in jobs:
1010 | p.join()
1011 |
1012 | playing_times = list(player_idx2playing_time.values())
1013 | print(np.quantile(playing_times, [0.1, 0.2, 0.25, 0.5, 0.75, 0.8, 0.9]))
1014 | # [ 3364.6814 10270.2988 13314.768 39917.09399999
1015 | # 59131.73249999 63400.76839999 72048.72879999]
1016 | return player_idx2playing_time
1017 |
1018 |
1019 | if __name__ == "__main__":
1020 | os.makedirs(GAMES_DIR, exist_ok=True)
1021 |
1022 | (playerid2player_idx, player_idx2props) = get_playerid2player_idx_map()
1023 |
1024 | try:
1025 | baller2vec_config = pickle.load(
1026 | open(f"{DATA_DIR}/baller2vec_config.pydict", "rb")
1027 | )
1028 | player_idx2props = baller2vec_config["player_idx2props"]
1029 | event2event_idx = baller2vec_config["event2event_idx"]
1030 | playerid2player_idx = {}
1031 | for (player_idx, props) in player_idx2props.items():
1032 | playerid2player_idx[props["playerid"]] = player_idx
1033 |
1034 | except FileNotFoundError:
1035 | baller2vec_config = False
1036 |
1037 | shot_times = get_shot_times()
1038 | hoop_sides = get_team_hoop_sides()
1039 | (event2event_idx, gameid2event_stream) = get_event_streams()
1040 | if baller2vec_config:
1041 | event2event_idx = baller2vec_config["event2event_idx"]
1042 |
1043 | save_numpy_arrays()
1044 |
1045 | player_idx2playing_time = get_player_idx2playing_time_map()
1046 | for (player_idx, playing_time) in player_idx2playing_time.items():
1047 | player_idx2props[player_idx]["playing_time"] = playing_time
1048 |
1049 | if not baller2vec_config:
1050 | baller2vec_config = {
1051 | "player_idx2props": player_idx2props,
1052 | "event2event_idx": event2event_idx,
1053 | }
1054 | pickle.dump(
1055 | baller2vec_config, open(f"{DATA_DIR}/baller2vec_config.pydict", "wb")
1056 | )
1057 |
--------------------------------------------------------------------------------
/gif_cropper.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # See: https://stackoverflow.com/q/965053/1316276.
4 | # for f in *.gif; do gifsicle --crop 0,0-250,266 --output ${f%.*}_cropped.gif ${f}; done
5 | for f in *.gif; do gifsicle --crop 250,0-500,266 --output ${f%.*}_cropped.gif ${f}; done
--------------------------------------------------------------------------------
/gif_merger.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # See: https://stackoverflow.com/a/30932152/1316276.
4 | # Separate frames of train.gif.
5 | convert train.gif -coalesce a-%04d.gif
6 | # Separate frames of baller2vec.gif.
7 | convert baller2vec.gif -coalesce b-%04d.gif
8 | # Separate frames of baller2vec++.gif.
9 | convert baller2vec++.gif -coalesce c-%04d.gif
10 | # Append frames side-by-side.
11 | for f in a-*.gif; do convert $f -size 10x xc:black ${f/a/b} -size 10x xc:black ${f/a/c} +append ${f}; done
12 | # Rejoin frames.
13 | convert -loop 0 -delay 40 a-*.gif result.gif
14 | rm a-*.gif b-*.gif c-*.gif
--------------------------------------------------------------------------------
/image_cropper.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # See: https://deparkes.co.uk/2015/04/30/batch-crop-images-with-imagemagick/.
4 | mogrify -crop 250x266+250+0 *.png
5 |
--------------------------------------------------------------------------------
/paper_functions.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 |
3 | matplotlib.use("Agg")
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import pandas as pd
8 | import shutil
9 | import torch
10 | import yaml
11 |
12 | from PIL import Image, ImageDraw
13 | from settings import *
14 | from torch import nn
15 | from toy_dataset import ToyDataset
16 | from train_baller2vecplusplus import init_basketball_datasets, init_model
17 |
18 | colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
19 |
20 | shuffle_keys = [
21 | "player_idxs",
22 | "player_xs",
23 | "player_ys",
24 | "player_hoop_sides",
25 | "player_x_diffs",
26 | "player_y_diffs",
27 | "player_trajs",
28 | ]
29 |
30 |
31 | def add_grid(img, steps, highlight_center=True):
32 | # See: https://randomgeekery.org/post/2017/11/drawing-grids-with-python-and-pillow/.
33 | fill = (128, 128, 128)
34 |
35 | draw = ImageDraw.Draw(img)
36 | y_start = 0
37 | y_end = img.height
38 | step_size = int(img.width / steps)
39 |
40 | for x in range(0, img.width, step_size):
41 | line = ((x, y_start), (x, y_end))
42 | draw.line(line, fill=fill)
43 |
44 | x_start = 0
45 | x_end = img.width
46 |
47 | for y in range(0, img.height, step_size):
48 | line = ((x_start, y), (x_end, y))
49 | draw.line(line, fill=fill)
50 |
51 | if highlight_center:
52 | mid_color = (0, 0, 255)
53 | half_steps = steps // 2
54 | (mid_start, mid_end) = (
55 | half_steps * step_size,
56 | half_steps * step_size + step_size,
57 | )
58 | draw.line(((mid_start, mid_start), (mid_end, mid_start)), fill=mid_color)
59 | draw.line(((mid_start, mid_end), (mid_end, mid_end)), fill=mid_color)
60 | draw.line(((mid_start, mid_start), (mid_start, mid_end)), fill=mid_color)
61 | draw.line(((mid_end, mid_start), (mid_end, mid_end)), fill=mid_color)
62 |
63 | del draw
64 |
65 |
66 | def gen_imgs_from_sample(tensors, seq_len):
67 | # Create grid background.
68 | scale = 4
69 | img_size = 2 * seq_len + 2
70 | img_array = np.full((img_size, img_size, 3), 255, dtype=np.uint8)
71 | img = Image.fromarray(img_array).resize(
72 | (scale * img_size, scale * img_size), resample=0
73 | )
74 | add_grid(img, img_size, False)
75 |
76 | (width, height) = (5, 5)
77 | (fig, ax) = plt.subplots(figsize=(width, height))
78 | pt_scale = 1.4
79 | half_size = img_size / 2
80 |
81 | # Always give the agents the same colors.
82 | (player_idxs, p_idxs) = tensors["player_idxs"][0].sort()
83 |
84 | traj_imgs = []
85 | for time_step in range(seq_len):
86 | ax.imshow(img, zorder=0, extent=[-half_size, half_size, -half_size, half_size])
87 | ax.axis("off")
88 | ax.grid(False)
89 | for (idx, p_idx) in enumerate(p_idxs):
90 | ax.scatter(
91 | [tensors["player_xs"][time_step, p_idx] + 0.5],
92 | [tensors["player_ys"][time_step, p_idx] - 0.5],
93 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2,
94 | c=colors[idx],
95 | edgecolors="black",
96 | )
97 | plt.xlim(auto=False)
98 | plt.ylim(auto=False)
99 | ax.plot(
100 | tensors["player_xs"][: time_step + 1, p_idx] + 0.5,
101 | tensors["player_ys"][: time_step + 1, p_idx] - 0.5,
102 | c=colors[idx],
103 | )
104 |
105 | plt.subplots_adjust(0, 0, 1, 1)
106 | fig.canvas.draw()
107 | plt.cla()
108 | traj_imgs.append(
109 | Image.frombytes(
110 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()
111 | )
112 | )
113 |
114 | ax.imshow(img, zorder=0, extent=[-half_size, half_size, -half_size, half_size])
115 | ax.axis("off")
116 | ax.grid(False)
117 | for (idx, p_idx) in enumerate(p_idxs):
118 | ax.scatter(
119 | [tensors["player_xs"][0, p_idx] + 0.5],
120 | [tensors["player_ys"][0, p_idx] - 0.5],
121 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2,
122 | c=colors[idx],
123 | edgecolors="black",
124 | )
125 | plt.xlim(auto=False)
126 | plt.ylim(auto=False)
127 | ax.plot(
128 | tensors["player_xs"][:, p_idx] + 0.5,
129 | tensors["player_ys"][:, p_idx] - 0.5,
130 | c=colors[idx],
131 | )
132 |
133 | plt.subplots_adjust(0, 0, 1, 1)
134 | fig.canvas.draw()
135 | plt.cla()
136 | traj_img = Image.frombytes(
137 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()
138 | )
139 |
140 | return (traj_imgs, traj_img)
141 |
142 |
143 | def animate_toy_examples():
144 | # Initialize dataset.
145 | ds = ToyDataset()
146 | n_players = ds.n_players
147 | player_traj_n = ds.player_traj_n
148 | seq_len = ds.seq_len
149 |
150 | home_dir = os.path.expanduser("~")
151 | os.makedirs(f"{home_dir}/results", exist_ok=True)
152 |
153 | tensors = ds[0]
154 | (traj_imgs, traj_img) = gen_imgs_from_sample(tensors, seq_len)
155 | traj_imgs[0].save(
156 | f"{home_dir}/results/train.gif",
157 | save_all=True,
158 | append_images=traj_imgs[1:],
159 | duration=400,
160 | loop=0,
161 | )
162 | traj_img.save(
163 | f"{home_dir}/results/train.png",
164 | )
165 |
166 | for JOB in ["20210408161424", "20210408160343"]:
167 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}"
168 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml"))
169 |
170 | # Initialize model.
171 | device = torch.device("cuda:0")
172 | model = init_model(opts, ds).to(device)
173 | model.load_state_dict(torch.load(f"{JOB_DIR}/best_params.pth"))
174 | model.eval()
175 |
176 | with torch.no_grad():
177 | tensors = ds[0]
178 | for step in range(seq_len):
179 | preds_start = n_players * step
180 | for player_idx in range(n_players):
181 | pred_idx = preds_start + player_idx
182 | preds = model(tensors)[pred_idx]
183 | probs = torch.softmax(preds, dim=0)
184 | samp_traj = torch.multinomial(probs, 1)
185 |
186 | samp_row = samp_traj // player_traj_n
187 | samp_col = samp_traj % player_traj_n
188 |
189 | samp_x = samp_col.item() - 1
190 | samp_y = samp_row.item() - 1
191 |
192 | tensors["player_x_diffs"][step, player_idx] = samp_x
193 | tensors["player_y_diffs"][step, player_idx] = samp_y
194 | if step < seq_len - 1:
195 | tensors["player_xs"][step + 1, player_idx] = (
196 | tensors["player_xs"][step, player_idx] + samp_x
197 | )
198 | tensors["player_ys"][step + 1, player_idx] = (
199 | tensors["player_ys"][step, player_idx] + samp_y
200 | )
201 |
202 | (traj_imgs, traj_img) = gen_imgs_from_sample(tensors, seq_len)
203 | traj_imgs[0].save(
204 | f"{home_dir}/results/{JOB}.gif",
205 | save_all=True,
206 | append_images=traj_imgs[1:],
207 | duration=400,
208 | loop=0,
209 | )
210 | traj_img.save(f"{home_dir}/results/{JOB}.png")
211 |
212 | shutil.make_archive(f"{home_dir}/results", "zip", f"{home_dir}/results")
213 | shutil.rmtree(f"{home_dir}/results")
214 |
215 |
216 | def gen_imgs_from_basketball_sample(tensors, seq_len):
217 | court = plt.imread("court.png")
218 | width = 5
219 | height = width * court.shape[0] / float(court.shape[1])
220 | pt_scale = 1.4
221 | (fig, ax) = plt.subplots(figsize=(width, height))
222 |
223 | # Always give the agents the same colors.
224 | (player_idxs, p_idxs) = tensors["player_idxs"][0].sort()
225 | markers = []
226 | for p_idx in range(10):
227 | if tensors["player_hoop_sides"][0, p_idx]:
228 | markers.append("s")
229 | else:
230 | markers.append("^")
231 |
232 | traj_imgs = []
233 | for time_step in range(seq_len):
234 | ax.imshow(court, zorder=0, extent=[X_MIN, X_MAX - DIFF, Y_MAX, Y_MIN])
235 | ax.axis("off")
236 | ax.grid(False)
237 | for (idx, p_idx) in enumerate(p_idxs):
238 | ax.scatter(
239 | [tensors["player_xs"][time_step, p_idx]],
240 | [tensors["player_ys"][time_step, p_idx]],
241 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2,
242 | c=colors[idx],
243 | marker=markers[p_idx],
244 | edgecolors="black",
245 | )
246 | plt.xlim(auto=False)
247 | plt.ylim(auto=False)
248 | ax.plot(
249 | tensors["player_xs"][: time_step + 1, p_idx],
250 | tensors["player_ys"][: time_step + 1, p_idx],
251 | c=colors[idx],
252 | )
253 |
254 | plt.subplots_adjust(0, 0, 1, 1)
255 | fig.canvas.draw()
256 | traj_img = Image.frombytes(
257 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()
258 | )
259 | plt.cla()
260 | traj_imgs.append(traj_img)
261 |
262 | ax.imshow(court, zorder=0, extent=[X_MIN, X_MAX - DIFF, Y_MAX, Y_MIN])
263 | ax.axis("off")
264 | ax.grid(False)
265 | for (idx, p_idx) in enumerate(p_idxs):
266 | ax.scatter(
267 | [tensors["player_xs"][0, p_idx]],
268 | [tensors["player_ys"][0, p_idx]],
269 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2,
270 | c=colors[idx],
271 | marker=markers[p_idx],
272 | edgecolors="black",
273 | )
274 | plt.xlim(auto=False)
275 | plt.ylim(auto=False)
276 | ax.plot(
277 | tensors["player_xs"][:, p_idx],
278 | tensors["player_ys"][:, p_idx],
279 | c=colors[idx],
280 | )
281 |
282 | plt.subplots_adjust(0, 0, 1, 1)
283 | fig.canvas.draw()
284 | traj_img = Image.frombytes(
285 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()
286 | )
287 | plt.close()
288 |
289 | return (traj_imgs, traj_img)
290 |
291 |
292 | def gen_tgt_imgs_from_basketball_sample(tensors, seq_len, tgt_idx):
293 | court = plt.imread("court.png")
294 | width = 5
295 | height = width * court.shape[0] / float(court.shape[1])
296 | pt_scale = 1.4
297 | (fig, ax) = plt.subplots(figsize=(width, height))
298 |
299 | # Always give the agents the same colors.
300 | colors = []
301 | markers = []
302 | for p_idx in range(10):
303 | if p_idx == tgt_idx:
304 | colors.append("r")
305 | elif tensors["player_hoop_sides"][0, p_idx]:
306 | colors.append("white")
307 | else:
308 | colors.append("gray")
309 |
310 | if tensors["player_hoop_sides"][0, p_idx]:
311 | markers.append("s")
312 | else:
313 | markers.append("^")
314 |
315 | traj_imgs = []
316 | for time_step in range(seq_len):
317 | ax.imshow(court, zorder=0, extent=[X_MIN, X_MAX - DIFF, Y_MAX, Y_MIN])
318 | ax.axis("off")
319 | ax.grid(False)
320 | for p_idx in range(10):
321 | ax.scatter(
322 | [tensors["player_xs"][time_step, p_idx]],
323 | [tensors["player_ys"][time_step, p_idx]],
324 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2,
325 | c=colors[p_idx],
326 | marker=markers[p_idx],
327 | edgecolors="black",
328 | )
329 | plt.xlim(auto=False)
330 | plt.ylim(auto=False)
331 | ax.plot(
332 | tensors["player_xs"][: time_step + 1, p_idx],
333 | tensors["player_ys"][: time_step + 1, p_idx],
334 | c=colors[p_idx],
335 | )
336 |
337 | plt.subplots_adjust(0, 0, 1, 1)
338 | fig.canvas.draw()
339 | traj_img = Image.frombytes(
340 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()
341 | )
342 | plt.cla()
343 | traj_imgs.append(traj_img)
344 |
345 | ax.imshow(court, zorder=0, extent=[X_MIN, X_MAX - DIFF, Y_MAX, Y_MIN])
346 | ax.axis("off")
347 | ax.grid(False)
348 | for p_idx in range(10):
349 | ax.scatter(
350 | [tensors["player_xs"][0, p_idx]],
351 | [tensors["player_ys"][0, p_idx]],
352 | s=(pt_scale * matplotlib.rcParams["lines.markersize"]) ** 2,
353 | c=colors[p_idx],
354 | marker=markers[p_idx],
355 | edgecolors="black",
356 | )
357 | plt.xlim(auto=False)
358 | plt.ylim(auto=False)
359 | ax.plot(
360 | tensors["player_xs"][:, p_idx],
361 | tensors["player_ys"][:, p_idx],
362 | c=colors[p_idx],
363 | )
364 |
365 | plt.subplots_adjust(0, 0, 1, 1)
366 | fig.canvas.draw()
367 | traj_img = Image.frombytes(
368 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()
369 | )
370 | plt.close()
371 |
372 | return (traj_imgs, traj_img)
373 |
374 |
375 | def plot_generated_trajectories():
376 | models = {"baller2vec": "20210402111942", "baller2vec++": "20210402111956"}
377 | samples = 3
378 | device = torch.device("cuda:0")
379 | home_dir = os.path.expanduser("~")
380 | os.makedirs(f"{home_dir}/results", exist_ok=True)
381 | for (which_model, JOB) in models.items():
382 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}"
383 |
384 | # Load model.
385 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml"))
386 | (train_dataset, _, _, _, test_dataset, _) = init_basketball_datasets(opts)
387 | n_players = train_dataset.n_players
388 | player_traj_n = test_dataset.player_traj_n
389 | model = init_model(opts, train_dataset).to(device)
390 | model_dict = model.state_dict()
391 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth")
392 | pretrained_dict = {
393 | k: v for (k, v) in pretrained_dict.items() if k in model_dict
394 | }
395 | model_dict.update(pretrained_dict)
396 | model.load_state_dict(pretrained_dict)
397 | model.eval()
398 | seq_len = model.seq_len
399 | grid_gap = np.diff(test_dataset.player_traj_bins)[-1] / 2
400 |
401 | cand_test_idxs = []
402 | for test_idx in range(len(test_dataset.gameids)):
403 | tensors = test_dataset[test_idx]
404 | if len(tensors["player_idxs"]) == model.seq_len:
405 | cand_test_idxs.append(test_idx)
406 |
407 | player_traj_bins = np.array(list(test_dataset.player_traj_bins) + [5.5])
408 |
409 | torch.manual_seed(2010)
410 | np.random.seed(2010)
411 |
412 | test_idx = 97
413 | tensors = test_dataset[test_idx]
414 |
415 | (traj_imgs, traj_img) = gen_imgs_from_basketball_sample(tensors, seq_len)
416 | traj_imgs[0].save(
417 | f"{home_dir}/results/{test_idx}_truth.gif",
418 | save_all=True,
419 | append_images=traj_imgs[1:],
420 | duration=400,
421 | loop=0,
422 | )
423 | traj_img.save(
424 | f"{home_dir}/results/{test_idx}_truth.png",
425 | )
426 | for sample in range(samples):
427 | tensors = test_dataset[test_idx]
428 | with torch.no_grad():
429 | for step in range(seq_len):
430 | preds_start = 10 * step
431 | for player_idx in range(n_players):
432 | pred_idx = preds_start + player_idx
433 | preds = model(tensors)[pred_idx]
434 | probs = torch.softmax(preds, dim=0)
435 | samp_traj = torch.multinomial(probs, 1)
436 |
437 | samp_row = samp_traj // player_traj_n
438 | samp_col = samp_traj % player_traj_n
439 |
440 | samp_x = (
441 | player_traj_bins[samp_col]
442 | - grid_gap
443 | + np.random.uniform(-grid_gap, grid_gap)
444 | )
445 | samp_y = (
446 | player_traj_bins[samp_row]
447 | - grid_gap
448 | + np.random.uniform(-grid_gap, grid_gap)
449 | )
450 |
451 | tensors["player_x_diffs"][step, player_idx] = samp_x
452 | tensors["player_y_diffs"][step, player_idx] = samp_y
453 | if step < seq_len - 1:
454 | tensors["player_xs"][step + 1, player_idx] = (
455 | tensors["player_xs"][step, player_idx] + samp_x
456 | )
457 | tensors["player_ys"][step + 1, player_idx] = (
458 | tensors["player_ys"][step, player_idx] + samp_y
459 | )
460 |
461 | (traj_imgs, traj_img) = gen_imgs_from_basketball_sample(tensors, seq_len)
462 | traj_imgs[0].save(
463 | f"{home_dir}/results/{test_idx}_gen_{which_model}_{sample}.gif",
464 | save_all=True,
465 | append_images=traj_imgs[1:],
466 | duration=400,
467 | loop=0,
468 | )
469 | traj_img.save(
470 | f"{home_dir}/results/{test_idx}_gen_{which_model}_{sample}.png",
471 | )
472 |
473 | shutil.make_archive(f"{home_dir}/results", "zip", f"{home_dir}/results")
474 | shutil.rmtree(f"{home_dir}/results")
475 |
476 |
477 | def compare_single_player_generation():
478 | models = {"baller2vec": "20210402111942", "baller2vec++": "20210402111956"}
479 | samples = 10
480 | device = torch.device("cuda:0")
481 | home_dir = os.path.expanduser("~")
482 | os.makedirs(f"{home_dir}/results", exist_ok=True)
483 | for (which_model, JOB) in models.items():
484 | print(which_model)
485 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}"
486 |
487 | # Load model.
488 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml"))
489 | (train_dataset, _, _, _, test_dataset, _) = init_basketball_datasets(opts)
490 | n_players = train_dataset.n_players
491 | player_traj_n = test_dataset.player_traj_n
492 | model = init_model(opts, train_dataset).to(device)
493 | model_dict = model.state_dict()
494 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth")
495 | pretrained_dict = {
496 | k: v for (k, v) in pretrained_dict.items() if k in model_dict
497 | }
498 | model_dict.update(pretrained_dict)
499 | model.load_state_dict(pretrained_dict)
500 | model.eval()
501 | seq_len = model.seq_len
502 | grid_gap = np.diff(test_dataset.player_traj_bins)[-1] / 2
503 |
504 | cand_test_idxs = []
505 | for test_idx in range(len(test_dataset.gameids)):
506 | tensors = test_dataset[test_idx]
507 | if len(tensors["player_idxs"]) == model.seq_len:
508 | cand_test_idxs.append(test_idx)
509 |
510 | player_traj_bins = np.array(list(test_dataset.player_traj_bins) + [5.5])
511 |
512 | torch.manual_seed(2010)
513 | np.random.seed(2010)
514 |
515 | test_idx = 267
516 | for p_idx in range(n_players):
517 | # for p_idx in [0, 1]:
518 | print(p_idx, flush=True)
519 | swap_idxs = [p_idx, 9]
520 | tensors = test_dataset[test_idx]
521 |
522 | (traj_imgs, traj_img) = gen_tgt_imgs_from_basketball_sample(
523 | tensors, seq_len, p_idx
524 | )
525 | traj_imgs[0].save(
526 | f"{home_dir}/results/{test_idx}_{p_idx}_truth.gif",
527 | save_all=True,
528 | append_images=traj_imgs[1:],
529 | duration=400,
530 | loop=0,
531 | )
532 | traj_img.save(
533 | f"{home_dir}/results/{test_idx}_{p_idx}_truth.png",
534 | )
535 | for sample in range(samples):
536 | tensors = test_dataset[test_idx]
537 | for key in shuffle_keys:
538 | tensors[key][:, swap_idxs] = tensors[key][:, swap_idxs[::-1]]
539 |
540 | with torch.no_grad():
541 | for step in range(seq_len):
542 | preds_start = 10 * step
543 | pred_idx = preds_start + 9
544 | preds = model(tensors)[pred_idx]
545 | probs = torch.softmax(preds, dim=0)
546 | samp_traj = torch.multinomial(probs, 1)
547 |
548 | samp_row = samp_traj // player_traj_n
549 | samp_col = samp_traj % player_traj_n
550 |
551 | samp_x = (
552 | player_traj_bins[samp_col]
553 | - grid_gap
554 | + np.random.uniform(-grid_gap, grid_gap)
555 | )
556 | samp_y = (
557 | player_traj_bins[samp_row]
558 | - grid_gap
559 | + np.random.uniform(-grid_gap, grid_gap)
560 | )
561 |
562 | tensors["player_x_diffs"][step, 9] = samp_x
563 | tensors["player_y_diffs"][step, 9] = samp_y
564 | if step < seq_len - 1:
565 | tensors["player_xs"][step + 1, 9] = (
566 | tensors["player_xs"][step, 9] + samp_x
567 | )
568 | tensors["player_ys"][step + 1, 9] = (
569 | tensors["player_ys"][step, 9] + samp_y
570 | )
571 |
572 | (traj_imgs, traj_img) = gen_tgt_imgs_from_basketball_sample(
573 | tensors, seq_len, 9
574 | )
575 | traj_imgs[0].save(
576 | f"{home_dir}/results/{test_idx}_{p_idx}_gen_{which_model}_{sample}.gif",
577 | save_all=True,
578 | append_images=traj_imgs[1:],
579 | duration=400,
580 | loop=0,
581 | )
582 | traj_img.save(
583 | f"{home_dir}/results/{test_idx}_{p_idx}_gen_{which_model}_{sample}.png",
584 | )
585 |
586 | shutil.make_archive(f"{home_dir}/results", "zip", f"{home_dir}/results")
587 | shutil.rmtree(f"{home_dir}/results")
588 |
589 |
590 | def compare_time_steps():
591 | models = {"baller2vec": "20210402111942", "baller2vec++": "20210402111956"}
592 | results = {}
593 | device = torch.device("cuda:0")
594 | for (which_model, JOB) in models.items():
595 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}"
596 |
597 | # Load model.
598 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml"))
599 | (train_dataset, _, _, _, test_dataset, test_loader) = init_basketball_datasets(
600 | opts
601 | )
602 | n_players = train_dataset.n_players
603 | model = init_model(opts, train_dataset).to(device)
604 | # See: https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3.
605 | model_dict = model.state_dict()
606 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth")
607 | pretrained_dict = {
608 | k: v for (k, v) in pretrained_dict.items() if k in model_dict
609 | }
610 | model_dict.update(pretrained_dict)
611 | model.load_state_dict(pretrained_dict)
612 | model.eval()
613 | seq_len = model.seq_len
614 |
615 | criterion = nn.CrossEntropyLoss()
616 |
617 | test_loss_steps = np.zeros(seq_len)
618 | test_loss_all = 0.0
619 | n_test = 0
620 | with torch.no_grad():
621 | for test_tensors in test_loader:
622 | # Skip bad sequences.
623 | if len(test_tensors["player_idxs"]) < seq_len:
624 | continue
625 |
626 | player_trajs = test_tensors["player_trajs"].flatten()
627 | labels = player_trajs.to(device)
628 | preds = model(test_tensors)
629 |
630 | test_loss_all += criterion(preds, labels).item()
631 | for time_step in range(seq_len):
632 | start = time_step * n_players
633 | stop = start + n_players
634 | test_loss_steps[time_step] += criterion(
635 | preds[start:stop], labels[start:stop]
636 | ).item()
637 |
638 | n_test += 1
639 |
640 | print(test_loss_steps / n_test)
641 | print(test_loss_all / n_test)
642 | results[which_model] = test_loss_steps / n_test
643 |
644 | print((results["baller2vec"] - results["baller2vec++"]) / results["baller2vec"])
645 |
646 |
647 | def test_permutation_invariance():
648 | JOB = "20210402111956"
649 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}"
650 |
651 | device = torch.device("cuda:0")
652 |
653 | # Load model.
654 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml"))
655 | (train_dataset, _, _, _, test_dataset, test_loader) = init_basketball_datasets(opts)
656 | n_players = train_dataset.n_players
657 | model = init_model(opts, train_dataset).to(device)
658 | # See: https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3.
659 | model_dict = model.state_dict()
660 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth")
661 | pretrained_dict = {k: v for (k, v) in pretrained_dict.items() if k in model_dict}
662 | model_dict.update(pretrained_dict)
663 | model.load_state_dict(pretrained_dict)
664 | model.eval()
665 | seq_len = model.seq_len
666 |
667 | criterion = nn.CrossEntropyLoss()
668 |
669 | shuffles = 10
670 | np.random.seed(2010)
671 | data = []
672 | avg_diffs = []
673 | with torch.no_grad():
674 | for test_tensors in test_loader:
675 | # Skip bad sequences.
676 | if len(test_tensors["player_idxs"]) < seq_len:
677 | continue
678 |
679 | player_trajs = test_tensors["player_trajs"].flatten()
680 | labels = player_trajs.to(device)
681 | preds = model(test_tensors)
682 |
683 | test_loss = criterion(preds, labels).item()
684 | avg_diff = 0
685 |
686 | for shuffle in range(shuffles):
687 | shuffled_players = np.random.choice(np.arange(10), n_players, False)
688 | for key in shuffle_keys:
689 | test_tensors[key] = test_tensors[key][:, shuffled_players]
690 |
691 | player_trajs = test_tensors["player_trajs"].flatten()
692 | labels = player_trajs.to(device)
693 | preds = model(test_tensors)
694 | shuffle_loss = criterion(preds, labels).item()
695 | data.append({"anchor": test_loss, "shuffled": shuffle_loss})
696 | avg_diff += np.abs(test_loss - shuffle_loss)
697 |
698 | avg_diff /= shuffles
699 | avg_diffs.append(avg_diff / test_loss)
700 |
701 | print(np.mean(avg_diffs))
702 | df = pd.DataFrame(data)
703 | home_dir = os.path.expanduser("~")
704 | df.to_csv(f"{home_dir}/results.csv")
705 | print(np.abs(df["anchor"] - df["shuffled"]).mean())
706 |
707 |
708 | def compare_player_beginning_end():
709 | JOB = "20210402111956"
710 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}"
711 |
712 | device = torch.device("cuda:0")
713 |
714 | # Load model.
715 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml"))
716 | (train_dataset, _, _, _, test_dataset, test_loader) = init_basketball_datasets(opts)
717 | n_players = train_dataset.n_players
718 | model = init_model(opts, train_dataset).to(device)
719 | # See: https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3.
720 | model_dict = model.state_dict()
721 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth")
722 | pretrained_dict = {k: v for (k, v) in pretrained_dict.items() if k in model_dict}
723 | model_dict.update(pretrained_dict)
724 | model.load_state_dict(pretrained_dict)
725 | model.eval()
726 | seq_len = model.seq_len
727 |
728 | criterion = nn.CrossEntropyLoss()
729 |
730 | shuffles = 10
731 | shuffle_keys = [
732 | "player_idxs",
733 | "player_xs",
734 | "player_ys",
735 | "player_hoop_sides",
736 | "player_x_diffs",
737 | "player_y_diffs",
738 | "player_trajs",
739 | ]
740 | np.random.seed(2010)
741 | avg_diffs = []
742 | with torch.no_grad():
743 | for (test_idx, test_tensors) in enumerate(test_loader):
744 | print(test_idx)
745 | # Skip bad sequences.
746 | if len(test_tensors["player_idxs"]) < seq_len:
747 | continue
748 |
749 | for roll in range(n_players):
750 | for key in shuffle_keys:
751 | test_tensors[key] = test_tensors[key].roll(roll, 1)
752 |
753 | player_trajs = test_tensors["player_trajs"].flatten()
754 | labels = player_trajs.to(device)
755 | preds = model(test_tensors)
756 | test_loss = criterion(preds[:1], labels[:1]).item()
757 | # print(test_loss)
758 |
759 | avg_diff = 0
760 | for shuffle in range(shuffles):
761 | shuffled_players = list(
762 | np.random.choice(np.arange(1, 10), n_players - 1, False)
763 | ) + [0]
764 | shuffled_tensors = {}
765 | for key in shuffle_keys:
766 | shuffled_tensors[key] = test_tensors[key][:, shuffled_players]
767 |
768 | player_trajs = shuffled_tensors["player_trajs"].flatten()
769 | labels = player_trajs.to(device)
770 | preds = model(shuffled_tensors)
771 | shuffle_loss = criterion(preds[9:10], labels[9:10]).item()
772 | # print(shuffle_loss)
773 | avg_diff += test_loss - shuffle_loss
774 |
775 | # print()
776 | avg_diff /= shuffles
777 | avg_diffs.append(avg_diff / test_loss)
778 | # print(avg_diffs[-1])
779 |
780 | print(np.mean(avg_diffs))
781 |
782 |
783 | def get_diff_for_each_seq():
784 | device = torch.device("cuda:0")
785 | models = {"baller2vec": "20210402111942", "baller2vec++": "20210402111956"}
786 | for (which_model, JOB) in models.items():
787 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}"
788 |
789 | # Load model.
790 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml"))
791 | (train_dataset, _, _, _, _, test_loader) = init_basketball_datasets(opts)
792 | model = init_model(opts, train_dataset).to(device)
793 | model_dict = model.state_dict()
794 | pretrained_dict = torch.load(f"{JOB_DIR}/best_params.pth")
795 | pretrained_dict = {
796 | k: v for (k, v) in pretrained_dict.items() if k in model_dict
797 | }
798 | model_dict.update(pretrained_dict)
799 | model.load_state_dict(pretrained_dict)
800 | model.eval()
801 | seq_len = model.seq_len
802 | models[which_model] = model
803 |
804 | home_dir = os.path.expanduser("~")
805 | os.makedirs(f"{home_dir}/results", exist_ok=True)
806 |
807 | criterion = nn.CrossEntropyLoss()
808 |
809 | with torch.no_grad():
810 | for (test_idx, test_tensors) in enumerate(test_loader):
811 | print(test_idx, flush=True)
812 |
813 | # Skip bad sequences.
814 | if len(test_tensors["player_idxs"]) < seq_len:
815 | continue
816 |
817 | player_trajs = test_tensors["player_trajs"].flatten()
818 | labels = player_trajs.to(device)
819 | losses = {}
820 | for (which_model, model) in models.items():
821 | preds = model(test_tensors)
822 | losses[which_model] = criterion(preds, labels).item()
823 |
824 | b2v_loss = losses["baller2vec"]
825 | b2vpp_loss = losses["baller2vec++"]
826 | loss_diff = str(int(100 * (b2v_loss - b2vpp_loss) / b2v_loss))
827 |
828 | (traj_imgs, traj_img) = gen_imgs_from_basketball_sample(
829 | test_tensors, seq_len
830 | )
831 | traj_img.save(
832 | f"{home_dir}/results/{loss_diff.zfill(3)}_{test_idx}.png",
833 | )
834 |
835 | shutil.make_archive(f"{home_dir}/results", "zip", f"{home_dir}/results")
836 | shutil.rmtree(f"{home_dir}/results")
837 |
--------------------------------------------------------------------------------
/paper_plots.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from scipy.stats import pearsonr
7 |
8 | font = {"size": 22}
9 |
10 | matplotlib.rc("font", **font)
11 |
12 |
13 | def plot_nll_over_time():
14 | # From: https://stackoverflow.com/a/59421062/1316276.
15 | # Numbers of pairs of bars you want.
16 | N = 20
17 |
18 | # Data on x-axis.
19 |
20 | # Specify the values of blue bars (height).
21 | blue_bar = [
22 | 1.86897819,
23 | 0.58918497,
24 | 0.49199169,
25 | 0.47602542,
26 | 0.45367828,
27 | 0.43433007,
28 | 0.4377114,
29 | 0.43828331,
30 | 0.41888807,
31 | 0.42203086,
32 | 0.42575513,
33 | 0.41059803,
34 | 0.42874022,
35 | 0.41963211,
36 | 0.42825773,
37 | 0.42731653,
38 | 0.44892668,
39 | 0.44514571,
40 | 0.43499798,
41 | 0.46455422,
42 | ]
43 | # Specify the values of orange bars (height).
44 | orange_bar = [
45 | 1.56743299,
46 | 0.54410327,
47 | 0.46949427,
48 | 0.45175594,
49 | 0.42762741,
50 | 0.41282431,
51 | 0.41227819,
52 | 0.41077594,
53 | 0.38996656,
54 | 0.39531277,
55 | 0.40007369,
56 | 0.37965269,
57 | 0.39489294,
58 | 0.3802205,
59 | 0.39562016,
60 | 0.39000066,
61 | 0.40920398,
62 | 0.39750678,
63 | 0.39637249,
64 | 0.41968062,
65 | ]
66 |
67 | # Position of bars on x-axis.
68 | ind = np.arange(N)
69 |
70 | # Figure size.
71 | plt.figure(figsize=(10, 5))
72 |
73 | # Width of a bar.
74 | width = 0.3
75 |
76 | # Plotting.
77 | plt.bar(ind, blue_bar, width, label="baller2vec")
78 | plt.bar(ind + width, orange_bar, width, label="baller2vec++")
79 |
80 | plt.xlabel("Time step")
81 | plt.ylabel("Average NLL")
82 |
83 | # xticks().
84 | # First argument: a list of positions at which ticks should be placed.
85 | # Second argument: a list of labels to place at the given locations.
86 | plt.xticks(ind + width / 2, list(range(1, N + 1)))
87 |
88 | # Finding the best position for legends and putting it there.
89 | plt.legend(loc="best", prop={"family": "monospace"})
90 | plt.show()
91 |
92 |
93 | def plot_permutations():
94 | df = pd.read_csv("test.csv")
95 | print(pearsonr(df["anchor"], df["shuffled"]))
96 | plt.scatter(df["anchor"], df["shuffled"])
97 | plt.xlabel("Unshuffled average NLL")
98 | plt.ylabel("Shuffled average NLL")
99 | plt.show()
100 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | py7zr
3 | torch
4 | scipy
5 | numpy
6 | pandas
7 | Pillow
8 | PyYAML
9 |
--------------------------------------------------------------------------------
/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | NORMALIZATION_COEF = 7
4 | PLAYER_CIRCLE_SIZE = 12 / NORMALIZATION_COEF
5 | INTERVAL = 10
6 | DIFF = 6
7 | X_MIN = 0
8 | X_MAX = 100
9 | Y_MIN = 0
10 | Y_MAX = 50
11 | COL_WIDTH = 0.3
12 | SCALE = 1.65
13 | FONTSIZE = 6
14 | X_CENTER = X_MAX / 2 - DIFF / 1.5 + 0.10
15 | Y_CENTER = Y_MAX - DIFF / 1.5 - 0.35
16 | BALL_COLOR = "#ff8c00"
17 | (COURT_WIDTH, COURT_LENGTH) = (50, 94)
18 | TEAM_ID2PROPS = {
19 | 1610612737: {"color": "#E13A3E", "abbreviation": "ATL"},
20 | 1610612738: {"color": "#008348", "abbreviation": "BOS"},
21 | 1610612751: {"color": "#061922", "abbreviation": "BKN"},
22 | 1610612766: {"color": "#1D1160", "abbreviation": "CHA"},
23 | 1610612741: {"color": "#CE1141", "abbreviation": "CHI"},
24 | 1610612739: {"color": "#860038", "abbreviation": "CLE"},
25 | 1610612742: {"color": "#007DC5", "abbreviation": "DAL"},
26 | 1610612743: {"color": "#4D90CD", "abbreviation": "DEN"},
27 | 1610612765: {"color": "#006BB6", "abbreviation": "DET"},
28 | 1610612744: {"color": "#FDB927", "abbreviation": "GSW"},
29 | 1610612745: {"color": "#CE1141", "abbreviation": "HOU"},
30 | 1610612754: {"color": "#00275D", "abbreviation": "IND"},
31 | 1610612746: {"color": "#ED174C", "abbreviation": "LAC"},
32 | 1610612747: {"color": "#552582", "abbreviation": "LAL"},
33 | 1610612763: {"color": "#0F586C", "abbreviation": "MEM"},
34 | 1610612748: {"color": "#98002E", "abbreviation": "MIA"},
35 | 1610612749: {"color": "#00471B", "abbreviation": "MIL"},
36 | 1610612750: {"color": "#005083", "abbreviation": "MIN"},
37 | 1610612740: {"color": "#002B5C", "abbreviation": "NOP"},
38 | 1610612752: {"color": "#006BB6", "abbreviation": "NYK"},
39 | 1610612760: {"color": "#007DC3", "abbreviation": "OKC"},
40 | 1610612753: {"color": "#007DC5", "abbreviation": "ORL"},
41 | 1610612755: {"color": "#006BB6", "abbreviation": "PHI"},
42 | 1610612756: {"color": "#1D1160", "abbreviation": "PHX"},
43 | 1610612757: {"color": "#E03A3E", "abbreviation": "POR"},
44 | 1610612758: {"color": "#724C9F", "abbreviation": "SAC"},
45 | 1610612759: {"color": "#BAC3C9", "abbreviation": "SAS"},
46 | 1610612761: {"color": "#CE1141", "abbreviation": "TOR"},
47 | 1610612762: {"color": "#00471B", "abbreviation": "UTA"},
48 | 1610612764: {"color": "#002B5C", "abbreviation": "WAS"},
49 | }
50 | EVENTS_DIR = os.environ["EVENTS_DIR"]
51 | TRACKING_DIR = os.environ["TRACKING_DIR"]
52 | GAMES_DIR = os.environ["GAMES_DIR"]
53 | DATA_DIR = os.environ["DATA_DIR"]
54 | EXPERIMENTS_DIR = os.environ["EXPERIMENTS_DIR"]
55 |
--------------------------------------------------------------------------------
/test_gameids.txt:
--------------------------------------------------------------------------------
1 | 0021500637
2 | 0021500479
3 | 0021500167
4 | 0021500149
5 | 0021500456
6 | 0021500425
7 | 0021500460
8 | 0021500160
9 | 0021500399
10 | 0021500145
11 | 0021500626
12 | 0021500267
13 | 0021500340
14 | 0021500268
15 | 0021500278
16 | 0021500147
17 | 0021500163
18 | 0021500593
19 | 0021500084
20 | 0021500403
21 | 0021500061
22 | 0021500434
23 | 0021500567
24 | 0021500359
25 | 0021500079
26 | 0021500481
27 | 0021500427
28 | 0021500243
29 | 0021500245
30 | 0021500622
31 | 0021500330
32 | 0021500522
33 |
--------------------------------------------------------------------------------
/toy_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class ToyDataset(Dataset):
8 | def __init__(self):
9 | self.seq_len = 20
10 | self.n_players = 2
11 | self.n_player_ids = 2
12 | player_0_idxs = np.array(self.seq_len * [0])
13 | player_1_idxs = np.array(self.seq_len * [1])
14 | self.player_idxs = np.vstack([player_0_idxs, player_1_idxs]).T
15 | self.player_hoop_sides = np.zeros((self.seq_len, 2))
16 | self.player_traj_n = 3
17 |
18 | def __len__(self):
19 | return 500
20 |
21 | def get_sample(self):
22 | # Sample coordinated trajectories.
23 | player_x_diffs = np.random.randint(-1, 2, self.seq_len)
24 | player_y_diffs = np.random.randint(-1, 2, self.seq_len)
25 |
26 | # Calculate total distance traveled in x and y directions.
27 | player_x_dists = np.cumsum(player_x_diffs)[None].T
28 | player_y_dists = np.cumsum(player_y_diffs)[None].T
29 |
30 | # Calculate x and y positions for agents.
31 | player_xs = np.zeros((self.seq_len, self.n_players))
32 | # Agents start one unit to the left and right of the origin.
33 | player_xs[0] = [-1, 1]
34 | player_xs[1:] = player_xs[0] + player_x_dists[: self.seq_len - 1]
35 | player_ys = np.zeros((self.seq_len, self.n_players))
36 | player_ys[1:] = player_y_dists[: self.seq_len - 1]
37 |
38 | # Agents start in random order from left to right.
39 | keep_players = np.random.choice(np.arange(2), self.n_players, False)
40 | player_idxs = self.player_idxs[:, keep_players].astype(int)
41 |
42 | # Convert trajectories to index labels.
43 | player_traj_rows = 1 + player_y_diffs
44 | player_traj_cols = 1 + player_x_diffs
45 | player_trajs = player_traj_rows * self.player_traj_n + player_traj_cols
46 | player_trajs = np.vstack([player_trajs, player_trajs]).T
47 |
48 | player_x_diffs = np.vstack([player_x_diffs, player_x_diffs]).T
49 | player_y_diffs = np.vstack([player_y_diffs, player_y_diffs]).T
50 |
51 | return {
52 | "player_idxs": torch.LongTensor(player_idxs),
53 | "player_xs": torch.Tensor(player_xs),
54 | "player_ys": torch.Tensor(player_ys),
55 | "player_x_diffs": torch.Tensor(player_x_diffs),
56 | "player_y_diffs": torch.Tensor(player_y_diffs),
57 | "player_hoop_sides": torch.Tensor(self.player_hoop_sides),
58 | "player_trajs": torch.LongTensor(player_trajs),
59 | }
60 |
61 | def __getitem__(self, idx):
62 | return self.get_sample()
63 |
--------------------------------------------------------------------------------
/train_baller2vecplusplus.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import sys
4 | import time
5 | import torch
6 | import yaml
7 |
8 | from baller2vecplusplus import Baller2VecPlusPlus
9 | from baller2vecplusplus_dataset import Baller2VecPlusPlusDataset
10 | from settings import *
11 | from torch import nn, optim
12 | from torch.utils.data import DataLoader
13 | from toy_dataset import ToyDataset
14 |
15 | SEED = 2010
16 | torch.manual_seed(SEED)
17 | torch.set_printoptions(linewidth=160)
18 | np.random.seed(SEED)
19 |
20 |
21 | def worker_init_fn(worker_id):
22 | # See: https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers
23 | # and: https://pytorch.org/docs/stable/data.html#multi-process-data-loading
24 | # and: https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading.
25 | # NumPy seed takes a 32-bit unsigned integer.
26 | np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2 ** 32 - 1))
27 |
28 |
29 | def get_train_valid_test_gameids():
30 | with open("train_gameids.txt") as f:
31 | train_gameids = f.read().split()
32 |
33 | with open("valid_gameids.txt") as f:
34 | valid_gameids = f.read().split()
35 |
36 | with open("test_gameids.txt") as f:
37 | test_gameids = f.read().split()
38 |
39 | return (train_gameids, valid_gameids, test_gameids)
40 |
41 |
42 | def init_basketball_datasets(opts):
43 | baller2vec_config = pickle.load(open(f"{DATA_DIR}/baller2vec_config.pydict", "rb"))
44 | n_player_ids = len(baller2vec_config["player_idx2props"])
45 |
46 | (train_gameids, valid_gameids, test_gameids) = get_train_valid_test_gameids()
47 |
48 | dataset_config = opts["dataset"]
49 | dataset_config["gameids"] = train_gameids
50 | dataset_config["N"] = opts["train"]["train_samples_per_epoch"]
51 | dataset_config["starts"] = []
52 | dataset_config["mode"] = "train"
53 | dataset_config["n_player_ids"] = n_player_ids
54 | train_dataset = Baller2VecPlusPlusDataset(**dataset_config)
55 | train_loader = DataLoader(
56 | dataset=train_dataset,
57 | batch_size=None,
58 | num_workers=opts["train"]["workers"],
59 | worker_init_fn=worker_init_fn,
60 | )
61 |
62 | N = opts["train"]["valid_samples"]
63 | samps_per_gameid = int(np.ceil(N / len(valid_gameids)))
64 | starts = []
65 | for gameid in valid_gameids:
66 | y = np.load(f"{GAMES_DIR}/{gameid}_y.npy")
67 | max_start = len(y) - train_dataset.chunk_size
68 | gaps = max_start // samps_per_gameid
69 | starts.append(gaps * np.arange(samps_per_gameid))
70 |
71 | dataset_config["gameids"] = np.repeat(valid_gameids, samps_per_gameid)
72 | dataset_config["N"] = len(dataset_config["gameids"])
73 | dataset_config["starts"] = np.concatenate(starts)
74 | dataset_config["mode"] = "valid"
75 | valid_dataset = Baller2VecPlusPlusDataset(**dataset_config)
76 | valid_loader = DataLoader(
77 | dataset=valid_dataset,
78 | batch_size=None,
79 | num_workers=opts["train"]["workers"],
80 | )
81 |
82 | samps_per_gameid = int(np.ceil(N / len(test_gameids)))
83 | starts = []
84 | for gameid in test_gameids:
85 | y = np.load(f"{GAMES_DIR}/{gameid}_y.npy")
86 | max_start = len(y) - train_dataset.chunk_size
87 | gaps = max_start // samps_per_gameid
88 | starts.append(gaps * np.arange(samps_per_gameid))
89 |
90 | dataset_config["gameids"] = np.repeat(test_gameids, samps_per_gameid)
91 | dataset_config["N"] = len(dataset_config["gameids"])
92 | dataset_config["starts"] = np.concatenate(starts)
93 | dataset_config["mode"] = "test"
94 | test_dataset = Baller2VecPlusPlusDataset(**dataset_config)
95 | test_loader = DataLoader(
96 | dataset=test_dataset,
97 | batch_size=None,
98 | num_workers=opts["train"]["workers"],
99 | )
100 |
101 | return (
102 | train_dataset,
103 | train_loader,
104 | valid_dataset,
105 | valid_loader,
106 | test_dataset,
107 | test_loader,
108 | )
109 |
110 |
111 | def init_toy_dataset():
112 | train_dataset = ToyDataset()
113 | train_loader = DataLoader(
114 | dataset=train_dataset,
115 | batch_size=None,
116 | num_workers=opts["train"]["workers"],
117 | worker_init_fn=worker_init_fn,
118 | )
119 | return (train_dataset, train_loader)
120 |
121 |
122 | def init_model(opts, train_dataset):
123 | model_config = opts["model"]
124 | # Add one for the generic player.
125 | model_config["n_player_ids"] = train_dataset.n_player_ids + 1
126 | model_config["seq_len"] = train_dataset.seq_len
127 | if opts["train"]["task"] == "basketball":
128 | model_config["seq_len"] -= 1
129 |
130 | model_config["n_players"] = train_dataset.n_players
131 | model_config["n_player_labels"] = train_dataset.player_traj_n ** 2
132 | model = Baller2VecPlusPlus(**model_config)
133 |
134 | return model
135 |
136 |
137 | def get_preds_labels(tensors):
138 | preds = model(tensors)
139 | labels = tensors["player_trajs"].flatten().to(device)
140 | return (preds, labels)
141 |
142 |
143 | def train_model():
144 | # Initialize optimizer.
145 | train_params = [params for params in model.parameters()]
146 | optimizer = optim.Adam(train_params, lr=opts["train"]["learning_rate"])
147 | criterion = nn.CrossEntropyLoss()
148 |
149 | # Continue training on a prematurely terminated model.
150 | try:
151 | model.load_state_dict(torch.load(f"{JOB_DIR}/best_params.pth"))
152 |
153 | try:
154 | state_dict = torch.load(f"{JOB_DIR}/optimizer.pth")
155 | if opts["train"]["learning_rate"] == state_dict["param_groups"][0]["lr"]:
156 | optimizer.load_state_dict(state_dict)
157 |
158 | except ValueError:
159 | print("Old optimizer doesn't match.")
160 |
161 | except FileNotFoundError:
162 | pass
163 |
164 | best_train_loss = float("inf")
165 | best_valid_loss = float("inf")
166 | test_loss_best_valid = float("inf")
167 | total_train_loss = None
168 | no_improvement = 0
169 | for epoch in range(1000000):
170 | print(f"\nepoch: {epoch}", flush=True)
171 |
172 | if task == "basketball":
173 | model.eval()
174 | total_valid_loss = 0.0
175 | with torch.no_grad():
176 | n_valid = 0
177 | for (valid_idx, valid_tensors) in enumerate(valid_loader):
178 | # Skip bad sequences.
179 | if len(valid_tensors["player_idxs"]) < seq_len:
180 | continue
181 |
182 | (preds, labels) = get_preds_labels(valid_tensors)
183 | loss = criterion(preds, labels)
184 | total_valid_loss += loss.item()
185 | n_valid += 1
186 |
187 | probs = torch.softmax(preds, dim=1)
188 | (probs, preds) = probs.max(1)
189 | print(probs.view(seq_len, n_players), flush=True)
190 | print(preds.view(seq_len, n_players), flush=True)
191 | print(labels.view(seq_len, n_players), flush=True)
192 |
193 | total_valid_loss /= n_valid
194 |
195 | if total_valid_loss < best_valid_loss:
196 | best_valid_loss = total_valid_loss
197 | no_improvement = 0
198 | torch.save(optimizer.state_dict(), f"{JOB_DIR}/optimizer.pth")
199 | torch.save(model.state_dict(), f"{JOB_DIR}/best_params.pth")
200 |
201 | test_loss_best_valid = 0.0
202 | with torch.no_grad():
203 | n_test = 0
204 | for (test_idx, test_tensors) in enumerate(test_loader):
205 | # Skip bad sequences.
206 | if len(test_tensors["player_idxs"]) < seq_len:
207 | continue
208 |
209 | (preds, labels) = get_preds_labels(test_tensors)
210 | loss = criterion(preds, labels)
211 | test_loss_best_valid += loss.item()
212 | n_test += 1
213 |
214 | test_loss_best_valid /= n_test
215 |
216 | elif no_improvement < opts["train"]["patience"]:
217 | no_improvement += 1
218 | if no_improvement == opts["train"]["patience"]:
219 | print("Reducing learning rate.")
220 | for g in optimizer.param_groups:
221 | g["lr"] *= 0.1
222 |
223 | print(f"total_train_loss: {total_train_loss}")
224 | print(f"best_train_loss: {best_train_loss}")
225 | if task == "basketball":
226 | print(f"total_valid_loss: {total_valid_loss}")
227 | print(f"best_valid_loss: {best_valid_loss}")
228 | print(f"test_loss_best_valid: {test_loss_best_valid}")
229 |
230 | model.train()
231 | total_train_loss = 0.0
232 | n_train = 0
233 | start_time = time.time()
234 | for (train_idx, train_tensors) in enumerate(train_loader):
235 | if train_idx % 1000 == 0:
236 | print(train_idx, flush=True)
237 |
238 | # Skip bad sequences.
239 | if len(train_tensors["player_idxs"]) < seq_len:
240 | continue
241 |
242 | optimizer.zero_grad()
243 | (preds, labels) = get_preds_labels(train_tensors)
244 | loss = criterion(preds, labels)
245 | total_train_loss += loss.item()
246 | loss.backward()
247 | optimizer.step()
248 | n_train += 1
249 |
250 | epoch_time = time.time() - start_time
251 |
252 | total_train_loss /= n_train
253 | if total_train_loss < best_train_loss:
254 | best_train_loss = total_train_loss
255 | if task == "toy":
256 | torch.save(optimizer.state_dict(), f"{JOB_DIR}/optimizer.pth")
257 | torch.save(model.state_dict(), f"{JOB_DIR}/best_params.pth")
258 |
259 | if task == "toy":
260 | train_tensors = train_dataset[0]
261 | while len(train_tensors["player_idxs"]) < seq_len:
262 | train_tensors = train_dataset[0]
263 |
264 | (preds, labels) = get_preds_labels(train_tensors)
265 | probs = torch.softmax(preds, dim=1)
266 | (probs, preds) = probs.max(1)
267 | print(probs.view(seq_len, n_players), flush=True)
268 | print(preds.view(seq_len, n_players), flush=True)
269 | print(labels.view(seq_len, n_players), flush=True)
270 |
271 | print(f"epoch_time: {epoch_time:.2f}", flush=True)
272 |
273 |
274 | if __name__ == "__main__":
275 | JOB = sys.argv[1]
276 | JOB_DIR = f"{EXPERIMENTS_DIR}/{JOB}"
277 |
278 | try:
279 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[2]
280 | except IndexError:
281 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
282 |
283 | opts = yaml.safe_load(open(f"{JOB_DIR}/{JOB}.yaml"))
284 | task = opts["train"]["task"]
285 |
286 | # Initialize datasets.
287 | if task == "basketball":
288 | (
289 | train_dataset,
290 | train_loader,
291 | valid_dataset,
292 | valid_loader,
293 | test_dataset,
294 | test_loader,
295 | ) = init_basketball_datasets(opts)
296 | elif task == "toy":
297 | (train_dataset, train_loader) = init_toy_dataset()
298 |
299 | # Initialize model.
300 | device = torch.device("cuda:0")
301 |
302 | model = init_model(opts, train_dataset).to(device)
303 | print(model)
304 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
305 | print(f"Parameters: {n_params}")
306 |
307 | seq_len = model.seq_len
308 | n_players = model.n_players
309 |
310 | train_model()
311 |
--------------------------------------------------------------------------------
/train_cropped.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/airalcorn2/baller2vecplusplus/e9753be4995300e4e2057ebd06f6ad96131f7ba9/train_cropped.gif
--------------------------------------------------------------------------------
/train_gameids.txt:
--------------------------------------------------------------------------------
1 | 0021500257
2 | 0021500516
3 | 0021500632
4 | 0021500193
5 | 0021500170
6 | 0021500080
7 | 0021500465
8 | 0021500142
9 | 0021500053
10 | 0021500374
11 | 0021500619
12 | 0021500651
13 | 0021500576
14 | 0021500277
15 | 0021500215
16 | 0021500410
17 | 0021500121
18 | 0021500517
19 | 0021500413
20 | 0021500143
21 | 0021500166
22 | 0021500463
23 | 0021500469
24 | 0021500621
25 | 0021500056
26 | 0021500271
27 | 0021500623
28 | 0021500341
29 | 0021500263
30 | 0021500366
31 | 0021500377
32 | 0021500009
33 | 0021500414
34 | 0021500151
35 | 0021500159
36 | 0021500227
37 | 0021500448
38 | 0021500052
39 | 0021500451
40 | 0021500212
41 | 0021500549
42 | 0021500633
43 | 0021500085
44 | 0021500356
45 | 0021500237
46 | 0021500018
47 | 0021500625
48 | 0021500261
49 | 0021500449
50 | 0021500346
51 | 0021500624
52 | 0021500295
53 | 0021500508
54 | 0021500187
55 | 0021500655
56 | 0021500127
57 | 0021500339
58 | 0021500452
59 | 0021500096
60 | 0021500210
61 | 0021500335
62 | 0021500493
63 | 0021500049
64 | 0021500013
65 | 0021500015
66 | 0021500480
67 | 0021500474
68 | 0021500354
69 | 0021500140
70 | 0021500652
71 | 0021500062
72 | 0021500562
73 | 0021500293
74 | 0021500432
75 | 0021500126
76 | 0021500103
77 | 0021500433
78 | 0021500326
79 | 0021500218
80 | 0021500094
81 | 0021500397
82 | 0021500099
83 | 0021500040
84 | 0021500491
85 | 0021500109
86 | 0021500171
87 | 0021500577
88 | 0021500662
89 | 0021500279
90 | 0021500309
91 | 0021500068
92 | 0021500106
93 | 0021500524
94 | 0021500379
95 | 0021500554
96 | 0021500338
97 | 0021500305
98 | 0021500105
99 | 0021500555
100 | 0021500639
101 | 0021500031
102 | 0021500462
103 | 0021500136
104 | 0021500026
105 | 0021500417
106 | 0021500083
107 | 0021500331
108 | 0021500074
109 | 0021500511
110 | 0021500405
111 | 0021500563
112 | 0021500504
113 | 0021500205
114 | 0021500382
115 | 0021500059
116 | 0021500046
117 | 0021500334
118 | 0021500254
119 | 0021500086
120 | 0021500017
121 | 0021500296
122 | 0021500337
123 | 0021500542
124 | 0021500638
125 | 0021500547
126 | 0021500520
127 | 0021500032
128 | 0021500518
129 | 0021500422
130 | 0021500630
131 | 0021500400
132 | 0021500412
133 | 0021500484
134 | 0021500048
135 | 0021500436
136 | 0021500371
137 | 0021500450
138 | 0021500224
139 | 0021500022
140 | 0021500082
141 | 0021500146
142 | 0021500239
143 | 0021500369
144 | 0021500391
145 | 0021500027
146 | 0021500381
147 | 0021500201
148 | 0021500636
149 | 0021500290
150 | 0021500012
151 | 0021500423
152 | 0021500431
153 | 0021500365
154 | 0021500387
155 | 0021500199
156 | 0021500150
157 | 0021500073
158 | 0021500663
159 | 0021500598
160 | 0021500580
161 | 0021500325
162 | 0021500342
163 | 0021500585
164 | 0021500281
165 | 0021500467
166 | 0021500066
167 | 0021500559
168 | 0021500492
169 | 0021500063
170 | 0021500306
171 | 0021500550
172 | 0021500113
173 | 0021500631
174 | 0021500353
175 | 0021500394
176 | 0021500318
177 | 0021500148
178 | 0021500019
179 | 0021500320
180 | 0021500247
181 | 0021500489
182 | 0021500057
183 | 0021500540
184 | 0021500246
185 | 0021500192
186 | 0021500510
187 | 0021500204
188 | 0021500368
189 | 0021500444
190 | 0021500586
191 | 0021500002
192 | 0021500572
193 | 0021500286
194 | 0021500538
195 | 0021500488
196 | 0021500496
197 | 0021500620
198 | 0021500288
199 | 0021500545
200 | 0021500095
201 | 0021500249
202 | 0021500584
203 | 0021500646
204 | 0021500116
205 | 0021500098
206 | 0021500348
207 | 0021500568
208 | 0021500307
209 | 0021500453
210 | 0021500408
211 | 0021500455
212 | 0021500260
213 | 0021500490
214 | 0021500599
215 | 0021500220
216 | 0021500601
217 | 0021500357
218 | 0021500274
219 | 0021500280
220 | 0021500176
221 | 0021500378
222 | 0021500273
223 | 0021500028
224 | 0021500447
225 | 0021500248
226 | 0021500471
227 | 0021500124
228 | 0021500200
229 | 0021500256
230 | 0021500115
231 | 0021500174
232 | 0021500457
233 | 0021500075
234 | 0021500658
235 | 0021500327
236 | 0021500534
237 | 0021500117
238 | 0021500389
239 | 0021500072
240 | 0021500380
241 | 0021500104
242 | 0021500011
243 | 0021500512
244 | 0021500195
245 | 0021500138
246 | 0021500556
247 | 0021500507
248 | 0021500553
249 | 0021500409
250 | 0021500437
251 | 0021500565
252 | 0021500442
253 | 0021500566
254 | 0021500446
255 | 0021500175
256 | 0021500198
257 | 0021500351
258 | 0021500616
259 | 0021500428
260 | 0021500537
261 | 0021500486
262 | 0021500370
263 | 0021500529
264 | 0021500575
265 | 0021500169
266 | 0021500531
267 | 0021500384
268 | 0021500398
269 | 0021500064
270 | 0021500158
271 | 0021500129
272 | 0021500153
273 | 0021500445
274 | 0021500564
275 | 0021500119
276 | 0021500345
277 | 0021500242
278 | 0021500071
279 | 0021500217
280 | 0021500500
281 | 0021500476
282 | 0021500358
283 | 0021500135
284 | 0021500482
285 | 0021500165
286 | 0021500472
287 | 0021500311
288 | 0021500373
289 | 0021500005
290 | 0021500232
291 | 0021500264
292 | 0021500214
293 | 0021500231
294 | 0021500528
295 | 0021500081
296 | 0021500191
297 | 0021500188
298 | 0021500592
299 | 0021500050
300 | 0021500582
301 | 0021500364
302 | 0021500054
303 | 0021500235
304 | 0021500660
305 | 0021500001
306 | 0021500429
307 | 0021500352
308 | 0021500078
309 | 0021500122
310 | 0021500509
311 | 0021500313
312 | 0021500023
313 | 0021500016
314 | 0021500360
315 | 0021500494
316 | 0021500362
317 | 0021500321
318 | 0021500299
319 | 0021500551
320 | 0021500229
321 | 0021500441
322 | 0021500025
323 | 0021500310
324 | 0021500648
325 | 0021500207
326 | 0021500315
327 | 0021500222
328 | 0021500285
329 | 0021500595
330 | 0021500617
331 | 0021500502
332 | 0021500514
333 | 0021500038
334 | 0021500211
335 | 0021500527
336 | 0021500597
337 | 0021500291
338 | 0021500383
339 | 0021500255
340 | 0021500539
341 | 0021500141
342 | 0021500583
343 | 0021500250
344 | 0021500234
345 | 0021500202
346 | 0021500418
347 | 0021500579
348 | 0021500089
349 | 0021500178
350 | 0021500144
351 | 0021500594
352 | 0021500189
353 | 0021500478
354 | 0021500519
355 | 0021500385
356 | 0021500506
357 | 0021500653
358 | 0021500505
359 | 0021500101
360 | 0021500533
361 | 0021500180
362 | 0021500152
363 | 0021500308
364 | 0021500312
365 | 0021500355
366 | 0021500090
367 | 0021500393
368 | 0021500349
369 | 0021500498
370 | 0021500420
371 | 0021500177
372 | 0021500629
373 | 0021500501
374 | 0021500344
375 | 0021500544
376 | 0021500581
377 | 0021500047
378 | 0021500128
379 | 0021500569
380 | 0021500114
381 | 0021500503
382 | 0021500314
383 | 0021500155
384 | 0021500650
385 | 0021500004
386 | 0021500156
387 | 0021500470
388 | 0021500402
389 | 0021500317
390 | 0021500173
391 | 0021500618
392 | 0021500236
393 | 0021500020
394 | 0021500571
395 | 0021500375
396 | 0021500185
397 | 0021500110
398 | 0021500272
399 | 0021500645
400 | 0021500615
401 | 0021500303
402 | 0021500322
403 | 0021500197
404 | 0021500036
405 | 0021500526
406 | 0021500323
407 | 0021500363
408 | 0021500329
409 | 0021500419
410 | 0021500251
411 | 0021500561
412 | 0021500558
413 | 0021500426
414 | 0021500468
415 | 0021500181
416 | 0021500596
417 | 0021500573
418 | 0021500390
419 | 0021500244
420 | 0021500219
421 | 0021500600
422 | 0021500111
423 | 0021500411
424 | 0021500560
425 | 0021500336
426 | 0021500475
427 | 0021500233
428 | 0021500515
429 | 0021500591
430 | 0021500058
431 | 0021500513
432 | 0021500208
433 | 0021500157
434 | 0021500270
435 | 0021500319
436 | 0021500186
437 | 0021500421
438 | 0021500034
439 | 0021500287
440 | 0021500301
441 | 0021500372
442 | 0021500483
443 | 0021500042
444 | 0021500024
445 | 0021500435
446 | 0021500259
447 | 0021500401
448 | 0021500164
449 | 0021500067
450 | 0021500535
451 | 0021500206
452 | 0021500557
453 | 0021500541
454 | 0021500213
455 | 0021500304
456 | 0021500350
457 | 0021500137
458 | 0021500657
459 | 0021500347
460 | 0021500172
461 | 0021500035
462 | 0021500184
463 | 0021500464
464 | 0021500196
465 | 0021500131
466 | 0021500439
467 | 0021500635
468 | 0021500262
469 | 0021500241
470 | 0021500485
471 | 0021500440
472 | 0021500182
473 | 0021500458
474 | 0021500060
475 | 0021500376
476 | 0021500037
477 | 0021500168
478 | 0021500275
479 | 0021500087
480 | 0021500183
481 | 0021500044
482 | 0021500092
483 | 0021500546
484 | 0021500396
485 | 0021500395
486 | 0021500487
487 | 0021500627
488 | 0021500077
489 | 0021500548
490 | 0021500300
491 | 0021500284
492 | 0021500216
493 | 0021500316
494 | 0021500404
495 | 0021500343
496 | 0021500125
497 | 0021500093
498 | 0021500258
499 | 0021500552
500 | 0021500328
501 | 0021500161
502 | 0021500223
503 | 0021500194
504 | 0021500097
505 | 0021500386
506 | 0021500041
507 | 0021500430
508 | 0021500102
509 | 0021500070
510 | 0021500179
511 | 0021500406
512 | 0021500130
513 | 0021500473
514 | 0021500107
515 | 0021500525
516 | 0021500407
517 | 0021500091
518 | 0021500139
519 | 0021500289
520 | 0021500499
521 | 0021500221
522 | 0021500530
523 | 0021500230
524 | 0021500051
525 | 0021500228
526 | 0021500021
527 | 0021500332
528 | 0021500162
529 | 0021500133
530 | 0021500570
531 | 0021500225
532 | 0021500367
533 | 0021500108
534 | 0021500424
535 | 0021500297
536 | 0021500523
537 | 0021500298
538 | 0021500112
539 | 0021500649
540 | 0021500203
541 | 0021500252
542 | 0021500292
543 | 0021500118
544 | 0021500209
545 | 0021500361
546 | 0021500154
547 | 0021500647
548 | 0021500532
549 | 0021500190
550 | 0021500642
551 | 0021500521
552 | 0021500388
553 | 0021500543
554 | 0021500045
555 | 0021500120
556 | 0021500415
557 | 0021500269
558 | 0021500088
559 | 0021500226
560 | 0021500003
561 | 0021500076
562 | 0021500123
563 | 0021500477
564 | 0021500438
565 | 0021500033
566 | 0021500392
567 | 0021500443
568 | 0021500578
569 | 0021500634
570 |
--------------------------------------------------------------------------------
/valid_gameids.txt:
--------------------------------------------------------------------------------
1 | 0021500266
2 | 0021500466
3 | 0021500010
4 | 0021500132
5 | 0021500283
6 | 0021500043
7 | 0021500069
8 | 0021500029
9 | 0021500007
10 | 0021500055
11 | 0021500461
12 | 0021500265
13 | 0021500039
14 | 0021500134
15 | 0021500302
16 | 0021500454
17 | 0021500294
18 | 0021500628
19 | 0021500324
20 | 0021500333
21 | 0021500641
22 | 0021500253
23 | 0021500416
24 | 0021500495
25 | 0021500661
26 | 0021500497
27 | 0021500459
28 | 0021500536
29 | 0021500065
30 | 0021500574
31 |
--------------------------------------------------------------------------------