├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── LICENSE ├── README.md ├── changelog.md ├── dataset ├── criteo_sample_10k.txt └── movielens_ratings_10k.csv ├── example ├── eda │ └── criteo_eda.ipynb ├── loader │ ├── criteo_loader.py │ └── movielens_loader.py └── model │ ├── deepfm_example.py │ ├── fm_example.py │ └── pnn_example.py ├── feature ├── __init__.py ├── feature.py └── feature_meta.py ├── model ├── basic │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── embedding_layer.py │ ├── enum │ │ ├── __init__.py │ │ ├── activation_enum.py │ │ └── attention_enum.py │ ├── functional.py │ ├── gbdt.py │ ├── mlp.py │ └── output_layer.py ├── ctr │ ├── __init__.py │ ├── afm.py │ ├── autoInt.py │ ├── dcn.py │ ├── deepfm.py │ ├── flen.py │ ├── fm.py │ ├── fnn.py │ ├── gbdt_lr.py │ ├── lr.py │ ├── nfm.py │ ├── pnn.py │ └── wide_and_deep.py ├── sequence │ ├── __init__.py │ ├── dien.py │ └── din.py └── wrapper │ ├── __init__.py │ ├── base.py │ └── ctr │ ├── __init__.py │ ├── fnn.py │ └── pnn.py ├── preprocess ├── discretize.py ├── feat_engineering.py └── preprocess.py └── util ├── checkpoint_util.py ├── filedir_util.py ├── log_util.py └── train.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /job/ 131 | /test.py 132 | /.idea/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 zeroized 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepRec-torch 2 | DeepRec-torch is a framework based on pytorch. 3 | This project is more like a tutorial for learning recommender system models than a tool for direct using. 4 | The analysis of the implemented models is available in author`s github pages, [zeroized.github.io](https://github.com/zeroized/zeroized.github.io) or the corresponding blog URL [zeroized.xyz](http://www.zeroized.xyz/), which are provided in Simplified Chinese. 5 | 6 | ## Dependency 7 | 8 | - torch 1.2.0 9 | - numpy 1.17.3 10 | - pandas 0.25.3 11 | - scikit-learn 0.21.3 12 | - tensorboard 2.2.1 (For loss and metrics visualization) 13 | - lightgbm 2.3.0 (For building high-order feature interaction with GBDT) 14 | 15 | ## Quick Start 16 | 1.Load and preprocess data 17 | ```python 18 | from example.loader.criteo_loader import load_data,missing_values_process 19 | # load 10,000 pieces from criteo-1m dataset 20 | data = load_data('/path/to/data',n_samples=10000) 21 | data = missing_values_process(data) 22 | ``` 23 | 24 | 2.Describe the columns with FeatureMeta 25 | ```python 26 | from feature.feature_meta import FeatureMeta 27 | from example.loader.criteo_loader import continuous_columns,category_columns 28 | 29 | feature_meta = FeatureMeta() 30 | for column in continuous_columns: 31 | # By default, the continuous feature will not be discretized. 32 | feature_meta.add_continuous_feat(column) 33 | for column in category_columns: 34 | feature_meta.add_categorical_feat(column) 35 | ``` 36 | 37 | 3.Transform data into wanted format (usually feat_index and feat_value) 38 | ```python 39 | from preprocess.feat_engineering import preprocess_features 40 | 41 | x_idx, x_value = preprocess_features(feature_meta, data) 42 | 43 | label = data.y 44 | ``` 45 | 46 | 4.Prepare for training 47 | ```python 48 | import torch 49 | 50 | # Assign the device for training 51 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 52 | 53 | # Load data into assigned device 54 | X_idx_tensor_gpu = torch.LongTensor(x_idx).to(device) 55 | X_value_tensor_gpu = torch.Tensor(x_value).to(device) 56 | y_tensor_gpu = torch.Tensor(label).to(device) 57 | 58 | # Note that a binary classifier requires label with shape (n_samples,1) 59 | y_tensor_gpu = y_tensor_gpu.reshape(-1, 1) 60 | 61 | # Form a dataset for torch`s DataLoader 62 | X_cuda = TensorDataset(X_idx_tensor_gpu, X_value_tensor_gpu, y_tensor_gpu) 63 | ``` 64 | 65 | 5.Load a model and set parameters (pre-defined models for ctr prediction task are in model.ctr package) 66 | ```python 67 | from model.ctr.fm import FM 68 | 69 | # Create an FM model with embedding size of 5 and binary output, and load it into the assigned device 70 | fm_model = FM(emb_dim=5, num_feats=feat_meta.get_num_feats(), out_type='binary').to(device) 71 | 72 | # Assign an optimizer for the model 73 | optimizer = torch.optim.Adam(fm_model.parameters(), lr=1e-4) 74 | ``` 75 | 76 | 6.Train the model with a trainer 77 | ```python 78 | from util.train import train_model_hold_out 79 | 80 | # Train the model with hold-out model selection 81 | train_model_hold_out(job_name='fm-binary-cls', device=device, 82 | model=fm_model, dataset=X_cuda, 83 | loss_func=nn.BCELoss(), optimizer=optimizer, 84 | epochs=20, batch_size=256) 85 | # Checkpoint saving is by default true in trainers. 86 | # For more custom settings, create a dict like follow: 87 | ckpt_settings = {'save_ckpt':True, 'ckpt_dir':'path/to/ckpt_dir', 'ckpt_interval':3} 88 | # Then send the kwargs parameter 89 | train_model_hold_out(...,**ckpt_settings) 90 | # Settings for log file path, model saving path and tensorboard file path is similar, see util.train.py 91 | ``` 92 | The role of the trainer is more a log writer than a simple model training method. 93 | 94 | For more examples: 95 | 96 | - Model usage examples are available in example.model package. 97 | 98 | - Data loader examples are available in example.loader package. 99 | 100 | - Dataset EDA examples are available in example.eda package with jupyter notebook format. 101 | 102 | ## Change Log 103 | 104 | See changelog.md 105 | 106 | ## Model list 107 | ### Click Through Rate Prediction 108 | | model | paper | 109 | |:-----|:------| 110 | |LR: Logistic Regression| [Simple and Scalable Response Prediction for Display Advertising][LR]| 111 | |FM: Factorization Machine|\[ICDM 2010\][Factorization Machines][FM]| 112 | |GBDT+LR: Gradient Boosting Tree with Logistic Regression|[Practical Lessons from Predicting Clicks on Ads at Facebook][GBDTLR]| 113 | |FNN: Factorization-supported Neural Network|\[ECIR 2016\][Deep Learning over Multi-field Categorical Data: A Case Study on User Response Prediction][FNN]| 114 | |PNN: Product-based Neural Network|\[ICDM 2016\][Product-based neural networks for user response prediction][PNN]| 115 | |Wide and Deep|\[DLRS 2016\][Wide & Deep Learning for Recommender Systems][WideDeep]| 116 | |DeepFM|\[IJCAI 2017\][DeepFM: A Factorization-Machine based Neural Network for CTR Prediction][DeepFM]| 117 | |AFM: Attentional Factorization Machine|\[IJCAI 2017\][Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks][AFM]| 118 | |NFM: Neural Factorization Machine|\[SIGIR 2017\][Neural Factorization Machines for Sparse Predictive Analytics][NFM]| 119 | |DCN: Deep & Cross Network|\[ADKDD 2017\][Deep & Cross Network for Ad Click Predictions][DCN]| 120 | |AutoInt|\[CIKM 2019\][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks][AutoInt]| 121 | |FLEN|\[AAAI 2020\][FLEN: Leveraging Field for Scalable CTR Prediction][FLEN]| 122 | 131 | [LR]:https://dl.acm.org/doi/pdf/10.1145/2532128?download=true 132 | [FM]:https://dl.acm.org/doi/10.1109/ICDM.2010.127 133 | [GBDTLR]:https://dl.acm.org/doi/pdf/10.1145/2648584.2648589 134 | [CCPM]:http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf 135 | [FFM]:https://dl.acm.org/doi/pdf/10.1145/2959100.2959134 136 | [FNN]:https://arxiv.org/pdf/1601.02376.pdf 137 | [PNN]:https://arxiv.org/pdf/1611.00144.pdf 138 | [WideDeep]:https://arxiv.org/pdf/1606.07792.pdf 139 | [DeepFM]:https://arxiv.org/pdf/1703.04247.pdf 140 | [PLM]:https://arxiv.org/abs/1704.05194 141 | [DCN]:https://arxiv.org/abs/1708.05123 142 | [AFM]:http://www.ijcai.org/proceedings/2017/435 143 | [NFM]:https://arxiv.org/pdf/1708.05027.pdf 144 | [xDeepFM]:https://arxiv.org/pdf/1803.05170.pdf 145 | [AutoInt]:https://arxiv.org/abs/1810.11921 146 | [ONN]:https://arxiv.org/pdf/1904.12579.pdf 147 | [FGCNN]:https://arxiv.org/pdf/1904.04447 148 | [FiBiNET]:https://arxiv.org/pdf/1905.09433.pdf 149 | [FLEN]:https://arxiv.org/pdf/1911.04690.pdf 150 | 151 | ### Sequential Recommendation 152 | | model/keywords | paper | 153 | |:------|:------| 154 | |DIN: Deep Interest Network|\[KDD 2018\][Deep Interest Network for Click-Through Rate Prediction][DIN]| 155 | |DIEN: Deep Interest Evolution Network|\[AAAI 2019\][Deep Interest Evolution Network for Click-Through Rate Prediction][DIEN]| 156 | 157 | [DIN]:https://arxiv.org/pdf/1706.06978.pdf 158 | [DIEN]:https://arxiv.org/pdf/1809.03672.pdf 159 | 160 | 178 | -------------------------------------------------------------------------------- /changelog.md: -------------------------------------------------------------------------------- 1 | version 0.0.1, 2020-06-10 2 | - Initial the project -------------------------------------------------------------------------------- /example/eda/criteo_eda.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [ 10 | { 11 | "data": { 12 | "text/plain": " y I1 I2 I3 I4 I5 I6 I7 I8 I9 ... C17 \\\n0 0 1.0 1 5.0 0.0 1382.0 4.0 15.0 2.0 181.0 ... e5ba7672 \n1 0 2.0 0 44.0 1.0 102.0 8.0 2.0 2.0 4.0 ... 07c540c4 \n2 0 2.0 0 1.0 14.0 767.0 89.0 4.0 2.0 245.0 ... 8efede7f \n3 0 1.0 893 5.0 4.0 4392.0 33.0 0.0 0.0 0.0 ... 1e88c74f \n4 0 3.0 -1 5.0 0.0 2.0 0.0 3.0 0.0 0.0 ... 1e88c74f \n\n C18 C19 C20 C21 C22 C23 C24 \\\n0 f54016b9 21ddcdc9 b1252a9d 07b5194c 0 3a171ecb c5c50484 \n1 b04e4670 21ddcdc9 5840adea 60f6221e 0 3a171ecb 43f13e8b \n2 3412118d 0 0 e587c466 ad3062eb 3a171ecb 3b183c5c \n3 74ef3502 0 0 6b3a5ca6 0 3a171ecb 9117a34a \n4 26b3c7a7 0 0 21c9516a 0 32c7478e b34f3128 \n\n C25 C26 \n0 e8b83407 9727dd16 \n1 e8b83407 731c3655 \n2 0 0 \n3 0 0 \n4 0 0 \n\n[5 rows x 40 columns]", 13 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
yI1I2I3I4I5I6I7I8I9...C17C18C19C20C21C22C23C24C25C26
001.015.00.01382.04.015.02.0181.0...e5ba7672f54016b921ddcdc9b1252a9d07b5194c03a171ecbc5c50484e8b834079727dd16
102.0044.01.0102.08.02.02.04.0...07c540c4b04e467021ddcdc95840adea60f6221e03a171ecb43f13e8be8b83407731c3655
202.001.014.0767.089.04.02.0245.0...8efede7f3412118d00e587c466ad3062eb3a171ecb3b183c5c00
301.08935.04.04392.033.00.00.00.0...1e88c74f74ef3502006b3a5ca603a171ecb9117a34a00
403.0-15.00.02.00.03.00.00.0...1e88c74f26b3c7a70021c9516a032c7478eb34f312800
\n

5 rows × 40 columns

\n
" 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import pandas as pd\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "import numpy as np\n", 24 | "from example.loader.criteo_loader import load_data,missing_values_process\n", 25 | "data=load_data('../../dataset/criteo_sample_10k.txt')\n", 26 | "data=missing_values_process(data)\n", 27 | "data.head()" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "outputs": [], 34 | "source": [ 35 | "counts=data['I1'].value_counts()\n", 36 | "counts=counts.sort_index()\n", 37 | "counts_ratio=counts/len(data)" 38 | ], 39 | "metadata": { 40 | "collapsed": false, 41 | "pycharm": { 42 | "name": "#%%\n" 43 | } 44 | } 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 13, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "0 0.260558\n", 55 | "0 0 0 0.260558\n", 56 | "1 0.521462\n", 57 | "1 1 1 0.521462\n", 58 | "2 0.054293\n", 59 | "2 2 2 0.054293\n", 60 | "3 0.033641\n", 61 | "3 3 3 0.033641\n", 62 | "4 0.024655\n", 63 | "5 0.017578\n", 64 | "4 5 4 0.042233\n", 65 | "6 0.013687\n", 66 | "7 0.010403\n", 67 | "8 0.008458\n", 68 | "6 8 5 0.032548\n", 69 | "9 0.006793\n", 70 | "10 0.005751\n", 71 | "11 0.004659\n", 72 | "12 0.00406\n", 73 | "13 0.003334\n", 74 | "14 0.003063\n", 75 | "15 0.002627\n", 76 | "9 15 6 0.030287\n", 77 | "0.9750219999999999\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "total_acc=0\n", 83 | "acc=0\n", 84 | "idx=0\n", 85 | "left=0\n", 86 | "right=0\n", 87 | "for i in range(len(counts_ratio)):\n", 88 | " total_acc+=counts_ratio[i]\n", 89 | " acc+=counts_ratio[i]\n", 90 | " print(i,counts_ratio[i])\n", 91 | "\n", 92 | " if acc>0.03:\n", 93 | " right=i\n", 94 | " print(left,right,idx,acc)\n", 95 | " idx+=1\n", 96 | " left=i+1\n", 97 | " acc=0\n", 98 | " if total_acc>0.97:\n", 99 | " print(total_acc)\n", 100 | " break\n" 101 | ], 102 | "metadata": { 103 | "collapsed": false, 104 | "pycharm": { 105 | "name": "#%%\n" 106 | } 107 | } 108 | } 109 | ], 110 | "metadata": { 111 | "kernelspec": { 112 | "display_name": "Python 3", 113 | "language": "python", 114 | "name": "python3" 115 | }, 116 | "language_info": { 117 | "codemirror_mode": { 118 | "name": "ipython", 119 | "version": 2 120 | }, 121 | "file_extension": ".py", 122 | "mimetype": "text/x-python", 123 | "name": "python", 124 | "nbconvert_exporter": "python", 125 | "pygments_lexer": "ipython2", 126 | "version": "2.7.6" 127 | } 128 | }, 129 | "nbformat": 4, 130 | "nbformat_minor": 0 131 | } -------------------------------------------------------------------------------- /example/loader/criteo_loader.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from feature.feature_meta import FeatureMeta 3 | from preprocess.feat_engineering import preprocess_features 4 | 5 | continuous_columns = ['I' + str(i) for i in range(1, 14)] 6 | category_columns = ['C' + str(i) for i in range(1, 27)] 7 | columns = ['y'] 8 | 9 | 10 | def load_data(path, n_samples=-1): 11 | data = pd.read_csv(path, delimiter='\t', header=None) 12 | 13 | columns.extend(continuous_columns) 14 | columns.extend(category_columns) 15 | data.columns = columns 16 | 17 | if n_samples > 0: 18 | data = data.sample(n=n_samples) 19 | data.reset_index(drop=True, inplace=True) 20 | return data 21 | 22 | 23 | def missing_values_process(data): 24 | continuous_fillna = data[continuous_columns].fillna(data[continuous_columns].median()) 25 | data[continuous_columns] = continuous_fillna 26 | category_fillna = data[category_columns].fillna('0') 27 | data[category_columns] = category_fillna 28 | return data 29 | 30 | 31 | def load_and_preprocess(path, n_samples=10000, discretize=False): 32 | r""" An example for load and preprocess criteo dataset 33 | 34 | :param path: File path of criteo dataset. 35 | :param n_samples: Number to sample from the full dataset. n_samples <= 0 means not to sample. 36 | :param discretize: Whether to discretize continuous features. All features will be processed with the same method 37 | :return: X_idx,X_value,y,feature_meta 38 | """ 39 | # discretize in {False,'linear','non-linear'}, if non-linear 40 | data = load_data(path, n_samples=n_samples) 41 | 42 | # fill NaN 43 | data = missing_values_process(data) 44 | 45 | # build feature meta instance 46 | feature_meta = FeatureMeta() 47 | for column in continuous_columns: 48 | feature_meta.add_continuous_feat(column, discretize=discretize) 49 | for column in category_columns: 50 | feature_meta.add_categorical_feat(column) 51 | 52 | X_idx, X_value = preprocess_features(feature_meta, data) 53 | y = data.y 54 | return X_idx, X_value, y, feature_meta 55 | -------------------------------------------------------------------------------- /example/loader/movielens_loader.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from feature.feature_meta import FeatureMeta 3 | from preprocess.feat_engineering import preprocess_features 4 | 5 | 6 | def load_data(path, n_samples=-1): 7 | raw_ratings = pd.read_csv(path) 8 | if n_samples > 0: 9 | raw_ratings = raw_ratings.sample(n=n_samples) 10 | raw_ratings.reset_index(drop=True, inplace=True) 11 | 12 | return raw_ratings 13 | 14 | 15 | def load_and_preprocess(path, n_samples=10000, binarize_label=True): 16 | raw_ratings = load_data(path, n_samples) 17 | 18 | feature_meta = FeatureMeta() 19 | feature_meta.add_categorical_feat('userId') 20 | feature_meta.add_categorical_feat('movieId') 21 | 22 | X = raw_ratings 23 | X_idx, X_value = preprocess_features(feature_meta, X) 24 | y = raw_ratings.rating 25 | 26 | if binarize_label: 27 | def transform_y(label): 28 | if label > 3: 29 | return 1 30 | else: 31 | return 0 32 | 33 | y = y.apply(transform_y) 34 | return X_idx, X_value, y, feature_meta 35 | -------------------------------------------------------------------------------- /example/model/deepfm_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data.dataset import TensorDataset 4 | from model.ctr.deepfm import DeepFM 5 | from util.train import train_model_hold_out 6 | from example.loader.criteo_loader import load_and_preprocess 7 | 8 | 9 | def train_deepfm(x_idx, x_value, label, feat_meta, out_type='binary'): 10 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | X_idx_tensor_gpu = torch.LongTensor(x_idx).to(device) 13 | X_value_tensor_gpu = torch.Tensor(x_value).to(device) 14 | y_tensor_gpu = torch.Tensor(label).to(device) 15 | y_tensor_gpu = y_tensor_gpu.reshape(-1, 1) 16 | 17 | dataset = TensorDataset(X_idx_tensor_gpu, X_value_tensor_gpu, y_tensor_gpu) 18 | deepfm_model = DeepFM(emb_dim=5, feat_dim=feat_meta.get_num_feats(), num_fields=feat_meta.get_num_fields(), 19 | out_type=out_type).to(device) 20 | optimizer = torch.optim.Adam(deepfm_model.parameters(), lr=1e-4) 21 | 22 | train_model_hold_out(job_name='deepfm-binary-cls', device=device, 23 | model=deepfm_model, dataset=dataset, 24 | loss_func=nn.BCELoss(), optimizer=optimizer, 25 | epochs=20, batch_size=256) 26 | return deepfm_model 27 | 28 | 29 | if __name__ == '__main__': 30 | # load movielens dataset 31 | path = '../../dataset/criteo_sampled_10k.txt' 32 | X_idx, X_value, y, feature_meta = load_and_preprocess(path) 33 | model = train_deepfm(X_idx, X_value, y, feature_meta) 34 | -------------------------------------------------------------------------------- /example/model/fm_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data.dataset import TensorDataset 4 | from model.ctr.fm import FM 5 | from util.train import train_model_hold_out 6 | from example.loader.criteo_loader import load_and_preprocess 7 | 8 | 9 | def train_fm(x_idx, x_value, label, feat_meta, out_type='binary'): 10 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | X_idx_tensor_gpu = torch.LongTensor(x_idx).to(device) 13 | X_value_tensor_gpu = torch.Tensor(x_value).to(device) 14 | y_tensor_gpu = torch.Tensor(label).to(device) 15 | y_tensor_gpu = y_tensor_gpu.reshape(-1, 1) 16 | 17 | X_cuda = TensorDataset(X_idx_tensor_gpu, X_value_tensor_gpu, y_tensor_gpu) 18 | fm_model = FM(emb_dim=5, num_feats=feat_meta.get_num_feats(), out_type=out_type).to(device) 19 | optimizer = torch.optim.Adam(fm_model.parameters(), lr=1e-4) 20 | 21 | train_model_hold_out(job_name='fm-binary-cls', device=device, 22 | model=fm_model, dataset=X_cuda, 23 | loss_func=nn.BCELoss(), optimizer=optimizer, 24 | epochs=20, batch_size=256) 25 | return fm_model 26 | 27 | 28 | if __name__ == '__main__': 29 | # load movielens dataset 30 | path = '../../dataset/criteo_sampled_10k.txt' 31 | X_idx, X_value, y, feature_meta = load_and_preprocess(path) 32 | model = train_fm(X_idx, X_value, y, feature_meta) 33 | -------------------------------------------------------------------------------- /example/model/pnn_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data.dataset import TensorDataset 4 | from model.ctr.pnn import PNN 5 | from util.train import train_model_hold_out 6 | from example.loader.criteo_loader import load_and_preprocess 7 | 8 | 9 | def train_pnn(x_idx, x_value, label, feat_meta, out_type='binary'): 10 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | X_idx_tensor_gpu = torch.LongTensor(x_idx).to(device) 13 | # X_value_tensor_gpu = torch.Tensor(x_value).to(device) 14 | y_tensor_gpu = torch.Tensor(label).to(device) 15 | y_tensor_gpu = y_tensor_gpu.reshape(-1, 1) 16 | 17 | dataset = TensorDataset(X_idx_tensor_gpu, y_tensor_gpu) 18 | pnn_model = PNN(emb_dim=5, num_feats=feat_meta.get_num_feats(), num_fields=feat_meta.get_num_fields(), 19 | out_type=out_type).to(device) 20 | optimizer = torch.optim.Adam(pnn_model.parameters(), lr=1e-4) 21 | 22 | train_model_hold_out(job_name='pnn-binary-cls', device=device, 23 | model=pnn_model, dataset=dataset, 24 | loss_func=nn.BCELoss(), optimizer=optimizer, 25 | epochs=20, batch_size=256) 26 | return pnn_model 27 | 28 | 29 | if __name__ == '__main__': 30 | # load movielens dataset 31 | path = '../../dataset/criteo_sampled_10k.txt' 32 | X_idx, X_value, y, feature_meta = load_and_preprocess(path) 33 | model = train_pnn(X_idx, X_value, y, feature_meta) 34 | -------------------------------------------------------------------------------- /feature/__init__.py: -------------------------------------------------------------------------------- 1 | from feature.feature import ContinuousFeature, CategoricalFeature, MultiCategoryFeature 2 | from feature.feature_meta import FeatureMeta 3 | -------------------------------------------------------------------------------- /feature/feature.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from sklearn.preprocessing import LabelEncoder, MinMaxScaler, Normalizer, StandardScaler 3 | import numpy as np 4 | 5 | 6 | class Feature: 7 | r"""General feature description for all types of feature 8 | 9 | :argument 10 | name (str): name of the feature, also refers to the column name in data (pd.Dataframe) 11 | start_idx: start index of the feature when it is transformed into index and value form 12 | dim: number of classes when it is transformed into index and value form 13 | proc_type (str): type of the feature 14 | processor: processor used in the preprocess stage. 15 | 16 | """ 17 | 18 | def __init__(self, name, start_idx=None, proc_type='continuous', processor=None): 19 | super().__init__() 20 | self.name = name 21 | self.proc_type = proc_type 22 | self.start_idx = start_idx 23 | self.processor = processor 24 | 25 | @abstractmethod 26 | def get_idx_and_value(self, value): 27 | pass 28 | 29 | def __str__(self): 30 | return 'feature name:{0}, start index:{1}, feature type:{2}'.format(self.name, self.start_idx, self.proc_type) 31 | 32 | 33 | class ContinuousFeature(Feature): 34 | r"""Feature description for continuous feature 35 | 36 | :argument 37 | discretize: method the feature to discretize, default is None 38 | discretize_bins: number of bins the feature discretize into 39 | transformation: method to transform the feature, default is MinMaxScaler. Note that the feature will be 40 | transformed only when discretize=None 41 | """ 42 | 43 | def __init__(self, name, transformation=None, discretize=None, discretize_bins=10): 44 | if not transformation: 45 | transformation = MinMaxScaler() 46 | if discretize: 47 | if discretize not in ['eq_dist', 'eq_freq', 'cluster']: 48 | discretize = 'eq_freq' 49 | self.discretize = discretize 50 | self.dim = discretize_bins 51 | self.bins = [] 52 | else: 53 | self.transformation = transformation 54 | self.discretize = discretize 55 | super(ContinuousFeature, self).__init__(name) 56 | 57 | def get_idx_and_value(self, value): 58 | return self.start_idx, value 59 | 60 | 61 | class CategoricalFeature(Feature): 62 | r"""Feature description for categorical feature 63 | 64 | :argument 65 | all_categories: give all categories of the feature when the description is created. The argument should be None 66 | or list of str. Once the argument is not None, processor and dim is decided. By default the 67 | categories is generated after scanning the column. 68 | """ 69 | 70 | def __init__(self, name, all_categories=None): 71 | super(CategoricalFeature, self).__init__(name, proc_type='categorical') 72 | if all_categories: 73 | self.processor = LabelEncoder() 74 | self.processor.fit(all_categories) 75 | self.dim = len(self.processor.classes_) 76 | 77 | def get_idx_and_value(self, value): 78 | value = [value] 79 | idx = self.processor.transform(value)[0] 80 | return self.start_idx + idx, 1 81 | 82 | def __str__(self): 83 | str_basic = super(CategoricalFeature, self).__str__() 84 | if self.processor: 85 | return str_basic + ', feature dim:{0}, category encoder'.format(self.dim) 86 | else: 87 | return str_basic + ',category encoder not generated yet.' 88 | 89 | 90 | class MultiCategoryFeature(Feature): 91 | def __init__(self, name, all_categories=None): 92 | super(MultiCategoryFeature, self).__init__(name, proc_type='multi_category') 93 | if all_categories: 94 | self.processor = LabelEncoder() 95 | self.processor.fit(all_categories) 96 | self.dim = len(self.processor.classes_) 97 | 98 | def get_idx_and_value(self, value): 99 | value = [value] 100 | values = self.processor.transform(value)[0] 101 | idxes = range(self.start_idx, self.start_idx + self.dim) 102 | return idxes, values 103 | 104 | def __str__(self): 105 | str_basic = super(MultiCategoryFeature, self).__str__() 106 | if self.processor: 107 | return str_basic + ', feature dim:{0}, category encoder'.format(self.dim) 108 | else: 109 | return str_basic + ',category encoder not generated yet.' 110 | -------------------------------------------------------------------------------- /feature/feature_meta.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from feature.feature import * 4 | 5 | 6 | class FeatureMeta: 7 | def __init__(self): 8 | super().__init__() 9 | self.continuous_feats = OrderedDict() 10 | self.categorical_feats = OrderedDict() 11 | self.multi_category_feats = OrderedDict() 12 | self.feat_dict = {} 13 | 14 | def add_continuous_feat(self, name, transformation=None, discretize=False, discretize_bin=10): 15 | self.delete_feat(name) 16 | self.continuous_feats[name] = ContinuousFeature(name, transformation, discretize, discretize_bin) 17 | self.feat_dict[name] = 'continuous' 18 | 19 | def add_categorical_feat(self, name, all_categories=None): 20 | self.delete_feat(name) 21 | self.categorical_feats[name] = CategoricalFeature(name, all_categories) 22 | self.feat_dict[name] = 'categorical' 23 | 24 | def add_multi_category_feat(self, name, all_categories=None): 25 | self.delete_feat(name) 26 | self.multi_category_feats[name] = MultiCategoryFeature(name, all_categories) 27 | self.feat_dict[name] = 'multi_category' 28 | 29 | def delete_feat(self, name): 30 | if name in self.feat_dict: 31 | feat_type = self.feat_dict[name] 32 | if feat_type == 'continuous': 33 | del self.continuous_feats[name] 34 | elif feat_type == 'categorical': 35 | del self.categorical_feats[name] 36 | elif feat_type == 'multi_category': 37 | del self.multi_category_feats[name] 38 | 39 | def get_num_feats(self): 40 | total_dim = 0 41 | total_dim += len(self.continuous_feats) 42 | for key in self.categorical_feats: 43 | feat = self.categorical_feats[key] 44 | total_dim += feat.dim 45 | 46 | for key in self.multi_category_feats: 47 | feat = self.multi_category_feats[key] 48 | total_dim += feat.dim 49 | return total_dim 50 | 51 | def get_num_fields(self): 52 | return len(self.feat_dict.keys()) 53 | 54 | def get_num_continuous_fields(self): 55 | return len(self.continuous_feats.keys()) 56 | 57 | def __str__(self): 58 | feats_list = [self.continuous_feats, self.categorical_feats, self.multi_category_feats] 59 | info_strs = [] 60 | for feats in feats_list: 61 | info_str = '' 62 | for key in feats: 63 | feat = feats[key] 64 | info_str += str(feat) 65 | info_str += '\n' 66 | info_strs.append(info_str) 67 | return 'Continuous Features:\n{}Categorical Features:\n{}Multi-Category Features:\n{}'.format(*info_strs) 68 | -------------------------------------------------------------------------------- /model/basic/__init__.py: -------------------------------------------------------------------------------- 1 | from .output_layer import OutputLayer 2 | from .mlp import MLP 3 | import model.basic.functional 4 | from .attention import * 5 | from .embedding_layer import * 6 | from .activation import * 7 | -------------------------------------------------------------------------------- /model/basic/activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Dice(nn.Module): 7 | def __init__(self, epsilon=1e-3): 8 | super(Dice, self).__init__() 9 | self.epsilon = epsilon 10 | self.alpha = nn.Parameter(torch.randn(1)) 11 | 12 | def forward(self, x: torch.Tensor): 13 | # x: N * num_neurons 14 | 15 | avg = x.mean(dim=1) # N 16 | avg = avg.unsqueeze(dim=1) # N * 1 17 | var = torch.pow(x - avg, 2) + self.epsilon # N * num_neurons 18 | var = var.sum(dim=1).unsqueeze(dim=1) # N * 1 19 | 20 | ps = (x - avg) / torch.sqrt(var) # N * 1 21 | 22 | ps = nn.Sigmoid()(ps) # N * 1 23 | return ps * x + (1 - ps) * self.alpha * x 24 | -------------------------------------------------------------------------------- /model/basic/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model.basic.functional import inner_product_attention_signal 5 | 6 | 7 | class LocationBasedAttention(nn.Module): 8 | def __init__(self, emb_dim, att_weight_dim): 9 | super(LocationBasedAttention, self).__init__() 10 | self.weights = nn.Parameter(torch.zeros(emb_dim, att_weight_dim)) 11 | nn.init.xavier_uniform_(self.weights.data) 12 | self.bias = nn.Parameter(torch.randn(att_weight_dim)) 13 | self.h = nn.Parameter(torch.randn(att_weight_dim)) 14 | 15 | def forward(self, values): 16 | # values: N * num * emb_dim 17 | att_signal = torch.matmul(values, self.weights) # N * num * att_weight_dim 18 | att_signal = att_signal + self.bias # N * num * att_weight_dim 19 | att_signal = F.relu(att_signal) 20 | att_signal = torch.mul(att_signal, self.h) # N * num * att_weight_dim 21 | att_signal = torch.sum(att_signal, dim=2) # N * num 22 | att_signal = F.softmax(att_signal, dim=1) # N * num 23 | return att_signal 24 | 25 | 26 | # class MultiHeadSelfAttention(nn.Module): 27 | # def __init__(self, num_heads, dim, project_dim): 28 | # super(MultiHeadSelfAttention, self).__init__() 29 | # self.num_heads = num_heads 30 | # self.dim = dim 31 | # self.project_dim = project_dim 32 | # # W_query^h 33 | # self.query_projection_weights = nn.Parameter(torch.zeros(dim, project_dim * num_heads)) 34 | # # emb_dim * (d` * num_heads) 35 | # nn.init.xavier_uniform_(self.query_projection_weights.data) 36 | # 37 | # # W_key^h 38 | # self.key_projection_weights = nn.Parameter(torch.zeros(dim, project_dim * num_heads)) 39 | # # emb_dim * (d` * num_heads) 40 | # nn.init.xavier_uniform_(self.key_projection_weights.data) 41 | # 42 | # # W_value^h 43 | # self.value_projection_weights = nn.Parameter(torch.zeros(dim, project_dim * num_heads)) 44 | # # emb_dim * (d` * num_heads) 45 | # nn.init.xavier_uniform_(self.value_projection_weights.data) 46 | # 47 | # def forward(self, feat_emb): 48 | # # feat_emb: N * num_feats * emb_dim 49 | # # (N * num_feats * emb_dim) * (emb_dim * (d` * num_heads)) = (N * num_feats * (d` * num_heads)) 50 | # queries = torch.matmul(feat_emb, self.query_projection_weights) 51 | 52 | 53 | class ScaledDotProductAttention(nn.Module): 54 | def __init__(self, dim): 55 | super(ScaledDotProductAttention, self).__init__() 56 | self.dim = dim 57 | self.attention_signal = AttentionSignal(dim, 'inner-product', True) 58 | 59 | def forward(self, query, keys, values): 60 | # query: N * emb_dim 61 | # keys: N * num_keys * emb_dim 62 | # values: N * num_keys * emb_dim 63 | att_signal = self.attention_signal(query, keys) # N * num_keys 64 | att_signal = att_signal.unsqueeze(dim=2) # N * num_keys * 1 65 | weighted_values = torch.mul(att_signal, values) 66 | return weighted_values # N * num_keys * emb_dim 67 | 68 | 69 | class AttentionSignal(nn.Module): 70 | def __init__(self, query_dim,key_dim=None, similarity='inner-product', scale=False, activation='relu'): 71 | super(AttentionSignal, self).__init__() 72 | self.query_dim = query_dim 73 | self.similarity = similarity 74 | self.scale = scale 75 | self.activation = activation 76 | if similarity == 'inner-product': # a_i = query^T * keys_i 77 | pass 78 | 79 | elif self.similarity == 'concat': # a_i = v^T * ReLU(W_q * query + W_k * keys_i) 80 | # v 81 | self.v_a = nn.Parameter(torch.zeros(query_dim)) 82 | nn.init.xavier_uniform_(self.v_a.data) 83 | # W_q 84 | self.weights_q = nn.Parameter(torch.zeros((query_dim, query_dim))) 85 | nn.init.xavier_uniform_(self.weights_q.data) 86 | # W_k 87 | self.weights_k = nn.Parameter(torch.zeros((query_dim, query_dim))) 88 | nn.init.xavier_uniform_(self.weights_k.data) 89 | 90 | else: # general, a_i = query^T * W * keys_i 91 | self.weights_a = nn.Parameter(torch.zeros((query_dim, key_dim))) 92 | nn.init.xavier_uniform_(self.weights_a.data) 93 | 94 | def forward(self, query, keys): 95 | # query: N * emb_dim 96 | # keys: N * num_keys * emb_dim 97 | 98 | if self.similarity == 'inner-product': 99 | att = inner_product_attention_signal(query, keys, 'softmax') 100 | 101 | elif self.similarity == 'concat': 102 | query = query.unsqueeze(dim=1) # N * 1 * emb_dim 103 | weighted_q = torch.matmul(query, self.weights_q) # N * 1 * emb_dim 104 | weighted_k = torch.matmul(keys, self.weights_k) # N * num_keys * emb_dim 105 | weighted_kq = torch.add(weighted_q, weighted_k) # N * num_keys * emb_dim 106 | if not self.activation: 107 | pass 108 | elif self.activation == 'relu': 109 | weighted_kq = F.relu(weighted_kq) 110 | elif self.activation == 'tanh': 111 | weighted_kq = F.tanh(weighted_kq) 112 | elif self.activation == 'sigmoid': 113 | weighted_kq = F.sigmoid(weighted_kq) 114 | att = torch.mul(weighted_kq, self.v_a) # N * num_keys * emb_dim 115 | att = torch.sum(att, dim=2) # N * num_keys 116 | 117 | else: 118 | query = query.unsqueeze(dim=1) # N * 1 * Q_dim 119 | qw = torch.matmul(query, self.weights_a) # (N * 1 * emb_dim) * (Q_dim * K_dim) = N * 1 * K_dim 120 | qw = qw.transpose(1, 2) # N * K_dim * 1 121 | att = torch.bmm(keys, qw) # (N * num_keys * K_dim) * (N * K_dim * 1) = N * num_keys * 1 122 | att = att.squeeze() # N * num_keys 123 | if self.scale: 124 | att = att / torch.sqrt(self.query_dim) 125 | return F.softmax(att,dim=1) 126 | -------------------------------------------------------------------------------- /model/basic/embedding_layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from feature import FeatureMeta 5 | 6 | 7 | class UniformDimensionEmbedding(nn.Module): 8 | """ An embedding layer using embedding vector with a uniform dimension for all features. Continuous features are 9 | embedded through 'emb_mul", which is the multiplication of the embedding vector and the value. 10 | """ 11 | 12 | def __init__(self, feat_meta: FeatureMeta, emb_dim): 13 | super(UniformDimensionEmbedding, self).__init__() 14 | self.num_feats = feat_meta.get_num_feats() 15 | self.emb_dim = emb_dim 16 | 17 | self.emb_layer = nn.Embedding(num_embeddings=self.num_feats, embedding_dim=emb_dim) 18 | nn.init.xavier_uniform_(self.emb_layer.weight) 19 | 20 | cont_fields = [ 21 | feat_meta.continuous_feats[feat_name].start_idx 22 | for feat_name in feat_meta.continuous_feats 23 | ] 24 | 25 | self.cont_idx = torch.LongTensor(cont_fields) 26 | 27 | def forward(self, continuous_value, universal_category_index): 28 | # continuous_value: N * num_cont_fields 29 | # category_index: N * num_cate_fields 30 | cont_emb = self.emb_layer(self.cont_idx) # num_cont_fields * emb_dim 31 | continuous_value = continuous_value.unsqueeze(dim=2) # N * num_cont_fields * 1 32 | cont_emb = torch.mul(continuous_value, cont_emb) # N * num_cont_fields * emb_dim 33 | 34 | cate_emb = self.emb_layer(universal_category_index) # N * num_cate_fields * emb_dim 35 | emb = torch.cat([cont_emb, cate_emb], dim=1) # N * num_fields * emb_dim 36 | return emb 37 | 38 | 39 | class FeatureEmbedding(nn.Module): 40 | def __init__(self, feat_meta: FeatureMeta, uniform_dim=None, proc_continuous='concat', continuous_emb_dim=None): 41 | super(FeatureEmbedding, self).__init__() 42 | 43 | if proc_continuous not in ['concat', 'emb_mul']: 44 | proc_continuous = 'concat' 45 | self.proc_continuous = proc_continuous 46 | 47 | if not isinstance(uniform_dim, int): 48 | uniform_dim = 'auto' 49 | else: 50 | continuous_emb_dim = uniform_dim 51 | 52 | if proc_continuous == 'emb_mul' and ( 53 | not isinstance(uniform_dim, int)) and ( 54 | not isinstance(continuous_emb_dim, int)): 55 | raise Exception('No dim designated for embeddings of continuous feature! ' 56 | 'Check param \'uniform_dim\' or \'continuous_emb_dim\'.') 57 | 58 | self.uniform_dim = uniform_dim 59 | self.continuous_emb_dim = continuous_emb_dim 60 | self.feat_meta = feat_meta 61 | 62 | self.num_cont_fields = len(feat_meta.continuous_feats.keys()) 63 | continuous_emb_list = [ 64 | ContinuousEmbedding(proc_continuous, continuous_emb_dim) 65 | for _ in feat_meta.continuous_feats.keys() 66 | ] 67 | self.continuous_embeddings = nn.ModuleList(continuous_emb_list) 68 | 69 | self.num_cate_fields = len(feat_meta.categorical_feats.keys()) 70 | categorical_feat_dict = feat_meta.categorical_feats 71 | categorical_emb_list = [ 72 | CategoricalEmbedding(categorical_feat_dict[feat_name].dim) 73 | for feat_name in categorical_feat_dict.keys() 74 | ] 75 | self.categorical_embeddings = nn.ModuleList(categorical_emb_list) 76 | 77 | def forward(self, continuous_value, category_index): 78 | continuous_value = torch.split(continuous_value, 1, 1) 79 | cont_embs = [ 80 | self.continuous_embeddings[i](continuous_value[i]) # N * 1/emb_dim 81 | for i in range(len(continuous_value)) 82 | ] 83 | cont_emb = torch.cat(cont_embs, dim=1) 84 | 85 | category_index = torch.split(category_index, 1, 1) 86 | cate_embs = [ 87 | self.categorical_embeddings[i](category_index[i]) # N * emb_dim 88 | for i in range(len(category_index)) 89 | ] 90 | cate_emb = torch.cat(cate_embs, dim=1) 91 | 92 | embedding = torch.cat([cont_emb, cate_emb], dim=1) 93 | return embedding 94 | 95 | 96 | class ContinuousEmbedding(nn.Module): 97 | def __init__(self, proc='concat', dim=1): 98 | super(ContinuousEmbedding, self).__init__() 99 | if proc not in ['concat', 'emb_mul']: 100 | proc = 'concat' 101 | self.proc = proc 102 | self.dim = dim 103 | if proc == 'emb_mul': 104 | self.emb_layer = nn.Parameter(torch.randn(dim)) 105 | 106 | def forward(self, x): 107 | if self.proc == 'concat': 108 | return x # N * 1 109 | else: # emb_mul 110 | return self.emb_layer * x # N * emb_dim 111 | 112 | 113 | class CategoricalEmbedding(nn.Module): 114 | def __init__(self, num_classes, emb_dim='auto'): 115 | super(CategoricalEmbedding, self).__init__() 116 | 117 | if emb_dim == 'auto' or not isinstance(emb_dim, int): 118 | emb_dim = get_auto_embedding_dim(num_classes) 119 | 120 | self.emb_dim = emb_dim 121 | self.num_classes = num_classes 122 | 123 | self.emb_layer = nn.Embedding(num_embeddings=num_classes, embedding_dim=emb_dim) 124 | nn.init.xavier_uniform_(self.emb_layer.weight) 125 | 126 | def forward(self, x): 127 | return self.emb_layer(x) # N * emb_dim 128 | 129 | 130 | class MultiValueEmbedding(nn.Module): 131 | """ 132 | Embedding for multi-value feature with sum pooling or average pooling 133 | if average pooling is designated, 134 | a second parameter should be provided with number of non-zero values in this field. 135 | """ 136 | def __init__(self, num_classes, emb_dim='auto', aggregate='sum'): 137 | super(MultiValueEmbedding, self).__init__() 138 | if emb_dim == 'auto' or not isinstance(emb_dim, int): 139 | emb_dim = get_auto_embedding_dim(num_classes) 140 | 141 | self.emb_dim = emb_dim 142 | self.num_classes = num_classes 143 | self.agg = aggregate 144 | 145 | self.emb_layer = nn.Parameter(torch.zeros((1, self.num_classes, self.emb_dim))) 146 | nn.init.xavier_uniform_(self.emb_layer.data) 147 | 148 | def forward(self, x, num_ones): 149 | # x: N * num_classes 150 | # num_ones: N * 1 151 | x = x.unsqueeze(dim=2) # N * num_classes * 1 152 | emb_value = torch.mul(x, self.emb_layer) # N * num_classes * emb_dim 153 | emb_value = emb_value.sum(dim=1) # N * emb_dim 154 | if self.agg == 'avg': 155 | emb_value = emb_value / num_ones 156 | return emb_value 157 | 158 | 159 | def get_auto_embedding_dim(num_classes): 160 | """ Calculate the dim of embedding vector according to number of classes in the category 161 | emb_dim = [6 * (num_classes)^(1/4)] 162 | 163 | ref: Ruoxi Wang, Bin Fu, Gang Fu, and Mingliang Wang. 2017. Deep & Cross Network for Ad Click Predictions. 164 | In Proceedings of the ADKDD’17 (ADKDD’17). Association for Computing Machinery, New York, NY, USA, Article 12, 1–7. 165 | DOI:https://doi.org/10.1145/3124749.3124754 166 | 167 | 168 | :param num_classes: number of classes in the category 169 | :return: the dim of embedding vector 170 | """ 171 | return math.floor(6 * math.pow(num_classes, 0.26)) 172 | -------------------------------------------------------------------------------- /model/basic/enum/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroized/DeepRec-torch/2957f65501243107284f3a43735b77b3e89ce684/model/basic/enum/__init__.py -------------------------------------------------------------------------------- /model/basic/enum/activation_enum.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import torch.nn as nn 3 | 4 | 5 | class ActivationEnum(Enum): 6 | RELU = nn.ReLU() 7 | SIGMOID = nn.Sigmoid() 8 | TANH = nn.Tanh() 9 | PRELU = nn.PReLU() 10 | -------------------------------------------------------------------------------- /model/basic/enum/attention_enum.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class AttentionSimilarityEnum(Enum): 5 | INNER_PRODUCT = 'inner-product' 6 | CONCAT = 'concat' 7 | GENERAL = 'general' 8 | -------------------------------------------------------------------------------- /model/basic/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def build_cross(num_fields, feat_emb): 6 | # num_pairs = num_fields * (num_fields-1) / 2 7 | row = [] 8 | col = [] 9 | for i in range(num_fields - 1): 10 | for j in range(i + 1, num_fields): 11 | row.append(i) 12 | col.append(j) 13 | p = feat_emb[:, row] # N * num_pairs * emb_dim 14 | q = feat_emb[:, col] # N * num_pairs * emb_dim 15 | return p, q 16 | 17 | 18 | def bi_interaction(input_tensor): 19 | # tensor: N * F * emb_dim 20 | square_of_sum = torch.sum(input_tensor, dim=1) # N * emb_dim 21 | square_of_sum = torch.mul(square_of_sum, square_of_sum) # N * emb_dim 22 | 23 | sum_of_square = torch.mul(input_tensor, input_tensor) # N * F * emb_dim 24 | sum_of_square = torch.sum(sum_of_square, dim=1) # N * emb_dim 25 | 26 | bi_out = torch.sub(square_of_sum, sum_of_square) 27 | bi_out = bi_out / 2 28 | return bi_out # N * emb_dim 29 | 30 | 31 | def inner_product_attention_signal(query, keys, norm='softmax'): 32 | # query: N * emb_dim 33 | # keys: N * num_keys * emb_dim 34 | 35 | query = query.unsqueeze(dim=1) # N * 1 * emb_dim 36 | kq = torch.mul(query, keys) # N * num_keys * emb_dim (broadcast) 37 | kq = torch.sum(kq, dim=2) # N * num_keys 38 | 39 | if norm == 'softmax': 40 | kq = F.softmax(kq, dim=1) # N * num_keys 41 | return kq 42 | -------------------------------------------------------------------------------- /model/basic/gbdt.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import lightgbm as lgb 3 | import numpy as np 4 | from lightgbm.sklearn import LGBMClassifier, LGBMRegressor 5 | from sklearn.model_selection import train_test_split 6 | 7 | 8 | class GBDT: 9 | def __init__(self, num_leaves=31, max_depth=-1, n_estimators=100, min_data_in_leaf=20, 10 | learning_rate=0.1, objective='binary'): 11 | self.num_leaves = num_leaves 12 | self.max_depth = max_depth 13 | self.n_estimators = n_estimators 14 | self.min_data_in_leaf = min_data_in_leaf 15 | self.learning_rate = learning_rate 16 | self.objective = objective 17 | if objective == 'binary': 18 | self.model = LGBMClassifier 19 | else: 20 | self.model = LGBMRegressor 21 | self.model = self.model(num_leaves=num_leaves, max_depth=max_depth, 22 | n_estimators=n_estimators, learning_rate=learning_rate, 23 | min_child_samples=min_data_in_leaf) 24 | 25 | def train(self, data, y, val_ratio=0.2, early_stopping_rounds=5): 26 | X_train, X_test, y_train, y_test = train_test_split(data, y, test_size=val_ratio) 27 | # train_set = lgb.Dataset(X_train, label=y_train, free_raw_data=False) 28 | val_set = lgb.Dataset(X_test, label=y_test, free_raw_data=False) 29 | self.model.fit(X_train, y_train, eval_set=val_set, early_stopping_rounds=early_stopping_rounds) 30 | 31 | def pred(self, data): 32 | pred_y, leaf_indices = self.model.predict(data, pred_leaf=True) 33 | base_idx = np.arange(0, self.num_leaves * self.n_estimators, self.n_estimators) 34 | feat_idx = base_idx + leaf_indices 35 | return pred_y, feat_idx 36 | -------------------------------------------------------------------------------- /model/basic/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from .enum.activation_enum import * 5 | 6 | 7 | class MLP(nn.Module): 8 | 9 | def __init__(self, fc_in_dim, fc_dims, dropout=None, batch_norm=None, activation=nn.ReLU()): 10 | """ 11 | The MLP(Multi-Layer Perceptrons) module 12 | :param fc_in_dim: The dimension of input tensor 13 | :param fc_dims: The num_neurons of each layer, should be array-like 14 | :param dropout: The dropout rate of the MLP module, can be number or array-like ranges (0,1), by default None 15 | :param batch_norm: Whether to use batch normalization after each layer, by default None 16 | :param activation: The activation function used in each layer, by default nn.ReLU() 17 | """ 18 | super(MLP, self).__init__() 19 | self.fc_dims = fc_dims 20 | layer_dims = [fc_in_dim] 21 | layer_dims.extend(fc_dims) 22 | layers = [] 23 | 24 | if not dropout: 25 | dropout = np.repeat(0, len(fc_dims)) 26 | if isinstance(dropout, float): 27 | dropout = np.repeat(dropout, len(fc_dims)) 28 | 29 | for i in range(len(layer_dims) - 1): 30 | fc_layer = nn.Linear(in_features=layer_dims[i], out_features=layer_dims[i + 1]) 31 | nn.init.xavier_uniform_(fc_layer.weight) 32 | layers.append(fc_layer) 33 | if batch_norm: 34 | batch_norm_layer = nn.BatchNorm1d(num_features=layer_dims[i + 1]) 35 | layers.append(batch_norm_layer) 36 | layers.append(activation) 37 | if dropout[i]: 38 | dropout_layer = nn.Dropout(dropout[i]) 39 | layers.append(dropout_layer) 40 | self.mlp = nn.Sequential(*layers) 41 | 42 | def forward(self, feature): 43 | y = self.mlp(feature) 44 | return y 45 | -------------------------------------------------------------------------------- /model/basic/output_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class OutputLayer(nn.Module): 6 | def __init__(self, in_dim, out_type='binary', use_bias=True): 7 | super(OutputLayer, self).__init__() 8 | self.out_type = out_type 9 | self.in_dim = in_dim 10 | self.use_bias = use_bias 11 | if not self.in_dim == 1: 12 | self.weights = nn.Linear(in_features=in_dim, out_features=1, bias=self.use_bias) 13 | if self.out_type == 'binary': 14 | self.output_layer = nn.Sigmoid() 15 | 16 | def forward(self, x): 17 | if not self.in_dim == 1: 18 | y = self.weights(x) 19 | else: 20 | y = x 21 | if self.out_type == 'binary': 22 | y = self.output_layer(y) 23 | return y 24 | -------------------------------------------------------------------------------- /model/ctr/__init__.py: -------------------------------------------------------------------------------- 1 | from model.ctr.afm import AFM 2 | from model.ctr.autoInt import AutoInt 3 | from model.ctr.deepfm import DeepFM 4 | from model.ctr.flen import FLEN 5 | from model.ctr.fm import FM 6 | from model.ctr.fnn import FNN 7 | from model.ctr.gbdt_lr import GBDTLR 8 | from model.ctr.lr import LR 9 | from model.ctr.nfm import NFM 10 | from model.ctr.pnn import PNN 11 | from model.ctr.wide_and_deep import WideAndDeep 12 | from model.ctr.dcn import DCN 13 | -------------------------------------------------------------------------------- /model/ctr/afm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.basic.attention import LocationBasedAttention 5 | from model.basic.output_layer import OutputLayer 6 | from model.ctr.pnn import build_cross 7 | 8 | """ 9 | Model: AFM: Attentional Factorization Machines 10 | Version: IJCAI 2017 11 | Reference: Xiao, J., Ye, H., He, X., Zhang, H., Wu, F., & Chua, T. (2017). 12 | Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks. 13 | Proceedings of the Twenty-Sixth International Joint Conference on Artificial Intelligence 14 | """ 15 | 16 | 17 | class AFM(nn.Module): 18 | def __init__(self, emb_dim, num_feats, num_fields, att_weight_dim, out_type='binary'): 19 | super(AFM, self).__init__() 20 | self.emb_dim = emb_dim 21 | self.num_feats = num_feats 22 | self.num_fields = num_fields 23 | self.att_weight_dim = att_weight_dim 24 | self.first_order_weights = nn.Embedding(num_embeddings=num_feats, embedding_dim=1) 25 | nn.init.xavier_uniform_(self.first_order_weights.weight) 26 | self.bias = nn.Parameter(torch.randn(1)) 27 | self.emb_layer = nn.Embedding(num_embeddings=num_feats, embedding_dim=emb_dim) 28 | nn.init.xavier_uniform_(self.emb_layer.weight) 29 | self.num_pairs = num_fields * (num_fields - 1) / 2 30 | 31 | self.att_layer = LocationBasedAttention(emb_dim, att_weight_dim) 32 | self.p = nn.Parameter(torch.randn(emb_dim)) 33 | 34 | self.output_layer = OutputLayer(1, out_type) 35 | 36 | def forward(self, feat_index, feat_value): 37 | feat_value = feat_value.unsqueeze(2) # N * num_fields * 1 38 | # first order 39 | first_order_weight = self.first_order_weights(feat_index) # N * num_fields * 1 40 | y_first_order = torch.mul(first_order_weight, feat_value) # N * num_fields * 1 41 | y_first_order = torch.sum(y_first_order, dim=1) # N * 1 42 | y_first_order = y_first_order.squeeze(1) 43 | 44 | feat_emb = self.emb_layer(feat_index) # N * num_fields * emb_dim 45 | feat_emb_value = torch.mul(feat_emb, feat_value) # N * num_fields * emb_dim 46 | 47 | p, q = build_cross(self.num_fields, feat_emb_value) # N * num_pairs * emb_dim 48 | pair_wise_inter = torch.mul(p, q) 49 | 50 | att_signal = self.att_layer(pair_wise_inter) # N * num_pairs 51 | att_signal = att_signal.unsqueeze(dim=2) # N * num_pairs * 1 52 | 53 | att_inter = torch.mul(att_signal, pair_wise_inter) # N * num_pairs * emb_dim 54 | att_pooling = torch.sum(att_inter, dim=1) # N * emb_dim 55 | 56 | att_pooling = torch.mul(att_pooling, self.p) # N * emb_dim 57 | att_pooling = torch.sum(att_pooling, dim=1) # N 58 | 59 | y = self.bias + y_first_order + att_pooling 60 | y = self.output_layer(y) 61 | return y 62 | -------------------------------------------------------------------------------- /model/ctr/autoInt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.basic.output_layer import OutputLayer 6 | 7 | """ 8 | Model: AutoInt 9 | Version: arXiv [v2] Fri, 23 Aug 2019 19:51:41 UTC (1,940 KB) 10 | Reference: Song, W., Shi, C., Xiao, Z., Duan, Z., Xu, Y., Zhang, M., & Tang, J. (2019). 11 | AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks. 12 | Conference on Information and Knowledge Management. 13 | """ 14 | 15 | 16 | class AutoInt(nn.Module): 17 | def __init__(self, emb_dim, projection_dim, num_heads, num_feats, num_fields, use_res=True, out_type='binary'): 18 | super(AutoInt, self).__init__() 19 | self.emb_dim = emb_dim 20 | self.projection_dim = projection_dim 21 | self.num_heads = num_heads 22 | self.num_feats = num_feats 23 | self.num_fields = num_fields 24 | 25 | self.emb_layer = nn.Embedding(num_embeddings=num_feats, embedding_dim=emb_dim) 26 | nn.init.xavier_uniform_(self.emb_layer.weight) 27 | 28 | self.query_weights = nn.Parameter(torch.zeros(emb_dim, projection_dim * num_heads)) 29 | nn.init.xavier_uniform_(self.query_weights.data) 30 | 31 | self.key_weights = nn.Parameter(torch.zeros(emb_dim, projection_dim * num_heads)) 32 | nn.init.xavier_uniform_(self.key_weights.data) 33 | 34 | self.value_weights = nn.Parameter(torch.zeros(emb_dim, projection_dim * num_heads)) 35 | nn.init.xavier_uniform_(self.value_weights.data) 36 | 37 | self.use_res = use_res 38 | if use_res: 39 | self.res_weights = nn.Parameter(torch.zeros(emb_dim, projection_dim * num_heads)) 40 | nn.init.xavier_uniform_(self.res_weights.data) 41 | 42 | self.output_layer = OutputLayer(in_dim=num_fields * num_heads * projection_dim, out_type=out_type) 43 | 44 | def forward(self, feat_index): 45 | # for each field, there is a multi-head self-attention, 46 | # so the calculation is num_heads * num_fields (inner-product), 47 | # and the total calculation is num_heads * num_fields * num_fields 48 | feat_emb = self.emb_layer(feat_index) # N * num_fields * emb_dim 49 | 50 | queries = torch.matmul(feat_emb, self.query_weights) # N * num_fields * (pro_dim * num_heads) 51 | queries = torch.split(queries, self.projection_dim, dim=2) # [N * num_fields * pro_dim] * num_heads 52 | queries = torch.stack(queries, dim=1) # N * num_heads * num_fields * pro_dim 53 | 54 | keys = torch.matmul(feat_emb, self.key_weights) # N * num_fields * (pro_dim * num_heads) 55 | keys = torch.split(keys, self.projection_dim, dim=2) # [N * num_fields * pro_dim] * num_heads 56 | keys = torch.stack(keys, dim=1) # N * num_heads * num_fields * pro_dim 57 | 58 | values = torch.matmul(feat_emb, self.value_weights) # N * num_fields * (pro_dim * num_heads) 59 | values = torch.split(values, self.projection_dim, dim=2) # [N * num_fields * pro_dim] * num_heads 60 | values = torch.stack(values, dim=1) # N * num_heads * num_fields * pro_dim 61 | 62 | keys = keys.transpose(2, 3) 63 | # the i^th row of inner-product (pro_dim * pro_dim) means the attention signal when the i^th field is the query 64 | inner_product_qk = torch.matmul(queries, keys) # N * num_heads * num_fields * num_fields 65 | 66 | # here the inner-product is not scaled by sqrt(n) 67 | att_signal = F.softmax(inner_product_qk, dim=2) # N * num_heads * num_fields * num_fields 68 | att_value = torch.matmul(att_signal, values) # N * num_heads * num_fields * pro_dim 69 | att_values = torch.split(att_value, 1, dim=1) # [N * 1 * num_fields * pro_dim] * num_heads 70 | att_values = torch.cat(att_values, dim=3) # N * 1 * num_fields * (pro_dim * num_heads) 71 | multi_head_emb = att_values.squeeze() # N * num_fields * (pro_dim * num_heads) 72 | 73 | if self.use_res: 74 | res = torch.matmul(feat_emb, self.res_weights) # N * num_fields * (pro_dim * num_heads) 75 | multi_head_emb = torch.add(multi_head_emb, res) # N * num_fields * (pro_dim * num_heads) 76 | 77 | multi_head_emb = F.relu(multi_head_emb) 78 | multi_head_emb = multi_head_emb.reshape((-1, self.num_fields * self.num_heads * self.projection_dim)) 79 | y = self.output_layer(multi_head_emb) 80 | return y 81 | -------------------------------------------------------------------------------- /model/ctr/dcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.basic.mlp import MLP 5 | from model.basic.output_layer import OutputLayer 6 | 7 | """ 8 | Model: DCN: Deep & Cross Network 9 | Version: ADKDD 2017 10 | Reference: Ruoxi Wang, Bin Fu, Gang Fu, and Mingliang Wang. 2017. 11 | Deep & Cross Network for Ad Click Predictions. 12 | In Proceedings of the ADKDD’17 (ADKDD’17). Association for Computing Machinery, 13 | New York, NY, USA, Article 12, 1–7. 14 | DOI:https://doi.org/10.1145/3124749.3124754 15 | """ 16 | 17 | 18 | class DCN(nn.Module): 19 | def __init__(self, emb_dim, num_feats, num_cate_fields, num_cont_fields, cross_depth, fc_dims=None, 20 | dropout=None, batch_norm=None, out_type='binary'): 21 | super(DCN, self).__init__() 22 | self.emb_dim = emb_dim 23 | self.num_feats = num_feats 24 | self.num_cate_fields = num_cate_fields 25 | self.num_cont_fields = num_cont_fields 26 | 27 | self.cross_depth = cross_depth 28 | # embedding for category features 29 | self.emb_layer = nn.Embedding(num_embeddings=num_feats - num_cont_fields, embedding_dim=emb_dim) 30 | nn.init.xavier_uniform_(self.emb_layer.weight) 31 | 32 | # deep network 33 | if not fc_dims: 34 | fc_dims = [32, 32] 35 | self.fc_dims = fc_dims 36 | x0_dim = num_cont_fields + num_cate_fields * emb_dim 37 | self.deep = MLP(x0_dim, fc_dims, dropout, batch_norm) 38 | 39 | # cross network 40 | cross_layers = [] 41 | for _ in range(cross_depth): 42 | cross_layers.append(CrossLayer(x0_dim)) 43 | self.cross = nn.ModuleList(cross_layers) 44 | 45 | self.out_layer = OutputLayer(in_dim=fc_dims[-1] + x0_dim, out_type=out_type) 46 | 47 | def forward(self, continuous_value, categorical_index): 48 | cate_emb_value = self.emb_layer(categorical_index) # N * num_cate_fields * emb_dim 49 | # N * (num_cate_fields * emb_dim) 50 | cate_emb_value = cate_emb_value.reshape((-1, self.num_cate_fields * self.emb_dim)) 51 | x0 = torch.cat([continuous_value, cate_emb_value], 1) 52 | 53 | y_dnn = self.deep(x0) 54 | 55 | xi = x0 56 | for cross_layer in self.cross_depth: 57 | xi = cross_layer(x0, xi) 58 | 59 | output = torch.cat([y_dnn, xi], dim=1) 60 | output = self.out_layer(output) 61 | return output 62 | 63 | 64 | class CrossLayer(nn.Module): 65 | def __init__(self, x_dim): 66 | super(CrossLayer, self).__init__() 67 | self.x_dim = x_dim 68 | self.weights = nn.Parameter(torch.zeros(x_dim, 1)) # x_dim * 1 69 | nn.init.xavier_uniform_(self.weights.data) 70 | self.bias = nn.Parameter(torch.randn(x_dim)) # x_dim 71 | 72 | def forward(self, x0, xi): 73 | # x0,x1: N * x_dim 74 | x = torch.mul(xi, self.weights) # N * x_dim 75 | x = torch.sum(x, dim=1) # N 76 | x = x.unsqueeze(dim=1) # N * 1 77 | x = torch.mul(x, x0) # N * x_dim 78 | x = x + self.bias + xi 79 | return x 80 | -------------------------------------------------------------------------------- /model/ctr/deepfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.basic.mlp import MLP 4 | from model.ctr.fm import FM 5 | from model.basic.output_layer import OutputLayer 6 | 7 | """ 8 | Model: DeepFM 9 | Version: IJCAI 2017 10 | Reference: Guo, H., Tang, R., Ye, Y., Li, Z., & He, X. (2017). 11 | DeepFM: A Factorization-Machine based Neural Network for CTR Prediction. 12 | Proceedings of the Twenty-Sixth International Joint Conference on Artificial Intelligence 13 | """ 14 | 15 | 16 | class DeepFM(nn.Module): 17 | 18 | def __init__(self, emb_dim, feat_dim, num_fields, fc_dims=None, dropout=None, batch_norm=None, out_type='binary'): 19 | super(DeepFM, self).__init__() 20 | # embedding layer is embedded in the FM sub-module 21 | self.emb_dim = emb_dim 22 | 23 | # fm 24 | self.fm = FM(emb_dim, feat_dim, out_type='regression') 25 | 26 | # dnn 27 | if not fc_dims: 28 | fc_dims = [32, 32, 32] 29 | self.fc_dims = fc_dims 30 | self.num_fields = num_fields 31 | self.dnn = MLP(emb_dim * num_fields, fc_dims, dropout, batch_norm) 32 | 33 | # output 34 | self.output_layer = OutputLayer(fc_dims[-1] + 1, out_type) 35 | 36 | def forward(self, feat_index, feat_value): 37 | # embedding 38 | emb_layer = self.fm.get_embedding() 39 | feat_emb = emb_layer(feat_index) 40 | 41 | # compute y_FM 42 | y_fm = self.fm(feat_index, feat_value) # N 43 | y_fm = y_fm.unsqueeze(1) # N * 1 44 | 45 | # compute y_dnn 46 | # reshape the embedding matrix to a vector 47 | dnn_in = feat_emb.reshape(-1, self.emb_dim * self.num_fields) # N * (emb_dim * num_fields) 48 | y_dnn = self.dnn(dnn_in) # N * fc_dims[-1] 49 | 50 | # compute output 51 | y = torch.cat((y_fm, y_dnn), dim=1) # N * (fc_dims[-1] + 1) 52 | y = self.output_layer(y) 53 | return y 54 | -------------------------------------------------------------------------------- /model/ctr/flen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.basic.output_layer import OutputLayer 4 | from model.basic.mlp import MLP 5 | from model.basic.functional import build_cross, bi_interaction 6 | 7 | """ 8 | Model: FLEN 9 | Version: arXiv [v3] Tue, 3 Mar 2020 06:46:18 UTC 10 | Reference: Chen, W., Zhan, L., Ci, Y., & Lin, C. (2019). 11 | FLEN: Leveraging Field for Scalable CTR Prediction. 12 | arXiv: Information Retrieval,. 13 | """ 14 | 15 | 16 | class FLEN(nn.Module): 17 | def __init__(self, emb_dim, num_feats, num_categories, field_ranges, fc_dims=None, dropout=None, batch_norm=None, 18 | out_type='binary'): 19 | super(FLEN, self).__init__() 20 | self.num_feats = num_feats 21 | self.emb_dim = emb_dim 22 | self.num_categories = num_categories 23 | if not field_ranges: 24 | field_ranges = torch.tensor(range(num_categories)) 25 | self.field_ranges = field_ranges 26 | self.num_fields = len(field_ranges) 27 | 28 | # embedding layer 29 | self.emb_layer = nn.Embedding(num_embeddings=num_feats, embedding_dim=emb_dim) 30 | nn.init.xavier_uniform_(self.emb_layer.weight) 31 | 32 | # S part 33 | self.first_order_weights = nn.Embedding(num_embeddings=num_categories, embedding_dim=1) 34 | nn.init.xavier_uniform_(self.first_order_weights.weight) 35 | self.first_order_bias = nn.Parameter(torch.randn(1)) 36 | 37 | # MF part 38 | self.num_pairs = self.num_fields * (self.num_fields - 1) / 2 39 | self.r_mf = nn.Parameter(torch.zeros(self.num_pairs, 1)) # num_pairs * 1 40 | nn.init.xavier_uniform_(self.r_mf.data) 41 | 42 | # FM part 43 | self.r_fm = nn.Parameter(torch.zeros(self.num_fields, 1)) # num_fields * 1 44 | nn.init.xavier_uniform_(self.r_fm.data) 45 | 46 | # dnn 47 | if not fc_dims: 48 | fc_dims = [32, 32, 32] 49 | self.fc_dims = fc_dims 50 | self.fc_layers = MLP(fc_dims, dropout, batch_norm) 51 | 52 | self.output_layer = OutputLayer(fc_dims[-1] + 1 + self.emb_dim, out_type) 53 | 54 | def forward(self, feat_index): 55 | feat_emb = self.emb_layer(feat_index) # N * num_categories * emb_dim 56 | 57 | field_wise_emb_list = [ 58 | feat_emb[:, field_range] # N * num_categories_in_field * emb_dim 59 | for field_range in self.field_ranges 60 | ] 61 | 62 | field_emb_list = [ 63 | torch.sum(field_wise_emb, dim=1).unsqueeze(dim=1) # N * emb_dim 64 | for field_wise_emb in field_wise_emb_list 65 | ] 66 | field_emb = torch.cat(field_emb_list, dim=1) # N * num_fields * emb_dim 67 | # S part 68 | y_S = self.first_order_weights(feat_index) # N * num_categories * 1 69 | y_S = y_S.squeeze() # N * num_categories 70 | y_S = torch.sum(y_S, dim=1) # N 71 | y_S = torch.add(y_S, self.first_order_bias) # N 72 | y_S = y_S.unsqueeze(dim=1) # N * 1 73 | 74 | # MF part -> N * emb_dim 75 | p, q = build_cross(self.num_fields, field_emb) # N * num_pairs * emb_dim 76 | y_MF = torch.mul(p, q) # N * num_pairs * emb_dim 77 | y_MF = torch.mul(y_MF, self.r_mf) # N * num_pairs * emb_dim 78 | y_MF = torch.sum(y_MF, dim=1) # N * emb_dim 79 | 80 | # FM part 81 | field_wise_fm = [ 82 | bi_interaction(field_wise_emb).unsqueeze(dim=1) # N * 1 * emb_dim 83 | for field_wise_emb in field_wise_emb_list 84 | ] 85 | field_wise_fm = torch.cat(field_wise_fm, dim=1) # N * num_fields * emb_dim 86 | y_FM = torch.mul(field_wise_fm, self.r_fm) # N * num_fields * emb_dim 87 | y_FM = torch.sum(y_FM, dim=1) # N * emb_dim 88 | 89 | # dnn 90 | fc_in = field_emb.reshape((-1, self.num_fields * self.emb_dim)) 91 | y_dnn = self.fc_layers(fc_in) 92 | 93 | # output 94 | fwBI = y_MF + y_FM 95 | fwBI = torch.cat([y_S, fwBI], dim=1) # N * (emb_dim + 1) 96 | y = torch.cat([fwBI, y_dnn], dim=1) # N * (fc_dims[-1] + emb_dim + 1) 97 | y = self.output_layer(y) 98 | return y 99 | -------------------------------------------------------------------------------- /model/ctr/fm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.basic.output_layer import OutputLayer 4 | from model.basic.functional import bi_interaction 5 | 6 | """ 7 | Model: FM: Factorization Machines 8 | Version: 9 | Reference: Steffen Rendle. 2010. 10 | Factorization Machines. 11 | In Proceedings of the 2010 IEEE International Conference on Data Mining (ICDM ’10). 12 | IEEE Computer Society, USA, 995–1000. 13 | DOI:https://doi.org/10.1109/ICDM.2010.127 14 | """ 15 | 16 | 17 | class FM(nn.Module): 18 | def __init__(self, emb_dim, num_feats, out_type='binary'): 19 | super(FM, self).__init__() 20 | self.emb_dim = emb_dim 21 | self.num_feats = num_feats 22 | self.emb_layer = nn.Embedding(num_embeddings=num_feats, embedding_dim=emb_dim) 23 | nn.init.xavier_uniform_(self.emb_layer.weight) 24 | self.bias = nn.Parameter(torch.randn(1)) 25 | self.first_order_weights = nn.Embedding(num_embeddings=num_feats, embedding_dim=1) 26 | nn.init.xavier_uniform_(self.first_order_weights.weight) 27 | self.output_layer = OutputLayer(1, out_type) 28 | 29 | def forward(self, feat_index, feat_value): 30 | # With single sample, it should be expanded into 1 * F * K 31 | # Batch_size: N 32 | # feat_index_dim&feat_value_dim: F 33 | # embedding_dim: K 34 | 35 | # feat_index: N * F 36 | # feat_value: N * F 37 | 38 | # compute first order 39 | feat_value = torch.unsqueeze(feat_value, dim=2) # N * F * 1 40 | first_order_weights = self.first_order_weights(feat_index) # N * F * 1 41 | first_order = torch.mul(feat_value, first_order_weights) # N * F * 1 42 | first_order = torch.squeeze(first_order, dim=2) # N * F 43 | y_first_order = torch.sum(first_order, dim=1) # N 44 | 45 | # compute second order 46 | # look up embedding table 47 | feat_emb = self.emb_layer(feat_index) # N * F * K 48 | feat_emb_value = torch.mul(feat_emb, feat_value) # N * F * K element-wise mul 49 | 50 | # compute sum of square 51 | # squared_feat_emb = torch.pow(feat_emb_value, 2) # N * K 52 | # sum_of_square = torch.sum(squared_feat_emb, dim=1) # N * K 53 | # 54 | # # compute square of sum 55 | # summed_feat_emb = torch.sum(feat_emb_value, dim=1) # N * K 56 | # square_of_sum = torch.pow(summed_feat_emb, 2) # N * K 57 | 58 | BI = bi_interaction(feat_emb_value) 59 | 60 | y_second_order = torch.sum(BI, dim=1) # N 61 | 62 | # compute y 63 | y = self.bias + y_first_order + y_second_order # N 64 | y = self.output_layer(y) 65 | return y 66 | 67 | def get_embedding(self): 68 | return self.emb_layer 69 | -------------------------------------------------------------------------------- /model/ctr/fnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.ctr.fm import FM 4 | from model.basic.mlp import MLP 5 | from model.basic.output_layer import OutputLayer 6 | 7 | """ 8 | Model: FNN: Factorization-machine supported Neural Network 9 | Version: arXiv [v1] Mon, 11 Jan 2016 10:04:40 UTC 10 | Reference: Zhang, W., Du, T., & Wang, J. (2016). 11 | Deep Learning over Multi-field Categorical Data: A Case Study on User Response Prediction. 12 | arXiv: Learning,. 13 | """ 14 | 15 | 16 | class FNN(nn.Module): 17 | 18 | def __init__(self, emb_dim, num_feats, num_fields, fc_dims=None, dropout=None, batch_norm=None, out_type='binary', 19 | train_fm=True): 20 | super(FNN, self).__init__() 21 | # set model object to training FNN or training FM embedding 22 | self.fm_trained = not train_fm 23 | 24 | # embedding layer is embedded in the FM sub-module 25 | self.emb_dim = emb_dim 26 | self.num_feats = num_feats 27 | 28 | # fc layers 29 | if not fc_dims: 30 | fc_dims = [32, 32] 31 | self.fc_dims = fc_dims 32 | self.num_fields = num_fields 33 | self.fc_layers = MLP(emb_dim * num_fields, fc_dims, dropout, batch_norm) 34 | 35 | # fm model as the pre-trained embedding layer 36 | self.fm = FM(emb_dim, num_feats, out_type) 37 | 38 | # output 39 | self.output_layer = OutputLayer(fc_dims[-1], out_type) 40 | 41 | def forward(self, feat_index, feat_value): 42 | if not self.fm_trained: 43 | y = self.fm(feat_index, feat_value) 44 | else: 45 | emb_layer = self.fm.get_embedding() 46 | feat_emb = emb_layer(feat_index) 47 | 48 | # reshape the embedding matrix to a vector 49 | fc_in = feat_emb.reshape(-1, self.emb_dim * self.num_fields) 50 | 51 | y = self.mlp(fc_in) 52 | 53 | # compute output 54 | y = self.output_layer(y) 55 | return y 56 | 57 | def train_fm_embedding(self): 58 | self.fm_trained = True 59 | 60 | def train_fnn(self): 61 | self.fm_trained = False 62 | -------------------------------------------------------------------------------- /model/ctr/gbdt_lr.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from model.basic.gbdt import GBDT 4 | from model.ctr.lr import LR 5 | 6 | """ 7 | Model: GBDT+LR: Gradient Boosting Decision Tree with Logistic Regression 8 | Version: arXiv [v1] Mon, 11 Jan 2016 10:04:40 UTC 9 | Reference: He, X., Pan, J., Jin, O., Xu, T., Liu, B., Xu, T., ... & Candela, J. Q. (2014). 10 | Practical Lessons from Predicting Clicks on Ads at Facebook. 11 | International workshop on Data Mining for Online Advertising. 12 | """ 13 | 14 | class GBDTLR(nn.Module): 15 | def __init__(self, num_leaves=31, max_depth=-1, n_estimators=100, min_data_in_leaf=20, 16 | learning_rate=0.1, objective='binary'): 17 | super(GBDTLR, self).__init__() 18 | self.gbdt = GBDT(num_leaves, max_depth, n_estimators, min_data_in_leaf, learning_rate, objective) 19 | self.gbdt_trained = False 20 | self.logistic_layer = LR(num_leaves * n_estimators, out_type=objective) 21 | 22 | def forward(self, data): 23 | pred_y, feat_index = self.gbdt.pred(data) 24 | y = self.logistic_layer(feat_index) 25 | return y 26 | 27 | def train_gbdt(self, data, y): 28 | self.gbdt.train(data, y) 29 | self.gbdt_trained = True 30 | 31 | def get_gbdt_trained(self): 32 | return self.gbdt_trained 33 | -------------------------------------------------------------------------------- /model/ctr/lr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.basic.output_layer import OutputLayer 4 | 5 | 6 | class LR(nn.Module): 7 | def __init__(self, num_feats, out_type='binary'): 8 | super(LR, self).__init__() 9 | self.num_feats = num_feats 10 | self.weights = nn.Embedding(num_embeddings=num_feats, embedding_dim=1) 11 | self.bias = nn.Parameter(torch.randn(1)) 12 | self.output_layer = OutputLayer(1, out_type) 13 | 14 | def forward(self, feat_index, feat_value): 15 | weights = self.weights(feat_index) # N * F * 1 16 | feat_value = torch.unsqueeze(feat_value, dim=2) # N * F * 1 17 | first_order = torch.mul(feat_value, weights) # N * F * 1 18 | first_order = torch.squeeze(first_order, dim=2) # N * F 19 | y = torch.sum(first_order, dim=1) 20 | y += self.bias 21 | 22 | y = self.output_layer(y) 23 | return y 24 | -------------------------------------------------------------------------------- /model/ctr/nfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.basic.mlp import MLP 4 | from model.basic.output_layer import OutputLayer 5 | from model.basic.functional import bi_interaction 6 | 7 | """ 8 | Model: NFM: Neural Factorization Machines 9 | Version: 10 | Reference: Xiangnan He and Tat-Seng Chua. 2017. 11 | Neural Factorization Machines for Sparse Predictive Analytics. 12 | In Proceedings of the 40th International ACM SIGIR Conference on Research and Development in Information 13 | Retrieval (SIGIR ’17). 14 | Association for Computing Machinery, New York, NY, USA, 355–364. 15 | DOI:https://doi.org/10.1145/3077136.3080777 16 | """ 17 | 18 | 19 | class NFM(nn.Module): 20 | def __init__(self, emb_dim, num_feats, num_fields, fc_dims=None, dropout=None, batch_norm=None, out_type='binary'): 21 | super(NFM, self).__init__() 22 | self.emb_dim = emb_dim 23 | self.num_feats = num_feats 24 | self.num_fields = num_fields 25 | 26 | self.first_order_weights = nn.Embedding(num_embeddings=num_feats, embedding_dim=1) 27 | nn.init.xavier_uniform_(self.first_order_weights.weight) 28 | self.first_order_bias = nn.Parameter(torch.randn(1)) 29 | 30 | self.emb_layer = nn.Embedding(num_embeddings=num_feats, embedding_dim=emb_dim) 31 | nn.init.xavier_uniform_(self.emb_layer.weight) 32 | 33 | self.bi_intaraction_layer = BiInteractionLayer() 34 | if not fc_dims: 35 | fc_dims = [32, 32] 36 | self.fc_dims = fc_dims 37 | self.fc_layers = MLP(emb_dim, fc_dims, dropout, batch_norm) 38 | 39 | self.h = nn.Parameter(torch.zeros(1, fc_dims[-1])) # 1 * fc_dims[-1] 40 | nn.init.xavier_uniform_(self.h.data) 41 | self.output_layer = OutputLayer(in_dim=1, out_type=out_type) 42 | 43 | def forward(self, feat_index, feat_value): 44 | # feat_index, feat_value: N * num_fields 45 | first_order_weights = self.first_order_weights(feat_index) # N * num_fields * 1 46 | first_order_weights = first_order_weights.squeeze() 47 | first_order = torch.mul(feat_value, first_order_weights) # N * num_fields 48 | first_order = torch.sum(first_order, dim=1) # N 49 | 50 | feat_emb = self.emb_layer(feat_index) # N * num_fields * emb_dim 51 | feat_value = feat_value.unsqueeze(dim=2) # N * num_fields * 1 52 | feat_emb_value = torch.mul(feat_emb, feat_value) # N * num_fields * emb_dim 53 | bi = self.bi_intaraction_layer(feat_emb_value) # N * emb_dim 54 | 55 | fc_out = self.fc_layers(bi) # N * fc_dims[-1] 56 | out = torch.mul(fc_out, self.h) # N * fc_dims[-1] 57 | out = torch.sum(out, dim=1) # N 58 | out = out + first_order + self.first_order_bias # N 59 | out = out.unsqueeze(dim=1) # N * 1 60 | out = self.output_layer(out) 61 | return out 62 | 63 | 64 | class BiInteractionLayer(nn.Module): 65 | def __init__(self): 66 | super(BiInteractionLayer, self).__init__() 67 | 68 | def forward(self, feat_emb_value): 69 | # square_of_sum = torch.sum(feat_emb_value, dim=1) # N * emb_dim 70 | # square_of_sum = torch.mul(square_of_sum, square_of_sum) # N * emb_dim 71 | 72 | # sum_of_square = torch.mul(feat_emb_value, feat_emb_value) # N * num_fields * emb_dim 73 | # sum_of_square = torch.sum(sum_of_square, dim=1) # N * emb_dim 74 | 75 | # bi_out = square_of_sum - sum_of_square 76 | 77 | bi_out = bi_interaction(feat_emb_value) 78 | return bi_out 79 | -------------------------------------------------------------------------------- /model/ctr/pnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.basic.mlp import MLP 4 | from model.basic.output_layer import OutputLayer 5 | from model.basic.functional import build_cross 6 | 7 | """ 8 | Model: PNN: Product-based Neural Network 9 | Version: arXiv [v1] Tue, 1 Nov 2016 07:10:22 UTC 10 | Reference: Y. Qu et al., 11 | Product-Based Neural Networks for User Response Prediction, 12 | 2016 IEEE 16th International Conference on Data Mining (ICDM), 13 | Barcelona, 2016, pp. 1149-1154, 14 | doi: 10.1109/ICDM.2016.0151. 15 | """ 16 | 17 | 18 | class PNN(nn.Module): 19 | 20 | def __init__(self, emb_dim, num_feats, num_fields, fc_dims=None, dropout=None, batch_norm=None, 21 | product_type='inner', out_type='binary'): 22 | super(PNN, self).__init__() 23 | # embedding layer 24 | self.emb_dim = emb_dim 25 | self.num_feats = num_feats 26 | self.num_fields = num_fields 27 | self.emb_layer = nn.Embedding(num_embeddings=self.num_feats, 28 | embedding_dim=self.emb_dim) 29 | nn.init.xavier_uniform_(self.emb_layer.weight) 30 | 31 | # linear signal layer, named l_z 32 | if not fc_dims: 33 | fc_dims = [32, 32] 34 | self.d1 = d1 = fc_dims[0] 35 | self.product_type = product_type 36 | if product_type == '*': 37 | d1 *= 2 38 | self.linear_signal_weights = nn.Linear(in_features=num_fields * emb_dim, out_features=d1) 39 | nn.init.xavier_uniform_(self.linear_signal_weights.weight) 40 | 41 | # product layer, named l_p 42 | if product_type == 'inner': 43 | self.product_layer = InnerProductLayer(num_fields, d1) 44 | elif product_type == 'outer': 45 | self.product_layer = OuterProductLayer(emb_dim, num_fields, d1) 46 | else: 47 | self.product_layer = HybridProductLayer(emb_dim, num_fields, d1) 48 | 49 | # fc layers 50 | # l_1=relu(l_z+l_p_b_1) 51 | self.l1_layer = nn.ReLU() 52 | self.l1_bias = nn.Parameter(torch.randn(d1)) 53 | # l_2 to l_n 54 | self.fc_dims = fc_dims 55 | self.fc_layers = MLP(d1, self.fc_dims, dropout, batch_norm) 56 | 57 | # output layer 58 | self.output_layer = OutputLayer(fc_dims[-1], out_type) 59 | 60 | def forward(self, feat_index): 61 | # feat_index: N * num_fields 62 | feat_emb = self.emb_layer(feat_index) # N * num_fields * emb_dim 63 | 64 | # compute linear signal l_z 65 | concat_z = feat_emb.reshape(-1, self.emb_dim * self.num_fields) 66 | linear_signal = self.linear_signal_weights(concat_z) 67 | 68 | # product_layer 69 | product_out = self.product_layer(feat_emb) 70 | 71 | # fc layers from l_2 to l_n 72 | # l_1=relu(l_z+l_p_b_1) 73 | l1_in = torch.add(linear_signal, self.l1_bias) 74 | l1_in = torch.add(l1_in, product_out) 75 | l1_out = self.l1_layer(l1_in) 76 | y = self.fc_layers(l1_out) 77 | y = self.output_layer(y) 78 | return y 79 | 80 | 81 | class InnerProductLayer(nn.Module): 82 | def __init__(self, num_fields, d1): 83 | super(InnerProductLayer, self).__init__() 84 | self.num_fields = num_fields 85 | self.d1 = d1 86 | self.num_pairs = int(num_fields * (num_fields - 1) / 2) 87 | self.product_layer_weights = nn.Linear(in_features=self.num_pairs, out_features=d1) 88 | nn.init.xavier_uniform_(self.product_layer_weights.weight) 89 | 90 | def forward(self, feat_emb): 91 | # feat_emb: N * num_fields * emb_dim 92 | 93 | # p_ij= 94 | # p is symmetric matrix, so only upper triangular matrix needs calculation (without diagonal) 95 | p, q = build_cross(self.num_fields, feat_emb) 96 | pij = p * q # N * num_pairs * emb_dim 97 | pij = torch.sum(pij, dim=2) # N * num_pairs 98 | 99 | # l_p 100 | lp = self.product_layer_weights(pij) 101 | return lp 102 | 103 | 104 | class OuterProductLayer(nn.Module): 105 | def __init__(self, emb_dim, num_fields, d1, kernel_type='mat'): 106 | super(OuterProductLayer, self).__init__() 107 | self.emb_dim = emb_dim 108 | self.num_fields = num_fields 109 | self.d1 = d1 110 | self.num_pairs = num_fields * (num_fields - 1) / 2 111 | self.kernel_type = kernel_type 112 | if kernel_type == 'vec': 113 | kernel_shape = (self.num_pairs, emb_dim) 114 | elif kernel_type == 'num': 115 | kernel_shape = (self.num_pairs, 1) 116 | else: # by default mat 117 | kernel_shape = (emb_dim, self.num_pairs, emb_dim) 118 | self.kernel_shape = kernel_shape 119 | self.kernel = nn.Parameter(torch.zeros(kernel_shape)) 120 | nn.init.xavier_uniform_(self.kernel.data) 121 | self.num_pairs = num_fields * (num_fields - 1) / 2 122 | self.product_layer_weights = nn.Linear(in_features=num_fields, out_features=d1) 123 | nn.init.xavier_uniform_(self.product_layer_weights.weight) 124 | 125 | def forward(self, feat_emb): 126 | p, q = build_cross(self.num_fields, feat_emb) # p, q: N * num_pairs * emb_dim 127 | 128 | if self.kernel_type == 'mat': 129 | # self.kernel: emb_dim * num_pairs * emb_dim 130 | p = p.unsqueeze(1) # N * 1 * num_pairs * emb_dim 131 | p = p * self.kernel # N * emb_dim * num_pairs * emb_dim 132 | kp = torch.sum(p, dim=-1) # N * emb_dim * num_pairs 133 | kp = kp.permute(0, 2, 1) # N * num_pairs * emb_dim 134 | pij = torch.sum(kp * q, -1) # N * num_pairs 135 | else: 136 | # self.kernel: num_pairs * emb_dim/1 137 | kernel = self.kernel.unsqueeze(1) # 1 * num_pairs * emb_dim/1 138 | pij = p * q # N * num_pairs * emb_dim 139 | pij = pij * kernel # N * num_pairs * emb_dim 140 | pij = torch.sum(pij, -1) # N * num_pairs 141 | 142 | # l_p 143 | lp = self.product_layer_weights(pij) 144 | return lp 145 | 146 | 147 | class HybridProductLayer(nn.Module): 148 | def __init__(self, emb_dim, num_fields, d1): 149 | super(HybridProductLayer, self).__init__() 150 | self.num_fields = num_fields 151 | self.d1 = d1 / 2 152 | self.inner_product_layer = InnerProductLayer(num_fields, d1) 153 | self.outer_product_layer = OuterProductLayer(emb_dim, num_fields, d1) 154 | 155 | def forward(self, feat_emb): 156 | inner_product_out = self.inner_product_layer(feat_emb) 157 | outer_product_out = self.outer_product_layer(feat_emb) 158 | lp = torch.cat([inner_product_out, outer_product_out], dim=1) 159 | return lp 160 | -------------------------------------------------------------------------------- /model/ctr/wide_and_deep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.basic.output_layer import OutputLayer 5 | from model.basic.mlp import MLP 6 | 7 | """ 8 | Model: WDL: Wide and Deep Learning 9 | Version: DLRS 2016 10 | Reference: Heng-Tze Cheng, Levent Koc, Jeremiah Harmsen, Tal Shaked, Tushar Chandra, Hrishi Aradhye, Glen Anderson, 11 | Greg Corrado, Wei Chai, Mustafa Ispir, Rohan Anil, Zakaria Haque, Lichan Hong, Vihan Jain, Xiaobing Liu, 12 | and Hemal Shah. 2016. 13 | Wide & Deep Learning for Recommender Systems. 14 | In Proceedings of the 1st Workshop on Deep Learning for Recommender Systems (DLRS 2016). 15 | Association for Computing Machinery, New York, NY, USA, 7–10. 16 | DOI:https://doi.org/10.1145/2988450.2988454 17 | """ 18 | 19 | 20 | class WideAndDeep(nn.Module): 21 | def __init__(self, emb_dim, num_feats, num_cate_fields, num_cont_fields, num_cross_feats, fc_dims=None, 22 | dropout=None, batch_norm=None, out_type='binary'): 23 | super(WideAndDeep, self).__init__() 24 | self.emb_dim = emb_dim 25 | self.num_feats = num_feats 26 | self.num_cate_fields = num_cate_fields 27 | self.num_cont_fields = num_cont_fields 28 | self.num_cross_feats = num_cross_feats 29 | 30 | # first order weight for category features 31 | self.cate_weights = nn.Embedding(num_embeddings=num_feats - num_cont_fields, embedding_dim=1) 32 | nn.init.xavier_uniform_(self.cate_weights.weight) 33 | 34 | # first order weight for continuous features 35 | self.cont_weights = nn.Linear(in_features=num_cont_fields, out_features=1) 36 | nn.init.xavier_uniform_(self.cont_weights) 37 | 38 | self.wide_bias = nn.Parameter(torch.randn(1)) 39 | 40 | if not fc_dims: 41 | fc_dims = [32, 32] 42 | fc_dims.append(1) 43 | self.fc_dims = fc_dims 44 | 45 | # embedding for deep network 46 | self.emb_layer = nn.Embedding(num_embeddings=num_feats - num_cont_fields, embedding_dim=emb_dim) 47 | nn.init.xavier_uniform_(self.emb_layer.weight) 48 | 49 | self.deep = MLP(num_cont_fields + num_cate_fields * emb_dim, fc_dims, dropout, batch_norm) 50 | self.out_layer = OutputLayer(in_dim=1, out_type=out_type) 51 | 52 | def forward(self, continuous_value, categorical_index, cross_feat_index): 53 | first_order_cate = self.cate_weights(categorical_index) 54 | first_order_cont = self.cont_weights(continuous_value) 55 | y_wide = first_order_cate + first_order_cont + self.wide_bias 56 | 57 | cate_emb_value = self.emb_layer(categorical_index) # N * num_cate_fields * emb_dim 58 | # N * (num_cate_fields * emb_dim) 59 | cate_emb_value = cate_emb_value.reshape((-1, self.num_cate_fields * self.emb_dim)) 60 | deep_in = torch.cat([continuous_value, cate_emb_value], 1) # N * (num_cate_fields * emb_dim + num_cont_fields) 61 | y_deep = self.deep(deep_in) # N * 1 62 | y = y_deep + y_wide 63 | y = self.out_layer(y) 64 | return y 65 | -------------------------------------------------------------------------------- /model/sequence/__init__.py: -------------------------------------------------------------------------------- 1 | from .din import DIN -------------------------------------------------------------------------------- /model/sequence/dien.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # _*_ coding: utf-8 _*_ 3 | # @Time : 2020/7/27 9:50 4 | # @Author : Zeroized 5 | # @File : dien.py 6 | # @desc : DIEN implementation 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from model.basic import MLP, OutputLayer, Dice 12 | from model.basic.attention import AttentionSignal 13 | 14 | """ 15 | Model: DIEN: Deep Interest Evolution Network 16 | Version: arXiv [v4] Thu, 13 Sep 2018 04:37:06 UTC 17 | Reference: Zhou, G., Mou, N., Fan, Y., Pi, Q., Bian, W., Zhou, C., ... & Gai, K. (2018). 18 | Deep Interest Evolution Network for Click-Through Rate Prediction. 19 | arXiv: Machine Learning. 20 | """ 21 | 22 | 23 | class DIEN(nn.Module): 24 | def __init__(self, u_emb_dim, c_emb_dim, g_emb_dim, fc_dims=None, ext_hidden_dim=32, evo_hidden_dim=32, 25 | activation_linear_dim=36, 26 | activation='dice', 27 | dropout=None, out_type='binary'): 28 | super(DIEN, self).__init__() 29 | 30 | self.extractor_layer = nn.GRU(g_emb_dim, ext_hidden_dim, 1, batch_first=True) 31 | 32 | self.attention_layer = AttentionSignal(query_dim=g_emb_dim, key_dim=ext_hidden_dim, similarity='general') 33 | 34 | self.evolution_layer = EvolutionLayer(ext_hidden_dim, evo_hidden_dim) 35 | 36 | if not fc_dims: 37 | fc_dims = [200, 80] 38 | self.fc_dims = fc_dims 39 | 40 | if activation == 'dice': 41 | self.activation = Dice() 42 | else: 43 | self.activation = nn.PReLU() 44 | self.fc_layers = MLP(u_emb_dim + c_emb_dim + g_emb_dim + evo_hidden_dim, 45 | fc_dims, dropout, None, self.activation) 46 | self.output_layer = OutputLayer(fc_dims[-1], out_type) 47 | 48 | def forward(self, history_feats, candidate_feat, user_profile_feat, context_feat): 49 | # history_feats: N * seq_length * g_emb_dim 50 | # candidate_feat: N * g_emb_dim 51 | # user_profile_feat: N * u_emb_dim 52 | # context_feat: N * c_emb_dim 53 | 54 | extracted_interest, _ = self.extractor_layer(history_feats) # [batch_size * ext_hidden_dim] * seq_length 55 | # extracted_interest = torch.stack(extracted_interest,dim=1) 56 | att_signal = self.attention_layer(candidate_feat, extracted_interest) # batch_size * seq_length 57 | 58 | evolved_interest = self.evolution_layer(extracted_interest, att_signal) # batch_size *evo_hidden_dim 59 | 60 | fc_in = torch.cat([evolved_interest, candidate_feat, user_profile_feat, context_feat], dim=1) 61 | fc_out = self.fc_layers(fc_in) 62 | output = self.output_layer(fc_out) 63 | return output 64 | 65 | 66 | class EvolutionLayer(nn.Module): 67 | def __init__(self, input_dim, cell_hidden_dim): 68 | super(EvolutionLayer, self).__init__() 69 | self.cell_hidden_dim = cell_hidden_dim 70 | self.cell = AUGRUCell(input_dim, cell_hidden_dim) 71 | 72 | def forward(self, extracted_interests: torch.Tensor, att_signal: torch.Tensor, hx=None): 73 | # extracted_interests: batch_size * seq_length * gru_hidden_dim 74 | # att_signal: batch_size * seq_length 75 | interests = extracted_interests.split(split_size=1, dim=1) # [batch_size * 1 * gru_hidden_dim] * seq_length 76 | att_signals = att_signal.split(split_size=1, dim=1) # [batch_size * 1] * seq_length 77 | 78 | if hx is None: 79 | hx = torch.zeros((extracted_interests.size(0),extracted_interests.size(2)), 80 | dtype=extracted_interests.dtype, 81 | device=extracted_interests.device) 82 | 83 | for interest, att in zip(interests, att_signals): 84 | hx = self.cell(interest.squeeze(), att, hx) 85 | return hx # batch_size * cell_hidden_dim 86 | 87 | 88 | class AUGRUCell(nn.Module): 89 | def __init__(self, input_dim, hidden_dim): 90 | super(AUGRUCell, self).__init__() 91 | self.input_dim = input_dim 92 | self.hidden_dim = hidden_dim 93 | self.input_weights = nn.Linear(in_features=input_dim, 94 | out_features=3 * hidden_dim) # concat[Wu,Wr,Wh] 95 | nn.init.xavier_uniform_(self.input_weights.weight) 96 | self.hidden_weights = nn.Linear(in_features=hidden_dim, 97 | out_features=3 * hidden_dim, 98 | bias=False) # concat[Uu,Ur,Uh] 99 | nn.init.xavier_uniform_(self.hidden_weights.weight) 100 | 101 | def forward(self, ix: torch.Tensor, att_signal: torch.Tensor, hx: torch.Tensor = None): 102 | # ix: batch_size * input_dim 103 | # att_signal: batch_size * 1 104 | # hx: batch_size * hidden_dim 105 | if hx is None: 106 | hx = torch.zeros(self.hidden_dim, dtype=ix.dtype, device=ix.device) 107 | weighted_inputs = self.input_weights(ix) # [batch_size * hidden_dim] * 3 108 | weighted_inputs = weighted_inputs.chunk(chunks=3, dim=1) 109 | weighted_hiddens = self.hidden_weights(hx) 110 | weighted_hiddens = weighted_hiddens.chunk(chunks=3, dim=1) # [1 * hidden_dim] * 3 111 | 112 | update_gate = weighted_inputs[0] + weighted_hiddens[0] # batch_size * hidden_dim 113 | update_gate = torch.sigmoid(update_gate) 114 | update_gate = torch.mul(att_signal, update_gate) 115 | 116 | reset_gate = weighted_hiddens[1] + weighted_hiddens[1] # batch_size * hidden_dim 117 | reset_gate = torch.sigmoid(reset_gate) 118 | 119 | hat_hidden = torch.mul(weighted_hiddens[2], reset_gate) 120 | hat_hidden = hat_hidden + weighted_inputs[2] 121 | hat_hidden = torch.tanh(hat_hidden) # batch_size * hidden_dim 122 | 123 | hidden = torch.mul((1 - update_gate), hx) + torch.mul(update_gate, hat_hidden) 124 | return hidden 125 | -------------------------------------------------------------------------------- /model/sequence/din.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.basic import OutputLayer, Dice, MLP 4 | 5 | """ 6 | Model: DIN: Deep Interest Network 7 | Version: arXiv [v4] Thu, 13 Sep 2018 04:37:06 UTC 8 | Reference: Guorui Zhou, Xiaoqiang Zhu, Chenru Song, Ying Fan, Han Zhu, Xiao Ma, Yanghui Yan, Junqi Jin, Han Li, 9 | and Kun Gai. 2018. 10 | Deep Interest Network for Click-Through Rate Prediction. 11 | In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining 12 | (KDD ’18). 13 | Association for Computing Machinery, New York, NY, USA, 1059–1068. 14 | DOI:https://doi.org/10.1145/3219819.3219823 15 | """ 16 | 17 | 18 | class DIN(nn.Module): 19 | def __init__(self, u_emb_dim, c_emb_dim, g_emb_dim, fc_dims=None, activation_linear_dim=36, activation='dice', 20 | dropout=None, out_type='binary'): 21 | super(DIN, self).__init__() 22 | self.activation_unit = ActivationUnit(g_emb_dim, activation_linear_dim, activation) 23 | if not fc_dims: 24 | fc_dims = [200, 80] 25 | self.fc_dims = fc_dims 26 | 27 | if activation == 'dice': 28 | self.activation = Dice() 29 | else: 30 | self.activation = nn.PReLU() 31 | self.fc_layers = MLP(u_emb_dim + c_emb_dim + 2 * g_emb_dim, fc_dims, dropout, None, self.activation) 32 | self.output_layer = OutputLayer(fc_dims[-1], out_type) 33 | 34 | def forward(self, history_feats, candidate_feat, user_profile_feat, context_feat): 35 | # history_feats: N * seq_length * g_emb_dim 36 | # candidate_feat: N * g_emb_dim 37 | # user_profile_feat: N * u_emb_dim 38 | # context_feat: N * c_emb_dim 39 | histories = torch.split(history_feats, 1, dim=1) # [N * g_emb_dim] * seq_length 40 | att_signals = [ 41 | self.activation_unit(history_feat.squeeze(), candidate_feat) # N * 1 42 | for history_feat in histories # N * g_emb_dim 43 | ] 44 | att_signal = torch.cat(att_signals, dim=1) # N * seq_length 45 | att_signal = att_signal.unsqueeze(dim=2) # N * seq_length * 1 46 | weighted = torch.mul(att_signal, history_feats) # N * seq_length * g_emb_dim 47 | weighted_pooling = torch.sum(weighted, dim=1) # N * g_emb_dim 48 | fc_in = torch.cat([user_profile_feat, weighted_pooling, candidate_feat, context_feat], dim=1) 49 | fc_out = self.fc_layers(fc_in) 50 | output = self.output_layer(fc_out) 51 | return output 52 | 53 | 54 | class ActivationUnit(nn.Module): 55 | def __init__(self, g_emb_dim, linear_dim=36, activation='dice'): 56 | super(ActivationUnit, self).__init__() 57 | self.g_emb_dim = g_emb_dim 58 | if activation == 'dice': 59 | self.activation = Dice() 60 | else: 61 | self.activation = nn.PReLU() 62 | self.linear = nn.Linear(in_features=3 * g_emb_dim, out_features=linear_dim) 63 | self.out = nn.Linear(in_features=linear_dim, out_features=1) 64 | 65 | def forward(self, history_feat, candidate_feat): 66 | # history_feat: N * g_emb_dim 67 | # candidate_feat: N * g_emb_dim 68 | 69 | # There is no definition for "out product" in the activation unit, so here we use K * Q instead as many 70 | # other implementations do. 71 | out_product = torch.mul(history_feat, candidate_feat) # N * g_emb_dim 72 | linear_in = torch.cat([history_feat, out_product, candidate_feat], dim=1) # N * (3 * g_emb_dim) 73 | linear_out = self.linear(linear_in) 74 | out = self.activation(linear_out) 75 | return self.out(out) # N * 1 76 | -------------------------------------------------------------------------------- /model/wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from model.wrapper.base import BaseModel 2 | from model.wrapper.ctr.fnn import FNNModel 3 | -------------------------------------------------------------------------------- /model/wrapper/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | from util.log_util import create_file_console_logger 5 | from util.train import config_path, split_dataset, train_model 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | 9 | class BaseModel: 10 | def __init__(self): 11 | self.loader_args = None 12 | self.model = None 13 | self.job_name = '' 14 | self.device = torch.device('cpu') 15 | self.logger = None, 16 | self.tb_writer = None, 17 | self.ckpt_dir = None 18 | self.log_path = None, 19 | self.model_path = None 20 | self.ckpt_interval = -1 21 | 22 | def config_training(self, write_log_file=True, log_path=None, 23 | save_ckpt=True, ckpt_dir=None, ckpt_interval=None, 24 | save_model=True, model_path=None, 25 | write_tb=True, tb_dir=None): 26 | self.logger, self.tb_writer, self.ckpt_dir, self.log_path, self.model_path = \ 27 | config_path(self.job_name, self.device, write_log_file, log_path, save_ckpt, ckpt_dir, save_model, 28 | model_path, write_tb, tb_dir) 29 | if save_ckpt: 30 | self.ckpt_interval = ckpt_interval 31 | 32 | def config_tensorboard(self, write_tb=False, tb_dir=None): 33 | if write_tb: 34 | self.tb_writer = SummaryWriter(log_dir=tb_dir) 35 | 36 | def config_logger(self, write_log_file=False, log_path=None): 37 | if write_log_file: 38 | self.logger = create_file_console_logger(log_path, name=self.job_name) 39 | 40 | def config_ckpt(self, save_ckpt=False, ckpt_dir=None, ckpt_interval=None): 41 | if save_ckpt: 42 | self.ckpt_dir = ckpt_dir 43 | self.ckpt_interval = ckpt_interval 44 | 45 | def config_model_saving(self, save_model=False, model_path=None): 46 | if save_model: 47 | self.model_path = model_path 48 | 49 | def config_loader_meta(self, **kwargs): 50 | self.loader_args = kwargs 51 | 52 | def _train(self, dataset, loss_func, optimizer=None, epochs=2, val_size=0): 53 | if not optimizer: 54 | optimizer = torch.optim.SGD(params=self.model.parameters(), lr=1e-3) 55 | if val_size <= 0: 56 | train_loader = DataLoader(dataset, **self.loader_args) 57 | val_loader = None 58 | else: 59 | train_set, val_set = split_dataset(dataset, val_size) 60 | train_loader = DataLoader(train_set, **self.loader_args) 61 | val_loader = DataLoader(val_set, batch_size=self.loader_args['batch_size']) 62 | self.model.train() 63 | train_model(self.model, train_loader, loss_func, optimizer, val_loader, epochs, 64 | self.logger, self.tb_writer, self.ckpt_dir, self.ckpt_interval, self.model_path) 65 | 66 | def train(self, **kwargs): 67 | raise NotImplementedError 68 | 69 | def eval(self, **kwargs): 70 | raise NotImplementedError 71 | -------------------------------------------------------------------------------- /model/wrapper/ctr/__init__.py: -------------------------------------------------------------------------------- 1 | from .fnn import FNNModel 2 | from .pnn import PNNModel -------------------------------------------------------------------------------- /model/wrapper/ctr/fnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data.dataset import TensorDataset 4 | 5 | from feature import FeatureMeta 6 | from model.ctr import FNN 7 | from model.wrapper import BaseModel 8 | 9 | 10 | class FNNModel(BaseModel): 11 | def __init__(self, feat_meta: FeatureMeta, emb_dim, fc_dims=None, dropout=None, batch_norm=None, out_type='binary', 12 | device_name='cuda:0'): 13 | super(FNNModel, self).__init__() 14 | self.model = FNN(emb_dim, feat_meta.get_num_feats(), feat_meta.get_num_fields(), fc_dims, dropout, 15 | batch_norm, out_type, True) 16 | self.device = torch.device(device_name if torch.cuda.is_available() else "cpu") 17 | self.model.to(self.device) 18 | self.job_name = 'FNN-' + out_type 19 | 20 | def train(self, feat_index, y, batch_size=32, epochs=2, shuffle=True, val_size=0.2): 21 | """ train the FNN model with hold-out model selection method 22 | 23 | :param feat_index: ndarray-like, should be shape of (n_samples,num_fields) 24 | :param y: ndarray-like, should be shape of (n_samples,1) for binary cls/regression 25 | or (n_samples,num_classes) for multi-class cls 26 | :param batch_size: the size of samples in a batch 27 | :param epochs: the epochs to train 28 | :param shuffle: whether the order of samples is shuffled before each epoch, default true 29 | :param val_size: whether to use validation set in the training, default 0.2; 30 | set to >=1 means n_samples; 0 to 1 (0 and 1 not included) means ratio; 31 | 0 means not to use validation set 32 | """ 33 | self.config_loader_meta(batch_size=batch_size, shuffle=shuffle) 34 | 35 | feat_index_tensor = torch.LongTensor(feat_index).to(self.device) 36 | y_tensor = torch.Tensor(y).to(self.device) 37 | 38 | dataset = TensorDataset(feat_index_tensor, y_tensor) 39 | 40 | loss_func = nn.BCELoss() 41 | optimizer = torch.optim.Adam(params=self.model.parameters(), lr=1e-3) 42 | 43 | # train FM 44 | self.model.train_fm_embedding() 45 | self._train(dataset, loss_func, optimizer, epochs, val_size) 46 | 47 | # train FNN 48 | self.model.train_fm_embedding() 49 | self._train(dataset, loss_func, optimizer, epochs, val_size) 50 | 51 | def eval(self, feat_index): 52 | self.model.eval() 53 | return self.model(feat_index) 54 | -------------------------------------------------------------------------------- /model/wrapper/ctr/pnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data.dataset import TensorDataset 4 | 5 | from feature import FeatureMeta 6 | from model.ctr import PNN 7 | from model.wrapper import BaseModel 8 | 9 | 10 | class PNNModel(BaseModel): 11 | def __init__(self, feat_meta: FeatureMeta, emb_dim, fc_dims=None, dropout=None, batch_norm=None, out_type='binary', 12 | device_name='cuda:0'): 13 | super(PNNModel, self).__init__() 14 | self.model = PNN(emb_dim, feat_meta.get_num_feats(), feat_meta.get_num_fields(), fc_dims, dropout, 15 | batch_norm, out_type) 16 | self.device = torch.device(device_name if torch.cuda.is_available() else "cpu") 17 | self.model.to(self.device) 18 | self.job_name = 'PNN-' + out_type 19 | 20 | def train(self, feat_index, y, batch_size=32, epochs=2, shuffle=True, val_size=0.2): 21 | """ train the PNN model with hold-out model selection method 22 | 23 | :param feat_index: ndarray-like, should be shape of (n_samples,num_fields) 24 | :param y: ndarray-like, should be shape of (n_samples,1) for binary cls/regression 25 | or (n_samples,num_classes) for multi-class cls 26 | :param batch_size: the size of samples in a batch 27 | :param epochs: the epochs to train 28 | :param shuffle: whether the order of samples is shuffled before each epoch, default true 29 | :param val_size: whether to use validation set in the training, default 0.2; 30 | set to >=1 means n_samples; 0 to 1 (0 and 1 not included) means ratio; 31 | 0 means not to use validation set 32 | """ 33 | self.config_loader_meta(batch_size=batch_size, shuffle=shuffle) 34 | 35 | feat_index_tensor = torch.LongTensor(feat_index).to(self.device) 36 | y_tensor = torch.Tensor(y).to(self.device) 37 | 38 | dataset = TensorDataset(feat_index_tensor, y_tensor) 39 | 40 | self._train(dataset, nn.BCELoss(), 41 | torch.optim.SGD(params=self.model.parameters(), lr=1e-3), 42 | epochs, val_size) 43 | 44 | def eval(self, feat_index): 45 | self.model.eval() 46 | return self.model(feat_index) 47 | -------------------------------------------------------------------------------- /preprocess/discretize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.cluster import KMeans 4 | 5 | 6 | def discretize(data_series: pd.Series, disc_method, bins): 7 | if disc_method == 'eq_dist': 8 | discrete, intervals = pd.cut(data_series, bins=bins, labels=range(bins), retbins=True) 9 | return discrete.values, intervals 10 | elif disc_method == 'eq_freq': 11 | discrete, intervals = pd.qcut(data_series, q=bins, labels=range(bins), retbins=True, duplicates='drop') 12 | return discrete.values, intervals 13 | elif disc_method == 'cluster': 14 | data = np.reshape(data_series, (-1, 1)) 15 | kmeans = KMeans(n_clusters=bins) 16 | ret_data = kmeans.fit_transform(data) 17 | return np.reshape(ret_data, -1), None 18 | -------------------------------------------------------------------------------- /preprocess/feat_engineering.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer 4 | from sklearn.cluster import KMeans 5 | 6 | from feature.feature_meta import FeatureMeta 7 | from util.log_util import * 8 | from preprocess.discretize import * 9 | 10 | 11 | @DeprecationWarning 12 | def get_idx_and_value(feat_meta: FeatureMeta, raw_data: pd.DataFrame): 13 | logger = create_console_logger(name='feat_meta') 14 | write_info_log(logger, 'preprocess started') 15 | idx = 0 16 | # allocate indices for continuous features 17 | continuous_feats = feat_meta.continuous_feats 18 | for name in continuous_feats: 19 | continuous_feat = continuous_feats[name] 20 | continuous_feat.start_idx = idx 21 | idx += 1 22 | # generate label encoders and allocate indices range for categorical features 23 | categorical_feats = feat_meta.categorical_feats 24 | for name in categorical_feats: 25 | categorical_feat = categorical_feats[name] 26 | le = categorical_feat.processor 27 | if le: 28 | num_classes = len(le.classes_) 29 | raw_data[name] = le.transform(raw_data[name]) 30 | else: 31 | le = LabelEncoder() 32 | le.fit(raw_data[name]) 33 | categorical_feat.processor = le 34 | num_classes = len(le.classes_) 35 | categorical_feat.dim = num_classes 36 | categorical_feat.start_idx = idx 37 | idx += num_classes 38 | # generate multi-hot encoders and allocate indices range for multi-category features 39 | multi_category_feats = feat_meta.multi_category_feats 40 | for name in multi_category_feats: 41 | multi_category_feat = multi_category_feats[name] 42 | le = multi_category_feat.processor 43 | if le: 44 | num_classes = len(le.classes_) 45 | else: 46 | mlb = MultiLabelBinarizer() 47 | mlb.fit(raw_data[name]) 48 | multi_category_feat.processor = mlb 49 | num_classes = len(mlb.classes_) 50 | multi_category_feat.dim = num_classes 51 | multi_category_feat.start_idx = idx 52 | idx += num_classes 53 | write_info_log(logger, 'feature meta updated') 54 | # transform raw data to index and value form 55 | write_info_log(logger, 'index and value transformation started') 56 | feat_df = raw_data.apply(process_line, feat_meta=feat_meta, axis=1) 57 | write_info_log(logger, 'preprocess finished') 58 | return feat_df.feat_idx.values.tolist(), feat_df.feat_value.values.tolist() 59 | 60 | 61 | def preprocess_features(feat_meta: FeatureMeta, data: pd.DataFrame, split_continuous_category=False): 62 | r"""Transform raw data into index and value form. 63 | Continuous features will be discretized, standardized, normalized or scaled according to feature meta. 64 | Categorical features will be encoded with a label encoder. 65 | 66 | 67 | :param feat_meta: The FeatureMeta instance that describes raw_data. 68 | :param data: The raw_data to be transformed. 69 | :param split_continuous_category: Whether to return value of continuous features and index of category features. 70 | :return: feat_index, feat_value, category_index, continuous_value 71 | """ 72 | logger = create_console_logger(name='feat_meta') 73 | write_info_log(logger, 'preprocess started') 74 | 75 | idx = 0 76 | continuous_feats = feat_meta.continuous_feats 77 | categorical_feats = feat_meta.categorical_feats 78 | columns = list(continuous_feats.keys()) 79 | columns.extend(list(categorical_feats.keys())) 80 | data = data[columns] 81 | feat_idx = pd.DataFrame() 82 | 83 | # transform continuous features 84 | write_info_log(logger, 'transforming continuous features') 85 | feat_value_continuous = pd.DataFrame() 86 | for name in continuous_feats: 87 | feat = continuous_feats[name] 88 | feat.start_idx = idx 89 | if not feat.discretize: 90 | # standardized, normalize or scale 91 | processor = feat.transformation 92 | col_data = np.reshape(data[name].values, (-1, 1)) 93 | col_data = processor.fit_transform(col_data) 94 | col_data = np.reshape(col_data, -1) 95 | feat_value_continuous[name] = col_data 96 | feat_idx[name] = np.repeat(idx, repeats=len(data)) 97 | idx += 1 98 | else: 99 | # discretize 100 | discrete_data, intervals = discretize(data[name], feat.discretize, feat.dim) 101 | feat.bins = intervals 102 | feat_idx[name] = discrete_data + idx 103 | feat_value_continuous[name] = pd.Series(np.ones(len(data[name]))) 104 | idx += feat.dim 105 | 106 | write_info_log(logger, 'transforming categorical features') 107 | # transform categorical features 108 | category_index = pd.DataFrame() 109 | for name in categorical_feats: 110 | categorical_feat = categorical_feats[name] 111 | le = LabelEncoder() 112 | feat_idx[name] = le.fit_transform(data[name]) + idx 113 | category_index[name] = feat_idx[name] 114 | categorical_feat.processor = le 115 | num_classes = len(le.classes_) 116 | categorical_feat.dim = num_classes 117 | categorical_feat.start_idx = idx 118 | idx += num_classes 119 | 120 | # TODO add multi category features 121 | feat_idx = feat_idx.apply(lambda x: x.values, axis=1) 122 | category_index = category_index.apply(lambda x: x.values, axis=1) 123 | 124 | feat_value_category = pd.DataFrame(np.ones((len(data), len(categorical_feats.keys())))) 125 | feat_value = pd.concat([feat_value_continuous, feat_value_category], axis=1) 126 | feat_value = feat_value.apply(lambda x: x.values, axis=1) 127 | continuous_value = feat_value_continuous.apply(lambda x: x.values, axis=1) 128 | 129 | write_info_log(logger, 'preprocess finished') 130 | if split_continuous_category: 131 | return feat_idx, feat_value, category_index, continuous_value 132 | return feat_idx, feat_value 133 | 134 | 135 | def feature_fit_transform(feat_meta: FeatureMeta, data: pd.DataFrame): 136 | r""" Transform raw data into input of model. 137 | Continuous features will not be transformed. 138 | Category features will be encoded with Label Encoder. 139 | The description in feat_meta will be updated. 140 | 141 | :param feat_meta: The FeatureMeta instance that describes raw_data. 142 | :param data: The raw_data to be transformed. 143 | :return: continuous_value, category_index, column_list 144 | """ 145 | logger = create_console_logger(name='feat_meta') 146 | write_info_log(logger, 'preprocess started') 147 | 148 | idx = 0 149 | continuous_feats = feat_meta.continuous_feats 150 | categorical_feats = feat_meta.categorical_feats 151 | columns = list(continuous_feats.keys()) 152 | columns.extend(list(categorical_feats.keys())) 153 | 154 | continuous_value = pd.DataFrame() 155 | write_info_log(logger, 'transforming continuous features') 156 | for name in continuous_feats: 157 | continuous_value[name] = data[name] 158 | continuous_feats[name].start_idx = idx 159 | idx += 1 160 | 161 | idx = 0 162 | category_index = pd.DataFrame() 163 | write_info_log(logger, 'transforming categorical features') 164 | for name in categorical_feats: 165 | categorical_feat = categorical_feats[name] 166 | le = LabelEncoder() 167 | category_index[name] = le.fit_transform(data[name]) 168 | categorical_feat.processor = le 169 | num_classes = len(le.classes_) 170 | categorical_feat.dim = num_classes 171 | categorical_feat.start_idx = idx 172 | idx += num_classes 173 | 174 | return continuous_value, category_index, columns 175 | 176 | 177 | def universal_category_index_transform(feature_meta: FeatureMeta, category_index: pd.DataFrame): 178 | """ Transform the indices of categorical feature index into universal indices. The universal index is 179 | (start_idx + index) of the categorical feature. 180 | 181 | :param feature_meta: The FeatureMeta instance that describes raw_data. 182 | :param category_index: The inner-category indices of data 183 | :return: universal_category_index 184 | """ 185 | 186 | category_start_idx_dict = {} 187 | for feat_name in feature_meta.categorical_feats: 188 | category_start_idx_dict[feat_name] = feature_meta.categorical_feats[feat_name].start_idx 189 | 190 | universal_category_index = pd.DataFrame() 191 | for column in category_index.columns: 192 | universal_category_index[column] = category_index[column].add(category_start_idx_dict[column]) 193 | 194 | return universal_category_index 195 | 196 | 197 | def process_line(row, feat_meta): 198 | feat_idx, feat_value = [], [] 199 | # process continuous features 200 | continuous_feats = feat_meta.continuous_feats 201 | for feat_name in continuous_feats: 202 | feat = continuous_feats[feat_name] 203 | row_value = row[feat_name] 204 | idx, value = feat.get_idx_and_value(row_value) 205 | feat_idx.append(idx) 206 | feat_value.append(value) 207 | # process categorical features 208 | categorical_feats = feat_meta.categorical_feats 209 | for feat_name in categorical_feats: 210 | feat = categorical_feats[feat_name] 211 | row_value = row[feat_name] 212 | idx, value = feat.get_idx_and_value(row_value) 213 | feat_idx.append(idx) 214 | feat_value.append(value) 215 | # process multi-category features 216 | multi_category_feats = feat_meta.multi_category_feats 217 | for feat_name in multi_category_feats: 218 | feat = multi_category_feats[feat_name] 219 | row_value = row[feat_name] 220 | idxes, values = feat.get_idx_and_value(row_value) 221 | feat_idx.extend(idxes) 222 | feat_value.extend(values) 223 | return pd.Series(index=['feat_idx', 'feat_value'], data=[feat_idx, feat_value]) 224 | -------------------------------------------------------------------------------- /preprocess/preprocess.py: -------------------------------------------------------------------------------- 1 | import scipy.stats as stats 2 | import numpy as np 3 | 4 | 5 | def box_cox_transform(x, lmbda=None): 6 | if not lmbda: 7 | lmbda = stats.boxcox(x) 8 | return stats.boxcox(x, lmbda) 9 | 10 | 11 | def log_transform(x): 12 | return np.log(x) 13 | -------------------------------------------------------------------------------- /util/checkpoint_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from util.filedir_util import get_file_dir_and_name 5 | 6 | 7 | def load_to_eval(model: nn.Module, optimizer, path): 8 | load_checkpoint(model, optimizer, path) 9 | model.eval() 10 | 11 | 12 | def load_to_train(model: nn.Module, optimizer, path): 13 | epoch, loss = load_checkpoint(model, optimizer, path) 14 | model.eval() 15 | return epoch, loss 16 | 17 | 18 | def load_checkpoint(model: nn.Module, optimizer, path): 19 | checkpoint = torch.load(path) 20 | model.load_state_dict(checkpoint['model_state_dict']) 21 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 22 | epoch = checkpoint['epoch'] 23 | loss = checkpoint['loss'] 24 | return epoch, loss 25 | 26 | 27 | def save_checkpoint(ckpt_path, model, optimizer, loss, epoch): 28 | ckpt_dir, ckpt_file = get_file_dir_and_name(ckpt_path) 29 | if not os.path.exists(ckpt_dir): 30 | os.makedirs(ckpt_dir) 31 | 32 | torch.save({ 33 | 'epoch': epoch, 34 | 'model_state_dict': model.state_dict(), 35 | 'optimizer_state_dict': optimizer.state_dict(), 36 | 'loss': loss 37 | }, ckpt_path) 38 | return ckpt_path 39 | 40 | 41 | def save_trained_model(model_path, model): 42 | model_dir, model_file = get_file_dir_and_name(model_path) 43 | if not os.path.exists(model_dir): 44 | os.makedirs(model_dir) 45 | 46 | torch.save(model, model_path) 47 | return model_path 48 | 49 | 50 | def load_trained_model(model_path): 51 | model = torch.load(model_path) 52 | model.eval() 53 | return model 54 | -------------------------------------------------------------------------------- /util/filedir_util.py: -------------------------------------------------------------------------------- 1 | def get_file_dir_and_name(path): 2 | last_sep = path.rfind('/') 3 | dir = path[:last_sep] 4 | name = path[last_sep + 1:] 5 | return dir, name 6 | -------------------------------------------------------------------------------- /util/log_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from util.filedir_util import get_file_dir_and_name 5 | 6 | LEVEL_DICT = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 7 | DEFAULT_FORMATTER = logging.Formatter( 8 | # "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" 9 | "[%(asctime)s][%(levelname)s] %(message)s" 10 | ) 11 | 12 | 13 | def create_console_logger(formatter=DEFAULT_FORMATTER, verbosity=1, name=None): 14 | logger = logging.getLogger(name) 15 | logger.setLevel(LEVEL_DICT[verbosity]) 16 | 17 | sh = logging.StreamHandler() 18 | sh.setFormatter(formatter) 19 | logger.addHandler(sh) 20 | return logger 21 | 22 | 23 | def create_file_console_logger(log_path, formatter=DEFAULT_FORMATTER, verbosity=1, name=None): 24 | log_dir, log_file = get_file_dir_and_name(log_path) 25 | 26 | if not os.path.exists(log_dir): 27 | os.makedirs(log_dir) 28 | 29 | logger = logging.getLogger(name) 30 | logger.setLevel(LEVEL_DICT[verbosity]) 31 | 32 | fh = logging.FileHandler(log_path, "w") 33 | fh.setFormatter(formatter) 34 | logger.addHandler(fh) 35 | 36 | sh = logging.StreamHandler() 37 | sh.setFormatter(formatter) 38 | logger.addHandler(sh) 39 | 40 | return logger 41 | 42 | 43 | def write_training_file_meta(logger, ckpt_path=None, log_path=None, model_path=None, tb_dir=None): 44 | if ckpt_path: 45 | logger.info('checkpoint file path:{}'.format(ckpt_path)) 46 | else: 47 | logger.info('no checkpoint will be saved') 48 | 49 | if log_path: 50 | logger.info('log file path:{}'.format(log_path)) 51 | else: 52 | logger.info('only logs on console') 53 | 54 | if model_path: 55 | logger.info('trained model path:{}'.format(model_path)) 56 | else: 57 | logger.info('trained model will not be saved') 58 | 59 | if tb_dir: 60 | logger.info('tensorboard log directory:{}'.format(tb_dir)) 61 | else: 62 | logger.info('no logs will be shown on tensorboard') 63 | 64 | 65 | def write_model_meta(logger, 66 | job_name, device, 67 | model, loss_func, optimizer, 68 | epochs, batch_size, shuffle): 69 | logger.info('job:{},device:{}'.format(job_name, device)) 70 | logger.info('model structure') 71 | logger.info(model) 72 | logger.info('loss function:{}'.format(loss_func)) 73 | logger.info('optimizer information') 74 | logger.info(optimizer) 75 | logger.info('training meta') 76 | logger.info('epochs:{},batch size:{},shuffle:{}'.format(epochs, batch_size, shuffle)) 77 | 78 | 79 | def write_training_log(logger, epoch, epochs, loss, loss_type='training'): 80 | log_info = 'epoch:[{}/{}], {} loss:{:.5f}'.format(epoch, epochs, loss_type, loss) 81 | logger.info(log_info) 82 | 83 | 84 | def write_info_log(logger, msg): 85 | logger.info(msg) 86 | 87 | 88 | def close_logger(logger): 89 | for handler in logger.handlers: 90 | handler.close() 91 | logger.removeHandler(handler) 92 | -------------------------------------------------------------------------------- /util/train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torch.utils.tensorboard import SummaryWriter 3 | from torch.utils.data.dataset import Dataset 4 | 5 | from util.checkpoint_util import * 6 | from util.log_util import * 7 | 8 | 9 | def train_model_hold_out(job_name, device, 10 | model: nn.Module, dataset: Dataset, 11 | loss_func, optimizer, val_size=0.2, 12 | batch_size=32, epochs=2, shuffle=True, 13 | write_log_file=True, log_path=None, 14 | save_ckpt=True, ckpt_dir=None, ckpt_interval=None, 15 | save_model=True, model_path=None, 16 | write_tb=False, tb_dir=None, 17 | load_ckpt=None): 18 | train_set, val_set = split_dataset(dataset, val_size) 19 | train_config_model(job_name, device, model, train_set, loss_func, optimizer, val_set, batch_size, epochs, 20 | shuffle, write_log_file, log_path, save_ckpt, ckpt_dir, ckpt_interval, save_model, model_path, 21 | write_tb, tb_dir, load_ckpt) 22 | 23 | 24 | def train_config_model(job_name, device, 25 | model: nn.Module, train_set: Dataset, 26 | loss_func, optimizer, val_set: Dataset = None, 27 | batch_size=32, epochs=2, shuffle=True, 28 | write_log_file=True, log_path=None, 29 | save_ckpt=True, ckpt_dir=None, ckpt_interval=None, 30 | save_model=True, model_path=None, 31 | write_tb=False, tb_dir=None, 32 | load_ckpt=None): 33 | # before training 34 | # config logger, tensorboard writer, checkpoint file directory and trained model saving path 35 | logger, writer, ckpt_dir, log_path, model_path = config_path(job_name, device, write_log_file, log_path, 36 | save_ckpt, ckpt_dir, save_model, model_path, 37 | write_tb, tb_dir) 38 | # config training meta 39 | if not save_ckpt: 40 | ckpt_interval = epochs + 1 41 | first_epoch = 1 42 | # if load_ckpt is not None, load the checkpoint and continue training 43 | if load_ckpt: 44 | curr_epoch, curr_loss = load_to_train(model, optimizer, load_ckpt) 45 | write_info_log(logger, 'model loaded from {}'.format(load_ckpt)) 46 | write_info_log(logger, 'epochs trained:{}, current loss:{:.5f}'.format(curr_epoch, curr_loss)) 47 | first_epoch = curr_epoch + 1 48 | model.to(device) 49 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=shuffle) 50 | if val_set: 51 | val_loader = DataLoader(val_set, batch_size=batch_size) 52 | else: 53 | val_loader = None 54 | 55 | # write training meta log 56 | write_model_meta(logger, job_name, device, model, loss_func, optimizer, epochs, batch_size, shuffle) 57 | 58 | train_model(model, train_loader, loss_func, optimizer, val_loader, epochs, logger, writer, 59 | ckpt_dir, ckpt_interval, model_path, first_epoch) 60 | 61 | 62 | def train_model(model: nn.Module, train_loader, 63 | loss_func, optimizer, val_loader=None, epochs=2, 64 | logger=None, writer=None, 65 | ckpt_dir=None, ckpt_interval=None, 66 | model_path=None, first_epoch=1): 67 | # start training 68 | write_info_log(logger, 'training started') 69 | train_epochs(model, epochs, train_loader, loss_func, optimizer, ckpt_interval, logger, val_loader=val_loader, 70 | tensorboard_writer=writer, ckpt_dir=ckpt_dir, first_epoch_idx=first_epoch) 71 | 72 | # finish training 73 | if model_path: 74 | saved_model_path = save_trained_model(model_path, model) 75 | write_info_log(logger, 'model saved:{}'.format(saved_model_path)) 76 | if writer: 77 | writer.flush() 78 | writer.close() 79 | write_info_log(logger, 'training_finished') 80 | close_logger(logger) 81 | 82 | 83 | def config_path(job_name, device, write_log_file=True, log_path=None, 84 | save_ckpt=True, ckpt_dir=None, 85 | save_model=True, model_path=None, 86 | write_tb=True, tb_dir=None): 87 | # config job directory 88 | job_timestamp = time.time() 89 | job_timestamp_str = time.strftime("%b%d-%H-%M-%S", time.localtime(job_timestamp)) 90 | device_str = '{}-{}'.format(device.type, device.index) 91 | 92 | job_dir = '/job/{}_{}_{}'.format(job_name, device_str, job_timestamp_str) 93 | job_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + job_dir 94 | if not os.path.exists(job_dir): 95 | os.makedirs(job_dir) 96 | 97 | # config checkpoint path 98 | if save_ckpt: 99 | if not ckpt_dir: 100 | ckpt_dir = job_dir 101 | else: 102 | ckpt_dir = None 103 | 104 | # config log path and load logger 105 | logger_name = '{}_{}'.format(job_name, device_str) 106 | if write_log_file: 107 | if not log_path: 108 | log_path = job_dir + '/train_log.log' 109 | logger = create_file_console_logger(log_path, name=logger_name) 110 | else: 111 | logger = create_console_logger(name=logger_name) 112 | log_path = None 113 | 114 | # config model saving path 115 | if save_model: 116 | if not model_path: 117 | model_path = job_dir + '/model.pt' 118 | else: 119 | model_path = None 120 | 121 | # config tensorboard 122 | if write_tb: 123 | if not tb_dir: 124 | tb_dir = job_dir + '/tb' 125 | writer = SummaryWriter(log_dir=tb_dir) 126 | else: 127 | writer = None 128 | tb_dir = None 129 | 130 | write_training_file_meta(logger, ckpt_dir, log_path, model_path, tb_dir) 131 | return logger, writer, ckpt_dir, log_path, model_path 132 | 133 | 134 | def train_epochs(model, epochs, train_loader, loss_func, optimizer, ckpt_interval, 135 | logger, val_loader=None, tensorboard_writer=None, ckpt_dir=None, 136 | first_epoch_idx=1): 137 | # config checkpoint interval epoch 138 | if not ckpt_interval: 139 | if epochs > 50: 140 | ckpt_interval = 10 141 | elif epochs > 10: 142 | ckpt_interval = 5 143 | else: 144 | ckpt_interval = 2 145 | 146 | save_ckpt_flag = ckpt_interval 147 | loss = 0 148 | max_epoch = epochs + first_epoch_idx - 1 149 | for epoch in range(first_epoch_idx, max_epoch + 1): 150 | # train the epoch 151 | model.train() 152 | for step, tensors in enumerate(train_loader): 153 | y = tensors[-1] 154 | X = tensors[:-1] 155 | pred_y = model(*X) 156 | loss = loss_func(pred_y, y) 157 | 158 | if tensorboard_writer: 159 | tensorboard_writer.add_scalar("Loss/train", loss, epoch) 160 | 161 | # back propagation 162 | optimizer.zero_grad() 163 | loss.backward() 164 | optimizer.step() 165 | 166 | # calculate average loss during the epoch and write log 167 | write_training_log(logger, epoch, max_epoch, loss) 168 | 169 | # run validation set 170 | if val_loader: 171 | model.eval() 172 | acc_epoch_loss = 0 173 | batch_nums = 0 174 | for step, tensors in enumerate(val_loader): 175 | y = tensors[-1] 176 | X = tensors[:-1] 177 | pred_y = model(*X) 178 | loss = loss_func(pred_y, y) 179 | 180 | if tensorboard_writer: 181 | tensorboard_writer.add_scalar("Loss/validation", loss, epoch) 182 | 183 | acc_epoch_loss += loss 184 | batch_nums += 1 185 | avg_epoch_loss = acc_epoch_loss / batch_nums 186 | write_training_log(logger, epoch, max_epoch, avg_epoch_loss, loss_type='validation') 187 | 188 | # save model to checkpoint file 189 | save_ckpt_flag -= 1 190 | if save_ckpt_flag == 0: 191 | saved_ckpt_path = save_checkpoint('{}/epoch_{}.pt'.format(ckpt_dir, epoch), model, optimizer, loss, epoch) 192 | write_info_log(logger, 'checkpoint saved:{}'.format(saved_ckpt_path)) 193 | save_ckpt_flag = ckpt_interval 194 | 195 | 196 | def split_dataset(dataset, val_size): 197 | data_size = len(dataset) 198 | if 0 < val_size < 1: 199 | val_size = int(data_size * val_size) 200 | train_size = data_size - val_size 201 | train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size]) 202 | return train_set, val_set 203 | --------------------------------------------------------------------------------