├── README.md ├── polara_based.ipynb └── main.ipynb /README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /polara_based.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from polara import RecommenderData\n", 10 | "from polara import SVDModel\n", 11 | "from polara import get_movielens_data\n", 12 | "from polara.tools.preprocessing import filter_sessions_by_length\n", 13 | "from polara.evaluation import evaluation_engine as ee\n", 14 | "import numpy as np\n", 15 | "import scipy.sparse as SP\n", 16 | "from io import BytesIO\n", 17 | "import pandas as pd\n", 18 | "\n", 19 | "import numpy as np, scipy.stats as st\n", 20 | "import numpy as np\n", 21 | "import scipy as sp\n", 22 | "import scipy.stats\n", 23 | "\n" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "ml_train_items = np.load(\"ml_train_items.npy\")\n", 33 | "ml_train_mask = np.load(\"ml_train_mask.npy\")\n", 34 | "ml_train_users = np.load(\"ml_train_users.npy\")\n", 35 | "ml_val_items = np.load(\"ml_val_items.npy\")\n", 36 | "ml_val_mask = np.load(\"ml_val_mask.npy\")\n", 37 | "ml_val_users = np.load(\"ml_val_users.npy\")\n", 38 | "ml_test_items = np.load(\"ml_test_items.npy\")\n", 39 | "ml_test_mask = np.load(\"ml_test_mask.npy\")\n", 40 | "ml_test_users = np.load(\"ml_test_users.npy\")\n", 41 | "ml_train_user_idx = np.load('ml_train_user_idx.npy')\n", 42 | "ml_train_item_idx = np.load('ml_train_item_idx.npy')\n", 43 | "ml_train_feedback = np.load('ml_train_feedback.npy')" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 13, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/plain": [ 54 | "762.8421874999999" 55 | ] 56 | }, 57 | "execution_count": 13, 58 | "metadata": {}, 59 | "output_type": "execute_result" 60 | } 61 | ], 62 | "source": [ 63 | "len(lf_train_items)/128/50*10" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "lf_train_items = np.load(\"lf_train_items.npy\")\n", 73 | "lf_train_mask = np.load(\"lf_train_mask.npy\")\n", 74 | "lf_train_users = np.load(\"lf_train_users.npy\")\n", 75 | "lf_val_items = np.load(\"lf_val_items.npy\")\n", 76 | "lf_val_mask = np.load(\"lf_val_mask.npy\")\n", 77 | "lf_val_users = np.load(\"lf_val_users.npy\")\n", 78 | "lf_test_items = np.load(\"lf_test_items.npy\")\n", 79 | "lf_test_mask = np.load(\"lf_test_mask.npy\")\n", 80 | "lf_test_users = np.load(\"lf_test_users.npy\")\n", 81 | "lf_train_user_idx = np.load('lf_train_user_idx.npy')\n", 82 | "lf_train_item_idx = np.load('lf_train_item_idx.npy')\n", 83 | "lf_train_feedback = np.load('lf_train_feedback.npy')" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "Collecting git+https://github.com/Evfro/polara.git@develop\n", 96 | " Cloning https://github.com/Evfro/polara.git (to revision develop) to /tmp/pip-req-build-0a3rx0t8\n", 97 | "Building wheels for collected packages: polara\n", 98 | " Running setup.py bdist_wheel for polara ... \u001b[?25ldone\n", 99 | "\u001b[?25h Stored in directory: /tmp/pip-ephem-wheel-cache-x7hk2c0d/wheels/95/b2/f8/18e769bc21d1fc5323b933f0ab7261b9521a589243f7549bf4\n", 100 | "Successfully built polara\n", 101 | "Installing collected packages: polara\n", 102 | " Found existing installation: polara 0.5.3\n", 103 | " Uninstalling polara-0.5.3:\n", 104 | " Successfully uninstalled polara-0.5.3\n", 105 | "Successfully installed polara-0.5.3\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "!pip3 install --upgrade git+https://github.com/Evfro/polara.git@develop#egg=polara" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 2, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "def remove_gaps(data):\n", 120 | " data['movieid'] = ml_data.groupby('movieid', sort=False).grouper.group_info[0]\n", 121 | " data['userid'] = ml_data.groupby('userid', sort=False).grouper.group_info[0]\n", 122 | " return data" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 3, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "def normalize_timestamp(x):\n", 132 | " x[\"timestamp\"] = np.argsort(list(x[\"timestamp\"]))\n", 133 | " return x\n", 134 | "\n", 135 | "def length_col(x):\n", 136 | " x['timestamp'] = len(x)\n", 137 | " return x\n", 138 | "\n", 139 | "def train_test_val_split(data):\n", 140 | " \n", 141 | " data = data.groupby(\"userid\").apply(normalize_timestamp)\n", 142 | " lc = data.groupby(\"userid\").apply(length_col)\n", 143 | " max_time_stamp = lc['timestamp']\n", 144 | " timestamp = data['timestamp']\n", 145 | " data_train = data[timestamp=1).T[0]\n", 439 | " recommendation = new_P[user].dot(Q.T)\n", 440 | " recommendation[consumed_items] = -np.inf\n", 441 | " \n", 442 | " true_consumption = targets[row_inx]\n", 443 | " mrrs.append(mrr_at_k(recommendation,true_consumption,k=20))\n", 444 | " recalls.append(recall_at_k(recommendation,true_consumption,k=20))\n", 445 | " \n", 446 | " mrr, h_mrr = mean_confidence_interval(mrrs)\n", 447 | " recall, h_recall = mean_confidence_interval(recalls)\n", 448 | " \n", 449 | " return (mrr, h_mrr),(recall, h_recall)\n", 450 | " \n", 451 | "@jit(nopython=True, nogil=True)\n", 452 | "def recall_at_k(recommendation,true_consumption,k=20):\n", 453 | " topk_inds = recommendation.argsort()[-k:][::-1]\n", 454 | " reccommnded_topk_items = np.zeros(recommendation.shape)\n", 455 | " reccommnded_topk_items[topk_inds] = 1\n", 456 | " \n", 457 | " recall = reccommnded_topk_items[int(true_consumption)]\n", 458 | " return recall\n", 459 | "\n", 460 | "@jit(nopython=True, nogil=True)\n", 461 | "def mrr_at_k(recommendation,true_consumption,k=20):\n", 462 | " topk_inds = recommendation.argsort()[-k:][::-1]\n", 463 | " rr = np.zeros(recommendation.shape)\n", 464 | " rr[topk_inds] = 1/np.arange(1,k+1) \n", 465 | " current_rr = rr[int(true_consumption)]\n", 466 | " return current_rr\n", 467 | "\n", 468 | "\n", 469 | "\n", 470 | " " 471 | ] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "metadata": {}, 476 | "source": [ 477 | "### MovieLens" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": 15, 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "import scipy as sp\n", 487 | "train_matrix = np.array(sp.sparse.csr_matrix((np.ones(len(ml_train_user_idx)),\n", 488 | " (ml_train_user_idx, ml_train_item_idx)),\n", 489 | " shape=(max(ml_train_user_idx)+1,max(ml_train_item_idx)+1), dtype=np.float64).todense())\n" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 16, 495 | "metadata": {}, 496 | "outputs": [ 497 | { 498 | "name": "stdout", 499 | "output_type": "stream", 500 | "text": [ 501 | "Epoch 1 RMSE: 1.1889509722951297\n", 502 | "Epoch 2 RMSE: 0.9005314957384513\n", 503 | "Epoch 3 RMSE: 0.8463680763042712\n", 504 | "Epoch 4 RMSE: 0.8188051422034335\n", 505 | "Epoch 5 RMSE: 0.8004663158387065\n", 506 | "Epoch 6 RMSE: 0.7855627864243409\n", 507 | "Epoch 7 RMSE: 0.7732885934551574\n", 508 | "Epoch 8 RMSE: 0.7632023403587169\n", 509 | "Epoch 9 RMSE: 0.7548784181251089\n", 510 | "Epoch 10 RMSE: 0.7479567842743502\n", 511 | "Epoch 11 RMSE: 0.742152630887336\n", 512 | "Epoch 12 RMSE: 0.7372427424804076\n", 513 | "Epoch 13 RMSE: 0.7330523912264771\n", 514 | "Epoch 14 RMSE: 0.7294447595737159\n", 515 | "Epoch 15 RMSE: 0.7263125637673405\n", 516 | "Epoch 16 RMSE: 0.7235714286511026\n", 517 | "Epoch 17 RMSE: 0.7211546768303492\n", 518 | "Epoch 18 RMSE: 0.7190092797572853\n", 519 | "Epoch 19 RMSE: 0.7170927493456211\n", 520 | "Epoch 20 RMSE: 0.7153707702341925\n", 521 | "Epoch 21 RMSE: 0.7138154014210378\n", 522 | "Epoch 22 RMSE: 0.7124037087144436\n", 523 | "Epoch 23 RMSE: 0.7111167205435434\n", 524 | "Epoch 24 RMSE: 0.7099386259592084\n", 525 | "Epoch 25 RMSE: 0.708856154402729\n", 526 | "Epoch 26 RMSE: 0.7078580925338287\n", 527 | "Epoch 27 RMSE: 0.706934905051986\n", 528 | "Epoch 28 RMSE: 0.7060784349627898\n", 529 | "Epoch 29 RMSE: 0.7052816649704758\n", 530 | "Epoch 30 RMSE: 0.7045385262243349\n" 531 | ] 532 | } 533 | ], 534 | "source": [ 535 | "# train_feedback[:] = 1\n", 536 | "P, Q, _ = basic_matrix_factorization(ml_train_user_idx, ml_train_item_idx, ml_train_feedback\\\n", 537 | " ,rank=20,reg = 0.01,num_epochs=30)" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": 12, 543 | "metadata": {}, 544 | "outputs": [ 545 | { 546 | "name": "stdout", 547 | "output_type": "stream", 548 | "text": [ 549 | "MRR@20 score for MF on MovieLens : 0.004567202376613807 ± 0.001339522261853232\n", 550 | "Recall@20 score for MF on MovieLens : 0.019013581129378128 ± 0.005759417598196718\n" 551 | ] 552 | } 553 | ], 554 | "source": [ 555 | "(mrr, h_mrr),(recall, h_recall), = estimate_model(P,Q,ml_test_items, ml_test_mask, ml_test_users,reg = 0.01)\n", 556 | "ds_name = \"MovieLens\"\n", 557 | "print(\"MRR@20 score for MF on \", ds_name,\": \",mrr,\"±\",h_mrr)\n", 558 | "print(\"Recall@20 score for MF on\",ds_name,\": \",recall,\"±\",h_recall)" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": {}, 564 | "source": [ 565 | "### LastFM" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 7, 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "import scipy as sp\n", 575 | "train_matrix = np.array(sp.sparse.csr_matrix((np.ones(len(lf_train_user_idx)),\n", 576 | " (lf_train_user_idx, lf_train_item_idx)),\n", 577 | " shape=(max(lf_train_user_idx)+1,max(lf_train_item_idx)+1), dtype=np.float64).todense())\n", 578 | "\n" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 8, 584 | "metadata": {}, 585 | "outputs": [ 586 | { 587 | "name": "stdout", 588 | "output_type": "stream", 589 | "text": [ 590 | "Epoch 1 RMSE: 0.33389627661066806\n", 591 | "Epoch 2 RMSE: 0.13595819292966613\n", 592 | "Epoch 3 RMSE: 0.10508689900024767\n", 593 | "Epoch 4 RMSE: 0.09070079526606878\n", 594 | "Epoch 5 RMSE: 0.08147071061295962\n", 595 | "Epoch 6 RMSE: 0.07477153011689404\n", 596 | "Epoch 7 RMSE: 0.06956005889239489\n", 597 | "Epoch 8 RMSE: 0.06532270200989314\n", 598 | "Epoch 9 RMSE: 0.06177040644012343\n", 599 | "Epoch 10 RMSE: 0.058724800935394526\n", 600 | "Epoch 11 RMSE: 0.056068301626020486\n", 601 | "Epoch 12 RMSE: 0.05371946888402425\n", 602 | "Epoch 13 RMSE: 0.051619679737187375\n", 603 | "Epoch 14 RMSE: 0.04972540075475995\n", 604 | "Epoch 15 RMSE: 0.04800345859082646\n", 605 | "Epoch 16 RMSE: 0.04642801718895144\n", 606 | "Epoch 17 RMSE: 0.04497857557847389\n", 607 | "Epoch 18 RMSE: 0.04363860094936924\n", 608 | "Epoch 19 RMSE: 0.04239457073868288\n", 609 | "Epoch 20 RMSE: 0.041235285857219646\n", 610 | "Epoch 21 RMSE: 0.04015136833763086\n", 611 | "Epoch 22 RMSE: 0.03913488730514752\n", 612 | "Epoch 23 RMSE: 0.03817907605459219\n", 613 | "Epoch 24 RMSE: 0.03727811497846672\n", 614 | "Epoch 25 RMSE: 0.03642696285749802\n", 615 | "Epoch 26 RMSE: 0.035621224182504706\n", 616 | "Epoch 27 RMSE: 0.03485704367226345\n", 617 | "Epoch 28 RMSE: 0.03413102156514454\n", 618 | "Epoch 29 RMSE: 0.03344014495524048\n", 619 | "Epoch 30 RMSE: 0.032781731648978206\n" 620 | ] 621 | } 622 | ], 623 | "source": [ 624 | "P, Q, _ = basic_matrix_factorization(lf_train_user_idx, lf_train_item_idx, lf_train_feedback\\\n", 625 | " ,rank=20,reg = 0.01,num_epochs=30)" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": 9, 631 | "metadata": {}, 632 | "outputs": [ 633 | { 634 | "name": "stdout", 635 | "output_type": "stream", 636 | "text": [ 637 | "MRR@20 score for MF on LastFM : 2.559322309196812e-05 ± 4.358562546655403e-05\n", 638 | "Recall@20 score for MF on LastFM : 0.00022010271460014673 ± 0.0004937387018228\n" 639 | ] 640 | } 641 | ], 642 | "source": [ 643 | "(mrr, h_mrr),(recall, h_recall), = estimate_model(P,Q,lf_test_items, lf_test_mask, lf_test_users,reg = 0.01)\n", 644 | "ds_name = \"LastFM\"\n", 645 | "print(\"MRR@20 score for MF on \", ds_name,\": \",mrr,\"±\",h_mrr)\n", 646 | "print(\"Recall@20 score for MF on\",ds_name,\": \",recall,\"±\",h_recall)" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": null, 652 | "metadata": {}, 653 | "outputs": [], 654 | "source": [] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": null, 659 | "metadata": {}, 660 | "outputs": [], 661 | "source": [] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": null, 666 | "metadata": {}, 667 | "outputs": [], 668 | "source": [] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 3, 673 | "metadata": {}, 674 | "outputs": [ 675 | { 676 | "data": { 677 | "text/html": [ 678 | "
\n", 679 | "\n", 692 | "\n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | "
useridmovieidratingtimestamp
011225.0838985046
111855.0838983525
212315.0838983392
312925.0838983421
413165.0838983392
\n", 740 | "
" 741 | ], 742 | "text/plain": [ 743 | " userid movieid rating timestamp\n", 744 | "0 1 122 5.0 838985046\n", 745 | "1 1 185 5.0 838983525\n", 746 | "2 1 231 5.0 838983392\n", 747 | "3 1 292 5.0 838983421\n", 748 | "4 1 316 5.0 838983392" 749 | ] 750 | }, 751 | "execution_count": 3, 752 | "metadata": {}, 753 | "output_type": "execute_result" 754 | } 755 | ], 756 | "source": [ 757 | "ml_data = get_movielens_data(\"ml-10m.zip\", include_time=True)\n", 758 | "ml_data.head()" 759 | ] 760 | }, 761 | { 762 | "cell_type": "code", 763 | "execution_count": 10, 764 | "metadata": {}, 765 | "outputs": [], 766 | "source": [ 767 | "data = (filter_sessions_by_length(ml_data, min_session_length=20)\n", 768 | " #.query('rating >= 4')\n", 769 | " #.assign(rating=1)\n", 770 | " )" 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "execution_count": 11, 776 | "metadata": {}, 777 | "outputs": [], 778 | "source": [ 779 | "data_model = RecommenderData(data, 'userid', 'movieid', 'rating', custom_order='timestamp', seed=0)\n", 780 | "data_model.holdout_size = 1\n", 781 | "data_model.random_holdout = False\n", 782 | "data_model.warm_start = False\n", 783 | "data_model.permute_tops = False" 784 | ] 785 | }, 786 | { 787 | "cell_type": "code", 788 | "execution_count": 12, 789 | "metadata": {}, 790 | "outputs": [ 791 | { 792 | "name": "stdout", 793 | "output_type": "stream", 794 | "text": [ 795 | "Preparing data...\n", 796 | "Done.\n" 797 | ] 798 | } 799 | ], 800 | "source": [ 801 | "data_model.prepare()" 802 | ] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": 13, 807 | "metadata": {}, 808 | "outputs": [], 809 | "source": [ 810 | "idx, val, shp = data_model.to_coo()" 811 | ] 812 | }, 813 | { 814 | "cell_type": "code", 815 | "execution_count": null, 816 | "metadata": {}, 817 | "outputs": [], 818 | "source": [] 819 | }, 820 | { 821 | "cell_type": "code", 822 | "execution_count": 14, 823 | "metadata": {}, 824 | "outputs": [ 825 | { 826 | "name": "stdout", 827 | "output_type": "stream", 828 | "text": [ 829 | "Epoch 1 RMSE: 1.3085293049789994\n", 830 | "Epoch 2 RMSE: 0.9151132113832263\n", 831 | "Epoch 3 RMSE: 0.8808651366877939\n", 832 | "Epoch 4 RMSE: 0.8676692047213754\n", 833 | "Epoch 5 RMSE: 0.8587222170703273\n", 834 | "Epoch 6 RMSE: 0.8525469762850404\n", 835 | "Epoch 7 RMSE: 0.8479598409933002\n", 836 | "Epoch 8 RMSE: 0.8442665480920708\n", 837 | "Epoch 9 RMSE: 0.8411850669634959\n", 838 | "Epoch 10 RMSE: 0.8385955015689102\n" 839 | ] 840 | } 841 | ], 842 | "source": [ 843 | "P, Q, biases = basic_matrix_factorization(*idx.T, val,rank=20,num_epochs=10)" 844 | ] 845 | }, 846 | { 847 | "cell_type": "code", 848 | "execution_count": 20, 849 | "metadata": {}, 850 | "outputs": [], 851 | "source": [ 852 | "R = Q.dot(P.T).T" 853 | ] 854 | }, 855 | { 856 | "cell_type": "code", 857 | "execution_count": 50, 858 | "metadata": {}, 859 | "outputs": [], 860 | "source": [ 861 | "topk = np.argsort(R,axis = 1)[:,-20:]" 862 | ] 863 | }, 864 | { 865 | "cell_type": "code", 866 | "execution_count": 53, 867 | "metadata": {}, 868 | "outputs": [ 869 | { 870 | "data": { 871 | "text/plain": [ 872 | "(69878, 20)" 873 | ] 874 | }, 875 | "execution_count": 53, 876 | "metadata": {}, 877 | "output_type": "execute_result" 878 | } 879 | ], 880 | "source": [] 881 | }, 882 | { 883 | "cell_type": "code", 884 | "execution_count": 60, 885 | "metadata": {}, 886 | "outputs": [], 887 | "source": [ 888 | "user_idx, item_idx, fdbk_val = data_model.test_to_coo()" 889 | ] 890 | }, 891 | { 892 | "cell_type": "code", 893 | "execution_count": 75, 894 | "metadata": {}, 895 | "outputs": [ 896 | { 897 | "ename": "TypeError", 898 | "evalue": "__init__() got multiple values for argument 'shape'", 899 | "output_type": "error", 900 | "traceback": [ 901 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 902 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 903 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mSP\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcoo_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfdbk_val\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0muser_idx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mitem_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0muser_idx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mitem_idx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtodense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 904 | "\u001b[0;31mTypeError\u001b[0m: __init__() got multiple values for argument 'shape'" 905 | ] 906 | } 907 | ], 908 | "source": [ 909 | "SP.coo_matrix(fdbk_val,(user_idx,item_idx),shape =(user_idx.max()+1,item_idx.max()+1) ).todense().shape" 910 | ] 911 | }, 912 | { 913 | "cell_type": "code", 914 | "execution_count": 73, 915 | "metadata": {}, 916 | "outputs": [ 917 | { 918 | "data": { 919 | "text/plain": [ 920 | "array([ 108, 1557, 1564, ..., 1365, 775, 2300])" 921 | ] 922 | }, 923 | "execution_count": 73, 924 | "metadata": {}, 925 | "output_type": "execute_result" 926 | } 927 | ], 928 | "source": [ 929 | "item_idx" 930 | ] 931 | }, 932 | { 933 | "cell_type": "code", 934 | "execution_count": null, 935 | "metadata": {}, 936 | "outputs": [], 937 | "source": [ 938 | "from polara import RecommenderData\n", 939 | "from polara import SVDModel\n", 940 | "from polara import get_movielens_data\n", 941 | "from polara.tools.preprocessing import filter_sessions_by_length\n", 942 | "from polara.evaluation import evaluation_engine as ee\n", 943 | "import numpy as np\n", 944 | "import scipy.sparse as SP\n", 945 | "from io import BytesIO\n", 946 | "import pandas as pd\n", 947 | "\n", 948 | "def train_MF():\n", 949 | " ml_data = get_movielens_data(\"ml-10m.zip\", include_time=True)\n", 950 | " data = (filter_sessions_by_length(ml_data, min_session_length=20)\n", 951 | " #.query('rating >= 4')\n", 952 | " #.assign(rating=1)\n", 953 | " )\n", 954 | " data_model = RecommenderData(data, 'userid', 'movieid', 'rating', custom_order='timestamp', seed=0)\n", 955 | " #data_model.holdout_size = 1\n", 956 | " data_model.random_holdout = False\n", 957 | " data_model.warm_start = False\n", 958 | " data_model.permute_tops = False\n" 959 | ] 960 | } 961 | ], 962 | "metadata": { 963 | "kernelspec": { 964 | "display_name": "Python 3", 965 | "language": "python", 966 | "name": "python3" 967 | }, 968 | "language_info": { 969 | "codemirror_mode": { 970 | "name": "ipython", 971 | "version": 3 972 | }, 973 | "file_extension": ".py", 974 | "mimetype": "text/x-python", 975 | "name": "python", 976 | "nbconvert_exporter": "python", 977 | "pygments_lexer": "ipython3", 978 | "version": "3.5.2" 979 | } 980 | }, 981 | "nbformat": 4, 982 | "nbformat_minor": 2 983 | } 984 | -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### 1. Linear User-based GRU" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch, torch.nn as nn\n", 17 | "import torch.nn.functional as F\n", 18 | "from torch.autograd import Variable\n", 19 | "import numpy as np\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "\n", 22 | "class LinearGRU(nn.Module):\n", 23 | " def __init__(self, n_users,n_items, emb_size=None, hidden_units=1000,dropout = 0.8,user_dropout = 0.5):\n", 24 | " super(self.__class__, self).__init__()\n", 25 | " self.n_users = n_users\n", 26 | " self.n_items = n_items\n", 27 | " self.hidden_units = hidden_units\n", 28 | " if emb_size == None:\n", 29 | " emb_size = hidden_units\n", 30 | " self.emb_size = emb_size\n", 31 | " ## todo why embeding?\n", 32 | " self.user_emb = nn.Embedding(n_users,emb_size)\n", 33 | " self.item_emb = nn.Embedding(n_items,emb_size)\n", 34 | " self.grucell = nn.GRUCell(input_size = emb_size*2,hidden_size = hidden_units)\n", 35 | " self.linear = nn.Linear(hidden_units,n_items)\n", 36 | " self.dropout = nn.Dropout(dropout)\n", 37 | " self.user_dropout = nn.Dropout(user_dropout)\n", 38 | " \n", 39 | " def forward(self, user_vectors, item_vectors):\n", 40 | " \n", 41 | " batch_size,_ = user_vectors.size()\n", 42 | " user_vectors = user_vectors\n", 43 | " item_vectors = item_vectors\n", 44 | " sequence_size = user_vectors.size()[1]\n", 45 | " \n", 46 | " users = self.user_dropout(self.user_emb(user_vectors))#.view(-1,sequence_size,self.emb_size)\n", 47 | " items = self.item_emb(item_vectors)#.view(-1,sequence_size,self.emb_size)\n", 48 | " \n", 49 | " h = torch.zeros(batch_size,self.hidden_units).to(device)\n", 50 | " h_t = h.unsqueeze(0)\n", 51 | " for i in range(sequence_size):\n", 52 | " gru_input = torch.cat([users[:,i,:],items[:,i,:]],dim=-1)\n", 53 | " h = self.grucell(gru_input,h)\n", 54 | " h_t = torch.cat([h_t,h.unsqueeze(0)],dim=0)\n", 55 | " ln_input = self.dropout(h_t[1:].transpose(0,1))\n", 56 | " \n", 57 | " output_ln = self.linear(ln_input)\n", 58 | " output = F.log_softmax(output_ln, dim=-1)\n", 59 | " return output\n", 60 | " \n" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "### 2. Rectified Linear User-based GRU" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 2, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "from torch.nn.parameter import Parameter\n", 77 | "class RectifiedLinearGRU(nn.Module):\n", 78 | "\n", 79 | " def __init__(self, n_users,n_items, emb_size=None, hidden_units=1000,dropout = 0.8,user_dropout = 0.5):\n", 80 | " super(self.__class__, self).__init__()\n", 81 | " self.n_users = n_users\n", 82 | " self.n_items = n_items\n", 83 | " self.hidden_units = hidden_units\n", 84 | " if emb_size == None:\n", 85 | " emb_size = hidden_units\n", 86 | " self.emb_size = emb_size\n", 87 | " ## todo why embeding?\n", 88 | " self.user_emb = nn.Embedding(n_users,emb_size)\n", 89 | " self.item_emb = nn.Embedding(n_items,emb_size)\n", 90 | " self.k1 = nn.Linear(hidden_units+2*emb_size,emb_size)\n", 91 | " self.k2 = nn.Linear(hidden_units+2*emb_size,emb_size)\n", 92 | " self.grucell = nn.GRUCell(input_size = emb_size*2,hidden_size = hidden_units)\n", 93 | " self.linear = nn.Linear(hidden_units,n_items)\n", 94 | " self.dropout = nn.Dropout(dropout)\n", 95 | " self.user_dropout = nn.Dropout(user_dropout)\n", 96 | " \n", 97 | " def forward(self, user_vectors, item_vectors):\n", 98 | " batch_size,_ = user_vectors.size()\n", 99 | " user_vectors = user_vectors\n", 100 | " item_vectors = item_vectors\n", 101 | " sequence_size = user_vectors.size()[1]\n", 102 | " \n", 103 | " users = self.user_dropout(self.user_emb(user_vectors))#.view(-1,sequence_size,self.emb_size)\n", 104 | " items = self.item_emb(item_vectors)#.view(-1,sequence_size,self.emb_size)\n", 105 | " \n", 106 | " h = torch.zeros(batch_size,self.hidden_units).to(device)\n", 107 | " h_t = h.unsqueeze(0)\n", 108 | " for i in range(sequence_size):\n", 109 | " rect_users = rectified_users(self,users[:,i,:],items[:,i,:],h)\n", 110 | " gru_input = torch.cat([rect_users,items[:,i,:]],dim=-1)\n", 111 | " h = self.grucell(gru_input,h)\n", 112 | " h_t = torch.cat([h_t,h.unsqueeze(0)],dim=0)\n", 113 | " ln_input = self.dropout(h_t[1:].transpose(0,1))\n", 114 | " output_ln = self.linear(ln_input)\n", 115 | "\n", 116 | " output = F.log_softmax(output_ln, dim=-1)\n", 117 | " return output\n", 118 | " \n", 119 | "def rectified_users(self,users,items,h):\n", 120 | " \n", 121 | " k1 = self.k1(torch.cat([users,items,h],dim = -1))\n", 122 | " k2 = self.k2(torch.cat([users,items,h],dim = -1))\n", 123 | " rect_users = users\n", 124 | " rect_users[users h x d_k \n", 241 | " query, key, value = \\\n", 242 | " [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)\n", 243 | " for l, x in zip(self.linears, (query, key, value))]\n", 244 | " \n", 245 | " # 2) Apply attention on all the projected vectors in batch. \n", 246 | " x, self.attn = attention(query, key, value, mask=mask, \n", 247 | " dropout=self.dropout)\n", 248 | " \n", 249 | " # 3) \"Concat\" using a view and apply a final linear. \n", 250 | " x = x.transpose(1, 2).contiguous() \\\n", 251 | " .view(nbatches, -1, self.h * self.d_k)\n", 252 | " return self.linears[-1](x)\n", 253 | "\n", 254 | "\n", 255 | "class MHLinearGRU(nn.Module):\n", 256 | " \n", 257 | " def __init__(self, n_users,n_items, emb_size=1000,head_num = 5,dropout = 0.8,user_dropout = 0.5):\n", 258 | " super(self.__class__, self).__init__()\n", 259 | " self.n_users = n_users\n", 260 | " self.n_items = n_items\n", 261 | " self.hidden_units = emb_size\n", 262 | " self.emb_size = emb_size\n", 263 | " ## todo why embeding?\n", 264 | " self.user_emb = nn.Embedding(n_users,emb_size)\n", 265 | " self.item_emb = nn.Embedding(n_items,emb_size)\n", 266 | " self.grucell = nn.GRUCell(input_size = emb_size*2,hidden_size = self.hidden_units)\n", 267 | " #self.att_linear = nn.Linear(hidden_units+emb_size*2,emb_size)\n", 268 | "# torch.nn.init.constant_(self.att_linear.weight,1e-6)\n", 269 | "# torch.nn.init.constant_(self.att_linear.bias,1e-6)\n", 270 | " self.linear = nn.Linear(self.hidden_units,n_items)\n", 271 | " self.dropout = nn.Dropout(dropout)\n", 272 | " self.user_dropout = nn.Dropout(user_dropout)\n", 273 | " self.user_attention = MultiHeadedAttention(head_num, self.emb_size)\n", 274 | " self.item_attention = MultiHeadedAttention(head_num, self.emb_size)\n", 275 | " self.bn_user = nn.BatchNorm1d(self.emb_size)\n", 276 | " self.bn_item = nn.BatchNorm1d(self.emb_size)\n", 277 | " self.bn_last = nn.BatchNorm1d(self.emb_size)\n", 278 | " \n", 279 | " def forward(self, user_vectors, item_vectors):\n", 280 | " batch_size,_ = user_vectors.size()\n", 281 | " user_vectors = user_vectors\n", 282 | " item_vectors = item_vectors\n", 283 | " sequence_size = user_vectors.size()[1]\n", 284 | " \n", 285 | " users = self.user_dropout(self.user_emb(user_vectors))#.view(-1,sequence_size,self.emb_size)\n", 286 | " \n", 287 | " items = self.item_emb(item_vectors)#.view(-1,sequence_size,self.emb_size)\n", 288 | " \n", 289 | " h = torch.zeros(batch_size,self.hidden_units).to(device)\n", 290 | " h_t = h.unsqueeze(0)\n", 291 | " for i in range(sequence_size):\n", 292 | " #attention = F.sigmoid(self.att_linear(torch.cat([users[:,i,:],items[:,i,:],h],dim = -1)))\n", 293 | " attnd_users = self.user_attention(items[:,i,:],h,users[:,i,:]).squeeze(1)\n", 294 | " attnd_items = self.item_attention(users[:,i,:],h,items[:,i,:]).squeeze(1)\n", 295 | " attnd_users = self.bn_user(attnd_users)\n", 296 | " attnd_items = self.bn_item(attnd_items)\n", 297 | " gru_input = torch.cat([attnd_users,attnd_items],dim=-1)\n", 298 | " h = self.grucell(gru_input,h)\n", 299 | " h_t = torch.cat([h_t,h.unsqueeze(0)],dim=0)\n", 300 | " ln_input = self.dropout(h_t[1:].transpose(0,1))\n", 301 | " output_ln = self.linear(ln_input)\n", 302 | " output = F.log_softmax(output_ln, dim=-1)\n", 303 | " return output\n", 304 | " \n", 305 | " " 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 9, 311 | "metadata": {}, 312 | "outputs": [ 313 | { 314 | "ename": "NameError", 315 | "evalue": "name 'ml_test_users' is not defined", 316 | "output_type": "error", 317 | "traceback": [ 318 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 319 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 320 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mn_users\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mml_test_users\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mn_items\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mml_train_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mml_val_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mml_test_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mn_users\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 321 | "\u001b[0;31mNameError\u001b[0m: name 'ml_test_users' is not defined" 322 | ] 323 | } 324 | ], 325 | "source": [ 326 | "n_users = int(ml_test_users.max()+1)\n", 327 | "n_items = int(np.max([ml_train_items.max()+1,ml_val_items.max()+1,ml_test_items.max()])+1)\n", 328 | "n_users" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 97, 334 | "metadata": {}, 335 | "outputs": [ 336 | { 337 | "data": { 338 | "text/plain": [ 339 | "torch.Size([2, 5, 10678])" 340 | ] 341 | }, 342 | "execution_count": 97, 343 | "metadata": {}, 344 | "output_type": "execute_result" 345 | } 346 | ], 347 | "source": [ 348 | "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n", 349 | "# device = torch.device(\"cpu\")\n", 350 | "# network = MHLinearGRU(n_users=3,n_items=3,emb_size=20).to(device)\n", 351 | "network = MHLinearGRU(n_users=n_users,n_items=n_items).to(device)\n", 352 | "\n", 353 | "users = np.array([[1,1,1,1,1],\n", 354 | " [2,2,2,2,2]])\n", 355 | "items = np.array([[0,1,2,1,1],\n", 356 | " [0,2,2,1,0]])\n", 357 | "pred = network(Variable(torch.LongTensor(users)).to(device),Variable(torch.LongTensor(items)).to(device))\n", 358 | "pred.size()" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "### MovieLens Prerocessing" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 5, 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "# Libraries and provided functions\n", 375 | "import pandas as pd\n", 376 | "import zipfile\n", 377 | "import wget\n", 378 | "from io import StringIO \n", 379 | "import numpy as np\n", 380 | "import scipy as sp\n", 381 | "from scipy import sparse\n", 382 | "import scipy.sparse.linalg\n", 383 | "from tqdm import tqdm # Very useful library to see progress bar during range iterations: just type `for i in tqdm(range(10)):`\n", 384 | "from matplotlib import pyplot as plt\n", 385 | "%matplotlib inline\n", 386 | "\n", 387 | "from collections import namedtuple\n", 388 | "import sys\n", 389 | "\n", 390 | "def drop_unused_items(train,val,test):\n", 391 | " train_items = train.itemid.unique()\n", 392 | " val_items = val.itemid.unique()\n", 393 | " test_items = test.itemid.unique()\n", 394 | " \n", 395 | " droped_items = list((set(val_items)|set(test_items)) - set(train_items))\n", 396 | " val_mask = val.userid == droped_items[0]\n", 397 | " test_mask = test.userid == droped_items[0]\n", 398 | " for droped_item in droped_items:\n", 399 | " val_mask += val.itemid==droped_item\n", 400 | " test_mask += test.itemid==droped_item\n", 401 | " val = val[~val_mask]\n", 402 | " test = test[~test_mask]\n", 403 | " return val,test\n", 404 | "\n", 405 | "def move_timestamps_to_end(x,max_order):\n", 406 | " new_order = x.groupby('timestamp', sort=True).grouper.group_info[0]\n", 407 | " x[\"timestamp\"] = (max_order - new_order.max())+new_order\n", 408 | " return x\n", 409 | "\n", 410 | "def normalize_timestamp(x):\n", 411 | " x[\"timestamp\"] = x.groupby(['timestamp','itemid'], sort=True).grouper.group_info[0]\n", 412 | " return x\n", 413 | "\n", 414 | "def set_timestamp_length(x):\n", 415 | " x['length'] = len(x)\n", 416 | " return x\n", 417 | "\n", 418 | "def to_coo(data):\n", 419 | " \n", 420 | " user_idx, item_idx, feedback = data['userid'], data['itemid'], data['rating']\n", 421 | " return user_idx, item_idx, feedback\n", 422 | "\n", 423 | "def to_matrices(data):\n", 424 | " data = split_by_groups(data)\n", 425 | " \n", 426 | " data_max_order = data['timestamp'].max()\n", 427 | " data = data.groupby(\"index\").apply(move_timestamps_to_end,data_max_order)\n", 428 | "\n", 429 | " data_shape = data[['index', 'timestamp']].max()+1\n", 430 | " data_matrix = sp.sparse.csr_matrix((data['itemid'],\n", 431 | " (data['index'], data['timestamp'])),\n", 432 | " shape=data_shape, dtype=np.float64).todense()\n", 433 | " mask_matrix = sp.sparse.csr_matrix((np.ones(len(data)),\n", 434 | " (data['index'], data['timestamp'])),\n", 435 | " shape=data_shape, dtype=np.float64).todense()\n", 436 | " \n", 437 | " data_users = data.drop_duplicates(['index'])\n", 438 | " user_data_shape = data_users['index'].max()+1\n", 439 | " user_vector = sp.sparse.csr_matrix((data_users['userid'],\n", 440 | " (data_users['index'],np.zeros(user_data_shape))),\n", 441 | " shape=(user_data_shape,1), dtype=np.float64).todense()\n", 442 | " user_matrix = np.tile(user_vector,(1,data_shape[1]))\n", 443 | " return data_matrix, mask_matrix, user_matrix\n", 444 | "\n", 445 | "def train_val_test_split(data,frac):\n", 446 | " data = data.groupby(\"userid\").apply(set_timestamp_length)\n", 447 | " max_time_stamp = data['length']*frac\n", 448 | " timestamp = data['timestamp']\n", 449 | " data_train = data[timestamp\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cuda:2\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_available\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# device = torch.device(\"cpu\")\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mn_users\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mml_test_users\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mn_items\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mml_train_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mml_val_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mml_test_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mnetwork\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMHLinearGRU\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_users\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_users\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mn_items\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_items\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 710 | "\u001b[0;31mNameError\u001b[0m: name 'ml_test_users' is not defined" 711 | ] 712 | } 713 | ], 714 | "source": [ 715 | "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n", 716 | "# device = torch.device(\"cpu\")\n", 717 | "n_users = int(ml_test_users.max()+1)\n", 718 | "n_items = int(np.max([ml_train_items.max()+1,ml_val_items.max()+1,ml_test_items.max()])+1)\n", 719 | "network = MHLinearGRU(n_users=n_users,n_items=n_items).to(device)" 720 | ] 721 | }, 722 | { 723 | "cell_type": "code", 724 | "execution_count": 89, 725 | "metadata": {}, 726 | "outputs": [], 727 | "source": [ 728 | "import torch.utils.data\n", 729 | "opt = torch.optim.Adam(network.parameters(),lr =0.001)\n", 730 | "\n", 731 | "history = []\n", 732 | "\n", 733 | "train_loader = torch.utils.data.DataLoader(\\\n", 734 | " torch.utils.data.TensorDataset(\\\n", 735 | " *(torch.LongTensor(ml_train_users),torch.LongTensor(ml_train_items),torch.FloatTensor(ml_train_mask))),\\\n", 736 | " batch_size=1000,shuffle=True)\n", 737 | "\n", 738 | "val_loader = torch.utils.data.DataLoader(\\\n", 739 | " torch.utils.data.TensorDataset(\\\n", 740 | " *(torch.LongTensor(ml_val_users),torch.LongTensor(ml_val_items),torch.FloatTensor(ml_val_mask))),\\\n", 741 | " batch_size=1000,shuffle=True)\n", 742 | "test_loader = torch.utils.data.DataLoader(\\\n", 743 | " torch.utils.data.TensorDataset(\\\n", 744 | " *(torch.LongTensor(ml_test_users),torch.LongTensor(ml_test_items),torch.FloatTensor(ml_test_mask))),\\\n", 745 | " batch_size=1000,shuffle=True)" 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": 6, 751 | "metadata": {}, 752 | "outputs": [], 753 | "source": [ 754 | "import torch.utils.data\n", 755 | "from IPython.display import clear_output\n", 756 | "import gc\n", 757 | "\n", 758 | "\n", 759 | "\n", 760 | "def validate_lce(network,val_loader):\n", 761 | " torch.cuda.empty_cache()\n", 762 | " gc.collect()\n", 763 | " network.eval()\n", 764 | " losses = []\n", 765 | " with torch.no_grad():\n", 766 | " for user_batch_ix,item_batch_ix, mask_batch_ix in val_loader:\n", 767 | "\n", 768 | " user_batch_ix = Variable(user_batch_ix).to(device)\n", 769 | " item_batch_ix = Variable(item_batch_ix).to(device)\n", 770 | " mask_batch_ix = Variable(mask_batch_ix).to(device)\n", 771 | "\n", 772 | " logp_seq = network(user_batch_ix, item_batch_ix)\n", 773 | " # compute loss\n", 774 | " predictions_logp = logp_seq[:, :-1]*mask_batch_ix[:, :-1,None]\n", 775 | " actual_next_tokens = item_batch_ix[:, 1:]\n", 776 | "\n", 777 | " logp_next = torch.gather(predictions_logp, dim=2, index=actual_next_tokens[:,:,None])\n", 778 | " loss = -logp_next.sum()/mask_batch_ix[:, :-1].sum()\n", 779 | " losses.append(loss.cpu().data.numpy())\n", 780 | " torch.cuda.empty_cache()\n", 781 | " gc.collect() \n", 782 | " return np.mean(losses)\n", 783 | "\n", 784 | "\n", 785 | "def train_network(network,train_loader,val_loader,num_epoch = 10):\n", 786 | " for epoch in range(num_epoch):\n", 787 | " i=0\n", 788 | " for user_batch_ix,item_batch_ix, mask_batch_ix in train_loader:\n", 789 | " network.train()\n", 790 | " user_batch_ix = Variable(user_batch_ix).to(device)\n", 791 | " item_batch_ix = Variable(item_batch_ix).to(device)\n", 792 | " mask_batch_ix = Variable(mask_batch_ix).to(device)\n", 793 | "\n", 794 | " logp_seq = network(user_batch_ix, item_batch_ix)\n", 795 | " # compute loss\n", 796 | " predictions_logp = logp_seq[:, :-1]*mask_batch_ix[:, :-1,None]\n", 797 | " actual_next_tokens = item_batch_ix[:, 1:]\n", 798 | "\n", 799 | " logp_next = torch.gather(predictions_logp, dim=2, index=actual_next_tokens[:,:,None])\n", 800 | " loss = -logp_next.sum()/mask_batch_ix[:, :-1].sum()\n", 801 | "\n", 802 | " # train with backprop\n", 803 | " opt.zero_grad()\n", 804 | " loss.backward()\n", 805 | " nn.utils.clip_grad_norm_(network.parameters(),5)\n", 806 | " opt.step()\n", 807 | " \n", 808 | " if (i+1)%50==0:\n", 809 | " val_loss = validate_lce(network,val_loader)\n", 810 | " history.append(val_loss)\n", 811 | "\n", 812 | " clear_output(True)\n", 813 | " plt.title(\"Validation error\")\n", 814 | " plt.plot(history)\n", 815 | " plt.ylabel('Cross-entropy Error')\n", 816 | " plt.xlabel('#iter')\n", 817 | " plt.show()\n", 818 | " i+=1\n", 819 | "\n", 820 | " val_loss = validate_lce(network,val_loader)\n", 821 | " history.append(val_loss)\n", 822 | "\n", 823 | " clear_output(True)\n", 824 | " plt.plot(history,label='val loss')\n", 825 | " plt.legend()\n", 826 | " plt.show()\n" 827 | ] 828 | }, 829 | { 830 | "cell_type": "code", 831 | "execution_count": 91, 832 | "metadata": {}, 833 | "outputs": [ 834 | { 835 | "data": { 836 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl4leWd//H3NycJ2UhCQkhISNiUzSAIUXHDVrRqVayjjto6VVtr++um0/l1atvpPr/pojO2vWzLONqpvVr3utZ9q2JVFJAdUXayQBIgCwnZv78/zkFDDCSBk5yccz6v6+JqzvPcec73nD58uL2f57lvc3dERCS2JES6ABERCT+Fu4hIDFK4i4jEIIW7iEgMUriLiMQghbuISAxSuIuIxCCFu4hIDFK4i4jEoMRIvfHo0aN9woQJkXp7EZGotGzZslp3z+urXcTCfcKECSxdujRSby8iEpXMbFt/2vVrWMbM/tnM1prZGjO718xSeuz/hpmtM7NVZvaimY0/kqJFRCQ8+gx3MysCvg6UuXspEACu7NHsndD+44GHgF+Eu1AREem//l5QTQRSzSwRSAMqu+9095fdvTn08k1gXPhKFBGRgepzzN3dK8zsVmA7sB94zt2fO8yvfB54Okz1iUiMam9vp7y8nJaWlkiXMiylpKQwbtw4kpKSjuj3+wx3MxsFXAxMBOqAB83sanf/Uy9trwbKgDMPcawbgBsASkpKjqhgEYkN5eXljBw5kgkTJmBmkS5nWHF3du/eTXl5ORMnTjyiY/RnWOZsYIu717h7O/AwcGrPRmZ2NvBdYKG7tx6i4Dvcvczdy/Ly+ryTR0RiWEtLC7m5uQr2XpgZubm5R/VfNf0J9+3APDNLs+D/CwuA9T0KOQH4b4LBXn3E1YhIXFGwH9rRfjd9hru7LyF4B8xyYHXod+4wsx+b2cJQs1uADIJDNivM7PGjquow3t3ZwK3PbmBPU9tgvYWISNTr10NM7v4D4Ac9Nn+/2/6zw1nU4WytbeL2lzdy/swCctKTh+ptRUTIyMhg3759/d4eSVE3t0xmavDKcX1ze4QrEREZvqIu3LNTg731+v0KdxE5cjfffDO/+c1vPnj9wx/+kFtvvZV9+/axYMEC5syZw8yZM3nsscf6fUx355vf/CalpaXMnDmT+++/H4Cqqirmz5/P7NmzKS0tZfHixXR2dnLttdd+0Pa2224L6+eL2NwyRyorLdRzV7iLxIwfPbGWdZUNYT3mjMJMfnDRcYfcf8UVV3DTTTfxla98BYAHHniAZ599lpSUFB555BEyMzOpra1l3rx5LFy4sF8XOB9++GFWrFjBypUrqa2t5cQTT2T+/Pncc889nHvuuXz3u9+ls7OT5uZmVqxYQUVFBWvWrAGgrq4uPB88JPrCPTQsU6dwF5GjcMIJJ1BdXU1lZSU1NTWMGjWK4uJi2tvb+c53vsOrr75KQkICFRUV7Nq1i4KCgj6P+dprr3HVVVcRCATIz8/nzDPP5O233+bEE0/kc5/7HO3t7XzqU59i9uzZTJo0ic2bN/O1r32NCy64gE984hNh/XxRF+7pyQESE0w9d5EYcrge9mC6/PLLeeihh9i5cydXXHEFAH/+85+pqalh2bJlJCUlMWHChKN+inb+/Pm8+uqrPPnkk1x77bV84xvf4LOf/SwrV67k2WefZdGiRTzwwAP8/ve/D8fHAqJwzN3MyEpNok4XVEXkKF1xxRXcd999PPTQQ1x++eUA1NfXM2bMGJKSknj55ZfZtq1fM+wCcMYZZ3D//ffT2dlJTU0Nr776KieddBLbtm0jPz+fL3zhC1x//fUsX76c2tpaurq6uPTSS/n3f/93li9fHtbPFnU9dwiOuzeo5y4iR+m4446jsbGRoqIixo4dC8BnPvMZLrroImbOnElZWRnTpk3r9/EuueQS3njjDWbNmoWZ8Ytf/IKCggLuvvtubrnlFpKSksjIyOCPf/wjFRUVXHfddXR1dQHw05/+NKyfzdw9rAfsr7KyMj/SxTou+e3fSU9O5E/XnxzmqkRkqKxfv57p06dHuoxhrbfvyMyWuXtZX78bdcMyANmpSdTt1xOqIiKHEpXhnpWapAuqIiKHEbXhrguqItEvUsPC0eBov5voDPe0ZBpbOujs0okhEq1SUlLYvXu3Ar4XB+ZzT0lJ6bvxIUTn3TKhB5kaW9rJTtPkYSLRaNy4cZSXl1NTUxPpUoalAysxHamoDPfsA0+pNivcRaJVUlLSEa8yJH2LzmGZVM0vIyJyOFEZ7tlpml9GRORw+hXuZvbPZrbWzNaY2b1mltJj/wgzu9/MNprZEjObMBjFHqCeu4jI4fUZ7mZWBHwdKHP3UiAAXNmj2eeBve5+DHAb8PNwF9qdwl1E5PD6OyyTCKSaWSKQBlT22H8xcHfo54eABTaIK99+uBqTnlIVEelNfxbIrgBuBbYDVUC9uz/Xo1kRsCPUvgOoB3LDW+qHUpICpCQlqOcuInII/RmWGUWwZz4RKATSzezqI3kzM7vBzJaa2dKjvbc1OzVZ4S4icgj9GZY5G9ji7jXu3g48DJzao00FUAwQGrrJAnb3PJC73+HuZe5elpeXd1SFawoCEZFD60+4bwfmmVlaaBx9AbC+R5vHgWtCP18GvOSD/EyxJg8TETm0/oy5LyF4kXQ5sDr0O3eY2Y/NbGGo2V1ArpltBL4B3DxI9X4gK03hLiJyKP2afsDdfwD8oMfm73fb3wJcHsa6+pSVmsQahbuISK+i8glVCM4vo567iEjvojbcs1KTaG7rpK2jK9KliIgMO1Eb7gfml1HvXUTko6I23D94SlVrqYqIfETUhrvmlxERObSoDfcDi3Qo3EVEPipqwz2r22pMIiJysKgN92wNy4iIHFLUhnumeu4iIocUteEeSDBGpiSq5y4i0ouoDXcIjrs3KNxFRD4i6sNdi2SLiHxUVId7tmaGFBHpVVSHe3DBDj2hKiLSU5SHezL1+zsiXYaIyLAT5eEevKA6yIs+iYhEnf4skD3VzFZ0+9NgZjf1aJNlZk+Y2UozW2tm1w1eyR/KSk2irbOL/e2dQ/F2IiJRo8+VmNx9AzAbwMwCBBfDfqRHs68A69z9IjPLAzaY2Z/dfVAHxLtP+5uW3K9FpURE4sJAh2UWAJvcfVuP7Q6MDC2gnQHsAQZ9MFzzy4iI9G6g4X4lcG8v228HpgOVBBfRvtHdP7JEkpndYGZLzWxpTU3NgIvtSfPLiIj0rt/hbmbJwELgwV52nwusAAoJDuHcbmaZPRu5+x3uXubuZXl5eUdY8ocyFe4iIr0aSM/9fGC5u+/qZd91wMMetBHYAkwLR4GH88GYu4ZlREQOMpBwv4reh2QAthMcj8fM8oGpwOajK61vWo1JRKR3/brFxMzSgXOAL3bb9iUAd18E/AT4g5mtBgz4lrvXhr/cg2WMSCSQYNRpHVURkYP0K9zdvQnI7bFtUbefK4FPhLe0vpkZeRkj2NXQOtRvLSIyrEX1E6oAJTlpbN/THOkyRESGlagP93E5qexQuIuIHCTqw70kJ42dDS20dmgKAhGRA2Ii3N2hYu/+SJciIjJsRH24F+ekAWjcXUSkm6gP95JQuGvcXUTkQ1Ef7nkZIxiRmMAODcuIiHwg6sM9IcEozklj+2713EVEDoj6cAcoHpWqMXcRkW5iItxLctLYsadZy+2JiITERLgX56TR2NqhCcREREJiJtxBt0OKiBwQE+FeonAXETlITIS7eu4iIgeLiXDPGJFITnoyO/boXncREehHuJvZVDNb0e1Pg5nd1Eu7j4X2rzWzVwan3EMrDt0xIyIi/Visw903EFz0GjMLABXAI93bmFk28FvgPHffbmZjBqHWwyrJSWPljrqhflsRkWFpoMMyC4BN7r6tx/ZPE1wgezuAu1eHo7iBKB6VSmXdfjo6u4b6rUVEhp2BhvuV9L5I9hRglJn9zcyWmdlnj760gSnJSaOjy6mqbxnqtxYRGXb6He5mlgwsBB7sZXciMBe4ADgX+J6ZTenlGDeY2VIzW1pTU3OEJfdOs0OKiHxoID3384Hl7r6rl33lwLPu3uTutcCrwKyejdz9Dncvc/eyvLy8I6v4EHQ7pIjIhwYS7lfR+5AMwGPA6WaWaGZpwMnA+qMtbiDGZqUQSDB27FW4i4j0ebcMgJmlA+cAX+y27UsA7r7I3deb2TPAKqALuNPd1wxCvYeUGEigKDuV7brXXUSkf+Hu7k1Abo9ti3q8vgW4JXylDVxxjqb+FRGBGHlC9YCSnDS27W7S1L8iEvdiKtxnjcumrrmdzbVNkS5FRCSiYircT5yYA8BbW/ZEuBIRkciKqXCfNDqd0RnJCncRiXsxFe5mxkkTcxTuIhL3YircAU6akENF3X7Kdb+7iMSx2Av3icE7Nt/eqt67iMSvmAv3qQUjyUxJ1NCMiMS1mAv3QIJx4oQclijcRSSOxVy4A5w0MYfNNU3UNLZGuhQRkYiIyXA/cL+7xt1FJF7FZLiXFmaRmhTQuLuIxK2YDPfkxATmjM/WuLuIxK2YDHeAkybk8u7OBur3t0e6FBGRIRe74T4xB3fNMyMi8Slmw33O+GzSkwO8vKE60qWIiAy5PsPdzKaa2YpufxrM7KZDtD3RzDrM7LLwlzowIxIDnHFsHi+tr9b87iISd/oMd3ff4O6z3X02MBdoBh7p2c7MAsDPgefCXuURWjB9DDsbWlhb2RDpUkREhtRAh2UWAJvcfVsv+74G/AUYNuMgH582BjN4cf2wKUlEZEgMNNyvBO7tudHMioBLgN+Fo6hwGZ0xgtnF2bz07q5IlyIiMqT6He5mlgwsBB7sZfcvgW+5e1cfx7jBzJaa2dKampqBVXqEFkwbw8ryeqobWobk/UREhoOB9NzPB5a7e2/d4DLgPjPbClwG/NbMPtWzkbvf4e5l7l6Wl5d3RAUP1ILp+QC6a0ZE4spAwv0qehmSAXD3ie4+wd0nAA8BX3b3R8NQ31GbVjCSwqwUXtC4u4jEkX6Fu5mlA+cAD3fb9iUz+9JgFRYuZsaC6fm89n4tLe2dkS5HRGRI9Cvc3b3J3XPdvb7btkXuvqiXtte6+0PhLPJonTV9DPvbO3lj8+5IlyIiMiRi9gnV7k6ZlEtqUoAX1umuGRGJD3ER7ilJAc6aPoanVlfR2qGhGRGJfXER7gD/WFbM3uZ2PdAkInEhbsL99GNGMzYrhfvf3hHpUkREBl3chHsgwbhs7jhefb+Gyrr9kS5HRGRQxU24A1w2dxzu8PDy8kiXIiIyqOIq3MfnpjNvUg4PLC2nq0vTAItI7IqrcIfghdXte5p5a6tWaBKR2BV34X5+6VgyRiTywFJdWBWR2BV34Z6aHOCiWYU8tbqKmsbWSJcjIjIo4i7cAa4/YyKdXc5Pn1of6VJERAZFXIb75LwMvjh/Mg+/U8Hrm2ojXY6ISNjFZbgDfPWsYyjOSeV7j66hreOwa4yIiESduA33lKQAP15YyqaaJv5n8eZIlyMiElZxG+4QXED7vOMK+PWL77NjT3OkyxERCZu4DneAHyycQSDB+NETayNdiohI2PQZ7mY21cxWdPvTYGY39WjzGTNbZWarzex1M5s1eCWH19isVG5ccCwvrK/mxfWa711EYkOf4e7uG9x9trvPBuYCzcAjPZptAc5095nAT4A7wl7pILrutIlMzkvnR0+s01J8IhITBjosswDY5O7bum9099fdfW/o5ZvAuHAUN1SSExP48cWlbN/TzH+/oourIhL9BhruVwL39tHm88DTR1ZO5Jx2zGguPH4sv/3bRl1cFZGo1+9wN7NkYCHw4GHafJxguH/rEPtvMLOlZra0pqZmoLUOuu9eMJ1AgvH9x9bgrlkjRSR6DaTnfj6w3N17vepoZscDdwIXu/vu3tq4+x3uXubuZXl5eQOvdpCNzUrlm+dO5eUNNdz9+tZIlyMicsQGEu5XcYghGTMrAR4G/snd3wtHYZFy7akTWDBtDP/x1LusqaiPdDkiIkekX+FuZunAOQQD/MC2L5nZl0Ivvw/kAr8N3S65NOyVDhEz45bLZ5GTnsxX71nOvtaOSJckIjJg/Qp3d29y91x3r++2bZG7Lwr9fL27jzpwy6S7lw1WwUMhJz2ZX191Atv3NPNvj6zW+LuIRJ24f0L1UE6amMNNZ0/h0RWVPL9ODzeJSHRRuB/Glz82mUl56fzsmXfp6NTMkSISPRTuh5EYSODm86axuaaJ+7Usn4hEEYV7H86Zkc+JE0Zx2/Pv06SLqyISJRTufTAzvv3J6dTua9W87yISNRTu/TCnZBSfnFnAHa9uprqxJdLliIj0SeHeT988dxptHV189Z53qG5QwIvI8KZw76eJo9O55fLjWVVex/m/Wswr7w2/uXFERA5QuA/AJSeM44mvns7ojBFc8/u3+NnT79KuWyRFZBhSuA/Qsfkjeeyrp3HVSSUsemUTV97xJhV1+yNdlojIQRTuRyAlKcBP/2Emv77qBDbsbOSCXy/mBT3FKiLDiML9KCycVcgTXzudouxUrv/jUn7y13W0dWiYRkQiT+F+lCaOTucv/+dUrjllPHe9toXLFr3Ott1NkS5LROKcwj0MUpIC/OjiUhZdPZettU1c+OvXeODtHXR2aTZJEYkMhXsYnVdawFM3nsHUgpH8619Wcc5tr/DYigqFvIgMOYV7mI0blcYDXzyFRVfPISkhgRvvW8H5v3qVlTvqIl2aiMSRPsPdzKaGVlc68KfBzG7q0cbM7NdmttHMVpnZnMErefhLSDDOKx3L0zeewe2fPoF9LR1c+rvX+e3fNqoXLyJDos9wd/cNB1ZYAuYCzcAjPZqdDxwb+nMD8LtwFxqNEhKMC48v5Okb53NuaQG/eGYDn7nzTXbsaY50aSIS4wY6LLMA2OTu23psvxj4owe9CWSb2diwVBgDstKSuP2qE7jlsuNZXV7P2f/1Cr984T1a2jsjXZqIxKiBhvuVwL29bC8Cuq9mUR7adhAzu8HMlprZ0pqa+Jqbxcy4vKyY575xJmfPyOeXL7zPgv8MXnDVFAYiEm79DnczSwYWAg8e6Zu5+x3uXubuZXl5eUd6mKhWlJ3Kbz49h3u/MI+RKYnceN8KTv3ZS/zncxuo1DQGIhImA+m5nw8sd/fenrOvAIq7vR4X2iaHcMrkXJ78+hncdU0ZpYWZ3P7yRk7/+Ut879E17G1qi3R5IhLlEgfQ9ip6H5IBeBz4qpndB5wM1Lt71dEWF+sCCcaC6fksmJ7Pjj3N/M/izfx5yXYeX1nJv3xiCp8+qYTEgO5WFZGBM/e+b80zs3RgOzDJ3etD274E4O6LzMyA24HzCN5Nc527Lz3cMcvKynzp0sM2iUsbdjbyoyfW8vqm3cwal8Xtn55DcU5apMsSkWHCzJa5e1mf7foT7oNB4X5o7s6Tq6v49sOrAfj5pcfzyZm6+UhE+h/u+m/+YcgseH/8U18/g0l5GXz5z8v55oMreXPzbt1ZIyL9op77MNfW0cWtz23g969toaPLyRiRyCmTc/n41DGcNW0MBVkpkS5RRIaQhmViTGNLO69v2s0r79XwyoaaD1Z/mjE2kwuOH8vV88aTlZoU4SpFZLAp3GOYu/N+9T5eXF/Ni+t3sXTbXkaOSOTqU8bzudMmkjdyRKRLFJFBonCPI2sr6/nt3zbx1OoqkgIJnF9awBVlxcyblEtCgkW6PBEJI4V7HNpcs4///ftWHl1RQWNLByU5aVxw/FjOmjaGE4qzdc+8SAxQuMexlvZOnl27kweXlvPm5t10dDlZqUmcfsxo5owfxdzxo5gxNpPkRIW9SLTpb7gP5AlViRIpSQEunl3ExbOLaGhp57X3a3np3Wre2LSbJ1dXhdokcM2pE/jqx49hZIouxIrEGvXc48zO+haWb9/Lc2t38uiKSvJGjuBfz53KpXPGaXxeJApoWEb6tHJHHT98Yi3vbK+jKDuV047J5bRjRnPq5NG640ZkmFK4S790dTl/XV3FU6uqeGPzbur3txNIMC48fixfnD+ZGYWZkS5RRLpRuMuAdXY56yobeHxlBfcs2U5TWyfzp+RxfmkBM4uymJI/UhdhRSJM4S5Hpb65nT8t2cYfXt9KTWMrAMmBBKaPHcms4myOH5fN7OJsJuelE5wUVESGgsJdwsLd2b6nmVXl9ayuqGdVeR2ry+tpaguu/1qUncrZ08ewYHo+J07IITU5EOGKRWKbwl0GTWeXs7lmH0u37eXF9dW8trGGlvYuzKB4VBrHjslg+thM5k3KZe74UQp8kTBSuMuQaWnv5PVNtawub+C96kY27trHxpp9dHY5yYEEZpdkc35pAQtnFZKbobtwRI5GWMPdzLKBO4FSwIHPufsb3fZnAX8CSgg+GHWru//v4Y6pcI9t+1o7eHvrHt4MzWT57s5GEhOMj03N46JZhcw/No9R6cmRLlMk6oQ73O8GFrv7nWaWDKS5e123/d8Bstz9W2aWB2wACtz9kCs9K9zjy7s7G3hkeQWPvFNBdWMrCQazirP52JQxnDI5l1nFWYxI1PCNSF/CFu6hXvkKguun9trYzL4NFANfASYAzwNT3P2QywYp3ONTZ5ezqryOv22o4W/v1bCqvA53GJGYwNzxo7hoViGXnFBESpKCXqQ34Qz32cAdwDpgFrAMuNHdm7q1GQk8DkwDRgJXuPuTvRzrBuAGgJKSkrnbtm3r9weS2FTX3MZbW/bw5uY9LH6/hver95GTnszV88Zz6Zwi8jNTFPQi3YQz3MuAN4HT3H2Jmf0KaHD373VrcxlwGvANYDLBnvssd2841HHVc5ee3J03N+/hrtc288L66g+2pyYFyElPJj9zBGOzUynMSuGkibmcNW0MAc2HI3EmnLNClgPl7r4k9Poh4OYeba4DfhYattloZlsI9uLfGkDNEufMjFMm53LK5Fw21+xjyZY97G1uY29TG7v3tVFV38LainpeWLeL/1m8hQm5aXzu9IlcNnccacma4FSkuz7/Rrj7TjPbYWZT3X0DsIDgEE1320PbF5tZPjAV2Bz2aiVuTMrLYFJeRq/7Ojq7eGbtTu5cvIXvP7aW/3hqPdPHZjJjbCYzCjM5rjCLaQUjNZwjca2/d8vMJngrZDLB0L4OuALA3ReZWSHwB2AsYAR78X863DE1LCPhsGzbXv66qpJ1lQ2sq2qgsaUDgECCceyYDCaPyWDMyBGMGZlCYXaKZryUqKeHmCTuuDvle/eztrKBtZXB6RK27W6muqHlg+kSzGBuySjOmZHPhNHpwW1AbkYyJxSP0pz2MuxpJSaJO2ZGcU4axTlpnFdacNC+ptYOttQ28cL6XTy3dhc/ffrdj/x+UXYq/zCniEtOKGLiaE2IJtFNPXeJS1X1+9nT1MaB039TzT7+sryC196vocshKzWJiaPTmTQ6ndOPHc3Fs4t0Z44MCxqWETkCuxpaeHbtTt7b1ciW2iY2Vu9jV0MrU/Iz+NZ50zhr2hj16CWiNCwjcgTyM1P47CkTPnjt7jy9Zie3PLuBz9+9lCn5GSSYsbe5jbrmdkalJTM+N43xuWlMLcjkjGNHc+yYDP0DIBGnnrtIP7R3dnHf2zt4enUV6SMSGZWWRFZqEnua2tm2u4mtu5up3Rdc1KQgM4V5k3LISPmw73RcYRYXzSokY4T6U3J0NCwjMsTK9zbz2vu1vPp+Dcu27aWjM/h3q6PLqd/fTlpygItnFzJvUi7rKhtYum0vayrqKclJY96kXE6elMPpx4wmO02zZcqhKdxFhgl3550dddz31naeWFnF/vZOkgMJzByXxcyiLDbXNrFs6x6a2joZmZLI//3EVK6eN14XcKVXCneRYaihpZ2ttU1MyT/4CdqOzi5Wltdx2/Pv89rGWo4rzOTfLpjB5Lx0RiQFSElKoKPTaWrroLm1k4yUREZr4ZO4pHAXiULuzpOrq/jJX9exq6H1kO0SDBZMz+fqeeM545jRevgqjuhuGZEoZGZceHwhH5s6hhfX76KxpYOW9k5aO7pITDDSRiSSnhzg/ep9PPD2Dp5ft4ui7FQyU5NoDbVLTkwgJz2ZnPRkirJTOWvaGOZNyiU5MSHSH0+GkHruIlGqtaOTZ9bs5KnVVXR2QUpSAiMSA7R2dLKnqY09TW1s293M/vZOMlMSOWvaGHLSR7C/vZOW9k4KslK4et54irJTI/1RZAA0LCMitLR38tr7tTyzdicvv1tNW0cXKcnBMfzKuhYAzi8t4AtnTGJWcfZHfr92XytVdS2UFmXq3v1hQsMyIkJKUoCzZ+Rz9oz8j+yrrNvPH17fyr1LtvPXVVWUFmVy5YklLJxdSGXdfu5avIXHVlTS1tnFjLGZfPHMSVwwcyyJgY8O77g7bZ1dWgd3GFHPXSTO7Wvt4C/Lyrn3re28u7OR5MQE2jq6SE0KcOncIqYWZPKHv29hU00ThVkpzC7JpiAzlcLsFBpbOlhdUc+q8nr2NLUyb1IuFxw/lvOOKyBXd/MMCg3LiMiAuDuryut55J0K8jNTuOqk4g8eqOrqcl56t5p73trO1t1NVNW1sL+9kwSDY8eMpLQoi9yMZJ5ft4sttU0EEoxpBSMpLcyitCiT7LRkttY2saW2ierGVs4rLeCyueO0oMoRCGu4m1k2wcU6SgEHPufub/Ro8zHgl0ASUOvuZx7umAp3kejl7jTs7yAp0Q5a4tDdWV/VyDNrqnhnRx1rKurZ29z+wf7CrBRSkwNsqmkiNz2Za06dwPjcNNZWNrCusoGKuv2kjwgwckRweofTjsnlwuMLGZWup3YPCHe43w0sdvc7zSwZSHP3um77s4HXgfPcfbuZjXH36kMdDxTuIvHA3amsb6FhfzsTctNJTQ7g7ry1ZQ///epmXno3GBPJiQlMzR9JSW4a+9s6adjfTs2+VrbtbiYpYJw1bQxl43No7eikpb2LprYOqhtaqarfT3VjK9MKMvn0ycWcOSX2F00PW7ibWRawApjkh2hsZl8GCt393/pboMJdRLbWNtHS0cnkvAySerlQu66ygYeXl/PoisoPJmYzg7SkAPmZKRRkpZCbMYI3Nu2mdl8rhVkpnHrMaHY1tFC+dz+1ja3Mn5rHP81DRJa0AAAGR0lEQVQbz8kTc2Lijp9whvts4A6Ci2LPApYBN7p7U7c2B4ZjjgNGAr9y9z8e7rgKdxHpr86u4NQLKYkBkgL2kZBu6+jixfW7uOet7ayvaqAwO5XiUWmkjwjwzJqdNLR0MCU/g3Nm5DNuVBpF2amMyRxBwIzgoYwRiQmkJQdIS04kJSlh2P5DEM5wLwPeBE5z9yVm9iugwd2/163N7UAZsABIBd4ALnD393oc6wbgBoCSkpK527ZtG9inEhEZoP1tnTyxspI/LdnGmop6uvpxD8nojBFcMLOAhbMLmVMyis21TfxtQw1/31jLqLRkLjx+LKcdM5rkxAS21jbx6IoKXnmvhtnF2Xzm5BKOGTNy0D5POMO9AHjT3SeEXp8B3OzuF3RrczOQ6u4/CL2+C3jG3R881HHVcxeRodbR2cWuxlYq9u6nprGVLnec4LWB1o4u9rd10tzWyeqKOl5cX01rRxfpyYEPFlifNDqdmn2tNLZ0kJWaxLhRqaytbMAMjivMZMPORto7nZMn5nDFicWcPSOfzJSksH6GsD3E5O47zWyHmU119w0Ee+frejR7DLjdzBKBZOBk4LYjqFtEZNAkBhIoyk7t15QLjS3tPL9uF29t2UNpURZnTsmjOCeN1o5OFr9Xy5Orq9ixp5nvfHIaF80qZGxWKrX7WnlwafCZgW88sJLkQAJnHDuac48rYEZhJhNHp5M+RAu29PdumdkEb4VMBjYD1wFXALj7olCbb4a2dwF3uvsvD3dM9dxFJFZ1dTkryut4alUVT62uorK+5YN9BZkpfP70iXxh/qQjOrYeYhIRGQa6upyNNfvYVL2PzbVNbKrex5lT87h4dtERHU9zy4iIDAMJCcaU/JFMyR+8i6y9vu+QvpuIiAwJhbuISAxSuIuIxCCFu4hIDFK4i4jEIIW7iEgMUriLiMQghbuISAyK2BOqZlYDHOm0kKOB2jCWE+30fRxM38eH9F0cLBa+j/HuntdXo4iF+9Ews6X9efw2Xuj7OJi+jw/puzhYPH0fGpYREYlBCncRkRgUreF+R6QLGGb0fRxM38eH9F0cLG6+j6gccxcRkcOL1p67iIgcRtSFu5mdZ2YbzGxjaO3WuGFmxWb2spmtM7O1ZnZjaHuOmT1vZu+H/ndUpGsdSmYWMLN3zOyvodcTzWxJ6By538ySI13jUDGzbDN7yMzeNbP1ZnZKvJ4fZvbPob8na8zsXjNLiadzI6rC3cwCwG+A84EZwFVmNiOyVQ2pDuBf3H0GMA/4Sujz3wy86O7HAi+GXseTG4H13V7/HLjN3Y8B9gKfj0hVkfErgovTTwNmEfxe4u78MLMi4OtAmbuXAgHgSuLo3IiqcAdOAja6+2Z3bwPuAy6OcE1Dxt2r3H156OdGgn9xiwh+B3eHmt0NfCoyFQ49MxsHXEBwjV/MzICzgIdCTeLm+zCzLGA+cBeAu7e5ex3xe34kAqlmlgikAVXE0bkRbeFeBOzo9ro8tC3umNkE4ARgCZDv7lWhXTuB/AiVFQm/BP6V4MLsALlAnbt3hF7H0zkyEagB/jc0THWnmaUTh+eHu1cAtwLbCYZ6PbCMODo3oi3cBTCzDOAvwE3u3tB9nwdvf4qLW6DM7EKg2t2XRbqWYSIRmAP8zt1PAJroMQQTL+dH6LrCxQT/wSsE0oHzIlrUEIu2cK8Airu9HhfaFjfMLIlgsP/Z3R8Obd5lZmND+8cC1ZGqb4idBiw0s60Eh+jOIjjmnB36T3GIr3OkHCh39yWh1w8RDPt4PD/OBra4e427twMPEzxf4ubciLZwfxs4NnTFO5ngBZLHI1zTkAmNJ98FrHf3/+q263HgmtDP1wCPDXVtkeDu33b3ce4+geC58JK7fwZ4Gbgs1Cyevo+dwA4zmxratABYR3yeH9uBeWaWFvp7c+C7iJtzI+oeYjKzTxIcZw0Av3f3/xfhkoaMmZ0OLAZW8+EY83cIjrs/AJQQnGnzH919T0SKjBAz+xjwf939QjObRLAnnwO8A1zt7q2RrG+omNlsgheXk4HNwHUEO3Fxd36Y2Y+AKwjeZfYOcD3BMfa4ODeiLtxFRKRv0TYsIyIi/aBwFxGJQQp3EZEYpHAXEYlBCncRkRikcBcRiUEKdxGRGKRwFxGJQf8fvsLO/I0iSRgAAAAASUVORK5CYII=\n", 837 | "text/plain": [ 838 | "
" 839 | ] 840 | }, 841 | "metadata": {}, 842 | "output_type": "display_data" 843 | } 844 | ], 845 | "source": [ 846 | "import gc\n", 847 | "torch.cuda.empty_cache()\n", 848 | "gc.collect() \n", 849 | "train_network(network,train_loader,val_loader)" 850 | ] 851 | }, 852 | { 853 | "cell_type": "code", 854 | "execution_count": 93, 855 | "metadata": {}, 856 | "outputs": [ 857 | { 858 | "name": "stderr", 859 | "output_type": "stream", 860 | "text": [ 861 | "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:193: UserWarning: Couldn't retrieve source code for container of type MHLinearGRU. It won't be checked for correctness upon loading.\n", 862 | " \"type \" + obj.__name__ + \". It won't be checked \"\n", 863 | "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:193: UserWarning: Couldn't retrieve source code for container of type MultiHeadedAttention. It won't be checked for correctness upon loading.\n", 864 | " \"type \" + obj.__name__ + \". It won't be checked \"\n" 865 | ] 866 | } 867 | ], 868 | "source": [ 869 | "torch.save(network,\"network_bn_bn5mh_att_linear_ml.p\")" 870 | ] 871 | }, 872 | { 873 | "cell_type": "code", 874 | "execution_count": 7, 875 | "metadata": {}, 876 | "outputs": [], 877 | "source": [ 878 | "import numpy as np, scipy.stats as st\n", 879 | "import numpy as np\n", 880 | "import scipy as sp\n", 881 | "import scipy.stats\n", 882 | "\n", 883 | "def mean_confidence_interval(data, confidence=0.95,num_parts = 5):\n", 884 | " part_len = len(data)//num_parts\n", 885 | " estimations = []\n", 886 | " for i in range(num_parts):\n", 887 | " est = np.mean(data[part_len*i:part_len*(i+1)])\n", 888 | " estimations.append(est)\n", 889 | " a = 1.0*np.array(estimations)\n", 890 | " n = len(a)\n", 891 | " m, se = np.mean(a), scipy.stats.sem(a)\n", 892 | " h = se * sp.stats.t._ppf((1+confidence)/2., n-1)\n", 893 | " return m, h\n", 894 | "\n", 895 | "def validate_mrr(network,k,test_loader):\n", 896 | " network.eval()\n", 897 | " losses = []\n", 898 | " with torch.no_grad():\n", 899 | " for user_batch_ix,item_batch_ix, mask_batch_ix in test_loader:\n", 900 | " user_batch_ix = Variable(user_batch_ix).to(device)\n", 901 | " item_batch_ix = Variable(item_batch_ix).to(device)\n", 902 | " mask_batch_ix = Variable(mask_batch_ix).to(device)\n", 903 | "\n", 904 | " logp_seq = network(user_batch_ix, item_batch_ix)\n", 905 | " # compute loss\n", 906 | " predictions_logp = logp_seq[:,-2]\n", 907 | " _,ind = torch.topk(predictions_logp, k,dim=-1)\n", 908 | " mrr = torch.zeros(predictions_logp.size())\n", 909 | " mrr.scatter_(-1,ind.cpu(),1/torch.range(1,k).repeat(*ind.size()[:-1],1).type(torch.FloatTensor).cpu())\n", 910 | " actual_next_tokens = item_batch_ix[:, -1]\n", 911 | "\n", 912 | " logp_next = torch.gather(mrr.to(device)*mask_batch_ix[:, -2,None], dim=1, index=actual_next_tokens[:,None])\n", 913 | "# if mask_batch_ix[:,-2].sum() >0:\n", 914 | " loss = logp_next.sum()/mask_batch_ix[:,-2].sum()\n", 915 | " losses.append(loss.cpu().data.numpy())\n", 916 | " torch.cuda.empty_cache()\n", 917 | " gc.collect() \n", 918 | " m, h = mean_confidence_interval(losses)\n", 919 | " return m, h\n", 920 | "\n", 921 | "def validate_recall(network,k,test_loader):\n", 922 | " torch.cuda.empty_cache()\n", 923 | " gc.collect() \n", 924 | " \n", 925 | " network.eval()\n", 926 | " losses = []\n", 927 | " with torch.no_grad():\n", 928 | " for user_batch_ix,item_batch_ix, mask_batch_ix in test_loader:\n", 929 | " user_batch_ix = Variable(user_batch_ix).to(device)\n", 930 | " item_batch_ix = Variable(item_batch_ix).to(device)\n", 931 | " mask_batch_ix = Variable(mask_batch_ix).to(device)\n", 932 | "\n", 933 | " logp_seq = network(user_batch_ix, item_batch_ix)\n", 934 | " # compute loss\n", 935 | " predictions_logp = logp_seq[:, -2]\n", 936 | " minus_kth_biggest_logp,_ = torch.kthvalue(-predictions_logp.cpu(), k,dim=-1,keepdim=True)\n", 937 | " prediicted_kth_biggest = (predictions_logp>(-minus_kth_biggest_logp.to(device)))\\\n", 938 | " .type(torch.FloatTensor).to(device)\n", 939 | " actual_next_tokens = item_batch_ix[:, -1]\n", 940 | "\n", 941 | " logp_next = torch.gather(prediicted_kth_biggest*mask_batch_ix[:, -2,None], dim=1, index=actual_next_tokens[:,None])\n", 942 | " loss = logp_next.sum()/mask_batch_ix[:,-2].sum()\n", 943 | " losses.append(loss.cpu().data.numpy())\n", 944 | " torch.cuda.empty_cache()\n", 945 | " gc.collect() \n", 946 | " m, h = mean_confidence_interval(losses)\n", 947 | " return m, h\n", 948 | "\n", 949 | "def print_scores(model,name):\n", 950 | " network = torch.load(model).to(device)\n", 951 | " mrr_score, h = validate_mrr(network,20,test_loader)\n", 952 | " print(\"MRR@20 score for \", name,\": \",mrr_score,\"±\",h)\n", 953 | " recall_score, h = validate_recall(network,20,test_loader)\n", 954 | " print(\"Recall@20 score for \"+name+\": \",recall_score,\"±\",h)\n", 955 | " " 956 | ] 957 | }, 958 | { 959 | "cell_type": "markdown", 960 | "metadata": {}, 961 | "source": [ 962 | "### Training on LastFM" 963 | ] 964 | }, 965 | { 966 | "cell_type": "code", 967 | "execution_count": 29, 968 | "metadata": {}, 969 | "outputs": [ 970 | { 971 | "data": { 972 | "text/plain": [ 973 | "0" 974 | ] 975 | }, 976 | "execution_count": 29, 977 | "metadata": {}, 978 | "output_type": "execute_result" 979 | } 980 | ], 981 | "source": [ 982 | "import gc\n", 983 | "torch.cuda.empty_cache()\n", 984 | "gc.collect() " 985 | ] 986 | }, 987 | { 988 | "cell_type": "code", 989 | "execution_count": 30, 990 | "metadata": {}, 991 | "outputs": [], 992 | "source": [ 993 | "device = torch.device(\"cuda:6\" if torch.cuda.is_available() else \"cpu\")\n", 994 | "# device = torch.device(\"cpu\")\n", 995 | "n_users = int(np.max([lf_train_users.max()+1,lf_val_users.max()+1,lf_test_users.max()])+1)\n", 996 | "n_items = int(np.max([lf_train_items.max()+1,lf_val_items.max()+1,lf_test_items.max()])+1)\n", 997 | "# with torch.cuda.device(2):\n", 998 | "network = MHLinearGRU(n_users=n_users,n_items=n_items).to(device)\n", 999 | "# network = torch.load(\"network_linear_lf_bs128.p\").to(device)\n" 1000 | ] 1001 | }, 1002 | { 1003 | "cell_type": "code", 1004 | "execution_count": 31, 1005 | "metadata": {}, 1006 | "outputs": [], 1007 | "source": [ 1008 | "import torch.utils.data\n", 1009 | "opt = torch.optim.SGD(network.parameters(),lr =1)\n", 1010 | "\n", 1011 | "history = []\n", 1012 | "batch_size = 100\n", 1013 | "train_loader = torch.utils.data.DataLoader(\\\n", 1014 | " torch.utils.data.TensorDataset(\\\n", 1015 | " *(torch.LongTensor(lf_train_users),torch.LongTensor(lf_train_items),torch.FloatTensor(lf_train_mask))),\\\n", 1016 | " batch_size=batch_size,shuffle=True)\n", 1017 | "\n", 1018 | "val_loader = torch.utils.data.DataLoader(\\\n", 1019 | " torch.utils.data.TensorDataset(\\\n", 1020 | " *(torch.LongTensor(lf_val_users),torch.LongTensor(lf_val_items),torch.FloatTensor(lf_val_mask))),\\\n", 1021 | " batch_size=batch_size,shuffle=True)\n", 1022 | "test_loader = torch.utils.data.DataLoader(\\\n", 1023 | " torch.utils.data.TensorDataset(\\\n", 1024 | " *(torch.LongTensor(lf_test_users),torch.LongTensor(lf_test_items),torch.FloatTensor(lf_test_mask))),\\\n", 1025 | " batch_size=batch_size,shuffle=True)\n" 1026 | ] 1027 | }, 1028 | { 1029 | "cell_type": "code", 1030 | "execution_count": null, 1031 | "metadata": {}, 1032 | "outputs": [ 1033 | { 1034 | "data": { 1035 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAEWCAYAAABbgYH9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd4FWX6//H3J/QuJSC9SZEuRjpBXKSqiIKia1kb9gK7ul93bev6dV13l2oDFF0bFmwISLMQQIpBIfSOAiJFlC4t9++PM/y+x0jIAXJyTpL7dV3nyswzM8/cQ0LuzDwz98jMcM4557JbQqwDcM45lzd5gnHOORcVnmCcc85FhScY55xzUeEJxjnnXFR4gnHOORcVnmCcOwFJtSSZpILB/CeSro9k3VPY118kvXg68ToXTzzBuDxN0mRJjx+nvbekH042GZhZDzP7bzbEdb6kTRn6ftLMbj7dvp2LF55gXF73X+AaScrQfi3whpkdiUFMcet4Cfdkk7BC/HeL8wTj8rwPgfJAx2MNksoCFwGvBvO9JH0jabekjZIey6wzSV9IujmYLiDp35J2SFoH9Mqw7g2SlkvaI2mdpFuD9hLAJ0AVSXuDTxVJj0l6PWz7SyQtlfRzsN+zw5ZtkPQnSWmSdkl6W1LRE8R9YxDLT5KmSKoZtswk3SlpNbD6BG3tJH0V7O8rSe0y/Lv8r6TZwH6gTmaxuPzDE4zL08zsAPAOcF1Y8xXACjNbFMzvC5afQShJ3C7p0gi6v4VQojoHSAL6Zli+LVheGrgBGCKppZntA3oA35tZyeDzffiGkuoDY4H7gERgEvCxpMIZjqM7UBtoBvzheEFK6g38Bbgs6Gtm0He4S4HWQKPjtUkqB0wEhhNK2IOBiZLKh61/LTAAKAV8e7xYXP7iCcblB/8F+ob9hX9d0AaAmX1hZovNLN3M0gj98u0UQb9XAEPNbKOZ7QT+Eb7QzCaa2VoLmQFMJexMKgtXAhPNbJqZHQb+DRQD2oWtM9zMvg/2/THQIpO+bgP+YWbLg0uCTwItws9iguU7g4R8vLZewGoze83MjpjZWGAFcHHY+q+Y2dJg+eEIj9PlYZ5gXJ5nZrOAHcClkuoCrYA3jy2X1FrS55K2S9pF6BdyhQi6rgJsDJv/1V/tknpImitpp6SfgZ4R9nus7//fn5mlB/uqGrbOD2HT+4GSmfRVExgWXGr7GdgJKENfG4+zXXjbr+IJfBtBHy4f8wTj8otXCZ25XANMMbOtYcveBMYD1c2sDPACoV/AWdkCVA+br3FsQlIR4D1CZx6VzOwMQpe5jvWbVRnz7wklhmP9KdjX5gjiymgjcKuZnRH2KWZmX4atc7x4wtt+FU+gRoZ4vDS7+xVPMC6/eBXoQmjcJONtxqWAnWb2i6RWwNUR9vkOcI+kasGNA/8TtqwwUATYDhyR1APoGrZ8K1BeUpkT9N1L0u8kFQL+CBwEvsxk/RN5AXhQUmMASWUk9TvJPiYB9SVdLamgpCsJjddMOIV4XD7hCcblC2a2gdAv5xKEzlbC3QE8LmkP8AihX+6RGA1MARYBXwPvh+1vD3BP0NdPhJLW+LDlKwiN9awLLl1VyRDvSkJnWyMIXd67GLjYzA5FGFt4Xx8A/wTekrQbWELoJoOT6eNHQjcs/BH4EXgAuMjMdpxsPC7/kL9wzDnnXDT4GYxzzrmo8ATjnHMuKjzBOOeciwpPMM4556LilMqK5xUVKlSwWrVqxToM55zLVRYsWLDDzBKzWi9fJ5hatWqRmpoa6zCccy5XkRRRrTm/ROaccy4qPME455yLCk8wzjnnosITjHPOuajwBOOccy4qPME455yLCk8wzjnnosITzCn45fBRHhu/lG27f4l1KM45F7c8wZyCRRt/5s3539Fl8AzeSd2Iv/LAOed+yxPMKWhdpzyT7+1IwzNL88C4NK59aT4bd+6PdVjOORdXopZgJI2RtE3SkrC2fpKWSkqXlHSCbTdIWixpoaTUsPa3g7aFwToLw5Y9KGmNpJWSukXruI6pk1iStwa04e+XNuGb736i65AUxsxaz9F0P5txzjmI7hnMK0D3DG1LgMuAlAi272xmLczs/yciM7syaGsBvEfwilpJjYD+QONgn89JKnD6h3BiCQni2jY1mTqoE63rlOPxCcvo98KXrN66J9q7ds65uBe1BGNmKcDODG3Lg3eNnxZJAq4g9E5zgN7AW2Z20MzWA2uAVqe7n0hVPaMYL//hPIZc2Zx1O/bRa/gsRny6msNH03MqBOecizvxOgZjwFRJCyQNOM7yjsBWM1sdzFcFNoYt3xS0/YakAZJSJaVu37492wKWRJ9zqjF9UCcubFyJ/0xbxcUjZrF4065s24dzzuUm8ZpgOphZS6AHcKek5AzLr+L/zl5OipmNMrMkM0tKTMzydQYnrULJIjx7dUtGXnsuO/cdovezs/jHJ8v55fDRbN+Xc87Fs7hMMGa2Ofi6DfiAsMtdkgoSGsd5O2yTzUD1sPlqQVvMdGt8JtMGdeKKpOqMnLGOHsNmMm/dj7EMyTnnclTcJRhJJSSVOjYNdCV0c8AxXYAVZrYprG080F9SEUm1gXrA/JyKOTNlihXiqcub8cbNrTmSns6Vo+by0IeL2fPL4ViH5pxzURfN25THAnOABpI2SbpJUh9Jm4C2wERJU4J1q0iaFGxaCZglaRGhJDHRzCaHdd2fDJfHzGwp8A6wDJgM3GlmcXNNqv1ZFZhyXzI3dajNG/O+o+uQFD5fsS3WYTnnXFQpPz+FnpSUZDn9yuSvv/uJP49LY/W2vVzaogqPXNyYciUK52gMzjl3OiQtCH+EJDNxd4ksr2tZoywT7unAPb+rx4S0LVw4eAYfL/rey8045/IcTzAxUKRgAQZdWJ+P7+5A1bLFuHvsN9zy6gK2evFM51we4gkmhs6uXJr3b2/HX3uezczV2+kyeAZvzf/Oz2acc3mCJ5gYK1gggVuS6zDlvmQaVS7N/7y/mKtHz+PbH/fFOjTnnDstnmDiRK0KJRh7Sxue7NOUxZt30W1oCi/OXOfFM51zuZYnmDiSkCCubl2DaYOSaVe3Ak9MXM5lz3/Jyh+8eKZzLvfxBBOHKpcpxkvXJzGsfws27tzPRSNmMnT6Kg4d8eKZzrncwxNMnJJE7xZVmTYwmZ5NKzN0+mouHjGLRRt/jnVozjkXEU8wca58ySIM638OL16XxK4Dh+nz3Gz+d+IyDhyKm0IFzjl3XJ5gcokujSoxdVAy/VvVYPTM9XQbmsKXa3fEOiznnMuUJ5hcpHTRQjzZpylv3tIaCa4ePY8H31/Mbi+e6ZyLQ55gcqF2dSsw+d5kBiTX4e2vvuPCwTOYvmxrrMNyzrlf8QSTSxUrXIC/9DybD+5oT9nihbn51VTuGfsNP+49GOvQnHMO8AST6zWvfgbj7+rAwC71+WTJFroMnsFHCzd7uRnnXMx5gskDChdM4N4u9Zh4T0dqli/BvW8t5Ob/prJl14FYh+acy8c8weQh9SuV4r3b2/FQr7OZvXYHFw5O4Y1535Lu5WacczHgCSaPKZAgbu5Yh6n3daJZtTL89YMlXDV6Lut3ePFM51zOiuYrk8dI2iZpSVhbP0lLJaVLyvRtaJI2SFosaaGk1AzL7pa0Iujn6aCtlqQDwfoLJb0QrePKLWqUL84bN7fmqcuasuz73XQfmsKolLUcOerlZpxzOaNgFPt+BXgGeDWsbQlwGTAygu07m9mvniSU1BnoDTQ3s4OSKoYtXmtmLU4v5LxFEv1b1eD8BhV56MMlPDlpBRPStvDPy5txduXSsQ7POZfHRe0MxsxSgJ0Z2pab2crT6PZ24CkzOxj0t+00+so3zixTlNHXncszV5/D5p8OcPGIWQyetoqDR7zcjHMueuJ1DMaAqZIWSBoQ1l4f6ChpnqQZks4LW1Zb0jdBe8fMOpY0QFKqpNTt27dHK/64I4mLmlVh+qBOXNy8CsM/Xc1Fw2fx9Xc/xTo051weFa8JpoOZtQR6AHdKSg7aCwLlgDbA/cA7kgRsAWqY2TnAIOBNSce9BmRmo8wsycySEhMTo34g8aZsicIMubIFL//hPPYePMLlz3/J4x8vY/+hI7EOzTmXx8RlgjGzzcHXbcAHQKtg0SbgfQuZD6QDFczsoJn9GGyzAFhL6GzHZaJzw4pMHZjM71vXYMzsUPHM2Wu8eKZzLvvEXYKRVEJSqWPTQFdCNwcAfAh0DpbVBwoDOyQlSioQtNcB6gHrcjr23KZU0UI8cWlT3h7QhoIJCfz+xXn8eVwauw548Uzn3OmL5m3KY4E5QANJmyTdJKmPpE1AW2CipCnBulUkTQo2rQTMkrQImA9MNLPJwbIxQJ3g1ue3gOstVBMlGUiTtBAYB9xmZr+6wcBlrnWd8nxyb0du61SXcV9v4sLBM5i69IdYh+Wcy+WUn2tWJSUlWWpqatYr5iOLN+3igffSWL5lN72aVeaxixuTWKpIrMNyzsURSQvMLNNnGY+Ju0tkLraaVivD+Lva86eu9Zm2dCsXDpnB+19v8uKZzrmT5gnG/UahAgncdUE9Jt3bgToVSjDonUXc8MpXbP7Zi2c65yLnCcZl6qyKpXj3tnY8enEj5q3bSdfBM3htzgYvnumci4gnGHdCBRLEDe1rM3VgMi1rluXhj5bSf9Rc1m3fG+vQnHNxzhOMi0j1csV59cZW/KtvM1b8sJvuw2by/BdePNM5lzlPMC5ikuiXVJ3pgzrRuUEi/5y8gkufm83S73fFOjTnXBzyBONOWsXSRRl5bRLP/74lP+w6yCXPzOZfU1bwy2Evnumc+z+eYNwp69G0MtMHJXNpi6o8+/laeg2fyYJv/flW51yIJxh3Ws4oXpj/XNGc/97Yil8Op9P3hTk8Nn4p+w568Uzn8jtPMC5bdKqfyJSByVzXpib/nbOBrkNSSFmVf16H4Jz7LU8wLtuULFKQv/Vuwju3tqVIoQSuGzOfP727iF37vXimc/mRJxiX7c6rVY5J93TkjvPr8sE3m+kyZAaTl2yJdVjOuRzmCcZFRdFCBXige0M+urM9iSWLcNvrX3P76wvYtueXWIfmnMshnmBcVDWpWoaP7mrP/d0a8OmKbVw4OIV3Uzd68Uzn8gFPMC7qChVI4M7OZzHpno7Uq1iS+8elcd2Y+WzcuT/WoTnnouiECUZSAUkDcyoYl7edVbEk79zalsd7N+brb3+i29AUXpm93otnOpdHnTDBmNlR4KocisXlAwkJ4rq2tZgyMJmkWuV47ONlXDFyDmu2efFM5/KaSC6RzZb0jKSOkloe+2S1kaQxkrYFrzc+1tZP0lJJ6ZIyfRuapA2SFktaKCk1w7K7Ja0I+nk6rP1BSWskrZTULYLjcjFUrWxx/nvDefynX3NWb9tLz2EzefbzNRz24pnO5RkFI1inRfD18bA2Ay7IYrtXgGeAV8PalgCXASMj2G9nM9sR3iCpM9AbaG5mByVVDNobAf2BxkAVYLqk+sEZmItTkrj83Gok10/k0fFL+NeUlUxM28LTfZvRpGqZWIfnnDtNWSYYM+t8Kh2bWYqkWhnalkPoF8spuh14yswOBv1tC9p7A28F7eslrQFaAXNOdUcu5ySWKsJzvz+XyUu28PBHS+n97GwGJNfh3t/Vo2ihArEOzzl3irK8RCapjKTBklKDz38kRfvPSwOmSlogaUBYe32go6R5kmZIOi9orwpsDFtvU9D2G5IGHDuW7du9lEk86d6kMtMHduLyllV5/ou19Bw2k682ePFM53KrSMZgxgB7gCuCz27g5WgGBXQws5ZAD+BOSclBe0GgHNAGuB94Ryd5OmRmo8wsycySEhMTszVod/rKFC/E032b8/pNrTl0NJ1+L8zhkY+WsNeLZzqX60SSYOqa2aNmti74/A2oE82gzGxz8HUb8AGhy10QOjN530LmA+lABWAzUD2si2pBm8ulOtSrwJT7krmhfS1em/stXQfP4POV27Le0DkXNyJJMAckdTg2I6k9cCBaAUkqIanUsWmgK6GbAwA+BDoHy+oDhYEdwHigv6QikmoD9YD50YrR5YwSRQry6MWNGXdbO4oXKcgNL3/FoLcX8tO+Q7EOzTkXgUgSzG3As8GtwxsI3Rl2a1YbSRpLaJC9gaRNkm6S1EfSJqAtMFHSlGDdKpImBZtWAmZJWkQoSUw0s8nBsjFAneDW57eA64OzmaXAO8AyYDJwp99BlnecW7MsE+/pwN0XnMX4Rd9z4ZAZTEzb4uVmnItzOtF/UkkJQF8ze0dSaQAz251TwUVbUlKSpaamZr2iixvLvt/Nn99LY/HmXXRtVIknLm1CxdJFYx2Wc/mKpAVmlumzjMdk9SR/OvBAML07LyUXlzs1qlKaD+5ox4M9GjJj1XZ+N3gG73zlxTOdi0eRXCKbLulPkqpLKnfsE/XInMtEwQIJ3NqpLp/c25GzK5fmgffSuPYlL57pXLw54SUyAEnrj9NsZhbVO8lygl8iy/3S040353/HU5+s4Gi6cX+3BlzfrhYFEk75YV7nXBay5RJZMAZzjZnVzvDJ9cnF5Q0JCeKaNjWZOjCZ1nXK8fiEZfR94UtWb90T69Ccy/ciGYN5Jodice6UVTmjGC//4TyGXtmCDTv20Wv4LIZ/uppDR7x4pnOxEskYzKeSLj/ZJ+ady2mSuPScqkwb1IluTc5k8LRVXPLMLNI2/Rzr0JzLlyJJMLcC7wIHJe2WtEeS303m4laFkkUYcdU5jL4uiZ/2H+LSZ2fzj0nL+eWwPxrlXE7KMsGYWSkzSzCzwmZWOpgvnRPBOXc6LmxUiakDO3HledUZmbKO7kNTmLvux1iH5Vy+kWmCkXRN2HT7DMvuimZQzmWXMsUK8Y/LmvHmza1JN+g/ai5//WAxe345HOvQnMvzTnQGMyhsekSGZTdGIRbnoqbdWRWYfF9Hbu5Qm7Hzv6PrkBQ+W7E11mE5l6edKMEok+njzTsX94oXLshDFzXivdvbUapoQW58JZX73vqGnV4807moOFGCsUymjzfvXK5xTo2yTLi7I/f+rh4TF2+hy+AZjF/0vZebcS6bnSjBNJSUJmlx2PSx+QY5FJ9zUVG4YAIDL6zPx3d3oHrZYtwz9htueXUBP+z6JdahOZdnZFoqRlLNE21oZt9GJaIc5KViHMDRdGPMrPX8Z9pKCiUk8JdeZ9P/vOr4o1/OHV+kpWKyrEWWl3mCceE27NjH/7yfxtx1O2lbpzxPXd6UmuVLxDos5+JOttQicy4/qVWhBG/e3IYn+zRlyeZddBuawosz13E0Pf/+Eebc6YhagpE0RtK24O2Tx9r6SVoqKV1SptkveHvmYkkLJaWGtT8maXPQvlBSz6C9lqQDYe0vROu4XN6WkCCubl2DqYOSaV+3Ak9MXM5lz3/Jyh+8eKZzJyvLBCPp4qCq8sl6BeieoW0JcBmQEsH2nc2sxXFOw4YE7S3MbFJY+9qw9ttOIV7n/r/KZYrx4vVJDL/qHDbu3M9FI2YydPoqL57p3EmIJHFcCayW9LSkhpF2bGYpwM4MbcvNbOVJxuhcTEjikuZVmD6oEz2bVmbo9NVcPGIWCzd68UznIhFJLbJrgHOAtcArkuZIGiCpVBTjMmCqpAWSBmRYdldwu/QYSWXD2mtL+kbSDEkdoxiby2fKlSjMsP7n8NL1Sew6cJjLnpvNExOWceCQF8907kQiuvRlZruBccBbQGWgD/C1pLujFFcHM2sJ9ADulJQctD8P1AVaAFuA/wTtW4AaZnYOoRI3b0o6bkHOIDmmSkrdvn17lMJ3edHvzq7E1EHJ9G9Vgxdnrafb0BS+XLsj1mE5F7ciGYO5RNIHwBdAIaCVmfUAmgN/jEZQZrY5+LoN+ABoFcxvNbOjwYvQRoe1HzSzH4PpBYTOtupn0vcoM0sys6TExMRohO/ysNJFC/Fkn6aMvaUNCYKrR8/jwffT2O3FM537jUjOYC4nNLDe1Mz+FfzSx8z2Azdld0CSShy7/CapBNCV0M0BSKoctmqfsPZESQWC6TpAPWBddsfm3DFt65bnk3uTuTW5Dm9/tZELB89g+jIvnulcuEjGYK4HVgVnMhdLOjNs2aeZbSdpLDAHaCBpk6SbJPWRtAloC0yUNCVYt4qkY3eEVQJmSVoEzAcmmtnkYNnTwe3LaUBnYGDQngykSVpI6FLebWb2qxsMnMtuxQoX4MGeZ/Phne0pW7wwN7+ayt1jv+HHvQdjHZpzcSHLJ/kl3QQ8CnxGqIpyJ+BxMxsT/fCiy5/kd9nl0JF0XpixlhGfraZkkYI8dkljLmlexcvNuDwp20rFSFoJtDs2xiGpPPClmeX6gpeeYFx2W7V1Dw+MS2Phxp+5oGFFnri0CVXOKBbrsJzLVtlZKuZHIPwx5j1Bm3Mug/qVSvHe7e14+KJGzFn7I12HpPD63G9J93IzLh+KJMGsAeYFZVoeBeYSGpMZJGlQFts6l+8USBA3dajNlPuSaV69DA99uISrRs9l/Y59sQ7NuRwVSYJZC3zI/71k7CNgPVAq+DjnjqNG+eK8flNrnr68Gcu27Kb70BRGzljLkaNebsblDxGX65dUEsDM9kY1ohzkYzAup2zd/QsPfbiEacu20qxaGf55eTPOrnzcZ4Gdi3vZNgYjqYmkb4ClwNKgfEvj7AjSufyiUumijLr2XJ69uiXf/3yAi0fMYvDUlRw84uVmXN4VySWyUcAgM6tpZjUJPb0/OrphOZf3SKJXs8pMG9iJS5pXYfhna+g1fBYLvv0p1qE5FxWRJJgSZvb5sRkz+wLw1/w5d4rKlijM4Ctb8PIN57H/4BH6vvAlf/t4KfsPHYl1aM5lq0gSzDpJDwcv9aol6SG8DItzp61zg4pMGZjMNa1r8vLsDXQdksKs1V480+UdkSSYG4FE4H3gPaBC0OacO02lihbi75c24Z1b21KoQALXvDSPB8YtYtcBL57pcr+CJ1oYFJD8q5ndk0PxOJcvtapdjk/u7ciwT1czKmUdX6zczt8vbUK3xmdmvbFzceqEZzBmdhTokEOxOJevFS1UgD93b8iHd7SnfMki3PraAu5842u27/HimS53OuEZTOAbSeOBd4H//yiymb0ftaicy8eaVivD+LvaMyplHcOmr2bWmh08clEjLmtZ1YtnulwlkjGYooRqj10AXBx8LopmUM7ld4UKJHBn57OYdG8HzqpYkj++u4g/vPwVm38+EOvQnItYJNWU25vZ7KzaciN/kt/lBunpxqtzNvD0lJUI+HOPhlzTuiYJCX4242IjO6spj4iwzTkXBQkJ4g/tQ8UzW9YsyyMfLeXKUXNYuz3PVG1yeVSmYzCS2gLtgMQMVZNLAwWiHZhz7teqlyvOqze2YtyCTfx9wjJ6DJvJfV3qMaBjHQoWiORvRedy1ol+KgsDJQkloVJhn91A36w6ljRG0jZJS8La+klaKildUqanV5I2BK9GXigpNaz9MUmbg/aFknqGLXtQ0hpJKyV1yyo+53IjSfRLqs70P3biggYVeXrySi59bjZLv98V69Cc+41IxmBqmtm3J92xlAzsBV41syZB29lAOjAS+JOZHXcARNIGIMnMdmRofwzYa2b/ztDeCBgLtAKqANOB+sFt1pnyMRiX232yeAsPf7SUn/Yf4rZOdbj7gnoULeQXGFx0RToGE8ltykUkjQJqha9vZhecaCMzS5FUK0Pb8iC4CHZ7UnoDb5nZQWC9pDWEks2c7N6Rc/GkR9PKtK1bnicmLufZz9fyyZIfePryZiTVKhfr0JyLaJD/XeAb4CHg/rBPNBkwNXg1wIAMy+6SlBZcgisbtFUFNoatsylo+w1JAySlSkrdvn179kfuXA47o3hh/t2vOa/e2IqDh9PpN3IOj41fyr6DXjzTxVYkCeaImT1vZvPNbMGxT5Tj6mBmLYEewJ3B5TaA54G6QAtgC/Cfk+3YzEaZWZKZJSUmJmZbwM7FWnL9RKYOTOb6trX475xQ8cyUVf5HlIudSBLMx5LukFRZUrljn2gGZWabg6/bgA8IXe7CzLaa2VEzSyf0TppWwSabgephXVQL2pzLV0oUKchjlzTm3VvbUqRQAteNmc+f3l3Ez/sPxTo0lw9FkmCuJ3RJ7EtgQfCJ2si4pBKSSh2bBroCS4L5ymGr9jnWDowH+ksqIqk2UA+YH60YnYt3SbXKMemejtzZuS4ffLOZLoNT+GTxlliH5fKZLAf5zaz2qXQsaSxwPlBB0ibgUWAnoYc0E4GJkhaaWTdJVYAXzawnUAn4ILgRoCDwpplNDrp9WlILQmM0G4BbgxiXSnoHWAYcAe7M6g4y5/K6ooUKcH+3hvRsWpkHxqVx+xtf073xmTzeuzEVSxeNdXguH4jkNuXiwCCghpkNkFQPaGBmE3IiwGjy25RdfnHkaDqjZ65nyPRVFC2YwMMXNaLvudW8eKY7JdlZKuZl4BChp/ohNLbxxGnE5pzLYQULJHD7+XX55N6ONDizFPePS+O6MfPZuHN/rENzeVgkCaaumT0NHAYws/2A/9njXC5UN7Ekbw9oy997N+brb3+i29AUXpm9nvT0E1/JcO5URJJgDkkqRmjcA0l1AX8DknO5VEKCuLZtLaYMTOa8WuV47ONl9Bs5hzXb9sQ6NJfHRJJgHgUmA9UlvQF8CjwQ1aicc1FXrWxxXrnhPAZf0Zy12/fSc9gsnv18DYePpsc6NJdHZDnIDyCpPNCG0KWxuRlrhOVWPsjvXMj2PQd57OOlTEzbwtmVS/Ovvs1oUrVMrMNycSo7B/kxsx/NbCLHKUDpnMv9EksV4dmrWzLy2nPZsfcgvZ+dzVOfrOCXw363vzt1J/sSiUuiEoVzLi50a3wm0wd2om/LarwwYy09h81k/vqdsQ7L5VInm2D87jHn8rgyxQvxz77NeP2m1hw6ms4VI+fw8IdL2OvFM91JOtkEc25UonDOxZ0O9SowdWAyN7avzevzvqXr4Bl8vnJbrMNyuUiWCUbS05JKSyoETJO0XdI1ORCbcy7GihcuyCMXN2Lcbe0oUaQgN7z8FYPeXshP+7x4pstaJGcwXc1sN3ARofpfZxH998E45+LIuTXLMuGeDtxzwVmMX/S18duNAAAU3ElEQVQ9XQbPYELa90RyF6rLvyJJMMcKYvYC3jUzf/m3c/lQkYIFGNS1AR/f3YEqZxTjrje/4dbXFrB19y+xDs3FqUgSzARJKwiNv3wqKRHwnyjn8qmzK5fmgzva8WCPhsxYtZ0ug2fw9lff+dmM+41IH7QsB+wys6NBdeXSZvZD1KOLMn/Q0rnTs37HPv78Xhrz1++k/Vnl+UefZtQoXzzWYbkoy7YHLSX1Aw4HyeUh4HWgSjbE6JzL5WpXKMFbt7ThiUubsGjjLroNTeGlWes56sUzHZFdInvYzPZI6gB0AV4Cno9uWM653CIhQVzTpiZTBybTtm55/j5hGZc//yWrtnrxzPwukgRzrFZEL2BUUDKmcPRCcs7lRlXOKMZL1ycxrH8Lvv1xH72Gz2T4p6s5dMSLZ+ZXkSSYzZJGAlcCkyQViWQ7SWMkbZO0JKytn6SlktIlZXr9TtIGSYslLZT0m0ESSX+UZJIqBPPnS9oVrL9Q0iMRHJdzLptJoneLqkwf1InuTSozeNoqLnlmFos2/hzr0FwMRJJgrgCmAN3M7GegHJE9B/MK0D1D2xLgMiAlgu07m1mLjANJkqoDXYHvMqw/M1i/hZk9HkH/zrkoKV+yCCOuOofR1yXx0/5D9HluNv+YtJwDh7x4Zn6SZYIJ3mC5Fugm6S6goplNjWC7FGBnhrblZrbyVIMNDCH0PhofRXQuzl3YqBLTBnXiyvOqMzJlHT2GpTB33Y+xDsvlkEgudd0LvAFUDD6vS7o7ynEZMFXSAkkDwmLpDWw2s0XH2aatpEWSPpHUOLOOJQ2QlCopdfv27VEI3TkXrnTRQvzjsma8eXNr0g36j5rLXz5YzO5fDsc6NBdlWT4HIykNaGtm+4L5EsAcM2uWZedSLWCCmTXJ0P4F8CczO+5DKJKqmtlmSRWBacDdQCrwOaHSNbskbSB4P42k0kC6me2V1BMYZmb1sorPn4NxLmcdOHSUwdNW8tKs9VQsVZQnL2vCBQ0rxTosd5Ky84Vj4v/uJCOYjmrZfjPbHHzdBnwAtALqArWBRUFyqQZ8LelMM9ttZnuDbSYBhY7dAOCcix/FChfgr70a8f4d7SlTrBA3vpLKvW99w497D8Y6NBcFkSSYl4F5kh6T9Bgwl9CzMFEhqYSkUsemCQ3oLzGzxWZW0cxqmVktYBPQ0sx+kHSmJAXbtCJ0XH6h17k41aL6GXx8dwfu61KPSYu3cOGQFMYv8uKZeU0kg/yDgRsIDdjvBG4ws6FZbSdpLDAHaCBpk6SbJPWRtAloC0yUNCVYt4qkScGmlYBZkhYB84GJZjY5i931BZYE2wwH+pv/pDoX1woXTOC+LvWZcHdHqpcrzj1jv+GWV1P5YZeXOswrTjgGI6kAsNTMGuZcSDnHx2Cciw9H042XZ6/n31NXUighgb/0Opv+51UnuDDh4ky2jMGY2VFgpaQa2RaZc85lUCBB3NyxDlPuS6ZJ1TI8+P5irh49jw079sU6NHcaIhmDKQsslfSppPHHPtEOzDmX/9QsX4I3b2nNU5c1ZcnmXXQflsLolHVePDOXKpj1Kjwc9Siccy4gif6tanB+g4o89OFi/nfSciakfc/TfZvT4MxSsQ7PnYRMz2AknSWpvZnNCP8Quk15U86F6JzLj84sU5TR1yUx4qpz2PTTAS4aMZMh01Z58cxc5ESXyIYCu4/TvitY5pxzUSWJi5tXYdqgTvRqWplhn67mohEzWejFM3OFEyWYSma2OGNj0FYrahE551wG5UoUZmj/cxjzhyT2/HKEy56bzRMTlrH/0JFYh+ZO4EQJ5owTLCuW3YE451xWLmhYiakDk7mqVQ1enLWe7kNn8uWaHbEOy2XiRAkmVdItGRsl3QwsiF5IzjmXuVJFC/G/fZry1oA2JAiufnEe//NeGrsOePHMeJPpg5aSKhGqA3aI/0soSYTeZtnHzH7IkQijyB+0dC53++XwUYZMX8XolHUklirCE5c25cJGXjwz2iJ90DKSasqdgWPVkJea2WfZEF9c8ATjXN6QtulnHhiXxoof9nBRs8o8dkljKpQsEuuw8qxsSzB5mScY5/KOQ0fSGTljLSM+W0OJIgV49OLG9G5RxcvNREF2lut3zrm4V7hgAnf/rh4T7+lArQoluO/thdz4yld8//OBWIeWb3mCcc7lKfUqlWLcbe145KJGzF23k65DUnht7reke7mZHOcJxjmX5xRIEDd2qM3Ugcm0qH4GD3+4hP6j57Lei2fmKE8wzrk8q3q54rx2UyuevrwZy7fspvvQFF6YsZYjR73cTE7wBOOcy9MkccV51Zk+qBOd6ify1Ccr6PPclyz7/niVsFx28gTjnMsXKpUuyshrz+XZq1uyZdcBLnlmFv+ZupKDR47GOrQ8K6oJRtIYSdskLQlr6ydpqaR0SZne5iZpg6TFkhZK+s29xJL+KMkkVQjmJWm4pDWS0iS1jM5ROedyK0n0alaZaQM7cUmLKoz4bA29hs9iwbc/xTq0PCnaZzCvAN0ztC0BLgNSIti+s5m1yHi/taTqQFfgu7DmHkC94DMAeP4UY3bO5XFlSxRm8BUteOWG8zhw6Ch9X/iSv328lH0HvXhmdopqgjGzFGBnhrblZrbyNLseAjwAhN932Bt41ULmAmdIqnya+3HO5WHnN6jIlIHJXNumJi/P3kC3oSnMXL091mHlGfE8BmPAVEkLJA041iipN7DZzBZlWL8qsDFsflPQ9iuSBkhKlZS6fbv/IDmX35UsUpDHezfhnVvbUrhAAte+NJ8Hxi1i134vnnm64jnBdDCzloQufd0pKVlSceAvwCOn2qmZjTKzJDNLSkxMzK5YnXO5XKva5Zh0b0duP78u7329mS5DZjB5Sa6v6RtTcZtgzGxz8HUboarOrYC6QG1gkaQNQDXga0lnApuB6mFdVAvanHMuIkULFeDP3Rvy0Z3tSSxZhNteX8Cdb3zN9j0HYx1arhSXCUZSCUmljk0TGtBfYmaLzayimdUys1qELoO1DF4dMB64LribrA2wy8y2xOoYnHO5V5OqZfjorvbc360B05ZvpcvgGby3YBP5uTjwqYj2bcpjgTlAA0mbJN0kqY+kTUBbYKKkKcG6VSRNCjatBMyStAiYD0w0s8lZ7G4SsA5YA4wG7ojCITnn8olCBRK4s/NZTLqnI2dVLMkf313E9S9/xaaf9sc6tFzDy/V7uX7nXBbS043X5n7LPyevQMCfezTkmtY1SUjIn68C8HL9zjmXTRISxPXtajHlvmRa1izLIx8t5cpRc1i7fW+sQ4trnmCccy5C1csV59UbW/Hvfs1ZtXUvPYbN5Lkv1nDYi2celycY55w7CZLoe241pg1KpsvZFXl68koufXY2SzbvinVocccTjHPOnYKKpYry3O/P5YVrWrJ190F6Pzubpyev4JfDXjzzGE8wzjl3Gro3qcyngzpx2TlVee6LtfQcPpPUDTuz3jAf8ATjnHOnqUzxQvyrX3NevbEVBw+n02/kHB79aAl783nxTE8wzjmXTZLrJzJ1YDLXt63Fq3O/pduQFGasyr81Dz3BOOdcNipRpCCPXdKYd29tS9FCCVw/Zj5/fGcRP+8/FOvQcpwnGOeci4KkWuWYeE9H7up8Fh8t3EyXwSl8sjh/Va/yBOOcc1FStFAB/tStAR/d1Z4zyxTh9je+5rbXFrBt9y+xDi1HeIJxzrkoa1ylDB/e0Z4/d2/IZyu30WXwDN5J3Zjni2d6gnHOuRxQsEACt59fl8n3dqThmaV5YFwa142Zz8adebd4picY55zLQXUSS/LWgDb8vXdjvv72J7oNTeHl2es5mp73zmY8wTjnXA5LSBDXtq3F1EGdaFW7HH/7eBlXjJzDmm17Yh1atvIE45xzMVL1jGK8/IfzGHJlc9Zu30vPYbN45rPVeaZ4picY55yLIUn0Oaca0wd14sLGlfj31FVcPGIWizfl/uKZUUswksZI2iZpSVhbP0lLJaVLyvRlNZI2SFosaaGk1LD2v0tKC9qnSqoStJ8vaVfQvlDSI9E6Lueci4YKJYvw7NUtGXntuezcd4hLn5vNU5/k7uKZ0TyDeQXonqFtCXAZkBLB9p3NrEWGt6b9y8yamVkLYAIQnkhmBuu3MLPHTydw55yLlW6Nz2TaoE70bVmNF2aspcewmcxb92OswzolUUswZpYC7MzQttzMVp5Gn7vDZksAee+2C+dcvlemWCH+2bcZb9zcmiPp6Vw5ai4Pf7iEPb8cjnVoJyVex2AMmCppgaQB4Qsk/a+kjcDv+fUZTFtJiyR9IqlxTgbrnHPR0P6sCky5L5mbOtTm9Xmh4pmfr9gW67AiFq8JpoOZtQR6AHdKSj62wMz+ambVgTeAu4Lmr4GaZtYcGAF8mFnHkgZISpWUun17/q1y6pzLHYoXLsjDFzXivdvbUaJIQW545SsGvr2Qnfviv3hmXCYYM9scfN0GfAC0Os5qbwCXB+vtNrO9wfQkoJCkCpn0PcrMkswsKTExMSrxO+dcdmtZoywT7unAPb+rx8eLvufCwTOYkPZ9XJebibsEI6mEpFLHpoGuhG4OQFK9sFV7AyuC9jMlKZhuRei4cueomHPOZaJIwQIMurA+H9/dgapli3HXm98w4LUFbI3T4pnRvE15LDAHaCBpk6SbJPWRtAloC0yUNCVYt4qkScGmlYBZkhYB84GJZjY5WPaUpCWS0gglnnuD9r7AkmCb4UB/i+e07pxzp+HsyqV5//Z2/KVnQ1JWbafL4Bm8/dV3cXc2o3gLKCclJSVZampq1is651yc2rBjH39+L41563fSrm55nrqsGTXKF4/qPiUtyPAIyXHF3SUy55xzkatVoQRjb2nDk32akrZpF92GpvDizHVxUTzTE4xzzuVyCQni6tY1mDYombZ1y/PExOVc/vyXrNoa2+KZnmCccy6PqFymGC9dn8Sw/i34bud+eg2fybDpqzl0JDbFMz3BOOdcHiKJ3i2qMm1gMj2aVGbI9FVc8swsFm38Ocdj8QTjnHN5UPmSRRh+1Tm8eF0SP+8/TJ/nZvPkpOUcOJRzxTM9wTjnXB7WpVElpg5Kpn+rGoxKWUePYSnMWZszjwl6gnHOuTyudNFCPNmnKW/e0hoDrho9lycmLIv6fj3BOOdcPtGubgUm35vMgOQ61IzyszIABaO+B+ecc3GjWOEC/KXn2TmyLz+Dcc45FxWeYJxzzkWFJxjnnHNR4QnGOedcVHiCcc45FxWeYJxzzkWFJxjnnHNR4QnGOedcVOTrN1pK2g58expdVAB2ZFM4uUF+O17wY84v/JhPTk0zS8xqpXydYE6XpNRIXhuaV+S34wU/5vzCjzk6/BKZc865qPAE45xzLio8wZyeUbEOIIflt+MFP+b8wo85CnwMxjnnXFT4GYxzzrmo8ATjnHMuKjzBZEFSd0krJa2R9D/HWV5E0tvB8nmSauV8lNkrgmMeJGmZpDRJn0qqGYs4s1NWxxy23uWSTFKuv6U1kmOWdEXwvV4q6c2cjjG7RfCzXUPS55K+CX6+e8YizuwiaYykbZKWZLJckoYH/x5pklpmawBm5p9MPkABYC1QBygMLAIaZVjnDuCFYLo/8Has486BY+4MFA+mb88PxxysVwpIAeYCSbGOOwe+z/WAb4CywXzFWMedA8c8Crg9mG4EbIh13Kd5zMlAS2BJJst7Ap8AAtoA87Jz/34Gc2KtgDVmts7MDgFvAb0zrNMb+G8wPQ74nSTlYIzZLctjNrPPzWx/MDsXqJbDMWa3SL7PAH8H/gn8kpPBRUkkx3wL8KyZ/QRgZttyOMbsFskxG1A6mC4DfJ+D8WU7M0sBdp5gld7AqxYyFzhDUuXs2r8nmBOrCmwMm98UtB13HTM7AuwCyudIdNERyTGHu4nQX0C5WZbHHFw6qG5mE3MysCiK5PtcH6gvabakuZK651h00RHJMT8GXCNpEzAJuDtnQouZk/3/flIKZldHLv+RdA2QBHSKdSzRJCkBGAz8Icah5LSChC6TnU/oLDVFUlMz+zmmUUXXVcArZvYfSW2B1yQ1MbP0WAeWG/kZzIltBqqHzVcL2o67jqSChE6rf8yR6KIjkmNGUhfgr8AlZnYwh2KLlqyOuRTQBPhC0gZC16rH5/KB/ki+z5uA8WZ22MzWA6sIJZzcKpJjvgl4B8DM5gBFCRWFzKsi+v9+qjzBnNhXQD1JtSUVJjSIPz7DOuOB64PpvsBnFoye5VJZHrOkc4CRhJJLbr8uD1kcs5ntMrMKZlbLzGoRGne6xMxSYxNutojkZ/tDQmcvSKpA6JLZupwMMptFcszfAb8DkHQ2oQSzPUejzFnjgeuCu8naALvMbEt2de6XyE7AzI5IuguYQugOlDFmtlTS40CqmY0HXiJ0Gr2G0GBa/9hFfPoiPOZ/ASWBd4P7Gb4zs0tiFvRpivCY85QIj3kK0FXSMuAocL+Z5dqz8wiP+Y/AaEkDCQ34/yE3/8EoaSyhPxIqBONKjwKFAMzsBULjTD2BNcB+4IZs3X8u/rdzzjkXx/wSmXPOuajwBOOccy4qPME455yLCk8wzjnnosITjHPOuajwBONcDEj6h6TOki6V9GDQ9njwACuS7pNUPLZROnd6/DZl52JA0mdAL+BJYJyZzc6wfAOhis07TqLPAmZ2NFsDde40+IOWzuUgSf8CugG1gTlAXUIVuMcRKiM/AagSfD6XtMPMOkvqCvwNKEKo5PwNZrY3SERvAxcCTxOqEOxcXPBLZM7lIDO7n1C9q1eA84A0M2tmZo+HrTOcUJn4zkFyqQA8BHQxs5ZAKjAorNsfzaylmXlycXHFz2Ccy3ktCb3sqiGwPIL12xB6+dXsoDRPYUJnP8e8nd0BOpcdPME4l0MktSB05lIN2AEUDzVrIdD2RJsC08zsqkyW78vOOJ3LLn6JzLkcYmYLzawFobL3jYDPgG5m1sLMDmRYfQ+h1wRAqHpze0lnAUgqIal+TsXt3KnyBONcDpKUCPwUvMCqoZkty2TVUcBkSZ+b2XZCLzsbKymN0OWxhjkSsHOnwW9Tds45FxV+BuOccy4qPME455yLCk8wzjnnosITjHPOuajwBOOccy4qPME455yLCk8wzjnnouL/AV4n8wQc0zK1AAAAAElFTkSuQmCC\n", 1036 | "text/plain": [ 1037 | "
" 1038 | ] 1039 | }, 1040 | "metadata": {}, 1041 | "output_type": "display_data" 1042 | } 1043 | ], 1044 | "source": [ 1045 | "# 1 lr 10 epch\n", 1046 | "train_network(network.to(device),train_loader,val_loader)" 1047 | ] 1048 | }, 1049 | { 1050 | "cell_type": "code", 1051 | "execution_count": 35, 1052 | "metadata": {}, 1053 | "outputs": [ 1054 | { 1055 | "name": "stderr", 1056 | "output_type": "stream", 1057 | "text": [ 1058 | "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:193: UserWarning: Couldn't retrieve source code for container of type MHLinearGRU. It won't be checked for correctness upon loading.\n", 1059 | " \"type \" + obj.__name__ + \". It won't be checked \"\n", 1060 | "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:193: UserWarning: Couldn't retrieve source code for container of type MultiHeadedAttention. It won't be checked for correctness upon loading.\n", 1061 | " \"type \" + obj.__name__ + \". It won't be checked \"\n" 1062 | ] 1063 | } 1064 | ], 1065 | "source": [ 1066 | "torch.save(network,\"network_mhatt_linear_lf_bs100_lr1.p\")" 1067 | ] 1068 | }, 1069 | { 1070 | "cell_type": "markdown", 1071 | "metadata": {}, 1072 | "source": [ 1073 | "### Scores" 1074 | ] 1075 | }, 1076 | { 1077 | "cell_type": "markdown", 1078 | "metadata": {}, 1079 | "source": [ 1080 | "#### LastFM" 1081 | ] 1082 | }, 1083 | { 1084 | "cell_type": "code", 1085 | "execution_count": 13, 1086 | "metadata": {}, 1087 | "outputs": [ 1088 | { 1089 | "name": "stdout", 1090 | "output_type": "stream", 1091 | "text": [ 1092 | "MRR@20 score for Linear User-based GRU on LastFM : 0.21047361 ± 0.006856432342802942\n", 1093 | "Recall@20 score for Linear User-based GRU on LastFM: 0.3013901 ± 0.006992072808611331\n" 1094 | ] 1095 | } 1096 | ], 1097 | "source": [ 1098 | "print_scores(\"network_linear_lf_bs128_lr1.p\",'Linear User-based GRU on LastFM')" 1099 | ] 1100 | }, 1101 | { 1102 | "cell_type": "code", 1103 | "execution_count": 12, 1104 | "metadata": {}, 1105 | "outputs": [ 1106 | { 1107 | "name": "stdout", 1108 | "output_type": "stream", 1109 | "text": [ 1110 | "MRR@20 score for Rectified Linear User-based GRU on LastFM : 0.22049585 ± 0.004507800212955085\n", 1111 | "Recall@20 score for Rectified Linear User-based GRU on LastFM: 0.3142386 ± 0.005953415133536319\n" 1112 | ] 1113 | } 1114 | ], 1115 | "source": [ 1116 | "print_scores(\"network_rect_linear_lf_bs100_lr1.p\",'Rectified Linear User-based GRU on LastFM')" 1117 | ] 1118 | }, 1119 | { 1120 | "cell_type": "code", 1121 | "execution_count": 18, 1122 | "metadata": {}, 1123 | "outputs": [ 1124 | { 1125 | "name": "stdout", 1126 | "output_type": "stream", 1127 | "text": [ 1128 | "MRR@20 score for Attentional User-based GRU on LastFM : 0.21569276 ± 0.004434641506947223\n", 1129 | "Recall@20 score for Attentional User-based GRU on LastFM: 0.3121555 ± 0.007319821406235953\n" 1130 | ] 1131 | } 1132 | ], 1133 | "source": [ 1134 | "print_scores(\"network_att_linear_lf_bs100_lr1.p\",'Attentional User-based GRU on LastFM')" 1135 | ] 1136 | }, 1137 | { 1138 | "cell_type": "code", 1139 | "execution_count": 36, 1140 | "metadata": {}, 1141 | "outputs": [ 1142 | { 1143 | "name": "stdout", 1144 | "output_type": "stream", 1145 | "text": [ 1146 | "MRR@20 score for Multi-Head Attentional User-based GRU on LastFM : 0.21176393 ± 0.005497847311917814\n", 1147 | "Recall@20 score for Multi-Head Attentional User-based GRU on LastFM: 0.30064663 ± 0.008272520008487811\n" 1148 | ] 1149 | } 1150 | ], 1151 | "source": [ 1152 | "print_scores(\"network_mhatt_linear_lf_bs100_lr1.p\",'Multi-Head Attentional User-based GRU on LastFM')" 1153 | ] 1154 | }, 1155 | { 1156 | "cell_type": "markdown", 1157 | "metadata": {}, 1158 | "source": [ 1159 | "#### MovieLens" 1160 | ] 1161 | }, 1162 | { 1163 | "cell_type": "code", 1164 | "execution_count": 46, 1165 | "metadata": {}, 1166 | "outputs": [ 1167 | { 1168 | "name": "stdout", 1169 | "output_type": "stream", 1170 | "text": [ 1171 | "MRR@20 score for Linear User-based GRU on MovieLens : 0.057497583 ± 0.0007039992875517064\n", 1172 | "Recall@20 score for Linear User-based GRU on MovieLens: 0.18227212 ± 0.003842047864645815\n" 1173 | ] 1174 | } 1175 | ], 1176 | "source": [ 1177 | "print_scores(\"network_linear_ml.p\",'Linear User-based GRU on MovieLens')" 1178 | ] 1179 | }, 1180 | { 1181 | "cell_type": "code", 1182 | "execution_count": 52, 1183 | "metadata": {}, 1184 | "outputs": [ 1185 | { 1186 | "name": "stdout", 1187 | "output_type": "stream", 1188 | "text": [ 1189 | "MRR@20 score for Rectified Linear User-based GRU on MovieLens : 0.060260046 ± 0.0005038204951150433\n", 1190 | "Recall@20 score for Rectified Linear User-based GRU on MovieLens: 0.19106193 ± 0.004563955364643368\n" 1191 | ] 1192 | } 1193 | ], 1194 | "source": [ 1195 | "print_scores(\"network_rect_linear_ml.p\",'Rectified Linear User-based GRU on MovieLens')" 1196 | ] 1197 | }, 1198 | { 1199 | "cell_type": "code", 1200 | "execution_count": 58, 1201 | "metadata": {}, 1202 | "outputs": [ 1203 | { 1204 | "name": "stdout", 1205 | "output_type": "stream", 1206 | "text": [ 1207 | "MRR@20 score for Attentional User-based GRU on MovieLens : 0.065337464 ± 0.0021981462220191897\n", 1208 | "Recall@20 score for Attentional User-based GRU on MovieLens: 0.1987088 ± 0.0027508449974647546\n" 1209 | ] 1210 | } 1211 | ], 1212 | "source": [ 1213 | "print_scores(\"network_att_linear_ml.p\",'Attentional User-based GRU on MovieLens')" 1214 | ] 1215 | }, 1216 | { 1217 | "cell_type": "code", 1218 | "execution_count": 32, 1219 | "metadata": {}, 1220 | "outputs": [ 1221 | { 1222 | "name": "stdout", 1223 | "output_type": "stream", 1224 | "text": [ 1225 | "MRR@20 score for Multi-Head Attentional User-based GRU on MovieLens : 0.06202981 ± 0.0017889912578660427\n", 1226 | "Recall@20 score for Multi-Head Attentional User-based GRU on MovieLens: 0.20207629 ± 0.005651086872747263\n" 1227 | ] 1228 | } 1229 | ], 1230 | "source": [ 1231 | "print_scores(\"network_bn5mh_att_linear_ml.p\",'Multi-Head Attentional User-based GRU on MovieLens')" 1232 | ] 1233 | }, 1234 | { 1235 | "cell_type": "code", 1236 | "execution_count": null, 1237 | "metadata": {}, 1238 | "outputs": [], 1239 | "source": [] 1240 | } 1241 | ], 1242 | "metadata": { 1243 | "kernelspec": { 1244 | "display_name": "Python 3", 1245 | "language": "python", 1246 | "name": "python3" 1247 | }, 1248 | "language_info": { 1249 | "codemirror_mode": { 1250 | "name": "ipython", 1251 | "version": 3 1252 | }, 1253 | "file_extension": ".py", 1254 | "mimetype": "text/x-python", 1255 | "name": "python", 1256 | "nbconvert_exporter": "python", 1257 | "pygments_lexer": "ipython3", 1258 | "version": "3.5.2" 1259 | } 1260 | }, 1261 | "nbformat": 4, 1262 | "nbformat_minor": 2 1263 | } 1264 | --------------------------------------------------------------------------------