├── README.md ├── polara_based.ipynb └── main.ipynb /README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /polara_based.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from polara import RecommenderData\n", 10 | "from polara import SVDModel\n", 11 | "from polara import get_movielens_data\n", 12 | "from polara.tools.preprocessing import filter_sessions_by_length\n", 13 | "from polara.evaluation import evaluation_engine as ee\n", 14 | "import numpy as np\n", 15 | "import scipy.sparse as SP\n", 16 | "from io import BytesIO\n", 17 | "import pandas as pd\n", 18 | "\n", 19 | "import numpy as np, scipy.stats as st\n", 20 | "import numpy as np\n", 21 | "import scipy as sp\n", 22 | "import scipy.stats\n", 23 | "\n" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "ml_train_items = np.load(\"ml_train_items.npy\")\n", 33 | "ml_train_mask = np.load(\"ml_train_mask.npy\")\n", 34 | "ml_train_users = np.load(\"ml_train_users.npy\")\n", 35 | "ml_val_items = np.load(\"ml_val_items.npy\")\n", 36 | "ml_val_mask = np.load(\"ml_val_mask.npy\")\n", 37 | "ml_val_users = np.load(\"ml_val_users.npy\")\n", 38 | "ml_test_items = np.load(\"ml_test_items.npy\")\n", 39 | "ml_test_mask = np.load(\"ml_test_mask.npy\")\n", 40 | "ml_test_users = np.load(\"ml_test_users.npy\")\n", 41 | "ml_train_user_idx = np.load('ml_train_user_idx.npy')\n", 42 | "ml_train_item_idx = np.load('ml_train_item_idx.npy')\n", 43 | "ml_train_feedback = np.load('ml_train_feedback.npy')" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 13, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/plain": [ 54 | "762.8421874999999" 55 | ] 56 | }, 57 | "execution_count": 13, 58 | "metadata": {}, 59 | "output_type": "execute_result" 60 | } 61 | ], 62 | "source": [ 63 | "len(lf_train_items)/128/50*10" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "lf_train_items = np.load(\"lf_train_items.npy\")\n", 73 | "lf_train_mask = np.load(\"lf_train_mask.npy\")\n", 74 | "lf_train_users = np.load(\"lf_train_users.npy\")\n", 75 | "lf_val_items = np.load(\"lf_val_items.npy\")\n", 76 | "lf_val_mask = np.load(\"lf_val_mask.npy\")\n", 77 | "lf_val_users = np.load(\"lf_val_users.npy\")\n", 78 | "lf_test_items = np.load(\"lf_test_items.npy\")\n", 79 | "lf_test_mask = np.load(\"lf_test_mask.npy\")\n", 80 | "lf_test_users = np.load(\"lf_test_users.npy\")\n", 81 | "lf_train_user_idx = np.load('lf_train_user_idx.npy')\n", 82 | "lf_train_item_idx = np.load('lf_train_item_idx.npy')\n", 83 | "lf_train_feedback = np.load('lf_train_feedback.npy')" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "Collecting git+https://github.com/Evfro/polara.git@develop\n", 96 | " Cloning https://github.com/Evfro/polara.git (to revision develop) to /tmp/pip-req-build-0a3rx0t8\n", 97 | "Building wheels for collected packages: polara\n", 98 | " Running setup.py bdist_wheel for polara ... \u001b[?25ldone\n", 99 | "\u001b[?25h Stored in directory: /tmp/pip-ephem-wheel-cache-x7hk2c0d/wheels/95/b2/f8/18e769bc21d1fc5323b933f0ab7261b9521a589243f7549bf4\n", 100 | "Successfully built polara\n", 101 | "Installing collected packages: polara\n", 102 | " Found existing installation: polara 0.5.3\n", 103 | " Uninstalling polara-0.5.3:\n", 104 | " Successfully uninstalled polara-0.5.3\n", 105 | "Successfully installed polara-0.5.3\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "!pip3 install --upgrade git+https://github.com/Evfro/polara.git@develop#egg=polara" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 2, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "def remove_gaps(data):\n", 120 | " data['movieid'] = ml_data.groupby('movieid', sort=False).grouper.group_info[0]\n", 121 | " data['userid'] = ml_data.groupby('userid', sort=False).grouper.group_info[0]\n", 122 | " return data" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 3, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "def normalize_timestamp(x):\n", 132 | " x[\"timestamp\"] = np.argsort(list(x[\"timestamp\"]))\n", 133 | " return x\n", 134 | "\n", 135 | "def length_col(x):\n", 136 | " x['timestamp'] = len(x)\n", 137 | " return x\n", 138 | "\n", 139 | "def train_test_val_split(data):\n", 140 | " \n", 141 | " data = data.groupby(\"userid\").apply(normalize_timestamp)\n", 142 | " lc = data.groupby(\"userid\").apply(length_col)\n", 143 | " max_time_stamp = lc['timestamp']\n", 144 | " timestamp = data['timestamp']\n", 145 | " data_train = data[timestamp=1).T[0]\n", 439 | " recommendation = new_P[user].dot(Q.T)\n", 440 | " recommendation[consumed_items] = -np.inf\n", 441 | " \n", 442 | " true_consumption = targets[row_inx]\n", 443 | " mrrs.append(mrr_at_k(recommendation,true_consumption,k=20))\n", 444 | " recalls.append(recall_at_k(recommendation,true_consumption,k=20))\n", 445 | " \n", 446 | " mrr, h_mrr = mean_confidence_interval(mrrs)\n", 447 | " recall, h_recall = mean_confidence_interval(recalls)\n", 448 | " \n", 449 | " return (mrr, h_mrr),(recall, h_recall)\n", 450 | " \n", 451 | "@jit(nopython=True, nogil=True)\n", 452 | "def recall_at_k(recommendation,true_consumption,k=20):\n", 453 | " topk_inds = recommendation.argsort()[-k:][::-1]\n", 454 | " reccommnded_topk_items = np.zeros(recommendation.shape)\n", 455 | " reccommnded_topk_items[topk_inds] = 1\n", 456 | " \n", 457 | " recall = reccommnded_topk_items[int(true_consumption)]\n", 458 | " return recall\n", 459 | "\n", 460 | "@jit(nopython=True, nogil=True)\n", 461 | "def mrr_at_k(recommendation,true_consumption,k=20):\n", 462 | " topk_inds = recommendation.argsort()[-k:][::-1]\n", 463 | " rr = np.zeros(recommendation.shape)\n", 464 | " rr[topk_inds] = 1/np.arange(1,k+1) \n", 465 | " current_rr = rr[int(true_consumption)]\n", 466 | " return current_rr\n", 467 | "\n", 468 | "\n", 469 | "\n", 470 | " " 471 | ] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "metadata": {}, 476 | "source": [ 477 | "### MovieLens" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": 15, 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "import scipy as sp\n", 487 | "train_matrix = np.array(sp.sparse.csr_matrix((np.ones(len(ml_train_user_idx)),\n", 488 | " (ml_train_user_idx, ml_train_item_idx)),\n", 489 | " shape=(max(ml_train_user_idx)+1,max(ml_train_item_idx)+1), dtype=np.float64).todense())\n" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 16, 495 | "metadata": {}, 496 | "outputs": [ 497 | { 498 | "name": "stdout", 499 | "output_type": "stream", 500 | "text": [ 501 | "Epoch 1 RMSE: 1.1889509722951297\n", 502 | "Epoch 2 RMSE: 0.9005314957384513\n", 503 | "Epoch 3 RMSE: 0.8463680763042712\n", 504 | "Epoch 4 RMSE: 0.8188051422034335\n", 505 | "Epoch 5 RMSE: 0.8004663158387065\n", 506 | "Epoch 6 RMSE: 0.7855627864243409\n", 507 | "Epoch 7 RMSE: 0.7732885934551574\n", 508 | "Epoch 8 RMSE: 0.7632023403587169\n", 509 | "Epoch 9 RMSE: 0.7548784181251089\n", 510 | "Epoch 10 RMSE: 0.7479567842743502\n", 511 | "Epoch 11 RMSE: 0.742152630887336\n", 512 | "Epoch 12 RMSE: 0.7372427424804076\n", 513 | "Epoch 13 RMSE: 0.7330523912264771\n", 514 | "Epoch 14 RMSE: 0.7294447595737159\n", 515 | "Epoch 15 RMSE: 0.7263125637673405\n", 516 | "Epoch 16 RMSE: 0.7235714286511026\n", 517 | "Epoch 17 RMSE: 0.7211546768303492\n", 518 | "Epoch 18 RMSE: 0.7190092797572853\n", 519 | "Epoch 19 RMSE: 0.7170927493456211\n", 520 | "Epoch 20 RMSE: 0.7153707702341925\n", 521 | "Epoch 21 RMSE: 0.7138154014210378\n", 522 | "Epoch 22 RMSE: 0.7124037087144436\n", 523 | "Epoch 23 RMSE: 0.7111167205435434\n", 524 | "Epoch 24 RMSE: 0.7099386259592084\n", 525 | "Epoch 25 RMSE: 0.708856154402729\n", 526 | "Epoch 26 RMSE: 0.7078580925338287\n", 527 | "Epoch 27 RMSE: 0.706934905051986\n", 528 | "Epoch 28 RMSE: 0.7060784349627898\n", 529 | "Epoch 29 RMSE: 0.7052816649704758\n", 530 | "Epoch 30 RMSE: 0.7045385262243349\n" 531 | ] 532 | } 533 | ], 534 | "source": [ 535 | "# train_feedback[:] = 1\n", 536 | "P, Q, _ = basic_matrix_factorization(ml_train_user_idx, ml_train_item_idx, ml_train_feedback\\\n", 537 | " ,rank=20,reg = 0.01,num_epochs=30)" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": 12, 543 | "metadata": {}, 544 | "outputs": [ 545 | { 546 | "name": "stdout", 547 | "output_type": "stream", 548 | "text": [ 549 | "MRR@20 score for MF on MovieLens : 0.004567202376613807 ± 0.001339522261853232\n", 550 | "Recall@20 score for MF on MovieLens : 0.019013581129378128 ± 0.005759417598196718\n" 551 | ] 552 | } 553 | ], 554 | "source": [ 555 | "(mrr, h_mrr),(recall, h_recall), = estimate_model(P,Q,ml_test_items, ml_test_mask, ml_test_users,reg = 0.01)\n", 556 | "ds_name = \"MovieLens\"\n", 557 | "print(\"MRR@20 score for MF on \", ds_name,\": \",mrr,\"±\",h_mrr)\n", 558 | "print(\"Recall@20 score for MF on\",ds_name,\": \",recall,\"±\",h_recall)" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": {}, 564 | "source": [ 565 | "### LastFM" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 7, 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "import scipy as sp\n", 575 | "train_matrix = np.array(sp.sparse.csr_matrix((np.ones(len(lf_train_user_idx)),\n", 576 | " (lf_train_user_idx, lf_train_item_idx)),\n", 577 | " shape=(max(lf_train_user_idx)+1,max(lf_train_item_idx)+1), dtype=np.float64).todense())\n", 578 | "\n" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 8, 584 | "metadata": {}, 585 | "outputs": [ 586 | { 587 | "name": "stdout", 588 | "output_type": "stream", 589 | "text": [ 590 | "Epoch 1 RMSE: 0.33389627661066806\n", 591 | "Epoch 2 RMSE: 0.13595819292966613\n", 592 | "Epoch 3 RMSE: 0.10508689900024767\n", 593 | "Epoch 4 RMSE: 0.09070079526606878\n", 594 | "Epoch 5 RMSE: 0.08147071061295962\n", 595 | "Epoch 6 RMSE: 0.07477153011689404\n", 596 | "Epoch 7 RMSE: 0.06956005889239489\n", 597 | "Epoch 8 RMSE: 0.06532270200989314\n", 598 | "Epoch 9 RMSE: 0.06177040644012343\n", 599 | "Epoch 10 RMSE: 0.058724800935394526\n", 600 | "Epoch 11 RMSE: 0.056068301626020486\n", 601 | "Epoch 12 RMSE: 0.05371946888402425\n", 602 | "Epoch 13 RMSE: 0.051619679737187375\n", 603 | "Epoch 14 RMSE: 0.04972540075475995\n", 604 | "Epoch 15 RMSE: 0.04800345859082646\n", 605 | "Epoch 16 RMSE: 0.04642801718895144\n", 606 | "Epoch 17 RMSE: 0.04497857557847389\n", 607 | "Epoch 18 RMSE: 0.04363860094936924\n", 608 | "Epoch 19 RMSE: 0.04239457073868288\n", 609 | "Epoch 20 RMSE: 0.041235285857219646\n", 610 | "Epoch 21 RMSE: 0.04015136833763086\n", 611 | "Epoch 22 RMSE: 0.03913488730514752\n", 612 | "Epoch 23 RMSE: 0.03817907605459219\n", 613 | "Epoch 24 RMSE: 0.03727811497846672\n", 614 | "Epoch 25 RMSE: 0.03642696285749802\n", 615 | "Epoch 26 RMSE: 0.035621224182504706\n", 616 | "Epoch 27 RMSE: 0.03485704367226345\n", 617 | "Epoch 28 RMSE: 0.03413102156514454\n", 618 | "Epoch 29 RMSE: 0.03344014495524048\n", 619 | "Epoch 30 RMSE: 0.032781731648978206\n" 620 | ] 621 | } 622 | ], 623 | "source": [ 624 | "P, Q, _ = basic_matrix_factorization(lf_train_user_idx, lf_train_item_idx, lf_train_feedback\\\n", 625 | " ,rank=20,reg = 0.01,num_epochs=30)" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": 9, 631 | "metadata": {}, 632 | "outputs": [ 633 | { 634 | "name": "stdout", 635 | "output_type": "stream", 636 | "text": [ 637 | "MRR@20 score for MF on LastFM : 2.559322309196812e-05 ± 4.358562546655403e-05\n", 638 | "Recall@20 score for MF on LastFM : 0.00022010271460014673 ± 0.0004937387018228\n" 639 | ] 640 | } 641 | ], 642 | "source": [ 643 | "(mrr, h_mrr),(recall, h_recall), = estimate_model(P,Q,lf_test_items, lf_test_mask, lf_test_users,reg = 0.01)\n", 644 | "ds_name = \"LastFM\"\n", 645 | "print(\"MRR@20 score for MF on \", ds_name,\": \",mrr,\"±\",h_mrr)\n", 646 | "print(\"Recall@20 score for MF on\",ds_name,\": \",recall,\"±\",h_recall)" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": null, 652 | "metadata": {}, 653 | "outputs": [], 654 | "source": [] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": null, 659 | "metadata": {}, 660 | "outputs": [], 661 | "source": [] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": null, 666 | "metadata": {}, 667 | "outputs": [], 668 | "source": [] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 3, 673 | "metadata": {}, 674 | "outputs": [ 675 | { 676 | "data": { 677 | "text/html": [ 678 | "
\n", 679 | "\n", 692 | "\n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | "
useridmovieidratingtimestamp
011225.0838985046
111855.0838983525
212315.0838983392
312925.0838983421
413165.0838983392
\n", 740 | "
" 741 | ], 742 | "text/plain": [ 743 | " userid movieid rating timestamp\n", 744 | "0 1 122 5.0 838985046\n", 745 | "1 1 185 5.0 838983525\n", 746 | "2 1 231 5.0 838983392\n", 747 | "3 1 292 5.0 838983421\n", 748 | "4 1 316 5.0 838983392" 749 | ] 750 | }, 751 | "execution_count": 3, 752 | "metadata": {}, 753 | "output_type": "execute_result" 754 | } 755 | ], 756 | "source": [ 757 | "ml_data = get_movielens_data(\"ml-10m.zip\", include_time=True)\n", 758 | "ml_data.head()" 759 | ] 760 | }, 761 | { 762 | "cell_type": "code", 763 | "execution_count": 10, 764 | "metadata": {}, 765 | "outputs": [], 766 | "source": [ 767 | "data = (filter_sessions_by_length(ml_data, min_session_length=20)\n", 768 | " #.query('rating >= 4')\n", 769 | " #.assign(rating=1)\n", 770 | " )" 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "execution_count": 11, 776 | "metadata": {}, 777 | "outputs": [], 778 | "source": [ 779 | "data_model = RecommenderData(data, 'userid', 'movieid', 'rating', custom_order='timestamp', seed=0)\n", 780 | "data_model.holdout_size = 1\n", 781 | "data_model.random_holdout = False\n", 782 | "data_model.warm_start = False\n", 783 | "data_model.permute_tops = False" 784 | ] 785 | }, 786 | { 787 | "cell_type": "code", 788 | "execution_count": 12, 789 | "metadata": {}, 790 | "outputs": [ 791 | { 792 | "name": "stdout", 793 | "output_type": "stream", 794 | "text": [ 795 | "Preparing data...\n", 796 | "Done.\n" 797 | ] 798 | } 799 | ], 800 | "source": [ 801 | "data_model.prepare()" 802 | ] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": 13, 807 | "metadata": {}, 808 | "outputs": [], 809 | "source": [ 810 | "idx, val, shp = data_model.to_coo()" 811 | ] 812 | }, 813 | { 814 | "cell_type": "code", 815 | "execution_count": null, 816 | "metadata": {}, 817 | "outputs": [], 818 | "source": [] 819 | }, 820 | { 821 | "cell_type": "code", 822 | "execution_count": 14, 823 | "metadata": {}, 824 | "outputs": [ 825 | { 826 | "name": "stdout", 827 | "output_type": "stream", 828 | "text": [ 829 | "Epoch 1 RMSE: 1.3085293049789994\n", 830 | "Epoch 2 RMSE: 0.9151132113832263\n", 831 | "Epoch 3 RMSE: 0.8808651366877939\n", 832 | "Epoch 4 RMSE: 0.8676692047213754\n", 833 | "Epoch 5 RMSE: 0.8587222170703273\n", 834 | "Epoch 6 RMSE: 0.8525469762850404\n", 835 | "Epoch 7 RMSE: 0.8479598409933002\n", 836 | "Epoch 8 RMSE: 0.8442665480920708\n", 837 | "Epoch 9 RMSE: 0.8411850669634959\n", 838 | "Epoch 10 RMSE: 0.8385955015689102\n" 839 | ] 840 | } 841 | ], 842 | "source": [ 843 | "P, Q, biases = basic_matrix_factorization(*idx.T, val,rank=20,num_epochs=10)" 844 | ] 845 | }, 846 | { 847 | "cell_type": "code", 848 | "execution_count": 20, 849 | "metadata": {}, 850 | "outputs": [], 851 | "source": [ 852 | "R = Q.dot(P.T).T" 853 | ] 854 | }, 855 | { 856 | "cell_type": "code", 857 | "execution_count": 50, 858 | "metadata": {}, 859 | "outputs": [], 860 | "source": [ 861 | "topk = np.argsort(R,axis = 1)[:,-20:]" 862 | ] 863 | }, 864 | { 865 | "cell_type": "code", 866 | "execution_count": 53, 867 | "metadata": {}, 868 | "outputs": [ 869 | { 870 | "data": { 871 | "text/plain": [ 872 | "(69878, 20)" 873 | ] 874 | }, 875 | "execution_count": 53, 876 | "metadata": {}, 877 | "output_type": "execute_result" 878 | } 879 | ], 880 | "source": [] 881 | }, 882 | { 883 | "cell_type": "code", 884 | "execution_count": 60, 885 | "metadata": {}, 886 | "outputs": [], 887 | "source": [ 888 | "user_idx, item_idx, fdbk_val = data_model.test_to_coo()" 889 | ] 890 | }, 891 | { 892 | "cell_type": "code", 893 | "execution_count": 75, 894 | "metadata": {}, 895 | "outputs": [ 896 | { 897 | "ename": "TypeError", 898 | "evalue": "__init__() got multiple values for argument 'shape'", 899 | "output_type": "error", 900 | "traceback": [ 901 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 902 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 903 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mSP\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcoo_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfdbk_val\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0muser_idx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mitem_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0muser_idx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mitem_idx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtodense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 904 | "\u001b[0;31mTypeError\u001b[0m: __init__() got multiple values for argument 'shape'" 905 | ] 906 | } 907 | ], 908 | "source": [ 909 | "SP.coo_matrix(fdbk_val,(user_idx,item_idx),shape =(user_idx.max()+1,item_idx.max()+1) ).todense().shape" 910 | ] 911 | }, 912 | { 913 | "cell_type": "code", 914 | "execution_count": 73, 915 | "metadata": {}, 916 | "outputs": [ 917 | { 918 | "data": { 919 | "text/plain": [ 920 | "array([ 108, 1557, 1564, ..., 1365, 775, 2300])" 921 | ] 922 | }, 923 | "execution_count": 73, 924 | "metadata": {}, 925 | "output_type": "execute_result" 926 | } 927 | ], 928 | "source": [ 929 | "item_idx" 930 | ] 931 | }, 932 | { 933 | "cell_type": "code", 934 | "execution_count": null, 935 | "metadata": {}, 936 | "outputs": [], 937 | "source": [ 938 | "from polara import RecommenderData\n", 939 | "from polara import SVDModel\n", 940 | "from polara import get_movielens_data\n", 941 | "from polara.tools.preprocessing import filter_sessions_by_length\n", 942 | "from polara.evaluation import evaluation_engine as ee\n", 943 | "import numpy as np\n", 944 | "import scipy.sparse as SP\n", 945 | "from io import BytesIO\n", 946 | "import pandas as pd\n", 947 | "\n", 948 | "def train_MF():\n", 949 | " ml_data = get_movielens_data(\"ml-10m.zip\", include_time=True)\n", 950 | " data = (filter_sessions_by_length(ml_data, min_session_length=20)\n", 951 | " #.query('rating >= 4')\n", 952 | " #.assign(rating=1)\n", 953 | " )\n", 954 | " data_model = RecommenderData(data, 'userid', 'movieid', 'rating', custom_order='timestamp', seed=0)\n", 955 | " #data_model.holdout_size = 1\n", 956 | " data_model.random_holdout = False\n", 957 | " data_model.warm_start = False\n", 958 | " data_model.permute_tops = False\n" 959 | ] 960 | } 961 | ], 962 | "metadata": { 963 | "kernelspec": { 964 | "display_name": "Python 3", 965 | "language": "python", 966 | "name": "python3" 967 | }, 968 | "language_info": { 969 | "codemirror_mode": { 970 | "name": "ipython", 971 | "version": 3 972 | }, 973 | "file_extension": ".py", 974 | "mimetype": "text/x-python", 975 | "name": "python", 976 | "nbconvert_exporter": "python", 977 | "pygments_lexer": "ipython3", 978 | "version": "3.5.2" 979 | } 980 | }, 981 | "nbformat": 4, 982 | "nbformat_minor": 2 983 | } 984 | -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### 1. Linear User-based GRU" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch, torch.nn as nn\n", 17 | "import torch.nn.functional as F\n", 18 | "from torch.autograd import Variable\n", 19 | "import numpy as np\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "\n", 22 | "class LinearGRU(nn.Module):\n", 23 | " def __init__(self, n_users,n_items, emb_size=None, hidden_units=1000,dropout = 0.8,user_dropout = 0.5):\n", 24 | " super(self.__class__, self).__init__()\n", 25 | " self.n_users = n_users\n", 26 | " self.n_items = n_items\n", 27 | " self.hidden_units = hidden_units\n", 28 | " if emb_size == None:\n", 29 | " emb_size = hidden_units\n", 30 | " self.emb_size = emb_size\n", 31 | " ## todo why embeding?\n", 32 | " self.user_emb = nn.Embedding(n_users,emb_size)\n", 33 | " self.item_emb = nn.Embedding(n_items,emb_size)\n", 34 | " self.grucell = nn.GRUCell(input_size = emb_size*2,hidden_size = hidden_units)\n", 35 | " self.linear = nn.Linear(hidden_units,n_items)\n", 36 | " self.dropout = nn.Dropout(dropout)\n", 37 | " self.user_dropout = nn.Dropout(user_dropout)\n", 38 | " \n", 39 | " def forward(self, user_vectors, item_vectors):\n", 40 | " \n", 41 | " batch_size,_ = user_vectors.size()\n", 42 | " user_vectors = user_vectors\n", 43 | " item_vectors = item_vectors\n", 44 | " sequence_size = user_vectors.size()[1]\n", 45 | " \n", 46 | " users = self.user_dropout(self.user_emb(user_vectors))#.view(-1,sequence_size,self.emb_size)\n", 47 | " items = self.item_emb(item_vectors)#.view(-1,sequence_size,self.emb_size)\n", 48 | " \n", 49 | " h = torch.zeros(batch_size,self.hidden_units).to(device)\n", 50 | " h_t = h.unsqueeze(0)\n", 51 | " for i in range(sequence_size):\n", 52 | " gru_input = torch.cat([users[:,i,:],items[:,i,:]],dim=-1)\n", 53 | " h = self.grucell(gru_input,h)\n", 54 | " h_t = torch.cat([h_t,h.unsqueeze(0)],dim=0)\n", 55 | " ln_input = self.dropout(h_t[1:].transpose(0,1))\n", 56 | " \n", 57 | " output_ln = self.linear(ln_input)\n", 58 | " output = F.log_softmax(output_ln, dim=-1)\n", 59 | " return output\n", 60 | " \n" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "### 2. Rectified Linear User-based GRU" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 2, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "from torch.nn.parameter import Parameter\n", 77 | "class RectifiedLinearGRU(nn.Module):\n", 78 | "\n", 79 | " def __init__(self, n_users,n_items, emb_size=None, hidden_units=1000,dropout = 0.8,user_dropout = 0.5):\n", 80 | " super(self.__class__, self).__init__()\n", 81 | " self.n_users = n_users\n", 82 | " self.n_items = n_items\n", 83 | " self.hidden_units = hidden_units\n", 84 | " if emb_size == None:\n", 85 | " emb_size = hidden_units\n", 86 | " self.emb_size = emb_size\n", 87 | " ## todo why embeding?\n", 88 | " self.user_emb = nn.Embedding(n_users,emb_size)\n", 89 | " self.item_emb = nn.Embedding(n_items,emb_size)\n", 90 | " self.k1 = nn.Linear(hidden_units+2*emb_size,emb_size)\n", 91 | " self.k2 = nn.Linear(hidden_units+2*emb_size,emb_size)\n", 92 | " self.grucell = nn.GRUCell(input_size = emb_size*2,hidden_size = hidden_units)\n", 93 | " self.linear = nn.Linear(hidden_units,n_items)\n", 94 | " self.dropout = nn.Dropout(dropout)\n", 95 | " self.user_dropout = nn.Dropout(user_dropout)\n", 96 | " \n", 97 | " def forward(self, user_vectors, item_vectors):\n", 98 | " batch_size,_ = user_vectors.size()\n", 99 | " user_vectors = user_vectors\n", 100 | " item_vectors = item_vectors\n", 101 | " sequence_size = user_vectors.size()[1]\n", 102 | " \n", 103 | " users = self.user_dropout(self.user_emb(user_vectors))#.view(-1,sequence_size,self.emb_size)\n", 104 | " items = self.item_emb(item_vectors)#.view(-1,sequence_size,self.emb_size)\n", 105 | " \n", 106 | " h = torch.zeros(batch_size,self.hidden_units).to(device)\n", 107 | " h_t = h.unsqueeze(0)\n", 108 | " for i in range(sequence_size):\n", 109 | " rect_users = rectified_users(self,users[:,i,:],items[:,i,:],h)\n", 110 | " gru_input = torch.cat([rect_users,items[:,i,:]],dim=-1)\n", 111 | " h = self.grucell(gru_input,h)\n", 112 | " h_t = torch.cat([h_t,h.unsqueeze(0)],dim=0)\n", 113 | " ln_input = self.dropout(h_t[1:].transpose(0,1))\n", 114 | " output_ln = self.linear(ln_input)\n", 115 | "\n", 116 | " output = F.log_softmax(output_ln, dim=-1)\n", 117 | " return output\n", 118 | " \n", 119 | "def rectified_users(self,users,items,h):\n", 120 | " \n", 121 | " k1 = self.k1(torch.cat([users,items,h],dim = -1))\n", 122 | " k2 = self.k2(torch.cat([users,items,h],dim = -1))\n", 123 | " rect_users = users\n", 124 | " rect_users[users h x d_k \n", 241 | " query, key, value = \\\n", 242 | " [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)\n", 243 | " for l, x in zip(self.linears, (query, key, value))]\n", 244 | " \n", 245 | " # 2) Apply attention on all the projected vectors in batch. \n", 246 | " x, self.attn = attention(query, key, value, mask=mask, \n", 247 | " dropout=self.dropout)\n", 248 | " \n", 249 | " # 3) \"Concat\" using a view and apply a final linear. \n", 250 | " x = x.transpose(1, 2).contiguous() \\\n", 251 | " .view(nbatches, -1, self.h * self.d_k)\n", 252 | " return self.linears[-1](x)\n", 253 | "\n", 254 | "\n", 255 | "class MHLinearGRU(nn.Module):\n", 256 | " \n", 257 | " def __init__(self, n_users,n_items, emb_size=1000,head_num = 5,dropout = 0.8,user_dropout = 0.5):\n", 258 | " super(self.__class__, self).__init__()\n", 259 | " self.n_users = n_users\n", 260 | " self.n_items = n_items\n", 261 | " self.hidden_units = emb_size\n", 262 | " self.emb_size = emb_size\n", 263 | " ## todo why embeding?\n", 264 | " self.user_emb = nn.Embedding(n_users,emb_size)\n", 265 | " self.item_emb = nn.Embedding(n_items,emb_size)\n", 266 | " self.grucell = nn.GRUCell(input_size = emb_size*2,hidden_size = self.hidden_units)\n", 267 | " #self.att_linear = nn.Linear(hidden_units+emb_size*2,emb_size)\n", 268 | "# torch.nn.init.constant_(self.att_linear.weight,1e-6)\n", 269 | "# torch.nn.init.constant_(self.att_linear.bias,1e-6)\n", 270 | " self.linear = nn.Linear(self.hidden_units,n_items)\n", 271 | " self.dropout = nn.Dropout(dropout)\n", 272 | " self.user_dropout = nn.Dropout(user_dropout)\n", 273 | " self.user_attention = MultiHeadedAttention(head_num, self.emb_size)\n", 274 | " self.item_attention = MultiHeadedAttention(head_num, self.emb_size)\n", 275 | " self.bn_user = nn.BatchNorm1d(self.emb_size)\n", 276 | " self.bn_item = nn.BatchNorm1d(self.emb_size)\n", 277 | " self.bn_last = nn.BatchNorm1d(self.emb_size)\n", 278 | " \n", 279 | " def forward(self, user_vectors, item_vectors):\n", 280 | " batch_size,_ = user_vectors.size()\n", 281 | " user_vectors = user_vectors\n", 282 | " item_vectors = item_vectors\n", 283 | " sequence_size = user_vectors.size()[1]\n", 284 | " \n", 285 | " users = self.user_dropout(self.user_emb(user_vectors))#.view(-1,sequence_size,self.emb_size)\n", 286 | " \n", 287 | " items = self.item_emb(item_vectors)#.view(-1,sequence_size,self.emb_size)\n", 288 | " \n", 289 | " h = torch.zeros(batch_size,self.hidden_units).to(device)\n", 290 | " h_t = h.unsqueeze(0)\n", 291 | " for i in range(sequence_size):\n", 292 | " #attention = F.sigmoid(self.att_linear(torch.cat([users[:,i,:],items[:,i,:],h],dim = -1)))\n", 293 | " attnd_users = self.user_attention(items[:,i,:],h,users[:,i,:]).squeeze(1)\n", 294 | " attnd_items = self.item_attention(users[:,i,:],h,items[:,i,:]).squeeze(1)\n", 295 | " attnd_users = self.bn_user(attnd_users)\n", 296 | " attnd_items = self.bn_item(attnd_items)\n", 297 | " gru_input = torch.cat([attnd_users,attnd_items],dim=-1)\n", 298 | " h = self.grucell(gru_input,h)\n", 299 | " h_t = torch.cat([h_t,h.unsqueeze(0)],dim=0)\n", 300 | " ln_input = self.dropout(h_t[1:].transpose(0,1))\n", 301 | " output_ln = self.linear(ln_input)\n", 302 | " output = F.log_softmax(output_ln, dim=-1)\n", 303 | " return output\n", 304 | " \n", 305 | " " 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 9, 311 | "metadata": {}, 312 | "outputs": [ 313 | { 314 | "ename": "NameError", 315 | "evalue": "name 'ml_test_users' is not defined", 316 | "output_type": "error", 317 | "traceback": [ 318 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 319 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 320 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mn_users\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mml_test_users\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mn_items\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mml_train_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mml_val_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mml_test_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mn_users\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 321 | "\u001b[0;31mNameError\u001b[0m: name 'ml_test_users' is not defined" 322 | ] 323 | } 324 | ], 325 | "source": [ 326 | "n_users = int(ml_test_users.max()+1)\n", 327 | "n_items = int(np.max([ml_train_items.max()+1,ml_val_items.max()+1,ml_test_items.max()])+1)\n", 328 | "n_users" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 97, 334 | "metadata": {}, 335 | "outputs": [ 336 | { 337 | "data": { 338 | "text/plain": [ 339 | "torch.Size([2, 5, 10678])" 340 | ] 341 | }, 342 | "execution_count": 97, 343 | "metadata": {}, 344 | "output_type": "execute_result" 345 | } 346 | ], 347 | "source": [ 348 | "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n", 349 | "# device = torch.device(\"cpu\")\n", 350 | "# network = MHLinearGRU(n_users=3,n_items=3,emb_size=20).to(device)\n", 351 | "network = MHLinearGRU(n_users=n_users,n_items=n_items).to(device)\n", 352 | "\n", 353 | "users = np.array([[1,1,1,1,1],\n", 354 | " [2,2,2,2,2]])\n", 355 | "items = np.array([[0,1,2,1,1],\n", 356 | " [0,2,2,1,0]])\n", 357 | "pred = network(Variable(torch.LongTensor(users)).to(device),Variable(torch.LongTensor(items)).to(device))\n", 358 | "pred.size()" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "### MovieLens Prerocessing" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 5, 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "# Libraries and provided functions\n", 375 | "import pandas as pd\n", 376 | "import zipfile\n", 377 | "import wget\n", 378 | "from io import StringIO \n", 379 | "import numpy as np\n", 380 | "import scipy as sp\n", 381 | "from scipy import sparse\n", 382 | "import scipy.sparse.linalg\n", 383 | "from tqdm import tqdm # Very useful library to see progress bar during range iterations: just type `for i in tqdm(range(10)):`\n", 384 | "from matplotlib import pyplot as plt\n", 385 | "%matplotlib inline\n", 386 | "\n", 387 | "from collections import namedtuple\n", 388 | "import sys\n", 389 | "\n", 390 | "def drop_unused_items(train,val,test):\n", 391 | " train_items = train.itemid.unique()\n", 392 | " val_items = val.itemid.unique()\n", 393 | " test_items = test.itemid.unique()\n", 394 | " \n", 395 | " droped_items = list((set(val_items)|set(test_items)) - set(train_items))\n", 396 | " val_mask = val.userid == droped_items[0]\n", 397 | " test_mask = test.userid == droped_items[0]\n", 398 | " for droped_item in droped_items:\n", 399 | " val_mask += val.itemid==droped_item\n", 400 | " test_mask += test.itemid==droped_item\n", 401 | " val = val[~val_mask]\n", 402 | " test = test[~test_mask]\n", 403 | " return val,test\n", 404 | "\n", 405 | "def move_timestamps_to_end(x,max_order):\n", 406 | " new_order = x.groupby('timestamp', sort=True).grouper.group_info[0]\n", 407 | " x[\"timestamp\"] = (max_order - new_order.max())+new_order\n", 408 | " return x\n", 409 | "\n", 410 | "def normalize_timestamp(x):\n", 411 | " x[\"timestamp\"] = x.groupby(['timestamp','itemid'], sort=True).grouper.group_info[0]\n", 412 | " return x\n", 413 | "\n", 414 | "def set_timestamp_length(x):\n", 415 | " x['length'] = len(x)\n", 416 | " return x\n", 417 | "\n", 418 | "def to_coo(data):\n", 419 | " \n", 420 | " user_idx, item_idx, feedback = data['userid'], data['itemid'], data['rating']\n", 421 | " return user_idx, item_idx, feedback\n", 422 | "\n", 423 | "def to_matrices(data):\n", 424 | " data = split_by_groups(data)\n", 425 | " \n", 426 | " data_max_order = data['timestamp'].max()\n", 427 | " data = data.groupby(\"index\").apply(move_timestamps_to_end,data_max_order)\n", 428 | "\n", 429 | " data_shape = data[['index', 'timestamp']].max()+1\n", 430 | " data_matrix = sp.sparse.csr_matrix((data['itemid'],\n", 431 | " (data['index'], data['timestamp'])),\n", 432 | " shape=data_shape, dtype=np.float64).todense()\n", 433 | " mask_matrix = sp.sparse.csr_matrix((np.ones(len(data)),\n", 434 | " (data['index'], data['timestamp'])),\n", 435 | " shape=data_shape, dtype=np.float64).todense()\n", 436 | " \n", 437 | " data_users = data.drop_duplicates(['index'])\n", 438 | " user_data_shape = data_users['index'].max()+1\n", 439 | " user_vector = sp.sparse.csr_matrix((data_users['userid'],\n", 440 | " (data_users['index'],np.zeros(user_data_shape))),\n", 441 | " shape=(user_data_shape,1), dtype=np.float64).todense()\n", 442 | " user_matrix = np.tile(user_vector,(1,data_shape[1]))\n", 443 | " return data_matrix, mask_matrix, user_matrix\n", 444 | "\n", 445 | "def train_val_test_split(data,frac):\n", 446 | " data = data.groupby(\"userid\").apply(set_timestamp_length)\n", 447 | " max_time_stamp = data['length']*frac\n", 448 | " timestamp = data['timestamp']\n", 449 | " data_train = data[timestamp\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cuda:2\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_available\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# device = torch.device(\"cpu\")\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mn_users\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mml_test_users\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mn_items\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mml_train_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mml_val_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mml_test_items\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mnetwork\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMHLinearGRU\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_users\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_users\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mn_items\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_items\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 710 | "\u001b[0;31mNameError\u001b[0m: name 'ml_test_users' is not defined" 711 | ] 712 | } 713 | ], 714 | "source": [ 715 | "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n", 716 | "# device = torch.device(\"cpu\")\n", 717 | "n_users = int(ml_test_users.max()+1)\n", 718 | "n_items = int(np.max([ml_train_items.max()+1,ml_val_items.max()+1,ml_test_items.max()])+1)\n", 719 | "network = MHLinearGRU(n_users=n_users,n_items=n_items).to(device)" 720 | ] 721 | }, 722 | { 723 | "cell_type": "code", 724 | "execution_count": 89, 725 | "metadata": {}, 726 | "outputs": [], 727 | "source": [ 728 | "import torch.utils.data\n", 729 | "opt = torch.optim.Adam(network.parameters(),lr =0.001)\n", 730 | "\n", 731 | "history = []\n", 732 | "\n", 733 | "train_loader = torch.utils.data.DataLoader(\\\n", 734 | " torch.utils.data.TensorDataset(\\\n", 735 | " *(torch.LongTensor(ml_train_users),torch.LongTensor(ml_train_items),torch.FloatTensor(ml_train_mask))),\\\n", 736 | " batch_size=1000,shuffle=True)\n", 737 | "\n", 738 | "val_loader = torch.utils.data.DataLoader(\\\n", 739 | " torch.utils.data.TensorDataset(\\\n", 740 | " *(torch.LongTensor(ml_val_users),torch.LongTensor(ml_val_items),torch.FloatTensor(ml_val_mask))),\\\n", 741 | " batch_size=1000,shuffle=True)\n", 742 | "test_loader = torch.utils.data.DataLoader(\\\n", 743 | " torch.utils.data.TensorDataset(\\\n", 744 | " *(torch.LongTensor(ml_test_users),torch.LongTensor(ml_test_items),torch.FloatTensor(ml_test_mask))),\\\n", 745 | " batch_size=1000,shuffle=True)" 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": 6, 751 | "metadata": {}, 752 | "outputs": [], 753 | "source": [ 754 | "import torch.utils.data\n", 755 | "from IPython.display import clear_output\n", 756 | "import gc\n", 757 | "\n", 758 | "\n", 759 | "\n", 760 | "def validate_lce(network,val_loader):\n", 761 | " torch.cuda.empty_cache()\n", 762 | " gc.collect()\n", 763 | " network.eval()\n", 764 | " losses = []\n", 765 | " with torch.no_grad():\n", 766 | " for user_batch_ix,item_batch_ix, mask_batch_ix in val_loader:\n", 767 | "\n", 768 | " user_batch_ix = Variable(user_batch_ix).to(device)\n", 769 | " item_batch_ix = Variable(item_batch_ix).to(device)\n", 770 | " mask_batch_ix = Variable(mask_batch_ix).to(device)\n", 771 | "\n", 772 | " logp_seq = network(user_batch_ix, item_batch_ix)\n", 773 | " # compute loss\n", 774 | " predictions_logp = logp_seq[:, :-1]*mask_batch_ix[:, :-1,None]\n", 775 | " actual_next_tokens = item_batch_ix[:, 1:]\n", 776 | "\n", 777 | " logp_next = torch.gather(predictions_logp, dim=2, index=actual_next_tokens[:,:,None])\n", 778 | " loss = -logp_next.sum()/mask_batch_ix[:, :-1].sum()\n", 779 | " losses.append(loss.cpu().data.numpy())\n", 780 | " torch.cuda.empty_cache()\n", 781 | " gc.collect() \n", 782 | " return np.mean(losses)\n", 783 | "\n", 784 | "\n", 785 | "def train_network(network,train_loader,val_loader,num_epoch = 10):\n", 786 | " for epoch in range(num_epoch):\n", 787 | " i=0\n", 788 | " for user_batch_ix,item_batch_ix, mask_batch_ix in train_loader:\n", 789 | " network.train()\n", 790 | " user_batch_ix = Variable(user_batch_ix).to(device)\n", 791 | " item_batch_ix = Variable(item_batch_ix).to(device)\n", 792 | " mask_batch_ix = Variable(mask_batch_ix).to(device)\n", 793 | "\n", 794 | " logp_seq = network(user_batch_ix, item_batch_ix)\n", 795 | " # compute loss\n", 796 | " predictions_logp = logp_seq[:, :-1]*mask_batch_ix[:, :-1,None]\n", 797 | " actual_next_tokens = item_batch_ix[:, 1:]\n", 798 | "\n", 799 | " logp_next = torch.gather(predictions_logp, dim=2, index=actual_next_tokens[:,:,None])\n", 800 | " loss = -logp_next.sum()/mask_batch_ix[:, :-1].sum()\n", 801 | "\n", 802 | " # train with backprop\n", 803 | " opt.zero_grad()\n", 804 | " loss.backward()\n", 805 | " nn.utils.clip_grad_norm_(network.parameters(),5)\n", 806 | " opt.step()\n", 807 | " \n", 808 | " if (i+1)%50==0:\n", 809 | " val_loss = validate_lce(network,val_loader)\n", 810 | " history.append(val_loss)\n", 811 | "\n", 812 | " clear_output(True)\n", 813 | " plt.title(\"Validation error\")\n", 814 | " plt.plot(history)\n", 815 | " plt.ylabel('Cross-entropy Error')\n", 816 | " plt.xlabel('#iter')\n", 817 | " plt.show()\n", 818 | " i+=1\n", 819 | "\n", 820 | " val_loss = validate_lce(network,val_loader)\n", 821 | " history.append(val_loss)\n", 822 | "\n", 823 | " clear_output(True)\n", 824 | " plt.plot(history,label='val loss')\n", 825 | " plt.legend()\n", 826 | " plt.show()\n" 827 | ] 828 | }, 829 | { 830 | "cell_type": "code", 831 | "execution_count": 91, 832 | "metadata": {}, 833 | "outputs": [ 834 | { 835 | "data": { 836 | "image/png": "\n", 837 | "text/plain": [ 838 | "
" 839 | ] 840 | }, 841 | "metadata": {}, 842 | "output_type": "display_data" 843 | } 844 | ], 845 | "source": [ 846 | "import gc\n", 847 | "torch.cuda.empty_cache()\n", 848 | "gc.collect() \n", 849 | "train_network(network,train_loader,val_loader)" 850 | ] 851 | }, 852 | { 853 | "cell_type": "code", 854 | "execution_count": 93, 855 | "metadata": {}, 856 | "outputs": [ 857 | { 858 | "name": "stderr", 859 | "output_type": "stream", 860 | "text": [ 861 | "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:193: UserWarning: Couldn't retrieve source code for container of type MHLinearGRU. It won't be checked for correctness upon loading.\n", 862 | " \"type \" + obj.__name__ + \". It won't be checked \"\n", 863 | "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:193: UserWarning: Couldn't retrieve source code for container of type MultiHeadedAttention. It won't be checked for correctness upon loading.\n", 864 | " \"type \" + obj.__name__ + \". It won't be checked \"\n" 865 | ] 866 | } 867 | ], 868 | "source": [ 869 | "torch.save(network,\"network_bn_bn5mh_att_linear_ml.p\")" 870 | ] 871 | }, 872 | { 873 | "cell_type": "code", 874 | "execution_count": 7, 875 | "metadata": {}, 876 | "outputs": [], 877 | "source": [ 878 | "import numpy as np, scipy.stats as st\n", 879 | "import numpy as np\n", 880 | "import scipy as sp\n", 881 | "import scipy.stats\n", 882 | "\n", 883 | "def mean_confidence_interval(data, confidence=0.95,num_parts = 5):\n", 884 | " part_len = len(data)//num_parts\n", 885 | " estimations = []\n", 886 | " for i in range(num_parts):\n", 887 | " est = np.mean(data[part_len*i:part_len*(i+1)])\n", 888 | " estimations.append(est)\n", 889 | " a = 1.0*np.array(estimations)\n", 890 | " n = len(a)\n", 891 | " m, se = np.mean(a), scipy.stats.sem(a)\n", 892 | " h = se * sp.stats.t._ppf((1+confidence)/2., n-1)\n", 893 | " return m, h\n", 894 | "\n", 895 | "def validate_mrr(network,k,test_loader):\n", 896 | " network.eval()\n", 897 | " losses = []\n", 898 | " with torch.no_grad():\n", 899 | " for user_batch_ix,item_batch_ix, mask_batch_ix in test_loader:\n", 900 | " user_batch_ix = Variable(user_batch_ix).to(device)\n", 901 | " item_batch_ix = Variable(item_batch_ix).to(device)\n", 902 | " mask_batch_ix = Variable(mask_batch_ix).to(device)\n", 903 | "\n", 904 | " logp_seq = network(user_batch_ix, item_batch_ix)\n", 905 | " # compute loss\n", 906 | " predictions_logp = logp_seq[:,-2]\n", 907 | " _,ind = torch.topk(predictions_logp, k,dim=-1)\n", 908 | " mrr = torch.zeros(predictions_logp.size())\n", 909 | " mrr.scatter_(-1,ind.cpu(),1/torch.range(1,k).repeat(*ind.size()[:-1],1).type(torch.FloatTensor).cpu())\n", 910 | " actual_next_tokens = item_batch_ix[:, -1]\n", 911 | "\n", 912 | " logp_next = torch.gather(mrr.to(device)*mask_batch_ix[:, -2,None], dim=1, index=actual_next_tokens[:,None])\n", 913 | "# if mask_batch_ix[:,-2].sum() >0:\n", 914 | " loss = logp_next.sum()/mask_batch_ix[:,-2].sum()\n", 915 | " losses.append(loss.cpu().data.numpy())\n", 916 | " torch.cuda.empty_cache()\n", 917 | " gc.collect() \n", 918 | " m, h = mean_confidence_interval(losses)\n", 919 | " return m, h\n", 920 | "\n", 921 | "def validate_recall(network,k,test_loader):\n", 922 | " torch.cuda.empty_cache()\n", 923 | " gc.collect() \n", 924 | " \n", 925 | " network.eval()\n", 926 | " losses = []\n", 927 | " with torch.no_grad():\n", 928 | " for user_batch_ix,item_batch_ix, mask_batch_ix in test_loader:\n", 929 | " user_batch_ix = Variable(user_batch_ix).to(device)\n", 930 | " item_batch_ix = Variable(item_batch_ix).to(device)\n", 931 | " mask_batch_ix = Variable(mask_batch_ix).to(device)\n", 932 | "\n", 933 | " logp_seq = network(user_batch_ix, item_batch_ix)\n", 934 | " # compute loss\n", 935 | " predictions_logp = logp_seq[:, -2]\n", 936 | " minus_kth_biggest_logp,_ = torch.kthvalue(-predictions_logp.cpu(), k,dim=-1,keepdim=True)\n", 937 | " prediicted_kth_biggest = (predictions_logp>(-minus_kth_biggest_logp.to(device)))\\\n", 938 | " .type(torch.FloatTensor).to(device)\n", 939 | " actual_next_tokens = item_batch_ix[:, -1]\n", 940 | "\n", 941 | " logp_next = torch.gather(prediicted_kth_biggest*mask_batch_ix[:, -2,None], dim=1, index=actual_next_tokens[:,None])\n", 942 | " loss = logp_next.sum()/mask_batch_ix[:,-2].sum()\n", 943 | " losses.append(loss.cpu().data.numpy())\n", 944 | " torch.cuda.empty_cache()\n", 945 | " gc.collect() \n", 946 | " m, h = mean_confidence_interval(losses)\n", 947 | " return m, h\n", 948 | "\n", 949 | "def print_scores(model,name):\n", 950 | " network = torch.load(model).to(device)\n", 951 | " mrr_score, h = validate_mrr(network,20,test_loader)\n", 952 | " print(\"MRR@20 score for \", name,\": \",mrr_score,\"±\",h)\n", 953 | " recall_score, h = validate_recall(network,20,test_loader)\n", 954 | " print(\"Recall@20 score for \"+name+\": \",recall_score,\"±\",h)\n", 955 | " " 956 | ] 957 | }, 958 | { 959 | "cell_type": "markdown", 960 | "metadata": {}, 961 | "source": [ 962 | "### Training on LastFM" 963 | ] 964 | }, 965 | { 966 | "cell_type": "code", 967 | "execution_count": 29, 968 | "metadata": {}, 969 | "outputs": [ 970 | { 971 | "data": { 972 | "text/plain": [ 973 | "0" 974 | ] 975 | }, 976 | "execution_count": 29, 977 | "metadata": {}, 978 | "output_type": "execute_result" 979 | } 980 | ], 981 | "source": [ 982 | "import gc\n", 983 | "torch.cuda.empty_cache()\n", 984 | "gc.collect() " 985 | ] 986 | }, 987 | { 988 | "cell_type": "code", 989 | "execution_count": 30, 990 | "metadata": {}, 991 | "outputs": [], 992 | "source": [ 993 | "device = torch.device(\"cuda:6\" if torch.cuda.is_available() else \"cpu\")\n", 994 | "# device = torch.device(\"cpu\")\n", 995 | "n_users = int(np.max([lf_train_users.max()+1,lf_val_users.max()+1,lf_test_users.max()])+1)\n", 996 | "n_items = int(np.max([lf_train_items.max()+1,lf_val_items.max()+1,lf_test_items.max()])+1)\n", 997 | "# with torch.cuda.device(2):\n", 998 | "network = MHLinearGRU(n_users=n_users,n_items=n_items).to(device)\n", 999 | "# network = torch.load(\"network_linear_lf_bs128.p\").to(device)\n" 1000 | ] 1001 | }, 1002 | { 1003 | "cell_type": "code", 1004 | "execution_count": 31, 1005 | "metadata": {}, 1006 | "outputs": [], 1007 | "source": [ 1008 | "import torch.utils.data\n", 1009 | "opt = torch.optim.SGD(network.parameters(),lr =1)\n", 1010 | "\n", 1011 | "history = []\n", 1012 | "batch_size = 100\n", 1013 | "train_loader = torch.utils.data.DataLoader(\\\n", 1014 | " torch.utils.data.TensorDataset(\\\n", 1015 | " *(torch.LongTensor(lf_train_users),torch.LongTensor(lf_train_items),torch.FloatTensor(lf_train_mask))),\\\n", 1016 | " batch_size=batch_size,shuffle=True)\n", 1017 | "\n", 1018 | "val_loader = torch.utils.data.DataLoader(\\\n", 1019 | " torch.utils.data.TensorDataset(\\\n", 1020 | " *(torch.LongTensor(lf_val_users),torch.LongTensor(lf_val_items),torch.FloatTensor(lf_val_mask))),\\\n", 1021 | " batch_size=batch_size,shuffle=True)\n", 1022 | "test_loader = torch.utils.data.DataLoader(\\\n", 1023 | " torch.utils.data.TensorDataset(\\\n", 1024 | " *(torch.LongTensor(lf_test_users),torch.LongTensor(lf_test_items),torch.FloatTensor(lf_test_mask))),\\\n", 1025 | " batch_size=batch_size,shuffle=True)\n" 1026 | ] 1027 | }, 1028 | { 1029 | "cell_type": "code", 1030 | "execution_count": null, 1031 | "metadata": {}, 1032 | "outputs": [ 1033 | { 1034 | "data": { 1035 | "image/png": "\n", 1036 | "text/plain": [ 1037 | "
" 1038 | ] 1039 | }, 1040 | "metadata": {}, 1041 | "output_type": "display_data" 1042 | } 1043 | ], 1044 | "source": [ 1045 | "# 1 lr 10 epch\n", 1046 | "train_network(network.to(device),train_loader,val_loader)" 1047 | ] 1048 | }, 1049 | { 1050 | "cell_type": "code", 1051 | "execution_count": 35, 1052 | "metadata": {}, 1053 | "outputs": [ 1054 | { 1055 | "name": "stderr", 1056 | "output_type": "stream", 1057 | "text": [ 1058 | "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:193: UserWarning: Couldn't retrieve source code for container of type MHLinearGRU. It won't be checked for correctness upon loading.\n", 1059 | " \"type \" + obj.__name__ + \". It won't be checked \"\n", 1060 | "/usr/local/lib/python3.5/dist-packages/torch/serialization.py:193: UserWarning: Couldn't retrieve source code for container of type MultiHeadedAttention. It won't be checked for correctness upon loading.\n", 1061 | " \"type \" + obj.__name__ + \". It won't be checked \"\n" 1062 | ] 1063 | } 1064 | ], 1065 | "source": [ 1066 | "torch.save(network,\"network_mhatt_linear_lf_bs100_lr1.p\")" 1067 | ] 1068 | }, 1069 | { 1070 | "cell_type": "markdown", 1071 | "metadata": {}, 1072 | "source": [ 1073 | "### Scores" 1074 | ] 1075 | }, 1076 | { 1077 | "cell_type": "markdown", 1078 | "metadata": {}, 1079 | "source": [ 1080 | "#### LastFM" 1081 | ] 1082 | }, 1083 | { 1084 | "cell_type": "code", 1085 | "execution_count": 13, 1086 | "metadata": {}, 1087 | "outputs": [ 1088 | { 1089 | "name": "stdout", 1090 | "output_type": "stream", 1091 | "text": [ 1092 | "MRR@20 score for Linear User-based GRU on LastFM : 0.21047361 ± 0.006856432342802942\n", 1093 | "Recall@20 score for Linear User-based GRU on LastFM: 0.3013901 ± 0.006992072808611331\n" 1094 | ] 1095 | } 1096 | ], 1097 | "source": [ 1098 | "print_scores(\"network_linear_lf_bs128_lr1.p\",'Linear User-based GRU on LastFM')" 1099 | ] 1100 | }, 1101 | { 1102 | "cell_type": "code", 1103 | "execution_count": 12, 1104 | "metadata": {}, 1105 | "outputs": [ 1106 | { 1107 | "name": "stdout", 1108 | "output_type": "stream", 1109 | "text": [ 1110 | "MRR@20 score for Rectified Linear User-based GRU on LastFM : 0.22049585 ± 0.004507800212955085\n", 1111 | "Recall@20 score for Rectified Linear User-based GRU on LastFM: 0.3142386 ± 0.005953415133536319\n" 1112 | ] 1113 | } 1114 | ], 1115 | "source": [ 1116 | "print_scores(\"network_rect_linear_lf_bs100_lr1.p\",'Rectified Linear User-based GRU on LastFM')" 1117 | ] 1118 | }, 1119 | { 1120 | "cell_type": "code", 1121 | "execution_count": 18, 1122 | "metadata": {}, 1123 | "outputs": [ 1124 | { 1125 | "name": "stdout", 1126 | "output_type": "stream", 1127 | "text": [ 1128 | "MRR@20 score for Attentional User-based GRU on LastFM : 0.21569276 ± 0.004434641506947223\n", 1129 | "Recall@20 score for Attentional User-based GRU on LastFM: 0.3121555 ± 0.007319821406235953\n" 1130 | ] 1131 | } 1132 | ], 1133 | "source": [ 1134 | "print_scores(\"network_att_linear_lf_bs100_lr1.p\",'Attentional User-based GRU on LastFM')" 1135 | ] 1136 | }, 1137 | { 1138 | "cell_type": "code", 1139 | "execution_count": 36, 1140 | "metadata": {}, 1141 | "outputs": [ 1142 | { 1143 | "name": "stdout", 1144 | "output_type": "stream", 1145 | "text": [ 1146 | "MRR@20 score for Multi-Head Attentional User-based GRU on LastFM : 0.21176393 ± 0.005497847311917814\n", 1147 | "Recall@20 score for Multi-Head Attentional User-based GRU on LastFM: 0.30064663 ± 0.008272520008487811\n" 1148 | ] 1149 | } 1150 | ], 1151 | "source": [ 1152 | "print_scores(\"network_mhatt_linear_lf_bs100_lr1.p\",'Multi-Head Attentional User-based GRU on LastFM')" 1153 | ] 1154 | }, 1155 | { 1156 | "cell_type": "markdown", 1157 | "metadata": {}, 1158 | "source": [ 1159 | "#### MovieLens" 1160 | ] 1161 | }, 1162 | { 1163 | "cell_type": "code", 1164 | "execution_count": 46, 1165 | "metadata": {}, 1166 | "outputs": [ 1167 | { 1168 | "name": "stdout", 1169 | "output_type": "stream", 1170 | "text": [ 1171 | "MRR@20 score for Linear User-based GRU on MovieLens : 0.057497583 ± 0.0007039992875517064\n", 1172 | "Recall@20 score for Linear User-based GRU on MovieLens: 0.18227212 ± 0.003842047864645815\n" 1173 | ] 1174 | } 1175 | ], 1176 | "source": [ 1177 | "print_scores(\"network_linear_ml.p\",'Linear User-based GRU on MovieLens')" 1178 | ] 1179 | }, 1180 | { 1181 | "cell_type": "code", 1182 | "execution_count": 52, 1183 | "metadata": {}, 1184 | "outputs": [ 1185 | { 1186 | "name": "stdout", 1187 | "output_type": "stream", 1188 | "text": [ 1189 | "MRR@20 score for Rectified Linear User-based GRU on MovieLens : 0.060260046 ± 0.0005038204951150433\n", 1190 | "Recall@20 score for Rectified Linear User-based GRU on MovieLens: 0.19106193 ± 0.004563955364643368\n" 1191 | ] 1192 | } 1193 | ], 1194 | "source": [ 1195 | "print_scores(\"network_rect_linear_ml.p\",'Rectified Linear User-based GRU on MovieLens')" 1196 | ] 1197 | }, 1198 | { 1199 | "cell_type": "code", 1200 | "execution_count": 58, 1201 | "metadata": {}, 1202 | "outputs": [ 1203 | { 1204 | "name": "stdout", 1205 | "output_type": "stream", 1206 | "text": [ 1207 | "MRR@20 score for Attentional User-based GRU on MovieLens : 0.065337464 ± 0.0021981462220191897\n", 1208 | "Recall@20 score for Attentional User-based GRU on MovieLens: 0.1987088 ± 0.0027508449974647546\n" 1209 | ] 1210 | } 1211 | ], 1212 | "source": [ 1213 | "print_scores(\"network_att_linear_ml.p\",'Attentional User-based GRU on MovieLens')" 1214 | ] 1215 | }, 1216 | { 1217 | "cell_type": "code", 1218 | "execution_count": 32, 1219 | "metadata": {}, 1220 | "outputs": [ 1221 | { 1222 | "name": "stdout", 1223 | "output_type": "stream", 1224 | "text": [ 1225 | "MRR@20 score for Multi-Head Attentional User-based GRU on MovieLens : 0.06202981 ± 0.0017889912578660427\n", 1226 | "Recall@20 score for Multi-Head Attentional User-based GRU on MovieLens: 0.20207629 ± 0.005651086872747263\n" 1227 | ] 1228 | } 1229 | ], 1230 | "source": [ 1231 | "print_scores(\"network_bn5mh_att_linear_ml.p\",'Multi-Head Attentional User-based GRU on MovieLens')" 1232 | ] 1233 | }, 1234 | { 1235 | "cell_type": "code", 1236 | "execution_count": null, 1237 | "metadata": {}, 1238 | "outputs": [], 1239 | "source": [] 1240 | } 1241 | ], 1242 | "metadata": { 1243 | "kernelspec": { 1244 | "display_name": "Python 3", 1245 | "language": "python", 1246 | "name": "python3" 1247 | }, 1248 | "language_info": { 1249 | "codemirror_mode": { 1250 | "name": "ipython", 1251 | "version": 3 1252 | }, 1253 | "file_extension": ".py", 1254 | "mimetype": "text/x-python", 1255 | "name": "python", 1256 | "nbconvert_exporter": "python", 1257 | "pygments_lexer": "ipython3", 1258 | "version": "3.5.2" 1259 | } 1260 | }, 1261 | "nbformat": 4, 1262 | "nbformat_minor": 2 1263 | } 1264 | --------------------------------------------------------------------------------