├── .circleci ├── config.yml └── tests │ ├── learning.py │ └── main.py ├── .gitignore ├── LICENSE ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── data.rst │ ├── examples │ ├── cloning_from_github.rst │ ├── examples.rst │ ├── getting_started.rst │ ├── pandas_backend.rst │ └── your_data.rst │ ├── index.rst │ └── nn.rst ├── examples ├── 0. Embeddings Generation │ ├── 1. (proof of concept) DQN.ipynb │ └── Pipelines │ │ └── ML20M │ │ ├── 1. Async Parsing.ipynb │ │ ├── 2. NLP.ipynb │ │ ├── 3. Feature Engineering .ipynb │ │ ├── 4. Graph Embeddings.ipynb │ │ └── 5. The Big Merge.ipynb ├── 1. Vanilla RL │ ├── 1. Anomaly Detection.ipynb │ ├── 2. DDPG.ipynb │ ├── 3. TD3.ipynb │ ├── 4. SAC.ipynb │ └── 5. LSTM State Encoder.ipynb ├── 2. REINFORCE TopK Off Policy Correction │ ├── 0. Inner workings of REINFORCE inside recnn (optional).ipynb │ ├── 1. Basic Reinforce with RecNN.ipynb │ ├── 2. Reinforce Off Policy Correction.ipynb │ └── 3. TopK Reinforce Off Policy Correction.ipynb ├── 99.To be released, but working │ ├── 2. BCQ │ │ ├── 1. BCQ PyTorch .ipynb │ │ └── 2. BCQ Pyro.ipynb │ ├── 3. Decentralized Recommendation with PySyft │ │ └── 1. PySyft with RecNN Environment .ipynb │ └── 4. SearchNet │ │ └── 1. DDPG_SN.ipynb ├── [Library Basics] │ ├── 1. Getting Started.ipynb │ ├── 2. Different Pandas Backends.ipynb │ └── algorithms how to │ │ ├── ddpg.ipynb │ │ ├── reinforce.ipynb │ │ └── td3.ipynb ├── [Results] │ ├── 1. Ranking.ipynb │ ├── 2. Diversity Test (Indexes).ipynb │ ├── 3. Distances Test.ipynb │ └── 4. BCQ Stochastic Diversity .ipynb ├── readme.md └── streamlit_demo.py ├── readme.md ├── recnn ├── __init__.py ├── data │ ├── __init__.py │ ├── dataset_functions.py │ ├── db_con.py │ ├── env.py │ ├── pandas_backend.py │ └── utils.py ├── nn │ ├── __init__.py │ ├── algo.py │ ├── models.py │ └── update │ │ ├── __init__.py │ │ ├── bcq.py │ │ ├── ddpg.py │ │ ├── misc.py │ │ ├── reinforce.py │ │ └── td3.py └── utils │ ├── __init__.py │ ├── misc.py │ └── plot.py ├── requirements.txt ├── res ├── article_1.png ├── article_2.png └── logo big.png ├── setup.cfg └── setup.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Python CircleCI 2.0 configuration file 2 | # 3 | # Check https://circleci.com/docs/2.0/language-python/ for more details 4 | # 5 | version: 2 6 | jobs: 7 | build: 8 | docker: 9 | - image: pytorch/pytorch 10 | 11 | 12 | working_directory: ~/repo 13 | 14 | steps: 15 | - checkout 16 | 17 | # Download and cache dependencies 18 | - restore_cache: 19 | keys: 20 | - v1-dependencies-{{ checksum "requirements.txt" }} 21 | # fallback to using the latest cache if no exact match is found 22 | - v1-dependencies- 23 | 24 | - run: 25 | name: install dependencies 26 | command: | 27 | python3 -m venv venv 28 | . venv/bin/activate 29 | pip install -r requirements.txt 30 | 31 | - save_cache: 32 | paths: 33 | - ./venv 34 | key: v1-dependencies-{{ checksum "requirements.txt" }} 35 | 36 | - run: 37 | name: run tests 38 | command: | 39 | . venv/bin/activate 40 | pip3 install torch==1.7.0 41 | pip3 install matplotlib 42 | pip3 install . 43 | pytest -v -s ./.circleci/tests/main.py 44 | 45 | - store_artifacts: 46 | path: test-reports 47 | destination: test-reports -------------------------------------------------------------------------------- /.circleci/tests/learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # == recnn == 4 | import recnn 5 | import torch_optimizer as optim 6 | 7 | state = torch.randn(10, 1290) 8 | action = torch.randn(10, 128) 9 | reward = torch.randn(10, 1) 10 | next_state = torch.randn(10, 1290) 11 | done = torch.randn(10, 1) 12 | batch = { 13 | "state": state, 14 | "action": action, 15 | "reward": reward, 16 | "next_state": next_state, 17 | "done": done, 18 | } 19 | 20 | value_net = recnn.nn.Critic(1290, 128, 256, 54e-2) 21 | policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1) 22 | 23 | 24 | def test_recommendation(): 25 | 26 | recommendation = policy_net(state) 27 | value = value_net(state, recommendation) 28 | 29 | assert recommendation.std() > 0 and recommendation.mean != 0 30 | assert value.std() > 0 31 | 32 | 33 | def check_loss_and_networks(loss, nets): 34 | assert loss["value"] > 0 and loss["policy"] != 0 and loss["step"] == 0 35 | for name, netw in nets.items(): 36 | assert netw.training == ("target" not in name) 37 | 38 | 39 | def test_update_function(): 40 | target_value_net = recnn.nn.Critic(1290, 128, 256) 41 | target_policy_net = recnn.nn.Actor(1290, 128, 256) 42 | 43 | target_policy_net.eval() 44 | target_value_net.eval() 45 | 46 | # soft update 47 | recnn.utils.soft_update(value_net, target_value_net, soft_tau=1.0) 48 | recnn.utils.soft_update(policy_net, target_policy_net, soft_tau=1.0) 49 | 50 | # define optimizers 51 | value_optimizer = optim.RAdam(value_net.parameters(), lr=1e-5, weight_decay=1e-2) 52 | policy_optimizer = optim.RAdam(policy_net.parameters(), lr=1e-5, weight_decay=1e-2) 53 | 54 | nets = { 55 | "value_net": value_net, 56 | "target_value_net": target_value_net, 57 | "policy_net": policy_net, 58 | "target_policy_net": target_policy_net, 59 | } 60 | 61 | optimizer = { 62 | "policy_optimizer": policy_optimizer, 63 | "value_optimizer": value_optimizer, 64 | } 65 | 66 | debug = {} 67 | writer = recnn.utils.misc.DummyWriter() 68 | 69 | step = 0 70 | params = { 71 | "gamma": 0.99, 72 | "min_value": -10, 73 | "max_value": 10, 74 | "policy_step": 10, 75 | "soft_tau": 0.001, 76 | } 77 | 78 | loss = recnn.nn.update.ddpg_update( 79 | batch, params, nets, optimizer, torch.device("cpu"), debug, writer, step=step 80 | ) 81 | 82 | check_loss_and_networks(loss, nets) 83 | 84 | 85 | def test_algo(): 86 | value_net = recnn.nn.Critic(1290, 128, 256, 54e-2) 87 | policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1) 88 | 89 | ddpg = recnn.nn.DDPG(policy_net, value_net) 90 | ddpg = ddpg 91 | loss = ddpg.update(batch, learn=True) 92 | check_loss_and_networks(loss, ddpg.nets) 93 | -------------------------------------------------------------------------------- /.circleci/tests/main.py: -------------------------------------------------------------------------------- 1 | from learning import * 2 | 3 | if __name__ == "__main__": 4 | test_recommendation() 5 | test_update_function() 6 | test_algo() 7 | print("All passed") 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | !recnn/data/ 3 | runs/ 4 | envs/ 5 | .vscode/ 6 | models/ 7 | recnn/.idea/ 8 | .ipynb_checkpoints/ 9 | *pt 10 | *csv 11 | *zip 12 | .idea/ 13 | .desktopfolder 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | 120 | # other 121 | .pyre_configuration -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # http://www.sphinx-doc.org/en/master/config 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import sys 14 | 15 | sys.path.append("../../") 16 | import recnn 17 | import sphinx_rtd_theme 18 | 19 | autodoc_mock_imports = ["torch", "tqdm"] 20 | 21 | # sys.path.insert(0, os.path.abspath('.')) 22 | 23 | 24 | # -- Project information ----------------------------------------------------- 25 | 26 | project = "recnn" 27 | copyright = "2019, Mike Watts" 28 | author = "Mike Watts" 29 | 30 | # The full version, including alpha/beta/rc tags 31 | release = "0.1" 32 | 33 | 34 | # -- General configuration --------------------------------------------------- 35 | 36 | # Add any Sphinx extension module names here, as strings. They can be 37 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 38 | # ones. 39 | extensions = [ 40 | "sphinx.ext.autodoc", 41 | "sphinx_rtd_theme", 42 | ] 43 | 44 | # Add any paths that contain templates here, relative to this directory. 45 | templates_path = ["_templates"] 46 | 47 | # List of patterns, relative to source directory, that match files and 48 | # directories to ignore when looking for source files. 49 | # This pattern also affects html_static_path and html_extra_path. 50 | exclude_patterns = [] 51 | 52 | 53 | # -- Options for HTML output ------------------------------------------------- 54 | 55 | # The theme to use for HTML and HTML Help pages. See the documentation for 56 | # a list of builtin themes. 57 | # 58 | html_theme = "sphinx_rtd_theme" 59 | master_doc = "index" 60 | 61 | # Add any paths that contain custom static files (such as style sheets) here, 62 | # relative to this directory. They are copied after the builtin static files, 63 | # so a file named "default.css" will overwrite the builtin "default.css". 64 | html_static_path = ["_static"] 65 | -------------------------------------------------------------------------------- /docs/source/data.rst: -------------------------------------------------------------------------------- 1 | Data 2 | ==== 3 | This module contains things to work with datasets. At the moment, utils are pretty messy and will be rewritten. 4 | 5 | env 6 | --- 7 | 8 | Main abstraction of the library for datasets is called environment, similar to how other reinforcement learning libraries name it. This interface is created to provide SARSA like input for your RL Models. When you are working with recommendation env, you have two choices: using static length inputs (say 10 items) or dynamic length time series with sequential encoders (many to one rnn). Static length is provided via FrameEnv, and dynamic length along with sequential state representation encoder is implemented in SeqEnv. Let's take a look at FrameEnv first: 9 | 10 | 11 | .. automodule:: recnn.data.env 12 | :members: 13 | 14 | Reference 15 | +++++++++ 16 | 17 | .. autoclass:: UserDataset 18 | :members: __init__, __len__, __getitem__ 19 | 20 | .. autoclass:: Env 21 | :members: __init__ 22 | 23 | .. autoclass:: FrameEnv 24 | :members: __init__, train_batch, test_batch 25 | 26 | .. autoclass:: SeqEnv 27 | :members: __init__, train_batch, test_batch 28 | 29 | 30 | dataset_functions 31 | ----------------- 32 | 33 | What? 34 | +++++ 35 | 36 | RecNN is designed to work with your data flow. 37 | 38 | Set kwargs in the beginning of prepare_dataset function. 39 | Kwargs you set are immutable. 40 | 41 | args_mut are mutable arguments, you can access the following: 42 | base: data.EnvBase, df: DataFrame, users: List[int], 43 | user_dict: Dict[int, Dict[str, np.ndarray] 44 | 45 | Access args_mut and modify them in functions defined by you. 46 | Best to use function chaining with build_data_pipeline. 47 | 48 | recnn.data.prepare_dataset is a function that is used by default in Env.__init__ 49 | But sometimes you want some extra. I have also predefined truncate_dataset. 50 | This function truncates the number of items to specified one. 51 | In reinforce example I modify it to look like:: 52 | 53 | def prepare_dataset(args_mut, kwargs): 54 | kwargs.set('reduce_items_to', num_items) # set kwargs for your functions here! 55 | pipeline = [recnn.data.truncate_dataset, recnn.data.prepare_dataset] 56 | recnn.data.build_data_pipeline(pipeline, kwargs, args_mut) 57 | 58 | # embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL 59 | env = recnn.data.env.FrameEnv('..', 60 | '...', frame_size, batch_size, 61 | embed_batch=embed_batch, prepare_dataset=prepare_dataset, 62 | num_workers=0) 63 | 64 | .. automodule:: recnn.data.dataset_functions 65 | :members: 66 | 67 | -------------------------------------------------------------------------------- /docs/source/examples/cloning_from_github.rst: -------------------------------------------------------------------------------- 1 | Cloning from github 2 | =================== 3 | 4 | Pro tip: clone without history (unless you need it):: 5 | 6 | git clone --depth 1 git@github.com:awarebayes/RecNN.git 7 | 8 | Create ENV and install deps:: 9 | 10 | conda create --name recnn 11 | conda activate recnn 12 | cd RecNN 13 | pip install -r requirements.txt 14 | 15 | Download data from the donwloads section 16 | 17 | Start jupyter notebook and jump to the examples folder :: 18 | 19 | jupyter-notebook . 20 | 21 | Here is how my project directories looks like (shallow):: 22 | 23 | RecNN 24 | ├── .circleci 25 | ├── data 26 | ├── docs 27 | ├── examples 28 | ├── .git 29 | ├── .gitignore 30 | ├── LICENSE 31 | ├── models 32 | ├── readme.md 33 | ├── recnn 34 | ├── requirements.txt 35 | ├── res 36 | ├── runs 37 | ├── setup.cfg 38 | └── setup.py 39 | 40 | Here is the data directory (ignore the cache):: 41 | 42 | data 43 | ├── cache 44 | │ ├── frame_env.pkl 45 | │ └── frame_env_truncated.pkl 46 | ├── embeddings 47 | │ └── ml20_pca128.pkl 48 | └── ml-20m 49 | ├── genome-scores.csv 50 | ├── genome-tags.csv 51 | ├── links.csv 52 | ├── movies.csv 53 | ├── ratings.csv 54 | ├── README.txt 55 | └── tags.csv -------------------------------------------------------------------------------- /docs/source/examples/examples.rst: -------------------------------------------------------------------------------- 1 | Examples Page 2 | =============== 3 | 4 | Welcome to the tutorial page. It is advised you'd use Local Jupyter/Google Colab/Gradient instead. 5 | Those are mostly copies of the notebooks. 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | :caption: Reference: 10 | 11 | Getting Started 12 | Working with Your Data 13 | Using Pandas backends 14 | 15 | -------------------------------------------------------------------------------- /docs/source/examples/getting_started.rst: -------------------------------------------------------------------------------- 1 | Getting Started with recnn 2 | ========================== 3 | 4 | Colab Version Here (clickable): 5 | 6 | .. image:: https://colab.research.google.com/assets/colab-badge.svg 7 | :target: https://colab.research.google.com/drive/1xWX4JQvlcx3mizwL4gB0THEyxw6LsXTL 8 | 9 | 10 | Offline example is in: RecNN/examples/[Library Basics]/1. Getting Started.ipynb 11 | 12 | Let's do some imports:: 13 | 14 | import recnn 15 | 16 | import recnn 17 | import torch 18 | import torch.nn as nn 19 | from tqdm.auto import tqdm 20 | 21 | tqdm.pandas() 22 | 23 | from jupyterthemes import jtplot 24 | jtplot.style(theme='grade3') 25 | 26 | Environments 27 | ++++++++++++ 28 | Main abstraction of the library for datasets is called environment, similar to how other reinforcement learning libraries name it. This interface is created to provide SARSA like input for your RL Models. When you are working with recommendation env, you have two choices: using static length inputs (say 10 items) or dynamic length time series with sequential encoders (many to one rnn). Static length is provided via FrameEnv, and dynamic length along with sequential state representation encoder is implemented in SeqEnv. Let’s take a look at FrameEnv first: 29 | 30 | In order to initialize an env, you need to provide embeddings and ratings directories:: 31 | 32 | frame_size = 10 33 | batch_size = 25 34 | # embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL 35 | dirs = recnn.data.env.DataPath( 36 | base="../../../data/", 37 | embeddings="embeddings/ml20_pca128.pkl", 38 | ratings="ml-20m/ratings.csv", 39 | cache="cache/frame_env.pkl", # cache will generate after you run 40 | use_cache=True 41 | ) 42 | env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size) 43 | 44 | train = env.train_batch() 45 | test = env.train_batch() 46 | state, action, reward, next_state, done = recnn.data.get_base_batch(train, device=torch.device('cpu')) 47 | 48 | print(state) 49 | 50 | # State 51 | tensor([[ 5.4261, -4.6243, 2.3351, ..., 3.0000, 4.0000, 1.0000], 52 | [ 6.2052, -1.8592, -0.3248, ..., 4.0000, 1.0000, 4.0000], 53 | [ 3.2902, -5.0021, -10.7066, ..., 1.0000, 4.0000, 2.0000], 54 | ..., 55 | [ 3.0571, -4.1390, -2.7344, ..., 3.0000, -3.0000, -1.0000], 56 | [ 0.8177, -7.0827, -0.6607, ..., -3.0000, -1.0000, 3.0000], 57 | [ 9.0742, 0.3944, -6.4801, ..., -1.0000, 3.0000, -1.0000]]) 58 | 59 | Recommending 60 | ++++++++++++ 61 | 62 | Let's initialize main networks, and recommend something! :: 63 | 64 | value_net = recnn.nn.Critic(1290, 128, 256, 54e-2) 65 | policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1) 66 | 67 | recommendation = policy_net(state) 68 | value = value_net(state, recommendation) 69 | print(recommendation) 70 | print(value) 71 | 72 | # Output: 73 | 74 | tensor([[ 1.5302, -2.3658, 1.6439, ..., 0.1297, 2.2236, 2.9672], 75 | [ 0.8570, -1.3491, -0.3350, ..., -0.8712, 5.8390, 3.0899], 76 | [-3.3727, -3.6797, -3.9109, ..., 3.2436, 1.2161, -1.4018], 77 | ..., 78 | [-1.7834, -0.4289, 0.9808, ..., -2.3487, -5.8386, 3.5981], 79 | [ 2.3813, -1.9076, 4.3054, ..., 5.2221, 2.3165, -0.0192], 80 | [-3.8265, 1.8143, -1.8106, ..., 3.3988, -3.1845, 0.7432]], 81 | grad_fn=) 82 | tensor([[-1.0065], 83 | [ 0.3728], 84 | [ 2.1063], 85 | ..., 86 | [-2.1382], 87 | [ 0.3330], 88 | [ 5.4069]], grad_fn=) 89 | 90 | Algo classes 91 | ++++++++++++ 92 | 93 | Algo is a high level abstraction for an RL algorithm. You need two networks 94 | (policy and value) in order to initialize it. Later on you can tweak parameters 95 | and stuff in the algo itself. 96 | 97 | Important: you can set writer to torch.SummaryWriter and get the debug output 98 | Tweak how you want:: 99 | 100 | ddpg = recnn.nn.DDPG(policy_net, value_net) 101 | print(ddpg.params) 102 | ddpg.params['gamma'] = 0.9 103 | ddpg.params['policy_step'] = 3 104 | ddpg.optimizers['policy_optimizer'] = torch.optim.Adam(ddpg.nets['policy_net'], your_lr) 105 | ddpg.writer = torch.utils.tensorboard.SummaryWriter('./runs') 106 | ddpg = ddpg.to(torch.device('cuda')) 107 | 108 | ddpg.loss_layout is also handy, it allows you to see how the loss should look like :: 109 | 110 | # test function 111 | def run_tests(): 112 | batch = next(iter(env.test_dataloader)) 113 | loss = ddpg.update(batch, learn=False) 114 | return loss 115 | 116 | value_net = recnn.nn.Critic(1290, 128, 256, 54e-2) 117 | policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1) 118 | 119 | cuda = torch.device('cuda') 120 | ddpg = recnn.nn.DDPG(policy_net, value_net) 121 | ddpg = ddpg.to(cuda) 122 | plotter = recnn.utils.Plotter(ddpg.loss_layout, [['value', 'policy']],) 123 | ddpg.writer = SummaryWriter(dir='./runs') 124 | 125 | from IPython.display import clear_output 126 | import matplotlib.pyplot as plt 127 | %matplotlib inline 128 | 129 | plot_every = 50 130 | n_epochs = 2 131 | 132 | def learn(): 133 | for epoch in range(n_epochs): 134 | for batch in tqdm(env.train_dataloader): 135 | loss = ddpg.update(batch, learn=True) 136 | plotter.log_losses(loss) 137 | ddpg.step() 138 | if ddpg._step % plot_every == 0: 139 | clear_output(True) 140 | print('step', ddpg._step) 141 | test_loss = run_tests() 142 | plotter.log_losses(test_loss, test=True) 143 | plotter.plot_loss() 144 | if ddpg._step > 1000: 145 | return 146 | 147 | learn() 148 | 149 | Update Functions 150 | ++++++++++++++++ 151 | 152 | Basically, the Algo class is a high level wrapper around the update function. The code for that is pretty messy, 153 | so if you want to check it out, I explained it in the colab notebook linked at the top. 154 | -------------------------------------------------------------------------------- /docs/source/examples/pandas_backend.rst: -------------------------------------------------------------------------------- 1 | Using Pandas Backends 2 | ========================== 3 | 4 | 5 | RecNN supports different types of pandas backends for faster data loading/processing in and out of core 6 | 7 | 8 | Pandas is your default backend:: 9 | 10 | # but you can also set it directly: 11 | recnn.pd.set("pandas") 12 | frame_size = 10 13 | batch_size = 25 14 | dirs = recnn.data.env.DataPath( 15 | base="../../../data/", 16 | embeddings="embeddings/ml20_pca128.pkl", 17 | ratings="ml-20m/ratings.csv", 18 | cache="cache/frame_env.pkl", # cache will generate after you run 19 | use_cache=False # disable for testing purposes 20 | ) 21 | 22 | %%time 23 | env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size) 24 | 25 | # Output: 26 | 100%|██████████| 20000263/20000263 [00:13<00:00, 1469488.15it/s] 27 | 100%|██████████| 20000263/20000263 [00:15<00:00, 1265183.17it/s] 28 | 100%|██████████| 138493/138493 [00:06<00:00, 19935.53it/s] 29 | CPU times: user 41.6 s, sys: 1.89 s, total: 43.5 s 30 | Wall time: 43.5 s 31 | 32 | 33 | IP.S. nstall Modin `here 34 | `_ , it is not installed via RecNN's deps 35 | 36 | You can also use modin with Dask / Ray. 37 | 38 | Here is a little Ray example:: 39 | 40 | import os 41 | import ray 42 | 43 | if ray.is_initialized(): 44 | ray.shutdown() 45 | os.environ["MODIN_ENGINE"] = "ray" # Modin will use Ray 46 | ray.init(num_cpus=10) # adjust for your liking 47 | recnn.pd.set("modin") 48 | %%time 49 | env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size) 50 | 51 | 100%|██████████| 138493/138493 [00:07<00:00, 18503.97it/s] 52 | CPU times: user 12 s, sys: 2.06 s, total: 14 s 53 | Wall time: 21.4 s 54 | 55 | Using Dask:: 56 | 57 | ### dask 58 | import os 59 | os.environ["MODIN_ENGINE"] = "dask" # Modin will use Dask 60 | recnn.pd.set("modin") 61 | %%time 62 | env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size) 63 | 64 | 100%|██████████| 138493/138493 [00:06<00:00, 19785.99it/s] 65 | CPU times: user 14.2 s, sys: 2.13 s, total: 16.3 s 66 | Wall time: 22 s 67 | 68 | 69 | 70 | **Free 2x improvement in loading speed** 71 | -------------------------------------------------------------------------------- /docs/source/examples/your_data.rst: -------------------------------------------------------------------------------- 1 | Working with your own data 2 | ========================== 3 | 4 | 5 | Colab Version Here (clickable): 6 | 7 | .. image:: https://colab.research.google.com/assets/colab-badge.svg 8 | :target: https://colab.research.google.com/drive/1xWX4JQvlcx3mizwL4gB0THEyxw6LsXTL 9 | 10 | **Some things to know beforehand:** 11 | 12 | When you load and preprocess data, all of the additional data preprocessing happens in the 'prepare_dataset' 13 | function that you should pass. An example of that is in the your own data notebook. Also if you have inconsistent 14 | indexes (i.e. movies index in MovieLens looks like [1, 3, 10, 20]), recnn handles in on its own, reducing 15 | memory usage. There is no need to worry about mixing up indexes while preprocessing your own data. 16 | 17 | Here is how default ML20M dataset is processed. Use this as a reference:: 18 | 19 | def prepare_dataset(args_mut: DataFuncArgsMut, kwargs: DataFuncKwargs): 20 | # get args 21 | frame_size = kwargs.get('frame_size') 22 | key_to_id = args_mut.base.key_to_id 23 | df = args_mut.df 24 | 25 | # rating range mapped from [0, 5] to [-5, 5] 26 | df['rating'] = try_progress_apply(df['rating'], lambda i: 2 * (i - 2.5)) 27 | # id's tend to be inconsistent and sparse so they are remapped here 28 | df['movieId'] = try_progress_apply(df['movieId'], lambda i: key_to_id.get(i)) 29 | users = df[['userId', 'movieId']].groupby(['userId']).size() 30 | users = users[users > frame_size].sort_values(ascending=False).index 31 | 32 | if pd.get_type() == "modin": 33 | df = df._to_pandas() # pandas groupby is sync and doesnt affect performance 34 | ratings = df.sort_values(by='timestamp').set_index('userId').drop('timestamp', axis=1).groupby('userId') 35 | 36 | # Groupby user 37 | user_dict = {} 38 | 39 | def app(x): 40 | userid = int(x.index[0]) 41 | user_dict[userid] = {} 42 | user_dict[userid]['items'] = x['movieId'].values 43 | user_dict[userid]['ratings'] = x['rating'].values 44 | 45 | try_progress_apply(ratings, app) 46 | 47 | args_mut.user_dict = user_dict 48 | args_mut.users = users 49 | 50 | return args_mut, kwargs 51 | 52 | Look in reference/data/dataset_functions for further details. 53 | 54 | Toy Dataset 55 | +++++++++++ 56 | 57 | The code below generates an artificial dataset:: 58 | 59 | import pandas as pd 60 | import numpy as np 61 | import datetime 62 | import random 63 | import time 64 | 65 | def random_string_date(): 66 | return datetime.datetime.strptime('{} {} {} {}'.format(random.randint(1, 366), 67 | random.randint(0, 23), 68 | random.randint(1, 59), 69 | 2019), '%j %H %M %Y').strftime("%m/%d/%Y, %H:%M:%S") 70 | 71 | def string_time_to_unix(s): 72 | return int(time.mktime(datetime.datetime.strptime(s, "%m/%d/%Y, %H:%M:%S").timetuple())) 73 | 74 | size = 100000 75 | n_emb = 1000 76 | n_usr = 1000 77 | mydf = pd.DataFrame({'book_id': np.random.randint(0, n_emb, size=size), 78 | 'reader_id': np.random.randint(1, n_usr, size=size), 79 | 'liked': np.random.randint(0, 2, size=size), 80 | 'when': [random_string_date() for i in range(size)]}) 81 | my_embeddings = dict([(i, torch.tensor(np.random.randn(128)).float()) for i in range(n_emb)]) 82 | mydf.head() 83 | 84 | # output: 85 | book_id reader_id liked when 86 | 0 919 130 0 06/16/2019, 11:54:00 87 | 1 850 814 1 11/29/2019, 12:35:00 88 | 2 733 553 0 07/07/2019, 05:45:00 89 | 3 902 695 1 02/03/2019, 10:29:00 90 | 4 960 993 1 05/29/2019, 01:35:00 91 | 92 | # saving the data 93 | ! mkdir mydataset 94 | import pickle 95 | 96 | mydf.to_csv('mydataset/mydf.csv', index=False) 97 | with open('mydataset/myembeddings.pickle', 'wb') as handle: 98 | pickle.dump(my_embeddings, handle) 99 | 100 | 101 | Writing custom preprocessing function 102 | +++++++++++++++++++++++++++++++++++++ 103 | 104 | The following is a copy of the preprocessing function listed above to work with the toy dataset:: 105 | 106 | def prepare_my_dataset(args_mut, kwargs): 107 | 108 | # get args 109 | frame_size = kwargs.get('frame_size') 110 | key_to_id = args_mut.base.key_to_id 111 | df = args_mut.df 112 | 113 | df['liked'] = df['liked'].apply(lambda a: (a - 1) * (1 - a) + a) 114 | df['when'] = df['when'].apply(string_time_to_unix) 115 | df['book_id'] = df['book_id'].apply(key_to_id.get) 116 | 117 | users = df[['reader_id', 'book_id']].groupby(['reader_id']).size() 118 | users = users[users > frame_size].sort_values(ascending=False).index 119 | 120 | # If using modin: pandas groupby is sync and doesnt affect performance 121 | # if pd.get_type() == "modin": df = df._to_pandas() 122 | ratings = df.sort_values(by='when').set_index('reader_id').drop('when', axis=1).groupby('reader_id') 123 | 124 | # Groupby user 125 | user_dict = {} 126 | 127 | def app(x): 128 | userid = x.index[0] 129 | user_dict[int(userid)] = {} 130 | user_dict[int(userid)]['items'] = x['book_id'].values 131 | user_dict[int(userid)]['ratings'] = x['liked'].values 132 | 133 | ratings.apply(app) 134 | 135 | args_mut.user_dict = user_dict 136 | args_mut.users = users 137 | 138 | return args_mut, kwargs 139 | 140 | 141 | Putting it all together 142 | +++++++++++++++++++++++ 143 | 144 | Final touches:: 145 | 146 | frame_size = 10 147 | batch_size = 25 148 | 149 | dirs = recnn.data.env.DataPath( 150 | base="/mydataset", 151 | embeddings="myembeddings.pickle", 152 | ratings="mydf.csv", 153 | cache="cache/frame_env.pkl", # cache will generate after you run 154 | use_cache=True # generally you want to save env after it runs 155 | ) 156 | # pass prepare_my_dataset here 157 | env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size, prepare_dataset=prepare_my_dataset) 158 | 159 | # test function 160 | def run_tests(): 161 | batch = next(iter(env.test_dataloader)) 162 | loss = ddpg.update(batch, learn=False) 163 | return loss 164 | 165 | value_net = recnn.nn.Critic(1290, 128, 256, 54e-2) 166 | policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1) 167 | 168 | cuda = torch.device('cuda') 169 | ddpg = recnn.nn.DDPG(policy_net, value_net) 170 | ddpg = ddpg.to(cuda) 171 | plotter = recnn.utils.Plotter(ddpg.loss_layout, [['value', 'policy']],) 172 | 173 | from IPython.display import clear_output 174 | import matplotlib.pyplot as plt 175 | %matplotlib inline 176 | 177 | plot_every = 3 178 | n_epochs = 2 179 | 180 | def learn(): 181 | for epoch in range(n_epochs): 182 | for batch in tqdm(env.train_dataloader): 183 | loss = ddpg.update(batch, learn=True) 184 | plotter.log_losses(loss) 185 | ddpg.step() 186 | if ddpg._step % plot_every == 0: 187 | clear_output(True) 188 | print('step', ddpg._step) 189 | test_loss = run_tests() 190 | plotter.log_losses(test_loss, test=True) 191 | plotter.plot_loss() 192 | if ddpg._step > 100: 193 | return 194 | 195 | learn() 196 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. image:: https://github.com/awarebayes/RecNN/raw/master/res/logo%20big.png 2 | :align: center 3 | 4 | Welcome to recnn's documentation! 5 | ======================================== 6 | 7 | What 8 | ++++ 9 | 10 | This is my school project. It focuses on Reinforcement Learning for personalized 11 | news recommendation. The main distinction is that it tries to solve online off-policy 12 | learning with dynamically generated item embeddings. Also, there is no exploration, 13 | since we are working with a dataset. In the example section, I use Google's BERT on 14 | the ML20M dataset to extract contextual information from the movie description to form 15 | the latent vector representations. Later, you can use the same transformation on new, 16 | previously unseen items (hence, the embeddings are dynamically generated). If you don't 17 | want to bother with embeddings pipeline, I have a DQN embeddings generator as a proof 18 | of concept. 19 | 20 | Getting Started 21 | +++++++++++++++ 22 | 23 | There are a couple of ways you can get started. The most straightforward is to clone and go to the examples section. 24 | You can also use Google Colab or Gradient Experiment. 25 | 26 | 27 | How parameters should look like:: 28 | 29 | import torch 30 | import recnn 31 | 32 | env = recnn.data.env.FrameEnv('ml20_pca128.pkl','ml-20m/ratings.csv') 33 | 34 | value_net = recnn.nn.Critic(1290, 128, 256, 54e-2) 35 | policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1) 36 | 37 | cuda = torch.device('cuda') 38 | ddpg = recnn.nn.DDPG(policy_net, value_net) 39 | ddpg = ddpg.to(cuda) 40 | 41 | for batch in env.train_dataloader: 42 | ddpg.update(batch, learn=True) 43 | 44 | 45 | .. toctree:: 46 | :maxdepth: 3 47 | :caption: Tutorials: 48 | 49 | Tutorials 50 | 51 | 52 | .. toctree:: 53 | :maxdepth: 2 54 | :caption: Reference: 55 | 56 | NN 57 | 58 | Data 59 | 60 | Indices and tables 61 | ================== 62 | 63 | * :ref:`genindex` 64 | * :ref:`modindex` 65 | * :ref:`search` 66 | 67 | -------------------------------------------------------------------------------- /docs/source/nn.rst: -------------------------------------------------------------------------------- 1 | NN 2 | == 3 | 4 | 5 | Models 6 | ------ 7 | 8 | .. automodule:: recnn.nn.models 9 | :members: 10 | 11 | Update 12 | ------ 13 | 14 | .. automodule:: recnn.nn.update.ddpg 15 | :members: 16 | :exclude-members: forward 17 | 18 | .. automodule:: recnn.nn.update.td3 19 | :members: 20 | :exclude-members: forward 21 | 22 | .. automodule:: recnn.nn.update.bcq 23 | :members: 24 | :exclude-members: forward 25 | 26 | .. automodule:: recnn.nn.update.misc 27 | :members: 28 | :exclude-members: forward 29 | 30 | Algo 31 | ------ 32 | 33 | .. automodule:: recnn.nn.algo 34 | :members: 35 | -------------------------------------------------------------------------------- /examples/0. Embeddings Generation/Pipelines/ML20M/1. Async Parsing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Movie Parsing\n", 8 | "\n", 9 | "## Disclamer: this code takes a coulple of hours to run. \n", 10 | "## You can download parsed data [here](https://drive.google.com/open?id=1t0LNCbqLjiLkAMFwtP8OIYU-zPUCNAjK)\n", 11 | "\n", 12 | "\n", 13 | "## OMDB\n", 14 | "OMDB is Open Movie Database. Although, it is open, you will need to pay 1 doller to get the key and send up to 100k requests/day. For 5 you get access to the poster API.\n", 15 | "\n", 16 | "http://www.omdbapi.com/" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import pandas as pd\n", 26 | "import requests\n", 27 | "from tqdm import tqdm_notebook as tqdm\n", 28 | "import json\n", 29 | "\n", 30 | "myOmdbKey = 'your key here' # you need to buy omdb key for 1$ on patreon\n", 31 | "movies = pd.read_csv('../../../../data/ml-20m/links.csv')\n", 32 | "movies['imdbId'] = movies['imdbId'].apply(lambda i: '0' * (8 - len(str(i))) + str(i))\n", 33 | "movies['tmdbId'] = movies['tmdbId'].fillna(-1).astype(int).apply(str)\n", 34 | "movies = movies.set_index('movieId')\n", 35 | "movies = movies.to_dict(orient='index')" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "> If failed pops up, run this block again till it's done" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 44, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# movies = json.load(open(\"../../../../data/parsed/omdb.json\", \"r\") )" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 9, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "data": { 61 | "application/vnd.jupyter.widget-view+json": { 62 | "model_id": "c1f10c26028c4326b7c7e6e2e43d1894", 63 | "version_major": 2, 64 | "version_minor": 0 65 | }, 66 | "text/plain": [ 67 | "HBox(children=(IntProgress(value=0, max=27278), HTML(value='')))" 68 | ] 69 | }, 70 | "metadata": {}, 71 | "output_type": "display_data" 72 | }, 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "for id in tqdm(movies.keys()):\n", 83 | " imdbId = movies[id]['imdbId']\n", 84 | " if movies[id].get('omdb', False):\n", 85 | " continue\n", 86 | " try:\n", 87 | " movies[id]['omdb'] = requests.get(\"http://www.omdbapi.com/?i=tt{}&apikey={}&plot=full\".format(imdbId,\n", 88 | " myOmdbKey)).json()\n", 89 | " except:\n", 90 | " print(id, imdbId, 'failed')" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 17, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "with open(\"../../../../data/parsed/omdb.json\", \"w\") as write_file:\n", 100 | " json.dump(movies, write_file)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 46, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "text/plain": [ 111 | "{'imdbId': '00114709',\n", 112 | " 'tmdbId': '862',\n", 113 | " 'omdb': {'Title': 'Toy Story',\n", 114 | " 'Year': '1995',\n", 115 | " 'Rated': 'G',\n", 116 | " 'Released': '22 Nov 1995',\n", 117 | " 'Runtime': '81 min',\n", 118 | " 'Genre': 'Animation, Adventure, Comedy, Family, Fantasy',\n", 119 | " 'Director': 'John Lasseter',\n", 120 | " 'Writer': 'John Lasseter (original story by), Pete Docter (original story by), Andrew Stanton (original story by), Joe Ranft (original story by), Joss Whedon (screenplay by), Andrew Stanton (screenplay by), Joel Cohen (screenplay by), Alec Sokolow (screenplay by)',\n", 121 | " 'Actors': 'Tom Hanks, Tim Allen, Don Rickles, Jim Varney',\n", 122 | " 'Plot': 'A little boy named Andy loves to be in his room, playing with his toys, especially his doll named \"Woody\". But, what do the toys do when Andy is not with them, they come to life. Woody believes that he has life (as a toy) good. However, he must worry about Andy\\'s family moving, and what Woody does not know is about Andy\\'s birthday party. Woody does not realize that Andy\\'s mother gave him an action figure known as Buzz Lightyear, who does not believe that he is a toy, and quickly becomes Andy\\'s new favorite toy. Woody, who is now consumed with jealousy, tries to get rid of Buzz. Then, both Woody and Buzz are now lost. They must find a way to get back to Andy before he moves without them, but they will have to pass through a ruthless toy killer, Sid Phillips.',\n", 123 | " 'Language': 'English',\n", 124 | " 'Country': 'USA',\n", 125 | " 'Awards': 'Nominated for 3 Oscars. Another 23 wins & 17 nominations.',\n", 126 | " 'Poster': 'https://m.media-amazon.com/images/M/MV5BMDU2ZWJlMjktMTRhMy00ZTA5LWEzNDgtYmNmZTEwZTViZWJkXkEyXkFqcGdeQXVyNDQ2OTk4MzI@._V1_SX300.jpg',\n", 127 | " 'Ratings': [{'Source': 'Internet Movie Database', 'Value': '8.3/10'},\n", 128 | " {'Source': 'Rotten Tomatoes', 'Value': '100%'},\n", 129 | " {'Source': 'Metacritic', 'Value': '95/100'}],\n", 130 | " 'Metascore': '95',\n", 131 | " 'imdbRating': '8.3',\n", 132 | " 'imdbVotes': '810,875',\n", 133 | " 'imdbID': 'tt0114709',\n", 134 | " 'Type': 'movie',\n", 135 | " 'DVD': '20 Mar 2001',\n", 136 | " 'BoxOffice': 'N/A',\n", 137 | " 'Production': 'Buena Vista',\n", 138 | " 'Website': 'http://www.disney.com/ToyStory',\n", 139 | " 'Response': 'True'}}" 140 | ] 141 | }, 142 | "execution_count": 46, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | } 146 | ], 147 | "source": [ 148 | "movies['1']" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "## TMDB" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 25, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "import pandas as pd\n", 165 | "import requests\n", 166 | "from tqdm import tqdm_notebook as tqdm\n", 167 | "import json\n", 168 | "\n", 169 | "myTmdbKey = 'your key here' # you can get it for free if you ask them nicely\n", 170 | "movies = pd.read_csv('../../../../data/ml-20m/links.csv')\n", 171 | "movies['imdbId'] = movies['imdbId'].apply(lambda i: '0' * (8 - len(str(i))) + str(i))\n", 172 | "movies['tmdbId'] = movies['tmdbId'].fillna(-1).astype(int).apply(str)\n", 173 | "movies = movies.set_index('movieId')\n", 174 | "movies = movies.to_dict(orient='index')" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 26, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "import asyncio\n", 184 | "# ! pip install aiohttp --user\n", 185 | "import aiohttp\n", 186 | "# ! pip install asyncio-throttle --user\n", 187 | "from asyncio_throttle import Throttler" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 27, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "# movies = json.load(open(\"../../../../data/parsed/tmdb.json\", \"r\") )" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "> you can also run this code multiple times" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 28, 209 | "metadata": {}, 210 | "outputs": [ 211 | { 212 | "data": { 213 | "application/vnd.jupyter.widget-view+json": { 214 | "model_id": "aaacbda54580430396f4ae20b114f146", 215 | "version_major": 2, 216 | "version_minor": 0 217 | }, 218 | "text/plain": [ 219 | "HBox(children=(IntProgress(value=0, max=27278), HTML(value='')))" 220 | ] 221 | }, 222 | "metadata": {}, 223 | "output_type": "display_data" 224 | }, 225 | { 226 | "name": "stdout", 227 | "output_type": "stream", 228 | "text": [ 229 | "\n" 230 | ] 231 | } 232 | ], 233 | "source": [ 234 | "throttler = Throttler(rate_limit=4, period=2)\n", 235 | "\n", 236 | "async def tmdb(session, id, tmdbId):\n", 237 | " url = \"https://api.themoviedb.org/3/movie/{}?api_key={}\".format(tmdbId, myTmdbKey)\n", 238 | " async with throttler:\n", 239 | " async with session.get(url) as resp:\n", 240 | " if resp.status == 429:\n", 241 | " print('throttling')\n", 242 | " await asyncio.sleep(0.2)\n", 243 | " \n", 244 | " movies[id]['tmdb'] = await resp.json()\n", 245 | " \n", 246 | " # this also controlls the timespan between calls\n", 247 | " await asyncio.sleep(0.05)\n", 248 | " \n", 249 | "\n", 250 | "async def main():\n", 251 | " async with aiohttp.ClientSession() as session:\n", 252 | " for id in tqdm(movies.keys()):\n", 253 | " tmdbId = movies[id]['tmdbId']\n", 254 | " if movies[id].get('tmdb', False) and 'status_code' not in movies[id]['tmdb']:\n", 255 | " continue\n", 256 | " await tmdb(session, id, tmdbId)\n", 257 | " \n", 258 | " \n", 259 | "if __name__ == '__main__':\n", 260 | " loop = asyncio.get_event_loop()\n", 261 | " loop.create_task(main())" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 31, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "with open(\"../../../../data/parsed/tmdb.json\", \"w\") as write_file:\n", 271 | " json.dump(movies, write_file)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 43, 277 | "metadata": {}, 278 | "outputs": [ 279 | { 280 | "data": { 281 | "text/plain": [ 282 | "{'imdbId': '00114709',\n", 283 | " 'tmdbId': '862',\n", 284 | " 'tmdb': {'adult': False,\n", 285 | " 'backdrop_path': '/dji4Fm0gCDVb9DQQMRvAI8YNnTz.jpg',\n", 286 | " 'belongs_to_collection': {'id': 10194,\n", 287 | " 'name': 'Toy Story Collection',\n", 288 | " 'poster_path': '/7G9915LfUQ2lVfwMEEhDsn3kT4B.jpg',\n", 289 | " 'backdrop_path': '/9FBwqcd9IRruEDUrTdcaafOMKUq.jpg'},\n", 290 | " 'budget': 30000000,\n", 291 | " 'genres': [{'id': 16, 'name': 'Animation'},\n", 292 | " {'id': 35, 'name': 'Comedy'},\n", 293 | " {'id': 10751, 'name': 'Family'}],\n", 294 | " 'homepage': 'http://toystory.disney.com/toy-story',\n", 295 | " 'id': 862,\n", 296 | " 'imdb_id': 'tt0114709',\n", 297 | " 'original_language': 'en',\n", 298 | " 'original_title': 'Toy Story',\n", 299 | " 'overview': \"Led by Woody, Andy's toys live happily in his room until Andy's birthday brings Buzz Lightyear onto the scene. Afraid of losing his place in Andy's heart, Woody plots against Buzz. But when circumstances separate Buzz and Woody from their owner, the duo eventually learns to put aside their differences.\",\n", 300 | " 'popularity': 29.3,\n", 301 | " 'poster_path': '/rhIRbceoE9lR4veEXuwCC2wARtG.jpg',\n", 302 | " 'production_companies': [{'id': 3,\n", 303 | " 'logo_path': '/1TjvGVDMYsj6JBxOAkUHpPEwLf7.png',\n", 304 | " 'name': 'Pixar',\n", 305 | " 'origin_country': 'US'}],\n", 306 | " 'production_countries': [{'iso_3166_1': 'US',\n", 307 | " 'name': 'United States of America'}],\n", 308 | " 'release_date': '1995-10-30',\n", 309 | " 'revenue': 373554033,\n", 310 | " 'runtime': 81,\n", 311 | " 'spoken_languages': [{'iso_639_1': 'en', 'name': 'English'}],\n", 312 | " 'status': 'Released',\n", 313 | " 'tagline': '',\n", 314 | " 'title': 'Toy Story',\n", 315 | " 'video': False,\n", 316 | " 'vote_average': 7.9,\n", 317 | " 'vote_count': 10896}}" 318 | ] 319 | }, 320 | "execution_count": 43, 321 | "metadata": {}, 322 | "output_type": "execute_result" 323 | } 324 | ], 325 | "source": [ 326 | "movies['1']" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": null, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [] 342 | } 343 | ], 344 | "metadata": { 345 | "kernelspec": { 346 | "display_name": "Python 3", 347 | "language": "python", 348 | "name": "python3" 349 | }, 350 | "language_info": { 351 | "codemirror_mode": { 352 | "name": "ipython", 353 | "version": 3 354 | }, 355 | "file_extension": ".py", 356 | "mimetype": "text/x-python", 357 | "name": "python", 358 | "nbconvert_exporter": "python", 359 | "pygments_lexer": "ipython3", 360 | "version": "3.7.3" 361 | } 362 | }, 363 | "nbformat": 4, 364 | "nbformat_minor": 2 365 | } 366 | -------------------------------------------------------------------------------- /examples/0. Embeddings Generation/Pipelines/ML20M/2. NLP.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# NLP with RoBERTa.\n", 8 | "\n", 9 | "Yeah, I am somewhat of an NLP engineer myself" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "import json\n", 20 | "import torch.nn.functional as F\n", 21 | "import pandas as pd\n", 22 | "from fairseq.data.data_utils import collate_tokens\n", 23 | "from tqdm.auto import tqdm\n", 24 | "import numpy as np" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "omdb = json.load(open(\"../../../../data/parsed/omdb.json\", \"r\") )\n", 34 | "tmdb = json.load(open(\"../../../../data/parsed/tmdb.json\", \"r\") )" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "batch_size = 4\n", 44 | "cuda = torch.device('cuda')" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 4, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "plots = []\n", 54 | "for i in tmdb.keys():\n", 55 | " omdb_plot = omdb[i]['omdb'].get('Plot', '')\n", 56 | " tmdb_plot = tmdb[i]['tmdb'].get('overview', '')\n", 57 | " plot = tmdb_plot + ' ' + omdb_plot\n", 58 | " plots.append((i, plot, len(plot)))\n", 59 | " \n", 60 | "plots = list(sorted(plots, key=lambda x: x[2]))\n", 61 | "plots = list(filter(lambda x: x[2] > 4, plots))\n", 62 | "\n", 63 | "def chunks(l, n):\n", 64 | " for i in range(0, len(l), n):\n", 65 | " yield l[i:i + n]\n", 66 | "\n", 67 | "ids = [i[0] for i in plots]\n", 68 | "plots = [i[1] for i in plots]\n", 69 | "plots = list(chunks(plots, batch_size))\n", 70 | "ids = list(chunks(ids, batch_size))" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 5, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stderr", 80 | "output_type": "stream", 81 | "text": [ 82 | "Using cache found in /home/dev/.cache/torch/hub/pytorch_fairseq_master\n" 83 | ] 84 | }, 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "loading archive file http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz from cache at /home/dev/.cache/torch/pytorch_fairseq/37d2bc14cf6332d61ed5abeb579948e6054e46cc724c7d23426382d11a31b2d6.ae5852b4abc6bf762e0b6b30f19e741aa05562471e9eb8f4a6ae261f04f9b350\n", 90 | "| dictionary: 50264 types\n", 91 | "\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "roberta = torch.hub.load('pytorch/fairseq', 'roberta.base').to(cuda)\n", 97 | "roberta.eval()\n", 98 | "print()" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 6, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "fs = {}\n", 108 | "\n", 109 | "def extract_features(batch, ids):\n", 110 | " batch = collate_tokens([roberta.encode(sent) for sent in batch], pad_idx=1).to(cuda)\n", 111 | " batch = batch[:, :512]\n", 112 | " features = roberta.extract_features(batch)\n", 113 | " pooled_features = F.avg_pool2d(features, (features.size(1), 1)).squeeze()\n", 114 | " for i in range(pooled_features.size(0)):\n", 115 | " fs[ids[i]] = pooled_features[i].detach().cpu().numpy()" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 7, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "application/vnd.jupyter.widget-view+json": { 126 | "model_id": "51a46152153e49e6ab7c6744feb97427", 127 | "version_major": 2, 128 | "version_minor": 0 129 | }, 130 | "text/plain": [ 131 | "HBox(children=(IntProgress(value=0, max=6779), HTML(value='')))" 132 | ] 133 | }, 134 | "metadata": {}, 135 | "output_type": "display_data" 136 | }, 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "\n" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "for batch, ids in tqdm(zip(plots[::-1], ids[::-1]), total=len(plots)):\n", 147 | " extract_features(batch, ids)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 15, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "transformed = pd.DataFrame(fs).T\n", 157 | "transformed.index = transformed.index.astype(int)\n", 158 | "transformed = transformed.sort_index()" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 16, 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "data": { 168 | "text/html": [ 169 | "
\n", 170 | "\n", 183 | "\n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | "
0123456789...758759760761762763764765766767
1-0.0055990.1384940.047051-0.0999810.2082670.163597-0.0502470.0353690.021860-0.001333...-0.0534770.014401-0.035731-0.0686120.1469320.106177-0.128289-0.2316060.047912-0.046285
2-0.0289360.0537340.066000-0.1307390.1975910.014505-0.0017840.0911640.036338-0.002871...-0.0584950.049999-0.049668-0.0378010.0880530.142559-0.166629-0.0814390.034168-0.023142
30.0239510.0820140.041002-0.0583340.1885240.0992000.0092920.0442680.0514450.032975...-0.031117-0.017112-0.016568-0.0092610.0706780.122078-0.029504-0.0450540.1142560.064617
40.0284170.1694140.063841-0.0369330.1143280.0820390.0174220.084967-0.0016090.048082...-0.081082-0.0446950.1646800.0292100.0155970.0805080.006273-0.1553800.0397710.049289
50.0114590.1311490.039703-0.0374070.2890720.121404-0.046844-0.013482-0.1030100.039538...-0.0756060.0075510.031218-0.0005650.1133640.0927640.033090-0.2854670.0503610.061391
\n", 333 | "

5 rows × 768 columns

\n", 334 | "
" 335 | ], 336 | "text/plain": [ 337 | " 0 1 2 3 4 5 6 \\\n", 338 | "1 -0.005599 0.138494 0.047051 -0.099981 0.208267 0.163597 -0.050247 \n", 339 | "2 -0.028936 0.053734 0.066000 -0.130739 0.197591 0.014505 -0.001784 \n", 340 | "3 0.023951 0.082014 0.041002 -0.058334 0.188524 0.099200 0.009292 \n", 341 | "4 0.028417 0.169414 0.063841 -0.036933 0.114328 0.082039 0.017422 \n", 342 | "5 0.011459 0.131149 0.039703 -0.037407 0.289072 0.121404 -0.046844 \n", 343 | "\n", 344 | " 7 8 9 ... 758 759 760 761 \\\n", 345 | "1 0.035369 0.021860 -0.001333 ... -0.053477 0.014401 -0.035731 -0.068612 \n", 346 | "2 0.091164 0.036338 -0.002871 ... -0.058495 0.049999 -0.049668 -0.037801 \n", 347 | "3 0.044268 0.051445 0.032975 ... -0.031117 -0.017112 -0.016568 -0.009261 \n", 348 | "4 0.084967 -0.001609 0.048082 ... -0.081082 -0.044695 0.164680 0.029210 \n", 349 | "5 -0.013482 -0.103010 0.039538 ... -0.075606 0.007551 0.031218 -0.000565 \n", 350 | "\n", 351 | " 762 763 764 765 766 767 \n", 352 | "1 0.146932 0.106177 -0.128289 -0.231606 0.047912 -0.046285 \n", 353 | "2 0.088053 0.142559 -0.166629 -0.081439 0.034168 -0.023142 \n", 354 | "3 0.070678 0.122078 -0.029504 -0.045054 0.114256 0.064617 \n", 355 | "4 0.015597 0.080508 0.006273 -0.155380 0.039771 0.049289 \n", 356 | "5 0.113364 0.092764 0.033090 -0.285467 0.050361 0.061391 \n", 357 | "\n", 358 | "[5 rows x 768 columns]" 359 | ] 360 | }, 361 | "execution_count": 16, 362 | "metadata": {}, 363 | "output_type": "execute_result" 364 | } 365 | ], 366 | "source": [ 367 | "transformed.head()" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 18, 373 | "metadata": {}, 374 | "outputs": [], 375 | "source": [ 376 | "transformed.to_csv('../../../../data/engineering/roberta.csv', index=True, index_label='idx')" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [] 385 | } 386 | ], 387 | "metadata": { 388 | "kernelspec": { 389 | "display_name": "Python 3", 390 | "language": "python", 391 | "name": "python3" 392 | }, 393 | "language_info": { 394 | "codemirror_mode": { 395 | "name": "ipython", 396 | "version": 3 397 | }, 398 | "file_extension": ".py", 399 | "mimetype": "text/x-python", 400 | "name": "python", 401 | "nbconvert_exporter": "python", 402 | "pygments_lexer": "ipython3", 403 | "version": "3.7.3" 404 | } 405 | }, 406 | "nbformat": 4, 407 | "nbformat_minor": 2 408 | } 409 | -------------------------------------------------------------------------------- /examples/0. Embeddings Generation/Pipelines/ML20M/5. The Big Merge.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# The Big Merge\n", 8 | "\n", 9 | "Other methods will be added soon" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import json\n", 19 | "import pandas as pd" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "roberta = pd.read_csv('../../../../data/engineering/roberta.csv')\n", 29 | "cat = pd.read_csv('../../../../data/engineering/mca.csv')\n", 30 | "num = pd.read_csv('../../../../data/engineering/pca.csv')\n", 31 | "num = num.set_index('idx')\n", 32 | "cat = cat.set_index(cat.columns[0])\n", 33 | "roberta = roberta.set_index('idx')" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "movies = pd.read_csv('../../../../data/ml-20m/links.csv')" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "df = pd.concat([roberta, cat, num], axis=1)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 5, 57 | "metadata": { 58 | "scrolled": true 59 | }, 60 | "outputs": [ 61 | { 62 | "name": "stderr", 63 | "output_type": "stream", 64 | "text": [ 65 | "/home/dev/.local/lib/python3.7/site-packages/ppca/_ppca.py:82: RuntimeWarning: divide by zero encountered in log\n", 66 | " det = np.log(np.linalg.det(Sx))\n" 67 | ] 68 | }, 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "1.0\n", 74 | "0.3197493112573442\n", 75 | "0.08642294395895767\n", 76 | "0.011613807065083748\n", 77 | "0.006629248893893269\n", 78 | "0.011409862478935606\n", 79 | "0.8079228072819928\n", 80 | "0.06312485654055533\n", 81 | "0.05411146255289068\n", 82 | "0.04663160316817749\n", 83 | "0.04029385825484155\n", 84 | "0.03489849054666294\n", 85 | "0.03031603746989564\n", 86 | "0.02643809621940929\n", 87 | "0.023162995698678968\n", 88 | "0.0203959192899823\n", 89 | "0.018052295018768483\n", 90 | "0.016059788417409626\n", 91 | "0.014358180687624955\n", 92 | "0.012897967368266539\n", 93 | "0.011638581618058197\n", 94 | "0.010546746675690777\n", 95 | "0.009595126661316566\n", 96 | "0.008761272289783184\n", 97 | "0.008026793373915764\n", 98 | "0.007376680908001365\n", 99 | "0.006798720521605572\n", 100 | "0.0062829681706675355\n", 101 | "0.005821283576526337\n", 102 | "0.005406929219866408\n", 103 | "0.005034242882371531\n", 104 | "0.0046983850262487525\n", 105 | "0.004395154395472112\n", 106 | "0.0041208596504158646\n", 107 | "0.003872232710471879\n", 108 | "0.003646370274946076\n", 109 | "0.0034406925665739774\n", 110 | "0.003252911535067904\n", 111 | "0.003081003730827092\n", 112 | "0.0029231853657245566\n", 113 | "0.0027778886027165495\n", 114 | "0.0026437389509543774\n", 115 | "0.002519533975680055\n", 116 | "0.0024042235610552964\n", 117 | "0.002296891856018224\n", 118 | "0.002196740901195815\n", 119 | "0.002103075838896906\n", 120 | "0.0020152915670275107\n", 121 | "0.0019328607029456268\n", 122 | "0.0018553227544089168\n", 123 | "0.001782274431118891\n", 124 | "0.0017133610573070168\n", 125 | "0.0016482690561274715\n", 126 | "0.0015867194714918043\n", 127 | "0.0015284624761753296\n", 128 | "0.0014732727944593016\n", 129 | "0.0014209459480452047\n", 130 | "0.0013712952201179185\n", 131 | "0.001324149226606064\n", 132 | "0.0012793499847434386\n", 133 | "0.0012367513771092131\n", 134 | "0.0011962179212254842\n", 135 | "0.0011576237692405567\n", 136 | "0.001120851876855733\n", 137 | "0.0010857932949968063\n", 138 | "0.0010523465494305384\n", 139 | "0.001020417084160119\n", 140 | "0.0009899167520883712\n", 141 | "0.0009607633424089101\n", 142 | "0.0009328801386421226\n", 143 | "0.0009061955037479308\n", 144 | "0.0008806424908571753\n", 145 | "0.000856158479084268\n", 146 | "0.0008326848342841142\n", 147 | "0.0008101665949125092\n", 148 | "0.000788552183035085\n", 149 | "0.0007677931401675053\n", 150 | "0.0007478438876611371\n", 151 | "0.0007286615106072425\n", 152 | "0.0007102055644578886\n", 153 | "0.0006924379028039329\n", 154 | "0.0006753225248765649\n", 155 | "0.0006588254409225502\n", 156 | "0.000642914553711682\n", 157 | "0.0006275595539855239\n", 158 | "0.000612731828080193\n", 159 | "0.0005984043756386281\n", 160 | "0.0005845517355467234\n", 161 | "0.0005711499183786994\n", 162 | "0.0005581763436834919\n", 163 | "0.0005456097808138605\n", 164 | "0.0005334302919983713\n", 165 | "0.0005216191767176692\n", 166 | "0.0005101589167735288\n", 167 | "0.0004990331212120225\n", 168 | "0.0004882264711523199\n", 169 | "0.00047772466406392766\n", 170 | "0.00046751435769309957\n", 171 | "0.0004575831136903741\n", 172 | "0.00044791934120325116\n", 173 | "0.0004385122407848385\n", 174 | "0.0004293517489004639\n", 175 | "0.0004204284835234162\n", 176 | "0.00041173369106006774\n", 177 | "0.0004032591951146358\n", 178 | "0.0003949973473198476\n", 179 | "0.00038694098056102355\n", 180 | "0.00037908336481762284\n", 181 | "0.00037141816578634135\n", 182 | "0.0003639394064953727\n", 183 | "0.00035664143180103025\n", 184 | "0.0003495188759925494\n", 185 | "0.00034256663333187554\n", 186 | "0.00033577983148291857\n", 187 | "0.0003291538078078471\n", 188 | "0.00032268408822133665\n", 189 | "0.000316366368663612\n", 190 | "0.0003101964988172501\n", 191 | "0.0003041704680177837\n", 192 | "0.0002982843930814383\n", 193 | "0.000292534507940978\n", 194 | "0.00028691715479323143\n", 195 | "0.0002814287767278767\n", 196 | "0.0002760659114651176\n", 197 | "0.000270825186255097\n", 198 | "0.000265703313633292\n", 199 | "0.0002606970879552861\n", 200 | "0.00025580338255681845\n", 201 | "0.00025101914752978516\n", 202 | "0.00024634140783730274\n", 203 | "0.0002417672618386657\n", 204 | "0.00023729388011872743\n", 205 | "0.0002329185044387394\n", 206 | "0.00022863844695875102\n", 207 | "0.0002244510894975349\n", 208 | "0.00022035388291419267\n", 209 | "0.0002163443465119652\n", 210 | "0.00021242006746824416\n", 211 | "0.000208578700262807\n", 212 | "0.000204817966149351\n", 213 | "0.00020113565245827303\n", 214 | "0.00019752961203156616\n", 215 | "0.0001939977624916267\n", 216 | "0.0001905380855302674\n", 217 | "0.00018714862607027705\n", 218 | "0.00018382749150203104\n", 219 | "0.00018057285068118212\n", 220 | "0.00017738293307356656\n", 221 | "0.00017425602765497317\n", 222 | "0.00017119048191438502\n", 223 | "0.00016818470067314628\n", 224 | "0.00016523714497562736\n", 225 | "0.00016234633080802752\n", 226 | "0.00015951082789844584\n", 227 | "0.0001567292583348756\n", 228 | "0.0001540002952917785\n", 229 | "0.00015132266158923713\n", 230 | "0.0001486951282998472\n", 231 | "0.00014611651331408737\n", 232 | "0.0001435856798353008\n", 233 | "0.00014110153492197242\n", 234 | "0.00013866302796849972\n", 235 | "0.00013626914917508337\n", 236 | "0.00013391892803160665\n", 237 | "0.00013161143180129287\n", 238 | "0.00012934576395484676\n", 239 | "0.0001271210626927477\n", 240 | "0.00012493649938116747\n", 241 | "0.00012279127712133686\n", 242 | "0.00012068462918946032\n", 243 | "0.000118615817648271\n", 244 | "0.00011658413186843575\n", 245 | "0.00011458888713344884\n", 246 | "0.00011262942328538195\n", 247 | "0.00011070510334643124\n", 248 | "0.00010881531226547558\n", 249 | "0.00010695945561289832\n", 250 | "0.00010513695836045223\n", 251 | "0.00010334726374328085\n", 252 | "0.00010158983205243999\n", 253 | "9.986413961660112e-05\n" 254 | ] 255 | } 256 | ], 257 | "source": [ 258 | "from ppca import PPCA\n", 259 | "ppca = PPCA()\n", 260 | "ppca.fit(data=df.values.astype(float), d=128, verbose=True)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 34, 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "data": { 270 | "text/plain": [ 271 | "array([0.10027779, 0.15507462, 0.19134567, 0.21945611, 0.24460353,\n", 272 | " 0.26845611, 0.29102592, 0.31095034, 0.32996193, 0.34643802,\n", 273 | " 0.3620572 , 0.37633324, 0.38997795, 0.40261482, 0.41483703,\n", 274 | " 0.42643099, 0.4368385 , 0.44636099, 0.45568047, 0.46451587,\n", 275 | " 0.47289856, 0.480961 , 0.48863717, 0.49602041, 0.50310459,\n", 276 | " 0.51003156, 0.51656425, 0.52289624, 0.52908073, 0.53498392,\n", 277 | " 0.54078847, 0.54642035, 0.55186093, 0.55712029, 0.56207776,\n", 278 | " 0.56691434, 0.57167964, 0.57630872, 0.58075463, 0.58516137,\n", 279 | " 0.58940848, 0.59361722, 0.597649 , 0.60165605, 0.60551434,\n", 280 | " 0.60932604, 0.61301356, 0.61662949, 0.62021721, 0.62374785,\n", 281 | " 0.62722111, 0.63066072, 0.63403963, 0.63737077, 0.64063838,\n", 282 | " 0.64385491, 0.64703044, 0.65019842, 0.65329135, 0.65630261,\n", 283 | " 0.65926582, 0.66220377, 0.66510903, 0.66796339, 0.67077644,\n", 284 | " 0.67351453, 0.67620537, 0.67883389, 0.68142576, 0.68400129,\n", 285 | " 0.6865571 , 0.68904796, 0.69152211, 0.69396168, 0.69638304,\n", 286 | " 0.69873471, 0.70108001, 0.70338456, 0.70567939, 0.70792041,\n", 287 | " 0.71015812, 0.71239188, 0.71457543, 0.71672394, 0.71886385,\n", 288 | " 0.72098915, 0.72306577, 0.7251115 , 0.72713712, 0.72914898,\n", 289 | " 0.73113489, 0.73309689, 0.73505694, 0.73698951, 0.73889228,\n", 290 | " 0.74079208, 0.74265585, 0.74450515, 0.74634531, 0.7481664 ,\n", 291 | " 0.74995547, 0.75172597, 0.75347283, 0.75519582, 0.75690472,\n", 292 | " 0.75859582, 0.76027799, 0.7619423 , 0.7635889 , 0.76521979,\n", 293 | " 0.76682305, 0.76841543, 0.77000291, 0.77157517, 0.77313363,\n", 294 | " 0.77467194, 0.77619465, 0.77770459, 0.77920782, 0.78069985,\n", 295 | " 0.78217445, 0.78363201, 0.78507767, 0.78651602, 0.78793485,\n", 296 | " 0.78934381, 0.79074488, 0.79214315])" 297 | ] 298 | }, 299 | "execution_count": 34, 300 | "metadata": {}, 301 | "output_type": "execute_result" 302 | } 303 | ], 304 | "source": [ 305 | "ppca.var_exp" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 36, 311 | "metadata": {}, 312 | "outputs": [ 313 | { 314 | "data": { 315 | "text/plain": [ 316 | "Int64Index([ 1, 2, 3, 4, 5, 6, 7, 8,\n", 317 | " 9, 10,\n", 318 | " ...\n", 319 | " 131241, 131243, 131248, 131250, 131252, 131254, 131256, 131258,\n", 320 | " 131260, 131262],\n", 321 | " dtype='int64', length=27278)" 322 | ] 323 | }, 324 | "execution_count": 36, 325 | "metadata": {}, 326 | "output_type": "execute_result" 327 | } 328 | ], 329 | "source": [ 330 | "df.index" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 7, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "import pickle\n", 340 | "import torch\n", 341 | "transformed = ppca.transform()\n", 342 | "films_dict = dict([(k, torch.tensor(transformed[i]).float()) for k, i in zip(df.index, range(transformed.shape[0]))])" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 41, 348 | "metadata": {}, 349 | "outputs": [], 350 | "source": [ 351 | "pickle.dump(films_dict, open('../../../../data/embeddings/ml20_pca128.pkl', 'wb'))" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [] 367 | } 368 | ], 369 | "metadata": { 370 | "kernelspec": { 371 | "display_name": "Python 3", 372 | "language": "python", 373 | "name": "python3" 374 | }, 375 | "language_info": { 376 | "codemirror_mode": { 377 | "name": "ipython", 378 | "version": 3 379 | }, 380 | "file_extension": ".py", 381 | "mimetype": "text/x-python", 382 | "name": "python", 383 | "nbconvert_exporter": "python", 384 | "pygments_lexer": "ipython3", 385 | "version": "3.7.3" 386 | } 387 | }, 388 | "nbformat": 4, 389 | "nbformat_minor": 2 390 | } 391 | -------------------------------------------------------------------------------- /examples/99.To be released, but working/3. Decentralized Recommendation with PySyft/1. PySyft with RecNN Environment .ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Basic PySyft Setup\n", 8 | "\n", 9 | "### RecNN prerequisites " 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "import pandas as pd\n", 20 | "import torch\n", 21 | "from torch.utils.tensorboard import SummaryWriter\n", 22 | "import syft as sy\n", 23 | "from tqdm.auto import tqdm\n", 24 | "\n", 25 | "from IPython.display import clear_output\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "%matplotlib inline\n", 28 | "\n", 29 | "\n", 30 | "# == recnn ==\n", 31 | "import sys\n", 32 | "sys.path.append(\"../../\")\n", 33 | "import recnn\n", 34 | "\n", 35 | "# you can enable cuda here\n", 36 | "cuda = False\n", 37 | "if cuda:\n", 38 | " cuda = torch.device('cuda')\n", 39 | " torch.set_default_tensor_type('torch.cuda.FloatTensor')\n", 40 | "\n", 41 | "# ---\n", 42 | "frame_size = 10\n", 43 | "batch_size = 25\n", 44 | "n_epochs = 100\n", 45 | "plot_every = 30\n", 46 | "step = 0\n", 47 | "# --- \n", 48 | "\n", 49 | "tqdm.pandas()\n", 50 | "\n", 51 | "from jupyterthemes import jtplot\n", 52 | "jtplot.style(theme='grade3')" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "def overwrite_batch_tensor_embeddings(batch, item_embeddings_tensor, frame_size):\n", 62 | " \n", 63 | " from recnn.data.utils import get_irsu\n", 64 | " items_t, ratings_t, sizes_t, users_t = get_irsu(batch)\n", 65 | " \n", 66 | " items_emb = item_embeddings_tensor[items_t.long()]\n", 67 | " b_size = ratings_t.size(0)\n", 68 | "\n", 69 | " items = items_emb[:, :-1, :].view(b_size, -1)\n", 70 | " next_items = items_emb[:, 1:, :].view(b_size, -1)\n", 71 | " ratings = ratings_t[:, :-1]\n", 72 | " next_ratings = ratings_t[:, 1:]\n", 73 | " \n", 74 | " state = torch.cat([items, ratings], 1)\n", 75 | " next_state = torch.cat([next_items, next_ratings], 1)\n", 76 | " action = items_emb[:, -1, :]\n", 77 | " reward = ratings_t[:, -1]\n", 78 | "\n", 79 | " done = torch.zeros(b_size)\n", 80 | " # for some reason syft dies at this line\n", 81 | " # so no done in training. not a big deal\n", 82 | " # done[torch.cumsum(sizes_t - frame_size, dim=0) - 1] = 1\n", 83 | " \n", 84 | " batch = {'state': share(state), 'action': share(action),\n", 85 | " 'reward': share(reward), 'next_state': share(next_state),\n", 86 | " 'done': share(done),\n", 87 | " 'meta': {'users': users_t, 'sizes': sizes_t}}\n", 88 | " return batch\n", 89 | "\n", 90 | "# overwrite batch generation function\n", 91 | "recnn.data.utils.batch_tensor_embeddings = overwrite_batch_tensor_embeddings" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 4, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "data": { 101 | "application/vnd.jupyter.widget-view+json": { 102 | "model_id": "4158c3975d294f21af6dbd8e2e7cfff2", 103 | "version_major": 2, 104 | "version_minor": 0 105 | }, 106 | "text/plain": [ 107 | "HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))" 108 | ] 109 | }, 110 | "metadata": {}, 111 | "output_type": "display_data" 112 | }, 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "\n" 118 | ] 119 | }, 120 | { 121 | "data": { 122 | "application/vnd.jupyter.widget-view+json": { 123 | "model_id": "d12e989ce1c045a1ae8d85236248d668", 124 | "version_major": 2, 125 | "version_minor": 0 126 | }, 127 | "text/plain": [ 128 | "HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))" 129 | ] 130 | }, 131 | "metadata": {}, 132 | "output_type": "display_data" 133 | }, 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "\n" 139 | ] 140 | }, 141 | { 142 | "data": { 143 | "application/vnd.jupyter.widget-view+json": { 144 | "model_id": "f4babedc28a04872ad9e2cb65bcee89c", 145 | "version_major": 2, 146 | "version_minor": 0 147 | }, 148 | "text/plain": [ 149 | "HBox(children=(IntProgress(value=0, max=138493), HTML(value='')))" 150 | ] 151 | }, 152 | "metadata": {}, 153 | "output_type": "display_data" 154 | }, 155 | { 156 | "name": "stdout", 157 | "output_type": "stream", 158 | "text": [ 159 | "\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n", 165 | "env = recnn.data.env.FrameEnv('../../data/embeddings/ml20_pca128.pkl',\n", 166 | " '../../data/ml-20m/ratings.csv', frame_size, batch_size, num_workers=0)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 5, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "# test function\n", 176 | "def run_tests():\n", 177 | " batch = next(iter(env.test_dataloader))\n", 178 | " loss = ddpg.update(batch, learn=False)\n", 179 | " return loss\n", 180 | "\n", 181 | "\n", 182 | "\n", 183 | "value_net = recnn.nn.Critic(1290, 128, 256, 54e-2)\n", 184 | "policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1)\n", 185 | "\n", 186 | "if cuda:\n", 187 | " torch.set_default_tensor_type('torch.cuda.FloatTensor')\n", 188 | " value_net = recnn.nn.Critic(1290, 128, 256, 54e-2).to(cuda)\n", 189 | " policy_net = recnn.nn.Actor(1290, 128, 256, 6e-1).to(cuda)\n", 190 | " torch.set_default_tensor_type('torch.FloatTensor')\n", 191 | "\n", 192 | "ddpg = recnn.nn.DDPG(policy_net, value_net)\n", 193 | "ddpg.writer = SummaryWriter(log_dir='../../runs')\n", 194 | "plotter = recnn.utils.Plotter(ddpg.loss_layout, [['value', 'policy']],)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "## Syft" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 6, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "hook = sy.TorchHook(torch) \n", 211 | "\n", 212 | "alice = sy.VirtualWorker(id=\"alice\", hook=hook)\n", 213 | "bob = sy.VirtualWorker(id=\"bob\", hook=hook)\n", 214 | "james = sy.VirtualWorker(id=\"james\", hook=hook)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 7, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "def share_no_cuda(m):\n", 224 | " m = m.fix_precision().share(bob, alice, crypto_provider=james, requires_grad=True)\n", 225 | " return m\n", 226 | "\n", 227 | "def share_cuda(m):\n", 228 | " m = m.fix_precision()\n", 229 | " m = m.to(cuda).share(bob, alice, crypto_provider=james, requires_grad=True)\n", 230 | " return m\n", 231 | "\n", 232 | "def share(m):\n", 233 | " if cuda:\n", 234 | " return share_cuda(m)\n", 235 | " else:\n", 236 | " return share_no_cuda(m)\n", 237 | "\n", 238 | "if not cuda:\n", 239 | " ddpg.nets = dict([(k, share(v)) for k, v in ddpg.nets.items()])\n", 240 | "share_optim = lambda o: o.fix_precision()\n", 241 | "ddpg.optimizers = dict([(k, share_optim(v)) for k, v in ddpg.optimizers.items()])" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "## Warning: it may freeze your system here.\n", 249 | "\n", 250 | "On my PC it just freezes it. But google colab seem to withstand the BIG FREEZE. Just wait for 5 minutes and restart, click on stop button and restart the cell." 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 10, 256 | "metadata": { 257 | "scrolled": false 258 | }, 259 | "outputs": [ 260 | { 261 | "data": { 262 | "application/vnd.jupyter.widget-view+json": { 263 | "model_id": "59d02a596bd84002ac590e0cb28007cc", 264 | "version_major": 2, 265 | "version_minor": 0 266 | }, 267 | "text/plain": [ 268 | "HBox(children=(IntProgress(value=0, max=5263), HTML(value='')))" 269 | ] 270 | }, 271 | "metadata": {}, 272 | "output_type": "display_data" 273 | }, 274 | { 275 | "name": "stderr", 276 | "output_type": "stream", 277 | "text": [ 278 | "ERROR:root:Internal Python error in the inspect module.\n", 279 | "Below is the traceback from this internal error.\n", 280 | "\n" 281 | ] 282 | }, 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | "Traceback (most recent call last):\n", 288 | " File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py\", line 3325, in run_code\n", 289 | " exec(code_obj, self.user_global_ns, self.user_ns)\n", 290 | " File \"\", line 19, in \n", 291 | " learn()\n", 292 | " File \"\", line 7, in learn\n", 293 | " loss = ddpg.update(batch, learn=True)\n", 294 | " File \"../../recnn/nn/algo.py\", line 67, in update\n", 295 | " self.device, self.debug, self.writer, step=self._step, learn=learn)\n", 296 | " File \"../../recnn/nn/update.py\", line 60, in ddpg_update\n", 297 | " next_action = nets['target_policy_net'](next_state)\n", 298 | " File \"/home/dev/.local/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 541, in __call__\n", 299 | " result = self.forward(*input, **kwargs)\n", 300 | " File \"../../recnn/nn/models.py\", line 65, in forward\n", 301 | " action = F.relu(self.linear1(state))\n", 302 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 449, in overloaded_func\n", 303 | " response = handle_func_command(command)\n", 304 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/native.py\", line 298, in handle_func_command\n", 305 | " response = new_type.handle_func_command(new_command)\n", 306 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/autograd.py\", line 237, in handle_func_command\n", 307 | " return cmd(*args, **kwargs)\n", 308 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/autograd.py\", line 207, in relu\n", 309 | " return tensor.relu()\n", 310 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/autograd.py\", line 141, in method_with_grad\n", 311 | " result = getattr(new_self, name)(*new_args, **new_kwargs)\n", 312 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 341, in overloaded_syft_method\n", 313 | " response = getattr(new_self, attr)(*new_args, **new_kwargs)\n", 314 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py\", line 805, in relu\n", 315 | " return securenn.relu(self)\n", 316 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 434, in relu\n", 317 | " return a_sh * relu_deriv(a_sh) + u\n", 318 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 399, in relu_deriv\n", 319 | " y_sh = share_convert(y_sh)\n", 320 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 354, in share_convert\n", 321 | " eta_p = private_compare(x_bit_sh, r - 1, eta_pp)\n", 322 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/crypto/securenn.py\", line 172, in private_compare\n", 323 | " c_beta0 = -x_bit_sh + (j * r_bit) + j + wc\n", 324 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py\", line 362, in __add__\n", 325 | " return self.add(other, **kwargs)\n", 326 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/overload.py\", line 28, in _hook_method_args\n", 327 | " response = attr(self, new_self, *new_args, **new_kwargs)\n", 328 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/frameworks/torch/tensors/interpreters/additive_shared.py\", line 355, in add\n", 329 | " new_shares[k] = (other[k] + v) % self.field\n", 330 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 484, in overloaded_pointer_method\n", 331 | " response = owner.send_command(location, command)\n", 332 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 489, in send_command\n", 333 | " ret_val = self.send_msg(Operation(message, return_ids), location=recipient)\n", 334 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 258, in send_msg\n", 335 | " bin_response = self._send_msg(bin_message, location)\n", 336 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/virtual.py\", line 7, in _send_msg\n", 337 | " return location._recv_msg(message)\n", 338 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/virtual.py\", line 10, in _recv_msg\n", 339 | " return self.recv_msg(message)\n", 340 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 292, in recv_msg\n", 341 | " response = self._message_router[msg_type](contents)\n", 342 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/workers/base.py\", line 412, in execute_command\n", 343 | " response = getattr(_self, command_name)(*args, **kwargs)\n", 344 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 381, in overloaded_native_method\n", 345 | " raise route_method_exception(e, self, args, kwargs)\n", 346 | " File \"/home/dev/.local/lib/python3.7/site-packages/syft/generic/frameworks/hook/hook.py\", line 378, in overloaded_native_method\n", 347 | " response = method(*args, **kwargs)\n", 348 | "KeyboardInterrupt\n", 349 | "\n", 350 | "During handling of the above exception, another exception occurred:\n", 351 | "\n", 352 | "Traceback (most recent call last):\n", 353 | " File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py\", line 2039, in showtraceback\n", 354 | " stb = value._render_traceback_()\n", 355 | "AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'\n", 356 | "\n", 357 | "During handling of the above exception, another exception occurred:\n", 358 | "\n", 359 | "Traceback (most recent call last):\n", 360 | " File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/ultratb.py\", line 1101, in get_records\n", 361 | " return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)\n", 362 | " File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/ultratb.py\", line 319, in wrapped\n", 363 | " return f(*args, **kwargs)\n", 364 | " File \"/home/dev/anaconda3/lib/python3.7/site-packages/IPython/core/ultratb.py\", line 353, in _fixed_getinnerframes\n", 365 | " records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))\n", 366 | " File \"/home/dev/anaconda3/lib/python3.7/inspect.py\", line 1502, in getinnerframes\n", 367 | " frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)\n", 368 | " File \"/home/dev/anaconda3/lib/python3.7/inspect.py\", line 1460, in getframeinfo\n", 369 | " filename = getsourcefile(frame) or getfile(frame)\n", 370 | " File \"/home/dev/anaconda3/lib/python3.7/inspect.py\", line 693, in getsourcefile\n", 371 | " if os.path.exists(filename):\n", 372 | " File \"/home/dev/anaconda3/lib/python3.7/genericpath.py\", line 19, in exists\n", 373 | " os.stat(path)\n", 374 | "KeyboardInterrupt\n" 375 | ] 376 | }, 377 | { 378 | "ename": "KeyboardInterrupt", 379 | "evalue": "", 380 | "output_type": "error", 381 | "traceback": [ 382 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m" 383 | ] 384 | } 385 | ], 386 | "source": [ 387 | "plot_every = 50\n", 388 | "n_epochs = 2\n", 389 | "\n", 390 | "def learn():\n", 391 | " for epoch in range(n_epochs):\n", 392 | " for batch in tqdm(env.train_dataloader):\n", 393 | " loss = ddpg.update(batch, learn=True)\n", 394 | " plotter.log_losses(loss)\n", 395 | " ddpg.step()\n", 396 | " if ddpg._step % plot_every == 0:\n", 397 | " clear_output(True)\n", 398 | " print('step', ddpg._step)\n", 399 | " test_loss = run_tests()\n", 400 | " plotter.log_losses(test_loss, test=True)\n", 401 | " plotter.plot_loss()\n", 402 | " if ddpg._step > 1000:\n", 403 | " return\n", 404 | " \n", 405 | "learn()" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [] 414 | } 415 | ], 416 | "metadata": { 417 | "kernelspec": { 418 | "display_name": "Python 3", 419 | "language": "python", 420 | "name": "python3" 421 | }, 422 | "language_info": { 423 | "codemirror_mode": { 424 | "name": "ipython", 425 | "version": 3 426 | }, 427 | "file_extension": ".py", 428 | "mimetype": "text/x-python", 429 | "name": "python", 430 | "nbconvert_exporter": "python", 431 | "pygments_lexer": "ipython3", 432 | "version": "3.7.3" 433 | } 434 | }, 435 | "nbformat": 4, 436 | "nbformat_minor": 2 437 | } 438 | -------------------------------------------------------------------------------- /examples/99.To be released, but working/4. SearchNet/1. DDPG_SN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "## Deep TopK Search with Critic Adjustment" 7 | ], 8 | "metadata": { 9 | "collapsed": false, 10 | "pycharm": { 11 | "name": "#%% md\n" 12 | } 13 | } 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "outputs": [], 19 | "source": [ 20 | "from abc import ABC\n", 21 | "import torch\n", 22 | "import torch.nn as nn\n", 23 | "from torch.utils.tensorboard import SummaryWriter\n", 24 | "import torch.nn.functional as F\n", 25 | "import torch_optimizer as optim\n", 26 | "\n", 27 | "from tqdm.auto import tqdm\n", 28 | "\n", 29 | "from IPython.display import clear_output\n", 30 | "%matplotlib inline\n", 31 | "\n", 32 | "\n", 33 | "# == recnn ==\n", 34 | "import sys\n", 35 | "sys.path.append(\"../../\")\n", 36 | "import recnn\n", 37 | "\n", 38 | "cuda = torch.device('cuda')\n", 39 | "\n", 40 | "# ---\n", 41 | "frame_size = 10\n", 42 | "batch_size = 25\n", 43 | "n_epochs = 100\n", 44 | "plot_every = 30\n", 45 | "step = 0\n", 46 | "# ---\n", 47 | "\n", 48 | "tqdm.pandas()\n", 49 | "\n", 50 | "from jupyterthemes import jtplot\n", 51 | "jtplot.style(theme='grade3')" 52 | ], 53 | "metadata": { 54 | "collapsed": false, 55 | "pycharm": { 56 | "name": "#%%\n" 57 | } 58 | } 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "outputs": [], 64 | "source": [ 65 | "# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n", 66 | "dirs = recnn.data.env.DataPath(\n", 67 | " base=\"../../data/\",\n", 68 | " embeddings=\"embeddings/ml20_pca128.pkl\",\n", 69 | " ratings=\"ml-20m/ratings.csv\",\n", 70 | " cache=\"cache/frame_env.pkl\", # cache will generate after you run\n", 71 | " use_cache=True\n", 72 | ")\n", 73 | "env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size)" 74 | ], 75 | "metadata": { 76 | "collapsed": false, 77 | "pycharm": { 78 | "name": "#%%\n" 79 | } 80 | } 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "outputs": [], 86 | "source": [ 87 | "class Actor(nn.Module):\n", 88 | " def __init__(self, input_dim, action_dim, hidden_size, init_w=3e-1):\n", 89 | " super(Actor, self).__init__()\n", 90 | "\n", 91 | " self.drop_layer = nn.Dropout(p=0.5)\n", 92 | "\n", 93 | " self.linear1 = nn.Linear(input_dim, hidden_size)\n", 94 | " self.linear2 = nn.Linear(hidden_size, hidden_size)\n", 95 | " self.linear3 = nn.Linear(hidden_size, action_dim)\n", 96 | "\n", 97 | " self.linear3.weight.data.uniform_(-init_w, init_w)\n", 98 | " self.linear3.bias.data.uniform_(-init_w, init_w)\n", 99 | "\n", 100 | " def forward(self, state):\n", 101 | " # state = self.state_rep(state)\n", 102 | " x = F.relu(self.linear1(state))\n", 103 | " x = self.drop_layer(x)\n", 104 | " x = F.relu(self.linear2(x))\n", 105 | " x = self.drop_layer(x)\n", 106 | " # x = torch.tanh(self.linear3(x)) # in case embeds are -1 1 normalized\n", 107 | " x = self.linear3(x) # in case embeds are standard scaled / wiped using PCA whitening\n", 108 | " # return state, x\n", 109 | " return x" 110 | ], 111 | "metadata": { 112 | "collapsed": false, 113 | "pycharm": { 114 | "name": "#%%\n" 115 | } 116 | } 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "outputs": [], 122 | "source": [ 123 | "class Critic(nn.Module):\n", 124 | " def __init__(self, input_dim, action_dim, hidden_size, init_w=3e-5):\n", 125 | " super(Critic, self).__init__()\n", 126 | "\n", 127 | " self.drop_layer = nn.Dropout(p=0.5)\n", 128 | "\n", 129 | " self.linear1 = nn.Linear(input_dim + action_dim, hidden_size)\n", 130 | " self.linear2 = nn.Linear(hidden_size, hidden_size)\n", 131 | " self.linear3 = nn.Linear(hidden_size, 1)\n", 132 | "\n", 133 | " self.linear3.weight.data.uniform_(-init_w, init_w)\n", 134 | " self.linear3.bias.data.uniform_(-init_w, init_w)\n", 135 | "\n", 136 | " def forward(self, state, action):\n", 137 | " x = torch.cat([state, action], 1)\n", 138 | " x = F.relu(self.linear1(x))\n", 139 | " x = self.drop_layer(x)\n", 140 | " x = F.relu(self.linear2(x))\n", 141 | " x = self.drop_layer(x)\n", 142 | " x = self.linear3(x)\n", 143 | " return x" 144 | ], 145 | "metadata": { 146 | "collapsed": false, 147 | "pycharm": { 148 | "name": "#%%\n" 149 | } 150 | } 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "outputs": [], 156 | "source": [ 157 | "class SearchK(nn.Module):\n", 158 | " def __init__(self, input_dim, action_dim, hidden_size, topK, init_w=3e-1):\n", 159 | " super(SearchK, self).__init__()\n", 160 | "\n", 161 | " self.drop_layer = nn.Dropout(p=0.5)\n", 162 | " self.linear1 = nn.Linear(input_dim + action_dim, hidden_size)\n", 163 | " self.linear2 = nn.Linear(hidden_size, hidden_size)\n", 164 | " self.linear3 = nn.Linear(hidden_size, action_dim*topK)\n", 165 | "\n", 166 | " self.linear3.weight.data.uniform_(-init_w, init_w)\n", 167 | " self.linear3.bias.data.uniform_(-init_w, init_w)\n", 168 | "\n", 169 | " def forward(self, state, action):\n", 170 | " x = torch.cat([state, action], 1)\n", 171 | " x = F.relu(self.linear1(x))\n", 172 | " x = self.drop_layer(x)\n", 173 | " x = F.relu(self.linear2(x))\n", 174 | " x = self.drop_layer(x)\n", 175 | " x = self.linear3(x)\n", 176 | " return x" 177 | ], 178 | "metadata": { 179 | "collapsed": false, 180 | "pycharm": { 181 | "name": "#%%\n" 182 | } 183 | } 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "outputs": [], 189 | "source": [ 190 | "def soft_update(net, target_net, soft_tau=1e-2):\n", 191 | " for target_param, param in zip(target_net.parameters(), net.parameters()):\n", 192 | " target_param.data.copy_(\n", 193 | " target_param.data * (1.0 - soft_tau) + param.data * soft_tau\n", 194 | " )\n", 195 | "\n", 196 | "def run_tests():\n", 197 | " test_batch = next(iter(env.test_dataloader))\n", 198 | " losses = ddpg_sn_update(test_batch, params, learn=False, step=step)\n", 199 | "\n", 200 | " gen_actions = debug['next_action']\n", 201 | " true_actions = env.base.embeddings.detach().cpu().numpy()\n", 202 | "\n", 203 | " f = plotter.kde_reconstruction_error(ad, gen_actions, true_actions, cuda)\n", 204 | " writer.add_figure('rec_error',f, losses['step'])\n", 205 | " return losses" 206 | ], 207 | "metadata": { 208 | "collapsed": false, 209 | "pycharm": { 210 | "name": "#%%\n" 211 | } 212 | } 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "outputs": [], 218 | "source": [ 219 | "def ddpg_sn_update(batch, params, learn=True, step=-1):\n", 220 | "\n", 221 | " state, action, reward, next_state, done = recnn.data.get_base_batch(batch)\n", 222 | "\n", 223 | " # --------------------------------------------------------#\n", 224 | " # Value Learning\n", 225 | "\n", 226 | " with torch.no_grad():\n", 227 | " next_action = target_policy_net(next_state)\n", 228 | " target_value = target_value_net(next_state, next_action.detach())\n", 229 | " expected_value = reward + (1.0 - done) * params['gamma'] * target_value\n", 230 | " expected_value = torch.clamp(expected_value,\n", 231 | " params['min_value'], params['max_value'])\n", 232 | "\n", 233 | " value = value_net(state, action)\n", 234 | "\n", 235 | " value_loss = torch.pow(value - expected_value.detach(), 2).mean()\n", 236 | "\n", 237 | " if learn:\n", 238 | " value_optimizer.zero_grad()\n", 239 | " value_loss.backward()\n", 240 | " value_optimizer.step()\n", 241 | " else:\n", 242 | " debug['next_action'] = next_action\n", 243 | " writer.add_figure('next_action',\n", 244 | " recnn.utils.pairwise_distances_fig(next_action[:50]), step)\n", 245 | " writer.add_histogram('value', value, step)\n", 246 | " writer.add_histogram('target_value', target_value, step)\n", 247 | " writer.add_histogram('expected_value', expected_value, step)\n", 248 | "\n", 249 | " # --------------------------------------------------------#\n", 250 | " # Policy learning\n", 251 | "\n", 252 | " gen_action = policy_net(state)\n", 253 | " policy_loss = -value_net(state, gen_action)\n", 254 | "\n", 255 | " if not learn:\n", 256 | " debug['gen_action'] = gen_action\n", 257 | " writer.add_histogram('policy_loss', policy_loss, step)\n", 258 | " writer.add_figure('next_action',\n", 259 | " recnn.utils.pairwise_distances_fig(gen_action[:50]), step)\n", 260 | "\n", 261 | " policy_loss = policy_loss.mean()\n", 262 | "\n", 263 | " if learn and step % params['policy_step']== 0:\n", 264 | " policy_optimizer.zero_grad()\n", 265 | " policy_loss.backward()\n", 266 | " torch.nn.utils.clip_grad_norm_(policy_net.parameters(), -1, 1)\n", 267 | " policy_optimizer.step()\n", 268 | "\n", 269 | " soft_update(value_net, target_value_net, soft_tau=params['soft_tau'])\n", 270 | " soft_update(policy_net, target_policy_net, soft_tau=params['soft_tau'])\n", 271 | "\n", 272 | " # dont forget search loss here !\n", 273 | " losses = {'value': value_loss.item(), 'policy': policy_loss.item(), 'step': step}\n", 274 | " recnn.utils.write_losses(writer, losses, kind='train' if learn else 'test')\n", 275 | " return losses" 276 | ], 277 | "metadata": { 278 | "collapsed": false, 279 | "pycharm": { 280 | "name": "#%%\n" 281 | } 282 | } 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "outputs": [], 288 | "source": [ 289 | "# === ddpg settings ===\n", 290 | "\n", 291 | "params = {\n", 292 | " 'gamma' : 0.99,\n", 293 | " 'min_value' : -10,\n", 294 | " 'max_value' : 10,\n", 295 | " 'policy_step': 10,\n", 296 | " 'soft_tau' : 0.001,\n", 297 | "\n", 298 | " 'policy_lr' : 1e-5,\n", 299 | " 'value_lr' : 1e-5,\n", 300 | " 'search_lr' : 1e-5,\n", 301 | " 'actor_weight_init': 54e-2,\n", 302 | " 'search_weight_init': 54e-2,\n", 303 | " 'critic_weight_init': 6e-1,\n", 304 | "}\n", 305 | "\n", 306 | "# === end ===" 307 | ], 308 | "metadata": { 309 | "collapsed": false, 310 | "pycharm": { 311 | "name": "#%%\n" 312 | } 313 | } 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "outputs": [], 319 | "source": [ 320 | "value_net = Critic(1290, 128, 256, params['critic_weight_init']).to(cuda)\n", 321 | "policy_net = Actor(1290, 128, 256, params['actor_weight_init']).to(cuda)\n", 322 | "search_net = SearchK(1290, 128, 2048, topK=10, init_w=params['search_weight_init']).to(cuda)\n", 323 | "\n", 324 | "target_value_net = Critic(1290, 128, 256).to(cuda)\n", 325 | "target_policy_net = Actor(1290, 128, 256).to(cuda)\n", 326 | "target_search_net = SearchK(1290, 128, 2048, topK=10).to(cuda)\n", 327 | "\n", 328 | "ad = recnn.nn.models.AnomalyDetector().to(cuda)\n", 329 | "ad.load_state_dict(torch.load('../../models/anomaly.pt'))\n", 330 | "ad.eval()\n", 331 | "\n", 332 | "target_policy_net.eval()\n", 333 | "target_value_net.eval()\n", 334 | "\n", 335 | "soft_update(value_net, target_value_net, soft_tau=1.0)\n", 336 | "soft_update(policy_net, target_policy_net, soft_tau=1.0)\n", 337 | "soft_update(search_net, target_search_net, soft_tau=1.0)\n", 338 | "\n", 339 | "value_criterion = nn.MSELoss()\n", 340 | "search_criterion = nn.MSELoss()\n", 341 | "\n", 342 | "# from good to bad: Ranger Radam Adam RMSprop\n", 343 | "value_optimizer = optim.Ranger(value_net.parameters(),\n", 344 | " lr=params['value_lr'], weight_decay=1e-2)\n", 345 | "policy_optimizer = optim.Ranger(policy_net.parameters(),\n", 346 | " lr=params['policy_lr'], weight_decay=1e-5)\n", 347 | "search_optimizer = optim.Ranger(search_net.parameters(),\n", 348 | " weight_decay=1e-5,\n", 349 | " lr=params['search_lr'])\n", 350 | "\n", 351 | "loss = {\n", 352 | " 'test': {'value': [], 'policy': [], 'search': [], 'step': []},\n", 353 | " 'train': {'value': [], 'policy': [], 'search': [], 'step': []}\n", 354 | " }\n", 355 | "\n", 356 | "debug = {}\n", 357 | "\n", 358 | "writer = SummaryWriter(log_dir='../../runs')\n", 359 | "plotter = recnn.utils.Plotter(loss, [['value', 'policy', 'search']],)" 360 | ], 361 | "metadata": { 362 | "collapsed": false, 363 | "pycharm": { 364 | "name": "#%%\n" 365 | } 366 | } 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "outputs": [], 372 | "source": [ 373 | "for epoch in range(n_epochs):\n", 374 | " for batch in tqdm(env.train_dataloader):\n", 375 | " loss = ddpg_sn_update(batch, params, step=step)\n", 376 | " plotter.log_losses(loss)\n", 377 | " step += 1\n", 378 | " if step % plot_every == 0:\n", 379 | " clear_output(True)\n", 380 | " print('step', step)\n", 381 | " test_loss = run_tests()\n", 382 | " plotter.log_losses(test_loss, test=True)\n", 383 | " plotter.plot_loss()\n", 384 | " if step > 1000:\n", 385 | " assert False" 386 | ], 387 | "metadata": { 388 | "collapsed": false, 389 | "pycharm": { 390 | "name": "#%%\n" 391 | } 392 | } 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "outputs": [], 398 | "source": [ 399 | "torch.save(value_net.state_dict(), \"../../models/ddpg_value.pt\")\n", 400 | "torch.save(policy_net.state_dict(), \"../../models/ddpg_policy.pt\")" 401 | ], 402 | "metadata": { 403 | "collapsed": false, 404 | "pycharm": { 405 | "name": "#%%\n" 406 | } 407 | } 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "source": [ 412 | "# Reconstruction error" 413 | ], 414 | "metadata": { 415 | "collapsed": false 416 | } 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": null, 421 | "outputs": [], 422 | "source": [ 423 | "gen_actions = debug['next_action']\n", 424 | "true_actions = env.base.embeddings.numpy()\n", 425 | "\n", 426 | "\n", 427 | "ad = recnn.nn.AnomalyDetector().to(cuda)\n", 428 | "ad.load_state_dict(torch.load('../../models/anomaly.pt'))\n", 429 | "ad.eval()\n", 430 | "\n", 431 | "plotter.plot_kde_reconstruction_error(ad, gen_actions, true_actions, cuda)" 432 | ], 433 | "metadata": { 434 | "collapsed": false, 435 | "pycharm": { 436 | "name": "#%%\n" 437 | } 438 | } 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "outputs": [], 444 | "source": [], 445 | "metadata": { 446 | "collapsed": false, 447 | "pycharm": { 448 | "name": "#%%\n" 449 | } 450 | } 451 | } 452 | ], 453 | "metadata": { 454 | "kernelspec": { 455 | "display_name": "Python 3", 456 | "language": "python", 457 | "name": "python3" 458 | }, 459 | "language_info": { 460 | "codemirror_mode": { 461 | "name": "ipython", 462 | "version": 2 463 | }, 464 | "file_extension": ".py", 465 | "mimetype": "text/x-python", 466 | "name": "python", 467 | "nbconvert_exporter": "python", 468 | "pygments_lexer": "ipython2", 469 | "version": "2.7.6" 470 | } 471 | }, 472 | "nbformat": 4, 473 | "nbformat_minor": 0 474 | } -------------------------------------------------------------------------------- /examples/[Library Basics]/2. Different Pandas Backends.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# == recnn ==\n", 10 | "import sys\n", 11 | "sys.path.append(\"../../\")\n", 12 | "import recnn" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "## RecNN supports different types of pandas backends\n", 20 | "### for faster loading/processing in and out of core" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "\n", 28 | "![here be pandas logo](https://dev.pandas.io/static/img/pandas.svg \"Pandas\")\n", 29 | "\n", 30 | "#### Pandas is you default backend.\n", 31 | " (no need to set it like that)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 8, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# but you can also set it directly:\n", 41 | "recnn.pd.set(\"pandas\")\n", 42 | "frame_size = 10\n", 43 | "batch_size = 25\n", 44 | "dirs = recnn.data.env.DataPath(\n", 45 | " base=\"../../data/\",\n", 46 | " embeddings=\"embeddings/ml20_pca128.pkl\",\n", 47 | " ratings=\"ml-20m/ratings.csv\",\n", 48 | " cache=\"cache/frame_env.pkl\", # cache will generate after you run\n", 49 | " use_cache=False\n", 50 | ")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "metadata": { 57 | "tags": [] 58 | }, 59 | "outputs": [ 60 | { 61 | "output_type": "stream", 62 | "name": "stderr", 63 | "text": "100%|██████████| 20000263/20000263 [00:13<00:00, 1469488.15it/s]\n100%|██████████| 20000263/20000263 [00:15<00:00, 1265183.17it/s]\n100%|██████████| 138493/138493 [00:06<00:00, 19935.53it/s]\nCPU times: user 41.6 s, sys: 1.89 s, total: 43.5 s\nWall time: 43.5 s\n" 64 | }, 65 | { 66 | "output_type": "execute_result", 67 | "data": { 68 | "text/plain": "" 69 | }, 70 | "metadata": {}, 71 | "execution_count": 3 72 | } 73 | ], 74 | "source": [ 75 | "%%time\n", 76 | "env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "![here be modin logo](https://modin.readthedocs.io/en/latest/_images/MODIN_ver2_hrz.png \"Modin\") " 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "Modin uses Ray or Dask to provide an effortless way to speed up your pandas notebooks, scripts, and libraries. Unlike other distributed DataFrame libraries, Modin provides seamless integration and compatibility with existing pandas code. Even using the DataFrame constructor is identical." 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "![here be Ray logo](https://github.com/ray-project/ray/raw/master/doc/source/images/ray_header_logo.png \"Ray\") " 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "A fast and simple framework for building and running distributed applications. Ray is packaged with RLlib, a scalable reinforcement learning library, and Tune, a scalable hyperparameter tuning library." 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 9, 110 | "metadata": { 111 | "tags": [] 112 | }, 113 | "outputs": [ 114 | { 115 | "output_type": "stream", 116 | "name": "stderr", 117 | "text": "2020-08-09 16:55:54,693\tINFO resource_spec.py:204 -- Starting Ray with 4.98 GiB memory available for workers and up to 2.51 GiB for objects. You can adjust these settings with ray.init(memory=, object_store_memory=).\n2020-08-09 16:55:54,919\tWARNING services.py:923 -- Redis failed to start, retrying now.\n2020-08-09 16:55:55,069\tINFO services.py:1163 -- View the Ray dashboard at \u001b[1m\u001b[32mlocalhost:8265\u001b[39m\u001b[22m\n" 118 | } 119 | ], 120 | "source": [ 121 | "import os\n", 122 | "import ray\n", 123 | "\n", 124 | "if ray.is_initialized():\n", 125 | " ray.shutdown()\n", 126 | "os.environ[\"MODIN_ENGINE\"] = \"ray\" # Modin will use Ray\n", 127 | "ray.init(num_cpus=10) # adjust for your liking\n", 128 | "recnn.pd.set(\"modin\")" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 10, 134 | "metadata": { 135 | "tags": [] 136 | }, 137 | "outputs": [ 138 | { 139 | "output_type": "stream", 140 | "name": "stderr", 141 | "text": "100%|██████████| 138493/138493 [00:07<00:00, 18503.97it/s]\nCPU times: user 12 s, sys: 2.06 s, total: 14 s\nWall time: 21.4 s\n" 142 | } 143 | ], 144 | "source": [ 145 | "%%time\n", 146 | "env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "![here be Ray logo](https://dask.org/_images/dask_horizontal_white_no_pad.svg \"Ray\") " 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "## Dask is a flexible library for parallel computing in Python." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 3, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "### dask\n", 170 | "import os\n", 171 | "os.environ[\"MODIN_ENGINE\"] = \"dask\" # Modin will use Dask\n", 172 | "recnn.pd.set(\"modin\")" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 4, 178 | "metadata": { 179 | "tags": [] 180 | }, 181 | "outputs": [ 182 | { 183 | "name": "stderr", 184 | "output_type": "stream", 185 | "text": [ 186 | "100%|██████████| 138493/138493 [00:06<00:00, 19785.99it/s]\n", 187 | "CPU times: user 14.2 s, sys: 2.13 s, total: 16.3 s\n", 188 | "Wall time: 22 s\n" 189 | ] 190 | }, 191 | { 192 | "data": { 193 | "text/plain": [ 194 | "" 195 | ] 196 | }, 197 | "execution_count": 4, 198 | "metadata": {}, 199 | "output_type": "execute_result" 200 | } 201 | ], 202 | "source": [ 203 | "%%time\n", 204 | "env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "# Free 2x increase in load speed!\n", 212 | "\n", 213 | "### Pandas Wall time: 40.6 s\n", 214 | "### Modin/Ray Wall time: 20.8S\n", 215 | "### Modin/Dusk Wall time: 22 s\n" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "Python 3.8.5 64-bit", 229 | "language": "python", 230 | "name": "python38564bitfba12b29602d49fd94d253df959599f4" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.8.5-final" 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 2 247 | } -------------------------------------------------------------------------------- /examples/readme.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | This is the primary section that contains all of my experiments. As you can see, it is divided into subfolders, topics of which can vary. 3 | 4 | Topics: 5 | 6 | 0. Embeddings generation. This section covers basic python web crapping of RESTful APIs (OMDB, IMDB, TMDB) and feature engineering of the gathered data. The feature engineering part shows how to efficiently encode categories with PCA (Multiple Correspondence Analysis), assuming you have no missing values. Numerical data is processed with Probabilistic version of Principal Component Analysis (PCA) that can handle missing values, and do all sorts of fun stuff. I also attempt to visualize the embedding space using Uniform Manifold Approximation and Projection (UMAP). Text data is processed with Facebooks FairSeq RoBERTa, which I have finetuned on MNLI like dataset (Multi-Genre Natural Language Inference). I finetune it trying to predict whether two (or more) plots belong to the same movie. Then I downcasted all of the gathered data into 3 embedding types you can choose from: UMAP, PCA (recommended) and AutoEncoder. 7 | 8 | 1. Vanilla Reinforcement Learning: this section contains my implementations for basic RL Actor-Critic algorithms. 9 | 10 | 2. Reinforce TOP K OffPolicy Correction. 11 | 12 | I took inspiration from Higgsfield's code, but I have rewritten most of the stuff in order to achieve 13 | the abstract and clear look of the code. -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 |

6 | 7 | 8 | Documentation Status 9 | 10 | 11 | 12 | Documentation Status 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | Code style: black 29 | 30 |

31 | 32 | 33 |

34 | This is my school project. It focuses on Reinforcement Learning for personalized news recommendation. The main distinction is that it tries to solve online off-policy learning with dynamically generated item embeddings. I want to create a library with SOTA algorithms for reinforcement learning recommendation, providing the level of abstraction you like. 35 |

36 | 37 |

38 | recnn.readthedocs.io 39 |

40 | 41 | ### 📊 The features can be summed up to 42 | 43 | - Abstract as you decide: you can import the entire algorithm (say DDPG) and tell it to ddpg.learn(batch), you can import networks and the learning function separately, create a custom loader for your task, or can define everything by yourself. 44 | 45 | - Examples do not contain any of the junk code or workarounds: pure model definition and the algorithm itself in one file. I wrote a couple of articles explaining how it functions. 46 | 47 | - The learning is built around sequential or frame environment that supports ML20M and like. Seq and Frame determine the length type of sequential data, seq is fully sequential dynamic size (WIP), while the frame is just a static frame. 48 | 49 | - State Representation module with various methods. For sequential state representation, you can use LSTM/RNN/GRU (WIP) 50 | 51 | - Parallel data loading with Modin (Dask / Ray) and caching 52 | 53 | - Pytorch 1.7 support with Tensorboard visualization. 54 | 55 | - New datasets will be added in the future. 56 | 57 | ## 📚 Medium Articles 58 | 59 | The repo consists of two parts: the library (./recnn), and the playground (./examples) where I explain how to work with certain things. 60 | 61 | - Pretty much what you need to get started with this library if you know recommenders but don't know much about 62 | reinforcement learning: 63 | 64 |

65 | 66 | 67 | 68 |

69 | 70 | - Top-K Off-Policy Correction for a REINFORCE Recommender System: 71 |

72 | 73 | 74 | 75 |

76 | 77 | 78 | ## Algorithms that are/will be added 79 | 80 |

81 | 82 | | Algorithm | Paper | Code | 83 | |---------------------------------------|----------------------------------|----------------------------| 84 | | Deep Q Learning (PoC) | https://arxiv.org/abs/1312.5602 | examples/0. Embeddings/ 1.DQN | 85 | | Deep Deterministic Policy Gradients | https://arxiv.org/abs/1509.02971 | examples/1.Vanilla RL/DDPG | 86 | | Twin Delayed DDPG (TD3) | https://arxiv.org/abs/1802.09477 | examples/1.Vanilla RL/TD3 | 87 | | Soft Actor-Critic | https://arxiv.org/abs/1801.01290 | examples/1.Vanilla RL/SAC | 88 | | Batch Constrained Q-Learning | https://arxiv.org/abs/1812.02900 | examples/99.To be released/BCQ | 89 | | REINFORCE Top-K Off-Policy Correction | https://arxiv.org/abs/1812.02353 | examples/2. REINFORCE TopK | 90 | 91 |

92 | 93 | ### ‍Repos I used code from 94 | 95 | - Sfujim's [BCQ](https://github.com/sfujim/BCQ) (not implemented yet) 96 | - Higgsfield's [RL Adventure 2](https://github.com/higgsfield/RL-Adventure-2) (great inspiration) 97 | 98 | ### 🤔 What is this 99 | 100 |

101 | This is my school project. It focuses on Reinforcement Learning for personalized news recommendation. The main distinction is that it tries to solve online off-policy learning with dynamically generated item embeddings. Also, there is no exploration, since we are working with a dataset. In the example section, I use Google's BERT on the ML20M dataset to extract contextual information from the movie description to form the latent vector representations. Later, you can use the same transformation on new, previously unseen items (hence, the embeddings are dynamically generated). If you don't want to bother with embeddings pipeline, I have a DQN embeddings generator as a proof of concept. 102 |

103 | 104 | 105 | ## ✋ Getting Started 106 |

107 | 108 |

109 | 110 |

111 | 112 |

113 | 114 | 115 | p.s. Image is clickable. here is direct link: 116 | 117 | 118 | 119 | 120 | To learn more about recnn, read the docs: recnn.readthedocs.io 121 | 122 | ### ⚙️ Installing 123 | 124 | ``` 125 | pip install git+git://github.com/awarebayes/RecNN.git 126 | ``` 127 | 128 | PyPi is on its way... 129 | 130 | ### 🚀 Try demo 131 | 132 | I built a [Streamlit](https://www.streamlit.io/) demo to showcase its features. 133 | It has 'recommend me a movie' feature! Note how the score changes when you **rate** the movies. When you start 134 | and the movies aren't rated (5/10 by default) the score is about ~40 (euc), but as you rate them it drops to <10, 135 | indicating more personalized and precise predictions. You can also test diversity, check out the correlation of 136 | recommendations, pairwise distances, and pinpoint accuracy. 137 | 138 | Run it: 139 | ``` 140 | git clone git@github.com:awarebayes/RecNN.git 141 | cd RecNN && streamlit run examples/streamlit_demo.py 142 | ``` 143 | 144 | [Docker image is available here](https://github.com/awarebayes/recnn-demo) 145 | 146 | ## 📁 Downloads 147 | - [MovieLens 20M](https://grouplens.org/datasets/movielens/20m/) 148 | - [Movie Embeddings](https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL) 149 | - [Misc Data](https://drive.google.com/open?id=1TclEmCnZN_Xkl3TfUXL5ivPYmLnIjQSu) 150 | - [Parsed (omdb,tmdb)](https://drive.google.com/open?id=1t0LNCbqLjiLkAMFwtP8OIYU-zPUCNAjK) 151 | 152 | ## 📁 [Download the Models](https://drive.google.com/file/d/1goGa15XZmDAp2msZvRi2v_1h9xfmnhz7/view?usp=sharing) 153 | 154 | ## 📄 Citing 155 | If you find RecNN useful for an academic publication, then please use the following BibTeX to cite it: 156 | 157 | ``` 158 | @misc{RecNN, 159 | author = {M Scherbina}, 160 | title = {RecNN: RL Recommendation with PyTorch}, 161 | year = {2019}, 162 | publisher = {GitHub}, 163 | journal = {GitHub repository}, 164 | howpublished = {\url{https://github.com/awarebayes/RecNN}}, 165 | } 166 | ``` 167 | -------------------------------------------------------------------------------- /recnn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data, utils, nn 2 | from .data import pd 3 | -------------------------------------------------------------------------------- /recnn/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils, env 2 | from .utils import * 3 | from .env import * 4 | from .dataset_functions import * 5 | from .pandas_backend import pd 6 | -------------------------------------------------------------------------------- /recnn/data/dataset_functions.py: -------------------------------------------------------------------------------- 1 | from .pandas_backend import pd 2 | import numpy as np 3 | from typing import List, Dict, Callable 4 | 5 | """ 6 | What? 7 | +++++ 8 | 9 | RecNN is designed to work with your data flow. 10 | 11 | Set kwargs in the beginning of prepare_dataset function. 12 | Kwargs you set are immutable. 13 | 14 | args_mut are mutable arguments, you can access the following: 15 | base: data.EnvBase, df: DataFrame, users: List[int], 16 | user_dict: Dict[int, Dict[str, np.ndarray] 17 | 18 | Access args_mut and modify them in functions defined by you. 19 | Best to use function chaining with build_data_pipeline. 20 | 21 | recnn.data.prepare_dataset is a function that is used by default in Env.__init__ 22 | But sometimes you want some extra. I have also predefined truncate_dataset. 23 | This function truncates the number of items to specified one. 24 | In reinforce example I modify it to look like:: 25 | 26 | def prepare_dataset(args_mut, kwargs): 27 | kwargs.set('reduce_items_to', num_items) # set kwargs for your functions here! 28 | pipeline = [recnn.data.truncate_dataset, recnn.data.prepare_dataset] 29 | recnn.data.build_data_pipeline(pipeline, kwargs, args_mut) 30 | 31 | # embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL 32 | env = recnn.data.env.FrameEnv('..', 33 | '...', frame_size, batch_size, 34 | embed_batch=embed_batch, prepare_dataset=prepare_dataset, 35 | num_workers=0) 36 | 37 | """ 38 | 39 | 40 | def try_progress_apply(dataframe, function): 41 | try: 42 | return dataframe.progress_apply(function) 43 | except AttributeError: 44 | return dataframe.apply(function) 45 | 46 | 47 | # Plain args. Shouldn't be mutated 48 | class DataFuncKwargs: 49 | def __init__(self, **kwargs): 50 | self.kwargs = kwargs 51 | 52 | def keys(self): 53 | return self.kwargs.keys() 54 | 55 | def get(self, name: str): 56 | if name not in self.kwargs: 57 | example = """ 58 | # example on how to use kwargs: 59 | def prepare_dataset(args, args_mut): 60 | args.set_kwarg('{}', your_value) # set kwargs for your functions here! 61 | pipeline = [recnn.data.truncate_dataset, recnn.data.prepare_dataset] 62 | recnn.data.build_data_pipeline(pipeline, args, args_mut) 63 | """ 64 | raise AttributeError( 65 | "No kwarg with name {} found!\n{}".format(name, example.format(example)) 66 | ) 67 | return self.kwargs[name] 68 | 69 | def set(self, name: str, value): 70 | self.kwargs[name] = value 71 | 72 | 73 | # Used for returning, arguments are mutable 74 | class DataFuncArgsMut: 75 | def __init__( 76 | self, df, base, users: List[int], user_dict: Dict[int, Dict[str, np.ndarray]] 77 | ): 78 | self.base = base 79 | self.users = users 80 | self.user_dict = user_dict 81 | self.df = df 82 | 83 | 84 | def prepare_dataset(args_mut: DataFuncArgsMut, kwargs: DataFuncKwargs): 85 | 86 | """ 87 | Basic prepare dataset function. Automatically makes index linear, in ml20 movie indices look like: 88 | [1, 34, 123, 2000], recnn makes it look like [0,1,2,3] for you. 89 | """ 90 | 91 | # get args 92 | frame_size = kwargs.get("frame_size") 93 | key_to_id = args_mut.base.key_to_id 94 | df = args_mut.df 95 | 96 | # rating range mapped from [0, 5] to [-5, 5] 97 | df["rating"] = try_progress_apply(df["rating"], lambda i: 2 * (i - 2.5)) 98 | # id's tend to be inconsistent and sparse so they are remapped here 99 | df["movieId"] = try_progress_apply(df["movieId"], key_to_id.get) 100 | users = df[["userId", "movieId"]].groupby(["userId"]).size() 101 | users = users[users > frame_size].sort_values(ascending=False).index 102 | 103 | if pd.get_type() == "modin": 104 | df = df._to_pandas() # pandas groupby is sync and doesnt affect performance 105 | ratings = ( 106 | df.sort_values(by="timestamp") 107 | .set_index("userId") 108 | .drop("timestamp", axis=1) 109 | .groupby("userId") 110 | ) 111 | 112 | # Groupby user 113 | user_dict = {} 114 | 115 | def app(x): 116 | userid = x.index[0] 117 | user_dict[userid] = {} 118 | user_dict[userid]["items"] = x["movieId"].values 119 | user_dict[userid]["ratings"] = x["rating"].values 120 | 121 | try_progress_apply(ratings, app) 122 | 123 | args_mut.user_dict = user_dict 124 | args_mut.users = users 125 | 126 | return args_mut, kwargs 127 | 128 | 129 | def truncate_dataset(args_mut: DataFuncArgsMut, kwargs: DataFuncKwargs): 130 | """ 131 | Truncate #items to reduce_items_to provided in kwargs 132 | """ 133 | 134 | # here are adjusted n items to keep 135 | num_items = kwargs.get("reduce_items_to") 136 | df = args_mut.df 137 | 138 | counts = df["movieId"].value_counts().sort_values() 139 | to_remove = counts[:-num_items].index 140 | to_keep = counts[-num_items:].index 141 | to_keep_id = pd.get().Series(to_keep).apply(args_mut.base.key_to_id.get).values 142 | to_keep_mask = np.zeros(len(counts)) 143 | to_keep_mask[to_keep_id] = 1 144 | 145 | args_mut.df = df.drop(df[df["movieId"].isin(to_remove)].index) 146 | 147 | key_to_id_new = {} 148 | id_to_key_new = {} 149 | count = 0 150 | 151 | for idx, i in enumerate(list(args_mut.base.key_to_id.keys())): 152 | if i in to_keep: 153 | key_to_id_new[i] = count 154 | id_to_key_new[idx] = i 155 | count += 1 156 | 157 | args_mut.base.embeddings = args_mut.base.embeddings[to_keep_mask] 158 | args_mut.base.key_to_id = key_to_id_new 159 | args_mut.base.id_to_key = id_to_key_new 160 | 161 | print( 162 | "action space is reduced to {} - {} = {}".format( 163 | num_items + len(to_remove), len(to_remove), num_items 164 | ) 165 | ) 166 | 167 | return args_mut, kwargs 168 | 169 | 170 | def build_data_pipeline( 171 | chain: List[Callable], kwargs: DataFuncKwargs, args_mut: DataFuncArgsMut 172 | ): 173 | """ 174 | Higher order function 175 | :param chain: array of callable 176 | :param **kwargs: any kwargs you like 177 | """ 178 | for call in chain: 179 | # note: returned kwargs are not utilized to guarantee immutability 180 | args_mut, _ = call(args_mut, kwargs) 181 | return args_mut, kwargs 182 | -------------------------------------------------------------------------------- /recnn/data/db_con.py: -------------------------------------------------------------------------------- 1 | from milvus import Milvus, MetricType 2 | import torch 3 | 4 | 5 | class SearchResult: 6 | def __init__(self, data): 7 | self.data = data 8 | 9 | def id(self, device): 10 | return torch.tensor(self.data.id_array).to(device) 11 | 12 | def dist(self, device): 13 | return torch.tensor(self.data.distance_array).to(device) 14 | 15 | 16 | class MilvusConnection: 17 | def __init__(self, env, name="movies_L2", port="19530", param=None): 18 | 19 | if param is None: 20 | param = dict() 21 | param = { 22 | "collection_name": name, 23 | "dimension": 128, 24 | "index_file_size": 1024, 25 | "metric_type": MetricType.L2, 26 | **param, 27 | } 28 | self.name = name 29 | self.client = Milvus(host="localhost", port=port) 30 | self.statuses = {} 31 | if not self.client.has_collection(name)[1]: 32 | status_created_collection = self.client.create_collection(param) 33 | vectors = env.base.embeddings.detach().cpu().numpy().astype("float32") 34 | target_ids = list(range(vectors.shape[0])) 35 | status_inserted, inserted_vector_ids = self.client.insert( 36 | collection_name=name, records=vectors, ids=target_ids 37 | ) 38 | status_flushed = self.client.flush([name]) 39 | status_compacted = self.client.compact(collection_name=name) 40 | self.statuses["created_collection"] = status_created_collection 41 | self.statuses["inserted"] = status_inserted 42 | self.statuses["flushed"] = status_flushed 43 | self.statuses["compacted"] = status_compacted 44 | 45 | def search(self, search_vecs, topk=10, search_param=None): 46 | if search_param is None: 47 | search_param = dict() 48 | search_param = {"nprobe": 16, **search_param} 49 | status, results = self.client.search( 50 | collection_name=self.name, 51 | query_records=search_vecs, 52 | top_k=topk, 53 | params=search_param, 54 | ) 55 | self.statuses['last_search'] = status 56 | return SearchResult(results) 57 | 58 | def get_log(self): 59 | return self.statuses 60 | -------------------------------------------------------------------------------- /recnn/data/env.py: -------------------------------------------------------------------------------- 1 | from . import utils, dataset_functions as dset_F 2 | from .pandas_backend import pd 3 | import pickle 4 | from torch.utils.data import Dataset, DataLoader 5 | from sklearn.model_selection import train_test_split 6 | import os 7 | 8 | """ 9 | .. module:: env 10 | :synopsis: Main abstraction of the library for datasets is called environment, similar to how other reinforcement 11 | learning libraries name it. This interface is created to provide SARSA like input for your RL Models. When you are 12 | working with recommendation env, you have two choices: using static length inputs (say 10 items) or dynamic length 13 | time series with sequential encoders (many to one rnn). Static length is provided via FrameEnv, and dynamic length 14 | along with sequential state representation encoder is implemented in SeqEnv. Let’s take a look at FrameEnv first: 15 | 16 | 17 | .. moduleauthor:: Mike Watts 18 | 19 | 20 | """ 21 | 22 | 23 | class UserDataset(Dataset): 24 | 25 | """ 26 | Low Level API: dataset class user: [items, ratings], Instance of torch.DataSet 27 | """ 28 | 29 | def __init__(self, users, user_dict): 30 | """ 31 | 32 | :param users: integer list of user_id. Useful for train/test splitting 33 | :type users: list. 34 | :param user_dict: dictionary of users with user_id as key and [items, ratings] as value 35 | :type user_dict: (dict{ user_id: dict{'items': list, 'ratings': list} }). 36 | 37 | """ 38 | 39 | self.users = users 40 | self.user_dict = user_dict 41 | 42 | def __len__(self): 43 | """ 44 | useful for tqdm, consists of a single line: 45 | return len(self.users) 46 | """ 47 | return len(self.users) 48 | 49 | def __getitem__(self, idx): 50 | """ 51 | getitem is a function where non linear user_id maps to a linear index. For instance in the ml20m dataset, 52 | there are big gaps between neighbouring user_id. getitem removes these gaps, optimizing the speed. 53 | 54 | :param idx: index drawn from range(0, len(self.users)). User id can be not linear, idx is. 55 | :type idx: int 56 | 57 | :returns: dict{'items': list, rates:list, sizes: int} 58 | """ 59 | idx = self.users[idx] 60 | group = self.user_dict[idx] 61 | items = group["items"][:] 62 | rates = group["ratings"][:] 63 | size = items.shape[0] 64 | return {"items": items, "rates": rates, "sizes": size, "users": idx} 65 | 66 | 67 | class EnvBase: 68 | 69 | """ 70 | Misc class used for serializing 71 | """ 72 | 73 | def __init__(self): 74 | self.train_user_dataset = None 75 | self.test_user_dataset = None 76 | self.embeddings = None 77 | self.key_to_id = None 78 | self.id_to_key = None 79 | 80 | 81 | class DataPath: 82 | 83 | """ 84 | [New!] Path to your data. Note: cache is optional. It saves EnvBase as a pickle 85 | """ 86 | 87 | def __init__( 88 | self, 89 | base: str, 90 | ratings: str, 91 | embeddings: str, 92 | cache: str = "", 93 | use_cache: bool = True, 94 | ): 95 | self.ratings = base + ratings 96 | self.embeddings = base + embeddings 97 | self.cache = base + cache 98 | self.use_cache = use_cache 99 | 100 | 101 | class Env: 102 | 103 | """ 104 | Env abstract class 105 | """ 106 | 107 | def __init__( 108 | self, 109 | path: DataPath, 110 | prepare_dataset=dset_F.prepare_dataset, 111 | embed_batch=utils.batch_tensor_embeddings, 112 | **kwargs 113 | ): 114 | 115 | """ 116 | .. note:: 117 | embeddings need to be provided in {movie_id: torch.tensor} format! 118 | 119 | :param path: DataPath to where item embeddings are stored. 120 | :type path: DataPath 121 | :param test_size: ratio of users to use in testing. Rest will be used for training/validation 122 | :type test_size: int 123 | :param min_seq_size: (use as kwarg) filter users: len(user.items) > min seq size 124 | :type min_seq_size: int 125 | :param prepare_dataset: (use as kwarg) function you provide. 126 | :type prepare_dataset: function 127 | :param embed_batch: function to apply embeddings to batch. Can be set to yield continuous/discrete state/action 128 | :type embed_batch: function 129 | """ 130 | 131 | self.base = EnvBase() 132 | self.embed_batch = embed_batch 133 | self.prepare_dataset = prepare_dataset 134 | if path.use_cache and os.path.isfile(path.cache): 135 | self.load_env(path.cache) 136 | else: 137 | self.process_env(path) 138 | if path.use_cache: 139 | self.save_env(path.cache) 140 | 141 | def process_env(self, path: DataPath, **kwargs): 142 | if "frame_size" in kwargs.keys(): 143 | frame_size = kwargs["frame_size"] 144 | else: 145 | frame_size = 10 146 | 147 | if "test_size" in kwargs.keys(): 148 | test_size = kwargs["test_size"] 149 | else: 150 | test_size = 0.05 151 | 152 | movie_embeddings_key_dict = pickle.load(open(path.embeddings, "rb")) 153 | ( 154 | self.base.embeddings, 155 | self.base.key_to_id, 156 | self.base.id_to_key, 157 | ) = utils.make_items_tensor(movie_embeddings_key_dict) 158 | ratings = pd.get().read_csv(path.ratings) 159 | 160 | process_kwargs = dset_F.DataFuncKwargs( 161 | frame_size=frame_size, # remove when lstm gets implemented 162 | ) 163 | 164 | process_args_mut = dset_F.DataFuncArgsMut( 165 | df=ratings, 166 | base=self.base, 167 | users=None, # will be set later 168 | user_dict=None, # will be set later 169 | ) 170 | 171 | self.prepare_dataset(process_args_mut, process_kwargs) 172 | self.base = process_args_mut.base 173 | self.df = process_args_mut.df 174 | users = process_args_mut.users 175 | user_dict = process_args_mut.user_dict 176 | 177 | train_users, test_users = train_test_split(users, test_size=test_size) 178 | train_users = utils.sort_users_itemwise(user_dict, train_users)[2:] 179 | test_users = utils.sort_users_itemwise(user_dict, test_users) 180 | self.base.train_user_dataset = UserDataset(train_users, user_dict) 181 | self.base.test_user_dataset = UserDataset(test_users, user_dict) 182 | 183 | def load_env(self, where: str): 184 | self.base = pickle.load(open(where, "rb")) 185 | 186 | def save_env(self, where: str): 187 | pickle.dump(self.base, open(where, "wb")) 188 | 189 | 190 | class FrameEnv(Env): 191 | """ 192 | Static length user environment. 193 | """ 194 | 195 | def __init__( 196 | self, path, frame_size=10, batch_size=25, num_workers=1, *args, **kwargs 197 | ): 198 | 199 | """ 200 | :param embeddings: path to where item embeddings are stored. 201 | :type embeddings: str 202 | :param ratings: path to the dataset that is similar to the ml20m 203 | :type ratings: str 204 | :param frame_size: len of a static sequence, frame 205 | :type frame_size: int 206 | 207 | p.s. you can also provide **pandas_conf in the arguments. 208 | 209 | It is useful if you dataset columns are different from ml20:: 210 | 211 | pandas_conf = {user_id='userId', rating='rating', item='movieId', timestamp='timestamp'} 212 | env = FrameEnv(embed_dir, rating_dir, **pandas_conf) 213 | 214 | """ 215 | 216 | kwargs["frame_size"] = frame_size 217 | super(FrameEnv, self).__init__( 218 | path, min_seq_size=frame_size + 1, *args, **kwargs 219 | ) 220 | 221 | self.frame_size = frame_size 222 | self.batch_size = batch_size 223 | self.num_workers = num_workers 224 | 225 | self.train_dataloader = DataLoader( 226 | self.base.train_user_dataset, 227 | batch_size=batch_size, 228 | shuffle=True, 229 | num_workers=num_workers, 230 | collate_fn=self.prepare_batch_wrapper, 231 | ) 232 | 233 | self.test_dataloader = DataLoader( 234 | self.base.test_user_dataset, 235 | batch_size=batch_size, 236 | shuffle=True, 237 | num_workers=num_workers, 238 | collate_fn=self.prepare_batch_wrapper, 239 | ) 240 | 241 | def prepare_batch_wrapper(self, x): 242 | batch = utils.prepare_batch_static_size( 243 | x, 244 | self.base.embeddings, 245 | embed_batch=self.embed_batch, 246 | frame_size=self.frame_size, 247 | ) 248 | return batch 249 | 250 | def train_batch(self): 251 | """ Get batch for training """ 252 | return next(iter(self.train_dataloader)) 253 | 254 | def test_batch(self): 255 | """ Get batch for testing """ 256 | return next(iter(self.test_dataloader)) 257 | 258 | 259 | # I will rewrite it some day 260 | # Pm me if I get lazy 261 | ''' 262 | class SeqEnv(Env): 263 | 264 | """ 265 | WARNING: THIS FEATURE IS IN ALPHA 266 | Dynamic length user environment. 267 | Due to some complications, this module is implemented quiet differently from FrameEnv. 268 | First of all, it relies on the replay buffer. Train/Test batch is a generator. 269 | In batch generator, I iterate through the batch, and choose target action with certain probability. 270 | Hence, ~95% is state that is encoded with state encoder and ~5% are actions. 271 | If you have a better solution, your contribution is welcome 272 | """ 273 | 274 | def __init__(self, embeddings, ratings, state_encoder, batch_size=25, device=torch.device('cuda'), 275 | layout=None, max_buf_size=1000, num_workers=1, embed_batch=utils.batch_tensor_embeddings, 276 | *args, **kwargs): 277 | 278 | """ 279 | :param embeddings: path to where item embeddings are stored. 280 | :type embeddings: str 281 | :param ratings: path to the dataset that is similar to the ml20m 282 | :type ratings: str 283 | :param state_encoder: state encoder of your choice 284 | :type state_encoder: nn.Module 285 | :param device: device of your choice 286 | :type device: torch.device 287 | :param max_buf_size: maximum size of a replay buffer 288 | :type max_buf_size: int 289 | :param layout: how sizes in batch should look like 290 | :type layout: list 291 | """ 292 | 293 | super(SeqEnv, self).__init__(embeddings, ratings, min_seq_size=10, *args, **kwargs) 294 | print("Sequential support is super experimental and is not guaranteed to work") 295 | self.embed_batch = embed_batch 296 | 297 | if layout is None: 298 | # default ml20m layout for my (action=128) embeddings, (state=256) 299 | layout = [torch.Size([max_buf_size, 256]), 300 | torch.Size([max_buf_size, 128]), 301 | torch.Size([max_buf_size, 1]), 302 | torch.Size([max_buf_size, 256])] 303 | 304 | def prepare_batch_wrapper(batch): 305 | 306 | batch = utils.padder(batch) 307 | batch = utils.prepare_batch_dynamic_size(batch, self.embeddings) 308 | return batch 309 | 310 | self.prepare_batch_wrapper = prepare_batch_wrapper 311 | self.batch_size = batch_size 312 | self.num_workers = num_workers 313 | 314 | self.device = device 315 | self.state_encoder = state_encoder 316 | self.max_buf_size = max_buf_size 317 | self.train_dataloader = DataLoader(self.train_user_dataset, batch_size=batch_size, 318 | shuffle=False, num_workers=num_workers, collate_fn=prepare_batch_wrapper) 319 | self.test_dataloader = DataLoader(self.test_user_dataset, batch_size=batch_size, 320 | shuffle=False, num_workers=num_workers, collate_fn=prepare_batch_wrapper) 321 | 322 | self.buffer_layout = layout 323 | 324 | self.train_buffer = utils.ReplayBuffer(self.max_buf_size, layout=self.buffer_layout) 325 | self.test_buffer = utils.ReplayBuffer(self.max_buf_size, layout=self.buffer_layout) 326 | 327 | def train_batch(self): 328 | while 1: 329 | for batch in tqdm(self.train_dataloader): 330 | items, ratings, sizes, users = utils.get_irsu(batch) 331 | items, ratings, sizes = [i.to(self.device) for i in [items, ratings, sizes]] 332 | hidden = None 333 | state = None 334 | self.train_buffer.meta.update({'sizes': sizes, 'users': users}) 335 | for t in range(int(sizes.min().item()) - 1): 336 | action = items[:, t] 337 | reward = ratings[:, t].unsqueeze(-1) 338 | s = torch.cat([action, reward], 1).unsqueeze(0) 339 | next_state, hidden = self.state_encoder(s, hidden) if hidden else self.state_encoder(s) 340 | next_state = next_state.squeeze() 341 | 342 | if np.random.random() > 0.95 and state is not None: 343 | batch = {'state': state, 'action': action, 'reward': reward, 344 | 'next_state': next_state, 'step': t} 345 | self.train_buffer.append(batch) 346 | 347 | if self.train_buffer.len() >= self.max_buf_size: 348 | g = self.train_buffer.get() 349 | self.train_buffer.flush() 350 | yield g 351 | 352 | state = next_state 353 | 354 | def test_batch(self): 355 | while 1: 356 | for batch in tqdm(self.test_dataloader): 357 | batch = [i.to(self.device) for i in batch] 358 | items, ratings, sizes = batch 359 | hidden = None 360 | state = None 361 | for t in range(int(sizes.min().item()) - 1): 362 | action = items[:, t] 363 | reward = ratings[:, t].unsqueeze(-1) 364 | s = torch.cat([action, reward], 1).unsqueeze(0) 365 | next_state, hidden = self.state_encoder(s, hidden) if hidden else self.state_encoder(s) 366 | next_state = next_state.squeeze() 367 | 368 | if np.random.random() > 0.95 and state is not None: 369 | batch = [state, action, reward, next_state] 370 | self.test_buffer.append(batch) 371 | 372 | if self.test_buffer.len() >= self.max_buf_size: 373 | g = self.test_buffer.get() 374 | self.test_buffer.flush() 375 | yield g 376 | del g 377 | 378 | state = next_state 379 | ''' 380 | -------------------------------------------------------------------------------- /recnn/data/pandas_backend.py: -------------------------------------------------------------------------------- 1 | class PandasBackend: 2 | def __init__(self): 3 | self.backend = None 4 | self.type = "pandas" 5 | self.set() 6 | 7 | def set(self, backend="pandas"): 8 | if backend not in ["pandas", "modin"]: 9 | print("Wrong backend specified! Usage: pd.set('pandas') or pd.set('modin')") 10 | print("Using default pandas backend!") 11 | backend = "pandas" 12 | self.type = backend 13 | if backend == "pandas": 14 | import pandas 15 | 16 | try: 17 | from tqdm.auto import tqdm 18 | 19 | tqdm.pandas() 20 | except ImportError: 21 | print("Error in tqdm.pandas()") 22 | print("Pandas progress is disabled") 23 | self.backend = pandas 24 | elif backend == "modin": 25 | from modin import pandas 26 | 27 | self.backend = pandas 28 | 29 | def get(self): 30 | return self.backend 31 | 32 | def get_type(self): 33 | return self.type 34 | 35 | 36 | pd = PandasBackend() 37 | -------------------------------------------------------------------------------- /recnn/data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .pandas_backend import pd 4 | 5 | 6 | # helper function similar to pandas.Series.rolling 7 | def rolling_window(a, window): 8 | shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) 9 | strides = a.strides + (a.strides[-1],) 10 | return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) 11 | 12 | 13 | def get_irsu(batch): 14 | items_t, ratings_t, sizes_t, users_t = ( 15 | batch["items"], 16 | batch["ratings"], 17 | batch["sizes"], 18 | batch["users"], 19 | ) 20 | return items_t, ratings_t, sizes_t, users_t 21 | 22 | 23 | def batch_no_embeddings(batch, frame_size, *args, **kwargs): 24 | """ 25 | Embed Batch: discrete state discrete action 26 | """ 27 | items_t, ratings_t, sizes_t, users_t = get_irsu(batch) 28 | b_size = ratings_t.size(0) 29 | items = items_t[:, :-1] 30 | next_items = items_t[:, 1:] 31 | ratings = ratings_t[:, :-1] 32 | next_ratings = ratings_t[:, 1:] 33 | action = items_t[:, -1] 34 | reward = ratings_t[:, -1] 35 | done = torch.zeros(b_size) 36 | 37 | done[torch.cumsum(sizes_t - frame_size, dim=0) - 1] = 1 38 | batch = { 39 | "items": items, 40 | "next_items": next_items, 41 | ratings: "ratings", 42 | "next_ratings": next_ratings, 43 | "action": action, 44 | "reward": reward, 45 | "done": done, 46 | "meta": {"users": users_t, "sizes": sizes_t}, 47 | } 48 | return batch 49 | 50 | 51 | def batch_tensor_embeddings(batch, item_embeddings_tensor, frame_size, *args, **kwargs): 52 | """ 53 | Embed Batch: continuous state continuous action 54 | """ 55 | 56 | items_t, ratings_t, sizes_t, users_t = get_irsu(batch) 57 | items_emb = item_embeddings_tensor[items_t.long()] 58 | b_size = ratings_t.size(0) 59 | 60 | items = items_emb[:, :-1, :].view(b_size, -1) 61 | next_items = items_emb[:, 1:, :].view(b_size, -1) 62 | ratings = ratings_t[:, :-1] 63 | next_ratings = ratings_t[:, 1:] 64 | 65 | state = torch.cat([items, ratings], 1) 66 | next_state = torch.cat([next_items, next_ratings], 1) 67 | action = items_emb[:, -1, :] 68 | reward = ratings_t[:, -1] 69 | 70 | done = torch.zeros(b_size) 71 | done[torch.cumsum(sizes_t - frame_size, dim=0) - 1] = 1 72 | 73 | batch = { 74 | "state": state, 75 | "action": action, 76 | "reward": reward, 77 | "next_state": next_state, 78 | "done": done, 79 | "meta": {"users": users_t, "sizes": sizes_t}, 80 | } 81 | return batch 82 | 83 | 84 | def batch_contstate_discaction( 85 | batch, item_embeddings_tensor, frame_size, num_items, *args, **kwargs 86 | ): 87 | 88 | """ 89 | Embed Batch: continuous state discrete action 90 | """ 91 | 92 | items_t, ratings_t, sizes_t, users_t = get_irsu(batch) 93 | items_emb = item_embeddings_tensor[items_t.long()] 94 | b_size = ratings_t.size(0) 95 | 96 | items = items_emb[:, :-1, :].view(b_size, -1) 97 | next_items = items_emb[:, 1:, :].view(b_size, -1) 98 | ratings = ratings_t[:, :-1] 99 | next_ratings = ratings_t[:, 1:] 100 | 101 | state = torch.cat([items, ratings], 1) 102 | next_state = torch.cat([next_items, next_ratings], 1) 103 | action = items_t[:, -1] 104 | reward = ratings_t[:, -1] 105 | 106 | done = torch.zeros(b_size) 107 | done[torch.cumsum(sizes_t - frame_size, dim=0) - 1] = 1 108 | 109 | one_hot_action = torch.zeros(b_size, num_items) 110 | one_hot_action.scatter_(1, action.view(-1, 1), 1) 111 | 112 | batch = { 113 | "state": state, 114 | "action": one_hot_action, 115 | "reward": reward, 116 | "next_state": next_state, 117 | "done": done, 118 | "meta": {"users": users_t, "sizes": sizes_t}, 119 | } 120 | return batch 121 | 122 | 123 | # pads stuff to work with lstms 124 | def padder(x): 125 | items_t = [] 126 | ratings_t = [] 127 | sizes_t = [] 128 | users_t = [] 129 | for i in range(len(x)): 130 | items_t.append(torch.tensor(x[i]["items"])) 131 | ratings_t.append(torch.tensor(x[i]["rates"])) 132 | sizes_t.append(x[i]["sizes"]) 133 | users_t.append(x[i]["users"]) 134 | items_t = torch.nn.utils.rnn.pad_sequence(items_t, batch_first=True).long() 135 | ratings_t = torch.nn.utils.rnn.pad_sequence(ratings_t, batch_first=True).float() 136 | sizes_t = torch.tensor(sizes_t).float() 137 | return {"items": items_t, "ratings": ratings_t, "sizes": sizes_t, "users": users_t} 138 | 139 | 140 | def sort_users_itemwise(user_dict, users): 141 | return ( 142 | pd.get() 143 | .Series(dict([(i, user_dict[i]["items"].shape[0]) for i in users])) 144 | .sort_values(ascending=False) 145 | .index 146 | ) 147 | 148 | 149 | def prepare_batch_dynamic_size(batch, item_embeddings_tensor, embed_batch=None): 150 | item_idx, ratings_t, sizes_t, users_t = get_irsu(batch) 151 | item_t = item_embeddings_tensor[item_idx] 152 | batch = {"items": item_t, "users": users_t, "ratings": ratings_t, "sizes": sizes_t} 153 | return batch 154 | 155 | 156 | # Main function that is used as torch.DataLoader->collate_fn 157 | # CollateFn docs: 158 | # https://pytorch.org/docs/stable/data.html#working-with-collate-fn 159 | 160 | 161 | def prepare_batch_static_size( 162 | batch, item_embeddings_tensor, frame_size=10, embed_batch=batch_tensor_embeddings 163 | ): 164 | item_t, ratings_t, sizes_t, users_t = [], [], [], [] 165 | for i in range(len(batch)): 166 | item_t.append(batch[i]["items"]) 167 | ratings_t.append(batch[i]["rates"]) 168 | sizes_t.append(batch[i]["sizes"]) 169 | users_t.append(batch[i]["users"]) 170 | 171 | item_t = np.concatenate([rolling_window(i, frame_size + 1) for i in item_t], 0) 172 | ratings_t = np.concatenate( 173 | [rolling_window(i, frame_size + 1) for i in ratings_t], 0 174 | ) 175 | 176 | item_t = torch.tensor(item_t) 177 | users_t = torch.tensor(users_t) 178 | ratings_t = torch.tensor(ratings_t).float() 179 | sizes_t = torch.tensor(sizes_t) 180 | 181 | batch = {"items": item_t, "users": users_t, "ratings": ratings_t, "sizes": sizes_t} 182 | 183 | return embed_batch( 184 | batch=batch, 185 | item_embeddings_tensor=item_embeddings_tensor, 186 | frame_size=frame_size, 187 | ) 188 | 189 | 190 | # Usually in data sets there item index is inconsistent (if you plot it doesn't look like a line) 191 | # This function makes the index linear, allows for better compression of the data 192 | # And also makes use of tensor[tensor] semantics 193 | 194 | # items_embeddings_key_dict:arg - item embeddings by key 195 | # include_zero:arg - whether to include items_embeddings_id_dict[0] = [0, 0, 0, ..., 0] (128) 196 | # sometimes needed for rnn padding, by default True 197 | # returns: 198 | # items_embeddings_tensor - items_embeddings_dict compressed into tensor 199 | # key_to_id - dict key -> index 200 | # id_to_key - dict index -> key 201 | 202 | 203 | def make_items_tensor(items_embeddings_key_dict): 204 | keys = list(sorted(items_embeddings_key_dict.keys())) 205 | key_to_id = dict(zip(keys, range(len(keys)))) 206 | id_to_key = dict(zip(range(len(keys)), keys)) 207 | 208 | items_embeddings_id_dict = {} 209 | for k in items_embeddings_key_dict.keys(): 210 | items_embeddings_id_dict[key_to_id[k]] = items_embeddings_key_dict[k] 211 | items_embeddings_tensor = torch.stack( 212 | [items_embeddings_id_dict[i] for i in range(len(items_embeddings_id_dict))] 213 | ) 214 | return items_embeddings_tensor, key_to_id, id_to_key 215 | 216 | 217 | class ReplayBuffer: 218 | def __init__(self, buffer_size, layout): 219 | self.buffer = None 220 | self.idx = 0 221 | self.size = buffer_size 222 | self.layout = layout 223 | self.meta = {"step": []} 224 | self.flush() 225 | 226 | def flush(self): 227 | # state, action, reward, next_state 228 | del self.buffer 229 | self.buffer = [torch.zeros(i) for i in self.layout] 230 | self.idx = 0 231 | self.meta["step"] = [] 232 | 233 | def append(self, batch): 234 | state, action, reward, next_state, step = ( 235 | batch["state"], 236 | batch["action"], 237 | batch["reward"], 238 | batch["next_state"], 239 | batch["step"], 240 | ) 241 | self.meta["step"].append(step) 242 | lower = self.idx 243 | upper = state.size(0) + lower 244 | self.buffer[0][lower:upper] = state 245 | self.buffer[1][lower:upper] = action 246 | self.buffer[2][lower:upper] = reward 247 | self.buffer[3][lower:upper] = next_state 248 | self.idx = upper 249 | 250 | def get(self): 251 | state, action, reward, next_state = self.buffer 252 | batch = { 253 | "state": state, 254 | "action": action, 255 | "reward": reward, 256 | "next_state": next_state, 257 | "meta": self.meta, 258 | } 259 | return batch 260 | 261 | def len(self): 262 | return self.idx 263 | 264 | 265 | def get_base_batch(batch, device=torch.device("cuda"), done=True): 266 | b = [ 267 | batch["state"], 268 | batch["action"], 269 | batch["reward"].unsqueeze(1), 270 | batch["next_state"], 271 | ] 272 | if done: 273 | b.append(batch["done"].unsqueeze(1)) 274 | else: 275 | batch.append(torch.zeros_like(batch["reward"])) 276 | return [i.to(device) for i in b] 277 | -------------------------------------------------------------------------------- /recnn/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import algo, models, update 2 | from .models import * 3 | from .algo import * 4 | from .update import * 5 | -------------------------------------------------------------------------------- /recnn/nn/algo.py: -------------------------------------------------------------------------------- 1 | from recnn import utils 2 | from recnn.nn import update 3 | from recnn.nn.update import ChooseREINFORCE 4 | 5 | import torch 6 | import torch_optimizer as optim 7 | import copy 8 | 9 | 10 | """ 11 | Algorithms as is aren't much. Just classes with pre set up parameters, optimizers and stuff. 12 | """ 13 | 14 | 15 | class Algo: 16 | def __init__(self): 17 | self.nets = { 18 | "value_net": None, 19 | "policy_net": None, 20 | } 21 | 22 | self.optimizers = {"policy_optimizer": None, "value_optimizer": None} 23 | 24 | self.params = {"Some parameters here": None} 25 | 26 | self._step = 0 27 | 28 | self.debug = {} 29 | 30 | # by default it will not output anything 31 | # use torch.SummaryWriter instance if you want output 32 | self.writer = utils.misc.DummyWriter() 33 | 34 | self.device = torch.device("cpu") 35 | 36 | self.loss_layout = { 37 | "test": {"value": [], "policy": [], "step": []}, 38 | "train": {"value": [], "policy": [], "step": []}, 39 | } 40 | 41 | self.algorithm = None 42 | 43 | def update(self, batch, learn=True): 44 | return self.algorithm( 45 | batch, 46 | self.params, 47 | self.nets, 48 | self.optimizers, 49 | device=self.device, 50 | debug=self.debug, 51 | writer=self.writer, 52 | learn=learn, 53 | step=self._step, 54 | ) 55 | 56 | def to(self, device): 57 | self.nets = {k: v.to(device) for k, v in self.nets.items()} 58 | self.device = device 59 | return self 60 | 61 | def step(self): 62 | self._step += 1 63 | 64 | 65 | class DDPG(Algo): 66 | def __init__(self, policy_net, value_net): 67 | 68 | super(DDPG, self).__init__() 69 | 70 | self.algorithm = update.ddpg_update 71 | 72 | # these are target networks that we need for ddpg algorigm to work 73 | target_policy_net = copy.deepcopy(policy_net) 74 | target_value_net = copy.deepcopy(value_net) 75 | 76 | target_policy_net.eval() 77 | target_value_net.eval() 78 | 79 | # soft update 80 | utils.soft_update(value_net, target_value_net, soft_tau=1.0) 81 | utils.soft_update(policy_net, target_policy_net, soft_tau=1.0) 82 | 83 | # define optimizers 84 | value_optimizer = optim.Ranger( 85 | value_net.parameters(), lr=1e-5, weight_decay=1e-2 86 | ) 87 | policy_optimizer = optim.Ranger( 88 | policy_net.parameters(), lr=1e-5, weight_decay=1e-2 89 | ) 90 | 91 | self.nets = { 92 | "value_net": value_net, 93 | "target_value_net": target_value_net, 94 | "policy_net": policy_net, 95 | "target_policy_net": target_policy_net, 96 | } 97 | 98 | self.optimizers = { 99 | "policy_optimizer": policy_optimizer, 100 | "value_optimizer": value_optimizer, 101 | } 102 | 103 | self.params = { 104 | "gamma": 0.99, 105 | "min_value": -10, 106 | "max_value": 10, 107 | "policy_step": 10, 108 | "soft_tau": 0.001, 109 | } 110 | 111 | self.loss_layout = { 112 | "test": {"value": [], "policy": [], "step": []}, 113 | "train": {"value": [], "policy": [], "step": []}, 114 | } 115 | 116 | 117 | class TD3(Algo): 118 | def __init__(self, policy_net, value_net1, value_net2): 119 | 120 | super(TD3, self).__init__() 121 | 122 | self.algorithm = update.td3_update 123 | 124 | # these are target networks that we need for TD3 algorigm to work 125 | target_policy_net = copy.deepcopy(policy_net) 126 | target_value_net1 = copy.deepcopy(value_net1) 127 | target_value_net2 = copy.deepcopy(value_net2) 128 | 129 | target_policy_net.eval() 130 | target_value_net1.eval() 131 | target_value_net2.eval() 132 | 133 | # soft update 134 | utils.soft_update(value_net1, target_value_net1, soft_tau=1.0) 135 | utils.soft_update(value_net2, target_value_net2, soft_tau=1.0) 136 | utils.soft_update(policy_net, target_policy_net, soft_tau=1.0) 137 | 138 | # define optimizers 139 | value_optimizer1 = optim.Ranger( 140 | value_net1.parameters(), lr=1e-5, weight_decay=1e-2 141 | ) 142 | value_optimizer2 = optim.Ranger( 143 | value_net2.parameters(), lr=1e-5, weight_decay=1e-2 144 | ) 145 | policy_optimizer = optim.Ranger( 146 | policy_net.parameters(), lr=1e-5, weight_decay=1e-2 147 | ) 148 | 149 | self.nets = { 150 | "value_net1": value_net1, 151 | "target_value_net1": target_value_net1, 152 | "value_net2": value_net2, 153 | "target_value_net2": target_value_net2, 154 | "policy_net": policy_net, 155 | "target_policy_net": target_policy_net, 156 | } 157 | 158 | self.optimizers = { 159 | "policy_optimizer": policy_optimizer, 160 | "value_optimizer1": value_optimizer1, 161 | "value_optimizer2": value_optimizer2, 162 | } 163 | 164 | self.params = { 165 | "gamma": 0.99, 166 | "noise_std": 0.5, 167 | "noise_clip": 3, 168 | "soft_tau": 0.001, 169 | "policy_update": 10, 170 | "policy_lr": 1e-5, 171 | "value_lr": 1e-5, 172 | "actor_weight_init": 25e-2, 173 | "critic_weight_init": 6e-1, 174 | } 175 | 176 | self.loss_layout = { 177 | "test": {"value1": [], "value2": [], "policy": [], "step": []}, 178 | "train": {"value1": [], "value2": [], "policy": [], "step": []}, 179 | } 180 | 181 | 182 | class Reinforce(Algo): 183 | def __init__(self, policy_net, value_net): 184 | 185 | super(Reinforce, self).__init__() 186 | 187 | self.algorithm = update.reinforce_update 188 | 189 | # these are target networks that we need for ddpg algorigm to work 190 | target_policy_net = copy.deepcopy(policy_net) 191 | target_value_net = copy.deepcopy(value_net) 192 | 193 | target_policy_net.eval() 194 | target_value_net.eval() 195 | 196 | # soft update 197 | utils.soft_update(value_net, target_value_net, soft_tau=1.0) 198 | utils.soft_update(policy_net, target_policy_net, soft_tau=1.0) 199 | 200 | # define optimizers 201 | value_optimizer = optim.Ranger( 202 | value_net.parameters(), lr=1e-5, weight_decay=1e-2 203 | ) 204 | policy_optimizer = optim.Ranger( 205 | policy_net.parameters(), lr=1e-5, weight_decay=1e-2 206 | ) 207 | 208 | self.nets = { 209 | "value_net": value_net, 210 | "target_value_net": target_value_net, 211 | "policy_net": policy_net, 212 | "target_policy_net": target_policy_net, 213 | } 214 | 215 | self.optimizers = { 216 | "policy_optimizer": policy_optimizer, 217 | "value_optimizer": value_optimizer, 218 | } 219 | 220 | self.params = { 221 | "reinforce": ChooseREINFORCE(ChooseREINFORCE.basic_reinforce), 222 | "K": 10, 223 | "gamma": 0.99, 224 | "min_value": -10, 225 | "max_value": 10, 226 | "policy_step": 10, 227 | "soft_tau": 0.001, 228 | } 229 | 230 | self.loss_layout = { 231 | "test": {"value": [], "policy": [], "step": []}, 232 | "train": {"value": [], "policy": [], "step": []}, 233 | } 234 | -------------------------------------------------------------------------------- /recnn/nn/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Categorical 5 | 6 | 7 | class AnomalyDetector(nn.Module): 8 | 9 | """ 10 | Anomaly detector used for debugging. Basically an auto encoder. 11 | P.S. You need to use different weights for different embeddings. 12 | """ 13 | 14 | def __init__(self): 15 | super(AnomalyDetector, self).__init__() 16 | self.ae = nn.Sequential( 17 | nn.Linear(128, 64), 18 | nn.ReLU(), 19 | nn.BatchNorm1d(64), 20 | nn.Linear(64, 32), 21 | nn.ReLU(), 22 | nn.BatchNorm1d(32), 23 | nn.Linear(32, 64), 24 | nn.ReLU(), 25 | nn.BatchNorm1d(64), 26 | nn.Linear(64, 128), 27 | nn.ReLU(), 28 | ) 29 | 30 | def forward(self, x): 31 | """""" 32 | return self.ae(x) 33 | 34 | def rec_error(self, x): 35 | error = torch.sum((x - self.ae(x)) ** 2, 1) 36 | if x.size(1) != 1: 37 | return error.detach() 38 | return error.item() 39 | 40 | 41 | class Actor(nn.Module): 42 | 43 | """ 44 | Vanilla actor. Takes state as an argument, returns action. 45 | """ 46 | 47 | def __init__(self, input_dim, action_dim, hidden_size, init_w=2e-1): 48 | super(Actor, self).__init__() 49 | 50 | self.drop_layer = nn.Dropout(p=0.5) 51 | 52 | self.linear1 = nn.Linear(input_dim, hidden_size) 53 | self.linear2 = nn.Linear(hidden_size, hidden_size) 54 | self.linear3 = nn.Linear(hidden_size, action_dim) 55 | 56 | self.linear3.weight.data.uniform_(-init_w, init_w) 57 | self.linear3.bias.data.uniform_(-init_w, init_w) 58 | 59 | def forward(self, state, tanh=False): 60 | """ 61 | :param action: nothing should be provided here. 62 | :param state: state 63 | :param tanh: whether to use tahn as action activation 64 | :return: action 65 | """ 66 | action = F.relu(self.linear1(state)) 67 | action = self.drop_layer(action) 68 | action = F.relu(self.linear2(action)) 69 | action = self.drop_layer(action) 70 | action = self.linear3(action) 71 | if tanh: 72 | action = F.tanh(action) 73 | return action 74 | 75 | 76 | class DiscreteActor(nn.Module): 77 | def __init__(self, input_dim, action_dim, hidden_size, init_w=0): 78 | super(DiscreteActor, self).__init__() 79 | 80 | self.linear1 = nn.Linear(input_dim, hidden_size) 81 | self.linear2 = nn.Linear(hidden_size, action_dim) 82 | 83 | self.saved_log_probs = [] 84 | self.rewards = [] 85 | self.correction = [] 86 | self.lambda_k = [] 87 | 88 | # What's action source? See this issue: https://github.com/awarebayes/RecNN/issues/7 89 | # by default {pi: pi, beta: beta} 90 | # you can change it to be like {pi: beta, beta: beta} as miracle24 suggested 91 | 92 | self.action_source = {"pi": "pi", "beta": "beta"} 93 | self.select_action = self._select_action 94 | 95 | def forward(self, inputs): 96 | x = inputs 97 | x = F.relu(self.linear1(x)) 98 | action_scores = self.linear2(x) 99 | return F.softmax(action_scores) 100 | 101 | def gc(self): 102 | del self.rewards[:] 103 | del self.saved_log_probs[:] 104 | del self.correction[:] 105 | del self.lambda_k[:] 106 | 107 | def _select_action(self, state, **kwargs): 108 | 109 | # for reinforce without correction only pi_probs is available. 110 | # the action source is ignored, since there is no beta 111 | 112 | pi_probs = self.forward(state) 113 | pi_categorical = Categorical(pi_probs) 114 | pi_action = pi_categorical.sample() 115 | self.saved_log_probs.append(pi_categorical.log_prob(pi_action)) 116 | return pi_probs 117 | 118 | def pi_beta_sample(self, state, beta, action, **kwargs): 119 | # 1. obtain probabilities 120 | # note: detach is to block gradient 121 | beta_probs = beta(state.detach(), action=action) 122 | pi_probs = self.forward(state) 123 | 124 | # 2. probabilities -> categorical distribution. 125 | beta_categorical = Categorical(beta_probs) 126 | pi_categorical = Categorical(pi_probs) 127 | 128 | # 3. sample the actions 129 | # See this issue: https://github.com/awarebayes/RecNN/issues/7 130 | # usually it works like: 131 | # pi_action = pi_categorical.sample(); beta_action = beta_categorical.sample(); 132 | # but changing the action_source to {pi: beta, beta: beta} can be configured to be: 133 | # pi_action = beta_categorical.sample(); beta_action = beta_categorical.sample(); 134 | available_actions = { 135 | "pi": pi_categorical.sample(), 136 | "beta": beta_categorical.sample(), 137 | } 138 | pi_action = available_actions[self.action_source["pi"]] 139 | beta_action = available_actions[self.action_source["beta"]] 140 | 141 | # 4. calculate stuff we need 142 | pi_log_prob = pi_categorical.log_prob(pi_action) 143 | beta_log_prob = beta_categorical.log_prob(beta_action) 144 | 145 | return pi_log_prob, beta_log_prob, pi_probs 146 | 147 | def _select_action_with_correction( 148 | self, state, beta, action, writer, step, **kwargs 149 | ): 150 | pi_log_prob, beta_log_prob, pi_probs = self.pi_beta_sample(state, beta, action) 151 | 152 | # calculate correction 153 | corr = torch.exp(pi_log_prob) / torch.exp(beta_log_prob) 154 | 155 | writer.add_histogram("correction", corr, step) 156 | writer.add_histogram("pi_log_prob", pi_log_prob, step) 157 | writer.add_histogram("beta_log_prob", beta_log_prob, step) 158 | 159 | self.correction.append(corr) 160 | self.saved_log_probs.append(pi_log_prob) 161 | 162 | return pi_probs 163 | 164 | def _select_action_with_TopK_correction( 165 | self, state, beta, action, K, writer, step, **kwargs 166 | ): 167 | pi_log_prob, beta_log_prob, pi_probs = self.pi_beta_sample(state, beta, action) 168 | 169 | # calculate correction 170 | corr = torch.exp(pi_log_prob) / torch.exp(beta_log_prob) 171 | 172 | # calculate top K correction 173 | l_k = K * (1 - torch.exp(pi_log_prob)) ** (K - 1) 174 | 175 | writer.add_histogram("correction", corr, step) 176 | writer.add_histogram("l_k", l_k, step) 177 | writer.add_histogram("pi_log_prob", pi_log_prob, step) 178 | writer.add_histogram("beta_log_prob", beta_log_prob, step) 179 | 180 | self.correction.append(corr) 181 | self.lambda_k.append(l_k) 182 | self.saved_log_probs.append(pi_log_prob) 183 | 184 | return pi_probs 185 | 186 | 187 | class Critic(nn.Module): 188 | 189 | """ 190 | Vanilla critic. Takes state and action as an argument, returns value. 191 | """ 192 | 193 | def __init__(self, input_dim, action_dim, hidden_size, init_w=3e-5): 194 | super(Critic, self).__init__() 195 | 196 | self.drop_layer = nn.Dropout(p=0.5) 197 | 198 | self.linear1 = nn.Linear(input_dim + action_dim, hidden_size) 199 | self.linear2 = nn.Linear(hidden_size, hidden_size) 200 | self.linear3 = nn.Linear(hidden_size, 1) 201 | 202 | self.linear3.weight.data.uniform_(-init_w, init_w) 203 | self.linear3.bias.data.uniform_(-init_w, init_w) 204 | 205 | def forward(self, state, action): 206 | """""" 207 | value = torch.cat([state, action], 1) 208 | value = F.relu(self.linear1(value)) 209 | value = self.drop_layer(value) 210 | value = F.relu(self.linear2(value)) 211 | value = self.drop_layer(value) 212 | value = self.linear3(value) 213 | return value 214 | 215 | 216 | class bcqPerturbator(nn.Module): 217 | 218 | """ 219 | Batch constrained perturbative actor. Takes action as an argument, adjusts it. 220 | """ 221 | 222 | def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-1): 223 | super(bcqPerturbator, self).__init__() 224 | 225 | self.drop_layer = nn.Dropout(p=0.5) 226 | 227 | self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size) 228 | self.linear2 = nn.Linear(hidden_size, hidden_size) 229 | self.linear3 = nn.Linear(hidden_size, num_actions) 230 | 231 | self.linear3.weight.data.uniform_(-init_w, init_w) 232 | self.linear3.bias.data.uniform_(-init_w, init_w) 233 | 234 | def forward(self, state, action): 235 | """""" 236 | a = torch.cat([state, action], 1) 237 | a = F.relu(self.linear1(a)) 238 | a = self.drop_layer(a) 239 | a = F.relu(self.linear2(a)) 240 | a = self.drop_layer(a) 241 | a = self.linear3(a) 242 | return a + action 243 | 244 | 245 | class bcqGenerator(nn.Module): 246 | 247 | """ 248 | Batch constrained generator. Basically VAE 249 | """ 250 | 251 | def __init__(self, state_dim, action_dim, latent_dim): 252 | super(bcqGenerator, self).__init__() 253 | # encoder 254 | self.e1 = nn.Linear(state_dim + action_dim, 750) 255 | self.e2 = nn.Linear(750, 750) 256 | 257 | self.mean = nn.Linear(750, latent_dim) 258 | self.log_std = nn.Linear(750, latent_dim) 259 | 260 | # decoder 261 | self.d1 = nn.Linear(state_dim + latent_dim, 750) 262 | self.d2 = nn.Linear(750, 750) 263 | self.d3 = nn.Linear(750, action_dim) 264 | 265 | self.latent_dim = latent_dim 266 | self.normal = torch.distributions.Normal(0, 1) 267 | 268 | def forward(self, state, action): 269 | """""" 270 | # z is encoded state + action 271 | z = F.relu(self.e1(torch.cat([state, action], 1))) 272 | z = F.relu(self.e2(z)) 273 | 274 | mean = self.mean(z) 275 | # Clamped for numerical stability 276 | log_std = self.log_std(z).clamp(-4, 15) 277 | std = torch.exp(log_std) 278 | z = mean + std * self.normal.sample(std.size()).to( 279 | next(self.parameters()).device 280 | ) 281 | 282 | # u is decoded action 283 | u = self.decode(state, z) 284 | 285 | return u, mean, std 286 | 287 | def decode(self, state, z=None): 288 | # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5] 289 | if z is None: 290 | z = self.normal.sample([state.size(0), self.latent_dim]) 291 | z = z.clamp(-0.5, 0.5).to(next(self.parameters()).device) 292 | 293 | a = F.relu(self.d1(torch.cat([state, z], 1))) 294 | a = F.relu(self.d2(a)) 295 | return self.d3(a) 296 | -------------------------------------------------------------------------------- /recnn/nn/update/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import temporal_difference, value_update 2 | from .ddpg import ddpg_update 3 | from .td3 import td3_update 4 | from .bcq import bcq_update 5 | from .reinforce import ChooseREINFORCE, reinforce_update 6 | -------------------------------------------------------------------------------- /recnn/nn/update/bcq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.functional as F 3 | 4 | from recnn import utils 5 | from recnn import data 6 | from recnn.utils import soft_update 7 | from recnn.nn.update import temporal_difference 8 | 9 | 10 | # batch, params, writer, debug, learn=True, step=-1 11 | def bcq_update( 12 | batch, 13 | params, 14 | nets, 15 | optimizer, 16 | device=torch.device("cpu"), 17 | debug=None, 18 | writer=utils.DummyWriter(), 19 | learn=False, 20 | step=-1, 21 | ): 22 | 23 | """ 24 | :param batch: batch [state, action, reward, next_state] returned by environment. 25 | :param params: dict of algorithm parameters. 26 | :param nets: dict of networks. 27 | :param optimizer: dict of optimizers 28 | :param device: torch.device 29 | :param debug: dictionary where debug data about actions is saved 30 | :param writer: torch.SummaryWriter 31 | :param learn: whether to learn on this step (used for testing) 32 | :param step: integer step for policy update 33 | :return: loss dictionary 34 | 35 | How parameters should look like:: 36 | 37 | params = { 38 | # algorithm parameters 39 | 'gamma' : 0.99, 40 | 'soft_tau' : 0.001, 41 | 'n_generator_samples': 10, 42 | 'perturbator_step' : 30, 43 | 44 | # learning rates 45 | 'perturbator_lr' : 1e-5, 46 | 'value_lr' : 1e-5, 47 | 'generator_lr' : 1e-3, 48 | } 49 | 50 | 51 | nets = { 52 | 'generator_net': models.bcqGenerator, 53 | 'perturbator_net': models.bcqPerturbator, 54 | 'target_perturbator_net': models.bcqPerturbator, 55 | 'value_net1': models.Critic, 56 | 'target_value_net1': models.Critic, 57 | 'value_net2': models.Critic, 58 | 'target_value_net2': models.Critic, 59 | } 60 | 61 | optimizer = { 62 | 'generator_optimizer': some optimizer 63 | 'policy_optimizer': some optimizer 64 | 'value_optimizer1': some optimizer 65 | 'value_optimizer2': some optimizer 66 | } 67 | 68 | 69 | """ 70 | 71 | if debug is None: 72 | debug = dict() 73 | state, action, reward, next_state, done = data.get_base_batch(batch, device=device) 74 | batch_size = done.size(0) 75 | 76 | # --------------------------------------------------------# 77 | # Variational Auto-Encoder Learning 78 | recon, mean, std = nets["generator_net"](state, action) 79 | recon_loss = F.mse_loss(recon, action) 80 | KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() 81 | generator_loss = recon_loss + 0.5 * KL_loss 82 | 83 | if not learn: 84 | writer.add_histogram("generator_mean", mean, step) 85 | writer.add_histogram("generator_std", std, step) 86 | debug["recon"] = recon 87 | writer.add_figure( 88 | "reconstructed", utils.pairwise_distances_fig(recon[:50]), step 89 | ) 90 | 91 | if learn: 92 | optimizer["generator_optimizer"].zero_grad() 93 | generator_loss.backward() 94 | optimizer["generator_optimizer"].step() 95 | # --------------------------------------------------------# 96 | # Value Learning 97 | with torch.no_grad(): 98 | # p.s. repeat_interleave was added in torch 1.1 99 | # if an error pops up, run 'conda update pytorch' 100 | state_rep = torch.repeat_interleave( 101 | next_state, params["n_generator_samples"], 0 102 | ) 103 | sampled_action = nets["generator_net"].decode(state_rep) 104 | perturbed_action = nets["target_perturbator_net"](state_rep, sampled_action) 105 | target_Q1 = nets["target_value_net1"](state_rep, perturbed_action) 106 | target_Q2 = nets["target_value_net1"](state_rep, perturbed_action) 107 | target_value = 0.75 * torch.min(target_Q1, target_Q2) # value soft update 108 | target_value += 0.25 * torch.max(target_Q1, target_Q2) # 109 | target_value = target_value.view(batch_size, -1).max(1)[0].view(-1, 1) 110 | 111 | expected_value = temporal_difference( 112 | reward, done, params["gamma"], target_value 113 | ) 114 | 115 | value = nets["value_net1"](state, action) 116 | value_loss = torch.pow(value - expected_value.detach(), 2).mean() 117 | 118 | if learn: 119 | optimizer["value_optimizer1"].zero_grad() 120 | optimizer["value_optimizer2"].zero_grad() 121 | value_loss.backward() 122 | optimizer["value_optimizer1"].step() 123 | optimizer["value_optimizer2"].step() 124 | else: 125 | writer.add_histogram("value", value, step) 126 | writer.add_histogram("target_value", target_value, step) 127 | writer.add_histogram("expected_value", expected_value, step) 128 | writer.close() 129 | 130 | # --------------------------------------------------------# 131 | # Perturbator learning 132 | sampled_actions = nets["generator_net"].decode(state) 133 | perturbed_actions = nets["perturbator_net"](state, sampled_actions) 134 | perturbator_loss = -nets["value_net1"](state, perturbed_actions) 135 | if not learn: 136 | writer.add_histogram("perturbator_loss", perturbator_loss, step) 137 | perturbator_loss = perturbator_loss.mean() 138 | 139 | if learn: 140 | if step % params["perturbator_step"] == 0: 141 | optimizer["perturbator_optimizer"].zero_grad() 142 | perturbator_loss.backward() 143 | torch.nn.utils.clip_grad_norm_(nets["perturbator_net"].parameters(), -1, 1) 144 | optimizer["perturbator_optimizer"].step() 145 | 146 | soft_update( 147 | nets["value_net1"], nets["target_value_net1"], soft_tau=params["soft_tau"] 148 | ) 149 | soft_update( 150 | nets["value_net2"], nets["target_value_net2"], soft_tau=params["soft_tau"] 151 | ) 152 | soft_update( 153 | nets["perturbator_net"], 154 | nets["target_perturbator_net"], 155 | soft_tau=params["soft_tau"], 156 | ) 157 | else: 158 | debug["sampled_actions"] = sampled_actions 159 | debug["perturbed_actions"] = perturbed_actions 160 | writer.add_figure( 161 | "sampled_actions", utils.pairwise_distances_fig(sampled_actions[:50]), step 162 | ) 163 | writer.add_figure( 164 | "perturbed_actions", 165 | utils.pairwise_distances_fig(perturbed_actions[:50]), 166 | step, 167 | ) 168 | 169 | # --------------------------------------------------------# 170 | 171 | losses = { 172 | "value": value_loss.item(), 173 | "perturbator": perturbator_loss.item(), 174 | "generator": generator_loss.item(), 175 | "step": step, 176 | } 177 | 178 | utils.write_losses(writer, losses, kind="train" if learn else "test") 179 | return losses 180 | -------------------------------------------------------------------------------- /recnn/nn/update/ddpg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from recnn import utils 3 | from recnn import data 4 | from recnn.utils import soft_update 5 | from recnn.nn.update import value_update 6 | 7 | 8 | def ddpg_update( 9 | batch, 10 | params, 11 | nets, 12 | optimizer, 13 | device=torch.device("cpu"), 14 | debug=None, 15 | writer=utils.DummyWriter(), 16 | learn=False, 17 | step=-1, 18 | ): 19 | 20 | """ 21 | :param batch: batch [state, action, reward, next_state] returned by environment. 22 | :param params: dict of algorithm parameters. 23 | :param nets: dict of networks. 24 | :param optimizer: dict of optimizers 25 | :param device: torch.device 26 | :param debug: dictionary where debug data about actions is saved 27 | :param writer: torch.SummaryWriter 28 | :param learn: whether to learn on this step (used for testing) 29 | :param step: integer step for policy update 30 | :return: loss dictionary 31 | 32 | How parameters should look like:: 33 | 34 | params = { 35 | 'gamma' : 0.99, 36 | 'min_value' : -10, 37 | 'max_value' : 10, 38 | 'policy_step': 3, 39 | 'soft_tau' : 0.001, 40 | 'policy_lr' : 1e-5, 41 | 'value_lr' : 1e-5, 42 | 'actor_weight_init': 3e-1, 43 | 'critic_weight_init': 6e-1, 44 | } 45 | nets = { 46 | 'value_net': models.Critic, 47 | 'target_value_net': models.Critic, 48 | 'policy_net': models.Actor, 49 | 'target_policy_net': models.Actor, 50 | } 51 | optimizer - { 52 | 'policy_optimizer': some optimizer 53 | 'value_optimizer': some optimizer 54 | } 55 | 56 | """ 57 | 58 | state, action, reward, next_state, _ = data.get_base_batch(batch, device=device) 59 | 60 | # --------------------------------------------------------# 61 | # Value Learning 62 | 63 | value_loss = value_update( 64 | batch, 65 | params, 66 | nets, 67 | optimizer, 68 | writer=writer, 69 | device=device, 70 | debug=debug, 71 | learn=learn, 72 | step=step, 73 | ) 74 | 75 | # --------------------------------------------------------# 76 | # Policy learning 77 | 78 | gen_action = nets["policy_net"](state) 79 | policy_loss = -nets["value_net"](state, gen_action) 80 | 81 | if not learn: 82 | debug["gen_action"] = gen_action 83 | writer.add_histogram("policy_loss", policy_loss, step) 84 | writer.add_figure( 85 | "next_action", utils.pairwise_distances_fig(gen_action[:50]), step 86 | ) 87 | policy_loss = policy_loss.mean() 88 | 89 | if learn and step % params["policy_step"] == 0: 90 | optimizer["policy_optimizer"].zero_grad() 91 | policy_loss.backward(retain_graph=True) 92 | torch.nn.utils.clip_grad_norm_(nets["policy_net"].parameters(), -1, 1) 93 | optimizer["policy_optimizer"].step() 94 | 95 | soft_update( 96 | nets["value_net"], nets["target_value_net"], soft_tau=params["soft_tau"] 97 | ) 98 | soft_update( 99 | nets["policy_net"], nets["target_policy_net"], soft_tau=params["soft_tau"] 100 | ) 101 | 102 | losses = {"value": value_loss.item(), "policy": policy_loss.item(), "step": step} 103 | utils.write_losses(writer, losses, kind="train" if learn else "test") 104 | return losses 105 | -------------------------------------------------------------------------------- /recnn/nn/update/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from recnn import utils 3 | from recnn import data 4 | 5 | 6 | def temporal_difference(reward, done, gamma, target): 7 | return reward + (1.0 - done) * gamma * target 8 | 9 | 10 | def value_update( 11 | batch, 12 | params, 13 | nets, 14 | optimizer, 15 | device=torch.device("cpu"), 16 | debug=None, 17 | writer=utils.DummyWriter(), 18 | learn=False, 19 | step=-1, 20 | ): 21 | """ 22 | Everything is the same as in ddpg_update 23 | """ 24 | 25 | state, action, reward, next_state, done = data.get_base_batch(batch, device=device) 26 | 27 | with torch.no_grad(): 28 | next_action = nets["target_policy_net"](next_state) 29 | target_value = nets["target_value_net"](next_state, next_action.detach()) 30 | expected_value = temporal_difference( 31 | reward, done, params["gamma"], target_value 32 | ) 33 | expected_value = torch.clamp( 34 | expected_value, params["min_value"], params["max_value"] 35 | ) 36 | 37 | value = nets["value_net"](state, action) 38 | 39 | value_loss = torch.pow(value - expected_value.detach(), 2).mean() 40 | 41 | if learn: 42 | optimizer["value_optimizer"].zero_grad() 43 | value_loss.backward(retain_graph=True) 44 | optimizer["value_optimizer"].step() 45 | 46 | elif not learn: 47 | debug["next_action"] = next_action 48 | writer.add_figure( 49 | "next_action", utils.pairwise_distances_fig(next_action[:50]), step 50 | ) 51 | writer.add_histogram("value", value, step) 52 | writer.add_histogram("target_value", target_value, step) 53 | writer.add_histogram("expected_value", expected_value, step) 54 | 55 | return value_loss 56 | -------------------------------------------------------------------------------- /recnn/nn/update/reinforce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from recnn import utils 4 | from recnn import data 5 | from recnn.utils import soft_update 6 | from recnn.nn.update import value_update 7 | import gc 8 | 9 | 10 | class ChooseREINFORCE: 11 | def __init__(self, method=None): 12 | if method is None: 13 | method = ChooseREINFORCE.basic_reinforce 14 | self.method = method 15 | 16 | @staticmethod 17 | def basic_reinforce(policy, returns, *args, **kwargs): 18 | policy_loss = [] 19 | for log_prob, R in zip(policy.saved_log_probs, returns): 20 | policy_loss.append(-log_prob * R) # <- this line here 21 | policy_loss = torch.cat(policy_loss).sum() 22 | return policy_loss 23 | 24 | @staticmethod 25 | def reinforce_with_correction(policy, returns, *args, **kwargs): 26 | policy_loss = [] 27 | for corr, log_prob, R in zip( 28 | policy.correction, policy.saved_log_probs, returns 29 | ): 30 | policy_loss.append(corr * -log_prob * R) # <- this line here 31 | policy_loss = torch.cat(policy_loss).sum() 32 | return policy_loss 33 | 34 | @staticmethod 35 | def reinforce_with_TopK_correction(policy, returns, *args, **kwargs): 36 | policy_loss = [] 37 | for l_k, corr, log_prob, R in zip( 38 | policy.lambda_k, policy.correction, policy.saved_log_probs, returns 39 | ): 40 | policy_loss.append(l_k * corr * -log_prob * R) # <- this line here 41 | policy_loss = torch.cat(policy_loss).sum() 42 | return policy_loss 43 | 44 | def __call__(self, policy, optimizer, learn=True): 45 | R = 0 46 | 47 | returns = [] 48 | for r in policy.rewards[::-1]: 49 | R = r + 0.99 * R 50 | returns.insert(0, R) 51 | 52 | returns = torch.tensor(returns) 53 | returns = (returns - returns.mean()) / (returns.std() + 0.0001) 54 | 55 | policy_loss = self.method(policy, returns) 56 | 57 | if learn: 58 | optimizer.zero_grad() 59 | policy_loss.backward() 60 | optimizer.step() 61 | 62 | policy.gc() 63 | gc.collect() 64 | 65 | return policy_loss 66 | 67 | 68 | def reinforce_update( 69 | batch, 70 | params, 71 | nets, 72 | optimizer, 73 | device=torch.device("cpu"), 74 | debug=None, 75 | writer=utils.DummyWriter(), 76 | learn=True, 77 | step=-1, 78 | ): 79 | 80 | # Due to its mechanics, reinforce doesn't support testing! 81 | learn = True 82 | 83 | state, action, reward, next_state, done = data.get_base_batch(batch) 84 | 85 | predicted_probs = nets["policy_net"].select_action( 86 | state=state, action=action, K=params["K"], learn=learn, writer=writer, step=step 87 | ) 88 | writer.add_histogram("predicted_probs_std", predicted_probs.std(), step) 89 | writer.add_histogram("predicted_probs_mean", predicted_probs.mean(), step) 90 | mx = predicted_probs.max(dim=1).values 91 | writer.add_histogram("predicted_probs_max_mean", mx.mean(), step) 92 | writer.add_histogram("predicted_probs_max_std", mx.std(), step) 93 | reward = nets["value_net"](state, predicted_probs).detach() 94 | nets["policy_net"].rewards.append(reward.mean()) 95 | 96 | value_loss = value_update( 97 | batch, 98 | params, 99 | nets, 100 | optimizer, 101 | writer=writer, 102 | device=device, 103 | debug=debug, 104 | learn=True, 105 | step=step, 106 | ) 107 | 108 | if step % params["policy_step"] == 0 and step > 0: 109 | policy_loss = params["reinforce"]( 110 | nets["policy_net"], 111 | optimizer["policy_optimizer"], 112 | ) 113 | 114 | utils.soft_update( 115 | nets["value_net"], nets["target_value_net"], soft_tau=params["soft_tau"] 116 | ) 117 | utils.soft_update( 118 | nets["policy_net"], nets["target_policy_net"], soft_tau=params["soft_tau"] 119 | ) 120 | 121 | losses = { 122 | "value": value_loss.item(), 123 | "policy": policy_loss.item(), 124 | "step": step, 125 | } 126 | 127 | utils.write_losses(writer, losses, kind="train" if learn else "test") 128 | 129 | return losses 130 | -------------------------------------------------------------------------------- /recnn/nn/update/td3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from recnn import utils 3 | from recnn import data 4 | from recnn.utils import soft_update 5 | from recnn.nn.update import temporal_difference 6 | 7 | 8 | def td3_update( 9 | batch, 10 | params, 11 | nets, 12 | optimizer, 13 | device=torch.device("cpu"), 14 | debug=None, 15 | writer=utils.DummyWriter(), 16 | learn=False, 17 | step=-1, 18 | ): 19 | """ 20 | :param batch: batch [state, action, reward, next_state] returned by environment. 21 | :param params: dict of algorithm parameters. 22 | :param nets: dict of networks. 23 | :param optimizer: dict of optimizers 24 | :param device: torch.device 25 | :param debug: dictionary where debug data about actions is saved 26 | :param writer: torch.SummaryWriter 27 | :param learn: whether to learn on this step (used for testing) 28 | :param step: integer step for policy update 29 | :return: loss dictionary 30 | 31 | How parameters should look like:: 32 | 33 | params = { 34 | 'gamma': 0.99, 35 | 'noise_std': 0.5, 36 | 'noise_clip': 3, 37 | 'soft_tau': 0.001, 38 | 'policy_update': 10, 39 | 40 | 'policy_lr': 1e-5, 41 | 'value_lr': 1e-5, 42 | 43 | 'actor_weight_init': 25e-2, 44 | 'critic_weight_init': 6e-1, 45 | } 46 | 47 | 48 | nets = { 49 | 'value_net1': models.Critic, 50 | 'target_value_net1': models.Critic, 51 | 'value_net2': models.Critic, 52 | 'target_value_net2': models.Critic, 53 | 'policy_net': models.Actor, 54 | 'target_policy_net': models.Actor, 55 | } 56 | 57 | optimizer = { 58 | 'policy_optimizer': some optimizer 59 | 'value_optimizer1': some optimizer 60 | 'value_optimizer2': some optimizer 61 | } 62 | 63 | 64 | """ 65 | 66 | if debug is None: 67 | debug = dict() 68 | state, action, reward, next_state, done = data.get_base_batch(batch, device=device) 69 | 70 | # --------------------------------------------------------# 71 | # Value Learning 72 | 73 | next_action = nets["target_policy_net"](next_state) 74 | noise = torch.normal(torch.zeros(next_action.size()), params["noise_std"]).to( 75 | device 76 | ) 77 | noise = torch.clamp(noise, -params["noise_clip"], params["noise_clip"]) 78 | next_action += noise 79 | 80 | with torch.no_grad(): 81 | target_q_value1 = nets["target_value_net1"](next_state, next_action) 82 | target_q_value2 = nets["target_value_net2"](next_state, next_action) 83 | target_q_value = torch.min(target_q_value1, target_q_value2) 84 | expected_q_value = temporal_difference( 85 | reward, done, params["gamma"], target_q_value 86 | ) 87 | 88 | q_value1 = nets["value_net1"](state, action) 89 | q_value2 = nets["value_net2"](state, action) 90 | 91 | value_criterion = torch.nn.MSELoss() 92 | value_loss1 = value_criterion(q_value1, expected_q_value.detach()) 93 | value_loss2 = value_criterion(q_value2, expected_q_value.detach()) 94 | 95 | if learn: 96 | optimizer["value_optimizer1"].zero_grad() 97 | value_loss1.backward() 98 | optimizer["value_optimizer1"].step() 99 | 100 | optimizer["value_optimizer2"].zero_grad() 101 | value_loss2.backward() 102 | optimizer["value_optimizer2"].step() 103 | else: 104 | debug["next_action"] = next_action 105 | writer.add_figure( 106 | "next_action", utils.pairwise_distances_fig(next_action[:50]), step 107 | ) 108 | writer.add_histogram("value1", q_value1, step) 109 | writer.add_histogram("value2", q_value2, step) 110 | writer.add_histogram("target_value", target_q_value, step) 111 | writer.add_histogram("expected_value", expected_q_value, step) 112 | 113 | # --------------------------------------------------------# 114 | # Policy learning 115 | 116 | gen_action = nets["policy_net"](state) 117 | policy_loss = nets["value_net1"](state, gen_action) 118 | policy_loss = -policy_loss 119 | 120 | if not learn: 121 | debug["gen_action"] = gen_action 122 | writer.add_figure( 123 | "gen_action", utils.pairwise_distances_fig(gen_action[:50]), step 124 | ) 125 | writer.add_histogram("policy_loss", policy_loss, step) 126 | 127 | policy_loss = policy_loss.mean() 128 | 129 | # delayed policy update 130 | if step % params["policy_update"] == 0 and learn: 131 | optimizer["policy_optimizer"].zero_grad() 132 | policy_loss.backward() 133 | torch.nn.utils.clip_grad_norm_(nets["policy_net"].parameters(), -1, 1) 134 | optimizer["policy_optimizer"].step() 135 | 136 | soft_update( 137 | nets["value_net1"], nets["target_value_net1"], soft_tau=params["soft_tau"] 138 | ) 139 | soft_update( 140 | nets["value_net2"], nets["target_value_net2"], soft_tau=params["soft_tau"] 141 | ) 142 | 143 | losses = { 144 | "value1": value_loss1.item(), 145 | "value2": value_loss2.item(), 146 | "policy": policy_loss.item(), 147 | "step": step, 148 | } 149 | utils.write_losses(writer, losses, kind="train" if learn else "test") 150 | return losses 151 | -------------------------------------------------------------------------------- /recnn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import misc, plot 2 | from .misc import * 3 | from .plot import * 4 | -------------------------------------------------------------------------------- /recnn/utils/misc.py: -------------------------------------------------------------------------------- 1 | def soft_update(net, target_net, soft_tau=1e-2): 2 | for target_param, param in zip(target_net.parameters(), net.parameters()): 3 | target_param.data.copy_( 4 | target_param.data * (1.0 - soft_tau) + param.data * soft_tau 5 | ) 6 | 7 | 8 | def write_losses(writer, loss_dict, kind="train"): 9 | def write_loss(kind, key, item, step): 10 | writer.add_scalar(kind + "/" + key, item, global_step=step) 11 | 12 | step = loss_dict["step"] 13 | for k, v in loss_dict.items(): 14 | if k == "step": 15 | continue 16 | write_loss(kind, k, v, step) 17 | 18 | writer.close() 19 | 20 | 21 | class DummyWriter: 22 | def add_figure(self, *args, **kwargs): 23 | pass 24 | 25 | def add_histogram(self, *args, **kwargs): 26 | pass 27 | 28 | def add_scalar(self, *args, **kwargs): 29 | pass 30 | 31 | def add_scalars(self, *args, **kwargs): 32 | pass 33 | 34 | def close(self, *args, **kwargs): 35 | pass 36 | -------------------------------------------------------------------------------- /recnn/utils/plot.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import distance 2 | from scipy import ndimage 3 | import matplotlib.pyplot as plt 4 | import torch 5 | from scipy import stats 6 | import numpy as np 7 | 8 | 9 | def pairwise_distances_fig(embs): 10 | embs = embs.detach().cpu().numpy() 11 | similarity_matrix_cos = distance.cdist(embs, embs, "cosine") 12 | similarity_matrix_euc = distance.cdist(embs, embs, "euclidean") 13 | 14 | fig = plt.figure(figsize=(16, 10)) 15 | 16 | ax = fig.add_subplot(121) 17 | cax = ax.matshow(similarity_matrix_cos) 18 | fig.colorbar(cax) 19 | ax.set_title("Cosine") 20 | ax.axis("off") 21 | 22 | ax = fig.add_subplot(122) 23 | cax = ax.matshow(similarity_matrix_euc) 24 | fig.colorbar(cax) 25 | ax.set_title("Euclidian") 26 | ax.axis("off") 27 | 28 | fig.suptitle("Action pairwise distances") 29 | plt.close() 30 | return fig 31 | 32 | 33 | def pairwise_distances(embs): 34 | fig = pairwise_distances_fig(embs) 35 | fig.show() 36 | 37 | 38 | def smooth(scalars, weight): # Weight between 0 and 1 39 | last = scalars[0] # First value in the plot (first timestep) 40 | smoothed = list() 41 | for point in scalars: 42 | smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value 43 | smoothed.append(smoothed_val) # Save it 44 | last = smoothed_val # Anchor the last smoothed value 45 | 46 | return smoothed 47 | 48 | 49 | def smooth_gauss(arr, var): 50 | return ndimage.gaussian_filter1d(arr, var) 51 | 52 | 53 | class Plotter: 54 | def __init__(self, loss, style): 55 | self.loss = loss 56 | self.style = style 57 | self.smoothing = lambda x: smooth_gauss(x, 4) 58 | 59 | def set_smoothing_func(self, f): 60 | self.smoothing = f 61 | 62 | def plot_loss(self): 63 | for row in self.style: 64 | fig, axes = plt.subplots(1, len(row), figsize=(16, 6)) 65 | if len(row) == 1: 66 | axes = [axes] 67 | for col in range(len(row)): 68 | key = row[col] 69 | axes[col].set_title(key) 70 | axes[col].plot( 71 | self.loss["train"]["step"], 72 | self.smoothing(self.loss["train"][key]), 73 | "b-", 74 | label="train", 75 | ) 76 | axes[col].plot( 77 | self.loss["test"]["step"], 78 | self.loss["test"][key], 79 | "r-.", 80 | label="test", 81 | ) 82 | plt.legend() 83 | plt.show() 84 | 85 | def log_loss(self, key, item, test=False): 86 | kind = "train" 87 | if test: 88 | kind = "test" 89 | self.loss[kind][key].append(item) 90 | 91 | def log_losses(self, losses, test=False): 92 | for key, val in losses.items(): 93 | self.log_loss(key, val, test) 94 | 95 | @staticmethod 96 | def kde_reconstruction_error( 97 | ad, gen_actions, true_actions, device=torch.device("cpu") 98 | ): 99 | def rec_score(actions): 100 | return ( 101 | ad.rec_error(torch.tensor(actions).to(device).float()) 102 | .detach() 103 | .cpu() 104 | .numpy() 105 | ) 106 | 107 | true_scores = rec_score(true_actions) 108 | gen_scores = rec_score(gen_actions) 109 | 110 | true_kernel = stats.gaussian_kde(true_scores) 111 | gen_kernel = stats.gaussian_kde(gen_scores) 112 | 113 | x = np.linspace(0, 1000, 100) 114 | probs_true = true_kernel(x) 115 | probs_gen = gen_kernel(x) 116 | fig = plt.figure(figsize=(16, 10)) 117 | ax = fig.add_subplot(111) 118 | ax.plot(x, probs_true, "-b", label="true dist") 119 | ax.plot(x, probs_gen, "-r", label="generated dist") 120 | ax.legend() 121 | return fig 122 | 123 | @staticmethod 124 | def plot_kde_reconstruction_error(*args, **kwargs): 125 | fig = Plotter.kde_reconstruction_error(*args, **kwargs) 126 | fig.show() 127 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | # numpy==1.19.2 3 | tensorboard 4 | pymilvus 5 | # torch==1.6.0 # install with conda instead https://pytorch.org/ 6 | tqdm 7 | torch_optimizer 8 | matplotlib 9 | jupyterthemes 10 | pandas 11 | scikit_learn 12 | sphinx_rtd_theme 13 | streamlit 14 | pytest 15 | -------------------------------------------------------------------------------- /res/article_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awarebayes/RecNN/e61c764295ccd6d52f716b4bf9a0898a707ef85a/res/article_1.png -------------------------------------------------------------------------------- /res/article_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awarebayes/RecNN/e61c764295ccd6d52f716b4bf9a0898a707ef85a/res/article_2.png -------------------------------------------------------------------------------- /res/logo big.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awarebayes/RecNN/e61c764295ccd6d52f716b4bf9a0898a707ef85a/res/logo big.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = readme.md 3 | [pycodestyle] 4 | max-line-length = 160 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="recnn", 5 | version="0.1", 6 | description="A Python toolkit for Reinforced News Recommendation.", 7 | long_description="A Python toolkit for Reinforced News Recommendation.", 8 | author="Mike Watts", 9 | author_email="awarebayes@gmail.com", 10 | license="Apache 2.0", 11 | packages=find_packages(), 12 | install_requires=[ 13 | "torch", 14 | "numpy", 15 | ], 16 | url="https://github.com/awarebayes/RecNN", 17 | zip_safe=False, 18 | classifiers=[ 19 | "Development Status :: 4 - Beta", 20 | "Intended Audience :: Developers", # Define that your audience are developers 21 | "Topic :: Software Development :: Build Tools", 22 | "License :: OSI Approved :: MIT License", # Again, pick a license 23 | "Programming Language :: Python :: 3", # Specify which pyhton versions that you want to support 24 | "Programming Language :: Python :: 3.6", 25 | "Programming Language :: Python :: 3.7", 26 | "Programming Language :: Python :: 3.8", 27 | ], 28 | ) 29 | --------------------------------------------------------------------------------