├── .gitignore
├── .style.yapf
├── LICENSE
├── MANIFEST.in
├── README.md
├── datasets
├── .gitignore
└── make_datasets.ipynb
├── description.md
├── docs
├── CTR预测.xmind
├── CTR预测基础.md
└── flow.md
├── examples
├── criteo_classification.py
└── movielens_regression.py
├── imgs
├── AFM.png
├── AutoInt.png
├── CCPM.png
├── CIN.png
├── DCN.png
├── DIEN.png
├── DIN.png
├── DSIN.png
├── DeepFM.png
├── FFM.png
├── FGCNN.png
├── FM.png
├── FNN.png
├── InteractingLayer.png
├── MLR.png
├── NFFM.png
├── NFM.png
├── PNN.png
├── WDL.png
└── xDeepFM.png
├── nbs
├── movielen.ipynb
├── test.ipynb
├── 协同过滤.ipynb
└── 评测指标.ipynb
├── requirements.txt
├── setup.py
├── test.py
└── torchctr
├── __init__.py
├── datasets
├── __init__.py
├── criteo.py
├── data.py
├── movielens.py
├── transform.py
└── utils.py
├── layers.py
├── learner.py
├── metrics.py
├── models
├── __init__.py
├── deepfm.py
├── ffm.py
├── fm.py
├── lr.py
└── mf.py
└── tools.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # user define
107 | .vscode/
108 | .idea/
--------------------------------------------------------------------------------
/.style.yapf:
--------------------------------------------------------------------------------
1 | [style]
2 | based_on_style = pep8
3 | spaces_before_comment = 4
4 | split_before_logical_operator = true
5 | indent_width = 4
6 | column_limit = 120
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 AutuanLiu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 | include description.md
3 | include torchctr/*
4 | exclude examples/*
5 | exclude docs/*
6 | exclude datasets/*
7 | exclude nbs/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Recommended-System-PyTorch
2 |
3 | Recommended system(2018-2019)
4 |
5 | **参考腾讯开源工具[PyTorch On Angel, arming PyTorch with a powerful Parameter Server, which enable PyTorch to train very big models.](https://github.com/Angel-ML/PyTorch-On-Angel)**
6 |
7 | ## Data
8 |
9 | (**Fin**)
10 |
11 | 1. movielen data
12 | - [ml-latest](http://files.grouplens.org/datasets/movielens/ml-latest.zip)
13 | - [ml-100k](http://files.grouplens.org/datasets/movielens/ml-100k.zip)
14 | - [ml-1m](http://files.grouplens.org/datasets/movielens/ml-1m.zip)
15 | - [ml-10m](http://files.grouplens.org/datasets/movielens/ml-10m.zip)
16 | - [ml-20m](http://files.grouplens.org/datasets/movielens/ml-20m.zi)
17 | 2. Criteo data
18 | - [dac](https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz)
19 |
20 | ## Embedding
21 |
22 | (**Fin**)
23 |
24 | 1. sparse features
25 | 2. sequence features
26 | 3. dense features
27 |
28 | ## CTR 模型
29 |
30 | (**WIP**)
31 |
32 | | model | structure |
33 | | :------------: | :----------------------------: |
34 | | FM |  |
35 | | FFM |  |
36 | | DeepFM-201703 |  |
37 | | xDeepFM-2018 |  |
38 | | AFM-201708 |  |
39 | | NFM-201708 |  |
40 | | FGCNN-201904 |  |
41 | | MLR |  |
42 | | NFFM |  |
43 | | WDL |  |
44 | | PNN-201611 |  |
45 | | CIN |  |
46 | | CCPM-201510 |  |
47 | | AutoInt-201810 |  |
48 | | DCN-201708 |  |
49 | | DSIN |  |
50 | | FNN-201601 |  |
51 | | DIEN |  |
52 | | DIN-201706 |  |
53 |
54 |
55 | ## Refrences
56 |
57 | 1. 《推荐系统实践》
58 | 2. git@github.com:dawenl/vae_cf.git
59 | 3. git@github.com:eelxpeng/CollaborativeVAE.git
60 | 4. git@github.com:hidasib/GRU4Rec.git
61 | 5. git@github.com:hexiangnan/neural_collaborative_filtering.git
62 | 6. git@github.com:NVIDIA/DeepRecommender.git
63 | 7. [shenweichen/DeepCTR](https://github.com/shenweichen/DeepCTR)
64 |
--------------------------------------------------------------------------------
/datasets/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 | !*.ipynb
4 |
--------------------------------------------------------------------------------
/datasets/make_datasets.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 8,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "from pathlib import Path\n",
11 | "import os"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 4,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "def get_dat_data(fp):\n",
21 | " \"\"\"读取 .dat 数据\n",
22 | "\n",
23 | " Args:\n",
24 | " fp (str or Path): 文件路径名\n",
25 | " \"\"\"\n",
26 | "\n",
27 | " if not isinstance(fp, Path):\n",
28 | " fp = Path(fp)\n",
29 | " data = pd.read_csv(fp, sep='::', header=None, engine='python')\n",
30 | " return data"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 5,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "root = Path('./ml-1m/')"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 6,
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "data": {
49 | "text/plain": [
50 | "['movies.dat', 'ratings.dat', 'README', 'users.dat']"
51 | ]
52 | },
53 | "execution_count": 6,
54 | "metadata": {},
55 | "output_type": "execute_result"
56 | }
57 | ],
58 | "source": [
59 | "os.listdir(root)"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 9,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "movies = get_dat_data(root/r'movies.dat') # MovieID::Title::Genres\n",
69 | "ratings = get_dat_data(root/r'ratings.dat') # UserID::MovieID::Rating::Timestamp\n",
70 | "users = get_dat_data(root/r'users.dat') # UserID::Gender::Age::Occupation::Zip-code"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 11,
76 | "metadata": {},
77 | "outputs": [
78 | {
79 | "data": {
80 | "text/html": [
81 | "
\n",
82 | "\n",
95 | "
\n",
96 | " \n",
97 | " \n",
98 | " | \n",
99 | " 0 | \n",
100 | " 1 | \n",
101 | " 2 | \n",
102 | "
\n",
103 | " \n",
104 | " \n",
105 | " \n",
106 | " 0 | \n",
107 | " 1 | \n",
108 | " Toy Story (1995) | \n",
109 | " Animation|Children's|Comedy | \n",
110 | "
\n",
111 | " \n",
112 | " 1 | \n",
113 | " 2 | \n",
114 | " Jumanji (1995) | \n",
115 | " Adventure|Children's|Fantasy | \n",
116 | "
\n",
117 | " \n",
118 | " 2 | \n",
119 | " 3 | \n",
120 | " Grumpier Old Men (1995) | \n",
121 | " Comedy|Romance | \n",
122 | "
\n",
123 | " \n",
124 | " 3 | \n",
125 | " 4 | \n",
126 | " Waiting to Exhale (1995) | \n",
127 | " Comedy|Drama | \n",
128 | "
\n",
129 | " \n",
130 | " 4 | \n",
131 | " 5 | \n",
132 | " Father of the Bride Part II (1995) | \n",
133 | " Comedy | \n",
134 | "
\n",
135 | " \n",
136 | "
\n",
137 | "
"
138 | ],
139 | "text/plain": [
140 | " 0 1 2\n",
141 | "0 1 Toy Story (1995) Animation|Children's|Comedy\n",
142 | "1 2 Jumanji (1995) Adventure|Children's|Fantasy\n",
143 | "2 3 Grumpier Old Men (1995) Comedy|Romance\n",
144 | "3 4 Waiting to Exhale (1995) Comedy|Drama\n",
145 | "4 5 Father of the Bride Part II (1995) Comedy"
146 | ]
147 | },
148 | "execution_count": 11,
149 | "metadata": {},
150 | "output_type": "execute_result"
151 | }
152 | ],
153 | "source": [
154 | "movies.head()"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": 12,
160 | "metadata": {},
161 | "outputs": [
162 | {
163 | "data": {
164 | "text/html": [
165 | "\n",
166 | "\n",
179 | "
\n",
180 | " \n",
181 | " \n",
182 | " | \n",
183 | " 0 | \n",
184 | " 1 | \n",
185 | " 2 | \n",
186 | " 3 | \n",
187 | "
\n",
188 | " \n",
189 | " \n",
190 | " \n",
191 | " 0 | \n",
192 | " 1 | \n",
193 | " 1193 | \n",
194 | " 5 | \n",
195 | " 978300760 | \n",
196 | "
\n",
197 | " \n",
198 | " 1 | \n",
199 | " 1 | \n",
200 | " 661 | \n",
201 | " 3 | \n",
202 | " 978302109 | \n",
203 | "
\n",
204 | " \n",
205 | " 2 | \n",
206 | " 1 | \n",
207 | " 914 | \n",
208 | " 3 | \n",
209 | " 978301968 | \n",
210 | "
\n",
211 | " \n",
212 | " 3 | \n",
213 | " 1 | \n",
214 | " 3408 | \n",
215 | " 4 | \n",
216 | " 978300275 | \n",
217 | "
\n",
218 | " \n",
219 | " 4 | \n",
220 | " 1 | \n",
221 | " 2355 | \n",
222 | " 5 | \n",
223 | " 978824291 | \n",
224 | "
\n",
225 | " \n",
226 | "
\n",
227 | "
"
228 | ],
229 | "text/plain": [
230 | " 0 1 2 3\n",
231 | "0 1 1193 5 978300760\n",
232 | "1 1 661 3 978302109\n",
233 | "2 1 914 3 978301968\n",
234 | "3 1 3408 4 978300275\n",
235 | "4 1 2355 5 978824291"
236 | ]
237 | },
238 | "execution_count": 12,
239 | "metadata": {},
240 | "output_type": "execute_result"
241 | }
242 | ],
243 | "source": [
244 | "ratings.head()"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": 13,
250 | "metadata": {},
251 | "outputs": [
252 | {
253 | "data": {
254 | "text/html": [
255 | "\n",
256 | "\n",
269 | "
\n",
270 | " \n",
271 | " \n",
272 | " | \n",
273 | " 0 | \n",
274 | " 1 | \n",
275 | " 2 | \n",
276 | " 3 | \n",
277 | " 4 | \n",
278 | "
\n",
279 | " \n",
280 | " \n",
281 | " \n",
282 | " 0 | \n",
283 | " 1 | \n",
284 | " F | \n",
285 | " 1 | \n",
286 | " 10 | \n",
287 | " 48067 | \n",
288 | "
\n",
289 | " \n",
290 | " 1 | \n",
291 | " 2 | \n",
292 | " M | \n",
293 | " 56 | \n",
294 | " 16 | \n",
295 | " 70072 | \n",
296 | "
\n",
297 | " \n",
298 | " 2 | \n",
299 | " 3 | \n",
300 | " M | \n",
301 | " 25 | \n",
302 | " 15 | \n",
303 | " 55117 | \n",
304 | "
\n",
305 | " \n",
306 | " 3 | \n",
307 | " 4 | \n",
308 | " M | \n",
309 | " 45 | \n",
310 | " 7 | \n",
311 | " 02460 | \n",
312 | "
\n",
313 | " \n",
314 | " 4 | \n",
315 | " 5 | \n",
316 | " M | \n",
317 | " 25 | \n",
318 | " 20 | \n",
319 | " 55455 | \n",
320 | "
\n",
321 | " \n",
322 | "
\n",
323 | "
"
324 | ],
325 | "text/plain": [
326 | " 0 1 2 3 4\n",
327 | "0 1 F 1 10 48067\n",
328 | "1 2 M 56 16 70072\n",
329 | "2 3 M 25 15 55117\n",
330 | "3 4 M 45 7 02460\n",
331 | "4 5 M 25 20 55455"
332 | ]
333 | },
334 | "execution_count": 13,
335 | "metadata": {},
336 | "output_type": "execute_result"
337 | }
338 | ],
339 | "source": [
340 | "users.head()"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": 17,
346 | "metadata": {},
347 | "outputs": [],
348 | "source": [
349 | "ratings_table = ratings.pivot_table(values=2, index=0, columns=1)"
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": 18,
355 | "metadata": {},
356 | "outputs": [
357 | {
358 | "data": {
359 | "text/html": [
360 | "\n",
361 | "\n",
374 | "
\n",
375 | " \n",
376 | " \n",
377 | " 1 | \n",
378 | " 1 | \n",
379 | " 2 | \n",
380 | " 3 | \n",
381 | " 4 | \n",
382 | " 5 | \n",
383 | " 6 | \n",
384 | " 7 | \n",
385 | " 8 | \n",
386 | " 9 | \n",
387 | " 10 | \n",
388 | " ... | \n",
389 | " 3943 | \n",
390 | " 3944 | \n",
391 | " 3945 | \n",
392 | " 3946 | \n",
393 | " 3947 | \n",
394 | " 3948 | \n",
395 | " 3949 | \n",
396 | " 3950 | \n",
397 | " 3951 | \n",
398 | " 3952 | \n",
399 | "
\n",
400 | " \n",
401 | " 0 | \n",
402 | " | \n",
403 | " | \n",
404 | " | \n",
405 | " | \n",
406 | " | \n",
407 | " | \n",
408 | " | \n",
409 | " | \n",
410 | " | \n",
411 | " | \n",
412 | " | \n",
413 | " | \n",
414 | " | \n",
415 | " | \n",
416 | " | \n",
417 | " | \n",
418 | " | \n",
419 | " | \n",
420 | " | \n",
421 | " | \n",
422 | " | \n",
423 | "
\n",
424 | " \n",
425 | " \n",
426 | " \n",
427 | " 1 | \n",
428 | " 5.0 | \n",
429 | " NaN | \n",
430 | " NaN | \n",
431 | " NaN | \n",
432 | " NaN | \n",
433 | " NaN | \n",
434 | " NaN | \n",
435 | " NaN | \n",
436 | " NaN | \n",
437 | " NaN | \n",
438 | " ... | \n",
439 | " NaN | \n",
440 | " NaN | \n",
441 | " NaN | \n",
442 | " NaN | \n",
443 | " NaN | \n",
444 | " NaN | \n",
445 | " NaN | \n",
446 | " NaN | \n",
447 | " NaN | \n",
448 | " NaN | \n",
449 | "
\n",
450 | " \n",
451 | " 2 | \n",
452 | " NaN | \n",
453 | " NaN | \n",
454 | " NaN | \n",
455 | " NaN | \n",
456 | " NaN | \n",
457 | " NaN | \n",
458 | " NaN | \n",
459 | " NaN | \n",
460 | " NaN | \n",
461 | " NaN | \n",
462 | " ... | \n",
463 | " NaN | \n",
464 | " NaN | \n",
465 | " NaN | \n",
466 | " NaN | \n",
467 | " NaN | \n",
468 | " NaN | \n",
469 | " NaN | \n",
470 | " NaN | \n",
471 | " NaN | \n",
472 | " NaN | \n",
473 | "
\n",
474 | " \n",
475 | " 3 | \n",
476 | " NaN | \n",
477 | " NaN | \n",
478 | " NaN | \n",
479 | " NaN | \n",
480 | " NaN | \n",
481 | " NaN | \n",
482 | " NaN | \n",
483 | " NaN | \n",
484 | " NaN | \n",
485 | " NaN | \n",
486 | " ... | \n",
487 | " NaN | \n",
488 | " NaN | \n",
489 | " NaN | \n",
490 | " NaN | \n",
491 | " NaN | \n",
492 | " NaN | \n",
493 | " NaN | \n",
494 | " NaN | \n",
495 | " NaN | \n",
496 | " NaN | \n",
497 | "
\n",
498 | " \n",
499 | " 4 | \n",
500 | " NaN | \n",
501 | " NaN | \n",
502 | " NaN | \n",
503 | " NaN | \n",
504 | " NaN | \n",
505 | " NaN | \n",
506 | " NaN | \n",
507 | " NaN | \n",
508 | " NaN | \n",
509 | " NaN | \n",
510 | " ... | \n",
511 | " NaN | \n",
512 | " NaN | \n",
513 | " NaN | \n",
514 | " NaN | \n",
515 | " NaN | \n",
516 | " NaN | \n",
517 | " NaN | \n",
518 | " NaN | \n",
519 | " NaN | \n",
520 | " NaN | \n",
521 | "
\n",
522 | " \n",
523 | " 5 | \n",
524 | " NaN | \n",
525 | " NaN | \n",
526 | " NaN | \n",
527 | " NaN | \n",
528 | " NaN | \n",
529 | " 2.0 | \n",
530 | " NaN | \n",
531 | " NaN | \n",
532 | " NaN | \n",
533 | " NaN | \n",
534 | " ... | \n",
535 | " NaN | \n",
536 | " NaN | \n",
537 | " NaN | \n",
538 | " NaN | \n",
539 | " NaN | \n",
540 | " NaN | \n",
541 | " NaN | \n",
542 | " NaN | \n",
543 | " NaN | \n",
544 | " NaN | \n",
545 | "
\n",
546 | " \n",
547 | "
\n",
548 | "
5 rows × 3706 columns
\n",
549 | "
"
550 | ],
551 | "text/plain": [
552 | "1 1 2 3 4 5 6 7 8 9 10 ... 3943 \\\n",
553 | "0 ... \n",
554 | "1 5.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n",
555 | "2 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n",
556 | "3 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n",
557 | "4 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN \n",
558 | "5 NaN NaN NaN NaN NaN 2.0 NaN NaN NaN NaN ... NaN \n",
559 | "\n",
560 | "1 3944 3945 3946 3947 3948 3949 3950 3951 3952 \n",
561 | "0 \n",
562 | "1 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
563 | "2 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
564 | "3 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
565 | "4 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
566 | "5 NaN NaN NaN NaN NaN NaN NaN NaN NaN \n",
567 | "\n",
568 | "[5 rows x 3706 columns]"
569 | ]
570 | },
571 | "execution_count": 18,
572 | "metadata": {},
573 | "output_type": "execute_result"
574 | }
575 | ],
576 | "source": [
577 | "ratings_table.head()"
578 | ]
579 | },
580 | {
581 | "cell_type": "markdown",
582 | "metadata": {},
583 | "source": [
584 | "## ratings 数据集"
585 | ]
586 | },
587 | {
588 | "cell_type": "code",
589 | "execution_count": 21,
590 | "metadata": {},
591 | "outputs": [],
592 | "source": [
593 | "ratings_table = ratings_table.fillna(0)"
594 | ]
595 | },
596 | {
597 | "cell_type": "code",
598 | "execution_count": 22,
599 | "metadata": {},
600 | "outputs": [
601 | {
602 | "data": {
603 | "text/html": [
604 | "\n",
605 | "\n",
618 | "
\n",
619 | " \n",
620 | " \n",
621 | " 1 | \n",
622 | " 1 | \n",
623 | " 2 | \n",
624 | " 3 | \n",
625 | " 4 | \n",
626 | " 5 | \n",
627 | " 6 | \n",
628 | " 7 | \n",
629 | " 8 | \n",
630 | " 9 | \n",
631 | " 10 | \n",
632 | " ... | \n",
633 | " 3943 | \n",
634 | " 3944 | \n",
635 | " 3945 | \n",
636 | " 3946 | \n",
637 | " 3947 | \n",
638 | " 3948 | \n",
639 | " 3949 | \n",
640 | " 3950 | \n",
641 | " 3951 | \n",
642 | " 3952 | \n",
643 | "
\n",
644 | " \n",
645 | " 0 | \n",
646 | " | \n",
647 | " | \n",
648 | " | \n",
649 | " | \n",
650 | " | \n",
651 | " | \n",
652 | " | \n",
653 | " | \n",
654 | " | \n",
655 | " | \n",
656 | " | \n",
657 | " | \n",
658 | " | \n",
659 | " | \n",
660 | " | \n",
661 | " | \n",
662 | " | \n",
663 | " | \n",
664 | " | \n",
665 | " | \n",
666 | " | \n",
667 | "
\n",
668 | " \n",
669 | " \n",
670 | " \n",
671 | " 1 | \n",
672 | " 5.0 | \n",
673 | " 0.0 | \n",
674 | " 0.0 | \n",
675 | " 0.0 | \n",
676 | " 0.0 | \n",
677 | " 0.0 | \n",
678 | " 0.0 | \n",
679 | " 0.0 | \n",
680 | " 0.0 | \n",
681 | " 0.0 | \n",
682 | " ... | \n",
683 | " 0.0 | \n",
684 | " 0.0 | \n",
685 | " 0.0 | \n",
686 | " 0.0 | \n",
687 | " 0.0 | \n",
688 | " 0.0 | \n",
689 | " 0.0 | \n",
690 | " 0.0 | \n",
691 | " 0.0 | \n",
692 | " 0.0 | \n",
693 | "
\n",
694 | " \n",
695 | " 2 | \n",
696 | " 0.0 | \n",
697 | " 0.0 | \n",
698 | " 0.0 | \n",
699 | " 0.0 | \n",
700 | " 0.0 | \n",
701 | " 0.0 | \n",
702 | " 0.0 | \n",
703 | " 0.0 | \n",
704 | " 0.0 | \n",
705 | " 0.0 | \n",
706 | " ... | \n",
707 | " 0.0 | \n",
708 | " 0.0 | \n",
709 | " 0.0 | \n",
710 | " 0.0 | \n",
711 | " 0.0 | \n",
712 | " 0.0 | \n",
713 | " 0.0 | \n",
714 | " 0.0 | \n",
715 | " 0.0 | \n",
716 | " 0.0 | \n",
717 | "
\n",
718 | " \n",
719 | " 3 | \n",
720 | " 0.0 | \n",
721 | " 0.0 | \n",
722 | " 0.0 | \n",
723 | " 0.0 | \n",
724 | " 0.0 | \n",
725 | " 0.0 | \n",
726 | " 0.0 | \n",
727 | " 0.0 | \n",
728 | " 0.0 | \n",
729 | " 0.0 | \n",
730 | " ... | \n",
731 | " 0.0 | \n",
732 | " 0.0 | \n",
733 | " 0.0 | \n",
734 | " 0.0 | \n",
735 | " 0.0 | \n",
736 | " 0.0 | \n",
737 | " 0.0 | \n",
738 | " 0.0 | \n",
739 | " 0.0 | \n",
740 | " 0.0 | \n",
741 | "
\n",
742 | " \n",
743 | " 4 | \n",
744 | " 0.0 | \n",
745 | " 0.0 | \n",
746 | " 0.0 | \n",
747 | " 0.0 | \n",
748 | " 0.0 | \n",
749 | " 0.0 | \n",
750 | " 0.0 | \n",
751 | " 0.0 | \n",
752 | " 0.0 | \n",
753 | " 0.0 | \n",
754 | " ... | \n",
755 | " 0.0 | \n",
756 | " 0.0 | \n",
757 | " 0.0 | \n",
758 | " 0.0 | \n",
759 | " 0.0 | \n",
760 | " 0.0 | \n",
761 | " 0.0 | \n",
762 | " 0.0 | \n",
763 | " 0.0 | \n",
764 | " 0.0 | \n",
765 | "
\n",
766 | " \n",
767 | " 5 | \n",
768 | " 0.0 | \n",
769 | " 0.0 | \n",
770 | " 0.0 | \n",
771 | " 0.0 | \n",
772 | " 0.0 | \n",
773 | " 2.0 | \n",
774 | " 0.0 | \n",
775 | " 0.0 | \n",
776 | " 0.0 | \n",
777 | " 0.0 | \n",
778 | " ... | \n",
779 | " 0.0 | \n",
780 | " 0.0 | \n",
781 | " 0.0 | \n",
782 | " 0.0 | \n",
783 | " 0.0 | \n",
784 | " 0.0 | \n",
785 | " 0.0 | \n",
786 | " 0.0 | \n",
787 | " 0.0 | \n",
788 | " 0.0 | \n",
789 | "
\n",
790 | " \n",
791 | "
\n",
792 | "
5 rows × 3706 columns
\n",
793 | "
"
794 | ],
795 | "text/plain": [
796 | "1 1 2 3 4 5 6 7 8 9 10 ... 3943 \\\n",
797 | "0 ... \n",
798 | "1 5.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n",
799 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n",
800 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n",
801 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 \n",
802 | "5 0.0 0.0 0.0 0.0 0.0 2.0 0.0 0.0 0.0 0.0 ... 0.0 \n",
803 | "\n",
804 | "1 3944 3945 3946 3947 3948 3949 3950 3951 3952 \n",
805 | "0 \n",
806 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
807 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
808 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
809 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
810 | "5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
811 | "\n",
812 | "[5 rows x 3706 columns]"
813 | ]
814 | },
815 | "execution_count": 22,
816 | "metadata": {},
817 | "output_type": "execute_result"
818 | }
819 | ],
820 | "source": [
821 | "ratings_table.head()"
822 | ]
823 | },
824 | {
825 | "cell_type": "markdown",
826 | "metadata": {},
827 | "source": [
828 | "保存数据(结构化数据)"
829 | ]
830 | },
831 | {
832 | "cell_type": "code",
833 | "execution_count": 24,
834 | "metadata": {},
835 | "outputs": [],
836 | "source": [
837 | "ratings_table.to_csv(root/r'ratings_table.csv', encoding='utf-8')"
838 | ]
839 | },
840 | {
841 | "cell_type": "code",
842 | "execution_count": null,
843 | "metadata": {},
844 | "outputs": [],
845 | "source": []
846 | }
847 | ],
848 | "metadata": {
849 | "kernelspec": {
850 | "display_name": "Python 3",
851 | "language": "python",
852 | "name": "python3"
853 | },
854 | "language_info": {
855 | "name": ""
856 | }
857 | },
858 | "nbformat": 4,
859 | "nbformat_minor": 2
860 | }
861 |
--------------------------------------------------------------------------------
/description.md:
--------------------------------------------------------------------------------
1 | # torchctr
2 |
--------------------------------------------------------------------------------
/docs/CTR预测.xmind:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/docs/CTR预测.xmind
--------------------------------------------------------------------------------
/docs/CTR预测基础.md:
--------------------------------------------------------------------------------
1 | # CTR 预测基础
2 |
3 | 在计算广告和推荐系统中,CTR预估(click-through rate)是非常重要的一个环节,判断一个商品的是否进行推荐需要根据CTR预估的点击率来进行。在进行
4 | CTR预估时,除了单特征外,往往要对特征进行组合。对于特征组合来说,业界现在通用的做法主要有两大类:**FM系列与Tree系列**
5 |
6 | FM(Factorization Machine)主要是为了解决**数据稀疏**的情况下,特征怎样组合的问题。普通的线性模型,我们都是将各个特征独立考虑的,并没有考虑
7 | 到特征与特征之间的相互关系。但实际上,大量的特征之间是有关联的。一般的线性模型压根没有考虑特征间的关联。为了表述特征间的相关性,我们采用**多项式模型**。与线性模型相比,FM的模型就多了后面**特征组合**的部分。
8 |
9 |
10 |
11 | ## 参考文献
12 |
13 | 1. [推荐系统遇上深度学习(一)--FM模型理论和实践 - 简书](https://www.jianshu.com/p/152ae633fb00)
14 | 2. [简单易学的机器学习算法——因子分解机(Factorization Machine) - null的专栏 - CSDN博客](https://blog.csdn.net/google19890102/article/details/45532745)
15 | 3. [分解机(Factorization Machines)推荐算法原理 - 刘建平Pinard - 博客园](https://www.cnblogs.com/pinard/p/6370127.html)
16 | 4. [机器学习算法系列(26):因子分解机(FM)与场感知分解机(FFM) | Free Will](https://plushunter.github.io/2017/07/13/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E7%B3%BB%E5%88%97%EF%BC%8826%EF%BC%89%EF%BC%9A%E5%9B%A0%E5%AD%90%E5%88%86%E8%A7%A3%E6%9C%BA%EF%BC%88FM%EF%BC%89%E4%B8%8E%E5%9C%BA%E6%84%9F%E7%9F%A5%E5%88%86%E8%A7%A3%E6%9C%BA%EF%BC%88FFM%EF%BC%89/)
17 | 5. [第09章:深入浅出ML之Factorization家族 | 计算广告与机器学习](http://www.52caml.com/head_first_ml/ml-chapter9-factorization-family/)
18 | 6. [深入FFM原理与实践 - 美团技术团队](https://tech.meituan.com/2016/03/03/deep-understanding-of-ffm-principles-and-practices.html)
19 | 7. [从FFM到DeepFFM,推荐排序模型到底哪家强?](https://www.infoq.cn/article/vKoKh_ZDXcWRh8fLSsRp)
20 | 8. [FM与FFM的区别 - AI_盲的博客 - CSDN博客](https://blog.csdn.net/xwd18280820053/article/details/77529274)
21 | 9. [矩阵分解在推荐系统中的应用:NMF和经典SVD实战 | 乐天的个人网站](https://www.letiantian.me/2015-05-25-nmf-svd-recommend/)
22 | 10. [TF-IDF与余弦相似度 - 知乎](https://zhuanlan.zhihu.com/p/32826433)
23 | 11. [王喆的机器学习笔记 - 知乎](https://zhuanlan.zhihu.com/wangzhenotes)
24 | 12. [Embedding在深度推荐系统中的3大应用方向 - 知乎](https://zhuanlan.zhihu.com/p/67218758)
25 | 13. [谷歌、阿里、微软等10大深度学习CTR模型最全演化图谱【推荐、广告、搜索领域】 - 知乎](https://zhuanlan.zhihu.com/p/63186101)
--------------------------------------------------------------------------------
/docs/flow.md:
--------------------------------------------------------------------------------
1 | download_data --> read_data --> process_data --> Dataset --> split_dataset --> DataLoader --> model --> prediction
--------------------------------------------------------------------------------
/examples/criteo_classification.py:
--------------------------------------------------------------------------------
1 | from torchctr.datasets.criteo import get_criteo
2 |
3 | # step 1: download dataset
4 | get_criteo('datasets')
5 |
6 | # step 2: read data
7 |
--------------------------------------------------------------------------------
/examples/movielens_regression.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 | from torchctr.layers import EmbeddingLayer
4 | from torchctr.datasets import (FeatureDict, get_movielens, make_datasets, read_data, defaults, fillna, make_dataloader,
5 | RecommendDataset)
6 |
7 | # step 1: download dataset
8 | root = get_movielens('datasets', 'ml-1m')
9 |
10 | # step 2: read data
11 | users = read_data(root / 'users.dat', sep='::', names=['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code'])
12 | movies = read_data(root / 'movies.dat', sep='::', names=['MovieID', 'Title', 'Genres'])
13 | ratings = read_data(root / 'ratings.dat', sep='::', names=['UserID', 'MovieID', 'Rating', 'Timestamp'])
14 |
15 | # step 3: make dataset
16 | dataset = pd.merge(ratings, users, on='UserID')
17 | dataset = pd.merge(dataset, movies, on='MovieID')
18 |
19 | # subsample
20 | dataset = dataset.iloc[5000:10000, :]
21 |
22 | # step 4: make features and dataloader
23 | sparse_features = ['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code', 'MovieID']
24 | sequence_features = ['Genres']
25 | dataset = fillna(dataset, dataset.columns, fill_v='unk')
26 | features = FeatureDict(sparse_features, None, sequence_features)
27 | input, _ = make_datasets(dataset, features, sep='|')
28 | loader = make_dataloader(input, dataset['Rating'].values, batch_size=64, shuffle=True)
29 | dataset = RecommendDataset(input, dataset['Rating'].values)
30 | print(dataset)
31 |
32 | # step 5: build model
33 | model = EmbeddingLayer(input).to(defaults.device)
34 | print(model)
35 | out = model(input)
36 | print(out.shape, out, sep='\n')
37 |
38 | for data, target in loader:
39 | print(data, target)
40 |
--------------------------------------------------------------------------------
/imgs/AFM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/AFM.png
--------------------------------------------------------------------------------
/imgs/AutoInt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/AutoInt.png
--------------------------------------------------------------------------------
/imgs/CCPM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/CCPM.png
--------------------------------------------------------------------------------
/imgs/CIN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/CIN.png
--------------------------------------------------------------------------------
/imgs/DCN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/DCN.png
--------------------------------------------------------------------------------
/imgs/DIEN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/DIEN.png
--------------------------------------------------------------------------------
/imgs/DIN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/DIN.png
--------------------------------------------------------------------------------
/imgs/DSIN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/DSIN.png
--------------------------------------------------------------------------------
/imgs/DeepFM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/DeepFM.png
--------------------------------------------------------------------------------
/imgs/FFM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/FFM.png
--------------------------------------------------------------------------------
/imgs/FGCNN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/FGCNN.png
--------------------------------------------------------------------------------
/imgs/FM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/FM.png
--------------------------------------------------------------------------------
/imgs/FNN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/FNN.png
--------------------------------------------------------------------------------
/imgs/InteractingLayer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/InteractingLayer.png
--------------------------------------------------------------------------------
/imgs/MLR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/MLR.png
--------------------------------------------------------------------------------
/imgs/NFFM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/NFFM.png
--------------------------------------------------------------------------------
/imgs/NFM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/NFM.png
--------------------------------------------------------------------------------
/imgs/PNN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/PNN.png
--------------------------------------------------------------------------------
/imgs/WDL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/WDL.png
--------------------------------------------------------------------------------
/imgs/xDeepFM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/imgs/xDeepFM.png
--------------------------------------------------------------------------------
/nbs/movielen.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import torch\n",
11 | "from torchctr.layers import EmbeddingLayer, EmbeddingDropout\n",
12 | "from torchctr.datasets import (FeatureDict, get_movielens, make_datasets, read_data, defaults, fillna, make_dataloader, DataMeta)"
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {},
18 | "source": [
19 | "## step 1: download dataset"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [
27 | {
28 | "name": "stdout",
29 | "output_type": "stream",
30 | "text": [
31 | "Downloading...\n",
32 | "Using downloaded and verified file: ../datasets/ml-1m/raw/ml-1m.zip\n",
33 | "Extracting...\n",
34 | "Done!\n"
35 | ]
36 | }
37 | ],
38 | "source": [
39 | "root = get_movielens('../datasets', 'ml-1m')"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "## step 2: read data"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 3,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "users = read_data(root / 'users.dat', sep='::', names=['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code'])\n",
56 | "movies = read_data(root / 'movies.dat', sep='::', names=['MovieID', 'Title', 'Genres'])\n",
57 | "ratings = read_data(root / 'ratings.dat', sep='::', names=['UserID', 'MovieID', 'Rating', 'Timestamp'])"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {},
63 | "source": [
64 | "## step 3: make dataset"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 4,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "dataset = pd.merge(ratings, users, on='UserID')\n",
74 | "dataset = pd.merge(dataset, movies, on='MovieID')"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {},
80 | "source": [
81 | "## subsample(optional)"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 5,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "dataset = dataset.iloc[5000:10000, :]"
91 | ]
92 | },
93 | {
94 | "cell_type": "markdown",
95 | "metadata": {},
96 | "source": [
97 | "## step 4: make features and dataloader"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 6,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "name": "stdout",
107 | "output_type": "stream",
108 | "text": [
109 | "Making dataset Done!\n"
110 | ]
111 | }
112 | ],
113 | "source": [
114 | "sparse_features = ['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code', 'MovieID']\n",
115 | "sequence_features = ['Genres']\n",
116 | "dataset = fillna(dataset, dataset.columns, fill_v='unk')\n",
117 | "features = FeatureDict(sparse_features, None, sequence_features)\n",
118 | "input, _ = make_datasets(dataset, features, sep='|')\n",
119 | "# loader = make_dataloader(input, dataset['Rating'].values, batch_size=64, shuffle=True)"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {},
125 | "source": [
126 | "## step 5: build model"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 7,
132 | "metadata": {},
133 | "outputs": [
134 | {
135 | "name": "stdout",
136 | "output_type": "stream",
137 | "text": [
138 | "EmbeddingLayer(\n",
139 | " (sparse_embeds): ModuleList(\n",
140 | " (0): EmbeddingDropout(\n",
141 | " (emb): Embedding(3205, 147)\n",
142 | " )\n",
143 | " (1): EmbeddingDropout(\n",
144 | " (emb): Embedding(2, 2)\n",
145 | " )\n",
146 | " (2): EmbeddingDropout(\n",
147 | " (emb): Embedding(7, 5)\n",
148 | " )\n",
149 | " (3): EmbeddingDropout(\n",
150 | " (emb): Embedding(21, 9)\n",
151 | " )\n",
152 | " (4): EmbeddingDropout(\n",
153 | " (emb): Embedding(2153, 118)\n",
154 | " )\n",
155 | " (5): EmbeddingDropout(\n",
156 | " (emb): Embedding(4, 3)\n",
157 | " )\n",
158 | " )\n",
159 | " (sequence_embeds): ModuleList(\n",
160 | " (0): Embedding(7, 5)\n",
161 | " )\n",
162 | ")\n"
163 | ]
164 | }
165 | ],
166 | "source": [
167 | "model = EmbeddingLayer(input, emb_drop=0.1).to(defaults.device)\n",
168 | "print(model)\n",
169 | "out = model(input)"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 8,
175 | "metadata": {},
176 | "outputs": [
177 | {
178 | "data": {
179 | "text/plain": [
180 | "tensor([[-0.8287, 0.3714, -0.7944, 0.5302, -0.1847],\n",
181 | " [-0.8287, 0.3714, -0.7944, 0.5302, -0.1847],\n",
182 | " [-0.8287, 0.3714, -0.7944, 0.5302, -0.1847],\n",
183 | " ...,\n",
184 | " [-0.4850, -0.0608, 1.1737, 0.4636, -0.4604],\n",
185 | " [-0.4850, -0.0608, 1.1737, 0.4636, -0.4604],\n",
186 | " [-0.4850, -0.0608, 1.1737, 0.4636, -0.4604]], device='cuda:0',\n",
187 | " grad_fn=)"
188 | ]
189 | },
190 | "execution_count": 8,
191 | "metadata": {},
192 | "output_type": "execute_result"
193 | }
194 | ],
195 | "source": [
196 | "out[:, -5:]"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 9,
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "ly = EmbeddingDropout(torch.nn.Embedding(7, 5), 0.1)"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 10,
211 | "metadata": {},
212 | "outputs": [
213 | {
214 | "data": {
215 | "text/plain": [
216 | "EmbeddingDropout(\n",
217 | " (emb): Embedding(7, 5)\n",
218 | ")"
219 | ]
220 | },
221 | "execution_count": 10,
222 | "metadata": {},
223 | "output_type": "execute_result"
224 | }
225 | ],
226 | "source": [
227 | "ly"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 11,
233 | "metadata": {},
234 | "outputs": [
235 | {
236 | "data": {
237 | "text/plain": [
238 | "Parameter containing:\n",
239 | "tensor([[ 0.0122, 0.7963, -0.9860, 2.5023, 0.9121],\n",
240 | " [-0.2414, -1.1864, -0.0428, 1.4428, 0.6048],\n",
241 | " [-3.1064, -0.8661, -0.4674, -0.6350, -0.0244],\n",
242 | " [-1.4281, -0.2473, 1.4546, 0.1025, -0.1300],\n",
243 | " [-2.0995, 0.1254, 0.0183, -0.6482, 0.9680],\n",
244 | " [ 0.2651, -2.6695, -0.7403, -1.3880, 0.3184],\n",
245 | " [-0.6377, 0.6056, 0.6045, -0.6367, -0.1732]], requires_grad=True)"
246 | ]
247 | },
248 | "execution_count": 11,
249 | "metadata": {},
250 | "output_type": "execute_result"
251 | }
252 | ],
253 | "source": [
254 | "ly.emb.weight"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": 12,
260 | "metadata": {},
261 | "outputs": [
262 | {
263 | "data": {
264 | "text/plain": [
265 | "array([[0, 0, 1, ..., 1, 0, 0],\n",
266 | " [0, 0, 1, ..., 1, 0, 0],\n",
267 | " [0, 0, 1, ..., 1, 0, 0],\n",
268 | " ...,\n",
269 | " [0, 0, 0, ..., 1, 1, 0],\n",
270 | " [0, 0, 0, ..., 1, 1, 0],\n",
271 | " [0, 0, 0, ..., 1, 1, 0]], dtype=int64)"
272 | ]
273 | },
274 | "execution_count": 12,
275 | "metadata": {},
276 | "output_type": "execute_result"
277 | }
278 | ],
279 | "source": [
280 | "input.sequence_data.data"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": 13,
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "y = torch.as_tensor(input.sequence_data.data).float()"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": 14,
295 | "metadata": {},
296 | "outputs": [
297 | {
298 | "data": {
299 | "text/plain": [
300 | "tensor([[-2.2113, -0.3294, 0.3352, -0.3936, 0.2712],\n",
301 | " [-2.2113, -0.3294, 0.3352, -0.3936, 0.2712],\n",
302 | " [-2.2113, -0.3294, 0.3352, -0.3936, 0.2712],\n",
303 | " ...,\n",
304 | " [-0.9172, -1.2720, -0.3610, -1.0181, 0.6432],\n",
305 | " [-0.9172, -1.2720, -0.3610, -1.0181, 0.6432],\n",
306 | " [-0.9172, -1.2720, -0.3610, -1.0181, 0.6432]], grad_fn=)"
307 | ]
308 | },
309 | "execution_count": 14,
310 | "metadata": {},
311 | "output_type": "execute_result"
312 | }
313 | ],
314 | "source": [
315 | "y @ ly.emb.weight/y.sum(1).view(-1,1)"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": 15,
321 | "metadata": {},
322 | "outputs": [
323 | {
324 | "data": {
325 | "text/plain": [
326 | "[7]"
327 | ]
328 | },
329 | "execution_count": 15,
330 | "metadata": {},
331 | "output_type": "execute_result"
332 | }
333 | ],
334 | "source": [
335 | "input.sequence_data.nunique"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": 16,
341 | "metadata": {},
342 | "outputs": [],
343 | "source": [
344 | "from sklearn.feature_extraction.text import CountVectorizer\n",
345 | "import numpy as np\n",
346 | "def sequence_feature_encoding(data, features_names, sep: str = ','):\n",
347 | " \"\"\"Encoding for sequence features.\"\"\"\n",
348 | "\n",
349 | " if not features_names:\n",
350 | " return None\n",
351 | " data_value, nuniques = [], []\n",
352 | " for feature in features_names:\n",
353 | " vocab = set.union(*[set(str(x).strip().split(sep=sep)) for x in data[feature]])\n",
354 | " vec = CountVectorizer(vocabulary=vocab)\n",
355 | " multi_hot = vec.transform(data[feature])\n",
356 | " data_value.append(multi_hot.toarray())\n",
357 | " nuniques.append(len(vocab))\n",
358 | " data_meta = DataMeta(np.hstack(data_value), None, features_names, nuniques)\n",
359 | " return data_meta"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": 17,
365 | "metadata": {},
366 | "outputs": [
367 | {
368 | "data": {
369 | "text/html": [
370 | "\n",
371 | "\n",
384 | "
\n",
385 | " \n",
386 | " \n",
387 | " | \n",
388 | " UserID | \n",
389 | " MovieID | \n",
390 | " Rating | \n",
391 | " Timestamp | \n",
392 | " Gender | \n",
393 | " Age | \n",
394 | " Occupation | \n",
395 | " Zip-code | \n",
396 | " Title | \n",
397 | " Genres | \n",
398 | "
\n",
399 | " \n",
400 | " \n",
401 | " \n",
402 | " 8756 | \n",
403 | " 2396 | \n",
404 | " 1 | \n",
405 | " 5 | \n",
406 | " 999463533 | \n",
407 | " 1 | \n",
408 | " 2 | \n",
409 | " 17 | \n",
410 | " 54 | \n",
411 | " Ben-Hur (1959) | \n",
412 | " Action|Adventure|Drama | \n",
413 | "
\n",
414 | " \n",
415 | " 8845 | \n",
416 | " 2867 | \n",
417 | " 1 | \n",
418 | " 4 | \n",
419 | " 959889599 | \n",
420 | " 1 | \n",
421 | " 3 | \n",
422 | " 1 | \n",
423 | " 581 | \n",
424 | " Ben-Hur (1959) | \n",
425 | " Action|Adventure|Drama | \n",
426 | "
\n",
427 | " \n",
428 | " 7664 | \n",
429 | " 2417 | \n",
430 | " 0 | \n",
431 | " 5 | \n",
432 | " 964113264 | \n",
433 | " 1 | \n",
434 | " 2 | \n",
435 | " 14 | \n",
436 | " 2061 | \n",
437 | " Princess Bride, The (1987) | \n",
438 | " Action|Adventure|Comedy|Romance | \n",
439 | "
\n",
440 | " \n",
441 | " 9482 | \n",
442 | " 1127 | \n",
443 | " 3 | \n",
444 | " 5 | \n",
445 | " 974177311 | \n",
446 | " 1 | \n",
447 | " 2 | \n",
448 | " 7 | \n",
449 | " 1637 | \n",
450 | " Christmas Story, A (1983) | \n",
451 | " Comedy|Drama | \n",
452 | "
\n",
453 | " \n",
454 | " 7709 | \n",
455 | " 2491 | \n",
456 | " 0 | \n",
457 | " 3 | \n",
458 | " 963196099 | \n",
459 | " 0 | \n",
460 | " 5 | \n",
461 | " 0 | \n",
462 | " 133 | \n",
463 | " Princess Bride, The (1987) | \n",
464 | " Action|Adventure|Comedy|Romance | \n",
465 | "
\n",
466 | " \n",
467 | " 8167 | \n",
468 | " 3133 | \n",
469 | " 0 | \n",
470 | " 3 | \n",
471 | " 957361499 | \n",
472 | " 1 | \n",
473 | " 2 | \n",
474 | " 1 | \n",
475 | " 1483 | \n",
476 | " Princess Bride, The (1987) | \n",
477 | " Action|Adventure|Comedy|Romance | \n",
478 | "
\n",
479 | " \n",
480 | " 5151 | \n",
481 | " 1674 | \n",
482 | " 2 | \n",
483 | " 4 | \n",
484 | " 967416107 | \n",
485 | " 1 | \n",
486 | " 1 | \n",
487 | " 12 | \n",
488 | " 1123 | \n",
489 | " Bug's Life, A (1998) | \n",
490 | " Animation|Children's|Comedy | \n",
491 | "
\n",
492 | " \n",
493 | " 6165 | \n",
494 | " 319 | \n",
495 | " 0 | \n",
496 | " 4 | \n",
497 | " 975607581 | \n",
498 | " 1 | \n",
499 | " 3 | \n",
500 | " 16 | \n",
501 | " 614 | \n",
502 | " Princess Bride, The (1987) | \n",
503 | " Action|Adventure|Comedy|Romance | \n",
504 | "
\n",
505 | " \n",
506 | " 8179 | \n",
507 | " 3148 | \n",
508 | " 0 | \n",
509 | " 5 | \n",
510 | " 957213786 | \n",
511 | " 0 | \n",
512 | " 4 | \n",
513 | " 15 | \n",
514 | " 1476 | \n",
515 | " Princess Bride, The (1987) | \n",
516 | " Action|Adventure|Comedy|Romance | \n",
517 | "
\n",
518 | " \n",
519 | " 9677 | \n",
520 | " 1652 | \n",
521 | " 3 | \n",
522 | " 5 | \n",
523 | " 967468418 | \n",
524 | " 1 | \n",
525 | " 2 | \n",
526 | " 14 | \n",
527 | " 217 | \n",
528 | " Christmas Story, A (1983) | \n",
529 | " Comedy|Drama | \n",
530 | "
\n",
531 | " \n",
532 | "
\n",
533 | "
"
534 | ],
535 | "text/plain": [
536 | " UserID MovieID Rating Timestamp Gender Age Occupation Zip-code \\\n",
537 | "8756 2396 1 5 999463533 1 2 17 54 \n",
538 | "8845 2867 1 4 959889599 1 3 1 581 \n",
539 | "7664 2417 0 5 964113264 1 2 14 2061 \n",
540 | "9482 1127 3 5 974177311 1 2 7 1637 \n",
541 | "7709 2491 0 3 963196099 0 5 0 133 \n",
542 | "8167 3133 0 3 957361499 1 2 1 1483 \n",
543 | "5151 1674 2 4 967416107 1 1 12 1123 \n",
544 | "6165 319 0 4 975607581 1 3 16 614 \n",
545 | "8179 3148 0 5 957213786 0 4 15 1476 \n",
546 | "9677 1652 3 5 967468418 1 2 14 217 \n",
547 | "\n",
548 | " Title Genres \n",
549 | "8756 Ben-Hur (1959) Action|Adventure|Drama \n",
550 | "8845 Ben-Hur (1959) Action|Adventure|Drama \n",
551 | "7664 Princess Bride, The (1987) Action|Adventure|Comedy|Romance \n",
552 | "9482 Christmas Story, A (1983) Comedy|Drama \n",
553 | "7709 Princess Bride, The (1987) Action|Adventure|Comedy|Romance \n",
554 | "8167 Princess Bride, The (1987) Action|Adventure|Comedy|Romance \n",
555 | "5151 Bug's Life, A (1998) Animation|Children's|Comedy \n",
556 | "6165 Princess Bride, The (1987) Action|Adventure|Comedy|Romance \n",
557 | "8179 Princess Bride, The (1987) Action|Adventure|Comedy|Romance \n",
558 | "9677 Christmas Story, A (1983) Comedy|Drama "
559 | ]
560 | },
561 | "execution_count": 17,
562 | "metadata": {},
563 | "output_type": "execute_result"
564 | }
565 | ],
566 | "source": [
567 | "dataset.sample(10)"
568 | ]
569 | },
570 | {
571 | "cell_type": "code",
572 | "execution_count": 18,
573 | "metadata": {},
574 | "outputs": [],
575 | "source": [
576 | "x = sequence_feature_encoding(dataset, ['Genres'], '|')"
577 | ]
578 | },
579 | {
580 | "cell_type": "code",
581 | "execution_count": 19,
582 | "metadata": {},
583 | "outputs": [
584 | {
585 | "data": {
586 | "text/plain": [
587 | "{'Action',\n",
588 | " 'Adventure',\n",
589 | " 'Animation',\n",
590 | " \"Children's\",\n",
591 | " 'Comedy',\n",
592 | " 'Drama',\n",
593 | " 'Romance'}"
594 | ]
595 | },
596 | "execution_count": 19,
597 | "metadata": {},
598 | "output_type": "execute_result"
599 | }
600 | ],
601 | "source": [
602 | "vocab = set.union(*[set(str(x).strip().split(sep='|')) for x in dataset['Genres']])\n",
603 | "vocab"
604 | ]
605 | },
606 | {
607 | "cell_type": "code",
608 | "execution_count": 20,
609 | "metadata": {},
610 | "outputs": [],
611 | "source": [
612 | "vec = CountVectorizer(vocabulary=vocab)"
613 | ]
614 | },
615 | {
616 | "cell_type": "code",
617 | "execution_count": 21,
618 | "metadata": {},
619 | "outputs": [],
620 | "source": [
621 | "# [','.join(str(x).strip().split(sep='|')) for x in dataset['Genres']]"
622 | ]
623 | },
624 | {
625 | "cell_type": "code",
626 | "execution_count": 22,
627 | "metadata": {},
628 | "outputs": [
629 | {
630 | "data": {
631 | "text/plain": [
632 | "CountVectorizer(analyzer='word', binary=False, decode_error='strict',\n",
633 | " dtype=, encoding='utf-8', input='content',\n",
634 | " lowercase=True, max_df=1.0, max_features=None, min_df=1,\n",
635 | " ngram_range=(1, 1), preprocessor=None, stop_words=None,\n",
636 | " strip_accents=None, token_pattern='(?u)\\\\b\\\\w\\\\w+\\\\b',\n",
637 | " tokenizer=None,\n",
638 | " vocabulary={\"Children's\", 'Romance', 'Adventure', 'Drama', 'Animation', 'Action', 'Comedy'})"
639 | ]
640 | },
641 | "execution_count": 22,
642 | "metadata": {},
643 | "output_type": "execute_result"
644 | }
645 | ],
646 | "source": [
647 | "vec.fit([' '.join(str(x).strip().split(sep='|')) for x in dataset['Genres']])"
648 | ]
649 | },
650 | {
651 | "cell_type": "code",
652 | "execution_count": 23,
653 | "metadata": {},
654 | "outputs": [],
655 | "source": [
656 | "multi_hot = vec.transform(['Action Comedy', 'Action'])"
657 | ]
658 | },
659 | {
660 | "cell_type": "code",
661 | "execution_count": 24,
662 | "metadata": {},
663 | "outputs": [
664 | {
665 | "data": {
666 | "text/plain": [
667 | "[array([0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0])]"
668 | ]
669 | },
670 | "execution_count": 24,
671 | "metadata": {},
672 | "output_type": "execute_result"
673 | }
674 | ],
675 | "source": [
676 | "list(multi_hot.toarray())"
677 | ]
678 | },
679 | {
680 | "cell_type": "code",
681 | "execution_count": 25,
682 | "metadata": {},
683 | "outputs": [
684 | {
685 | "data": {
686 | "text/plain": [
687 | "array([[0, 1, 1, 1, 0, 0, 0, 0],\n",
688 | " [0, 0, 0, 0, 1, 1, 1, 0],\n",
689 | " [1, 0, 0, 0, 0, 0, 0, 1]], dtype=int64)"
690 | ]
691 | },
692 | "execution_count": 25,
693 | "metadata": {},
694 | "output_type": "execute_result"
695 | }
696 | ],
697 | "source": [
698 | "CountVectorizer(token_pattern=r'(?u)\\b\\w+\\b', analyzer='word').fit_transform(['1 2 31', 'a, b, c3', '中 0']).toarray()"
699 | ]
700 | },
701 | {
702 | "cell_type": "code",
703 | "execution_count": 26,
704 | "metadata": {},
705 | "outputs": [],
706 | "source": [
707 | "corpus = [' '.join(str(x).strip().split(sep='|')) for x in dataset['Genres']]"
708 | ]
709 | },
710 | {
711 | "cell_type": "code",
712 | "execution_count": 27,
713 | "metadata": {},
714 | "outputs": [],
715 | "source": [
716 | "vocab = set.union(*[set(x.split(' ')) for x in corpus])"
717 | ]
718 | },
719 | {
720 | "cell_type": "code",
721 | "execution_count": 28,
722 | "metadata": {},
723 | "outputs": [],
724 | "source": [
725 | "vec = CountVectorizer(token_pattern=r'(?u)\\b[\\w\\']+\\b')\n",
726 | "# vec = CountVectorizer(vocabulary=vocab)"
727 | ]
728 | },
729 | {
730 | "cell_type": "code",
731 | "execution_count": 29,
732 | "metadata": {},
733 | "outputs": [
734 | {
735 | "data": {
736 | "text/plain": [
737 | "array([[0, 0, 1, ..., 1, 0, 0],\n",
738 | " [0, 0, 1, ..., 1, 0, 0],\n",
739 | " [0, 0, 1, ..., 1, 0, 0],\n",
740 | " ...,\n",
741 | " [0, 0, 0, ..., 1, 1, 0],\n",
742 | " [0, 0, 0, ..., 1, 1, 0],\n",
743 | " [0, 0, 0, ..., 1, 1, 0]], dtype=int64)"
744 | ]
745 | },
746 | "execution_count": 29,
747 | "metadata": {},
748 | "output_type": "execute_result"
749 | }
750 | ],
751 | "source": [
752 | "vec.fit_transform(corpus).toarray()"
753 | ]
754 | },
755 | {
756 | "cell_type": "code",
757 | "execution_count": 30,
758 | "metadata": {},
759 | "outputs": [
760 | {
761 | "data": {
762 | "text/plain": [
763 | "{'animation': 2,\n",
764 | " \"children's\": 3,\n",
765 | " 'comedy': 4,\n",
766 | " 'action': 0,\n",
767 | " 'adventure': 1,\n",
768 | " 'romance': 6,\n",
769 | " 'drama': 5}"
770 | ]
771 | },
772 | "execution_count": 30,
773 | "metadata": {},
774 | "output_type": "execute_result"
775 | }
776 | ],
777 | "source": [
778 | "vec.vocabulary_"
779 | ]
780 | },
781 | {
782 | "cell_type": "code",
783 | "execution_count": null,
784 | "metadata": {},
785 | "outputs": [],
786 | "source": []
787 | }
788 | ],
789 | "metadata": {
790 | "kernelspec": {
791 | "display_name": "Python 3",
792 | "language": "python",
793 | "name": "python3"
794 | },
795 | "language_info": {
796 | "codemirror_mode": {
797 | "name": "ipython",
798 | "version": 3
799 | },
800 | "file_extension": ".py",
801 | "mimetype": "text/x-python",
802 | "name": "python",
803 | "nbconvert_exporter": "python",
804 | "pygments_lexer": "ipython3",
805 | "version": "3.6.6"
806 | }
807 | },
808 | "nbformat": 4,
809 | "nbformat_minor": 2
810 | }
811 |
--------------------------------------------------------------------------------
/nbs/test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "from torchctr.layers import EmbeddingLayer\n",
11 | "from torchctr.datasets import (FeatureDict, get_movielens, make_datasets, read_data, defaults, fillna, make_dataloader)\n",
12 | "from torchctr.datasets.data import RecommendDataset"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "name": "stdout",
22 | "output_type": "stream",
23 | "text": [
24 | "Downloading...\n",
25 | "Using downloaded and verified file: ../datasets\\ml-1m\\raw\\ml-1m.zip\n",
26 | "Extracting...\n",
27 | "Done!\n"
28 | ]
29 | }
30 | ],
31 | "source": [
32 | "# step 1: download dataset\n",
33 | "root = get_movielens('../datasets', 'ml-1m')"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 3,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "# step 2: read data\n",
43 | "users = read_data(root / 'users.dat', sep='::', names=['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code'])\n",
44 | "movies = read_data(root / 'movies.dat', sep='::', names=['MovieID', 'Title', 'Genres'])\n",
45 | "ratings = read_data(root / 'ratings.dat', sep='::', names=['UserID', 'MovieID', 'Rating', 'Timestamp'])"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 4,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "# step 3: make dataset\n",
55 | "dataset = pd.merge(ratings, users, on='UserID')\n",
56 | "dataset = pd.merge(dataset, movies, on='MovieID')"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 5,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "# subsample\n",
66 | "dataset = dataset.iloc[5000:10000, :]"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 6,
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "name": "stdout",
76 | "output_type": "stream",
77 | "text": [
78 | "Making dataset Done!\n"
79 | ]
80 | }
81 | ],
82 | "source": [
83 | "# step 4: make features and dataloader\n",
84 | "sparse_features = ['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code', 'MovieID']\n",
85 | "sequence_features = ['Genres']\n",
86 | "dataset = fillna(dataset, dataset.columns, fill_v='unk')\n",
87 | "features = FeatureDict(sparse_features, None, sequence_features)\n",
88 | "input, _ = make_datasets(dataset, features, sep='|')\n",
89 | "loader = make_dataloader(input, dataset['Rating'].values, batch_size=64, shuffle=True)"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 7,
95 | "metadata": {},
96 | "outputs": [
97 | {
98 | "name": "stdout",
99 | "output_type": "stream",
100 | "text": [
101 | "EmbeddingLayer(\n",
102 | " (sparse_embeds): ModuleList(\n",
103 | " (0): Embedding(3205, 147)\n",
104 | " (1): Embedding(2, 2)\n",
105 | " (2): Embedding(7, 5)\n",
106 | " (3): Embedding(21, 9)\n",
107 | " (4): Embedding(2153, 118)\n",
108 | " (5): Embedding(4, 3)\n",
109 | " )\n",
110 | " (sequence_embeds): ModuleList(\n",
111 | " (0): EmbeddingBag(7, 5, mode=mean)\n",
112 | " )\n",
113 | " (drop): Dropout(p=0.0)\n",
114 | ")\n",
115 | "torch.Size([5000, 289])\n",
116 | "tensor([[ 1.0832, -0.3852, 0.9774, ..., 0.4901, 0.2720, 0.2515],\n",
117 | " [-2.9299, 1.2940, -0.9595, ..., 0.4901, 0.2720, 0.2515],\n",
118 | " [ 2.9813, 0.2656, 0.1590, ..., 0.4901, 0.2720, 0.2515],\n",
119 | " ...,\n",
120 | " [ 0.6574, 0.1386, 0.7176, ..., 1.2335, 0.4204, 0.3841],\n",
121 | " [ 0.0121, -0.4749, -0.2445, ..., 1.2335, 0.4204, 0.3841],\n",
122 | " [-0.6250, 1.1999, 0.7947, ..., 1.2335, 0.4204, 0.3841]],\n",
123 | " grad_fn=)\n"
124 | ]
125 | }
126 | ],
127 | "source": [
128 | "# step 5: build model\n",
129 | "model = EmbeddingLayer(input).to(defaults.device)\n",
130 | "print(model)\n",
131 | "out = model(input)\n",
132 | "print(out.shape, out, sep='\\n')\n",
133 | "# print(input)"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": 8,
139 | "metadata": {},
140 | "outputs": [
141 | {
142 | "data": {
143 | "text/plain": [
144 | "16244"
145 | ]
146 | },
147 | "execution_count": 8,
148 | "metadata": {},
149 | "output_type": "execute_result"
150 | }
151 | ],
152 | "source": [
153 | "len(input.sequence_data.data[0])"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": 9,
159 | "metadata": {},
160 | "outputs": [
161 | {
162 | "data": {
163 | "text/plain": [
164 | "5000"
165 | ]
166 | },
167 | "execution_count": 9,
168 | "metadata": {},
169 | "output_type": "execute_result"
170 | }
171 | ],
172 | "source": [
173 | "len(input.sequence_data.bag_offsets[0])"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": 10,
179 | "metadata": {},
180 | "outputs": [],
181 | "source": [
182 | "targets = dataset['Rating'].values"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 11,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "data = RecommendDataset(input, targets)"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": 12,
197 | "metadata": {},
198 | "outputs": [
199 | {
200 | "data": {
201 | "text/plain": [
202 | "5000"
203 | ]
204 | },
205 | "execution_count": 12,
206 | "metadata": {},
207 | "output_type": "execute_result"
208 | }
209 | ],
210 | "source": [
211 | "len(data)"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": 13,
217 | "metadata": {},
218 | "outputs": [
219 | {
220 | "data": {
221 | "text/plain": [
222 | "16244"
223 | ]
224 | },
225 | "execution_count": 13,
226 | "metadata": {},
227 | "output_type": "execute_result"
228 | }
229 | ],
230 | "source": [
231 | "len(input.sequence_data.data[0])"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": 14,
237 | "metadata": {},
238 | "outputs": [],
239 | "source": [
240 | "import numpy as np"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 15,
246 | "metadata": {},
247 | "outputs": [
248 | {
249 | "name": "stdout",
250 | "output_type": "stream",
251 | "text": [
252 | "Wall time: 3.99 ms\n"
253 | ]
254 | }
255 | ],
256 | "source": [
257 | "%%time\n",
258 | "data1, offsets = [], np.zeros((data.lens, len(input.sequence_data.bag_offsets)), dtype=int)\n",
259 | "for x, y in zip(input.sequence_data.data, input.sequence_data.bag_offsets):\n",
260 | " tmp = []\n",
261 | " for idx, item in enumerate(y):\n",
262 | " tmp1 = []\n",
263 | " if idx == data.lens - 1:\n",
264 | " tmp1.extend(x[item:])\n",
265 | " else:\n",
266 | " tmp1.extend(x[item:y[idx + 1]])\n",
267 | " data1.append(tmp)"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 16,
273 | "metadata": {},
274 | "outputs": [],
275 | "source": [
276 | "# input.sequence_data.data/"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 17,
282 | "metadata": {},
283 | "outputs": [
284 | {
285 | "data": {
286 | "text/plain": [
287 | "[]"
288 | ]
289 | },
290 | "execution_count": 17,
291 | "metadata": {},
292 | "output_type": "execute_result"
293 | }
294 | ],
295 | "source": [
296 | "data1[0][3:8]"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": 18,
302 | "metadata": {},
303 | "outputs": [
304 | {
305 | "name": "stdout",
306 | "output_type": "stream",
307 | "text": [
308 | "Wall time: 6.98 ms\n"
309 | ]
310 | }
311 | ],
312 | "source": [
313 | "%%time\n",
314 | "data1, offsets = [], []\n",
315 | "for i in range(data.lens):\n",
316 | " tmp = []\n",
317 | " for x, y in zip(input.sequence_data.data, input.sequence_data.bag_offsets): \n",
318 | " if i == data.lens - 1:\n",
319 | " t = x[y[-1]:]\n",
320 | " t = [t] if isinstance(t, int) else t\n",
321 | " tmp.append(t)\n",
322 | " else:\n",
323 | " t = x[y[i]:y[i + 1]]\n",
324 | " t = [t] if isinstance(t, int) else t\n",
325 | " tmp.append(t)\n",
326 | " data1.append(tmp)"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": 19,
332 | "metadata": {},
333 | "outputs": [],
334 | "source": [
335 | "offsets=np.zeros((data.lens, len(input.sequence_data.bag_offsets)), dtype=int)"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": 20,
341 | "metadata": {},
342 | "outputs": [
343 | {
344 | "data": {
345 | "text/plain": [
346 | "array([[0],\n",
347 | " [0],\n",
348 | " [0],\n",
349 | " [0]])"
350 | ]
351 | },
352 | "execution_count": 20,
353 | "metadata": {},
354 | "output_type": "execute_result"
355 | }
356 | ],
357 | "source": [
358 | "offsets[3:7]"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 21,
364 | "metadata": {},
365 | "outputs": [
366 | {
367 | "data": {
368 | "text/plain": [
369 | "[[[2, 3, 4]], [[2, 3, 4]], [[2, 3, 4]], [[2, 3, 4]]]"
370 | ]
371 | },
372 | "execution_count": 21,
373 | "metadata": {},
374 | "output_type": "execute_result"
375 | }
376 | ],
377 | "source": [
378 | "data1[3:7]"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": 22,
384 | "metadata": {},
385 | "outputs": [
386 | {
387 | "data": {
388 | "text/plain": [
389 | "[2, 3, 4]"
390 | ]
391 | },
392 | "execution_count": 22,
393 | "metadata": {},
394 | "output_type": "execute_result"
395 | }
396 | ],
397 | "source": [
398 | "data1[3:7][1][0]"
399 | ]
400 | },
401 | {
402 | "cell_type": "code",
403 | "execution_count": 23,
404 | "metadata": {},
405 | "outputs": [
406 | {
407 | "ename": "TypeError",
408 | "evalue": "'int' object is not iterable",
409 | "output_type": "error",
410 | "traceback": [
411 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
412 | "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
413 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msequence_data\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbag_offsets\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata1\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m7\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
414 | "\u001b[1;31mTypeError\u001b[0m: 'int' object is not iterable"
415 | ]
416 | }
417 | ],
418 | "source": [
419 | "for i in len(input.sequence_data.bag_offsets):\n",
420 | " y = []\n",
421 | " for t in range(4):\n",
422 | " y.extend(data1[3:7][t][i])"
423 | ]
424 | },
425 | {
426 | "cell_type": "code",
427 | "execution_count": null,
428 | "metadata": {},
429 | "outputs": [],
430 | "source": [
431 | "input.sequence_data.bag_offsets[0][235]"
432 | ]
433 | },
434 | {
435 | "cell_type": "code",
436 | "execution_count": null,
437 | "metadata": {},
438 | "outputs": [],
439 | "source": [
440 | "# input.sequence_data"
441 | ]
442 | },
443 | {
444 | "cell_type": "code",
445 | "execution_count": null,
446 | "metadata": {},
447 | "outputs": [],
448 | "source": [
449 | "# data.sequence_data"
450 | ]
451 | },
452 | {
453 | "cell_type": "code",
454 | "execution_count": null,
455 | "metadata": {},
456 | "outputs": [],
457 | "source": []
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": null,
462 | "metadata": {},
463 | "outputs": [],
464 | "source": [
465 | "# for data, target in loader:\n",
466 | "# print(data, target)"
467 | ]
468 | }
469 | ],
470 | "metadata": {
471 | "kernelspec": {
472 | "display_name": "Python 3",
473 | "language": "python",
474 | "name": "python3"
475 | },
476 | "language_info": {
477 | "codemirror_mode": {
478 | "name": "ipython",
479 | "version": 3
480 | },
481 | "file_extension": ".py",
482 | "mimetype": "text/x-python",
483 | "name": "python",
484 | "nbconvert_exporter": "python",
485 | "pygments_lexer": "ipython3",
486 | "version": "3.6.9"
487 | }
488 | },
489 | "nbformat": 4,
490 | "nbformat_minor": 4
491 | }
492 |
--------------------------------------------------------------------------------
/nbs/协同过滤.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%matplotlib inline\n",
10 | "%reload_ext autoreload\n",
11 | "%autoreload 2\n",
12 | "# 多行输出\n",
13 | "from IPython.core.interactiveshell import InteractiveShell\n",
14 | "InteractiveShell.ast_node_interactivity = \"all\""
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "# 协同过滤"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "协同过滤就是指用户可以齐心协力,通过不断地和网站互动,使自己的推荐列表能够不断过滤掉自己不感兴趣的物品,从而越来越满足自己的需求"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "1. **基于用户的协同过滤算法** 这种算法给用户推荐和他兴趣相似的其他用户喜欢的物品。\n",
36 | "2. **基于物品的协同过滤算法** 这种算法给用户推荐和他之前喜欢的物品相似的物品\n",
37 | "\n",
38 | "**TopN** 推荐的任务是预测用户会不会对某部电影评分,而不是预测用户在准备对某部电影评分的前提下会给电影评多少分"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": []
47 | }
48 | ],
49 | "metadata": {
50 | "kernelspec": {
51 | "display_name": "Python 3",
52 | "language": "python",
53 | "name": "python3"
54 | },
55 | "language_info": {
56 | "codemirror_mode": {
57 | "name": "ipython",
58 | "version": 3
59 | },
60 | "file_extension": ".py",
61 | "mimetype": "text/x-python",
62 | "name": "python",
63 | "nbconvert_exporter": "python",
64 | "pygments_lexer": "ipython3",
65 | "version": "3.6.6"
66 | }
67 | },
68 | "nbformat": 4,
69 | "nbformat_minor": 2
70 | }
71 |
--------------------------------------------------------------------------------
/nbs/评测指标.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%matplotlib inline\n",
10 | "%reload_ext autoreload\n",
11 | "%autoreload 2\n",
12 | "# 多行输出\n",
13 | "from IPython.core.interactiveshell import InteractiveShell\n",
14 | "InteractiveShell.ast_node_interactivity = \"all\""
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "# 评测指标"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "1. 用户满意度(在线)\n",
29 | " \n",
30 | " - 问卷调查\n",
31 | " - 用购买率度量用户的满意度\n",
32 | " - 用户反馈界面收集用户满意度\n",
33 | " - 点击率\n",
34 | " - 用户停留时间\n",
35 | " - 转化率"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "2. 预测准确度(离线)\n",
43 | " - 评分预测\n",
44 | " \n",
45 | " 评分预测的预测准确度一般通过均方根误差(RMSE)和平均绝对误差(MAE)计算\n",
46 | " $$\n",
47 | " \\operatorname{RMSE}=\\frac{\\sqrt{\\sum_{u, i\\in T}\\left(r_{u i}-\\hat{r}_{u i}\\right)^{2}}}{|T|}\n",
48 | " $$\n",
49 | " $$\n",
50 | " \\mathrm{MAE}=\\frac{\\sum_{u, i \\in T}\\left|r_{u i}-\\hat{r}_{u i}\\right|}{|T|}\n",
51 | " $$\n",
52 | "\n",
53 | " $r_{u i}$ 用户 u 对物品 i 的实际评分,而 $\\hat{r}_{u i}$ 是推荐算法给出的预测评分,RMSE加大了对预测不准的用户物品评分的惩罚(平方项的惩罚),因而对系统的评测更加苛刻 "
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "metadata": {},
59 | "source": [
60 | "3. TopN 推荐\n",
61 | "\n",
62 | "TopN推荐的预测准确率一般通过准确率(precision) /召回率(recall)度量\n",
63 | "\n",
64 | "$$\n",
65 | "\\operatorname{Recall}=\\frac{\\sum_{u \\in U}|R(u) \\cap T(u)|}{\\sum_{u \\in U}|T(u)|}\n",
66 | "$$\n",
67 | "\n",
68 | "$$\n",
69 | "\\operatorname{Precision}=\\frac{\\sum_{u \\in U}|R(u) \\cap T(u)|}{\\sum_{u \\in U}|R(u)|}\n",
70 | "$$\n",
71 | "\n",
72 | "R(u) 是根据用户在训练集上的行为给用户作出的推荐列表,而 T(u) 是用户在测试集上的行为列表\n",
73 | "\n",
74 | "为了全面评测TopN推荐的准确率和召回率,一般会选取不同的推荐列表长度 N,计算出一组准确率/召回率,然后画出准确率/召回率曲线(precision/recall curve)\n",
75 | "\n",
76 | "预测用户是否会看一部电影,应该比预测用户看了电影后会给它什么评分更加重要。TopN 预测更符合实际要求"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {},
82 | "source": [
83 | "4. 覆盖率\n",
84 | "\n",
85 | "覆盖率定义为推荐系统能够推荐出来的物品占总物品集合的比例。\n",
86 | "\n",
87 | "$$\n",
88 | "\\operatorname{Coverage}=\\frac{\\left|U_{u \\in U} R(u)\\right|}{|I|}\n",
89 | "$$\n",
90 | "\n",
91 | "\n",
92 | " - 覆盖率是一个内容提供商会关心的指标.覆盖率为100%的推荐系统可以将每个物品都推荐给至少一个用户\n",
93 | " - 热门排行榜的推荐覆盖率是很低的,它只会推荐那些热门的物品,这些物品在总物品中占的比例很小\n",
94 | " - 一个好的推荐系统不仅需要有比较高的用户满意度,也要有较高的覆盖率\n",
95 | "\n",
96 | " - 信息熵\n",
97 | "\n",
98 | " $$\n",
99 | " H=-\\sum_{i=1}^{n} p(i) \\log p(i)\n",
100 | " $$\n",
101 | "\n",
102 | " p(i) 是物品 i 的流行度除以所有物品流行度之和\n",
103 | "\n",
104 | " - 基尼系数\n",
105 | "\n",
106 | " $$\n",
107 | " G=\\frac{1}{n-1} \\sum_{j=1}^{n}(2 j-n-1) p\\left(i_{j}\\right)\n",
108 | " $$\n",
109 | "\n",
110 | " $i_j$ 是按照物品流行度 p() 从小到大排序的物品列表中第 j 个物品"
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "metadata": {},
116 | "source": [
117 | "5. 多样性\n",
118 | "\n",
119 | "多样性描述了推荐列表中物品两两之间的不相似性\n",
120 | "\n",
121 | "6. 新颖性\n",
122 | "\n",
123 | "新颖的推荐是指给用户推荐那些他们以前没有听说过的物品\n",
124 | "\n",
125 | "7. 惊喜度\n",
126 | "\n",
127 | "如果推荐结果和用户的历史兴趣不相似,但却让用户觉得满意,那么就可以说推荐结果的惊喜度很高,而推荐的新颖性仅仅取决于用户是否听说过这个推荐结果\n",
128 | "\n",
129 | "8. 信任度\n",
130 | "\n",
131 | "度量推荐系统的信任度只能通过问卷调查的方式,询问用户是否信任推荐系统的推荐结果"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {},
137 | "source": [
138 | "9. 评测维度\n",
139 | "\n",
140 | " - 用户维度\n",
141 | " - 物品维度\n",
142 | " - 时间维度\n",
143 | "\n",
144 | "在评测系统中还需要考虑评测维度,比如一个推荐算法,虽然整体性能不好,但可能在某种情况下性能比较好,而增加评测维度的目的就是知道一个算法在什么情况下性能最好。"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "metadata": {},
151 | "outputs": [],
152 | "source": []
153 | }
154 | ],
155 | "metadata": {
156 | "kernelspec": {
157 | "display_name": "Python 3",
158 | "language": "python",
159 | "name": "python3"
160 | },
161 | "language_info": {
162 | "codemirror_mode": {
163 | "name": "ipython",
164 | "version": 3
165 | },
166 | "file_extension": ".py",
167 | "mimetype": "text/x-python",
168 | "name": "python",
169 | "nbconvert_exporter": "python",
170 | "pygments_lexer": "ipython3",
171 | "version": "3.6.6"
172 | }
173 | },
174 | "nbformat": 4,
175 | "nbformat_minor": 2
176 | }
177 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scikit-learn
3 | torch>=1.0
4 | torchvision
5 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """setup
2 | Copyright:
3 | ----------
4 | Author: AutuanLiu
5 | Date: 2019/06/01
6 | """
7 |
8 | import distutils.spawn
9 | import shlex
10 | import subprocess
11 | import sys
12 |
13 | from setuptools import find_packages, setup
14 |
15 | version = '0.1.0'
16 |
17 | if sys.argv[1] == 'release':
18 | if not distutils.spawn.find_executable('twine'):
19 | print(
20 | 'Please install twine:\n\n\tpip install twine\n',
21 | file=sys.stderr,
22 | )
23 | sys.exit(1)
24 |
25 | commands = [
26 | 'git pull origin master',
27 | 'git tag v{:s}'.format(version),
28 | 'git push origin master --tag',
29 | 'python setup.py sdist',
30 | 'twine upload dist/imgviz-{:s}.tar.gz'.format(version),
31 | ]
32 | for cmd in commands:
33 | print('+ {}'.format(cmd))
34 | subprocess.check_call(shlex.split(cmd))
35 | sys.exit(0)
36 |
37 |
38 | def get_install_requires():
39 | install_requires = []
40 | with open('requirements.txt') as f:
41 | for req in f:
42 | install_requires.append(req.strip())
43 | return install_requires
44 |
45 |
46 | with open('description.md') as f:
47 | long_description = f.read()
48 |
49 | setup(
50 | name='torchctr',
51 | version=version,
52 | packages=find_packages(),
53 | install_requires=get_install_requires(),
54 | description='CTR prediction based on PyTorch.',
55 | long_description=long_description,
56 | long_description_content_type='text/markdown',
57 | include_package_data=True,
58 | python_requires='>=3.5',
59 | author='Autuan Liu',
60 | author_email='autuanliu@163.com',
61 | url='https://github.com/AutuanLiu/torchctr',
62 | license='MIT',
63 | classifiers=[
64 | 'Development Status :: 5 - Production/Stable',
65 | 'Intended Audience :: Developers',
66 | 'Natural Language :: English',
67 | 'Programming Language :: Python',
68 | 'Programming Language :: Python :: 3.5',
69 | 'Programming Language :: Python :: 3.6',
70 | 'Programming Language :: Python :: 3.7',
71 | 'Programming Language :: Python :: Implementation :: CPython',
72 | 'Programming Language :: Python :: Implementation :: PyPy',
73 | ],
74 | )
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from torchctr.datasets import get_movielens, read_data
2 |
3 | # step 1: download dataset
4 | root = get_movielens('datasets', 'ml-1m')
5 |
6 | # step 2: read data
7 | users = read_data(root / 'users.dat', sep='::', names=['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code'])
8 | movies = read_data(root / 'movies.dat', sep='::', names=['MovieID', 'Title', 'Genres'])
9 | ratings = read_data(root / 'ratings.dat', sep='::', names=['UserID', 'MovieID', 'Rating', 'Timestamp'])
10 |
--------------------------------------------------------------------------------
/torchctr/__init__.py:
--------------------------------------------------------------------------------
1 | from .layers import EmbeddingDropout, EmbeddingLayer
2 | from .tools import timmer
3 |
4 | __all__ = ['EmbeddingLayer', 'timmer', 'EmbeddingDropout']
5 |
--------------------------------------------------------------------------------
/torchctr/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .criteo import get_criteo
2 | from .data import RecommendDataset
3 | from .movielens import get_movielens
4 | from .transform import (dense_feature_scale, fillna, make_dataloader, make_datasets, sequence_feature_encoding,
5 | sparse_feature_encoding)
6 | from .utils import (DataInput, DataMeta, FeatureDict, defaults, dropout_mask, emb_sz_rule, extract_file, read_data,
7 | totensor, train_test_split)
8 |
9 | __all__ = [
10 | 'RecommendDataset', 'extract_file', 'get_movielens', 'get_criteo', 'train_test_split', 'DataMeta', 'DataInput',
11 | 'FeatureDict', 'defaults', 'read_data', 'sequence_feature_encoding', 'dense_feature_scale', 'dropout_mask',
12 | 'sparse_feature_encoding', 'make_datasets', 'fillna', 'emb_sz_rule', 'totensor', 'make_dataloader'
13 | ]
14 |
--------------------------------------------------------------------------------
/torchctr/datasets/criteo.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | from torchvision.datasets.utils import download_url, makedir_exist_ok
5 |
6 | from .utils import extract_file
7 |
8 |
9 | def get_criteo(root):
10 | """Download the Criteo data if it doesn't exist."""
11 |
12 | url = 'https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz'
13 |
14 | raw_folder = os.path.join(root, 'criteo', 'raw')
15 | processed_folder = os.path.join(root, 'criteo', 'processed')
16 | makedir_exist_ok(raw_folder)
17 | makedir_exist_ok(processed_folder)
18 |
19 | # download files and extract
20 | filename = url.rpartition('/')[2]
21 | print('Downloading...')
22 | download_url(url, root=raw_folder, filename=filename, md5=None)
23 | print('Extracting...')
24 | extract_file(os.path.join(raw_folder, filename), processed_folder)
25 | print('Done!')
26 | return Path(processed_folder)
27 |
--------------------------------------------------------------------------------
/torchctr/datasets/data.py:
--------------------------------------------------------------------------------
1 | from .utils import DataInput, DataMeta, totensor
2 |
3 |
4 | class RecommendDataset:
5 | """only support for sparse, sequence and dense data"""
6 |
7 | def __init__(self, input, target):
8 | pass
9 |
10 | def __getitem__(self, index):
11 | pass
12 |
13 | def __len__(self):
14 | return self.lens
15 |
--------------------------------------------------------------------------------
/torchctr/datasets/movielens.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | from torchvision.datasets.utils import download_url, makedir_exist_ok
5 |
6 | from .utils import extract_file
7 |
8 |
9 | def get_movielens(root, version='ml_20m'):
10 | """Download the MovieLens data if it doesn't exist."""
11 |
12 | urls = {
13 | 'ml-latest': 'http://files.grouplens.org/datasets/movielens/ml-latest.zip',
14 | 'ml-100k': 'http://files.grouplens.org/datasets/movielens/ml-100k.zip',
15 | 'ml-1m': 'http://files.grouplens.org/datasets/movielens/ml-1m.zip',
16 | 'ml-10m': 'http://files.grouplens.org/datasets/movielens/ml-10m.zip',
17 | 'ml-20m': 'http://files.grouplens.org/datasets/movielens/ml-20m.zip'
18 | }
19 |
20 | assert version in urls.keys(), f"version must be one of {set(urls.keys())}"
21 | raw_folder = os.path.join(root, version, 'raw')
22 | processed_folder = os.path.join(root, version, 'processed')
23 | makedir_exist_ok(raw_folder)
24 | makedir_exist_ok(processed_folder)
25 |
26 | # download files and extract
27 | filename = urls[version].rpartition('/')[2]
28 | print('Downloading...')
29 | download_url(urls[version], root=raw_folder, filename=filename, md5=None)
30 | print('Extracting...')
31 | extract_file(os.path.join(raw_folder, filename), processed_folder)
32 | print('Done!')
33 | return Path(os.path.join(processed_folder, version))
34 |
--------------------------------------------------------------------------------
/torchctr/datasets/transform.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from sklearn.feature_extraction.text import CountVectorizer
6 | from sklearn.preprocessing import LabelEncoder, MinMaxScaler, StandardScaler
7 | from torch.utils.data import DataLoader
8 |
9 | from .data import RecommendDataset
10 | from .utils import DataInput, DataMeta, FeatureDict, defaults
11 |
12 |
13 | def sparse_feature_encoding(data: pd.DataFrame, features_names: Union[str, List[str]]):
14 | """Encoding for sparse features."""
15 |
16 | if not features_names:
17 | return None
18 | nuniques = []
19 | for feat in features_names:
20 | lbe = LabelEncoder()
21 | data[feat] = lbe.fit_transform(data[feat])
22 | nuniques.append(len(lbe.classes_))
23 | data_meta = DataMeta(data[features_names].values, features_names, nuniques)
24 | return data_meta
25 |
26 |
27 | def sequence_feature_encoding(data: pd.DataFrame, features_names: Union[str, List[str]], sep: str = ','):
28 | """Encoding for sequence features."""
29 |
30 | if not features_names:
31 | return None
32 | data_value, nuniques = [], []
33 | for feature in features_names:
34 | vocab = set.union(*[set(str(x).strip().split(sep=sep)) for x in data[feature]])
35 | vec = CountVectorizer(vocabulary=vocab)
36 | multi_hot = vec.transform(data[feature])
37 | # data_value.append(multi_hot)
38 | nuniques.append(len(vocab))
39 | data_meta = DataMeta(data_value, None, features_names, nuniques, bags_offsets)
40 | to index
41 | ret, offsets, offset = [], [], 0
42 | for row in data[feature]:
43 | offsets.append(offset)
44 | row = row.split(sep) if isinstance(row, str) else str(row).split(sep)
45 | ret.extend(map(lambda word: vec.vocabulary_[word], row))
46 | offset += len(row)
47 | data_value.append(ret)
48 | bags_offsets.append(offsets)
49 | data_meta = DataMeta(data_value, None, features_names, nuniques, bags_offsets)
50 | return data_meta
51 |
52 |
53 | def dense_feature_scale(data: pd.DataFrame, features_names: Union[str, List[str]], scaler_instance=None):
54 | """Scaling for sparse features."""
55 |
56 | if not features_names:
57 | return None, None
58 | scaler = scaler_instance if scaler_instance else StandardScaler()
59 | scaler = scaler.fit(data[features_names])
60 | data[features_names] = scaler.transform(data[features_names])
61 | data_meta = DataMeta(data[features_names].values, features_names)
62 | return data_meta, scaler
63 |
64 |
65 | def fillna(data: pd.DataFrame, features_names: Union[str, List[str]], fill_v, **kwargs):
66 | """Fill Nan with fill_v."""
67 |
68 | data[features_names] = data[features_names].fillna(fill_v, **kwargs)
69 | return data
70 |
71 |
72 | def make_datasets(data: pd.DataFrame, features_dict=None, sep=',', scaler=None):
73 | """make dataset for df.
74 |
75 | Args:
76 | data (pd.DataFrame): data
77 | features_dict (FeatureDict): instance of FeatureDict. Defaults to None.
78 | sep (str, optional): sep for sequence. Defaults to ','.
79 | scaler: sacler for dense data.
80 | """
81 |
82 | pass
83 |
84 |
85 | def make_dataloader(input: DataInput, targets=None, batch_size=64, shuffle=False, drop_last=False):
86 | pass
87 |
--------------------------------------------------------------------------------
/torchctr/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import os
3 | import tarfile
4 | import zipfile
5 | from collections import namedtuple
6 | from pathlib import Path
7 | from types import SimpleNamespace
8 | from functools import lru_cache
9 |
10 | import pandas as pd
11 | import torch
12 | from torch.utils.data import random_split
13 |
14 |
15 | def num_cpus() -> int:
16 | "Get number of cpus"
17 |
18 | try:
19 | return len(os.sched_getaffinity(0))
20 | except AttributeError:
21 | return os.cpu_count()
22 |
23 |
24 | # simple name space
25 | defaults = SimpleNamespace(cpus=min(16, num_cpus()),
26 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
27 |
28 |
29 | def extract_file(from_path, to_path, remove_finished=False):
30 | """https://github.com/pytorch/vision/blob/master/torchvision/datasets/utils.py"""
31 |
32 | if from_path.endswith(".zip"):
33 | with zipfile.ZipFile(from_path, 'r') as z:
34 | z.extractall(to_path)
35 | elif from_path.endswith(".tar"):
36 | with tarfile.open(from_path, 'r:') as tar:
37 | tar.extractall(path=to_path)
38 | elif from_path.endswith(".tar.gz"):
39 | with tarfile.open(from_path, 'r:gz') as tar:
40 | tar.extractall(path=to_path)
41 | elif from_path.endswith(".gz") and not from_path.endswith(".tar.gz"):
42 | to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
43 | with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
44 | out_f.write(zip_f.read())
45 | else:
46 | raise ValueError("Extraction of {from_path} not supported")
47 |
48 | if remove_finished:
49 | os.unlink(from_path)
50 |
51 |
52 | def train_test_split(dataset, test_rate):
53 | """Split dataset into two subdataset(train/test)."""
54 |
55 | test_size = round(len(dataset) * test_rate)
56 | train_size = len(dataset) - test_size
57 | return random_split(dataset, [train_size, test_size])
58 |
59 |
60 | def read_data(filename, **kwargs):
61 | """read data from files.
62 |
63 | Args:
64 | filename (str or Path): file name.
65 | """
66 |
67 | if not isinstance(filename, Path):
68 | filename = Path(filename)
69 | return pd.read_csv(filename, engine='python', **kwargs)
70 |
71 |
72 | def emb_sz_rule(dim: int) -> int:
73 | return min(600, round(1.6 * dim**0.56))
74 |
75 |
76 | def totensor(x):
77 | return x if isinstance(x, torch.Tensor) else torch.as_tensor(x, device=defaults.device)
78 |
79 |
80 | def dropout_mask(x: torch.Tensor, sz: Collection[int], p: float):
81 | "Return a dropout mask of the same type as `x`, size `sz`, with probability `p` to cancel an element."
82 |
83 | return x.new(*sz).bernoulli_(1 - p).div_(1 - p)
84 |
--------------------------------------------------------------------------------
/torchctr/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .datasets import DataInput, defaults, dropout_mask, emb_sz_rule, totensor
4 | from typing import Optional
5 | import torch.nn.functional as F
6 |
7 |
8 | class EmbeddingDropout(nn.Module):
9 | "Apply dropout with probabily `embed_p` to an embedding layer `emb`."
10 |
11 | def __init__(self, emb: nn.Module, embed_p: float):
12 | super().__init__()
13 | self.emb, self.embed_p = emb, embed_p
14 |
15 | def forward(self, words: torch.LongTensor, scale: Optional[float] = None) -> torch.Tensor:
16 | if self.training and self.embed_p != 0:
17 | size = (self.emb.weight.size(0), 1)
18 | mask = dropout_mask(self.emb.weight.data, size, self.embed_p)
19 | masked_embed = self.emb.weight * mask
20 | else:
21 | masked_embed = self.emb.weight
22 | if scale: masked_embed.mul_(scale)
23 | return F.embedding(words, masked_embed, self.emb.padding_idx, self.emb.max_norm, self.emb.norm_type,
24 | self.emb.scale_grad_by_freq, self.emb.sparse)
25 |
26 |
27 | class EmbeddingLayer(nn.Module):
28 | """Embedding layer: convert sparse data to dense data.
29 |
30 | Args:
31 | emb_szs (dict): {feature: embedding size}.
32 | emb_drop (float): drop out. only support for sparse data now.
33 | x (DataInput): instance of DataInput, which includes sparse, sequence, dense data.
34 |
35 | Returns:
36 | torch.Tensor: dense data.
37 | """
38 |
39 | def __init__(self, x, emb_szs_dict=None, emb_drop=0, mode='mean'):
40 | super().__init__()
41 | assert mode in ['sum', 'mean'], "mode must in {'sum', 'mean'}"
42 | layers = []
43 | self.mode = mode
44 | if x.sparse_data:
45 | nuniques = x.sparse_data.nunique
46 | if emb_szs_dict:
47 | emb_szs = [emb_szs_dict[f] for f in x.sparse_data.features]
48 | else:
49 | emb_szs = [emb_sz_rule(t) for t in nuniques]
50 | self.sparse_embeds = nn.ModuleList(
51 | [EmbeddingDropout(nn.Embedding(ni, nf), emb_drop) for ni, nf in zip(nuniques, emb_szs)])
52 | del nuniques, emb_szs
53 | if x.sequence_data:
54 | nuniques = x.sequence_data.nunique
55 | if emb_szs_dict:
56 | emb_szs = [emb_szs_dict[f] for f in x.sequence_data.features]
57 | else:
58 | emb_szs = [self.emb_sz_rule(t) for t in nuniques]
59 | # self.sequence_embeds = nn.ModuleList(
60 | # [nn.EmbeddingBag(ni, nf, mode=mode) for ni, nf in zip(nuniques, emb_szs)])
61 | self.sequence_embeds = nn.ModuleList(
62 | [nn.Embedding(ni, nf) for ni, nf in zip(nuniques, emb_szs)])
63 | del nuniques, emb_szs
64 | self.drop = emb_drop
65 |
66 | def forward(self, x):
67 | out = []
68 | if x.sparse_data:
69 | data = totensor(x.sparse_data.data).long()
70 | sparse_out = [e(data[:, i]) for i, e in enumerate(self.sparse_embeds)]
71 | sparse_out = torch.cat(sparse_out, 1)
72 | out.append(sparse_out)
73 | if x.sequence_data:
74 | nuniques = x.sequence_data.nunique
75 | data = totensor(x.sequence_data.data).float()
76 | data = data.split(nuniques, dim=1)
77 |
78 | sequence_out = [
79 | data[i] @ e.weight if self.mode == 'sum' else data[i] @ e.weight / data[i].sum(dim=1).view(-1, 1)
80 | for i, e in enumerate(self.sequence_embeds)
81 | ]
82 | sequence_out = torch.cat(sequence_out, 1)
83 | out.append(sequence_out)
84 | if x.dense_data:
85 | dense_data = totensor(x.dense_data.data).float()
86 | out.append(dense_data)
87 | return torch.cat(out, 1)
88 |
--------------------------------------------------------------------------------
/torchctr/learner.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | from torch import nn, optim
5 | from .datasets.utils import defaults, totensor
6 |
7 |
8 | @dataclass
9 | class Learner:
10 | model: nn.Module = model.to(defaults.device)
11 | criterion: nn.Module
12 | opt: optim.Optimizer
13 |
14 | def fit(input_loader, epoch=100):
15 | pass
16 |
17 | @torch.no_grad()
18 | def predict(input):
19 | pass
20 |
21 | def save_trained_model(self, path):
22 | """save trained model's weights.
23 | Args:
24 | path (str): the path to save checkpoint.
25 | """
26 |
27 | # save model weights
28 | torch.save(self.model.state_dict(), path)
29 |
30 | def save_model(self, path):
31 | """save model.
32 | Args:
33 | path (str): the path to save checkpoint.
34 | """
35 |
36 | # save model weights
37 | torch.save(self.model, path)
38 |
--------------------------------------------------------------------------------
/torchctr/metrics.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | class Metric():
4 | def __init__(self, train, test):
5 | """评价指标
6 |
7 | Args:
8 | train ([type]): 训练数据
9 | test ([type]): 测试数据
10 | """
11 |
12 | self.train = train
13 | self.test = test
14 | self.recs = self.getRec()
15 |
16 |
17 | def getRec(self):
18 | recs = {}
19 | for user in self.test:
20 | rank = self.GetRecommendation(user)
21 | recs[user] = rank
22 | return recs
23 |
24 | def precision(self):
25 | all, hit = 0, 0
26 | for user in self.test:
27 | test_items = set(self.test[user])
28 | rank = self.recs[user]
29 | for item, score in rank:
30 | if item in test_items:
31 | hit += 1
32 | all += len(rank)
33 | return round(hit / all * 100, 2)
34 |
35 | # 定义召回率指标计算方式
36 | def recall(self):
37 | all, hit = 0, 0
38 | for user in self.test:
39 | test_items = set(self.test[user])
40 | rank = self.recs[user]
41 | for item, score in rank:
42 | if item in test_items:
43 | hit += 1
44 | all += len(test_items)
45 | return round(hit / all * 100, 2)
46 |
47 | # 定义覆盖率指标计算方式
48 | def coverage(self):
49 | all_item, recom_item = set(), set()
50 | for user in self.test:
51 | for item in self.train[user]:
52 | all_item.add(item)
53 | rank = self.recs[user]
54 | for item, score in rank:
55 | recom_item.add(item)
56 | return round(len(recom_item) / len(all_item) * 100, 2)
57 |
58 | # 定义新颖度指标计算方式
59 | def popularity(self):
60 | # 计算物品的流行度
61 | item_pop = {}
62 | for user in self.train:
63 | for item in self.train[user]:
64 | if item not in item_pop:
65 | item_pop[item] = 0
66 | item_pop[item] += 1
67 |
68 | num, pop = 0, 0
69 | for user in self.test:
70 | rank = self.recs[user]
71 | for item, score in rank:
72 | # 取对数,防止因长尾问题带来的被流行物品所主导
73 | pop += math.log(1 + item_pop[item])
74 | num += 1
75 | return round(pop / num, 6)
76 |
--------------------------------------------------------------------------------
/torchctr/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AutuanLiu/torchctr/7300b640179f46402e552d0b434e0f49f6c2ddaf/torchctr/models/__init__.py
--------------------------------------------------------------------------------
/torchctr/models/deepfm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch import Tensor, nn
5 | from typing import List
6 |
7 |
8 | class DeepFM(nn.Module):
9 | def __init__(self, input_dim=-1, n_fields=-1, embedding_dim=-1, fc_dims=[]):
10 | super().__init__()
11 | self.input_dim = input_dim
12 | self.n_fields = n_fields
13 | self.embedding_dim = embedding_dim
14 | self.mats = []
15 | if input_dim > 0 and embedding_dim > 0 and n_fields > 0 and fc_dims:
16 | self.bias = torch.nn.Parameter(torch.zeros(1, 1))
17 | self.weights = torch.nn.Parameter(torch.zeros(input_dim, 1))
18 | self.embedding = torch.nn.Parameter(torch.zeros(input_dim, embedding_dim))
19 | torch.nn.init.xavier_uniform_(self.weights)
20 | torch.nn.init.xavier_uniform_(self.embedding)
21 | dim = n_fields * embedding_dim # DNN input dim
22 | # DNN FC layers
23 | for (index, fc_dim) in enumerate(fc_dims):
24 | self.mats.append(torch.nn.Parameter(torch.randn(dim, fc_dim))) # weight
25 | self.mats.append(torch.nn.Parameter(torch.randn(1, 1))) # bias
26 | torch.nn.init.xavier_uniform_(self.mats[index * 2])
27 | dim = fc_dim
28 |
29 | def first_order(self, batch_size, index, values, bias, weights):
30 | # type: (int, Tensor, Tensor, Tensor, Tensor) -> Tensor
31 | srcs = weights.view(1, -1).mul(values.view(1, -1)).view(-1)
32 | output = torch.zeros(batch_size, dtype=torch.float32)
33 | output.scatter_add_(0, index, srcs)
34 | first = output + bias
35 | return first
36 |
37 | def second_order(self, batch_size, index, values, embeddings):
38 | # type: (int, Tensor, Tensor, Tensor) -> Tensor
39 | k = embeddings.size(1)
40 | b = batch_size
41 |
42 | # t1: [k, n]
43 | t1 = embeddings.mul(values.view(-1, 1)).transpose_(0, 1)
44 | # t1: [k, b]
45 | t1_ = torch.zeros(k, b, dtype=torch.float32)
46 |
47 | for i in range(k):
48 | t1_[i].scatter_add_(0, index, t1[i])
49 |
50 | # t1: [k, b]
51 | t1 = t1_.pow(2)
52 |
53 | # t2: [k, n]
54 | t2 = embeddings.pow(2).mul(values.pow(2).view(-1, 1)).transpose_(0, 1)
55 | # t2: [k, b]
56 | t2_ = torch.zeros(k, b, dtype=torch.float32)
57 | for i in range(k):
58 | t2_[i].scatter_add_(0, index, t2[i])
59 |
60 | # t2: [k, b]
61 | t2 = t2_
62 |
63 | second = t1.sub(t2).transpose_(0, 1).sum(1).mul(0.5)
64 | return second
65 |
66 | def higher_order(self, batch_size, embeddings, mats):
67 | # type: (int, Tensor, List[Tensor]) -> Tensor
68 | # activate function: relu
69 | output = embeddings.view(batch_size, -1)
70 |
71 | for i in range(int(len(mats) / 2)):
72 | output = torch.relu(output.matmul(mats[i * 2]) + mats[i * 2 + 1])
73 |
74 | return output.view(-1)
75 |
76 | def forward_(self, batch_size, index, feats, values, bias, weights, embeddings, mats):
77 | # type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[Tensor]) -> Tensor
78 |
79 | first = self.first_order(batch_size, index, values, bias, weights)
80 | second = self.second_order(batch_size, index, values, embeddings)
81 | higher = self.higher_order(batch_size, embeddings, mats)
82 |
83 | return torch.sigmoid(first + second + higher)
84 |
85 | def forward(self, batch_size, index, feats, values):
86 | # type: (int, Tensor, Tensor, Tensor) -> Tensor
87 | batch_first = F.embedding(feats, self.weights)
88 | batch_second = F.embedding(feats, self.embedding)
89 | return self.forward_(batch_size, index, feats, values, self.bias, batch_first, batch_second, self.mats)
90 |
--------------------------------------------------------------------------------
/torchctr/models/ffm.py:
--------------------------------------------------------------------------------
1 | # https://github.com/LLSean/data-mining
2 |
3 | import os
4 | import sys
5 | import tensorflow as tf
6 | import logging
7 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
8 | import numpy as np
9 | import argparse
10 | from util import *
11 | from sklearn.metrics import *
12 |
13 |
14 | class FFM(object):
15 | def __init__(self, num_classes, k, field, lr, batch_size, feature_length, reg_l1, reg_l2, feature2field):
16 | self.num_classes = num_classes
17 | self.k = k
18 | self.field = field
19 | self.lr = lr
20 | self.batch_size = batch_size
21 | self.p = feature_length
22 | self.reg_l1 = reg_l1
23 | self.reg_l2 = reg_l2
24 | self.feature2field = feature2field
25 |
26 | def add_input(self):
27 | self.X = tf.placeholder('float32', [None, self.p])
28 | self.y = tf.placeholder('float32', [None, num_classes])
29 | self.keep_prob = tf.placeholder('float32')
30 |
31 | def inference(self):
32 | with tf.variable_scope('linear_layer'):
33 | w0 = tf.get_variable('w0', shape=[self.num_classes], initializer=tf.zeros_initializer())
34 | self.w = tf.get_variable('w', shape=[self.p, num_classes], initializer=tf.truncated_normal_initializer(mean=0, stddev=0.01))
35 | self.linear_terms = tf.add(tf.matmul(self.X, self.w), w0)
36 |
37 | with tf.variable_scope('interaction_layer'):
38 | self.v = tf.get_variable('v', shape=[self.p, self.field, self.k], initializer=tf.truncated_normal_initializer(mean=0, stddev=0.01))
39 | self.interaction_terms = tf.constant(0, dtype='float32')
40 | for i in range(self.p):
41 | for j in range(i + 1, self.p):
42 | self.interaction_terms += tf.multiply(
43 | tf.reduce_sum(tf.multiply(self.v[i, self.feature2field[i]], self.v[j, self.feature2field[j]])), tf.multiply(self.X[:, i], self.X[:, j]))
44 | self.interaction_terms = tf.reshape(self.interaction_terms, [-1, 1])
45 | self.y_out = tf.math.add(self.linear_terms, self.interaction_terms)
46 | if self.num_classes == 2:
47 | self.y_out_prob = tf.nn.sigmoid(self.y_out)
48 | elif self.num_classes > 2:
49 | self.y_out_prob = tf.nn.softmax(self.y_out)
50 |
51 | def add_loss(self):
52 | if self.num_classes == 2:
53 | cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.y, logits=self.y_out)
54 | elif self.num_classes > 2:
55 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=self.y_out)
56 | mean_loss = tf.reduce_mean(cross_entropy)
57 | self.loss = mean_loss
58 | tf.summary.scalar('loss', self.loss)
59 |
60 | def add_accuracy(self):
61 | # accuracy
62 | self.correct_prediction = tf.equal(tf.cast(tf.argmax(self.y_out, 1), tf.float32), tf.cast(tf.argmax(self.y, 1), tf.float32))
63 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
64 | # add summary to accuracy
65 | tf.summary.scalar('accuracy', self.accuracy)
66 |
67 | def train(self):
68 | self.global_step = tf.Variable(0, trainable=False)
69 | optimizer = tf.train.AdagradOptimizer(self.lr)
70 | extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
71 | with tf.control_dependencies(extra_update_ops):
72 | self.train_op = optimizer.minimize(self.loss, global_step=self.global_step)
73 |
74 | def build_graph(self):
75 | self.add_input()
76 | self.inference()
77 | self.add_loss()
78 | self.add_accuracy()
79 | self.train()
80 |
81 |
82 | def train_model(sess, model, epochs=100, print_every=50):
83 | """training model"""
84 | # Merge all the summaries and write them out to train_logs
85 | merged = tf.summary.merge_all()
86 | train_writer = tf.summary.FileWriter('train_logs', sess.graph)
87 |
88 | # get number of batches
89 | num_batches = len(x_train) // batch_size + 1
90 |
91 | for e in range(epochs):
92 | num_samples = 0
93 | losses = []
94 | for ibatch in range(num_batches):
95 | # batch_size data
96 | batch_x, batch_y = next(batch_gen)
97 | batch_y = np.array(batch_y).astype(np.float32)
98 | actual_batch_size = len(batch_y)
99 | # create a feed dictionary for this batch
100 | feed_dict = {model.X: batch_x, model.y: batch_y, model.keep_prob: 1.0}
101 |
102 | loss, accuracy, summary, global_step, _ = sess.run([model.loss, model.accuracy, merged, model.global_step, model.train_op], feed_dict=feed_dict)
103 | # aggregate performance stats
104 | losses.append(loss * actual_batch_size)
105 | num_samples += actual_batch_size
106 | # Record summaries and train.csv-set accuracy
107 | train_writer.add_summary(summary, global_step=global_step)
108 | # print training loss and accuracy
109 | if global_step % print_every == 0:
110 | logging.info("Iteration {0}: with minibatch training loss = {1} and accuracy of {2}".format(global_step, loss, accuracy))
111 | saver.save(sess, "checkpoints/model", global_step=global_step)
112 | # print loss of one epoch
113 | total_loss = np.sum(losses) / num_samples
114 | print("Epoch {1}, Overall loss = {0:.3g}".format(total_loss, e + 1))
115 |
116 |
117 | def test_model(sess, model, print_every=50):
118 | """training model"""
119 | # get testing data, iterable
120 | all_ids = []
121 | all_clicks = []
122 | # get number of batches
123 | num_batches = len(y_test) // batch_size + 1
124 |
125 | for ibatch in range(num_batches):
126 | # batch_size data
127 | batch_x, batch_y = next(test_batch_gen)
128 | actual_batch_size = len(batch_y)
129 | # create a feed dictionary for this15162 batch
130 | feed_dict = {model.X: batch_x, model.keep_prob: 1}
131 | # shape of [None,2]
132 | y_out_prob = sess.run([model.y_out_prob], feed_dict=feed_dict)
133 | y_out_prob = np.array(y_out_prob[0])
134 |
135 | batch_clicks = np.argmax(y_out_prob, axis=1)
136 |
137 | batch_y = np.argmax(batch_y, axis=1)
138 |
139 | print(confusion_matrix(batch_y, batch_clicks))
140 | ibatch += 1
141 | if ibatch % print_every == 0:
142 | logging.info("Iteration {0} has finished".format(ibatch))
143 |
144 |
145 | def shuffle_list(data):
146 | num = data[0].shape[0]
147 | p = np.random.permutation(num)
148 | return [d[p] for d in data]
149 |
150 |
151 | def batch_generator(data, batch_size, shuffle=True):
152 | if shuffle:
153 | data = shuffle_list(data)
154 |
155 | batch_count = 0
156 | while True:
157 | if batch_count * batch_size + batch_size > len(data[0]):
158 | batch_count = 0
159 |
160 | if shuffle:
161 | data = shuffle_list(data)
162 |
163 | start = batch_count * batch_size
164 | end = start + batch_size
165 | batch_count += 1
166 | yield [d[start:end] for d in data]
167 |
168 |
169 | def check_restore_parameters(sess, saver):
170 | """ Restore the previously trained parameters if there are any. """
171 | ckpt = tf.train.get_checkpoint_state("checkpoints")
172 | if ckpt and ckpt.model_checkpoint_path:
173 | logging.info("Loading parameters for the my Factorization Machine")
174 | saver.restore(sess, ckpt.model_checkpoint_path)
175 | else:
176 | logging.info("Initializing fresh parameters for the my Factorization Machine")
177 |
178 |
179 | if __name__ == '__main__':
180 | '''launching TensorBoard: tensorboard --logdir=path/to/log-directory'''
181 | # get mode (train or test)
182 | parser = argparse.ArgumentParser()
183 | parser.add_argument('--mode', help='train or test', type=str)
184 | args = parser.parse_args()
185 | mode = args.mode
186 | # length of representation
187 | x_train, y_train, x_test, y_test, feature2field = load_dataset()
188 | # initialize the model
189 | num_classes = 2
190 | lr = 0.01
191 | batch_size = 128
192 | k = 8
193 | reg_l1 = 2e-2
194 | reg_l2 = 0
195 | feature_length = x_train.shape[1]
196 | # initialize FM model
197 | batch_gen = batch_generator([x_train, y_train], batch_size)
198 | test_batch_gen = batch_generator([x_test, y_test], batch_size)
199 | model = FFM(num_classes, k, 5, lr, batch_size, feature_length, reg_l1, reg_l2, feature2field)
200 | # build graph for model
201 | model.build_graph()
202 |
203 | saver = tf.train.Saver(max_to_keep=5)
204 |
205 | with tf.Session() as sess:
206 | sess.run(tf.global_variables_initializer())
207 | check_restore_parameters(sess, saver)
208 | if mode == 'train':
209 | print('start training...')
210 | train_model(sess, model, epochs=100, print_every=500)
211 | if mode == 'test':
212 | print('start testing...')
213 | test_model(sess, model)
--------------------------------------------------------------------------------
/torchctr/models/fm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FactorizationMachine(nn.Module):
7 | def __init__(self, input_dim=-1, embedding_dim=-1):
8 | super().__init__()
9 | self.input_dim = input_dim
10 | self.embedding_dim = embedding_dim
11 |
12 | if input_dim > 0 and embedding_dim > 0:
13 | self.bias = torch.randn(1, 1, dtype=torch.float32)
14 | self.weights = torch.randn(input_dim, 1)
15 | self.embedding = torch.randn(input_dim, embedding_dim)
16 | self.bias = nn.Parameter(self.bias, requires_grad=True)
17 | self.weights = nn.Parameter(self.weights, requires_grad=True)
18 | self.embedding = nn.Parameter(self.embedding, requires_grad=True)
19 | nn.init.xavier_uniform_(self.weights)
20 | nn.init.xavier_uniform_(self.embedding)
21 |
22 | def first_order(self, batch_size, index, values, bias, weights):
23 | # type: (int, Tensor, Tensor, Tensor, Tensor) -> Tensor
24 | size = batch_size
25 | srcs = weights.view(1, -1).mul(values.view(1, -1)).view(-1)
26 | output = torch.zeros(size, dtype=torch.float32)
27 | output.scatter_add_(0, index, srcs)
28 | first = output + bias
29 | return first
30 |
31 | def second_order(self, batch_size, index, values, embeddings):
32 | # type: (int, Tensor, Tensor, Tensor) -> Tensor
33 | k = embeddings.size(1)
34 | b = batch_size
35 |
36 | # t1: [k, n]
37 | t1 = embeddings.mul(values.view(-1, 1)).transpose_(0, 1)
38 | # t1: [k, b]
39 | t1_ = torch.zeros(k, b, dtype=torch.float32)
40 |
41 | for i in range(k):
42 | t1_[i].scatter_add_(0, index, t1[i])
43 |
44 | # t1: [k, b]
45 | t1 = t1_.pow(2)
46 |
47 | # t2: [k, n]
48 | t2 = embeddings.pow(2).mul(values.pow(2).view(-1, 1)).transpose_(0, 1)
49 | # t2: [k, b]
50 | t2_ = torch.zeros(k, b, dtype=torch.float32)
51 | for i in range(k):
52 | t2_[i].scatter_add_(0, index, t2[i])
53 |
54 | # t2: [k, b]
55 | t2 = t2_
56 |
57 | second = t1.sub(t2).transpose_(0, 1).sum(1).mul(0.5)
58 | return second
59 |
60 | def forward_(self, batch_size, index, feats, values, bias, weights, embeddings):
61 | # type: (int, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
62 | first = self.first_order(batch_size, index, values, bias, weights)
63 | second = self.second_order(batch_size, index, values, embeddings)
64 | return torch.sigmoid(first + second)
65 |
66 | def forward(self, batch_size, index, feats, values):
67 | # type: (int, Tensor, Tensor, Tensor) -> Tensor
68 | batch_first = F.embedding(feats, self.weights)
69 | batch_second = F.embedding(feats, self.embedding)
70 | return self.forward_(batch_size, index, feats, values, self.bias, batch_first, batch_second)
71 |
--------------------------------------------------------------------------------
/torchctr/models/lr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class LogisticRegression(nn.Module):
7 | def __init__(self, input_dim=-1):
8 | super().__init__()
9 | self.input_dim = input_dim
10 | assert input_dim > 0, "input_dim must be greater than 0."
11 | self.bias = torch.nn.Parameter(torch.zeros(1, 1, dtype=torch.float32), requires_grad=True)
12 | self.weights = torch.nn.Parameter(torch.randn(input_dim, 1), requires_grad=True)
13 | torch.nn.init.xavier_uniform_(self.weights)
14 |
15 | def forward_(self, batch_size, index, feats, values, bias, weight):
16 | # type: (int, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
17 | index = index.view(-1)
18 | values = values.view(1, -1)
19 | srcs = weight.view(1, -1).mul(values).view(-1)
20 | output = torch.zeros(batch_size, dtype=torch.float32)
21 | output.scatter_add_(0, index, srcs)
22 | output = output + bias
23 | return torch.sigmoid(output)
24 |
25 | def forward(self, batch_size, index, feats, values):
26 | # index: sample id, feats: feature id, values: feature value
27 | # type: (int, Tensor, Tensor, Tensor) -> Tensor
28 | weight = F.embedding(feats, self.weights)
29 | bias = self.bias
30 | return self.forward_(batch_size, index, feats, values, bias, weight)
31 |
--------------------------------------------------------------------------------
/torchctr/models/mf.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class MatrixFactorization(nn.Module):
5 | def __init__(self, n_users, n_items, n_factors=20):
6 | super().__init__()
7 | # create user embeddings
8 | self.user_factors = nn.Embedding(n_users, n_factors, sparse=True)
9 | # create item embeddings
10 | self.item_factors = nn.Embedding(n_items, n_factors, sparse=True)
11 |
12 | def forward(self, user, item):
13 | # matrix multiplication
14 | return (self.user_factors(user) * self.item_factors(item)).sum(1)
15 |
--------------------------------------------------------------------------------
/torchctr/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 |
5 | def timmer(func):
6 | def wrapper(*args, **kwargs):
7 | start_time = time.time()
8 | res = func(*args, **kwargs)
9 | stop_time = time.time()
10 | print(f'Begin: {time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())}\nfunc_name: {func.__name__}\nCost: {(stop_time - start_time):.4f}s')
11 | return res
12 |
13 | return wrapper
14 |
--------------------------------------------------------------------------------