├── LICENSE ├── README.md ├── docs ├── .nojekyll ├── Makefile ├── _sources │ ├── index.rst.txt │ ├── torchfm.dataset.rst.txt │ ├── torchfm.model.rst.txt │ └── torchfm.rst.txt ├── _static │ ├── basic.css │ ├── css │ │ ├── badge_only.css │ │ └── theme.css │ ├── doctools.js │ ├── documentation_options.js │ ├── file.png │ ├── fonts │ │ ├── Inconsolata-Bold.ttf │ │ ├── Inconsolata-Regular.ttf │ │ ├── Inconsolata.ttf │ │ ├── Lato-Bold.ttf │ │ ├── Lato-Regular.ttf │ │ ├── Lato │ │ │ ├── lato-bold.eot │ │ │ ├── lato-bold.ttf │ │ │ ├── lato-bold.woff │ │ │ ├── lato-bold.woff2 │ │ │ ├── lato-bolditalic.eot │ │ │ ├── lato-bolditalic.ttf │ │ │ ├── lato-bolditalic.woff │ │ │ ├── lato-bolditalic.woff2 │ │ │ ├── lato-italic.eot │ │ │ ├── lato-italic.ttf │ │ │ ├── lato-italic.woff │ │ │ ├── lato-italic.woff2 │ │ │ ├── lato-regular.eot │ │ │ ├── lato-regular.ttf │ │ │ ├── lato-regular.woff │ │ │ └── lato-regular.woff2 │ │ ├── RobotoSlab-Bold.ttf │ │ ├── RobotoSlab-Regular.ttf │ │ ├── RobotoSlab │ │ │ ├── roboto-slab-v7-bold.eot │ │ │ ├── roboto-slab-v7-bold.ttf │ │ │ ├── roboto-slab-v7-bold.woff │ │ │ ├── roboto-slab-v7-bold.woff2 │ │ │ ├── roboto-slab-v7-regular.eot │ │ │ ├── roboto-slab-v7-regular.ttf │ │ │ ├── roboto-slab-v7-regular.woff │ │ │ └── roboto-slab-v7-regular.woff2 │ │ ├── fontawesome-webfont.eot │ │ ├── fontawesome-webfont.svg │ │ ├── fontawesome-webfont.ttf │ │ ├── fontawesome-webfont.woff │ │ └── fontawesome-webfont.woff2 │ ├── jquery-3.2.1.js │ ├── jquery.js │ ├── js │ │ ├── modernizr.min.js │ │ └── theme.js │ ├── language_data.js │ ├── minus.png │ ├── plus.png │ ├── pygments.css │ ├── searchtools.js │ ├── underscore-1.3.1.js │ └── underscore.js ├── conf.py ├── genindex.html ├── index.html ├── index.rst ├── make.bat ├── objects.inv ├── py-modindex.html ├── search.html ├── searchindex.js ├── torchfm.dataset.html ├── torchfm.dataset.rst ├── torchfm.html ├── torchfm.model.html ├── torchfm.model.rst └── torchfm.rst ├── examples └── main.py ├── requirements.txt ├── setup.py ├── test └── test_layers.py └── torchfm ├── __init__.py ├── dataset ├── __init__.py ├── avazu.py ├── criteo.py └── movielens.py ├── layer.py └── model ├── __init__.py ├── afi.py ├── afm.py ├── afn.py ├── dcn.py ├── dfm.py ├── ffm.py ├── fm.py ├── fnfm.py ├── fnn.py ├── hofm.py ├── lr.py ├── ncf.py ├── nfm.py ├── pnn.py ├── wd.py └── xdfm.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 rixwew 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Factorization Machine models in PyTorch 2 | 3 | This package provides a PyTorch implementation of factorization machine models and common datasets in CTR prediction. 4 | 5 | 6 | ## Available Datasets 7 | 8 | * [MovieLens Dataset](https://grouplens.org/datasets/movielens) 9 | * [Criteo Display Advertising Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge) 10 | * [Avazu Click-Through Rate Prediction](https://www.kaggle.com/c/avazu-ctr-prediction) 11 | 12 | 13 | ## Available Models 14 | 15 | | Model | Reference | 16 | |-------|-----------| 17 | | Logistic Regression | | 18 | | Factorization Machine | [S Rendle, Factorization Machines, 2010.](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf) | 19 | | Field-aware Factorization Machine | [Y Juan, et al. Field-aware Factorization Machines for CTR Prediction, 2015.](https://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf) | 20 | | Higher-Order Factorization Machines | [ M Blondel, et al. Higher-Order Factorization Machines, 2016.](https://dl.acm.org/doi/10.5555/3157382.3157473) | 21 | | Factorization-Supported Neural Network | [W Zhang, et al. Deep Learning over Multi-field Categorical Data - A Case Study on User Response Prediction, 2016.](https://arxiv.org/abs/1601.02376) | 22 | | Wide&Deep | [HT Cheng, et al. Wide & Deep Learning for Recommender Systems, 2016.](https://arxiv.org/abs/1606.07792) | 23 | | Attentional Factorization Machine | [J Xiao, et al. Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks, 2017.](https://arxiv.org/abs/1708.04617) | 24 | | Neural Factorization Machine | [X He and TS Chua, Neural Factorization Machines for Sparse Predictive Analytics, 2017.](https://arxiv.org/abs/1708.05027) | 25 | | Neural Collaborative Filtering | [X He, et al. Neural Collaborative Filtering, 2017.](https://arxiv.org/abs/1708.05031) | 26 | | Field-aware Neural Factorization Machine | [L Zhang, et al. Field-aware Neural Factorization Machine for Click-Through Rate Prediction, 2019.](https://arxiv.org/abs/1902.09096) | 27 | | Product Neural Network | [Y Qu, et al. Product-based Neural Networks for User Response Prediction, 2016.](https://arxiv.org/abs/1611.00144) | 28 | | Deep Cross Network | [R Wang, et al. Deep & Cross Network for Ad Click Predictions, 2017.](https://arxiv.org/abs/1708.05123) | 29 | | DeepFM | [H Guo, et al. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, 2017.](https://arxiv.org/abs/1703.04247) | 30 | | xDeepFM | [J Lian, et al. xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems, 2018.](https://arxiv.org/abs/1803.05170) | 31 | | AutoInt (Automatic Feature Interaction Model) | [W Song, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks, 2018.](https://arxiv.org/abs/1810.11921) | 32 | | AFN(AdaptiveFactorizationNetwork Model) | [Cheng W, et al. Adaptive Factorization Network: Learning Adaptive-Order Feature Interactions, AAAI'20.](https://arxiv.org/pdf/1909.03276.pdf) | 33 | 34 | Each model's AUC values are about 0.80 for criteo dataset, and about 0.78 for avazu dataset. (please see [example code](examples/main.py)) 35 | 36 | 37 | ## Installation 38 | 39 | pip install torchfm 40 | 41 | 42 | ## API Documentation 43 | 44 | https://rixwew.github.io/pytorch-fm 45 | 46 | 47 | ## Licence 48 | 49 | MIT 50 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | pytorch-fm 2 | =========== 3 | 4 | *Factorization Machine models in PyTorch.* 5 | 6 | This package provides an implementation of various factorization machine models and common datasets in PyTorch. 7 | 8 | 9 | 10 | Minimal requirements 11 | ==================== 12 | 13 | * Python 3.x 14 | * PyTorch 1.1.0 15 | * numpy 16 | * lmdb 17 | 18 | Installation 19 | ============ 20 | 21 | Install with pip:: 22 | 23 | pip install torchfm 24 | 25 | 26 | API documentation 27 | ================= 28 | 29 | .. toctree:: 30 | torchfm 31 | 32 | 33 | Indices and tables 34 | ================== 35 | 36 | * :ref:`genindex` 37 | * :ref:`modindex` 38 | * :ref:`search` 39 | -------------------------------------------------------------------------------- /docs/_sources/torchfm.dataset.rst.txt: -------------------------------------------------------------------------------- 1 | torchfm.dataset 2 | ======================= 3 | 4 | torchfm.dataset.avazu 5 | ---------------------------- 6 | 7 | .. automodule:: torchfm.dataset.avazu 8 | :members: 9 | 10 | torchfm.dataset.criteo 11 | ----------------------------- 12 | 13 | .. automodule:: torchfm.dataset.criteo 14 | :members: 15 | 16 | torchfm.dataset.movielens 17 | -------------------------------- 18 | 19 | .. automodule:: torchfm.dataset.movielens 20 | :members: 21 | -------------------------------------------------------------------------------- /docs/_sources/torchfm.model.rst.txt: -------------------------------------------------------------------------------- 1 | torchfm.model 2 | ===================== 3 | 4 | torchfm.model.afi 5 | ------------------------ 6 | 7 | .. automodule:: torchfm.model.afi 8 | :members: 9 | 10 | torchfm.model.afm 11 | ------------------------ 12 | 13 | .. automodule:: torchfm.model.afm 14 | :members: 15 | 16 | torchfm.model.dcn 17 | ------------------------ 18 | 19 | .. automodule:: torchfm.model.dcn 20 | :members: 21 | 22 | torchfm.model.dfm 23 | ------------------------ 24 | 25 | .. automodule:: torchfm.model.dfm 26 | :members: 27 | 28 | torchfm.model.ffm 29 | ------------------------ 30 | 31 | .. automodule:: torchfm.model.ffm 32 | :members: 33 | 34 | torchfm.model.fm 35 | ----------------------- 36 | 37 | .. automodule:: torchfm.model.fm 38 | :members: 39 | 40 | torchfm.model.fnfm 41 | ------------------------- 42 | 43 | .. automodule:: torchfm.model.fnfm 44 | :members: 45 | 46 | torchfm.model.fnn 47 | ------------------------ 48 | 49 | .. automodule:: torchfm.model.fnn 50 | :members: 51 | 52 | torchfm.model.lr 53 | ----------------------- 54 | 55 | .. automodule:: torchfm.model.lr 56 | :members: 57 | 58 | torchfm.model.nfm 59 | ------------------------ 60 | 61 | .. automodule:: torchfm.model.nfm 62 | :members: 63 | 64 | torchfm.model.pnn 65 | ------------------------ 66 | 67 | .. automodule:: torchfm.model.pnn 68 | :members: 69 | 70 | torchfm.model.wd 71 | ----------------------- 72 | 73 | .. automodule:: torchfm.model.wd 74 | :members: 75 | 76 | torchfm.model.xdfm 77 | ------------------------- 78 | 79 | .. automodule:: torchfm.model.xdfm 80 | :members: 81 | -------------------------------------------------------------------------------- /docs/_sources/torchfm.rst.txt: -------------------------------------------------------------------------------- 1 | torchfm package 2 | =============== 3 | 4 | 5 | 6 | .. toctree:: 7 | 8 | torchfm.dataset 9 | torchfm.model 10 | 11 | 12 | torchfm.layer 13 | -------------------- 14 | 15 | .. automodule:: torchfm.layer 16 | :members: 17 | -------------------------------------------------------------------------------- /docs/_static/basic.css: -------------------------------------------------------------------------------- 1 | /* 2 | * basic.css 3 | * ~~~~~~~~~ 4 | * 5 | * Sphinx stylesheet -- basic theme. 6 | * 7 | * :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | /* -- main layout ----------------------------------------------------------- */ 13 | 14 | div.clearer { 15 | clear: both; 16 | } 17 | 18 | /* -- relbar ---------------------------------------------------------------- */ 19 | 20 | div.related { 21 | width: 100%; 22 | font-size: 90%; 23 | } 24 | 25 | div.related h3 { 26 | display: none; 27 | } 28 | 29 | div.related ul { 30 | margin: 0; 31 | padding: 0 0 0 10px; 32 | list-style: none; 33 | } 34 | 35 | div.related li { 36 | display: inline; 37 | } 38 | 39 | div.related li.right { 40 | float: right; 41 | margin-right: 5px; 42 | } 43 | 44 | /* -- sidebar --------------------------------------------------------------- */ 45 | 46 | div.sphinxsidebarwrapper { 47 | padding: 10px 5px 0 10px; 48 | } 49 | 50 | div.sphinxsidebar { 51 | float: left; 52 | width: 230px; 53 | margin-left: -100%; 54 | font-size: 90%; 55 | word-wrap: break-word; 56 | overflow-wrap : break-word; 57 | } 58 | 59 | div.sphinxsidebar ul { 60 | list-style: none; 61 | } 62 | 63 | div.sphinxsidebar ul ul, 64 | div.sphinxsidebar ul.want-points { 65 | margin-left: 20px; 66 | list-style: square; 67 | } 68 | 69 | div.sphinxsidebar ul ul { 70 | margin-top: 0; 71 | margin-bottom: 0; 72 | } 73 | 74 | div.sphinxsidebar form { 75 | margin-top: 10px; 76 | } 77 | 78 | div.sphinxsidebar input { 79 | border: 1px solid #98dbcc; 80 | font-family: sans-serif; 81 | font-size: 1em; 82 | } 83 | 84 | div.sphinxsidebar #searchbox form.search { 85 | overflow: hidden; 86 | } 87 | 88 | div.sphinxsidebar #searchbox input[type="text"] { 89 | float: left; 90 | width: 80%; 91 | padding: 0.25em; 92 | box-sizing: border-box; 93 | } 94 | 95 | div.sphinxsidebar #searchbox input[type="submit"] { 96 | float: left; 97 | width: 20%; 98 | border-left: none; 99 | padding: 0.25em; 100 | box-sizing: border-box; 101 | } 102 | 103 | 104 | img { 105 | border: 0; 106 | max-width: 100%; 107 | } 108 | 109 | /* -- search page ----------------------------------------------------------- */ 110 | 111 | ul.search { 112 | margin: 10px 0 0 20px; 113 | padding: 0; 114 | } 115 | 116 | ul.search li { 117 | padding: 5px 0 5px 20px; 118 | background-image: url(file.png); 119 | background-repeat: no-repeat; 120 | background-position: 0 7px; 121 | } 122 | 123 | ul.search li a { 124 | font-weight: bold; 125 | } 126 | 127 | ul.search li div.context { 128 | color: #888; 129 | margin: 2px 0 0 30px; 130 | text-align: left; 131 | } 132 | 133 | ul.keywordmatches li.goodmatch a { 134 | font-weight: bold; 135 | } 136 | 137 | /* -- index page ------------------------------------------------------------ */ 138 | 139 | table.contentstable { 140 | width: 90%; 141 | margin-left: auto; 142 | margin-right: auto; 143 | } 144 | 145 | table.contentstable p.biglink { 146 | line-height: 150%; 147 | } 148 | 149 | a.biglink { 150 | font-size: 1.3em; 151 | } 152 | 153 | span.linkdescr { 154 | font-style: italic; 155 | padding-top: 5px; 156 | font-size: 90%; 157 | } 158 | 159 | /* -- general index --------------------------------------------------------- */ 160 | 161 | table.indextable { 162 | width: 100%; 163 | } 164 | 165 | table.indextable td { 166 | text-align: left; 167 | vertical-align: top; 168 | } 169 | 170 | table.indextable ul { 171 | margin-top: 0; 172 | margin-bottom: 0; 173 | list-style-type: none; 174 | } 175 | 176 | table.indextable > tbody > tr > td > ul { 177 | padding-left: 0em; 178 | } 179 | 180 | table.indextable tr.pcap { 181 | height: 10px; 182 | } 183 | 184 | table.indextable tr.cap { 185 | margin-top: 10px; 186 | background-color: #f2f2f2; 187 | } 188 | 189 | img.toggler { 190 | margin-right: 3px; 191 | margin-top: 3px; 192 | cursor: pointer; 193 | } 194 | 195 | div.modindex-jumpbox { 196 | border-top: 1px solid #ddd; 197 | border-bottom: 1px solid #ddd; 198 | margin: 1em 0 1em 0; 199 | padding: 0.4em; 200 | } 201 | 202 | div.genindex-jumpbox { 203 | border-top: 1px solid #ddd; 204 | border-bottom: 1px solid #ddd; 205 | margin: 1em 0 1em 0; 206 | padding: 0.4em; 207 | } 208 | 209 | /* -- domain module index --------------------------------------------------- */ 210 | 211 | table.modindextable td { 212 | padding: 2px; 213 | border-collapse: collapse; 214 | } 215 | 216 | /* -- general body styles --------------------------------------------------- */ 217 | 218 | div.body { 219 | min-width: 450px; 220 | max-width: 800px; 221 | } 222 | 223 | div.body p, div.body dd, div.body li, div.body blockquote { 224 | -moz-hyphens: auto; 225 | -ms-hyphens: auto; 226 | -webkit-hyphens: auto; 227 | hyphens: auto; 228 | } 229 | 230 | a.headerlink { 231 | visibility: hidden; 232 | } 233 | 234 | a.brackets:before, 235 | span.brackets > a:before{ 236 | content: "["; 237 | } 238 | 239 | a.brackets:after, 240 | span.brackets > a:after { 241 | content: "]"; 242 | } 243 | 244 | h1:hover > a.headerlink, 245 | h2:hover > a.headerlink, 246 | h3:hover > a.headerlink, 247 | h4:hover > a.headerlink, 248 | h5:hover > a.headerlink, 249 | h6:hover > a.headerlink, 250 | dt:hover > a.headerlink, 251 | caption:hover > a.headerlink, 252 | p.caption:hover > a.headerlink, 253 | div.code-block-caption:hover > a.headerlink { 254 | visibility: visible; 255 | } 256 | 257 | div.body p.caption { 258 | text-align: inherit; 259 | } 260 | 261 | div.body td { 262 | text-align: left; 263 | } 264 | 265 | .first { 266 | margin-top: 0 !important; 267 | } 268 | 269 | p.rubric { 270 | margin-top: 30px; 271 | font-weight: bold; 272 | } 273 | 274 | img.align-left, .figure.align-left, object.align-left { 275 | clear: left; 276 | float: left; 277 | margin-right: 1em; 278 | } 279 | 280 | img.align-right, .figure.align-right, object.align-right { 281 | clear: right; 282 | float: right; 283 | margin-left: 1em; 284 | } 285 | 286 | img.align-center, .figure.align-center, object.align-center { 287 | display: block; 288 | margin-left: auto; 289 | margin-right: auto; 290 | } 291 | 292 | .align-left { 293 | text-align: left; 294 | } 295 | 296 | .align-center { 297 | text-align: center; 298 | } 299 | 300 | .align-right { 301 | text-align: right; 302 | } 303 | 304 | /* -- sidebars -------------------------------------------------------------- */ 305 | 306 | div.sidebar { 307 | margin: 0 0 0.5em 1em; 308 | border: 1px solid #ddb; 309 | padding: 7px 7px 0 7px; 310 | background-color: #ffe; 311 | width: 40%; 312 | float: right; 313 | } 314 | 315 | p.sidebar-title { 316 | font-weight: bold; 317 | } 318 | 319 | /* -- topics ---------------------------------------------------------------- */ 320 | 321 | div.topic { 322 | border: 1px solid #ccc; 323 | padding: 7px 7px 0 7px; 324 | margin: 10px 0 10px 0; 325 | } 326 | 327 | p.topic-title { 328 | font-size: 1.1em; 329 | font-weight: bold; 330 | margin-top: 10px; 331 | } 332 | 333 | /* -- admonitions ----------------------------------------------------------- */ 334 | 335 | div.admonition { 336 | margin-top: 10px; 337 | margin-bottom: 10px; 338 | padding: 7px; 339 | } 340 | 341 | div.admonition dt { 342 | font-weight: bold; 343 | } 344 | 345 | div.admonition dl { 346 | margin-bottom: 0; 347 | } 348 | 349 | p.admonition-title { 350 | margin: 0px 10px 5px 0px; 351 | font-weight: bold; 352 | } 353 | 354 | div.body p.centered { 355 | text-align: center; 356 | margin-top: 25px; 357 | } 358 | 359 | /* -- tables ---------------------------------------------------------------- */ 360 | 361 | table.docutils { 362 | border: 0; 363 | border-collapse: collapse; 364 | } 365 | 366 | table.align-center { 367 | margin-left: auto; 368 | margin-right: auto; 369 | } 370 | 371 | table caption span.caption-number { 372 | font-style: italic; 373 | } 374 | 375 | table caption span.caption-text { 376 | } 377 | 378 | table.docutils td, table.docutils th { 379 | padding: 1px 8px 1px 5px; 380 | border-top: 0; 381 | border-left: 0; 382 | border-right: 0; 383 | border-bottom: 1px solid #aaa; 384 | } 385 | 386 | table.footnote td, table.footnote th { 387 | border: 0 !important; 388 | } 389 | 390 | th { 391 | text-align: left; 392 | padding-right: 5px; 393 | } 394 | 395 | table.citation { 396 | border-left: solid 1px gray; 397 | margin-left: 1px; 398 | } 399 | 400 | table.citation td { 401 | border-bottom: none; 402 | } 403 | 404 | th > p:first-child, 405 | td > p:first-child { 406 | margin-top: 0px; 407 | } 408 | 409 | th > p:last-child, 410 | td > p:last-child { 411 | margin-bottom: 0px; 412 | } 413 | 414 | /* -- figures --------------------------------------------------------------- */ 415 | 416 | div.figure { 417 | margin: 0.5em; 418 | padding: 0.5em; 419 | } 420 | 421 | div.figure p.caption { 422 | padding: 0.3em; 423 | } 424 | 425 | div.figure p.caption span.caption-number { 426 | font-style: italic; 427 | } 428 | 429 | div.figure p.caption span.caption-text { 430 | } 431 | 432 | /* -- field list styles ----------------------------------------------------- */ 433 | 434 | table.field-list td, table.field-list th { 435 | border: 0 !important; 436 | } 437 | 438 | .field-list ul { 439 | margin: 0; 440 | padding-left: 1em; 441 | } 442 | 443 | .field-list p { 444 | margin: 0; 445 | } 446 | 447 | .field-name { 448 | -moz-hyphens: manual; 449 | -ms-hyphens: manual; 450 | -webkit-hyphens: manual; 451 | hyphens: manual; 452 | } 453 | 454 | /* -- hlist styles ---------------------------------------------------------- */ 455 | 456 | table.hlist td { 457 | vertical-align: top; 458 | } 459 | 460 | 461 | /* -- other body styles ----------------------------------------------------- */ 462 | 463 | ol.arabic { 464 | list-style: decimal; 465 | } 466 | 467 | ol.loweralpha { 468 | list-style: lower-alpha; 469 | } 470 | 471 | ol.upperalpha { 472 | list-style: upper-alpha; 473 | } 474 | 475 | ol.lowerroman { 476 | list-style: lower-roman; 477 | } 478 | 479 | ol.upperroman { 480 | list-style: upper-roman; 481 | } 482 | 483 | li > p:first-child { 484 | margin-top: 0px; 485 | } 486 | 487 | li > p:last-child { 488 | margin-bottom: 0px; 489 | } 490 | 491 | dl.footnote > dt, 492 | dl.citation > dt { 493 | float: left; 494 | } 495 | 496 | dl.footnote > dd, 497 | dl.citation > dd { 498 | margin-bottom: 0em; 499 | } 500 | 501 | dl.footnote > dd:after, 502 | dl.citation > dd:after { 503 | content: ""; 504 | clear: both; 505 | } 506 | 507 | dl.field-list { 508 | display: flex; 509 | flex-wrap: wrap; 510 | } 511 | 512 | dl.field-list > dt { 513 | flex-basis: 20%; 514 | font-weight: bold; 515 | word-break: break-word; 516 | } 517 | 518 | dl.field-list > dt:after { 519 | content: ":"; 520 | } 521 | 522 | dl.field-list > dd { 523 | flex-basis: 70%; 524 | padding-left: 1em; 525 | margin-left: 0em; 526 | margin-bottom: 0em; 527 | } 528 | 529 | dl { 530 | margin-bottom: 15px; 531 | } 532 | 533 | dd > p:first-child { 534 | margin-top: 0px; 535 | } 536 | 537 | dd ul, dd table { 538 | margin-bottom: 10px; 539 | } 540 | 541 | dd { 542 | margin-top: 3px; 543 | margin-bottom: 10px; 544 | margin-left: 30px; 545 | } 546 | 547 | dt:target, span.highlighted { 548 | background-color: #fbe54e; 549 | } 550 | 551 | rect.highlighted { 552 | fill: #fbe54e; 553 | } 554 | 555 | dl.glossary dt { 556 | font-weight: bold; 557 | font-size: 1.1em; 558 | } 559 | 560 | .optional { 561 | font-size: 1.3em; 562 | } 563 | 564 | .sig-paren { 565 | font-size: larger; 566 | } 567 | 568 | .versionmodified { 569 | font-style: italic; 570 | } 571 | 572 | .system-message { 573 | background-color: #fda; 574 | padding: 5px; 575 | border: 3px solid red; 576 | } 577 | 578 | .footnote:target { 579 | background-color: #ffa; 580 | } 581 | 582 | .line-block { 583 | display: block; 584 | margin-top: 1em; 585 | margin-bottom: 1em; 586 | } 587 | 588 | .line-block .line-block { 589 | margin-top: 0; 590 | margin-bottom: 0; 591 | margin-left: 1.5em; 592 | } 593 | 594 | .guilabel, .menuselection { 595 | font-family: sans-serif; 596 | } 597 | 598 | .accelerator { 599 | text-decoration: underline; 600 | } 601 | 602 | .classifier { 603 | font-style: oblique; 604 | } 605 | 606 | .classifier:before { 607 | font-style: normal; 608 | margin: 0.5em; 609 | content: ":"; 610 | } 611 | 612 | abbr, acronym { 613 | border-bottom: dotted 1px; 614 | cursor: help; 615 | } 616 | 617 | /* -- code displays --------------------------------------------------------- */ 618 | 619 | pre { 620 | overflow: auto; 621 | overflow-y: hidden; /* fixes display issues on Chrome browsers */ 622 | } 623 | 624 | span.pre { 625 | -moz-hyphens: none; 626 | -ms-hyphens: none; 627 | -webkit-hyphens: none; 628 | hyphens: none; 629 | } 630 | 631 | td.linenos pre { 632 | padding: 5px 0px; 633 | border: 0; 634 | background-color: transparent; 635 | color: #aaa; 636 | } 637 | 638 | table.highlighttable { 639 | margin-left: 0.5em; 640 | } 641 | 642 | table.highlighttable td { 643 | padding: 0 0.5em 0 0.5em; 644 | } 645 | 646 | div.code-block-caption { 647 | padding: 2px 5px; 648 | font-size: small; 649 | } 650 | 651 | div.code-block-caption code { 652 | background-color: transparent; 653 | } 654 | 655 | div.code-block-caption + div > div.highlight > pre { 656 | margin-top: 0; 657 | } 658 | 659 | div.code-block-caption span.caption-number { 660 | padding: 0.1em 0.3em; 661 | font-style: italic; 662 | } 663 | 664 | div.code-block-caption span.caption-text { 665 | } 666 | 667 | div.literal-block-wrapper { 668 | padding: 1em 1em 0; 669 | } 670 | 671 | div.literal-block-wrapper div.highlight { 672 | margin: 0; 673 | } 674 | 675 | code.descname { 676 | background-color: transparent; 677 | font-weight: bold; 678 | font-size: 1.2em; 679 | } 680 | 681 | code.descclassname { 682 | background-color: transparent; 683 | } 684 | 685 | code.xref, a code { 686 | background-color: transparent; 687 | font-weight: bold; 688 | } 689 | 690 | h1 code, h2 code, h3 code, h4 code, h5 code, h6 code { 691 | background-color: transparent; 692 | } 693 | 694 | .viewcode-link { 695 | float: right; 696 | } 697 | 698 | .viewcode-back { 699 | float: right; 700 | font-family: sans-serif; 701 | } 702 | 703 | div.viewcode-block:target { 704 | margin: -1px -10px; 705 | padding: 0 10px; 706 | } 707 | 708 | /* -- math display ---------------------------------------------------------- */ 709 | 710 | img.math { 711 | vertical-align: middle; 712 | } 713 | 714 | div.body div.math p { 715 | text-align: center; 716 | } 717 | 718 | span.eqno { 719 | float: right; 720 | } 721 | 722 | span.eqno a.headerlink { 723 | position: relative; 724 | left: 0px; 725 | z-index: 1; 726 | } 727 | 728 | div.math:hover a.headerlink { 729 | visibility: visible; 730 | } 731 | 732 | /* -- printout stylesheet --------------------------------------------------- */ 733 | 734 | @media print { 735 | div.document, 736 | div.documentwrapper, 737 | div.bodywrapper { 738 | margin: 0 !important; 739 | width: 100%; 740 | } 741 | 742 | div.sphinxsidebar, 743 | div.related, 744 | div.footer, 745 | #top-link { 746 | display: none; 747 | } 748 | } -------------------------------------------------------------------------------- /docs/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:before,.clearfix:after{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-weight:normal;font-style:normal;src:url("../fonts/fontawesome-webfont.eot");src:url("../fonts/fontawesome-webfont.eot?#iefix") format("embedded-opentype"),url("../fonts/fontawesome-webfont.woff") format("woff"),url("../fonts/fontawesome-webfont.ttf") format("truetype"),url("../fonts/fontawesome-webfont.svg#FontAwesome") format("svg")}.fa:before{display:inline-block;font-family:FontAwesome;font-style:normal;font-weight:normal;line-height:1;text-decoration:inherit}a .fa{display:inline-block;text-decoration:inherit}li .fa{display:inline-block}li .fa-large:before,li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-0.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before,ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before{content:""}.icon-book:before{content:""}.fa-caret-down:before{content:""}.icon-caret-down:before{content:""}.fa-caret-up:before{content:""}.icon-caret-up:before{content:""}.fa-caret-left:before{content:""}.icon-caret-left:before{content:""}.fa-caret-right:before{content:""}.icon-caret-right:before{content:""}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;z-index:400}.rst-versions a{color:#2980B9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27AE60;*zoom:1}.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book{float:left}.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#E74C3C;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#F1C40F;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:gray;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:solid 1px #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge .fa-book{float:none}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book{float:left}.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge .rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width: 768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} 2 | -------------------------------------------------------------------------------- /docs/_static/doctools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * doctools.js 3 | * ~~~~~~~~~~~ 4 | * 5 | * Sphinx JavaScript utilities for all documentation. 6 | * 7 | * :copyright: Copyright 2007-2019 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | /** 13 | * select a different prefix for underscore 14 | */ 15 | $u = _.noConflict(); 16 | 17 | /** 18 | * make the code below compatible with browsers without 19 | * an installed firebug like debugger 20 | if (!window.console || !console.firebug) { 21 | var names = ["log", "debug", "info", "warn", "error", "assert", "dir", 22 | "dirxml", "group", "groupEnd", "time", "timeEnd", "count", "trace", 23 | "profile", "profileEnd"]; 24 | window.console = {}; 25 | for (var i = 0; i < names.length; ++i) 26 | window.console[names[i]] = function() {}; 27 | } 28 | */ 29 | 30 | /** 31 | * small helper function to urldecode strings 32 | */ 33 | jQuery.urldecode = function(x) { 34 | return decodeURIComponent(x).replace(/\+/g, ' '); 35 | }; 36 | 37 | /** 38 | * small helper function to urlencode strings 39 | */ 40 | jQuery.urlencode = encodeURIComponent; 41 | 42 | /** 43 | * This function returns the parsed url parameters of the 44 | * current request. Multiple values per key are supported, 45 | * it will always return arrays of strings for the value parts. 46 | */ 47 | jQuery.getQueryParameters = function(s) { 48 | if (typeof s === 'undefined') 49 | s = document.location.search; 50 | var parts = s.substr(s.indexOf('?') + 1).split('&'); 51 | var result = {}; 52 | for (var i = 0; i < parts.length; i++) { 53 | var tmp = parts[i].split('=', 2); 54 | var key = jQuery.urldecode(tmp[0]); 55 | var value = jQuery.urldecode(tmp[1]); 56 | if (key in result) 57 | result[key].push(value); 58 | else 59 | result[key] = [value]; 60 | } 61 | return result; 62 | }; 63 | 64 | /** 65 | * highlight a given string on a jquery object by wrapping it in 66 | * span elements with the given class name. 67 | */ 68 | jQuery.fn.highlightText = function(text, className) { 69 | function highlight(node, addItems) { 70 | if (node.nodeType === 3) { 71 | var val = node.nodeValue; 72 | var pos = val.toLowerCase().indexOf(text); 73 | if (pos >= 0 && 74 | !jQuery(node.parentNode).hasClass(className) && 75 | !jQuery(node.parentNode).hasClass("nohighlight")) { 76 | var span; 77 | var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg"); 78 | if (isInSVG) { 79 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 80 | } else { 81 | span = document.createElement("span"); 82 | span.className = className; 83 | } 84 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 85 | node.parentNode.insertBefore(span, node.parentNode.insertBefore( 86 | document.createTextNode(val.substr(pos + text.length)), 87 | node.nextSibling)); 88 | node.nodeValue = val.substr(0, pos); 89 | if (isInSVG) { 90 | var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); 91 | var bbox = node.parentElement.getBBox(); 92 | rect.x.baseVal.value = bbox.x; 93 | rect.y.baseVal.value = bbox.y; 94 | rect.width.baseVal.value = bbox.width; 95 | rect.height.baseVal.value = bbox.height; 96 | rect.setAttribute('class', className); 97 | addItems.push({ 98 | "parent": node.parentNode, 99 | "target": rect}); 100 | } 101 | } 102 | } 103 | else if (!jQuery(node).is("button, select, textarea")) { 104 | jQuery.each(node.childNodes, function() { 105 | highlight(this, addItems); 106 | }); 107 | } 108 | } 109 | var addItems = []; 110 | var result = this.each(function() { 111 | highlight(this, addItems); 112 | }); 113 | for (var i = 0; i < addItems.length; ++i) { 114 | jQuery(addItems[i].parent).before(addItems[i].target); 115 | } 116 | return result; 117 | }; 118 | 119 | /* 120 | * backward compatibility for jQuery.browser 121 | * This will be supported until firefox bug is fixed. 122 | */ 123 | if (!jQuery.browser) { 124 | jQuery.uaMatch = function(ua) { 125 | ua = ua.toLowerCase(); 126 | 127 | var match = /(chrome)[ \/]([\w.]+)/.exec(ua) || 128 | /(webkit)[ \/]([\w.]+)/.exec(ua) || 129 | /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) || 130 | /(msie) ([\w.]+)/.exec(ua) || 131 | ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) || 132 | []; 133 | 134 | return { 135 | browser: match[ 1 ] || "", 136 | version: match[ 2 ] || "0" 137 | }; 138 | }; 139 | jQuery.browser = {}; 140 | jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true; 141 | } 142 | 143 | /** 144 | * Small JavaScript module for the documentation. 145 | */ 146 | var Documentation = { 147 | 148 | init : function() { 149 | this.fixFirefoxAnchorBug(); 150 | this.highlightSearchWords(); 151 | this.initIndexTable(); 152 | if (DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) { 153 | this.initOnKeyListeners(); 154 | } 155 | }, 156 | 157 | /** 158 | * i18n support 159 | */ 160 | TRANSLATIONS : {}, 161 | PLURAL_EXPR : function(n) { return n === 1 ? 0 : 1; }, 162 | LOCALE : 'unknown', 163 | 164 | // gettext and ngettext don't access this so that the functions 165 | // can safely bound to a different name (_ = Documentation.gettext) 166 | gettext : function(string) { 167 | var translated = Documentation.TRANSLATIONS[string]; 168 | if (typeof translated === 'undefined') 169 | return string; 170 | return (typeof translated === 'string') ? translated : translated[0]; 171 | }, 172 | 173 | ngettext : function(singular, plural, n) { 174 | var translated = Documentation.TRANSLATIONS[singular]; 175 | if (typeof translated === 'undefined') 176 | return (n == 1) ? singular : plural; 177 | return translated[Documentation.PLURALEXPR(n)]; 178 | }, 179 | 180 | addTranslations : function(catalog) { 181 | for (var key in catalog.messages) 182 | this.TRANSLATIONS[key] = catalog.messages[key]; 183 | this.PLURAL_EXPR = new Function('n', 'return +(' + catalog.plural_expr + ')'); 184 | this.LOCALE = catalog.locale; 185 | }, 186 | 187 | /** 188 | * add context elements like header anchor links 189 | */ 190 | addContextElements : function() { 191 | $('div[id] > :header:first').each(function() { 192 | $('\u00B6'). 193 | attr('href', '#' + this.id). 194 | attr('title', _('Permalink to this headline')). 195 | appendTo(this); 196 | }); 197 | $('dt[id]').each(function() { 198 | $('\u00B6'). 199 | attr('href', '#' + this.id). 200 | attr('title', _('Permalink to this definition')). 201 | appendTo(this); 202 | }); 203 | }, 204 | 205 | /** 206 | * workaround a firefox stupidity 207 | * see: https://bugzilla.mozilla.org/show_bug.cgi?id=645075 208 | */ 209 | fixFirefoxAnchorBug : function() { 210 | if (document.location.hash && $.browser.mozilla) 211 | window.setTimeout(function() { 212 | document.location.href += ''; 213 | }, 10); 214 | }, 215 | 216 | /** 217 | * highlight the search words provided in the url in the text 218 | */ 219 | highlightSearchWords : function() { 220 | var params = $.getQueryParameters(); 221 | var terms = (params.highlight) ? params.highlight[0].split(/\s+/) : []; 222 | if (terms.length) { 223 | var body = $('div.body'); 224 | if (!body.length) { 225 | body = $('body'); 226 | } 227 | window.setTimeout(function() { 228 | $.each(terms, function() { 229 | body.highlightText(this.toLowerCase(), 'highlighted'); 230 | }); 231 | }, 10); 232 | $('') 234 | .appendTo($('#searchbox')); 235 | } 236 | }, 237 | 238 | /** 239 | * init the domain index toggle buttons 240 | */ 241 | initIndexTable : function() { 242 | var togglers = $('img.toggler').click(function() { 243 | var src = $(this).attr('src'); 244 | var idnum = $(this).attr('id').substr(7); 245 | $('tr.cg-' + idnum).toggle(); 246 | if (src.substr(-9) === 'minus.png') 247 | $(this).attr('src', src.substr(0, src.length-9) + 'plus.png'); 248 | else 249 | $(this).attr('src', src.substr(0, src.length-8) + 'minus.png'); 250 | }).css('display', ''); 251 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) { 252 | togglers.click(); 253 | } 254 | }, 255 | 256 | /** 257 | * helper function to hide the search marks again 258 | */ 259 | hideSearchWords : function() { 260 | $('#searchbox .highlight-link').fadeOut(300); 261 | $('span.highlighted').removeClass('highlighted'); 262 | }, 263 | 264 | /** 265 | * make the url absolute 266 | */ 267 | makeURL : function(relativeURL) { 268 | return DOCUMENTATION_OPTIONS.URL_ROOT + '/' + relativeURL; 269 | }, 270 | 271 | /** 272 | * get the current relative url 273 | */ 274 | getCurrentURL : function() { 275 | var path = document.location.pathname; 276 | var parts = path.split(/\//); 277 | $.each(DOCUMENTATION_OPTIONS.URL_ROOT.split(/\//), function() { 278 | if (this === '..') 279 | parts.pop(); 280 | }); 281 | var url = parts.join('/'); 282 | return path.substring(url.lastIndexOf('/') + 1, path.length - 1); 283 | }, 284 | 285 | initOnKeyListeners: function() { 286 | $(document).keyup(function(event) { 287 | var activeElementType = document.activeElement.tagName; 288 | // don't navigate when in search box or textarea 289 | if (activeElementType !== 'TEXTAREA' && activeElementType !== 'INPUT' && activeElementType !== 'SELECT') { 290 | switch (event.keyCode) { 291 | case 37: // left 292 | var prevHref = $('link[rel="prev"]').prop('href'); 293 | if (prevHref) { 294 | window.location.href = prevHref; 295 | return false; 296 | } 297 | case 39: // right 298 | var nextHref = $('link[rel="next"]').prop('href'); 299 | if (nextHref) { 300 | window.location.href = nextHref; 301 | return false; 302 | } 303 | } 304 | } 305 | }); 306 | } 307 | }; 308 | 309 | // quick alias for translations 310 | _ = Documentation.gettext; 311 | 312 | $(document).ready(function() { 313 | Documentation.init(); 314 | }); 315 | -------------------------------------------------------------------------------- /docs/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '0.1', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | FILE_SUFFIX: '.html', 7 | HAS_SOURCE: true, 8 | SOURCELINK_SUFFIX: '.txt', 9 | NAVIGATION_WITH_KEYS: false 10 | }; -------------------------------------------------------------------------------- /docs/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/file.png -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Inconsolata-Bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Inconsolata-Regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Inconsolata.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato-Bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato-Regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-bold.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-bolditalic.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-bolditalic.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-bolditalic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-bolditalic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-italic.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-italic.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-italic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-regular.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/Lato/lato-regular.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/RobotoSlab-Bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/RobotoSlab-Regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_static/js/modernizr.min.js: -------------------------------------------------------------------------------- 1 | /* Modernizr 2.6.2 (Custom Build) | MIT & BSD 2 | * Build: http://modernizr.com/download/#-fontface-backgroundsize-borderimage-borderradius-boxshadow-flexbox-hsla-multiplebgs-opacity-rgba-textshadow-cssanimations-csscolumns-generatedcontent-cssgradients-cssreflections-csstransforms-csstransforms3d-csstransitions-applicationcache-canvas-canvastext-draganddrop-hashchange-history-audio-video-indexeddb-input-inputtypes-localstorage-postmessage-sessionstorage-websockets-websqldatabase-webworkers-geolocation-inlinesvg-smil-svg-svgclippaths-touch-webgl-shiv-mq-cssclasses-addtest-prefixed-teststyles-testprop-testallprops-hasevent-prefixes-domprefixes-load 3 | */ 4 | ;window.Modernizr=function(a,b,c){function D(a){j.cssText=a}function E(a,b){return D(n.join(a+";")+(b||""))}function F(a,b){return typeof a===b}function G(a,b){return!!~(""+a).indexOf(b)}function H(a,b){for(var d in a){var e=a[d];if(!G(e,"-")&&j[e]!==c)return b=="pfx"?e:!0}return!1}function I(a,b,d){for(var e in a){var f=b[a[e]];if(f!==c)return d===!1?a[e]:F(f,"function")?f.bind(d||b):f}return!1}function J(a,b,c){var d=a.charAt(0).toUpperCase()+a.slice(1),e=(a+" "+p.join(d+" ")+d).split(" ");return F(b,"string")||F(b,"undefined")?H(e,b):(e=(a+" "+q.join(d+" ")+d).split(" "),I(e,b,c))}function K(){e.input=function(c){for(var d=0,e=c.length;d',a,""].join(""),l.id=h,(m?l:n).innerHTML+=f,n.appendChild(l),m||(n.style.background="",n.style.overflow="hidden",k=g.style.overflow,g.style.overflow="hidden",g.appendChild(n)),i=c(l,a),m?l.parentNode.removeChild(l):(n.parentNode.removeChild(n),g.style.overflow=k),!!i},z=function(b){var c=a.matchMedia||a.msMatchMedia;if(c)return c(b).matches;var d;return y("@media "+b+" { #"+h+" { position: absolute; } }",function(b){d=(a.getComputedStyle?getComputedStyle(b,null):b.currentStyle)["position"]=="absolute"}),d},A=function(){function d(d,e){e=e||b.createElement(a[d]||"div"),d="on"+d;var f=d in e;return f||(e.setAttribute||(e=b.createElement("div")),e.setAttribute&&e.removeAttribute&&(e.setAttribute(d,""),f=F(e[d],"function"),F(e[d],"undefined")||(e[d]=c),e.removeAttribute(d))),e=null,f}var a={select:"input",change:"input",submit:"form",reset:"form",error:"img",load:"img",abort:"img"};return d}(),B={}.hasOwnProperty,C;!F(B,"undefined")&&!F(B.call,"undefined")?C=function(a,b){return B.call(a,b)}:C=function(a,b){return b in a&&F(a.constructor.prototype[b],"undefined")},Function.prototype.bind||(Function.prototype.bind=function(b){var c=this;if(typeof c!="function")throw new TypeError;var d=w.call(arguments,1),e=function(){if(this instanceof e){var a=function(){};a.prototype=c.prototype;var f=new a,g=c.apply(f,d.concat(w.call(arguments)));return Object(g)===g?g:f}return c.apply(b,d.concat(w.call(arguments)))};return e}),s.flexbox=function(){return J("flexWrap")},s.canvas=function(){var a=b.createElement("canvas");return!!a.getContext&&!!a.getContext("2d")},s.canvastext=function(){return!!e.canvas&&!!F(b.createElement("canvas").getContext("2d").fillText,"function")},s.webgl=function(){return!!a.WebGLRenderingContext},s.touch=function(){var c;return"ontouchstart"in a||a.DocumentTouch&&b instanceof DocumentTouch?c=!0:y(["@media (",n.join("touch-enabled),("),h,")","{#modernizr{top:9px;position:absolute}}"].join(""),function(a){c=a.offsetTop===9}),c},s.geolocation=function(){return"geolocation"in navigator},s.postmessage=function(){return!!a.postMessage},s.websqldatabase=function(){return!!a.openDatabase},s.indexedDB=function(){return!!J("indexedDB",a)},s.hashchange=function(){return A("hashchange",a)&&(b.documentMode===c||b.documentMode>7)},s.history=function(){return!!a.history&&!!history.pushState},s.draganddrop=function(){var a=b.createElement("div");return"draggable"in a||"ondragstart"in a&&"ondrop"in a},s.websockets=function(){return"WebSocket"in a||"MozWebSocket"in a},s.rgba=function(){return D("background-color:rgba(150,255,150,.5)"),G(j.backgroundColor,"rgba")},s.hsla=function(){return D("background-color:hsla(120,40%,100%,.5)"),G(j.backgroundColor,"rgba")||G(j.backgroundColor,"hsla")},s.multiplebgs=function(){return D("background:url(https://),url(https://),red url(https://)"),/(url\s*\(.*?){3}/.test(j.background)},s.backgroundsize=function(){return J("backgroundSize")},s.borderimage=function(){return J("borderImage")},s.borderradius=function(){return J("borderRadius")},s.boxshadow=function(){return J("boxShadow")},s.textshadow=function(){return b.createElement("div").style.textShadow===""},s.opacity=function(){return E("opacity:.55"),/^0.55$/.test(j.opacity)},s.cssanimations=function(){return J("animationName")},s.csscolumns=function(){return J("columnCount")},s.cssgradients=function(){var a="background-image:",b="gradient(linear,left top,right bottom,from(#9f9),to(white));",c="linear-gradient(left top,#9f9, white);";return D((a+"-webkit- ".split(" ").join(b+a)+n.join(c+a)).slice(0,-a.length)),G(j.backgroundImage,"gradient")},s.cssreflections=function(){return J("boxReflect")},s.csstransforms=function(){return!!J("transform")},s.csstransforms3d=function(){var a=!!J("perspective");return a&&"webkitPerspective"in g.style&&y("@media (transform-3d),(-webkit-transform-3d){#modernizr{left:9px;position:absolute;height:3px;}}",function(b,c){a=b.offsetLeft===9&&b.offsetHeight===3}),a},s.csstransitions=function(){return J("transition")},s.fontface=function(){var a;return y('@font-face {font-family:"font";src:url("https://")}',function(c,d){var e=b.getElementById("smodernizr"),f=e.sheet||e.styleSheet,g=f?f.cssRules&&f.cssRules[0]?f.cssRules[0].cssText:f.cssText||"":"";a=/src/i.test(g)&&g.indexOf(d.split(" ")[0])===0}),a},s.generatedcontent=function(){var a;return y(["#",h,"{font:0/0 a}#",h,':after{content:"',l,'";visibility:hidden;font:3px/1 a}'].join(""),function(b){a=b.offsetHeight>=3}),a},s.video=function(){var a=b.createElement("video"),c=!1;try{if(c=!!a.canPlayType)c=new Boolean(c),c.ogg=a.canPlayType('video/ogg; codecs="theora"').replace(/^no$/,""),c.h264=a.canPlayType('video/mp4; codecs="avc1.42E01E"').replace(/^no$/,""),c.webm=a.canPlayType('video/webm; codecs="vp8, vorbis"').replace(/^no$/,"")}catch(d){}return c},s.audio=function(){var a=b.createElement("audio"),c=!1;try{if(c=!!a.canPlayType)c=new Boolean(c),c.ogg=a.canPlayType('audio/ogg; codecs="vorbis"').replace(/^no$/,""),c.mp3=a.canPlayType("audio/mpeg;").replace(/^no$/,""),c.wav=a.canPlayType('audio/wav; codecs="1"').replace(/^no$/,""),c.m4a=(a.canPlayType("audio/x-m4a;")||a.canPlayType("audio/aac;")).replace(/^no$/,"")}catch(d){}return c},s.localstorage=function(){try{return localStorage.setItem(h,h),localStorage.removeItem(h),!0}catch(a){return!1}},s.sessionstorage=function(){try{return sessionStorage.setItem(h,h),sessionStorage.removeItem(h),!0}catch(a){return!1}},s.webworkers=function(){return!!a.Worker},s.applicationcache=function(){return!!a.applicationCache},s.svg=function(){return!!b.createElementNS&&!!b.createElementNS(r.svg,"svg").createSVGRect},s.inlinesvg=function(){var a=b.createElement("div");return a.innerHTML="",(a.firstChild&&a.firstChild.namespaceURI)==r.svg},s.smil=function(){return!!b.createElementNS&&/SVGAnimate/.test(m.call(b.createElementNS(r.svg,"animate")))},s.svgclippaths=function(){return!!b.createElementNS&&/SVGClipPath/.test(m.call(b.createElementNS(r.svg,"clipPath")))};for(var L in s)C(s,L)&&(x=L.toLowerCase(),e[x]=s[L](),v.push((e[x]?"":"no-")+x));return e.input||K(),e.addTest=function(a,b){if(typeof a=="object")for(var d in a)C(a,d)&&e.addTest(d,a[d]);else{a=a.toLowerCase();if(e[a]!==c)return e;b=typeof b=="function"?b():b,typeof f!="undefined"&&f&&(g.className+=" "+(b?"":"no-")+a),e[a]=b}return e},D(""),i=k=null,function(a,b){function k(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function l(){var a=r.elements;return typeof a=="string"?a.split(" "):a}function m(a){var b=i[a[g]];return b||(b={},h++,a[g]=h,i[h]=b),b}function n(a,c,f){c||(c=b);if(j)return c.createElement(a);f||(f=m(c));var g;return f.cache[a]?g=f.cache[a].cloneNode():e.test(a)?g=(f.cache[a]=f.createElem(a)).cloneNode():g=f.createElem(a),g.canHaveChildren&&!d.test(a)?f.frag.appendChild(g):g}function o(a,c){a||(a=b);if(j)return a.createDocumentFragment();c=c||m(a);var d=c.frag.cloneNode(),e=0,f=l(),g=f.length;for(;e",f="hidden"in a,j=a.childNodes.length==1||function(){b.createElement("a");var a=b.createDocumentFragment();return typeof a.cloneNode=="undefined"||typeof a.createDocumentFragment=="undefined"||typeof a.createElement=="undefined"}()}catch(c){f=!0,j=!0}})();var r={elements:c.elements||"abbr article aside audio bdi canvas data datalist details figcaption figure footer header hgroup mark meter nav output progress section summary time video",shivCSS:c.shivCSS!==!1,supportsUnknownElements:j,shivMethods:c.shivMethods!==!1,type:"default",shivDocument:q,createElement:n,createDocumentFragment:o};a.html5=r,q(b)}(this,b),e._version=d,e._prefixes=n,e._domPrefixes=q,e._cssomPrefixes=p,e.mq=z,e.hasEvent=A,e.testProp=function(a){return H([a])},e.testAllProps=J,e.testStyles=y,e.prefixed=function(a,b,c){return b?J(a,b,c):J(a,"pfx")},g.className=g.className.replace(/(^|\s)no-js(\s|$)/,"$1$2")+(f?" js "+v.join(" "):""),e}(this,this.document),function(a,b,c){function d(a){return"[object Function]"==o.call(a)}function e(a){return"string"==typeof a}function f(){}function g(a){return!a||"loaded"==a||"complete"==a||"uninitialized"==a}function h(){var a=p.shift();q=1,a?a.t?m(function(){("c"==a.t?B.injectCss:B.injectJs)(a.s,0,a.a,a.x,a.e,1)},0):(a(),h()):q=0}function i(a,c,d,e,f,i,j){function k(b){if(!o&&g(l.readyState)&&(u.r=o=1,!q&&h(),l.onload=l.onreadystatechange=null,b)){"img"!=a&&m(function(){t.removeChild(l)},50);for(var d in y[c])y[c].hasOwnProperty(d)&&y[c][d].onload()}}var j=j||B.errorTimeout,l=b.createElement(a),o=0,r=0,u={t:d,s:c,e:f,a:i,x:j};1===y[c]&&(r=1,y[c]=[]),"object"==a?l.data=c:(l.src=c,l.type=a),l.width=l.height="0",l.onerror=l.onload=l.onreadystatechange=function(){k.call(this,r)},p.splice(e,0,u),"img"!=a&&(r||2===y[c]?(t.insertBefore(l,s?null:n),m(k,j)):y[c].push(l))}function j(a,b,c,d,f){return q=0,b=b||"j",e(a)?i("c"==b?v:u,a,b,this.i++,c,d,f):(p.splice(this.i++,0,a),1==p.length&&h()),this}function k(){var a=B;return a.loader={load:j,i:0},a}var l=b.documentElement,m=a.setTimeout,n=b.getElementsByTagName("script")[0],o={}.toString,p=[],q=0,r="MozAppearance"in l.style,s=r&&!!b.createRange().compareNode,t=s?l:n.parentNode,l=a.opera&&"[object Opera]"==o.call(a.opera),l=!!b.attachEvent&&!l,u=r?"object":l?"script":"img",v=l?"script":u,w=Array.isArray||function(a){return"[object Array]"==o.call(a)},x=[],y={},z={timeout:function(a,b){return b.length&&(a.timeout=b[0]),a}},A,B;B=function(a){function b(a){var a=a.split("!"),b=x.length,c=a.pop(),d=a.length,c={url:c,origUrl:c,prefixes:a},e,f,g;for(f=0;f"),i("table.docutils.footnote").wrap("
"),i("table.docutils.citation").wrap("
"),i(".wy-menu-vertical ul").not(".simple").siblings("a").each(function(){var e=i(this);expand=i(''),expand.on("click",function(n){return t.toggleCurrent(e),n.stopPropagation(),!1}),e.prepend(expand)})},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),i=e.find('[href="'+n+'"]');if(0===i.length){var t=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(i=e.find('[href="#'+t.attr("id")+'"]')).length&&(i=e.find('[href="#"]'))}0this.docHeight||(this.navBar.scrollTop(i),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",function(){this.linkScroll=!1})},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current"),e.siblings().find("li.current").removeClass("current"),e.find("> ul li.current").removeClass("current"),e.toggleClass("current")}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:e.exports.ThemeNav,StickyNav:e.exports.ThemeNav}),function(){for(var r=0,n=["ms","moz","webkit","o"],e=0;e0 62 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 63 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 64 | var s_v = "^(" + C + ")?" + v; // vowel in stem 65 | 66 | this.stemWord = function (w) { 67 | var stem; 68 | var suffix; 69 | var firstch; 70 | var origword = w; 71 | 72 | if (w.length < 3) 73 | return w; 74 | 75 | var re; 76 | var re2; 77 | var re3; 78 | var re4; 79 | 80 | firstch = w.substr(0,1); 81 | if (firstch == "y") 82 | w = firstch.toUpperCase() + w.substr(1); 83 | 84 | // Step 1a 85 | re = /^(.+?)(ss|i)es$/; 86 | re2 = /^(.+?)([^s])s$/; 87 | 88 | if (re.test(w)) 89 | w = w.replace(re,"$1$2"); 90 | else if (re2.test(w)) 91 | w = w.replace(re2,"$1$2"); 92 | 93 | // Step 1b 94 | re = /^(.+?)eed$/; 95 | re2 = /^(.+?)(ed|ing)$/; 96 | if (re.test(w)) { 97 | var fp = re.exec(w); 98 | re = new RegExp(mgr0); 99 | if (re.test(fp[1])) { 100 | re = /.$/; 101 | w = w.replace(re,""); 102 | } 103 | } 104 | else if (re2.test(w)) { 105 | var fp = re2.exec(w); 106 | stem = fp[1]; 107 | re2 = new RegExp(s_v); 108 | if (re2.test(stem)) { 109 | w = stem; 110 | re2 = /(at|bl|iz)$/; 111 | re3 = new RegExp("([^aeiouylsz])\\1$"); 112 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 113 | if (re2.test(w)) 114 | w = w + "e"; 115 | else if (re3.test(w)) { 116 | re = /.$/; 117 | w = w.replace(re,""); 118 | } 119 | else if (re4.test(w)) 120 | w = w + "e"; 121 | } 122 | } 123 | 124 | // Step 1c 125 | re = /^(.+?)y$/; 126 | if (re.test(w)) { 127 | var fp = re.exec(w); 128 | stem = fp[1]; 129 | re = new RegExp(s_v); 130 | if (re.test(stem)) 131 | w = stem + "i"; 132 | } 133 | 134 | // Step 2 135 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; 136 | if (re.test(w)) { 137 | var fp = re.exec(w); 138 | stem = fp[1]; 139 | suffix = fp[2]; 140 | re = new RegExp(mgr0); 141 | if (re.test(stem)) 142 | w = stem + step2list[suffix]; 143 | } 144 | 145 | // Step 3 146 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; 147 | if (re.test(w)) { 148 | var fp = re.exec(w); 149 | stem = fp[1]; 150 | suffix = fp[2]; 151 | re = new RegExp(mgr0); 152 | if (re.test(stem)) 153 | w = stem + step3list[suffix]; 154 | } 155 | 156 | // Step 4 157 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; 158 | re2 = /^(.+?)(s|t)(ion)$/; 159 | if (re.test(w)) { 160 | var fp = re.exec(w); 161 | stem = fp[1]; 162 | re = new RegExp(mgr1); 163 | if (re.test(stem)) 164 | w = stem; 165 | } 166 | else if (re2.test(w)) { 167 | var fp = re2.exec(w); 168 | stem = fp[1] + fp[2]; 169 | re2 = new RegExp(mgr1); 170 | if (re2.test(stem)) 171 | w = stem; 172 | } 173 | 174 | // Step 5 175 | re = /^(.+?)e$/; 176 | if (re.test(w)) { 177 | var fp = re.exec(w); 178 | stem = fp[1]; 179 | re = new RegExp(mgr1); 180 | re2 = new RegExp(meq1); 181 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 182 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) 183 | w = stem; 184 | } 185 | re = /ll$/; 186 | re2 = new RegExp(mgr1); 187 | if (re.test(w) && re2.test(w)) { 188 | re = /.$/; 189 | w = w.replace(re,""); 190 | } 191 | 192 | // and turn initial Y back to y 193 | if (firstch == "y") 194 | w = firstch.toLowerCase() + w.substr(1); 195 | return w; 196 | } 197 | } 198 | 199 | 200 | 201 | 202 | 203 | var splitChars = (function() { 204 | var result = {}; 205 | var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648, 206 | 1748, 1809, 2416, 2473, 2481, 2526, 2601, 2609, 2612, 2615, 2653, 2702, 207 | 2706, 2729, 2737, 2740, 2857, 2865, 2868, 2910, 2928, 2948, 2961, 2971, 208 | 2973, 3085, 3089, 3113, 3124, 3213, 3217, 3241, 3252, 3295, 3341, 3345, 209 | 3369, 3506, 3516, 3633, 3715, 3721, 3736, 3744, 3748, 3750, 3756, 3761, 210 | 3781, 3912, 4239, 4347, 4681, 4695, 4697, 4745, 4785, 4799, 4801, 4823, 211 | 4881, 5760, 5901, 5997, 6313, 7405, 8024, 8026, 8028, 8030, 8117, 8125, 212 | 8133, 8181, 8468, 8485, 8487, 8489, 8494, 8527, 11311, 11359, 11687, 11695, 213 | 11703, 11711, 11719, 11727, 11735, 12448, 12539, 43010, 43014, 43019, 43587, 214 | 43696, 43713, 64286, 64297, 64311, 64317, 64319, 64322, 64325, 65141]; 215 | var i, j, start, end; 216 | for (i = 0; i < singles.length; i++) { 217 | result[singles[i]] = true; 218 | } 219 | var ranges = [[0, 47], [58, 64], [91, 94], [123, 169], [171, 177], [182, 184], [706, 709], 220 | [722, 735], [741, 747], [751, 879], [888, 889], [894, 901], [1154, 1161], 221 | [1318, 1328], [1367, 1368], [1370, 1376], [1416, 1487], [1515, 1519], [1523, 1568], 222 | [1611, 1631], [1642, 1645], [1750, 1764], [1767, 1773], [1789, 1790], [1792, 1807], 223 | [1840, 1868], [1958, 1968], [1970, 1983], [2027, 2035], [2038, 2041], [2043, 2047], 224 | [2070, 2073], [2075, 2083], [2085, 2087], [2089, 2307], [2362, 2364], [2366, 2383], 225 | [2385, 2391], [2402, 2405], [2419, 2424], [2432, 2436], [2445, 2446], [2449, 2450], 226 | [2483, 2485], [2490, 2492], [2494, 2509], [2511, 2523], [2530, 2533], [2546, 2547], 227 | [2554, 2564], [2571, 2574], [2577, 2578], [2618, 2648], [2655, 2661], [2672, 2673], 228 | [2677, 2692], [2746, 2748], [2750, 2767], [2769, 2783], [2786, 2789], [2800, 2820], 229 | [2829, 2830], [2833, 2834], [2874, 2876], [2878, 2907], [2914, 2917], [2930, 2946], 230 | [2955, 2957], [2966, 2968], [2976, 2978], [2981, 2983], [2987, 2989], [3002, 3023], 231 | [3025, 3045], [3059, 3076], [3130, 3132], [3134, 3159], [3162, 3167], [3170, 3173], 232 | [3184, 3191], [3199, 3204], [3258, 3260], [3262, 3293], [3298, 3301], [3312, 3332], 233 | [3386, 3388], [3390, 3423], [3426, 3429], [3446, 3449], [3456, 3460], [3479, 3481], 234 | [3518, 3519], [3527, 3584], [3636, 3647], [3655, 3663], [3674, 3712], [3717, 3718], 235 | [3723, 3724], [3726, 3731], [3752, 3753], [3764, 3772], [3774, 3775], [3783, 3791], 236 | [3802, 3803], [3806, 3839], [3841, 3871], [3892, 3903], [3949, 3975], [3980, 4095], 237 | [4139, 4158], [4170, 4175], [4182, 4185], [4190, 4192], [4194, 4196], [4199, 4205], 238 | [4209, 4212], [4226, 4237], [4250, 4255], [4294, 4303], [4349, 4351], [4686, 4687], 239 | [4702, 4703], [4750, 4751], [4790, 4791], [4806, 4807], [4886, 4887], [4955, 4968], 240 | [4989, 4991], [5008, 5023], [5109, 5120], [5741, 5742], [5787, 5791], [5867, 5869], 241 | [5873, 5887], [5906, 5919], [5938, 5951], [5970, 5983], [6001, 6015], [6068, 6102], 242 | [6104, 6107], [6109, 6111], [6122, 6127], [6138, 6159], [6170, 6175], [6264, 6271], 243 | [6315, 6319], [6390, 6399], [6429, 6469], [6510, 6511], [6517, 6527], [6572, 6592], 244 | [6600, 6607], [6619, 6655], [6679, 6687], [6741, 6783], [6794, 6799], [6810, 6822], 245 | [6824, 6916], [6964, 6980], [6988, 6991], [7002, 7042], [7073, 7085], [7098, 7167], 246 | [7204, 7231], [7242, 7244], [7294, 7400], [7410, 7423], [7616, 7679], [7958, 7959], 247 | [7966, 7967], [8006, 8007], [8014, 8015], [8062, 8063], [8127, 8129], [8141, 8143], 248 | [8148, 8149], [8156, 8159], [8173, 8177], [8189, 8303], [8306, 8307], [8314, 8318], 249 | [8330, 8335], [8341, 8449], [8451, 8454], [8456, 8457], [8470, 8472], [8478, 8483], 250 | [8506, 8507], [8512, 8516], [8522, 8525], [8586, 9311], [9372, 9449], [9472, 10101], 251 | [10132, 11263], [11493, 11498], [11503, 11516], [11518, 11519], [11558, 11567], 252 | [11622, 11630], [11632, 11647], [11671, 11679], [11743, 11822], [11824, 12292], 253 | [12296, 12320], [12330, 12336], [12342, 12343], [12349, 12352], [12439, 12444], 254 | [12544, 12548], [12590, 12592], [12687, 12689], [12694, 12703], [12728, 12783], 255 | [12800, 12831], [12842, 12880], [12896, 12927], [12938, 12976], [12992, 13311], 256 | [19894, 19967], [40908, 40959], [42125, 42191], [42238, 42239], [42509, 42511], 257 | [42540, 42559], [42592, 42593], [42607, 42622], [42648, 42655], [42736, 42774], 258 | [42784, 42785], [42889, 42890], [42893, 43002], [43043, 43055], [43062, 43071], 259 | [43124, 43137], [43188, 43215], [43226, 43249], [43256, 43258], [43260, 43263], 260 | [43302, 43311], [43335, 43359], [43389, 43395], [43443, 43470], [43482, 43519], 261 | [43561, 43583], [43596, 43599], [43610, 43615], [43639, 43641], [43643, 43647], 262 | [43698, 43700], [43703, 43704], [43710, 43711], [43715, 43738], [43742, 43967], 263 | [44003, 44015], [44026, 44031], [55204, 55215], [55239, 55242], [55292, 55295], 264 | [57344, 63743], [64046, 64047], [64110, 64111], [64218, 64255], [64263, 64274], 265 | [64280, 64284], [64434, 64466], [64830, 64847], [64912, 64913], [64968, 65007], 266 | [65020, 65135], [65277, 65295], [65306, 65312], [65339, 65344], [65371, 65381], 267 | [65471, 65473], [65480, 65481], [65488, 65489], [65496, 65497]]; 268 | for (i = 0; i < ranges.length; i++) { 269 | start = ranges[i][0]; 270 | end = ranges[i][1]; 271 | for (j = start; j <= end; j++) { 272 | result[j] = true; 273 | } 274 | } 275 | return result; 276 | })(); 277 | 278 | function splitQuery(query) { 279 | var result = []; 280 | var start = -1; 281 | for (var i = 0; i < query.length; i++) { 282 | if (splitChars[query.charCodeAt(i)]) { 283 | if (start !== -1) { 284 | result.push(query.slice(start, i)); 285 | start = -1; 286 | } 287 | } else if (start === -1) { 288 | start = i; 289 | } 290 | } 291 | if (start !== -1) { 292 | result.push(query.slice(start)); 293 | } 294 | return result; 295 | } 296 | 297 | 298 | -------------------------------------------------------------------------------- /docs/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/minus.png -------------------------------------------------------------------------------- /docs/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/_static/plus.png -------------------------------------------------------------------------------- /docs/_static/pygments.css: -------------------------------------------------------------------------------- 1 | .highlight .hll { background-color: #ffffcc } 2 | .highlight { background: #f8f8f8; } 3 | .highlight .c { color: #408080; font-style: italic } /* Comment */ 4 | .highlight .err { border: 1px solid #FF0000 } /* Error */ 5 | .highlight .k { color: #008000; font-weight: bold } /* Keyword */ 6 | .highlight .o { color: #666666 } /* Operator */ 7 | .highlight .ch { color: #408080; font-style: italic } /* Comment.Hashbang */ 8 | .highlight .cm { color: #408080; font-style: italic } /* Comment.Multiline */ 9 | .highlight .cp { color: #BC7A00 } /* Comment.Preproc */ 10 | .highlight .cpf { color: #408080; font-style: italic } /* Comment.PreprocFile */ 11 | .highlight .c1 { color: #408080; font-style: italic } /* Comment.Single */ 12 | .highlight .cs { color: #408080; font-style: italic } /* Comment.Special */ 13 | .highlight .gd { color: #A00000 } /* Generic.Deleted */ 14 | .highlight .ge { font-style: italic } /* Generic.Emph */ 15 | .highlight .gr { color: #FF0000 } /* Generic.Error */ 16 | .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 17 | .highlight .gi { color: #00A000 } /* Generic.Inserted */ 18 | .highlight .go { color: #888888 } /* Generic.Output */ 19 | .highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */ 20 | .highlight .gs { font-weight: bold } /* Generic.Strong */ 21 | .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 22 | .highlight .gt { color: #0044DD } /* Generic.Traceback */ 23 | .highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */ 24 | .highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */ 25 | .highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */ 26 | .highlight .kp { color: #008000 } /* Keyword.Pseudo */ 27 | .highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */ 28 | .highlight .kt { color: #B00040 } /* Keyword.Type */ 29 | .highlight .m { color: #666666 } /* Literal.Number */ 30 | .highlight .s { color: #BA2121 } /* Literal.String */ 31 | .highlight .na { color: #7D9029 } /* Name.Attribute */ 32 | .highlight .nb { color: #008000 } /* Name.Builtin */ 33 | .highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */ 34 | .highlight .no { color: #880000 } /* Name.Constant */ 35 | .highlight .nd { color: #AA22FF } /* Name.Decorator */ 36 | .highlight .ni { color: #999999; font-weight: bold } /* Name.Entity */ 37 | .highlight .ne { color: #D2413A; font-weight: bold } /* Name.Exception */ 38 | .highlight .nf { color: #0000FF } /* Name.Function */ 39 | .highlight .nl { color: #A0A000 } /* Name.Label */ 40 | .highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */ 41 | .highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */ 42 | .highlight .nv { color: #19177C } /* Name.Variable */ 43 | .highlight .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */ 44 | .highlight .w { color: #bbbbbb } /* Text.Whitespace */ 45 | .highlight .mb { color: #666666 } /* Literal.Number.Bin */ 46 | .highlight .mf { color: #666666 } /* Literal.Number.Float */ 47 | .highlight .mh { color: #666666 } /* Literal.Number.Hex */ 48 | .highlight .mi { color: #666666 } /* Literal.Number.Integer */ 49 | .highlight .mo { color: #666666 } /* Literal.Number.Oct */ 50 | .highlight .sa { color: #BA2121 } /* Literal.String.Affix */ 51 | .highlight .sb { color: #BA2121 } /* Literal.String.Backtick */ 52 | .highlight .sc { color: #BA2121 } /* Literal.String.Char */ 53 | .highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */ 54 | .highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ 55 | .highlight .s2 { color: #BA2121 } /* Literal.String.Double */ 56 | .highlight .se { color: #BB6622; font-weight: bold } /* Literal.String.Escape */ 57 | .highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */ 58 | .highlight .si { color: #BB6688; font-weight: bold } /* Literal.String.Interpol */ 59 | .highlight .sx { color: #008000 } /* Literal.String.Other */ 60 | .highlight .sr { color: #BB6688 } /* Literal.String.Regex */ 61 | .highlight .s1 { color: #BA2121 } /* Literal.String.Single */ 62 | .highlight .ss { color: #19177C } /* Literal.String.Symbol */ 63 | .highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */ 64 | .highlight .fm { color: #0000FF } /* Name.Function.Magic */ 65 | .highlight .vc { color: #19177C } /* Name.Variable.Class */ 66 | .highlight .vg { color: #19177C } /* Name.Variable.Global */ 67 | .highlight .vi { color: #19177C } /* Name.Variable.Instance */ 68 | .highlight .vm { color: #19177C } /* Name.Variable.Magic */ 69 | .highlight .il { color: #666666 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /docs/_static/underscore.js: -------------------------------------------------------------------------------- 1 | // Underscore.js 1.3.1 2 | // (c) 2009-2012 Jeremy Ashkenas, DocumentCloud Inc. 3 | // Underscore is freely distributable under the MIT license. 4 | // Portions of Underscore are inspired or borrowed from Prototype, 5 | // Oliver Steele's Functional, and John Resig's Micro-Templating. 6 | // For all details and documentation: 7 | // http://documentcloud.github.com/underscore 8 | (function(){function q(a,c,d){if(a===c)return a!==0||1/a==1/c;if(a==null||c==null)return a===c;if(a._chain)a=a._wrapped;if(c._chain)c=c._wrapped;if(a.isEqual&&b.isFunction(a.isEqual))return a.isEqual(c);if(c.isEqual&&b.isFunction(c.isEqual))return c.isEqual(a);var e=l.call(a);if(e!=l.call(c))return false;switch(e){case "[object String]":return a==String(c);case "[object Number]":return a!=+a?c!=+c:a==0?1/a==1/c:a==+c;case "[object Date]":case "[object Boolean]":return+a==+c;case "[object RegExp]":return a.source== 9 | c.source&&a.global==c.global&&a.multiline==c.multiline&&a.ignoreCase==c.ignoreCase}if(typeof a!="object"||typeof c!="object")return false;for(var f=d.length;f--;)if(d[f]==a)return true;d.push(a);var f=0,g=true;if(e=="[object Array]"){if(f=a.length,g=f==c.length)for(;f--;)if(!(g=f in a==f in c&&q(a[f],c[f],d)))break}else{if("constructor"in a!="constructor"in c||a.constructor!=c.constructor)return false;for(var h in a)if(b.has(a,h)&&(f++,!(g=b.has(c,h)&&q(a[h],c[h],d))))break;if(g){for(h in c)if(b.has(c, 10 | h)&&!f--)break;g=!f}}d.pop();return g}var r=this,G=r._,n={},k=Array.prototype,o=Object.prototype,i=k.slice,H=k.unshift,l=o.toString,I=o.hasOwnProperty,w=k.forEach,x=k.map,y=k.reduce,z=k.reduceRight,A=k.filter,B=k.every,C=k.some,p=k.indexOf,D=k.lastIndexOf,o=Array.isArray,J=Object.keys,s=Function.prototype.bind,b=function(a){return new m(a)};if(typeof exports!=="undefined"){if(typeof module!=="undefined"&&module.exports)exports=module.exports=b;exports._=b}else r._=b;b.VERSION="1.3.1";var j=b.each= 11 | b.forEach=function(a,c,d){if(a!=null)if(w&&a.forEach===w)a.forEach(c,d);else if(a.length===+a.length)for(var e=0,f=a.length;e2;a== 12 | null&&(a=[]);if(y&&a.reduce===y)return e&&(c=b.bind(c,e)),f?a.reduce(c,d):a.reduce(c);j(a,function(a,b,i){f?d=c.call(e,d,a,b,i):(d=a,f=true)});if(!f)throw new TypeError("Reduce of empty array with no initial value");return d};b.reduceRight=b.foldr=function(a,c,d,e){var f=arguments.length>2;a==null&&(a=[]);if(z&&a.reduceRight===z)return e&&(c=b.bind(c,e)),f?a.reduceRight(c,d):a.reduceRight(c);var g=b.toArray(a).reverse();e&&!f&&(c=b.bind(c,e));return f?b.reduce(g,c,d,e):b.reduce(g,c)};b.find=b.detect= 13 | function(a,c,b){var e;E(a,function(a,g,h){if(c.call(b,a,g,h))return e=a,true});return e};b.filter=b.select=function(a,c,b){var e=[];if(a==null)return e;if(A&&a.filter===A)return a.filter(c,b);j(a,function(a,g,h){c.call(b,a,g,h)&&(e[e.length]=a)});return e};b.reject=function(a,c,b){var e=[];if(a==null)return e;j(a,function(a,g,h){c.call(b,a,g,h)||(e[e.length]=a)});return e};b.every=b.all=function(a,c,b){var e=true;if(a==null)return e;if(B&&a.every===B)return a.every(c,b);j(a,function(a,g,h){if(!(e= 14 | e&&c.call(b,a,g,h)))return n});return e};var E=b.some=b.any=function(a,c,d){c||(c=b.identity);var e=false;if(a==null)return e;if(C&&a.some===C)return a.some(c,d);j(a,function(a,b,h){if(e||(e=c.call(d,a,b,h)))return n});return!!e};b.include=b.contains=function(a,c){var b=false;if(a==null)return b;return p&&a.indexOf===p?a.indexOf(c)!=-1:b=E(a,function(a){return a===c})};b.invoke=function(a,c){var d=i.call(arguments,2);return b.map(a,function(a){return(b.isFunction(c)?c||a:a[c]).apply(a,d)})};b.pluck= 15 | function(a,c){return b.map(a,function(a){return a[c]})};b.max=function(a,c,d){if(!c&&b.isArray(a))return Math.max.apply(Math,a);if(!c&&b.isEmpty(a))return-Infinity;var e={computed:-Infinity};j(a,function(a,b,h){b=c?c.call(d,a,b,h):a;b>=e.computed&&(e={value:a,computed:b})});return e.value};b.min=function(a,c,d){if(!c&&b.isArray(a))return Math.min.apply(Math,a);if(!c&&b.isEmpty(a))return Infinity;var e={computed:Infinity};j(a,function(a,b,h){b=c?c.call(d,a,b,h):a;bd?1:0}),"value")};b.groupBy=function(a,c){var d={},e=b.isFunction(c)?c:function(a){return a[c]};j(a,function(a,b){var c=e(a,b);(d[c]||(d[c]=[])).push(a)});return d};b.sortedIndex=function(a, 17 | c,d){d||(d=b.identity);for(var e=0,f=a.length;e>1;d(a[g])=0})})};b.difference=function(a){var c=b.flatten(i.call(arguments,1));return b.filter(a,function(a){return!b.include(c,a)})};b.zip=function(){for(var a=i.call(arguments),c=b.max(b.pluck(a,"length")),d=Array(c),e=0;e=0;d--)b=[a[d].apply(this,b)];return b[0]}}; 24 | b.after=function(a,b){return a<=0?b():function(){if(--a<1)return b.apply(this,arguments)}};b.keys=J||function(a){if(a!==Object(a))throw new TypeError("Invalid object");var c=[],d;for(d in a)b.has(a,d)&&(c[c.length]=d);return c};b.values=function(a){return b.map(a,b.identity)};b.functions=b.methods=function(a){var c=[],d;for(d in a)b.isFunction(a[d])&&c.push(d);return c.sort()};b.extend=function(a){j(i.call(arguments,1),function(b){for(var d in b)a[d]=b[d]});return a};b.defaults=function(a){j(i.call(arguments, 25 | 1),function(b){for(var d in b)a[d]==null&&(a[d]=b[d])});return a};b.clone=function(a){return!b.isObject(a)?a:b.isArray(a)?a.slice():b.extend({},a)};b.tap=function(a,b){b(a);return a};b.isEqual=function(a,b){return q(a,b,[])};b.isEmpty=function(a){if(b.isArray(a)||b.isString(a))return a.length===0;for(var c in a)if(b.has(a,c))return false;return true};b.isElement=function(a){return!!(a&&a.nodeType==1)};b.isArray=o||function(a){return l.call(a)=="[object Array]"};b.isObject=function(a){return a===Object(a)}; 26 | b.isArguments=function(a){return l.call(a)=="[object Arguments]"};if(!b.isArguments(arguments))b.isArguments=function(a){return!(!a||!b.has(a,"callee"))};b.isFunction=function(a){return l.call(a)=="[object Function]"};b.isString=function(a){return l.call(a)=="[object String]"};b.isNumber=function(a){return l.call(a)=="[object Number]"};b.isNaN=function(a){return a!==a};b.isBoolean=function(a){return a===true||a===false||l.call(a)=="[object Boolean]"};b.isDate=function(a){return l.call(a)=="[object Date]"}; 27 | b.isRegExp=function(a){return l.call(a)=="[object RegExp]"};b.isNull=function(a){return a===null};b.isUndefined=function(a){return a===void 0};b.has=function(a,b){return I.call(a,b)};b.noConflict=function(){r._=G;return this};b.identity=function(a){return a};b.times=function(a,b,d){for(var e=0;e/g,">").replace(/"/g,""").replace(/'/g,"'").replace(/\//g,"/")};b.mixin=function(a){j(b.functions(a), 28 | function(c){K(c,b[c]=a[c])})};var L=0;b.uniqueId=function(a){var b=L++;return a?a+b:b};b.templateSettings={evaluate:/<%([\s\S]+?)%>/g,interpolate:/<%=([\s\S]+?)%>/g,escape:/<%-([\s\S]+?)%>/g};var t=/.^/,u=function(a){return a.replace(/\\\\/g,"\\").replace(/\\'/g,"'")};b.template=function(a,c){var d=b.templateSettings,d="var __p=[],print=function(){__p.push.apply(__p,arguments);};with(obj||{}){__p.push('"+a.replace(/\\/g,"\\\\").replace(/'/g,"\\'").replace(d.escape||t,function(a,b){return"',_.escape("+ 29 | u(b)+"),'"}).replace(d.interpolate||t,function(a,b){return"',"+u(b)+",'"}).replace(d.evaluate||t,function(a,b){return"');"+u(b).replace(/[\r\n\t]/g," ")+";__p.push('"}).replace(/\r/g,"\\r").replace(/\n/g,"\\n").replace(/\t/g,"\\t")+"');}return __p.join('');",e=new Function("obj","_",d);return c?e(c,b):function(a){return e.call(this,a,b)}};b.chain=function(a){return b(a).chain()};var m=function(a){this._wrapped=a};b.prototype=m.prototype;var v=function(a,c){return c?b(a).chain():a},K=function(a,c){m.prototype[a]= 30 | function(){var a=i.call(arguments);H.call(a,this._wrapped);return v(c.apply(b,a),this._chain)}};b.mixin(b);j("pop,push,reverse,shift,sort,splice,unshift".split(","),function(a){var b=k[a];m.prototype[a]=function(){var d=this._wrapped;b.apply(d,arguments);var e=d.length;(a=="shift"||a=="splice")&&e===0&&delete d[0];return v(d,this._chain)}});j(["concat","join","slice"],function(a){var b=k[a];m.prototype[a]=function(){return v(b.apply(this._wrapped,arguments),this._chain)}});m.prototype.chain=function(){this._chain= 31 | true;return this};m.prototype.value=function(){return this._wrapped}}).call(this); 32 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # http://www.sphinx-doc.org/en/master/config 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('..')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'pytorch-fm' 21 | copyright = '2019, rixwew@gmail.com' 22 | author = 'rixwew@gmail.com' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = '0.1' 26 | 27 | # -- General configuration --------------------------------------------------- 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = [ 33 | 'sphinx.ext.todo', 34 | 'sphinx.ext.autosummary', 35 | 'sphinx.ext.viewcode', 36 | 'sphinx.ext.autodoc' 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | # This pattern also affects html_static_path and html_extra_path. 45 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 46 | 47 | # -- Options for HTML output ------------------------------------------------- 48 | 49 | # The theme to use for HTML and HTML Help pages. See the documentation for 50 | # a list of builtin themes. 51 | # 52 | html_theme = 'sphinx_rtd_theme' 53 | 54 | # Add any paths that contain custom static files (such as style sheets) here, 55 | # relative to this directory. They are copied after the builtin static files, 56 | # so a file named "default.css" will overwrite the builtin "default.css". 57 | html_static_path = ['_static'] 58 | -------------------------------------------------------------------------------- /docs/genindex.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | Index — pytorch-fm 0.1 documentation 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 |
45 | 46 | 91 | 92 |
93 | 94 | 95 | 101 | 102 | 103 |
104 | 105 |
106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 |
124 | 125 |
    126 | 127 |
  • Docs »
  • 128 | 129 |
  • Index
  • 130 | 131 | 132 |
  • 133 | 134 | 135 | 136 |
  • 137 | 138 |
139 | 140 | 141 |
142 |
143 |
144 |
145 | 146 | 147 |

Index

148 | 149 |
150 | A 151 | | C 152 | | D 153 | | E 154 | | F 155 | | I 156 | | L 157 | | M 158 | | N 159 | | O 160 | | P 161 | | T 162 | | W 163 | 164 |
165 |

A

166 | 167 | 173 | 179 |
180 | 181 |

C

182 | 183 | 187 | 193 |
194 | 195 |

D

196 | 197 | 201 | 205 |
206 | 207 |

E

208 | 209 | 213 |
214 | 215 |

F

216 | 217 | 283 |
284 | 285 |

I

286 | 287 | 291 |
292 | 293 |

L

294 | 295 | 299 |
300 | 301 |

M

302 | 303 | 307 | 313 |
314 | 315 |

N

316 | 317 | 321 |
322 | 323 |

O

324 | 325 | 329 |
330 | 331 |

P

332 | 333 | 337 |
338 | 339 |

T

340 | 341 | 359 | 379 |
380 | 381 |

W

382 | 383 | 387 |
388 | 389 | 390 | 391 |
392 | 393 |
394 |
395 | 396 | 397 |
398 | 399 |
400 |

401 | © Copyright 2019, rixwew@gmail.com 402 | 403 |

404 |
405 | Built with Sphinx using a theme provided by Read the Docs. 406 | 407 |
408 | 409 |
410 |
411 | 412 |
413 | 414 |
415 | 416 | 417 | 418 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | pytorch-fm — pytorch-fm 0.1 documentation 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 |
45 | 46 | 91 | 92 |
93 | 94 | 95 | 101 | 102 | 103 |
104 | 105 |
106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 |
124 | 125 |
    126 | 127 |
  • Docs »
  • 128 | 129 |
  • pytorch-fm
  • 130 | 131 | 132 |
  • 133 | 134 | 135 | View page source 136 | 137 | 138 |
  • 139 | 140 |
141 | 142 | 143 |
144 |
145 |
146 |
147 | 148 |
149 |

pytorch-fm

150 |

Factorization Machine models in PyTorch.

151 |

This package provides an implementation of various factorization machine models and common datasets in PyTorch.

152 |
153 |
154 |

Minimal requirements

155 |
    156 |
  • Python 3.x

  • 157 |
  • PyTorch 1.1.0

  • 158 |
  • numpy

  • 159 |
  • lmdb

  • 160 |
161 |
162 |
163 |

Installation

164 |

Install with pip:

165 |
pip install torchfm
166 | 
167 |
168 |
169 |
170 |

API documentation

171 | 201 |
202 |
203 |

Indices and tables

204 | 209 |
210 | 211 | 212 |
213 | 214 |
215 |
216 | 217 | 223 | 224 | 225 |
226 | 227 |
228 |

229 | © Copyright 2019, rixwew@gmail.com 230 | 231 |

232 |
233 | Built with Sphinx using a theme provided by Read the Docs. 234 | 235 |
236 | 237 |
238 |
239 | 240 |
241 | 242 |
243 | 244 | 245 | 246 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | pytorch-fm 2 | =========== 3 | 4 | *Factorization Machine models in PyTorch.* 5 | 6 | This package provides an implementation of various factorization machine models and common datasets in PyTorch. 7 | 8 | 9 | 10 | Minimal requirements 11 | ==================== 12 | 13 | * Python 3.x 14 | * PyTorch 1.1.0 15 | * numpy 16 | * lmdb 17 | 18 | Installation 19 | ============ 20 | 21 | Install with pip:: 22 | 23 | pip install torchfm 24 | 25 | 26 | API documentation 27 | ================= 28 | 29 | .. toctree:: 30 | torchfm 31 | 32 | 33 | Indices and tables 34 | ================== 35 | 36 | * :ref:`genindex` 37 | * :ref:`modindex` 38 | * :ref:`search` 39 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/docs/objects.inv -------------------------------------------------------------------------------- /docs/py-modindex.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Python Module Index — pytorch-fm 0.1 documentation 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 |
47 | 48 | 93 | 94 |
95 | 96 | 97 | 103 | 104 | 105 |
106 | 107 |
108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 |
126 | 127 |
    128 | 129 |
  • Docs »
  • 130 | 131 |
  • Python Module Index
  • 132 | 133 | 134 |
  • 135 | 136 |
  • 137 | 138 |
139 | 140 | 141 |
142 |
143 |
144 |
145 | 146 | 147 |

Python Module Index

148 | 149 |
150 | t 151 |
152 | 153 | 154 | 155 | 157 | 158 | 160 | 163 | 164 | 165 | 168 | 169 | 170 | 173 | 174 | 175 | 178 | 179 | 180 | 183 | 184 | 185 | 188 | 189 | 190 | 193 | 194 | 195 | 198 | 199 | 200 | 203 | 204 | 205 | 208 | 209 | 210 | 213 | 214 | 215 | 218 | 219 | 220 | 223 | 224 | 225 | 228 | 229 | 230 | 233 | 234 | 235 | 238 | 239 | 240 | 243 | 244 | 245 | 248 |
 
156 | t
161 | torchfm 162 |
    166 | torchfm.dataset.avazu 167 |
    171 | torchfm.dataset.criteo 172 |
    176 | torchfm.dataset.movielens 177 |
    181 | torchfm.layer 182 |
    186 | torchfm.model.afi 187 |
    191 | torchfm.model.afm 192 |
    196 | torchfm.model.dcn 197 |
    201 | torchfm.model.dfm 202 |
    206 | torchfm.model.ffm 207 |
    211 | torchfm.model.fm 212 |
    216 | torchfm.model.fnfm 217 |
    221 | torchfm.model.fnn 222 |
    226 | torchfm.model.lr 227 |
    231 | torchfm.model.nfm 232 |
    236 | torchfm.model.pnn 237 |
    241 | torchfm.model.wd 242 |
    246 | torchfm.model.xdfm 247 |
249 | 250 | 251 |
252 | 253 |
254 |
255 | 256 | 257 |
258 | 259 |
260 |

261 | © Copyright 2019, rixwew@gmail.com 262 | 263 |

264 |
265 | Built with Sphinx using a theme provided by Read the Docs. 266 | 267 |
268 | 269 |
270 |
271 | 272 |
273 | 274 |
275 | 276 | 277 | 278 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | -------------------------------------------------------------------------------- /docs/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Search — pytorch-fm 0.1 documentation 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 |
45 | 46 | 91 | 92 |
93 | 94 | 95 | 101 | 102 | 103 |
104 | 105 |
106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 |
124 | 125 |
    126 | 127 |
  • Docs »
  • 128 | 129 |
  • Search
  • 130 | 131 | 132 |
  • 133 | 134 | 135 | 136 |
  • 137 | 138 |
139 | 140 | 141 |
142 |
143 |
144 |
145 | 146 | 154 | 155 | 156 |
157 | 158 |
159 | 160 |
161 | 162 |
163 |
164 | 165 | 166 |
167 | 168 |
169 |

170 | © Copyright 2019, rixwew@gmail.com 171 | 172 |

173 |
174 | Built with Sphinx using a theme provided by Read the Docs. 175 | 176 |
177 | 178 |
179 |
180 | 181 |
182 | 183 |
184 | 185 | 186 | 187 | 192 | 193 | 194 | 195 | 196 | 197 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /docs/searchindex.js: -------------------------------------------------------------------------------- 1 | Search.setIndex({docnames:["index","torchfm","torchfm.dataset","torchfm.model"],envversion:{"sphinx.domains.c":1,"sphinx.domains.changeset":1,"sphinx.domains.cpp":1,"sphinx.domains.javascript":1,"sphinx.domains.math":2,"sphinx.domains.python":1,"sphinx.domains.rst":1,"sphinx.domains.std":1,"sphinx.ext.todo":1,"sphinx.ext.viewcode":1,sphinx:56},filenames:["index.rst","torchfm.rst","torchfm.dataset.rst","torchfm.model.rst"],objects:{"torchfm.dataset":{avazu:[2,0,0,"-"],criteo:[2,0,0,"-"],movielens:[2,0,0,"-"]},"torchfm.dataset.avazu":{AvazuDataset:[2,1,1,""]},"torchfm.dataset.criteo":{CriteoDataset:[2,1,1,""]},"torchfm.dataset.movielens":{MovieLens1MDataset:[2,1,1,""],MovieLens20MDataset:[2,1,1,""]},"torchfm.layer":{AttentionalFactorizationMachine:[1,1,1,""],CompressedInteractionNetwork:[1,1,1,""],CrossNetwork:[1,1,1,""],FactorizationMachine:[1,1,1,""],FeaturesEmbedding:[1,1,1,""],FeaturesLinear:[1,1,1,""],FieldAwareFactorizationMachine:[1,1,1,""],InnerProductNetwork:[1,1,1,""],MultiLayerPerceptron:[1,1,1,""],OuterProductNetwork:[1,1,1,""]},"torchfm.layer.AttentionalFactorizationMachine":{forward:[1,2,1,""]},"torchfm.layer.CompressedInteractionNetwork":{forward:[1,2,1,""]},"torchfm.layer.CrossNetwork":{forward:[1,2,1,""]},"torchfm.layer.FactorizationMachine":{forward:[1,2,1,""]},"torchfm.layer.FeaturesEmbedding":{forward:[1,2,1,""]},"torchfm.layer.FeaturesLinear":{forward:[1,2,1,""]},"torchfm.layer.FieldAwareFactorizationMachine":{forward:[1,2,1,""]},"torchfm.layer.InnerProductNetwork":{forward:[1,2,1,""]},"torchfm.layer.MultiLayerPerceptron":{forward:[1,2,1,""]},"torchfm.layer.OuterProductNetwork":{forward:[1,2,1,""]},"torchfm.model":{afi:[3,0,0,"-"],afm:[3,0,0,"-"],dcn:[3,0,0,"-"],dfm:[3,0,0,"-"],ffm:[3,0,0,"-"],fm:[3,0,0,"-"],fnfm:[3,0,0,"-"],fnn:[3,0,0,"-"],lr:[3,0,0,"-"],nfm:[3,0,0,"-"],pnn:[3,0,0,"-"],wd:[3,0,0,"-"],xdfm:[3,0,0,"-"]},"torchfm.model.afi":{AutomaticFeatureInteractionModel:[3,1,1,""]},"torchfm.model.afi.AutomaticFeatureInteractionModel":{forward:[3,2,1,""]},"torchfm.model.afm":{AttentionalFactorizationMachineModel:[3,1,1,""]},"torchfm.model.afm.AttentionalFactorizationMachineModel":{forward:[3,2,1,""]},"torchfm.model.dcn":{DeepCrossNetworkModel:[3,1,1,""]},"torchfm.model.dcn.DeepCrossNetworkModel":{forward:[3,2,1,""]},"torchfm.model.dfm":{DeepFactorizationMachineModel:[3,1,1,""]},"torchfm.model.dfm.DeepFactorizationMachineModel":{forward:[3,2,1,""]},"torchfm.model.ffm":{FieldAwareFactorizationMachineModel:[3,1,1,""]},"torchfm.model.ffm.FieldAwareFactorizationMachineModel":{forward:[3,2,1,""]},"torchfm.model.fm":{FactorizationMachineModel:[3,1,1,""]},"torchfm.model.fm.FactorizationMachineModel":{forward:[3,2,1,""]},"torchfm.model.fnfm":{FieldAwareNeuralFactorizationMachineModel:[3,1,1,""]},"torchfm.model.fnfm.FieldAwareNeuralFactorizationMachineModel":{forward:[3,2,1,""]},"torchfm.model.fnn":{FactorizationSupportedNeuralNetworkModel:[3,1,1,""]},"torchfm.model.fnn.FactorizationSupportedNeuralNetworkModel":{forward:[3,2,1,""]},"torchfm.model.lr":{LogisticRegressionModel:[3,1,1,""]},"torchfm.model.lr.LogisticRegressionModel":{forward:[3,2,1,""]},"torchfm.model.nfm":{NeuralFactorizationMachineModel:[3,1,1,""]},"torchfm.model.nfm.NeuralFactorizationMachineModel":{forward:[3,2,1,""]},"torchfm.model.pnn":{ProductNeuralNetworkModel:[3,1,1,""]},"torchfm.model.pnn.ProductNeuralNetworkModel":{forward:[3,2,1,""]},"torchfm.model.wd":{WideAndDeepModel:[3,1,1,""]},"torchfm.model.wd.WideAndDeepModel":{forward:[3,2,1,""]},"torchfm.model.xdfm":{ExtremeDeepFactorizationMachineModel:[3,1,1,""]},"torchfm.model.xdfm.ExtremeDeepFactorizationMachineModel":{forward:[3,2,1,""]},torchfm:{layer:[1,0,0,"-"]}},objnames:{"0":["py","module","Python module"],"1":["py","class","Python class"],"2":["py","method","Python method"]},objtypes:{"0":"py:module","1":"py:class","2":"py:method"},terms:{"20m":2,"case":3,"class":[1,2,3],"float":1,"long":[1,3],"true":[1,2,3],advertis:2,afi:[0,1],afm:[0,1],analyt:3,appear:2,attent:3,attentionalfactorizationmachin:1,attentionalfactorizationmachinemodel:3,attn_siz:[1,3],autoint:3,automat:3,automaticfeatureinteractionmodel:3,avazu:[0,1],avazudataset:2,awar:3,base:3,batch_siz:[1,3],cach:2,cache_path:2,categor:3,challeng:2,cheng:3,chua:3,click:[2,3],com:2,combin:3,common:0,competit:2,compressedinteractionnetwork:1,criteo:[0,1],criteodataset:2,cross:3,cross_layer_s:[1,3],crossnetwork:1,csie:2,ctr:[2,3],data:[2,3],dataset:[0,1],dataset_path:2,dcn:[0,1],deep:3,deepcrossnetworkmodel:3,deepfactorizationmachinemodel:3,deepfm:3,dfm:[0,1],discret:2,displai:2,dropout:[1,3],edu:2,embed_dim:[1,3],explicit:3,extremedeepfactorizationmachinemodel:3,factor:[0,3],factorizationmachin:1,factorizationmachinemodel:3,factorizationsupportedneuralnetworkmodel:3,fals:2,featur:[2,3],featuresembed:1,featureslinear:1,ffm:[0,1],field:3,field_dim:[1,3],fieldawarefactorizationmachin:1,fieldawarefactorizationmachinemodel:3,fieldawareneuralfactorizationmachinemodel:3,fnfm:[0,1],fnn:[0,1],forward:[1,3],grouplen:2,guo:3,http:2,implement:[0,3],implicit:3,index:0,infrequ:2,inner:3,innerproductnetwork:1,input_dim:1,instanc:2,interact:3,juan:3,kaggl:2,kernel_typ:1,lab:2,layer:0,learn:3,less:2,lian:3,lmdb:[0,2],log2:2,logist:3,logisticregressionmodel:3,machin:[0,3],mat:1,method:3,min_threshold:2,mlp_dim:3,model:[0,1],modul:0,movielen:[0,1],movielens1mdataset:2,movielens20mdataset:2,multi:3,multilayerperceptron:1,neg:2,network:3,neural:3,neuralfactorizationmachinemodel:3,nfm:[0,1],none:2,ntu:2,num_field:[1,3],num_head:3,num_lay:[1,3],numer:2,numpi:0,org:2,outer:3,outerproductnetwork:1,output_dim:1,output_lay:1,over:3,packag:0,page:0,paramet:[1,2,3],path:2,pdf:2,pip:0,pnn:[0,1],predict:[2,3],prepar:2,preprat:2,product:3,productneuralnetworkmodel:3,propos:2,provid:0,python:0,pytorch:3,r01922136:2,rate:[2,3],rebuild_cach:2,recommend:3,reduce_sum:1,refer:[2,3],refresh:2,regress:3,remov:2,rendl:3,respons:3,sampl:2,search:0,self:3,sep:2,singl:2,size:[1,3],song:3,sourc:[1,2,3],spars:3,split_half:[1,3],studi:3,system:3,tensor:[1,3],than:2,them:2,thi:0,threshold:2,through:[2,3],torchfm:0,train:2,transform:2,treat:2,txt:2,user:3,valu:2,variou:0,via:3,wang:3,weight:3,which:2,wide:3,wideanddeepmodel:3,winner:2,www:2,xdeepfm:3,xdfm:[0,1],xiao:3,zhang:3},titles:["pytorch-fm","torchfm package","torchfm.dataset","torchfm.model"],titleterms:{afi:3,afm:3,api:0,avazu:2,criteo:2,dataset:2,dcn:3,dfm:3,document:0,ffm:3,fnfm:3,fnn:3,indic:0,instal:0,layer:1,minim:0,model:3,movielen:2,nfm:3,packag:1,pnn:3,pytorch:0,requir:0,tabl:0,torchfm:[1,2,3],xdfm:3}}) -------------------------------------------------------------------------------- /docs/torchfm.dataset.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | torchfm.dataset — pytorch-fm 0.1 documentation 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 |
46 | 47 | 102 | 103 |
104 | 105 | 106 | 112 | 113 | 114 |
115 | 116 |
117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 |
135 | 136 | 154 | 155 | 156 |
157 |
158 |
159 |
160 | 161 |
162 |

torchfm.dataset

163 |
164 |

torchfm.dataset.avazu

165 |
166 |
167 | class torchfm.dataset.avazu.AvazuDataset(dataset_path=None, cache_path='.avazu', rebuild_cache=False, min_threshold=4)[source]
168 |

Avazu Click-Through Rate Prediction Dataset

169 |
170 |
Dataset preparation

Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature

171 |
172 |
173 |
174 |
Parameters
175 |
    176 |
  • dataset_path – avazu train path

  • 177 |
  • cache_path – lmdb cache path

  • 178 |
  • rebuild_cache – If True, lmdb cache is refreshed

  • 179 |
  • min_threshold – infrequent feature threshold

  • 180 |
181 |
182 |
183 |
184 |
Reference

https://www.kaggle.com/c/avazu-ctr-prediction

185 |
186 |
187 |
188 | 189 |
190 |
191 |

torchfm.dataset.criteo

192 |
193 |
194 | class torchfm.dataset.criteo.CriteoDataset(dataset_path=None, cache_path='.criteo', rebuild_cache=False, min_threshold=10)[source]
195 |

Criteo Display Advertising Challenge Dataset

196 |
197 |
Data prepration:
    198 |
  • Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature

  • 199 |
  • Discretize numerical values by log2 transformation which is proposed by the winner of Criteo Competition

  • 200 |
201 |
202 |
203 |
204 |
Parameters
205 |
    206 |
  • dataset_path – criteo train.txt path.

  • 207 |
  • cache_path – lmdb cache path.

  • 208 |
  • rebuild_cache – If True, lmdb cache is refreshed.

  • 209 |
  • min_threshold – infrequent feature threshold.

  • 210 |
211 |
212 |
213 |
214 |
Reference:

https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset 215 | https://www.csie.ntu.edu.tw/~r01922136/kaggle-2014-criteo.pdf

216 |
217 |
218 |
219 | 220 |
221 |
222 |

torchfm.dataset.movielens

223 |
224 |
225 | class torchfm.dataset.movielens.MovieLens1MDataset(dataset_path)[source]
226 |

MovieLens 1M Dataset

227 |
228 |
Data preparation

treat samples with a rating less than 3 as negative samples

229 |
230 |
231 |
232 |
Parameters
233 |

dataset_path – MovieLens dataset path

234 |
235 |
236 |
237 |
Reference:

https://grouplens.org/datasets/movielens

238 |
239 |
240 |
241 | 242 |
243 |
244 | class torchfm.dataset.movielens.MovieLens20MDataset(dataset_path, sep=', ')[source]
245 |

MovieLens 20M Dataset

246 |
247 |
Data preparation

treat samples with a rating less than 3 as negative samples

248 |
249 |
250 |
251 |
Parameters
252 |

dataset_path – MovieLens dataset path

253 |
254 |
255 |
256 |
Reference:

https://grouplens.org/datasets/movielens

257 |
258 |
259 |
260 | 261 |
262 |
263 | 264 | 265 |
266 | 267 |
268 |
269 | 270 | 278 | 279 | 280 |
281 | 282 |
283 |

284 | © Copyright 2019, rixwew@gmail.com 285 | 286 |

287 |
288 | Built with Sphinx using a theme provided by Read the Docs. 289 | 290 |
291 | 292 |
293 |
294 | 295 |
296 | 297 |
298 | 299 | 300 | 301 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | -------------------------------------------------------------------------------- /docs/torchfm.dataset.rst: -------------------------------------------------------------------------------- 1 | torchfm.dataset 2 | ======================= 3 | 4 | torchfm.dataset.avazu 5 | ---------------------------- 6 | 7 | .. automodule:: torchfm.dataset.avazu 8 | :members: 9 | 10 | torchfm.dataset.criteo 11 | ----------------------------- 12 | 13 | .. automodule:: torchfm.dataset.criteo 14 | :members: 15 | 16 | torchfm.dataset.movielens 17 | -------------------------------- 18 | 19 | .. automodule:: torchfm.dataset.movielens 20 | :members: 21 | -------------------------------------------------------------------------------- /docs/torchfm.model.rst: -------------------------------------------------------------------------------- 1 | torchfm.model 2 | ===================== 3 | 4 | torchfm.model.afi 5 | ------------------------ 6 | 7 | .. automodule:: torchfm.model.afi 8 | :members: 9 | 10 | torchfm.model.afm 11 | ------------------------ 12 | 13 | .. automodule:: torchfm.model.afm 14 | :members: 15 | 16 | torchfm.model.dcn 17 | ------------------------ 18 | 19 | .. automodule:: torchfm.model.dcn 20 | :members: 21 | 22 | torchfm.model.dfm 23 | ------------------------ 24 | 25 | .. automodule:: torchfm.model.dfm 26 | :members: 27 | 28 | torchfm.model.ffm 29 | ------------------------ 30 | 31 | .. automodule:: torchfm.model.ffm 32 | :members: 33 | 34 | torchfm.model.fm 35 | ----------------------- 36 | 37 | .. automodule:: torchfm.model.fm 38 | :members: 39 | 40 | torchfm.model.fnfm 41 | ------------------------- 42 | 43 | .. automodule:: torchfm.model.fnfm 44 | :members: 45 | 46 | torchfm.model.fnn 47 | ------------------------ 48 | 49 | .. automodule:: torchfm.model.fnn 50 | :members: 51 | 52 | torchfm.model.lr 53 | ----------------------- 54 | 55 | .. automodule:: torchfm.model.lr 56 | :members: 57 | 58 | torchfm.model.nfm 59 | ------------------------ 60 | 61 | .. automodule:: torchfm.model.nfm 62 | :members: 63 | 64 | torchfm.model.pnn 65 | ------------------------ 66 | 67 | .. automodule:: torchfm.model.pnn 68 | :members: 69 | 70 | torchfm.model.wd 71 | ----------------------- 72 | 73 | .. automodule:: torchfm.model.wd 74 | :members: 75 | 76 | torchfm.model.xdfm 77 | ------------------------- 78 | 79 | .. automodule:: torchfm.model.xdfm 80 | :members: 81 | -------------------------------------------------------------------------------- /docs/torchfm.rst: -------------------------------------------------------------------------------- 1 | torchfm package 2 | =============== 3 | 4 | 5 | 6 | .. toctree:: 7 | 8 | torchfm.dataset 9 | torchfm.model 10 | 11 | 12 | torchfm.layer 13 | -------------------- 14 | 15 | .. automodule:: torchfm.layer 16 | :members: 17 | -------------------------------------------------------------------------------- /examples/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from sklearn.metrics import roc_auc_score 4 | from torch.utils.data import DataLoader 5 | 6 | from torchfm.dataset.avazu import AvazuDataset 7 | from torchfm.dataset.criteo import CriteoDataset 8 | from torchfm.dataset.movielens import MovieLens1MDataset, MovieLens20MDataset 9 | from torchfm.model.afi import AutomaticFeatureInteractionModel 10 | from torchfm.model.afm import AttentionalFactorizationMachineModel 11 | from torchfm.model.dcn import DeepCrossNetworkModel 12 | from torchfm.model.dfm import DeepFactorizationMachineModel 13 | from torchfm.model.ffm import FieldAwareFactorizationMachineModel 14 | from torchfm.model.fm import FactorizationMachineModel 15 | from torchfm.model.fnfm import FieldAwareNeuralFactorizationMachineModel 16 | from torchfm.model.fnn import FactorizationSupportedNeuralNetworkModel 17 | from torchfm.model.hofm import HighOrderFactorizationMachineModel 18 | from torchfm.model.lr import LogisticRegressionModel 19 | from torchfm.model.ncf import NeuralCollaborativeFiltering 20 | from torchfm.model.nfm import NeuralFactorizationMachineModel 21 | from torchfm.model.pnn import ProductNeuralNetworkModel 22 | from torchfm.model.wd import WideAndDeepModel 23 | from torchfm.model.xdfm import ExtremeDeepFactorizationMachineModel 24 | from torchfm.model.afn import AdaptiveFactorizationNetwork 25 | 26 | 27 | def get_dataset(name, path): 28 | if name == 'movielens1M': 29 | return MovieLens1MDataset(path) 30 | elif name == 'movielens20M': 31 | return MovieLens20MDataset(path) 32 | elif name == 'criteo': 33 | return CriteoDataset(path) 34 | elif name == 'avazu': 35 | return AvazuDataset(path) 36 | else: 37 | raise ValueError('unknown dataset name: ' + name) 38 | 39 | 40 | def get_model(name, dataset): 41 | """ 42 | Hyperparameters are empirically determined, not opitmized. 43 | """ 44 | field_dims = dataset.field_dims 45 | if name == 'lr': 46 | return LogisticRegressionModel(field_dims) 47 | elif name == 'fm': 48 | return FactorizationMachineModel(field_dims, embed_dim=16) 49 | elif name == 'hofm': 50 | return HighOrderFactorizationMachineModel(field_dims, order=3, embed_dim=16) 51 | elif name == 'ffm': 52 | return FieldAwareFactorizationMachineModel(field_dims, embed_dim=4) 53 | elif name == 'fnn': 54 | return FactorizationSupportedNeuralNetworkModel(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2) 55 | elif name == 'wd': 56 | return WideAndDeepModel(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2) 57 | elif name == 'ipnn': 58 | return ProductNeuralNetworkModel(field_dims, embed_dim=16, mlp_dims=(16,), method='inner', dropout=0.2) 59 | elif name == 'opnn': 60 | return ProductNeuralNetworkModel(field_dims, embed_dim=16, mlp_dims=(16,), method='outer', dropout=0.2) 61 | elif name == 'dcn': 62 | return DeepCrossNetworkModel(field_dims, embed_dim=16, num_layers=3, mlp_dims=(16, 16), dropout=0.2) 63 | elif name == 'nfm': 64 | return NeuralFactorizationMachineModel(field_dims, embed_dim=64, mlp_dims=(64,), dropouts=(0.2, 0.2)) 65 | elif name == 'ncf': 66 | # only supports MovieLens dataset because for other datasets user/item colums are indistinguishable 67 | assert isinstance(dataset, MovieLens20MDataset) or isinstance(dataset, MovieLens1MDataset) 68 | return NeuralCollaborativeFiltering(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2, 69 | user_field_idx=dataset.user_field_idx, 70 | item_field_idx=dataset.item_field_idx) 71 | elif name == 'fnfm': 72 | return FieldAwareNeuralFactorizationMachineModel(field_dims, embed_dim=4, mlp_dims=(64,), dropouts=(0.2, 0.2)) 73 | elif name == 'dfm': 74 | return DeepFactorizationMachineModel(field_dims, embed_dim=16, mlp_dims=(16, 16), dropout=0.2) 75 | elif name == 'xdfm': 76 | return ExtremeDeepFactorizationMachineModel( 77 | field_dims, embed_dim=16, cross_layer_sizes=(16, 16), split_half=False, mlp_dims=(16, 16), dropout=0.2) 78 | elif name == 'afm': 79 | return AttentionalFactorizationMachineModel(field_dims, embed_dim=16, attn_size=16, dropouts=(0.2, 0.2)) 80 | elif name == 'afi': 81 | return AutomaticFeatureInteractionModel( 82 | field_dims, embed_dim=16, atten_embed_dim=64, num_heads=2, num_layers=3, mlp_dims=(400, 400), dropouts=(0, 0, 0)) 83 | elif name == 'afn': 84 | print("Model:AFN") 85 | return AdaptiveFactorizationNetwork( 86 | field_dims, embed_dim=16, LNN_dim=1500, mlp_dims=(400, 400, 400), dropouts=(0, 0, 0)) 87 | else: 88 | raise ValueError('unknown model name: ' + name) 89 | 90 | 91 | class EarlyStopper(object): 92 | 93 | def __init__(self, num_trials, save_path): 94 | self.num_trials = num_trials 95 | self.trial_counter = 0 96 | self.best_accuracy = 0 97 | self.save_path = save_path 98 | 99 | def is_continuable(self, model, accuracy): 100 | if accuracy > self.best_accuracy: 101 | self.best_accuracy = accuracy 102 | self.trial_counter = 0 103 | torch.save(model, self.save_path) 104 | return True 105 | elif self.trial_counter + 1 < self.num_trials: 106 | self.trial_counter += 1 107 | return True 108 | else: 109 | return False 110 | 111 | 112 | def train(model, optimizer, data_loader, criterion, device, log_interval=100): 113 | model.train() 114 | total_loss = 0 115 | tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0) 116 | for i, (fields, target) in enumerate(tk0): 117 | fields, target = fields.to(device), target.to(device) 118 | y = model(fields) 119 | loss = criterion(y, target.float()) 120 | model.zero_grad() 121 | loss.backward() 122 | optimizer.step() 123 | total_loss += loss.item() 124 | if (i + 1) % log_interval == 0: 125 | tk0.set_postfix(loss=total_loss / log_interval) 126 | total_loss = 0 127 | 128 | 129 | def test(model, data_loader, device): 130 | model.eval() 131 | targets, predicts = list(), list() 132 | with torch.no_grad(): 133 | for fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0): 134 | fields, target = fields.to(device), target.to(device) 135 | y = model(fields) 136 | targets.extend(target.tolist()) 137 | predicts.extend(y.tolist()) 138 | return roc_auc_score(targets, predicts) 139 | 140 | 141 | def main(dataset_name, 142 | dataset_path, 143 | model_name, 144 | epoch, 145 | learning_rate, 146 | batch_size, 147 | weight_decay, 148 | device, 149 | save_dir): 150 | device = torch.device(device) 151 | dataset = get_dataset(dataset_name, dataset_path) 152 | train_length = int(len(dataset) * 0.8) 153 | valid_length = int(len(dataset) * 0.1) 154 | test_length = len(dataset) - train_length - valid_length 155 | train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split( 156 | dataset, (train_length, valid_length, test_length)) 157 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8) 158 | valid_data_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=8) 159 | test_data_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=8) 160 | model = get_model(model_name, dataset).to(device) 161 | criterion = torch.nn.BCELoss() 162 | optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay) 163 | early_stopper = EarlyStopper(num_trials=2, save_path=f'{save_dir}/{model_name}.pt') 164 | for epoch_i in range(epoch): 165 | train(model, optimizer, train_data_loader, criterion, device) 166 | auc = test(model, valid_data_loader, device) 167 | print('epoch:', epoch_i, 'validation: auc:', auc) 168 | if not early_stopper.is_continuable(model, auc): 169 | print(f'validation: best auc: {early_stopper.best_accuracy}') 170 | break 171 | auc = test(model, test_data_loader, device) 172 | print(f'test auc: {auc}') 173 | 174 | 175 | if __name__ == '__main__': 176 | import argparse 177 | 178 | parser = argparse.ArgumentParser() 179 | parser.add_argument('--dataset_name', default='criteo') 180 | parser.add_argument('--dataset_path', help='criteo/train.txt, avazu/train, or ml-1m/ratings.dat') 181 | parser.add_argument('--model_name', default='afi') 182 | parser.add_argument('--epoch', type=int, default=100) 183 | parser.add_argument('--learning_rate', type=float, default=0.001) 184 | parser.add_argument('--batch_size', type=int, default=2048) 185 | parser.add_argument('--weight_decay', type=float, default=1e-6) 186 | parser.add_argument('--device', default='cuda:0') 187 | parser.add_argument('--save_dir', default='chkpt') 188 | args = parser.parse_args() 189 | main(args.dataset_name, 190 | args.dataset_path, 191 | args.model_name, 192 | args.epoch, 193 | args.learning_rate, 194 | args.batch_size, 195 | args.weight_decay, 196 | args.device, 197 | args.save_dir) 198 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.1 2 | scikit-learn 3 | numpy 4 | pandas 5 | tqdm 6 | lmdb -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import setup, find_packages 4 | 5 | 6 | if __name__ == '__main__': 7 | 8 | with open(Path(__file__).parent / 'README.md', encoding='utf-8') as f: 9 | long_description = f.read() 10 | 11 | setup( 12 | name="torchfm", 13 | version="0.7.0", 14 | description="PyTorch implementation of Factorization Machine Models", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | url="https://github.com/rixwew/torchfm", 18 | author="rixwew", 19 | author_email="rixwew@gmail.com", 20 | packages=find_packages(exclude=["examples", "docs"]), 21 | ) 22 | -------------------------------------------------------------------------------- /test/test_layers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from torchfm.layer import AnovaKernel 7 | 8 | 9 | class TestAnovaKernel(unittest.TestCase): 10 | 11 | def test_forward_order_2(self): 12 | batch_size, num_fields, embed_dim = 32, 16, 16 13 | kernel = AnovaKernel(order=2, reduce_sum=True).eval() 14 | with torch.no_grad(): 15 | x = torch.FloatTensor(np.random.randn(batch_size, num_fields, embed_dim)) 16 | y_true = 0 17 | for i in range(num_fields - 1): 18 | for j in range(i + 1, num_fields): 19 | y_true = x[:, i, :] * x[:, j, :] + y_true 20 | y_true = torch.sum(y_true, dim=1, keepdim=True).numpy() 21 | y_pred = kernel(x).numpy() 22 | np.testing.assert_almost_equal(y_pred, y_true, 3) 23 | 24 | def test_forward_order_3(self): 25 | batch_size, num_fields, embed_dim = 32, 16, 16 26 | kernel = AnovaKernel(order=3, reduce_sum=True).eval() 27 | with torch.no_grad(): 28 | x = torch.FloatTensor(np.random.randn(batch_size, num_fields, embed_dim)) 29 | y_true = 0 30 | for i in range(num_fields - 2): 31 | for j in range(i + 1, num_fields - 1): 32 | for k in range(j + 1, num_fields): 33 | y_true = x[:, i, :] * x[:, j, :] * x[:, k, :] + y_true 34 | y_true = torch.sum(y_true, dim=1, keepdim=True).numpy() 35 | y_pred = kernel(x).numpy() 36 | np.testing.assert_almost_equal(y_pred, y_true, 3) 37 | 38 | def test_forward_order_4(self): 39 | batch_size, num_fields, embed_dim = 32, 16, 16 40 | kernel = AnovaKernel(order=4, reduce_sum=True).eval() 41 | with torch.no_grad(): 42 | x = torch.FloatTensor(np.random.randn(batch_size, num_fields, embed_dim)) 43 | y_true = 0 44 | for i in range(num_fields - 3): 45 | for j in range(i + 1, num_fields - 2): 46 | for k in range(j + 1, num_fields - 1): 47 | for l in range(k + 1, num_fields): 48 | y_true = x[:, i, :] * x[:, j, :] * x[:, k, :] * x[:, l, :] + y_true 49 | y_true = torch.sum(y_true, dim=1, keepdim=True).numpy() 50 | y_pred = kernel(x).numpy() 51 | np.testing.assert_almost_equal(y_pred, y_true, 3) 52 | 53 | 54 | if __name__ == '__main__': 55 | unittest.main(verbosity=2) 56 | -------------------------------------------------------------------------------- /torchfm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/torchfm/__init__.py -------------------------------------------------------------------------------- /torchfm/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/torchfm/dataset/__init__.py -------------------------------------------------------------------------------- /torchfm/dataset/avazu.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import struct 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import lmdb 7 | import numpy as np 8 | import torch.utils.data 9 | from tqdm import tqdm 10 | 11 | 12 | class AvazuDataset(torch.utils.data.Dataset): 13 | """ 14 | Avazu Click-Through Rate Prediction Dataset 15 | 16 | Dataset preparation 17 | Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature 18 | 19 | :param dataset_path: avazu train path 20 | :param cache_path: lmdb cache path 21 | :param rebuild_cache: If True, lmdb cache is refreshed 22 | :param min_threshold: infrequent feature threshold 23 | 24 | Reference 25 | https://www.kaggle.com/c/avazu-ctr-prediction 26 | """ 27 | 28 | def __init__(self, dataset_path=None, cache_path='.avazu', rebuild_cache=False, min_threshold=4): 29 | self.NUM_FEATS = 22 30 | self.min_threshold = min_threshold 31 | if rebuild_cache or not Path(cache_path).exists(): 32 | shutil.rmtree(cache_path, ignore_errors=True) 33 | if dataset_path is None: 34 | raise ValueError('create cache: failed: dataset_path is None') 35 | self.__build_cache(dataset_path, cache_path) 36 | self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True) 37 | with self.env.begin(write=False) as txn: 38 | self.length = txn.stat()['entries'] - 1 39 | self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32) 40 | 41 | def __getitem__(self, index): 42 | with self.env.begin(write=False) as txn: 43 | np_array = np.frombuffer( 44 | txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=np.long) 45 | return np_array[1:], np_array[0] 46 | 47 | def __len__(self): 48 | return self.length 49 | 50 | def __build_cache(self, path, cache_path): 51 | feat_mapper, defaults = self.__get_feat_mapper(path) 52 | with lmdb.open(cache_path, map_size=int(1e11)) as env: 53 | field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32) 54 | for i, fm in feat_mapper.items(): 55 | field_dims[i - 1] = len(fm) + 1 56 | with env.begin(write=True) as txn: 57 | txn.put(b'field_dims', field_dims.tobytes()) 58 | for buffer in self.__yield_buffer(path, feat_mapper, defaults): 59 | with env.begin(write=True) as txn: 60 | for key, value in buffer: 61 | txn.put(key, value) 62 | 63 | def __get_feat_mapper(self, path): 64 | feat_cnts = defaultdict(lambda: defaultdict(int)) 65 | with open(path) as f: 66 | f.readline() 67 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 68 | pbar.set_description('Create avazu dataset cache: counting features') 69 | for line in pbar: 70 | values = line.rstrip('\n').split(',') 71 | if len(values) != self.NUM_FEATS + 2: 72 | continue 73 | for i in range(1, self.NUM_FEATS + 1): 74 | feat_cnts[i][values[i + 1]] += 1 75 | feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()} 76 | feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()} 77 | defaults = {i: len(cnt) for i, cnt in feat_mapper.items()} 78 | return feat_mapper, defaults 79 | 80 | def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)): 81 | item_idx = 0 82 | buffer = list() 83 | with open(path) as f: 84 | f.readline() 85 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 86 | pbar.set_description('Create avazu dataset cache: setup lmdb') 87 | for line in pbar: 88 | values = line.rstrip('\n').split(',') 89 | if len(values) != self.NUM_FEATS + 2: 90 | continue 91 | np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32) 92 | np_array[0] = int(values[1]) 93 | for i in range(1, self.NUM_FEATS + 1): 94 | np_array[i] = feat_mapper[i].get(values[i+1], defaults[i]) 95 | buffer.append((struct.pack('>I', item_idx), np_array.tobytes())) 96 | item_idx += 1 97 | if item_idx % buffer_size == 0: 98 | yield buffer 99 | buffer.clear() 100 | yield buffer 101 | -------------------------------------------------------------------------------- /torchfm/dataset/criteo.py: -------------------------------------------------------------------------------- 1 | import math 2 | import shutil 3 | import struct 4 | from collections import defaultdict 5 | from functools import lru_cache 6 | from pathlib import Path 7 | 8 | import lmdb 9 | import numpy as np 10 | import torch.utils.data 11 | from tqdm import tqdm 12 | 13 | 14 | class CriteoDataset(torch.utils.data.Dataset): 15 | """ 16 | Criteo Display Advertising Challenge Dataset 17 | 18 | Data prepration: 19 | * Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature 20 | * Discretize numerical values by log2 transformation which is proposed by the winner of Criteo Competition 21 | 22 | :param dataset_path: criteo train.txt path. 23 | :param cache_path: lmdb cache path. 24 | :param rebuild_cache: If True, lmdb cache is refreshed. 25 | :param min_threshold: infrequent feature threshold. 26 | 27 | Reference: 28 | https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset 29 | https://www.csie.ntu.edu.tw/~r01922136/kaggle-2014-criteo.pdf 30 | """ 31 | 32 | def __init__(self, dataset_path=None, cache_path='.criteo', rebuild_cache=False, min_threshold=10): 33 | self.NUM_FEATS = 39 34 | self.NUM_INT_FEATS = 13 35 | self.min_threshold = min_threshold 36 | if rebuild_cache or not Path(cache_path).exists(): 37 | shutil.rmtree(cache_path, ignore_errors=True) 38 | if dataset_path is None: 39 | raise ValueError('create cache: failed: dataset_path is None') 40 | self.__build_cache(dataset_path, cache_path) 41 | self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True) 42 | with self.env.begin(write=False) as txn: 43 | self.length = txn.stat()['entries'] - 1 44 | self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32) 45 | 46 | def __getitem__(self, index): 47 | with self.env.begin(write=False) as txn: 48 | np_array = np.frombuffer( 49 | txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=np.long) 50 | return np_array[1:], np_array[0] 51 | 52 | def __len__(self): 53 | return self.length 54 | 55 | def __build_cache(self, path, cache_path): 56 | feat_mapper, defaults = self.__get_feat_mapper(path) 57 | with lmdb.open(cache_path, map_size=int(1e11)) as env: 58 | field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32) 59 | for i, fm in feat_mapper.items(): 60 | field_dims[i - 1] = len(fm) + 1 61 | with env.begin(write=True) as txn: 62 | txn.put(b'field_dims', field_dims.tobytes()) 63 | for buffer in self.__yield_buffer(path, feat_mapper, defaults): 64 | with env.begin(write=True) as txn: 65 | for key, value in buffer: 66 | txn.put(key, value) 67 | 68 | def __get_feat_mapper(self, path): 69 | feat_cnts = defaultdict(lambda: defaultdict(int)) 70 | with open(path) as f: 71 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 72 | pbar.set_description('Create criteo dataset cache: counting features') 73 | for line in pbar: 74 | values = line.rstrip('\n').split('\t') 75 | if len(values) != self.NUM_FEATS + 1: 76 | continue 77 | for i in range(1, self.NUM_INT_FEATS + 1): 78 | feat_cnts[i][convert_numeric_feature(values[i])] += 1 79 | for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1): 80 | feat_cnts[i][values[i]] += 1 81 | feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()} 82 | feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()} 83 | defaults = {i: len(cnt) for i, cnt in feat_mapper.items()} 84 | return feat_mapper, defaults 85 | 86 | def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)): 87 | item_idx = 0 88 | buffer = list() 89 | with open(path) as f: 90 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 91 | pbar.set_description('Create criteo dataset cache: setup lmdb') 92 | for line in pbar: 93 | values = line.rstrip('\n').split('\t') 94 | if len(values) != self.NUM_FEATS + 1: 95 | continue 96 | np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32) 97 | np_array[0] = int(values[0]) 98 | for i in range(1, self.NUM_INT_FEATS + 1): 99 | np_array[i] = feat_mapper[i].get(convert_numeric_feature(values[i]), defaults[i]) 100 | for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1): 101 | np_array[i] = feat_mapper[i].get(values[i], defaults[i]) 102 | buffer.append((struct.pack('>I', item_idx), np_array.tobytes())) 103 | item_idx += 1 104 | if item_idx % buffer_size == 0: 105 | yield buffer 106 | buffer.clear() 107 | yield buffer 108 | 109 | 110 | @lru_cache(maxsize=None) 111 | def convert_numeric_feature(val: str): 112 | if val == '': 113 | return 'NULL' 114 | v = int(val) 115 | if v > 2: 116 | return str(int(math.log(v) ** 2)) 117 | else: 118 | return str(v - 2) 119 | -------------------------------------------------------------------------------- /torchfm/dataset/movielens.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch.utils.data 4 | 5 | 6 | class MovieLens20MDataset(torch.utils.data.Dataset): 7 | """ 8 | MovieLens 20M Dataset 9 | 10 | Data preparation 11 | treat samples with a rating less than 3 as negative samples 12 | 13 | :param dataset_path: MovieLens dataset path 14 | 15 | Reference: 16 | https://grouplens.org/datasets/movielens 17 | """ 18 | 19 | def __init__(self, dataset_path, sep=',', engine='c', header='infer'): 20 | data = pd.read_csv(dataset_path, sep=sep, engine=engine, header=header).to_numpy()[:, :3] 21 | self.items = data[:, :2].astype(np.int) - 1 # -1 because ID begins from 1 22 | self.targets = self.__preprocess_target(data[:, 2]).astype(np.float32) 23 | self.field_dims = np.max(self.items, axis=0) + 1 24 | self.user_field_idx = np.array((0, ), dtype=np.long) 25 | self.item_field_idx = np.array((1,), dtype=np.long) 26 | 27 | def __len__(self): 28 | return self.targets.shape[0] 29 | 30 | def __getitem__(self, index): 31 | return self.items[index], self.targets[index] 32 | 33 | def __preprocess_target(self, target): 34 | target[target <= 3] = 0 35 | target[target > 3] = 1 36 | return target 37 | 38 | 39 | class MovieLens1MDataset(MovieLens20MDataset): 40 | """ 41 | MovieLens 1M Dataset 42 | 43 | Data preparation 44 | treat samples with a rating less than 3 as negative samples 45 | 46 | :param dataset_path: MovieLens dataset path 47 | 48 | Reference: 49 | https://grouplens.org/datasets/movielens 50 | """ 51 | 52 | def __init__(self, dataset_path): 53 | super().__init__(dataset_path, sep='::', engine='python', header=None) 54 | -------------------------------------------------------------------------------- /torchfm/layer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class FeaturesLinear(torch.nn.Module): 7 | 8 | def __init__(self, field_dims, output_dim=1): 9 | super().__init__() 10 | self.fc = torch.nn.Embedding(sum(field_dims), output_dim) 11 | self.bias = torch.nn.Parameter(torch.zeros((output_dim,))) 12 | self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long) 13 | 14 | def forward(self, x): 15 | """ 16 | :param x: Long tensor of size ``(batch_size, num_fields)`` 17 | """ 18 | x = x + x.new_tensor(self.offsets).unsqueeze(0) 19 | return torch.sum(self.fc(x), dim=1) + self.bias 20 | 21 | 22 | class FeaturesEmbedding(torch.nn.Module): 23 | 24 | def __init__(self, field_dims, embed_dim): 25 | super().__init__() 26 | self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim) 27 | self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long) 28 | torch.nn.init.xavier_uniform_(self.embedding.weight.data) 29 | 30 | def forward(self, x): 31 | """ 32 | :param x: Long tensor of size ``(batch_size, num_fields)`` 33 | """ 34 | x = x + x.new_tensor(self.offsets).unsqueeze(0) 35 | return self.embedding(x) 36 | 37 | 38 | class FieldAwareFactorizationMachine(torch.nn.Module): 39 | 40 | def __init__(self, field_dims, embed_dim): 41 | super().__init__() 42 | self.num_fields = len(field_dims) 43 | self.embeddings = torch.nn.ModuleList([ 44 | torch.nn.Embedding(sum(field_dims), embed_dim) for _ in range(self.num_fields) 45 | ]) 46 | self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long) 47 | for embedding in self.embeddings: 48 | torch.nn.init.xavier_uniform_(embedding.weight.data) 49 | 50 | def forward(self, x): 51 | """ 52 | :param x: Long tensor of size ``(batch_size, num_fields)`` 53 | """ 54 | x = x + x.new_tensor(self.offsets).unsqueeze(0) 55 | xs = [self.embeddings[i](x) for i in range(self.num_fields)] 56 | ix = list() 57 | for i in range(self.num_fields - 1): 58 | for j in range(i + 1, self.num_fields): 59 | ix.append(xs[j][:, i] * xs[i][:, j]) 60 | ix = torch.stack(ix, dim=1) 61 | return ix 62 | 63 | 64 | class FactorizationMachine(torch.nn.Module): 65 | 66 | def __init__(self, reduce_sum=True): 67 | super().__init__() 68 | self.reduce_sum = reduce_sum 69 | 70 | def forward(self, x): 71 | """ 72 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 73 | """ 74 | square_of_sum = torch.sum(x, dim=1) ** 2 75 | sum_of_square = torch.sum(x ** 2, dim=1) 76 | ix = square_of_sum - sum_of_square 77 | if self.reduce_sum: 78 | ix = torch.sum(ix, dim=1, keepdim=True) 79 | return 0.5 * ix 80 | 81 | 82 | class MultiLayerPerceptron(torch.nn.Module): 83 | 84 | def __init__(self, input_dim, embed_dims, dropout, output_layer=True): 85 | super().__init__() 86 | layers = list() 87 | for embed_dim in embed_dims: 88 | layers.append(torch.nn.Linear(input_dim, embed_dim)) 89 | layers.append(torch.nn.BatchNorm1d(embed_dim)) 90 | layers.append(torch.nn.ReLU()) 91 | layers.append(torch.nn.Dropout(p=dropout)) 92 | input_dim = embed_dim 93 | if output_layer: 94 | layers.append(torch.nn.Linear(input_dim, 1)) 95 | self.mlp = torch.nn.Sequential(*layers) 96 | 97 | def forward(self, x): 98 | """ 99 | :param x: Float tensor of size ``(batch_size, embed_dim)`` 100 | """ 101 | return self.mlp(x) 102 | 103 | 104 | class InnerProductNetwork(torch.nn.Module): 105 | 106 | def forward(self, x): 107 | """ 108 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 109 | """ 110 | num_fields = x.shape[1] 111 | row, col = list(), list() 112 | for i in range(num_fields - 1): 113 | for j in range(i + 1, num_fields): 114 | row.append(i), col.append(j) 115 | return torch.sum(x[:, row] * x[:, col], dim=2) 116 | 117 | 118 | class OuterProductNetwork(torch.nn.Module): 119 | 120 | def __init__(self, num_fields, embed_dim, kernel_type='mat'): 121 | super().__init__() 122 | num_ix = num_fields * (num_fields - 1) // 2 123 | if kernel_type == 'mat': 124 | kernel_shape = embed_dim, num_ix, embed_dim 125 | elif kernel_type == 'vec': 126 | kernel_shape = num_ix, embed_dim 127 | elif kernel_type == 'num': 128 | kernel_shape = num_ix, 1 129 | else: 130 | raise ValueError('unknown kernel type: ' + kernel_type) 131 | self.kernel_type = kernel_type 132 | self.kernel = torch.nn.Parameter(torch.zeros(kernel_shape)) 133 | torch.nn.init.xavier_uniform_(self.kernel.data) 134 | 135 | def forward(self, x): 136 | """ 137 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 138 | """ 139 | num_fields = x.shape[1] 140 | row, col = list(), list() 141 | for i in range(num_fields - 1): 142 | for j in range(i + 1, num_fields): 143 | row.append(i), col.append(j) 144 | p, q = x[:, row], x[:, col] 145 | if self.kernel_type == 'mat': 146 | kp = torch.sum(p.unsqueeze(1) * self.kernel, dim=-1).permute(0, 2, 1) 147 | return torch.sum(kp * q, -1) 148 | else: 149 | return torch.sum(p * q * self.kernel.unsqueeze(0), -1) 150 | 151 | 152 | class CrossNetwork(torch.nn.Module): 153 | 154 | def __init__(self, input_dim, num_layers): 155 | super().__init__() 156 | self.num_layers = num_layers 157 | self.w = torch.nn.ModuleList([ 158 | torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers) 159 | ]) 160 | self.b = torch.nn.ParameterList([ 161 | torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers) 162 | ]) 163 | 164 | def forward(self, x): 165 | """ 166 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 167 | """ 168 | x0 = x 169 | for i in range(self.num_layers): 170 | xw = self.w[i](x) 171 | x = x0 * xw + self.b[i] + x 172 | return x 173 | 174 | 175 | class AttentionalFactorizationMachine(torch.nn.Module): 176 | 177 | def __init__(self, embed_dim, attn_size, dropouts): 178 | super().__init__() 179 | self.attention = torch.nn.Linear(embed_dim, attn_size) 180 | self.projection = torch.nn.Linear(attn_size, 1) 181 | self.fc = torch.nn.Linear(embed_dim, 1) 182 | self.dropouts = dropouts 183 | 184 | def forward(self, x): 185 | """ 186 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 187 | """ 188 | num_fields = x.shape[1] 189 | row, col = list(), list() 190 | for i in range(num_fields - 1): 191 | for j in range(i + 1, num_fields): 192 | row.append(i), col.append(j) 193 | p, q = x[:, row], x[:, col] 194 | inner_product = p * q 195 | attn_scores = F.relu(self.attention(inner_product)) 196 | attn_scores = F.softmax(self.projection(attn_scores), dim=1) 197 | attn_scores = F.dropout(attn_scores, p=self.dropouts[0], training=self.training) 198 | attn_output = torch.sum(attn_scores * inner_product, dim=1) 199 | attn_output = F.dropout(attn_output, p=self.dropouts[1], training=self.training) 200 | return self.fc(attn_output) 201 | 202 | 203 | class CompressedInteractionNetwork(torch.nn.Module): 204 | 205 | def __init__(self, input_dim, cross_layer_sizes, split_half=True): 206 | super().__init__() 207 | self.num_layers = len(cross_layer_sizes) 208 | self.split_half = split_half 209 | self.conv_layers = torch.nn.ModuleList() 210 | prev_dim, fc_input_dim = input_dim, 0 211 | for i in range(self.num_layers): 212 | cross_layer_size = cross_layer_sizes[i] 213 | self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, 214 | stride=1, dilation=1, bias=True)) 215 | if self.split_half and i != self.num_layers - 1: 216 | cross_layer_size //= 2 217 | prev_dim = cross_layer_size 218 | fc_input_dim += prev_dim 219 | self.fc = torch.nn.Linear(fc_input_dim, 1) 220 | 221 | def forward(self, x): 222 | """ 223 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 224 | """ 225 | xs = list() 226 | x0, h = x.unsqueeze(2), x 227 | for i in range(self.num_layers): 228 | x = x0 * h.unsqueeze(1) 229 | batch_size, f0_dim, fin_dim, embed_dim = x.shape 230 | x = x.view(batch_size, f0_dim * fin_dim, embed_dim) 231 | x = F.relu(self.conv_layers[i](x)) 232 | if self.split_half and i != self.num_layers - 1: 233 | x, h = torch.split(x, x.shape[1] // 2, dim=1) 234 | else: 235 | h = x 236 | xs.append(x) 237 | return self.fc(torch.sum(torch.cat(xs, dim=1), 2)) 238 | 239 | 240 | class AnovaKernel(torch.nn.Module): 241 | 242 | def __init__(self, order, reduce_sum=True): 243 | super().__init__() 244 | self.order = order 245 | self.reduce_sum = reduce_sum 246 | 247 | def forward(self, x): 248 | """ 249 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 250 | """ 251 | batch_size, num_fields, embed_dim = x.shape 252 | a_prev = torch.ones((batch_size, num_fields + 1, embed_dim), dtype=torch.float).to(x.device) 253 | for t in range(self.order): 254 | a = torch.zeros((batch_size, num_fields + 1, embed_dim), dtype=torch.float).to(x.device) 255 | a[:, t+1:, :] += x[:, t:, :] * a_prev[:, t:-1, :] 256 | a = torch.cumsum(a, dim=1) 257 | a_prev = a 258 | if self.reduce_sum: 259 | return torch.sum(a[:, -1, :], dim=-1, keepdim=True) 260 | else: 261 | return a[:, -1, :] 262 | -------------------------------------------------------------------------------- /torchfm/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/pytorch-fm/f74ad19771eda104e99874d19dc892e988ec53fa/torchfm/model/__init__.py -------------------------------------------------------------------------------- /torchfm/model/afi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torchfm.layer import FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron 5 | 6 | 7 | class AutomaticFeatureInteractionModel(torch.nn.Module): 8 | """ 9 | A pytorch implementation of AutoInt. 10 | 11 | Reference: 12 | W Song, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks, 2018. 13 | """ 14 | 15 | def __init__(self, field_dims, embed_dim, atten_embed_dim, num_heads, num_layers, mlp_dims, dropouts, has_residual=True): 16 | super().__init__() 17 | self.num_fields = len(field_dims) 18 | self.linear = FeaturesLinear(field_dims) 19 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 20 | self.atten_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 21 | self.embed_output_dim = len(field_dims) * embed_dim 22 | self.atten_output_dim = len(field_dims) * atten_embed_dim 23 | self.has_residual = has_residual 24 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropouts[1]) 25 | self.self_attns = torch.nn.ModuleList([ 26 | torch.nn.MultiheadAttention(atten_embed_dim, num_heads, dropout=dropouts[0]) for _ in range(num_layers) 27 | ]) 28 | self.attn_fc = torch.nn.Linear(self.atten_output_dim, 1) 29 | if self.has_residual: 30 | self.V_res_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 31 | 32 | def forward(self, x): 33 | """ 34 | :param x: Long tensor of size ``(batch_size, num_fields)`` 35 | """ 36 | embed_x = self.embedding(x) 37 | atten_x = self.atten_embedding(embed_x) 38 | cross_term = atten_x.transpose(0, 1) 39 | for self_attn in self.self_attns: 40 | cross_term, _ = self_attn(cross_term, cross_term, cross_term) 41 | cross_term = cross_term.transpose(0, 1) 42 | if self.has_residual: 43 | V_res = self.V_res_embedding(embed_x) 44 | cross_term += V_res 45 | cross_term = F.relu(cross_term).contiguous().view(-1, self.atten_output_dim) 46 | x = self.linear(x) + self.attn_fc(cross_term) + self.mlp(embed_x.view(-1, self.embed_output_dim)) 47 | return torch.sigmoid(x.squeeze(1)) 48 | -------------------------------------------------------------------------------- /torchfm/model/afm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesEmbedding, FeaturesLinear, AttentionalFactorizationMachine 4 | 5 | 6 | class AttentionalFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Attentional Factorization Machine. 9 | 10 | Reference: 11 | J Xiao, et al. Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks, 2017. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, attn_size, dropouts): 15 | super().__init__() 16 | self.num_fields = len(field_dims) 17 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 18 | self.linear = FeaturesLinear(field_dims) 19 | self.afm = AttentionalFactorizationMachine(embed_dim, attn_size, dropouts) 20 | 21 | def forward(self, x): 22 | """ 23 | :param x: Long tensor of size ``(batch_size, num_fields)`` 24 | """ 25 | x = self.linear(x) + self.afm(self.embedding(x)) 26 | return torch.sigmoid(x.squeeze(1)) 27 | -------------------------------------------------------------------------------- /torchfm/model/afn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from torchfm.layer import FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron 6 | 7 | class LNN(torch.nn.Module): 8 | """ 9 | A pytorch implementation of LNN layer 10 | Input shape 11 | - A 3D tensor with shape: ``(batch_size,field_size,embedding_size)``. 12 | Output shape 13 | - 2D tensor with shape:``(batch_size,LNN_dim*embedding_size)``. 14 | Arguments 15 | - **in_features** : Embedding of feature. 16 | - **num_fields**: int.The field size of feature. 17 | - **LNN_dim**: int.The number of Logarithmic neuron. 18 | - **bias**: bool.Whether or not use bias in LNN. 19 | """ 20 | def __init__(self, num_fields, embed_dim, LNN_dim, bias=False): 21 | super(LNN, self).__init__() 22 | self.num_fields = num_fields 23 | self.embed_dim = embed_dim 24 | self.LNN_dim = LNN_dim 25 | self.lnn_output_dim = LNN_dim * embed_dim 26 | self.weight = torch.nn.Parameter(torch.Tensor(LNN_dim, num_fields)) 27 | if bias: 28 | self.bias = torch.nn.Parameter(torch.Tensor(LNN_dim, embed_dim)) 29 | else: 30 | self.register_parameter('bias', None) 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | stdv = 1. / math.sqrt(self.weight.size(1)) 35 | self.weight.data.uniform_(-stdv, stdv) 36 | if self.bias is not None: 37 | self.bias.data.uniform_(-stdv, stdv) 38 | 39 | def forward(self, x): 40 | """ 41 | :param x: Long tensor of size ``(batch_size, num_fields, embedding_size)`` 42 | """ 43 | embed_x_abs = torch.abs(x) # Computes the element-wise absolute value of the given input tensor. 44 | embed_x_afn = torch.add(embed_x_abs, 1e-7) 45 | # Logarithmic Transformation 46 | embed_x_log = torch.log1p(embed_x_afn) # torch.log1p and torch.expm1 47 | lnn_out = torch.matmul(self.weight, embed_x_log) 48 | if self.bias is not None: 49 | lnn_out += self.bias 50 | lnn_exp = torch.expm1(lnn_out) 51 | output = F.relu(lnn_exp).contiguous().view(-1, self.lnn_output_dim) 52 | return output 53 | 54 | 55 | 56 | 57 | 58 | 59 | class AdaptiveFactorizationNetwork(torch.nn.Module): 60 | """ 61 | A pytorch implementation of AFN. 62 | 63 | Reference: 64 | Cheng W, et al. Adaptive Factorization Network: Learning Adaptive-Order Feature Interactions, 2019. 65 | """ 66 | def __init__(self, field_dims, embed_dim, LNN_dim, mlp_dims, dropouts): 67 | super().__init__() 68 | self.num_fields = len(field_dims) 69 | self.linear = FeaturesLinear(field_dims) # Linear 70 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) # Embedding 71 | self.LNN_dim = LNN_dim 72 | self.LNN_output_dim = self.LNN_dim * embed_dim 73 | self.LNN = LNN(self.num_fields, embed_dim, LNN_dim) 74 | self.mlp = MultiLayerPerceptron(self.LNN_output_dim, mlp_dims, dropouts[0]) 75 | 76 | def forward(self, x): 77 | """ 78 | :param x: Long tensor of size ``(batch_size, num_fields)`` 79 | """ 80 | embed_x = self.embedding(x) 81 | lnn_out = self.LNN(embed_x) 82 | x = self.linear(x) + self.mlp(lnn_out) 83 | return torch.sigmoid(x.squeeze(1)) 84 | 85 | -------------------------------------------------------------------------------- /torchfm/model/dcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesEmbedding, CrossNetwork, MultiLayerPerceptron 4 | 5 | 6 | class DeepCrossNetworkModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Deep & Cross Network. 9 | 10 | Reference: 11 | R Wang, et al. Deep & Cross Network for Ad Click Predictions, 2017. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, num_layers, mlp_dims, dropout): 15 | super().__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.embed_output_dim = len(field_dims) * embed_dim 18 | self.cn = CrossNetwork(self.embed_output_dim, num_layers) 19 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False) 20 | self.linear = torch.nn.Linear(mlp_dims[-1] + self.embed_output_dim, 1) 21 | 22 | def forward(self, x): 23 | """ 24 | :param x: Long tensor of size ``(batch_size, num_fields)`` 25 | """ 26 | embed_x = self.embedding(x).view(-1, self.embed_output_dim) 27 | x_l1 = self.cn(embed_x) 28 | h_l2 = self.mlp(embed_x) 29 | x_stack = torch.cat([x_l1, h_l2], dim=1) 30 | p = self.linear(x_stack) 31 | return torch.sigmoid(p.squeeze(1)) 32 | -------------------------------------------------------------------------------- /torchfm/model/dfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FactorizationMachine, FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron 4 | 5 | 6 | class DeepFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of DeepFM. 9 | 10 | Reference: 11 | H Guo, et al. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, 2017. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout): 15 | super().__init__() 16 | self.linear = FeaturesLinear(field_dims) 17 | self.fm = FactorizationMachine(reduce_sum=True) 18 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 19 | self.embed_output_dim = len(field_dims) * embed_dim 20 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 21 | 22 | def forward(self, x): 23 | """ 24 | :param x: Long tensor of size ``(batch_size, num_fields)`` 25 | """ 26 | embed_x = self.embedding(x) 27 | x = self.linear(x) + self.fm(embed_x) + self.mlp(embed_x.view(-1, self.embed_output_dim)) 28 | return torch.sigmoid(x.squeeze(1)) 29 | -------------------------------------------------------------------------------- /torchfm/model/ffm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesLinear, FieldAwareFactorizationMachine 4 | 5 | 6 | class FieldAwareFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Field-aware Factorization Machine. 9 | 10 | Reference: 11 | Y Juan, et al. Field-aware Factorization Machines for CTR Prediction, 2015. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim): 15 | super().__init__() 16 | self.linear = FeaturesLinear(field_dims) 17 | self.ffm = FieldAwareFactorizationMachine(field_dims, embed_dim) 18 | 19 | def forward(self, x): 20 | """ 21 | :param x: Long tensor of size ``(batch_size, num_fields)`` 22 | """ 23 | ffm_term = torch.sum(torch.sum(self.ffm(x), dim=1), dim=1, keepdim=True) 24 | x = self.linear(x) + ffm_term 25 | return torch.sigmoid(x.squeeze(1)) 26 | -------------------------------------------------------------------------------- /torchfm/model/fm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FactorizationMachine, FeaturesEmbedding, FeaturesLinear 4 | 5 | 6 | class FactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Factorization Machine. 9 | 10 | Reference: 11 | S Rendle, Factorization Machines, 2010. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim): 15 | super().__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.linear = FeaturesLinear(field_dims) 18 | self.fm = FactorizationMachine(reduce_sum=True) 19 | 20 | def forward(self, x): 21 | """ 22 | :param x: Long tensor of size ``(batch_size, num_fields)`` 23 | """ 24 | x = self.linear(x) + self.fm(self.embedding(x)) 25 | return torch.sigmoid(x.squeeze(1)) 26 | -------------------------------------------------------------------------------- /torchfm/model/fnfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FieldAwareFactorizationMachine, MultiLayerPerceptron, FeaturesLinear 4 | 5 | 6 | class FieldAwareNeuralFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Field-aware Neural Factorization Machine. 9 | 10 | Reference: 11 | L Zhang, et al. Field-aware Neural Factorization Machine for Click-Through Rate Prediction, 2019. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropouts): 15 | super().__init__() 16 | self.linear = FeaturesLinear(field_dims) 17 | self.ffm = FieldAwareFactorizationMachine(field_dims, embed_dim) 18 | self.ffm_output_dim = len(field_dims) * (len(field_dims) - 1) // 2 * embed_dim 19 | self.bn = torch.nn.BatchNorm1d(self.ffm_output_dim) 20 | self.dropout = torch.nn.Dropout(dropouts[0]) 21 | self.mlp = MultiLayerPerceptron(self.ffm_output_dim, mlp_dims, dropouts[1]) 22 | 23 | def forward(self, x): 24 | """ 25 | :param x: Long tensor of size ``(batch_size, num_fields)`` 26 | """ 27 | cross_term = self.ffm(x).view(-1, self.ffm_output_dim) 28 | cross_term = self.bn(cross_term) 29 | cross_term = self.dropout(cross_term) 30 | x = self.linear(x) + self.mlp(cross_term) 31 | return torch.sigmoid(x.squeeze(1)) 32 | -------------------------------------------------------------------------------- /torchfm/model/fnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesEmbedding, MultiLayerPerceptron 4 | 5 | 6 | class FactorizationSupportedNeuralNetworkModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Neural Factorization Machine. 9 | 10 | Reference: 11 | W Zhang, et al. Deep Learning over Multi-field Categorical Data - A Case Study on User Response Prediction, 2016. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout): 15 | super().__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.embed_output_dim = len(field_dims) * embed_dim 18 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 19 | 20 | def forward(self, x): 21 | """ 22 | :param x: Long tensor of size ``(batch_size, num_fields)`` 23 | """ 24 | embed_x = self.embedding(x) 25 | x = self.mlp(embed_x.view(-1, self.embed_output_dim)) 26 | return torch.sigmoid(x.squeeze(1)) 27 | -------------------------------------------------------------------------------- /torchfm/model/hofm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesLinear, FactorizationMachine, AnovaKernel, FeaturesEmbedding 4 | 5 | 6 | class HighOrderFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Higher-Order Factorization Machines. 9 | 10 | Reference: 11 | M Blondel, et al. Higher-Order Factorization Machines, 2016. 12 | """ 13 | 14 | def __init__(self, field_dims, order, embed_dim): 15 | super().__init__() 16 | if order < 1: 17 | raise ValueError(f'invalid order: {order}') 18 | self.order = order 19 | self.embed_dim = embed_dim 20 | self.linear = FeaturesLinear(field_dims) 21 | if order >= 2: 22 | self.embedding = FeaturesEmbedding(field_dims, embed_dim * (order - 1)) 23 | self.fm = FactorizationMachine(reduce_sum=True) 24 | if order >= 3: 25 | self.kernels = torch.nn.ModuleList([ 26 | AnovaKernel(order=i, reduce_sum=True) for i in range(3, order + 1) 27 | ]) 28 | 29 | def forward(self, x): 30 | """ 31 | :param x: Long tensor of size ``(batch_size, num_fields)`` 32 | """ 33 | y = self.linear(x).squeeze(1) 34 | if self.order >= 2: 35 | x = self.embedding(x) 36 | x_part = x[:, :, :self.embed_dim] 37 | y += self.fm(x_part).squeeze(1) 38 | for i in range(self.order - 2): 39 | x_part = x[:, :, (i + 1) * self.embed_dim: (i + 2) * self.embed_dim] 40 | y += self.kernels[i](x_part).squeeze(1) 41 | return torch.sigmoid(y) 42 | -------------------------------------------------------------------------------- /torchfm/model/lr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesLinear 4 | 5 | 6 | class LogisticRegressionModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Logistic Regression. 9 | """ 10 | 11 | def __init__(self, field_dims): 12 | super().__init__() 13 | self.linear = FeaturesLinear(field_dims) 14 | 15 | def forward(self, x): 16 | """ 17 | :param x: Long tensor of size ``(batch_size, num_fields)`` 18 | """ 19 | return torch.sigmoid(self.linear(x).squeeze(1)) 20 | -------------------------------------------------------------------------------- /torchfm/model/ncf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchfm.layer import FeaturesEmbedding, MultiLayerPerceptron 3 | 4 | 5 | class NeuralCollaborativeFiltering(torch.nn.Module): 6 | """ 7 | A pytorch implementation of Neural Collaborative Filtering. 8 | 9 | Reference: 10 | X He, et al. Neural Collaborative Filtering, 2017. 11 | """ 12 | 13 | def __init__(self, field_dims, user_field_idx, item_field_idx, embed_dim, mlp_dims, dropout): 14 | super().__init__() 15 | self.user_field_idx = user_field_idx 16 | self.item_field_idx = item_field_idx 17 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 18 | self.embed_output_dim = len(field_dims) * embed_dim 19 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False) 20 | self.fc = torch.nn.Linear(mlp_dims[-1] + embed_dim, 1) 21 | 22 | def forward(self, x): 23 | """ 24 | :param x: Long tensor of size ``(batch_size, num_user_fields)`` 25 | """ 26 | x = self.embedding(x) 27 | user_x = x[:, self.user_field_idx].squeeze(1) 28 | item_x = x[:, self.item_field_idx].squeeze(1) 29 | x = self.mlp(x.view(-1, self.embed_output_dim)) 30 | gmf = user_x * item_x 31 | x = torch.cat([gmf, x], dim=1) 32 | x = self.fc(x).squeeze(1) 33 | return torch.sigmoid(x) 34 | -------------------------------------------------------------------------------- /torchfm/model/nfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FactorizationMachine, FeaturesEmbedding, MultiLayerPerceptron, FeaturesLinear 4 | 5 | 6 | class NeuralFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of Neural Factorization Machine. 9 | 10 | Reference: 11 | X He and TS Chua, Neural Factorization Machines for Sparse Predictive Analytics, 2017. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropouts): 15 | super().__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.linear = FeaturesLinear(field_dims) 18 | self.fm = torch.nn.Sequential( 19 | FactorizationMachine(reduce_sum=False), 20 | torch.nn.BatchNorm1d(embed_dim), 21 | torch.nn.Dropout(dropouts[0]) 22 | ) 23 | self.mlp = MultiLayerPerceptron(embed_dim, mlp_dims, dropouts[1]) 24 | 25 | def forward(self, x): 26 | """ 27 | :param x: Long tensor of size ``(batch_size, num_fields)`` 28 | """ 29 | cross_term = self.fm(self.embedding(x)) 30 | x = self.linear(x) + self.mlp(cross_term) 31 | return torch.sigmoid(x.squeeze(1)) 32 | -------------------------------------------------------------------------------- /torchfm/model/pnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesEmbedding, FeaturesLinear, InnerProductNetwork, \ 4 | OuterProductNetwork, MultiLayerPerceptron 5 | 6 | 7 | class ProductNeuralNetworkModel(torch.nn.Module): 8 | """ 9 | A pytorch implementation of inner/outer Product Neural Network. 10 | Reference: 11 | Y Qu, et al. Product-based Neural Networks for User Response Prediction, 2016. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout, method='inner'): 15 | super().__init__() 16 | num_fields = len(field_dims) 17 | if method == 'inner': 18 | self.pn = InnerProductNetwork() 19 | elif method == 'outer': 20 | self.pn = OuterProductNetwork(num_fields, embed_dim) 21 | else: 22 | raise ValueError('unknown product type: ' + method) 23 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 24 | self.linear = FeaturesLinear(field_dims, embed_dim) 25 | self.embed_output_dim = num_fields * embed_dim 26 | self.mlp = MultiLayerPerceptron(num_fields * (num_fields - 1) // 2 + self.embed_output_dim, mlp_dims, dropout) 27 | 28 | def forward(self, x): 29 | """ 30 | :param x: Long tensor of size ``(batch_size, num_fields)`` 31 | """ 32 | embed_x = self.embedding(x) 33 | cross_term = self.pn(embed_x) 34 | x = torch.cat([embed_x.view(-1, self.embed_output_dim), cross_term], dim=1) 35 | x = self.mlp(x) 36 | return torch.sigmoid(x.squeeze(1)) 37 | -------------------------------------------------------------------------------- /torchfm/model/wd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import FeaturesLinear, MultiLayerPerceptron, FeaturesEmbedding 4 | 5 | 6 | class WideAndDeepModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of wide and deep learning. 9 | 10 | Reference: 11 | HT Cheng, et al. Wide & Deep Learning for Recommender Systems, 2016. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout): 15 | super().__init__() 16 | self.linear = FeaturesLinear(field_dims) 17 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 18 | self.embed_output_dim = len(field_dims) * embed_dim 19 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 20 | 21 | def forward(self, x): 22 | """ 23 | :param x: Long tensor of size ``(batch_size, num_fields)`` 24 | """ 25 | embed_x = self.embedding(x) 26 | x = self.linear(x) + self.mlp(embed_x.view(-1, self.embed_output_dim)) 27 | return torch.sigmoid(x.squeeze(1)) 28 | -------------------------------------------------------------------------------- /torchfm/model/xdfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchfm.layer import CompressedInteractionNetwork, FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron 4 | 5 | 6 | class ExtremeDeepFactorizationMachineModel(torch.nn.Module): 7 | """ 8 | A pytorch implementation of xDeepFM. 9 | 10 | Reference: 11 | J Lian, et al. xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems, 2018. 12 | """ 13 | 14 | def __init__(self, field_dims, embed_dim, mlp_dims, dropout, cross_layer_sizes, split_half=True): 15 | super().__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.embed_output_dim = len(field_dims) * embed_dim 18 | self.cin = CompressedInteractionNetwork(len(field_dims), cross_layer_sizes, split_half) 19 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 20 | self.linear = FeaturesLinear(field_dims) 21 | 22 | def forward(self, x): 23 | """ 24 | :param x: Long tensor of size ``(batch_size, num_fields)`` 25 | """ 26 | embed_x = self.embedding(x) 27 | x = self.linear(x) + self.cin(embed_x) + self.mlp(embed_x.view(-1, self.embed_output_dim)) 28 | return torch.sigmoid(x.squeeze(1)) 29 | --------------------------------------------------------------------------------