├── images
├── attn-1.png
├── model.pdf
└── model-1.png
├── LICENSE
├── README.md
└── model.ipynb
/images/attn-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ojedaf/IC-TIR-Lol/HEAD/images/attn-1.png
--------------------------------------------------------------------------------
/images/model.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ojedaf/IC-TIR-Lol/HEAD/images/model.pdf
--------------------------------------------------------------------------------
/images/model-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ojedaf/IC-TIR-Lol/HEAD/images/model-1.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 ojedaf
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Interpretable Contextual Team-aware Item Recommendation: Application in Multiplayer Online Battle Arena Game
2 |
3 |
4 |
5 |
6 |
7 | See paper in [arXiv](https://arxiv.org/abs/2007.15236)
8 |
9 | 
10 |
11 | We release the PyTorch code of the TTIR model.
12 |
13 | ## Content
14 |
15 | - [Prerequisites](#prerequisites)
16 | - [Dataset](#dataset)
17 | - [Code](#code)
18 | - [Result](#testing)
19 |
20 | ## Prerequisites
21 |
22 | The code is built with following libraries:
23 |
24 | - [PyTorch](https://pytorch.org/) 1.0 or higher
25 | - [Comet-ml](https://www.comet.ml/site/)
26 | - [PyTorchLightning](https://github.com/PyTorchLightning/pytorch-lightning)
27 | - [Google Colab](https://colab.research.google.com/)
28 |
29 | ## Dataset
30 |
31 | The used dataset is available [here](https://drive.google.com/drive/folders/1lsCjmVrOA0stNiUguGWKN46fEqzzsXPH?usp=sharing).
32 |
33 | ## Code
34 |
35 | We develop this project using Google Colab. That's why you must have a Google Account and the dataset in a gDrive folder. Furthermore, you have to change these paths according to the location of the dataset.
36 |
37 | ```python
38 | train_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/train_splits.pkl'
39 | test_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/test_splits.pkl'
40 | champion_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/champion_types.pkl'
41 | ```
42 |
43 | And the comet parameters (api_key, project_name, workspace)
44 |
45 | ```python
46 | comet_logger = CometLogger(
47 | experiment_name=conf['exp_name'],
48 | api_key = 'YOUR_KEY',
49 | project_name="YOUR_PROJECT_NAME",
50 | workspace = 'YOUR_WORKSPACE'
51 | )
52 | ```
53 |
54 | ## Baselines
55 |
56 | This work uses the proposed models in [Data mining for item recommendation in MOBA games paper](https://github.com/vgaraujov/RecSysLoL) as baselines.
57 |
58 | ## Results
59 |
60 | This method outperforms the state of the art approaches and explains the result.
61 |
62 | Method | Precision@6 | Recall@6 | F1@6 | MAP@6 |
63 | --- | --- | --- | --- |--- |
64 | TTIR | 0.492 | 0.756 | 0.596 | 0.805 |
65 | CNN | 0.484 | 0.744 | 0.586 | 0.795 |
66 | ANN | 0.476 | 0.732 | 0.566 | 0.785 |
67 |
68 | 
69 |
70 | ## Citation
71 |
72 | If you find this repository useful for your research, please consider citing our paper:
73 | ```
74 | @inproceedings{10.1145/3383313.3412211,
75 | author = {Villa, Andr\'{e}s and Araujo, Vladimir and Cattan, Francisca and Parra, Denis},
76 | title = {Interpretable Contextual Team-Aware Item Recommendation: Application in Multiplayer Online Battle Arena Games},
77 | year = {2020},
78 | isbn = {9781450375832},
79 | publisher = {Association for Computing Machinery},
80 | address = {New York, NY, USA},
81 | url = {https://doi.org/10.1145/3383313.3412211},
82 | doi = {10.1145/3383313.3412211},
83 | booktitle = {Fourteenth ACM Conference on Recommender Systems},
84 | pages = {503–508},
85 | numpages = {6},
86 | keywords = {Item Recommendation, Deep Learning, MOBA Games},
87 | location = {Virtual Event, Brazil},
88 | series = {RecSys '20}
89 | }
90 | ```
91 |
92 | For any questions, welcome to create an issue or contact Andrés Villa (afvilla@uc.cl) - Vladimir Araujo (vgaraujo@uc.cl).
93 |
--------------------------------------------------------------------------------
/model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Aux_RecAttModel_Multitask_Aug_Data_Villa_Cattan_RECSYS_champ_type.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "toc_visible": true,
10 | "machine_shape": "hm"
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "accelerator": "GPU",
17 | "widgets": {
18 | "application/vnd.jupyter.widget-state+json": {
19 | "a705d3b71b5d4fb587b3bb1fb38161fa": {
20 | "model_module": "@jupyter-widgets/controls",
21 | "model_name": "HBoxModel",
22 | "state": {
23 | "_view_name": "HBoxView",
24 | "_dom_classes": [],
25 | "_model_name": "HBoxModel",
26 | "_view_module": "@jupyter-widgets/controls",
27 | "_model_module_version": "1.5.0",
28 | "_view_count": null,
29 | "_view_module_version": "1.5.0",
30 | "box_style": "",
31 | "layout": "IPY_MODEL_abfceea38af9444b8da09122eb0c867d",
32 | "_model_module": "@jupyter-widgets/controls",
33 | "children": [
34 | "IPY_MODEL_0dfe065d9c0c468389f75dd74a9a11ac",
35 | "IPY_MODEL_741dabfff47941f3b290d4ad4cb6be12"
36 | ]
37 | }
38 | },
39 | "abfceea38af9444b8da09122eb0c867d": {
40 | "model_module": "@jupyter-widgets/base",
41 | "model_name": "LayoutModel",
42 | "state": {
43 | "_view_name": "LayoutView",
44 | "grid_template_rows": null,
45 | "right": null,
46 | "justify_content": null,
47 | "_view_module": "@jupyter-widgets/base",
48 | "overflow": null,
49 | "_model_module_version": "1.2.0",
50 | "_view_count": null,
51 | "flex_flow": "row wrap",
52 | "width": "100%",
53 | "min_width": null,
54 | "border": null,
55 | "align_items": null,
56 | "bottom": null,
57 | "_model_module": "@jupyter-widgets/base",
58 | "top": null,
59 | "grid_column": null,
60 | "overflow_y": null,
61 | "overflow_x": null,
62 | "grid_auto_flow": null,
63 | "grid_area": null,
64 | "grid_template_columns": null,
65 | "flex": null,
66 | "_model_name": "LayoutModel",
67 | "justify_items": null,
68 | "grid_row": null,
69 | "max_height": null,
70 | "align_content": null,
71 | "visibility": null,
72 | "align_self": null,
73 | "height": null,
74 | "min_height": null,
75 | "padding": null,
76 | "grid_auto_rows": null,
77 | "grid_gap": null,
78 | "max_width": null,
79 | "order": null,
80 | "_view_module_version": "1.2.0",
81 | "grid_template_areas": null,
82 | "object_position": null,
83 | "object_fit": null,
84 | "grid_auto_columns": null,
85 | "margin": null,
86 | "display": "inline-flex",
87 | "left": null
88 | }
89 | },
90 | "0dfe065d9c0c468389f75dd74a9a11ac": {
91 | "model_module": "@jupyter-widgets/controls",
92 | "model_name": "FloatProgressModel",
93 | "state": {
94 | "_view_name": "ProgressView",
95 | "style": "IPY_MODEL_503e7ca9191948a4bcf6cc64a8862820",
96 | "_dom_classes": [],
97 | "description": "Validation sanity check: 100%",
98 | "_model_name": "FloatProgressModel",
99 | "bar_style": "info",
100 | "max": 1,
101 | "_view_module": "@jupyter-widgets/controls",
102 | "_model_module_version": "1.5.0",
103 | "value": 1,
104 | "_view_count": null,
105 | "_view_module_version": "1.5.0",
106 | "orientation": "horizontal",
107 | "min": 0,
108 | "description_tooltip": null,
109 | "_model_module": "@jupyter-widgets/controls",
110 | "layout": "IPY_MODEL_249092a9ce0940218d291fe75cf35f3e"
111 | }
112 | },
113 | "741dabfff47941f3b290d4ad4cb6be12": {
114 | "model_module": "@jupyter-widgets/controls",
115 | "model_name": "HTMLModel",
116 | "state": {
117 | "_view_name": "HTMLView",
118 | "style": "IPY_MODEL_828e6e2c56f3456cb016bf7c8b701ba8",
119 | "_dom_classes": [],
120 | "description": "",
121 | "_model_name": "HTMLModel",
122 | "placeholder": "",
123 | "_view_module": "@jupyter-widgets/controls",
124 | "_model_module_version": "1.5.0",
125 | "value": " 1/1.0 [00:05<00:00, 2.86s/it]",
126 | "_view_count": null,
127 | "_view_module_version": "1.5.0",
128 | "description_tooltip": null,
129 | "_model_module": "@jupyter-widgets/controls",
130 | "layout": "IPY_MODEL_7540c343bf9f461a84c157f84d529cd9"
131 | }
132 | },
133 | "503e7ca9191948a4bcf6cc64a8862820": {
134 | "model_module": "@jupyter-widgets/controls",
135 | "model_name": "ProgressStyleModel",
136 | "state": {
137 | "_view_name": "StyleView",
138 | "_model_name": "ProgressStyleModel",
139 | "description_width": "initial",
140 | "_view_module": "@jupyter-widgets/base",
141 | "_model_module_version": "1.5.0",
142 | "_view_count": null,
143 | "_view_module_version": "1.2.0",
144 | "bar_color": null,
145 | "_model_module": "@jupyter-widgets/controls"
146 | }
147 | },
148 | "249092a9ce0940218d291fe75cf35f3e": {
149 | "model_module": "@jupyter-widgets/base",
150 | "model_name": "LayoutModel",
151 | "state": {
152 | "_view_name": "LayoutView",
153 | "grid_template_rows": null,
154 | "right": null,
155 | "justify_content": null,
156 | "_view_module": "@jupyter-widgets/base",
157 | "overflow": null,
158 | "_model_module_version": "1.2.0",
159 | "_view_count": null,
160 | "flex_flow": null,
161 | "width": null,
162 | "min_width": null,
163 | "border": null,
164 | "align_items": null,
165 | "bottom": null,
166 | "_model_module": "@jupyter-widgets/base",
167 | "top": null,
168 | "grid_column": null,
169 | "overflow_y": null,
170 | "overflow_x": null,
171 | "grid_auto_flow": null,
172 | "grid_area": null,
173 | "grid_template_columns": null,
174 | "flex": "2",
175 | "_model_name": "LayoutModel",
176 | "justify_items": null,
177 | "grid_row": null,
178 | "max_height": null,
179 | "align_content": null,
180 | "visibility": null,
181 | "align_self": null,
182 | "height": null,
183 | "min_height": null,
184 | "padding": null,
185 | "grid_auto_rows": null,
186 | "grid_gap": null,
187 | "max_width": null,
188 | "order": null,
189 | "_view_module_version": "1.2.0",
190 | "grid_template_areas": null,
191 | "object_position": null,
192 | "object_fit": null,
193 | "grid_auto_columns": null,
194 | "margin": null,
195 | "display": null,
196 | "left": null
197 | }
198 | },
199 | "828e6e2c56f3456cb016bf7c8b701ba8": {
200 | "model_module": "@jupyter-widgets/controls",
201 | "model_name": "DescriptionStyleModel",
202 | "state": {
203 | "_view_name": "StyleView",
204 | "_model_name": "DescriptionStyleModel",
205 | "description_width": "",
206 | "_view_module": "@jupyter-widgets/base",
207 | "_model_module_version": "1.5.0",
208 | "_view_count": null,
209 | "_view_module_version": "1.2.0",
210 | "_model_module": "@jupyter-widgets/controls"
211 | }
212 | },
213 | "7540c343bf9f461a84c157f84d529cd9": {
214 | "model_module": "@jupyter-widgets/base",
215 | "model_name": "LayoutModel",
216 | "state": {
217 | "_view_name": "LayoutView",
218 | "grid_template_rows": null,
219 | "right": null,
220 | "justify_content": null,
221 | "_view_module": "@jupyter-widgets/base",
222 | "overflow": null,
223 | "_model_module_version": "1.2.0",
224 | "_view_count": null,
225 | "flex_flow": null,
226 | "width": null,
227 | "min_width": null,
228 | "border": null,
229 | "align_items": null,
230 | "bottom": null,
231 | "_model_module": "@jupyter-widgets/base",
232 | "top": null,
233 | "grid_column": null,
234 | "overflow_y": null,
235 | "overflow_x": null,
236 | "grid_auto_flow": null,
237 | "grid_area": null,
238 | "grid_template_columns": null,
239 | "flex": null,
240 | "_model_name": "LayoutModel",
241 | "justify_items": null,
242 | "grid_row": null,
243 | "max_height": null,
244 | "align_content": null,
245 | "visibility": null,
246 | "align_self": null,
247 | "height": null,
248 | "min_height": null,
249 | "padding": null,
250 | "grid_auto_rows": null,
251 | "grid_gap": null,
252 | "max_width": null,
253 | "order": null,
254 | "_view_module_version": "1.2.0",
255 | "grid_template_areas": null,
256 | "object_position": null,
257 | "object_fit": null,
258 | "grid_auto_columns": null,
259 | "margin": null,
260 | "display": null,
261 | "left": null
262 | }
263 | },
264 | "c43e576730a940c28fa78d49e95e7165": {
265 | "model_module": "@jupyter-widgets/controls",
266 | "model_name": "HBoxModel",
267 | "state": {
268 | "_view_name": "HBoxView",
269 | "_dom_classes": [],
270 | "_model_name": "HBoxModel",
271 | "_view_module": "@jupyter-widgets/controls",
272 | "_model_module_version": "1.5.0",
273 | "_view_count": null,
274 | "_view_module_version": "1.5.0",
275 | "box_style": "",
276 | "layout": "IPY_MODEL_b3d8f9b86d0d47ae8c51b8f2eb202aab",
277 | "_model_module": "@jupyter-widgets/controls",
278 | "children": [
279 | "IPY_MODEL_82170aabdef246edbf668bb1cdf4a5e3",
280 | "IPY_MODEL_43a58bf6d9ab454095f8f5f30f10cdca"
281 | ]
282 | }
283 | },
284 | "b3d8f9b86d0d47ae8c51b8f2eb202aab": {
285 | "model_module": "@jupyter-widgets/base",
286 | "model_name": "LayoutModel",
287 | "state": {
288 | "_view_name": "LayoutView",
289 | "grid_template_rows": null,
290 | "right": null,
291 | "justify_content": null,
292 | "_view_module": "@jupyter-widgets/base",
293 | "overflow": null,
294 | "_model_module_version": "1.2.0",
295 | "_view_count": null,
296 | "flex_flow": "row wrap",
297 | "width": "100%",
298 | "min_width": null,
299 | "border": null,
300 | "align_items": null,
301 | "bottom": null,
302 | "_model_module": "@jupyter-widgets/base",
303 | "top": null,
304 | "grid_column": null,
305 | "overflow_y": null,
306 | "overflow_x": null,
307 | "grid_auto_flow": null,
308 | "grid_area": null,
309 | "grid_template_columns": null,
310 | "flex": null,
311 | "_model_name": "LayoutModel",
312 | "justify_items": null,
313 | "grid_row": null,
314 | "max_height": null,
315 | "align_content": null,
316 | "visibility": null,
317 | "align_self": null,
318 | "height": null,
319 | "min_height": null,
320 | "padding": null,
321 | "grid_auto_rows": null,
322 | "grid_gap": null,
323 | "max_width": null,
324 | "order": null,
325 | "_view_module_version": "1.2.0",
326 | "grid_template_areas": null,
327 | "object_position": null,
328 | "object_fit": null,
329 | "grid_auto_columns": null,
330 | "margin": null,
331 | "display": "inline-flex",
332 | "left": null
333 | }
334 | },
335 | "82170aabdef246edbf668bb1cdf4a5e3": {
336 | "model_module": "@jupyter-widgets/controls",
337 | "model_name": "FloatProgressModel",
338 | "state": {
339 | "_view_name": "ProgressView",
340 | "style": "IPY_MODEL_f80b50d773b240a091c6c3bdf7961924",
341 | "_dom_classes": [],
342 | "description": "Epoch 1: 10%",
343 | "_model_name": "FloatProgressModel",
344 | "bar_style": "info",
345 | "max": 1577,
346 | "_view_module": "@jupyter-widgets/controls",
347 | "_model_module_version": "1.5.0",
348 | "value": 160,
349 | "_view_count": null,
350 | "_view_module_version": "1.5.0",
351 | "orientation": "horizontal",
352 | "min": 0,
353 | "description_tooltip": null,
354 | "_model_module": "@jupyter-widgets/controls",
355 | "layout": "IPY_MODEL_f8a352088b8d4ed896bf8a206ecc024e"
356 | }
357 | },
358 | "43a58bf6d9ab454095f8f5f30f10cdca": {
359 | "model_module": "@jupyter-widgets/controls",
360 | "model_name": "HTMLModel",
361 | "state": {
362 | "_view_name": "HTMLView",
363 | "style": "IPY_MODEL_2e151daf92e84c369ee90e8ded7a24f2",
364 | "_dom_classes": [],
365 | "description": "",
366 | "_model_name": "HTMLModel",
367 | "placeholder": "",
368 | "_view_module": "@jupyter-widgets/controls",
369 | "_model_module_version": "1.5.0",
370 | "value": " 160/1577 [06:58<1:01:48, 2.62s/it, loss=0.485, v_num=8dcde95206ec45daac4cc6657844b03d, train_loss=0.0916, train_loss_aux=1.97, train_prec_avg=0.405, total_loss_train=0.485, train_loss_avg=0.117]",
371 | "_view_count": null,
372 | "_view_module_version": "1.5.0",
373 | "description_tooltip": null,
374 | "_model_module": "@jupyter-widgets/controls",
375 | "layout": "IPY_MODEL_5b0b2240357743c4b8285ce6017638c4"
376 | }
377 | },
378 | "f80b50d773b240a091c6c3bdf7961924": {
379 | "model_module": "@jupyter-widgets/controls",
380 | "model_name": "ProgressStyleModel",
381 | "state": {
382 | "_view_name": "StyleView",
383 | "_model_name": "ProgressStyleModel",
384 | "description_width": "initial",
385 | "_view_module": "@jupyter-widgets/base",
386 | "_model_module_version": "1.5.0",
387 | "_view_count": null,
388 | "_view_module_version": "1.2.0",
389 | "bar_color": null,
390 | "_model_module": "@jupyter-widgets/controls"
391 | }
392 | },
393 | "f8a352088b8d4ed896bf8a206ecc024e": {
394 | "model_module": "@jupyter-widgets/base",
395 | "model_name": "LayoutModel",
396 | "state": {
397 | "_view_name": "LayoutView",
398 | "grid_template_rows": null,
399 | "right": null,
400 | "justify_content": null,
401 | "_view_module": "@jupyter-widgets/base",
402 | "overflow": null,
403 | "_model_module_version": "1.2.0",
404 | "_view_count": null,
405 | "flex_flow": null,
406 | "width": null,
407 | "min_width": null,
408 | "border": null,
409 | "align_items": null,
410 | "bottom": null,
411 | "_model_module": "@jupyter-widgets/base",
412 | "top": null,
413 | "grid_column": null,
414 | "overflow_y": null,
415 | "overflow_x": null,
416 | "grid_auto_flow": null,
417 | "grid_area": null,
418 | "grid_template_columns": null,
419 | "flex": "2",
420 | "_model_name": "LayoutModel",
421 | "justify_items": null,
422 | "grid_row": null,
423 | "max_height": null,
424 | "align_content": null,
425 | "visibility": null,
426 | "align_self": null,
427 | "height": null,
428 | "min_height": null,
429 | "padding": null,
430 | "grid_auto_rows": null,
431 | "grid_gap": null,
432 | "max_width": null,
433 | "order": null,
434 | "_view_module_version": "1.2.0",
435 | "grid_template_areas": null,
436 | "object_position": null,
437 | "object_fit": null,
438 | "grid_auto_columns": null,
439 | "margin": null,
440 | "display": null,
441 | "left": null
442 | }
443 | },
444 | "2e151daf92e84c369ee90e8ded7a24f2": {
445 | "model_module": "@jupyter-widgets/controls",
446 | "model_name": "DescriptionStyleModel",
447 | "state": {
448 | "_view_name": "StyleView",
449 | "_model_name": "DescriptionStyleModel",
450 | "description_width": "",
451 | "_view_module": "@jupyter-widgets/base",
452 | "_model_module_version": "1.5.0",
453 | "_view_count": null,
454 | "_view_module_version": "1.2.0",
455 | "_model_module": "@jupyter-widgets/controls"
456 | }
457 | },
458 | "5b0b2240357743c4b8285ce6017638c4": {
459 | "model_module": "@jupyter-widgets/base",
460 | "model_name": "LayoutModel",
461 | "state": {
462 | "_view_name": "LayoutView",
463 | "grid_template_rows": null,
464 | "right": null,
465 | "justify_content": null,
466 | "_view_module": "@jupyter-widgets/base",
467 | "overflow": null,
468 | "_model_module_version": "1.2.0",
469 | "_view_count": null,
470 | "flex_flow": null,
471 | "width": null,
472 | "min_width": null,
473 | "border": null,
474 | "align_items": null,
475 | "bottom": null,
476 | "_model_module": "@jupyter-widgets/base",
477 | "top": null,
478 | "grid_column": null,
479 | "overflow_y": null,
480 | "overflow_x": null,
481 | "grid_auto_flow": null,
482 | "grid_area": null,
483 | "grid_template_columns": null,
484 | "flex": null,
485 | "_model_name": "LayoutModel",
486 | "justify_items": null,
487 | "grid_row": null,
488 | "max_height": null,
489 | "align_content": null,
490 | "visibility": null,
491 | "align_self": null,
492 | "height": null,
493 | "min_height": null,
494 | "padding": null,
495 | "grid_auto_rows": null,
496 | "grid_gap": null,
497 | "max_width": null,
498 | "order": null,
499 | "_view_module_version": "1.2.0",
500 | "grid_template_areas": null,
501 | "object_position": null,
502 | "object_fit": null,
503 | "grid_auto_columns": null,
504 | "margin": null,
505 | "display": null,
506 | "left": null
507 | }
508 | }
509 | }
510 | }
511 | },
512 | "cells": [
513 | {
514 | "cell_type": "markdown",
515 | "metadata": {
516 | "id": "uoLSVVIBCwLm",
517 | "colab_type": "text"
518 | },
519 | "source": [
520 | "# Interpretable Contextual Team-aware Item Recommendation: Application in Multiplayer Online Battle Arena Games\n",
521 | "*Andres Villa, Vladimir Araujo, Francisca Cattan*"
522 | ]
523 | },
524 | {
525 | "cell_type": "markdown",
526 | "metadata": {
527 | "id": "t8_YV_PIDR97",
528 | "colab_type": "text"
529 | },
530 | "source": [
531 | "# Introduction\n",
532 | "\n",
533 | "This notebook contains the code of the proposed model. It is composed of 8 main stages:\n",
534 | "\n",
535 | "1. Connect to gDrive\n",
536 | "2. Dataset and Transformations\n",
537 | "3. Model\n",
538 | "4. Logger and Checkpointer\n",
539 | "5. Metrics\n",
540 | "6. Training and evaluation loop\n",
541 | "7. Config file\n",
542 | "8. Training and evaluation executor\n",
543 | "9. Obtain the role and id of each champion in each match\n",
544 | "10. Load the attention weights\n",
545 | "11. Draw the attention map\n",
546 | "\n",
547 | "*This notebook can be run in it's entirety. The final cell executes the training and validation of the model. "
548 | ]
549 | },
550 | {
551 | "cell_type": "markdown",
552 | "metadata": {
553 | "id": "eYxhCKYPbBT2",
554 | "colab_type": "toc"
555 | },
556 | "source": [
557 | ">[Main Model - Project Title](#scrollTo=uoLSVVIBCwLm)\n",
558 | "\n",
559 | ">[Introduction](#scrollTo=t8_YV_PIDR97)\n",
560 | "\n",
561 | ">[Install all the dependencies](#scrollTo=etkQTYydGkFM)\n",
562 | "\n",
563 | ">[Import the dependencies](#scrollTo=S0YvGjijGxET)\n",
564 | "\n",
565 | ">[Connect to gDrive](#scrollTo=pfDyM4E7G4L2)\n",
566 | "\n",
567 | ">[Dataset and Transformations](#scrollTo=h9MDWroJSkhM)\n",
568 | "\n",
569 | ">[Model](#scrollTo=UIm1_KUCUNB0)\n",
570 | "\n",
571 | ">>[Transformer encoder modified to obtain the attention weights](#scrollTo=qr3TZbrnUg2H)\n",
572 | "\n",
573 | ">>[Auxiliary Task Classes](#scrollTo=pwRy106QU6sH)\n",
574 | "\n",
575 | ">>[Main Class of the proposed model](#scrollTo=9AlU_u42VG8A)\n",
576 | "\n",
577 | ">[Logger and Checkpointer](#scrollTo=rwYoKWcsVqex)\n",
578 | "\n",
579 | ">[Metrics](#scrollTo=5ktMqAUMWeEz)\n",
580 | "\n",
581 | ">[Training and evaluation loop](#scrollTo=WDA0GHysW4vX)\n",
582 | "\n",
583 | ">[Config file](#scrollTo=CyRfaqN8XvYi)\n",
584 | "\n",
585 | ">[Training and evaluation executor](#scrollTo=IVtKoVTcYDS1)\n",
586 | "\n",
587 | ">[T-test](#scrollTo=D2TPs5U3vv7m)\n",
588 | "\n",
589 | ">[Obtain the role and id of each champion in each match](#scrollTo=sFtaUCU5T8fl)\n",
590 | "\n",
591 | ">[Load the attention weights](#scrollTo=VINfHm76U1vz)\n",
592 | "\n",
593 | ">[Draw the attention map](#scrollTo=SvhQCEzcU6x_)\n",
594 | "\n"
595 | ]
596 | },
597 | {
598 | "cell_type": "markdown",
599 | "metadata": {
600 | "id": "etkQTYydGkFM",
601 | "colab_type": "text"
602 | },
603 | "source": [
604 | "# Install all the dependencies"
605 | ]
606 | },
607 | {
608 | "cell_type": "markdown",
609 | "metadata": {
610 | "id": "uH0huCgYRJou",
611 | "colab_type": "text"
612 | },
613 | "source": [
614 | "Install all the libraries neccesary to run the model. "
615 | ]
616 | },
617 | {
618 | "cell_type": "code",
619 | "metadata": {
620 | "id": "qsqG6vM9tqER",
621 | "colab_type": "code",
622 | "colab": {
623 | "base_uri": "https://localhost:8080/",
624 | "height": 369
625 | },
626 | "outputId": "1450ea6d-a49f-4586-b3a9-6c111b197efa"
627 | },
628 | "source": [
629 | "!nvidia-smi"
630 | ],
631 | "execution_count": 3,
632 | "outputs": [
633 | {
634 | "output_type": "stream",
635 | "text": [
636 | "Tue Jul 28 06:40:42 2020 \n",
637 | "+-----------------------------------------------------------------------------+\n",
638 | "| NVIDIA-SMI 450.51.05 Driver Version: 418.67 CUDA Version: 10.1 |\n",
639 | "|-------------------------------+----------------------+----------------------+\n",
640 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
641 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
642 | "| | | MIG M. |\n",
643 | "|===============================+======================+======================|\n",
644 | "| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
645 | "| N/A 43C P0 28W / 250W | 0MiB / 16280MiB | 0% Default |\n",
646 | "| | | ERR! |\n",
647 | "+-------------------------------+----------------------+----------------------+\n",
648 | " \n",
649 | "+-----------------------------------------------------------------------------+\n",
650 | "| Processes: |\n",
651 | "| GPU GI CI PID Type Process name GPU Memory |\n",
652 | "| ID ID Usage |\n",
653 | "|=============================================================================|\n",
654 | "| No running processes found |\n",
655 | "+-----------------------------------------------------------------------------+\n"
656 | ],
657 | "name": "stdout"
658 | }
659 | ]
660 | },
661 | {
662 | "cell_type": "code",
663 | "metadata": {
664 | "id": "qTd06UwscDsS",
665 | "colab_type": "code",
666 | "colab": {
667 | "base_uri": "https://localhost:8080/",
668 | "height": 1000
669 | },
670 | "outputId": "ca753b75-1078-420f-cdfb-b3e5ee505b97"
671 | },
672 | "source": [
673 | "!pip install git+git://github.com/williamFalcon/pytorch-lightning.git@master --upgrade"
674 | ],
675 | "execution_count": 4,
676 | "outputs": [
677 | {
678 | "output_type": "stream",
679 | "text": [
680 | "Collecting git+git://github.com/williamFalcon/pytorch-lightning.git@master\n",
681 | " Cloning git://github.com/williamFalcon/pytorch-lightning.git (to revision master) to /tmp/pip-req-build-j84m_1zy\n",
682 | " Running command git clone -q git://github.com/williamFalcon/pytorch-lightning.git /tmp/pip-req-build-j84m_1zy\n",
683 | " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
684 | " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
685 | " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
686 | "Collecting future>=0.17.1\n",
687 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)\n",
688 | "\u001b[K |████████████████████████████████| 829kB 2.9MB/s \n",
689 | "\u001b[?25hRequirement already satisfied, skipping upgrade: tqdm>=4.41.0 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning==0.9.0rc2) (4.41.1)\n",
690 | "Requirement already satisfied, skipping upgrade: numpy>=1.16.4 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning==0.9.0rc2) (1.18.5)\n",
691 | "Requirement already satisfied, skipping upgrade: torch>=1.3 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning==0.9.0rc2) (1.5.1+cu101)\n",
692 | "Requirement already satisfied, skipping upgrade: tensorboard>=1.14 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning==0.9.0rc2) (2.2.2)\n",
693 | "Collecting PyYAML>=5.1\n",
694 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n",
695 | "\u001b[K |████████████████████████████████| 276kB 15.9MB/s \n",
696 | "\u001b[?25hRequirement already satisfied, skipping upgrade: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (3.12.2)\n",
697 | "Requirement already satisfied, skipping upgrade: grpcio>=1.24.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.30.0)\n",
698 | "Requirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (0.34.2)\n",
699 | "Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.0.1)\n",
700 | "Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (3.2.2)\n",
701 | "Requirement already satisfied, skipping upgrade: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (0.9.0)\n",
702 | "Requirement already satisfied, skipping upgrade: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.15.0)\n",
703 | "Requirement already satisfied, skipping upgrade: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.7.0)\n",
704 | "Requirement already satisfied, skipping upgrade: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (49.1.0)\n",
705 | "Requirement already satisfied, skipping upgrade: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.17.2)\n",
706 | "Requirement already satisfied, skipping upgrade: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (2.23.0)\n",
707 | "Requirement already satisfied, skipping upgrade: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (0.4.1)\n",
708 | "Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.7.0)\n",
709 | "Requirement already satisfied, skipping upgrade: rsa<5,>=3.1.4; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (4.6)\n",
710 | "Requirement already satisfied, skipping upgrade: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (4.1.1)\n",
711 | "Requirement already satisfied, skipping upgrade: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (0.2.8)\n",
712 | "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (2020.6.20)\n",
713 | "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.24.3)\n",
714 | "Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (3.0.4)\n",
715 | "Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (2.10)\n",
716 | "Requirement already satisfied, skipping upgrade: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.3.0)\n",
717 | "Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (3.1.0)\n",
718 | "Requirement already satisfied, skipping upgrade: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<5,>=3.1.4; python_version >= \"3\"->google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (0.4.8)\n",
719 | "Requirement already satisfied, skipping upgrade: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (3.1.0)\n",
720 | "Building wheels for collected packages: pytorch-lightning\n",
721 | " Building wheel for pytorch-lightning (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
722 | " Created wheel for pytorch-lightning: filename=pytorch_lightning-0.9.0rc2-cp36-none-any.whl size=353828 sha256=30b73a303ccd241770a24f1984519ad7e086144a4db7285f7b8166e4de330d64\n",
723 | " Stored in directory: /tmp/pip-ephem-wheel-cache-jlmk53yu/wheels/02/e9/33/ecf2ab0b937f47c530a3d24222ca1a784412a0c7d490195c5f\n",
724 | "Successfully built pytorch-lightning\n",
725 | "Building wheels for collected packages: future, PyYAML\n",
726 | " Building wheel for future (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
727 | " Created wheel for future: filename=future-0.18.2-cp36-none-any.whl size=491057 sha256=15e732369ebb372a11250e7c9d0e57f7af49cc24c3e76854d4126655f02df3b9\n",
728 | " Stored in directory: /root/.cache/pip/wheels/8b/99/a0/81daf51dcd359a9377b110a8a886b3895921802d2fc1b2397e\n",
729 | " Building wheel for PyYAML (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
730 | " Created wheel for PyYAML: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44621 sha256=085fe27c9be3cbcd42d6336ee010929d00e36263ea5e88a9613aeb9fbf6e1b49\n",
731 | " Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd\n",
732 | "Successfully built future PyYAML\n",
733 | "Installing collected packages: future, PyYAML, pytorch-lightning\n",
734 | " Found existing installation: future 0.16.0\n",
735 | " Uninstalling future-0.16.0:\n",
736 | " Successfully uninstalled future-0.16.0\n",
737 | " Found existing installation: PyYAML 3.13\n",
738 | " Uninstalling PyYAML-3.13:\n",
739 | " Successfully uninstalled PyYAML-3.13\n",
740 | "Successfully installed PyYAML-5.3.1 future-0.18.2 pytorch-lightning-0.9.0rc2\n"
741 | ],
742 | "name": "stdout"
743 | }
744 | ]
745 | },
746 | {
747 | "cell_type": "code",
748 | "metadata": {
749 | "id": "AwD1P0lHVaLO",
750 | "colab_type": "code",
751 | "colab": {
752 | "base_uri": "https://localhost:8080/",
753 | "height": 600
754 | },
755 | "outputId": "4d2b24ba-90cd-4304-cb42-f09c697a3827"
756 | },
757 | "source": [
758 | "!pip install comet_ml==3.0.2"
759 | ],
760 | "execution_count": 5,
761 | "outputs": [
762 | {
763 | "output_type": "stream",
764 | "text": [
765 | "Collecting comet_ml==3.0.2\n",
766 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/99/c6/fac88f43f2aa61a09fee4ffb769c73fe93fe7de75764246e70967d31da09/comet_ml-3.0.2-py3-none-any.whl (170kB)\n",
767 | "\u001b[K |████████████████████████████████| 174kB 2.9MB/s \n",
768 | "\u001b[?25hCollecting websocket-client>=0.55.0\n",
769 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/4c/5f/f61b420143ed1c8dc69f9eaec5ff1ac36109d52c80de49d66e0c36c3dfdf/websocket_client-0.57.0-py2.py3-none-any.whl (200kB)\n",
770 | "\u001b[K |████████████████████████████████| 204kB 8.9MB/s \n",
771 | "\u001b[?25hCollecting everett[ini]>=1.0.1; python_version >= \"3.0\"\n",
772 | " Downloading https://files.pythonhosted.org/packages/12/34/de70a3d913411e40ce84966f085b5da0c6df741e28c86721114dd290aaa0/everett-1.0.2-py2.py3-none-any.whl\n",
773 | "Requirement already satisfied: requests>=2.18.4 in /usr/local/lib/python3.6/dist-packages (from comet_ml==3.0.2) (2.23.0)\n",
774 | "Collecting wurlitzer>=1.0.2\n",
775 | " Downloading https://files.pythonhosted.org/packages/0c/1e/52f4effa64a447c4ec0fb71222799e2ac32c55b4b6c1725fccdf6123146e/wurlitzer-2.0.1-py2.py3-none-any.whl\n",
776 | "Collecting comet-git-pure>=0.19.11\n",
777 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/72/7a/483413046e48908986a0f9a1d8a917e1da46ae58e6ba16b2ac71b3adf8d7/comet_git_pure-0.19.16-py3-none-any.whl (409kB)\n",
778 | "\u001b[K |████████████████████████████████| 419kB 8.8MB/s \n",
779 | "\u001b[?25hRequirement already satisfied: jsonschema<3.1.0,>=2.6.0 in /usr/local/lib/python3.6/dist-packages (from comet_ml==3.0.2) (2.6.0)\n",
780 | "Collecting netifaces>=0.10.7\n",
781 | " Downloading https://files.pythonhosted.org/packages/0c/9b/c4c7eb09189548d45939a3d3a6b3d53979c67d124459b27a094c365c347f/netifaces-0.10.9-cp36-cp36m-manylinux1_x86_64.whl\n",
782 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from comet_ml==3.0.2) (1.15.0)\n",
783 | "Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.6/dist-packages (from comet_ml==3.0.2) (7.352.0)\n",
784 | "Collecting configobj; extra == \"ini\"\n",
785 | " Downloading https://files.pythonhosted.org/packages/64/61/079eb60459c44929e684fa7d9e2fdca403f67d64dd9dbac27296be2e0fab/configobj-5.0.6.tar.gz\n",
786 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18.4->comet_ml==3.0.2) (1.24.3)\n",
787 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18.4->comet_ml==3.0.2) (2020.6.20)\n",
788 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18.4->comet_ml==3.0.2) (3.0.4)\n",
789 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18.4->comet_ml==3.0.2) (2.10)\n",
790 | "Building wheels for collected packages: configobj\n",
791 | " Building wheel for configobj (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
792 | " Created wheel for configobj: filename=configobj-5.0.6-cp36-none-any.whl size=34546 sha256=b4795ece5a011d0faed00fc2c825d5a241e54d2ac170c3ec2b6ba97eafcb809f\n",
793 | " Stored in directory: /root/.cache/pip/wheels/f1/e4/16/4981ca97c2d65106b49861e0b35e2660695be7219a2d351ee0\n",
794 | "Successfully built configobj\n",
795 | "Installing collected packages: websocket-client, configobj, everett, wurlitzer, comet-git-pure, netifaces, comet-ml\n",
796 | "Successfully installed comet-git-pure-0.19.16 comet-ml-3.0.2 configobj-5.0.6 everett-1.0.2 netifaces-0.10.9 websocket-client-0.57.0 wurlitzer-2.0.1\n"
797 | ],
798 | "name": "stdout"
799 | }
800 | ]
801 | },
802 | {
803 | "cell_type": "code",
804 | "metadata": {
805 | "id": "9vHwQVrRWBgV",
806 | "colab_type": "code",
807 | "colab": {
808 | "base_uri": "https://localhost:8080/",
809 | "height": 160
810 | },
811 | "outputId": "3b656b18-e85b-4549-a55b-f99a5d1b1c73"
812 | },
813 | "source": [
814 | "!pip install omegaconf"
815 | ],
816 | "execution_count": 6,
817 | "outputs": [
818 | {
819 | "output_type": "stream",
820 | "text": [
821 | "Collecting omegaconf\n",
822 | " Downloading https://files.pythonhosted.org/packages/3d/95/ebd73361f9c6e94bd0f3b19ffe31c24e833834c022f1c0328ac71b2d6c90/omegaconf-2.0.0-py3-none-any.whl\n",
823 | "Requirement already satisfied: PyYAML in /usr/local/lib/python3.6/dist-packages (from omegaconf) (5.3.1)\n",
824 | "Requirement already satisfied: dataclasses; python_version == \"3.6\" in /usr/local/lib/python3.6/dist-packages (from omegaconf) (0.7)\n",
825 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from omegaconf) (3.7.4.2)\n",
826 | "Installing collected packages: omegaconf\n",
827 | "Successfully installed omegaconf-2.0.0\n"
828 | ],
829 | "name": "stdout"
830 | }
831 | ]
832 | },
833 | {
834 | "cell_type": "code",
835 | "metadata": {
836 | "id": "6yA6EwQM6P8i",
837 | "colab_type": "code",
838 | "colab": {
839 | "base_uri": "https://localhost:8080/",
840 | "height": 160
841 | },
842 | "outputId": "0dfe679a-88c9-4cdc-81d8-c98b24d5e010"
843 | },
844 | "source": [
845 | "!pip install adabound"
846 | ],
847 | "execution_count": 7,
848 | "outputs": [
849 | {
850 | "output_type": "stream",
851 | "text": [
852 | "Collecting adabound\n",
853 | " Downloading https://files.pythonhosted.org/packages/cd/44/0c2c414effb3d9750d780b230dbb67ea48ddc5d9a6d7a9b7e6fcc6bdcff9/adabound-0.0.5-py3-none-any.whl\n",
854 | "Requirement already satisfied: torch>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from adabound) (1.5.1+cu101)\n",
855 | "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch>=0.4.0->adabound) (0.18.2)\n",
856 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch>=0.4.0->adabound) (1.18.5)\n",
857 | "Installing collected packages: adabound\n",
858 | "Successfully installed adabound-0.0.5\n"
859 | ],
860 | "name": "stdout"
861 | }
862 | ]
863 | },
864 | {
865 | "cell_type": "code",
866 | "metadata": {
867 | "id": "JMNgXeWwNx-N",
868 | "colab_type": "code",
869 | "colab": {
870 | "base_uri": "https://localhost:8080/",
871 | "height": 283
872 | },
873 | "outputId": "521241cd-9097-4098-da3a-71bc7d2e3c05"
874 | },
875 | "source": [
876 | "!pip install ml_metrics"
877 | ],
878 | "execution_count": 8,
879 | "outputs": [
880 | {
881 | "output_type": "stream",
882 | "text": [
883 | "Collecting ml_metrics\n",
884 | " Downloading https://files.pythonhosted.org/packages/c1/e7/c31a2dd37045a0c904bee31c2dbed903d4f125a6ce980b91bae0c961abb8/ml_metrics-0.1.4.tar.gz\n",
885 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from ml_metrics) (1.18.5)\n",
886 | "Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from ml_metrics) (1.0.5)\n",
887 | "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas->ml_metrics) (2.8.1)\n",
888 | "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->ml_metrics) (2018.9)\n",
889 | "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.6.1->pandas->ml_metrics) (1.15.0)\n",
890 | "Building wheels for collected packages: ml-metrics\n",
891 | " Building wheel for ml-metrics (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
892 | " Created wheel for ml-metrics: filename=ml_metrics-0.1.4-cp36-none-any.whl size=7850 sha256=4c96be29d4d35a67a4b7cfad5648a20405ba81588e3a54aafe24d966ed354a94\n",
893 | " Stored in directory: /root/.cache/pip/wheels/b3/61/2d/776be7b8a4f14c5db48c8e5451451cabc58dc6aa7ee3801163\n",
894 | "Successfully built ml-metrics\n",
895 | "Installing collected packages: ml-metrics\n",
896 | "Successfully installed ml-metrics-0.1.4\n"
897 | ],
898 | "name": "stdout"
899 | }
900 | ]
901 | },
902 | {
903 | "cell_type": "markdown",
904 | "metadata": {
905 | "id": "S0YvGjijGxET",
906 | "colab_type": "text"
907 | },
908 | "source": [
909 | "# Import the dependencies"
910 | ]
911 | },
912 | {
913 | "cell_type": "markdown",
914 | "metadata": {
915 | "id": "ypMu2vhSRZD2",
916 | "colab_type": "text"
917 | },
918 | "source": [
919 | "Import all the libraries neccesary to run the model."
920 | ]
921 | },
922 | {
923 | "cell_type": "code",
924 | "metadata": {
925 | "id": "jNaJbRoaa8ZU",
926 | "colab_type": "code",
927 | "colab": {}
928 | },
929 | "source": [
930 | "from comet_ml import Experiment as CometExperiment\n",
931 | "from comet_ml import ExistingExperiment as CometExistingExperiment\n",
932 | "from google.colab import drive\n",
933 | "import torch\n",
934 | "import copy\n",
935 | "import torch.nn as nn\n",
936 | "import torch.nn.functional as F\n",
937 | "from torch.utils.data import Dataset, DataLoader\n",
938 | "import numpy as np\n",
939 | "from omegaconf import OmegaConf\n",
940 | "from omegaconf.dictconfig import DictConfig\n",
941 | "import pandas as pd\n",
942 | "import time\n",
943 | "\n",
944 | "# from tqdm.notebook import trange, tqdm\n",
945 | "from pytorch_lightning.callbacks import ModelCheckpoint\n",
946 | "from pytorch_lightning.utilities import rank_zero_only\n",
947 | "from pytorch_lightning.logging import LightningLoggerBase\n",
948 | "from pytorch_lightning.loggers import CometLogger\n",
949 | "\n",
950 | "import os\n",
951 | "import pytorch_lightning as pl\n",
952 | "import pickle\n",
953 | "import adabound\n",
954 | "import ml_metrics as metrics\n",
955 | "import random\n",
956 | "import itertools\n",
957 | "from torchvision import transforms\n"
958 | ],
959 | "execution_count": 9,
960 | "outputs": []
961 | },
962 | {
963 | "cell_type": "markdown",
964 | "metadata": {
965 | "id": "pfDyM4E7G4L2",
966 | "colab_type": "text"
967 | },
968 | "source": [
969 | "# Connect to gDrive"
970 | ]
971 | },
972 | {
973 | "cell_type": "markdown",
974 | "metadata": {
975 | "id": "XLY4l7XzReoa",
976 | "colab_type": "text"
977 | },
978 | "source": [
979 | "Connect the notebook with the gDrive, which is essential to load and save data like dataset, checkpoints, and attention weights. "
980 | ]
981 | },
982 | {
983 | "cell_type": "code",
984 | "metadata": {
985 | "id": "hmQwtSRCbchS",
986 | "colab_type": "code",
987 | "colab": {
988 | "base_uri": "https://localhost:8080/",
989 | "height": 125
990 | },
991 | "outputId": "26c7f9ce-aa16-4b69-aea8-38817cb64731"
992 | },
993 | "source": [
994 | "drive.mount('/content/gdrive/')"
995 | ],
996 | "execution_count": 10,
997 | "outputs": [
998 | {
999 | "output_type": "stream",
1000 | "text": [
1001 | "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n",
1002 | "\n",
1003 | "Enter your authorization code:\n",
1004 | "··········\n",
1005 | "Mounted at /content/gdrive/\n"
1006 | ],
1007 | "name": "stdout"
1008 | }
1009 | ]
1010 | },
1011 | {
1012 | "cell_type": "markdown",
1013 | "metadata": {
1014 | "id": "h9MDWroJSkhM",
1015 | "colab_type": "text"
1016 | },
1017 | "source": [
1018 | "# Dataset and Transformations"
1019 | ]
1020 | },
1021 | {
1022 | "cell_type": "markdown",
1023 | "metadata": {
1024 | "id": "uIDIGrXVSzdy",
1025 | "colab_type": "text"
1026 | },
1027 | "source": [
1028 | "This is important to load the k different partitions which are obtained using cross validation k-fold."
1029 | ]
1030 | },
1031 | {
1032 | "cell_type": "code",
1033 | "metadata": {
1034 | "id": "Mir9w6zDzkSN",
1035 | "colab_type": "code",
1036 | "colab": {}
1037 | },
1038 | "source": [
1039 | "train_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/train_splits.pkl'\n",
1040 | "test_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/test_splits.pkl'\n",
1041 | "champion_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/champion_types.pkl'"
1042 | ],
1043 | "execution_count": 11,
1044 | "outputs": []
1045 | },
1046 | {
1047 | "cell_type": "code",
1048 | "metadata": {
1049 | "id": "yDP51QcmNgCd",
1050 | "colab_type": "code",
1051 | "cellView": "form",
1052 | "colab": {}
1053 | },
1054 | "source": [
1055 | "#@title Cargar listas de particiones\n",
1056 | "with open(train_path, 'rb') as handle:\n",
1057 | " list_trainset = pickle.load(handle)\n",
1058 | "\n",
1059 | "with open(test_path, 'rb') as handle:\n",
1060 | " list_testset = pickle.load(handle)\n",
1061 | "\n",
1062 | "with open(champion_path, 'rb') as handle:\n",
1063 | " champion_types = pickle.load(handle)"
1064 | ],
1065 | "execution_count": 12,
1066 | "outputs": []
1067 | },
1068 | {
1069 | "cell_type": "code",
1070 | "metadata": {
1071 | "id": "2zbdC7M1S3kh",
1072 | "colab_type": "code",
1073 | "colab": {}
1074 | },
1075 | "source": [
1076 | "def get_partition(id_split, list_splits = list_trainset):\n",
1077 | " df = list_splits[id_split]\n",
1078 | " null_registers = df.loc[(df.item1 == 0) & (df.item2 == 0) & (df.item3 == 0) & (df.item4 == 0) & (df.item5 == 0) & (df.item6 == 0)]\n",
1079 | " match_to_del = list(set(null_registers['matchid']))\n",
1080 | " df = df[~df.matchid.isin(match_to_del)]\n",
1081 | " return df"
1082 | ],
1083 | "execution_count": 13,
1084 | "outputs": []
1085 | },
1086 | {
1087 | "cell_type": "markdown",
1088 | "metadata": {
1089 | "id": "dK-OEg1dTYJw",
1090 | "colab_type": "text"
1091 | },
1092 | "source": [
1093 | "These transformations rote randomly the order between the two teams, and the champions within each team."
1094 | ]
1095 | },
1096 | {
1097 | "cell_type": "code",
1098 | "metadata": {
1099 | "id": "muCDKh3n524z",
1100 | "colab_type": "code",
1101 | "colab": {}
1102 | },
1103 | "source": [
1104 | "class RandomSort_Team(object):\n",
1105 | " \"\"\"Crop randomly the image in a sample.\n",
1106 | "\n",
1107 | " Args:\n",
1108 | " output_size (tuple or int): Desired output size. If int, square crop\n",
1109 | " is made.\n",
1110 | " \"\"\"\n",
1111 | " \n",
1112 | " def get_random_sample(self, sample):\n",
1113 | " x, y = sample\n",
1114 | "\n",
1115 | " ids_teams_1 = [x for x in range(5)]\n",
1116 | " ids_teams_2 = [x for x in range(5,10)]\n",
1117 | "\n",
1118 | " ids_team_t = [ids_teams_1, ids_teams_2]\n",
1119 | "\n",
1120 | " ids_teams = [1, 0]\n",
1121 | " #ids_teams = [x for x in range(2)]\n",
1122 | " #random.shuffle(ids_teams)\n",
1123 | "\n",
1124 | " ids_team_t = [ids_team_t[i] for i in ids_teams]\n",
1125 | " \n",
1126 | " ids_team_t = list(itertools.chain.from_iterable(ids_team_t))\n",
1127 | "\n",
1128 | " x['champions'] = x['champions'][ids_team_t]\n",
1129 | " x['role'] = x['role'][ids_team_t]\n",
1130 | " x['type'] = x['type'][ids_team_t,:]\n",
1131 | "\n",
1132 | " y['items'] = y['items'][ids_team_t,:]\n",
1133 | "\n",
1134 | " if ids_teams == [1, 0]:\n",
1135 | " y['win'] = torch.tensor(1) - y['win']\n",
1136 | " \n",
1137 | " return x, y\n",
1138 | "\n",
1139 | " def __call__(self, sample_list):\n",
1140 | " list_x_champions = []\n",
1141 | " list_x_role = []\n",
1142 | " list_x_type = []\n",
1143 | " list_y_items = []\n",
1144 | " list_y_win = []\n",
1145 | " x_old, y_old = sample_list\n",
1146 | " if isinstance(x_old, (list)) and isinstance(y_old, (list)):\n",
1147 | " for i in range(len(x_old)):\n",
1148 | " list_x_champions.append(x_old[i]['champions'])\n",
1149 | " list_x_role.append(x_old[i]['role'])\n",
1150 | " list_x_type.append(x_old[i]['type'])\n",
1151 | " list_y_items.append(y_old[i]['items'])\n",
1152 | " list_y_win.append(y_old[i]['win'])\n",
1153 | " sample = x_old[i], y_old[i]\n",
1154 | " x, y = self.get_random_sample(sample)\n",
1155 | " list_x_champions.append(x['champions'])\n",
1156 | " list_x_role.append(x['role'])\n",
1157 | " list_x_type.append(x['type'])\n",
1158 | " list_y_items.append(y['items'])\n",
1159 | " list_y_win.append(y['win'])\n",
1160 | " else:\n",
1161 | " list_x_champions.append(x_old['champions'])\n",
1162 | " list_x_role.append(x_old['role'])\n",
1163 | " list_x_type.append(x_old['type'])\n",
1164 | " list_y_items.append(y_old['items'])\n",
1165 | " list_y_win.append(y_old['win'])\n",
1166 | " sample = x_old, y_old\n",
1167 | " x, y = self.get_random_sample(sample)\n",
1168 | " list_x_champions.append(x['champions'])\n",
1169 | " list_x_role.append(x['role'])\n",
1170 | " list_x_type.append(x['type'])\n",
1171 | " list_y_items.append(y['items'])\n",
1172 | " list_y_win.append(y['win'])\n",
1173 | " new_x = {\n",
1174 | " 'champions': torch.stack(list_x_champions, dim=0),\n",
1175 | " 'role': torch.stack(list_x_role, dim=0),\n",
1176 | " 'type': torch.stack(list_x_type, dim=0)\n",
1177 | " }\n",
1178 | " new_y = {\n",
1179 | " 'items': torch.stack(list_y_items, dim=0),\n",
1180 | " 'win': torch.stack(list_y_win, dim=0)\n",
1181 | " }\n",
1182 | " return new_x, new_y\n"
1183 | ],
1184 | "execution_count": 14,
1185 | "outputs": []
1186 | },
1187 | {
1188 | "cell_type": "code",
1189 | "metadata": {
1190 | "id": "ozFSmovk06GG",
1191 | "colab_type": "code",
1192 | "colab": {}
1193 | },
1194 | "source": [
1195 | "class RandomSort_Part(object):\n",
1196 | " \"\"\"Crop randomly the image in a sample.\n",
1197 | "\n",
1198 | " Args:\n",
1199 | " output_size (tuple or int): Desired output size. If int, square crop\n",
1200 | " is made.\n",
1201 | " \"\"\"\n",
1202 | " \n",
1203 | "\n",
1204 | " def __call__(self, sample):\n",
1205 | "\n",
1206 | " list_t_x = []\n",
1207 | " list_t_y = []\n",
1208 | " x, y = sample\n",
1209 | "\n",
1210 | " list_t_x.append(x)\n",
1211 | " list_t_y.append(y)\n",
1212 | "\n",
1213 | " ids_team_1 = [x for x in range(5)]\n",
1214 | " ids_team_2 = [x for x in range(5,10)]\n",
1215 | " random.shuffle(ids_team_1)\n",
1216 | " random.shuffle(ids_team_2)\n",
1217 | "\n",
1218 | " ids_match = ids_team_1\n",
1219 | " ids_match.extend(ids_team_2)\n",
1220 | " \n",
1221 | " x['champions'] = x['champions'][ids_match]\n",
1222 | " x['role'] = x['role'][ids_match]\n",
1223 | " x['type'] = x['type'][ids_match,:]\n",
1224 | "\n",
1225 | " y['items'] = y['items'][ids_match,:]\n",
1226 | "\n",
1227 | " list_t_x.append(x)\n",
1228 | " list_t_y.append(y)\n",
1229 | "\n",
1230 | " return list_t_x, list_t_y"
1231 | ],
1232 | "execution_count": 15,
1233 | "outputs": []
1234 | },
1235 | {
1236 | "cell_type": "code",
1237 | "metadata": {
1238 | "id": "i3KCpfKDPX3D",
1239 | "colab_type": "code",
1240 | "colab": {}
1241 | },
1242 | "source": [
1243 | "class LolDataset(Dataset):\n",
1244 | " def __init__(self, data, transform=None):\n",
1245 | " # cargar el dataset\n",
1246 | " #self.matches = self._load_matches(path)\n",
1247 | " self.matches = data\n",
1248 | " # comprobar si existe el .pkl con los diccionarios\n",
1249 | "\n",
1250 | " # else:\n",
1251 | " # extraer info. del dataframe\n",
1252 | " self.champions = set(self.matches['championid'])\n",
1253 | " self.roles = set(self.matches['position-role'])\n",
1254 | " self.matches_id = list(set(self.matches['matchid']))\n",
1255 | " self.items = self.matches['item1']\n",
1256 | " self.items.append(self.matches['item2'])\n",
1257 | " self.items.append(self.matches['item3'])\n",
1258 | " self.items.append(self.matches['item4'])\n",
1259 | " self.items.append(self.matches['item5'])\n",
1260 | " self.items.append(self.matches['item6'])\n",
1261 | " items = set(self.items)\n",
1262 | " self.items = {i for i in items if i != 0}\n",
1263 | " self.champion_types = champion_types\n",
1264 | " list_champion_types = []\n",
1265 | " for k,v in champion_types.items():\n",
1266 | " list_champion_types.extend(v)\n",
1267 | " \n",
1268 | " self.set_champ_type = set(list_champion_types)\n",
1269 | "\n",
1270 | " # crear diccionarios token2id y id2token\n",
1271 | " self.champions_token2id, self.champions_id2token = self._token_dict(self.champions)\n",
1272 | " self.roles_token2id, self.roles_id2token = self._token_dict(self.roles)\n",
1273 | " self.items_token2id, self.items_id2token = self._token_dict(self.items)\n",
1274 | " self.types_token2id, self.types_id2token = self._token_dict(self.set_champ_type)\n",
1275 | "\n",
1276 | " self.transform = transform\n",
1277 | "\n",
1278 | " def _load_matches(self, path):\n",
1279 | " data_matches = pd.read_csv(path) \n",
1280 | " return data_matches\n",
1281 | "\n",
1282 | " def _token_dict(self, data):\n",
1283 | " token2id = {}\n",
1284 | " id2token = {}\n",
1285 | " for i, j in enumerate(data):\n",
1286 | " token2id.update({j:i})\n",
1287 | " id2token.update({i:j})\n",
1288 | "\n",
1289 | " return token2id, id2token\n",
1290 | "\n",
1291 | " def _tokens2ids(self, token2id, tokens):\n",
1292 | " ids = []\n",
1293 | " for token in tokens:\n",
1294 | " ids.append(token2id[token])\n",
1295 | " \n",
1296 | " return ids\n",
1297 | "\n",
1298 | " def _tokens2ids_items(self, token2id, tokens):\n",
1299 | " #items_vecs = []\n",
1300 | " item_vec = np.zeros((len(token2id)))\n",
1301 | " for token in tokens:\n",
1302 | " if token in token2id: \n",
1303 | " item_vec[token2id[token]] = 1\n",
1304 | " #items_vecs.append(item_vec)\n",
1305 | " \n",
1306 | " return item_vec\n",
1307 | "\n",
1308 | " def _build_dict(self, match):\n",
1309 | " # sacar en orden los campeones de la partida\n",
1310 | " champion_tokens = list(match['championid'])\n",
1311 | " champions_ids = self._tokens2ids(self.champions_token2id, champion_tokens)\n",
1312 | "\n",
1313 | " # sacar en orden los items de la partida\n",
1314 | " #items_tokens = match['championid']\n",
1315 | " #items_ids = self._tokens2ids(self.items_token2id, items_tokens)\n",
1316 | " # sacar en orden los roles de la partida\n",
1317 | " role_tokens = list(match['position-role'])\n",
1318 | " role_ids = self._tokens2ids(self.roles_token2id, role_tokens)\n",
1319 | " list_win = list(match['win'])[4:6]\n",
1320 | " \n",
1321 | " list_win = np.array(list_win)\n",
1322 | " num_win = np.argsort(list_win)\n",
1323 | " num_win = num_win[len(num_win)-1]\n",
1324 | "\n",
1325 | " list_part_items = []\n",
1326 | " list_types = []\n",
1327 | " items_list = ['item1','item2','item3','item4','item5','item6']\n",
1328 | " for id_champ in champion_tokens:\n",
1329 | " champ_atr = match[match.championid == id_champ]\n",
1330 | " items = champ_atr[items_list]\n",
1331 | " items_tokens = list(items.iloc[0, :])\n",
1332 | " items_ids = self._tokens2ids_items(self.items_token2id, items_tokens)\n",
1333 | " list_part_items.append(items_ids)\n",
1334 | "\n",
1335 | " type_champ = self.champion_types[id_champ]\n",
1336 | " type_ids = self._tokens2ids(self.types_token2id, type_champ)\n",
1337 | " list_types.append(type_ids)\n",
1338 | "\n",
1339 | " # construir 5 veces 0s y 5 veces 1s\n",
1340 | " #team_ids = \n",
1341 | " x = {\n",
1342 | " 'champions': torch.from_numpy(np.array(champions_ids)),\n",
1343 | " 'role': torch.from_numpy(np.array(role_ids)),\n",
1344 | " 'type': torch.from_numpy(np.array(list_types))\n",
1345 | " }\n",
1346 | " y= {\n",
1347 | " 'items': torch.from_numpy(np.array(list_part_items)),\n",
1348 | " 'win': torch.from_numpy(np.array(num_win))\n",
1349 | " }\n",
1350 | " \n",
1351 | " return x, y\n",
1352 | "\n",
1353 | " def __getitem__(self, idx): \n",
1354 | " # idx es el match_id en este caso\n",
1355 | " # la función debiera retornar la info de cada partida\n",
1356 | " # buscar idx de la partida en mi estructura, y retornar los diccionarios con los atributos\n",
1357 | " id_match = self.matches_id[idx]\n",
1358 | " match = self.matches[(self.matches.matchid == id_match)]\n",
1359 | " x, y = self._build_dict(match) # entrega un df de la partida según el idx\n",
1360 | " if self.transform:\n",
1361 | " sample = x, y\n",
1362 | " x, y = self.transform(sample)\n",
1363 | " return x, y # el item per sé, la partida con todas sus características\n",
1364 | "\n",
1365 | " def __len__(self):\n",
1366 | " return len(self.matches_id)\n"
1367 | ],
1368 | "execution_count": 16,
1369 | "outputs": []
1370 | },
1371 | {
1372 | "cell_type": "markdown",
1373 | "metadata": {
1374 | "id": "UIm1_KUCUNB0",
1375 | "colab_type": "text"
1376 | },
1377 | "source": [
1378 | "# Model"
1379 | ]
1380 | },
1381 | {
1382 | "cell_type": "markdown",
1383 | "metadata": {
1384 | "id": "qr3TZbrnUg2H",
1385 | "colab_type": "text"
1386 | },
1387 | "source": [
1388 | "## Transformer encoder modified to obtain the attention weights"
1389 | ]
1390 | },
1391 | {
1392 | "cell_type": "code",
1393 | "metadata": {
1394 | "colab_type": "code",
1395 | "id": "HQvpd9lxxQqI",
1396 | "colab": {}
1397 | },
1398 | "source": [
1399 | "class TransformerEncoder(nn.Module):\n",
1400 | " \"\"\"TransformerEncoder is a stack of N encoder layers\n",
1401 | "\n",
1402 | " Args:\n",
1403 | " encoder_layer: an instance of the TransformerEncoderLayer() class (required).\n",
1404 | " num_layers: the number of sub-encoder-layers in the encoder (required).\n",
1405 | " norm: the layer normalization component (optional).\n",
1406 | "\n",
1407 | " Examples::\n",
1408 | " >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n",
1409 | " >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)\n",
1410 | " >>> src = torch.rand(10, 32, 512)\n",
1411 | " >>> out = transformer_encoder(src)\n",
1412 | " \"\"\"\n",
1413 | "\n",
1414 | " def __init__(self, encoder_layer, num_layers, norm=None):\n",
1415 | " super(TransformerEncoder, self).__init__()\n",
1416 | " self.layers = _get_clones(encoder_layer, num_layers)\n",
1417 | " self.num_layers = num_layers\n",
1418 | " self.norm = norm\n",
1419 | "\n",
1420 | " def forward(self, src, mask=None, src_key_padding_mask=None):\n",
1421 | " \"\"\"Pass the input through the endocder layers in turn.\n",
1422 | "\n",
1423 | " Args:\n",
1424 | " src: the sequnce to the encoder (required).\n",
1425 | " mask: the mask for the src sequence (optional).\n",
1426 | " src_key_padding_mask: the mask for the src keys per batch (optional).\n",
1427 | "\n",
1428 | " Shape:\n",
1429 | " see the docs in Transformer class.\n",
1430 | " \"\"\"\n",
1431 | " output = src\n",
1432 | " att_weights = []\n",
1433 | "\n",
1434 | " for i in range(self.num_layers):\n",
1435 | " output, attn_output_weights = self.layers[i](output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)\n",
1436 | " \n",
1437 | " att_weights.append(attn_output_weights)\n",
1438 | "\n",
1439 | " if self.norm:\n",
1440 | " output = self.norm(output)\n",
1441 | "\n",
1442 | " return output, att_weights\n",
1443 | "\n"
1444 | ],
1445 | "execution_count": 17,
1446 | "outputs": []
1447 | },
1448 | {
1449 | "cell_type": "code",
1450 | "metadata": {
1451 | "id": "eh2Nz2CWDypX",
1452 | "colab_type": "code",
1453 | "colab": {}
1454 | },
1455 | "source": [
1456 | "def _get_clones(module, N):\n",
1457 | " return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n",
1458 | "\n",
1459 | "\n",
1460 | "def _get_activation_fn(activation):\n",
1461 | " if activation == \"relu\":\n",
1462 | " return F.relu\n",
1463 | " elif activation == \"gelu\":\n",
1464 | " return F.gelu\n",
1465 | " else:\n",
1466 | " raise RuntimeError(\"activation should be relu/gelu, not %s.\" % activation)"
1467 | ],
1468 | "execution_count": 18,
1469 | "outputs": []
1470 | },
1471 | {
1472 | "cell_type": "code",
1473 | "metadata": {
1474 | "id": "zJCX_mqojgwQ",
1475 | "colab_type": "code",
1476 | "colab": {}
1477 | },
1478 | "source": [
1479 | "class TransformerEncoderLayer(nn.Module):\n",
1480 | " \"\"\"TransformerEncoderLayer is made up of self-attn and feedforward network.\n",
1481 | " This standard encoder layer is based on the paper \"Attention Is All You Need\".\n",
1482 | " Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,\n",
1483 | " Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in\n",
1484 | " Neural Information Processing Systems, pages 6000-6010. Users may modify or implement\n",
1485 | " in a different way during application.\n",
1486 | "\n",
1487 | " Args:\n",
1488 | " d_model: the number of expected features in the input (required).\n",
1489 | " nhead: the number of heads in the multiheadattention models (required).\n",
1490 | " dim_feedforward: the dimension of the feedforward network model (default=2048).\n",
1491 | " dropout: the dropout value (default=0.1).\n",
1492 | " activation: the activation function of intermediate layer, relu or gelu (default=relu).\n",
1493 | "\n",
1494 | " Examples::\n",
1495 | " >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n",
1496 | " >>> src = torch.rand(10, 32, 512)\n",
1497 | " >>> out = encoder_layer(src)\n",
1498 | " \"\"\"\n",
1499 | "\n",
1500 | " def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=\"relu\"):\n",
1501 | " super(TransformerEncoderLayer, self).__init__()\n",
1502 | " self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n",
1503 | " # Implementation of Feedforward model\n",
1504 | " self.linear1 = nn.Linear(d_model, dim_feedforward)\n",
1505 | " self.dropout = nn.Dropout(dropout)\n",
1506 | " self.linear2 = nn.Linear(dim_feedforward, d_model)\n",
1507 | "\n",
1508 | " self.norm1 = nn.LayerNorm(d_model)\n",
1509 | " self.norm2 = nn.LayerNorm(d_model)\n",
1510 | " self.dropout1 = nn.Dropout(dropout)\n",
1511 | " self.dropout2 = nn.Dropout(dropout)\n",
1512 | "\n",
1513 | " self.activation = _get_activation_fn(activation)\n",
1514 | "\n",
1515 | " def forward(self, src, src_mask=None, src_key_padding_mask=None):\n",
1516 | " \"\"\"Pass the input through the endocder layer.\n",
1517 | "\n",
1518 | " Args:\n",
1519 | " src: the sequnce to the encoder layer (required).\n",
1520 | " src_mask: the mask for the src sequence (optional).\n",
1521 | " src_key_padding_mask: the mask for the src keys per batch (optional).\n",
1522 | "\n",
1523 | " Shape:\n",
1524 | " see the docs in Transformer class.\n",
1525 | " \"\"\"\n",
1526 | " src2, attn_output_weights = self.self_attn(src, src, src, attn_mask=src_mask,\n",
1527 | " key_padding_mask=src_key_padding_mask)\n",
1528 | " src = src + self.dropout1(src2)\n",
1529 | " src = self.norm1(src)\n",
1530 | " if hasattr(self, \"activation\"):\n",
1531 | " src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))\n",
1532 | " else: # for backward compatibility\n",
1533 | " src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))\n",
1534 | " src = src + self.dropout2(src2)\n",
1535 | " src = self.norm2(src)\n",
1536 | " return src, attn_output_weights"
1537 | ],
1538 | "execution_count": 19,
1539 | "outputs": []
1540 | },
1541 | {
1542 | "cell_type": "markdown",
1543 | "metadata": {
1544 | "id": "pwRy106QU6sH",
1545 | "colab_type": "text"
1546 | },
1547 | "source": [
1548 | "## Auxiliary Task Classes"
1549 | ]
1550 | },
1551 | {
1552 | "cell_type": "code",
1553 | "metadata": {
1554 | "id": "RNyvw9ynsLq3",
1555 | "colab_type": "code",
1556 | "colab": {}
1557 | },
1558 | "source": [
1559 | "def getItems(gt_items, table_emb, num_items, emb_dim):\n",
1560 | " list_match = []\n",
1561 | " device = gt_items.device\n",
1562 | " for i in range(gt_items.size(0)):\n",
1563 | " match = gt_items[i,:,:]\n",
1564 | " list_part_item = []\n",
1565 | " for j in range(gt_items.size(1)):\n",
1566 | " participant_items = match[j,:]\n",
1567 | " sum_k = torch.sum(participant_items, dim = 0).item()\n",
1568 | " if int(sum_k) > 0:\n",
1569 | " _, pos_items = torch.topk(participant_items, k = int(sum_k), dim = 0)\n",
1570 | " items_emb = table_emb(pos_items)\n",
1571 | " items_emb = torch.mean(items_emb, dim = 0)\n",
1572 | " list_part_item.append(items_emb)\n",
1573 | " else:\n",
1574 | " list_part_item.append(torch.zeros(emb_dim).to(device))\n",
1575 | " team_item_emb = torch.stack(list_part_item)\n",
1576 | " list_match.append(team_item_emb)\n",
1577 | " return torch.stack(list_match)"
1578 | ],
1579 | "execution_count": 20,
1580 | "outputs": []
1581 | },
1582 | {
1583 | "cell_type": "code",
1584 | "metadata": {
1585 | "id": "KahjODMTdPS4",
1586 | "colab_type": "code",
1587 | "colab": {}
1588 | },
1589 | "source": [
1590 | "class WinEncoder(nn.Module):\n",
1591 | " def __init__(self, model_dim, n_items):\n",
1592 | " super(WinEncoder, self).__init__()\n",
1593 | " self.proj_win = nn.Linear(4*model_dim, 2)\n",
1594 | " self.embeddings_table_items = nn.Embedding(num_embeddings = n_items, embedding_dim = model_dim)\n",
1595 | " self.n_items = n_items\n",
1596 | " self.model_dim = model_dim\n",
1597 | " self.init_weights()\n",
1598 | "\n",
1599 | " def init_weights(self):\n",
1600 | " initrange = 0.1\n",
1601 | " self.proj_win.bias.data.zero_()\n",
1602 | " self.proj_win.weight.data.uniform_(-initrange, initrange)\n",
1603 | " self.embeddings_table_items.weight.data.uniform_(-initrange, initrange)\n",
1604 | "\n",
1605 | " def forward(self, att_match, item_list):\n",
1606 | " # att_match size (Batch, Seq, Emb)\n",
1607 | " # item_list size (Batch, Seq, Num_items, Emb)\n",
1608 | " att_item_team_1, att_item_team_2 = torch.chunk(att_match, 2, dim=1)\n",
1609 | " items_team_1, items_team_2 = torch.chunk(item_list, 2, dim=1)\n",
1610 | "\n",
1611 | " items_team_1 = getItems(items_team_1, self.embeddings_table_items, self.n_items,self.model_dim)\n",
1612 | " items_team_2 = getItems(items_team_2, self.embeddings_table_items, self.n_items,self.model_dim)\n",
1613 | "\n",
1614 | " att_item_team_1 = torch.mean(att_item_team_1, dim=1)\n",
1615 | " att_item_team_1 = F.relu(att_item_team_1)\n",
1616 | " att_item_team_1 = (att_item_team_1 / att_item_team_1.max())\n",
1617 | " items_team_1 = torch.mean(items_team_1, dim=1)\n",
1618 | " items_team_1 = F.relu(items_team_1)\n",
1619 | " items_team_1 = (items_team_1 / items_team_1.max())\n",
1620 | "\n",
1621 | " att_item_team_2 = torch.mean(att_item_team_2, dim=1)\n",
1622 | " att_item_team_2 = F.relu(att_item_team_2)\n",
1623 | " att_item_team_2 = (att_item_team_2 / att_item_team_2.max())\n",
1624 | " items_team_2 = torch.mean(items_team_2, dim=1)\n",
1625 | " items_team_2 = F.relu(items_team_2)\n",
1626 | " items_team_2 = (items_team_2 / items_team_2.max())\n",
1627 | "\n",
1628 | " att_item_team_1 = torch.cat((att_item_team_1, items_team_1), 1)\n",
1629 | " att_item_team_2 = torch.cat((att_item_team_2, items_team_2), 1)\n",
1630 | " proj_win_team = torch.cat((att_item_team_1, att_item_team_2), 1)\n",
1631 | " win_emb = self.proj_win(F.relu(proj_win_team))\n",
1632 | "\n",
1633 | " return win_emb"
1634 | ],
1635 | "execution_count": 21,
1636 | "outputs": []
1637 | },
1638 | {
1639 | "cell_type": "code",
1640 | "metadata": {
1641 | "id": "pdzByFr9ZyB7",
1642 | "colab_type": "code",
1643 | "colab": {}
1644 | },
1645 | "source": [
1646 | "def getTensorPredItem(items_logits):\n",
1647 | " pred_items = torch.zeros(items_logits.size())\n",
1648 | " for i in range(items_logits.size(0)):\n",
1649 | " for j in range(items_logits.size(1)):\n",
1650 | " _,pos_items = torch.topk(items_logits[i,j,:],k = 6,dim=0)\n",
1651 | " pred_items[i,j,pos_items] = 1\n",
1652 | " return pred_items"
1653 | ],
1654 | "execution_count": 22,
1655 | "outputs": []
1656 | },
1657 | {
1658 | "cell_type": "markdown",
1659 | "metadata": {
1660 | "id": "9AlU_u42VG8A",
1661 | "colab_type": "text"
1662 | },
1663 | "source": [
1664 | "## Main Class of the proposed model"
1665 | ]
1666 | },
1667 | {
1668 | "cell_type": "code",
1669 | "metadata": {
1670 | "id": "xY-TRDX2BHAr",
1671 | "colab_type": "code",
1672 | "colab": {}
1673 | },
1674 | "source": [
1675 | "class TransformerLolRecommender(nn.Module):\n",
1676 | "\n",
1677 | " def __init__(self, n_role, n_champions, embeddings_size, nhead, n_items, n_type, nlayers = 1, nhid = 2048, dropout=0.5, aux_task = False, \n",
1678 | " learnable_team_emb = False):\n",
1679 | " super(TransformerLolRecommender, self).__init__()\n",
1680 | "\n",
1681 | " self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
1682 | " \n",
1683 | " self.embeddings_table_role = nn.Embedding(num_embeddings = n_role, embedding_dim = embeddings_size)\n",
1684 | " \n",
1685 | " self.embeddings_table_champion = nn.Embedding(num_embeddings = n_champions, embedding_dim = embeddings_size)\n",
1686 | "\n",
1687 | " self.embeddings_table_type = nn.Embedding(num_embeddings = n_type, embedding_dim = embeddings_size, padding_idx=0)\n",
1688 | " \n",
1689 | " self.learnable_team_emb = learnable_team_emb\n",
1690 | " if learnable_team_emb:\n",
1691 | " self.team_encoder = nn.Embedding(num_embeddings = 2, embedding_dim = embeddings_size)\n",
1692 | " else:\n",
1693 | " self.team_encoder = self.get_team_encoding(embeddings_size, 10)\n",
1694 | " \n",
1695 | " encoder_layers = TransformerEncoderLayer(embeddings_size, nhead, nhid, dropout)\n",
1696 | " self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n",
1697 | " \n",
1698 | " self.recommender = nn.Linear(embeddings_size, n_items)\n",
1699 | " self.pred_champ = nn.Linear(embeddings_size, n_champions)\n",
1700 | "\n",
1701 | " self.aux_task = aux_task\n",
1702 | "\n",
1703 | " if self.aux_task: \n",
1704 | " self.win_encoder = WinEncoder(embeddings_size, n_items)\n",
1705 | "\n",
1706 | " self.init_weights()\n",
1707 | " \n",
1708 | " def get_learnable_team_emb(self, num_batch):\n",
1709 | " emb_team_0 = self.team_encoder(torch.LongTensor([0]).to(self.device))\n",
1710 | " emb_team_0 = emb_team_0.expand(5, emb_team_0.size(1))\n",
1711 | " emb_team_1 = self.team_encoder(torch.LongTensor([1]).to(self.device))\n",
1712 | " emb_team_1 = emb_team_1.expand(5, emb_team_1.size(1))\n",
1713 | " emb_team = torch.cat([emb_team_0, emb_team_1], dim = 0)\n",
1714 | " emb_team = emb_team.unsqueeze(0).expand(num_batch, emb_team.size(0), emb_team.size(1))\n",
1715 | " return emb_team\n",
1716 | "\n",
1717 | " \n",
1718 | " def get_team_encoding(self, embedding_dim, num_champions = 10):\n",
1719 | " team_encoding = torch.zeros(num_champions, embedding_dim)\n",
1720 | " team_encoding[5:,:] = 1\n",
1721 | " return team_encoding.to(self.device)\n",
1722 | "\n",
1723 | " def init_weights(self):\n",
1724 | " initrange = 0.1\n",
1725 | " \n",
1726 | " self.embeddings_table_role.weight.data.uniform_(-initrange, initrange)\n",
1727 | " self.embeddings_table_champion.weight.data.uniform_(-initrange, initrange)\n",
1728 | " self.embeddings_table_type.weight.data.uniform_(-initrange, initrange)\n",
1729 | " \n",
1730 | " self.recommender.bias.data.zero_()\n",
1731 | " self.recommender.weight.data.uniform_(-initrange, initrange)\n",
1732 | "\n",
1733 | " self.pred_champ.bias.data.zero_()\n",
1734 | " self.pred_champ.weight.data.uniform_(-initrange, initrange)\n",
1735 | "\n",
1736 | " if self.learnable_team_emb:\n",
1737 | " self.team_encoder.weight.data.uniform_(-initrange, initrange)\n",
1738 | "\n",
1739 | " def forward(self, role, champion_id, types, items, win, enable_teacher_f):\n",
1740 | "\n",
1741 | " role_participants = self.embeddings_table_role(role)\n",
1742 | " id_participants = self.embeddings_table_champion(champion_id)\n",
1743 | " type_champ = self.embeddings_table_type(types)\n",
1744 | " type_champ = torch.sum(type_champ, dim =2)\n",
1745 | " batch_size = role_participants.size(0)\n",
1746 | " if self.learnable_team_emb:\n",
1747 | " team_participants = self.get_learnable_team_emb(batch_size)\n",
1748 | " else:\n",
1749 | " size_team_emb = self.team_encoder.size()\n",
1750 | " team_participants = self.team_encoder.unsqueeze(0).expand(batch_size, size_team_emb[0], size_team_emb[1])\n",
1751 | "\n",
1752 | " sel_champions = []\n",
1753 | " pos_champions = []\n",
1754 | " for i in range(win.size(0)):\n",
1755 | " id_el = random.randint(0,4)\n",
1756 | " pos_champions.append(id_el)\n",
1757 | " if win[i] != 0:\n",
1758 | " id_el = id_el + 5\n",
1759 | " sel_champion = champion_id[i,id_el]\n",
1760 | " id_participants[i,id_el,:] = 0\n",
1761 | " sel_champions.append(sel_champion)\n",
1762 | "\n",
1763 | " sel_champions = torch.stack(sel_champions)\n",
1764 | " # pos_champions = torch.stack(pos_champions)\n",
1765 | "\n",
1766 | " participants = role_participants + id_participants + team_participants + type_champ\n",
1767 | " # size (Seq, Batch, Emb)\n",
1768 | " participants = participants.permute(1,0,2)\n",
1769 | " # size (Seq, Batch, Emb)\n",
1770 | " output, att_weights = self.transformer_encoder(participants)\n",
1771 | " # size (Batch, Seq, Emb)\n",
1772 | " output = output.permute(1,0,2)\n",
1773 | " logits_items = self.recommender(output)\n",
1774 | "\n",
1775 | " output_obj = {\n",
1776 | " 'logits_items': logits_items,\n",
1777 | " 'att_weights': att_weights,\n",
1778 | " 'outputs': output,\n",
1779 | " 'sel_champions': sel_champions,\n",
1780 | " 'pos_champions': pos_champions\n",
1781 | " }\n",
1782 | "\n",
1783 | " if self.aux_task:\n",
1784 | " if enable_teacher_f: \n",
1785 | " items_used = items\n",
1786 | " else:\n",
1787 | " items_used = getTensorPredItem(logits_items).to(self.device)\n",
1788 | " logits_win = self.win_encoder(output, items_used)\n",
1789 | " output_obj['logits_win'] = logits_win\n",
1790 | "\n",
1791 | " return output_obj"
1792 | ],
1793 | "execution_count": 23,
1794 | "outputs": []
1795 | },
1796 | {
1797 | "cell_type": "markdown",
1798 | "metadata": {
1799 | "id": "rwYoKWcsVqex",
1800 | "colab_type": "text"
1801 | },
1802 | "source": [
1803 | "# Logger and Checkpointer"
1804 | ]
1805 | },
1806 | {
1807 | "cell_type": "markdown",
1808 | "metadata": {
1809 | "id": "nrXJc5reVwyM",
1810 | "colab_type": "text"
1811 | },
1812 | "source": [
1813 | "These classes and methods are essential to log relevant information about the model and metrics in Coment. Likewise, they allow to save checkpoint in each epoch. "
1814 | ]
1815 | },
1816 | {
1817 | "cell_type": "code",
1818 | "metadata": {
1819 | "id": "D0JLqdBEV67X",
1820 | "colab_type": "code",
1821 | "colab": {}
1822 | },
1823 | "source": [
1824 | "def load_defaults(defaults_file):\n",
1825 | " return OmegaConf.load(defaults_file)\n",
1826 | "\n",
1827 | "\n",
1828 | "def load_config_file(config_file):\n",
1829 | " if not config_file:\n",
1830 | " return OmegaConf.create()\n",
1831 | " return OmegaConf.load(config_file)\n",
1832 | "\n",
1833 | "\n",
1834 | "def load_config(config_file, defaults_file):\n",
1835 | " defaults = load_defaults(defaults_file)\n",
1836 | " config = OmegaConf.merge(defaults, load_config_file(config_file))\n",
1837 | " config.merge_with_cli()\n",
1838 | " return config\n",
1839 | "\n",
1840 | "\n",
1841 | "def build_config(args):\n",
1842 | " return load_config(args.config_file, args.defaults_file)\n",
1843 | "\n",
1844 | "\n",
1845 | "def config_to_dict(cfg):\n",
1846 | " return dict(cfg)\n",
1847 | "\n",
1848 | "\n",
1849 | "def config_to_comet(cfg):\n",
1850 | " def _config_to_comet(cfg, local_dict, parent_str):\n",
1851 | " for key, value in cfg.items():\n",
1852 | " full_key = \"{}.{}\".format(parent_str, key)\n",
1853 | " if isinstance(value, (dict, DictConfig)):\n",
1854 | " _config_to_comet(value, local_dict, full_key)\n",
1855 | " else:\n",
1856 | " local_dict[full_key] = value\n",
1857 | "\n",
1858 | " local_dict = {}\n",
1859 | " for key, value in cfg.items():\n",
1860 | " if isinstance(value, (dict, DictConfig)):\n",
1861 | " _config_to_comet(value, local_dict, key)\n",
1862 | " else:\n",
1863 | " local_dict[key] = value\n",
1864 | " return local_dict"
1865 | ],
1866 | "execution_count": 24,
1867 | "outputs": []
1868 | },
1869 | {
1870 | "cell_type": "code",
1871 | "metadata": {
1872 | "id": "RWTpFcpaU50q",
1873 | "colab_type": "code",
1874 | "colab": {}
1875 | },
1876 | "source": [
1877 | "def get_checkpointer(save_path, metric_name='val_acc'):\n",
1878 | " if not os.path.exists(save_path):\n",
1879 | " os.makedirs(save_path)\n",
1880 | " return ModelCheckpoint(\n",
1881 | " filepath=save_path,\n",
1882 | " verbose=True,\n",
1883 | " monitor=metric_name,\n",
1884 | " mode='max',\n",
1885 | " )\n",
1886 | "\n",
1887 | "\n",
1888 | "# class CometLogger(LightningLoggerBase):\n",
1889 | "# # Thank you @ceyzaguirre4\n",
1890 | "# def __init__(self, config, *args, **kwargs):\n",
1891 | "# super().__init__()\n",
1892 | "# self.comet_exp = CometExperiment(*args, **kwargs)\n",
1893 | "# self.comet_exp.set_name(config['exp_name'])\n",
1894 | "# self.comet_exp.log_parameters(config)\n",
1895 | "# self.config = config\n",
1896 | "\n",
1897 | "# @rank_zero_only\n",
1898 | "# def log_hyperparams(self, params):\n",
1899 | "# self.comet_exp.log_parameters(config_to_comet(params))\n",
1900 | "\n",
1901 | "# @rank_zero_only\n",
1902 | "# def log_metrics(self, metrics, step):\n",
1903 | "# self.comet_exp.log_metrics(metrics)\n",
1904 | "\n",
1905 | "# @rank_zero_only\n",
1906 | "# def finalize(self, status):\n",
1907 | "# self.comet_exp.end()\n",
1908 | " \n",
1909 | "# def version(self):\n",
1910 | "# return self.config['exp']\n"
1911 | ],
1912 | "execution_count": 25,
1913 | "outputs": []
1914 | },
1915 | {
1916 | "cell_type": "markdown",
1917 | "metadata": {
1918 | "id": "5ktMqAUMWeEz",
1919 | "colab_type": "text"
1920 | },
1921 | "source": [
1922 | "# Metrics"
1923 | ]
1924 | },
1925 | {
1926 | "cell_type": "code",
1927 | "metadata": {
1928 | "id": "-TMnKThtdtb4",
1929 | "colab_type": "code",
1930 | "colab": {}
1931 | },
1932 | "source": [
1933 | "def recall_at_k(output, target, k = 6):\n",
1934 | " output_k, ind_k = torch.topk(output, k, dim = 1)\n",
1935 | " sum_recall = 0\n",
1936 | " num_part = output_k.size(0)\n",
1937 | " relevants = target.sum(dim = 1)\n",
1938 | " list_recall = []\n",
1939 | " for i in range(num_part):\n",
1940 | " target_k = target[i, ind_k[i,:]]\n",
1941 | " intersection = target_k.sum(dim = 0)\n",
1942 | " recall_n = intersection/relevants[i]\n",
1943 | " list_recall.append(recall_n)\n",
1944 | " sum_recall+=recall_n\n",
1945 | " \n",
1946 | " recall_avg = sum_recall/num_part\n",
1947 | " return recall_avg, num_part, list_recall\n"
1948 | ],
1949 | "execution_count": 26,
1950 | "outputs": []
1951 | },
1952 | {
1953 | "cell_type": "code",
1954 | "metadata": {
1955 | "id": "1xRNZaAJ_RZ8",
1956 | "colab_type": "code",
1957 | "colab": {}
1958 | },
1959 | "source": [
1960 | "def precision_at_k(r, k):\n",
1961 | " \"\"\"Score is precision @ k\n",
1962 | "\n",
1963 | " Relevance is binary (nonzero is relevant).\n",
1964 | "\n",
1965 | " >>> r = [0, 0, 1]\n",
1966 | " >>> precision_at_k(r, 1)\n",
1967 | " 0.0\n",
1968 | " >>> precision_at_k(r, 2)\n",
1969 | " 0.0\n",
1970 | " >>> precision_at_k(r, 3)\n",
1971 | " 0.33333333333333331\n",
1972 | " >>> precision_at_k(r, 4)\n",
1973 | " Traceback (most recent call last):\n",
1974 | " File \"\", line 1, in ?\n",
1975 | " ValueError: Relevance score length < k\n",
1976 | "\n",
1977 | "\n",
1978 | " Args:\n",
1979 | " r: Relevance scores (list or numpy) in rank order\n",
1980 | " (first element is the first item)\n",
1981 | "\n",
1982 | " Returns:\n",
1983 | " Precision @ k\n",
1984 | "\n",
1985 | " Raises:\n",
1986 | " ValueError: len(r) must be >= k\n",
1987 | " \"\"\"\n",
1988 | " assert k >= 1\n",
1989 | " r = np.asarray(r)[:k] != 0\n",
1990 | " if r.size != k:\n",
1991 | " raise ValueError('Relevance score length < k')\n",
1992 | " return np.mean(r)\n",
1993 | "\n",
1994 | "\n",
1995 | "def average_precision(r):\n",
1996 | " \"\"\"Score is average precision (area under PR curve)\n",
1997 | "\n",
1998 | " Relevance is binary (nonzero is relevant).\n",
1999 | "\n",
2000 | " >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1]\n",
2001 | " >>> delta_r = 1. / sum(r)\n",
2002 | " >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y])\n",
2003 | " 0.7833333333333333\n",
2004 | " >>> average_precision(r)\n",
2005 | " 0.78333333333333333\n",
2006 | "\n",
2007 | " Args:\n",
2008 | " r: Relevance scores (list or numpy) in rank order\n",
2009 | " (first element is the first item)\n",
2010 | "\n",
2011 | " Returns:\n",
2012 | " Average precision\n",
2013 | " \"\"\"\n",
2014 | " r = np.asarray(r) != 0\n",
2015 | " out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]]\n",
2016 | " if not out:\n",
2017 | " return 0.\n",
2018 | " return np.mean(out)"
2019 | ],
2020 | "execution_count": 27,
2021 | "outputs": []
2022 | },
2023 | {
2024 | "cell_type": "code",
2025 | "metadata": {
2026 | "id": "HfusZ9qS_uSq",
2027 | "colab_type": "code",
2028 | "colab": {}
2029 | },
2030 | "source": [
2031 | "def map_at(output, target, k=6):\n",
2032 | " sum_ap = 0\n",
2033 | " num_part = output.size(0)\n",
2034 | " list_map = []\n",
2035 | " for i in range(num_part):\n",
2036 | " out_p = output[i,:]\n",
2037 | " target_p = target[i,:]\n",
2038 | " output_k, ind_k = torch.topk(out_p, k, dim = 0)\n",
2039 | " list_rel = target_p[ind_k].tolist()\n",
2040 | " ap_at = average_precision(list_rel)\n",
2041 | " list_map.append(ap_at) \n",
2042 | " sum_ap += ap_at\n",
2043 | " return sum_ap/num_part, list_map"
2044 | ],
2045 | "execution_count": 28,
2046 | "outputs": []
2047 | },
2048 | {
2049 | "cell_type": "code",
2050 | "metadata": {
2051 | "id": "CctZXsremks4",
2052 | "colab_type": "code",
2053 | "colab": {}
2054 | },
2055 | "source": [
2056 | "def calc_precision_multiclass(output, target, k = 6):\n",
2057 | " output_k, ind_k = torch.topk(output, k, dim = 1)\n",
2058 | " sum_prec = 0\n",
2059 | " num_part = output_k.size(0)\n",
2060 | " list_prec = []\n",
2061 | " for i in range(num_part):\n",
2062 | " target_k = target[i, ind_k[i,:]]\n",
2063 | " intersection = target_k.sum(dim = 0)\n",
2064 | " preci_n = intersection/k\n",
2065 | " list_prec.append(preci_n)\n",
2066 | " sum_prec+=preci_n\n",
2067 | " \n",
2068 | " prec_avg = sum_prec/num_part\n",
2069 | " return prec_avg, num_part, list_prec"
2070 | ],
2071 | "execution_count": 29,
2072 | "outputs": []
2073 | },
2074 | {
2075 | "cell_type": "code",
2076 | "metadata": {
2077 | "id": "mykasAMS0bnG",
2078 | "colab_type": "code",
2079 | "colab": {}
2080 | },
2081 | "source": [
2082 | "def f1_score(recall, precision):\n",
2083 | " f1 = 2 * ((precision * recall) / (precision + recall))\n",
2084 | " return f1"
2085 | ],
2086 | "execution_count": 30,
2087 | "outputs": []
2088 | },
2089 | {
2090 | "cell_type": "code",
2091 | "metadata": {
2092 | "id": "1zMdgUJslL4V",
2093 | "colab_type": "code",
2094 | "colab": {}
2095 | },
2096 | "source": [
2097 | "class AverageMeter(object):\n",
2098 | " \"\"\"Computes and stores the average and current value\n",
2099 | " Taken from PyTorch's examples.imagenet.main\n",
2100 | " \"\"\"\n",
2101 | " def __init__(self):\n",
2102 | " self.reset()\n",
2103 | "\n",
2104 | " def reset(self):\n",
2105 | " self.val = 0\n",
2106 | " self.avg = 0\n",
2107 | " self.sum = 0\n",
2108 | " self.count = 0\n",
2109 | "\n",
2110 | " def update(self, val, n=1):\n",
2111 | " self.val = val\n",
2112 | " self.sum += val * n\n",
2113 | " self.count += n\n",
2114 | " self.avg = self.sum / self.count"
2115 | ],
2116 | "execution_count": 31,
2117 | "outputs": []
2118 | },
2119 | {
2120 | "cell_type": "code",
2121 | "metadata": {
2122 | "id": "x01LspaO_A7f",
2123 | "colab_type": "code",
2124 | "colab": {}
2125 | },
2126 | "source": [
2127 | "def set_seed(seed, slow=False):\n",
2128 | " import random\n",
2129 | "\n",
2130 | " if torch.cuda.is_available():\n",
2131 | " torch.cuda.manual_seed(seed)\n",
2132 | "\n",
2133 | " torch.manual_seed(seed)\n",
2134 | " random.seed(seed)\n",
2135 | " np.random.seed(seed)\n",
2136 | "\n",
2137 | " if slow:\n",
2138 | " torch.backends.cudnn.deterministic = True\n",
2139 | " torch.backends.cudnn.benchmark = False"
2140 | ],
2141 | "execution_count": 32,
2142 | "outputs": []
2143 | },
2144 | {
2145 | "cell_type": "code",
2146 | "metadata": {
2147 | "id": "2URNBE4I7x7j",
2148 | "colab_type": "code",
2149 | "colab": {}
2150 | },
2151 | "source": [
2152 | "def get_winners(att_vec, gt_item, win_vec, pos_champions, outputs_log):\n",
2153 | " list_att = []\n",
2154 | " list_gt = []\n",
2155 | " list_cham = []\n",
2156 | " for i in range(att_vec.size(0)):\n",
2157 | " win = win_vec[i]\n",
2158 | " pos = pos_champions[i]\n",
2159 | " if win == 0:\n",
2160 | " a = list(range(0,5))\n",
2161 | " del a[pos]\n",
2162 | " att_vec_match = att_vec[i,a,:]\n",
2163 | " gt_match = gt_item[i,a, :]\n",
2164 | " list_cham.append(outputs_log[i,pos, :])\n",
2165 | " list_att.append(att_vec_match)\n",
2166 | " list_gt.append(gt_match) \n",
2167 | " else:\n",
2168 | " a = list(range(5,10))\n",
2169 | " del a[pos]\n",
2170 | " att_vec_match = att_vec[i,a,:]\n",
2171 | " gt_match = gt_item[i, a, :]\n",
2172 | " list_cham.append(outputs_log[i,pos + 5, :])\n",
2173 | " list_att.append(att_vec_match)\n",
2174 | " list_gt.append(gt_match)\n",
2175 | "\n",
2176 | " att_winners = torch.stack(list_att, dim=0)\n",
2177 | " gt_winners = torch.stack(list_gt, dim=0)\n",
2178 | " att_cham = torch.stack(list_cham, dim=0)\n",
2179 | " return att_winners, gt_winners, att_cham\n",
2180 | " \n",
2181 | "\n"
2182 | ],
2183 | "execution_count": 33,
2184 | "outputs": []
2185 | },
2186 | {
2187 | "cell_type": "code",
2188 | "metadata": {
2189 | "id": "X3xn5Zd30i3C",
2190 | "colab_type": "code",
2191 | "colab": {}
2192 | },
2193 | "source": [
2194 | "def save_att_weights(list_att, path_save_att):\n",
2195 | " with open(path_save_att, 'wb') as handle:\n",
2196 | " pickle.dump(list_att, handle)"
2197 | ],
2198 | "execution_count": 34,
2199 | "outputs": []
2200 | },
2201 | {
2202 | "cell_type": "markdown",
2203 | "metadata": {
2204 | "id": "WDA0GHysW4vX",
2205 | "colab_type": "text"
2206 | },
2207 | "source": [
2208 | "# Training and evaluation loop"
2209 | ]
2210 | },
2211 | {
2212 | "cell_type": "markdown",
2213 | "metadata": {
2214 | "id": "TVeG7SOJW80S",
2215 | "colab_type": "text"
2216 | },
2217 | "source": [
2218 | "The training and evaluation loop are based on [Pytorch-lightning](https://github.com/williamFalcon/pytorch-lightning)"
2219 | ]
2220 | },
2221 | {
2222 | "cell_type": "code",
2223 | "metadata": {
2224 | "id": "OOJROUqzSqJU",
2225 | "colab_type": "code",
2226 | "colab": {}
2227 | },
2228 | "source": [
2229 | "import argparse"
2230 | ],
2231 | "execution_count": 35,
2232 | "outputs": []
2233 | },
2234 | {
2235 | "cell_type": "code",
2236 | "metadata": {
2237 | "id": "mR77pKXdXDDF",
2238 | "colab_type": "code",
2239 | "colab": {}
2240 | },
2241 | "source": [
2242 | "class Struct:\n",
2243 | " def __init__(self, **entries):\n",
2244 | " self.__dict__.update(entries)\n",
2245 | " #self.elems = entries.items()\n",
2246 | " \n",
2247 | " def items(self):\n",
2248 | " return self.__dict__.items()"
2249 | ],
2250 | "execution_count": 36,
2251 | "outputs": []
2252 | },
2253 | {
2254 | "cell_type": "code",
2255 | "metadata": {
2256 | "id": "RxtiaNMYc6iw",
2257 | "colab_type": "code",
2258 | "colab": {}
2259 | },
2260 | "source": [
2261 | "class LolRecAttModel(pl.LightningModule):\n",
2262 | "\n",
2263 | " def __init__(self, cfg):\n",
2264 | " super(LolRecAttModel, self).__init__()\n",
2265 | " \n",
2266 | " if type(cfg) is argparse.Namespace:\n",
2267 | " cfg = vars(cfg)\n",
2268 | " self.conf = cfg\n",
2269 | " self.hparams = cfg\n",
2270 | " self.index_split = self.conf['index_split']\n",
2271 | " self.optim = self.conf['optim']\n",
2272 | " set_seed(seed = self.conf['seed'])\n",
2273 | " train_dataset = self.train_dataset()\n",
2274 | " self.batch_size = self.conf['batch_size']\n",
2275 | " self.iter_max_train = len(train_dataset)//self.batch_size\n",
2276 | " num_roles = len(train_dataset.roles)\n",
2277 | " num_champions = len(train_dataset.champions)\n",
2278 | " n_items = len(train_dataset.items)\n",
2279 | " n_types = len(train_dataset.set_champ_type)\n",
2280 | " self.model = TransformerLolRecommender(n_role=num_roles, n_champions=num_champions, embeddings_size=self.conf['embeddings_size'], nhead=self.conf['nhead'], n_items=n_items, n_type=n_types,\n",
2281 | " nlayers = self.conf['nlayers'], nhid = self.conf['nhid'], dropout=self.conf['dropout'], aux_task = self.conf['win_task'], \n",
2282 | " learnable_team_emb = self.conf['learnable_team_emb'])\n",
2283 | " self.loss = nn.BCEWithLogitsLoss()\n",
2284 | " self.loss_aux = nn.CrossEntropyLoss()\n",
2285 | " self.train_loss = AverageMeter()\n",
2286 | " self.train_prec = AverageMeter()\n",
2287 | " self.iter_epoch = 0\n",
2288 | " isExist = os.path.exists(path_save) \n",
2289 | " if isExist:\n",
2290 | " dirs = os.listdir(path_save)\n",
2291 | " self.iter_epoch = len(dirs)\n",
2292 | "\n",
2293 | " self.aux_task = self.conf['win_task']\n",
2294 | " \n",
2295 | " if self.aux_task:\n",
2296 | " self.second_loss = nn.CrossEntropyLoss()\n",
2297 | " self.train_acc_win = AverageMeter()\n",
2298 | " self.train_main_loss = AverageMeter()\n",
2299 | " self.train_win_loss = AverageMeter()\n",
2300 | " self.alpha = self.conf['alpha']\n",
2301 | " self.beta = self.conf['beta']\n",
2302 | " self.epoch_to_win = self.conf['init_epoch']\n",
2303 | "\n",
2304 | " def check_epoch(self, num_iter):\n",
2305 | " if num_iter == 0:\n",
2306 | " self.train_loss = AverageMeter()\n",
2307 | " self.train_prec = AverageMeter()\n",
2308 | " if self.aux_task:\n",
2309 | " self.train_acc_win = AverageMeter()\n",
2310 | " self.train_main_loss = AverageMeter()\n",
2311 | " self.train_win_loss = AverageMeter()\n",
2312 | " self.iter_epoch+=1\n",
2313 | "\n",
2314 | " def custom_print(self, batch, loss, start_time, prec, acc = 0, log_interval = 100, loss_win =0, epoch=1):\n",
2315 | " if batch % log_interval == 0:\n",
2316 | " elapsed = time.time() - start_time\n",
2317 | " elapsed = elapsed*log_interval if batch > 0 else elapsed\n",
2318 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n",
2319 | " print('| epoch {:3d} | {:5d}/{:5d} batches | '\n",
2320 | " 'ms/batch {:5.2f} | '\n",
2321 | " 'loss {:5.6f} | loss win {:5.6f} | precision {:5.6f} | Accuracy (win) {:5.6f}'.format(\n",
2322 | " self.iter_epoch, batch, self.iter_max_train,\n",
2323 | " elapsed, loss, loss_win, prec, acc))\n",
2324 | " else:\n",
2325 | " print('| epoch {:3d} | {:5d}/{:5d} batches | '\n",
2326 | " 'ms/batch {:5.2f} | '\n",
2327 | " 'loss {:5.6f} | precision {:5.6f}'.format(\n",
2328 | " self.iter_epoch, batch, self.iter_max_train,\n",
2329 | " elapsed, loss, prec)) \n",
2330 | "\n",
2331 | " def forward(self, x, items, win, teacher_forcing):\n",
2332 | " role = x['role']\n",
2333 | " champions = x['champions']\n",
2334 | " types = x['type']\n",
2335 | " out = self.model(role, champions, types, items, win, teacher_forcing)\n",
2336 | " return out\n",
2337 | " #return torch.relu(self.l1(x.view(x.size(0), -1)))\n",
2338 | "\n",
2339 | " def training_step(self, batch, batch_nb):\n",
2340 | " # REQUIRED\n",
2341 | " self.check_epoch(batch_nb)\n",
2342 | " start_time = time.time()\n",
2343 | " x, y = batch\n",
2344 | " if len(x['role'].size()) == 3:\n",
2345 | " x['role'] = x['role'].reshape(x['role'].size(0)*x['role'].size(1), x['role'].size(2))\n",
2346 | " x['champions'] = x['champions'].reshape(x['champions'].size(0)*x['champions'].size(1), x['champions'].size(2))\n",
2347 | " x['type'] = x['type'].reshape(x['type'].size(0)*x['type'].size(1), x['type'].size(2), x['type'].size(3))\n",
2348 | " y['items'] = y['items'].reshape(y['items'].size(0)*y['items'].size(1), y['items'].size(2), y['items'].size(3))\n",
2349 | " y['win'] = y['win'].reshape(y['win'].size(0)*y['win'].size(1))\n",
2350 | " y_hat = self.forward(x, y['items'], y['win'], self.conf['teacher_forcing'])\n",
2351 | " \n",
2352 | " #Mains task\n",
2353 | " logits_items = y_hat['logits_items']\n",
2354 | " gt_items = y['items']\n",
2355 | " sel_champions = y_hat['sel_champions']\n",
2356 | " pos_champions = y_hat['pos_champions']\n",
2357 | " outputs_log = y_hat['outputs']\n",
2358 | " logits_items, gt_items, att_cham = get_winners(logits_items, gt_items, y['win'], pos_champions, outputs_log)\n",
2359 | " \n",
2360 | " out = logits_items.reshape(logits_items.size(0)*logits_items.size(1), logits_items.size(2))\n",
2361 | " out_aux = self.model.pred_champ(att_cham)\n",
2362 | "\n",
2363 | " gt = gt_items.reshape(gt_items.size(0)*gt_items.size(1), gt_items.size(2))\n",
2364 | " loss = self.loss(out, gt)\n",
2365 | " loss_aux = self.loss_aux(out_aux, sel_champions)\n",
2366 | "\n",
2367 | " prec, num, _ = calc_precision_multiclass(out, gt, k=6)\n",
2368 | " self.train_prec.update(prec, num)\n",
2369 | "\n",
2370 | " tensor_avg_prec = torch.tensor([self.train_prec.avg], device=loss.device)\n",
2371 | " tensorboard_logs = {'train_loss': loss, 'train_loss_aux': loss_aux, 'train_prec_avg': tensor_avg_prec}\n",
2372 | "\n",
2373 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n",
2374 | "\n",
2375 | " #Second Task\n",
2376 | " out_win = y_hat['logits_win']\n",
2377 | " \n",
2378 | " gt_win = y['win'].reshape(-1)\n",
2379 | "\n",
2380 | " _, preds_win = torch.max(out_win, 1)\n",
2381 | "\n",
2382 | " loss_win = self.second_loss(out_win, gt_win)\n",
2383 | " loss_total = self.alpha*loss + self.beta*loss_win\n",
2384 | " self.train_loss.update(self.alpha*loss.item(), out.size(0))\n",
2385 | " self.train_loss.update(self.beta*loss_win.item(), out_win.size(0))\n",
2386 | "\n",
2387 | " train_acc = torch.sum(preds_win == gt_win).item()/out_win.size(0)\n",
2388 | " self.train_acc_win.update(train_acc, out_win.size(0))\n",
2389 | " self.train_main_loss.update(loss.item(), out.size(0))\n",
2390 | " self.train_win_loss.update(loss_win.item(), out_win.size(0))\n",
2391 | "\n",
2392 | " tensor_avg_acc = torch.tensor([self.train_acc_win.avg], device=loss.device)\n",
2393 | " tensorboard_logs['train_acc_win_avg'] = tensor_avg_acc\n",
2394 | " tensorboard_logs['train_win_loss'] = loss_win\n",
2395 | " tensorboard_logs['train_main_loss'] = loss\n",
2396 | " tensorboard_logs['train_win_loss_avg'] = torch.tensor([self.train_win_loss.avg], device=loss.device)\n",
2397 | " tensorboard_logs['train_main_loss_avg'] = torch.tensor([self.train_main_loss.avg], device=loss.device)\n",
2398 | " tensorboard_logs['train_loss'] = loss_total\n",
2399 | "\n",
2400 | " # self.custom_print(batch_nb, self.train_main_loss.avg, start_time, self.train_prec.avg, self.train_acc_win.avg, 100, self.train_win_loss.avg)\n",
2401 | " else:\n",
2402 | " loss_total = loss + 0.2*loss_aux\n",
2403 | " self.train_loss.update(loss.item(), out.size(0))\n",
2404 | " tensorboard_logs['total_loss_train'] = loss_total\n",
2405 | " # self.custom_print(batch_nb, self.train_loss.avg, start_time, self.train_prec.avg, 0, 100)\n",
2406 | " \n",
2407 | " tensor_avg_loss = torch.tensor([self.train_loss.avg], device=loss.device)\n",
2408 | " tensorboard_logs['train_loss_avg'] = tensor_avg_loss\n",
2409 | " return {'loss': loss_total, 'progress_bar': tensorboard_logs, 'avg_loss': tensor_avg_loss, 'avg_prec':tensor_avg_prec ,'log': tensorboard_logs}\n",
2410 | "\n",
2411 | " def validation_step(self, batch, batch_nb):\n",
2412 | " x, y = batch\n",
2413 | " y_hat = self.forward(x, y['items'], y['win'], False)\n",
2414 | " att_weights = y_hat['att_weights']\n",
2415 | "\n",
2416 | " #Main Task\n",
2417 | " logits_items = y_hat['logits_items']\n",
2418 | " gt_items = y['items']\n",
2419 | " sel_champions = y_hat['sel_champions']\n",
2420 | " pos_champions = y_hat['pos_champions']\n",
2421 | " outputs_log = y_hat['outputs']\n",
2422 | "\n",
2423 | " logits_items, gt_items, att_cham = get_winners(logits_items, gt_items, y['win'], pos_champions, outputs_log)\n",
2424 | " out = logits_items.reshape(logits_items.size(0)*logits_items.size(1), logits_items.size(2))\n",
2425 | " out_aux = self.model.pred_champ(att_cham)\n",
2426 | " \n",
2427 | " gt = gt_items.reshape(gt_items.size(0)*gt_items.size(1), gt_items.size(2))\n",
2428 | "\n",
2429 | " loss = self.loss(out, gt)\n",
2430 | " loss_aux = self.loss_aux(out_aux, sel_champions)\n",
2431 | "\n",
2432 | " prec, num, list_prec = calc_precision_multiclass(out, gt, k=6)\n",
2433 | " prec1, num, list_prec1 = calc_precision_multiclass(out, gt, k=1)\n",
2434 | " prec3, num, list_prec3 = calc_precision_multiclass(out, gt, k=3)\n",
2435 | "\n",
2436 | " recall1, num, list_recall1 = recall_at_k(out, gt, k=1)\n",
2437 | " recall3, num, list_recall3 = recall_at_k(out, gt, k=3)\n",
2438 | " recall6, num, list_recall6 = recall_at_k(out, gt, k=6)\n",
2439 | "\n",
2440 | " f11 = f1_score(recall1, prec1)\n",
2441 | " f13 = f1_score(recall3, prec3) \n",
2442 | " f16 = f1_score(recall6, prec)\n",
2443 | "\n",
2444 | " map6, list_map6 = map_at(out, gt, k=6)\n",
2445 | " map1, list_map1 = map_at(out, gt, k=1)\n",
2446 | " map3, list_map3 = map_at(out, gt, k=3)\n",
2447 | "\n",
2448 | " obj_list = {\n",
2449 | " 'list_prec1': list_prec1,\n",
2450 | " 'list_prec3': list_prec3,\n",
2451 | " 'list_prec': list_prec,\n",
2452 | " 'list_recall1': list_recall1,\n",
2453 | " 'list_recall3': list_recall3,\n",
2454 | " 'list_recall6': list_recall6,\n",
2455 | " 'list_map1': list_map1,\n",
2456 | " 'list_map3': list_map3,\n",
2457 | " 'list_map6': list_map6\n",
2458 | " }\n",
2459 | " obj_res = {'val_loss': loss, 'val_loss_aux': loss_aux, 'val_prec': prec, 'num_batch': out.size(0), 'num':num, 'map6': map6, \n",
2460 | " 'map1': map1, 'map3': map3, 'val_prec1': prec1, 'val_prec3': prec3, 'val_recall1': recall1, \n",
2461 | " 'val_recall3': recall3, 'val_recall6': recall6, 'val_f1_1': f11, 'val_f1_3': f13, 'val_f1_6': f16, \n",
2462 | " 'att_weights': att_weights, 'logits_items': logits_items, 'obj_list': obj_list}\n",
2463 | "\n",
2464 | " #Second Task\n",
2465 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n",
2466 | " out_win = y_hat['logits_win']\n",
2467 | " \n",
2468 | " gt_win = y['win'].reshape(-1)\n",
2469 | " _, preds_win = torch.max(out_win, 1)\n",
2470 | "\n",
2471 | " loss_win = self.second_loss(out_win, gt_win)\n",
2472 | "\n",
2473 | " acc_win = torch.sum(preds_win == gt_win).item()/out_win.size(0)\n",
2474 | " obj_res['val_acc'] = acc_win\n",
2475 | " obj_res['val_loss_win'] = loss_win\n",
2476 | " obj_res['val_main_loss'] = loss\n",
2477 | " obj_res['num_batch_acc'] = out_win.size(0)\n",
2478 | "\n",
2479 | " return obj_res\n",
2480 | "\n",
2481 | " def validation_epoch_end(self, outputs):\n",
2482 | " avg_loss = AverageMeter()\n",
2483 | " avg_loss_aux = AverageMeter()\n",
2484 | " avg_prec = AverageMeter()\n",
2485 | " avg_prec1 = AverageMeter()\n",
2486 | " avg_prec3 = AverageMeter()\n",
2487 | "\n",
2488 | " avg_recall1 = AverageMeter()\n",
2489 | " avg_recall3 = AverageMeter()\n",
2490 | " avg_recall6 = AverageMeter()\n",
2491 | "\n",
2492 | " avg_f1_1 = AverageMeter()\n",
2493 | " avg_f1_3 = AverageMeter()\n",
2494 | " avg_f1_6 = AverageMeter()\n",
2495 | "\n",
2496 | " avg_map = AverageMeter()\n",
2497 | " avg_map1 = AverageMeter()\n",
2498 | " avg_map3 = AverageMeter()\n",
2499 | "\n",
2500 | " list_att_weights = []\n",
2501 | " list_logits_items = []\n",
2502 | "\n",
2503 | " list_prec1 = []\n",
2504 | " list_prec3 = []\n",
2505 | " list_prec6 = []\n",
2506 | "\n",
2507 | " list_recall1 = []\n",
2508 | " list_recall3 = []\n",
2509 | " list_recall6 = []\n",
2510 | "\n",
2511 | " list_map1 = []\n",
2512 | " list_map3 = []\n",
2513 | " list_map6 = []\n",
2514 | "\n",
2515 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n",
2516 | " avg_main_loss = AverageMeter()\n",
2517 | " avg_win_loss = AverageMeter()\n",
2518 | " avg_acc = AverageMeter()\n",
2519 | "\n",
2520 | " device = None\n",
2521 | " for x in outputs:\n",
2522 | "\n",
2523 | " avg_prec.update(x['val_prec'], x['num'])\n",
2524 | " avg_prec1.update(x['val_prec1'], x['num'])\n",
2525 | " avg_prec3.update(x['val_prec3'], x['num'])\n",
2526 | "\n",
2527 | " avg_recall1.update(x['val_recall1'], x['num'])\n",
2528 | " avg_recall3.update(x['val_recall3'], x['num'])\n",
2529 | " avg_recall6.update(x['val_recall6'], x['num'])\n",
2530 | "\n",
2531 | " avg_f1_1.update(x['val_f1_1'], x['num'])\n",
2532 | " avg_f1_3.update(x['val_f1_3'], x['num'])\n",
2533 | " avg_f1_6.update(x['val_f1_6'], x['num'])\n",
2534 | "\n",
2535 | " avg_map.update(x['map6'], x['num_batch'])\n",
2536 | " avg_map1.update(x['map1'], x['num_batch'])\n",
2537 | " avg_map3.update(x['map3'], x['num_batch'])\n",
2538 | "\n",
2539 | " list_att_weights.append(x['att_weights'])\n",
2540 | " list_logits_items.append(x['logits_items'])\n",
2541 | "\n",
2542 | " list_prec1.extend(x['obj_list']['list_prec1'])\n",
2543 | " list_prec3.extend(x['obj_list']['list_prec3'])\n",
2544 | " list_prec6.extend(x['obj_list']['list_prec'])\n",
2545 | "\n",
2546 | " list_recall1.extend(x['obj_list']['list_recall1'])\n",
2547 | " list_recall3.extend(x['obj_list']['list_recall3'])\n",
2548 | " list_recall6.extend(x['obj_list']['list_recall6'])\n",
2549 | "\n",
2550 | " list_map1.extend(x['obj_list']['list_map1'])\n",
2551 | " list_map3.extend(x['obj_list']['list_map3'])\n",
2552 | " list_map6.extend(x['obj_list']['list_map6'])\n",
2553 | "\n",
2554 | " device = x['val_loss'].device\n",
2555 | "\n",
2556 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n",
2557 | " avg_main_loss.update(x['val_main_loss'], x['num_batch'])\n",
2558 | " avg_win_loss.update(x['val_loss_win'], x['num_batch_acc'])\n",
2559 | " avg_acc.update(x['val_acc'], x['num_batch_acc'])\n",
2560 | "\n",
2561 | " avg_loss.update(self.alpha*x['val_main_loss'], x['num_batch'])\n",
2562 | " avg_loss.update(self.beta*x['val_loss_win'], x['num_batch_acc'])\n",
2563 | " else:\n",
2564 | " avg_loss.update(x['val_loss'], x['num_batch'])\n",
2565 | " avg_loss_aux.update(x['val_loss_aux'], x['num_batch'])\n",
2566 | "\n",
2567 | " tensorboard_logs = {'val_loss': torch.tensor([avg_loss.avg], device=device), 'val_prec': torch.tensor([avg_prec.avg], device=device), \n",
2568 | " 'val_map6': torch.tensor([avg_map.avg], device=device), 'val_map1': torch.tensor([avg_map1.avg], device=device),\n",
2569 | " 'val_map3': torch.tensor([avg_map3.avg], device=device), 'val_prec1': torch.tensor([avg_prec1.avg], device=device), \n",
2570 | " 'val_prec3': torch.tensor([avg_prec3.avg], device=device), 'val_recall1': torch.tensor([avg_recall1.avg], device=device),\n",
2571 | " 'val_recall3': torch.tensor([avg_recall3.avg], device=device), 'val_recall6': torch.tensor([avg_recall6.avg], device=device),\n",
2572 | " 'val_f1_1': torch.tensor([avg_f1_1.avg], device=device), 'val_f1_3': torch.tensor([avg_f1_3.avg], device=device),\n",
2573 | " 'val_f1_6': torch.tensor([avg_f1_6.avg], device=device)}\n",
2574 | "\n",
2575 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n",
2576 | " tensorboard_logs['val_main_loss'] = torch.tensor([avg_main_loss.avg], device=device)\n",
2577 | " tensorboard_logs['val_win_loss'] = torch.tensor([avg_win_loss.avg], device=device)\n",
2578 | " tensorboard_logs['val_win_acc'] = torch.tensor([avg_acc.avg], device=device)\n",
2579 | " print('| loss_val {:5.6f} | main_loss_val {:5.6f} | win_loss_val {:5.6f} | precision_val {:5.6f} | map6_val {:5.6f} | acc_val {:5.6f}'.format(avg_loss.avg, avg_main_loss.avg, \n",
2580 | " avg_win_loss.avg, avg_prec.avg, \n",
2581 | " avg_map.avg, avg_acc.avg))\n",
2582 | " # else:\n",
2583 | " # print('| loss_val {:5.6f} | precision1_val {:5.6f} | precision3_val {:5.6f} | precision6_val {:5.6f} | map1_val {:5.6f} | map3_val {:5.6f} | map6_val {:5.6f} | recall1 {:5.6f} | recall3 {:5.6f} | recall6 {:5.6f} | f1_1 {:5.6f} | f1_3 {:5.6f} | f1_6 {:5.6f}'.format(\n",
2584 | " # avg_loss.avg, avg_prec1.avg, avg_prec3.avg, avg_prec.avg, avg_map1.avg, avg_map3.avg, avg_map.avg, avg_recall1.avg, avg_recall3.avg, avg_recall6.avg, avg_f1_1.avg, avg_f1_3.avg, avg_f1_6.avg))\n",
2585 | " \n",
2586 | " path_save_att = path_save_att_format.format(str(self.conf['index_split']), str(self.conf['exp']), str(self.iter_epoch))\n",
2587 | " path_save_list_metrics = path_save_list_metrics_format.format(str(self.conf['index_split']), str(self.conf['exp']), str(self.iter_epoch))\n",
2588 | " weights_items = {\n",
2589 | " 'list_att_weights': list_att_weights,\n",
2590 | " 'list_logits_items': list_logits_items\n",
2591 | " }\n",
2592 | "\n",
2593 | " list_metrics = {\n",
2594 | " 'list_prec1': list_prec1, \n",
2595 | " 'list_prec3': list_prec3,\n",
2596 | " 'list_prec6': list_prec6,\n",
2597 | " 'list_recall1': list_recall1,\n",
2598 | " 'list_recall3': list_recall3,\n",
2599 | " 'list_recall6': list_recall6,\n",
2600 | " 'list_map1': list_map1,\n",
2601 | " 'list_map3': list_map3,\n",
2602 | " 'list_map6': list_map6\n",
2603 | " }\n",
2604 | " save_att_weights(weights_items, path_save_att)\n",
2605 | " save_att_weights(list_metrics, path_save_list_metrics)\n",
2606 | " return {'avg_val_loss': avg_loss.avg, 'avg_val_prec': avg_prec.avg, 'val_map6': avg_map.avg,'progress_bar': tensorboard_logs,'log': tensorboard_logs}\n",
2607 | "\n",
2608 | " def test_step(self, batch, batch_idx):\n",
2609 | " # OPTIONAL\n",
2610 | " return self.validation_step(batch, batch_idx)\n",
2611 | "\n",
2612 | " def test_epoch_end(self, outputs):\n",
2613 | " \n",
2614 | " return self.validation_end(outputs)\n",
2615 | "\n",
2616 | " def configure_optimizers(self):\n",
2617 | " # REQUIRED\n",
2618 | " # can return multiple optimizers and learning_rate schedulers\n",
2619 | " # (LBFGS it is automatically supported, no need for closure function)\n",
2620 | " if self.optim == 'adabound':\n",
2621 | " optimizer = adabound.AdaBound(self.model.parameters(), lr=1e-3, final_lr=0.1)\n",
2622 | " else:\n",
2623 | " optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)\n",
2624 | " return optimizer\n",
2625 | " \n",
2626 | " def train_dataset(self):\n",
2627 | "\n",
2628 | " data = get_partition(self.index_split, list_trainset)\n",
2629 | " composed = transforms.Compose([RandomSort_Part(),\n",
2630 | " RandomSort_Team()])\n",
2631 | " train_dataset = LolDataset(data, transform=composed)\n",
2632 | " return train_dataset\n",
2633 | "\n",
2634 | " @pl.data_loader\n",
2635 | " def train_dataloader(self):\n",
2636 | " \n",
2637 | " train_dataset = self.train_dataset()\n",
2638 | " return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)\n",
2639 | "\n",
2640 | " @pl.data_loader\n",
2641 | " def val_dataloader(self):\n",
2642 | " #data = list_testset[self.index_split]\n",
2643 | " data = get_partition(self.index_split, list_testset)\n",
2644 | " val_dataset = LolDataset(data)\n",
2645 | " return DataLoader(val_dataset, batch_size=self.batch_size)\n",
2646 | " \n",
2647 | " @pl.data_loader\n",
2648 | " def test_dataloader(self):\n",
2649 | " # OPTIONAL\n",
2650 | " return self.val_dataloader()"
2651 | ],
2652 | "execution_count": 37,
2653 | "outputs": []
2654 | },
2655 | {
2656 | "cell_type": "markdown",
2657 | "metadata": {
2658 | "id": "CyRfaqN8XvYi",
2659 | "colab_type": "text"
2660 | },
2661 | "source": [
2662 | "# Config file"
2663 | ]
2664 | },
2665 | {
2666 | "cell_type": "markdown",
2667 | "metadata": {
2668 | "id": "4Aa7VmxeYF7X",
2669 | "colab_type": "text"
2670 | },
2671 | "source": [
2672 | "This config establish the model hyperparameters like:\n",
2673 | "\n",
2674 | "1. index_split - num of the partition used to train.\n",
2675 | "2. optim - optimizer (could be adam or adabound).\n",
2676 | "3. batch_size - Batch size \n",
2677 | "4. embeddings_size - model dim\n",
2678 | "5. nhead - number of attention heads \n",
2679 | "6. nlayers - number of encoder layers\n",
2680 | "7. exp - experiment number\n",
2681 | "8. epochs - number of epoch\n",
2682 | "9. exp_name - experiment name in comet.ml\n",
2683 | "10. alpha, beta - importance weights for losses\n",
2684 | "11. win_task - enable the auxiliary task.\n",
2685 | "12. learnable_team_emb - when it is True the team embedding is learnable \n",
2686 | "otherwise it is static. \n",
2687 | "13. teacher_forcing - enable the teacher forcing for the auxiliary task.\n",
2688 | "14. init_epoch - indicate the epoch when the second task start. \n"
2689 | ]
2690 | },
2691 | {
2692 | "cell_type": "code",
2693 | "metadata": {
2694 | "id": "dqm4Zi0_ha8W",
2695 | "colab_type": "code",
2696 | "colab": {}
2697 | },
2698 | "source": [
2699 | "conf = {\n",
2700 | " 'index_split': 0,\n",
2701 | " 'optim': 'adam',\n",
2702 | " 'seed': 1642,\n",
2703 | " 'batch_size': 100,\n",
2704 | " 'embeddings_size': 512,\n",
2705 | " 'nhead': 2,\n",
2706 | " 'nlayers': 1, \n",
2707 | " 'nhid': 2048, \n",
2708 | " 'dropout': 0.5,\n",
2709 | " 'exp': 13,\n",
2710 | " 'epochs': 10,\n",
2711 | " 'exp_name': 'Main_tasks_rec_only_winners_final_prueba',\n",
2712 | " 'win_task': False,\n",
2713 | " 'alpha': 1,\n",
2714 | " 'beta': 1,\n",
2715 | " 'learnable_team_emb': True,\n",
2716 | " 'teacher_forcing': False,\n",
2717 | " 'init_epoch': 2\n",
2718 | "}"
2719 | ],
2720 | "execution_count": 38,
2721 | "outputs": []
2722 | },
2723 | {
2724 | "cell_type": "markdown",
2725 | "metadata": {
2726 | "id": "IVtKoVTcYDS1",
2727 | "colab_type": "text"
2728 | },
2729 | "source": [
2730 | "# Training and evaluation executor"
2731 | ]
2732 | },
2733 | {
2734 | "cell_type": "code",
2735 | "metadata": {
2736 | "id": "EIBfoXNe1TOh",
2737 | "colab_type": "code",
2738 | "colab": {
2739 | "base_uri": "https://localhost:8080/",
2740 | "height": 833,
2741 | "referenced_widgets": [
2742 | "a705d3b71b5d4fb587b3bb1fb38161fa",
2743 | "abfceea38af9444b8da09122eb0c867d",
2744 | "0dfe065d9c0c468389f75dd74a9a11ac",
2745 | "741dabfff47941f3b290d4ad4cb6be12",
2746 | "503e7ca9191948a4bcf6cc64a8862820",
2747 | "249092a9ce0940218d291fe75cf35f3e",
2748 | "828e6e2c56f3456cb016bf7c8b701ba8",
2749 | "7540c343bf9f461a84c157f84d529cd9",
2750 | "c43e576730a940c28fa78d49e95e7165",
2751 | "b3d8f9b86d0d47ae8c51b8f2eb202aab",
2752 | "82170aabdef246edbf668bb1cdf4a5e3",
2753 | "43a58bf6d9ab454095f8f5f30f10cdca",
2754 | "f80b50d773b240a091c6c3bdf7961924",
2755 | "f8a352088b8d4ed896bf8a206ecc024e",
2756 | "2e151daf92e84c369ee90e8ded7a24f2",
2757 | "5b0b2240357743c4b8285ce6017638c4"
2758 | ]
2759 | },
2760 | "outputId": "3c6e7509-69b7-4562-e993-c1012e1c94dc"
2761 | },
2762 | "source": [
2763 | "from pytorch_lightning import Trainer\n",
2764 | "\n",
2765 | "path_save = '/content/gdrive/My Drive/Proyecto_RecSys/split/{}/exp_recsys/{}/checkpoints/'.format(str(conf['index_split']), str(conf['exp']))\n",
2766 | "path_save_att_format = '/content/gdrive/My Drive/Proyecto_RecSys/split/{}/exp_recsys/{}/checkpoints/att_weights_{}.pkl'\n",
2767 | "path_save_list_metrics_format = '/content/gdrive/My Drive/Proyecto_RecSys/split/{}/exp_recsys/{}/checkpoints/list_metrics_{}.pkl'\n",
2768 | "\n",
2769 | "model = LolRecAttModel(conf)\n",
2770 | "\n",
2771 | "checkpoint_callback = get_checkpointer(path_save,'avg_val_prec')\n",
2772 | "\n",
2773 | "\n",
2774 | "comet_logger = CometLogger(\n",
2775 | " experiment_name=conf['exp_name'],\n",
2776 | " api_key = 'YOUR_KEY',\n",
2777 | " project_name=\"YOUR_PROJECT_NAME\",\n",
2778 | " workspace = 'YOUR_WORKSPACE'\n",
2779 | ")\n",
2780 | "trainer = Trainer(\n",
2781 | " gpus=[0],\n",
2782 | " distributed_backend='dp',\n",
2783 | " logger=comet_logger,\n",
2784 | " max_epochs=conf['epochs'],\n",
2785 | " checkpoint_callback=checkpoint_callback,\n",
2786 | " show_progress_bar=False,\n",
2787 | " gradient_clip_val=0.5\n",
2788 | ")\n",
2789 | "\n",
2790 | "trainer.fit(model) "
2791 | ],
2792 | "execution_count": null,
2793 | "outputs": [
2794 | {
2795 | "output_type": "stream",
2796 | "text": [
2797 | "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/distributed.py:22: UserWarning: Checkpoint directory /content/gdrive/My Drive/Proyecto_RecSys/split/0/exp_recsys/13/checkpoints/ exists and is not empty with save_top_k != 0.All files in this directory will be deleted when a checkpoint is saved!\n",
2798 | " warnings.warn(*args, **kwargs)\n",
2799 | "CometLogger will be initialized in online mode\n",
2800 | "COMET INFO: ----------------------------\n",
2801 | "COMET INFO: Comet.ml Experiment Summary:\n",
2802 | "COMET INFO: Data:\n",
2803 | "COMET INFO: url: https://www.comet.ml/afvilla/lolnet/5094e7cd80244d62bac9be446cbfeb0b\n",
2804 | "COMET INFO: Metrics [count] (min, max):\n",
2805 | "COMET INFO: sys.cpu.percent.01 [4] : (1.0, 12.3)\n",
2806 | "COMET INFO: sys.cpu.percent.02 [4] : (1.0, 12.9)\n",
2807 | "COMET INFO: sys.cpu.percent.03 [4] : (0.9, 12.5)\n",
2808 | "COMET INFO: sys.cpu.percent.04 [4] : (1.0, 12.9)\n",
2809 | "COMET INFO: sys.cpu.percent.avg [4] : (0.975, 12.65)\n",
2810 | "COMET INFO: sys.gpu.0.free_memory [4] : (17061249024.0, 17061249024.0)\n",
2811 | "COMET INFO: sys.gpu.0.gpu_utilization [4]: (0.0, 0.0)\n",
2812 | "COMET INFO: sys.gpu.0.total_memory : (17071734784.0, 17071734784.0)\n",
2813 | "COMET INFO: sys.gpu.0.used_memory [4] : (10485760.0, 10485760.0)\n",
2814 | "COMET INFO: sys.ram.total [4] : (27393740800.0, 27393740800.0)\n",
2815 | "COMET INFO: sys.ram.used [4] : (7792205824.0, 7797219328.0)\n",
2816 | "COMET INFO: Other [count]:\n",
2817 | "COMET INFO: Name: Main_tasks_rec_only_winners_final_prueba\n",
2818 | "COMET INFO: ----------------------------\n",
2819 | "COMET INFO: old comet version (3.0.2) detected. current: 3.1.14 please update your comet lib with command: `pip install --no-cache-dir --upgrade comet_ml`\n",
2820 | "COMET INFO: Experiment is live on comet.ml https://www.comet.ml/afvilla/lolnet/8dcde95206ec45daac4cc6657844b03d\n",
2821 | "\n",
2822 | "GPU available: True, used: True\n",
2823 | "TPU available: False, using: 0 TPU cores\n",
2824 | "CUDA_VISIBLE_DEVICES: [0]\n",
2825 | "\n",
2826 | " | Name | Type | Params\n",
2827 | "-------------------------------------------------------\n",
2828 | "0 | model | TransformerLolRecommender | 3 M \n",
2829 | "1 | loss | BCEWithLogitsLoss | 0 \n",
2830 | "2 | loss_aux | CrossEntropyLoss | 0 \n",
2831 | "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/distributed.py:22: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
2832 | " warnings.warn(*args, **kwargs)\n"
2833 | ],
2834 | "name": "stderr"
2835 | },
2836 | {
2837 | "output_type": "display_data",
2838 | "data": {
2839 | "application/vnd.jupyter.widget-view+json": {
2840 | "model_id": "a705d3b71b5d4fb587b3bb1fb38161fa",
2841 | "version_minor": 0,
2842 | "version_major": 2
2843 | },
2844 | "text/plain": [
2845 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…"
2846 | ]
2847 | },
2848 | "metadata": {
2849 | "tags": []
2850 | }
2851 | },
2852 | {
2853 | "output_type": "stream",
2854 | "text": [
2855 | "/usr/local/lib/python3.6/dist-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead\n",
2856 | " warnings.warn(\"pickle support for Storage will be removed in 1.5. Use `torch.save` instead\", FutureWarning)\n"
2857 | ],
2858 | "name": "stderr"
2859 | },
2860 | {
2861 | "output_type": "stream",
2862 | "text": [
2863 | "\r"
2864 | ],
2865 | "name": "stdout"
2866 | },
2867 | {
2868 | "output_type": "stream",
2869 | "text": [
2870 | "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/distributed.py:22: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
2871 | " warnings.warn(*args, **kwargs)\n"
2872 | ],
2873 | "name": "stderr"
2874 | },
2875 | {
2876 | "output_type": "display_data",
2877 | "data": {
2878 | "application/vnd.jupyter.widget-view+json": {
2879 | "model_id": "c43e576730a940c28fa78d49e95e7165",
2880 | "version_minor": 0,
2881 | "version_major": 2
2882 | },
2883 | "text/plain": [
2884 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…"
2885 | ]
2886 | },
2887 | "metadata": {
2888 | "tags": []
2889 | }
2890 | }
2891 | ]
2892 | }
2893 | ]
2894 | }
2895 |
--------------------------------------------------------------------------------