├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── requirements.txt ├── torchfm ├── __init__.py ├── dataset │ ├── __init__.py │ ├── amazon.py │ ├── avazu.py │ ├── criteo.py │ └── movielens.py ├── layer.py └── model │ ├── __init__.py │ ├── afi.py │ ├── afm.py │ ├── afn.py │ ├── awesome.py │ ├── dcn.py │ ├── dcnv2.py │ ├── dfm.py │ ├── ffm.py │ ├── finalmlp.py │ ├── fm.py │ ├── fnfm.py │ ├── fnn.py │ ├── fwfm.py │ ├── hofm.py │ ├── lr.py │ ├── mwd.py │ ├── ncf.py │ ├── nfm.py │ ├── pnn.py │ ├── rdcnv2.py │ ├── temp.py │ ├── transformer.py │ ├── wd.py │ ├── wrdcnv2.py │ └── xdfm.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 112 | .pdm.toml 113 | .pdm-python 114 | .pdm-build/ 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 THUML @ Tsinghua University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Embedding (ICML 2024) 2 | 3 | On the Embedding Collapse When Scaling Up Recommendation Models [[paper]](https://arxiv.org/abs/2310.04400) 4 | 5 | Recommendation models lack scalability due to the **embedding collapse** phenomenon. We propose multi-embedding to overcome embedding collapse while maintaining feature interactions modules. 6 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import time 4 | from sklearn.metrics import roc_auc_score 5 | from torch.utils.data import DataLoader 6 | 7 | from torchfm.dataset.avazu import AvazuDataset 8 | from torchfm.dataset.criteo import CriteoDataset 9 | from torchfm.model.ffm import FieldAwareFactorizationMachineModel 10 | from torchfm.model.pnn import ProductNeuralNetworkModel, MultiPNNModel 11 | from torchfm.model.xdfm import XDeepFM, MultiXDeepFM 12 | from torchfm.model.fwfm import NFwFMModel, MultiNFwFMModel 13 | from torchfm.model.dcnv2 import CrossNetworkV2Model 14 | from torchfm.model.mwd import ( 15 | DNNModel, 16 | MultiDNNModel, 17 | ) 18 | from torchfm.model.awesome import ( 19 | MultiDCNnew2, 20 | WeightNormAlignedMultiDCNnew2, 21 | MultiESingleIDCNv2, 22 | SpaceSimilarityRegularizedMultiDCNnew2, 23 | SingularValueRegularizedDCNv2, 24 | ) 25 | from torchfm.model.rdcnv2 import RestrictedCrossNetworkV2Model 26 | from torchfm.model.wrdcnv2 import WeightedRestrictedCrossNetworkV2Model, WeightedRestrictedMultiDCN 27 | from torchfm.model.finalmlp import FinalMLP, MultiFinalMLP 28 | from utils import CompleteLogger, AverageMeter, ProgressMeter, CriterionWithLoss, EarlyStopper 29 | 30 | 31 | def get_dataset(name, path): 32 | if name == 'criteo': 33 | return CriteoDataset(path) 34 | elif name == 'avazu': 35 | return AvazuDataset(path) 36 | else: 37 | raise ValueError('unknown dataset name: ' + name) 38 | 39 | 40 | def get_model(name, dataset): 41 | """ 42 | Hyperparameters are empirically determined, not opitmized. 43 | """ 44 | field_dims = dataset.field_dims 45 | print(field_dims) 46 | print(sum(field_dims)) 47 | 48 | ########################################## Appendix ######################################## 49 | if name == 'space-similarity-regularized-mdcn-4x10-1e-3': 50 | return SpaceSimilarityRegularizedMultiDCNnew2(field_dims, embed_dims=[10]*4, num_layers=4, mlp_dims=(400, 400), dropout=0.2, reg_weight=1e-3) 51 | elif name == 'space-similarity-regularized-mdcn-4x10-1e-4': 52 | return SpaceSimilarityRegularizedMultiDCNnew2(field_dims, embed_dims=[10]*4, num_layers=4, mlp_dims=(400, 400), dropout=0.2, reg_weight=1e-4) 53 | elif name == 'space-similarity-regularized-mdcn-4x10-1e-5': 54 | return SpaceSimilarityRegularizedMultiDCNnew2(field_dims, embed_dims=[10]*4, num_layers=4, mlp_dims=(400, 400), dropout=0.2, reg_weight=1e-5) 55 | elif name == 'singular-value-regularized-dcn-40-1e-3': 56 | return SingularValueRegularizedDCNv2(field_dims, embed_dim=40, num_layers=4, mlp_dims=(400, 400), dropout=0.2, reg_weight=1e-3) 57 | elif name == 'singular-value-regularized-dcn-40-1e-4': 58 | return SingularValueRegularizedDCNv2(field_dims, embed_dim=40, num_layers=4, mlp_dims=(400, 400), dropout=0.2, reg_weight=1e-4) 59 | elif name == 'singular-value-regularized-dcn-40-1e-5': 60 | return SingularValueRegularizedDCNv2(field_dims, embed_dim=40, num_layers=4, mlp_dims=(400, 400), dropout=0.2, reg_weight=1e-5) 61 | elif name == "me-si-dcn-2x10": 62 | return MultiESingleIDCNv2(field_dims, embed_dims=[10]*2, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 63 | elif name == "me-si-dcn-3x10": 64 | return MultiESingleIDCNv2(field_dims, embed_dims=[10]*3, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 65 | elif name == "me-si-dcn-4x10": 66 | return MultiESingleIDCNv2(field_dims, embed_dims=[10]*4, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 67 | elif name == "me-si-dcn-10x10": 68 | return MultiESingleIDCNv2(field_dims, embed_dims=[10]*10, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 69 | elif name == 'rebuttal-restricted-weighted-mdcn-2x10': 70 | return WeightedRestrictedMultiDCN(field_dims, embed_dims=[10]*2, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 71 | elif name == 'rebuttal-restricted-weighted-mdcn-3x10': 72 | return WeightedRestrictedMultiDCN(field_dims, embed_dims=[10]*3, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 73 | elif name == 'rebuttal-restricted-weighted-mdcn-4x10': 74 | return WeightedRestrictedMultiDCN(field_dims, embed_dims=[10]*4, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 75 | elif name == 'rebuttal-restricted-weighted-mdcn-10x10': 76 | return WeightedRestrictedMultiDCN(field_dims, embed_dims=[10]*10, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 77 | ########################################## Appendix ######################################## 78 | 79 | ############################## 80 | # Main experiments started. # 81 | ############################## 82 | elif name == 'ffm': 83 | return FieldAwareFactorizationMachineModel(field_dims, embed_dim=8) 84 | elif name == 'dcn-10': 85 | return CrossNetworkV2Model(field_dims, embed_dim=10, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 86 | elif name == 'dcn-20': 87 | return CrossNetworkV2Model(field_dims, embed_dim=20, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 88 | elif name == 'dcn-30': 89 | return CrossNetworkV2Model(field_dims, embed_dim=30, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 90 | elif name == 'dcn-40': 91 | return CrossNetworkV2Model(field_dims, embed_dim=40, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 92 | elif name == 'dcn-100': 93 | return CrossNetworkV2Model(field_dims, embed_dim=100, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 94 | 95 | elif name == "mdcn-2x10": 96 | return MultiDCNnew2(field_dims, embed_dims=[10]*2, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 97 | elif name == "mdcn-3x10": 98 | return MultiDCNnew2(field_dims, embed_dims=[10]*3, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 99 | elif name == "mdcn-4x10": 100 | return MultiDCNnew2(field_dims, embed_dims=[10]*4, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 101 | elif name == "mdcn-10x10": 102 | return MultiDCNnew2(field_dims, embed_dims=[10]*10, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 103 | 104 | elif name == "weight-norm-aligned-mdcn-2x10": 105 | return WeightNormAlignedMultiDCNnew2(field_dims, embed_dims=[10]*2, num_layers=4, mlp_dims=(400, 400), dropout=0.2, reg_weight=1.0) 106 | elif name == "weight-norm-aligned-mdcn-3x10": 107 | return WeightNormAlignedMultiDCNnew2(field_dims, embed_dims=[10]*3, num_layers=4, mlp_dims=(400, 400), dropout=0.2, reg_weight=1.0) 108 | elif name == "weight-norm-aligned-mdcn-4x10": 109 | return WeightNormAlignedMultiDCNnew2(field_dims, embed_dims=[10]*4, num_layers=4, mlp_dims=(400, 400), dropout=0.2, reg_weight=1.0) 110 | elif name == "weight-norm-aligned-mdcn-10x10": 111 | return WeightNormAlignedMultiDCNnew2(field_dims, embed_dims=[10]*10, num_layers=4, mlp_dims=(400, 400), dropout=0.2, reg_weight=1.0) 112 | 113 | elif name == "dnn-10": 114 | return DNNModel(field_dims, embed_dim=10, mlp_dims=(400, 400), dropout=0.2) 115 | elif name == "dnn-20": 116 | return DNNModel(field_dims, embed_dim=20, mlp_dims=(400, 400), dropout=0.2) 117 | elif name == "dnn-30": 118 | return DNNModel(field_dims, embed_dim=30, mlp_dims=(400, 400), dropout=0.2) 119 | elif name == "dnn-40": 120 | return DNNModel(field_dims, embed_dim=40, mlp_dims=(400, 400), dropout=0.2) 121 | elif name == "dnn-100": 122 | return DNNModel(field_dims, embed_dim=100, mlp_dims=(400, 400), dropout=0.2) 123 | 124 | elif name == "mdnnW-2x10": 125 | return MultiDNNModel(field_dims, embed_dims=[10]*2, mlp_dims=(400, 400), dropout=0.2) 126 | elif name == "mdnnW-3x10": 127 | return MultiDNNModel(field_dims, embed_dims=[10]*3, mlp_dims=(400, 400), dropout=0.2) 128 | elif name == "mdnnW-4x10": 129 | return MultiDNNModel(field_dims, embed_dims=[10]*4, mlp_dims=(400, 400), dropout=0.2) 130 | elif name == "mdnnW-10x10": 131 | return MultiDNNModel(field_dims, embed_dims=[10]*10, mlp_dims=(400, 400), dropout=0.2) 132 | 133 | elif name == 'restricted-weighted-dcn-10': 134 | return WeightedRestrictedCrossNetworkV2Model(field_dims, embed_dim=10, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 135 | elif name == 'restricted-weighted-dcn-20': 136 | return WeightedRestrictedCrossNetworkV2Model(field_dims, embed_dim=20, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 137 | elif name == 'restricted-weighted-dcn-30': 138 | return WeightedRestrictedCrossNetworkV2Model(field_dims, embed_dim=30, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 139 | elif name == 'restricted-weighted-dcn-40': 140 | return WeightedRestrictedCrossNetworkV2Model(field_dims, embed_dim=40, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 141 | elif name == 'restricted-weighted-dcn-100': 142 | return WeightedRestrictedCrossNetworkV2Model(field_dims, embed_dim=100, num_layers=4, mlp_dims=(400, 400), dropout=0.2) 143 | 144 | elif name == 'ipnn-10': 145 | return ProductNeuralNetworkModel(field_dims, embed_dim=10, mlp_dims=(400, 400), method='inner', dropout=0.2) 146 | elif name == 'ipnn-20': 147 | return ProductNeuralNetworkModel(field_dims, embed_dim=10, mlp_dims=(400, 400), method='inner', dropout=0.2) 148 | elif name == 'ipnn-30': 149 | return ProductNeuralNetworkModel(field_dims, embed_dim=10, mlp_dims=(400, 400), method='inner', dropout=0.2) 150 | elif name == 'ipnn-40': 151 | return ProductNeuralNetworkModel(field_dims, embed_dim=10, mlp_dims=(400, 400), method='inner', dropout=0.2) 152 | elif name == 'ipnn-100': 153 | return ProductNeuralNetworkModel(field_dims, embed_dim=10, mlp_dims=(400, 400), method='inner', dropout=0.2) 154 | 155 | elif name == 'multi-ipnn-2x10': 156 | return MultiPNNModel(field_dims, embed_dims=[10]*2, mlp_dims=(400, 400), method='inner', dropout=0.2) 157 | elif name == 'multi-ipnn-3x10': 158 | return MultiPNNModel(field_dims, embed_dims=[10]*3, mlp_dims=(400, 400), method='inner', dropout=0.2) 159 | elif name == 'multi-ipnn-4x10': 160 | return MultiPNNModel(field_dims, embed_dims=[10]*4, mlp_dims=(400, 400), method='inner', dropout=0.2) 161 | elif name == 'multi-ipnn-10x10': 162 | return MultiPNNModel(field_dims, embed_dims=[10]*10, mlp_dims=(400, 400), method='inner', dropout=0.2) 163 | 164 | elif name == 'nfwfm-50': 165 | return NFwFMModel(field_dims, embed_dim=50, mlp_dims=(400, 400), dropouts=(0.2, 0.2)) 166 | elif name == 'nfwfm-100': 167 | return NFwFMModel(field_dims, embed_dim=100, mlp_dims=(400, 400), dropouts=(0.2, 0.2)) 168 | elif name == 'nfwfm-150': 169 | return NFwFMModel(field_dims, embed_dim=150, mlp_dims=(400, 400), dropouts=(0.2, 0.2)) 170 | elif name == 'nfwfm-200': 171 | return NFwFMModel(field_dims, embed_dim=200, mlp_dims=(400, 400), dropouts=(0.2, 0.2)) 172 | elif name == 'nfwfm-500': 173 | return NFwFMModel(field_dims, embed_dim=500, mlp_dims=(400, 400), dropouts=(0.2, 0.2)) 174 | 175 | elif name == 'multi-nfwfm-2x50': 176 | return MultiNFwFMModel(field_dims, embed_dims=[50]*2, mlp_dims=(400, 400), dropouts=(0.2, 0.2)) 177 | elif name == 'multi-nfwfm-3x50': 178 | return MultiNFwFMModel(field_dims, embed_dims=[50]*3, mlp_dims=(400, 400), dropouts=(0.2, 0.2)) 179 | elif name == 'multi-nfwfm-4x50': 180 | return MultiNFwFMModel(field_dims, embed_dims=[50]*4, mlp_dims=(400, 400), dropouts=(0.2, 0.2)) 181 | elif name == 'multi-nfwfm-10x50': 182 | return MultiNFwFMModel(field_dims, embed_dims=[50]*10, mlp_dims=(400, 400), dropouts=(0.2, 0.2)) 183 | 184 | elif name == 'xdfm-10': 185 | return XDeepFM(field_dims, embed_dim=10, mlp_dims=(400, 400), dropout=0.2, cross_layer_sizes=(16, 16)) 186 | elif name == 'xdfm-20': 187 | return XDeepFM(field_dims, embed_dim=20, mlp_dims=(400, 400), dropout=0.2, cross_layer_sizes=(16, 16)) 188 | elif name == 'xdfm-30': 189 | return XDeepFM(field_dims, embed_dim=30, mlp_dims=(400, 400), dropout=0.2, cross_layer_sizes=(16, 16)) 190 | elif name == 'xdfm-40': 191 | return XDeepFM(field_dims, embed_dim=40, mlp_dims=(400, 400), dropout=0.2, cross_layer_sizes=(16, 16)) 192 | elif name == 'xdfm-100': 193 | return XDeepFM(field_dims, embed_dim=100, mlp_dims=(400, 400), dropout=0.2, cross_layer_sizes=(16, 16)) 194 | 195 | elif name == 'multi-xdfm-2x10': 196 | return MultiXDeepFM(field_dims, embed_dims=[10]*2, mlp_dims=(400, 400), dropout=0.2, cross_layer_sizes=(16, 16)) 197 | elif name == 'multi-xdfm-3x10': 198 | return MultiXDeepFM(field_dims, embed_dims=[10]*3, mlp_dims=(400, 400), dropout=0.2, cross_layer_sizes=(16, 16)) 199 | elif name == 'multi-xdfm-4x10': 200 | return MultiXDeepFM(field_dims, embed_dims=[10]*4, mlp_dims=(400, 400), dropout=0.2, cross_layer_sizes=(16, 16)) 201 | elif name == 'multi-xdfm-10x10': 202 | return MultiXDeepFM(field_dims, embed_dims=[10]*10, mlp_dims=(400, 400), dropout=0.2, cross_layer_sizes=(16, 16)) 203 | 204 | elif name == 'finalmlp-10': 205 | return FinalMLP(field_dims, embed_dim=10, mlp_dims=(400, 400), fs_mlp_dims=(800, ), dropout=0.2) 206 | elif name == 'finalmlp-20': 207 | return FinalMLP(field_dims, embed_dim=20, mlp_dims=(400, 400), fs_mlp_dims=(800, ), dropout=0.2) 208 | elif name == 'finalmlp-30': 209 | return FinalMLP(field_dims, embed_dim=30, mlp_dims=(400, 400), fs_mlp_dims=(800, ), dropout=0.2) 210 | elif name == 'finalmlp-40': 211 | return FinalMLP(field_dims, embed_dim=40, mlp_dims=(400, 400), fs_mlp_dims=(800, ), dropout=0.2) 212 | elif name == 'finalmlp-100': 213 | return FinalMLP(field_dims, embed_dim=100, mlp_dims=(400, 400), fs_mlp_dims=(800, ), dropout=0.2) 214 | 215 | elif name == 'multi-finalmlp-2x10': 216 | return MultiFinalMLP(field_dims, embed_dims=[10]*2, mlp_dims=(400, 400), fs_mlp_dims=(800, ), dropout=0.2) 217 | elif name == 'multi-finalmlp-3x10': 218 | return MultiFinalMLP(field_dims, embed_dims=[10]*3, mlp_dims=(400, 400), fs_mlp_dims=(800, ), dropout=0.2) 219 | elif name == 'multi-finalmlp-4x10': 220 | return MultiFinalMLP(field_dims, embed_dims=[10]*4, mlp_dims=(400, 400), fs_mlp_dims=(800, ), dropout=0.2) 221 | elif name == 'multi-finalmlp-10x10': 222 | return MultiFinalMLP(field_dims, embed_dims=[10]*10, mlp_dims=(400, 400), fs_mlp_dims=(800, ), dropout=0.2) 223 | 224 | else: 225 | raise ValueError('unknown model name: ' + name) 226 | 227 | 228 | def train(model, optimizer, data_loader, criterion, device, epoch, accumulate_gradient=1, log_interval=500): 229 | model.train() 230 | batch_time = AverageMeter('Total Time', ':4.2f') 231 | losses = AverageMeter("Loss", ":5.4f") 232 | progress = ProgressMeter(len(data_loader), [batch_time, losses], prefix="Epoch: [{}]".format(epoch)) 233 | steps = 0 234 | optimizer.zero_grad() 235 | end = time.time() 236 | for i, (fields, target) in enumerate(data_loader): 237 | fields, target = fields.to(device), target.to(device) 238 | y = model(fields) 239 | loss = criterion(y, target.float()) 240 | losses.update(loss.item()) 241 | accumulate_loss = loss / accumulate_gradient 242 | accumulate_loss.backward() 243 | steps += 1 244 | if steps % accumulate_gradient == 0: 245 | optimizer.step() 246 | optimizer.zero_grad() 247 | batch_time.update(len(data_loader) * (time.time() - end)) 248 | end = time.time() 249 | if (i + 1) % log_interval == 0: 250 | progress.display(i + 1) 251 | optimizer.zero_grad() 252 | 253 | 254 | def test(model, data_loader, device, log_interval=500): 255 | model.eval() 256 | batch_time = AverageMeter('Total Time', ':4.2f') 257 | progress = ProgressMeter(len(data_loader), [batch_time], prefix="Test:") 258 | targets, predicts = list(), list() 259 | end = time.time() 260 | with torch.no_grad(): 261 | for i, (fields, target) in enumerate(data_loader): 262 | fields, target = fields.to(device), target.to(device) 263 | y = model(fields) 264 | targets.extend(target.tolist()) 265 | predicts.extend(y.tolist()) 266 | batch_time.update(len(data_loader) * (time.time() - end)) 267 | end = time.time() 268 | if (i + 1) % log_interval == 0: 269 | progress.display(i + 1) 270 | 271 | return roc_auc_score(targets, predicts) 272 | 273 | 274 | def main(dataset_name, 275 | dataset_path, 276 | model_name, 277 | epoch, 278 | learning_rate, 279 | batch_size, 280 | weight_decay, 281 | device, 282 | phase, 283 | seed, 284 | accumulate_gradient, 285 | ): 286 | logger = CompleteLogger(args.log, args.phase) 287 | print(args) 288 | device = torch.device(device) 289 | dataset = get_dataset(dataset_name, dataset_path) 290 | train_length = int(len(dataset) * 0.8) 291 | valid_length = int(len(dataset) * 0.1) 292 | test_length = len(dataset) - train_length - valid_length 293 | train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split( 294 | dataset, (train_length, valid_length, test_length), generator=torch.Generator().manual_seed(seed)) 295 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8) 296 | valid_data_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=8) 297 | test_data_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=8) 298 | 299 | 300 | model = get_model(model_name, dataset).to(device) 301 | print(model.state_dict().keys()) 302 | 303 | # count parameters 304 | for n, p in model.named_parameters(): 305 | print(n, p.numel()) 306 | print(sum(p.numel() for p in model.parameters() if p.requires_grad)) 307 | 308 | if phase == "train": 309 | if isinstance(model, (RestrictedCrossNetworkV2Model, 310 | WeightedRestrictedCrossNetworkV2Model, 311 | WeightedRestrictedMultiDCN, 312 | WeightNormAlignedMultiDCNnew2, 313 | SpaceSimilarityRegularizedMultiDCNnew2, 314 | SingularValueRegularizedDCNv2,)): 315 | criterion = CriterionWithLoss(torch.nn.BCELoss()) 316 | else: 317 | criterion = torch.nn.BCELoss() 318 | if hasattr(model, "get_parameters"): 319 | optimizer = torch.optim.Adam(params=model.get_parameters(learning_rate), weight_decay=weight_decay) 320 | else: 321 | optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay) 322 | save_paths = [logger.get_checkpoint_path("best"), logger.get_checkpoint_path("optimizer")] 323 | if epoch == 0: 324 | early_stopper = EarlyStopper(num_trials=3, save_paths=save_paths) 325 | epoch = 100 326 | else: 327 | early_stopper = EarlyStopper(num_trials=epoch, save_paths=save_paths) 328 | auc = test(model, valid_data_loader, device) 329 | print(auc) 330 | for epoch_i in range(epoch): 331 | train(model, optimizer, train_data_loader, criterion, device, epoch_i, accumulate_gradient) 332 | torch.save(model.state_dict(), logger.get_checkpoint_path("latest")) 333 | auc = test(model, valid_data_loader, device) 334 | print('epoch:', epoch_i, 'validation: auc:', auc) 335 | if not early_stopper.is_continuable((model, optimizer), auc): 336 | print(f'validation: best auc: {early_stopper.best_accuracy}') 337 | break 338 | 339 | model.load_state_dict(torch.load(logger.get_checkpoint_path("best"))) 340 | 341 | auc = test(model, test_data_loader, device) 342 | print('test auc:', auc) 343 | 344 | logger.close() 345 | 346 | 347 | if __name__ == '__main__': 348 | import argparse 349 | 350 | parser = argparse.ArgumentParser() 351 | parser.add_argument('--dataset_name', default='criteo') 352 | parser.add_argument('--dataset_path', help='criteo/train.txt or avazu/train') 353 | parser.add_argument('--model_name', default='dcn-10') 354 | parser.add_argument('--epoch', type=int, default=0) 355 | parser.add_argument('--learning_rate', type=float, default=0.001) 356 | parser.add_argument('--batch_size', type=int, default=2048) 357 | parser.add_argument('--weight_decay', type=float, default=1e-6) 358 | parser.add_argument('--k_dim', type=int, default=None) 359 | parser.add_argument('--device', default='cuda:0') 360 | parser.add_argument('--log', default='logs/test') 361 | parser.add_argument('--phase', default='train', choices=['train', 'test']) 362 | parser.add_argument('--seed', type=int, default=0) 363 | parser.add_argument('--accumulate_gradient', '--acc_grad', type=int, default=1) 364 | args = parser.parse_args() 365 | main(args.dataset_name, 366 | args.dataset_path, 367 | args.model_name, 368 | args.epoch, 369 | args.learning_rate, 370 | args.batch_size, 371 | args.weight_decay, 372 | args.device, 373 | args.phase, 374 | args.seed, 375 | args.accumulate_gradient) 376 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.1 2 | scikit-learn 3 | numpy 4 | pandas 5 | tqdm 6 | lmdb -------------------------------------------------------------------------------- /torchfm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Multi-Embedding/55078571bf1922dced112505514958f85ca4804c/torchfm/__init__.py -------------------------------------------------------------------------------- /torchfm/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Multi-Embedding/55078571bf1922dced112505514958f85ca4804c/torchfm/dataset/__init__.py -------------------------------------------------------------------------------- /torchfm/dataset/amazon.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import pandas as pd 4 | import torch.utils.data 5 | import random 6 | import math 7 | 8 | 9 | class Amazon(torch.utils.data.Dataset): 10 | """ 11 | """ 12 | 13 | def __init__(self, data, split, field_dims): 14 | data.sort_values('timestamp', inplace=True) 15 | 16 | timestamps = data.timestamp.to_numpy() 17 | if split == 'train': 18 | start_timestamp = 0 19 | end_timestamp = timestamps[int(0.8 * len(data))] 20 | elif split == 'val': 21 | start_timestamp = timestamps[int(0.8 * len(data))] 22 | end_timestamp = timestamps[int(0.9 * len(data))] 23 | else: 24 | start_timestamp = timestamps[int(0.9 * len(data))] 25 | end_timestamp = math.inf 26 | # print(split, start_timestamp, end_timestamp) 27 | 28 | def gen_neg(pos_list): 29 | neg = pos_list[0] 30 | while neg in pos_list: 31 | neg = random.randint(0, field_dims[1] - 1) 32 | return neg 33 | 34 | self.items = [] 35 | self.targets = [] 36 | for user_id, data in data.groupby("user_id"): 37 | if len(data) < 5: 38 | continue 39 | item_ids = data.to_numpy()[:, 1] 40 | timestamps = data.to_numpy()[:, 3] 41 | for item_id, timestamp in zip(item_ids, timestamps): 42 | if start_timestamp <= timestamp < end_timestamp: 43 | self.items.append([user_id, item_id]) 44 | self.targets.append(1) 45 | self.items.append([user_id, gen_neg(item_ids)]) 46 | self.targets.append(0) 47 | 48 | self.items = np.array(self.items).astype(np.int) 49 | self.targets = np.array(self.targets).astype(np.float32) 50 | self.field_dims = field_dims 51 | self.user_field_idx = np.array((0,), dtype=np.long) 52 | self.item_field_idx = np.array((1,), dtype=np.long) 53 | 54 | def __len__(self): 55 | return self.targets.shape[0] 56 | 57 | def __getitem__(self, index): 58 | return self.items[index], self.targets[index] 59 | 60 | 61 | def get_amazon_datasets(root, domains): 62 | from sklearn.preprocessing import LabelEncoder 63 | label_encoders = { 64 | "user_id": LabelEncoder(), 65 | "item_id": LabelEncoder() 66 | } 67 | data = {} 68 | for domain in domains: 69 | # TODO add sorting preprocessing 70 | data[domain] = pd.read_csv(osp.join(root, "ratings_{}.csv".format(domain)), 71 | names=["user_id", "item_id", "label", "timestamp"]) 72 | all_data = pd.concat(data.values()) 73 | field_dims = [] 74 | for feat in label_encoders: 75 | label_encoders[feat].fit(all_data[feat]) 76 | field_dims.append(len(label_encoders[feat].classes_)) 77 | 78 | for domain in domains: 79 | for feat in label_encoders: 80 | data[domain][feat] = label_encoders[feat].transform(data[domain][feat]) 81 | 82 | return {domain: Amazon(data[domain], 'train', field_dims) for domain in domains},\ 83 | {domain: Amazon(data[domain], 'val', field_dims) for domain in domains}, \ 84 | {domain: Amazon(data[domain], 'test', field_dims) for domain in domains} 85 | -------------------------------------------------------------------------------- /torchfm/dataset/avazu.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import struct 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import lmdb 7 | import numpy as np 8 | import torch.utils.data 9 | from tqdm import tqdm 10 | 11 | 12 | class AvazuDataset(torch.utils.data.Dataset): 13 | """ 14 | Avazu Click-Through Rate Prediction Dataset 15 | 16 | Dataset preparation 17 | Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature 18 | 19 | :param dataset_path: avazu train path 20 | :param cache_path: lmdb cache path 21 | :param rebuild_cache: If True, lmdb cache is refreshed 22 | :param min_threshold: infrequent feature threshold 23 | 24 | Reference 25 | https://www.kaggle.com/c/avazu-ctr-prediction 26 | """ 27 | 28 | def __init__(self, dataset_path=None, cache_path='.avazu', rebuild_cache=False, min_threshold=4): 29 | self.NUM_FEATS = 22 30 | self.min_threshold = min_threshold 31 | if rebuild_cache or not Path(cache_path).exists(): 32 | shutil.rmtree(cache_path, ignore_errors=True) 33 | if dataset_path is None: 34 | raise ValueError('create cache: failed: dataset_path is None') 35 | self.__build_cache(dataset_path, cache_path) 36 | self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True) 37 | with self.env.begin(write=False) as txn: 38 | self.length = txn.stat()['entries'] - 1 39 | self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32) 40 | 41 | def __getitem__(self, index): 42 | with self.env.begin(write=False) as txn: 43 | np_array = np.frombuffer( 44 | txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=int) 45 | return np_array[1:], np_array[0] 46 | 47 | def __len__(self): 48 | return self.length 49 | 50 | def __build_cache(self, path, cache_path): 51 | feat_mapper, defaults = self.__get_feat_mapper(path) 52 | with lmdb.open(cache_path, map_size=int(1e11)) as env: 53 | field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32) 54 | for i, fm in feat_mapper.items(): 55 | field_dims[i - 1] = len(fm) + 1 56 | with env.begin(write=True) as txn: 57 | txn.put(b'field_dims', field_dims.tobytes()) 58 | for buffer in self.__yield_buffer(path, feat_mapper, defaults): 59 | with env.begin(write=True) as txn: 60 | for key, value in buffer: 61 | txn.put(key, value) 62 | 63 | def __get_feat_mapper(self, path): 64 | feat_cnts = defaultdict(lambda: defaultdict(int)) 65 | with open(path) as f: 66 | f.readline() 67 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 68 | pbar.set_description('Create avazu dataset cache: counting features') 69 | for line in pbar: 70 | values = line.rstrip('\n').split(',') 71 | if len(values) != self.NUM_FEATS + 2: 72 | continue 73 | for i in range(1, self.NUM_FEATS + 1): 74 | feat_cnts[i][values[i + 1]] += 1 75 | feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()} 76 | feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()} 77 | defaults = {i: len(cnt) for i, cnt in feat_mapper.items()} 78 | return feat_mapper, defaults 79 | 80 | def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)): 81 | item_idx = 0 82 | buffer = list() 83 | with open(path) as f: 84 | f.readline() 85 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 86 | pbar.set_description('Create avazu dataset cache: setup lmdb') 87 | for line in pbar: 88 | values = line.rstrip('\n').split(',') 89 | if len(values) != self.NUM_FEATS + 2: 90 | continue 91 | np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32) 92 | np_array[0] = int(values[1]) 93 | for i in range(1, self.NUM_FEATS + 1): 94 | np_array[i] = feat_mapper[i].get(values[i+1], defaults[i]) 95 | buffer.append((struct.pack('>I', item_idx), np_array.tobytes())) 96 | item_idx += 1 97 | if item_idx % buffer_size == 0: 98 | yield buffer 99 | buffer.clear() 100 | yield buffer 101 | -------------------------------------------------------------------------------- /torchfm/dataset/criteo.py: -------------------------------------------------------------------------------- 1 | import math 2 | import shutil 3 | import struct 4 | from collections import defaultdict 5 | from functools import lru_cache 6 | from pathlib import Path 7 | 8 | import lmdb 9 | import numpy as np 10 | import torch.utils.data 11 | from tqdm import tqdm 12 | 13 | 14 | class CriteoDataset(torch.utils.data.Dataset): 15 | """ 16 | Criteo Display Advertising Challenge Dataset 17 | 18 | Data prepration: 19 | * Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature 20 | * Discretize numerical values by log2 transformation which is proposed by the winner of Criteo Competition 21 | 22 | :param dataset_path: criteo train.txt path. 23 | :param cache_path: lmdb cache path. 24 | :param rebuild_cache: If True, lmdb cache is refreshed. 25 | :param min_threshold: infrequent feature threshold. 26 | 27 | Reference: 28 | https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset 29 | https://www.csie.ntu.edu.tw/~r01922136/kaggle-2014-criteo.pdf 30 | """ 31 | 32 | def __init__(self, dataset_path=None, cache_path='.criteo', rebuild_cache=False, min_threshold=10): 33 | self.NUM_FEATS = 39 34 | self.NUM_INT_FEATS = 13 35 | self.min_threshold = min_threshold 36 | if rebuild_cache or not Path(cache_path).exists(): 37 | shutil.rmtree(cache_path, ignore_errors=True) 38 | if dataset_path is None: 39 | raise ValueError('create cache: failed: dataset_path is None') 40 | self.__build_cache(dataset_path, cache_path) 41 | self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True) 42 | with self.env.begin(write=False) as txn: 43 | self.length = txn.stat()['entries'] - 1 44 | self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32) 45 | 46 | def __getitem__(self, index): 47 | with self.env.begin(write=False) as txn: 48 | np_array = np.frombuffer( 49 | txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=np.int32) 50 | return np_array[1:], np_array[0] 51 | 52 | def __len__(self): 53 | return self.length 54 | 55 | def __build_cache(self, path, cache_path): 56 | feat_mapper, defaults = self.__get_feat_mapper(path) 57 | with lmdb.open(cache_path, map_size=int(1e11)) as env: 58 | field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32) 59 | for i, fm in feat_mapper.items(): 60 | field_dims[i - 1] = len(fm) + 1 61 | with env.begin(write=True) as txn: 62 | txn.put(b'field_dims', field_dims.tobytes()) 63 | for buffer in self.__yield_buffer(path, feat_mapper, defaults): 64 | with env.begin(write=True) as txn: 65 | for key, value in buffer: 66 | txn.put(key, value) 67 | 68 | def __get_feat_mapper(self, path): 69 | feat_cnts = defaultdict(lambda: defaultdict(int)) 70 | with open(path) as f: 71 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 72 | pbar.set_description('Create criteo dataset cache: counting features') 73 | for line in pbar: 74 | values = line.rstrip('\n').split('\t') 75 | if len(values) != self.NUM_FEATS + 1: 76 | continue 77 | for i in range(1, self.NUM_INT_FEATS + 1): 78 | feat_cnts[i][convert_numeric_feature(values[i])] += 1 79 | for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1): 80 | feat_cnts[i][values[i]] += 1 81 | feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()} 82 | feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()} 83 | defaults = {i: len(cnt) for i, cnt in feat_mapper.items()} 84 | return feat_mapper, defaults 85 | 86 | def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)): 87 | item_idx = 0 88 | buffer = list() 89 | with open(path) as f: 90 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 91 | pbar.set_description('Create criteo dataset cache: setup lmdb') 92 | for line in pbar: 93 | values = line.rstrip('\n').split('\t') 94 | if len(values) != self.NUM_FEATS + 1: 95 | continue 96 | np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32) 97 | np_array[0] = int(values[0]) 98 | for i in range(1, self.NUM_INT_FEATS + 1): 99 | np_array[i] = feat_mapper[i].get(convert_numeric_feature(values[i]), defaults[i]) 100 | for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1): 101 | np_array[i] = feat_mapper[i].get(values[i], defaults[i]) 102 | buffer.append((struct.pack('>I', item_idx), np_array.tobytes())) 103 | item_idx += 1 104 | if item_idx % buffer_size == 0: 105 | yield buffer 106 | buffer.clear() 107 | yield buffer 108 | 109 | 110 | @lru_cache(maxsize=None) 111 | def convert_numeric_feature(val: str): 112 | if val == '': 113 | return 'NULL' 114 | v = int(val) 115 | if v > 2: 116 | return str(int(math.log(v) ** 2)) 117 | else: 118 | return str(v - 2) 119 | -------------------------------------------------------------------------------- /torchfm/dataset/movielens.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch.utils.data 4 | 5 | 6 | class MovieLens20MDataset(torch.utils.data.Dataset): 7 | """ 8 | MovieLens 20M Dataset 9 | 10 | Data preparation 11 | treat samples with a rating less than 3 as negative samples 12 | 13 | :param dataset_path: MovieLens dataset path 14 | 15 | Reference: 16 | https://grouplens.org/datasets/movielens 17 | """ 18 | 19 | def __init__(self, dataset_path, sep=',', engine='c', header='infer'): 20 | data = pd.read_csv(dataset_path, sep=sep, engine=engine, header=header).to_numpy()[:, :3] 21 | self.items = data[:, :2].astype(np.int) - 1 # -1 because ID begins from 1 22 | self.targets = self.__preprocess_target(data[:, 2]).astype(np.float32) 23 | self.field_dims = np.max(self.items, axis=0) + 1 24 | self.user_field_idx = np.array((0, ), dtype=np.long) 25 | self.item_field_idx = np.array((1,), dtype=np.long) 26 | 27 | def __len__(self): 28 | return self.targets.shape[0] 29 | 30 | def __getitem__(self, index): 31 | return self.items[index], self.targets[index] 32 | 33 | def __preprocess_target(self, target): 34 | target[target <= 3] = 0 35 | target[target > 3] = 1 36 | return target 37 | 38 | 39 | class MovieLens1MDataset(MovieLens20MDataset): 40 | """ 41 | MovieLens 1M Dataset 42 | 43 | Data preparation 44 | treat samples with a rating less than 3 as negative samples 45 | 46 | :param dataset_path: MovieLens dataset path 47 | 48 | Reference: 49 | https://grouplens.org/datasets/movielens 50 | """ 51 | 52 | def __init__(self, dataset_path): 53 | super().__init__(dataset_path, sep='::', engine='python', header=None) 54 | -------------------------------------------------------------------------------- /torchfm/layer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class FeaturesLinear(torch.nn.Module): 7 | 8 | def __init__(self, field_dims, output_dim=1): 9 | super().__init__() 10 | self.fc = torch.nn.Embedding(sum(field_dims), output_dim) 11 | self.bias = torch.nn.Parameter(torch.zeros((output_dim,))) 12 | self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=int) 13 | 14 | def forward(self, x): 15 | """ 16 | :param x: Long tensor of size ``(batch_size, num_fields)`` 17 | """ 18 | x = x + x.new_tensor(self.offsets).unsqueeze(0) 19 | return torch.sum(self.fc(x), dim=1) + self.bias 20 | 21 | 22 | class FeaturesEmbedding(torch.nn.Module): 23 | 24 | def __init__(self, field_dims, embed_dim): 25 | super().__init__() 26 | self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim) 27 | self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.int32) 28 | torch.nn.init.xavier_uniform_(self.embedding.weight.data) 29 | 30 | def forward(self, x): 31 | """ 32 | :param x: Long tensor of size ``(batch_size, num_fields)`` 33 | """ 34 | x = x + x.new_tensor(self.offsets).unsqueeze(0) 35 | return self.embedding(x) 36 | 37 | 38 | class FieldAwareFactorizationMachine(torch.nn.Module): 39 | 40 | def __init__(self, field_dims, embed_dim): 41 | super().__init__() 42 | self.num_fields = len(field_dims) 43 | self.embeddings = torch.nn.ModuleList([ 44 | torch.nn.Embedding(sum(field_dims), embed_dim) for _ in range(self.num_fields) 45 | ]) 46 | self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long) 47 | for embedding in self.embeddings: 48 | torch.nn.init.xavier_uniform_(embedding.weight.data) 49 | 50 | def forward(self, x): 51 | """ 52 | :param x: Long tensor of size ``(batch_size, num_fields)`` 53 | """ 54 | x = x + x.new_tensor(self.offsets).unsqueeze(0) 55 | xs = [self.embeddings[i](x) for i in range(self.num_fields)] 56 | ix = list() 57 | for i in range(self.num_fields - 1): 58 | for j in range(i + 1, self.num_fields): 59 | ix.append(xs[j][:, i] * xs[i][:, j]) 60 | ix = torch.stack(ix, dim=1) 61 | return ix 62 | 63 | 64 | class FactorizationMachine(torch.nn.Module): 65 | 66 | def __init__(self, reduce_sum=True): 67 | super().__init__() 68 | self.reduce_sum = reduce_sum 69 | 70 | def forward(self, x): 71 | """ 72 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 73 | """ 74 | square_of_sum = torch.sum(x, dim=1) ** 2 75 | sum_of_square = torch.sum(x ** 2, dim=1) 76 | ix = square_of_sum - sum_of_square 77 | if self.reduce_sum: 78 | ix = torch.sum(ix, dim=1, keepdim=True) 79 | return 0.5 * ix 80 | 81 | 82 | class MultiLayerPerceptron(torch.nn.Module): 83 | 84 | def __init__(self, input_dim, embed_dims, dropout, output_layer=True, num_tasks=1): 85 | super().__init__() 86 | layers = list() 87 | for embed_dim in embed_dims: 88 | layers.append(torch.nn.Linear(input_dim, embed_dim)) 89 | layers.append(torch.nn.BatchNorm1d(embed_dim)) 90 | layers.append(torch.nn.ReLU()) 91 | layers.append(torch.nn.Dropout(p=dropout)) 92 | input_dim = embed_dim 93 | if output_layer: 94 | layers.append(torch.nn.Linear(input_dim, num_tasks)) 95 | self.mlp = torch.nn.Sequential(*layers) 96 | 97 | def forward(self, x): 98 | """ 99 | :param x: Float tensor of size ``(batch_size, embed_dim)`` 100 | """ 101 | return self.mlp(x) 102 | 103 | 104 | class InnerProductNetwork(torch.nn.Module): 105 | 106 | def forward(self, x): 107 | """ 108 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 109 | """ 110 | num_fields = x.shape[1] 111 | row, col = list(), list() 112 | for i in range(num_fields - 1): 113 | for j in range(i + 1, num_fields): 114 | row.append(i), col.append(j) 115 | return torch.sum(x[:, row] * x[:, col], dim=2) 116 | 117 | 118 | class OuterProductNetwork(torch.nn.Module): 119 | 120 | def __init__(self, num_fields, embed_dim, kernel_type='mat'): 121 | super().__init__() 122 | num_ix = num_fields * (num_fields - 1) // 2 123 | if kernel_type == 'mat': 124 | kernel_shape = embed_dim, num_ix, embed_dim 125 | elif kernel_type == 'vec': 126 | kernel_shape = num_ix, embed_dim 127 | elif kernel_type == 'num': 128 | kernel_shape = num_ix, 1 129 | else: 130 | raise ValueError('unknown kernel type: ' + kernel_type) 131 | self.kernel_type = kernel_type 132 | self.kernel = torch.nn.Parameter(torch.zeros(kernel_shape)) 133 | torch.nn.init.xavier_uniform_(self.kernel.data) 134 | 135 | def forward(self, x): 136 | """ 137 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 138 | """ 139 | num_fields = x.shape[1] 140 | row, col = list(), list() 141 | for i in range(num_fields - 1): 142 | for j in range(i + 1, num_fields): 143 | row.append(i), col.append(j) 144 | p, q = x[:, row], x[:, col] 145 | if self.kernel_type == 'mat': 146 | kp = torch.sum(p.unsqueeze(1) * self.kernel, dim=-1).permute(0, 2, 1) 147 | return torch.sum(kp * q, -1) 148 | else: 149 | return torch.sum(p * q * self.kernel.unsqueeze(0), -1) 150 | 151 | 152 | class CrossNetwork(torch.nn.Module): 153 | 154 | def __init__(self, input_dim, num_layers): 155 | super().__init__() 156 | self.num_layers = num_layers 157 | self.w = torch.nn.ModuleList([ 158 | torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers) 159 | ]) 160 | self.b = torch.nn.ParameterList([ 161 | torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers) 162 | ]) 163 | 164 | def forward(self, x): 165 | """ 166 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 167 | """ 168 | x0 = x 169 | for i in range(self.num_layers): 170 | xw = self.w[i](x) 171 | x = x0 * xw + self.b[i] + x 172 | return x 173 | 174 | 175 | class AttentionalFactorizationMachine(torch.nn.Module): 176 | 177 | def __init__(self, embed_dim, attn_size, dropouts): 178 | super().__init__() 179 | self.attention = torch.nn.Linear(embed_dim, attn_size) 180 | self.projection = torch.nn.Linear(attn_size, 1) 181 | self.fc = torch.nn.Linear(embed_dim, 1) 182 | self.dropouts = dropouts 183 | 184 | def forward(self, x): 185 | """ 186 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 187 | """ 188 | num_fields = x.shape[1] 189 | row, col = list(), list() 190 | for i in range(num_fields - 1): 191 | for j in range(i + 1, num_fields): 192 | row.append(i), col.append(j) 193 | p, q = x[:, row], x[:, col] 194 | inner_product = p * q 195 | attn_scores = F.relu(self.attention(inner_product)) 196 | attn_scores = F.softmax(self.projection(attn_scores), dim=1) 197 | attn_scores = F.dropout(attn_scores, p=self.dropouts[0], training=self.training) 198 | attn_output = torch.sum(attn_scores * inner_product, dim=1) 199 | attn_output = F.dropout(attn_output, p=self.dropouts[1], training=self.training) 200 | return self.fc(attn_output) 201 | 202 | 203 | class CompressedInteractionNetwork(torch.nn.Module): 204 | 205 | def __init__(self, input_dim, cross_layer_sizes, split_half=True): 206 | super().__init__() 207 | self.num_layers = len(cross_layer_sizes) 208 | self.split_half = split_half 209 | self.conv_layers = torch.nn.ModuleList() 210 | prev_dim, fc_input_dim = input_dim, 0 211 | for i in range(self.num_layers): 212 | cross_layer_size = cross_layer_sizes[i] 213 | self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, 214 | stride=1, dilation=1, bias=True)) 215 | # (B, N x C, K) -> (B, C', K) 216 | if self.split_half and i != self.num_layers - 1: 217 | cross_layer_size //= 2 218 | prev_dim = cross_layer_size 219 | fc_input_dim += prev_dim 220 | self.fc = torch.nn.Linear(fc_input_dim, 1) 221 | 222 | def forward(self, x): 223 | """ 224 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 225 | """ 226 | xs = list() 227 | x0, h = x.unsqueeze(2), x 228 | # x0: (B, N, 1, K) 229 | # h: (B, C, K) 230 | for i in range(self.num_layers): 231 | x = x0 * h.unsqueeze(1) 232 | batch_size, f0_dim, fin_dim, embed_dim = x.shape 233 | x = x.view(batch_size, f0_dim * fin_dim, embed_dim) 234 | x = F.relu(self.conv_layers[i](x)) 235 | if self.split_half and i != self.num_layers - 1: 236 | x, h = torch.split(x, x.shape[1] // 2, dim=1) 237 | else: 238 | h = x 239 | xs.append(x) 240 | return self.fc(torch.sum(torch.cat(xs, dim=1), 2)) 241 | 242 | 243 | class AnovaKernel(torch.nn.Module): 244 | 245 | def __init__(self, order, reduce_sum=True): 246 | super().__init__() 247 | self.order = order 248 | self.reduce_sum = reduce_sum 249 | 250 | def forward(self, x): 251 | """ 252 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 253 | """ 254 | batch_size, num_fields, embed_dim = x.shape 255 | a_prev = torch.ones((batch_size, num_fields + 1, embed_dim), dtype=torch.float).to(x.device) 256 | for t in range(self.order): 257 | a = torch.zeros((batch_size, num_fields + 1, embed_dim), dtype=torch.float).to(x.device) 258 | a[:, t+1:, :] += x[:, t:, :] * a_prev[:, t:-1, :] 259 | a = torch.cumsum(a, dim=1) 260 | a_prev = a 261 | if self.reduce_sum: 262 | return torch.sum(a[:, -1, :], dim=-1, keepdim=True) 263 | else: 264 | return a[:, -1, :] 265 | -------------------------------------------------------------------------------- /torchfm/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Multi-Embedding/55078571bf1922dced112505514958f85ca4804c/torchfm/model/__init__.py -------------------------------------------------------------------------------- /torchfm/model/afi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torchfm.layer import FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron 5 | 6 | 7 | class AutomaticFeatureInteractionModel(torch.nn.Module): 8 | """ 9 | A pytorch implementation of AutoInt. 10 | 11 | Reference: 12 | W Song, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks, 2018. 13 | """ 14 | 15 | def __init__(self, field_dims, embed_dim, atten_embed_dim, num_heads, num_layers, mlp_dims, dropouts, has_residual=True): 16 | super().__init__() 17 | self.num_fields = len(field_dims) 18 | self.linear = FeaturesLinear(field_dims) 19 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 20 | self.atten_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 21 | self.embed_output_dim = len(field_dims) * embed_dim 22 | self.atten_output_dim = len(field_dims) * atten_embed_dim 23 | self.has_residual = has_residual 24 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropouts[1]) 25 | self.self_attns = torch.nn.ModuleList([ 26 | torch.nn.MultiheadAttention(atten_embed_dim, num_heads, dropout=dropouts[0]) for _ in range(num_layers) 27 | ]) 28 | self.attn_fc = torch.nn.Linear(self.atten_output_dim, 1) 29 | if self.has_residual: 30 | self.V_res_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 31 | 32 | def forward(self, x): 33 | """ 34 | :param x: Long tensor of size ``(batch_size, num_fields)`` 35 | """ 36 | embed_x = self.embedding(x) 37 | atten_x = self.atten_embedding(embed_x) 38 | cross_term = atten_x.transpose(0, 1) 39 | for self_attn in self.self_attns: 40 | cross_term, _ = self_attn(cross_term, cross_term, cross_term) 41 | cross_term = cross_term.transpose(0, 1) 42 | if self.has_residual: 43 | V_res = self.V_res_embedding(embed_x) 44 | cross_term += V_res 45 | cross_term = F.relu(cross_term).contiguous().view(-1, self.atten_output_dim) 46 | x = self.linear(x) + self.attn_fc(cross_term) + self.mlp(embed_x.view(-1, self.embed_output_dim)) 47 | return torch.sigmoid(x.squeeze(1)) 48 | 49 | 50 | class AttentionModule(torch.nn.Module): 51 | 52 | def __init__(self, atten_embed_dim, num_heads, dropout, residual=True, layer_norm=True): 53 | super().__init__() 54 | self.atten = torch.nn.MultiheadAttention(atten_embed_dim, num_heads, dropout=dropout) 55 | self.layer_norm = torch.nn.LayerNorm(atten_embed_dim) if layer_norm else None 56 | self.residual = residual 57 | 58 | def forward(self, x): 59 | if self.residual: 60 | x = x + self.atten(x, x, x)[0] 61 | else: 62 | x = self.atten(x, x, x)[0] 63 | if self.layer_norm: 64 | x = self.layer_norm(x) 65 | return x 66 | 67 | class MultiHeadSelfAttentionInteraction(torch.nn.Module): 68 | """ 69 | Multi-head self-attention only 70 | """ 71 | def __init__(self, embed_dim, atten_embed_dim, num_heads, num_layers, dropout, residual=True, layer_norm=True): 72 | super().__init__() 73 | self.atten_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 74 | self.attens = torch.nn.ModuleList([ 75 | AttentionModule(atten_embed_dim, num_heads, dropout, residual=residual, layer_norm=layer_norm) 76 | for _ in range(num_layers) 77 | ]) 78 | 79 | def forward(self, emb): 80 | atten_emb = self.atten_embedding(emb) 81 | atten_emb = atten_emb.transpose(0, 1) # batch second: (N, B, H) 82 | for atten in self.attens: 83 | atten_emb = atten(atten_emb) 84 | atten_emb = atten_emb.transpose(0, 1) # batch first: (B, N, H) 85 | atten_emb = F.relu(atten_emb) 86 | return atten_emb 87 | 88 | 89 | class AutoInt(torch.nn.Module): 90 | """ 91 | AutoInt w/o linear 92 | """ 93 | def __init__(self, field_dims, embed_dim, atten_embed_dim, num_heads, num_layers, dropouts, mlp_dims, residual=True, layer_norm=True): 94 | super().__init__() 95 | self.num_fields = len(field_dims) 96 | self.embed_output_dim = len(field_dims) * embed_dim 97 | self.atten_output_dim = len(field_dims) * atten_embed_dim 98 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 99 | self.atten = MultiHeadSelfAttentionInteraction(embed_dim, atten_embed_dim, num_heads, num_layers, dropouts[0], residual=residual, layer_norm=layer_norm) 100 | self.atten_post = torch.nn.Linear(self.atten_output_dim, 1) 101 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropouts[1]) if mlp_dims else None 102 | 103 | def forward(self, x): 104 | emb = self.embedding(x) 105 | x = self.atten_post(self.atten(emb).flatten(1)) 106 | if self.mlp: 107 | x += self.mlp(emb.flatten(1)) 108 | return torch.sigmoid(x.squeeze(1)) 109 | 110 | 111 | class MultiAutoInt(torch.nn.Module): 112 | 113 | def __init__(self, field_dims, embed_dims, atten_embed_dim, num_heads, num_layers, dropouts, mlp_dims, residual=True, layer_norm=True): 114 | super().__init__() 115 | self.num_fields = len(field_dims) 116 | self.embed_output_dims = [len(field_dims) * embed_dim for embed_dim in embed_dims] 117 | self.atten_output_dim = len(field_dims) * atten_embed_dim 118 | self.embeddings = torch.nn.ModuleList([FeaturesEmbedding(field_dims, embed_dim) for embed_dim in embed_dims]) 119 | self.attens = torch.nn.ModuleList([ 120 | MultiHeadSelfAttentionInteraction(embed_dim, atten_embed_dim, num_heads, num_layers, dropouts[0], residual=residual, layer_norm=layer_norm) 121 | for embed_dim in embed_dims 122 | ]) 123 | self.atten_post = torch.nn.Linear(self.atten_output_dim, 1) 124 | self.mlps = torch.nn.ModuleList([ 125 | MultiLayerPerceptron(embed_output_dim, mlp_dims[:1], dropouts[1], output_layer=False) 126 | for embed_output_dim in self.embed_output_dims 127 | ]) if mlp_dims else None 128 | self.mlp_post = MultiLayerPerceptron(mlp_dims[0], mlp_dims[1:], dropouts[1]) if mlp_dims else None 129 | 130 | def forward(self, x): 131 | embs = [embedding(x) for embedding in self.embeddings] 132 | atten_hidden = torch.stack([atten(emb) for atten, emb in zip(self.attens, embs)], dim=-1).mean(dim=-1) 133 | x = self.atten_post(atten_hidden.flatten(1)) 134 | if self.mlps: 135 | mlp_hidden = torch.stack([mlp(emb.flatten(1)) for mlp, emb in zip(self.mlps, embs)], dim=-1).mean(dim=-1) 136 | x += self.mlp_post(mlp_hidden) 137 | return torch.sigmoid(x.squeeze(1)) 138 | -------------------------------------------------------------------------------- /torchfm/model/afm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesEmbedding, FeaturesLinear, AttentionalFactorizationMachine 4 | 5 | 6 | class AttentionalFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Attentional Factorization Machine. 9 | 10 | Reference: 11 | J Xiao, et al. Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks, 2017. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, attn_size, dropouts): 15 | super().__init__() 16 | self.num_fields = len(field_dims) 17 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 18 | self.linear = FeaturesLinear(field_dims) 19 | self.afm = AttentionalFactorizationMachine(embed_dim, attn_size, dropouts) 20 | 21 | def forward(self, x): 22 | """ 23 | :param x: Long tensor of size ``(batch_size, num_fields)`` 24 | """ 25 | x = self.linear(x) + self.afm(self.embedding(x)) 26 | return torch.sigmoid(x.squeeze(1)) 27 | -------------------------------------------------------------------------------- /torchfm/model/afn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from torchfm.layer import FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron 6 | 7 | class LNN(torch.nn.Module): 8 | """ 9 | A pytorch implementation of LNN layer 10 | Input shape 11 | - A 3D tensor with shape: ``(batch_size,field_size,embedding_size)``. 12 | Output shape 13 | - 2D tensor with shape:``(batch_size,LNN_dim*embedding_size)``. 14 | Arguments 15 | - **in_features** : Embedding of feature. 16 | - **num_fields**: int.The field size of feature. 17 | - **LNN_dim**: int.The number of Logarithmic neuron. 18 | - **bias**: bool.Whether or not use bias in LNN. 19 | """ 20 | def __init__(self, num_fields, embed_dim, LNN_dim, bias=False): 21 | super(LNN, self).__init__() 22 | self.num_fields = num_fields 23 | self.embed_dim = embed_dim 24 | self.LNN_dim = LNN_dim 25 | self.lnn_output_dim = LNN_dim * embed_dim 26 | self.weight = torch.nn.Parameter(torch.Tensor(LNN_dim, num_fields)) 27 | if bias: 28 | self.bias = torch.nn.Parameter(torch.Tensor(LNN_dim, embed_dim)) 29 | else: 30 | self.register_parameter('bias', None) 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | stdv = 1. / math.sqrt(self.weight.size(1)) 35 | self.weight.data.uniform_(-stdv, stdv) 36 | if self.bias is not None: 37 | self.bias.data.uniform_(-stdv, stdv) 38 | 39 | def forward(self, x): 40 | """ 41 | :param x: Long tensor of size ``(batch_size, num_fields, embedding_size)`` 42 | """ 43 | embed_x_abs = torch.abs(x) # Computes the element-wise absolute value of the given input tensor. 44 | embed_x_afn = torch.add(embed_x_abs, 1e-7) 45 | # Logarithmic Transformation 46 | embed_x_log = torch.log1p(embed_x_afn) # torch.log1p and torch.expm1 47 | lnn_out = torch.matmul(self.weight, embed_x_log) 48 | if self.bias is not None: 49 | lnn_out += self.bias 50 | lnn_exp = torch.expm1(lnn_out) 51 | output = F.relu(lnn_exp).contiguous().view(-1, self.lnn_output_dim) 52 | return output 53 | 54 | 55 | 56 | 57 | 58 | 59 | class AdaptiveFactorizationNetwork(torch.nn.Module): 60 | """ 61 | A pytorch implementation of AFN. 62 | 63 | Reference: 64 | Cheng W, et al. Adaptive Factorization Network: Learning Adaptive-Order Feature Interactions, 2019. 65 | """ 66 | def __init__(self, field_dims, embed_dim, LNN_dim, mlp_dims, dropouts): 67 | super().__init__() 68 | self.num_fields = len(field_dims) 69 | self.linear = FeaturesLinear(field_dims) # Linear 70 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) # Embedding 71 | self.LNN_dim = LNN_dim 72 | self.LNN_output_dim = self.LNN_dim * embed_dim 73 | self.LNN = LNN(self.num_fields, embed_dim, LNN_dim) 74 | self.mlp = MultiLayerPerceptron(self.LNN_output_dim, mlp_dims, dropouts[0]) 75 | 76 | def forward(self, x): 77 | """ 78 | :param x: Long tensor of size ``(batch_size, num_fields)`` 79 | """ 80 | embed_x = self.embedding(x) 81 | lnn_out = self.LNN(embed_x) 82 | x = self.linear(x) + self.mlp(lnn_out) 83 | return torch.sigmoid(x.squeeze(1)) 84 | 85 | -------------------------------------------------------------------------------- /torchfm/model/awesome.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesEmbedding, MultiLayerPerceptron 4 | from torchfm.model.dcnv2 import CrossNetworkV2, CrossNetworkV2Model 5 | 6 | 7 | class MultiDCNnew2(torch.nn.Module): 8 | 9 | def __init__(self, field_dims, embed_dims, num_layers, mlp_dims, dropout): 10 | super().__init__() 11 | self.num_fields = len(field_dims) 12 | self.embeddings = torch.nn.ModuleList([FeaturesEmbedding(field_dims, embed_dim) for embed_dim in embed_dims]) 13 | self.embed_output_dims = [len(field_dims) * embed_dim for embed_dim in embed_dims] 14 | self.cns = torch.nn.ModuleList([ 15 | torch.nn.Sequential( 16 | CrossNetworkV2(embed_output_dim, num_layers, self.num_fields, embed_dim), 17 | MultiLayerPerceptron(embed_output_dim, mlp_dims[:1], dropout, output_layer=False) 18 | ) for embed_dim, embed_output_dim in zip(embed_dims, self.embed_output_dims)]) 19 | self.mlp = MultiLayerPerceptron(mlp_dims[0], mlp_dims[1:], dropout) 20 | 21 | def forward(self, x): 22 | x_l1 = torch.stack([cn(embedding(x).view(-1, embed_output_dim)) 23 | for embedding, embed_output_dim, cn in 24 | zip(self.embeddings, self.embed_output_dims, self.cns)], dim=-1) 25 | x_l1 = x_l1.mean(dim=-1) 26 | p = self.mlp(x_l1) 27 | return torch.sigmoid(p.squeeze(1)) 28 | 29 | 30 | class MultiESingleIDCNv2(torch.nn.Module): 31 | 32 | def __init__(self, field_dims, embed_dims, num_layers, mlp_dims, dropout): 33 | super().__init__() 34 | self.num_fields = len(field_dims) 35 | self.embeddings = torch.nn.ModuleList([FeaturesEmbedding(field_dims, embed_dim) for embed_dim in embed_dims]) 36 | self.embed_output_dims = [len(field_dims) * embed_dim for embed_dim in embed_dims] 37 | assert all([embed_output_dim == self.embed_output_dims[0] for embed_output_dim in self.embed_output_dims]) 38 | self.cn = torch.nn.Sequential( 39 | CrossNetworkV2(self.embed_output_dims[0], num_layers, self.num_fields, embed_dims[0]), 40 | MultiLayerPerceptron(self.embed_output_dims[0], mlp_dims[:1], dropout, output_layer=False) 41 | ) 42 | self.mlp = MultiLayerPerceptron(mlp_dims[0], mlp_dims[1:], dropout) 43 | 44 | def forward(self, x): 45 | x_l1 = torch.stack([self.cn(embedding(x).view(-1, embed_output_dim)) 46 | for embedding, embed_output_dim in 47 | zip(self.embeddings, self.embed_output_dims)], dim=-1) 48 | x_l1 = x_l1.mean(dim=-1) 49 | p = self.mlp(x_l1) 50 | return torch.sigmoid(p.squeeze(1)) 51 | 52 | 53 | class WeightNormAlignedMultiDCNnew2(MultiDCNnew2): 54 | 55 | def __init__(self, field_dims, embed_dims, num_layers, mlp_dims, dropout, reg_weight=0.0): 56 | assert all([embed_dim == embed_dims[0] for embed_dim in embed_dims]) 57 | super().__init__(field_dims, embed_dims, num_layers, mlp_dims, dropout) 58 | self.embed_dim = embed_dims[0] 59 | self.reg_weight = reg_weight 60 | 61 | def forward(self, x): 62 | output = super().forward(x) 63 | if self.training: 64 | W_all = torch.stack([torch.stack(list(cn[0].W), dim=0) for cn in self.cns], dim=0) # (num_embed, num_layer, ND, ND) 65 | W_all = W_all.reshape(W_all.shape[0], W_all.shape[1], self.num_fields, self.embed_dim, self.num_fields, self.embed_dim) 66 | W_norm_all = (W_all ** 2).sum(dim=(3, 5)) # (num_embed, num_layer, N, N) 67 | W_norm_mean = W_norm_all.mean(dim=0, keepdim=True) + 1e-6 # (1, num_layer, N, N) 68 | W_norm_variance_normalized = (W_norm_all - W_norm_mean).var(dim=0, unbiased=False) # (num_layer, N, N) 69 | reg_loss = W_norm_variance_normalized.mean() 70 | return output, self.reg_weight * reg_loss 71 | else: 72 | return output 73 | 74 | 75 | class SpaceSimilarityRegularizedMultiDCNnew2(MultiDCNnew2): 76 | 77 | def __init__(self, field_dims, embed_dims, num_layers, mlp_dims, dropout, reg_weight=0.0): 78 | assert all([embed_dim == embed_dims[0] for embed_dim in embed_dims]) 79 | super().__init__(field_dims, embed_dims, num_layers, mlp_dims, dropout) 80 | self.embed_dim = embed_dims[0] 81 | self.reg_weight = reg_weight 82 | 83 | def forward(self, x): 84 | es = [embedding(x) for embedding in self.embeddings] 85 | x_l1 = torch.stack([cn(e.view(-1, embed_output_dim)) 86 | for e, embed_output_dim, cn in 87 | zip(es, self.embed_output_dims, self.cns)], dim=-1) 88 | x_l1 = x_l1.mean(dim=-1) 89 | p = self.mlp(x_l1) 90 | output = torch.sigmoid(p.squeeze(1)) 91 | if self.training: 92 | sims = [] 93 | for i, e_i in enumerate(es): 94 | for e_j in es[:i]: 95 | simss = [] 96 | for k in range(e_i.shape[1]): 97 | simss.append(torch.svd(e_i[:, k, :].t() @ e_j[:, k, :]).S) 98 | sims.append(torch.stack(simss, dim=0).mean(dim=0)) 99 | sim = torch.cat(sims, dim=0) 100 | reg_loss = (sim ** 2).mean() 101 | return output, self.reg_weight * reg_loss 102 | else: 103 | return output 104 | 105 | 106 | class SingularValueRegularizedDCNv2(CrossNetworkV2Model): 107 | 108 | def __init__(self, field_dims, embed_dim, num_layers, mlp_dims, dropout, reg_weight=0.0): 109 | super().__init__(field_dims, embed_dim, num_layers, mlp_dims, dropout) 110 | self.reg_weight = reg_weight 111 | 112 | def forward(self, x): 113 | embed_x = self.embedding(x) 114 | x_l1 = self.cn(embed_x.view(-1, self.embed_output_dim)) 115 | h_l2 = self.mlp(x_l1) 116 | p = self.linear(h_l2) 117 | output = torch.sigmoid(p.squeeze(1)) 118 | if self.training: 119 | reg_losses = [] 120 | for i in range(embed_x.shape[1]): # field 121 | _, S, _ = torch.svd(embed_x[:, i, :]) 122 | # regularize the diversity of S 123 | # var(S / mean(S)) 124 | reg_losses.append((S / S.mean()).var()) 125 | reg_loss = torch.stack(reg_losses).mean() 126 | return output, self.reg_weight * reg_loss 127 | else: 128 | return output 129 | -------------------------------------------------------------------------------- /torchfm/model/dcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesEmbedding, CrossNetwork, MultiLayerPerceptron 4 | 5 | 6 | class DeepCrossNetworkModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Deep & Cross Network. 9 | 10 | Reference: 11 | R Wang, et al. Deep & Cross Network for Ad Click Predictions, 2017. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, num_layers, mlp_dims, dropout): 15 | super().__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.embed_output_dim = len(field_dims) * embed_dim 18 | self.cn = CrossNetwork(self.embed_output_dim, num_layers) 19 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False) 20 | self.linear = torch.nn.Linear(mlp_dims[-1] + self.embed_output_dim, 1) 21 | 22 | def forward(self, x): 23 | """ 24 | :param x: Long tensor of size ``(batch_size, num_fields)`` 25 | """ 26 | embed_x = self.embedding(x).view(-1, self.embed_output_dim) 27 | x_l1 = self.cn(embed_x) 28 | h_l2 = self.mlp(embed_x) 29 | x_stack = torch.cat([x_l1, h_l2], dim=1) 30 | p = self.linear(x_stack) 31 | return torch.sigmoid(p.squeeze(1)) 32 | -------------------------------------------------------------------------------- /torchfm/model/dcnv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchfm.layer import FeaturesEmbedding, CrossNetwork, MultiLayerPerceptron 3 | 4 | 5 | class CrossNetworkV2(torch.nn.Module): # layer 6 | 7 | def __init__(self, input_dim, num_layers, num_fields, embed_dim): 8 | """ 9 | input_dim: num_fields*embed_dim 10 | """ 11 | super().__init__() 12 | self.num_layers = num_layers 13 | self.num_fields = num_fields 14 | self.embed_dim = embed_dim 15 | self.input_dim = input_dim 16 | self.W = torch.nn.ParameterList([ 17 | torch.nn.Parameter(torch.Tensor(input_dim, input_dim)) for _ in range(num_layers) 18 | ]) 19 | self.b = torch.nn.ParameterList([ 20 | torch.nn.Parameter(torch.zeros(input_dim, )) for _ in range(num_layers) 21 | ]) 22 | for i in range(num_layers): 23 | torch.nn.init.xavier_uniform_(self.W[i]) 24 | 25 | def forward(self, x): 26 | """ 27 | x: Tensor of size ``(batch_size, num_fields*embed_dim)`` 28 | """ 29 | x0 = x 30 | for i in range(self.num_layers): 31 | x = x.unsqueeze(2) 32 | xw = torch.matmul(self.W[i], x) 33 | xw = xw.squeeze(2) 34 | x = x.squeeze(2) 35 | x = x0 * (xw + self.b[i]) + x 36 | return x 37 | 38 | 39 | class CrossNetworkV2Model(torch.nn.Module): # model 40 | """ 41 | A pytorch implementation of Deep & Cross Network - V2. 42 | Only Cross Network, without deep network. 43 | Reference: 44 | R Wang, et al. DCN V2: Improved Deep & Cross Network for Feature Cross Learning in Web-scale Learning to Rank Systems, 2021 45 | """ 46 | def __init__(self, field_dims, embed_dim, num_layers, mlp_dims, dropout): 47 | super().__init__() 48 | self.num_fields = len(field_dims) 49 | self.embed_dim = embed_dim 50 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 51 | self.embed_output_dim = len(field_dims) * embed_dim 52 | self.cn = CrossNetworkV2(self.embed_output_dim, num_layers, self.num_fields, embed_dim) 53 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False) 54 | self.linear = torch.nn.Linear(mlp_dims[-1], 1) 55 | 56 | def forward(self, x): 57 | """ 58 | x: Tensor of size ``(batch_size, num_fields)`` 59 | self.embedding(x): Tensor of size ``(batch_size, num_fields, embed_dim)`` 60 | embed_x: Tensor of size ``(batch_size, num_fields*embed_dim)`` 61 | x: Tensor of size ``(batch_size, num_fields*embed_dim)`` 62 | """ 63 | embed_x = self.embedding(x).view(-1, self.embed_output_dim) 64 | x_l1 = self.cn(embed_x) 65 | h_l2 = self.mlp(x_l1) 66 | p = self.linear(h_l2) 67 | return torch.sigmoid(p.squeeze(1)) 68 | -------------------------------------------------------------------------------- /torchfm/model/dfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FactorizationMachine, FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron 4 | 5 | 6 | class DeepFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of DeepFM. 9 | 10 | Reference: 11 | H Guo, et al. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, 2017. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout): 15 | super().__init__() 16 | self.linear = FeaturesLinear(field_dims) 17 | self.fm = FactorizationMachine(reduce_sum=True) 18 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 19 | self.embed_output_dim = len(field_dims) * embed_dim 20 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 21 | 22 | def forward(self, x): 23 | """ 24 | :param x: Long tensor of size ``(batch_size, num_fields)`` 25 | """ 26 | embed_x = self.embedding(x) 27 | x = self.linear(x) + self.fm(embed_x) + self.mlp(embed_x.view(-1, self.embed_output_dim)) 28 | return torch.sigmoid(x.squeeze(1)) 29 | -------------------------------------------------------------------------------- /torchfm/model/ffm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesLinear, FieldAwareFactorizationMachine 4 | 5 | 6 | class FieldAwareFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Field-aware Factorization Machine. 9 | 10 | Reference: 11 | Y Juan, et al. Field-aware Factorization Machines for CTR Prediction, 2015. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim): 15 | super().__init__() 16 | self.linear = FeaturesLinear(field_dims) 17 | self.ffm = FieldAwareFactorizationMachine(field_dims, embed_dim) 18 | 19 | def forward(self, x): 20 | """ 21 | :param x: Long tensor of size ``(batch_size, num_fields)`` 22 | """ 23 | ffm_term = torch.sum(torch.sum(self.ffm(x), dim=1), dim=1, keepdim=True) 24 | x = self.linear(x) + ffm_term 25 | return torch.sigmoid(x.squeeze(1)) 26 | -------------------------------------------------------------------------------- /torchfm/model/finalmlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesEmbedding, MultiLayerPerceptron 4 | 5 | 6 | class FeatureSelection(torch.nn.Module): 7 | 8 | def __init__(self, embed_output_dim, embed_dim, fs_mlp_dims, dropout): 9 | super().__init__() 10 | self.ctx = torch.nn.Parameter(torch.zeros(1, embed_dim)) 11 | self.gate = MultiLayerPerceptron(embed_dim, fs_mlp_dims + (embed_output_dim, ), dropout, output_layer=False) 12 | 13 | def forward(self, emb): 14 | return 2 * self.gate(self.ctx.repeat(emb.shape[0], 1)) * emb 15 | 16 | 17 | class FinalMLPInter(torch.nn.Module): 18 | 19 | def __init__(self, embed_dim, embed_output_dim, mlp_dims, fs_mlp_dims, dropout): 20 | super().__init__() 21 | self.mlps = torch.nn.ModuleList([ 22 | torch.nn.Sequential( 23 | # FeatureSelection(embed_output_dim, embed_dim, fs_mlp_dims, dropout), 24 | torch.nn.Identity(), 25 | MultiLayerPerceptron(embed_output_dim, mlp_dims, dropout, output_layer=False) 26 | ) 27 | for _ in range(2) 28 | ]) 29 | 30 | def forward(self, emb): 31 | return torch.stack([mlp(emb.flatten(1)) for mlp in self.mlps], dim=-1) 32 | 33 | 34 | class FinalMLPFusion(torch.nn.Module): 35 | 36 | def __init__(self, hidden_dim): 37 | super().__init__() 38 | self.w_x = torch.nn.Linear(hidden_dim, 1) 39 | self.w_y = torch.nn.Linear(hidden_dim, 1) 40 | self.w_xy = torch.nn.Parameter(torch.zeros(hidden_dim, hidden_dim)) 41 | # torch.nn.init.xavier_normal_(self.w_xy) 42 | 43 | def forward(self, hidden): 44 | x, y = hidden[:, :, 0], hidden[:, :, 1] 45 | xy = self.w_x(x) + self.w_y(y) + ((x @ self.w_xy) * y).sum(dim=1, keepdim=True) 46 | return xy 47 | 48 | 49 | class FinalMLP(torch.nn.Module): 50 | 51 | def __init__(self, field_dims, embed_dim, mlp_dims, fs_mlp_dims, dropout): 52 | super().__init__() 53 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 54 | self.embed_output_dim = len(field_dims) * embed_dim 55 | self.inter = FinalMLPInter(embed_dim, self.embed_output_dim, mlp_dims, fs_mlp_dims, dropout) 56 | self.fusion = FinalMLPFusion(mlp_dims[-1]) 57 | 58 | def forward(self, x): 59 | cross_term = self.inter(self.embedding(x)) 60 | x = self.fusion(cross_term) 61 | return torch.sigmoid(x.squeeze(1)) 62 | 63 | 64 | class MultiFinalMLP(torch.nn.Module): 65 | 66 | def __init__(self, field_dims, embed_dims, mlp_dims, fs_mlp_dims, dropout): 67 | super().__init__() 68 | self.embeddings = torch.nn.ModuleList([FeaturesEmbedding(field_dims, embed_dim) for embed_dim in embed_dims]) 69 | self.embed_output_dims = [len(field_dims) * embed_dim for embed_dim in embed_dims] 70 | self.inters = torch.nn.ModuleList([ 71 | FinalMLPInter(embed_dim, embed_output_dim, mlp_dims, fs_mlp_dims, dropout) 72 | for embed_dim, embed_output_dim in zip(embed_dims, self.embed_output_dims) 73 | ]) 74 | self.fusion = FinalMLPFusion(mlp_dims[-1]) 75 | 76 | def forward(self, x): 77 | cross_term = torch.stack([inter(embedding(x)) for inter, embedding in zip(self.inters, self.embeddings)], dim=-1).mean(dim=-1) 78 | x = self.fusion(cross_term) 79 | return torch.sigmoid(x.squeeze(1)) 80 | -------------------------------------------------------------------------------- /torchfm/model/fm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FactorizationMachine, FeaturesEmbedding, FeaturesLinear 4 | 5 | 6 | class FactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Factorization Machine. 9 | 10 | Reference: 11 | S Rendle, Factorization Machines, 2010. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim): 15 | super().__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.linear = FeaturesLinear(field_dims) 18 | self.fm = FactorizationMachine(reduce_sum=True) 19 | 20 | def forward(self, x): 21 | """ 22 | :param x: Long tensor of size ``(batch_size, num_fields)`` 23 | """ 24 | x = self.linear(x) + self.fm(self.embedding(x)) 25 | return torch.sigmoid(x.squeeze(1)) 26 | 27 | 28 | class FactorizationMachineModelNoLinear(torch.nn.Module): 29 | """ 30 | A pytorch implementation of Factorization Machine. 31 | 32 | Reference: 33 | S Rendle, Factorization Machines, 2010. 34 | """ 35 | 36 | def __init__(self, field_dims, embed_dim): 37 | super().__init__() 38 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 39 | # self.linear = FeaturesLinear(field_dims) 40 | self.fm = FactorizationMachine(reduce_sum=True) 41 | 42 | def forward(self, x): 43 | """ 44 | :param x: Long tensor of size ``(batch_size, num_fields)`` 45 | """ 46 | x = self.fm(self.embedding(x)) 47 | return torch.sigmoid(x.squeeze(1)) 48 | 49 | 50 | class MultiFM(torch.nn.Module): 51 | """ 52 | Multi-Kernel-FM: A Multi-Embedding & Kernelization Factorization Machine Framework for CTR Prediction 53 | """ 54 | 55 | def __init__(self, field_dims, embed_dims): 56 | super().__init__() 57 | self.fms = torch.nn.ModuleList([FactorizationMachineModel(field_dims, embed_dim) for embed_dim in embed_dims]) 58 | self.weights = torch.nn.Parameter(torch.ones(len(embed_dims))) 59 | 60 | def forward(self, x): 61 | x = torch.stack([fm.linear(x) + fm.fm(fm.embedding(x)) for fm in self.fms], dim=-1) 62 | x = torch.sum(x * self.weights.unsqueeze(0).unsqueeze(0), dim=-1) 63 | return torch.sigmoid(x.squeeze(1)) 64 | -------------------------------------------------------------------------------- /torchfm/model/fnfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FieldAwareFactorizationMachine, MultiLayerPerceptron, FeaturesLinear 4 | 5 | 6 | class FieldAwareNeuralFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Field-aware Neural Factorization Machine. 9 | 10 | Reference: 11 | L Zhang, et al. Field-aware Neural Factorization Machine for Click-Through Rate Prediction, 2019. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropouts): 15 | super().__init__() 16 | self.linear = FeaturesLinear(field_dims) 17 | self.ffm = FieldAwareFactorizationMachine(field_dims, embed_dim) 18 | self.ffm_output_dim = len(field_dims) * (len(field_dims) - 1) // 2 * embed_dim 19 | self.bn = torch.nn.BatchNorm1d(self.ffm_output_dim) 20 | self.dropout = torch.nn.Dropout(dropouts[0]) 21 | self.mlp = MultiLayerPerceptron(self.ffm_output_dim, mlp_dims, dropouts[1]) 22 | 23 | def forward(self, x): 24 | """ 25 | :param x: Long tensor of size ``(batch_size, num_fields)`` 26 | """ 27 | cross_term = self.ffm(x).view(-1, self.ffm_output_dim) 28 | cross_term = self.bn(cross_term) 29 | cross_term = self.dropout(cross_term) 30 | x = self.linear(x) + self.mlp(cross_term) 31 | return torch.sigmoid(x.squeeze(1)) 32 | -------------------------------------------------------------------------------- /torchfm/model/fnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesEmbedding, MultiLayerPerceptron 4 | 5 | 6 | class FactorizationSupportedNeuralNetworkModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Neural Factorization Machine. 9 | 10 | Reference: 11 | W Zhang, et al. Deep Learning over Multi-field Categorical Data - A Case Study on User Response Prediction, 2016. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout): 15 | super().__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.embed_output_dim = len(field_dims) * embed_dim 18 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 19 | 20 | def forward(self, x): 21 | """ 22 | :param x: Long tensor of size ``(batch_size, num_fields)`` 23 | """ 24 | embed_x = self.embedding(x) 25 | x = self.mlp(embed_x.view(-1, self.embed_output_dim)) 26 | return torch.sigmoid(x.squeeze(1)) 27 | -------------------------------------------------------------------------------- /torchfm/model/fwfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchfm.layer import FeaturesEmbedding, FeaturesEmbeddingMultiDevice, MultiLayerPerceptron, FeaturesLinear 5 | 6 | 7 | class NFwFM(torch.nn.Module): 8 | 9 | def __init__(self, num_fields): 10 | super().__init__() 11 | self.num_fields = num_fields 12 | self.weight = nn.Parameter(torch.randn( 13 | 1, self.num_fields, self.num_fields 14 | ), requires_grad=True) 15 | 16 | def forward(self, inputs): 17 | batch_size = inputs.shape[0] 18 | weight = self.weight.expand(batch_size, -1, -1, -1) # B x 1 x F x F 19 | inputs_a = inputs.transpose(1, 2).unsqueeze(dim=-1) # B x E x F x 1 20 | inputs_b = inputs.transpose(1, 2).unsqueeze(dim=-2) # B x E x 1 x F 21 | 22 | # fwfm_inter_list = [] 23 | # for f1 in range(self.num_fields): 24 | # fwfm_inter_list.append((inputs[:, f1, :].unsqueeze(1) * inputs[:, :, :] * self.weight[:, f1, :].unsqueeze(2)).sum(dim=1)) 25 | # fwfm_inter = sum(fwfm_inter_list) 26 | 27 | fwfm_inter = torch.matmul(inputs_a, inputs_b) * weight # B x E x F x F 28 | fwfm_inter = torch.sum(torch.sum(fwfm_inter, dim=-1), dim=-1) # [batch_size, emb_dim] 29 | 30 | return fwfm_inter 31 | 32 | 33 | class NFwFMModel(torch.nn.Module): 34 | """ 35 | A pytorch implementation of Neural Factorization Machine. 36 | 37 | Reference: 38 | X He and TS Chua, Neural Factorization Machines for Sparse Predictive Analytics, 2017. 39 | """ 40 | 41 | def __init__(self, field_dims, embed_dim, mlp_dims, dropouts): 42 | super().__init__() 43 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 44 | self.fm = torch.nn.Sequential( 45 | NFwFM(len(field_dims)), 46 | torch.nn.BatchNorm1d(embed_dim), 47 | torch.nn.Dropout(dropouts[0]) 48 | ) 49 | self.mlp = MultiLayerPerceptron(embed_dim, mlp_dims, dropouts[1]) 50 | 51 | def forward(self, x): 52 | """ 53 | :param x: Long tensor of size ``(batch_size, num_fields)`` 54 | """ 55 | cross_term = self.fm(self.embedding(x)) 56 | x = self.mlp(cross_term) 57 | return torch.sigmoid(x.squeeze(1)) 58 | 59 | 60 | class MultiNFwFMModel(torch.nn.Module): 61 | 62 | def __init__(self, field_dims, embed_dims, mlp_dims, dropouts): 63 | super().__init__() 64 | self.num_fields = len(field_dims) 65 | self.embeddings = torch.nn.ModuleList([FeaturesEmbedding(field_dims, embed_dim) for embed_dim in embed_dims]) 66 | self.fwfms = torch.nn.ModuleList([ 67 | torch.nn.Sequential( 68 | NFwFM(len(field_dims)), 69 | torch.nn.BatchNorm1d(embed_dim), 70 | torch.nn.Dropout(dropouts[0]), 71 | MultiLayerPerceptron(embed_dim, mlp_dims[:1], dropouts[1], output_layer=False) 72 | ) for embed_dim in embed_dims 73 | ]) 74 | self.mlp = MultiLayerPerceptron(mlp_dims[0], mlp_dims[1:], dropouts[1]) 75 | 76 | def forward(self, x): 77 | x_l1 = torch.stack([fwfm(embedding(x)) for embedding, fwfm in zip(self.embeddings, self.fwfms)], dim=-1) 78 | x_l1 = x_l1.mean(dim=-1) 79 | p = self.mlp(x_l1) 80 | return torch.sigmoid(p.squeeze(1)) 81 | -------------------------------------------------------------------------------- /torchfm/model/hofm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesLinear, FactorizationMachine, AnovaKernel, FeaturesEmbedding 4 | 5 | 6 | class HighOrderFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Higher-Order Factorization Machines. 9 | 10 | Reference: 11 | M Blondel, et al. Higher-Order Factorization Machines, 2016. 12 | """ 13 | 14 | def __init__(self, field_dims, order, embed_dim): 15 | super().__init__() 16 | if order < 1: 17 | raise ValueError(f'invalid order: {order}') 18 | self.order = order 19 | self.embed_dim = embed_dim 20 | self.linear = FeaturesLinear(field_dims) 21 | if order >= 2: 22 | self.embedding = FeaturesEmbedding(field_dims, embed_dim * (order - 1)) 23 | self.fm = FactorizationMachine(reduce_sum=True) 24 | if order >= 3: 25 | self.kernels = torch.nn.ModuleList([ 26 | AnovaKernel(order=i, reduce_sum=True) for i in range(3, order + 1) 27 | ]) 28 | 29 | def forward(self, x): 30 | """ 31 | :param x: Long tensor of size ``(batch_size, num_fields)`` 32 | """ 33 | y = self.linear(x).squeeze(1) 34 | if self.order >= 2: 35 | x = self.embedding(x) 36 | x_part = x[:, :, :self.embed_dim] 37 | y += self.fm(x_part).squeeze(1) 38 | for i in range(self.order - 2): 39 | x_part = x[:, :, (i + 1) * self.embed_dim: (i + 2) * self.embed_dim] 40 | y += self.kernels[i](x_part).squeeze(1) 41 | return torch.sigmoid(y) 42 | -------------------------------------------------------------------------------- /torchfm/model/lr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesLinear 4 | 5 | 6 | class LogisticRegressionModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Logistic Regression. 9 | """ 10 | 11 | def __init__(self, field_dims): 12 | super().__init__() 13 | self.linear = FeaturesLinear(field_dims) 14 | 15 | def forward(self, x): 16 | """ 17 | :param x: Long tensor of size ``(batch_size, num_fields)`` 18 | """ 19 | return torch.sigmoid(self.linear(x).squeeze(1)) 20 | -------------------------------------------------------------------------------- /torchfm/model/mwd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesLinear, MultiLayerPerceptron, FeaturesEmbedding 4 | 5 | 6 | class SingleHeadDeepModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of wide and deep learning. 9 | 10 | Reference: 11 | HT Cheng, et al. Wide & Deep Learning for Recommender Systems, 2016. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout): 15 | super().__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.embed_output_dim = len(field_dims) * embed_dim 18 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False) 19 | 20 | def forward(self, x): 21 | """ 22 | :param x: Long tensor of size ``(batch_size, num_fields)`` 23 | """ 24 | embed_x = self.embedding(x) 25 | x = self.mlp(embed_x.view(-1, self.embed_output_dim)) 26 | return x 27 | 28 | 29 | class MultiHeadWideAndDeepModel(torch.nn.Module): 30 | def __init__(self, field_dims, embed_dims, mlp_dims, dropout): 31 | super().__init__() 32 | self.linear = FeaturesLinear(field_dims) 33 | self.embeddings = torch.nn.ModuleList([SingleHeadDeepModel(field_dims, embed_dim, mlp_dims, dropout) for embed_dim in embed_dims]) 34 | self.mlp = torch.nn.Linear(mlp_dims[-1], 1) 35 | self.weights = torch.nn.Parameter(torch.ones(len(embed_dims))) 36 | 37 | def forward(self, x): 38 | """ 39 | :param x: Long tensor of size ``(batch_size, num_fields)`` 40 | """ 41 | embed_x = torch.stack([embedding(x) for embedding in self.embeddings], dim=-1) 42 | embed_x = torch.sum(embed_x * self.weights.unsqueeze(0).unsqueeze(0), dim=-1) 43 | x = self.linear(x) + self.mlp(embed_x) 44 | return torch.sigmoid(x.squeeze(1)) 45 | 46 | 47 | class MultiHeadWideAndDeepModelNoLinear(torch.nn.Module): 48 | def __init__(self, field_dims, embed_dims, mlp_dims, dropout): 49 | super().__init__() 50 | self.embeddings = torch.nn.ModuleList([SingleHeadDeepModel(field_dims, embed_dim, mlp_dims, dropout) for embed_dim in embed_dims]) 51 | self.mlp = torch.nn.Linear(mlp_dims[-1], 1) 52 | self.weights = torch.nn.Parameter(torch.ones(len(embed_dims))) 53 | 54 | def forward(self, x): 55 | """ 56 | :param x: Long tensor of size ``(batch_size, num_fields)`` 57 | """ 58 | embed_x = torch.stack([embedding(x) for embedding in self.embeddings], dim=-1) 59 | embed_x = torch.sum(embed_x * self.weights.unsqueeze(0).unsqueeze(0), dim=-1) 60 | x = self.mlp(embed_x) 61 | return torch.sigmoid(x.squeeze(1)) 62 | 63 | 64 | class SharedEmbeddingMWDNoLinear(torch.nn.Module): 65 | def __init__(self, field_dims, embed_dim, num_mlps, mlp_dims, dropout): 66 | super().__init__() 67 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 68 | self.embed_output_dim = len(field_dims) * embed_dim 69 | self.mlps = torch.nn.ModuleList([MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False) for _ in range(num_mlps)]) 70 | self.head = torch.nn.Linear(mlp_dims[-1], 1) 71 | 72 | 73 | def forward(self, x): 74 | """ 75 | :param x: Long tensor of size ``(batch_size, num_fields)`` 76 | """ 77 | embed_x = self.embedding(x).view(-1, self.embed_output_dim) 78 | feat_x = torch.stack([mlp(embed_x) for mlp in self.mlps], dim=-1) 79 | feat_x = feat_x.mean(dim=-1) 80 | x = self.head(feat_x) 81 | return torch.sigmoid(x.squeeze(1)) 82 | 83 | 84 | """ 85 | 2023.08.03 86 | """ 87 | class DNNModel(torch.nn.Module): 88 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout): 89 | super().__init__() 90 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 91 | self.embed_output_dim = len(field_dims) * embed_dim 92 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 93 | 94 | def forward(self, x): 95 | """ 96 | :param x: Long tensor of size ``(batch_size, num_fields)`` 97 | """ 98 | embed_x = self.embedding(x) 99 | x = self.mlp(embed_x.view(-1, self.embed_output_dim)) 100 | return torch.sigmoid(x.squeeze(1)) 101 | 102 | 103 | class MultiDNNModel(torch.nn.Module): 104 | def __init__(self, field_dims, embed_dims, mlp_dims, dropout): 105 | super().__init__() 106 | self.embeddings = torch.nn.ModuleList([ 107 | torch.nn.Sequential( 108 | FeaturesEmbedding(field_dims, embed_dim), 109 | torch.nn.Flatten(start_dim=1), 110 | MultiLayerPerceptron(embed_dim * len(field_dims), (mlp_dims[0],), dropout, output_layer=False) 111 | ) for embed_dim in embed_dims]) 112 | self.mlp = MultiLayerPerceptron(mlp_dims[0], mlp_dims[1:], dropout) 113 | 114 | def forward(self, x): 115 | """ 116 | :param x: Long tensor of size ``(batch_size, num_fields)`` 117 | """ 118 | embed_xs = torch.stack([embedding(x) for embedding in self.embeddings], dim=-1) 119 | embed_x = embed_xs.mean(dim=-1) 120 | x = self.mlp(embed_x) 121 | return torch.sigmoid(x.squeeze(1)) 122 | -------------------------------------------------------------------------------- /torchfm/model/ncf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchfm.layer import FeaturesEmbedding, MultiLayerPerceptron 3 | 4 | 5 | class NeuralCollaborativeFiltering(torch.nn.Module): 6 | """ 7 | A pytorch implementation of Neural Collaborative Filtering. 8 | 9 | Reference: 10 | X He, et al. Neural Collaborative Filtering, 2017. 11 | """ 12 | 13 | def __init__(self, field_dims, user_field_idx, item_field_idx, embed_dim, mlp_dims, dropout): 14 | super().__init__() 15 | self.user_field_idx = user_field_idx 16 | self.item_field_idx = item_field_idx 17 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 18 | self.embed_output_dim = len(field_dims) * embed_dim 19 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False) 20 | self.fc = torch.nn.Linear(mlp_dims[-1] + embed_dim, 1) 21 | 22 | def forward(self, x): 23 | """ 24 | :param x: Long tensor of size ``(batch_size, num_user_fields)`` 25 | """ 26 | x = self.embedding(x) 27 | user_x = x[:, self.user_field_idx].squeeze(1) 28 | item_x = x[:, self.item_field_idx].squeeze(1) 29 | x = self.mlp(x.view(-1, self.embed_output_dim)) 30 | gmf = user_x * item_x 31 | x = torch.cat([gmf, x], dim=1) 32 | x = self.fc(x).squeeze(1) 33 | return torch.sigmoid(x) 34 | -------------------------------------------------------------------------------- /torchfm/model/nfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FactorizationMachine, FeaturesEmbedding, MultiLayerPerceptron, FeaturesLinear 4 | 5 | 6 | class NeuralFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Neural Factorization Machine. 9 | 10 | Reference: 11 | X He and TS Chua, Neural Factorization Machines for Sparse Predictive Analytics, 2017. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropouts): 15 | super().__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.linear = FeaturesLinear(field_dims) 18 | self.fm = torch.nn.Sequential( 19 | FactorizationMachine(reduce_sum=False), 20 | torch.nn.BatchNorm1d(embed_dim), 21 | torch.nn.Dropout(dropouts[0]) 22 | ) 23 | self.mlp = MultiLayerPerceptron(embed_dim, mlp_dims, dropouts[1]) 24 | 25 | def forward(self, x): 26 | """ 27 | :param x: Long tensor of size ``(batch_size, num_fields)`` 28 | """ 29 | cross_term = self.fm(self.embedding(x)) 30 | x = self.linear(x) + self.mlp(cross_term) 31 | return torch.sigmoid(x.squeeze(1)) 32 | 33 | 34 | class NeuralFactorizationMachineModelNoLinear(torch.nn.Module): 35 | """ 36 | A pytorch implementation of Neural Factorization Machine. 37 | 38 | Reference: 39 | X He and TS Chua, Neural Factorization Machines for Sparse Predictive Analytics, 2017. 40 | """ 41 | 42 | def __init__(self, field_dims, embed_dim, mlp_dims, dropouts): 43 | super().__init__() 44 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 45 | # self.linear = FeaturesLinear(field_dims) 46 | self.fm = torch.nn.Sequential( 47 | FactorizationMachine(reduce_sum=False), 48 | torch.nn.BatchNorm1d(embed_dim), 49 | torch.nn.Dropout(dropouts[0]) 50 | ) 51 | self.mlp = MultiLayerPerceptron(embed_dim, mlp_dims, dropouts[1]) 52 | 53 | def forward(self, x): 54 | """ 55 | :param x: Long tensor of size ``(batch_size, num_fields)`` 56 | """ 57 | cross_term = self.fm(self.embedding(x)) 58 | x = self.mlp(cross_term) 59 | return torch.sigmoid(x.squeeze(1)) 60 | 61 | 62 | class MultiNFMModelNoLinear(torch.nn.Module): 63 | 64 | def __init__(self, field_dims, embed_dims, mlp_dims, dropouts): 65 | super().__init__() 66 | self.num_fields = len(field_dims) 67 | self.embeddings = torch.nn.ModuleList([FeaturesEmbedding(field_dims, embed_dim) for embed_dim in embed_dims]) 68 | self.fms = torch.nn.ModuleList([ 69 | torch.nn.Sequential( 70 | FactorizationMachine(reduce_sum=False), 71 | torch.nn.BatchNorm1d(embed_dim), 72 | torch.nn.Dropout(dropouts[0]), 73 | MultiLayerPerceptron(embed_dim, mlp_dims[:1], dropouts[1], output_layer=False) 74 | ) for embed_dim in embed_dims 75 | ]) 76 | self.mlp = MultiLayerPerceptron(mlp_dims[0], mlp_dims[1:], dropouts[1]) 77 | 78 | def forward(self, x): 79 | x_l1 = torch.stack([fm(embedding(x)) for embedding, fm in zip(self.embeddings, self.fms)], dim=-1) 80 | x_l1 = x_l1.mean(dim=-1) 81 | p = self.mlp(x_l1) 82 | return torch.sigmoid(p.squeeze(1)) -------------------------------------------------------------------------------- /torchfm/model/pnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesEmbedding, FeaturesLinear, InnerProductNetwork, \ 4 | OuterProductNetwork, MultiLayerPerceptron 5 | 6 | 7 | class ProductNeuralNetworkModel(torch.nn.Module): 8 | """ 9 | A pytorch implementation of inner/outer Product Neural Network. 10 | Reference: 11 | Y Qu, et al. Product-based Neural Networks for User Response Prediction, 2016. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout, method='inner'): 15 | super().__init__() 16 | num_fields = len(field_dims) 17 | if method == 'inner': 18 | self.pn = InnerProductNetwork() 19 | elif method == 'outer': 20 | self.pn = OuterProductNetwork(num_fields, embed_dim) 21 | else: 22 | raise ValueError('unknown product type: ' + method) 23 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 24 | # self.linear = FeaturesLinear(field_dims, embed_dim) 25 | self.embed_output_dim = num_fields * embed_dim 26 | self.mlp = MultiLayerPerceptron(num_fields * (num_fields - 1) // 2 + self.embed_output_dim, mlp_dims, dropout) 27 | 28 | def forward(self, x): 29 | """ 30 | :param x: Long tensor of size ``(batch_size, num_fields)`` 31 | """ 32 | embed_x = self.embedding(x) 33 | cross_term = self.pn(embed_x) 34 | x = torch.cat([embed_x.view(-1, self.embed_output_dim), cross_term], dim=1) 35 | x = self.mlp(x) 36 | return torch.sigmoid(x.squeeze(1)) 37 | 38 | 39 | class MultiPNNModel(torch.nn.Module): 40 | 41 | def __init__(self, field_dims, embed_dims, mlp_dims, dropout, method='inner'): 42 | super().__init__() 43 | self.num_fields = len(field_dims) 44 | self.embeddings = torch.nn.ModuleList([FeaturesEmbedding(field_dims, embed_dim) for embed_dim in embed_dims]) 45 | self.embed_output_dim = self.num_fields * (self.num_fields - 1) // 2 46 | if method == 'inner': 47 | self.pns = torch.nn.ModuleList([ 48 | torch.nn.Sequential( 49 | InnerProductNetwork(), 50 | MultiLayerPerceptron(self.embed_output_dim, mlp_dims[:1], dropout, output_layer=False) 51 | ) for _ in range(len(self.embeddings)) 52 | ]) 53 | elif method == 'outer': 54 | self.pns = torch.nn.ModuleList([ 55 | torch.nn.Sequential( 56 | OuterProductNetwork(self.num_fields, embed_dim), 57 | MultiLayerPerceptron(self.embed_output_dim, mlp_dims[:1], dropout, output_layer=False) 58 | ) for embed_dim in embed_dims 59 | ]) 60 | self.mlp = MultiLayerPerceptron(mlp_dims[0], mlp_dims[1:], dropout) 61 | 62 | def forward(self, x): 63 | x_l1 = torch.stack([pn(embedding(x)) for embedding, pn in zip(self.embeddings, self.pns)], dim=-1) 64 | x_l1 = x_l1.mean(dim=-1) 65 | p = self.mlp(x_l1) 66 | return torch.sigmoid(p.squeeze(1)) 67 | -------------------------------------------------------------------------------- /torchfm/model/rdcnv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchfm.layer import FeaturesEmbedding, MultiLayerPerceptron 3 | 4 | 5 | class CriterionWithLoss(torch.nn.Module): 6 | 7 | def __init__(self, criterion): 8 | super().__init__() 9 | self.criterion = criterion 10 | 11 | def forward(self, input, target): 12 | input_real, losses = input[0], input[1:] 13 | return self.criterion(input_real, target) + sum(losses) 14 | 15 | 16 | class RestrictedCrossNetworkV2(torch.nn.Module): # layer 17 | 18 | def __init__(self, input_dim, num_layers, num_fields, embed_dim): 19 | """ 20 | input_dim: num_fields*embed_dim 21 | """ 22 | super().__init__() 23 | self.num_layers = num_layers 24 | self.num_fields = num_fields 25 | self.embed_dim = embed_dim 26 | self.input_dim = input_dim 27 | self.W = torch.nn.ParameterList([ 28 | torch.nn.Parameter(torch.Tensor(input_dim, input_dim)) for _ in range(num_layers) 29 | ]) 30 | self.b = torch.nn.ParameterList([ 31 | torch.nn.Parameter(torch.zeros(input_dim, )) for _ in range(num_layers) 32 | ]) 33 | for i in range(num_layers): 34 | torch.nn.init.xavier_uniform_(self.W[i]) 35 | 36 | def forward(self, x): 37 | """ 38 | x: Tensor of size ``(batch_size, num_fields*embed_dim)`` 39 | """ 40 | x0 = x 41 | reg_loss = 0 42 | for i in range(self.num_layers): 43 | x = x.unsqueeze(2) 44 | xw = torch.matmul(self.W[i], x) 45 | xw = xw.squeeze(2) 46 | x = x.squeeze(2) 47 | x = x0 * (xw + self.b[i]) + x 48 | # Calculate regularization 49 | field_wise_w = self.W[i].reshape(self.num_fields, self.embed_dim, self.num_fields, self.embed_dim) 50 | field_wise_w = field_wise_w.transpose(1, 2).reshape(-1, self.embed_dim, self.embed_dim) 51 | identities = torch.eye(self.embed_dim, device=field_wise_w.device).unsqueeze(0) 52 | reg_loss += torch.square(identities - torch.bmm(field_wise_w, field_wise_w.transpose(1, 2))).sum() 53 | return x, reg_loss 54 | 55 | 56 | class RestrictedCrossNetworkV2Model(torch.nn.Module): # model 57 | """ 58 | A pytorch implementation of Deep & Cross Network - V2. 59 | Only Cross Network, without deep network. 60 | Reference: 61 | R Wang, et al. DCN V2: Improved Deep & Cross Network for Feature Cross Learning in Web-scale Learning to Rank Systems, 2021 62 | """ 63 | def __init__(self, field_dims, embed_dim, num_layers, mlp_dims, dropout): 64 | super().__init__() 65 | self.num_fields = len(field_dims) 66 | self.embed_dim = embed_dim 67 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 68 | self.embed_output_dim = len(field_dims) * embed_dim 69 | self.cn = RestrictedCrossNetworkV2(self.embed_output_dim, num_layers, self.num_fields, embed_dim) 70 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False) 71 | self.linear = torch.nn.Linear(mlp_dims[-1], 1) 72 | 73 | def forward(self, x): 74 | """ 75 | x: Tensor of size ``(batch_size, num_fields)`` 76 | self.embedding(x): Tensor of size ``(batch_size, num_fields, embed_dim)`` 77 | embed_x: Tensor of size ``(batch_size, num_fields*embed_dim)`` 78 | x: Tensor of size ``(batch_size, num_fields*embed_dim)`` 79 | """ 80 | embed_x = self.embedding(x).view(-1, self.embed_output_dim) 81 | x_l1, reg_loss = self.cn(embed_x) 82 | h_l2 = self.mlp(x_l1) 83 | p = self.linear(h_l2) 84 | if self.training: 85 | return torch.sigmoid(p.squeeze(1)), 1e-4 * reg_loss 86 | else: 87 | return torch.sigmoid(p.squeeze(1)) 88 | -------------------------------------------------------------------------------- /torchfm/model/temp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ME_TE(BaseModel): 5 | def __init__(self, feature_columns, device, num_tasks, tasks, 6 | target_name, loss_fn, l2_reg_embedding, tower_hidden_units, dropout,batch_norm,**kwargs): 7 | super(ME_TE, self).__init__(feature_columns, device, num_tasks, tasks, target_name, loss_fn, l2_reg_embedding) 8 | self.input_dim = self.get_input_dim(feature_columns) 9 | self.embedding_dict = nn.ModuleList([create_embedding_matrix(feature_columns, device=self.device) for _ in range(num_tasks)]) 10 | if self.dense_feat_num>0: 11 | self.dense_feature_embedding = nn.ModuleList([nn.Linear(self.dense_feat_num, self.sparse_feat_dim) for _ in range(num_tasks)]) 12 | self.towers = nn.ModuleList([MLP_Layer(input_dim=self.input_dim, 13 | hidden_units=tower_hidden_units, 14 | output_dim = 1, 15 | dropout_rates=dropout, 16 | batch_norm=batch_norm) for i in range(self.num_tasks)]) 17 | self.out = nn.ModuleList([PredictionLayer(task) for task in self.tasks]) 18 | self.regularization_weight = [] 19 | self.rep_gate = nn.Parameter(torch.normal(mean=0., std=1e-4, size=(num_tasks, num_tasks)), requires_grad=True) 20 | self.add_regularization_weight(self.embedding_dict.parameters(), l2=l2_reg_embedding) 21 | def predict(self, inputs): 22 | input_embedding = [] 23 | if self.dense_feat_num>0: 24 | task_embedding = [self.get_embedding(inputs, self.embedding_dict[i], self.dense_feature_embedding[i]) for i in range(self.num_tasks)] 25 | else: 26 | task_embedding = [self.get_embedding(inputs, self.embedding_dict[i]) for i in range(self.num_tasks)] 27 | rep_gate = self.rep_gate 28 | for i in range(self.num_tasks): 29 | task_input = [] 30 | for j in range(self.num_tasks): 31 | if j != i : 32 | task_input.append(task_embedding[j].detach()) 33 | else: 34 | task_input.append(task_embedding[j]) 35 | task_input = torch.stack(task_input,dim=1) # (Batchsize, num_tasks, m*D) 36 | task_gate = rep_gate[i,:].view(1,-1,1) # (1,num_tasks,1) 37 | task_input = torch.multiply(task_input,task_gate).sum(1) # (B, m*D) 38 | input_embedding.append(task_input) 39 | if not self.training: 40 | self.cache['rep_gate'] = rep_gate 41 | # tower 42 | output = [] 43 | for i in range(self.num_tasks): 44 | tower_output = self.towers[i](input_embedding[i]) #(Batchsize, 1) 45 | tower_output = self.out[i](tower_output) 46 | output.append(tower_output) 47 | result = torch.cat(output,-1) 48 | return result -------------------------------------------------------------------------------- /torchfm/model/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesLinear, MultiLayerPerceptron, FeaturesEmbedding 4 | 5 | 6 | class WideAndDeepModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of wide and deep learning. 9 | 10 | Reference: 11 | HT Cheng, et al. Wide & Deep Learning for Recommender Systems, 2016. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout): 15 | super().__init__() 16 | self.linear = FeaturesLinear(field_dims) 17 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 18 | self.embed_output_dim = len(field_dims) * embed_dim 19 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 20 | 21 | def forward(self, x): 22 | """ 23 | :param x: Long tensor of size ``(batch_size, num_fields)`` 24 | """ 25 | embed_x = self.embedding(x) 26 | x = self.linear(x) + self.mlp(embed_x.view(-1, self.embed_output_dim)) 27 | return torch.sigmoid(x.squeeze(1)) 28 | -------------------------------------------------------------------------------- /torchfm/model/wd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesLinear, MultiLayerPerceptron, FeaturesEmbedding 4 | 5 | 6 | class WideAndDeepModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of wide and deep learning. 9 | 10 | Reference: 11 | HT Cheng, et al. Wide & Deep Learning for Recommender Systems, 2016. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout): 15 | super().__init__() 16 | self.linear = FeaturesLinear(field_dims) 17 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 18 | self.embed_output_dim = len(field_dims) * embed_dim 19 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 20 | 21 | def forward(self, x): 22 | """ 23 | :param x: Long tensor of size ``(batch_size, num_fields)`` 24 | """ 25 | embed_x = self.embedding(x) 26 | x = self.linear(x) + self.mlp(embed_x.view(-1, self.embed_output_dim)) 27 | return torch.sigmoid(x.squeeze(1)) 28 | 29 | 30 | class WideAndDeepModelNoLinear(torch.nn.Module): 31 | """ 32 | A pytorch implementation of wide and deep learning. 33 | 34 | Reference: 35 | HT Cheng, et al. Wide & Deep Learning for Recommender Systems, 2016. 36 | """ 37 | 38 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout): 39 | super().__init__() 40 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 41 | self.embed_output_dim = len(field_dims) * embed_dim 42 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 43 | 44 | def forward(self, x): 45 | """ 46 | :param x: Long tensor of size ``(batch_size, num_fields)`` 47 | """ 48 | embed_x = self.embedding(x) 49 | x = self.mlp(embed_x.view(-1, self.embed_output_dim)) 50 | return torch.sigmoid(x.squeeze(1)) 51 | -------------------------------------------------------------------------------- /torchfm/model/wrdcnv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchfm.layer import FeaturesEmbedding, MultiLayerPerceptron 3 | 4 | 5 | class CriterionWithLoss(torch.nn.Module): 6 | 7 | def __init__(self, criterion): 8 | super().__init__() 9 | self.criterion = criterion 10 | 11 | def forward(self, input, target): 12 | input_real, losses = input[0], input[1:] 13 | return self.criterion(input_real, target) + sum(losses) 14 | 15 | 16 | class WeightedRestrictedCrossNetworkV2(torch.nn.Module): # layer 17 | 18 | def __init__(self, input_dim, num_layers, num_fields, embed_dim): 19 | """ 20 | input_dim: num_fields*embed_dim 21 | """ 22 | super().__init__() 23 | self.num_layers = num_layers 24 | self.num_fields = num_fields 25 | self.embed_dim = embed_dim 26 | self.input_dim = input_dim 27 | self.N = torch.nn.ParameterList([ 28 | torch.nn.Parameter(torch.ones(num_fields, num_fields)) for _ in range(num_layers) 29 | ]) 30 | self.W = torch.nn.ParameterList([ 31 | torch.nn.Parameter(torch.Tensor(input_dim, input_dim)) for _ in range(num_layers) 32 | ]) 33 | self.b = torch.nn.ParameterList([ 34 | torch.nn.Parameter(torch.zeros(input_dim, )) for _ in range(num_layers) 35 | ]) 36 | for i in range(num_layers): 37 | torch.nn.init.xavier_uniform_(self.W[i]) 38 | 39 | def forward(self, x): 40 | """ 41 | x: Tensor of size ``(batch_size, num_fields*embed_dim)`` 42 | """ 43 | x0 = x 44 | reg_loss = 0 45 | for i in range(self.num_layers): 46 | x = x.unsqueeze(2) 47 | ni = self.N[i].unsqueeze(1).unsqueeze(3).repeat(1, self.embed_dim, 1, self.embed_dim).reshape(self.input_dim, self.input_dim) 48 | xw = torch.matmul(ni * self.W[i], x) 49 | xw = xw.squeeze(2) 50 | x = x.squeeze(2) 51 | x = x0 * (xw + self.b[i]) + x 52 | # Calculate regularization 53 | field_wise_w = self.W[i].reshape(self.num_fields, self.embed_dim, self.num_fields, self.embed_dim) 54 | field_wise_w = field_wise_w.transpose(1, 2).reshape(-1, self.embed_dim, self.embed_dim) 55 | identities = torch.eye(self.embed_dim, device=field_wise_w.device).unsqueeze(0) 56 | reg_loss += torch.square(identities - torch.bmm(field_wise_w, field_wise_w.transpose(1, 2))).sum() 57 | return x, reg_loss 58 | 59 | 60 | class WeightedRestrictedCrossNetworkV2Model(torch.nn.Module): # model 61 | """ 62 | A pytorch implementation of Deep & Cross Network - V2. 63 | Only Cross Network, without deep network. 64 | Reference: 65 | R Wang, et al. DCN V2: Improved Deep & Cross Network for Feature Cross Learning in Web-scale Learning to Rank Systems, 2021 66 | """ 67 | def __init__(self, field_dims, embed_dim, num_layers, mlp_dims, dropout): 68 | super().__init__() 69 | self.num_fields = len(field_dims) 70 | self.embed_dim = embed_dim 71 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 72 | self.embed_output_dim = len(field_dims) * embed_dim 73 | self.cn = WeightedRestrictedCrossNetworkV2(self.embed_output_dim, num_layers, self.num_fields, embed_dim) 74 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False) 75 | self.linear = torch.nn.Linear(mlp_dims[-1], 1) 76 | 77 | def forward(self, x): 78 | """ 79 | x: Tensor of size ``(batch_size, num_fields)`` 80 | self.embedding(x): Tensor of size ``(batch_size, num_fields, embed_dim)`` 81 | embed_x: Tensor of size ``(batch_size, num_fields*embed_dim)`` 82 | x: Tensor of size ``(batch_size, num_fields*embed_dim)`` 83 | """ 84 | embed_x = self.embedding(x).view(-1, self.embed_output_dim) 85 | x_l1, reg_loss = self.cn(embed_x) 86 | h_l2 = self.mlp(x_l1) 87 | p = self.linear(h_l2) 88 | if self.training: 89 | return torch.sigmoid(p.squeeze(1)), 1e-5 * reg_loss 90 | else: 91 | return torch.sigmoid(p.squeeze(1)) 92 | 93 | 94 | class WeightedRestrictedMultiDCN(torch.nn.Module): 95 | def __init__(self, field_dims, embed_dims, num_layers, mlp_dims, dropout): 96 | super().__init__() 97 | self.num_fields = len(field_dims) 98 | self.embeddings = torch.nn.ModuleList([FeaturesEmbedding(field_dims, embed_dim) for embed_dim in embed_dims]) 99 | self.embed_output_dims = [len(field_dims) * embed_dim for embed_dim in embed_dims] 100 | self.cns = torch.nn.ModuleList([ 101 | WeightedRestrictedCrossNetworkV2(embed_output_dim, num_layers, self.num_fields, embed_dim) 102 | for embed_dim, embed_output_dim in zip(embed_dims, self.embed_output_dims) 103 | ]) 104 | self.projs = torch.nn.ModuleList([ 105 | MultiLayerPerceptron(embed_output_dim, mlp_dims[:1], dropout, output_layer=False) 106 | for embed_output_dim in self.embed_output_dims 107 | ]) 108 | self.mlp = MultiLayerPerceptron(mlp_dims[0], mlp_dims[1:], dropout) 109 | 110 | 111 | def forward(self, x): 112 | h_and_regs = [cn(embedding(x).view(-1, embed_output_dim)) 113 | for cn, embedding, embed_output_dim in zip(self.cns, self.embeddings, self.embed_output_dims)] 114 | x_l1 = torch.stack([proj(h_and_reg[0]) for proj, h_and_reg in zip(self.projs, h_and_regs)], dim=-1).mean(dim=-1) 115 | reg_loss = sum([h_and_reg[1] for h_and_reg in h_and_regs]) 116 | p = self.mlp(x_l1) 117 | if self.training: 118 | return torch.sigmoid(p.squeeze(1)), 1e-5 * reg_loss 119 | else: 120 | return torch.sigmoid(p.squeeze(1)) 121 | -------------------------------------------------------------------------------- /torchfm/model/xdfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torchfm.layer import CompressedInteractionNetwork, FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron 5 | 6 | 7 | class ExtremeDeepFactorizationMachineModel(torch.nn.Module): 8 | """ 9 | A pytorch implementation of xDeepFM. 10 | 11 | Reference: 12 | J Lian, et al. xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems, 2018. 13 | """ 14 | 15 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout, cross_layer_sizes, split_half=True): 16 | super().__init__() 17 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 18 | self.embed_output_dim = len(field_dims) * embed_dim 19 | self.cin = CompressedInteractionNetwork(len(field_dims), cross_layer_sizes, split_half) 20 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 21 | self.linear = FeaturesLinear(field_dims) 22 | 23 | def forward(self, x): 24 | """ 25 | :param x: Long tensor of size ``(batch_size, num_fields)`` 26 | """ 27 | embed_x = self.embedding(x) 28 | x = self.linear(x) + self.cin(embed_x) + self.mlp(embed_x.view(-1, self.embed_output_dim)) 29 | return torch.sigmoid(x.squeeze(1)) 30 | 31 | 32 | class CIN(torch.nn.Module): 33 | """ 34 | CIN w/o final linear layer 35 | """ 36 | 37 | def __init__(self, input_dim, cross_layer_sizes): 38 | super().__init__() 39 | self.num_layers = len(cross_layer_sizes) 40 | self.conv_layers = torch.nn.ModuleList() 41 | self.output_dim = 0 42 | prev_dim = input_dim 43 | for i in range(self.num_layers): 44 | cross_layer_size = cross_layer_sizes[i] 45 | self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, 46 | stride=1, dilation=1, bias=True)) 47 | # (B, N x C, K) -> (B, C', K) 48 | prev_dim = cross_layer_size 49 | self.output_dim += prev_dim 50 | 51 | def forward(self, x): 52 | """ 53 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 54 | """ 55 | xs = list() 56 | x0, h = x.unsqueeze(2), x 57 | # x0: (B, N, 1, K) 58 | # h: (B, C, K) 59 | for i in range(self.num_layers): 60 | x = x0 * h.unsqueeze(1) 61 | batch_size, f0_dim, fin_dim, embed_dim = x.shape 62 | x = x.view(batch_size, f0_dim * fin_dim, embed_dim) 63 | x = F.relu(self.conv_layers[i](x)) 64 | h = x 65 | xs.append(x) 66 | return torch.sum(torch.cat(xs, dim=1), 2) 67 | 68 | 69 | class XDeepFM(torch.nn.Module): 70 | """ 71 | XDeepFM w/o Linear 72 | """ 73 | 74 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout, cross_layer_sizes): 75 | super().__init__() 76 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 77 | self.embed_output_dim = len(field_dims) * embed_dim 78 | self.cin = CIN(len(field_dims), cross_layer_sizes) 79 | self.cin_post = torch.nn.Linear(self.cin.output_dim, 1) 80 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) if mlp_dims else None 81 | 82 | def forward(self, x): 83 | """ 84 | :param x: Long tensor of size ``(batch_size, num_fields)`` 85 | """ 86 | embed_x = self.embedding(x) 87 | x = self.cin_post(self.cin(embed_x)) 88 | if self.mlp: 89 | x += self.mlp(embed_x.flatten(1)) 90 | return torch.sigmoid(x.squeeze(1)) 91 | 92 | 93 | class MultiXDeepFM(torch.nn.Module): 94 | 95 | def __init__(self, field_dims, embed_dims, mlp_dims, dropout, cross_layer_sizes): 96 | super().__init__() 97 | self.embeddings = torch.nn.ModuleList([FeaturesEmbedding(field_dims, embed_dim) for embed_dim in embed_dims]) 98 | self.cins = torch.nn.ModuleList([CIN(len(field_dims), cross_layer_sizes) for _ in range(len(embed_dims))]) 99 | self.cin_post = torch.nn.Linear(self.cins[0].output_dim, 1) 100 | self.mlps = torch.nn.ModuleList([ 101 | MultiLayerPerceptron(embed_dim * len(field_dims), (mlp_dims[0],), dropout, output_layer=False) 102 | for embed_dim in embed_dims 103 | ]) if mlp_dims else None 104 | self.mlp_post = MultiLayerPerceptron(mlp_dims[0], mlp_dims[1:], dropout) if mlp_dims else None 105 | 106 | def forward(self, x): 107 | embs = [embedding(x) for embedding in self.embeddings] 108 | cin_feature = torch.stack([cin(emb) for cin, emb in zip(self.cins, embs)], dim=-1).mean(dim=-1) 109 | x = self.cin_post(cin_feature) 110 | if self.mlps: 111 | mlp_feature = torch.stack([mlp(emb.flatten(1)) for mlp, emb in zip(self.mlps, embs)], dim=-1).mean(dim=-1) 112 | x += self.mlp_post(mlp_feature) 113 | return torch.sigmoid(x.squeeze(1)) 114 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | from typing import Optional, List 6 | 7 | 8 | class CriterionWithLoss(torch.nn.Module): 9 | 10 | def __init__(self, criterion): 11 | super().__init__() 12 | self.criterion = criterion 13 | 14 | def forward(self, input, target): 15 | input_real, losses = input[0], input[1:] 16 | return self.criterion(input_real, target) + sum(losses) 17 | 18 | 19 | class EarlyStopper(object): 20 | 21 | def __init__(self, num_trials, save_paths): 22 | self.num_trials = num_trials 23 | self.trial_counter = 0 24 | self.best_accuracy = 0 25 | self.save_paths = save_paths 26 | 27 | def is_continuable(self, models, accuracy): 28 | if accuracy > self.best_accuracy: 29 | self.best_accuracy = accuracy 30 | self.trial_counter = 0 31 | for model, save_path in zip(models, self.save_paths): 32 | torch.save(model.state_dict(), save_path) 33 | return True 34 | elif self.trial_counter + 1 < self.num_trials: 35 | self.trial_counter += 1 36 | return True 37 | else: 38 | return False 39 | 40 | 41 | class AverageMeter(object): 42 | r"""Computes and stores the average and current value. 43 | 44 | Examples:: 45 | 46 | >>> # Initialize a meter to record loss 47 | >>> losses = AverageMeter() 48 | >>> # Update meter after every minibatch update 49 | >>> losses.update(loss_value, batch_size) 50 | """ 51 | def __init__(self, name: str, fmt: Optional[str] = ':f'): 52 | self.name = name 53 | self.fmt = fmt 54 | self.reset() 55 | 56 | def reset(self): 57 | self.val = 0 58 | self.avg = 0 59 | self.sum = 0 60 | self.count = 0 61 | 62 | def update(self, val, n=1): 63 | self.val = val 64 | self.sum += val * n 65 | self.count += n 66 | if self.count > 0: 67 | self.avg = self.sum / self.count 68 | 69 | def __str__(self): 70 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 71 | return fmtstr.format(**self.__dict__) 72 | 73 | 74 | class AverageMeterDict(object): 75 | def __init__(self, names: List, fmt: Optional[str] = ':f'): 76 | self.dict = { 77 | name: AverageMeter(name, fmt) for name in names 78 | } 79 | 80 | def reset(self): 81 | for meter in self.dict.values(): 82 | meter.reset() 83 | 84 | def update(self, accuracies, n=1): 85 | for name, acc in accuracies.items(): 86 | self.dict[name].update(acc, n) 87 | 88 | def average(self): 89 | return { 90 | name: meter.avg for name, meter in self.dict.items() 91 | } 92 | 93 | def __getitem__(self, item): 94 | return self.dict[item] 95 | 96 | 97 | class Meter(object): 98 | """Computes and stores the current value.""" 99 | def __init__(self, name: str, fmt: Optional[str] = ':f'): 100 | self.name = name 101 | self.fmt = fmt 102 | self.reset() 103 | 104 | def reset(self): 105 | self.val = 0 106 | 107 | def update(self, val): 108 | self.val = val 109 | 110 | def __str__(self): 111 | fmtstr = '{name} {val' + self.fmt + '}' 112 | return fmtstr.format(**self.__dict__) 113 | 114 | 115 | class ProgressMeter(object): 116 | def __init__(self, num_batches, meters, prefix=""): 117 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 118 | self.meters = meters 119 | self.prefix = prefix 120 | 121 | def display(self, batch): 122 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 123 | entries += [str(meter) for meter in self.meters] 124 | print('\t'.join(entries)) 125 | 126 | def _get_batch_fmtstr(self, num_batches): 127 | num_digits = len(str(num_batches // 1)) 128 | fmt = '{:' + str(num_digits) + 'd}' 129 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 130 | 131 | 132 | class TextLogger(object): 133 | """Writes stream output to external text file. 134 | 135 | Args: 136 | filename (str): the file to write stream output 137 | stream: the stream to read from. Default: sys.stdout 138 | """ 139 | def __init__(self, filename, stream=sys.stdout): 140 | self.terminal = stream 141 | self.log = open(filename, 'a') 142 | 143 | def write(self, message): 144 | self.terminal.write(message) 145 | self.log.write(message) 146 | self.flush() 147 | 148 | def flush(self): 149 | self.terminal.flush() 150 | self.log.flush() 151 | 152 | def close(self): 153 | self.terminal.close() 154 | self.log.close() 155 | 156 | 157 | class CompleteLogger: 158 | """ 159 | A useful logger that 160 | 161 | - writes outputs to files and displays them on the console at the same time. 162 | - manages the directory of checkpoints and debugging images. 163 | 164 | Args: 165 | root (str): the root directory of logger 166 | phase (str): the phase of training. 167 | 168 | """ 169 | 170 | def __init__(self, root, phase='train'): 171 | self.root = root 172 | self.phase = phase 173 | self.visualize_directory = os.path.join(self.root, "visualize") 174 | self.checkpoint_directory = os.path.join(self.root, "checkpoints") 175 | self.epoch = 0 176 | 177 | os.makedirs(self.root, exist_ok=True) 178 | os.makedirs(self.visualize_directory, exist_ok=True) 179 | os.makedirs(self.checkpoint_directory, exist_ok=True) 180 | 181 | # redirect std out 182 | now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time())) 183 | log_filename = os.path.join(self.root, "{}-{}.txt".format(phase, now)) 184 | if os.path.exists(log_filename): 185 | os.remove(log_filename) 186 | self.logger = TextLogger(log_filename) 187 | sys.stdout = self.logger 188 | sys.stderr = self.logger 189 | if phase != 'train': 190 | self.set_epoch(phase) 191 | 192 | def set_epoch(self, epoch): 193 | """Set the epoch number. Please use it during training.""" 194 | os.makedirs(os.path.join(self.visualize_directory, str(epoch)), exist_ok=True) 195 | self.epoch = epoch 196 | 197 | def _get_phase_or_epoch(self): 198 | if self.phase == 'train': 199 | return str(self.epoch) 200 | else: 201 | return self.phase 202 | 203 | def get_image_path(self, filename: str): 204 | """ 205 | Get the full image path for a specific filename 206 | """ 207 | return os.path.join(self.visualize_directory, self._get_phase_or_epoch(), filename) 208 | 209 | def get_checkpoint_path(self, name=None): 210 | """ 211 | Get the full checkpoint path. 212 | 213 | Args: 214 | name (optional): the filename (without file extension) to save checkpoint. 215 | If None, when the phase is ``train``, checkpoint will be saved to ``{epoch}.pth``. 216 | Otherwise, will be saved to ``{phase}.pth``. 217 | 218 | """ 219 | if name is None: 220 | name = self._get_phase_or_epoch() 221 | name = str(name) 222 | return os.path.join(self.checkpoint_directory, name + ".pth") 223 | 224 | def close(self): 225 | self.logger.close() 226 | --------------------------------------------------------------------------------