├── .gitignore ├── LICENSE ├── README.md ├── datasets └── sample │ ├── num_items.txt │ ├── test.txt │ └── train.txt ├── lessr.py ├── main.py ├── packages.yml ├── preprocess.py └── utils ├── data ├── collate.py ├── dataset.py └── preprocess.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.vscode/ 2 | __pycache__/ 3 | .ipynb_checkpoints/ 4 | /datasets/* 5 | !/datasets/sample 6 | /.ignored/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tianwen CHEN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LESSR 2 | A PyTorch implementation of LESSR (**L**ossless **E**dge-order preserving aggregation and **S**hortcut graph attention for **S**ession-based **R**ecommendation) from the paper: 3 | _Handling Information Loss of Graph Neural Networks for Session-based Recommendation, Tianwen Chen and Raymong Chi-Wing Wong, KDD '20_ 4 | 5 | ## Requirements 6 | - PyTorch 1.6.0 7 | - NumPy 1.19.1 8 | - Pandas 1.1.3 9 | - DGL 0.5.2 10 | 11 | ## Usage 12 | 1. Install the requirements. 13 | If you use Anaconda, you can create a conda environment with the required packages using the following command. 14 | ```sh 15 | conda env create -f packages.yml 16 | ``` 17 | Activate the created conda environment. 18 | ``` 19 | conda activate lessr 20 | ``` 21 | 22 | 2. Download and extract the datasets. 23 | - [Diginetica](https://competitions.codalab.org/competitions/11161) 24 | - [Gowalla](https://snap.stanford.edu/data/loc-Gowalla.html) 25 | - [Last.fm](http://ocelma.net/MusicRecommendationDataset/lastfm-1K.html) 26 | 27 | 3. Preprocess the datasets using [preprocess.py](preprocess.py). 28 | For example, to preprocess the *Diginetica* dataset, extract the file *train-item-views.csv* to the folder `datasets/` and run the following command: 29 | ```sh 30 | python preprocess.py -d diginetica -f datasets/train-item-views.csv 31 | ``` 32 | The preprocessed dataset is stored in the folder `datasets/diginetica`. 33 | You can see the detailed usage of `preprocess.py` by running the following command: 34 | ```sh 35 | python preprocess.py -h 36 | ``` 37 | 38 | 4. Train the model using [main.py](main.py). 39 | If no arguments are passed to `main.py`, it will train a model using a sample dataset with default hyperparameters. 40 | ```sh 41 | python main.py 42 | ``` 43 | The commands to train LESSR with suggested hyperparameters on different datasets are as follows: 44 | ```sh 45 | python main.py --dataset-dir datasets/diginetica --embedding-dim 32 --num-layers 4 46 | python main.py --dataset-dir datasets/gowalla --embedding-dim 64 --num-layers 4 47 | python main.py --dataset-dir datasets/lastfm --embedding-dim 128 --num-layers 4 48 | ``` 49 | You can see the detailed usage of `main.py` by running the following command: 50 | ```sh 51 | python main.py -h 52 | ``` 53 | 54 | 5. Use your own dataset. 55 | 1. Create a subfolder in the `datasets/` folder. 56 | 2. The subfolder should contain the following 3 files. 57 | - `num_items.txt`: This file contains a single integer which is the number of items in the dataset. 58 | - `train.txt`: This file contains all the training sessions. 59 | - `test.txt`: This file contains all the test sessions. 60 | 3. Each line of `train.txt` and `test.txt` represents a session, which is a list of item IDs separated by commas. Note the item IDs must be in the range of `[0, num_items)`. 61 | 4. See the folder [datasets/sample](datasets/sample) for an example of a dataset. 62 | 63 | ## Citation 64 | If you use our code in your research, please cite our [paper](http://home.cse.ust.hk/~raywong/paper/kdd20-informationLoss-GNN.pdf): 65 | ``` 66 | @inproceedings{chen2020lessr, 67 | title="Handling Information Loss of Graph Neural Networks for Session-based Recommendation", 68 | author="Tianwen {Chen} and Raymond Chi-Wing {Wong}", 69 | booktitle="Proceedings of the 26th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD '20)", 70 | pages="1172-–1180", 71 | year="2020" 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /datasets/sample/num_items.txt: -------------------------------------------------------------------------------- 1 | 3429 -------------------------------------------------------------------------------- /datasets/sample/test.txt: -------------------------------------------------------------------------------- 1 | 2090,2028,2090,2276 2 | 1070,955 3 | 1157,1821 4 | 1549,1549 5 | 1631,1833,616,1036,1036,1833 6 | 2243,2234 7 | 761,1669,1105,544 8 | 2802,2015,2065 9 | 538,1944,538 10 | 2196,1294,1551,485 11 | 874,3226,222,20,659,223 12 | 1795,618,1382 13 | 260,183 14 | 332,1407,1216,1255,1903 15 | 3405,2476 16 | 2528,2486 17 | 2970,1698 18 | 2853,2853,2016 19 | 2249,2362,2599 20 | 1623,235,942,1213,234,2665 21 | 216,231,3,1 22 | 1133,290 23 | 953,952,954 24 | 1773,1773 25 | 3305,3304,3305,3304 26 | 2541,2541 27 | 40,1134,2751 28 | 2496,2317,2846,1852,1991,2569 29 | 2199,629 30 | 3180,505 31 | 2943,2944,2945,2691 32 | 3383,3383,3383 33 | 2718,2717,2718,2718 34 | 1526,62 35 | 1785,33,377,202,3287,1900 36 | 323,485,332,1596 37 | 1951,1894 38 | 3007,2378 39 | 1283,1282 40 | 597,2875 41 | 1952,1952,1842 42 | 2476,2256 43 | 3027,3028,2902,3132,3027,3028,1120,3028,3028,3027,3028,2902,3132 44 | 1646,1625,1624 45 | 1019,2086 46 | 1425,119 47 | 2438,2497 48 | 377,396 49 | 1538,3040,2954,1283 50 | 3385,3385 51 | 2971,1149 52 | 664,578,260,355,1963,182,181,2332,664,580,1828,2946,664,1310,438,748 53 | 1526,741 54 | 311,2393 55 | 1238,1238 56 | 1912,2068 57 | 1954,3167,1954 58 | 2685,1845,1987,2142,2685,2142,2685,2142 59 | 202,649 60 | 2057,930,2057,1135,890,121,1265,1152,121,3037 61 | 1841,1841 62 | 979,1164 63 | 2034,1858,2404 64 | 2441,2526 65 | 2466,2035 66 | 2232,1440 67 | 2729,2729 68 | 1859,2528,2561,2511 69 | 1031,2060,2061 70 | 3326,3326,3326,3326,3326,3326,3326 71 | 2450,2314 72 | 1373,1485 73 | 1581,1736 74 | 3412,3412,3400 75 | 1661,1960 76 | 1470,1459 77 | 2958,3344 78 | 2304,2299 79 | 3055,2979 80 | 3287,458 81 | 1274,394,255,1155 82 | 485,2986 83 | 2401,2304,2370,2560,1852 84 | 1898,33 85 | 1211,2920,2920 86 | 2197,732 87 | 1149,1149,1149 88 | 1204,1204 89 | 2054,891,631,1019 90 | 394,2049 91 | 3386,3386 92 | 115,3037 93 | 3402,3402 94 | 440,352 95 | 1721,1720,1721,2022,3051 96 | 3067,485,485 97 | 3260,3302,3302 98 | 2203,2203 99 | 1270,1269 100 | 2984,2984 101 | 179,2473 102 | 285,2062,116,94 103 | 1149,2000 104 | 3380,3381,3380,3381 105 | 1086,1535 106 | 116,1091 107 | 1690,2059,3031,1322,2060,1690 108 | 2755,896,1733,115 109 | 161,55,58,54,56,53,164,162,163,59,57 110 | 2241,2242,2307,2311,2242 111 | 785,785,786,1050 112 | 732,86,732,713,85,732,86,713,732,85 113 | 1517,1331 114 | 1653,1014,761 115 | 3021,3021 116 | 2973,3149,3218,2791,3218,3218 117 | 729,540,727,728,730 118 | 2958,2957,2958,2958,2957,3168,2957 119 | 240,1327 120 | 2803,3056,2501,2803,2035,2803 121 | 343,3379,199 122 | 3287,39 123 | 1149,1757 124 | 1274,124 125 | 49,1090 126 | 2549,1795 127 | 860,1596 128 | 2062,1463,1495 129 | 3198,1375,3198,1375,3198,1375,3198,51 130 | 250,9 131 | 2635,2629,2635,3125,2635,2635 132 | 485,1986 133 | 1424,1424 134 | 2739,3000 135 | 3053,3053 136 | 1635,1634 137 | 1167,735 138 | 1423,1423,631 139 | 1360,487 140 | 708,1820 141 | 1495,631,711,1423,1495 142 | 1596,711 143 | 2588,2077,2497 144 | 3395,3395 145 | 3162,3162 146 | 2487,2591 147 | 3264,3265 148 | 2037,2037 149 | 1448,1844,1602 150 | 1666,1451,1665,568,569 151 | 2378,1166 152 | 1116,1985,568 153 | 33,1360,377,459,368,33,459,1360,459,1360,487,33 154 | 46,20,748,20,1618 155 | 1625,988 156 | 1457,1460,1457,1460 157 | 2647,20 158 | 890,240,550 159 | 1076,62,2828 160 | 1785,1426 161 | 22,3013 162 | 1493,860,1470 163 | 187,1618,3016 164 | 568,1602,1602,1465 165 | 1134,1617 166 | 3413,3413 167 | 2767,1266,2767,1266,2767 168 | 83,83 169 | 1218,1321,649,1044,2552 170 | 1031,330,2513,2059 171 | 358,1555,1556,358,438 172 | 116,61,749 173 | 2651,945,918,918,1175,1265 174 | 1550,1314 175 | 629,499,629 176 | 3219,1382,276,887,486,2384 177 | 499,1624 178 | 1068,1577,1393,40,241,1068 179 | 46,992,1547 180 | 1382,1675 181 | 1583,510,890 182 | 1048,1108,841,241,516,2975,2056,841,491 183 | 792,1833 184 | 1087,1195,1087,1195,1087 185 | 2832,1019 186 | 1915,571 187 | 500,2474 188 | 1045,380,3135,1321,27,3234 189 | 1752,1348,1751,1611,629,40 190 | 486,420,421,959 191 | 3274,3274 192 | 1806,2713,2713,2713 193 | 1265,1172 194 | 1822,816 195 | 3029,618 196 | 1076,538 197 | 1227,15,1041 198 | 739,1646 199 | 3289,988 200 | 2508,2438 201 | 2246,1990,2028,2486,2246,2527,2948,350 202 | 2615,2615 203 | 468,890,1835,2874,2873,1303,509 204 | 622,886,1541,625,1544,2926 205 | 2761,2760,2761,2760,2761,1859,1854 206 | 2853,2853,2032 207 | 2438,2256 208 | 2839,2839 209 | 2958,2958,2958 210 | 3107,2844,2043 211 | 2476,3382,2476,3382,3382,3006,2476,3382,2608,2476,3382,2476,2476 212 | 3401,3401 213 | 480,1138,199 214 | 2692,1950,2692 215 | 2853,2300,2302,2301,2301,2302,2301,2300 216 | 3129,2477,2448,2034 217 | 2697,2809,1324,1324,2697 218 | 2311,2297,2307 219 | 2263,2177,2246 220 | 2980,350,2246,2486 221 | 2791,2791,2790,2791,2790,1538,2790,1538,2791,2791,2790 222 | 3106,2564 223 | 3142,3142 224 | 2564,1859,2723,2599 225 | 3074,3074,3074,3074,3073,2321,2588 226 | 1529,1530,1529,1417,1418,1530,1417,1529,1418,1530 227 | 2134,2273,2801,2276,2136,2135,2064,2134 228 | 2951,2950 229 | 3376,3376,3376 230 | 2370,1859 231 | 597,1229,597,632 232 | 2916,2319 233 | 2077,3125,2744,2139 234 | 2526,2486,1859 235 | 2438,1772,2892 236 | 3186,3186 237 | 3008,3008,3008 238 | 3064,2404,2402 239 | 1885,2642 240 | 2483,2921 241 | 2851,2327 242 | 3408,3408,1948,1919,890,3408,1317,1027,3408,3029,1920 243 | 3262,1814,1814,3262,1814,1814,1816 244 | 2786,2786 245 | 2177,1859 246 | 2462,2043,2043,2928 247 | 3103,3104 248 | 2181,2182,2181,2181,2181,2888,2181 249 | 351,2246,2568 250 | 1775,1776,1775 251 | 2854,311,2854,311,311 252 | 2378,1423,2825,3296 253 | 2774,3014,2774,2774 254 | 2032,1859,2032,1859 255 | 3258,3259,3259,3259 256 | 2246,3399,2246,2370,2443,3399,2443,2238,2237,2275,2464,3399 257 | 2258,2258 258 | 268,1862,1861,1862,268 259 | 3071,3071 260 | 534,1859,3261,534,1859,3261,1859,534,1859,3261,534 261 | 1480,2456,1761,2935,2935,2456,2935 262 | 1385,2874,3037 263 | 3387,597 264 | 13,2824 265 | 2438,2077 266 | 3256,3256,3256,3256 267 | 1956,1712 268 | 3396,3318 269 | 3052,2450,3052,3052,2800,3052,2800 270 | 2336,2336 271 | 2157,3277 272 | 650,909,651,909,651 273 | 2147,2304,351 274 | 105,3037 275 | 1029,2386 276 | 1507,1505,3026 277 | 2715,1742 278 | 3094,3094 279 | 631,267 280 | 1000,2820 281 | 581,941,433,581,941,78,581,941,433,581,941 282 | 2599,2486 283 | 3240,3241,1721,2024,3012,2672 284 | 1859,2527 285 | 2975,1159 286 | 1524,1524,1524,1587 287 | 1259,46,1785,202,201,1259,62,1597 288 | 649,1227,15,1041,649 289 | 1134,1624 290 | 1837,3029 291 | 3349,3350 292 | 332,485 293 | 1742,3076,3075,3076,3075,3075,3076,1742,1744 294 | 2064,2248 295 | 1898,2155,1972,1781,1898 296 | 1700,1386 297 | 800,800,675 298 | 1690,1322,2060,2061,2060 299 | 94,116,923 300 | 3341,3341,3341,3341,3341 301 | 631,1289 302 | 116,963,116,1091,2005 303 | 1517,485,618,959,2040,419,961 304 | 3264,3264,3264,3265,3264,3265,3264,3265,1549 305 | 685,61,181 306 | 3030,3188,3030,3188,3030 307 | 49,2011 308 | 1868,1868,1868 309 | 1278,819 310 | 2650,2903,2650,2649,2650 311 | 643,499 312 | 316,844,69 313 | 1598,704,428,3275,807,704,740,1675 314 | 1548,42,3404,1942,2620 315 | 3303,3303,3303,3303,3303 316 | 817,839,1077,1550 317 | 1255,1255,1903 318 | 143,142 319 | 182,181,380,61,2204,62 320 | 2770,2768 321 | 988,1687 322 | 484,150,195,3,0,1,72,1676 323 | 3365,3364,3365,3364,3365,3365,3364,3365,3365,3364,3365,3364 324 | 3275,807,799 325 | 62,1833 326 | 1278,2939 327 | 3354,848,2967,3354,380,664 328 | 1459,1459,277,988 329 | 1896,2408,1896 330 | 115,2620,1290 331 | 3324,3324 332 | 3017,3017 333 | 65,7,2705,487,2705 334 | 2782,1675 335 | 601,49,374,49,733,192,49 336 | 1689,2080,1030,1063 337 | 1013,1569 338 | 621,622,1541,886,519 339 | 523,485,1582,1585 340 | 2876,1222,3269,2876,3269,3255 341 | 3370,767,1983,3370,833,3370,1983,3370,1239 342 | 649,359,649,182,355 343 | 1792,1792,1629,1792 344 | 649,194,649,260,363,321,1218,194,649,1828,1138,209,649 345 | 112,3110,3287,749 346 | 3172,3172,3172 347 | 258,2155,65,718,65,718,65,718,7,65,718,473,20 348 | 3131,3131,3131 349 | 1618,209 350 | 1149,2971,2000,1757,2971,2000 351 | 2911,2912 352 | 597,547 353 | 3038,465,3038,499,890 354 | 837,837,838,838,784,837,837 355 | 2895,3185 356 | 486,1036,1964 357 | 2473,664,182,2812,578,184,664,1640 358 | 1742,1744 359 | 3177,3178,3178,3177,3177,3172,3178,3172,3177 360 | 1093,1771 361 | 1553,2210,697,947,564,1726 362 | 3294,3294 363 | 2832,597 364 | 103,2786 365 | 2517,311 366 | 485,3029,485,2870,485,1459 367 | 276,276 368 | 3236,1681,3236 369 | 2215,2216,2214,2215,2214,2216,2215,2939 370 | 160,2862,2863,2863,6 371 | 2656,515,283,294,490 372 | 3330,3190,3190 373 | 197,196 374 | 2735,2658,541,2210,826 375 | 1736,1797,1797 376 | 704,781 377 | 2326,2326 378 | 1626,2937,2620,1589,2620,2620,1589 379 | 1083,2423,408,2131,1559 380 | 568,761,824 381 | 317,69 382 | 3291,3291,3291 383 | 1759,1561,1759,1759,2379,2656 384 | 1510,1399,713 385 | 946,85 386 | 485,1459 387 | 1844,1449 388 | 2823,2823,601 389 | 1584,551,3227 390 | 3,165 391 | 388,3379,388,3379 392 | 3412,3412,1139 393 | 1038,2711 394 | 3182,3182,656,3182 395 | 491,1253 396 | 2705,343,378,201,379,429,61 397 | 475,2946,112 398 | 568,569,1650,1918 399 | 3037,958 400 | 887,311,1463 401 | 46,250,1898,62 402 | 3291,3291,3314 403 | 3422,3422 404 | 1258,1257 405 | 2439,93,2439,2439,3029 406 | 2200,1837,1708,1331 407 | 664,1812 408 | 1449,1448 409 | 2825,486 410 | 552,1801 411 | 309,2820 412 | 2740,2829,2739,3001,2740,2830,2740 413 | 47,20,3135 414 | 958,959,421 415 | 1056,2200 416 | 294,295 417 | 1265,2379,2651,900 418 | 961,420,1469,417 419 | 747,455,1780 420 | 485,3029,1019 421 | 616,2815 422 | 2930,2930,3029 423 | 1092,277 424 | 857,631 425 | 2079,2060,1058 426 | 3279,595 427 | 792,1833 428 | 1162,547 429 | 537,536 430 | 3137,2914,3137,2914 431 | 1707,3389,1707,3389,1707,3389,1707,1803,3389 432 | 3278,1162 433 | 2051,2051,2052 434 | 1695,1339,2208 435 | 927,211 436 | 2136,1859 437 | 1327,1027 438 | 885,885,885,885 439 | 1871,1871,2525 440 | 1049,1851,1036,332,887 441 | 1290,1075 442 | 497,456,1933 443 | 843,3025 444 | 1332,3029 445 | 3103,3105,2990,2990,1480 446 | 2116,2581 447 | 2364,2362 448 | 1873,1873,1874,1873,1873,1873 449 | 1742,1679,3218 450 | 1852,2486,2846 451 | 3277,1164 452 | 2353,2353 453 | 2791,1538,3001 454 | 3360,2378 455 | 2976,1422 456 | 2535,2535,2534,2535,2535 457 | 1105,566 458 | 2342,2342 459 | 2486,1858,2801 460 | 3419,2096 461 | 1149,1149 462 | 3391,1510,571,3391 463 | 2697,1324 464 | 2930,2930,2870 465 | 2930,2930,2930,1220 466 | 546,546,546,1830,1394,571,1754,1830,1394 467 | 3393,3393 468 | 3392,3392,3392,3392 469 | 3276,3276,3276 470 | 441,2194,780,335 471 | 438,1310,1556,2756,204,3115,580 472 | 3026,1505,3026,1505,3026,1505,1239 473 | 1138,62 474 | 2002,485,2870,3029 475 | 2428,2428,2428,2428,2428,2428 476 | 664,183 477 | 659,223,448,2282 478 | 3264,3264,3265,3264,3264,3265,3264,3264 479 | 3356,3356,877,876,3356 480 | 1053,1804,2057,2131,1431,471,2131,1083,897,1904,485 481 | 1429,1430,374,264,474,1628,1429,1429,1429,616,1493 482 | 2197,2197,1013,2197 483 | 2657,485,1337 484 | 486,486,988,3289 485 | 580,1828 486 | 3251,978,1605,3252,978,3251,1605,3252,1605,978,3251,978,3252,1605,3251 487 | 1,139,983,670 488 | 745,211 489 | 2095,1324,1324 490 | 185,2988,2946,1556,1557,438,193,2988,46 491 | 633,559,561,559 492 | 1987,1711,2328,726,2965 493 | 221,1347,212,221 494 | 1526,1420 495 | 2126,2021,1878 496 | 1632,505 497 | 391,211 498 | 1419,1419 499 | 3360,3361,3360,3361 500 | 1713,1714,1713,2143,1571,1579,1915,712 501 | 1008,1007,1008 502 | 1337,597,2875 503 | 659,223,429,3115 504 | 345,171 505 | 616,1596 506 | 3011,1817 507 | 897,485,309 508 | 1022,475 509 | 185,3226,2902,209,2552 510 | 968,474,3378,3378,3378,3378,3378 511 | 770,770 512 | 3151,3185,3152,3152,3151,3185 513 | 1422,1422 514 | 862,485 515 | 1115,1114 516 | 3236,523,3270,1382 517 | 2289,1306,236 518 | 891,190,663,2686 519 | 335,335,208 520 | 1689,2080,1690 521 | 3118,1728,182,2946 522 | 62,194 523 | 2670,2671,2670,2671,2670,2671,2670,2671 524 | 1547,182 525 | 293,3273,843 526 | 2055,1828,1405 527 | 241,241,241,1701,241,2213,1054,1992,283 528 | 1986,485 529 | 96,1769 530 | 1277,2218 531 | 629,1752 532 | 459,2282,2283 533 | 896,2755,1752,643 534 | 3286,1700,3286,103,3286,3286 535 | 1901,1901 536 | 2926,886 537 | 2663,2662,2750,2517 538 | 967,2516,1926 539 | 2881,2930,2881,2930,1290 540 | 3322,3322,3322,3321 541 | 93,75,2657,1134,1517,1560,2619,3061,2619 542 | 641,641 543 | 3273,241,1084,1026,476,255 544 | 2008,2338,2008 545 | 41,1393 546 | 143,2235,798 547 | 3313,3312,3313,3313 548 | 2657,256,485,3029,864,2651,925,945,2657,40,241,1055,548 549 | 1312,3325 550 | 410,1290,1265 551 | 2712,1356,1356,1356,1356,1888 552 | 20,1228,20,46 553 | 753,823 554 | 1624,2782 555 | 2708,3093,2708,3143,3093 556 | 538,2664,219 557 | 710,832,710,832 558 | 3278,62,27 559 | 2232,2195 560 | 292,2056,940,238,2929,292,240,292 561 | 3015,2831,121,1554,1301,901,3015,154,2657,40,549 562 | 2681,119,2681 563 | 2080,1689,1690,2513,2059 564 | 1229,631 565 | 3264,3283,3264,3264,3265,3283,3264,3283,3264,3283,3283 566 | 927,1265,2047,2689,40,2657,1365,1586 567 | 1959,3289,988 568 | 479,739,1494,739,479,3067,1076 569 | 485,311 570 | 2203,476,485,312 571 | 1493,1526 572 | 2749,309 573 | 1158,245,2042,563 574 | 5,291 575 | 504,810,511,3180 576 | 1805,1805,1805,1805,1805,1805,1805,1805 577 | 2993,2993,1959,3120,798,3120 578 | 1742,2586,1479 579 | 1149,1757 580 | 1011,1633,1011,1633,1115 581 | 3072,3072,3072 582 | 1748,1748,1748,1748,1749 583 | 3254,3196,3197,3196,3254,3196,3254 584 | 950,949,949,949 585 | 3415,3415,3415,3415 586 | 503,490,506,606,547,267 587 | 3029,2870 588 | 616,616,1545 589 | 62,406 590 | 988,1675 591 | 2540,2540,2540,2540 592 | 2226,2226 593 | 1475,40,2657,282,2078 594 | 2960,3301,3184,62,185 595 | 3184,62,185 596 | 1480,1744 597 | 1061,2061,1058,1185,1058,1057,1325,2122,1064,331,1690,329,1058 598 | 792,40,925 599 | 2261,1115 600 | 1420,538 601 | 313,332,3160 602 | 573,574,574 603 | 2808,2807 604 | 311,311 605 | 1820,858,858 606 | 2051,2052 607 | 2509,2030,2177,2028 608 | 977,1537,977,1537,1538,1537,1283,1537,3001 609 | 792,3277 610 | 1463,2157,3014 611 | 1382,1549 612 | 1648,1565 613 | 1690,1060 614 | 957,166,417 615 | 2208,2093,1365,128,2093,2093,1431 616 | 890,1586 617 | 1079,105,105 618 | 2986,285 619 | 764,890 620 | 2714,3243 621 | 3298,3300,3298 622 | 3040,1538,3040 623 | 2065,2357,1853,2219,3124,2478 624 | 976,976 625 | 1860,1859 626 | 3201,3047,3201,3201,3048,3047 627 | 2398,2113 628 | 2527,2033,2035,2035,2246 629 | 20,473,46,20,250,659,20,223,659,20,62,20 630 | 1424,1424 631 | 1842,2296 632 | 2585,2170,2028,2238 633 | 3371,2599,2029 634 | 1641,882 635 | 2453,2154 636 | 3055,3055,2065 637 | 2482,2362,2304 638 | 1641,882 639 | 2246,2362 640 | 2066,2104 641 | 1600,1601 642 | 375,919,632 643 | 2276,2276,3417,3417,3417 644 | 1925,1740,2999,2935,1480 645 | 3124,2422,1991,2419,2422 646 | 1778,3400 647 | 338,336 648 | 2163,200 649 | 1842,2255,1842,1954 650 | 485,2384 651 | 1859,2383 652 | 1627,1297 653 | 860,2412 654 | 2889,1859 655 | 1006,1612,1613,1006,425 656 | 2147,2304 657 | 2098,2110 658 | 485,861 659 | 578,1640 660 | 1795,3037 661 | 2599,2771 662 | 485,3029,2062,285,2918,2854,311,1289,1289 663 | 2444,2359 664 | 190,663,676 665 | 3187,105 666 | 1170,1170 667 | 1215,1212 668 | 1495,2825 669 | 2256,1772,3157,2588 670 | 3383,3383 671 | 1051,40,1134,1265,1208,795 672 | 2975,1253,1558 673 | 1479,2120 674 | 3310,3309,3310,3344,3310 675 | 357,1828,182,1963,1556 676 | 618,618 677 | 311,1717 678 | 214,136 679 | 124,154 680 | 1546,1208,1025 681 | 2517,2750,2662,2663 682 | 1888,1548,371,1888 683 | 1091,2010 684 | 2388,1132,844,107,1307,5,2909 685 | 2384,62 686 | 1569,1653,1569,920 687 | 1496,958,959,166,958,808,167,960,420 688 | 2095,2095 689 | 3279,2156,1384,2975,2436 690 | 3219,1382,311 691 | 824,2278 692 | 741,860 693 | 1115,2955,1115 694 | 3037,523 695 | 244,900,509,1646,1941,2832 696 | 1690,1064,2869 697 | 3378,3378 698 | 717,2083 699 | 185,552,1173 700 | 412,1175,123,104 701 | 3369,3369 702 | 3235,241,3235,3235,241 703 | 1889,2815,2815,887 704 | 1974,1974 705 | 604,256 706 | 3179,2875 707 | 438,223,659,20 708 | 260,1219 709 | 112,1618 710 | 2657,1051,1134,2689,41,40,132,595,1108 711 | 2304,2803 712 | 186,2089 713 | 780,2518 714 | 962,485 715 | 1134,241,256,1791 716 | 2952,2952 717 | 1724,622,1651 718 | 988,285 719 | 1336,2138 720 | 2598,2868,32,2922,14,13,12,15,380,457,1227,447 721 | 1254,722 722 | 3147,3147,3147 723 | 2964,1838,1838,890 724 | 1590,586 725 | 1289,1289,2795,1860,2296 726 | 3419,2096 727 | 2223,2120 728 | 3045,1164 729 | 3334,3333 730 | 872,1711 731 | 2317,2317 732 | 2791,1282,1538,1679,2715,3170 733 | 2531,2139 734 | 1493,988 735 | 1671,1671,1671,1671,1671,1671,1671 736 | 1389,1390,1388 737 | 1215,1213 738 | 2531,2733,2584 739 | 2815,3277 740 | 2315,2522 741 | 2798,2928 742 | 2443,2467,2137,2597 743 | 2467,2137,2597 744 | 493,216 745 | 2443,2370,2246,3399,2246,3399 746 | 3149,3149,3040 747 | 27,1443 748 | 213,1347 749 | 1357,792 750 | 3344,2066 751 | 5,5 752 | 1854,1859 753 | 2742,2865 754 | 782,1944 755 | 28,1033 756 | 51,299 757 | 3008,3008 758 | 1463,2157,3014,2157,860 759 | 666,555,666,702,666,666,1932 760 | 1549,485 761 | 486,740 762 | 309,597 763 | 1795,332 764 | 3072,1744 765 | 1036,1653 766 | 2181,2182,2181,2182,2181,2182,2181,2182 767 | 1333,1547,458,12,1138,687,377,376,34 768 | 3395,3395 769 | 1987,1711 770 | 2976,538 771 | 323,2417 772 | 2427,2337 773 | 2975,516 774 | 2815,2815,887,1889 775 | 238,842 776 | 1495,62 777 | 708,1642,807 778 | 1859,2397,3384 779 | 511,1832,1026 780 | 575,34,35,36,37,31,2900 781 | 2412,860,2543 782 | 1324,1324,2872 783 | 473,481,204,3287,577 784 | 1146,314 785 | 27,2552,62,973,20 786 | 1675,1675,3272,277,277,547,1447,3293,3292 787 | 1577,861 788 | 2080,2432 789 | 309,313 790 | 1556,1557 791 | 311,62 792 | 792,792,792 793 | 241,2387,1582,1489,2213 794 | 1252,104,2380 795 | 3131,3131 796 | 2113,2588 797 | 3140,3140 798 | 1771,602 799 | 3106,1859 800 | 1378,1913 801 | 1625,3342,629,2657,1265 802 | 448,447,2660 803 | 1369,2352,1303 804 | 277,80,81,82,277 805 | 890,1372 806 | 631,597 807 | 313,485 808 | 1254,721 809 | 2235,2235 810 | 925,511,512,506,1025,890,1643,511 811 | 1579,826,1569,1013 812 | 1651,1648 813 | 887,1625 814 | 1493,486 815 | 824,1602,1579 816 | 1372,2937 817 | 3154,3155,3154,3155,3154,3155 818 | 1241,209 819 | 1995,1677 820 | 1146,1132,844 821 | 33,405,1824,775,14,12,200,20 822 | 616,62 823 | 1301,41,890,1301 824 | 3229,3229 825 | 631,1718 826 | 1258,308 827 | 260,1244 828 | 2131,1056 829 | 237,283,2213,1884 830 | 3404,2436 831 | 2679,1172,2459 832 | 1636,83,149 833 | 2209,1799,2209,1799 834 | 1493,616 835 | 1690,1062,1030 836 | 2782,1230,309,1337,792,1577,1979,1083 837 | 547,1162 838 | 887,988,1493 839 | 2615,2615 840 | 1027,861 841 | 2815,616,793 842 | 1302,124,40,1208,2047,851 843 | 276,276,486,1596 844 | 957,958 845 | 858,2940,3275,538 846 | 1228,362 847 | 1795,2432,1113,1065,1113,1031 848 | 1545,311,2854,311,312,311,1596 849 | 2687,547,485 850 | 2933,2395,2254 851 | 131,294,550,295,255 852 | 993,911,2127 853 | 1394,634,946 854 | 2438,2077,3381,2032 855 | 6,2058 856 | 1399,1400,1569 857 | 1378,1378 858 | 311,1382,3236,1742 859 | 237,238,237,284,283,2125,890,1201 860 | 1149,2839 861 | 2246,2153 862 | 2256,1772,2438,3375,1772,1164,1913 863 | 3272,1382 864 | 1873,1873,1873 865 | 3078,3077,1607,1606,1609 866 | 1229,2682 867 | 1550,672 868 | 1212,1214,1441 869 | 616,1596 870 | 2476,2035 871 | 1076,538 872 | 1526,860,2330,1962 873 | 1140,3400,3400 874 | 1029,742 875 | 682,589 876 | 2043,2043 877 | 2095,1324,2095 878 | 2403,2031,2254,2137,2403,2527,2147 879 | 2561,1859 880 | 909,652,651 881 | 2070,2416 882 | 2029,2105,2642,2889,2510,1859 883 | 631,2412 884 | 1278,1590,313,394 885 | 946,696,1174,1750,1174 886 | 3350,3349,3350 887 | 839,2580 888 | 2928,3183,2760 889 | 2825,1213 890 | 2242,2307,2311 891 | 1479,1744,3295,3091,3076 892 | 2577,2577 893 | 631,486 894 | 353,258 895 | 1231,1495 896 | 656,1752 897 | 2223,2120,1564,2223 898 | 2401,2362,2311 899 | 3347,3348,1600 900 | 2384,486 901 | 223,659 902 | 1550,1550,1943,2 903 | 2395,2443,2482 904 | 1091,1790 905 | 1382,2435 906 | 2220,2068 907 | 83,1729,1790 908 | 2600,1194 909 | 302,2686 910 | 925,512,513 911 | 700,2279 912 | 3192,3191,3192 913 | 1849,1850 914 | 1950,1991,2304 915 | 1270,1269 916 | 1687,2993 917 | 2697,1324 918 | 892,631,2986 919 | 3402,3402,3402 920 | 1687,1687 921 | 377,875,747,3165,526,1921,1512,7 922 | 1849,1850 923 | 2065,2030,2179 924 | 200,775,321,193 925 | 91,1583,129,1052 926 | 112,488,2083,580,3301,2083 927 | 356,2966,183,63,3135,193 928 | 877,876,1424,1424 929 | 3419,3419 930 | 1998,1998 931 | 3396,3318 932 | 358,1555 933 | 292,1384,2975,643 934 | 3424,3424 935 | 2820,309 936 | 3330,3330 937 | 313,421,2039 938 | 96,319,2092,309,597,2875 939 | 2089,20 940 | 1223,2860,2860,1223,277,311 941 | 3022,353 942 | 3177,3172,3177,3178,3177 943 | 578,2389 944 | 2706,1526 945 | 209,1138 946 | 486,267 947 | 890,3279,1385,1764,855,2656,855 948 | 1833,616 949 | 711,1596 950 | 2944,2943,2945,2691 951 | 1405,1785,1426 952 | 890,900,1704 953 | 1550,1628 954 | 3274,3274 955 | 617,617,617 956 | 2917,2199 957 | 238,2975 958 | 1351,1781,1259 959 | 1675,1231,962,309 960 | 1371,636,2905,498,644 961 | 2939,1277 962 | 3,138,217,195 963 | 283,2620 964 | 1271,1007,1008,1007 965 | 2131,2131,3037,861,3267,2910,890 966 | 506,476,490 967 | 124,2346,1302,237,1054,2975,1992 968 | 475,1310 969 | 126,146,127 970 | 1495,485 971 | 362,360 972 | 309,792,1026 973 | 3193,3018 974 | 332,3278,1625 975 | 1736,916,248,410,266 976 | 1718,631,1795,485,2964,1517,499 977 | 857,313 978 | 2316,1772 979 | 2422,1859,2421,2419,2422 980 | 832,832,832 981 | 1954,2270,1954 982 | 547,1795 983 | 2636,2307,2438 984 | 3099,3100,3100,3099,3100 985 | 2064,3332 986 | 2622,2622 987 | 2599,2510 988 | 2482,2032,2438 989 | 1289,1495 990 | 3090,3076,2998,1538 991 | 1839,2399 992 | 2406,2249 993 | 1631,797,1631 994 | 1859,2511,1859,2692,1772,2933 995 | 2509,1956 996 | 1805,1805 997 | 2090,2028,2362,2486,2136 998 | 2723,2359,2315,2311,2527 999 | 877,878 1000 | 2649,2540 1001 | 1236,3125 1002 | 532,3081,2154,2250,2252,2262,2243 1003 | 67,193,65,718 1004 | 1382,985,985 1005 | 2319,2578,2928,2588,2077 1006 | 2066,2891,2311 1007 | 1870,1870 1008 | 3367,3366,3367 1009 | 3385,3082,3385,3082,3385,3385 1010 | 2443,2370 1011 | 1721,2022,2023,1721 1012 | 309,332 1013 | 2438,2486,1858 1014 | 1047,1047 1015 | 2957,2957,2958,3344,2958,2957,2958,3117,2957 1016 | 2780,2527 1017 | 2749,1115 1018 | 1699,1699 1019 | 1015,2121 1020 | 1027,1692 1021 | 160,5 1022 | 2032,2246,2634 1023 | 271,271,272 1024 | 3053,3053 1025 | 3232,3232 1026 | 1698,2970 1027 | 1517,1517 1028 | 1986,485,2203 1029 | 1472,3236 1030 | 2527,2304 1031 | 1459,3264,3263,3264,3265,3264 1032 | 511,1348,2962,2755,2962 1033 | 547,267 1034 | 3368,3369 1035 | 1255,1256 1036 | 267,3037,503,2831,503 1037 | 2016,1858 1038 | 1689,2080 1039 | 2136,2607 1040 | 2762,2137 1041 | 1091,116 1042 | 2907,2907,3114,2907,3114,2907 1043 | 182,355,62,3184,182,181,3184 1044 | 3148,3148 1045 | 3305,3304,3305,3304,3305,3304,3305 1046 | 1833,1526 1047 | 2933,2443,2654,2252 1048 | 183,1241 1049 | 1080,426 1050 | 2096,1324,2096 1051 | 965,407,966,407,1151,1542,966 1052 | 1771,3172,3174,3172,3174,3172 1053 | 623,1750 1054 | 89,88,90,3,51,1464 1055 | 1844,1449 1056 | 262,17 1057 | 1310,1151 1058 | 2672,2673 1059 | 1149,1779 1060 | 1560,485,597 1061 | 383,67,1245 1062 | 3386,1517,3279,244,270,1398 1063 | 2802,2015 1064 | 2786,2786 1065 | 1590,1857 1066 | 3160,3160 1067 | 1239,1983,3370,891,1983,3370 1068 | 31,2812 1069 | 547,631 1070 | 2714,1229 1071 | 631,2280 1072 | 1495,40,2657,1134,510 1073 | 1517,1331 1074 | 2650,2649,2650,2649,2650 1075 | 1526,547 1076 | 329,1690 1077 | 421,2040,419 1078 | 1460,1458 1079 | 2708,3093 1080 | 1756,714 1081 | 2895,3185 1082 | 1289,631 1083 | 1505,3026,1505,1507 1084 | 125,1733 1085 | 442,538,538 1086 | 890,490,607,1331 1087 | 2831,1331 1088 | 1772,2611 1089 | 2057,394,294 1090 | 758,758,758 1091 | 1808,424,1613,466 1092 | 1833,1526,1717 1093 | 1990,2846 1094 | 1602,1765,1449 1095 | 2861,3020 1096 | 124,2346,1546 1097 | 3182,1551 1098 | 954,951,1070 1099 | 2994,1012,1012 1100 | 1070,1070 1101 | 3396,3318 1102 | 309,792,1385,1835,2873,1026,3404,904 1103 | 2187,2641,2639 1104 | 3396,3318,3396 1105 | 3396,3396,3318,3396 1106 | 2292,551 1107 | 1439,2126 1108 | 747,403,1512,376,3165,875,3226,874,377 1109 | 1499,1476,1476 1110 | 2674,2303,2675,2599 1111 | 519,1243 1112 | 1092,1717,1715,1715,2936 1113 | 2476,3382,2476,3382,3382,2364,3382 1114 | 1978,568,519,666,568 1115 | 530,1579,1995,1394,1976 1116 | 1285,811,949,949,949,1285,811,950 1117 | 264,439 1118 | 2500,2500,1859 1119 | 2214,2214 1120 | 1972,1933,658 1121 | 314,314 1122 | 353,171,473,3135,362,360,455,20 1123 | 2170,2177,2170,2169,2168,2028,2030,2254,2370,1858,2247 1124 | 2532,112,2552 1125 | 449,105,3033 1126 | 356,3115 1127 | 33,2900,1898,46 1128 | 380,13 1129 | 1435,564 1130 | 407,407,1119 1131 | 1495,1495 1132 | 511,511 1133 | 46,193 1134 | 568,569 1135 | 1965,97 1136 | 1590,1857,1182 1137 | 116,2010,116 1138 | 568,569,568 1139 | 1265,2975 1140 | 20,1971,46 1141 | 807,3275 1142 | 691,433 1143 | 277,3272 1144 | 10,2922,13 1145 | 1236,2657,40,1762,1051,1695,241,2679 1146 | 631,631,3207 1147 | 1135,1326 1148 | 1600,1601,1600 1149 | 3285,1496 1150 | 1759,2379,1759,2656 1151 | 241,890,3025,3273 1152 | 46,1785,360 1153 | 2131,897,1693,744 1154 | 193,185 1155 | 160,400,611 1156 | 797,988,597 1157 | 3325,1312 1158 | 62,857 1159 | 563,629,292 1160 | 3235,3235 1161 | 2767,1533,2767 1162 | 1459,485 1163 | 344,717,3420,1046,1151,3420,2966,3420 1164 | 2726,2726,2797,2726 1165 | 1108,1475 1166 | 2692,2692 1167 | 2737,2628 1168 | 3123,2502,2146 1169 | 2954,1283,2954 1170 | 2149,2453,2252 1171 | 3108,1852,3108 1172 | 1422,538 1173 | 2757,1721,2022 1174 | 2246,2362 1175 | 3168,3168,3344 1176 | 3011,1493,3011,3011 1177 | 3142,3142 1178 | 2145,2145 1179 | 3376,3376 1180 | 3044,3044,3044,3044 1181 | 1480,2223,1564,2223 1182 | 311,1076,1928 1183 | 1034,1034,1034,1034 1184 | 2077,2033,1859,2032 1185 | 2898,2644,2016,2591,2509,2016,2675 1186 | 2666,234,2653,942,2025,1623,2666,1215 1187 | 2043,2043 1188 | 1775,1774,1774,1774,1775 1189 | 2476,2476,2234 1190 | 1911,1912,1860,1958,1911 1191 | 3324,3324 1192 | 2985,2985,2985 1193 | 1353,1353 1194 | 1549,1459,3264,3263,3264,3265,1459,3264,3265,3264,3264 1195 | 3094,3094,3094 1196 | 2946,20,3115,759,20 1197 | 687,1443,396,3208,39,20,1046,1046,1345 1198 | 2083,526 1199 | 2102,2033 1200 | 3416,3416 1201 | 538,426 1202 | 2605,2237 1203 | 1634,1444,1635,1444,1634,1634 1204 | 2504,2502,2504,2502 1205 | 3117,2958,3117,2958,3117,2958,3117 1206 | 3341,3341,3341,3341 1207 | 1726,1014 1208 | 1422,1459,1382 1209 | 1597,1423 1210 | 32,46,1971,1426,1971 1211 | 1323,3338,1323,3338,1323,1323,3338,3338,1323,1323,1323 1212 | 3349,3350 1213 | 1048,3279,843,1582,254,1582 1214 | 33,406,405,552 1215 | 2362,2654 1216 | 10,194,659,2922,13,552 1217 | 2728,2856,2697,1209 1218 | 631,485,907,2975,907,907 1219 | 2901,332,332 1220 | 97,97,97 1221 | 395,3187,105 1222 | 986,868 1223 | 2097,2697 1224 | 2697,2697,1324,1324,2697,2697,1324,1324,2856,1324 1225 | 2304,2297 1226 | 3054,3054,1736,3053 1227 | 1609,1609 1228 | 2486,2611 1229 | 3052,2450,2800,3052,2800,3052,2800 1230 | 2692,2599,2890,2316,2599 1231 | 2589,2645 1232 | 517,1361,1202 1233 | 2403,3109,2397,2522,2403 1234 | 1792,1629,1792,49,1792,1629 1235 | 2895,103 1236 | 2583,2582 1237 | 267,597,2161 1238 | 238,237,3404 1239 | 52,384,397,620 1240 | 1092,740,426,335,334 1241 | 200,473,1229,597 1242 | 1226,340,340,341 1243 | 2672,3012 1244 | 482,482 1245 | 3283,3264,3264,3264,3264,3264,3264 1246 | 142,1620,733 1247 | 313,485 1248 | 51,1550 1249 | 1374,116,474,23,3150,2781 1250 | 926,509 1251 | 373,1485,1485 1252 | 1798,1531 1253 | 2389,62,1076 1254 | 3426,3426 1255 | 241,241,2213 1256 | 720,3396,844 1257 | 662,1239 1258 | 1903,332,313,1255,1408,1255 1259 | 770,2815,2816,771,771 1260 | 504,739 1261 | 2770,2769,2770,2769,2770 1262 | 270,890 1263 | 1549,277,740 1264 | 1678,631 1265 | 840,48,381,659 1266 | 1255,1268 1267 | 623,1465,1650,332,95 1268 | 540,1348 1269 | 876,877 1270 | 1833,792 1271 | 3277,1624 1272 | 962,486,1231 1273 | 2881,2930 1274 | 262,1223,2849,1223,262,1223 1275 | 1169,1170,1169,1170,1169,1170 1276 | 1689,329,2080,2059,1186,1030,1062,1064,2060,1064 1277 | 550,1328 1278 | 2501,2453 1279 | 2290,3267 1280 | 20,14,12,775 1281 | 1954,2270 1282 | 311,988,1986,1986,988,1926 1283 | 3250,3249,3351,3249,3249 1284 | 1849,1850 1285 | 239,240 1286 | 458,688 1287 | 1066,132,241,1904,294 1288 | 2434,2385 1289 | 740,311 1290 | 538,779,2687 1291 | 3414,3414 1292 | 597,1920,1996,1708,1409,897,409 1293 | 1671,1671,1671,1671 1294 | 193,759,20 1295 | 3313,3312,3313,3312 1296 | 2876,3255 1297 | 1019,1736,1624 1298 | 277,1092 1299 | 1864,1189,1206,1134,40 1300 | 1068,1172,241,2657,40,1762,41,2213,1835 1301 | 27,990 1302 | 169,5 1303 | 266,266 1304 | 890,3273,3025 1305 | 444,316,169 1306 | 3190,3190 1307 | 1901,1901,1901 1308 | 2157,860 1309 | 1848,1579,1932,1649 1310 | 1048,1290,721 1311 | 1403,1829 1312 | 476,1208,1829 1313 | 1299,311 1314 | 887,1624,1625,1178 1315 | 243,656,243,1546,1924 1316 | 1704,1084 1317 | 1528,311 1318 | 3131,3131,3131 1319 | 570,666 1320 | 105,1320,2751,1172,1068 1321 | 1596,616 1322 | 2057,324 1323 | 32,457 1324 | 2875,3037,2977 1325 | 32,457 1326 | 3147,3147,3147 1327 | 1800,2209,1799 1328 | 918,1252,104,2380,485 1329 | 1675,597 1330 | 2588,2438 1331 | 142,170,374 1332 | 1076,313 1333 | 1695,631 1334 | 1495,3037 1335 | 1066,1655,1265,2354,3273,123,595 1336 | 1748,1748,1748 1337 | 1791,2354,744,1295,1727,3037 1338 | 1336,3037 1339 | 280,377,525,497,496,456,3135 1340 | 2137,2629 1341 | 2242,2307 1342 | 1715,1715 1343 | 1080,1236 1344 | 2438,2933 1345 | 669,1032,669,1032,669,1032,1032 1346 | 2558,2557 1347 | 2179,2028,2180 1348 | 2610,2610 1349 | 2171,2172 1350 | 2552,2902,3132,1120,3028,3132,3028,3027,3028,3027,3028,3028 1351 | 3333,1841,3333 1352 | 3186,3186 1353 | 1817,1493,485,1577,104,485,267,267 1354 | 1843,1341 1355 | 3310,3344,3309 1356 | 2958,3344,2957 1357 | 1740,2222 1358 | 1859,2321 1359 | 2965,872,1711,1845,2360,2360,1711,2361 1360 | 2276,2028 1361 | 3197,3196,3254,3196,3254,3197,3196,3254,3197,3196,3254 1362 | 2476,2476,2591 1363 | 3360,3361 1364 | 1841,1840,1841,1840 1365 | 439,2085,116,1981,1091,311 1366 | 2535,2535,2535,2535 1367 | 792,792,792 1368 | 578,664,183,748,749,179,664 1369 | 2160,2741 1370 | 288,374 1371 | 1700,1700 1372 | 1744,2990,1721,2678 1373 | 1828,759,1972,1705 1374 | 568,761,1013,1568,997,996 1375 | 609,87 1376 | 1600,1601 1377 | 459,455,459 1378 | 1484,794 1379 | 2393,2393,2393,2393 1380 | 518,1116,518,518,1190,518,522,1190,881 1381 | 65,718,3226,65,718,2155,65,718,20 1382 | 2992,2344 1383 | 486,547 1384 | 486,1382 1385 | 631,1019 1386 | 485,547,523,336,336,338,336 1387 | 2786,2893 1388 | 2057,294 1389 | 2419,2422,2419 1390 | 718,1782,1245,718,524,688 1391 | 1827,1830,1830,1918,571 1392 | 3177,601,149,3174 1393 | 1013,1587,1932 1394 | 332,383 1395 | 724,724 1396 | 3263,3264 1397 | 3177,601,149,3177,2235 1398 | 1933,27 1399 | 2996,2834,2834,2996,2834 1400 | 861,1746,925 1401 | 31,189,627 1402 | 1480,2991 1403 | 358,812,1555,438,1557,1555,1046,2598 1404 | 2213,2125,1611,2213,1611,2125 1405 | 3404,1293,3404,629,499,1293,629,505 1406 | 94,439,1091,116,3095 1407 | 1244,3354,380 1408 | 11,3354,380 1409 | 2355,2346 1410 | 1735,485 1411 | 2946,3115 1412 | 116,116 1413 | 1144,1518 1414 | 571,546,529,546 1415 | 2875,227,227,692,2662,2517 1416 | 472,1627 1417 | 643,499,2620 1418 | 2821,33 1419 | 1356,3033,1356,3033,1356,371,3033 1420 | 1209,1324,3406,1209,1324,1209,3406,3406 1421 | 1003,339,340 1422 | 211,1125,1122,1125,722,211,1125,1125,240,1125,304 1423 | 1582,1919 1424 | 195,301 1425 | 1862,268,268 1426 | 1526,1162 1427 | 1458,1865,1460 1428 | 1774,1774 1429 | 748,748 1430 | 2620,2200 1431 | 240,3404,211 1432 | 2620,1354 1433 | 116,2005 1434 | 3311,3311 1435 | 3137,2914,3137,2914,3137,2914,3137,2914 1436 | 1873,1873 1437 | 209,686 1438 | 1997,1997,1997,1997 1439 | 1382,631 1440 | 2552,3287,203,46,2552 1441 | 552,2817,2988,2817,580,2817,2817,1535 1442 | 1589,2620,1354,2620,1589,1354 1443 | 2742,2885,2794,2733 1444 | 749,1801 1445 | 735,601 1446 | 1526,485 1447 | 1425,3327,1563,1652,1563,1947,1563 1448 | 2094,2689,1265 1449 | 595,887 1450 | 2301,2395,2275 1451 | 2705,2281 1452 | 1095,1080,1003 1453 | 2477,2363 1454 | 2026,2665 1455 | 178,1077,839,484,149 1456 | 1666,699 1457 | 546,699 1458 | 332,486 1459 | 1047,1326 1460 | 3286,1700,3286,3286,1386,3286,3286,3286 1461 | 704,62,384,386,386 1462 | 1409,2621 1463 | 199,1619,598,44,423 1464 | 486,740 1465 | 3425,3425 1466 | -------------------------------------------------------------------------------- /lessr.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | import dgl 4 | import dgl.ops as F 5 | import dgl.function as fn 6 | 7 | 8 | class EOPA(nn.Module): 9 | def __init__( 10 | self, input_dim, output_dim, batch_norm=True, feat_drop=0.0, activation=None 11 | ): 12 | super().__init__() 13 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 14 | self.feat_drop = nn.Dropout(feat_drop) 15 | self.gru = nn.GRU(input_dim, input_dim, batch_first=True) 16 | self.fc_self = nn.Linear(input_dim, output_dim, bias=False) 17 | self.fc_neigh = nn.Linear(input_dim, output_dim, bias=False) 18 | self.activation = activation 19 | 20 | def reducer(self, nodes): 21 | m = nodes.mailbox['m'] # (num_nodes, deg, d) 22 | # m[i]: the messages passed to the i-th node with in-degree equal to 'deg' 23 | # the order of messages follows the order of incoming edges 24 | # since the edges are sorted by occurrence time when the EOP multigraph is built 25 | # the messages are in the order required by EOPA 26 | _, hn = self.gru(m) # hn: (1, num_nodes, d) 27 | return {'neigh': hn.squeeze(0)} 28 | 29 | def forward(self, mg, feat): 30 | with mg.local_scope(): 31 | if self.batch_norm is not None: 32 | feat = self.batch_norm(feat) 33 | mg.ndata['ft'] = self.feat_drop(feat) 34 | if mg.number_of_edges() > 0: 35 | mg.update_all(fn.copy_u('ft', 'm'), self.reducer) 36 | neigh = mg.ndata['neigh'] 37 | rst = self.fc_self(feat) + self.fc_neigh(neigh) 38 | else: 39 | rst = self.fc_self(feat) 40 | if self.activation is not None: 41 | rst = self.activation(rst) 42 | return rst 43 | 44 | 45 | class SGAT(nn.Module): 46 | def __init__( 47 | self, 48 | input_dim, 49 | hidden_dim, 50 | output_dim, 51 | batch_norm=True, 52 | feat_drop=0.0, 53 | activation=None, 54 | ): 55 | super().__init__() 56 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 57 | self.feat_drop = nn.Dropout(feat_drop) 58 | self.fc_q = nn.Linear(input_dim, hidden_dim, bias=True) 59 | self.fc_k = nn.Linear(input_dim, hidden_dim, bias=False) 60 | self.fc_v = nn.Linear(input_dim, output_dim, bias=False) 61 | self.fc_e = nn.Linear(hidden_dim, 1, bias=False) 62 | self.activation = activation 63 | 64 | def forward(self, sg, feat): 65 | if self.batch_norm is not None: 66 | feat = self.batch_norm(feat) 67 | feat = self.feat_drop(feat) 68 | q = self.fc_q(feat) 69 | k = self.fc_k(feat) 70 | v = self.fc_v(feat) 71 | e = F.u_add_v(sg, q, k) 72 | e = self.fc_e(th.sigmoid(e)) 73 | a = F.edge_softmax(sg, e) 74 | rst = F.u_mul_e_sum(sg, v, a) 75 | if self.activation is not None: 76 | rst = self.activation(rst) 77 | return rst 78 | 79 | 80 | class AttnReadout(nn.Module): 81 | def __init__( 82 | self, 83 | input_dim, 84 | hidden_dim, 85 | output_dim, 86 | batch_norm=True, 87 | feat_drop=0.0, 88 | activation=None, 89 | ): 90 | super().__init__() 91 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 92 | self.feat_drop = nn.Dropout(feat_drop) 93 | self.fc_u = nn.Linear(input_dim, hidden_dim, bias=False) 94 | self.fc_v = nn.Linear(input_dim, hidden_dim, bias=True) 95 | self.fc_e = nn.Linear(hidden_dim, 1, bias=False) 96 | self.fc_out = ( 97 | nn.Linear(input_dim, output_dim, bias=False) 98 | if output_dim != input_dim else None 99 | ) 100 | self.activation = activation 101 | 102 | def forward(self, g, feat, last_nodes): 103 | if self.batch_norm is not None: 104 | feat = self.batch_norm(feat) 105 | feat = self.feat_drop(feat) 106 | feat_u = self.fc_u(feat) 107 | feat_v = self.fc_v(feat[last_nodes]) 108 | feat_v = dgl.broadcast_nodes(g, feat_v) 109 | e = self.fc_e(th.sigmoid(feat_u + feat_v)) 110 | alpha = F.segment.segment_softmax(g.batch_num_nodes(), e) 111 | feat_norm = feat * alpha 112 | rst = F.segment.segment_reduce(g.batch_num_nodes(), feat_norm, 'sum') 113 | if self.fc_out is not None: 114 | rst = self.fc_out(rst) 115 | if self.activation is not None: 116 | rst = self.activation(rst) 117 | return rst 118 | 119 | 120 | class LESSR(nn.Module): 121 | def __init__( 122 | self, num_items, embedding_dim, num_layers, batch_norm=True, feat_drop=0.0 123 | ): 124 | super().__init__() 125 | self.embedding = nn.Embedding(num_items, embedding_dim, max_norm=1) 126 | self.indices = nn.Parameter( 127 | th.arange(num_items, dtype=th.long), requires_grad=False 128 | ) 129 | self.num_layers = num_layers 130 | self.layers = nn.ModuleList() 131 | input_dim = embedding_dim 132 | for i in range(num_layers): 133 | if i % 2 == 0: 134 | layer = EOPA( 135 | input_dim, 136 | embedding_dim, 137 | batch_norm=batch_norm, 138 | feat_drop=feat_drop, 139 | activation=nn.PReLU(embedding_dim), 140 | ) 141 | else: 142 | layer = SGAT( 143 | input_dim, 144 | embedding_dim, 145 | embedding_dim, 146 | batch_norm=batch_norm, 147 | feat_drop=feat_drop, 148 | activation=nn.PReLU(embedding_dim), 149 | ) 150 | input_dim += embedding_dim 151 | self.layers.append(layer) 152 | self.readout = AttnReadout( 153 | input_dim, 154 | embedding_dim, 155 | embedding_dim, 156 | batch_norm=batch_norm, 157 | feat_drop=feat_drop, 158 | activation=nn.PReLU(embedding_dim), 159 | ) 160 | input_dim += embedding_dim 161 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 162 | self.feat_drop = nn.Dropout(feat_drop) 163 | self.fc_sr = nn.Linear(input_dim, embedding_dim, bias=False) 164 | 165 | def forward(self, mg, sg=None): 166 | iid = mg.ndata['iid'] 167 | feat = self.embedding(iid) 168 | for i, layer in enumerate(self.layers): 169 | if i % 2 == 0: 170 | out = layer(mg, feat) 171 | else: 172 | out = layer(sg, feat) 173 | feat = th.cat([out, feat], dim=1) 174 | last_nodes = mg.filter_nodes(lambda nodes: nodes.data['last'] == 1) 175 | sr_g = self.readout(mg, feat, last_nodes) 176 | sr_l = feat[last_nodes] 177 | sr = th.cat([sr_l, sr_g], dim=1) 178 | if self.batch_norm is not None: 179 | sr = self.batch_norm(sr) 180 | sr = self.fc_sr(self.feat_drop(sr)) 181 | logits = sr @ self.embedding(self.indices).t() 182 | return logits 183 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 4 | parser.add_argument( 5 | '--dataset-dir', default='datasets/sample', help='the dataset directory' 6 | ) 7 | parser.add_argument('--embedding-dim', type=int, default=32, help='the embedding size') 8 | parser.add_argument('--num-layers', type=int, default=3, help='the number of layers') 9 | parser.add_argument( 10 | '--feat-drop', type=float, default=0.2, help='the dropout ratio for features' 11 | ) 12 | parser.add_argument('--lr', type=float, default=1e-3, help='the learning rate') 13 | parser.add_argument( 14 | '--batch-size', type=int, default=512, help='the batch size for training' 15 | ) 16 | parser.add_argument( 17 | '--epochs', type=int, default=30, help='the number of training epochs' 18 | ) 19 | parser.add_argument( 20 | '--weight-decay', 21 | type=float, 22 | default=1e-4, 23 | help='the parameter for L2 regularization', 24 | ) 25 | parser.add_argument( 26 | '--Ks', 27 | default='10,20', 28 | help='the values of K in evaluation metrics, separated by commas', 29 | ) 30 | parser.add_argument( 31 | '--patience', 32 | type=int, 33 | default=2, 34 | help='the number of epochs that the performance does not improves after which the training stops', 35 | ) 36 | parser.add_argument( 37 | '--num-workers', 38 | type=int, 39 | default=4, 40 | help='the number of processes to load the input graphs', 41 | ) 42 | parser.add_argument( 43 | '--valid-split', 44 | type=float, 45 | default=None, 46 | help='the fraction for the validation set', 47 | ) 48 | parser.add_argument( 49 | '--log-interval', 50 | type=int, 51 | default=100, 52 | help='print the loss after this number of iterations', 53 | ) 54 | args = parser.parse_args() 55 | print(args) 56 | 57 | from pathlib import Path 58 | import torch as th 59 | from torch.utils.data import DataLoader 60 | from utils.data.dataset import read_dataset, AugmentedDataset 61 | from utils.data.collate import ( 62 | seq_to_eop_multigraph, 63 | seq_to_shortcut_graph, 64 | collate_fn_factory, 65 | ) 66 | from utils.train import TrainRunner 67 | from lessr import LESSR 68 | 69 | dataset_dir = Path(args.dataset_dir) 70 | args.Ks = [int(K) for K in args.Ks.split(',')] 71 | print('reading dataset') 72 | train_sessions, test_sessions, num_items = read_dataset(dataset_dir) 73 | 74 | if args.valid_split is not None: 75 | num_valid = int(len(train_sessions) * args.valid_split) 76 | test_sessions = train_sessions[-num_valid:] 77 | train_sessions = train_sessions[:-num_valid] 78 | 79 | train_set = AugmentedDataset(train_sessions) 80 | test_set = AugmentedDataset(test_sessions) 81 | 82 | if args.num_layers > 1: 83 | collate_fn = collate_fn_factory(seq_to_eop_multigraph, seq_to_shortcut_graph) 84 | else: 85 | collate_fn = collate_fn_factory(seq_to_eop_multigraph) 86 | 87 | train_loader = DataLoader( 88 | train_set, 89 | batch_size=args.batch_size, 90 | shuffle=True, 91 | drop_last=True, 92 | num_workers=args.num_workers, 93 | collate_fn=collate_fn, 94 | ) 95 | 96 | test_loader = DataLoader( 97 | test_set, 98 | batch_size=args.batch_size, 99 | shuffle=False, 100 | num_workers=args.num_workers, 101 | collate_fn=collate_fn, 102 | ) 103 | 104 | model = LESSR(num_items, args.embedding_dim, args.num_layers, feat_drop=args.feat_drop) 105 | device = th.device('cuda' if th.cuda.is_available() else 'cpu') 106 | model = model.to(device) 107 | print(model) 108 | 109 | runner = TrainRunner( 110 | model, 111 | train_loader, 112 | test_loader, 113 | device=device, 114 | lr=args.lr, 115 | weight_decay=args.weight_decay, 116 | patience=args.patience, 117 | Ks=args.Ks, 118 | ) 119 | 120 | print('start training') 121 | runner.train(args.epochs, args.log_interval) 122 | -------------------------------------------------------------------------------- /packages.yml: -------------------------------------------------------------------------------- 1 | name: lessr 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - dglteam 6 | - defaults 7 | dependencies: 8 | - python=3.8.6=h852b56e_0_cpython 9 | - cudatoolkit=10.2.89=hfd86e86_1 10 | - dgl-cuda10.2=0.5.2=py38_0 11 | - numpy=1.19.1=py38hbc27379_2 12 | - pandas=1.1.3=py38h950e882_0 13 | - pytorch=1.6.0=py3.8_cuda10.2.89_cudnn7.6.5_0 14 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 5 | 6 | optional = parser._action_groups.pop() 7 | required = parser.add_argument_group('required arguments') 8 | required.add_argument( 9 | '-d', 10 | '--dataset', 11 | choices=['diginetica', 'gowalla', 'lastfm'], 12 | required=True, 13 | help='the dataset name', 14 | ) 15 | required.add_argument( 16 | '-f', 17 | '--filepath', 18 | required=True, 19 | help='the file for the dataset, i.e., "train-item-views.csv" for diginetica, ' 20 | '"loc-gowalla_totalCheckins.txt" for gowalla, ' 21 | '"userid-timestamp-artid-artname-traid-traname.tsv" for lastfm', 22 | ) 23 | optional.add_argument( 24 | '-t', 25 | '--dataset-dir', 26 | default='datasets/{dataset}', 27 | help='the folder to save the preprocessed dataset', 28 | ) 29 | parser._action_groups.append(optional) 30 | args = parser.parse_args() 31 | 32 | dataset_dir = Path(args.dataset_dir.format(dataset=args.dataset)) 33 | 34 | if args.dataset == 'diginetica': 35 | from utils.data.preprocess import preprocess_diginetica 36 | 37 | preprocess_diginetica(dataset_dir, args.filepath) 38 | else: 39 | from pandas import Timedelta 40 | from utils.data.preprocess import preprocess_gowalla_lastfm 41 | 42 | csv_file = args.filepath 43 | if args.dataset == 'gowalla': 44 | usecols = [0, 1, 4] 45 | interval = Timedelta(days=1) 46 | n = 30000 47 | else: 48 | usecols = [0, 1, 2] 49 | interval = Timedelta(hours=8) 50 | n = 40000 51 | preprocess_gowalla_lastfm(dataset_dir, csv_file, usecols, interval, n) 52 | -------------------------------------------------------------------------------- /utils/data/collate.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import numpy as np 3 | import torch as th 4 | import dgl 5 | 6 | 7 | def label_last(g, last_nid): 8 | is_last = th.zeros(g.number_of_nodes(), dtype=th.int32) 9 | is_last[last_nid] = 1 10 | g.ndata['last'] = is_last 11 | return g 12 | 13 | 14 | def seq_to_eop_multigraph(seq): 15 | items = np.unique(seq) 16 | iid2nid = {iid: i for i, iid in enumerate(items)} 17 | num_nodes = len(items) 18 | 19 | if len(seq) > 1: 20 | seq_nid = [iid2nid[iid] for iid in seq] 21 | src = seq_nid[:-1] 22 | dst = seq_nid[1:] 23 | else: 24 | src = th.LongTensor([]) 25 | dst = th.LongTensor([]) 26 | g = dgl.graph((src, dst), num_nodes=num_nodes) 27 | g.ndata['iid'] = th.tensor(items, dtype=th.long) 28 | label_last(g, iid2nid[seq[-1]]) 29 | return g 30 | 31 | 32 | def seq_to_shortcut_graph(seq): 33 | items = np.unique(seq) 34 | iid2nid = {iid: i for i, iid in enumerate(items)} 35 | num_nodes = len(items) 36 | 37 | seq_nid = [iid2nid[iid] for iid in seq] 38 | counter = Counter( 39 | [(seq_nid[i], seq_nid[j]) for i in range(len(seq)) for j in range(i, len(seq))] 40 | ) 41 | edges = counter.keys() 42 | src, dst = zip(*edges) 43 | 44 | g = dgl.graph((src, dst), num_nodes=num_nodes) 45 | return g 46 | 47 | 48 | def collate_fn_factory(*seq_to_graph_fns): 49 | def collate_fn(samples): 50 | seqs, labels = zip(*samples) 51 | inputs = [] 52 | for seq_to_graph in seq_to_graph_fns: 53 | graphs = list(map(seq_to_graph, seqs)) 54 | bg = dgl.batch(graphs) 55 | inputs.append(bg) 56 | labels = th.LongTensor(labels) 57 | return inputs, labels 58 | 59 | return collate_fn 60 | -------------------------------------------------------------------------------- /utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import pandas as pd 4 | 5 | 6 | def create_index(sessions): 7 | lens = np.fromiter(map(len, sessions), dtype=np.long) 8 | session_idx = np.repeat(np.arange(len(sessions)), lens - 1) 9 | label_idx = map(lambda l: range(1, l), lens) 10 | label_idx = itertools.chain.from_iterable(label_idx) 11 | label_idx = np.fromiter(label_idx, dtype=np.long) 12 | idx = np.column_stack((session_idx, label_idx)) 13 | return idx 14 | 15 | 16 | def read_sessions(filepath): 17 | sessions = pd.read_csv(filepath, sep='\t', header=None, squeeze=True) 18 | sessions = sessions.apply(lambda x: list(map(int, x.split(',')))).values 19 | return sessions 20 | 21 | 22 | def read_dataset(dataset_dir): 23 | train_sessions = read_sessions(dataset_dir / 'train.txt') 24 | test_sessions = read_sessions(dataset_dir / 'test.txt') 25 | with open(dataset_dir / 'num_items.txt', 'r') as f: 26 | num_items = int(f.readline()) 27 | return train_sessions, test_sessions, num_items 28 | 29 | 30 | class AugmentedDataset: 31 | def __init__(self, sessions, sort_by_length=True): 32 | self.sessions = sessions 33 | index = create_index(sessions) # columns: sessionId, labelIndex 34 | if sort_by_length: 35 | # sort by labelIndex in descending order 36 | ind = np.argsort(index[:, 1])[::-1] 37 | index = index[ind] 38 | self.index = index 39 | 40 | def __getitem__(self, idx): 41 | sid, lidx = self.index[idx] 42 | seq = self.sessions[sid][:lidx] 43 | label = self.sessions[sid][lidx] 44 | return seq, label 45 | 46 | def __len__(self): 47 | return len(self.index) 48 | -------------------------------------------------------------------------------- /utils/data/preprocess.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | 5 | def get_session_id(df, interval): 6 | df_prev = df.shift() 7 | is_new_session = (df.userId != df_prev.userId) | ( 8 | df.timestamp - df_prev.timestamp > interval 9 | ) 10 | session_id = is_new_session.cumsum() - 1 11 | return session_id 12 | 13 | 14 | def group_sessions(df, interval): 15 | sessionId = get_session_id(df, interval) 16 | df = df.assign(sessionId=sessionId) 17 | return df 18 | 19 | 20 | def filter_short_sessions(df, min_len=2): 21 | session_len = df.groupby('sessionId', sort=False).size() 22 | long_sessions = session_len[session_len >= min_len].index 23 | df_long = df[df.sessionId.isin(long_sessions)] 24 | return df_long 25 | 26 | 27 | def filter_infreq_items(df, min_support=5): 28 | item_support = df.groupby('itemId', sort=False).size() 29 | freq_items = item_support[item_support >= min_support].index 30 | df_freq = df[df.itemId.isin(freq_items)] 31 | return df_freq 32 | 33 | 34 | def filter_until_all_long_and_freq(df, min_len=2, min_support=5): 35 | while True: 36 | df_long = filter_short_sessions(df, min_len) 37 | df_freq = filter_infreq_items(df_long, min_support) 38 | if len(df_freq) == len(df): 39 | break 40 | df = df_freq 41 | return df 42 | 43 | 44 | def truncate_long_sessions(df, max_len=20, is_sorted=False): 45 | if not is_sorted: 46 | df = df.sort_values(['sessionId', 'timestamp']) 47 | itemIdx = df.groupby('sessionId').cumcount() 48 | df_t = df[itemIdx < max_len] 49 | return df_t 50 | 51 | 52 | def update_id(df, field): 53 | labels = pd.factorize(df[field])[0] 54 | kwargs = {field: labels} 55 | df = df.assign(**kwargs) 56 | return df 57 | 58 | 59 | def remove_immediate_repeats(df): 60 | df_prev = df.shift() 61 | is_not_repeat = (df.sessionId != df_prev.sessionId) | (df.itemId != df_prev.itemId) 62 | df_no_repeat = df[is_not_repeat] 63 | return df_no_repeat 64 | 65 | 66 | def reorder_sessions_by_endtime(df): 67 | endtime = df.groupby('sessionId', sort=False).timestamp.max() 68 | df_endtime = endtime.sort_values().reset_index() 69 | oid2nid = dict(zip(df_endtime.sessionId, df_endtime.index)) 70 | sessionId_new = df.sessionId.map(oid2nid) 71 | df = df.assign(sessionId=sessionId_new) 72 | df = df.sort_values(['sessionId', 'timestamp']) 73 | return df 74 | 75 | 76 | def keep_top_n_items(df, n): 77 | item_support = df.groupby('itemId', sort=False).size() 78 | top_items = item_support.nlargest(n).index 79 | df_top = df[df.itemId.isin(top_items)] 80 | return df_top 81 | 82 | 83 | def split_by_time(df, timedelta): 84 | max_time = df.timestamp.max() 85 | end_time = df.groupby('sessionId').timestamp.max() 86 | split_time = max_time - timedelta 87 | train_sids = end_time[end_time < split_time].index 88 | df_train = df[df.sessionId.isin(train_sids)] 89 | df_test = df[~df.sessionId.isin(train_sids)] 90 | return df_train, df_test 91 | 92 | 93 | def train_test_split(df, test_split=0.2): 94 | endtime = df.groupby('sessionId', sort=False).timestamp.max() 95 | endtime = endtime.sort_values() 96 | num_tests = int(len(endtime) * test_split) 97 | test_session_ids = endtime.index[-num_tests:] 98 | df_train = df[~df.sessionId.isin(test_session_ids)] 99 | df_test = df[df.sessionId.isin(test_session_ids)] 100 | return df_train, df_test 101 | 102 | 103 | def save_sessions(df, filepath): 104 | df = reorder_sessions_by_endtime(df) 105 | sessions = df.groupby('sessionId').itemId.apply(lambda x: ','.join(map(str, x))) 106 | sessions.to_csv(filepath, sep='\t', header=False, index=False) 107 | 108 | 109 | def save_dataset(dataset_dir, df_train, df_test): 110 | # filter items in test but not in train 111 | df_test = df_test[df_test.itemId.isin(df_train.itemId.unique())] 112 | df_test = filter_short_sessions(df_test) 113 | 114 | print(f'No. of Clicks: {len(df_train) + len(df_test)}') 115 | print(f'No. of Items: {df_train.itemId.nunique()}') 116 | 117 | # update itemId 118 | train_itemId_new, uniques = pd.factorize(df_train.itemId) 119 | df_train = df_train.assign(itemId=train_itemId_new) 120 | oid2nid = {oid: i for i, oid in enumerate(uniques)} 121 | test_itemId_new = df_test.itemId.map(oid2nid) 122 | df_test = df_test.assign(itemId=test_itemId_new) 123 | 124 | print(f'saving dataset to {dataset_dir}') 125 | dataset_dir.mkdir(parents=True, exist_ok=True) 126 | save_sessions(df_train, dataset_dir / 'train.txt') 127 | save_sessions(df_test, dataset_dir / 'test.txt') 128 | num_items = len(uniques) 129 | with open(dataset_dir / 'num_items.txt', 'w') as f: 130 | f.write(str(num_items)) 131 | 132 | 133 | def preprocess_diginetica(dataset_dir, csv_file): 134 | print(f'reading {csv_file}...') 135 | df = pd.read_csv( 136 | csv_file, 137 | usecols=[0, 2, 3, 4], 138 | delimiter=';', 139 | parse_dates=['eventdate'], 140 | infer_datetime_format=True, 141 | ) 142 | print('start preprocessing') 143 | # timeframe (time since the first query in a session, in milliseconds) 144 | df['timestamp'] = pd.to_timedelta(df.timeframe, unit='ms') + df.eventdate 145 | df = df.drop(['eventdate', 'timeframe'], 1) 146 | df = df.sort_values(['sessionId', 'timestamp']) 147 | df = filter_short_sessions(df) 148 | df = truncate_long_sessions(df, is_sorted=True) 149 | df = filter_infreq_items(df) 150 | df = filter_short_sessions(df) 151 | df_train, df_test = split_by_time(df, pd.Timedelta(days=7)) 152 | save_dataset(dataset_dir, df_train, df_test) 153 | 154 | 155 | def preprocess_gowalla_lastfm(dataset_dir, csv_file, usecols, interval, n): 156 | print(f'reading {csv_file}...') 157 | df = pd.read_csv( 158 | csv_file, 159 | sep='\t', 160 | header=None, 161 | names=['userId', 'timestamp', 'itemId'], 162 | usecols=usecols, 163 | parse_dates=['timestamp'], 164 | infer_datetime_format=True, 165 | ) 166 | print('start preprocessing') 167 | df = df.dropna() 168 | df = update_id(df, 'userId') 169 | df = update_id(df, 'itemId') 170 | df = df.sort_values(['userId', 'timestamp']) 171 | 172 | df = group_sessions(df, interval) 173 | df = remove_immediate_repeats(df) 174 | df = truncate_long_sessions(df, is_sorted=True) 175 | df = keep_top_n_items(df, n) 176 | df = filter_until_all_long_and_freq(df) 177 | df_train, df_test = train_test_split(df, test_split=0.2) 178 | save_dataset(dataset_dir, df_train, df_test) 179 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import defaultdict 3 | 4 | import torch as th 5 | from torch import nn, optim 6 | 7 | 8 | # ignore weight decay for parameters in bias, batch norm and activation 9 | def fix_weight_decay(model): 10 | decay = [] 11 | no_decay = [] 12 | for name, param in model.named_parameters(): 13 | if not param.requires_grad: 14 | continue 15 | if any(map(lambda x: x in name, ['bias', 'batch_norm', 'activation'])): 16 | no_decay.append(param) 17 | else: 18 | decay.append(param) 19 | params = [{'params': decay}, {'params': no_decay, 'weight_decay': 0}] 20 | return params 21 | 22 | 23 | def prepare_batch(batch, device): 24 | inputs, labels = batch 25 | inputs_gpu = [x.to(device) for x in inputs] 26 | labels_gpu = labels.to(device) 27 | return inputs_gpu, labels_gpu 28 | 29 | 30 | def evaluate(model, data_loader, device, Ks=[20]): 31 | model.eval() 32 | num_samples = 0 33 | max_K = max(Ks) 34 | results = defaultdict(float) 35 | with th.no_grad(): 36 | for batch in data_loader: 37 | inputs, labels = prepare_batch(batch, device) 38 | logits = model(*inputs) 39 | batch_size = logits.size(0) 40 | num_samples += batch_size 41 | topk = th.topk(logits, k=max_K, sorted=True)[1] 42 | labels = labels.unsqueeze(-1) 43 | for K in Ks: 44 | hit_ranks = th.where(topk[:, :K] == labels)[1] + 1 45 | hit_ranks = hit_ranks.float().cpu() 46 | results[f'HR@{K}'] += hit_ranks.numel() 47 | results[f'MRR@{K}'] += hit_ranks.reciprocal().sum().item() 48 | results[f'NDCG@{K}'] += th.log2(1 + hit_ranks).reciprocal().sum().item() 49 | for metric in results: 50 | results[metric] /= num_samples 51 | return results 52 | 53 | 54 | def print_results(results, epochs=None): 55 | print('Metric\t' + '\t'.join(results.keys())) 56 | print( 57 | 'Value\t' + 58 | '\t'.join([f'{round(val * 100, 2):.2f}' for val in results.values()]) 59 | ) 60 | if epochs is not None: 61 | print('Epoch\t' + '\t'.join([str(epochs[metric]) for metric in results])) 62 | 63 | 64 | class TrainRunner: 65 | def __init__( 66 | self, 67 | model, 68 | train_loader, 69 | test_loader, 70 | device, 71 | lr=1e-3, 72 | weight_decay=0, 73 | patience=3, 74 | Ks=[20], 75 | ): 76 | self.model = model 77 | if weight_decay > 0: 78 | params = fix_weight_decay(model) 79 | else: 80 | params = model.parameters() 81 | self.optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay) 82 | self.train_loader = train_loader 83 | self.test_loader = test_loader 84 | self.device = device 85 | self.epoch = 0 86 | self.batch = 0 87 | self.patience = patience 88 | self.Ks = Ks 89 | 90 | def train(self, epochs, log_interval=100): 91 | max_results = defaultdict(float) 92 | max_epochs = defaultdict(int) 93 | bad_counter = 0 94 | t = time.time() 95 | mean_loss = 0 96 | for epoch in range(epochs): 97 | self.model.train() 98 | for batch in self.train_loader: 99 | inputs, labels = prepare_batch(batch, self.device) 100 | self.optimizer.zero_grad() 101 | logits = self.model(*inputs) 102 | loss = nn.functional.cross_entropy(logits, labels) 103 | loss.backward() 104 | self.optimizer.step() 105 | mean_loss += loss.item() / log_interval 106 | if self.batch > 0 and self.batch % log_interval == 0: 107 | print( 108 | f'Batch {self.batch}: Loss = {mean_loss:.4f}, Time Elapsed = {time.time() - t:.2f}s' 109 | ) 110 | t = time.time() 111 | mean_loss = 0 112 | self.batch += 1 113 | 114 | curr_results = evaluate( 115 | self.model, self.test_loader, self.device, Ks=self.Ks 116 | ) 117 | 118 | print(f'\nEpoch {self.epoch}:') 119 | print_results(curr_results) 120 | 121 | any_better_result = False 122 | for metric in curr_results: 123 | if curr_results[metric] > max_results[metric]: 124 | max_results[metric] = curr_results[metric] 125 | max_epochs[metric] = self.epoch 126 | any_better_result = True 127 | 128 | if any_better_result: 129 | bad_counter = 0 130 | else: 131 | bad_counter += 1 132 | if bad_counter == self.patience: 133 | break 134 | 135 | self.epoch += 1 136 | print('\nBest results') 137 | print_results(max_results, max_epochs) 138 | return max_results 139 | --------------------------------------------------------------------------------