├── .gitignore ├── .style.yapf ├── LICENSE ├── MANIFEST.in ├── README.md ├── datasets ├── .gitignore └── make_datasets.ipynb ├── description.md ├── docs ├── CTR预测.xmind ├── CTR预测基础.md └── flow.md ├── examples ├── criteo_classification.py └── movielens_regression.py ├── imgs ├── AFM.png ├── AutoInt.png ├── CCPM.png ├── CIN.png ├── DCN.png ├── DIEN.png ├── DIN.png ├── DSIN.png ├── DeepFM.png ├── FFM.png ├── FGCNN.png ├── FM.png ├── FNN.png ├── InteractingLayer.png ├── MLR.png ├── NFFM.png ├── NFM.png ├── PNN.png ├── WDL.png └── xDeepFM.png ├── nbs ├── movielen.ipynb ├── test.ipynb ├── 协同过滤.ipynb └── 评测指标.ipynb ├── requirements.txt ├── setup.py ├── test.py └── torchctr ├── __init__.py ├── datasets ├── __init__.py ├── criteo.py ├── data.py ├── movielens.py ├── transform.py └── utils.py ├── layers.py ├── learner.py ├── metrics.py ├── models ├── __init__.py ├── deepfm.py ├── ffm.py ├── fm.py ├── lr.py └── mf.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # user define 107 | .vscode/ 108 | .idea/ -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | spaces_before_comment = 4 4 | split_before_logical_operator = true 5 | indent_width = 4 6 | column_limit = 120 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 AutuanLiu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include description.md 3 | include torchctr/* 4 | exclude examples/* 5 | exclude docs/* 6 | exclude datasets/* 7 | exclude nbs/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recommended-System-PyTorch 2 | 3 | Recommended system(2018-2019) 4 | 5 | **参考腾讯开源工具[PyTorch On Angel, arming PyTorch with a powerful Parameter Server, which enable PyTorch to train very big models.](https://github.com/Angel-ML/PyTorch-On-Angel)** 6 | 7 | ## Data 8 | 9 | (**Fin**) 10 | 11 | 1. movielen data 12 | - [ml-latest](http://files.grouplens.org/datasets/movielens/ml-latest.zip) 13 | - [ml-100k](http://files.grouplens.org/datasets/movielens/ml-100k.zip) 14 | - [ml-1m](http://files.grouplens.org/datasets/movielens/ml-1m.zip) 15 | - [ml-10m](http://files.grouplens.org/datasets/movielens/ml-10m.zip) 16 | - [ml-20m](http://files.grouplens.org/datasets/movielens/ml-20m.zi) 17 | 2. Criteo data 18 | - [dac](https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz) 19 | 20 | ## Embedding 21 | 22 | (**Fin**) 23 | 24 | 1. sparse features 25 | 2. sequence features 26 | 3. dense features 27 | 28 | ## CTR 模型 29 | 30 | (**WIP**) 31 | 32 | | model | structure | 33 | | :------------: | :----------------------------: | 34 | | FM | ![FM](./imgs/FM.png) | 35 | | FFM | ![FFM](./imgs/FFM.png) | 36 | | DeepFM-201703 | ![DeepFM](./imgs/DeepFM.png) | 37 | | xDeepFM-2018 | ![xDeepFM](./imgs/xDeepFM.png) | 38 | | AFM-201708 | ![AFM](./imgs/AFM.png) | 39 | | NFM-201708 | ![NFM](./imgs/NFM.png) | 40 | | FGCNN-201904 | ![FGCNN](./imgs/FGCNN.png) | 41 | | MLR | ![MLR](./imgs/MLR.png) | 42 | | NFFM | ![NFFM](./imgs/NFFM.png) | 43 | | WDL | ![WDL](./imgs/WDL.png) | 44 | | PNN-201611 | ![PNN](./imgs/PNN.png) | 45 | | CIN | ![CIN](./imgs/CIN.png) | 46 | | CCPM-201510 | ![CCPM](./imgs/CCPM.png) | 47 | | AutoInt-201810 | ![AutoInt](./imgs/AutoInt.png) | 48 | | DCN-201708 | ![DCN](./imgs/DCN.png) | 49 | | DSIN | ![DSIN](./imgs/DSIN.png) | 50 | | FNN-201601 | ![FNN](./imgs/FNN.png) | 51 | | DIEN | ![DIEN](./imgs/DIEN.png) | 52 | | DIN-201706 | ![DIN](./imgs/DIN.png) | 53 | 54 | 55 | ## Refrences 56 | 57 | 1. 《推荐系统实践》 58 | 2. git@github.com:dawenl/vae_cf.git 59 | 3. git@github.com:eelxpeng/CollaborativeVAE.git 60 | 4. git@github.com:hidasib/GRU4Rec.git 61 | 5. git@github.com:hexiangnan/neural_collaborative_filtering.git 62 | 6. git@github.com:NVIDIA/DeepRecommender.git 63 | 7. [shenweichen/DeepCTR](https://github.com/shenweichen/DeepCTR) 64 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !*.ipynb 4 | -------------------------------------------------------------------------------- /datasets/make_datasets.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 8, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "from pathlib import Path\n", 11 | "import os" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 4, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "def get_dat_data(fp):\n", 21 | " \"\"\"读取 .dat 数据\n", 22 | "\n", 23 | " Args:\n", 24 | " fp (str or Path): 文件路径名\n", 25 | " \"\"\"\n", 26 | "\n", 27 | " if not isinstance(fp, Path):\n", 28 | " fp = Path(fp)\n", 29 | " data = pd.read_csv(fp, sep='::', header=None, engine='python')\n", 30 | " return data" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 5, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "root = Path('./ml-1m/')" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 6, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "data": { 49 | "text/plain": [ 50 | "['movies.dat', 'ratings.dat', 'README', 'users.dat']" 51 | ] 52 | }, 53 | "execution_count": 6, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "os.listdir(root)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 9, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "movies = get_dat_data(root/r'movies.dat') # MovieID::Title::Genres\n", 69 | "ratings = get_dat_data(root/r'ratings.dat') # UserID::MovieID::Rating::Timestamp\n", 70 | "users = get_dat_data(root/r'users.dat') # UserID::Gender::Age::Occupation::Zip-code" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 11, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "data": { 80 | "text/html": [ 81 | "
\n", 82 | "\n", 95 | "\n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | "
012
01Toy Story (1995)Animation|Children's|Comedy
12Jumanji (1995)Adventure|Children's|Fantasy
23Grumpier Old Men (1995)Comedy|Romance
34Waiting to Exhale (1995)Comedy|Drama
45Father of the Bride Part II (1995)Comedy
\n", 137 | "
" 138 | ], 139 | "text/plain": [ 140 | " 0 1 2\n", 141 | "0 1 Toy Story (1995) Animation|Children's|Comedy\n", 142 | "1 2 Jumanji (1995) Adventure|Children's|Fantasy\n", 143 | "2 3 Grumpier Old Men (1995) Comedy|Romance\n", 144 | "3 4 Waiting to Exhale (1995) Comedy|Drama\n", 145 | "4 5 Father of the Bride Part II (1995) Comedy" 146 | ] 147 | }, 148 | "execution_count": 11, 149 | "metadata": {}, 150 | "output_type": "execute_result" 151 | } 152 | ], 153 | "source": [ 154 | "movies.head()" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 12, 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "text/html": [ 165 | "
\n", 166 | "\n", 179 | "\n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | "
0123
0111935978300760
116613978302109
219143978301968
3134084978300275
4123555978824291
\n", 227 | "
" 228 | ], 229 | "text/plain": [ 230 | " 0 1 2 3\n", 231 | "0 1 1193 5 978300760\n", 232 | "1 1 661 3 978302109\n", 233 | "2 1 914 3 978301968\n", 234 | "3 1 3408 4 978300275\n", 235 | "4 1 2355 5 978824291" 236 | ] 237 | }, 238 | "execution_count": 12, 239 | "metadata": {}, 240 | "output_type": "execute_result" 241 | } 242 | ], 243 | "source": [ 244 | "ratings.head()" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 13, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "data": { 254 | "text/html": [ 255 | "
\n", 256 | "\n", 269 | "\n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | "
01234
01F11048067
12M561670072
23M251555117
34M45702460
45M252055455
\n", 323 | "
" 324 | ], 325 | "text/plain": [ 326 | " 0 1 2 3 4\n", 327 | "0 1 F 1 10 48067\n", 328 | "1 2 M 56 16 70072\n", 329 | "2 3 M 25 15 55117\n", 330 | "3 4 M 45 7 02460\n", 331 | "4 5 M 25 20 55455" 332 | ] 333 | }, 334 | "execution_count": 13, 335 | "metadata": {}, 336 | "output_type": "execute_result" 337 | } 338 | ], 339 | "source": [ 340 | "users.head()" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 17, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "ratings_table = ratings.pivot_table(values=2, index=0, columns=1)" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 18, 355 | "metadata": {}, 356 | "outputs": [ 357 | { 358 | "data": { 359 | "text/html": [ 360 | "
\n", 361 | "\n", 374 | "\n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | "
112345678910...3943394439453946394739483949395039513952
0
15.0NaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
2NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
4NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
5NaNNaNNaNNaNNaN2.0NaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
\n", 548 | "

5 rows × 3706 columns

\n", 549 | "
" 550 | ], 551 | "text/plain": [ 552 | "1 1 2 3 4 5 6 7 8 9 10 ... 3943 \\\n", 553 | "0 ... \n", 554 | "1 5.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n", 555 | "2 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n", 556 | "3 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n", 557 | "4 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n", 558 | "5 NaN NaN NaN NaN NaN 2.0 NaN NaN NaN NaN ... NaN \n", 559 | "\n", 560 | "1 3944 3945 3946 3947 3948 3949 3950 3951 3952 \n", 561 | "0 \n", 562 | "1 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", 563 | "2 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", 564 | "3 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", 565 | "4 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", 566 | "5 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", 567 | "\n", 568 | "[5 rows x 3706 columns]" 569 | ] 570 | }, 571 | "execution_count": 18, 572 | "metadata": {}, 573 | "output_type": "execute_result" 574 | } 575 | ], 576 | "source": [ 577 | "ratings_table.head()" 578 | ] 579 | }, 580 | { 581 | "cell_type": "markdown", 582 | "metadata": {}, 583 | "source": [ 584 | "## ratings 数据集" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 21, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "ratings_table = ratings_table.fillna(0)" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": 22, 599 | "metadata": {}, 600 | "outputs": [ 601 | { 602 | "data": { 603 | "text/html": [ 604 | "
\n", 605 | "\n", 618 | "\n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | "
112345678910...3943394439453946394739483949395039513952
0
15.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
20.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
30.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
40.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
50.00.00.00.00.02.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
\n", 792 | "

5 rows × 3706 columns

\n", 793 | "
" 794 | ], 795 | "text/plain": [ 796 | "1 1 2 3 4 5 6 7 8 9 10 ... 3943 \\\n", 797 | "0 ... \n", 798 | "1 5.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n", 799 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n", 800 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n", 801 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n", 802 | "5 0.0 0.0 0.0 0.0 0.0 2.0 0.0 0.0 0.0 0.0 ... 0.0 \n", 803 | "\n", 804 | "1 3944 3945 3946 3947 3948 3949 3950 3951 3952 \n", 805 | "0 \n", 806 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 807 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 808 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 809 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 810 | "5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 811 | "\n", 812 | "[5 rows x 3706 columns]" 813 | ] 814 | }, 815 | "execution_count": 22, 816 | "metadata": {}, 817 | "output_type": "execute_result" 818 | } 819 | ], 820 | "source": [ 821 | "ratings_table.head()" 822 | ] 823 | }, 824 | { 825 | "cell_type": "markdown", 826 | "metadata": {}, 827 | "source": [ 828 | "保存数据(结构化数据)" 829 | ] 830 | }, 831 | { 832 | "cell_type": "code", 833 | "execution_count": 24, 834 | "metadata": {}, 835 | "outputs": [], 836 | "source": [ 837 | "ratings_table.to_csv(root/r'ratings_table.csv', encoding='utf-8')" 838 | ] 839 | }, 840 | { 841 | "cell_type": "code", 842 | "execution_count": null, 843 | "metadata": {}, 844 | "outputs": [], 845 | "source": [] 846 | } 847 | ], 848 | "metadata": { 849 | "kernelspec": { 850 | "display_name": "Python 3", 851 | "language": "python", 852 | "name": "python3" 853 | }, 854 | "language_info": { 855 | "name": "" 856 | } 857 | }, 858 | "nbformat": 4, 859 | "nbformat_minor": 2 860 | } 861 | -------------------------------------------------------------------------------- /description.md: -------------------------------------------------------------------------------- 1 | # torchctr 2 | -------------------------------------------------------------------------------- /docs/CTR预测.xmind: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/docs/CTR预测.xmind -------------------------------------------------------------------------------- /docs/CTR预测基础.md: -------------------------------------------------------------------------------- 1 | # CTR 预测基础 2 | 3 | 在计算广告和推荐系统中,CTR预估(click-through rate)是非常重要的一个环节,判断一个商品的是否进行推荐需要根据CTR预估的点击率来进行。在进行 4 | CTR预估时,除了单特征外,往往要对特征进行组合。对于特征组合来说,业界现在通用的做法主要有两大类:**FM系列与Tree系列** 5 | 6 | FM(Factorization Machine)主要是为了解决**数据稀疏**的情况下,特征怎样组合的问题。普通的线性模型,我们都是将各个特征独立考虑的,并没有考虑 7 | 到特征与特征之间的相互关系。但实际上,大量的特征之间是有关联的。一般的线性模型压根没有考虑特征间的关联。为了表述特征间的相关性,我们采用**多项式模型**。与线性模型相比,FM的模型就多了后面**特征组合**的部分。 8 | 9 | 10 | 11 | ## 参考文献 12 | 13 | 1. [推荐系统遇上深度学习(一)--FM模型理论和实践 - 简书](https://www.jianshu.com/p/152ae633fb00) 14 | 2. [简单易学的机器学习算法——因子分解机(Factorization Machine) - null的专栏 - CSDN博客](https://blog.csdn.net/google19890102/article/details/45532745) 15 | 3. [分解机(Factorization Machines)推荐算法原理 - 刘建平Pinard - 博客园](https://www.cnblogs.com/pinard/p/6370127.html) 16 | 4. [机器学习算法系列(26):因子分解机(FM)与场感知分解机(FFM) | Free Will](https://plushunter.github.io/2017/07/13/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E7%B3%BB%E5%88%97%EF%BC%8826%EF%BC%89%EF%BC%9A%E5%9B%A0%E5%AD%90%E5%88%86%E8%A7%A3%E6%9C%BA%EF%BC%88FM%EF%BC%89%E4%B8%8E%E5%9C%BA%E6%84%9F%E7%9F%A5%E5%88%86%E8%A7%A3%E6%9C%BA%EF%BC%88FFM%EF%BC%89/) 17 | 5. [第09章:深入浅出ML之Factorization家族 | 计算广告与机器学习](http://www.52caml.com/head_first_ml/ml-chapter9-factorization-family/) 18 | 6. [深入FFM原理与实践 - 美团技术团队](https://tech.meituan.com/2016/03/03/deep-understanding-of-ffm-principles-and-practices.html) 19 | 7. [从FFM到DeepFFM,推荐排序模型到底哪家强?](https://www.infoq.cn/article/vKoKh_ZDXcWRh8fLSsRp) 20 | 8. [FM与FFM的区别 - AI_盲的博客 - CSDN博客](https://blog.csdn.net/xwd18280820053/article/details/77529274) 21 | 9. [矩阵分解在推荐系统中的应用:NMF和经典SVD实战 | 乐天的个人网站](https://www.letiantian.me/2015-05-25-nmf-svd-recommend/) 22 | 10. [TF-IDF与余弦相似度 - 知乎](https://zhuanlan.zhihu.com/p/32826433) 23 | 11. [王喆的机器学习笔记 - 知乎](https://zhuanlan.zhihu.com/wangzhenotes) 24 | 12. [Embedding在深度推荐系统中的3大应用方向 - 知乎](https://zhuanlan.zhihu.com/p/67218758) 25 | 13. [谷歌、阿里、微软等10大深度学习CTR模型最全演化图谱【推荐、广告、搜索领域】 - 知乎](https://zhuanlan.zhihu.com/p/63186101) -------------------------------------------------------------------------------- /docs/flow.md: -------------------------------------------------------------------------------- 1 | download_data --> read_data --> process_data --> Dataset --> split_dataset --> DataLoader --> model --> prediction -------------------------------------------------------------------------------- /examples/criteo_classification.py: -------------------------------------------------------------------------------- 1 | from torchctr.datasets.criteo import get_criteo 2 | 3 | # step 1: download dataset 4 | get_criteo('datasets') 5 | 6 | # step 2: read data 7 | -------------------------------------------------------------------------------- /examples/movielens_regression.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from torchctr.layers import EmbeddingLayer 4 | from torchctr.datasets import (FeatureDict, get_movielens, make_datasets, read_data, defaults, fillna, make_dataloader, 5 | RecommendDataset) 6 | 7 | # step 1: download dataset 8 | root = get_movielens('datasets', 'ml-1m') 9 | 10 | # step 2: read data 11 | users = read_data(root / 'users.dat', sep='::', names=['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code']) 12 | movies = read_data(root / 'movies.dat', sep='::', names=['MovieID', 'Title', 'Genres']) 13 | ratings = read_data(root / 'ratings.dat', sep='::', names=['UserID', 'MovieID', 'Rating', 'Timestamp']) 14 | 15 | # step 3: make dataset 16 | dataset = pd.merge(ratings, users, on='UserID') 17 | dataset = pd.merge(dataset, movies, on='MovieID') 18 | 19 | # subsample 20 | dataset = dataset.iloc[5000:10000, :] 21 | 22 | # step 4: make features and dataloader 23 | sparse_features = ['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code', 'MovieID'] 24 | sequence_features = ['Genres'] 25 | dataset = fillna(dataset, dataset.columns, fill_v='unk') 26 | features = FeatureDict(sparse_features, None, sequence_features) 27 | input, _ = make_datasets(dataset, features, sep='|') 28 | loader = make_dataloader(input, dataset['Rating'].values, batch_size=64, shuffle=True) 29 | dataset = RecommendDataset(input, dataset['Rating'].values) 30 | print(dataset) 31 | 32 | # step 5: build model 33 | model = EmbeddingLayer(input).to(defaults.device) 34 | print(model) 35 | out = model(input) 36 | print(out.shape, out, sep='\n') 37 | 38 | for data, target in loader: 39 | print(data, target) 40 | -------------------------------------------------------------------------------- /imgs/AFM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/AFM.png -------------------------------------------------------------------------------- /imgs/AutoInt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/AutoInt.png -------------------------------------------------------------------------------- /imgs/CCPM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/CCPM.png -------------------------------------------------------------------------------- /imgs/CIN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/CIN.png -------------------------------------------------------------------------------- /imgs/DCN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/DCN.png -------------------------------------------------------------------------------- /imgs/DIEN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/DIEN.png -------------------------------------------------------------------------------- /imgs/DIN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/DIN.png -------------------------------------------------------------------------------- /imgs/DSIN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/DSIN.png -------------------------------------------------------------------------------- /imgs/DeepFM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/DeepFM.png -------------------------------------------------------------------------------- /imgs/FFM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/FFM.png -------------------------------------------------------------------------------- /imgs/FGCNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/FGCNN.png -------------------------------------------------------------------------------- /imgs/FM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/FM.png -------------------------------------------------------------------------------- /imgs/FNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/FNN.png -------------------------------------------------------------------------------- /imgs/InteractingLayer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/InteractingLayer.png -------------------------------------------------------------------------------- /imgs/MLR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/MLR.png -------------------------------------------------------------------------------- /imgs/NFFM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/NFFM.png -------------------------------------------------------------------------------- /imgs/NFM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/NFM.png -------------------------------------------------------------------------------- /imgs/PNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/PNN.png -------------------------------------------------------------------------------- /imgs/WDL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/WDL.png -------------------------------------------------------------------------------- /imgs/xDeepFM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/xDeepFM.png -------------------------------------------------------------------------------- /nbs/movielen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import torch\n", 11 | "from torchctr.layers import EmbeddingLayer, EmbeddingDropout\n", 12 | "from torchctr.datasets import (FeatureDict, get_movielens, make_datasets, read_data, defaults, fillna, make_dataloader, DataMeta)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "## step 1: download dataset" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "Downloading...\n", 32 | "Using downloaded and verified file: ../datasets/ml-1m/raw/ml-1m.zip\n", 33 | "Extracting...\n", 34 | "Done!\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "root = get_movielens('../datasets', 'ml-1m')" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## step 2: read data" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "users = read_data(root / 'users.dat', sep='::', names=['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code'])\n", 56 | "movies = read_data(root / 'movies.dat', sep='::', names=['MovieID', 'Title', 'Genres'])\n", 57 | "ratings = read_data(root / 'ratings.dat', sep='::', names=['UserID', 'MovieID', 'Rating', 'Timestamp'])" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "## step 3: make dataset" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 4, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "dataset = pd.merge(ratings, users, on='UserID')\n", 74 | "dataset = pd.merge(dataset, movies, on='MovieID')" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## subsample(optional)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "dataset = dataset.iloc[5000:10000, :]" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "## step 4: make features and dataloader" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 6, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "Making dataset Done!\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "sparse_features = ['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code', 'MovieID']\n", 115 | "sequence_features = ['Genres']\n", 116 | "dataset = fillna(dataset, dataset.columns, fill_v='unk')\n", 117 | "features = FeatureDict(sparse_features, None, sequence_features)\n", 118 | "input, _ = make_datasets(dataset, features, sep='|')\n", 119 | "# loader = make_dataloader(input, dataset['Rating'].values, batch_size=64, shuffle=True)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "## step 5: build model" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 7, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "EmbeddingLayer(\n", 139 | " (sparse_embeds): ModuleList(\n", 140 | " (0): EmbeddingDropout(\n", 141 | " (emb): Embedding(3205, 147)\n", 142 | " )\n", 143 | " (1): EmbeddingDropout(\n", 144 | " (emb): Embedding(2, 2)\n", 145 | " )\n", 146 | " (2): EmbeddingDropout(\n", 147 | " (emb): Embedding(7, 5)\n", 148 | " )\n", 149 | " (3): EmbeddingDropout(\n", 150 | " (emb): Embedding(21, 9)\n", 151 | " )\n", 152 | " (4): EmbeddingDropout(\n", 153 | " (emb): Embedding(2153, 118)\n", 154 | " )\n", 155 | " (5): EmbeddingDropout(\n", 156 | " (emb): Embedding(4, 3)\n", 157 | " )\n", 158 | " )\n", 159 | " (sequence_embeds): ModuleList(\n", 160 | " (0): Embedding(7, 5)\n", 161 | " )\n", 162 | ")\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "model = EmbeddingLayer(input, emb_drop=0.1).to(defaults.device)\n", 168 | "print(model)\n", 169 | "out = model(input)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 8, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "data": { 179 | "text/plain": [ 180 | "tensor([[-0.8287, 0.3714, -0.7944, 0.5302, -0.1847],\n", 181 | " [-0.8287, 0.3714, -0.7944, 0.5302, -0.1847],\n", 182 | " [-0.8287, 0.3714, -0.7944, 0.5302, -0.1847],\n", 183 | " ...,\n", 184 | " [-0.4850, -0.0608, 1.1737, 0.4636, -0.4604],\n", 185 | " [-0.4850, -0.0608, 1.1737, 0.4636, -0.4604],\n", 186 | " [-0.4850, -0.0608, 1.1737, 0.4636, -0.4604]], device='cuda:0',\n", 187 | " grad_fn=)" 188 | ] 189 | }, 190 | "execution_count": 8, 191 | "metadata": {}, 192 | "output_type": "execute_result" 193 | } 194 | ], 195 | "source": [ 196 | "out[:, -5:]" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 9, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "ly = EmbeddingDropout(torch.nn.Embedding(7, 5), 0.1)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 10, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "text/plain": [ 216 | "EmbeddingDropout(\n", 217 | " (emb): Embedding(7, 5)\n", 218 | ")" 219 | ] 220 | }, 221 | "execution_count": 10, 222 | "metadata": {}, 223 | "output_type": "execute_result" 224 | } 225 | ], 226 | "source": [ 227 | "ly" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 11, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "data": { 237 | "text/plain": [ 238 | "Parameter containing:\n", 239 | "tensor([[ 0.0122, 0.7963, -0.9860, 2.5023, 0.9121],\n", 240 | " [-0.2414, -1.1864, -0.0428, 1.4428, 0.6048],\n", 241 | " [-3.1064, -0.8661, -0.4674, -0.6350, -0.0244],\n", 242 | " [-1.4281, -0.2473, 1.4546, 0.1025, -0.1300],\n", 243 | " [-2.0995, 0.1254, 0.0183, -0.6482, 0.9680],\n", 244 | " [ 0.2651, -2.6695, -0.7403, -1.3880, 0.3184],\n", 245 | " [-0.6377, 0.6056, 0.6045, -0.6367, -0.1732]], requires_grad=True)" 246 | ] 247 | }, 248 | "execution_count": 11, 249 | "metadata": {}, 250 | "output_type": "execute_result" 251 | } 252 | ], 253 | "source": [ 254 | "ly.emb.weight" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 12, 260 | "metadata": {}, 261 | "outputs": [ 262 | { 263 | "data": { 264 | "text/plain": [ 265 | "array([[0, 0, 1, ..., 1, 0, 0],\n", 266 | " [0, 0, 1, ..., 1, 0, 0],\n", 267 | " [0, 0, 1, ..., 1, 0, 0],\n", 268 | " ...,\n", 269 | " [0, 0, 0, ..., 1, 1, 0],\n", 270 | " [0, 0, 0, ..., 1, 1, 0],\n", 271 | " [0, 0, 0, ..., 1, 1, 0]], dtype=int64)" 272 | ] 273 | }, 274 | "execution_count": 12, 275 | "metadata": {}, 276 | "output_type": "execute_result" 277 | } 278 | ], 279 | "source": [ 280 | "input.sequence_data.data" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 13, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "y = torch.as_tensor(input.sequence_data.data).float()" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 14, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "data": { 299 | "text/plain": [ 300 | "tensor([[-2.2113, -0.3294, 0.3352, -0.3936, 0.2712],\n", 301 | " [-2.2113, -0.3294, 0.3352, -0.3936, 0.2712],\n", 302 | " [-2.2113, -0.3294, 0.3352, -0.3936, 0.2712],\n", 303 | " ...,\n", 304 | " [-0.9172, -1.2720, -0.3610, -1.0181, 0.6432],\n", 305 | " [-0.9172, -1.2720, -0.3610, -1.0181, 0.6432],\n", 306 | " [-0.9172, -1.2720, -0.3610, -1.0181, 0.6432]], grad_fn=)" 307 | ] 308 | }, 309 | "execution_count": 14, 310 | "metadata": {}, 311 | "output_type": "execute_result" 312 | } 313 | ], 314 | "source": [ 315 | "y @ ly.emb.weight/y.sum(1).view(-1,1)" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 15, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "data": { 325 | "text/plain": [ 326 | "[7]" 327 | ] 328 | }, 329 | "execution_count": 15, 330 | "metadata": {}, 331 | "output_type": "execute_result" 332 | } 333 | ], 334 | "source": [ 335 | "input.sequence_data.nunique" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 16, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "from sklearn.feature_extraction.text import CountVectorizer\n", 345 | "import numpy as np\n", 346 | "def sequence_feature_encoding(data, features_names, sep: str = ','):\n", 347 | " \"\"\"Encoding for sequence features.\"\"\"\n", 348 | "\n", 349 | " if not features_names:\n", 350 | " return None\n", 351 | " data_value, nuniques = [], []\n", 352 | " for feature in features_names:\n", 353 | " vocab = set.union(*[set(str(x).strip().split(sep=sep)) for x in data[feature]])\n", 354 | " vec = CountVectorizer(vocabulary=vocab)\n", 355 | " multi_hot = vec.transform(data[feature])\n", 356 | " data_value.append(multi_hot.toarray())\n", 357 | " nuniques.append(len(vocab))\n", 358 | " data_meta = DataMeta(np.hstack(data_value), None, features_names, nuniques)\n", 359 | " return data_meta" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 17, 365 | "metadata": {}, 366 | "outputs": [ 367 | { 368 | "data": { 369 | "text/html": [ 370 | "
\n", 371 | "\n", 384 | "\n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | "
UserIDMovieIDRatingTimestampGenderAgeOccupationZip-codeTitleGenres
8756239615999463533121754Ben-Hur (1959)Action|Adventure|Drama
8845286714959889599131581Ben-Hur (1959)Action|Adventure|Drama
766424170596411326412142061Princess Bride, The (1987)Action|Adventure|Comedy|Romance
94821127359741773111271637Christmas Story, A (1983)Comedy|Drama
7709249103963196099050133Princess Bride, The (1987)Action|Adventure|Comedy|Romance
81673133039573614991211483Princess Bride, The (1987)Action|Adventure|Comedy|Romance
515116742496741610711121123Bug's Life, A (1998)Animation|Children's|Comedy
6165319049756075811316614Princess Bride, The (1987)Action|Adventure|Comedy|Romance
817931480595721378604151476Princess Bride, The (1987)Action|Adventure|Comedy|Romance
96771652359674684181214217Christmas Story, A (1983)Comedy|Drama
\n", 533 | "
" 534 | ], 535 | "text/plain": [ 536 | " UserID MovieID Rating Timestamp Gender Age Occupation Zip-code \\\n", 537 | "8756 2396 1 5 999463533 1 2 17 54 \n", 538 | "8845 2867 1 4 959889599 1 3 1 581 \n", 539 | "7664 2417 0 5 964113264 1 2 14 2061 \n", 540 | "9482 1127 3 5 974177311 1 2 7 1637 \n", 541 | "7709 2491 0 3 963196099 0 5 0 133 \n", 542 | "8167 3133 0 3 957361499 1 2 1 1483 \n", 543 | "5151 1674 2 4 967416107 1 1 12 1123 \n", 544 | "6165 319 0 4 975607581 1 3 16 614 \n", 545 | "8179 3148 0 5 957213786 0 4 15 1476 \n", 546 | "9677 1652 3 5 967468418 1 2 14 217 \n", 547 | "\n", 548 | " Title Genres \n", 549 | "8756 Ben-Hur (1959) Action|Adventure|Drama \n", 550 | "8845 Ben-Hur (1959) Action|Adventure|Drama \n", 551 | "7664 Princess Bride, The (1987) Action|Adventure|Comedy|Romance \n", 552 | "9482 Christmas Story, A (1983) Comedy|Drama \n", 553 | "7709 Princess Bride, The (1987) Action|Adventure|Comedy|Romance \n", 554 | "8167 Princess Bride, The (1987) Action|Adventure|Comedy|Romance \n", 555 | "5151 Bug's Life, A (1998) Animation|Children's|Comedy \n", 556 | "6165 Princess Bride, The (1987) Action|Adventure|Comedy|Romance \n", 557 | "8179 Princess Bride, The (1987) Action|Adventure|Comedy|Romance \n", 558 | "9677 Christmas Story, A (1983) Comedy|Drama " 559 | ] 560 | }, 561 | "execution_count": 17, 562 | "metadata": {}, 563 | "output_type": "execute_result" 564 | } 565 | ], 566 | "source": [ 567 | "dataset.sample(10)" 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "execution_count": 18, 573 | "metadata": {}, 574 | "outputs": [], 575 | "source": [ 576 | "x = sequence_feature_encoding(dataset, ['Genres'], '|')" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 19, 582 | "metadata": {}, 583 | "outputs": [ 584 | { 585 | "data": { 586 | "text/plain": [ 587 | "{'Action',\n", 588 | " 'Adventure',\n", 589 | " 'Animation',\n", 590 | " \"Children's\",\n", 591 | " 'Comedy',\n", 592 | " 'Drama',\n", 593 | " 'Romance'}" 594 | ] 595 | }, 596 | "execution_count": 19, 597 | "metadata": {}, 598 | "output_type": "execute_result" 599 | } 600 | ], 601 | "source": [ 602 | "vocab = set.union(*[set(str(x).strip().split(sep='|')) for x in dataset['Genres']])\n", 603 | "vocab" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": 20, 609 | "metadata": {}, 610 | "outputs": [], 611 | "source": [ 612 | "vec = CountVectorizer(vocabulary=vocab)" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": 21, 618 | "metadata": {}, 619 | "outputs": [], 620 | "source": [ 621 | "# [','.join(str(x).strip().split(sep='|')) for x in dataset['Genres']]" 622 | ] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "execution_count": 22, 627 | "metadata": {}, 628 | "outputs": [ 629 | { 630 | "data": { 631 | "text/plain": [ 632 | "CountVectorizer(analyzer='word', binary=False, decode_error='strict',\n", 633 | " dtype=, encoding='utf-8', input='content',\n", 634 | " lowercase=True, max_df=1.0, max_features=None, min_df=1,\n", 635 | " ngram_range=(1, 1), preprocessor=None, stop_words=None,\n", 636 | " strip_accents=None, token_pattern='(?u)\\\\b\\\\w\\\\w+\\\\b',\n", 637 | " tokenizer=None,\n", 638 | " vocabulary={\"Children's\", 'Romance', 'Adventure', 'Drama', 'Animation', 'Action', 'Comedy'})" 639 | ] 640 | }, 641 | "execution_count": 22, 642 | "metadata": {}, 643 | "output_type": "execute_result" 644 | } 645 | ], 646 | "source": [ 647 | "vec.fit([' '.join(str(x).strip().split(sep='|')) for x in dataset['Genres']])" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": 23, 653 | "metadata": {}, 654 | "outputs": [], 655 | "source": [ 656 | "multi_hot = vec.transform(['Action Comedy', 'Action'])" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 24, 662 | "metadata": {}, 663 | "outputs": [ 664 | { 665 | "data": { 666 | "text/plain": [ 667 | "[array([0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0])]" 668 | ] 669 | }, 670 | "execution_count": 24, 671 | "metadata": {}, 672 | "output_type": "execute_result" 673 | } 674 | ], 675 | "source": [ 676 | "list(multi_hot.toarray())" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": 25, 682 | "metadata": {}, 683 | "outputs": [ 684 | { 685 | "data": { 686 | "text/plain": [ 687 | "array([[0, 1, 1, 1, 0, 0, 0, 0],\n", 688 | " [0, 0, 0, 0, 1, 1, 1, 0],\n", 689 | " [1, 0, 0, 0, 0, 0, 0, 1]], dtype=int64)" 690 | ] 691 | }, 692 | "execution_count": 25, 693 | "metadata": {}, 694 | "output_type": "execute_result" 695 | } 696 | ], 697 | "source": [ 698 | "CountVectorizer(token_pattern=r'(?u)\\b\\w+\\b', analyzer='word').fit_transform(['1 2 31', 'a, b, c3', '中 0']).toarray()" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 26, 704 | "metadata": {}, 705 | "outputs": [], 706 | "source": [ 707 | "corpus = [' '.join(str(x).strip().split(sep='|')) for x in dataset['Genres']]" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": 27, 713 | "metadata": {}, 714 | "outputs": [], 715 | "source": [ 716 | "vocab = set.union(*[set(x.split(' ')) for x in corpus])" 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "execution_count": 28, 722 | "metadata": {}, 723 | "outputs": [], 724 | "source": [ 725 | "vec = CountVectorizer(token_pattern=r'(?u)\\b[\\w\\']+\\b')\n", 726 | "# vec = CountVectorizer(vocabulary=vocab)" 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": 29, 732 | "metadata": {}, 733 | "outputs": [ 734 | { 735 | "data": { 736 | "text/plain": [ 737 | "array([[0, 0, 1, ..., 1, 0, 0],\n", 738 | " [0, 0, 1, ..., 1, 0, 0],\n", 739 | " [0, 0, 1, ..., 1, 0, 0],\n", 740 | " ...,\n", 741 | " [0, 0, 0, ..., 1, 1, 0],\n", 742 | " [0, 0, 0, ..., 1, 1, 0],\n", 743 | " [0, 0, 0, ..., 1, 1, 0]], dtype=int64)" 744 | ] 745 | }, 746 | "execution_count": 29, 747 | "metadata": {}, 748 | "output_type": "execute_result" 749 | } 750 | ], 751 | "source": [ 752 | "vec.fit_transform(corpus).toarray()" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": 30, 758 | "metadata": {}, 759 | "outputs": [ 760 | { 761 | "data": { 762 | "text/plain": [ 763 | "{'animation': 2,\n", 764 | " \"children's\": 3,\n", 765 | " 'comedy': 4,\n", 766 | " 'action': 0,\n", 767 | " 'adventure': 1,\n", 768 | " 'romance': 6,\n", 769 | " 'drama': 5}" 770 | ] 771 | }, 772 | "execution_count": 30, 773 | "metadata": {}, 774 | "output_type": "execute_result" 775 | } 776 | ], 777 | "source": [ 778 | "vec.vocabulary_" 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": null, 784 | "metadata": {}, 785 | "outputs": [], 786 | "source": [] 787 | } 788 | ], 789 | "metadata": { 790 | "kernelspec": { 791 | "display_name": "Python 3", 792 | "language": "python", 793 | "name": "python3" 794 | }, 795 | "language_info": { 796 | "codemirror_mode": { 797 | "name": "ipython", 798 | "version": 3 799 | }, 800 | "file_extension": ".py", 801 | "mimetype": "text/x-python", 802 | "name": "python", 803 | "nbconvert_exporter": "python", 804 | "pygments_lexer": "ipython3", 805 | "version": "3.6.6" 806 | } 807 | }, 808 | "nbformat": 4, 809 | "nbformat_minor": 2 810 | } 811 | -------------------------------------------------------------------------------- /nbs/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "from torchctr.layers import EmbeddingLayer\n", 11 | "from torchctr.datasets import (FeatureDict, get_movielens, make_datasets, read_data, defaults, fillna, make_dataloader)\n", 12 | "from torchctr.datasets.data import RecommendDataset" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stdout", 22 | "output_type": "stream", 23 | "text": [ 24 | "Downloading...\n", 25 | "Using downloaded and verified file: ../datasets\\ml-1m\\raw\\ml-1m.zip\n", 26 | "Extracting...\n", 27 | "Done!\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "# step 1: download dataset\n", 33 | "root = get_movielens('../datasets', 'ml-1m')" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# step 2: read data\n", 43 | "users = read_data(root / 'users.dat', sep='::', names=['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code'])\n", 44 | "movies = read_data(root / 'movies.dat', sep='::', names=['MovieID', 'Title', 'Genres'])\n", 45 | "ratings = read_data(root / 'ratings.dat', sep='::', names=['UserID', 'MovieID', 'Rating', 'Timestamp'])" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# step 3: make dataset\n", 55 | "dataset = pd.merge(ratings, users, on='UserID')\n", 56 | "dataset = pd.merge(dataset, movies, on='MovieID')" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 5, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "# subsample\n", 66 | "dataset = dataset.iloc[5000:10000, :]" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 6, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "Making dataset Done!\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "# step 4: make features and dataloader\n", 84 | "sparse_features = ['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code', 'MovieID']\n", 85 | "sequence_features = ['Genres']\n", 86 | "dataset = fillna(dataset, dataset.columns, fill_v='unk')\n", 87 | "features = FeatureDict(sparse_features, None, sequence_features)\n", 88 | "input, _ = make_datasets(dataset, features, sep='|')\n", 89 | "loader = make_dataloader(input, dataset['Rating'].values, batch_size=64, shuffle=True)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 7, 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "EmbeddingLayer(\n", 102 | " (sparse_embeds): ModuleList(\n", 103 | " (0): Embedding(3205, 147)\n", 104 | " (1): Embedding(2, 2)\n", 105 | " (2): Embedding(7, 5)\n", 106 | " (3): Embedding(21, 9)\n", 107 | " (4): Embedding(2153, 118)\n", 108 | " (5): Embedding(4, 3)\n", 109 | " )\n", 110 | " (sequence_embeds): ModuleList(\n", 111 | " (0): EmbeddingBag(7, 5, mode=mean)\n", 112 | " )\n", 113 | " (drop): Dropout(p=0.0)\n", 114 | ")\n", 115 | "torch.Size([5000, 289])\n", 116 | "tensor([[ 1.0832, -0.3852, 0.9774, ..., 0.4901, 0.2720, 0.2515],\n", 117 | " [-2.9299, 1.2940, -0.9595, ..., 0.4901, 0.2720, 0.2515],\n", 118 | " [ 2.9813, 0.2656, 0.1590, ..., 0.4901, 0.2720, 0.2515],\n", 119 | " ...,\n", 120 | " [ 0.6574, 0.1386, 0.7176, ..., 1.2335, 0.4204, 0.3841],\n", 121 | " [ 0.0121, -0.4749, -0.2445, ..., 1.2335, 0.4204, 0.3841],\n", 122 | " [-0.6250, 1.1999, 0.7947, ..., 1.2335, 0.4204, 0.3841]],\n", 123 | " grad_fn=)\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "# step 5: build model\n", 129 | "model = EmbeddingLayer(input).to(defaults.device)\n", 130 | "print(model)\n", 131 | "out = model(input)\n", 132 | "print(out.shape, out, sep='\\n')\n", 133 | "# print(input)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 8, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "data": { 143 | "text/plain": [ 144 | "16244" 145 | ] 146 | }, 147 | "execution_count": 8, 148 | "metadata": {}, 149 | "output_type": "execute_result" 150 | } 151 | ], 152 | "source": [ 153 | "len(input.sequence_data.data[0])" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 9, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "data": { 163 | "text/plain": [ 164 | "5000" 165 | ] 166 | }, 167 | "execution_count": 9, 168 | "metadata": {}, 169 | "output_type": "execute_result" 170 | } 171 | ], 172 | "source": [ 173 | "len(input.sequence_data.bag_offsets[0])" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 10, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "targets = dataset['Rating'].values" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 11, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "data = RecommendDataset(input, targets)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 12, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "data": { 201 | "text/plain": [ 202 | "5000" 203 | ] 204 | }, 205 | "execution_count": 12, 206 | "metadata": {}, 207 | "output_type": "execute_result" 208 | } 209 | ], 210 | "source": [ 211 | "len(data)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 13, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "data": { 221 | "text/plain": [ 222 | "16244" 223 | ] 224 | }, 225 | "execution_count": 13, 226 | "metadata": {}, 227 | "output_type": "execute_result" 228 | } 229 | ], 230 | "source": [ 231 | "len(input.sequence_data.data[0])" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 14, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "import numpy as np" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 15, 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "name": "stdout", 250 | "output_type": "stream", 251 | "text": [ 252 | "Wall time: 3.99 ms\n" 253 | ] 254 | } 255 | ], 256 | "source": [ 257 | "%%time\n", 258 | "data1, offsets = [], np.zeros((data.lens, len(input.sequence_data.bag_offsets)), dtype=int)\n", 259 | "for x, y in zip(input.sequence_data.data, input.sequence_data.bag_offsets):\n", 260 | " tmp = []\n", 261 | " for idx, item in enumerate(y):\n", 262 | " tmp1 = []\n", 263 | " if idx == data.lens - 1:\n", 264 | " tmp1.extend(x[item:])\n", 265 | " else:\n", 266 | " tmp1.extend(x[item:y[idx + 1]])\n", 267 | " data1.append(tmp)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 16, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "# input.sequence_data.data/" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 17, 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "data": { 286 | "text/plain": [ 287 | "[]" 288 | ] 289 | }, 290 | "execution_count": 17, 291 | "metadata": {}, 292 | "output_type": "execute_result" 293 | } 294 | ], 295 | "source": [ 296 | "data1[0][3:8]" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 18, 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "name": "stdout", 306 | "output_type": "stream", 307 | "text": [ 308 | "Wall time: 6.98 ms\n" 309 | ] 310 | } 311 | ], 312 | "source": [ 313 | "%%time\n", 314 | "data1, offsets = [], []\n", 315 | "for i in range(data.lens):\n", 316 | " tmp = []\n", 317 | " for x, y in zip(input.sequence_data.data, input.sequence_data.bag_offsets): \n", 318 | " if i == data.lens - 1:\n", 319 | " t = x[y[-1]:]\n", 320 | " t = [t] if isinstance(t, int) else t\n", 321 | " tmp.append(t)\n", 322 | " else:\n", 323 | " t = x[y[i]:y[i + 1]]\n", 324 | " t = [t] if isinstance(t, int) else t\n", 325 | " tmp.append(t)\n", 326 | " data1.append(tmp)" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 19, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "offsets=np.zeros((data.lens, len(input.sequence_data.bag_offsets)), dtype=int)" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 20, 341 | "metadata": {}, 342 | "outputs": [ 343 | { 344 | "data": { 345 | "text/plain": [ 346 | "array([[0],\n", 347 | " [0],\n", 348 | " [0],\n", 349 | " [0]])" 350 | ] 351 | }, 352 | "execution_count": 20, 353 | "metadata": {}, 354 | "output_type": "execute_result" 355 | } 356 | ], 357 | "source": [ 358 | "offsets[3:7]" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 21, 364 | "metadata": {}, 365 | "outputs": [ 366 | { 367 | "data": { 368 | "text/plain": [ 369 | "[[[2, 3, 4]], [[2, 3, 4]], [[2, 3, 4]], [[2, 3, 4]]]" 370 | ] 371 | }, 372 | "execution_count": 21, 373 | "metadata": {}, 374 | "output_type": "execute_result" 375 | } 376 | ], 377 | "source": [ 378 | "data1[3:7]" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 22, 384 | "metadata": {}, 385 | "outputs": [ 386 | { 387 | "data": { 388 | "text/plain": [ 389 | "[2, 3, 4]" 390 | ] 391 | }, 392 | "execution_count": 22, 393 | "metadata": {}, 394 | "output_type": "execute_result" 395 | } 396 | ], 397 | "source": [ 398 | "data1[3:7][1][0]" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 23, 404 | "metadata": {}, 405 | "outputs": [ 406 | { 407 | "ename": "TypeError", 408 | "evalue": "'int' object is not iterable", 409 | "output_type": "error", 410 | "traceback": [ 411 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 412 | "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", 413 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msequence_data\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbag_offsets\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata1\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m7\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 414 | "\u001b[1;31mTypeError\u001b[0m: 'int' object is not iterable" 415 | ] 416 | } 417 | ], 418 | "source": [ 419 | "for i in len(input.sequence_data.bag_offsets):\n", 420 | " y = []\n", 421 | " for t in range(4):\n", 422 | " y.extend(data1[3:7][t][i])" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [ 431 | "input.sequence_data.bag_offsets[0][235]" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "# input.sequence_data" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [ 449 | "# data.sequence_data" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "metadata": {}, 456 | "outputs": [], 457 | "source": [] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "metadata": {}, 463 | "outputs": [], 464 | "source": [ 465 | "# for data, target in loader:\n", 466 | "# print(data, target)" 467 | ] 468 | } 469 | ], 470 | "metadata": { 471 | "kernelspec": { 472 | "display_name": "Python 3", 473 | "language": "python", 474 | "name": "python3" 475 | }, 476 | "language_info": { 477 | "codemirror_mode": { 478 | "name": "ipython", 479 | "version": 3 480 | }, 481 | "file_extension": ".py", 482 | "mimetype": "text/x-python", 483 | "name": "python", 484 | "nbconvert_exporter": "python", 485 | "pygments_lexer": "ipython3", 486 | "version": "3.6.9" 487 | } 488 | }, 489 | "nbformat": 4, 490 | "nbformat_minor": 4 491 | } 492 | -------------------------------------------------------------------------------- /nbs/协同过滤.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "%reload_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "# 多行输出\n", 13 | "from IPython.core.interactiveshell import InteractiveShell\n", 14 | "InteractiveShell.ast_node_interactivity = \"all\"" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# 协同过滤" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "协同过滤就是指用户可以齐心协力,通过不断地和网站互动,使自己的推荐列表能够不断过滤掉自己不感兴趣的物品,从而越来越满足自己的需求" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "1. **基于用户的协同过滤算法** 这种算法给用户推荐和他兴趣相似的其他用户喜欢的物品。\n", 36 | "2. **基于物品的协同过滤算法** 这种算法给用户推荐和他之前喜欢的物品相似的物品\n", 37 | "\n", 38 | "**TopN** 推荐的任务是预测用户会不会对某部电影评分,而不是预测用户在准备对某部电影评分的前提下会给电影评多少分" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [] 47 | } 48 | ], 49 | "metadata": { 50 | "kernelspec": { 51 | "display_name": "Python 3", 52 | "language": "python", 53 | "name": "python3" 54 | }, 55 | "language_info": { 56 | "codemirror_mode": { 57 | "name": "ipython", 58 | "version": 3 59 | }, 60 | "file_extension": ".py", 61 | "mimetype": "text/x-python", 62 | "name": "python", 63 | "nbconvert_exporter": "python", 64 | "pygments_lexer": "ipython3", 65 | "version": "3.6.6" 66 | } 67 | }, 68 | "nbformat": 4, 69 | "nbformat_minor": 2 70 | } 71 | -------------------------------------------------------------------------------- /nbs/评测指标.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "%reload_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "# 多行输出\n", 13 | "from IPython.core.interactiveshell import InteractiveShell\n", 14 | "InteractiveShell.ast_node_interactivity = \"all\"" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# 评测指标" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "1. 用户满意度(在线)\n", 29 | " \n", 30 | " - 问卷调查\n", 31 | " - 用购买率度量用户的满意度\n", 32 | " - 用户反馈界面收集用户满意度\n", 33 | " - 点击率\n", 34 | " - 用户停留时间\n", 35 | " - 转化率" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "2. 预测准确度(离线)\n", 43 | " - 评分预测\n", 44 | " \n", 45 | " 评分预测的预测准确度一般通过均方根误差(RMSE)和平均绝对误差(MAE)计算\n", 46 | " $$\n", 47 | " \\operatorname{RMSE}=\\frac{\\sqrt{\\sum_{u, i\\in T}\\left(r_{u i}-\\hat{r}_{u i}\\right)^{2}}}{|T|}\n", 48 | " $$\n", 49 | " $$\n", 50 | " \\mathrm{MAE}=\\frac{\\sum_{u, i \\in T}\\left|r_{u i}-\\hat{r}_{u i}\\right|}{|T|}\n", 51 | " $$\n", 52 | "\n", 53 | " $r_{u i}$ 用户 u 对物品 i 的实际评分,而 $\\hat{r}_{u i}$ 是推荐算法给出的预测评分,RMSE加大了对预测不准的用户物品评分的惩罚(平方项的惩罚),因而对系统的评测更加苛刻 " 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "3. TopN 推荐\n", 61 | "\n", 62 | "TopN推荐的预测准确率一般通过准确率(precision) /召回率(recall)度量\n", 63 | "\n", 64 | "$$\n", 65 | "\\operatorname{Recall}=\\frac{\\sum_{u \\in U}|R(u) \\cap T(u)|}{\\sum_{u \\in U}|T(u)|}\n", 66 | "$$\n", 67 | "\n", 68 | "$$\n", 69 | "\\operatorname{Precision}=\\frac{\\sum_{u \\in U}|R(u) \\cap T(u)|}{\\sum_{u \\in U}|R(u)|}\n", 70 | "$$\n", 71 | "\n", 72 | "R(u) 是根据用户在训练集上的行为给用户作出的推荐列表,而 T(u) 是用户在测试集上的行为列表\n", 73 | "\n", 74 | "为了全面评测TopN推荐的准确率和召回率,一般会选取不同的推荐列表长度 N,计算出一组准确率/召回率,然后画出准确率/召回率曲线(precision/recall curve)\n", 75 | "\n", 76 | "预测用户是否会看一部电影,应该比预测用户看了电影后会给它什么评分更加重要。TopN 预测更符合实际要求" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "4. 覆盖率\n", 84 | "\n", 85 | "覆盖率定义为推荐系统能够推荐出来的物品占总物品集合的比例。\n", 86 | "\n", 87 | "$$\n", 88 | "\\operatorname{Coverage}=\\frac{\\left|U_{u \\in U} R(u)\\right|}{|I|}\n", 89 | "$$\n", 90 | "\n", 91 | "\n", 92 | " - 覆盖率是一个内容提供商会关心的指标.覆盖率为100%的推荐系统可以将每个物品都推荐给至少一个用户\n", 93 | " - 热门排行榜的推荐覆盖率是很低的,它只会推荐那些热门的物品,这些物品在总物品中占的比例很小\n", 94 | " - 一个好的推荐系统不仅需要有比较高的用户满意度,也要有较高的覆盖率\n", 95 | "\n", 96 | " - 信息熵\n", 97 | "\n", 98 | " $$\n", 99 | " H=-\\sum_{i=1}^{n} p(i) \\log p(i)\n", 100 | " $$\n", 101 | "\n", 102 | " p(i) 是物品 i 的流行度除以所有物品流行度之和\n", 103 | "\n", 104 | " - 基尼系数\n", 105 | "\n", 106 | " $$\n", 107 | " G=\\frac{1}{n-1} \\sum_{j=1}^{n}(2 j-n-1) p\\left(i_{j}\\right)\n", 108 | " $$\n", 109 | "\n", 110 | " $i_j$ 是按照物品流行度 p() 从小到大排序的物品列表中第 j 个物品" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "5. 多样性\n", 118 | "\n", 119 | "多样性描述了推荐列表中物品两两之间的不相似性\n", 120 | "\n", 121 | "6. 新颖性\n", 122 | "\n", 123 | "新颖的推荐是指给用户推荐那些他们以前没有听说过的物品\n", 124 | "\n", 125 | "7. 惊喜度\n", 126 | "\n", 127 | "如果推荐结果和用户的历史兴趣不相似,但却让用户觉得满意,那么就可以说推荐结果的惊喜度很高,而推荐的新颖性仅仅取决于用户是否听说过这个推荐结果\n", 128 | "\n", 129 | "8. 信任度\n", 130 | "\n", 131 | "度量推荐系统的信任度只能通过问卷调查的方式,询问用户是否信任推荐系统的推荐结果" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "9. 评测维度\n", 139 | "\n", 140 | " - 用户维度\n", 141 | " - 物品维度\n", 142 | " - 时间维度\n", 143 | "\n", 144 | "在评测系统中还需要考虑评测维度,比如一个推荐算法,虽然整体性能不好,但可能在某种情况下性能比较好,而增加评测维度的目的就是知道一个算法在什么情况下性能最好。" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "Python 3", 158 | "language": "python", 159 | "name": "python3" 160 | }, 161 | "language_info": { 162 | "codemirror_mode": { 163 | "name": "ipython", 164 | "version": 3 165 | }, 166 | "file_extension": ".py", 167 | "mimetype": "text/x-python", 168 | "name": "python", 169 | "nbconvert_exporter": "python", 170 | "pygments_lexer": "ipython3", 171 | "version": "3.6.6" 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 2 176 | } 177 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-learn 3 | torch>=1.0 4 | torchvision 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """setup 2 | Copyright: 3 | ---------- 4 | Author: AutuanLiu 5 | Date: 2019/06/01 6 | """ 7 | 8 | import distutils.spawn 9 | import shlex 10 | import subprocess 11 | import sys 12 | 13 | from setuptools import find_packages, setup 14 | 15 | version = '0.1.0' 16 | 17 | if sys.argv[1] == 'release': 18 | if not distutils.spawn.find_executable('twine'): 19 | print( 20 | 'Please install twine:\n\n\tpip install twine\n', 21 | file=sys.stderr, 22 | ) 23 | sys.exit(1) 24 | 25 | commands = [ 26 | 'git pull origin master', 27 | 'git tag v{:s}'.format(version), 28 | 'git push origin master --tag', 29 | 'python setup.py sdist', 30 | 'twine upload dist/imgviz-{:s}.tar.gz'.format(version), 31 | ] 32 | for cmd in commands: 33 | print('+ {}'.format(cmd)) 34 | subprocess.check_call(shlex.split(cmd)) 35 | sys.exit(0) 36 | 37 | 38 | def get_install_requires(): 39 | install_requires = [] 40 | with open('requirements.txt') as f: 41 | for req in f: 42 | install_requires.append(req.strip()) 43 | return install_requires 44 | 45 | 46 | with open('description.md') as f: 47 | long_description = f.read() 48 | 49 | setup( 50 | name='torchctr', 51 | version=version, 52 | packages=find_packages(), 53 | install_requires=get_install_requires(), 54 | description='CTR prediction based on PyTorch.', 55 | long_description=long_description, 56 | long_description_content_type='text/markdown', 57 | include_package_data=True, 58 | python_requires='>=3.5', 59 | author='Autuan Liu', 60 | author_email='autuanliu@163.com', 61 | url='https://github.com/AutuanLiu/torchctr', 62 | license='MIT', 63 | classifiers=[ 64 | 'Development Status :: 5 - Production/Stable', 65 | 'Intended Audience :: Developers', 66 | 'Natural Language :: English', 67 | 'Programming Language :: Python', 68 | 'Programming Language :: Python :: 3.5', 69 | 'Programming Language :: Python :: 3.6', 70 | 'Programming Language :: Python :: 3.7', 71 | 'Programming Language :: Python :: Implementation :: CPython', 72 | 'Programming Language :: Python :: Implementation :: PyPy', 73 | ], 74 | ) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from torchctr.datasets import get_movielens, read_data 2 | 3 | # step 1: download dataset 4 | root = get_movielens('datasets', 'ml-1m') 5 | 6 | # step 2: read data 7 | users = read_data(root / 'users.dat', sep='::', names=['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code']) 8 | movies = read_data(root / 'movies.dat', sep='::', names=['MovieID', 'Title', 'Genres']) 9 | ratings = read_data(root / 'ratings.dat', sep='::', names=['UserID', 'MovieID', 'Rating', 'Timestamp']) 10 | -------------------------------------------------------------------------------- /torchctr/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import EmbeddingDropout, EmbeddingLayer 2 | from .tools import timmer 3 | 4 | __all__ = ['EmbeddingLayer', 'timmer', 'EmbeddingDropout'] 5 | -------------------------------------------------------------------------------- /torchctr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .criteo import get_criteo 2 | from .data import RecommendDataset 3 | from .movielens import get_movielens 4 | from .transform import (dense_feature_scale, fillna, make_dataloader, make_datasets, sequence_feature_encoding, 5 | sparse_feature_encoding) 6 | from .utils import (DataInput, DataMeta, FeatureDict, defaults, dropout_mask, emb_sz_rule, extract_file, read_data, 7 | totensor, train_test_split) 8 | 9 | __all__ = [ 10 | 'RecommendDataset', 'extract_file', 'get_movielens', 'get_criteo', 'train_test_split', 'DataMeta', 'DataInput', 11 | 'FeatureDict', 'defaults', 'read_data', 'sequence_feature_encoding', 'dense_feature_scale', 'dropout_mask', 12 | 'sparse_feature_encoding', 'make_datasets', 'fillna', 'emb_sz_rule', 'totensor', 'make_dataloader' 13 | ] 14 | -------------------------------------------------------------------------------- /torchctr/datasets/criteo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from torchvision.datasets.utils import download_url, makedir_exist_ok 5 | 6 | from .utils import extract_file 7 | 8 | 9 | def get_criteo(root): 10 | """Download the Criteo data if it doesn't exist.""" 11 | 12 | url = 'https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz' 13 | 14 | raw_folder = os.path.join(root, 'criteo', 'raw') 15 | processed_folder = os.path.join(root, 'criteo', 'processed') 16 | makedir_exist_ok(raw_folder) 17 | makedir_exist_ok(processed_folder) 18 | 19 | # download files and extract 20 | filename = url.rpartition('/')[2] 21 | print('Downloading...') 22 | download_url(url, root=raw_folder, filename=filename, md5=None) 23 | print('Extracting...') 24 | extract_file(os.path.join(raw_folder, filename), processed_folder) 25 | print('Done!') 26 | return Path(processed_folder) 27 | -------------------------------------------------------------------------------- /torchctr/datasets/data.py: -------------------------------------------------------------------------------- 1 | from .utils import DataInput, DataMeta, totensor 2 | 3 | 4 | class RecommendDataset: 5 | """only support for sparse, sequence and dense data""" 6 | 7 | def __init__(self, input, target): 8 | pass 9 | 10 | def __getitem__(self, index): 11 | pass 12 | 13 | def __len__(self): 14 | return self.lens 15 | -------------------------------------------------------------------------------- /torchctr/datasets/movielens.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from torchvision.datasets.utils import download_url, makedir_exist_ok 5 | 6 | from .utils import extract_file 7 | 8 | 9 | def get_movielens(root, version='ml_20m'): 10 | """Download the MovieLens data if it doesn't exist.""" 11 | 12 | urls = { 13 | 'ml-latest': 'http://files.grouplens.org/datasets/movielens/ml-latest.zip', 14 | 'ml-100k': 'http://files.grouplens.org/datasets/movielens/ml-100k.zip', 15 | 'ml-1m': 'http://files.grouplens.org/datasets/movielens/ml-1m.zip', 16 | 'ml-10m': 'http://files.grouplens.org/datasets/movielens/ml-10m.zip', 17 | 'ml-20m': 'http://files.grouplens.org/datasets/movielens/ml-20m.zip' 18 | } 19 | 20 | assert version in urls.keys(), f"version must be one of {set(urls.keys())}" 21 | raw_folder = os.path.join(root, version, 'raw') 22 | processed_folder = os.path.join(root, version, 'processed') 23 | makedir_exist_ok(raw_folder) 24 | makedir_exist_ok(processed_folder) 25 | 26 | # download files and extract 27 | filename = urls[version].rpartition('/')[2] 28 | print('Downloading...') 29 | download_url(urls[version], root=raw_folder, filename=filename, md5=None) 30 | print('Extracting...') 31 | extract_file(os.path.join(raw_folder, filename), processed_folder) 32 | print('Done!') 33 | return Path(os.path.join(processed_folder, version)) 34 | -------------------------------------------------------------------------------- /torchctr/datasets/transform.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.feature_extraction.text import CountVectorizer 6 | from sklearn.preprocessing import LabelEncoder, MinMaxScaler, StandardScaler 7 | from torch.utils.data import DataLoader 8 | 9 | from .data import RecommendDataset 10 | from .utils import DataInput, DataMeta, FeatureDict, defaults 11 | 12 | 13 | def sparse_feature_encoding(data: pd.DataFrame, features_names: Union[str, List[str]]): 14 | """Encoding for sparse features.""" 15 | 16 | if not features_names: 17 | return None 18 | nuniques = [] 19 | for feat in features_names: 20 | lbe = LabelEncoder() 21 | data[feat] = lbe.fit_transform(data[feat]) 22 | nuniques.append(len(lbe.classes_)) 23 | data_meta = DataMeta(data[features_names].values, features_names, nuniques) 24 | return data_meta 25 | 26 | 27 | def sequence_feature_encoding(data: pd.DataFrame, features_names: Union[str, List[str]], sep: str = ','): 28 | """Encoding for sequence features.""" 29 | 30 | if not features_names: 31 | return None 32 | data_value, nuniques = [], [] 33 | for feature in features_names: 34 | vocab = set.union(*[set(str(x).strip().split(sep=sep)) for x in data[feature]]) 35 | vec = CountVectorizer(vocabulary=vocab) 36 | multi_hot = vec.transform(data[feature]) 37 | # data_value.append(multi_hot) 38 | nuniques.append(len(vocab)) 39 | data_meta = DataMeta(data_value, None, features_names, nuniques, bags_offsets) 40 | to index 41 | ret, offsets, offset = [], [], 0 42 | for row in data[feature]: 43 | offsets.append(offset) 44 | row = row.split(sep) if isinstance(row, str) else str(row).split(sep) 45 | ret.extend(map(lambda word: vec.vocabulary_[word], row)) 46 | offset += len(row) 47 | data_value.append(ret) 48 | bags_offsets.append(offsets) 49 | data_meta = DataMeta(data_value, None, features_names, nuniques, bags_offsets) 50 | return data_meta 51 | 52 | 53 | def dense_feature_scale(data: pd.DataFrame, features_names: Union[str, List[str]], scaler_instance=None): 54 | """Scaling for sparse features.""" 55 | 56 | if not features_names: 57 | return None, None 58 | scaler = scaler_instance if scaler_instance else StandardScaler() 59 | scaler = scaler.fit(data[features_names]) 60 | data[features_names] = scaler.transform(data[features_names]) 61 | data_meta = DataMeta(data[features_names].values, features_names) 62 | return data_meta, scaler 63 | 64 | 65 | def fillna(data: pd.DataFrame, features_names: Union[str, List[str]], fill_v, **kwargs): 66 | """Fill Nan with fill_v.""" 67 | 68 | data[features_names] = data[features_names].fillna(fill_v, **kwargs) 69 | return data 70 | 71 | 72 | def make_datasets(data: pd.DataFrame, features_dict=None, sep=',', scaler=None): 73 | """make dataset for df. 74 | 75 | Args: 76 | data (pd.DataFrame): data 77 | features_dict (FeatureDict): instance of FeatureDict. Defaults to None. 78 | sep (str, optional): sep for sequence. Defaults to ','. 79 | scaler: sacler for dense data. 80 | """ 81 | 82 | pass 83 | 84 | 85 | def make_dataloader(input: DataInput, targets=None, batch_size=64, shuffle=False, drop_last=False): 86 | pass 87 | -------------------------------------------------------------------------------- /torchctr/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | import tarfile 4 | import zipfile 5 | from collections import namedtuple 6 | from pathlib import Path 7 | from types import SimpleNamespace 8 | from functools import lru_cache 9 | 10 | import pandas as pd 11 | import torch 12 | from torch.utils.data import random_split 13 | 14 | 15 | def num_cpus() -> int: 16 | "Get number of cpus" 17 | 18 | try: 19 | return len(os.sched_getaffinity(0)) 20 | except AttributeError: 21 | return os.cpu_count() 22 | 23 | 24 | # simple name space 25 | defaults = SimpleNamespace(cpus=min(16, num_cpus()), 26 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 27 | 28 | 29 | def extract_file(from_path, to_path, remove_finished=False): 30 | """https://github.com/pytorch/vision/blob/master/torchvision/datasets/utils.py""" 31 | 32 | if from_path.endswith(".zip"): 33 | with zipfile.ZipFile(from_path, 'r') as z: 34 | z.extractall(to_path) 35 | elif from_path.endswith(".tar"): 36 | with tarfile.open(from_path, 'r:') as tar: 37 | tar.extractall(path=to_path) 38 | elif from_path.endswith(".tar.gz"): 39 | with tarfile.open(from_path, 'r:gz') as tar: 40 | tar.extractall(path=to_path) 41 | elif from_path.endswith(".gz") and not from_path.endswith(".tar.gz"): 42 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) 43 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: 44 | out_f.write(zip_f.read()) 45 | else: 46 | raise ValueError("Extraction of {from_path} not supported") 47 | 48 | if remove_finished: 49 | os.unlink(from_path) 50 | 51 | 52 | def train_test_split(dataset, test_rate): 53 | """Split dataset into two subdataset(train/test).""" 54 | 55 | test_size = round(len(dataset) * test_rate) 56 | train_size = len(dataset) - test_size 57 | return random_split(dataset, [train_size, test_size]) 58 | 59 | 60 | def read_data(filename, **kwargs): 61 | """read data from files. 62 | 63 | Args: 64 | filename (str or Path): file name. 65 | """ 66 | 67 | if not isinstance(filename, Path): 68 | filename = Path(filename) 69 | return pd.read_csv(filename, engine='python', **kwargs) 70 | 71 | 72 | def emb_sz_rule(dim: int) -> int: 73 | return min(600, round(1.6 * dim**0.56)) 74 | 75 | 76 | def totensor(x): 77 | return x if isinstance(x, torch.Tensor) else torch.as_tensor(x, device=defaults.device) 78 | 79 | 80 | def dropout_mask(x: torch.Tensor, sz: Collection[int], p: float): 81 | "Return a dropout mask of the same type as `x`, size `sz`, with probability `p` to cancel an element." 82 | 83 | return x.new(*sz).bernoulli_(1 - p).div_(1 - p) 84 | -------------------------------------------------------------------------------- /torchctr/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .datasets import DataInput, defaults, dropout_mask, emb_sz_rule, totensor 4 | from typing import Optional 5 | import torch.nn.functional as F 6 | 7 | 8 | class EmbeddingDropout(nn.Module): 9 | "Apply dropout with probabily `embed_p` to an embedding layer `emb`." 10 | 11 | def __init__(self, emb: nn.Module, embed_p: float): 12 | super().__init__() 13 | self.emb, self.embed_p = emb, embed_p 14 | 15 | def forward(self, words: torch.LongTensor, scale: Optional[float] = None) -> torch.Tensor: 16 | if self.training and self.embed_p != 0: 17 | size = (self.emb.weight.size(0), 1) 18 | mask = dropout_mask(self.emb.weight.data, size, self.embed_p) 19 | masked_embed = self.emb.weight * mask 20 | else: 21 | masked_embed = self.emb.weight 22 | if scale: masked_embed.mul_(scale) 23 | return F.embedding(words, masked_embed, self.emb.padding_idx, self.emb.max_norm, self.emb.norm_type, 24 | self.emb.scale_grad_by_freq, self.emb.sparse) 25 | 26 | 27 | class EmbeddingLayer(nn.Module): 28 | """Embedding layer: convert sparse data to dense data. 29 | 30 | Args: 31 | emb_szs (dict): {feature: embedding size}. 32 | emb_drop (float): drop out. only support for sparse data now. 33 | x (DataInput): instance of DataInput, which includes sparse, sequence, dense data. 34 | 35 | Returns: 36 | torch.Tensor: dense data. 37 | """ 38 | 39 | def __init__(self, x, emb_szs_dict=None, emb_drop=0, mode='mean'): 40 | super().__init__() 41 | assert mode in ['sum', 'mean'], "mode must in {'sum', 'mean'}" 42 | layers = [] 43 | self.mode = mode 44 | if x.sparse_data: 45 | nuniques = x.sparse_data.nunique 46 | if emb_szs_dict: 47 | emb_szs = [emb_szs_dict[f] for f in x.sparse_data.features] 48 | else: 49 | emb_szs = [emb_sz_rule(t) for t in nuniques] 50 | self.sparse_embeds = nn.ModuleList( 51 | [EmbeddingDropout(nn.Embedding(ni, nf), emb_drop) for ni, nf in zip(nuniques, emb_szs)]) 52 | del nuniques, emb_szs 53 | if x.sequence_data: 54 | nuniques = x.sequence_data.nunique 55 | if emb_szs_dict: 56 | emb_szs = [emb_szs_dict[f] for f in x.sequence_data.features] 57 | else: 58 | emb_szs = [self.emb_sz_rule(t) for t in nuniques] 59 | # self.sequence_embeds = nn.ModuleList( 60 | # [nn.EmbeddingBag(ni, nf, mode=mode) for ni, nf in zip(nuniques, emb_szs)]) 61 | self.sequence_embeds = nn.ModuleList( 62 | [nn.Embedding(ni, nf) for ni, nf in zip(nuniques, emb_szs)]) 63 | del nuniques, emb_szs 64 | self.drop = emb_drop 65 | 66 | def forward(self, x): 67 | out = [] 68 | if x.sparse_data: 69 | data = totensor(x.sparse_data.data).long() 70 | sparse_out = [e(data[:, i]) for i, e in enumerate(self.sparse_embeds)] 71 | sparse_out = torch.cat(sparse_out, 1) 72 | out.append(sparse_out) 73 | if x.sequence_data: 74 | nuniques = x.sequence_data.nunique 75 | data = totensor(x.sequence_data.data).float() 76 | data = data.split(nuniques, dim=1) 77 | 78 | sequence_out = [ 79 | data[i] @ e.weight if self.mode == 'sum' else data[i] @ e.weight / data[i].sum(dim=1).view(-1, 1) 80 | for i, e in enumerate(self.sequence_embeds) 81 | ] 82 | sequence_out = torch.cat(sequence_out, 1) 83 | out.append(sequence_out) 84 | if x.dense_data: 85 | dense_data = totensor(x.dense_data.data).float() 86 | out.append(dense_data) 87 | return torch.cat(out, 1) 88 | -------------------------------------------------------------------------------- /torchctr/learner.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import nn, optim 5 | from .datasets.utils import defaults, totensor 6 | 7 | 8 | @dataclass 9 | class Learner: 10 | model: nn.Module = model.to(defaults.device) 11 | criterion: nn.Module 12 | opt: optim.Optimizer 13 | 14 | def fit(input_loader, epoch=100): 15 | pass 16 | 17 | @torch.no_grad() 18 | def predict(input): 19 | pass 20 | 21 | def save_trained_model(self, path): 22 | """save trained model's weights. 23 | Args: 24 | path (str): the path to save checkpoint. 25 | """ 26 | 27 | # save model weights 28 | torch.save(self.model.state_dict(), path) 29 | 30 | def save_model(self, path): 31 | """save model. 32 | Args: 33 | path (str): the path to save checkpoint. 34 | """ 35 | 36 | # save model weights 37 | torch.save(self.model, path) 38 | -------------------------------------------------------------------------------- /torchctr/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | class Metric(): 4 | def __init__(self, train, test): 5 | """评价指标 6 | 7 | Args: 8 | train ([type]): 训练数据 9 | test ([type]): 测试数据 10 | """ 11 | 12 | self.train = train 13 | self.test = test 14 | self.recs = self.getRec() 15 | 16 | 17 | def getRec(self): 18 | recs = {} 19 | for user in self.test: 20 | rank = self.GetRecommendation(user) 21 | recs[user] = rank 22 | return recs 23 | 24 | def precision(self): 25 | all, hit = 0, 0 26 | for user in self.test: 27 | test_items = set(self.test[user]) 28 | rank = self.recs[user] 29 | for item, score in rank: 30 | if item in test_items: 31 | hit += 1 32 | all += len(rank) 33 | return round(hit / all * 100, 2) 34 | 35 | # 定义召回率指标计算方式 36 | def recall(self): 37 | all, hit = 0, 0 38 | for user in self.test: 39 | test_items = set(self.test[user]) 40 | rank = self.recs[user] 41 | for item, score in rank: 42 | if item in test_items: 43 | hit += 1 44 | all += len(test_items) 45 | return round(hit / all * 100, 2) 46 | 47 | # 定义覆盖率指标计算方式 48 | def coverage(self): 49 | all_item, recom_item = set(), set() 50 | for user in self.test: 51 | for item in self.train[user]: 52 | all_item.add(item) 53 | rank = self.recs[user] 54 | for item, score in rank: 55 | recom_item.add(item) 56 | return round(len(recom_item) / len(all_item) * 100, 2) 57 | 58 | # 定义新颖度指标计算方式 59 | def popularity(self): 60 | # 计算物品的流行度 61 | item_pop = {} 62 | for user in self.train: 63 | for item in self.train[user]: 64 | if item not in item_pop: 65 | item_pop[item] = 0 66 | item_pop[item] += 1 67 | 68 | num, pop = 0, 0 69 | for user in self.test: 70 | rank = self.recs[user] 71 | for item, score in rank: 72 | # 取对数,防止因长尾问题带来的被流行物品所主导 73 | pop += math.log(1 + item_pop[item]) 74 | num += 1 75 | return round(pop / num, 6) 76 | -------------------------------------------------------------------------------- /torchctr/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/torchctr/models/__init__.py -------------------------------------------------------------------------------- /torchctr/models/deepfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch import Tensor, nn 5 | from typing import List 6 | 7 | 8 | class DeepFM(nn.Module): 9 | def __init__(self, input_dim=-1, n_fields=-1, embedding_dim=-1, fc_dims=[]): 10 | super().__init__() 11 | self.input_dim = input_dim 12 | self.n_fields = n_fields 13 | self.embedding_dim = embedding_dim 14 | self.mats = [] 15 | if input_dim > 0 and embedding_dim > 0 and n_fields > 0 and fc_dims: 16 | self.bias = torch.nn.Parameter(torch.zeros(1, 1)) 17 | self.weights = torch.nn.Parameter(torch.zeros(input_dim, 1)) 18 | self.embedding = torch.nn.Parameter(torch.zeros(input_dim, embedding_dim)) 19 | torch.nn.init.xavier_uniform_(self.weights) 20 | torch.nn.init.xavier_uniform_(self.embedding) 21 | dim = n_fields * embedding_dim # DNN input dim 22 | # DNN FC layers 23 | for (index, fc_dim) in enumerate(fc_dims): 24 | self.mats.append(torch.nn.Parameter(torch.randn(dim, fc_dim))) # weight 25 | self.mats.append(torch.nn.Parameter(torch.randn(1, 1))) # bias 26 | torch.nn.init.xavier_uniform_(self.mats[index * 2]) 27 | dim = fc_dim 28 | 29 | def first_order(self, batch_size, index, values, bias, weights): 30 | # type: (int, Tensor, Tensor, Tensor, Tensor) -> Tensor 31 | srcs = weights.view(1, -1).mul(values.view(1, -1)).view(-1) 32 | output = torch.zeros(batch_size, dtype=torch.float32) 33 | output.scatter_add_(0, index, srcs) 34 | first = output + bias 35 | return first 36 | 37 | def second_order(self, batch_size, index, values, embeddings): 38 | # type: (int, Tensor, Tensor, Tensor) -> Tensor 39 | k = embeddings.size(1) 40 | b = batch_size 41 | 42 | # t1: [k, n] 43 | t1 = embeddings.mul(values.view(-1, 1)).transpose_(0, 1) 44 | # t1: [k, b] 45 | t1_ = torch.zeros(k, b, dtype=torch.float32) 46 | 47 | for i in range(k): 48 | t1_[i].scatter_add_(0, index, t1[i]) 49 | 50 | # t1: [k, b] 51 | t1 = t1_.pow(2) 52 | 53 | # t2: [k, n] 54 | t2 = embeddings.pow(2).mul(values.pow(2).view(-1, 1)).transpose_(0, 1) 55 | # t2: [k, b] 56 | t2_ = torch.zeros(k, b, dtype=torch.float32) 57 | for i in range(k): 58 | t2_[i].scatter_add_(0, index, t2[i]) 59 | 60 | # t2: [k, b] 61 | t2 = t2_ 62 | 63 | second = t1.sub(t2).transpose_(0, 1).sum(1).mul(0.5) 64 | return second 65 | 66 | def higher_order(self, batch_size, embeddings, mats): 67 | # type: (int, Tensor, List[Tensor]) -> Tensor 68 | # activate function: relu 69 | output = embeddings.view(batch_size, -1) 70 | 71 | for i in range(int(len(mats) / 2)): 72 | output = torch.relu(output.matmul(mats[i * 2]) + mats[i * 2 + 1]) 73 | 74 | return output.view(-1) 75 | 76 | def forward_(self, batch_size, index, feats, values, bias, weights, embeddings, mats): 77 | # type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor 78 | 79 | first = self.first_order(batch_size, index, values, bias, weights) 80 | second = self.second_order(batch_size, index, values, embeddings) 81 | higher = self.higher_order(batch_size, embeddings, mats) 82 | 83 | return torch.sigmoid(first + second + higher) 84 | 85 | def forward(self, batch_size, index, feats, values): 86 | # type: (int, Tensor, Tensor, Tensor) -> Tensor 87 | batch_first = F.embedding(feats, self.weights) 88 | batch_second = F.embedding(feats, self.embedding) 89 | return self.forward_(batch_size, index, feats, values, self.bias, batch_first, batch_second, self.mats) 90 | -------------------------------------------------------------------------------- /torchctr/models/ffm.py: -------------------------------------------------------------------------------- 1 | # https://github.com/LLSean/data-mining 2 | 3 | import os 4 | import sys 5 | import tensorflow as tf 6 | import logging 7 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) 8 | import numpy as np 9 | import argparse 10 | from util import * 11 | from sklearn.metrics import * 12 | 13 | 14 | class FFM(object): 15 | def __init__(self, num_classes, k, field, lr, batch_size, feature_length, reg_l1, reg_l2, feature2field): 16 | self.num_classes = num_classes 17 | self.k = k 18 | self.field = field 19 | self.lr = lr 20 | self.batch_size = batch_size 21 | self.p = feature_length 22 | self.reg_l1 = reg_l1 23 | self.reg_l2 = reg_l2 24 | self.feature2field = feature2field 25 | 26 | def add_input(self): 27 | self.X = tf.placeholder('float32', [None, self.p]) 28 | self.y = tf.placeholder('float32', [None, num_classes]) 29 | self.keep_prob = tf.placeholder('float32') 30 | 31 | def inference(self): 32 | with tf.variable_scope('linear_layer'): 33 | w0 = tf.get_variable('w0', shape=[self.num_classes], initializer=tf.zeros_initializer()) 34 | self.w = tf.get_variable('w', shape=[self.p, num_classes], initializer=tf.truncated_normal_initializer(mean=0, stddev=0.01)) 35 | self.linear_terms = tf.add(tf.matmul(self.X, self.w), w0) 36 | 37 | with tf.variable_scope('interaction_layer'): 38 | self.v = tf.get_variable('v', shape=[self.p, self.field, self.k], initializer=tf.truncated_normal_initializer(mean=0, stddev=0.01)) 39 | self.interaction_terms = tf.constant(0, dtype='float32') 40 | for i in range(self.p): 41 | for j in range(i + 1, self.p): 42 | self.interaction_terms += tf.multiply( 43 | tf.reduce_sum(tf.multiply(self.v[i, self.feature2field[i]], self.v[j, self.feature2field[j]])), tf.multiply(self.X[:, i], self.X[:, j])) 44 | self.interaction_terms = tf.reshape(self.interaction_terms, [-1, 1]) 45 | self.y_out = tf.math.add(self.linear_terms, self.interaction_terms) 46 | if self.num_classes == 2: 47 | self.y_out_prob = tf.nn.sigmoid(self.y_out) 48 | elif self.num_classes > 2: 49 | self.y_out_prob = tf.nn.softmax(self.y_out) 50 | 51 | def add_loss(self): 52 | if self.num_classes == 2: 53 | cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.y, logits=self.y_out) 54 | elif self.num_classes > 2: 55 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=self.y_out) 56 | mean_loss = tf.reduce_mean(cross_entropy) 57 | self.loss = mean_loss 58 | tf.summary.scalar('loss', self.loss) 59 | 60 | def add_accuracy(self): 61 | # accuracy 62 | self.correct_prediction = tf.equal(tf.cast(tf.argmax(self.y_out, 1), tf.float32), tf.cast(tf.argmax(self.y, 1), tf.float32)) 63 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32)) 64 | # add summary to accuracy 65 | tf.summary.scalar('accuracy', self.accuracy) 66 | 67 | def train(self): 68 | self.global_step = tf.Variable(0, trainable=False) 69 | optimizer = tf.train.AdagradOptimizer(self.lr) 70 | extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 71 | with tf.control_dependencies(extra_update_ops): 72 | self.train_op = optimizer.minimize(self.loss, global_step=self.global_step) 73 | 74 | def build_graph(self): 75 | self.add_input() 76 | self.inference() 77 | self.add_loss() 78 | self.add_accuracy() 79 | self.train() 80 | 81 | 82 | def train_model(sess, model, epochs=100, print_every=50): 83 | """training model""" 84 | # Merge all the summaries and write them out to train_logs 85 | merged = tf.summary.merge_all() 86 | train_writer = tf.summary.FileWriter('train_logs', sess.graph) 87 | 88 | # get number of batches 89 | num_batches = len(x_train) // batch_size + 1 90 | 91 | for e in range(epochs): 92 | num_samples = 0 93 | losses = [] 94 | for ibatch in range(num_batches): 95 | # batch_size data 96 | batch_x, batch_y = next(batch_gen) 97 | batch_y = np.array(batch_y).astype(np.float32) 98 | actual_batch_size = len(batch_y) 99 | # create a feed dictionary for this batch 100 | feed_dict = {model.X: batch_x, model.y: batch_y, model.keep_prob: 1.0} 101 | 102 | loss, accuracy, summary, global_step, _ = sess.run([model.loss, model.accuracy, merged, model.global_step, model.train_op], feed_dict=feed_dict) 103 | # aggregate performance stats 104 | losses.append(loss * actual_batch_size) 105 | num_samples += actual_batch_size 106 | # Record summaries and train.csv-set accuracy 107 | train_writer.add_summary(summary, global_step=global_step) 108 | # print training loss and accuracy 109 | if global_step % print_every == 0: 110 | logging.info("Iteration {0}: with minibatch training loss = {1} and accuracy of {2}".format(global_step, loss, accuracy)) 111 | saver.save(sess, "checkpoints/model", global_step=global_step) 112 | # print loss of one epoch 113 | total_loss = np.sum(losses) / num_samples 114 | print("Epoch {1}, Overall loss = {0:.3g}".format(total_loss, e + 1)) 115 | 116 | 117 | def test_model(sess, model, print_every=50): 118 | """training model""" 119 | # get testing data, iterable 120 | all_ids = [] 121 | all_clicks = [] 122 | # get number of batches 123 | num_batches = len(y_test) // batch_size + 1 124 | 125 | for ibatch in range(num_batches): 126 | # batch_size data 127 | batch_x, batch_y = next(test_batch_gen) 128 | actual_batch_size = len(batch_y) 129 | # create a feed dictionary for this15162 batch 130 | feed_dict = {model.X: batch_x, model.keep_prob: 1} 131 | # shape of [None,2] 132 | y_out_prob = sess.run([model.y_out_prob], feed_dict=feed_dict) 133 | y_out_prob = np.array(y_out_prob[0]) 134 | 135 | batch_clicks = np.argmax(y_out_prob, axis=1) 136 | 137 | batch_y = np.argmax(batch_y, axis=1) 138 | 139 | print(confusion_matrix(batch_y, batch_clicks)) 140 | ibatch += 1 141 | if ibatch % print_every == 0: 142 | logging.info("Iteration {0} has finished".format(ibatch)) 143 | 144 | 145 | def shuffle_list(data): 146 | num = data[0].shape[0] 147 | p = np.random.permutation(num) 148 | return [d[p] for d in data] 149 | 150 | 151 | def batch_generator(data, batch_size, shuffle=True): 152 | if shuffle: 153 | data = shuffle_list(data) 154 | 155 | batch_count = 0 156 | while True: 157 | if batch_count * batch_size + batch_size > len(data[0]): 158 | batch_count = 0 159 | 160 | if shuffle: 161 | data = shuffle_list(data) 162 | 163 | start = batch_count * batch_size 164 | end = start + batch_size 165 | batch_count += 1 166 | yield [d[start:end] for d in data] 167 | 168 | 169 | def check_restore_parameters(sess, saver): 170 | """ Restore the previously trained parameters if there are any. """ 171 | ckpt = tf.train.get_checkpoint_state("checkpoints") 172 | if ckpt and ckpt.model_checkpoint_path: 173 | logging.info("Loading parameters for the my Factorization Machine") 174 | saver.restore(sess, ckpt.model_checkpoint_path) 175 | else: 176 | logging.info("Initializing fresh parameters for the my Factorization Machine") 177 | 178 | 179 | if __name__ == '__main__': 180 | '''launching TensorBoard: tensorboard --logdir=path/to/log-directory''' 181 | # get mode (train or test) 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument('--mode', help='train or test', type=str) 184 | args = parser.parse_args() 185 | mode = args.mode 186 | # length of representation 187 | x_train, y_train, x_test, y_test, feature2field = load_dataset() 188 | # initialize the model 189 | num_classes = 2 190 | lr = 0.01 191 | batch_size = 128 192 | k = 8 193 | reg_l1 = 2e-2 194 | reg_l2 = 0 195 | feature_length = x_train.shape[1] 196 | # initialize FM model 197 | batch_gen = batch_generator([x_train, y_train], batch_size) 198 | test_batch_gen = batch_generator([x_test, y_test], batch_size) 199 | model = FFM(num_classes, k, 5, lr, batch_size, feature_length, reg_l1, reg_l2, feature2field) 200 | # build graph for model 201 | model.build_graph() 202 | 203 | saver = tf.train.Saver(max_to_keep=5) 204 | 205 | with tf.Session() as sess: 206 | sess.run(tf.global_variables_initializer()) 207 | check_restore_parameters(sess, saver) 208 | if mode == 'train': 209 | print('start training...') 210 | train_model(sess, model, epochs=100, print_every=500) 211 | if mode == 'test': 212 | print('start testing...') 213 | test_model(sess, model) -------------------------------------------------------------------------------- /torchctr/models/fm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FactorizationMachine(nn.Module): 7 | def __init__(self, input_dim=-1, embedding_dim=-1): 8 | super().__init__() 9 | self.input_dim = input_dim 10 | self.embedding_dim = embedding_dim 11 | 12 | if input_dim > 0 and embedding_dim > 0: 13 | self.bias = torch.randn(1, 1, dtype=torch.float32) 14 | self.weights = torch.randn(input_dim, 1) 15 | self.embedding = torch.randn(input_dim, embedding_dim) 16 | self.bias = nn.Parameter(self.bias, requires_grad=True) 17 | self.weights = nn.Parameter(self.weights, requires_grad=True) 18 | self.embedding = nn.Parameter(self.embedding, requires_grad=True) 19 | nn.init.xavier_uniform_(self.weights) 20 | nn.init.xavier_uniform_(self.embedding) 21 | 22 | def first_order(self, batch_size, index, values, bias, weights): 23 | # type: (int, Tensor, Tensor, Tensor, Tensor) -> Tensor 24 | size = batch_size 25 | srcs = weights.view(1, -1).mul(values.view(1, -1)).view(-1) 26 | output = torch.zeros(size, dtype=torch.float32) 27 | output.scatter_add_(0, index, srcs) 28 | first = output + bias 29 | return first 30 | 31 | def second_order(self, batch_size, index, values, embeddings): 32 | # type: (int, Tensor, Tensor, Tensor) -> Tensor 33 | k = embeddings.size(1) 34 | b = batch_size 35 | 36 | # t1: [k, n] 37 | t1 = embeddings.mul(values.view(-1, 1)).transpose_(0, 1) 38 | # t1: [k, b] 39 | t1_ = torch.zeros(k, b, dtype=torch.float32) 40 | 41 | for i in range(k): 42 | t1_[i].scatter_add_(0, index, t1[i]) 43 | 44 | # t1: [k, b] 45 | t1 = t1_.pow(2) 46 | 47 | # t2: [k, n] 48 | t2 = embeddings.pow(2).mul(values.pow(2).view(-1, 1)).transpose_(0, 1) 49 | # t2: [k, b] 50 | t2_ = torch.zeros(k, b, dtype=torch.float32) 51 | for i in range(k): 52 | t2_[i].scatter_add_(0, index, t2[i]) 53 | 54 | # t2: [k, b] 55 | t2 = t2_ 56 | 57 | second = t1.sub(t2).transpose_(0, 1).sum(1).mul(0.5) 58 | return second 59 | 60 | def forward_(self, batch_size, index, feats, values, bias, weights, embeddings): 61 | # type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor 62 | first = self.first_order(batch_size, index, values, bias, weights) 63 | second = self.second_order(batch_size, index, values, embeddings) 64 | return torch.sigmoid(first + second) 65 | 66 | def forward(self, batch_size, index, feats, values): 67 | # type: (int, Tensor, Tensor, Tensor) -> Tensor 68 | batch_first = F.embedding(feats, self.weights) 69 | batch_second = F.embedding(feats, self.embedding) 70 | return self.forward_(batch_size, index, feats, values, self.bias, batch_first, batch_second) 71 | -------------------------------------------------------------------------------- /torchctr/models/lr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LogisticRegression(nn.Module): 7 | def __init__(self, input_dim=-1): 8 | super().__init__() 9 | self.input_dim = input_dim 10 | assert input_dim > 0, "input_dim must be greater than 0." 11 | self.bias = torch.nn.Parameter(torch.zeros(1, 1, dtype=torch.float32), requires_grad=True) 12 | self.weights = torch.nn.Parameter(torch.randn(input_dim, 1), requires_grad=True) 13 | torch.nn.init.xavier_uniform_(self.weights) 14 | 15 | def forward_(self, batch_size, index, feats, values, bias, weight): 16 | # type: (int, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor 17 | index = index.view(-1) 18 | values = values.view(1, -1) 19 | srcs = weight.view(1, -1).mul(values).view(-1) 20 | output = torch.zeros(batch_size, dtype=torch.float32) 21 | output.scatter_add_(0, index, srcs) 22 | output = output + bias 23 | return torch.sigmoid(output) 24 | 25 | def forward(self, batch_size, index, feats, values): 26 | # index: sample id, feats: feature id, values: feature value 27 | # type: (int, Tensor, Tensor, Tensor) -> Tensor 28 | weight = F.embedding(feats, self.weights) 29 | bias = self.bias 30 | return self.forward_(batch_size, index, feats, values, bias, weight) 31 | -------------------------------------------------------------------------------- /torchctr/models/mf.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class MatrixFactorization(nn.Module): 5 | def __init__(self, n_users, n_items, n_factors=20): 6 | super().__init__() 7 | # create user embeddings 8 | self.user_factors = nn.Embedding(n_users, n_factors, sparse=True) 9 | # create item embeddings 10 | self.item_factors = nn.Embedding(n_items, n_factors, sparse=True) 11 | 12 | def forward(self, user, item): 13 | # matrix multiplication 14 | return (self.user_factors(user) * self.item_factors(item)).sum(1) 15 | -------------------------------------------------------------------------------- /torchctr/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | 5 | def timmer(func): 6 | def wrapper(*args, **kwargs): 7 | start_time = time.time() 8 | res = func(*args, **kwargs) 9 | stop_time = time.time() 10 | print(f'Begin: {time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())}\nfunc_name: {func.__name__}\nCost: {(stop_time - start_time):.4f}s') 11 | return res 12 | 13 | return wrapper 14 | --------------------------------------------------------------------------------