├── images ├── attn-1.png ├── model.pdf └── model-1.png ├── LICENSE ├── README.md └── model.ipynb /images/attn-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ojedaf/IC-TIR-Lol/HEAD/images/attn-1.png -------------------------------------------------------------------------------- /images/model.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ojedaf/IC-TIR-Lol/HEAD/images/model.pdf -------------------------------------------------------------------------------- /images/model-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ojedaf/IC-TIR-Lol/HEAD/images/model-1.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 ojedaf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Interpretable Contextual Team-aware Item Recommendation: Application in Multiplayer Online Battle Arena Game 2 | 3 | 4 | Open In Colab 5 | 6 | 7 | See paper in [arXiv](https://arxiv.org/abs/2007.15236) 8 | 9 | ![tir-model](https://github.com/ojedaf/IC-TIR-Lol/blob/master/images/model-1.png) 10 | 11 | We release the PyTorch code of the TTIR model. 12 | 13 | ## Content 14 | 15 | - [Prerequisites](#prerequisites) 16 | - [Dataset](#dataset) 17 | - [Code](#code) 18 | - [Result](#testing) 19 | 20 | ## Prerequisites 21 | 22 | The code is built with following libraries: 23 | 24 | - [PyTorch](https://pytorch.org/) 1.0 or higher 25 | - [Comet-ml](https://www.comet.ml/site/) 26 | - [PyTorchLightning](https://github.com/PyTorchLightning/pytorch-lightning) 27 | - [Google Colab](https://colab.research.google.com/) 28 | 29 | ## Dataset 30 | 31 | The used dataset is available [here](https://drive.google.com/drive/folders/1lsCjmVrOA0stNiUguGWKN46fEqzzsXPH?usp=sharing). 32 | 33 | ## Code 34 | 35 | We develop this project using Google Colab. That's why you must have a Google Account and the dataset in a gDrive folder. Furthermore, you have to change these paths according to the location of the dataset. 36 | 37 | ```python 38 | train_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/train_splits.pkl' 39 | test_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/test_splits.pkl' 40 | champion_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/champion_types.pkl' 41 | ``` 42 | 43 | And the comet parameters (api_key, project_name, workspace) 44 | 45 | ```python 46 | comet_logger = CometLogger( 47 | experiment_name=conf['exp_name'], 48 | api_key = 'YOUR_KEY', 49 | project_name="YOUR_PROJECT_NAME", 50 | workspace = 'YOUR_WORKSPACE' 51 | ) 52 | ``` 53 | 54 | ## Baselines 55 | 56 | This work uses the proposed models in [Data mining for item recommendation in MOBA games paper](https://github.com/vgaraujov/RecSysLoL) as baselines. 57 | 58 | ## Results 59 | 60 | This method outperforms the state of the art approaches and explains the result. 61 | 62 | Method | Precision@6 | Recall@6 | F1@6 | MAP@6 | 63 | --- | --- | --- | --- |--- | 64 | TTIR | 0.492 | 0.756 | 0.596 | 0.805 | 65 | CNN | 0.484 | 0.744 | 0.586 | 0.795 | 66 | ANN | 0.476 | 0.732 | 0.566 | 0.785 | 67 | 68 | ![tir-att](https://github.com/ojedaf/IC-TIR-Lol/blob/master/images/attn-1.png) 69 | 70 | ## Citation 71 | 72 | If you find this repository useful for your research, please consider citing our paper: 73 | ``` 74 | @inproceedings{10.1145/3383313.3412211, 75 | author = {Villa, Andr\'{e}s and Araujo, Vladimir and Cattan, Francisca and Parra, Denis}, 76 | title = {Interpretable Contextual Team-Aware Item Recommendation: Application in Multiplayer Online Battle Arena Games}, 77 | year = {2020}, 78 | isbn = {9781450375832}, 79 | publisher = {Association for Computing Machinery}, 80 | address = {New York, NY, USA}, 81 | url = {https://doi.org/10.1145/3383313.3412211}, 82 | doi = {10.1145/3383313.3412211}, 83 | booktitle = {Fourteenth ACM Conference on Recommender Systems}, 84 | pages = {503–508}, 85 | numpages = {6}, 86 | keywords = {Item Recommendation, Deep Learning, MOBA Games}, 87 | location = {Virtual Event, Brazil}, 88 | series = {RecSys '20} 89 | } 90 | ``` 91 | 92 | For any questions, welcome to create an issue or contact Andrés Villa (afvilla@uc.cl) - Vladimir Araujo (vgaraujo@uc.cl). 93 | -------------------------------------------------------------------------------- /model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Aux_RecAttModel_Multitask_Aug_Data_Villa_Cattan_RECSYS_champ_type.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true, 10 | "machine_shape": "hm" 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "accelerator": "GPU", 17 | "widgets": { 18 | "application/vnd.jupyter.widget-state+json": { 19 | "a705d3b71b5d4fb587b3bb1fb38161fa": { 20 | "model_module": "@jupyter-widgets/controls", 21 | "model_name": "HBoxModel", 22 | "state": { 23 | "_view_name": "HBoxView", 24 | "_dom_classes": [], 25 | "_model_name": "HBoxModel", 26 | "_view_module": "@jupyter-widgets/controls", 27 | "_model_module_version": "1.5.0", 28 | "_view_count": null, 29 | "_view_module_version": "1.5.0", 30 | "box_style": "", 31 | "layout": "IPY_MODEL_abfceea38af9444b8da09122eb0c867d", 32 | "_model_module": "@jupyter-widgets/controls", 33 | "children": [ 34 | "IPY_MODEL_0dfe065d9c0c468389f75dd74a9a11ac", 35 | "IPY_MODEL_741dabfff47941f3b290d4ad4cb6be12" 36 | ] 37 | } 38 | }, 39 | "abfceea38af9444b8da09122eb0c867d": { 40 | "model_module": "@jupyter-widgets/base", 41 | "model_name": "LayoutModel", 42 | "state": { 43 | "_view_name": "LayoutView", 44 | "grid_template_rows": null, 45 | "right": null, 46 | "justify_content": null, 47 | "_view_module": "@jupyter-widgets/base", 48 | "overflow": null, 49 | "_model_module_version": "1.2.0", 50 | "_view_count": null, 51 | "flex_flow": "row wrap", 52 | "width": "100%", 53 | "min_width": null, 54 | "border": null, 55 | "align_items": null, 56 | "bottom": null, 57 | "_model_module": "@jupyter-widgets/base", 58 | "top": null, 59 | "grid_column": null, 60 | "overflow_y": null, 61 | "overflow_x": null, 62 | "grid_auto_flow": null, 63 | "grid_area": null, 64 | "grid_template_columns": null, 65 | "flex": null, 66 | "_model_name": "LayoutModel", 67 | "justify_items": null, 68 | "grid_row": null, 69 | "max_height": null, 70 | "align_content": null, 71 | "visibility": null, 72 | "align_self": null, 73 | "height": null, 74 | "min_height": null, 75 | "padding": null, 76 | "grid_auto_rows": null, 77 | "grid_gap": null, 78 | "max_width": null, 79 | "order": null, 80 | "_view_module_version": "1.2.0", 81 | "grid_template_areas": null, 82 | "object_position": null, 83 | "object_fit": null, 84 | "grid_auto_columns": null, 85 | "margin": null, 86 | "display": "inline-flex", 87 | "left": null 88 | } 89 | }, 90 | "0dfe065d9c0c468389f75dd74a9a11ac": { 91 | "model_module": "@jupyter-widgets/controls", 92 | "model_name": "FloatProgressModel", 93 | "state": { 94 | "_view_name": "ProgressView", 95 | "style": "IPY_MODEL_503e7ca9191948a4bcf6cc64a8862820", 96 | "_dom_classes": [], 97 | "description": "Validation sanity check: 100%", 98 | "_model_name": "FloatProgressModel", 99 | "bar_style": "info", 100 | "max": 1, 101 | "_view_module": "@jupyter-widgets/controls", 102 | "_model_module_version": "1.5.0", 103 | "value": 1, 104 | "_view_count": null, 105 | "_view_module_version": "1.5.0", 106 | "orientation": "horizontal", 107 | "min": 0, 108 | "description_tooltip": null, 109 | "_model_module": "@jupyter-widgets/controls", 110 | "layout": "IPY_MODEL_249092a9ce0940218d291fe75cf35f3e" 111 | } 112 | }, 113 | "741dabfff47941f3b290d4ad4cb6be12": { 114 | "model_module": "@jupyter-widgets/controls", 115 | "model_name": "HTMLModel", 116 | "state": { 117 | "_view_name": "HTMLView", 118 | "style": "IPY_MODEL_828e6e2c56f3456cb016bf7c8b701ba8", 119 | "_dom_classes": [], 120 | "description": "", 121 | "_model_name": "HTMLModel", 122 | "placeholder": "​", 123 | "_view_module": "@jupyter-widgets/controls", 124 | "_model_module_version": "1.5.0", 125 | "value": " 1/1.0 [00:05<00:00, 2.86s/it]", 126 | "_view_count": null, 127 | "_view_module_version": "1.5.0", 128 | "description_tooltip": null, 129 | "_model_module": "@jupyter-widgets/controls", 130 | "layout": "IPY_MODEL_7540c343bf9f461a84c157f84d529cd9" 131 | } 132 | }, 133 | "503e7ca9191948a4bcf6cc64a8862820": { 134 | "model_module": "@jupyter-widgets/controls", 135 | "model_name": "ProgressStyleModel", 136 | "state": { 137 | "_view_name": "StyleView", 138 | "_model_name": "ProgressStyleModel", 139 | "description_width": "initial", 140 | "_view_module": "@jupyter-widgets/base", 141 | "_model_module_version": "1.5.0", 142 | "_view_count": null, 143 | "_view_module_version": "1.2.0", 144 | "bar_color": null, 145 | "_model_module": "@jupyter-widgets/controls" 146 | } 147 | }, 148 | "249092a9ce0940218d291fe75cf35f3e": { 149 | "model_module": "@jupyter-widgets/base", 150 | "model_name": "LayoutModel", 151 | "state": { 152 | "_view_name": "LayoutView", 153 | "grid_template_rows": null, 154 | "right": null, 155 | "justify_content": null, 156 | "_view_module": "@jupyter-widgets/base", 157 | "overflow": null, 158 | "_model_module_version": "1.2.0", 159 | "_view_count": null, 160 | "flex_flow": null, 161 | "width": null, 162 | "min_width": null, 163 | "border": null, 164 | "align_items": null, 165 | "bottom": null, 166 | "_model_module": "@jupyter-widgets/base", 167 | "top": null, 168 | "grid_column": null, 169 | "overflow_y": null, 170 | "overflow_x": null, 171 | "grid_auto_flow": null, 172 | "grid_area": null, 173 | "grid_template_columns": null, 174 | "flex": "2", 175 | "_model_name": "LayoutModel", 176 | "justify_items": null, 177 | "grid_row": null, 178 | "max_height": null, 179 | "align_content": null, 180 | "visibility": null, 181 | "align_self": null, 182 | "height": null, 183 | "min_height": null, 184 | "padding": null, 185 | "grid_auto_rows": null, 186 | "grid_gap": null, 187 | "max_width": null, 188 | "order": null, 189 | "_view_module_version": "1.2.0", 190 | "grid_template_areas": null, 191 | "object_position": null, 192 | "object_fit": null, 193 | "grid_auto_columns": null, 194 | "margin": null, 195 | "display": null, 196 | "left": null 197 | } 198 | }, 199 | "828e6e2c56f3456cb016bf7c8b701ba8": { 200 | "model_module": "@jupyter-widgets/controls", 201 | "model_name": "DescriptionStyleModel", 202 | "state": { 203 | "_view_name": "StyleView", 204 | "_model_name": "DescriptionStyleModel", 205 | "description_width": "", 206 | "_view_module": "@jupyter-widgets/base", 207 | "_model_module_version": "1.5.0", 208 | "_view_count": null, 209 | "_view_module_version": "1.2.0", 210 | "_model_module": "@jupyter-widgets/controls" 211 | } 212 | }, 213 | "7540c343bf9f461a84c157f84d529cd9": { 214 | "model_module": "@jupyter-widgets/base", 215 | "model_name": "LayoutModel", 216 | "state": { 217 | "_view_name": "LayoutView", 218 | "grid_template_rows": null, 219 | "right": null, 220 | "justify_content": null, 221 | "_view_module": "@jupyter-widgets/base", 222 | "overflow": null, 223 | "_model_module_version": "1.2.0", 224 | "_view_count": null, 225 | "flex_flow": null, 226 | "width": null, 227 | "min_width": null, 228 | "border": null, 229 | "align_items": null, 230 | "bottom": null, 231 | "_model_module": "@jupyter-widgets/base", 232 | "top": null, 233 | "grid_column": null, 234 | "overflow_y": null, 235 | "overflow_x": null, 236 | "grid_auto_flow": null, 237 | "grid_area": null, 238 | "grid_template_columns": null, 239 | "flex": null, 240 | "_model_name": "LayoutModel", 241 | "justify_items": null, 242 | "grid_row": null, 243 | "max_height": null, 244 | "align_content": null, 245 | "visibility": null, 246 | "align_self": null, 247 | "height": null, 248 | "min_height": null, 249 | "padding": null, 250 | "grid_auto_rows": null, 251 | "grid_gap": null, 252 | "max_width": null, 253 | "order": null, 254 | "_view_module_version": "1.2.0", 255 | "grid_template_areas": null, 256 | "object_position": null, 257 | "object_fit": null, 258 | "grid_auto_columns": null, 259 | "margin": null, 260 | "display": null, 261 | "left": null 262 | } 263 | }, 264 | "c43e576730a940c28fa78d49e95e7165": { 265 | "model_module": "@jupyter-widgets/controls", 266 | "model_name": "HBoxModel", 267 | "state": { 268 | "_view_name": "HBoxView", 269 | "_dom_classes": [], 270 | "_model_name": "HBoxModel", 271 | "_view_module": "@jupyter-widgets/controls", 272 | "_model_module_version": "1.5.0", 273 | "_view_count": null, 274 | "_view_module_version": "1.5.0", 275 | "box_style": "", 276 | "layout": "IPY_MODEL_b3d8f9b86d0d47ae8c51b8f2eb202aab", 277 | "_model_module": "@jupyter-widgets/controls", 278 | "children": [ 279 | "IPY_MODEL_82170aabdef246edbf668bb1cdf4a5e3", 280 | "IPY_MODEL_43a58bf6d9ab454095f8f5f30f10cdca" 281 | ] 282 | } 283 | }, 284 | "b3d8f9b86d0d47ae8c51b8f2eb202aab": { 285 | "model_module": "@jupyter-widgets/base", 286 | "model_name": "LayoutModel", 287 | "state": { 288 | "_view_name": "LayoutView", 289 | "grid_template_rows": null, 290 | "right": null, 291 | "justify_content": null, 292 | "_view_module": "@jupyter-widgets/base", 293 | "overflow": null, 294 | "_model_module_version": "1.2.0", 295 | "_view_count": null, 296 | "flex_flow": "row wrap", 297 | "width": "100%", 298 | "min_width": null, 299 | "border": null, 300 | "align_items": null, 301 | "bottom": null, 302 | "_model_module": "@jupyter-widgets/base", 303 | "top": null, 304 | "grid_column": null, 305 | "overflow_y": null, 306 | "overflow_x": null, 307 | "grid_auto_flow": null, 308 | "grid_area": null, 309 | "grid_template_columns": null, 310 | "flex": null, 311 | "_model_name": "LayoutModel", 312 | "justify_items": null, 313 | "grid_row": null, 314 | "max_height": null, 315 | "align_content": null, 316 | "visibility": null, 317 | "align_self": null, 318 | "height": null, 319 | "min_height": null, 320 | "padding": null, 321 | "grid_auto_rows": null, 322 | "grid_gap": null, 323 | "max_width": null, 324 | "order": null, 325 | "_view_module_version": "1.2.0", 326 | "grid_template_areas": null, 327 | "object_position": null, 328 | "object_fit": null, 329 | "grid_auto_columns": null, 330 | "margin": null, 331 | "display": "inline-flex", 332 | "left": null 333 | } 334 | }, 335 | "82170aabdef246edbf668bb1cdf4a5e3": { 336 | "model_module": "@jupyter-widgets/controls", 337 | "model_name": "FloatProgressModel", 338 | "state": { 339 | "_view_name": "ProgressView", 340 | "style": "IPY_MODEL_f80b50d773b240a091c6c3bdf7961924", 341 | "_dom_classes": [], 342 | "description": "Epoch 1: 10%", 343 | "_model_name": "FloatProgressModel", 344 | "bar_style": "info", 345 | "max": 1577, 346 | "_view_module": "@jupyter-widgets/controls", 347 | "_model_module_version": "1.5.0", 348 | "value": 160, 349 | "_view_count": null, 350 | "_view_module_version": "1.5.0", 351 | "orientation": "horizontal", 352 | "min": 0, 353 | "description_tooltip": null, 354 | "_model_module": "@jupyter-widgets/controls", 355 | "layout": "IPY_MODEL_f8a352088b8d4ed896bf8a206ecc024e" 356 | } 357 | }, 358 | "43a58bf6d9ab454095f8f5f30f10cdca": { 359 | "model_module": "@jupyter-widgets/controls", 360 | "model_name": "HTMLModel", 361 | "state": { 362 | "_view_name": "HTMLView", 363 | "style": "IPY_MODEL_2e151daf92e84c369ee90e8ded7a24f2", 364 | "_dom_classes": [], 365 | "description": "", 366 | "_model_name": "HTMLModel", 367 | "placeholder": "​", 368 | "_view_module": "@jupyter-widgets/controls", 369 | "_model_module_version": "1.5.0", 370 | "value": " 160/1577 [06:58<1:01:48, 2.62s/it, loss=0.485, v_num=8dcde95206ec45daac4cc6657844b03d, train_loss=0.0916, train_loss_aux=1.97, train_prec_avg=0.405, total_loss_train=0.485, train_loss_avg=0.117]", 371 | "_view_count": null, 372 | "_view_module_version": "1.5.0", 373 | "description_tooltip": null, 374 | "_model_module": "@jupyter-widgets/controls", 375 | "layout": "IPY_MODEL_5b0b2240357743c4b8285ce6017638c4" 376 | } 377 | }, 378 | "f80b50d773b240a091c6c3bdf7961924": { 379 | "model_module": "@jupyter-widgets/controls", 380 | "model_name": "ProgressStyleModel", 381 | "state": { 382 | "_view_name": "StyleView", 383 | "_model_name": "ProgressStyleModel", 384 | "description_width": "initial", 385 | "_view_module": "@jupyter-widgets/base", 386 | "_model_module_version": "1.5.0", 387 | "_view_count": null, 388 | "_view_module_version": "1.2.0", 389 | "bar_color": null, 390 | "_model_module": "@jupyter-widgets/controls" 391 | } 392 | }, 393 | "f8a352088b8d4ed896bf8a206ecc024e": { 394 | "model_module": "@jupyter-widgets/base", 395 | "model_name": "LayoutModel", 396 | "state": { 397 | "_view_name": "LayoutView", 398 | "grid_template_rows": null, 399 | "right": null, 400 | "justify_content": null, 401 | "_view_module": "@jupyter-widgets/base", 402 | "overflow": null, 403 | "_model_module_version": "1.2.0", 404 | "_view_count": null, 405 | "flex_flow": null, 406 | "width": null, 407 | "min_width": null, 408 | "border": null, 409 | "align_items": null, 410 | "bottom": null, 411 | "_model_module": "@jupyter-widgets/base", 412 | "top": null, 413 | "grid_column": null, 414 | "overflow_y": null, 415 | "overflow_x": null, 416 | "grid_auto_flow": null, 417 | "grid_area": null, 418 | "grid_template_columns": null, 419 | "flex": "2", 420 | "_model_name": "LayoutModel", 421 | "justify_items": null, 422 | "grid_row": null, 423 | "max_height": null, 424 | "align_content": null, 425 | "visibility": null, 426 | "align_self": null, 427 | "height": null, 428 | "min_height": null, 429 | "padding": null, 430 | "grid_auto_rows": null, 431 | "grid_gap": null, 432 | "max_width": null, 433 | "order": null, 434 | "_view_module_version": "1.2.0", 435 | "grid_template_areas": null, 436 | "object_position": null, 437 | "object_fit": null, 438 | "grid_auto_columns": null, 439 | "margin": null, 440 | "display": null, 441 | "left": null 442 | } 443 | }, 444 | "2e151daf92e84c369ee90e8ded7a24f2": { 445 | "model_module": "@jupyter-widgets/controls", 446 | "model_name": "DescriptionStyleModel", 447 | "state": { 448 | "_view_name": "StyleView", 449 | "_model_name": "DescriptionStyleModel", 450 | "description_width": "", 451 | "_view_module": "@jupyter-widgets/base", 452 | "_model_module_version": "1.5.0", 453 | "_view_count": null, 454 | "_view_module_version": "1.2.0", 455 | "_model_module": "@jupyter-widgets/controls" 456 | } 457 | }, 458 | "5b0b2240357743c4b8285ce6017638c4": { 459 | "model_module": "@jupyter-widgets/base", 460 | "model_name": "LayoutModel", 461 | "state": { 462 | "_view_name": "LayoutView", 463 | "grid_template_rows": null, 464 | "right": null, 465 | "justify_content": null, 466 | "_view_module": "@jupyter-widgets/base", 467 | "overflow": null, 468 | "_model_module_version": "1.2.0", 469 | "_view_count": null, 470 | "flex_flow": null, 471 | "width": null, 472 | "min_width": null, 473 | "border": null, 474 | "align_items": null, 475 | "bottom": null, 476 | "_model_module": "@jupyter-widgets/base", 477 | "top": null, 478 | "grid_column": null, 479 | "overflow_y": null, 480 | "overflow_x": null, 481 | "grid_auto_flow": null, 482 | "grid_area": null, 483 | "grid_template_columns": null, 484 | "flex": null, 485 | "_model_name": "LayoutModel", 486 | "justify_items": null, 487 | "grid_row": null, 488 | "max_height": null, 489 | "align_content": null, 490 | "visibility": null, 491 | "align_self": null, 492 | "height": null, 493 | "min_height": null, 494 | "padding": null, 495 | "grid_auto_rows": null, 496 | "grid_gap": null, 497 | "max_width": null, 498 | "order": null, 499 | "_view_module_version": "1.2.0", 500 | "grid_template_areas": null, 501 | "object_position": null, 502 | "object_fit": null, 503 | "grid_auto_columns": null, 504 | "margin": null, 505 | "display": null, 506 | "left": null 507 | } 508 | } 509 | } 510 | } 511 | }, 512 | "cells": [ 513 | { 514 | "cell_type": "markdown", 515 | "metadata": { 516 | "id": "uoLSVVIBCwLm", 517 | "colab_type": "text" 518 | }, 519 | "source": [ 520 | "# Interpretable Contextual Team-aware Item Recommendation: Application in Multiplayer Online Battle Arena Games\n", 521 | "*Andres Villa, Vladimir Araujo, Francisca Cattan*" 522 | ] 523 | }, 524 | { 525 | "cell_type": "markdown", 526 | "metadata": { 527 | "id": "t8_YV_PIDR97", 528 | "colab_type": "text" 529 | }, 530 | "source": [ 531 | "# Introduction\n", 532 | "\n", 533 | "This notebook contains the code of the proposed model. It is composed of 8 main stages:\n", 534 | "\n", 535 | "1. Connect to gDrive\n", 536 | "2. Dataset and Transformations\n", 537 | "3. Model\n", 538 | "4. Logger and Checkpointer\n", 539 | "5. Metrics\n", 540 | "6. Training and evaluation loop\n", 541 | "7. Config file\n", 542 | "8. Training and evaluation executor\n", 543 | "9. Obtain the role and id of each champion in each match\n", 544 | "10. Load the attention weights\n", 545 | "11. Draw the attention map\n", 546 | "\n", 547 | "*This notebook can be run in it's entirety. The final cell executes the training and validation of the model. " 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": { 553 | "id": "eYxhCKYPbBT2", 554 | "colab_type": "toc" 555 | }, 556 | "source": [ 557 | ">[Main Model - Project Title](#scrollTo=uoLSVVIBCwLm)\n", 558 | "\n", 559 | ">[Introduction](#scrollTo=t8_YV_PIDR97)\n", 560 | "\n", 561 | ">[Install all the dependencies](#scrollTo=etkQTYydGkFM)\n", 562 | "\n", 563 | ">[Import the dependencies](#scrollTo=S0YvGjijGxET)\n", 564 | "\n", 565 | ">[Connect to gDrive](#scrollTo=pfDyM4E7G4L2)\n", 566 | "\n", 567 | ">[Dataset and Transformations](#scrollTo=h9MDWroJSkhM)\n", 568 | "\n", 569 | ">[Model](#scrollTo=UIm1_KUCUNB0)\n", 570 | "\n", 571 | ">>[Transformer encoder modified to obtain the attention weights](#scrollTo=qr3TZbrnUg2H)\n", 572 | "\n", 573 | ">>[Auxiliary Task Classes](#scrollTo=pwRy106QU6sH)\n", 574 | "\n", 575 | ">>[Main Class of the proposed model](#scrollTo=9AlU_u42VG8A)\n", 576 | "\n", 577 | ">[Logger and Checkpointer](#scrollTo=rwYoKWcsVqex)\n", 578 | "\n", 579 | ">[Metrics](#scrollTo=5ktMqAUMWeEz)\n", 580 | "\n", 581 | ">[Training and evaluation loop](#scrollTo=WDA0GHysW4vX)\n", 582 | "\n", 583 | ">[Config file](#scrollTo=CyRfaqN8XvYi)\n", 584 | "\n", 585 | ">[Training and evaluation executor](#scrollTo=IVtKoVTcYDS1)\n", 586 | "\n", 587 | ">[T-test](#scrollTo=D2TPs5U3vv7m)\n", 588 | "\n", 589 | ">[Obtain the role and id of each champion in each match](#scrollTo=sFtaUCU5T8fl)\n", 590 | "\n", 591 | ">[Load the attention weights](#scrollTo=VINfHm76U1vz)\n", 592 | "\n", 593 | ">[Draw the attention map](#scrollTo=SvhQCEzcU6x_)\n", 594 | "\n" 595 | ] 596 | }, 597 | { 598 | "cell_type": "markdown", 599 | "metadata": { 600 | "id": "etkQTYydGkFM", 601 | "colab_type": "text" 602 | }, 603 | "source": [ 604 | "# Install all the dependencies" 605 | ] 606 | }, 607 | { 608 | "cell_type": "markdown", 609 | "metadata": { 610 | "id": "uH0huCgYRJou", 611 | "colab_type": "text" 612 | }, 613 | "source": [ 614 | "Install all the libraries neccesary to run the model. " 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "metadata": { 620 | "id": "qsqG6vM9tqER", 621 | "colab_type": "code", 622 | "colab": { 623 | "base_uri": "https://localhost:8080/", 624 | "height": 369 625 | }, 626 | "outputId": "1450ea6d-a49f-4586-b3a9-6c111b197efa" 627 | }, 628 | "source": [ 629 | "!nvidia-smi" 630 | ], 631 | "execution_count": 3, 632 | "outputs": [ 633 | { 634 | "output_type": "stream", 635 | "text": [ 636 | "Tue Jul 28 06:40:42 2020 \n", 637 | "+-----------------------------------------------------------------------------+\n", 638 | "| NVIDIA-SMI 450.51.05 Driver Version: 418.67 CUDA Version: 10.1 |\n", 639 | "|-------------------------------+----------------------+----------------------+\n", 640 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 641 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 642 | "| | | MIG M. |\n", 643 | "|===============================+======================+======================|\n", 644 | "| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n", 645 | "| N/A 43C P0 28W / 250W | 0MiB / 16280MiB | 0% Default |\n", 646 | "| | | ERR! |\n", 647 | "+-------------------------------+----------------------+----------------------+\n", 648 | " \n", 649 | "+-----------------------------------------------------------------------------+\n", 650 | "| Processes: |\n", 651 | "| GPU GI CI PID Type Process name GPU Memory |\n", 652 | "| ID ID Usage |\n", 653 | "|=============================================================================|\n", 654 | "| No running processes found |\n", 655 | "+-----------------------------------------------------------------------------+\n" 656 | ], 657 | "name": "stdout" 658 | } 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "metadata": { 664 | "id": "qTd06UwscDsS", 665 | "colab_type": "code", 666 | "colab": { 667 | "base_uri": "https://localhost:8080/", 668 | "height": 1000 669 | }, 670 | "outputId": "ca753b75-1078-420f-cdfb-b3e5ee505b97" 671 | }, 672 | "source": [ 673 | "!pip install git+git://github.com/williamFalcon/pytorch-lightning.git@master --upgrade" 674 | ], 675 | "execution_count": 4, 676 | "outputs": [ 677 | { 678 | "output_type": "stream", 679 | "text": [ 680 | "Collecting git+git://github.com/williamFalcon/pytorch-lightning.git@master\n", 681 | " Cloning git://github.com/williamFalcon/pytorch-lightning.git (to revision master) to /tmp/pip-req-build-j84m_1zy\n", 682 | " Running command git clone -q git://github.com/williamFalcon/pytorch-lightning.git /tmp/pip-req-build-j84m_1zy\n", 683 | " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", 684 | " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", 685 | " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n", 686 | "Collecting future>=0.17.1\n", 687 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)\n", 688 | "\u001b[K |████████████████████████████████| 829kB 2.9MB/s \n", 689 | "\u001b[?25hRequirement already satisfied, skipping upgrade: tqdm>=4.41.0 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning==0.9.0rc2) (4.41.1)\n", 690 | "Requirement already satisfied, skipping upgrade: numpy>=1.16.4 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning==0.9.0rc2) (1.18.5)\n", 691 | "Requirement already satisfied, skipping upgrade: torch>=1.3 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning==0.9.0rc2) (1.5.1+cu101)\n", 692 | "Requirement already satisfied, skipping upgrade: tensorboard>=1.14 in /usr/local/lib/python3.6/dist-packages (from pytorch-lightning==0.9.0rc2) (2.2.2)\n", 693 | "Collecting PyYAML>=5.1\n", 694 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n", 695 | "\u001b[K |████████████████████████████████| 276kB 15.9MB/s \n", 696 | "\u001b[?25hRequirement already satisfied, skipping upgrade: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (3.12.2)\n", 697 | "Requirement already satisfied, skipping upgrade: grpcio>=1.24.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.30.0)\n", 698 | "Requirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (0.34.2)\n", 699 | "Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.0.1)\n", 700 | "Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (3.2.2)\n", 701 | "Requirement already satisfied, skipping upgrade: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (0.9.0)\n", 702 | "Requirement already satisfied, skipping upgrade: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.15.0)\n", 703 | "Requirement already satisfied, skipping upgrade: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.7.0)\n", 704 | "Requirement already satisfied, skipping upgrade: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (49.1.0)\n", 705 | "Requirement already satisfied, skipping upgrade: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.17.2)\n", 706 | "Requirement already satisfied, skipping upgrade: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (2.23.0)\n", 707 | "Requirement already satisfied, skipping upgrade: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (0.4.1)\n", 708 | "Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.7.0)\n", 709 | "Requirement already satisfied, skipping upgrade: rsa<5,>=3.1.4; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (4.6)\n", 710 | "Requirement already satisfied, skipping upgrade: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (4.1.1)\n", 711 | "Requirement already satisfied, skipping upgrade: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (0.2.8)\n", 712 | "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (2020.6.20)\n", 713 | "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.24.3)\n", 714 | "Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (3.0.4)\n", 715 | "Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (2.10)\n", 716 | "Requirement already satisfied, skipping upgrade: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (1.3.0)\n", 717 | "Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (3.1.0)\n", 718 | "Requirement already satisfied, skipping upgrade: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<5,>=3.1.4; python_version >= \"3\"->google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (0.4.8)\n", 719 | "Requirement already satisfied, skipping upgrade: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14->pytorch-lightning==0.9.0rc2) (3.1.0)\n", 720 | "Building wheels for collected packages: pytorch-lightning\n", 721 | " Building wheel for pytorch-lightning (PEP 517) ... \u001b[?25l\u001b[?25hdone\n", 722 | " Created wheel for pytorch-lightning: filename=pytorch_lightning-0.9.0rc2-cp36-none-any.whl size=353828 sha256=30b73a303ccd241770a24f1984519ad7e086144a4db7285f7b8166e4de330d64\n", 723 | " Stored in directory: /tmp/pip-ephem-wheel-cache-jlmk53yu/wheels/02/e9/33/ecf2ab0b937f47c530a3d24222ca1a784412a0c7d490195c5f\n", 724 | "Successfully built pytorch-lightning\n", 725 | "Building wheels for collected packages: future, PyYAML\n", 726 | " Building wheel for future (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 727 | " Created wheel for future: filename=future-0.18.2-cp36-none-any.whl size=491057 sha256=15e732369ebb372a11250e7c9d0e57f7af49cc24c3e76854d4126655f02df3b9\n", 728 | " Stored in directory: /root/.cache/pip/wheels/8b/99/a0/81daf51dcd359a9377b110a8a886b3895921802d2fc1b2397e\n", 729 | " Building wheel for PyYAML (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 730 | " Created wheel for PyYAML: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44621 sha256=085fe27c9be3cbcd42d6336ee010929d00e36263ea5e88a9613aeb9fbf6e1b49\n", 731 | " Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd\n", 732 | "Successfully built future PyYAML\n", 733 | "Installing collected packages: future, PyYAML, pytorch-lightning\n", 734 | " Found existing installation: future 0.16.0\n", 735 | " Uninstalling future-0.16.0:\n", 736 | " Successfully uninstalled future-0.16.0\n", 737 | " Found existing installation: PyYAML 3.13\n", 738 | " Uninstalling PyYAML-3.13:\n", 739 | " Successfully uninstalled PyYAML-3.13\n", 740 | "Successfully installed PyYAML-5.3.1 future-0.18.2 pytorch-lightning-0.9.0rc2\n" 741 | ], 742 | "name": "stdout" 743 | } 744 | ] 745 | }, 746 | { 747 | "cell_type": "code", 748 | "metadata": { 749 | "id": "AwD1P0lHVaLO", 750 | "colab_type": "code", 751 | "colab": { 752 | "base_uri": "https://localhost:8080/", 753 | "height": 600 754 | }, 755 | "outputId": "4d2b24ba-90cd-4304-cb42-f09c697a3827" 756 | }, 757 | "source": [ 758 | "!pip install comet_ml==3.0.2" 759 | ], 760 | "execution_count": 5, 761 | "outputs": [ 762 | { 763 | "output_type": "stream", 764 | "text": [ 765 | "Collecting comet_ml==3.0.2\n", 766 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/99/c6/fac88f43f2aa61a09fee4ffb769c73fe93fe7de75764246e70967d31da09/comet_ml-3.0.2-py3-none-any.whl (170kB)\n", 767 | "\u001b[K |████████████████████████████████| 174kB 2.9MB/s \n", 768 | "\u001b[?25hCollecting websocket-client>=0.55.0\n", 769 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/4c/5f/f61b420143ed1c8dc69f9eaec5ff1ac36109d52c80de49d66e0c36c3dfdf/websocket_client-0.57.0-py2.py3-none-any.whl (200kB)\n", 770 | "\u001b[K |████████████████████████████████| 204kB 8.9MB/s \n", 771 | "\u001b[?25hCollecting everett[ini]>=1.0.1; python_version >= \"3.0\"\n", 772 | " Downloading https://files.pythonhosted.org/packages/12/34/de70a3d913411e40ce84966f085b5da0c6df741e28c86721114dd290aaa0/everett-1.0.2-py2.py3-none-any.whl\n", 773 | "Requirement already satisfied: requests>=2.18.4 in /usr/local/lib/python3.6/dist-packages (from comet_ml==3.0.2) (2.23.0)\n", 774 | "Collecting wurlitzer>=1.0.2\n", 775 | " Downloading https://files.pythonhosted.org/packages/0c/1e/52f4effa64a447c4ec0fb71222799e2ac32c55b4b6c1725fccdf6123146e/wurlitzer-2.0.1-py2.py3-none-any.whl\n", 776 | "Collecting comet-git-pure>=0.19.11\n", 777 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/72/7a/483413046e48908986a0f9a1d8a917e1da46ae58e6ba16b2ac71b3adf8d7/comet_git_pure-0.19.16-py3-none-any.whl (409kB)\n", 778 | "\u001b[K |████████████████████████████████| 419kB 8.8MB/s \n", 779 | "\u001b[?25hRequirement already satisfied: jsonschema<3.1.0,>=2.6.0 in /usr/local/lib/python3.6/dist-packages (from comet_ml==3.0.2) (2.6.0)\n", 780 | "Collecting netifaces>=0.10.7\n", 781 | " Downloading https://files.pythonhosted.org/packages/0c/9b/c4c7eb09189548d45939a3d3a6b3d53979c67d124459b27a094c365c347f/netifaces-0.10.9-cp36-cp36m-manylinux1_x86_64.whl\n", 782 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from comet_ml==3.0.2) (1.15.0)\n", 783 | "Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.6/dist-packages (from comet_ml==3.0.2) (7.352.0)\n", 784 | "Collecting configobj; extra == \"ini\"\n", 785 | " Downloading https://files.pythonhosted.org/packages/64/61/079eb60459c44929e684fa7d9e2fdca403f67d64dd9dbac27296be2e0fab/configobj-5.0.6.tar.gz\n", 786 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18.4->comet_ml==3.0.2) (1.24.3)\n", 787 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18.4->comet_ml==3.0.2) (2020.6.20)\n", 788 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18.4->comet_ml==3.0.2) (3.0.4)\n", 789 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18.4->comet_ml==3.0.2) (2.10)\n", 790 | "Building wheels for collected packages: configobj\n", 791 | " Building wheel for configobj (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 792 | " Created wheel for configobj: filename=configobj-5.0.6-cp36-none-any.whl size=34546 sha256=b4795ece5a011d0faed00fc2c825d5a241e54d2ac170c3ec2b6ba97eafcb809f\n", 793 | " Stored in directory: /root/.cache/pip/wheels/f1/e4/16/4981ca97c2d65106b49861e0b35e2660695be7219a2d351ee0\n", 794 | "Successfully built configobj\n", 795 | "Installing collected packages: websocket-client, configobj, everett, wurlitzer, comet-git-pure, netifaces, comet-ml\n", 796 | "Successfully installed comet-git-pure-0.19.16 comet-ml-3.0.2 configobj-5.0.6 everett-1.0.2 netifaces-0.10.9 websocket-client-0.57.0 wurlitzer-2.0.1\n" 797 | ], 798 | "name": "stdout" 799 | } 800 | ] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "metadata": { 805 | "id": "9vHwQVrRWBgV", 806 | "colab_type": "code", 807 | "colab": { 808 | "base_uri": "https://localhost:8080/", 809 | "height": 160 810 | }, 811 | "outputId": "3b656b18-e85b-4549-a55b-f99a5d1b1c73" 812 | }, 813 | "source": [ 814 | "!pip install omegaconf" 815 | ], 816 | "execution_count": 6, 817 | "outputs": [ 818 | { 819 | "output_type": "stream", 820 | "text": [ 821 | "Collecting omegaconf\n", 822 | " Downloading https://files.pythonhosted.org/packages/3d/95/ebd73361f9c6e94bd0f3b19ffe31c24e833834c022f1c0328ac71b2d6c90/omegaconf-2.0.0-py3-none-any.whl\n", 823 | "Requirement already satisfied: PyYAML in /usr/local/lib/python3.6/dist-packages (from omegaconf) (5.3.1)\n", 824 | "Requirement already satisfied: dataclasses; python_version == \"3.6\" in /usr/local/lib/python3.6/dist-packages (from omegaconf) (0.7)\n", 825 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from omegaconf) (3.7.4.2)\n", 826 | "Installing collected packages: omegaconf\n", 827 | "Successfully installed omegaconf-2.0.0\n" 828 | ], 829 | "name": "stdout" 830 | } 831 | ] 832 | }, 833 | { 834 | "cell_type": "code", 835 | "metadata": { 836 | "id": "6yA6EwQM6P8i", 837 | "colab_type": "code", 838 | "colab": { 839 | "base_uri": "https://localhost:8080/", 840 | "height": 160 841 | }, 842 | "outputId": "0dfe679a-88c9-4cdc-81d8-c98b24d5e010" 843 | }, 844 | "source": [ 845 | "!pip install adabound" 846 | ], 847 | "execution_count": 7, 848 | "outputs": [ 849 | { 850 | "output_type": "stream", 851 | "text": [ 852 | "Collecting adabound\n", 853 | " Downloading https://files.pythonhosted.org/packages/cd/44/0c2c414effb3d9750d780b230dbb67ea48ddc5d9a6d7a9b7e6fcc6bdcff9/adabound-0.0.5-py3-none-any.whl\n", 854 | "Requirement already satisfied: torch>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from adabound) (1.5.1+cu101)\n", 855 | "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch>=0.4.0->adabound) (0.18.2)\n", 856 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch>=0.4.0->adabound) (1.18.5)\n", 857 | "Installing collected packages: adabound\n", 858 | "Successfully installed adabound-0.0.5\n" 859 | ], 860 | "name": "stdout" 861 | } 862 | ] 863 | }, 864 | { 865 | "cell_type": "code", 866 | "metadata": { 867 | "id": "JMNgXeWwNx-N", 868 | "colab_type": "code", 869 | "colab": { 870 | "base_uri": "https://localhost:8080/", 871 | "height": 283 872 | }, 873 | "outputId": "521241cd-9097-4098-da3a-71bc7d2e3c05" 874 | }, 875 | "source": [ 876 | "!pip install ml_metrics" 877 | ], 878 | "execution_count": 8, 879 | "outputs": [ 880 | { 881 | "output_type": "stream", 882 | "text": [ 883 | "Collecting ml_metrics\n", 884 | " Downloading https://files.pythonhosted.org/packages/c1/e7/c31a2dd37045a0c904bee31c2dbed903d4f125a6ce980b91bae0c961abb8/ml_metrics-0.1.4.tar.gz\n", 885 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from ml_metrics) (1.18.5)\n", 886 | "Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from ml_metrics) (1.0.5)\n", 887 | "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas->ml_metrics) (2.8.1)\n", 888 | "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->ml_metrics) (2018.9)\n", 889 | "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.6.1->pandas->ml_metrics) (1.15.0)\n", 890 | "Building wheels for collected packages: ml-metrics\n", 891 | " Building wheel for ml-metrics (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 892 | " Created wheel for ml-metrics: filename=ml_metrics-0.1.4-cp36-none-any.whl size=7850 sha256=4c96be29d4d35a67a4b7cfad5648a20405ba81588e3a54aafe24d966ed354a94\n", 893 | " Stored in directory: /root/.cache/pip/wheels/b3/61/2d/776be7b8a4f14c5db48c8e5451451cabc58dc6aa7ee3801163\n", 894 | "Successfully built ml-metrics\n", 895 | "Installing collected packages: ml-metrics\n", 896 | "Successfully installed ml-metrics-0.1.4\n" 897 | ], 898 | "name": "stdout" 899 | } 900 | ] 901 | }, 902 | { 903 | "cell_type": "markdown", 904 | "metadata": { 905 | "id": "S0YvGjijGxET", 906 | "colab_type": "text" 907 | }, 908 | "source": [ 909 | "# Import the dependencies" 910 | ] 911 | }, 912 | { 913 | "cell_type": "markdown", 914 | "metadata": { 915 | "id": "ypMu2vhSRZD2", 916 | "colab_type": "text" 917 | }, 918 | "source": [ 919 | "Import all the libraries neccesary to run the model." 920 | ] 921 | }, 922 | { 923 | "cell_type": "code", 924 | "metadata": { 925 | "id": "jNaJbRoaa8ZU", 926 | "colab_type": "code", 927 | "colab": {} 928 | }, 929 | "source": [ 930 | "from comet_ml import Experiment as CometExperiment\n", 931 | "from comet_ml import ExistingExperiment as CometExistingExperiment\n", 932 | "from google.colab import drive\n", 933 | "import torch\n", 934 | "import copy\n", 935 | "import torch.nn as nn\n", 936 | "import torch.nn.functional as F\n", 937 | "from torch.utils.data import Dataset, DataLoader\n", 938 | "import numpy as np\n", 939 | "from omegaconf import OmegaConf\n", 940 | "from omegaconf.dictconfig import DictConfig\n", 941 | "import pandas as pd\n", 942 | "import time\n", 943 | "\n", 944 | "# from tqdm.notebook import trange, tqdm\n", 945 | "from pytorch_lightning.callbacks import ModelCheckpoint\n", 946 | "from pytorch_lightning.utilities import rank_zero_only\n", 947 | "from pytorch_lightning.logging import LightningLoggerBase\n", 948 | "from pytorch_lightning.loggers import CometLogger\n", 949 | "\n", 950 | "import os\n", 951 | "import pytorch_lightning as pl\n", 952 | "import pickle\n", 953 | "import adabound\n", 954 | "import ml_metrics as metrics\n", 955 | "import random\n", 956 | "import itertools\n", 957 | "from torchvision import transforms\n" 958 | ], 959 | "execution_count": 9, 960 | "outputs": [] 961 | }, 962 | { 963 | "cell_type": "markdown", 964 | "metadata": { 965 | "id": "pfDyM4E7G4L2", 966 | "colab_type": "text" 967 | }, 968 | "source": [ 969 | "# Connect to gDrive" 970 | ] 971 | }, 972 | { 973 | "cell_type": "markdown", 974 | "metadata": { 975 | "id": "XLY4l7XzReoa", 976 | "colab_type": "text" 977 | }, 978 | "source": [ 979 | "Connect the notebook with the gDrive, which is essential to load and save data like dataset, checkpoints, and attention weights. " 980 | ] 981 | }, 982 | { 983 | "cell_type": "code", 984 | "metadata": { 985 | "id": "hmQwtSRCbchS", 986 | "colab_type": "code", 987 | "colab": { 988 | "base_uri": "https://localhost:8080/", 989 | "height": 125 990 | }, 991 | "outputId": "26c7f9ce-aa16-4b69-aea8-38817cb64731" 992 | }, 993 | "source": [ 994 | "drive.mount('/content/gdrive/')" 995 | ], 996 | "execution_count": 10, 997 | "outputs": [ 998 | { 999 | "output_type": "stream", 1000 | "text": [ 1001 | "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n", 1002 | "\n", 1003 | "Enter your authorization code:\n", 1004 | "··········\n", 1005 | "Mounted at /content/gdrive/\n" 1006 | ], 1007 | "name": "stdout" 1008 | } 1009 | ] 1010 | }, 1011 | { 1012 | "cell_type": "markdown", 1013 | "metadata": { 1014 | "id": "h9MDWroJSkhM", 1015 | "colab_type": "text" 1016 | }, 1017 | "source": [ 1018 | "# Dataset and Transformations" 1019 | ] 1020 | }, 1021 | { 1022 | "cell_type": "markdown", 1023 | "metadata": { 1024 | "id": "uIDIGrXVSzdy", 1025 | "colab_type": "text" 1026 | }, 1027 | "source": [ 1028 | "This is important to load the k different partitions which are obtained using cross validation k-fold." 1029 | ] 1030 | }, 1031 | { 1032 | "cell_type": "code", 1033 | "metadata": { 1034 | "id": "Mir9w6zDzkSN", 1035 | "colab_type": "code", 1036 | "colab": {} 1037 | }, 1038 | "source": [ 1039 | "train_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/train_splits.pkl'\n", 1040 | "test_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/test_splits.pkl'\n", 1041 | "champion_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/champion_types.pkl'" 1042 | ], 1043 | "execution_count": 11, 1044 | "outputs": [] 1045 | }, 1046 | { 1047 | "cell_type": "code", 1048 | "metadata": { 1049 | "id": "yDP51QcmNgCd", 1050 | "colab_type": "code", 1051 | "cellView": "form", 1052 | "colab": {} 1053 | }, 1054 | "source": [ 1055 | "#@title Cargar listas de particiones\n", 1056 | "with open(train_path, 'rb') as handle:\n", 1057 | " list_trainset = pickle.load(handle)\n", 1058 | "\n", 1059 | "with open(test_path, 'rb') as handle:\n", 1060 | " list_testset = pickle.load(handle)\n", 1061 | "\n", 1062 | "with open(champion_path, 'rb') as handle:\n", 1063 | " champion_types = pickle.load(handle)" 1064 | ], 1065 | "execution_count": 12, 1066 | "outputs": [] 1067 | }, 1068 | { 1069 | "cell_type": "code", 1070 | "metadata": { 1071 | "id": "2zbdC7M1S3kh", 1072 | "colab_type": "code", 1073 | "colab": {} 1074 | }, 1075 | "source": [ 1076 | "def get_partition(id_split, list_splits = list_trainset):\n", 1077 | " df = list_splits[id_split]\n", 1078 | " null_registers = df.loc[(df.item1 == 0) & (df.item2 == 0) & (df.item3 == 0) & (df.item4 == 0) & (df.item5 == 0) & (df.item6 == 0)]\n", 1079 | " match_to_del = list(set(null_registers['matchid']))\n", 1080 | " df = df[~df.matchid.isin(match_to_del)]\n", 1081 | " return df" 1082 | ], 1083 | "execution_count": 13, 1084 | "outputs": [] 1085 | }, 1086 | { 1087 | "cell_type": "markdown", 1088 | "metadata": { 1089 | "id": "dK-OEg1dTYJw", 1090 | "colab_type": "text" 1091 | }, 1092 | "source": [ 1093 | "These transformations rote randomly the order between the two teams, and the champions within each team." 1094 | ] 1095 | }, 1096 | { 1097 | "cell_type": "code", 1098 | "metadata": { 1099 | "id": "muCDKh3n524z", 1100 | "colab_type": "code", 1101 | "colab": {} 1102 | }, 1103 | "source": [ 1104 | "class RandomSort_Team(object):\n", 1105 | " \"\"\"Crop randomly the image in a sample.\n", 1106 | "\n", 1107 | " Args:\n", 1108 | " output_size (tuple or int): Desired output size. If int, square crop\n", 1109 | " is made.\n", 1110 | " \"\"\"\n", 1111 | " \n", 1112 | " def get_random_sample(self, sample):\n", 1113 | " x, y = sample\n", 1114 | "\n", 1115 | " ids_teams_1 = [x for x in range(5)]\n", 1116 | " ids_teams_2 = [x for x in range(5,10)]\n", 1117 | "\n", 1118 | " ids_team_t = [ids_teams_1, ids_teams_2]\n", 1119 | "\n", 1120 | " ids_teams = [1, 0]\n", 1121 | " #ids_teams = [x for x in range(2)]\n", 1122 | " #random.shuffle(ids_teams)\n", 1123 | "\n", 1124 | " ids_team_t = [ids_team_t[i] for i in ids_teams]\n", 1125 | " \n", 1126 | " ids_team_t = list(itertools.chain.from_iterable(ids_team_t))\n", 1127 | "\n", 1128 | " x['champions'] = x['champions'][ids_team_t]\n", 1129 | " x['role'] = x['role'][ids_team_t]\n", 1130 | " x['type'] = x['type'][ids_team_t,:]\n", 1131 | "\n", 1132 | " y['items'] = y['items'][ids_team_t,:]\n", 1133 | "\n", 1134 | " if ids_teams == [1, 0]:\n", 1135 | " y['win'] = torch.tensor(1) - y['win']\n", 1136 | " \n", 1137 | " return x, y\n", 1138 | "\n", 1139 | " def __call__(self, sample_list):\n", 1140 | " list_x_champions = []\n", 1141 | " list_x_role = []\n", 1142 | " list_x_type = []\n", 1143 | " list_y_items = []\n", 1144 | " list_y_win = []\n", 1145 | " x_old, y_old = sample_list\n", 1146 | " if isinstance(x_old, (list)) and isinstance(y_old, (list)):\n", 1147 | " for i in range(len(x_old)):\n", 1148 | " list_x_champions.append(x_old[i]['champions'])\n", 1149 | " list_x_role.append(x_old[i]['role'])\n", 1150 | " list_x_type.append(x_old[i]['type'])\n", 1151 | " list_y_items.append(y_old[i]['items'])\n", 1152 | " list_y_win.append(y_old[i]['win'])\n", 1153 | " sample = x_old[i], y_old[i]\n", 1154 | " x, y = self.get_random_sample(sample)\n", 1155 | " list_x_champions.append(x['champions'])\n", 1156 | " list_x_role.append(x['role'])\n", 1157 | " list_x_type.append(x['type'])\n", 1158 | " list_y_items.append(y['items'])\n", 1159 | " list_y_win.append(y['win'])\n", 1160 | " else:\n", 1161 | " list_x_champions.append(x_old['champions'])\n", 1162 | " list_x_role.append(x_old['role'])\n", 1163 | " list_x_type.append(x_old['type'])\n", 1164 | " list_y_items.append(y_old['items'])\n", 1165 | " list_y_win.append(y_old['win'])\n", 1166 | " sample = x_old, y_old\n", 1167 | " x, y = self.get_random_sample(sample)\n", 1168 | " list_x_champions.append(x['champions'])\n", 1169 | " list_x_role.append(x['role'])\n", 1170 | " list_x_type.append(x['type'])\n", 1171 | " list_y_items.append(y['items'])\n", 1172 | " list_y_win.append(y['win'])\n", 1173 | " new_x = {\n", 1174 | " 'champions': torch.stack(list_x_champions, dim=0),\n", 1175 | " 'role': torch.stack(list_x_role, dim=0),\n", 1176 | " 'type': torch.stack(list_x_type, dim=0)\n", 1177 | " }\n", 1178 | " new_y = {\n", 1179 | " 'items': torch.stack(list_y_items, dim=0),\n", 1180 | " 'win': torch.stack(list_y_win, dim=0)\n", 1181 | " }\n", 1182 | " return new_x, new_y\n" 1183 | ], 1184 | "execution_count": 14, 1185 | "outputs": [] 1186 | }, 1187 | { 1188 | "cell_type": "code", 1189 | "metadata": { 1190 | "id": "ozFSmovk06GG", 1191 | "colab_type": "code", 1192 | "colab": {} 1193 | }, 1194 | "source": [ 1195 | "class RandomSort_Part(object):\n", 1196 | " \"\"\"Crop randomly the image in a sample.\n", 1197 | "\n", 1198 | " Args:\n", 1199 | " output_size (tuple or int): Desired output size. If int, square crop\n", 1200 | " is made.\n", 1201 | " \"\"\"\n", 1202 | " \n", 1203 | "\n", 1204 | " def __call__(self, sample):\n", 1205 | "\n", 1206 | " list_t_x = []\n", 1207 | " list_t_y = []\n", 1208 | " x, y = sample\n", 1209 | "\n", 1210 | " list_t_x.append(x)\n", 1211 | " list_t_y.append(y)\n", 1212 | "\n", 1213 | " ids_team_1 = [x for x in range(5)]\n", 1214 | " ids_team_2 = [x for x in range(5,10)]\n", 1215 | " random.shuffle(ids_team_1)\n", 1216 | " random.shuffle(ids_team_2)\n", 1217 | "\n", 1218 | " ids_match = ids_team_1\n", 1219 | " ids_match.extend(ids_team_2)\n", 1220 | " \n", 1221 | " x['champions'] = x['champions'][ids_match]\n", 1222 | " x['role'] = x['role'][ids_match]\n", 1223 | " x['type'] = x['type'][ids_match,:]\n", 1224 | "\n", 1225 | " y['items'] = y['items'][ids_match,:]\n", 1226 | "\n", 1227 | " list_t_x.append(x)\n", 1228 | " list_t_y.append(y)\n", 1229 | "\n", 1230 | " return list_t_x, list_t_y" 1231 | ], 1232 | "execution_count": 15, 1233 | "outputs": [] 1234 | }, 1235 | { 1236 | "cell_type": "code", 1237 | "metadata": { 1238 | "id": "i3KCpfKDPX3D", 1239 | "colab_type": "code", 1240 | "colab": {} 1241 | }, 1242 | "source": [ 1243 | "class LolDataset(Dataset):\n", 1244 | " def __init__(self, data, transform=None):\n", 1245 | " # cargar el dataset\n", 1246 | " #self.matches = self._load_matches(path)\n", 1247 | " self.matches = data\n", 1248 | " # comprobar si existe el .pkl con los diccionarios\n", 1249 | "\n", 1250 | " # else:\n", 1251 | " # extraer info. del dataframe\n", 1252 | " self.champions = set(self.matches['championid'])\n", 1253 | " self.roles = set(self.matches['position-role'])\n", 1254 | " self.matches_id = list(set(self.matches['matchid']))\n", 1255 | " self.items = self.matches['item1']\n", 1256 | " self.items.append(self.matches['item2'])\n", 1257 | " self.items.append(self.matches['item3'])\n", 1258 | " self.items.append(self.matches['item4'])\n", 1259 | " self.items.append(self.matches['item5'])\n", 1260 | " self.items.append(self.matches['item6'])\n", 1261 | " items = set(self.items)\n", 1262 | " self.items = {i for i in items if i != 0}\n", 1263 | " self.champion_types = champion_types\n", 1264 | " list_champion_types = []\n", 1265 | " for k,v in champion_types.items():\n", 1266 | " list_champion_types.extend(v)\n", 1267 | " \n", 1268 | " self.set_champ_type = set(list_champion_types)\n", 1269 | "\n", 1270 | " # crear diccionarios token2id y id2token\n", 1271 | " self.champions_token2id, self.champions_id2token = self._token_dict(self.champions)\n", 1272 | " self.roles_token2id, self.roles_id2token = self._token_dict(self.roles)\n", 1273 | " self.items_token2id, self.items_id2token = self._token_dict(self.items)\n", 1274 | " self.types_token2id, self.types_id2token = self._token_dict(self.set_champ_type)\n", 1275 | "\n", 1276 | " self.transform = transform\n", 1277 | "\n", 1278 | " def _load_matches(self, path):\n", 1279 | " data_matches = pd.read_csv(path) \n", 1280 | " return data_matches\n", 1281 | "\n", 1282 | " def _token_dict(self, data):\n", 1283 | " token2id = {}\n", 1284 | " id2token = {}\n", 1285 | " for i, j in enumerate(data):\n", 1286 | " token2id.update({j:i})\n", 1287 | " id2token.update({i:j})\n", 1288 | "\n", 1289 | " return token2id, id2token\n", 1290 | "\n", 1291 | " def _tokens2ids(self, token2id, tokens):\n", 1292 | " ids = []\n", 1293 | " for token in tokens:\n", 1294 | " ids.append(token2id[token])\n", 1295 | " \n", 1296 | " return ids\n", 1297 | "\n", 1298 | " def _tokens2ids_items(self, token2id, tokens):\n", 1299 | " #items_vecs = []\n", 1300 | " item_vec = np.zeros((len(token2id)))\n", 1301 | " for token in tokens:\n", 1302 | " if token in token2id: \n", 1303 | " item_vec[token2id[token]] = 1\n", 1304 | " #items_vecs.append(item_vec)\n", 1305 | " \n", 1306 | " return item_vec\n", 1307 | "\n", 1308 | " def _build_dict(self, match):\n", 1309 | " # sacar en orden los campeones de la partida\n", 1310 | " champion_tokens = list(match['championid'])\n", 1311 | " champions_ids = self._tokens2ids(self.champions_token2id, champion_tokens)\n", 1312 | "\n", 1313 | " # sacar en orden los items de la partida\n", 1314 | " #items_tokens = match['championid']\n", 1315 | " #items_ids = self._tokens2ids(self.items_token2id, items_tokens)\n", 1316 | " # sacar en orden los roles de la partida\n", 1317 | " role_tokens = list(match['position-role'])\n", 1318 | " role_ids = self._tokens2ids(self.roles_token2id, role_tokens)\n", 1319 | " list_win = list(match['win'])[4:6]\n", 1320 | " \n", 1321 | " list_win = np.array(list_win)\n", 1322 | " num_win = np.argsort(list_win)\n", 1323 | " num_win = num_win[len(num_win)-1]\n", 1324 | "\n", 1325 | " list_part_items = []\n", 1326 | " list_types = []\n", 1327 | " items_list = ['item1','item2','item3','item4','item5','item6']\n", 1328 | " for id_champ in champion_tokens:\n", 1329 | " champ_atr = match[match.championid == id_champ]\n", 1330 | " items = champ_atr[items_list]\n", 1331 | " items_tokens = list(items.iloc[0, :])\n", 1332 | " items_ids = self._tokens2ids_items(self.items_token2id, items_tokens)\n", 1333 | " list_part_items.append(items_ids)\n", 1334 | "\n", 1335 | " type_champ = self.champion_types[id_champ]\n", 1336 | " type_ids = self._tokens2ids(self.types_token2id, type_champ)\n", 1337 | " list_types.append(type_ids)\n", 1338 | "\n", 1339 | " # construir 5 veces 0s y 5 veces 1s\n", 1340 | " #team_ids = \n", 1341 | " x = {\n", 1342 | " 'champions': torch.from_numpy(np.array(champions_ids)),\n", 1343 | " 'role': torch.from_numpy(np.array(role_ids)),\n", 1344 | " 'type': torch.from_numpy(np.array(list_types))\n", 1345 | " }\n", 1346 | " y= {\n", 1347 | " 'items': torch.from_numpy(np.array(list_part_items)),\n", 1348 | " 'win': torch.from_numpy(np.array(num_win))\n", 1349 | " }\n", 1350 | " \n", 1351 | " return x, y\n", 1352 | "\n", 1353 | " def __getitem__(self, idx): \n", 1354 | " # idx es el match_id en este caso\n", 1355 | " # la función debiera retornar la info de cada partida\n", 1356 | " # buscar idx de la partida en mi estructura, y retornar los diccionarios con los atributos\n", 1357 | " id_match = self.matches_id[idx]\n", 1358 | " match = self.matches[(self.matches.matchid == id_match)]\n", 1359 | " x, y = self._build_dict(match) # entrega un df de la partida según el idx\n", 1360 | " if self.transform:\n", 1361 | " sample = x, y\n", 1362 | " x, y = self.transform(sample)\n", 1363 | " return x, y # el item per sé, la partida con todas sus características\n", 1364 | "\n", 1365 | " def __len__(self):\n", 1366 | " return len(self.matches_id)\n" 1367 | ], 1368 | "execution_count": 16, 1369 | "outputs": [] 1370 | }, 1371 | { 1372 | "cell_type": "markdown", 1373 | "metadata": { 1374 | "id": "UIm1_KUCUNB0", 1375 | "colab_type": "text" 1376 | }, 1377 | "source": [ 1378 | "# Model" 1379 | ] 1380 | }, 1381 | { 1382 | "cell_type": "markdown", 1383 | "metadata": { 1384 | "id": "qr3TZbrnUg2H", 1385 | "colab_type": "text" 1386 | }, 1387 | "source": [ 1388 | "## Transformer encoder modified to obtain the attention weights" 1389 | ] 1390 | }, 1391 | { 1392 | "cell_type": "code", 1393 | "metadata": { 1394 | "colab_type": "code", 1395 | "id": "HQvpd9lxxQqI", 1396 | "colab": {} 1397 | }, 1398 | "source": [ 1399 | "class TransformerEncoder(nn.Module):\n", 1400 | " \"\"\"TransformerEncoder is a stack of N encoder layers\n", 1401 | "\n", 1402 | " Args:\n", 1403 | " encoder_layer: an instance of the TransformerEncoderLayer() class (required).\n", 1404 | " num_layers: the number of sub-encoder-layers in the encoder (required).\n", 1405 | " norm: the layer normalization component (optional).\n", 1406 | "\n", 1407 | " Examples::\n", 1408 | " >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n", 1409 | " >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)\n", 1410 | " >>> src = torch.rand(10, 32, 512)\n", 1411 | " >>> out = transformer_encoder(src)\n", 1412 | " \"\"\"\n", 1413 | "\n", 1414 | " def __init__(self, encoder_layer, num_layers, norm=None):\n", 1415 | " super(TransformerEncoder, self).__init__()\n", 1416 | " self.layers = _get_clones(encoder_layer, num_layers)\n", 1417 | " self.num_layers = num_layers\n", 1418 | " self.norm = norm\n", 1419 | "\n", 1420 | " def forward(self, src, mask=None, src_key_padding_mask=None):\n", 1421 | " \"\"\"Pass the input through the endocder layers in turn.\n", 1422 | "\n", 1423 | " Args:\n", 1424 | " src: the sequnce to the encoder (required).\n", 1425 | " mask: the mask for the src sequence (optional).\n", 1426 | " src_key_padding_mask: the mask for the src keys per batch (optional).\n", 1427 | "\n", 1428 | " Shape:\n", 1429 | " see the docs in Transformer class.\n", 1430 | " \"\"\"\n", 1431 | " output = src\n", 1432 | " att_weights = []\n", 1433 | "\n", 1434 | " for i in range(self.num_layers):\n", 1435 | " output, attn_output_weights = self.layers[i](output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)\n", 1436 | " \n", 1437 | " att_weights.append(attn_output_weights)\n", 1438 | "\n", 1439 | " if self.norm:\n", 1440 | " output = self.norm(output)\n", 1441 | "\n", 1442 | " return output, att_weights\n", 1443 | "\n" 1444 | ], 1445 | "execution_count": 17, 1446 | "outputs": [] 1447 | }, 1448 | { 1449 | "cell_type": "code", 1450 | "metadata": { 1451 | "id": "eh2Nz2CWDypX", 1452 | "colab_type": "code", 1453 | "colab": {} 1454 | }, 1455 | "source": [ 1456 | "def _get_clones(module, N):\n", 1457 | " return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n", 1458 | "\n", 1459 | "\n", 1460 | "def _get_activation_fn(activation):\n", 1461 | " if activation == \"relu\":\n", 1462 | " return F.relu\n", 1463 | " elif activation == \"gelu\":\n", 1464 | " return F.gelu\n", 1465 | " else:\n", 1466 | " raise RuntimeError(\"activation should be relu/gelu, not %s.\" % activation)" 1467 | ], 1468 | "execution_count": 18, 1469 | "outputs": [] 1470 | }, 1471 | { 1472 | "cell_type": "code", 1473 | "metadata": { 1474 | "id": "zJCX_mqojgwQ", 1475 | "colab_type": "code", 1476 | "colab": {} 1477 | }, 1478 | "source": [ 1479 | "class TransformerEncoderLayer(nn.Module):\n", 1480 | " \"\"\"TransformerEncoderLayer is made up of self-attn and feedforward network.\n", 1481 | " This standard encoder layer is based on the paper \"Attention Is All You Need\".\n", 1482 | " Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,\n", 1483 | " Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in\n", 1484 | " Neural Information Processing Systems, pages 6000-6010. Users may modify or implement\n", 1485 | " in a different way during application.\n", 1486 | "\n", 1487 | " Args:\n", 1488 | " d_model: the number of expected features in the input (required).\n", 1489 | " nhead: the number of heads in the multiheadattention models (required).\n", 1490 | " dim_feedforward: the dimension of the feedforward network model (default=2048).\n", 1491 | " dropout: the dropout value (default=0.1).\n", 1492 | " activation: the activation function of intermediate layer, relu or gelu (default=relu).\n", 1493 | "\n", 1494 | " Examples::\n", 1495 | " >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n", 1496 | " >>> src = torch.rand(10, 32, 512)\n", 1497 | " >>> out = encoder_layer(src)\n", 1498 | " \"\"\"\n", 1499 | "\n", 1500 | " def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=\"relu\"):\n", 1501 | " super(TransformerEncoderLayer, self).__init__()\n", 1502 | " self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n", 1503 | " # Implementation of Feedforward model\n", 1504 | " self.linear1 = nn.Linear(d_model, dim_feedforward)\n", 1505 | " self.dropout = nn.Dropout(dropout)\n", 1506 | " self.linear2 = nn.Linear(dim_feedforward, d_model)\n", 1507 | "\n", 1508 | " self.norm1 = nn.LayerNorm(d_model)\n", 1509 | " self.norm2 = nn.LayerNorm(d_model)\n", 1510 | " self.dropout1 = nn.Dropout(dropout)\n", 1511 | " self.dropout2 = nn.Dropout(dropout)\n", 1512 | "\n", 1513 | " self.activation = _get_activation_fn(activation)\n", 1514 | "\n", 1515 | " def forward(self, src, src_mask=None, src_key_padding_mask=None):\n", 1516 | " \"\"\"Pass the input through the endocder layer.\n", 1517 | "\n", 1518 | " Args:\n", 1519 | " src: the sequnce to the encoder layer (required).\n", 1520 | " src_mask: the mask for the src sequence (optional).\n", 1521 | " src_key_padding_mask: the mask for the src keys per batch (optional).\n", 1522 | "\n", 1523 | " Shape:\n", 1524 | " see the docs in Transformer class.\n", 1525 | " \"\"\"\n", 1526 | " src2, attn_output_weights = self.self_attn(src, src, src, attn_mask=src_mask,\n", 1527 | " key_padding_mask=src_key_padding_mask)\n", 1528 | " src = src + self.dropout1(src2)\n", 1529 | " src = self.norm1(src)\n", 1530 | " if hasattr(self, \"activation\"):\n", 1531 | " src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))\n", 1532 | " else: # for backward compatibility\n", 1533 | " src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))\n", 1534 | " src = src + self.dropout2(src2)\n", 1535 | " src = self.norm2(src)\n", 1536 | " return src, attn_output_weights" 1537 | ], 1538 | "execution_count": 19, 1539 | "outputs": [] 1540 | }, 1541 | { 1542 | "cell_type": "markdown", 1543 | "metadata": { 1544 | "id": "pwRy106QU6sH", 1545 | "colab_type": "text" 1546 | }, 1547 | "source": [ 1548 | "## Auxiliary Task Classes" 1549 | ] 1550 | }, 1551 | { 1552 | "cell_type": "code", 1553 | "metadata": { 1554 | "id": "RNyvw9ynsLq3", 1555 | "colab_type": "code", 1556 | "colab": {} 1557 | }, 1558 | "source": [ 1559 | "def getItems(gt_items, table_emb, num_items, emb_dim):\n", 1560 | " list_match = []\n", 1561 | " device = gt_items.device\n", 1562 | " for i in range(gt_items.size(0)):\n", 1563 | " match = gt_items[i,:,:]\n", 1564 | " list_part_item = []\n", 1565 | " for j in range(gt_items.size(1)):\n", 1566 | " participant_items = match[j,:]\n", 1567 | " sum_k = torch.sum(participant_items, dim = 0).item()\n", 1568 | " if int(sum_k) > 0:\n", 1569 | " _, pos_items = torch.topk(participant_items, k = int(sum_k), dim = 0)\n", 1570 | " items_emb = table_emb(pos_items)\n", 1571 | " items_emb = torch.mean(items_emb, dim = 0)\n", 1572 | " list_part_item.append(items_emb)\n", 1573 | " else:\n", 1574 | " list_part_item.append(torch.zeros(emb_dim).to(device))\n", 1575 | " team_item_emb = torch.stack(list_part_item)\n", 1576 | " list_match.append(team_item_emb)\n", 1577 | " return torch.stack(list_match)" 1578 | ], 1579 | "execution_count": 20, 1580 | "outputs": [] 1581 | }, 1582 | { 1583 | "cell_type": "code", 1584 | "metadata": { 1585 | "id": "KahjODMTdPS4", 1586 | "colab_type": "code", 1587 | "colab": {} 1588 | }, 1589 | "source": [ 1590 | "class WinEncoder(nn.Module):\n", 1591 | " def __init__(self, model_dim, n_items):\n", 1592 | " super(WinEncoder, self).__init__()\n", 1593 | " self.proj_win = nn.Linear(4*model_dim, 2)\n", 1594 | " self.embeddings_table_items = nn.Embedding(num_embeddings = n_items, embedding_dim = model_dim)\n", 1595 | " self.n_items = n_items\n", 1596 | " self.model_dim = model_dim\n", 1597 | " self.init_weights()\n", 1598 | "\n", 1599 | " def init_weights(self):\n", 1600 | " initrange = 0.1\n", 1601 | " self.proj_win.bias.data.zero_()\n", 1602 | " self.proj_win.weight.data.uniform_(-initrange, initrange)\n", 1603 | " self.embeddings_table_items.weight.data.uniform_(-initrange, initrange)\n", 1604 | "\n", 1605 | " def forward(self, att_match, item_list):\n", 1606 | " # att_match size (Batch, Seq, Emb)\n", 1607 | " # item_list size (Batch, Seq, Num_items, Emb)\n", 1608 | " att_item_team_1, att_item_team_2 = torch.chunk(att_match, 2, dim=1)\n", 1609 | " items_team_1, items_team_2 = torch.chunk(item_list, 2, dim=1)\n", 1610 | "\n", 1611 | " items_team_1 = getItems(items_team_1, self.embeddings_table_items, self.n_items,self.model_dim)\n", 1612 | " items_team_2 = getItems(items_team_2, self.embeddings_table_items, self.n_items,self.model_dim)\n", 1613 | "\n", 1614 | " att_item_team_1 = torch.mean(att_item_team_1, dim=1)\n", 1615 | " att_item_team_1 = F.relu(att_item_team_1)\n", 1616 | " att_item_team_1 = (att_item_team_1 / att_item_team_1.max())\n", 1617 | " items_team_1 = torch.mean(items_team_1, dim=1)\n", 1618 | " items_team_1 = F.relu(items_team_1)\n", 1619 | " items_team_1 = (items_team_1 / items_team_1.max())\n", 1620 | "\n", 1621 | " att_item_team_2 = torch.mean(att_item_team_2, dim=1)\n", 1622 | " att_item_team_2 = F.relu(att_item_team_2)\n", 1623 | " att_item_team_2 = (att_item_team_2 / att_item_team_2.max())\n", 1624 | " items_team_2 = torch.mean(items_team_2, dim=1)\n", 1625 | " items_team_2 = F.relu(items_team_2)\n", 1626 | " items_team_2 = (items_team_2 / items_team_2.max())\n", 1627 | "\n", 1628 | " att_item_team_1 = torch.cat((att_item_team_1, items_team_1), 1)\n", 1629 | " att_item_team_2 = torch.cat((att_item_team_2, items_team_2), 1)\n", 1630 | " proj_win_team = torch.cat((att_item_team_1, att_item_team_2), 1)\n", 1631 | " win_emb = self.proj_win(F.relu(proj_win_team))\n", 1632 | "\n", 1633 | " return win_emb" 1634 | ], 1635 | "execution_count": 21, 1636 | "outputs": [] 1637 | }, 1638 | { 1639 | "cell_type": "code", 1640 | "metadata": { 1641 | "id": "pdzByFr9ZyB7", 1642 | "colab_type": "code", 1643 | "colab": {} 1644 | }, 1645 | "source": [ 1646 | "def getTensorPredItem(items_logits):\n", 1647 | " pred_items = torch.zeros(items_logits.size())\n", 1648 | " for i in range(items_logits.size(0)):\n", 1649 | " for j in range(items_logits.size(1)):\n", 1650 | " _,pos_items = torch.topk(items_logits[i,j,:],k = 6,dim=0)\n", 1651 | " pred_items[i,j,pos_items] = 1\n", 1652 | " return pred_items" 1653 | ], 1654 | "execution_count": 22, 1655 | "outputs": [] 1656 | }, 1657 | { 1658 | "cell_type": "markdown", 1659 | "metadata": { 1660 | "id": "9AlU_u42VG8A", 1661 | "colab_type": "text" 1662 | }, 1663 | "source": [ 1664 | "## Main Class of the proposed model" 1665 | ] 1666 | }, 1667 | { 1668 | "cell_type": "code", 1669 | "metadata": { 1670 | "id": "xY-TRDX2BHAr", 1671 | "colab_type": "code", 1672 | "colab": {} 1673 | }, 1674 | "source": [ 1675 | "class TransformerLolRecommender(nn.Module):\n", 1676 | "\n", 1677 | " def __init__(self, n_role, n_champions, embeddings_size, nhead, n_items, n_type, nlayers = 1, nhid = 2048, dropout=0.5, aux_task = False, \n", 1678 | " learnable_team_emb = False):\n", 1679 | " super(TransformerLolRecommender, self).__init__()\n", 1680 | "\n", 1681 | " self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", 1682 | " \n", 1683 | " self.embeddings_table_role = nn.Embedding(num_embeddings = n_role, embedding_dim = embeddings_size)\n", 1684 | " \n", 1685 | " self.embeddings_table_champion = nn.Embedding(num_embeddings = n_champions, embedding_dim = embeddings_size)\n", 1686 | "\n", 1687 | " self.embeddings_table_type = nn.Embedding(num_embeddings = n_type, embedding_dim = embeddings_size, padding_idx=0)\n", 1688 | " \n", 1689 | " self.learnable_team_emb = learnable_team_emb\n", 1690 | " if learnable_team_emb:\n", 1691 | " self.team_encoder = nn.Embedding(num_embeddings = 2, embedding_dim = embeddings_size)\n", 1692 | " else:\n", 1693 | " self.team_encoder = self.get_team_encoding(embeddings_size, 10)\n", 1694 | " \n", 1695 | " encoder_layers = TransformerEncoderLayer(embeddings_size, nhead, nhid, dropout)\n", 1696 | " self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n", 1697 | " \n", 1698 | " self.recommender = nn.Linear(embeddings_size, n_items)\n", 1699 | " self.pred_champ = nn.Linear(embeddings_size, n_champions)\n", 1700 | "\n", 1701 | " self.aux_task = aux_task\n", 1702 | "\n", 1703 | " if self.aux_task: \n", 1704 | " self.win_encoder = WinEncoder(embeddings_size, n_items)\n", 1705 | "\n", 1706 | " self.init_weights()\n", 1707 | " \n", 1708 | " def get_learnable_team_emb(self, num_batch):\n", 1709 | " emb_team_0 = self.team_encoder(torch.LongTensor([0]).to(self.device))\n", 1710 | " emb_team_0 = emb_team_0.expand(5, emb_team_0.size(1))\n", 1711 | " emb_team_1 = self.team_encoder(torch.LongTensor([1]).to(self.device))\n", 1712 | " emb_team_1 = emb_team_1.expand(5, emb_team_1.size(1))\n", 1713 | " emb_team = torch.cat([emb_team_0, emb_team_1], dim = 0)\n", 1714 | " emb_team = emb_team.unsqueeze(0).expand(num_batch, emb_team.size(0), emb_team.size(1))\n", 1715 | " return emb_team\n", 1716 | "\n", 1717 | " \n", 1718 | " def get_team_encoding(self, embedding_dim, num_champions = 10):\n", 1719 | " team_encoding = torch.zeros(num_champions, embedding_dim)\n", 1720 | " team_encoding[5:,:] = 1\n", 1721 | " return team_encoding.to(self.device)\n", 1722 | "\n", 1723 | " def init_weights(self):\n", 1724 | " initrange = 0.1\n", 1725 | " \n", 1726 | " self.embeddings_table_role.weight.data.uniform_(-initrange, initrange)\n", 1727 | " self.embeddings_table_champion.weight.data.uniform_(-initrange, initrange)\n", 1728 | " self.embeddings_table_type.weight.data.uniform_(-initrange, initrange)\n", 1729 | " \n", 1730 | " self.recommender.bias.data.zero_()\n", 1731 | " self.recommender.weight.data.uniform_(-initrange, initrange)\n", 1732 | "\n", 1733 | " self.pred_champ.bias.data.zero_()\n", 1734 | " self.pred_champ.weight.data.uniform_(-initrange, initrange)\n", 1735 | "\n", 1736 | " if self.learnable_team_emb:\n", 1737 | " self.team_encoder.weight.data.uniform_(-initrange, initrange)\n", 1738 | "\n", 1739 | " def forward(self, role, champion_id, types, items, win, enable_teacher_f):\n", 1740 | "\n", 1741 | " role_participants = self.embeddings_table_role(role)\n", 1742 | " id_participants = self.embeddings_table_champion(champion_id)\n", 1743 | " type_champ = self.embeddings_table_type(types)\n", 1744 | " type_champ = torch.sum(type_champ, dim =2)\n", 1745 | " batch_size = role_participants.size(0)\n", 1746 | " if self.learnable_team_emb:\n", 1747 | " team_participants = self.get_learnable_team_emb(batch_size)\n", 1748 | " else:\n", 1749 | " size_team_emb = self.team_encoder.size()\n", 1750 | " team_participants = self.team_encoder.unsqueeze(0).expand(batch_size, size_team_emb[0], size_team_emb[1])\n", 1751 | "\n", 1752 | " sel_champions = []\n", 1753 | " pos_champions = []\n", 1754 | " for i in range(win.size(0)):\n", 1755 | " id_el = random.randint(0,4)\n", 1756 | " pos_champions.append(id_el)\n", 1757 | " if win[i] != 0:\n", 1758 | " id_el = id_el + 5\n", 1759 | " sel_champion = champion_id[i,id_el]\n", 1760 | " id_participants[i,id_el,:] = 0\n", 1761 | " sel_champions.append(sel_champion)\n", 1762 | "\n", 1763 | " sel_champions = torch.stack(sel_champions)\n", 1764 | " # pos_champions = torch.stack(pos_champions)\n", 1765 | "\n", 1766 | " participants = role_participants + id_participants + team_participants + type_champ\n", 1767 | " # size (Seq, Batch, Emb)\n", 1768 | " participants = participants.permute(1,0,2)\n", 1769 | " # size (Seq, Batch, Emb)\n", 1770 | " output, att_weights = self.transformer_encoder(participants)\n", 1771 | " # size (Batch, Seq, Emb)\n", 1772 | " output = output.permute(1,0,2)\n", 1773 | " logits_items = self.recommender(output)\n", 1774 | "\n", 1775 | " output_obj = {\n", 1776 | " 'logits_items': logits_items,\n", 1777 | " 'att_weights': att_weights,\n", 1778 | " 'outputs': output,\n", 1779 | " 'sel_champions': sel_champions,\n", 1780 | " 'pos_champions': pos_champions\n", 1781 | " }\n", 1782 | "\n", 1783 | " if self.aux_task:\n", 1784 | " if enable_teacher_f: \n", 1785 | " items_used = items\n", 1786 | " else:\n", 1787 | " items_used = getTensorPredItem(logits_items).to(self.device)\n", 1788 | " logits_win = self.win_encoder(output, items_used)\n", 1789 | " output_obj['logits_win'] = logits_win\n", 1790 | "\n", 1791 | " return output_obj" 1792 | ], 1793 | "execution_count": 23, 1794 | "outputs": [] 1795 | }, 1796 | { 1797 | "cell_type": "markdown", 1798 | "metadata": { 1799 | "id": "rwYoKWcsVqex", 1800 | "colab_type": "text" 1801 | }, 1802 | "source": [ 1803 | "# Logger and Checkpointer" 1804 | ] 1805 | }, 1806 | { 1807 | "cell_type": "markdown", 1808 | "metadata": { 1809 | "id": "nrXJc5reVwyM", 1810 | "colab_type": "text" 1811 | }, 1812 | "source": [ 1813 | "These classes and methods are essential to log relevant information about the model and metrics in Coment. Likewise, they allow to save checkpoint in each epoch. " 1814 | ] 1815 | }, 1816 | { 1817 | "cell_type": "code", 1818 | "metadata": { 1819 | "id": "D0JLqdBEV67X", 1820 | "colab_type": "code", 1821 | "colab": {} 1822 | }, 1823 | "source": [ 1824 | "def load_defaults(defaults_file):\n", 1825 | " return OmegaConf.load(defaults_file)\n", 1826 | "\n", 1827 | "\n", 1828 | "def load_config_file(config_file):\n", 1829 | " if not config_file:\n", 1830 | " return OmegaConf.create()\n", 1831 | " return OmegaConf.load(config_file)\n", 1832 | "\n", 1833 | "\n", 1834 | "def load_config(config_file, defaults_file):\n", 1835 | " defaults = load_defaults(defaults_file)\n", 1836 | " config = OmegaConf.merge(defaults, load_config_file(config_file))\n", 1837 | " config.merge_with_cli()\n", 1838 | " return config\n", 1839 | "\n", 1840 | "\n", 1841 | "def build_config(args):\n", 1842 | " return load_config(args.config_file, args.defaults_file)\n", 1843 | "\n", 1844 | "\n", 1845 | "def config_to_dict(cfg):\n", 1846 | " return dict(cfg)\n", 1847 | "\n", 1848 | "\n", 1849 | "def config_to_comet(cfg):\n", 1850 | " def _config_to_comet(cfg, local_dict, parent_str):\n", 1851 | " for key, value in cfg.items():\n", 1852 | " full_key = \"{}.{}\".format(parent_str, key)\n", 1853 | " if isinstance(value, (dict, DictConfig)):\n", 1854 | " _config_to_comet(value, local_dict, full_key)\n", 1855 | " else:\n", 1856 | " local_dict[full_key] = value\n", 1857 | "\n", 1858 | " local_dict = {}\n", 1859 | " for key, value in cfg.items():\n", 1860 | " if isinstance(value, (dict, DictConfig)):\n", 1861 | " _config_to_comet(value, local_dict, key)\n", 1862 | " else:\n", 1863 | " local_dict[key] = value\n", 1864 | " return local_dict" 1865 | ], 1866 | "execution_count": 24, 1867 | "outputs": [] 1868 | }, 1869 | { 1870 | "cell_type": "code", 1871 | "metadata": { 1872 | "id": "RWTpFcpaU50q", 1873 | "colab_type": "code", 1874 | "colab": {} 1875 | }, 1876 | "source": [ 1877 | "def get_checkpointer(save_path, metric_name='val_acc'):\n", 1878 | " if not os.path.exists(save_path):\n", 1879 | " os.makedirs(save_path)\n", 1880 | " return ModelCheckpoint(\n", 1881 | " filepath=save_path,\n", 1882 | " verbose=True,\n", 1883 | " monitor=metric_name,\n", 1884 | " mode='max',\n", 1885 | " )\n", 1886 | "\n", 1887 | "\n", 1888 | "# class CometLogger(LightningLoggerBase):\n", 1889 | "# # Thank you @ceyzaguirre4\n", 1890 | "# def __init__(self, config, *args, **kwargs):\n", 1891 | "# super().__init__()\n", 1892 | "# self.comet_exp = CometExperiment(*args, **kwargs)\n", 1893 | "# self.comet_exp.set_name(config['exp_name'])\n", 1894 | "# self.comet_exp.log_parameters(config)\n", 1895 | "# self.config = config\n", 1896 | "\n", 1897 | "# @rank_zero_only\n", 1898 | "# def log_hyperparams(self, params):\n", 1899 | "# self.comet_exp.log_parameters(config_to_comet(params))\n", 1900 | "\n", 1901 | "# @rank_zero_only\n", 1902 | "# def log_metrics(self, metrics, step):\n", 1903 | "# self.comet_exp.log_metrics(metrics)\n", 1904 | "\n", 1905 | "# @rank_zero_only\n", 1906 | "# def finalize(self, status):\n", 1907 | "# self.comet_exp.end()\n", 1908 | " \n", 1909 | "# def version(self):\n", 1910 | "# return self.config['exp']\n" 1911 | ], 1912 | "execution_count": 25, 1913 | "outputs": [] 1914 | }, 1915 | { 1916 | "cell_type": "markdown", 1917 | "metadata": { 1918 | "id": "5ktMqAUMWeEz", 1919 | "colab_type": "text" 1920 | }, 1921 | "source": [ 1922 | "# Metrics" 1923 | ] 1924 | }, 1925 | { 1926 | "cell_type": "code", 1927 | "metadata": { 1928 | "id": "-TMnKThtdtb4", 1929 | "colab_type": "code", 1930 | "colab": {} 1931 | }, 1932 | "source": [ 1933 | "def recall_at_k(output, target, k = 6):\n", 1934 | " output_k, ind_k = torch.topk(output, k, dim = 1)\n", 1935 | " sum_recall = 0\n", 1936 | " num_part = output_k.size(0)\n", 1937 | " relevants = target.sum(dim = 1)\n", 1938 | " list_recall = []\n", 1939 | " for i in range(num_part):\n", 1940 | " target_k = target[i, ind_k[i,:]]\n", 1941 | " intersection = target_k.sum(dim = 0)\n", 1942 | " recall_n = intersection/relevants[i]\n", 1943 | " list_recall.append(recall_n)\n", 1944 | " sum_recall+=recall_n\n", 1945 | " \n", 1946 | " recall_avg = sum_recall/num_part\n", 1947 | " return recall_avg, num_part, list_recall\n" 1948 | ], 1949 | "execution_count": 26, 1950 | "outputs": [] 1951 | }, 1952 | { 1953 | "cell_type": "code", 1954 | "metadata": { 1955 | "id": "1xRNZaAJ_RZ8", 1956 | "colab_type": "code", 1957 | "colab": {} 1958 | }, 1959 | "source": [ 1960 | "def precision_at_k(r, k):\n", 1961 | " \"\"\"Score is precision @ k\n", 1962 | "\n", 1963 | " Relevance is binary (nonzero is relevant).\n", 1964 | "\n", 1965 | " >>> r = [0, 0, 1]\n", 1966 | " >>> precision_at_k(r, 1)\n", 1967 | " 0.0\n", 1968 | " >>> precision_at_k(r, 2)\n", 1969 | " 0.0\n", 1970 | " >>> precision_at_k(r, 3)\n", 1971 | " 0.33333333333333331\n", 1972 | " >>> precision_at_k(r, 4)\n", 1973 | " Traceback (most recent call last):\n", 1974 | " File \"\", line 1, in ?\n", 1975 | " ValueError: Relevance score length < k\n", 1976 | "\n", 1977 | "\n", 1978 | " Args:\n", 1979 | " r: Relevance scores (list or numpy) in rank order\n", 1980 | " (first element is the first item)\n", 1981 | "\n", 1982 | " Returns:\n", 1983 | " Precision @ k\n", 1984 | "\n", 1985 | " Raises:\n", 1986 | " ValueError: len(r) must be >= k\n", 1987 | " \"\"\"\n", 1988 | " assert k >= 1\n", 1989 | " r = np.asarray(r)[:k] != 0\n", 1990 | " if r.size != k:\n", 1991 | " raise ValueError('Relevance score length < k')\n", 1992 | " return np.mean(r)\n", 1993 | "\n", 1994 | "\n", 1995 | "def average_precision(r):\n", 1996 | " \"\"\"Score is average precision (area under PR curve)\n", 1997 | "\n", 1998 | " Relevance is binary (nonzero is relevant).\n", 1999 | "\n", 2000 | " >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1]\n", 2001 | " >>> delta_r = 1. / sum(r)\n", 2002 | " >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y])\n", 2003 | " 0.7833333333333333\n", 2004 | " >>> average_precision(r)\n", 2005 | " 0.78333333333333333\n", 2006 | "\n", 2007 | " Args:\n", 2008 | " r: Relevance scores (list or numpy) in rank order\n", 2009 | " (first element is the first item)\n", 2010 | "\n", 2011 | " Returns:\n", 2012 | " Average precision\n", 2013 | " \"\"\"\n", 2014 | " r = np.asarray(r) != 0\n", 2015 | " out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]]\n", 2016 | " if not out:\n", 2017 | " return 0.\n", 2018 | " return np.mean(out)" 2019 | ], 2020 | "execution_count": 27, 2021 | "outputs": [] 2022 | }, 2023 | { 2024 | "cell_type": "code", 2025 | "metadata": { 2026 | "id": "HfusZ9qS_uSq", 2027 | "colab_type": "code", 2028 | "colab": {} 2029 | }, 2030 | "source": [ 2031 | "def map_at(output, target, k=6):\n", 2032 | " sum_ap = 0\n", 2033 | " num_part = output.size(0)\n", 2034 | " list_map = []\n", 2035 | " for i in range(num_part):\n", 2036 | " out_p = output[i,:]\n", 2037 | " target_p = target[i,:]\n", 2038 | " output_k, ind_k = torch.topk(out_p, k, dim = 0)\n", 2039 | " list_rel = target_p[ind_k].tolist()\n", 2040 | " ap_at = average_precision(list_rel)\n", 2041 | " list_map.append(ap_at) \n", 2042 | " sum_ap += ap_at\n", 2043 | " return sum_ap/num_part, list_map" 2044 | ], 2045 | "execution_count": 28, 2046 | "outputs": [] 2047 | }, 2048 | { 2049 | "cell_type": "code", 2050 | "metadata": { 2051 | "id": "CctZXsremks4", 2052 | "colab_type": "code", 2053 | "colab": {} 2054 | }, 2055 | "source": [ 2056 | "def calc_precision_multiclass(output, target, k = 6):\n", 2057 | " output_k, ind_k = torch.topk(output, k, dim = 1)\n", 2058 | " sum_prec = 0\n", 2059 | " num_part = output_k.size(0)\n", 2060 | " list_prec = []\n", 2061 | " for i in range(num_part):\n", 2062 | " target_k = target[i, ind_k[i,:]]\n", 2063 | " intersection = target_k.sum(dim = 0)\n", 2064 | " preci_n = intersection/k\n", 2065 | " list_prec.append(preci_n)\n", 2066 | " sum_prec+=preci_n\n", 2067 | " \n", 2068 | " prec_avg = sum_prec/num_part\n", 2069 | " return prec_avg, num_part, list_prec" 2070 | ], 2071 | "execution_count": 29, 2072 | "outputs": [] 2073 | }, 2074 | { 2075 | "cell_type": "code", 2076 | "metadata": { 2077 | "id": "mykasAMS0bnG", 2078 | "colab_type": "code", 2079 | "colab": {} 2080 | }, 2081 | "source": [ 2082 | "def f1_score(recall, precision):\n", 2083 | " f1 = 2 * ((precision * recall) / (precision + recall))\n", 2084 | " return f1" 2085 | ], 2086 | "execution_count": 30, 2087 | "outputs": [] 2088 | }, 2089 | { 2090 | "cell_type": "code", 2091 | "metadata": { 2092 | "id": "1zMdgUJslL4V", 2093 | "colab_type": "code", 2094 | "colab": {} 2095 | }, 2096 | "source": [ 2097 | "class AverageMeter(object):\n", 2098 | " \"\"\"Computes and stores the average and current value\n", 2099 | " Taken from PyTorch's examples.imagenet.main\n", 2100 | " \"\"\"\n", 2101 | " def __init__(self):\n", 2102 | " self.reset()\n", 2103 | "\n", 2104 | " def reset(self):\n", 2105 | " self.val = 0\n", 2106 | " self.avg = 0\n", 2107 | " self.sum = 0\n", 2108 | " self.count = 0\n", 2109 | "\n", 2110 | " def update(self, val, n=1):\n", 2111 | " self.val = val\n", 2112 | " self.sum += val * n\n", 2113 | " self.count += n\n", 2114 | " self.avg = self.sum / self.count" 2115 | ], 2116 | "execution_count": 31, 2117 | "outputs": [] 2118 | }, 2119 | { 2120 | "cell_type": "code", 2121 | "metadata": { 2122 | "id": "x01LspaO_A7f", 2123 | "colab_type": "code", 2124 | "colab": {} 2125 | }, 2126 | "source": [ 2127 | "def set_seed(seed, slow=False):\n", 2128 | " import random\n", 2129 | "\n", 2130 | " if torch.cuda.is_available():\n", 2131 | " torch.cuda.manual_seed(seed)\n", 2132 | "\n", 2133 | " torch.manual_seed(seed)\n", 2134 | " random.seed(seed)\n", 2135 | " np.random.seed(seed)\n", 2136 | "\n", 2137 | " if slow:\n", 2138 | " torch.backends.cudnn.deterministic = True\n", 2139 | " torch.backends.cudnn.benchmark = False" 2140 | ], 2141 | "execution_count": 32, 2142 | "outputs": [] 2143 | }, 2144 | { 2145 | "cell_type": "code", 2146 | "metadata": { 2147 | "id": "2URNBE4I7x7j", 2148 | "colab_type": "code", 2149 | "colab": {} 2150 | }, 2151 | "source": [ 2152 | "def get_winners(att_vec, gt_item, win_vec, pos_champions, outputs_log):\n", 2153 | " list_att = []\n", 2154 | " list_gt = []\n", 2155 | " list_cham = []\n", 2156 | " for i in range(att_vec.size(0)):\n", 2157 | " win = win_vec[i]\n", 2158 | " pos = pos_champions[i]\n", 2159 | " if win == 0:\n", 2160 | " a = list(range(0,5))\n", 2161 | " del a[pos]\n", 2162 | " att_vec_match = att_vec[i,a,:]\n", 2163 | " gt_match = gt_item[i,a, :]\n", 2164 | " list_cham.append(outputs_log[i,pos, :])\n", 2165 | " list_att.append(att_vec_match)\n", 2166 | " list_gt.append(gt_match) \n", 2167 | " else:\n", 2168 | " a = list(range(5,10))\n", 2169 | " del a[pos]\n", 2170 | " att_vec_match = att_vec[i,a,:]\n", 2171 | " gt_match = gt_item[i, a, :]\n", 2172 | " list_cham.append(outputs_log[i,pos + 5, :])\n", 2173 | " list_att.append(att_vec_match)\n", 2174 | " list_gt.append(gt_match)\n", 2175 | "\n", 2176 | " att_winners = torch.stack(list_att, dim=0)\n", 2177 | " gt_winners = torch.stack(list_gt, dim=0)\n", 2178 | " att_cham = torch.stack(list_cham, dim=0)\n", 2179 | " return att_winners, gt_winners, att_cham\n", 2180 | " \n", 2181 | "\n" 2182 | ], 2183 | "execution_count": 33, 2184 | "outputs": [] 2185 | }, 2186 | { 2187 | "cell_type": "code", 2188 | "metadata": { 2189 | "id": "X3xn5Zd30i3C", 2190 | "colab_type": "code", 2191 | "colab": {} 2192 | }, 2193 | "source": [ 2194 | "def save_att_weights(list_att, path_save_att):\n", 2195 | " with open(path_save_att, 'wb') as handle:\n", 2196 | " pickle.dump(list_att, handle)" 2197 | ], 2198 | "execution_count": 34, 2199 | "outputs": [] 2200 | }, 2201 | { 2202 | "cell_type": "markdown", 2203 | "metadata": { 2204 | "id": "WDA0GHysW4vX", 2205 | "colab_type": "text" 2206 | }, 2207 | "source": [ 2208 | "# Training and evaluation loop" 2209 | ] 2210 | }, 2211 | { 2212 | "cell_type": "markdown", 2213 | "metadata": { 2214 | "id": "TVeG7SOJW80S", 2215 | "colab_type": "text" 2216 | }, 2217 | "source": [ 2218 | "The training and evaluation loop are based on [Pytorch-lightning](https://github.com/williamFalcon/pytorch-lightning)" 2219 | ] 2220 | }, 2221 | { 2222 | "cell_type": "code", 2223 | "metadata": { 2224 | "id": "OOJROUqzSqJU", 2225 | "colab_type": "code", 2226 | "colab": {} 2227 | }, 2228 | "source": [ 2229 | "import argparse" 2230 | ], 2231 | "execution_count": 35, 2232 | "outputs": [] 2233 | }, 2234 | { 2235 | "cell_type": "code", 2236 | "metadata": { 2237 | "id": "mR77pKXdXDDF", 2238 | "colab_type": "code", 2239 | "colab": {} 2240 | }, 2241 | "source": [ 2242 | "class Struct:\n", 2243 | " def __init__(self, **entries):\n", 2244 | " self.__dict__.update(entries)\n", 2245 | " #self.elems = entries.items()\n", 2246 | " \n", 2247 | " def items(self):\n", 2248 | " return self.__dict__.items()" 2249 | ], 2250 | "execution_count": 36, 2251 | "outputs": [] 2252 | }, 2253 | { 2254 | "cell_type": "code", 2255 | "metadata": { 2256 | "id": "RxtiaNMYc6iw", 2257 | "colab_type": "code", 2258 | "colab": {} 2259 | }, 2260 | "source": [ 2261 | "class LolRecAttModel(pl.LightningModule):\n", 2262 | "\n", 2263 | " def __init__(self, cfg):\n", 2264 | " super(LolRecAttModel, self).__init__()\n", 2265 | " \n", 2266 | " if type(cfg) is argparse.Namespace:\n", 2267 | " cfg = vars(cfg)\n", 2268 | " self.conf = cfg\n", 2269 | " self.hparams = cfg\n", 2270 | " self.index_split = self.conf['index_split']\n", 2271 | " self.optim = self.conf['optim']\n", 2272 | " set_seed(seed = self.conf['seed'])\n", 2273 | " train_dataset = self.train_dataset()\n", 2274 | " self.batch_size = self.conf['batch_size']\n", 2275 | " self.iter_max_train = len(train_dataset)//self.batch_size\n", 2276 | " num_roles = len(train_dataset.roles)\n", 2277 | " num_champions = len(train_dataset.champions)\n", 2278 | " n_items = len(train_dataset.items)\n", 2279 | " n_types = len(train_dataset.set_champ_type)\n", 2280 | " self.model = TransformerLolRecommender(n_role=num_roles, n_champions=num_champions, embeddings_size=self.conf['embeddings_size'], nhead=self.conf['nhead'], n_items=n_items, n_type=n_types,\n", 2281 | " nlayers = self.conf['nlayers'], nhid = self.conf['nhid'], dropout=self.conf['dropout'], aux_task = self.conf['win_task'], \n", 2282 | " learnable_team_emb = self.conf['learnable_team_emb'])\n", 2283 | " self.loss = nn.BCEWithLogitsLoss()\n", 2284 | " self.loss_aux = nn.CrossEntropyLoss()\n", 2285 | " self.train_loss = AverageMeter()\n", 2286 | " self.train_prec = AverageMeter()\n", 2287 | " self.iter_epoch = 0\n", 2288 | " isExist = os.path.exists(path_save) \n", 2289 | " if isExist:\n", 2290 | " dirs = os.listdir(path_save)\n", 2291 | " self.iter_epoch = len(dirs)\n", 2292 | "\n", 2293 | " self.aux_task = self.conf['win_task']\n", 2294 | " \n", 2295 | " if self.aux_task:\n", 2296 | " self.second_loss = nn.CrossEntropyLoss()\n", 2297 | " self.train_acc_win = AverageMeter()\n", 2298 | " self.train_main_loss = AverageMeter()\n", 2299 | " self.train_win_loss = AverageMeter()\n", 2300 | " self.alpha = self.conf['alpha']\n", 2301 | " self.beta = self.conf['beta']\n", 2302 | " self.epoch_to_win = self.conf['init_epoch']\n", 2303 | "\n", 2304 | " def check_epoch(self, num_iter):\n", 2305 | " if num_iter == 0:\n", 2306 | " self.train_loss = AverageMeter()\n", 2307 | " self.train_prec = AverageMeter()\n", 2308 | " if self.aux_task:\n", 2309 | " self.train_acc_win = AverageMeter()\n", 2310 | " self.train_main_loss = AverageMeter()\n", 2311 | " self.train_win_loss = AverageMeter()\n", 2312 | " self.iter_epoch+=1\n", 2313 | "\n", 2314 | " def custom_print(self, batch, loss, start_time, prec, acc = 0, log_interval = 100, loss_win =0, epoch=1):\n", 2315 | " if batch % log_interval == 0:\n", 2316 | " elapsed = time.time() - start_time\n", 2317 | " elapsed = elapsed*log_interval if batch > 0 else elapsed\n", 2318 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n", 2319 | " print('| epoch {:3d} | {:5d}/{:5d} batches | '\n", 2320 | " 'ms/batch {:5.2f} | '\n", 2321 | " 'loss {:5.6f} | loss win {:5.6f} | precision {:5.6f} | Accuracy (win) {:5.6f}'.format(\n", 2322 | " self.iter_epoch, batch, self.iter_max_train,\n", 2323 | " elapsed, loss, loss_win, prec, acc))\n", 2324 | " else:\n", 2325 | " print('| epoch {:3d} | {:5d}/{:5d} batches | '\n", 2326 | " 'ms/batch {:5.2f} | '\n", 2327 | " 'loss {:5.6f} | precision {:5.6f}'.format(\n", 2328 | " self.iter_epoch, batch, self.iter_max_train,\n", 2329 | " elapsed, loss, prec)) \n", 2330 | "\n", 2331 | " def forward(self, x, items, win, teacher_forcing):\n", 2332 | " role = x['role']\n", 2333 | " champions = x['champions']\n", 2334 | " types = x['type']\n", 2335 | " out = self.model(role, champions, types, items, win, teacher_forcing)\n", 2336 | " return out\n", 2337 | " #return torch.relu(self.l1(x.view(x.size(0), -1)))\n", 2338 | "\n", 2339 | " def training_step(self, batch, batch_nb):\n", 2340 | " # REQUIRED\n", 2341 | " self.check_epoch(batch_nb)\n", 2342 | " start_time = time.time()\n", 2343 | " x, y = batch\n", 2344 | " if len(x['role'].size()) == 3:\n", 2345 | " x['role'] = x['role'].reshape(x['role'].size(0)*x['role'].size(1), x['role'].size(2))\n", 2346 | " x['champions'] = x['champions'].reshape(x['champions'].size(0)*x['champions'].size(1), x['champions'].size(2))\n", 2347 | " x['type'] = x['type'].reshape(x['type'].size(0)*x['type'].size(1), x['type'].size(2), x['type'].size(3))\n", 2348 | " y['items'] = y['items'].reshape(y['items'].size(0)*y['items'].size(1), y['items'].size(2), y['items'].size(3))\n", 2349 | " y['win'] = y['win'].reshape(y['win'].size(0)*y['win'].size(1))\n", 2350 | " y_hat = self.forward(x, y['items'], y['win'], self.conf['teacher_forcing'])\n", 2351 | " \n", 2352 | " #Mains task\n", 2353 | " logits_items = y_hat['logits_items']\n", 2354 | " gt_items = y['items']\n", 2355 | " sel_champions = y_hat['sel_champions']\n", 2356 | " pos_champions = y_hat['pos_champions']\n", 2357 | " outputs_log = y_hat['outputs']\n", 2358 | " logits_items, gt_items, att_cham = get_winners(logits_items, gt_items, y['win'], pos_champions, outputs_log)\n", 2359 | " \n", 2360 | " out = logits_items.reshape(logits_items.size(0)*logits_items.size(1), logits_items.size(2))\n", 2361 | " out_aux = self.model.pred_champ(att_cham)\n", 2362 | "\n", 2363 | " gt = gt_items.reshape(gt_items.size(0)*gt_items.size(1), gt_items.size(2))\n", 2364 | " loss = self.loss(out, gt)\n", 2365 | " loss_aux = self.loss_aux(out_aux, sel_champions)\n", 2366 | "\n", 2367 | " prec, num, _ = calc_precision_multiclass(out, gt, k=6)\n", 2368 | " self.train_prec.update(prec, num)\n", 2369 | "\n", 2370 | " tensor_avg_prec = torch.tensor([self.train_prec.avg], device=loss.device)\n", 2371 | " tensorboard_logs = {'train_loss': loss, 'train_loss_aux': loss_aux, 'train_prec_avg': tensor_avg_prec}\n", 2372 | "\n", 2373 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n", 2374 | "\n", 2375 | " #Second Task\n", 2376 | " out_win = y_hat['logits_win']\n", 2377 | " \n", 2378 | " gt_win = y['win'].reshape(-1)\n", 2379 | "\n", 2380 | " _, preds_win = torch.max(out_win, 1)\n", 2381 | "\n", 2382 | " loss_win = self.second_loss(out_win, gt_win)\n", 2383 | " loss_total = self.alpha*loss + self.beta*loss_win\n", 2384 | " self.train_loss.update(self.alpha*loss.item(), out.size(0))\n", 2385 | " self.train_loss.update(self.beta*loss_win.item(), out_win.size(0))\n", 2386 | "\n", 2387 | " train_acc = torch.sum(preds_win == gt_win).item()/out_win.size(0)\n", 2388 | " self.train_acc_win.update(train_acc, out_win.size(0))\n", 2389 | " self.train_main_loss.update(loss.item(), out.size(0))\n", 2390 | " self.train_win_loss.update(loss_win.item(), out_win.size(0))\n", 2391 | "\n", 2392 | " tensor_avg_acc = torch.tensor([self.train_acc_win.avg], device=loss.device)\n", 2393 | " tensorboard_logs['train_acc_win_avg'] = tensor_avg_acc\n", 2394 | " tensorboard_logs['train_win_loss'] = loss_win\n", 2395 | " tensorboard_logs['train_main_loss'] = loss\n", 2396 | " tensorboard_logs['train_win_loss_avg'] = torch.tensor([self.train_win_loss.avg], device=loss.device)\n", 2397 | " tensorboard_logs['train_main_loss_avg'] = torch.tensor([self.train_main_loss.avg], device=loss.device)\n", 2398 | " tensorboard_logs['train_loss'] = loss_total\n", 2399 | "\n", 2400 | " # self.custom_print(batch_nb, self.train_main_loss.avg, start_time, self.train_prec.avg, self.train_acc_win.avg, 100, self.train_win_loss.avg)\n", 2401 | " else:\n", 2402 | " loss_total = loss + 0.2*loss_aux\n", 2403 | " self.train_loss.update(loss.item(), out.size(0))\n", 2404 | " tensorboard_logs['total_loss_train'] = loss_total\n", 2405 | " # self.custom_print(batch_nb, self.train_loss.avg, start_time, self.train_prec.avg, 0, 100)\n", 2406 | " \n", 2407 | " tensor_avg_loss = torch.tensor([self.train_loss.avg], device=loss.device)\n", 2408 | " tensorboard_logs['train_loss_avg'] = tensor_avg_loss\n", 2409 | " return {'loss': loss_total, 'progress_bar': tensorboard_logs, 'avg_loss': tensor_avg_loss, 'avg_prec':tensor_avg_prec ,'log': tensorboard_logs}\n", 2410 | "\n", 2411 | " def validation_step(self, batch, batch_nb):\n", 2412 | " x, y = batch\n", 2413 | " y_hat = self.forward(x, y['items'], y['win'], False)\n", 2414 | " att_weights = y_hat['att_weights']\n", 2415 | "\n", 2416 | " #Main Task\n", 2417 | " logits_items = y_hat['logits_items']\n", 2418 | " gt_items = y['items']\n", 2419 | " sel_champions = y_hat['sel_champions']\n", 2420 | " pos_champions = y_hat['pos_champions']\n", 2421 | " outputs_log = y_hat['outputs']\n", 2422 | "\n", 2423 | " logits_items, gt_items, att_cham = get_winners(logits_items, gt_items, y['win'], pos_champions, outputs_log)\n", 2424 | " out = logits_items.reshape(logits_items.size(0)*logits_items.size(1), logits_items.size(2))\n", 2425 | " out_aux = self.model.pred_champ(att_cham)\n", 2426 | " \n", 2427 | " gt = gt_items.reshape(gt_items.size(0)*gt_items.size(1), gt_items.size(2))\n", 2428 | "\n", 2429 | " loss = self.loss(out, gt)\n", 2430 | " loss_aux = self.loss_aux(out_aux, sel_champions)\n", 2431 | "\n", 2432 | " prec, num, list_prec = calc_precision_multiclass(out, gt, k=6)\n", 2433 | " prec1, num, list_prec1 = calc_precision_multiclass(out, gt, k=1)\n", 2434 | " prec3, num, list_prec3 = calc_precision_multiclass(out, gt, k=3)\n", 2435 | "\n", 2436 | " recall1, num, list_recall1 = recall_at_k(out, gt, k=1)\n", 2437 | " recall3, num, list_recall3 = recall_at_k(out, gt, k=3)\n", 2438 | " recall6, num, list_recall6 = recall_at_k(out, gt, k=6)\n", 2439 | "\n", 2440 | " f11 = f1_score(recall1, prec1)\n", 2441 | " f13 = f1_score(recall3, prec3) \n", 2442 | " f16 = f1_score(recall6, prec)\n", 2443 | "\n", 2444 | " map6, list_map6 = map_at(out, gt, k=6)\n", 2445 | " map1, list_map1 = map_at(out, gt, k=1)\n", 2446 | " map3, list_map3 = map_at(out, gt, k=3)\n", 2447 | "\n", 2448 | " obj_list = {\n", 2449 | " 'list_prec1': list_prec1,\n", 2450 | " 'list_prec3': list_prec3,\n", 2451 | " 'list_prec': list_prec,\n", 2452 | " 'list_recall1': list_recall1,\n", 2453 | " 'list_recall3': list_recall3,\n", 2454 | " 'list_recall6': list_recall6,\n", 2455 | " 'list_map1': list_map1,\n", 2456 | " 'list_map3': list_map3,\n", 2457 | " 'list_map6': list_map6\n", 2458 | " }\n", 2459 | " obj_res = {'val_loss': loss, 'val_loss_aux': loss_aux, 'val_prec': prec, 'num_batch': out.size(0), 'num':num, 'map6': map6, \n", 2460 | " 'map1': map1, 'map3': map3, 'val_prec1': prec1, 'val_prec3': prec3, 'val_recall1': recall1, \n", 2461 | " 'val_recall3': recall3, 'val_recall6': recall6, 'val_f1_1': f11, 'val_f1_3': f13, 'val_f1_6': f16, \n", 2462 | " 'att_weights': att_weights, 'logits_items': logits_items, 'obj_list': obj_list}\n", 2463 | "\n", 2464 | " #Second Task\n", 2465 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n", 2466 | " out_win = y_hat['logits_win']\n", 2467 | " \n", 2468 | " gt_win = y['win'].reshape(-1)\n", 2469 | " _, preds_win = torch.max(out_win, 1)\n", 2470 | "\n", 2471 | " loss_win = self.second_loss(out_win, gt_win)\n", 2472 | "\n", 2473 | " acc_win = torch.sum(preds_win == gt_win).item()/out_win.size(0)\n", 2474 | " obj_res['val_acc'] = acc_win\n", 2475 | " obj_res['val_loss_win'] = loss_win\n", 2476 | " obj_res['val_main_loss'] = loss\n", 2477 | " obj_res['num_batch_acc'] = out_win.size(0)\n", 2478 | "\n", 2479 | " return obj_res\n", 2480 | "\n", 2481 | " def validation_epoch_end(self, outputs):\n", 2482 | " avg_loss = AverageMeter()\n", 2483 | " avg_loss_aux = AverageMeter()\n", 2484 | " avg_prec = AverageMeter()\n", 2485 | " avg_prec1 = AverageMeter()\n", 2486 | " avg_prec3 = AverageMeter()\n", 2487 | "\n", 2488 | " avg_recall1 = AverageMeter()\n", 2489 | " avg_recall3 = AverageMeter()\n", 2490 | " avg_recall6 = AverageMeter()\n", 2491 | "\n", 2492 | " avg_f1_1 = AverageMeter()\n", 2493 | " avg_f1_3 = AverageMeter()\n", 2494 | " avg_f1_6 = AverageMeter()\n", 2495 | "\n", 2496 | " avg_map = AverageMeter()\n", 2497 | " avg_map1 = AverageMeter()\n", 2498 | " avg_map3 = AverageMeter()\n", 2499 | "\n", 2500 | " list_att_weights = []\n", 2501 | " list_logits_items = []\n", 2502 | "\n", 2503 | " list_prec1 = []\n", 2504 | " list_prec3 = []\n", 2505 | " list_prec6 = []\n", 2506 | "\n", 2507 | " list_recall1 = []\n", 2508 | " list_recall3 = []\n", 2509 | " list_recall6 = []\n", 2510 | "\n", 2511 | " list_map1 = []\n", 2512 | " list_map3 = []\n", 2513 | " list_map6 = []\n", 2514 | "\n", 2515 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n", 2516 | " avg_main_loss = AverageMeter()\n", 2517 | " avg_win_loss = AverageMeter()\n", 2518 | " avg_acc = AverageMeter()\n", 2519 | "\n", 2520 | " device = None\n", 2521 | " for x in outputs:\n", 2522 | "\n", 2523 | " avg_prec.update(x['val_prec'], x['num'])\n", 2524 | " avg_prec1.update(x['val_prec1'], x['num'])\n", 2525 | " avg_prec3.update(x['val_prec3'], x['num'])\n", 2526 | "\n", 2527 | " avg_recall1.update(x['val_recall1'], x['num'])\n", 2528 | " avg_recall3.update(x['val_recall3'], x['num'])\n", 2529 | " avg_recall6.update(x['val_recall6'], x['num'])\n", 2530 | "\n", 2531 | " avg_f1_1.update(x['val_f1_1'], x['num'])\n", 2532 | " avg_f1_3.update(x['val_f1_3'], x['num'])\n", 2533 | " avg_f1_6.update(x['val_f1_6'], x['num'])\n", 2534 | "\n", 2535 | " avg_map.update(x['map6'], x['num_batch'])\n", 2536 | " avg_map1.update(x['map1'], x['num_batch'])\n", 2537 | " avg_map3.update(x['map3'], x['num_batch'])\n", 2538 | "\n", 2539 | " list_att_weights.append(x['att_weights'])\n", 2540 | " list_logits_items.append(x['logits_items'])\n", 2541 | "\n", 2542 | " list_prec1.extend(x['obj_list']['list_prec1'])\n", 2543 | " list_prec3.extend(x['obj_list']['list_prec3'])\n", 2544 | " list_prec6.extend(x['obj_list']['list_prec'])\n", 2545 | "\n", 2546 | " list_recall1.extend(x['obj_list']['list_recall1'])\n", 2547 | " list_recall3.extend(x['obj_list']['list_recall3'])\n", 2548 | " list_recall6.extend(x['obj_list']['list_recall6'])\n", 2549 | "\n", 2550 | " list_map1.extend(x['obj_list']['list_map1'])\n", 2551 | " list_map3.extend(x['obj_list']['list_map3'])\n", 2552 | " list_map6.extend(x['obj_list']['list_map6'])\n", 2553 | "\n", 2554 | " device = x['val_loss'].device\n", 2555 | "\n", 2556 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n", 2557 | " avg_main_loss.update(x['val_main_loss'], x['num_batch'])\n", 2558 | " avg_win_loss.update(x['val_loss_win'], x['num_batch_acc'])\n", 2559 | " avg_acc.update(x['val_acc'], x['num_batch_acc'])\n", 2560 | "\n", 2561 | " avg_loss.update(self.alpha*x['val_main_loss'], x['num_batch'])\n", 2562 | " avg_loss.update(self.beta*x['val_loss_win'], x['num_batch_acc'])\n", 2563 | " else:\n", 2564 | " avg_loss.update(x['val_loss'], x['num_batch'])\n", 2565 | " avg_loss_aux.update(x['val_loss_aux'], x['num_batch'])\n", 2566 | "\n", 2567 | " tensorboard_logs = {'val_loss': torch.tensor([avg_loss.avg], device=device), 'val_prec': torch.tensor([avg_prec.avg], device=device), \n", 2568 | " 'val_map6': torch.tensor([avg_map.avg], device=device), 'val_map1': torch.tensor([avg_map1.avg], device=device),\n", 2569 | " 'val_map3': torch.tensor([avg_map3.avg], device=device), 'val_prec1': torch.tensor([avg_prec1.avg], device=device), \n", 2570 | " 'val_prec3': torch.tensor([avg_prec3.avg], device=device), 'val_recall1': torch.tensor([avg_recall1.avg], device=device),\n", 2571 | " 'val_recall3': torch.tensor([avg_recall3.avg], device=device), 'val_recall6': torch.tensor([avg_recall6.avg], device=device),\n", 2572 | " 'val_f1_1': torch.tensor([avg_f1_1.avg], device=device), 'val_f1_3': torch.tensor([avg_f1_3.avg], device=device),\n", 2573 | " 'val_f1_6': torch.tensor([avg_f1_6.avg], device=device)}\n", 2574 | "\n", 2575 | " if self.aux_task and self.iter_epoch >= self.epoch_to_win:\n", 2576 | " tensorboard_logs['val_main_loss'] = torch.tensor([avg_main_loss.avg], device=device)\n", 2577 | " tensorboard_logs['val_win_loss'] = torch.tensor([avg_win_loss.avg], device=device)\n", 2578 | " tensorboard_logs['val_win_acc'] = torch.tensor([avg_acc.avg], device=device)\n", 2579 | " print('| loss_val {:5.6f} | main_loss_val {:5.6f} | win_loss_val {:5.6f} | precision_val {:5.6f} | map6_val {:5.6f} | acc_val {:5.6f}'.format(avg_loss.avg, avg_main_loss.avg, \n", 2580 | " avg_win_loss.avg, avg_prec.avg, \n", 2581 | " avg_map.avg, avg_acc.avg))\n", 2582 | " # else:\n", 2583 | " # print('| loss_val {:5.6f} | precision1_val {:5.6f} | precision3_val {:5.6f} | precision6_val {:5.6f} | map1_val {:5.6f} | map3_val {:5.6f} | map6_val {:5.6f} | recall1 {:5.6f} | recall3 {:5.6f} | recall6 {:5.6f} | f1_1 {:5.6f} | f1_3 {:5.6f} | f1_6 {:5.6f}'.format(\n", 2584 | " # avg_loss.avg, avg_prec1.avg, avg_prec3.avg, avg_prec.avg, avg_map1.avg, avg_map3.avg, avg_map.avg, avg_recall1.avg, avg_recall3.avg, avg_recall6.avg, avg_f1_1.avg, avg_f1_3.avg, avg_f1_6.avg))\n", 2585 | " \n", 2586 | " path_save_att = path_save_att_format.format(str(self.conf['index_split']), str(self.conf['exp']), str(self.iter_epoch))\n", 2587 | " path_save_list_metrics = path_save_list_metrics_format.format(str(self.conf['index_split']), str(self.conf['exp']), str(self.iter_epoch))\n", 2588 | " weights_items = {\n", 2589 | " 'list_att_weights': list_att_weights,\n", 2590 | " 'list_logits_items': list_logits_items\n", 2591 | " }\n", 2592 | "\n", 2593 | " list_metrics = {\n", 2594 | " 'list_prec1': list_prec1, \n", 2595 | " 'list_prec3': list_prec3,\n", 2596 | " 'list_prec6': list_prec6,\n", 2597 | " 'list_recall1': list_recall1,\n", 2598 | " 'list_recall3': list_recall3,\n", 2599 | " 'list_recall6': list_recall6,\n", 2600 | " 'list_map1': list_map1,\n", 2601 | " 'list_map3': list_map3,\n", 2602 | " 'list_map6': list_map6\n", 2603 | " }\n", 2604 | " save_att_weights(weights_items, path_save_att)\n", 2605 | " save_att_weights(list_metrics, path_save_list_metrics)\n", 2606 | " return {'avg_val_loss': avg_loss.avg, 'avg_val_prec': avg_prec.avg, 'val_map6': avg_map.avg,'progress_bar': tensorboard_logs,'log': tensorboard_logs}\n", 2607 | "\n", 2608 | " def test_step(self, batch, batch_idx):\n", 2609 | " # OPTIONAL\n", 2610 | " return self.validation_step(batch, batch_idx)\n", 2611 | "\n", 2612 | " def test_epoch_end(self, outputs):\n", 2613 | " \n", 2614 | " return self.validation_end(outputs)\n", 2615 | "\n", 2616 | " def configure_optimizers(self):\n", 2617 | " # REQUIRED\n", 2618 | " # can return multiple optimizers and learning_rate schedulers\n", 2619 | " # (LBFGS it is automatically supported, no need for closure function)\n", 2620 | " if self.optim == 'adabound':\n", 2621 | " optimizer = adabound.AdaBound(self.model.parameters(), lr=1e-3, final_lr=0.1)\n", 2622 | " else:\n", 2623 | " optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)\n", 2624 | " return optimizer\n", 2625 | " \n", 2626 | " def train_dataset(self):\n", 2627 | "\n", 2628 | " data = get_partition(self.index_split, list_trainset)\n", 2629 | " composed = transforms.Compose([RandomSort_Part(),\n", 2630 | " RandomSort_Team()])\n", 2631 | " train_dataset = LolDataset(data, transform=composed)\n", 2632 | " return train_dataset\n", 2633 | "\n", 2634 | " @pl.data_loader\n", 2635 | " def train_dataloader(self):\n", 2636 | " \n", 2637 | " train_dataset = self.train_dataset()\n", 2638 | " return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)\n", 2639 | "\n", 2640 | " @pl.data_loader\n", 2641 | " def val_dataloader(self):\n", 2642 | " #data = list_testset[self.index_split]\n", 2643 | " data = get_partition(self.index_split, list_testset)\n", 2644 | " val_dataset = LolDataset(data)\n", 2645 | " return DataLoader(val_dataset, batch_size=self.batch_size)\n", 2646 | " \n", 2647 | " @pl.data_loader\n", 2648 | " def test_dataloader(self):\n", 2649 | " # OPTIONAL\n", 2650 | " return self.val_dataloader()" 2651 | ], 2652 | "execution_count": 37, 2653 | "outputs": [] 2654 | }, 2655 | { 2656 | "cell_type": "markdown", 2657 | "metadata": { 2658 | "id": "CyRfaqN8XvYi", 2659 | "colab_type": "text" 2660 | }, 2661 | "source": [ 2662 | "# Config file" 2663 | ] 2664 | }, 2665 | { 2666 | "cell_type": "markdown", 2667 | "metadata": { 2668 | "id": "4Aa7VmxeYF7X", 2669 | "colab_type": "text" 2670 | }, 2671 | "source": [ 2672 | "This config establish the model hyperparameters like:\n", 2673 | "\n", 2674 | "1. index_split - num of the partition used to train.\n", 2675 | "2. optim - optimizer (could be adam or adabound).\n", 2676 | "3. batch_size - Batch size \n", 2677 | "4. embeddings_size - model dim\n", 2678 | "5. nhead - number of attention heads \n", 2679 | "6. nlayers - number of encoder layers\n", 2680 | "7. exp - experiment number\n", 2681 | "8. epochs - number of epoch\n", 2682 | "9. exp_name - experiment name in comet.ml\n", 2683 | "10. alpha, beta - importance weights for losses\n", 2684 | "11. win_task - enable the auxiliary task.\n", 2685 | "12. learnable_team_emb - when it is True the team embedding is learnable \n", 2686 | "otherwise it is static. \n", 2687 | "13. teacher_forcing - enable the teacher forcing for the auxiliary task.\n", 2688 | "14. init_epoch - indicate the epoch when the second task start. \n" 2689 | ] 2690 | }, 2691 | { 2692 | "cell_type": "code", 2693 | "metadata": { 2694 | "id": "dqm4Zi0_ha8W", 2695 | "colab_type": "code", 2696 | "colab": {} 2697 | }, 2698 | "source": [ 2699 | "conf = {\n", 2700 | " 'index_split': 0,\n", 2701 | " 'optim': 'adam',\n", 2702 | " 'seed': 1642,\n", 2703 | " 'batch_size': 100,\n", 2704 | " 'embeddings_size': 512,\n", 2705 | " 'nhead': 2,\n", 2706 | " 'nlayers': 1, \n", 2707 | " 'nhid': 2048, \n", 2708 | " 'dropout': 0.5,\n", 2709 | " 'exp': 13,\n", 2710 | " 'epochs': 10,\n", 2711 | " 'exp_name': 'Main_tasks_rec_only_winners_final_prueba',\n", 2712 | " 'win_task': False,\n", 2713 | " 'alpha': 1,\n", 2714 | " 'beta': 1,\n", 2715 | " 'learnable_team_emb': True,\n", 2716 | " 'teacher_forcing': False,\n", 2717 | " 'init_epoch': 2\n", 2718 | "}" 2719 | ], 2720 | "execution_count": 38, 2721 | "outputs": [] 2722 | }, 2723 | { 2724 | "cell_type": "markdown", 2725 | "metadata": { 2726 | "id": "IVtKoVTcYDS1", 2727 | "colab_type": "text" 2728 | }, 2729 | "source": [ 2730 | "# Training and evaluation executor" 2731 | ] 2732 | }, 2733 | { 2734 | "cell_type": "code", 2735 | "metadata": { 2736 | "id": "EIBfoXNe1TOh", 2737 | "colab_type": "code", 2738 | "colab": { 2739 | "base_uri": "https://localhost:8080/", 2740 | "height": 833, 2741 | "referenced_widgets": [ 2742 | "a705d3b71b5d4fb587b3bb1fb38161fa", 2743 | "abfceea38af9444b8da09122eb0c867d", 2744 | "0dfe065d9c0c468389f75dd74a9a11ac", 2745 | "741dabfff47941f3b290d4ad4cb6be12", 2746 | "503e7ca9191948a4bcf6cc64a8862820", 2747 | "249092a9ce0940218d291fe75cf35f3e", 2748 | "828e6e2c56f3456cb016bf7c8b701ba8", 2749 | "7540c343bf9f461a84c157f84d529cd9", 2750 | "c43e576730a940c28fa78d49e95e7165", 2751 | "b3d8f9b86d0d47ae8c51b8f2eb202aab", 2752 | "82170aabdef246edbf668bb1cdf4a5e3", 2753 | "43a58bf6d9ab454095f8f5f30f10cdca", 2754 | "f80b50d773b240a091c6c3bdf7961924", 2755 | "f8a352088b8d4ed896bf8a206ecc024e", 2756 | "2e151daf92e84c369ee90e8ded7a24f2", 2757 | "5b0b2240357743c4b8285ce6017638c4" 2758 | ] 2759 | }, 2760 | "outputId": "3c6e7509-69b7-4562-e993-c1012e1c94dc" 2761 | }, 2762 | "source": [ 2763 | "from pytorch_lightning import Trainer\n", 2764 | "\n", 2765 | "path_save = '/content/gdrive/My Drive/Proyecto_RecSys/split/{}/exp_recsys/{}/checkpoints/'.format(str(conf['index_split']), str(conf['exp']))\n", 2766 | "path_save_att_format = '/content/gdrive/My Drive/Proyecto_RecSys/split/{}/exp_recsys/{}/checkpoints/att_weights_{}.pkl'\n", 2767 | "path_save_list_metrics_format = '/content/gdrive/My Drive/Proyecto_RecSys/split/{}/exp_recsys/{}/checkpoints/list_metrics_{}.pkl'\n", 2768 | "\n", 2769 | "model = LolRecAttModel(conf)\n", 2770 | "\n", 2771 | "checkpoint_callback = get_checkpointer(path_save,'avg_val_prec')\n", 2772 | "\n", 2773 | "\n", 2774 | "comet_logger = CometLogger(\n", 2775 | " experiment_name=conf['exp_name'],\n", 2776 | " api_key = 'YOUR_KEY',\n", 2777 | " project_name=\"YOUR_PROJECT_NAME\",\n", 2778 | " workspace = 'YOUR_WORKSPACE'\n", 2779 | ")\n", 2780 | "trainer = Trainer(\n", 2781 | " gpus=[0],\n", 2782 | " distributed_backend='dp',\n", 2783 | " logger=comet_logger,\n", 2784 | " max_epochs=conf['epochs'],\n", 2785 | " checkpoint_callback=checkpoint_callback,\n", 2786 | " show_progress_bar=False,\n", 2787 | " gradient_clip_val=0.5\n", 2788 | ")\n", 2789 | "\n", 2790 | "trainer.fit(model) " 2791 | ], 2792 | "execution_count": null, 2793 | "outputs": [ 2794 | { 2795 | "output_type": "stream", 2796 | "text": [ 2797 | "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/distributed.py:22: UserWarning: Checkpoint directory /content/gdrive/My Drive/Proyecto_RecSys/split/0/exp_recsys/13/checkpoints/ exists and is not empty with save_top_k != 0.All files in this directory will be deleted when a checkpoint is saved!\n", 2798 | " warnings.warn(*args, **kwargs)\n", 2799 | "CometLogger will be initialized in online mode\n", 2800 | "COMET INFO: ----------------------------\n", 2801 | "COMET INFO: Comet.ml Experiment Summary:\n", 2802 | "COMET INFO: Data:\n", 2803 | "COMET INFO: url: https://www.comet.ml/afvilla/lolnet/5094e7cd80244d62bac9be446cbfeb0b\n", 2804 | "COMET INFO: Metrics [count] (min, max):\n", 2805 | "COMET INFO: sys.cpu.percent.01 [4] : (1.0, 12.3)\n", 2806 | "COMET INFO: sys.cpu.percent.02 [4] : (1.0, 12.9)\n", 2807 | "COMET INFO: sys.cpu.percent.03 [4] : (0.9, 12.5)\n", 2808 | "COMET INFO: sys.cpu.percent.04 [4] : (1.0, 12.9)\n", 2809 | "COMET INFO: sys.cpu.percent.avg [4] : (0.975, 12.65)\n", 2810 | "COMET INFO: sys.gpu.0.free_memory [4] : (17061249024.0, 17061249024.0)\n", 2811 | "COMET INFO: sys.gpu.0.gpu_utilization [4]: (0.0, 0.0)\n", 2812 | "COMET INFO: sys.gpu.0.total_memory : (17071734784.0, 17071734784.0)\n", 2813 | "COMET INFO: sys.gpu.0.used_memory [4] : (10485760.0, 10485760.0)\n", 2814 | "COMET INFO: sys.ram.total [4] : (27393740800.0, 27393740800.0)\n", 2815 | "COMET INFO: sys.ram.used [4] : (7792205824.0, 7797219328.0)\n", 2816 | "COMET INFO: Other [count]:\n", 2817 | "COMET INFO: Name: Main_tasks_rec_only_winners_final_prueba\n", 2818 | "COMET INFO: ----------------------------\n", 2819 | "COMET INFO: old comet version (3.0.2) detected. current: 3.1.14 please update your comet lib with command: `pip install --no-cache-dir --upgrade comet_ml`\n", 2820 | "COMET INFO: Experiment is live on comet.ml https://www.comet.ml/afvilla/lolnet/8dcde95206ec45daac4cc6657844b03d\n", 2821 | "\n", 2822 | "GPU available: True, used: True\n", 2823 | "TPU available: False, using: 0 TPU cores\n", 2824 | "CUDA_VISIBLE_DEVICES: [0]\n", 2825 | "\n", 2826 | " | Name | Type | Params\n", 2827 | "-------------------------------------------------------\n", 2828 | "0 | model | TransformerLolRecommender | 3 M \n", 2829 | "1 | loss | BCEWithLogitsLoss | 0 \n", 2830 | "2 | loss_aux | CrossEntropyLoss | 0 \n", 2831 | "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/distributed.py:22: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", 2832 | " warnings.warn(*args, **kwargs)\n" 2833 | ], 2834 | "name": "stderr" 2835 | }, 2836 | { 2837 | "output_type": "display_data", 2838 | "data": { 2839 | "application/vnd.jupyter.widget-view+json": { 2840 | "model_id": "a705d3b71b5d4fb587b3bb1fb38161fa", 2841 | "version_minor": 0, 2842 | "version_major": 2 2843 | }, 2844 | "text/plain": [ 2845 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…" 2846 | ] 2847 | }, 2848 | "metadata": { 2849 | "tags": [] 2850 | } 2851 | }, 2852 | { 2853 | "output_type": "stream", 2854 | "text": [ 2855 | "/usr/local/lib/python3.6/dist-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead\n", 2856 | " warnings.warn(\"pickle support for Storage will be removed in 1.5. Use `torch.save` instead\", FutureWarning)\n" 2857 | ], 2858 | "name": "stderr" 2859 | }, 2860 | { 2861 | "output_type": "stream", 2862 | "text": [ 2863 | "\r" 2864 | ], 2865 | "name": "stdout" 2866 | }, 2867 | { 2868 | "output_type": "stream", 2869 | "text": [ 2870 | "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/distributed.py:22: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", 2871 | " warnings.warn(*args, **kwargs)\n" 2872 | ], 2873 | "name": "stderr" 2874 | }, 2875 | { 2876 | "output_type": "display_data", 2877 | "data": { 2878 | "application/vnd.jupyter.widget-view+json": { 2879 | "model_id": "c43e576730a940c28fa78d49e95e7165", 2880 | "version_minor": 0, 2881 | "version_major": 2 2882 | }, 2883 | "text/plain": [ 2884 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" 2885 | ] 2886 | }, 2887 | "metadata": { 2888 | "tags": [] 2889 | } 2890 | } 2891 | ] 2892 | } 2893 | ] 2894 | } 2895 | --------------------------------------------------------------------------------