├── .github └── workflows │ └── main.yml ├── .pylintrc ├── LICENSE ├── README.md ├── demo ├── README.md ├── cycle_time_50M_users_5seeds.csv ├── demo_examples.ipynb ├── requirements.txt ├── rs_performance_evaluation.ipynb ├── simulation_cycle_time.ipynb ├── synthetic_data_generation.ipynb ├── synthetic_data_generation_time.ipynb └── synthetic_gen_time.csv ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── index.rst │ └── pages │ ├── modules.rst │ ├── response.rst │ └── utils.rst ├── experiments ├── Amazon │ ├── RU_amazon_processing.ipynb │ ├── RU_load_dataset.ipynb │ ├── RU_train_test_split.ipynb │ └── amazon.py ├── Movielens │ ├── ml20.py │ ├── movie_preprocessing.ipynb │ └── traint_test_split.ipynb ├── Netflix │ ├── RU_load_datasets.ipynb │ ├── RU_netflix_processing.ipynb │ ├── RU_train_test_split.ipynb │ └── netflix.py ├── RU_amazon_embeddings.ipynb ├── RU_amazon_generators.ipynb ├── RU_amazon_response.ipynb ├── RU_amazon_surface.ipynb ├── RU_netflix_embeddings.ipynb ├── RU_netflix_generators.ipynb ├── RU_netflix_response.ipynb ├── RU_netflix_surface.ipynb ├── RU_sim_validation.ipynb ├── RU_validation_quality_control.ipynb ├── custom_response_function.ipynb ├── datautils.py ├── generators_fit.ipynb ├── generators_generate.ipynb ├── generators_generate.py ├── items_generation.ipynb ├── movielens_embeddings.ipynb ├── movielens_generators.ipynb ├── movielens_quality_control.ipynb ├── movielens_response.ipynb ├── movielens_scenario.ipynb ├── movielens_surface.ipynb ├── response_models │ ├── data │ │ └── popular_items_popularity.parquet │ │ │ ├── .part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet.crc │ │ │ ├── _SUCCESS │ │ │ └── part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet │ ├── task_1_popular_items.py │ └── utils.py ├── simulator_time.ipynb ├── simulator_time.py └── transformers.py ├── notebooks ├── embeddings.ipynb └── pipeline.ipynb ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── sim4rec ├── __init__.py ├── modules │ ├── __init__.py │ ├── embeddings.py │ ├── evaluation.py │ ├── generator.py │ ├── selectors.py │ └── simulator.py ├── params │ ├── __init__.py │ └── params.py ├── recommenders │ ├── ucb.py │ └── utils.py ├── response │ ├── __init__.py │ └── response.py └── utils │ ├── __init__.py │ ├── convert.py │ ├── session_handler.py │ └── uce.py └── tests ├── __init__.py ├── conftest.py ├── test_embeddings.py ├── test_evaluation.py ├── test_generators.py ├── test_responses.py ├── test_selectors.py └── test_simulator.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | # Controls when the workflow will run 4 | on: 5 | # Triggers the workflow on pull request events but only for the main branch 6 | pull_request: 7 | branches: [main] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 13 | jobs: 14 | run_tests: 15 | runs-on: ubuntu-20.04 16 | strategy: 17 | matrix: 18 | python-version: ["3.8", "3.9"] 19 | steps: 20 | - uses: actions/checkout@v2 21 | - uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install package 25 | run: | 26 | python -m venv venv 27 | . ./venv/bin/activate 28 | pip install --upgrade pip wheel poetry pycodestyle pylint pytest-cov 29 | 30 | poetry cache clear pypi --all 31 | poetry lock 32 | poetry install 33 | - name: Build docs 34 | run: | 35 | . ./venv/bin/activate 36 | cd docs 37 | make clean html 38 | - name: pycodestyle 39 | run: | 40 | . ./venv/bin/activate 41 | pycodestyle --ignore=E203,E231,E501,W503,W605,E122,E125 --max-doc-length=160 sim4rec tests 42 | - name: pylint 43 | run: | 44 | . ./venv/bin/activate 45 | pylint --rcfile=.pylintrc sim4rec 46 | - name: pytest 47 | run: | 48 | . ./venv/bin/activate 49 | pytest --cov=sim4rec --cov-report=term-missing --doctest-modules sim4rec --cov-fail-under=89 tests 50 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [GENERAL] 2 | fail-under=10 3 | 4 | [TYPECHECK] 5 | ignored-modules=pyspark.sql.functions,pyspark,torch,ignite,ignite.engine,pyspark.sql,statsmodels.stats.proportion,numba 6 | disable=bad-option-value,no-else-return 7 | 8 | [MESSAGES CONTROL] 9 | disable=anomalous-backslash-in-string,bad-continuation,missing-module-docstring,wrong-import-order,protected-access,consider-using-generator,consider-using-enumerate,invalid-name 10 | 11 | [FORMAT] 12 | ignore-long-lines=^.*([А-Яа-я]|>{3}|\.{3}|\\{2}|https://).*$ 13 | max-line-length=100 14 | good-names=df,i,j,k,n,_,x,y 15 | 16 | [SIMILARITIES] 17 | min-similarity-lines=33 18 | ignore-comments=yes 19 | ignore-docstrings=yes 20 | ignore-imports=yes 21 | 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Sber AI Lab 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simulator 2 | 3 | Simulator is a framework for training and evaluating recommendation algorithms on real or synthetic data. The framework is based on the pyspark library for working with big data. 4 | As part of the simulation process, the framework includes data generators, response functions and other tools that allow flexible use of the simulator. 5 | 6 | # Table of contents 7 | 8 | * [Installation](#installation) 9 | * [Quickstart](#quickstart) 10 | * [Examples](#examples) 11 | * [Build from sources](#build-from-sources) 12 | * [Building documentation](#compile-documentation) 13 | * [Running tests](#tests) 14 | 15 | ## Installation 16 | 17 | ```bash 18 | pip install sim4rec 19 | ``` 20 | 21 | If the installation takes too long, try 22 | ```bash 23 | pip install sim4rec --use-deprecated=legacy-resolver 24 | ``` 25 | 26 | To install dependencies with poetry run 27 | 28 | ```bash 29 | pip install --upgrade pip wheel poetry 30 | poetry install 31 | ``` 32 | 33 | ## Quickstart 34 | 35 | The following example shows how to use simulator to train model iteratively by refitting recommendation algorithm on the new upcoming history log 36 | 37 | ```python 38 | import numpy as np 39 | import pandas as pd 40 | 41 | import pyspark.sql.types as st 42 | from pyspark.ml import PipelineModel 43 | 44 | from sim4rec.modules import RealDataGenerator, Simulator, EvaluateMetrics 45 | from sim4rec.response import NoiseResponse, BernoulliResponse 46 | from sim4rec.recommenders.ucb import UCB 47 | from sim4rec.utils import pandas_to_spark 48 | 49 | LOG_SCHEMA = st.StructType([ 50 | st.StructField('user_idx', st.LongType(), True), 51 | st.StructField('item_idx', st.LongType(), True), 52 | st.StructField('relevance', st.DoubleType(), False), 53 | st.StructField('response', st.IntegerType(), False) 54 | ]) 55 | 56 | users_df = pd.DataFrame( 57 | data=np.random.normal(0, 1, size=(100, 15)), 58 | columns=[f'user_attr_{i}' for i in range(15)] 59 | ) 60 | items_df = pd.DataFrame( 61 | data=np.random.normal(1, 1, size=(30, 10)), 62 | columns=[f'item_attr_{i}' for i in range(10)] 63 | ) 64 | history_df = pandas_to_spark(pd.DataFrame({ 65 | 'user_idx' : [1, 10, 10, 50], 66 | 'item_idx' : [4, 25, 26, 25], 67 | 'relevance' : [1.0, 0.0, 1.0, 1.0], 68 | 'response' : [1, 0, 1, 1] 69 | }), schema=LOG_SCHEMA) 70 | 71 | users_df['user_idx'] = np.arange(len(users_df)) 72 | items_df['item_idx'] = np.arange(len(items_df)) 73 | 74 | users_df = pandas_to_spark(users_df) 75 | items_df = pandas_to_spark(items_df) 76 | 77 | user_gen = RealDataGenerator(label='users_real') 78 | item_gen = RealDataGenerator(label='items_real') 79 | user_gen.fit(users_df) 80 | item_gen.fit(items_df) 81 | _ = user_gen.generate(100) 82 | _ = item_gen.generate(30) 83 | 84 | sim = Simulator( 85 | user_gen=user_gen, 86 | item_gen=item_gen, 87 | data_dir='test_simulator', 88 | user_key_col='user_idx', 89 | item_key_col='item_idx', 90 | log_df=history_df 91 | ) 92 | 93 | noise_resp = NoiseResponse(mu=0.5, sigma=0.2, outputCol='__noise') 94 | br = BernoulliResponse(inputCol='__noise', outputCol='response') 95 | pipeline = PipelineModel(stages=[noise_resp, br]) 96 | 97 | model = UCB() 98 | model.fit(log=history_df) 99 | 100 | evaluator = EvaluateMetrics( 101 | userKeyCol='user_idx', 102 | itemKeyCol='item_idx', 103 | predictionCol='relevance', 104 | labelCol='response', 105 | mllib_metrics=['areaUnderROC'] 106 | ) 107 | 108 | metrics = [] 109 | for i in range(10): 110 | users = sim.sample_users(0.1).cache() 111 | 112 | recs = model.predict( 113 | log=sim.log, k=5, users=users, items=items_df, filter_seen_items=True 114 | ).cache() 115 | 116 | true_resp = ( 117 | sim.sample_responses( 118 | recs_df=recs, 119 | user_features=users, 120 | item_features=items_df, 121 | action_models=pipeline, 122 | ) 123 | .select("user_idx", "item_idx", "relevance", "response") 124 | .cache() 125 | ) 126 | 127 | sim.update_log(true_resp, iteration=i) 128 | 129 | metrics.append(evaluator(true_resp)) 130 | 131 | model.fit(sim.log.drop("relevance").withColumnRenamed("response", "relevance")) 132 | 133 | users.unpersist() 134 | recs.unpersist() 135 | true_resp.unpersist() 136 | ``` 137 | 138 | ## Examples 139 | 140 | You can find useful examples in the 'notebooks' folder, which demonstrate how to use synthetic data generators, composite generators, evaluate the results of the generators, iteratively refit the recommendation algorithm, use response functions and more. 141 | 142 | Experiments with different datasets and a tutorial on writing custom response functions can be found in the 'experiments' folder. 143 | 144 | ## Case studies 145 | 146 | Case studies prepared for the demo track are available in the 'demo' directory. 147 | 148 | 1. Synthetic data generation 149 | 2. Long-term RS performance evaluation 150 | 151 | ## Build from sources 152 | 153 | ```bash 154 | poetry build 155 | pip install ./dist/sim4rec-0.0.1-py3-none-any.whl 156 | ``` 157 | 158 | ## Compile documentation 159 | 160 | ```bash 161 | cd docs 162 | make clean && make html 163 | ``` 164 | 165 | ## Tests 166 | 167 | The pytest Python library is used for testing, and to run tests for all modules you can run the following command from the repository root directory 168 | 169 | ```bash 170 | pytest 171 | ``` 172 | 173 | ## Licence 174 | Sim4Rec is distributed under the [Apache Licence Version 2.0] (https://github.com/sb-ai-lab/Sim4Rec/blob/main/LICENSE), 175 | however the SDV package imported by Sim4Rec for synthetic data generation 176 | is distributed under the [Business Source Licence (BSL) 1.1](https://github.com/sdv-dev/SDV/blob/master/LICENSE). 177 | 178 | Synthetic tabular data generation is not a purpose of the Sit4Rec framework. 179 | Sim4Rec provides an API and wrappers to run simulations with synthetic data, but the method of synthetic data generation is determined by the user. 180 | The SDV package is imported for illustration purposes and can be replaced by another synthetic data generation solution. 181 | 182 | Thus, synthetic data generation functionality and quality evaluation is provided by the SDV library, 183 | namely `SDVDataGenerator` from [generator.py](sim4rec/modules/generator.py) and `evaluate_synthetic` from [evaluation.py](sim4rec/modules/evaluation.py) 184 | should only be used for non-production purposes according to the SDV licence. 185 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | # Case studies for demo 2 | Cases for the demo: 3 | 1. [Synthetic data generation](https://github.com/sb-ai-lab/Sim4Rec/blob/main/demo/synthetic_data_generation.ipynb) 4 | 5 | Generation of synthetic users based on real data. 6 | 7 | 2. [Long-term RS performance evaluation](https://github.com/sb-ai-lab/Sim4Rec/blob/main/demo/rs_performance_evaluation.ipynb) 8 | 9 | Simple simulation pipeline for long-term performance evaluation of recommender system. 10 | 11 | # Table of contents 12 | 13 | * [Installation](#installation) 14 | * [Synthetic data generation pipeline](#synthetic-data-generation-pipeline) 15 | * [Long-term RS performance evaluation pipeline](#long-term-RS-performance-evaluation-pipeline) 16 | 17 | ## Installation 18 | 19 | Install dependencies with poetry run 20 | 21 | ```bash 22 | pip install --upgrade pip wheel poetry 23 | poetry install 24 | ``` 25 | 26 | Please install `rs_datasets` to import MovieLens dataset for the synthetic_data_generation.ipynb 27 | ```bash 28 | pip install rs_datasets 29 | ``` 30 | 31 | ## Synthetic data generation pipeline 32 | 1. Fit non-negative ALS to real data containing user interactions 33 | 2. Obtain vector representations of real users 34 | 3. Fit CopulaGAN to non-negative ALS embeddings of real users 35 | 4. Generate synthetic user feature vectors with CopulaGAN 36 | 5. Evaluate the quality of the synthetic user profiles 37 | 38 | ## Long-term RS performance evaluation pipeline 39 | 1. Before running the simulation cycle: 40 | - Initialize and fit recommender model to historical data 41 | - Construct response function pipeline and fit it to items 42 | - Initialize simulator 43 | 2. Run the simulation cycle: 44 | - Sample users using a simulator 45 | - Get recommendations for sampled users from the recommender system 46 | - Get responses to recommended items from the response function 47 | - Update user interaction history 48 | - Measure the quality of the recommender system 49 | - Refit the recommender model 50 | 3. After running the simulation cycle: 51 | - Get recommendations for all users from the recommender system 52 | - Get responses to recommended items from the response function 53 | - Measure the quality of the recommender system trained in the simulation cycle 54 | 55 | -------------------------------------------------------------------------------- /demo/cycle_time_50M_users_5seeds.csv: -------------------------------------------------------------------------------- 1 | index,sample_users_time,sample_responses_time,update_log_time,metrics_time,iter_time,users_num 2 | 0,11.15279746055603,1.0836639404296875,4.48405385017395,10.712761878967283,39.2559826374054,1001660 3 | 1,6.915727138519287,1.144331455230713,2.2269480228424072,7.514208316802978,29.782680749893192,3450869 4 | 2,9.624614000320436,1.9051272869110107,2.5454702377319336,8.540910720825195,38.522289514541626,5897983 5 | 3,10.06862211227417,2.072542428970337,2.2254538536071777,10.562509536743164,41.597875118255615,8347862 6 | 4,10.775094747543335,3.0510990619659424,3.777689456939697,10.119225978851318,44.3191511631012,10798519 7 | 5,10.705054521560667,2.914153575897217,3.795412063598633,8.681732177734375,40.83335065841675,13252619 8 | 6,8.72693920135498,3.5513224601745605,3.9841458797454834,11.071977853775024,42.67783713340759,15702401 9 | 7,11.697437286376953,4.566967010498047,3.8343541622161865,9.24762773513794,44.81257128715515,18153435 10 | 8,9.278602838516237,4.395270109176637,4.1952335834503165,10.463558673858644,45.67495346069336,20603933 11 | 9,9.882215976715088,4.600301027297974,4.617482662200928,7.5286338329315186,39.09629273414612,23053131 12 | 10,8.283291816711426,4.990021228790283,4.571868658065796,6.711239576339723,39.4278380870819,25502789 13 | 11,7.111920833587647,5.580391883850098,5.736581087112428,7.497943639755249,39.17783188819885,27948898 14 | 12,7.849193811416626,6.416064977645874,5.17004656791687,6.878835916519165,41.43364119529724,30399766 15 | 13,7.180546760559082,6.605871915817263,6.033704519271852,7.945231914520264,43.03808045387268,32848393 16 | 14,8.542397260665894,6.9637413024902335,6.764723062515259,6.6079912185668945,43.448081731796265,35299228 17 | 15,6.880044460296631,7.377933502197266,7.637020826339723,7.429396629333496,42.84182572364807,37749079 18 | 16,6.994188547134399,7.831934452056885,7.783137798309326,7.6338841915130615,44.403237104415894,40197676 19 | 17,11.496174812316895,9.800821781158447,7.58228325843811,7.516530275344849,53.62795400619507,42648904 20 | 18,7.79708194732666,9.787442684173584,7.138163089752197,8.8120014667511,50.9351236820221,45099701 21 | 19,8.43359112739563,9.523060560226439,13.3839693069458,7.699332237243652,55.89987015724182,47551011 22 | 20,7.566620349884032,9.683169603347777,8.543299436569214,7.9015352725982675,48.8129768371582,50000000 23 | 0,7.190418481826782,0.707650899887085,2.28092622756958,8.264355659484862,29.746066331863407,1000052 24 | 1,7.42564058303833,1.5093376636505127,1.7751114368438718,7.884857416152954,31.61107087135315,3449012 25 | 2,9.563541889190674,1.6901743412017822,2.4463682174682617,10.842523097991943,37.391902208328254,5900034 26 | 3,6.786839962005615,2.04491662979126,3.031137704849243,8.219735622406006,31.60140872001648,8349078 27 | 4,6.9660797119140625,2.544528245925904,3.1702497005462646,6.7007696628570566,30.183650732040405,10797805 28 | 5,6.701206207275392,3.25626277923584,3.102719783782959,8.533268451690674,32.68921160697937,13247233 29 | 6,6.420436382293701,3.4489002227783203,3.6940829753875732,8.104418039321901,33.3744637966156,15697274 30 | 7,7.689843416213987,3.54607367515564,4.6508872509002686,6.960585594177246,35.21638894081116,18147315 31 | 8,7.622220993041992,4.401278734207153,4.334021091461182,7.482393980026245,37.205557823181145,20600284 32 | 9,7.009465456008911,4.764593839645387,4.193192958831787,6.897244215011598,36.11890935897827,23048885 33 | 10,7.099967002868652,5.215404510498047,4.7635498046875,6.3894689083099365,36.10999631881714,25499101 34 | 11,6.69516921043396,5.744151592254639,6.602449655532838,6.736882925033568,38.676912784576416,27948933 35 | 12,7.4403440952301025,6.7572083473205575,6.003504753112793,7.172186851501465,40.154664516448975,30401023 36 | 13,6.998932123184204,6.20236349105835,5.672366619110107,7.1271512508392325,39.302847623825066,32851486 37 | 14,6.915180683135986,7.822373628616332,6.337413787841798,7.708374500274657,42.24116182327271,35302731 38 | 15,7.066570281982423,8.524396419525146,6.580228328704834,6.166517734527588,41.87582015991211,37752524 39 | 16,6.589472532272339,7.94525933265686,7.2967987060546875,6.8301634788513175,42.06322407722472,40202517 40 | 17,7.117594242095947,8.471821308135986,7.566959619522095,7.366473436355591,43.776485204696655,42652356 41 | 18,7.089627504348755,8.872769355773926,9.707370281219482,7.241292476654053,46.907626152038574,45099245 42 | 19,8.168758392333983,8.996371507644652,7.879936456680298,7.036686420440674,46.9917676448822,47551931 43 | 20,7.7413179874420175,9.81250810623169,8.90994095802307,8.575214385986326,50.75287199020386,50000000 44 | 0,14.046154499053955,1.1173858642578125,3.2084243297576904,13.64061188697815,53.393492698669434,999235 45 | 1,11.949691534042358,1.360381841659546,3.777711153030397,13.68442177772522,49.955246925354004,3446144 46 | 2,12.614464521408081,2.5385425090789795,3.680132150650024,9.67790412902832,48.45165801048279,5896817 47 | 3,7.58543848991394,2.63399624824524,3.206125497817993,7.958815097808838,35.34187650680542,8347080 48 | 4,8.984705686569212,2.444612979888916,3.4341411590576167,8.198475122451782,35.859378814697266,10796748 49 | 5,8.170023202896116,3.02512526512146,3.568270444869995,9.304476022720335,37.61494064331055,13245268 50 | 6,8.56131911277771,3.1345784664154053,3.905285120010376,7.664520978927612,37.156663179397576,15695938 51 | 7,7.557111024856567,3.866744756698608,4.302062511444092,7.748353481292725,37.79162979125977,18148543 52 | 8,7.450805902481079,4.136418581008911,4.425943851470947,8.722140789031982,38.14530491828919,20598543 53 | 9,8.626715898513794,4.416030645370483,6.006159543991089,8.419605731964111,42.09596753120423,23050550 54 | 10,8.417001008987429,5.234817981719972,4.37753963470459,8.036226987838745,41.50021719932556,25502173 55 | 11,7.920273303985598,5.881896018981934,5.305891990661621,8.453716993331911,46.445900678634644,27950737 56 | 12,8.933857202529909,6.833164215087892,7.10662579536438,7.8169307708740225,46.64496946334839,30400600 57 | 13,8.500087738037111,7.582361221313477,7.917228937149048,8.558687925338745,49.47708082199097,32849515 58 | 14,8.397745609283447,6.668765306472777,6.310772657394409,8.957489013671875,46.73683547973633,35301288 59 | 15,8.917690038681028,7.176633358001709,7.5546300411224365,8.910341739654541,49.477648973464966,37750324 60 | 16,9.176238298416136,7.644198417663574,7.773294448852539,8.591369390487671,50.09724831581116,40197372 61 | 17,9.06108808517456,9.202136516571043,9.213278055191038,9.451592922210693,54.75792002677918,42649165 62 | 18,9.47390842437744,9.417816400527952,7.360036134719849,8.879366874694824,53.35044598579407,45096715 63 | 19,9.15760898590088,9.921532869338991,10.288001775741575,9.495866537094118,56.78333520889282,47548777 64 | 20,16.184381008148193,9.112372875213623,8.641405582427979,9.695455789566038,61.253098726272576,50000000 65 | 0,6.146031141281128,0.6092917919158936,1.4211950302124023,5.874315500259399,23.75724244117737,1000900 66 | 1,6.513831615447997,1.0445129871368408,1.6927015781402588,5.854241371154785,25.26718783378601,3450816 67 | 2,6.894704580307008,1.874687910079956,2.5109894275665283,6.259218692779541,27.11310958862305,5902225 68 | 3,6.681285381317139,1.8760371208190918,2.541154146194458,6.62699556350708,28.388039350509644,8350918 69 | 4,6.579586744308473,2.4730172157287598,3.8537116050720215,6.489072561264037,30.003851413726807,10802310 70 | 5,7.6521430015563965,3.9225046634674072,5.195086479187012,9.061530828475954,40.26849174499512,13251873 71 | 6,8.345972299575807,3.4538462162017822,3.8924813270568848,9.229451656341553,40.28138494491577,15700939 72 | 7,8.912295579910277,3.58895206451416,5.256728410720825,9.334285974502565,42.195481300354,18152575 73 | 8,8.239080429077147,4.429729700088501,4.059387922286986,8.210181474685669,40.50559496879577,20603259 74 | 9,9.318784713745117,4.529031991958618,4.944356679916382,8.913965463638307,43.021955728530884,23055321 75 | 10,7.842668533325195,5.421630144119263,5.258700609207153,8.723150491714478,42.82481050491333,25502107 76 | 11,9.253255605697634,5.528379440307617,5.403992176055908,9.261732339859007,46.435377836227424,27951857 77 | 12,8.715843677520754,6.042918205261231,6.0950984954833975,10.580655574798584,48.58100295066834,30401979 78 | 13,10.208436012268066,6.864079475402832,6.1206886768341064,9.028311014175417,49.34014439582825,32852565 79 | 14,9.33082103729248,7.135979652404785,6.008606195449829,9.459828853607178,49.91624927520752,35301288 80 | 15,9.565973281860352,8.052979946136475,6.41303563117981,9.887627601623537,51.34547996520996,37751471 81 | 16,10.069446802139284,7.863379001617432,10.273854017257689,10.316685199737547,56.69612789154053,40200631 82 | 17,9.850059986114502,8.273949146270754,6.978187322616577,9.329088449478151,53.352537870407104,42651732 83 | 18,10.059386253356934,8.744035959243773,8.563298940658571,15.396541357040405,69.58081459999084,45102261 84 | 19,14.43708348274231,9.591934204101562,8.743757486343384,15.680865287780762,75.4746916294098,47551779 85 | 20,14.997735500335693,9.706062555313109,9.138372182846071,14.083873987197876,74.37188625335693,50000000 86 | 0,8.0369713306427,0.6770808696746826,1.845033884048462,8.336313486099241,30.60119819641113,998952 87 | 1,7.773632764816284,1.0916354656219482,2.2429637908935547,8.532743692398071,32.707759380340576,3451097 88 | 2,7.750965356826782,1.7475388050079346,3.01460862159729,7.785831451416016,32.291831016540534,5901157 89 | 3,9.205931425094603,2.2260079383850098,3.6819036006927486,8.288788080215454,39.349198341369636,8352896 90 | 4,8.438814640045166,2.592901468276977,3.4753494262695312,7.9989845752716064,36.04001092910767,10802364 91 | 5,7.781528472900391,3.16336727142334,4.267394781112672,8.632305383682251,39.529292345047004,13254049 92 | 6,7.7201831340789795,3.8348591327667236,4.173544406890868,8.322942733764647,37.802941083908074,15702498 93 | 7,7.653884172439575,4.125317096710205,4.42400050163269,8.175558090209961,39.1270318031311,18156231 94 | 8,9.750343084335329,4.440631151199341,4.769840717315674,8.377973794937134,42.31133937835693,20606681 95 | 9,8.555236101150513,5.129318952560425,5.573054552078247,8.976478576660156,43.04151272773743,23057180 96 | 10,9.08205509185791,5.223078727722168,5.408283948898315,8.495746850967407,43.826560258865356,25507343 97 | 11,8.577333211898804,6.8158326148986825,6.167778730392456,9.532485723495485,47.65059852600098,27957800 98 | 12,10.51621079444885,6.293447494506836,5.392074823379517,11.002318143844604,49.65189242362976,30407546 99 | 13,9.719135522842409,6.897059202194214,8.705862045288086,8.648547649383545,49.54546356201172,32858175 100 | 14,8.366340398788452,7.308839082717895,7.413874626159668,8.331423759460451,48.2678542137146,35307198 101 | 15,8.730891704559326,7.683967351913452,7.198543071746826,9.149252891540527,49.236247062683105,37756384 102 | 16,10.17529797554016,7.613230228424072,8.572442531585692,8.933055639266966,51.33104181289673,40205103 103 | 17,9.738165616989138,8.498561859130861,7.544682502746582,8.592431545257567,51.26230978965759,42652737 104 | 18,8.937159538269043,9.955829858779909,8.697373151779175,9.49044942855835,56.309436082839966,45101848 105 | 19,9.159666299819945,8.806410551071169,8.19353199005127,9.085792779922484,53.98657464981079,47550366 106 | 20,8.804518222808838,9.738995313644411,10.185393571853636,9.620959997177122,56.602333784103394,50000000 107 | -------------------------------------------------------------------------------- /demo/requirements.txt: -------------------------------------------------------------------------------- 1 | rs_datasets -------------------------------------------------------------------------------- /demo/synthetic_gen_time.csv: -------------------------------------------------------------------------------- 1 | n_samples,"time, s" 2 | 100000.0,38.741114616394036 3 | 595000.0,192.2492127418518 4 | 1090000.0,311.1856393814087 5 | 1585000.0,397.0943524837494 6 | 2080000.0,553.0572288036346 7 | 2575000.0,670.9549901485442 8 | 3070000.0,723.5816826820374 9 | 3565000.0,856.1593241691589 10 | 4060000.0,977.9992072582244 11 | 4555000.0,1063.0632655620577 12 | 5050000.0,1089.8609247207644 13 | 5545000.0,1322.955791473389 14 | 6040000.0,1417.3582293987274 15 | 6535000.0,1430.8968422412872 16 | 7030000.0,1486.5244340896606 17 | 7525000.0,1647.060962677002 18 | 8020000.0,1754.9148001670835 19 | 8515000.0,1840.006804227829 20 | 9010000.0,1771.7790541648865 21 | 9505000.0,1953.510559797287 22 | 10000000.0,2095.720278024673 23 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | gh-deploy: 23 | @make html 24 | @ghp-import _build/html -p -o -n -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../..')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'Sim4Rec' 21 | copyright = '2022, Sberbank AI Laboratory' 22 | author = 'Alexey Vasilev, Anna Volodkevich, Andrey Gurov, Elizaveta Stavinova, Anton Lysenko' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = '0.0.1' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'sphinx.ext.autodoc' 35 | ] 36 | 37 | autoclass_content = 'both' 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | # This pattern also affects html_static_path and html_extra_path. 45 | exclude_patterns = [] 46 | 47 | 48 | # -- Options for HTML output ------------------------------------------------- 49 | 50 | # The theme to use for HTML and HTML Help pages. See the documentation for 51 | # a list of builtin themes. 52 | # 53 | html_theme = 'sphinx_rtd_theme' 54 | 55 | # Add any paths that contain custom static files (such as style sheets) here, 56 | # relative to this directory. They are copied after the builtin static files, 57 | # so a file named "default.css" will overwrite the builtin "default.css". 58 | # html_static_path = ['_static'] 59 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Simulator documentation master file, created by 2 | sphinx-quickstart on Mon Sep 19 11:23:53 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Sim4Rec's documentation! 7 | ===================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | pages/modules 14 | pages/response 15 | pages/utils 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /docs/source/pages/modules.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ======= 3 | 4 | .. automodule:: sim4rec.modules 5 | 6 | 7 | Generators 8 | __________ 9 | 10 | Generators serves for generating synthetic either real data for simulation process. 11 | All of the generators are derived from ``GeneratorBase`` base class and to implement 12 | your own generator you must inherit from it. Basicly, the generator fits from a provided 13 | dataset (in case of real generator it just remembers it), than it creates a population to 14 | sample from with a number of rows by calling the ``generate()`` method and samples from 15 | this those population with ``sample()`` method. Note, that ``sample()`` takes the fraction 16 | of population size. 17 | 18 | If a user is interested in using multiple generators at ones (e.g. modelling multiple groups 19 | of users or mixing results from different generating models) that it will be useful to look 20 | at ``CompositeGenerator`` which can handle a list of generators has a proportion mixing parameter 21 | which controls the weights of particular generators at sampling and generating time. 22 | 23 | .. autoclass:: sim4rec.modules.GeneratorBase 24 | :members: 25 | 26 | .. autoclass:: sim4rec.modules.RealDataGenerator 27 | :members: 28 | 29 | .. autoclass:: sim4rec.modules.SDVDataGenerator 30 | :members: 31 | 32 | .. autoclass:: sim4rec.modules.CompositeGenerator 33 | :members: 34 | 35 | 36 | Embeddings 37 | __________ 38 | 39 | Embeddings can be utilized in case of high dimensional data or high data sparsity and it should be 40 | applied before performing the main simulator pipeline. Here the autoencoder estimator and transformer 41 | are implemented in the chance of the existing spark methods are not enough for your propose. The usage 42 | example can be found in `notebooks` directory. 43 | 44 | .. autoclass:: sim4rec.modules.EncoderEstimator 45 | :members: 46 | 47 | .. autoclass:: sim4rec.modules.EncoderTransformer 48 | :members: 49 | 50 | 51 | Items selectors 52 | _______________ 53 | 54 | Those spark transformers are used to assign items to given users while making the candidate 55 | pairs for prediction by recommendation system. It is optional to use selector in your pipelines 56 | if a recommendation algorithm can recommend any item with no restrictions. The one should 57 | implement own selector in case of custom logic of items selection for a certain users. Selector 58 | could implement some business rules (e.g. some items are not available for a user), be a simple 59 | recommendation model, generating candidates, or create user-specific items (e.g. price offers) 60 | online. To implement your custom selector, you can derive from a ``ItemSelectionEstimator`` base 61 | class, to implement some pre-calculation logic, and ``ItemSelectionTransformer`` to perform pairs 62 | creation. Both classes are inherited from spark's Estimator and Transformer classes and to define 63 | fit() and transform() methods the one can just overwrite ``_fit()`` and ``_transform()``. 64 | 65 | .. autoclass:: sim4rec.modules.ItemSelectionEstimator 66 | :members: 67 | 68 | .. autoclass:: sim4rec.modules.ItemSelectionTransformer 69 | :members: 70 | 71 | .. autoclass:: sim4rec.modules.CrossJoinItemEstimator 72 | :members: 73 | :private-members: _fit 74 | 75 | .. autoclass:: sim4rec.modules.CrossJoinItemTransformer 76 | :members: 77 | :private-members: _transform 78 | 79 | 80 | Simulator 81 | _________ 82 | 83 | The simulator class provides a way to handle the simulation process by connecting different 84 | parts of the module such as generatos and response pipelines and saving the results to a 85 | given directory. The simulation process consists of the following steps: 86 | 87 | - Sampling random real or synthetic users 88 | - Creation of candidate user-item pairs for a recommendation algorithm 89 | - Prediction by a recommendation system 90 | - Evaluating respones on a given recommendations 91 | - Updating the history log 92 | - Metrics evaluation 93 | - Refitting the recommendation model with a new data 94 | 95 | Some of the steps can be skipped depending on the task your perform. For example you don't need 96 | a second step if your algorithm dont use user-item pairs as an input or you don't need to refit 97 | the model if you want just to evaluate it on some data. For more usage please refer to examples 98 | 99 | .. autoclass:: sim4rec.modules.Simulator 100 | :members: 101 | 102 | 103 | Evaluation 104 | __________ 105 | 106 | .. autoclass:: sim4rec.modules.EvaluateMetrics 107 | :members: 108 | :special-members: __call__ 109 | 110 | .. autofunction:: sim4rec.modules.evaluate_synthetic 111 | 112 | .. autofunction:: sim4rec.modules.ks_test 113 | 114 | .. autofunction:: sim4rec.modules.kl_divergence 115 | -------------------------------------------------------------------------------- /docs/source/pages/response.rst: -------------------------------------------------------------------------------- 1 | Response functions 2 | ================== 3 | 4 | .. automodule:: sim4rec.response 5 | 6 | 7 | Response functions are used to model users behaviour on items to simulate any 8 | kind of reaction that user perform. For example it can be binary classification 9 | model that determines whether user clicked on a item, or rating, that user gave 10 | to an item. 11 | 12 | 13 | Base classes 14 | ____________ 15 | 16 | All of the existing response functions are made on a top of underlying base classes 17 | ``ActionModelEstimator`` and ``ActionModelTransformer`` which follows the logic of 18 | spark's Estimator and Transformer classes. To implement custom response function 19 | the one can inherit from base classes: 20 | 21 | * ``ActionModelEstimator`` if any learning logic is necessary 22 | * ``ActionModelTransformer`` for performing model infering 23 | 24 | Base classes are inherited from spark's Estimator and Transformer and to define fit() 25 | or transform() logic the one should overwrite ``_fit()`` and ``_transform()`` respectively. 26 | Note, that those base classes are useful to implement your own response function, but are not 27 | necessary, and to create a response pipeline any proper spark's estimators/transformers can be used 28 | 29 | .. autoclass:: sim4rec.response.ActionModelEstimator 30 | :members: 31 | 32 | .. autoclass:: sim4rec.response.ActionModelTransformer 33 | :members: 34 | 35 | 36 | Response functions 37 | __________________ 38 | 39 | .. autoclass:: sim4rec.response.ConstantResponse 40 | :members: 41 | 42 | .. autoclass:: sim4rec.response.NoiseResponse 43 | :members: 44 | 45 | .. autoclass:: sim4rec.response.CosineSimilatiry 46 | :members: 47 | 48 | .. autoclass:: sim4rec.response.BernoulliResponse 49 | :members: 50 | 51 | .. autoclass:: sim4rec.response.ParametricResponseFunction 52 | :members: 53 | -------------------------------------------------------------------------------- /docs/source/pages/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ===== 3 | 4 | .. automodule:: sim4rec.utils 5 | 6 | 7 | Dataframe convertation 8 | ______________________ 9 | 10 | .. autofunction:: sim4rec.utils.pandas_to_spark 11 | 12 | 13 | Session 14 | ______________________ 15 | 16 | .. autofunction:: sim4rec.utils.session_handler.get_spark_session 17 | 18 | .. autoclass:: sim4rec.utils.session_handler.State 19 | :members: 20 | 21 | Exceptions 22 | __________ 23 | 24 | .. autoclass:: sim4rec.utils.NotFittedError 25 | :members: 26 | 27 | .. autoclass:: sim4rec.utils.EmptyDataFrameError 28 | :members: 29 | 30 | 31 | Transformers 32 | ____________ 33 | 34 | .. autoclass:: sim4rec.utils.VectorElementExtractor 35 | :members: 36 | 37 | 38 | File management 39 | _______________ 40 | 41 | .. autofunction:: sim4rec.utils.save 42 | .. autofunction:: sim4rec.utils.load 43 | -------------------------------------------------------------------------------- /experiments/Amazon/RU_train_test_split.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6f3723f0-15da-4e41-baff-d1fac30b6da1", 6 | "metadata": {}, 7 | "source": [ 8 | "# Иллюстрация разбиения датасета Amazon на train/test/split\n", 9 | "
\n", 10 | "

Предобработки признаков для датасета Amazon, используемая здесь вынесена в \n", 11 | " файл,\n", 12 | " иллюстрация её работы продемонстрирована в ноутбуке.

\n", 14 | "
" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "id": "8854aa20-25e6-4468-935b-46fbb1dfb0a3", 20 | "metadata": {}, 21 | "source": [ 22 | "### $\\textbf{Содержание}$:\n", 23 | "\n", 24 | "### $\\textbf{I. Загрузка данных }$\n", 25 | "#### - Чтение данных с диска;\n", 26 | "---" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "d9d4aafa-b8c5-4e7d-b8c0-3a79ea3fae26", 32 | "metadata": {}, 33 | "source": [ 34 | "### $\\textbf{II. Разбиение данных для эксперимента}$\n", 35 | "### Для разбиения данных на $\\it{train/test/split}$ производится деление исходного датасета *df_rating* по квантилям атрибута $\\it{timestamp}$, $\\mathbb{q}$ для генерации признаков:\n", 36 | "#### $\\it{rating}_{t}$ = *df_rating*$[0, \\mathbb{q}_{t}]$, где $\\mathbb{q}_{train}=0.5$, $\\mathbb{q}_{val}=0.75$, $\\mathbb{q}_{test}=1$:\n", 37 | "#### - $\\it{rating}_{train}$ = *df_rating*$[0, 0.5]$;\n", 38 | "#### - $\\it{rating}_{val}$ = *df_rating*$[0, 0.75]$;\n", 39 | "#### - $\\it{rating}_{test}$ = *df_rating*$[0, 1]$;\n", 40 | "### Далее для каждого из промежутков {$\\it{rating}_{train}$, $\\it{rating}_{val}$, $\\it{rating}_{test}$} генерируются соответствующие им признаки пользователей и предложений:\n", 41 | "#### - $\\it{items}_{t}$, $\\it{users}_{t}$, $\\it{rating}_{t}$ = data_processing(movies, $\\it{rating}_{t}$, tags), $t \\in \\{\\it{train}, \\it{val}, \\it{test}\\}$;\n", 42 | "### После чего формируются окончательные рейтинги:\n", 43 | "#### - $\\it{rating}_{train}$ = $\\it{rating}_{train}$ = *df_rating*$[0, 0.5]$;\n", 44 | "#### - $\\it{rating}_{val}$ = $\\it{rating}_{val}$[$\\mathbb{q}>\\mathbb{q}_{train}$] = *df_rating*$(0.5, 0.75]$;\n", 45 | "#### - $\\it{rating}_{test}$ = $\\it{rating}_{test}$[$\\mathbb{q}>\\mathbb{q}_{val}$] = *df_rating*$(0.75, 1]$;\n", 46 | "\n", 47 | "
\n", 48 | "

То есть, если для генерации признаков для валидационного набора данных мы используем временные метки с 0 по 0.75 квантиль, то в качестве рейтингов мы возьмем оценки\n", 49 | " только с 0.5 по 0.75 квантили. Аналогично для тестового набора: все временные метки для генерации признаков, но в качестве рейтингов только оценки с 0.75 по 1\n", 50 | " квантили.

\n", 51 | "
\n", 52 | "
" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "0a23f90f-056c-44c3-b18c-13e8064a4bac", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stderr", 63 | "output_type": "stream", 64 | "text": [ 65 | "[nltk_data] Downloading package punkt to\n", 66 | "[nltk_data] /data/home/agurov/nltk_data...\n", 67 | "[nltk_data] Package punkt is already up-to-date!\n", 68 | "[nltk_data] Downloading package stopwords to\n", 69 | "[nltk_data] /data/home/agurov/nltk_data...\n", 70 | "[nltk_data] Package stopwords is already up-to-date!\n", 71 | "[nltk_data] Downloading package words to\n", 72 | "[nltk_data] /data/home/agurov/nltk_data...\n", 73 | "[nltk_data] Package words is already up-to-date!\n", 74 | "/data/home/agurov/.conda/envs/sber3.8/lib/python3.8/site-packages/pyspark/sql/pandas/functions.py:394: UserWarning: In Python 3.6+ and Spark 3.0+, it is preferred to specify type hints for pandas UDF instead of specifying pandas UDF type which will be deprecated in the future releases. See SPARK-28264 for more details.\n", 75 | " warnings.warn(\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "import pandas as pd\n", 81 | "import numpy as np\n", 82 | "import re\n", 83 | "import itertools\n", 84 | "import tqdm\n", 85 | "\n", 86 | "from amazon import data_processing, get_spark_session" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "91a4cb15-e49f-4ae8-a569-14703cf9f7a1", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "file_name_parquet = \"all.parquet\"\n", 97 | "save_path = \"hdfs://namenode:9000/Sber_data/Amazon/final_data\"\n", 98 | "\n", 99 | "spark = get_spark_session(1)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "4625524c-b878-4dc8-84bc-b41d7ffddd1c", 105 | "metadata": {}, 106 | "source": [ 107 | "I. Загрузка данных" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "6aec812a-a519-409e-9258-5d331cc9633c", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "df = pd.read_parquet(file_name_parquet)\n", 118 | "df_sp = spark.createDataFrame(df)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "id": "d09161ae-21f5-4a78-8c42-376c20009a4a", 124 | "metadata": {}, 125 | "source": [ 126 | "### II. Разбиение данных для эксперимента" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "648d7c9e-b0b4-4bdc-9878-a0e46b5832bd", 132 | "metadata": {}, 133 | "source": [ 134 | "### Разбиение df_rating на train/test/validation части по квантилям timestamp:\n", 135 | "#### - train [0, 0.5]\n", 136 | "#### - validation [0, 0.75]\n", 137 | "#### - test [0, 1.]" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 3, 143 | "id": "5dbd2348-4188-4a68-8fe6-b0d463703e0b", 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "Quantile: 0.5 - 1189555200.0, 0.75 - 1301443200.0\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "q_50, q_75 = df_sp.approxQuantile(\"timestamp\", [0.5, 0.75], 0)\n", 156 | "print(f\"Quantile: 0.5 - {q_50}, 0.75 - {q_75}\")" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "id": "fc1aa2c3-29fe-4278-9b33-08014ede15da", 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "df_train = df_procc.filter(sf.col(\"timestamp\") <= q_50)\n", 167 | "df_train.cache()" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "41174c08-956e-47a6-881c-1164a6f4f5e0", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "df_val = df_procc.filter(sf.col(\"timestamp\") <= q_75)\n", 178 | "df_val.cache()" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "id": "0fd08ad6-4d4f-489d-b530-a2a4ce26fc26", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "df_test = df_procc\n", 189 | "df_test.cache()" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "id": "03ef300d-d957-443c-b8f9-413354e574ea", 195 | "metadata": {}, 196 | "source": [ 197 | "#### Генерация признаков по временным промежуткам" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "id": "427898bd-811b-4c96-b85a-0003929bfc69", 203 | "metadata": {}, 204 | "source": [ 205 | "#### Train data" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "id": "5581f3cd-2ee3-4dae-bcba-bc227f319463", 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "df_users_train, df_items_train, df_rating_train = data_processing(df_train)\n", 216 | "\n", 217 | "df_items_train.write.parquet(os.path.join(save_path, r'train/items.parquet'))\n", 218 | "df_users_train.write.parquet(os.path.join(save_path, r'train/users.parquet'))\n", 219 | "df_rating_train.write.parquet(os.path.join(save_path, r'train/rating.parquet'))\n", 220 | "\n", 221 | "df_train.unpersist()" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "id": "72abd8d2-eec1-4aab-9de5-51da78fe1530", 227 | "metadata": {}, 228 | "source": [ 229 | "#### Validation data" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "id": "0df22394-6d7f-4499-bedd-2d35399de706", 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "df_users_val, df_items_val, df_rating_val = data_processing(df_val)\n", 240 | "\n", 241 | "df_rating_val = df_rating_val.filter(sf.col(\"timestamp\") > q_50)\n", 242 | "df_items_val.write.parquet(os.path.join(save_path, r'val/items.parquet'))\n", 243 | "df_users_val.write.parquet(os.path.join(save_path, r'val/users.parquet'))\n", 244 | "df_rating_val.write.parquet(os.path.join(save_path, r'val/rating.parquet'))\n", 245 | "\n", 246 | "df_val.unpersist()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "id": "31ebc6f5-4f65-4b99-a261-dbf02ea7d132", 252 | "metadata": {}, 253 | "source": [ 254 | "#### Test data" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "id": "4fa40c1f-c8ce-48c2-9d79-e58abc78752c", 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "df_users_test, df_items_test, df_rating_test = data_processing(df_test)\n", 265 | "\n", 266 | "df_rating_test = df_rating_test.filter(sf.col(\"timestamp\") > q_75)\n", 267 | "df_items_test.write.parquet(os.path.join(save_path, r'val/items.parquet'))\n", 268 | "df_users_test.write.parquet(os.path.join(save_path, r'val/users.parquet'))\n", 269 | "df_rating_test.write.parquet(os.path.join(save_path, r'val/rating.parquet'))\n", 270 | "\n", 271 | "df_test.unpersist()" 272 | ] 273 | } 274 | ], 275 | "metadata": { 276 | "kernelspec": { 277 | "display_name": "Python [conda env:.conda-sber3.8]", 278 | "language": "python", 279 | "name": "conda-env-.conda-sber3.8-py" 280 | }, 281 | "language_info": { 282 | "codemirror_mode": { 283 | "name": "ipython", 284 | "version": 3 285 | }, 286 | "file_extension": ".py", 287 | "mimetype": "text/x-python", 288 | "name": "python", 289 | "nbconvert_exporter": "python", 290 | "pygments_lexer": "ipython3", 291 | "version": "3.8.13" 292 | } 293 | }, 294 | "nbformat": 4, 295 | "nbformat_minor": 5 296 | } 297 | -------------------------------------------------------------------------------- /experiments/Amazon/amazon.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"]="" 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import requests 7 | import gzip 8 | import json 9 | import re 10 | 11 | import matplotlib.pyplot as plt 12 | from tqdm.notebook import tqdm 13 | 14 | from typing import List, Optional 15 | 16 | from bs4 import BeautifulSoup 17 | from nltk.tokenize import TreebankWordTokenizer, WhitespaceTokenizer 18 | 19 | import nltk 20 | nltk.download('punkt') 21 | nltk.download('stopwords') 22 | nltk.download('words') 23 | words = set(nltk.corpus.words.words()) 24 | words = set([w.lower() for w in words]) 25 | from nltk.stem import PorterStemmer, WordNetLemmatizer 26 | from nltk.corpus import stopwords 27 | stop_words = set(stopwords.words("english")) 28 | from nltk.tokenize import sent_tokenize 29 | 30 | import gensim 31 | from gensim.downloader import load 32 | from gensim.models import Word2Vec 33 | w2v_model = gensim.downloader.load('word2vec-google-news-300') 34 | 35 | import pyspark 36 | from pyspark.sql.types import * 37 | from pyspark import SparkConf 38 | from pyspark.sql import SparkSession 39 | from pyspark.sql import functions as sf 40 | from pyspark.ml.feature import Tokenizer, RegexTokenizer, StopWordsRemover 41 | from pyspark.ml import Pipeline 42 | from pyspark.sql.functions import expr 43 | 44 | 45 | def get_spark_session( 46 | mode 47 | ) -> SparkSession: 48 | """ 49 | The function creates spark session 50 | :param mode: session mode 51 | :type mode: int 52 | :return: SparkSession 53 | :rtype: SparkSession 54 | """ 55 | if mode == 1: 56 | 57 | SPARK_MASTER_URL = 'spark://spark:7077' 58 | SPARK_DRIVER_HOST = 'jupyterhub' 59 | 60 | conf = SparkConf().setAll([ 61 | ('spark.master', SPARK_MASTER_URL), 62 | ('spark.driver.bindAddress', '0.0.0.0'), 63 | ('spark.driver.host', SPARK_DRIVER_HOST), 64 | ('spark.driver.blockManager.port', '12346'), 65 | ('spark.driver.port', '12345'), 66 | ('spark.driver.memory', '8g'), #4 67 | ('spark.driver.memoryOverhead', '2g'), 68 | ('spark.executor.memory', '14g'), #14 69 | ('spark.executor.memoryOverhead', '2g'), 70 | ('spark.app.name', 'simulator'), 71 | ('spark.submit.deployMode', 'client'), 72 | ('spark.ui.showConsoleProgress', 'true'), 73 | ('spark.eventLog.enabled', 'false'), 74 | ('spark.logConf', 'false'), 75 | ('spark.network.timeout', '10000000'), 76 | ('spark.executor.heartbeatInterval', '10000000'), 77 | ('spark.sql.shuffle.partitions', '4'), 78 | ('spark.default.parallelism', '4'), 79 | ("spark.kryoserializer.buffer","1024"), 80 | ('spark.sql.execution.arrow.pyspark.enabled', 'true'), 81 | ('spark.rpc.message.maxSize', '1000'), 82 | ("spark.driver.maxResultSize", "2g") 83 | ]) 84 | spark = SparkSession.builder\ 85 | .config(conf=conf)\ 86 | .getOrCreate() 87 | 88 | elif mode == 0: 89 | spark = SparkSession.builder\ 90 | .appName('simulator')\ 91 | .master('local[4]')\ 92 | .config('spark.sql.shuffle.partitions', '4')\ 93 | .config('spark.default.parallelism', '4')\ 94 | .config('spark.driver.extraJavaOptions', '-XX:+UseG1GC')\ 95 | .config('spark.executor.extraJavaOptions', '-XX:+UseG1GC')\ 96 | .config('spark.sql.autoBroadcastJoinThreshold', '-1')\ 97 | .config('spark.sql.execution.arrow.pyspark.enabled', 'true')\ 98 | .getOrCreate() 99 | 100 | return spark 101 | 102 | def clean_text(text: str) -> str: 103 | 104 | """ 105 | Cleaning and preprocessing of the tags text with help of regular expressions. 106 | :param text: initial text 107 | :type text: str 108 | :return: cleaned text 109 | :rtype: str 110 | """ 111 | 112 | text = re.sub("[^a-zA-Z]", " ",text) 113 | text = re.sub(r"\s+", " ", text) 114 | text = re.sub(r"\s+$", "", text) 115 | text = re.sub(r"^\s+", "", text) 116 | text = text.lower() 117 | 118 | return text 119 | 120 | 121 | def string_embedding(arr: list) -> np.ndarray: 122 | """ 123 | Processing each word in the string with word2vec and return their aggregation (mean). 124 | 125 | :param arr: words 126 | :type text: List[str] 127 | :return: average vector of word2vec words representations 128 | :rtype: np.ndarray 129 | """ 130 | 131 | vec = 0 132 | cnt = 0 133 | for i in arr: 134 | try: 135 | vec += w2v_model[i] 136 | cnt += 1 137 | except: 138 | pass 139 | if cnt == 0: 140 | vec = np.zeros((300,)) 141 | else: 142 | vec /= cnt 143 | return vec 144 | 145 | @sf.pandas_udf(StringType(), sf.PandasUDFType.SCALAR) 146 | def clean_udf(str_series): 147 | """ 148 | pandas udf of the clean_text function 149 | """ 150 | result = [] 151 | for x in str_series: 152 | x_procc = clean_text(x) 153 | result.append(x_procc) 154 | return pd.Series(result) 155 | 156 | @sf.pandas_udf(ArrayType(DoubleType()), sf.PandasUDFType.SCALAR) 157 | def embedding_udf(str_series): 158 | """ 159 | pandas udf of the string_embedding function 160 | """ 161 | result = [] 162 | for x in str_series: 163 | x_procc = string_embedding(x) 164 | result.append(x_procc) 165 | return pd.Series(result) 166 | 167 | 168 | def data_processing(df_sp: pyspark.sql.DataFrame): 169 | df_procc = df.withColumnRenamed("item_id", "item_idx")\ 170 | .withColumnRenamed("user_id", "user_idx")\ 171 | .withColumn("review_clean", clean_udf(sf.col("review"))).drop("review", "__index_level_0__") 172 | df_procc = tokenizer.transform(df_procc).drop("review_clean") 173 | df_procc = remover.transform(df_procc).drop("tokens") 174 | df_procc = df_procc.withColumn("embedding", embedding_udf(sf.col("tokens_clean"))).drop("tokens_clean") 175 | 176 | df_items = df_sp.groupby("item_idx").agg(sf.array(*[sf.mean(sf.col("embedding")[i]) for i in range(300)]).alias("embedding"), 177 | sf.mean("helpfulness").alias("helpfulness"), 178 | sf.mean("score").alias("rating_avg"), 179 | sf.count("score").alias("rating_cnt")) 180 | 181 | df_users = df_sp.groupby("user_idx").agg(sf.array(*[sf.mean(sf.col("embedding")[i]) for i in range(300)]).alias("embedding"), 182 | sf.mean("helpfulness").alias("helpfulness"), 183 | sf.mean("score").alias("rating_avg"), 184 | sf.count("score").alias("rating_cnt")) 185 | 186 | df_rating = df_sp.groupby("user_idx", "item_idx", "timestamp").agg(sf.mean("score").alias("relevance"), sf.count("score").alias("rating_cnt")) 187 | 188 | df_items = df_items.select(['item_idx', 'helpfulness', 'rating_avg', 'rating_cnt']+[expr('embedding[' + str(x) + ']') for x in range(0, 300)]) 189 | new_colnames = ['item_idx', 'helpfulness', 'rating_avg', 'rating_cnt'] + ['w2v_' + str(i) for i in range(0, 300)] 190 | df_items = df_items.toDF(*new_colnames) 191 | 192 | df_users = df_users.select(['user_idx', 'helpfulness', 'rating_avg', 'rating_cnt']+[expr('embedding[' + str(x) + ']') for x in range(0, 300)]) 193 | new_colnames = ['user_idx', 'helpfulness', 'rating_avg', 'rating_cnt'] + ['w2v_' + str(i) for i in range(0, 300)] 194 | df_users = df_users.toDF(*new_colnames) 195 | 196 | return [df_users, df_items, df_rating] 197 | -------------------------------------------------------------------------------- /experiments/Movielens/ml20.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import re 4 | import itertools 5 | import tqdm 6 | 7 | import seaborn as sns 8 | import tqdm 9 | import matplotlib.pyplot as plt 10 | 11 | from bs4 import BeautifulSoup 12 | from nltk.tokenize import TreebankWordTokenizer, WhitespaceTokenizer 13 | 14 | import nltk 15 | nltk.download('stopwords') 16 | nltk.download('words') 17 | words = set(nltk.corpus.words.words()) 18 | words = set([w.lower() for w in words]) 19 | 20 | from nltk.stem import PorterStemmer, WordNetLemmatizer 21 | nltk.download("wordnet") 22 | 23 | from nltk.corpus import stopwords 24 | stop_words = set(stopwords.words("english")) 25 | 26 | 27 | from nltk.tokenize import sent_tokenize 28 | 29 | import gensim 30 | from gensim.downloader import load 31 | from gensim.models import Word2Vec 32 | w2v_model = gensim.downloader.load('word2vec-google-news-300') 33 | 34 | from typing import Dict, List, Optional, Tuple 35 | 36 | 37 | def title_prep(title: str) -> str: 38 | 39 | """ 40 | The function of cleaning the title of the movie from extra spaces, reduction to lowercase.ch of methods to create 41 | vector embeddings from original data 42 | 43 | :param title: the title of the movie 44 | :type title: str 45 | :return: cleaned title of the movie 46 | :rtype: str 47 | """ 48 | 49 | title = re.sub(r'\s+', r' ', title) 50 | title = re.sub(r'($\s+|^\s+)', '', title) 51 | title = title.lower() 52 | 53 | return title 54 | 55 | def extract_year(title: str) -> Optional[str]: 56 | 57 | """ 58 | Extracting year from the movie title 59 | 60 | :param title: the cleaned title of the movie 61 | :type title: str 62 | :return: movie year 63 | :rtype: float, optional 64 | """ 65 | 66 | one_year = re.findall(r'\(\d{4}\)', title) 67 | two_years = re.findall(r'\(\d{4}-\d{4}\)', title) 68 | one_year_till_today = re.findall(r'\(\d{4}[-–]\s?\)', title) 69 | if len(one_year) == 1: 70 | return int(one_year[0][1:-1]) 71 | 72 | elif len(two_years) == 1: 73 | return round((int(two_years[0][1:5]) + int(two_years[0][6:-1]))/2) 74 | 75 | elif len(one_year_till_today) == 1: 76 | return int(one_year_till_today[0][1:5]) 77 | else: 78 | return np.nan 79 | 80 | def genres_processing(movies: pd.DataFrame) -> pd.DataFrame: 81 | 82 | """ 83 | Processing movie genres by constructing a binary vector of length n, where n is the number of all possible genres. 84 | For example a string like 'genre1|genre3|...' will be transformed into a vector [0,1,0,1,...]. 85 | 86 | :param movies: DataFrame with column 'genres' 87 | :type title: pd.DataFrame 88 | :return: DataFrame with processed genres 89 | :rtype: pd.DataFrame 90 | """ 91 | 92 | genre_lists = [set(item.split('|')).difference(set(['(no genres listed)'])) for item in movies['genres']] 93 | genre_lists = pd.DataFrame(genre_lists) 94 | 95 | genre_dict = {token: idx for idx, token in enumerate(set(itertools.chain.from_iterable([item.split('|') 96 | for item in movies['genres']])).difference(set(['(no genres listed)'])))} 97 | genre_dict = pd.DataFrame(genre_dict.items()) 98 | genre_dict.columns = ['genre', 'index'] 99 | 100 | dummy = np.zeros([len(movies), len(genre_dict)]) 101 | 102 | for i in range(dummy.shape[0]): 103 | for j in range(dummy.shape[1]): 104 | if genre_dict['genre'][j] in list(genre_lists.iloc[i, :]): 105 | dummy[i, j] = 1 106 | 107 | df_dummy = pd.DataFrame(dummy, columns = ['genre' + str(i) for i in range(dummy.shape[1])]) 108 | 109 | movies_return = pd.concat([movies, df_dummy], 1) 110 | return movies_return 111 | 112 | def fill_null_years(movies: pd.DataFrame) -> pd.DataFrame: 113 | 114 | """ 115 | Processing null years 116 | 117 | :param movies: DataFrame with processed years 118 | :type title: pd.DataFrame 119 | :return: DataFrame with processed not null years 120 | :rtype: pd.DataFrame 121 | """ 122 | 123 | df_movies = movies.copy() 124 | genres_columns = [item for item in movies.columns.tolist() if item[:5]=='genre' and item !='genres'] 125 | df_no_year = movies[movies.year.isna()][['movieId', *genres_columns]] 126 | 127 | years_mean = {} 128 | for i in df_no_year.index: 129 | 130 | row = np.asarray(df_no_year.loc[i, :][genres_columns]) 131 | years = [] 132 | for j in np.asarray(movies[['year', *genres_columns]]): 133 | if np.sum(row == j[1:]) == len(genres_columns): 134 | try: 135 | years.append(int(j[0])) 136 | except: 137 | pass 138 | 139 | years_mean[i] = round(np.mean(years)) 140 | 141 | for i in years_mean: 142 | df_movies.loc[i, 'year'] = years_mean[i] 143 | 144 | df_movies.year=df_movies.year.astype('int') 145 | 146 | return df_movies 147 | 148 | def clean_text(text: str) -> str: 149 | """ 150 | Cleaning text: remove extra spaces and non-text characters 151 | 152 | :param text: tag row text 153 | :type title: str 154 | :return: tag cleaned text 155 | :rtype: str 156 | """ 157 | 158 | text = re.sub("[^a-zA-Z]", " ",text) 159 | text = re.sub(r"\s+", " ", text) 160 | text = re.sub(r"\s+$", "", text) 161 | text = re.sub(r"^\s+", "", text) 162 | text = text.lower() 163 | 164 | return text 165 | 166 | 167 | def procces_text(text): 168 | """ 169 | Processing text: lemmatization, tokenization, removing stop-words 170 | 171 | :param text: tag cleaned text 172 | :type title: str 173 | :return: tag processed text 174 | :rtype: str 175 | """ 176 | lemmatizer = WordNetLemmatizer() 177 | 178 | text = [word for word in nltk.word_tokenize(text) if not word in stop_words] 179 | text = [lemmatizer.lemmatize(token) for token in text] 180 | text = [word for word in text if word in words] 181 | 182 | text = " ".join(text) 183 | 184 | return text 185 | 186 | def string_embedding(string: str) -> np.ndarray: 187 | """ 188 | Processing text: lemmatization, tokenization, removing stop-words 189 | 190 | :param string: cleaned and processed tags 191 | :type title: str 192 | :return: average vector of the string words embeddings 193 | :rtype: np.ndarray, optional 194 | """ 195 | 196 | arr = string.split(' ') 197 | vec = 0 198 | cnt = 0 199 | for i in arr: 200 | try: 201 | vec += w2v_model[i] 202 | cnt += 1 203 | except: 204 | pass 205 | if cnt == 0: 206 | vec = np.zeros((300, 1)) 207 | else: 208 | vec /= cnt 209 | return vec 210 | 211 | 212 | def data_processing(df_movie: pd.DataFrame, 213 | df_rating: pd.DataFrame, 214 | df_tags: pd.DataFrame 215 | ) -> List[pd.DataFrame]: 216 | 217 | print("------------------------ Movie processing ------------------------") 218 | #Extraction of the movies' years and transform genres lists to genres vector 219 | df_movies_procc = df_movie.copy() 220 | df_movies_procc.title = df_movies_procc.title.apply(title_prep) #title processing 221 | df_movies_procc['year'] = df_movies_procc.title.apply(extract_year) #year processing 222 | df_movies_procc = genres_processing(df_movies_procc) #genres processing 223 | df_movies_procc = fill_null_years(df_movies_procc) #fillimg null year values 224 | 225 | #Creating rating_avg column 226 | print("------------------------ Rating processing ------------------------") 227 | df_movies_procc = pd.merge(df_movies_procc, df_rating.groupby('movieId', as_index=False).rating.mean(), on='movieId', how='left') 228 | df_movies_procc.rating = df_movies_procc.rating.fillna(0.0) 229 | df_movies_procc = df_movies_procc.rename(columns={'rating' : 'rating_avg'}) 230 | df_movies_clean = df_movies_procc.drop(['title', 'genres'], axis=1)[['movieId', 'year', 'rating_avg', *['genre' + str(i) for i in range(19)]]] 231 | 232 | print("------------------------ Tags processing ------------------------") 233 | df_tags_ = df_tags.drop(df_tags[df_tags.tag.isna()].index) 234 | df_movie_tags = df_tags_.sort_values(by=['movieId', 'timestamp'])[['movieId', 'tag', 'timestamp']] 235 | df_movie_tags['clean_tag'] = df_movie_tags.tag.apply(lambda x : procces_text(clean_text(x))) 236 | df_movie_tags = df_movie_tags[df_movie_tags.clean_tag.str.len()!=0] 237 | 238 | print("------------------------ Tags embedding ------------------------") 239 | #tags text gathering 240 | docs_movie_tags = df_movie_tags.sort_values(["movieId", "timestamp"]).groupby("movieId", as_index=False).agg({"clean_tag":lambda x: " ".join(x)}) 241 | df_movies_tags = pd.concat([docs_movie_tags.movieId, pd.DataFrame(docs_movie_tags.clean_tag.apply(string_embedding).to_list(), columns = ['w2v_' + str(i) for i in range(300)])], axis = 1) 242 | df_movies_clean = pd.merge(df_movies_clean, df_movies_tags, on = "movieId", how = "left").fillna(0.0) 243 | 244 | print("------------------------ Users processing ------------------------") 245 | #users procc 246 | df_users = df_rating.copy() 247 | df_users = df_users.groupby(by=['userId'], as_index=False).rating.mean().rename(columns = {'rating' : 'rating_avg'}) 248 | df_users_genres = pd.merge(df_movies_clean[['movieId', *df_movies_clean.columns[3:22]]], pd.merge(df_rating, df_users, on = 'userId')[['userId', 'movieId']], 249 | on = 'movieId') 250 | 251 | df_users_genres = df_users_genres.groupby(by = ['userId'], as_index = False)[df_movies_clean.columns[3:22]].mean() 252 | df_users_genres = pd.merge(df_users_genres, df_users, on = 'userId') 253 | df_pairs = pd.merge(df_rating, df_users, on = 'userId')[['userId', 'movieId']] 254 | 255 | print("------------------------ Users embedding ------------------------") 256 | users_id = [] 257 | vect_space = [] 258 | for Id in tqdm.tqdm(df_pairs.userId.unique()): 259 | movie_list = df_pairs[df_pairs.userId == Id].movieId.tolist() 260 | vect = np.asarray(df_movies_clean[df_movies_clean.movieId.isin(movie_list)][[*df_movies_clean.columns[22:]]].mean().tolist()) 261 | users_id.append(Id) 262 | vect_space.append(vect) 263 | 264 | df_users_w2v = pd.DataFrame(vect_space, columns = ['w2v_' + str(i) for i in range(len(df_movies_clean.columns[22:]))]) 265 | df_users_w2v['userId'] = users_id 266 | df_users_clean = pd.merge(df_users_genres, df_users_w2v, on = 'userId') 267 | df_rating_clean = df_rating[['userId', 'movieId', 'rating', 'timestamp']] 268 | 269 | """ 270 | cat_dict_movies = pd.Series(df_movies_clean.movieId.astype("category").cat.codes.values, index=df_movies_clean.movieId).to_dict() 271 | cat_dict_users = pd.Series(df_users_clean.userId.astype("category").cat.codes.values, index=df_users_clean.userId).to_dict() 272 | 273 | df_movies_clean.movieId = df_movies_clean.movieId.apply(lambda x: cat_dict_movies[x]) 274 | df_users_clean.userId = df_users_clean.userId.apply(lambda x: cat_dict_users[x]) 275 | df_rating_clean.movieId = df_rating.movieId.apply(lambda x: cat_dict_movies[x]) 276 | df_rating_clean.userId = df_rating.userId.apply(lambda x: cat_dict_users[x]) 277 | """ 278 | 279 | df_movies_clean = df_movies_clean.rename(columns={'movieId': 'item_idx'}) 280 | df_users_clean = df_users_clean.rename(columns={'userId': 'user_idx'}) 281 | df_rating_clean = df_rating_clean.rename(columns={'movieId': 'item_idx', 'userId': 'user_idx', 'rating': 'relevance'}) 282 | 283 | return [df_movies_clean, df_users_clean, df_rating_clean] -------------------------------------------------------------------------------- /experiments/Netflix/RU_load_datasets.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e481238a-e6ea-431d-b2a7-957600884500", 6 | "metadata": {}, 7 | "source": [ 8 | "# Иллюстрация загрузки и подготовки данных Netflix" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "d3607d22-edeb-445a-b281-8bc237e86a98", 14 | "metadata": {}, 15 | "source": [ 16 | "### $\\textbf{Содержание}$:" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "3c60e1bb-1f79-47b0-9cb4-1240e29f2d7a", 22 | "metadata": {}, 23 | "source": [ 24 | "### $\\textbf{I. Создание табличных данных на основании txt файлов}$\n", 25 | "### Из записей txt файлов формируются в табличные данные csv формата с атрибутами:\n", 26 | "#### - $\\it{movie\\_Id} \\in \\mathbb{N}$: идентификатор предложения; \n", 27 | "#### - $\\it{user_\\_i_d} \\in \\mathbb{N}$: идентификатор пользователя;\n", 28 | "#### - $\\it{rating} \\in [1, 5]$: полезность предложения для пользователя;\n", 29 | "#### - $\\it{date} \\in \\mathbb{N}$: временная метка;\n", 30 | "------\n" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "344a920d-f5a1-4770-80b5-80d8959acc21", 36 | "metadata": {}, 37 | "source": [ 38 | "### $\\textbf{II. Обработка фильмов}$\n", 39 | "### Из названия фильмов генерируются следующие признаки:\n", 40 | "#### - год выпуска фильма $\\it{year} \\in \\mathbb{N}$;\n", 41 | "### Из оценок, выставленных пользователями, генерируются следующие признаки:\n", 42 | "#### - средняя оценка фильма $\\it{rating\\_avg} \\in [0, 5]$;\n", 43 | "#### - количество оценок фильма $\\it{rating\\_cnt} \\in \\mathbb{N} \\cup \\{0\\}$;\n", 44 | "---" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "272be4e1-9592-4d71-8211-18c94b840c3a", 50 | "metadata": {}, 51 | "source": [ 52 | "### $\\textbf{III. Обработка рейтингов}$\n", 53 | "#### Никакие признаки на это этапе не генерируются;\n", 54 | "#### Атрибут даты приводятся к формату timestamp;\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "b51300d1-d5c1-425a-aef5-a2f5c00483ad", 61 | "metadata": { 62 | "tags": [] 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "import os\n", 67 | "import sys\n", 68 | "import pandas as pd\n", 69 | "import numpy as np\n", 70 | "\n", 71 | "import re\n", 72 | "from datetime import datetime\n", 73 | "\n", 74 | "import tqdm\n", 75 | "\n", 76 | "import pyspark\n", 77 | "from pyspark.sql import SparkSession\n", 78 | "from pyspark.sql import functions as F\n", 79 | "\n", 80 | "spark = SparkSession.builder\\\n", 81 | " .appName(\"processingApp\")\\\n", 82 | " .config(\"spark.driver.memory\", \"8G\")\\\n", 83 | " .config(\"spark.executor.cores\", \"8G\")\\\n", 84 | " .config(\"spark.executor.memory\", \"2G\")\\\n", 85 | " .getOrCreate()\n", 86 | "\n", 87 | "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "e8354c1e-a4d2-4cc6-810a-5e76e7b8c8fe", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "data_folder = r'./'\n", 98 | "row_data_files = ['combined_data_' + str(i) + '.txt' for i in range(1,5)]\n", 99 | "\n", 100 | "movie_titles_path = r'./movie_titles.csv'\n", 101 | "\n", 102 | "save_file_name = r\"./data_clean/netflix_full.csv\"\n", 103 | "\n", 104 | "save_movies = r'./data_clean/movies.csv'\n", 105 | "save_rating = r'./data_clean/rating.csv'" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "c08cea48-8e20-433b-ad67-5dd6b19ac0bd", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "222b7bcc-208a-459b-a7db-65245f7a4e77", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "row_data_files = [os.path.join(data_folder, i) for i in row_data_files]" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "id": "5e51560c-2254-49be-b33d-13a43075d4b2", 129 | "metadata": {}, 130 | "source": [ 131 | "### I. Создание табличных данных на основании txt файлов" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "e88db14a-d515-4ee8-be4d-a4ee6de986b7", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "def row_data_ops(\n", 142 | " files: List[str],\n", 143 | " save_file_name: str\n", 144 | "):\n", 145 | " \"\"\"\n", 146 | " Creating table data from txt files\n", 147 | " :param files: txt files names\n", 148 | " :type files: list\n", 149 | " :param save_file_name: file name for saving\n", 150 | " :type save_file_name: str\n", 151 | "\n", 152 | " \n", 153 | " \"\"\"\n", 154 | " for _, file_name in enumerate(files):\n", 155 | " df = spark.read.text(os.path.join(file_name))\n", 156 | " df = df.coalesce(1).withColumn(\"row_num\", F.monotonically_increasing_id())\n", 157 | "\n", 158 | " df_partitions = df.select( F.col(\"row_num\").alias(\"Id\"), \n", 159 | " F.regexp_extract(F.col(\"value\"), r'\\d+', 0).alias(\"Id_start\") ).where( F.substring(F.col(\"value\"), -1, 1)==\":\" )\n", 160 | " df_partitions = df_partitions.select( F.col(\"Id\").cast('int'),\n", 161 | " F.col(\"Id_start\").cast('int'))\n", 162 | "\n", 163 | " df_rows = df.select( F.col(\"row_num\").alias(\"Id\"),\n", 164 | " F.col(\"value\") ).where( F.substring(F.col(\"value\"), -1, 1)!=\":\" )\n", 165 | " df_rows = df_rows.select( F.col(\"Id\"),\n", 166 | " F.regexp_extract(F.col(\"value\"), r'(\\d+),(\\d+),(\\d+-\\d+-\\d+)', 1).cast('int').alias(\"user_Id\"),\n", 167 | " F.regexp_extract(F.col(\"value\"), r'(\\d+),(\\d+),(\\d+-\\d+-\\d+)', 2).cast('int').alias(\"rating\"),\n", 168 | " F.to_date(F.regexp_extract(F.col(\"value\"), r'(\\d+),(\\d+),(\\d+-\\d+-\\d+)', 3), \"yyyy-mm-dd\").alias(\"date\"))\n", 169 | " df_partitions2 = df_partitions.select( F.col(\"Id\").alias(\"Id2\"),\n", 170 | " F.col(\"Id_start\").alias(\"Id_end\"))\n", 171 | " df_indexes = df_partitions.join(df_partitions2, df_partitions2.Id_end - df_partitions.Id_start == 1, \"left\").select( \n", 172 | " F.col('Id').alias('Idx_start'), \n", 173 | " F.col('Id2').alias('Idx_stop'),\n", 174 | " F.col('Id_start').alias('Index'))\n", 175 | "\n", 176 | " df_result = df_rows.join(F.broadcast(df_indexes), (df_rows.Id > df_indexes.Idx_start) & ((df_rows.Id < df_indexes.Idx_stop) | (df_indexes.Idx_stop.isNull())), \"inner\").select(\n", 177 | " F.col('Index').alias('movie_Id'),\n", 178 | " F.col('user_Id'),\n", 179 | " F.col('rating'),\n", 180 | " F.col('date')\n", 181 | " ).distinct()\n", 182 | " \n", 183 | " if _ == 0:\n", 184 | " df_all_users = df_result\n", 185 | " else:\n", 186 | " df_all_users = df_all_users.union(df_result)\n", 187 | " \n", 188 | " df_all_users.write.csv(save_file_name)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "id": "7d25ad6e-c18e-4e7e-a7ee-2eaf80de0621", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "row_data_ops(row_data_files, save_file_name)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "id": "8a164e5c-e723-4a45-96a3-1d84fd04221a", 204 | "metadata": {}, 205 | "source": [ 206 | "### II. Обработка фильмов" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 1, 212 | "id": "033358f6-fdca-4872-b638-79db19f61725", 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "def movies_ops(\n", 217 | " data_path, \n", 218 | " movies_path, \n", 219 | " movie_save_path\n", 220 | "):\n", 221 | " \n", 222 | " \"\"\"\n", 223 | " operate movies name\n", 224 | " :param data_path: path to netflix full file\n", 225 | " :type data_path: str\n", 226 | " :param movies_path: path to netflix movies file\n", 227 | " :type movies_path: str\n", 228 | " :param movie_save_path: file path to save clear netflix movies\n", 229 | " :type movie_save_path: str\n", 230 | "\n", 231 | " \"\"\"\n", 232 | " \n", 233 | " df_all = pd.read_csv(data_path) \n", 234 | " df_movies = pd.merge(df_all.groupby(by = ['movie_Id'], as_index=False).rating.count().rename(columns={'rating':'rating_cnt'}),\n", 235 | " df_all.groupby(by = ['movie_Id'], as_index=False).rating.mean().rename(columns={'rating':'rating_avg'}), \n", 236 | " on = 'movie_Id', how = 'inner')\n", 237 | " \n", 238 | " with open(movies_path, 'r', encoding=\"ISO-8859-1\") as f:\n", 239 | " file = f.read()\n", 240 | " file_arr = file.split('\\n')\n", 241 | " \n", 242 | " file_arr_id = []\n", 243 | " file_arr_year = []\n", 244 | " file_arr_name = []\n", 245 | "\n", 246 | " file_arr_problem = []\n", 247 | "\n", 248 | " for i in file_arr:\n", 249 | " row = re.sub(r'^\\s+', '', i)\n", 250 | " row = re.sub(r'\\s+$', '', i)\n", 251 | " row_group = re.match(r'(\\d+),(\\d+),(.+)', row)\n", 252 | " if row_group != None:\n", 253 | " assert row == row_group.group(0)\n", 254 | "\n", 255 | " file_arr_id.append(int(row_group.group(1)))\n", 256 | " file_arr_year.append(int(row_group.group(2)))\n", 257 | " file_arr_name.append(row_group.group(3))\n", 258 | "\n", 259 | " else:\n", 260 | " file_arr_problem.append(row)\n", 261 | "\n", 262 | " \n", 263 | " df_names = pd.DataFrame({ 'movie_Id':file_arr_id, 'year':file_arr_year, 'title':file_arr_name })\n", 264 | " fill_na_year = ['2002', '2002', '2002', '1974', '1999', '1994', '1999']\n", 265 | " fill_na_name = []\n", 266 | " fill_na_id = []\n", 267 | "\n", 268 | " for i in range(len(file_arr_problem)-1):\n", 269 | " row_group = re.match(r'(\\d+),(NULL),(.+)', file_arr_problem[i])\n", 270 | "\n", 271 | " fill_na_id.append(int(row_group.group(1)))\n", 272 | " fill_na_name.append(row_group.group(3))\n", 273 | "\n", 274 | " df_names = pd.concat([df_names, pd.DataFrame({ 'movie_Id':fill_na_id, 'year':fill_na_year, 'title':fill_na_name })])\n", 275 | " df_names.movie_Id = df_names.movie_Id.astype('int')\n", 276 | " df_names.year = df_names.year.astype('int')\n", 277 | " \n", 278 | " df_movies = pd.merge(df_movies, df_names, on = 'movie_Id', how = 'left')\n", 279 | " df_movies.reset_index(drop=True).to_csv(movie_save_path, index=False)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "id": "344fe1df-ae89-47cb-b777-4d4214e86bbf", 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "movies_ops(save_file_name, movie_titles_path, save_movies)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "id": "4fa05a96-ddd8-4d92-82ec-d16145ecd774", 295 | "metadata": {}, 296 | "source": [ 297 | "### III. Обработка рейтингов" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 2, 303 | "id": "10d08d3b-f540-4b1b-8c81-7a2ac2d024fc", 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "def rating_op(\n", 308 | " data_path, \n", 309 | " rating_save_path\n", 310 | "):\n", 311 | " \"\"\"\n", 312 | " operate ratings \n", 313 | " :param data_path: path to netflix full file\n", 314 | " :type data_path: str\n", 315 | " :param rating_save_path: file path to save operated netflix ratings\n", 316 | " :type rating_save_path: str\n", 317 | "\n", 318 | " \"\"\"\n", 319 | " \n", 320 | " df_rating = pd.read_csv(data_path) \n", 321 | " df_rating['timestamp'] = df_rating.date.apply(lambda x: pd.to_datetime(x))\n", 322 | " df_rating['timestamp'] = df_rating.timestamp.apply(lambda x: x.timestamp())\n", 323 | " df_rating = df_rating[['movie_Id', 'user_Id', 'rating', 'timestamp']]\n", 324 | " \n", 325 | " df_rating.reset_index(drop=True).to_csv(rating_save_path, index=False)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "id": "5744d2fd-2507-48f7-9081-69d9daa8a1e2", 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "rating_op(save_file_name, save_rating)" 336 | ] 337 | } 338 | ], 339 | "metadata": { 340 | "kernelspec": { 341 | "display_name": "Python [conda env:.conda-sber3.8]", 342 | "language": "python", 343 | "name": "conda-env-.conda-sber3.8-py" 344 | }, 345 | "language_info": { 346 | "codemirror_mode": { 347 | "name": "ipython", 348 | "version": 3 349 | }, 350 | "file_extension": ".py", 351 | "mimetype": "text/x-python", 352 | "name": "python", 353 | "nbconvert_exporter": "python", 354 | "pygments_lexer": "ipython3", 355 | "version": "3.8.13" 356 | } 357 | }, 358 | "nbformat": 4, 359 | "nbformat_minor": 5 360 | } 361 | -------------------------------------------------------------------------------- /experiments/Netflix/netflix.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import re 4 | import itertools 5 | import tqdm 6 | 7 | import seaborn as sns 8 | import tqdm 9 | import matplotlib.pyplot as plt 10 | 11 | from bs4 import BeautifulSoup 12 | from nltk.tokenize import TreebankWordTokenizer, WhitespaceTokenizer 13 | 14 | import nltk 15 | nltk.download('stopwords') 16 | nltk.download('words') 17 | words = set(nltk.corpus.words.words()) 18 | words = set([w.lower() for w in words]) 19 | 20 | from nltk.stem import PorterStemmer, WordNetLemmatizer 21 | nltk.download("wordnet") 22 | 23 | from nltk.corpus import stopwords 24 | stop_words = set(stopwords.words("english")) 25 | 26 | 27 | from nltk.tokenize import sent_tokenize 28 | 29 | import gensim 30 | from gensim.downloader import load 31 | from gensim.models import Word2Vec 32 | w2v_model = gensim.downloader.load('word2vec-google-news-300') 33 | 34 | from typing import Dict, List, Optional, Tuple 35 | 36 | def clean_text(text: str) -> str: 37 | """ 38 | Cleaning text: remove extra spaces and non-text characters 39 | :param text: tag row text 40 | :type title: str 41 | :return: tag cleaned text 42 | :rtype: str 43 | """ 44 | 45 | text = re.sub("[^a-zA-Z]", " ",text) 46 | text = re.sub(r"\s+", " ", text) 47 | text = re.sub(r"\s+$", "", text) 48 | text = re.sub(r"^\s+", "", text) 49 | text = text.lower() 50 | 51 | return text 52 | 53 | 54 | def procces_text(text): 55 | """ 56 | Processing text: lemmatization, tokenization, removing stop-words 57 | :param text: tag cleaned text 58 | :type text: str 59 | :return: tag processed text 60 | :rtype: str 61 | """ 62 | lemmatizer = WordNetLemmatizer() 63 | 64 | text = [word for word in nltk.word_tokenize(text) if not word in stop_words] 65 | text = [lemmatizer.lemmatize(token) for token in text] 66 | text = [word for word in text if word in words] 67 | 68 | text = " ".join(text) 69 | 70 | return text 71 | 72 | def string_embedding(string: str) -> np.ndarray: 73 | """" 74 | Processing text: lemmatization, tokenization, removing stop-words 75 | :param string: cleaned and processed tags 76 | :type string: str 77 | :return: average vector of the string words embeddings 78 | :rtype: np.ndarray, optional 79 | """ 80 | 81 | arr = string.split(' ') 82 | vec = 0 83 | cnt = 0 84 | for i in arr: 85 | try: 86 | vec += w2v_model[i] 87 | cnt += 1 88 | except: 89 | pass 90 | if cnt == 0: 91 | vec = np.zeros((300, 1)) 92 | else: 93 | vec /= cnt 94 | return vec 95 | 96 | def group_w2v(history: pd.DataFrame, movies: pd.DataFrame) -> pd.DataFrame: 97 | """" 98 | Aggregate embedded data for users partitions. 99 | :param history: users movies history data . 100 | :type history: pd.DataFrame 101 | :param movies: movies data. 102 | :type movies: pd.DataFrame 103 | :return: average vector of the string words embeddings 104 | :rtype: np.ndarray, optional 105 | """ 106 | 107 | """ 108 | Aggregation (mean) embedded data for users watch history partitions. 109 | 110 | Arguments: 111 | --history: data frame of users movies history. 112 | --movies: data frame of movies. 113 | 114 | Return: 115 | --df: data frame of users with aggregation embedded movies data. 116 | """ 117 | users_id_arr = history.user_Id.unique() 118 | 119 | id_arr = [] 120 | vec_arr = np.zeros((len(users_id_arr), 300)) 121 | 122 | for user_id in tqdm.tqdm_notebook(range(len(users_id_arr))): 123 | vec = np.asarray(movies[movies.movie_Id.isin(history[history.user_Id == users_id_arr[user_id]].movie_Id)].iloc[:, 4:]).mean(axis=0) 124 | 125 | id_arr.append(users_id_arr[user_id]) 126 | vec_arr[user_id] = vec 127 | 128 | df = pd.DataFrame(vec_arr) 129 | df['user_Id'] = id_arr 130 | 131 | return df 132 | 133 | def data_processing(df_movies: pd.DataFrame, 134 | df_rating: pd.DataFrame, 135 | rename: bool = True 136 | ) -> List[pd.DataFrame]: 137 | 138 | df_movies = df_movies.drop(["rating_cnt", "rating_avg"], axis=1) 139 | df_movies['clean_title'] = df_movies.title.apply(lambda x : procces_text(clean_text(x))) 140 | df_movies.drop("title", axis = 1, inplace = True) 141 | df_movies_clean = pd.concat([df_movies.drop("clean_title", axis=1), 142 | pd.DataFrame(df_movies.clean_title.apply(string_embedding).to_list(), columns = ['w2v_' + str(i) for i in range(300)])], axis = 1) 143 | 144 | movies_vector = df_movies_clean.drop(['year'], axis=1) 145 | for col in movies_vector.drop("movie_Id", axis=1).columns: 146 | movies_vector[col] = movies_vector[col].astype('float') 147 | 148 | agg_columns = [] 149 | df_result = pd.DataFrame() 150 | 151 | chunksize=10000 152 | chunk_count = (df_rating.shape[0] // chunksize) + 1 if df_rating.shape[0]%chunksize!=0 else df_rating.shape[0] // chunksize 153 | for idx in tqdm.tqdm_notebook(range(chunk_count)): 154 | chunk = df_rating.iloc[idx*chunksize:(idx+1)*chunksize, :] 155 | df_history = pd.merge(chunk[['user_Id', 'movie_Id', 'rating']], movies_vector.movie_Id, on = 'movie_Id', how = 'left') 156 | df_history = pd.merge(df_history, movies_vector, how='left', on='movie_Id').drop('movie_Id', axis=1) 157 | df_history['cnt'] = 1 158 | 159 | if idx == 0: 160 | agg_columns = df_history.drop(['user_Id'], axis=1).columns 161 | df_history_aggregated = df_history.groupby("user_Id", as_index=False)[agg_columns].sum() 162 | df_result = df_result.append(df_history_aggregated, ignore_index=True) 163 | 164 | if idx % 20 == 0: 165 | df_result = df_result.groupby("user_Id", as_index=False)[agg_columns].sum() 166 | 167 | df_result = df_result.groupby("user_Id", as_index=False)[agg_columns].sum() 168 | for col in agg_columns: 169 | if col != "cnt": 170 | df_result[col] = df_result[col] / df_result["cnt"] 171 | df_result = df_result.rename(columns={"rating": "rating_avg", "cnt": "rating_cnt"}) 172 | df_users_clean = df_result 173 | 174 | df_movies_clean = pd.merge(df_rating.groupby("movie_Id", as_index=False)["rating"]\ 175 | .agg(['mean', 'count'])\ 176 | .rename(columns={"mean": "rating_avg", "count": "rating_cnt"}), df_movies_clean,how='left', on='movie_Id').fillna(0.0) 177 | 178 | df_rating_clean = df_rating 179 | if rename: 180 | cat_dict_movies = pd.Series(df_movies_clean.movie_Id.astype("category").cat.codes.values, index=df_movies_clean.movie_Id).to_dict() 181 | cat_dict_users = pd.Series(df_users_clean.user_Id.astype("category").cat.codes.values, index=df_users_clean.user_Id).to_dict() 182 | df_movies_clean.movie_Id = df_movies_clean.movie_Id.apply(lambda x: cat_dict_movies[x]) 183 | df_users_clean.user_Id = df_users_clean.user_Id.apply(lambda x: cat_dict_users[x]) 184 | df_rating_clean.movie_Id = df_rating_clean.movie_Id.apply(lambda x: cat_dict_movies[x]) 185 | df_rating_clean.user_Id = df_rating_clean.user_Id.apply(lambda x: cat_dict_users[x]) 186 | 187 | df_movies_clean = df_movies_clean.rename(columns={'movie_Id': 'item_idx'}) 188 | df_users_clean = df_users_clean.rename(columns={'user_Id': 'user_idx'}) 189 | df_rating_clean = df_rating_clean.rename(columns={'movie_Id': 'item_idx', 'user_Id': 'user_idx', 'rating': 'relevance'}) 190 | 191 | return [df_movies_clean, df_users_clean, df_rating_clean] -------------------------------------------------------------------------------- /experiments/datautils.py: -------------------------------------------------------------------------------- 1 | import pyspark.sql.types as st 2 | import pyspark.sql.functions as sf 3 | from pyspark.sql import DataFrame 4 | from pyspark.ml.feature import Bucketizer, VectorAssembler 5 | from pyspark.ml.clustering import KMeans 6 | 7 | 8 | USER_PREFIX = 'user_' 9 | ITEM_PREFIX = 'item_' 10 | 11 | MOVIELENS_USER_SCHEMA = st.StructType( 12 | [st.StructField('user_idx', st.IntegerType())] +\ 13 | [st.StructField(f'genre{i}', st.DoubleType()) for i in range(19)] +\ 14 | [st.StructField('rating_avg', st.DoubleType())] +\ 15 | [st.StructField(f'w2v_{i}', st.DoubleType()) for i in range(300)] 16 | ) 17 | MOVIELENS_ITEM_SCHEMA = st.StructType( 18 | [st.StructField('item_idx', st.IntegerType())] +\ 19 | [st.StructField('year', st.IntegerType())] +\ 20 | [st.StructField('rating_avg', st.DoubleType())] +\ 21 | [st.StructField(f'genre{i}', st.DoubleType()) for i in range(19)] +\ 22 | [st.StructField(f'w2v_{i}', st.DoubleType()) for i in range(300)] 23 | ) 24 | MOVIELENS_LOG_SCHEMA = st.StructType([ 25 | st.StructField('user_idx', st.IntegerType()), 26 | st.StructField('item_idx', st.IntegerType()), 27 | st.StructField('relevance', st.DoubleType()), 28 | st.StructField('timestamp', st.IntegerType()) 29 | ]) 30 | 31 | NETFLIX_USER_SCHEMA = st.StructType( 32 | [st.StructField('user_idx', st.IntegerType())] +\ 33 | [st.StructField('rating_avg', st.DoubleType())] +\ 34 | [st.StructField(f'w2v_{i}', st.DoubleType()) for i in range(300)] +\ 35 | [st.StructField('rating_cnt', st.IntegerType())] 36 | ) 37 | NETFLIX_ITEM_SCHEMA = st.StructType( 38 | [st.StructField('item_idx', st.IntegerType())] +\ 39 | [st.StructField('rating_avg', st.DoubleType())] +\ 40 | [st.StructField('rating_cnt', st.IntegerType())] +\ 41 | [st.StructField('year', st.IntegerType())] +\ 42 | [st.StructField(f'w2v_{i}', st.DoubleType()) for i in range(300)] 43 | ) 44 | NETFLIX_LOG_SCHEMA = st.StructType([ 45 | st.StructField('item_idx', st.IntegerType()), 46 | st.StructField('user_idx', st.IntegerType()), 47 | st.StructField('relevance', st.DoubleType()), 48 | st.StructField('timestamp', st.DoubleType()) 49 | ]) 50 | 51 | MOVIELENS_CLUSTER_COLS = [ 52 | 'genre0', 'genre1', 'genre2', 'genre3', 'genre4', 53 | 'genre5', 'genre6', 'genre7', 'genre8', 'genre9', 54 | 'genre10', 'genre11', 'genre12', 'genre13', 'genre14', 55 | 'genre15', 'genre16', 'genre17', 'genre18' 56 | ] 57 | NETFLIX_STRAT_COL = 'rating_cnt' 58 | NETFLIX_CLUSTER_COLS = [f'w2v_{i}' for i in range(300)] 59 | 60 | 61 | def read_movielens(base_path, type, spark_session): 62 | if type not in ['train', 'val', 'test']: 63 | raise ValueError('Wrong dataset type') 64 | 65 | users = spark_session\ 66 | .read.csv(f'{base_path}/{type}/users.csv', header=True, schema=MOVIELENS_USER_SCHEMA)\ 67 | .withColumnRenamed('user_idx', 'user_id') 68 | items = spark_session\ 69 | .read.csv(f'{base_path}/{type}/items.csv', header=True, schema=MOVIELENS_ITEM_SCHEMA)\ 70 | .withColumnRenamed('item_idx', 'item_id') 71 | log = spark_session\ 72 | .read.csv(f'{base_path}/{type}/rating.csv', header=True, schema=MOVIELENS_LOG_SCHEMA)\ 73 | .withColumnRenamed('user_idx', 'user_id')\ 74 | .withColumnRenamed('item_idx', 'item_id') 75 | 76 | log = log\ 77 | .join(users, on='user_id', how='leftsemi')\ 78 | .join(items, on='item_id', how='leftsemi') 79 | 80 | for c in users.columns: 81 | if not c.startswith('user_'): 82 | users = users.withColumnRenamed(c, 'user_' + c) 83 | 84 | for c in items.columns: 85 | if not c.startswith('item_'): 86 | items = items.withColumnRenamed(c, 'item_' + c) 87 | 88 | log = log.withColumn('relevance', sf.when(sf.col('relevance') >= 3, 1).otherwise(0)) 89 | 90 | users = users.na.drop() 91 | items = items.na.drop() 92 | log = log.na.drop() 93 | 94 | return users, items, log 95 | 96 | def read_netflix(base_path, type, spark_session): 97 | if type not in ['train', 'val', 'test']: 98 | raise ValueError('Wrong dataset type') 99 | 100 | users = spark_session\ 101 | .read.csv(f'{base_path}/{type}/users.csv', header=True, schema=NETFLIX_USER_SCHEMA)\ 102 | .withColumnRenamed('user_idx', 'user_id') 103 | items = spark_session\ 104 | .read.csv(f'{base_path}/{type}/items.csv', header=True, schema=NETFLIX_ITEM_SCHEMA)\ 105 | .withColumnRenamed('item_idx', 'item_id') 106 | log = spark_session\ 107 | .read.csv(f'{base_path}/{type}/rating.csv', header=True, schema=NETFLIX_LOG_SCHEMA)\ 108 | .withColumnRenamed('user_idx', 'user_id')\ 109 | .withColumnRenamed('item_idx', 'item_id') 110 | 111 | log = log\ 112 | .join(users, on='user_id', how='leftsemi')\ 113 | .join(items, on='item_id', how='leftsemi') 114 | 115 | for c in users.columns: 116 | if not c.startswith('user_'): 117 | users = users.withColumnRenamed(c, 'user_' + c) 118 | 119 | for c in items.columns: 120 | if not c.startswith('item_'): 121 | items = items.withColumnRenamed(c, 'item_' + c) 122 | 123 | log = log.withColumn('relevance', sf.when(sf.col('relevance') >= 3, 1).otherwise(0)) 124 | 125 | users = users.na.drop() 126 | items = items.na.drop() 127 | log = log.na.drop() 128 | 129 | return users, items, log 130 | 131 | def read_amazon(base_path, type, spark_session): 132 | if type not in ['train', 'val', 'test']: 133 | raise ValueError('Wrong dataset type') 134 | 135 | users = spark_session\ 136 | .read.parquet(f'{base_path}/{type}/users.parquet').withColumnRenamed('user_idx', 'user_id') 137 | items = spark_session\ 138 | .read.parquet(f'{base_path}/{type}/items.parquet').withColumnRenamed('item_idx', 'item_id') 139 | log = spark_session\ 140 | .read.parquet(f'{base_path}/{type}/rating.parquet')\ 141 | .withColumnRenamed('user_idx', 'user_id')\ 142 | .withColumnRenamed('item_idx', 'item_id') 143 | 144 | log = log\ 145 | .join(users, on='user_id', how='leftsemi')\ 146 | .join(items, on='item_id', how='leftsemi') 147 | 148 | for c in users.columns: 149 | if not c.startswith('user_'): 150 | users = users.withColumnRenamed(c, 'user_' + c) 151 | 152 | for c in items.columns: 153 | if not c.startswith('item_'): 154 | items = items.withColumnRenamed(c, 'item_' + c) 155 | 156 | log = log.withColumn('relevance', sf.when(sf.col('relevance') >= 3, 1).otherwise(0)) 157 | 158 | users = users.na.drop() 159 | items = items.na.drop() 160 | log = log.na.drop() 161 | 162 | return users, items, log 163 | 164 | 165 | def netflix_cluster_users( 166 | df : DataFrame, 167 | outputCol : str = 'cluster', 168 | column_prefix : str = '', 169 | seed : int = None 170 | ): 171 | cluster_cols = [f'{column_prefix}{c}' for c in NETFLIX_CLUSTER_COLS] 172 | assembler = VectorAssembler( 173 | inputCols=cluster_cols, outputCol='__features' 174 | ) 175 | kmeans = KMeans( 176 | k=10, featuresCol='__features', 177 | predictionCol=outputCol, maxIter=300, seed=seed 178 | ) 179 | 180 | df = assembler.transform(df) 181 | kmeans_model = kmeans.fit(df) 182 | df = kmeans_model.transform(df) 183 | 184 | return df.drop('__features') 185 | 186 | def movielens_cluster_users( 187 | df : DataFrame, 188 | outputCol : str = 'cluster', 189 | column_prefix : str = '', 190 | seed : int = None 191 | ): 192 | cluster_cols = [f'{column_prefix}{c}' for c in MOVIELENS_CLUSTER_COLS] 193 | assembler = VectorAssembler( 194 | inputCols=cluster_cols, outputCol='__features' 195 | ) 196 | kmeans = KMeans( 197 | k=10, featuresCol='__features', 198 | predictionCol=outputCol, maxIter=300, seed=seed 199 | ) 200 | 201 | df = assembler.transform(df) 202 | kmeans_model = kmeans.fit(df) 203 | df = kmeans_model.transform(df) 204 | 205 | return df.drop('__features') 206 | -------------------------------------------------------------------------------- /experiments/generators_generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import pandas as pd 5 | import torch 6 | from pyspark.sql import SparkSession 7 | from replay.session_handler import State 8 | from sim4rec.modules import SDVDataGenerator 9 | 10 | SPARK_LOCAL_DIR = '/data/home/anthony/tmp' 11 | RESULT_DIR = '../bin' 12 | 13 | 14 | NUM_JOBS = int(sys.argv[1]) 15 | torch.set_num_threads(8) 16 | 17 | spark = SparkSession.builder\ 18 | .appName('simulator')\ 19 | .master(f'local[{NUM_JOBS}]')\ 20 | .config('spark.sql.shuffle.partitions', f'{NUM_JOBS}')\ 21 | .config('spark.default.parallelism', f'{NUM_JOBS}')\ 22 | .config('spark.driver.extraJavaOptions', '-XX:+UseG1GC')\ 23 | .config('spark.executor.extraJavaOptions', '-XX:+UseG1GC')\ 24 | .config('spark.sql.autoBroadcastJoinThreshold', '-1')\ 25 | .config('spark.driver.memory', '256g')\ 26 | .config('spark.local.dir', SPARK_LOCAL_DIR)\ 27 | .getOrCreate() 28 | 29 | State(spark) 30 | 31 | def generate_time(generator, num_samples): 32 | start = time.time() 33 | df = generator.generate(num_samples).cache() 34 | df.count() 35 | result_time = time.time() - start 36 | df.unpersist() 37 | 38 | return (generator.getLabel(), result_time, num_samples, NUM_JOBS) 39 | 40 | generators = [SDVDataGenerator.load(f'{RESULT_DIR}/genscale_{g}_{10000}.pkl') for g in ['copulagan', 'ctgan', 'gaussiancopula', 'tvae']] 41 | for g in generators: 42 | g.setParallelizationLevel(NUM_JOBS) 43 | 44 | NUM_TEST_SAMPLES = [10, 100, 1000, 10000, 100000, 1000000, 10000000] 45 | 46 | result_df = pd.DataFrame(columns=['model_label', 'generate_time', 'num_samples', 'num_threads']) 47 | 48 | for g in generators: 49 | _ = g.generate(100).cache().count() 50 | for n in NUM_TEST_SAMPLES: 51 | print(f'Generating with {g.getLabel()} {n} samples') 52 | result_df.loc[len(result_df)] = generate_time(g, n) 53 | 54 | old_df = None 55 | if os.path.isfile(f'{RESULT_DIR}/gens_sample_time.csv'): 56 | old_df = pd.read_csv(f'{RESULT_DIR}/gens_sample_time.csv') 57 | else: 58 | old_df = pd.DataFrame(columns=['model_label', 'generate_time', 'num_samples', 'num_threads']) 59 | 60 | pd.concat([old_df, result_df], ignore_index=True).to_csv(f'{RESULT_DIR}/gens_sample_time.csv', index=False) 61 | -------------------------------------------------------------------------------- /experiments/response_models/data/popular_items_popularity.parquet/.part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet.crc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sb-ai-lab/Sim4Rec/267b644932b0e120cc9887f0f3a16be38d5739b7/experiments/response_models/data/popular_items_popularity.parquet/.part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet.crc -------------------------------------------------------------------------------- /experiments/response_models/data/popular_items_popularity.parquet/_SUCCESS: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sb-ai-lab/Sim4Rec/267b644932b0e120cc9887f0f3a16be38d5739b7/experiments/response_models/data/popular_items_popularity.parquet/_SUCCESS -------------------------------------------------------------------------------- /experiments/response_models/data/popular_items_popularity.parquet/part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sb-ai-lab/Sim4Rec/267b644932b0e120cc9887f0f3a16be38d5739b7/experiments/response_models/data/popular_items_popularity.parquet/part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet -------------------------------------------------------------------------------- /experiments/response_models/task_1_popular_items.py: -------------------------------------------------------------------------------- 1 | from pyspark.ml import PipelineModel 2 | from pyspark.sql import SparkSession 3 | 4 | import pyspark.sql.functions as sf 5 | 6 | from sim4rec.response import BernoulliResponse, ActionModelTransformer 7 | 8 | from response_models.utils import get_session 9 | 10 | 11 | class PopBasedTransformer(ActionModelTransformer): 12 | def __init__( 13 | self, 14 | spark: SparkSession, 15 | outputCol: str = None, 16 | pop_df_path: str = None, 17 | ): 18 | """ 19 | :param outputCol: Name of the response probability column 20 | :param pop_df_path: path to a spark dataframe with items' popularity 21 | """ 22 | self.pop_df = sf.broadcast(spark.read.parquet(pop_df_path)) 23 | self.outputCol = outputCol 24 | 25 | def _transform(self, dataframe): 26 | return (dataframe 27 | .join(self.pop_df, on='item_idx') 28 | .drop(*set(self.pop_df.columns).difference(["item_idx", self.outputCol])) 29 | ) 30 | 31 | 32 | class TaskOneResponse: 33 | def __init__(self, spark, pop_df_path="./response_models/data/popular_items_popularity.parquet", seed=123): 34 | pop_resp = PopBasedTransformer(spark=spark, outputCol="popularity", pop_df_path=pop_df_path) 35 | br = BernoulliResponse(seed=seed, inputCol='popularity', outputCol='response') 36 | self.model = PipelineModel( 37 | stages=[pop_resp, br]) 38 | 39 | def transform(self, df): 40 | return self.model.transform(df).drop("popularity") 41 | 42 | 43 | if __name__ == '__main__': 44 | import pandas as pd 45 | spark = get_session() 46 | task = TaskOneResponse(spark, pop_df_path="./data/popular_items_popularity.parquet", seed=123) 47 | task.model.stages[0].pop_df.show() 48 | test_df = spark.createDataFrame(pd.DataFrame({"item_idx": [3, 5, 1], "user_idx": [5, 2, 1]})) 49 | task.transform(test_df).show() 50 | task.transform(test_df).show() 51 | task.transform(test_df).show() 52 | -------------------------------------------------------------------------------- /experiments/response_models/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pyspark.sql.functions as sf 3 | from pyspark.sql import SparkSession, DataFrame 4 | from pyspark.ml import PipelineModel 5 | from sim4rec.response import BernoulliResponse, ActionModelTransformer 6 | from IPython.display import clear_output 7 | 8 | def get_session(num_threads=4) -> SparkSession: 9 | return SparkSession.builder \ 10 | .appName('simulator') \ 11 | .master(f'local[{num_threads}]') \ 12 | .config('spark.sql.shuffle.partitions', f'{num_threads * 3}') \ 13 | .config('spark.default.parallelism', f'{num_threads * 3}') \ 14 | .config('spark.driver.extraJavaOptions', '-XX:+UseG1GC') \ 15 | .config('spark.executor.extraJavaOptions', '-XX:+UseG1GC') \ 16 | .getOrCreate() 17 | 18 | 19 | def plot_metric(metrics): 20 | clear_output(wait=True) 21 | # plt.ylim(0, max(metrics) + 1) 22 | plt.plot(metrics) 23 | plt.grid() 24 | plt.xlabel('iteration') 25 | plt.ylabel('# of clicks') 26 | plt.show() 27 | 28 | def calc_metric(response_df): 29 | return (response_df 30 | .groupBy("user_idx").agg(sf.sum("response").alias("num_positive")) 31 | .select(sf.mean("num_positive")).collect()[0][0] 32 | ) 33 | 34 | class ResponseTransformer(ActionModelTransformer): 35 | 36 | def __init__( 37 | self, 38 | spark: SparkSession, 39 | outputCol: str = None, 40 | proba_df_path: str = None, 41 | ): 42 | """ 43 | Calculates users' response based on precomputed probability of item interaction 44 | 45 | :param outputCol: Name of the response probability column 46 | :param boost_df_path: path to a spark dataframe with precomputed user-item probability of interaction 47 | """ 48 | self.proba_df = spark.read.parquet(proba_df_path) 49 | self.outputCol = outputCol 50 | 51 | def _transform(self, dataframe): 52 | return (dataframe 53 | .join(self.proba_df, on=['item_idx', 'user_idx']) 54 | .drop(*set(self.proba_df.columns).difference(["user_idx", "item_idx", self.outputCol])) 55 | ) 56 | -------------------------------------------------------------------------------- /experiments/simulator_time.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | 4 | import sys 5 | import time 6 | import random 7 | 8 | import pandas as pd 9 | import numpy as np 10 | 11 | from pyspark.sql import SparkSession 12 | from replay.session_handler import State 13 | from sim4rec.utils import pandas_to_spark 14 | from sim4rec.modules import SDVDataGenerator 15 | from pyspark.ml.classification import LogisticRegression 16 | from pyspark.ml.feature import VectorAssembler 17 | from pyspark.ml import PipelineModel 18 | from sim4rec.response import CosineSimilatiry, BernoulliResponse, NoiseResponse, ParametricResponseFunction 19 | from sim4rec.utils import VectorElementExtractor 20 | 21 | from replay.data_preparator import Indexer 22 | from replay.models import UCB 23 | 24 | NUM_JOBS = int(sys.argv[1]) 25 | 26 | SPARK_LOCAL_DIR = '/data/home/anthony/tmp' 27 | CHECKPOINT_DIR = '/data/home/anthony/tmp/checkpoints' 28 | MODELS_PATH = '../bin' 29 | 30 | spark = SparkSession.builder\ 31 | .appName('simulator_validation')\ 32 | .master(f'local[{NUM_JOBS}]')\ 33 | .config('spark.sql.shuffle.partitions', f'{NUM_JOBS}')\ 34 | .config('spark.default.parallelism', f'{NUM_JOBS}')\ 35 | .config('spark.driver.extraJavaOptions', '-XX:+UseG1GC')\ 36 | .config('spark.executor.extraJavaOptions', '-XX:+UseG1GC')\ 37 | .config('spark.sql.autoBroadcastJoinThreshold', '-1')\ 38 | .config('spark.driver.memory', '64g')\ 39 | .config('spark.local.dir', SPARK_LOCAL_DIR)\ 40 | .getOrCreate() 41 | 42 | State(spark) 43 | 44 | users_df = pd.DataFrame(data=np.random.normal(1, 1, size=(10000, 100)), columns=[f'user_attr_{i}' for i in range(100)]) 45 | items_df = pd.DataFrame(data=np.random.normal(-1, 1, size=(2000, 100)), columns=[f'item_attr_{i}' for i in range(100)]) 46 | items_df.loc[random.sample(range(2000), 1000)] = np.random.normal(1, 1, size=(1000, 100)) 47 | users_df['user_id'] = np.arange(len(users_df)) 48 | items_df['item_id'] = np.arange(len(items_df)) 49 | history_df_all = pd.DataFrame() 50 | history_df_all['user_id'] = np.random.randint(0, 10000, size=33000) 51 | history_df_all['item_id'] = np.random.randint(0, 2000, size=33000) 52 | history_df_all['relevance'] = 0 53 | 54 | users_matrix = users_df.values[history_df_all.values[:, 0], :-1] 55 | items_matrix = items_df.values[history_df_all.values[:, 1], :-1] 56 | dot = np.sum(users_matrix * items_matrix, axis=1) 57 | history_df_all['relevance'] = np.where(dot >= 0.5, 1, 0) 58 | history_df_all = history_df_all.drop_duplicates(subset=['user_id', 'item_id'], ignore_index=True) 59 | 60 | history_df_train = history_df_all.iloc[:30000] 61 | history_df_val = history_df_all.iloc[30000:] 62 | 63 | users_df = pandas_to_spark(users_df) 64 | items_df = pandas_to_spark(items_df) 65 | history_df_train = pandas_to_spark(history_df_train) 66 | history_df_val = pandas_to_spark(history_df_val) 67 | 68 | user_generator = SDVDataGenerator.load(f'{MODELS_PATH}/cycle_scale_users_gen.pkl') 69 | item_generator = SDVDataGenerator.load(f'{MODELS_PATH}/cycle_scale_items_gen.pkl') 70 | 71 | user_generator.setDevice('cpu') 72 | item_generator.setDevice('cpu') 73 | user_generator.setParallelizationLevel(NUM_JOBS) 74 | item_generator.setParallelizationLevel(NUM_JOBS) 75 | 76 | syn_users = user_generator.generate(1000000).cache() 77 | syn_items = item_generator.generate(10000).cache() 78 | 79 | va_users_items = VectorAssembler( 80 | inputCols=users_df.columns[:-1] + items_df.columns[:-1], 81 | outputCol='features' 82 | ) 83 | 84 | lr = LogisticRegression( 85 | featuresCol='features', 86 | labelCol='relevance', 87 | probabilityCol='__lr_prob' 88 | ) 89 | 90 | vee = VectorElementExtractor(inputCol='__lr_prob', outputCol='__lr_prob', index=1) 91 | 92 | lr_train_df = history_df_train\ 93 | .join(users_df, 'user_id', 'left')\ 94 | .join(items_df, 'item_id', 'left') 95 | 96 | lr_model = lr.fit(va_users_items.transform(lr_train_df)) 97 | 98 | 99 | va_users = VectorAssembler( 100 | inputCols=users_df.columns[:-1], 101 | outputCol='features_usr' 102 | ) 103 | 104 | va_items = VectorAssembler( 105 | inputCols=items_df.columns[:-1], 106 | outputCol='features_itm' 107 | ) 108 | 109 | cos_sim = CosineSimilatiry( 110 | inputCols=["features_usr", "features_itm"], 111 | outputCol="__cos_prob" 112 | ) 113 | 114 | noise_resp = NoiseResponse(mu=0.5, sigma=0.2, outputCol='__noise_prob', seed=1234) 115 | 116 | param_resp = ParametricResponseFunction( 117 | inputCols=['__lr_prob', '__cos_prob', '__noise_prob'], 118 | outputCol='__proba', 119 | weights=[1/3, 1/3, 1/3] 120 | ) 121 | 122 | br = BernoulliResponse(inputCol='__proba', outputCol='response') 123 | 124 | pipeline = PipelineModel( 125 | stages=[ 126 | va_users_items, 127 | lr_model, 128 | vee, 129 | va_users, 130 | va_items, 131 | cos_sim, 132 | noise_resp, 133 | param_resp, 134 | br 135 | ] 136 | ) 137 | 138 | from sim4rec.modules import Simulator, EvaluateMetrics 139 | from replay.metrics import NDCG 140 | 141 | sim = Simulator( 142 | user_gen=user_generator, 143 | item_gen=item_generator, 144 | user_key_col='user_id', 145 | item_key_col='item_id', 146 | spark_session=spark, 147 | data_dir=f'{CHECKPOINT_DIR}/cycle_load_test_{NUM_JOBS}', 148 | ) 149 | 150 | evaluator = EvaluateMetrics( 151 | userKeyCol='user_id', 152 | itemKeyCol='item_id', 153 | predictionCol='relevance', 154 | labelCol='response', 155 | replay_label_filter=1.0, 156 | replay_metrics={NDCG() : 100} 157 | ) 158 | 159 | indexer = Indexer(user_col='user_id', item_col='item_id') 160 | indexer.fit(users=syn_users, items=syn_items) 161 | 162 | ucb = UCB(sample=True) 163 | ucb.fit(log=indexer.transform(history_df_train.limit(1))) 164 | 165 | items_replay = indexer.transform(syn_items).cache() 166 | 167 | ucb_metrics = [] 168 | 169 | time_list = [] 170 | for i in range(30): 171 | cycle_time = {} 172 | iter_start = time.time() 173 | 174 | start = time.time() 175 | users = sim.sample_users(0.02).cache() 176 | users.count() 177 | cycle_time['sample_users_time'] = time.time() - start 178 | 179 | start = time.time() 180 | log = sim.get_log(users) 181 | if log is not None: 182 | log = indexer.transform(log).cache() 183 | else: 184 | log = indexer.transform(history_df_train.limit(1)).cache() 185 | log.count() 186 | cycle_time['get_log_time'] = time.time() - start 187 | 188 | start = time.time() 189 | recs_ucb = ucb.predict( 190 | log=log, 191 | k=100, 192 | users=indexer.transform(users), 193 | items=items_replay 194 | ) 195 | recs_ucb = indexer.inverse_transform(recs_ucb).cache() 196 | recs_ucb.count() 197 | cycle_time['model_predict_time'] = time.time() - start 198 | 199 | start = time.time() 200 | resp_ucb = sim.sample_responses( 201 | recs_df=recs_ucb, 202 | user_features=users, 203 | item_features=syn_items, 204 | action_models=pipeline 205 | ).select('user_id', 'item_id', 'relevance', 'response').cache() 206 | resp_ucb.count() 207 | cycle_time['sample_responses_time'] = time.time() - start 208 | 209 | start = time.time() 210 | sim.update_log(resp_ucb, iteration=i) 211 | cycle_time['update_log_time'] = time.time() - start 212 | 213 | start = time.time() 214 | ucb_metrics.append(evaluator(resp_ucb)) 215 | cycle_time['metrics_time'] = time.time() - start 216 | 217 | start = time.time() 218 | ucb._clear_cache() 219 | ucb_train_log = sim.log.cache() 220 | cycle_time['log_size'] = ucb_train_log.count() 221 | ucb.fit( 222 | log=indexer.transform( 223 | ucb_train_log\ 224 | .select('user_id', 'item_id', 'response')\ 225 | .withColumnRenamed('response', 'relevance') 226 | ) 227 | ) 228 | cycle_time['model_train'] = time.time() - start 229 | 230 | users.unpersist() 231 | if log is not None: 232 | log.unpersist() 233 | recs_ucb.unpersist() 234 | resp_ucb.unpersist() 235 | ucb_train_log.unpersist() 236 | 237 | cycle_time['iter_time'] = time.time() - iter_start 238 | cycle_time['iteration'] = i 239 | cycle_time['num_threads'] = NUM_JOBS 240 | 241 | time_list.append(cycle_time) 242 | 243 | print(f'Iteration {i} ended in {cycle_time["iter_time"]} seconds') 244 | 245 | items_replay.unpersist() 246 | 247 | import os 248 | 249 | if os.path.isfile(f'{MODELS_PATH}/cycle_time.csv'): 250 | pd.concat([pd.read_csv(f'{MODELS_PATH}/cycle_time.csv'), pd.DataFrame(time_list)]).to_csv(f'{MODELS_PATH}/cycle_time.csv', index=False) 251 | else: 252 | pd.DataFrame(time_list).to_csv(f'{MODELS_PATH}/cycle_time.csv', index=False) 253 | -------------------------------------------------------------------------------- /experiments/transformers.py: -------------------------------------------------------------------------------- 1 | from sim4rec.response import ActionModelTransformer 2 | from pyspark.ml.param.shared import HasInputCol 3 | from pyspark.sql import DataFrame 4 | import pyspark.sql.functions as sf 5 | from pyspark.ml.param.shared import Params, Param, TypeConverters 6 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable 7 | 8 | class HasMultiplierValue(Params): 9 | multiplierValue = Param( 10 | Params._dummy(), 11 | "multiplierValue", 12 | "Multiplier value parameter", 13 | typeConverter=TypeConverters.toFloat 14 | ) 15 | 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def setMultiplierValue(self, value): 20 | return self._set(multiplierValue=value) 21 | 22 | def getMultiplierValue(self): 23 | return self.getOrDefault(self.multiplierValue) 24 | 25 | class ModelCalibration(ActionModelTransformer, 26 | HasInputCol, 27 | HasMultiplierValue, 28 | DefaultParamsReadable, 29 | DefaultParamsWritable): 30 | def __init__( 31 | self, 32 | value : float = 0.0, 33 | inputCol : str = None, 34 | outputCol : str = None 35 | ): 36 | """ 37 | Multiplies response function output by the chosen value. 38 | :param value: Multiplier value 39 | :param outputCol: Output column name 40 | """ 41 | 42 | super().__init__(outputCol=outputCol) 43 | 44 | self._set(inputCol=inputCol) 45 | self._set(multiplierValue=value) 46 | 47 | def _transform( 48 | self, 49 | df : DataFrame 50 | ): 51 | value = self.getMultiplierValue() 52 | inputCol = self.getInputCol() 53 | outputColumn = self.getOutputCol() 54 | 55 | 56 | return df.withColumn(outputColumn, sf.lit(value)*sf.col(inputCol)) 57 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "sim4rec" 3 | version = "0.0.2" 4 | description = "Simulator for recommendation algorithms" 5 | authors = ["Alexey Vasilev", 6 | "Anna Volodkevich", 7 | "Andrey Gurov", 8 | "Elizaveta Stavinova", 9 | "Anton Lysenko"] 10 | packages = [ 11 | { include = "sim4rec" } 12 | ] 13 | readme = "README.md" 14 | repository = "https://github.com/sb-ai-lab/Sim4Rec" 15 | 16 | [tool.poetry.dependencies] 17 | python = ">=3.8, <3.10" 18 | pyarrow = "*" 19 | sdv = "0.15.0" 20 | torch = "*" 21 | pandas = "*" 22 | pyspark = ">=3.0" 23 | numpy = ">=1.20.0" 24 | scipy = "*" 25 | 26 | [tool.poetry.dev-dependencies] 27 | # visualization 28 | jupyter = "*" 29 | jupyterlab = "*" 30 | matplotlib = "*" 31 | # testing 32 | pytest-cov = "*" 33 | pycodestyle = "*" 34 | pylint = "*" 35 | # docs 36 | Sphinx = "*" 37 | sphinx-rtd-theme = "*" 38 | sphinx-autodoc-typehints = "*" 39 | ghp-import = "*" 40 | 41 | [build-system] 42 | requires = ["poetry-core>=1.0.0"] 43 | build-backend = "poetry.core.masonry.api" 44 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | 2 | [pytest] 3 | minversion = 6.0 4 | log_cli = 1 5 | log_cli_level = ERROR 6 | log_cli_format = %(asctime)s [%(levelname)8s] [%(name)s] %(message)s (%(filename)s:%(lineno)s) 7 | log_cli_date_format = %H:%M:%S 8 | 9 | addopts = 10 | --capture=tee-sys -q 11 | -m 'not mycandidate' 12 | 13 | testpaths = 14 | tests 15 | filterwarnings = 16 | ignore::DeprecationWarning 17 | markers = 18 | mycandidate: test mycandidate api integration (disabled by default) -------------------------------------------------------------------------------- /sim4rec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sb-ai-lab/Sim4Rec/267b644932b0e120cc9887f0f3a16be38d5739b7/sim4rec/__init__.py -------------------------------------------------------------------------------- /sim4rec/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import ( 2 | GeneratorBase, 3 | RealDataGenerator, 4 | SDVDataGenerator, 5 | CompositeGenerator 6 | ) 7 | from .selectors import ( 8 | ItemSelectionEstimator, 9 | ItemSelectionTransformer, 10 | CrossJoinItemEstimator, 11 | CrossJoinItemTransformer 12 | ) 13 | from .simulator import Simulator 14 | from .embeddings import ( 15 | EncoderEstimator, 16 | EncoderTransformer 17 | ) 18 | from .evaluation import ( 19 | evaluate_synthetic, 20 | EvaluateMetrics, 21 | ks_test, 22 | kl_divergence 23 | ) 24 | 25 | __all__ = [ 26 | 'GeneratorBase', 27 | 'RealDataGenerator', 28 | 'SDVDataGenerator', 29 | 'CompositeGenerator', 30 | 'ItemSelectionEstimator', 31 | 'ItemSelectionTransformer', 32 | 'CrossJoinItemEstimator', 33 | 'CrossJoinItemTransformer', 34 | 'Simulator', 35 | 'EncoderEstimator', 36 | 'EncoderTransformer', 37 | 'evaluate_synthetic', 38 | 'EvaluateMetrics', 39 | 'ks_test', 40 | 'kl_divergence' 41 | ] 42 | -------------------------------------------------------------------------------- /sim4rec/modules/embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pandas as pd 4 | import torch 5 | import torch.optim as opt 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | 9 | import pyspark.sql.types as st 10 | from pyspark.sql import DataFrame 11 | from pyspark.ml import Transformer, Estimator 12 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable 13 | from pyspark.ml.param.shared import HasInputCols, HasOutputCols 14 | 15 | from sim4rec.params import HasDevice, HasSeed 16 | 17 | 18 | class Encoder(torch.nn.Module): 19 | """ 20 | Encoder layer 21 | """ 22 | def __init__( 23 | self, 24 | input_dim : int, 25 | hidden_dim : int, 26 | latent_dim : int 27 | ): 28 | super().__init__() 29 | 30 | input_dims = [input_dim, hidden_dim, latent_dim] 31 | self._layers = torch.nn.ModuleList([ 32 | torch.nn.Linear(_in, _out) 33 | for _in, _out in zip(input_dims[:-1], input_dims[1:]) 34 | ]) 35 | 36 | def forward(self, X): 37 | """ 38 | Performs forward pass through layer 39 | """ 40 | X = F.normalize(X, p=2) 41 | for layer in self._layers[:-1]: 42 | X = layer(X) 43 | X = F.leaky_relu(X) 44 | 45 | X = self._layers[-1](X) 46 | 47 | return X 48 | 49 | 50 | class Decoder(torch.nn.Module): 51 | """ 52 | Decoder layer 53 | """ 54 | def __init__( 55 | self, 56 | input_dim : int, 57 | hidden_dim : int, 58 | latent_dim : int 59 | ): 60 | super().__init__() 61 | 62 | input_dims = [latent_dim, hidden_dim, input_dim] 63 | self._layers = torch.nn.ModuleList([ 64 | torch.nn.Linear(_in, _out) 65 | for _in, _out in zip(input_dims[:-1], input_dims[1:]) 66 | ]) 67 | 68 | def forward(self, X): 69 | """ 70 | Performs forward pass through layer 71 | """ 72 | for layer in self._layers[:-1]: 73 | X = layer(X) 74 | X = F.leaky_relu(X) 75 | 76 | X = self._layers[-1](X) 77 | 78 | return X 79 | 80 | 81 | # pylint: disable=too-many-ancestors 82 | class EncoderEstimator(Estimator, 83 | HasInputCols, 84 | HasOutputCols, 85 | HasDevice, 86 | HasSeed, 87 | DefaultParamsReadable, 88 | DefaultParamsWritable): 89 | """ 90 | Estimator for encoder part of the autoencoder pipeline. Trains 91 | the encoder to process incoming data into latent representation 92 | """ 93 | # pylint: disable=too-many-arguments 94 | def __init__( 95 | self, 96 | inputCols : List[str], 97 | outputCols : List[str], 98 | hidden_dim : int, 99 | lr : float, 100 | batch_size : int, 101 | num_loader_workers : int, 102 | max_iter : int = 100, 103 | device_name : str = 'cpu', 104 | seed : int = None 105 | ): 106 | """ 107 | :param inputCols: Column names to process 108 | :param outputCols: List of output column names per latent coordinate. 109 | The length of outputCols will determine the embedding dimension size 110 | :param hidden_dim: Size of hidden layers 111 | :param lr: Learning rate 112 | :param batch_size: Batch size during training process 113 | :param num_loader_workers: Number of cpus to use for data loader 114 | :param max_iter: Maximum number of iterations, defaults to 100 115 | :param device_name: PyTorch device name, defaults to 'cpu' 116 | """ 117 | 118 | super().__init__() 119 | 120 | self._set(inputCols=inputCols, outputCols=outputCols) 121 | self.setDevice(device_name) 122 | self.setSeed(seed) 123 | 124 | self._input_dim = len(inputCols) 125 | self._hidden_dim = hidden_dim 126 | self._latent_dim = len(outputCols) 127 | 128 | self._lr = lr 129 | self._batch_size = batch_size 130 | self._num_loader_workers = num_loader_workers 131 | self._max_iter = max_iter 132 | 133 | # pylint: disable=too-many-locals, not-callable 134 | def _fit( 135 | self, 136 | dataset : DataFrame 137 | ): 138 | inputCols = self.getInputCols() 139 | outputCols = self.getOutputCols() 140 | device_name = self.getDevice() 141 | seed = self.getSeed() 142 | device = torch.device(self.getDevice()) 143 | # pylint: disable=not-an-iterable 144 | X = dataset.select(*inputCols).toPandas().values 145 | 146 | torch.manual_seed(torch.seed() if seed is None else seed) 147 | 148 | train_loader = DataLoader(X, batch_size=self._batch_size, 149 | shuffle=True, num_workers=self._num_loader_workers) 150 | 151 | encoder = Encoder( 152 | input_dim=self._input_dim, 153 | hidden_dim=self._hidden_dim, 154 | latent_dim=self._latent_dim 155 | ) 156 | decoder = Decoder( 157 | input_dim=self._input_dim, 158 | hidden_dim=self._hidden_dim, 159 | latent_dim=self._latent_dim 160 | ) 161 | 162 | model = torch.nn.Sequential(encoder, decoder).to(torch.device(self.getDevice())) 163 | 164 | optimizer = opt.Adam(model.parameters(), lr=self._lr) 165 | crit = torch.nn.MSELoss() 166 | 167 | for _ in range(self._max_iter): 168 | loss = 0 169 | for X_batch in train_loader: 170 | X_batch = X_batch.float().to(device) 171 | 172 | optimizer.zero_grad() 173 | 174 | pred = model(X_batch) 175 | train_loss = crit(pred, X_batch) 176 | 177 | train_loss.backward() 178 | optimizer.step() 179 | 180 | loss += train_loss.item() 181 | 182 | torch.manual_seed(torch.seed()) 183 | 184 | return EncoderTransformer( 185 | inputCols=inputCols, 186 | outputCols=outputCols, 187 | encoder=encoder, 188 | device_name=device_name 189 | ) 190 | 191 | 192 | class EncoderTransformer(Transformer, 193 | HasInputCols, 194 | HasOutputCols, 195 | HasDevice, 196 | DefaultParamsReadable, 197 | DefaultParamsWritable): 198 | """ 199 | Encoder transformer that transforms incoming columns into latent 200 | representation. Output data will be appended to dataframe and 201 | named according to outputCols parameter 202 | """ 203 | def __init__( 204 | self, 205 | inputCols : List[str], 206 | outputCols : List[str], 207 | encoder : Encoder, 208 | device_name : str = 'cpu' 209 | ): 210 | """ 211 | :param inputCols: Column names to process 212 | :param outputCols: List of output column names per latent coordinate. 213 | The length of outputCols must be equal to embedding dimension of 214 | a trained encoder 215 | :param encoder: Trained encoder 216 | :param device_name: PyTorch device name, defaults to 'cpu' 217 | """ 218 | 219 | super().__init__() 220 | 221 | self._set(inputCols=inputCols, outputCols=outputCols) 222 | self._encoder = encoder 223 | self.setDevice(device_name) 224 | 225 | def setDevice(self, value): 226 | super().setDevice(value) 227 | 228 | self._encoder.to(torch.device(value)) 229 | 230 | # pylint: disable=not-callable 231 | def _transform( 232 | self, 233 | dataset : DataFrame 234 | ): 235 | inputCols = self.getInputCols() 236 | outputCols = self.getOutputCols() 237 | device = torch.device(self.getDevice()) 238 | 239 | encoder = self._encoder 240 | 241 | @torch.no_grad() 242 | def encode(iterator): 243 | for pdf in iterator: 244 | X = torch.tensor(pdf.loc[:, inputCols].values).float().to(device) 245 | yield pd.DataFrame( 246 | data=encoder(X).cpu().numpy(), 247 | columns=outputCols 248 | ) 249 | 250 | schema = st.StructType( 251 | # pylint: disable=not-an-iterable 252 | [st.StructField(c, st.FloatType()) for c in outputCols] 253 | ) 254 | 255 | return dataset.mapInPandas(encode, schema) 256 | -------------------------------------------------------------------------------- /sim4rec/modules/evaluation.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import List, Union, Dict, Optional 3 | 4 | import numpy as np 5 | from scipy.stats import kstest 6 | # TL;DR scipy.special is a C library, pylint needs python source code 7 | # https://github.com/pylint-dev/pylint/issues/3703 8 | # pylint: disable=no-name-in-module 9 | from scipy.special import kl_div 10 | 11 | from pyspark.sql import DataFrame 12 | from pyspark.ml.evaluation import ( 13 | BinaryClassificationEvaluator, 14 | RegressionEvaluator, 15 | MulticlassClassificationEvaluator 16 | ) 17 | 18 | from sdv.evaluation import evaluate 19 | 20 | 21 | def evaluate_synthetic( 22 | synth_df : DataFrame, 23 | real_df : DataFrame 24 | ) -> dict: 25 | """ 26 | Evaluates the quality of synthetic data against real. The following 27 | metrics will be calculated: 28 | 29 | - LogisticDetection: The metric evaluates how hard it is to distinguish the synthetic 30 | data from the real data by using a Logistic regression model 31 | - SVCDetection: The metric evaluates how hard it is to distinguish the synthetic data 32 | from the real data by using a C-Support Vector Classification model 33 | - KSTest: This metric uses the two-sample Kolmogorov-Smirnov test to compare 34 | the distributions of continuous columns using the empirical CDF 35 | - ContinuousKLDivergence: This approximates the KL divergence by binning the continuous values 36 | to turn them into categorical values and then computing the relative entropy 37 | 38 | :param synth_df: Synthetic data without any identifiers 39 | :param real_df: Real data without any identifiers 40 | :return: Dictionary with metrics on synthetic data quality 41 | """ 42 | 43 | result = evaluate( 44 | synthetic_data=synth_df.toPandas(), 45 | real_data=real_df.toPandas(), 46 | metrics=[ 47 | 'LogisticDetection', 48 | 'SVCDetection', 49 | 'KSTest', 50 | 'ContinuousKLDivergence' 51 | ], 52 | aggregate=False 53 | ) 54 | 55 | return { 56 | row['metric'] : row['normalized_score'] 57 | for _, row in result.iterrows() 58 | } 59 | 60 | 61 | def ks_test( 62 | df : DataFrame, 63 | predCol : str, 64 | labelCol : str 65 | ) -> float: 66 | """ 67 | Kolmogorov-Smirnov test on two dataframe columns 68 | 69 | :param df: Dataframe with two target columns 70 | :param predCol: Column name with values to test 71 | :param labelCol: Column name with values to test against 72 | :return: Result of KS test 73 | """ 74 | 75 | pdf = df.select(predCol, labelCol).toPandas() 76 | rvs, cdf = pdf[predCol].values, pdf[labelCol].values 77 | 78 | return kstest(rvs, cdf).statistic 79 | 80 | 81 | def kl_divergence( 82 | df : DataFrame, 83 | predCol : str, 84 | labelCol : str 85 | ) -> float: 86 | """ 87 | Normalized Kullback–Leibler divergence on two dataframe columns. The normalization is 88 | as follows: 89 | 90 | .. math:: 91 | \\frac{1}{1 + KL\_div} 92 | 93 | :param df: Dataframe with two target columns 94 | :param predCol: First column name 95 | :param labelCol: Second column name 96 | :return: Result of KL divergence 97 | """ 98 | 99 | pdf = df.select(predCol, labelCol).toPandas() 100 | predicted, ground_truth = pdf[predCol].values, pdf[labelCol].values 101 | 102 | f_obs, edges = np.histogram(ground_truth) 103 | f_exp, _ = np.histogram(predicted, bins=edges) 104 | 105 | f_obs = f_obs.flatten() + 1e-5 106 | f_exp = f_exp.flatten() + 1e-5 107 | 108 | return 1 / (1 + np.sum(kl_div(f_obs, f_exp))) 109 | 110 | 111 | # pylint: disable=too-few-public-methods 112 | class EvaluateMetrics(ABC): 113 | """ 114 | Recommendation systems and response function metric evaluator class. 115 | The class allows you to evaluate the quality of a response function on 116 | historical data or a recommender system on historical data or based on 117 | the results of an experiment in a simulator. Provides simultaneous 118 | calculation of several metrics using metrics from the Spark MLlib library. 119 | A created instance is callable on a dataframe with ``user_id, item_id, 120 | predicted relevance/response, true relevance/response`` format, which 121 | you can usually retrieve from simulators sample_responses() or log data 122 | with recommendation algorithm scores. 123 | """ 124 | 125 | REGRESSION_METRICS = set(['rmse', 'mse', 'r2', 'mae', 'var']) 126 | MULTICLASS_METRICS = set([ 127 | 'f1', 'accuracy', 'weightedPrecision', 'weightedRecall', 128 | 'weightedTruePositiveRate', 'weightedFalsePositiveRate', 129 | 'weightedFMeasure', 'truePositiveRateByLabel', 'falsePositiveRateByLabel', 130 | 'precisionByLabel', 'recallByLabel', 'fMeasureByLabel', 131 | 'logLoss', 'hammingLoss' 132 | ]) 133 | BINARY_METRICS = set(['areaUnderROC', 'areaUnderPR']) 134 | 135 | # pylint: disable=too-many-arguments 136 | def __init__( 137 | self, 138 | userKeyCol : str, 139 | itemKeyCol : str, 140 | predictionCol : str, 141 | labelCol : str, 142 | mllib_metrics : Optional[Union[str, List[str]]] = None 143 | ): 144 | """ 145 | :param userKeyCol: User identifier column name 146 | :param itemKeyCol: Item identifier column name 147 | :param predictionCol: Predicted scores column name 148 | :param labelCol: True label column name 149 | :param mllib_metrics: Metrics to calculate from spark's mllib. See 150 | REGRESSION_METRICS, MULTICLASS_METRICS, BINARY_METRICS for available 151 | values, defaults to None 152 | """ 153 | 154 | super().__init__() 155 | 156 | self._userKeyCol = userKeyCol 157 | self._itemKeyCol = itemKeyCol 158 | self._predictionCol = predictionCol 159 | self._labelCol = labelCol 160 | 161 | if isinstance(mllib_metrics, str): 162 | mllib_metrics = [mllib_metrics] 163 | 164 | if mllib_metrics is None: 165 | mllib_metrics = [] 166 | 167 | self._mllib_metrics = mllib_metrics 168 | 169 | def __call__( 170 | self, 171 | df : DataFrame 172 | ) -> Dict[str, float]: 173 | """ 174 | Performs metrics calculations on passed dataframe 175 | 176 | :param df: Spark dataframe with userKeyCol, itemKeyCol, predictionCol 177 | and labelCol columns 178 | :return: Dictionary with metrics 179 | """ 180 | 181 | df = df.withColumnRenamed(self._userKeyCol, 'user_idx')\ 182 | .withColumnRenamed(self._itemKeyCol, 'item_idx') 183 | 184 | result = {} 185 | 186 | for m in self._mllib_metrics: 187 | evaluator = self._get_evaluator(m) 188 | result[m] = evaluator.evaluate(df) 189 | 190 | return result 191 | 192 | def _reg_or_multiclass_params(self): 193 | return {'predictionCol' : self._predictionCol, 'labelCol' : self._labelCol} 194 | 195 | def _binary_params(self): 196 | return {'rawPredictionCol' : self._predictionCol, 'labelCol' : self._labelCol} 197 | 198 | def _get_evaluator(self, metric): 199 | if metric in self.REGRESSION_METRICS: 200 | return RegressionEvaluator( 201 | metricName=metric, **self._reg_or_multiclass_params()) 202 | if metric in self.BINARY_METRICS: 203 | return BinaryClassificationEvaluator( 204 | metricName=metric, **self._binary_params()) 205 | if metric in self.MULTICLASS_METRICS: 206 | return MulticlassClassificationEvaluator( 207 | metricName=metric, **self._reg_or_multiclass_params()) 208 | 209 | raise ValueError(f'Non existing metric was passed: {metric}') 210 | -------------------------------------------------------------------------------- /sim4rec/modules/selectors.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-member,unused-argument,too-many-ancestors,abstract-method 2 | from pyspark.sql import functions as sf 3 | from pyspark.sql import DataFrame 4 | from pyspark.ml import Transformer, Estimator 5 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable 6 | from pyspark import keyword_only 7 | from sim4rec.params import HasUserKeyColumn, HasItemKeyColumn, HasSeed, HasSeedSequence 8 | 9 | 10 | class ItemSelectionEstimator(Estimator, 11 | HasUserKeyColumn, 12 | HasItemKeyColumn, 13 | DefaultParamsReadable, 14 | DefaultParamsWritable): 15 | """ 16 | Base class for item selection estimator 17 | """ 18 | @keyword_only 19 | def __init__( 20 | self, 21 | userKeyColumn : str = None, 22 | itemKeyColumn : str = None 23 | ): 24 | super().__init__() 25 | self.setParams(**self._input_kwargs) 26 | 27 | @keyword_only 28 | def setParams( 29 | self, 30 | userKeyColumn : str = None, 31 | itemKeyColumn : str = None 32 | ): 33 | """ 34 | Sets Estimator parameters 35 | """ 36 | return self._set(**self._input_kwargs) 37 | 38 | 39 | class ItemSelectionTransformer(Transformer, 40 | HasUserKeyColumn, 41 | HasItemKeyColumn, 42 | DefaultParamsReadable, 43 | DefaultParamsWritable): 44 | """ 45 | Base class for item selection transformer. transform() 46 | will be used to create user-item pairs 47 | """ 48 | @keyword_only 49 | def __init__( 50 | self, 51 | userKeyColumn : str = None, 52 | itemKeyColumn : str = None 53 | ): 54 | super().__init__() 55 | self.setParams(**self._input_kwargs) 56 | 57 | @keyword_only 58 | def setParams( 59 | self, 60 | userKeyColumn : str = None, 61 | itemKeyColumn : str = None 62 | ): 63 | """ 64 | Sets Transformer parameters 65 | """ 66 | self._set(**self._input_kwargs) 67 | 68 | 69 | class CrossJoinItemEstimator(ItemSelectionEstimator, HasSeed): 70 | """ 71 | Assigns k items for every user from random items subsample 72 | """ 73 | def __init__( 74 | self, 75 | k : int, 76 | userKeyColumn : str = None, 77 | itemKeyColumn : str = None, 78 | seed : int = None 79 | ): 80 | """ 81 | :param k: Number of items for every user 82 | :param userKeyColumn: Users identifier column, defaults to None 83 | :param itemKeyColumn: Items identifier column, defaults to None 84 | :param seed: Random state seed, defaults to None 85 | """ 86 | 87 | super().__init__(userKeyColumn=userKeyColumn, 88 | itemKeyColumn=itemKeyColumn) 89 | 90 | self.setSeed(seed) 91 | 92 | self._k = k 93 | 94 | def _fit( 95 | self, 96 | dataset : DataFrame 97 | ): 98 | """ 99 | Fits estimator with items dataframe 100 | 101 | :param df: Items dataframe 102 | :returns: CrossJoinItemTransformer instance 103 | """ 104 | 105 | userKeyColumn = self.getUserKeyColumn() 106 | itemKeyColumn = self.getItemKeyColumn() 107 | seed = self.getSeed() 108 | 109 | if itemKeyColumn not in dataset.columns: 110 | raise ValueError(f'Dataframe has no {itemKeyColumn} column') 111 | 112 | return CrossJoinItemTransformer( 113 | item_df=dataset, 114 | k=self._k, 115 | userKeyColumn=userKeyColumn, 116 | itemKeyColumn=itemKeyColumn, 117 | seed=seed 118 | ) 119 | 120 | 121 | class CrossJoinItemTransformer(ItemSelectionTransformer, HasSeedSequence): 122 | """ 123 | Assigns k items for every user from random items subsample 124 | """ 125 | # pylint: disable=too-many-arguments 126 | def __init__( 127 | self, 128 | item_df : DataFrame, 129 | k : int, 130 | userKeyColumn : str = None, 131 | itemKeyColumn : str = None, 132 | seed : int = None 133 | ): 134 | super().__init__(userKeyColumn=userKeyColumn, 135 | itemKeyColumn=itemKeyColumn) 136 | 137 | self.initSeedSequence(seed) 138 | 139 | self._item_df = item_df 140 | self._k = k 141 | 142 | def _transform( 143 | self, 144 | dataset : DataFrame 145 | ): 146 | """ 147 | Takes a users dataframe and assings defined number of items 148 | 149 | :param df: Users dataframe 150 | :returns: Users cross join on random items subsample 151 | """ 152 | 153 | userKeyColumn = self.getUserKeyColumn() 154 | itemKeyColumn = self.getItemKeyColumn() 155 | seed = self.getNextSeed() 156 | 157 | if userKeyColumn not in dataset.columns: 158 | raise ValueError(f'Dataframe has no {userKeyColumn} column') 159 | 160 | random_items = self._item_df.orderBy(sf.rand(seed=seed))\ 161 | .limit(self._k) 162 | 163 | return dataset.select(userKeyColumn)\ 164 | .crossJoin(random_items.select(itemKeyColumn)) 165 | -------------------------------------------------------------------------------- /sim4rec/modules/simulator.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from abc import ABC 3 | from typing import Tuple, Union, Optional 4 | 5 | from pyspark.sql import SparkSession 6 | from pyspark.sql import DataFrame 7 | from pyspark.ml import Transformer, PipelineModel 8 | 9 | from sim4rec.utils.session_handler import State 10 | from sim4rec.modules.generator import GeneratorBase 11 | 12 | 13 | # pylint: disable=too-many-instance-attributes 14 | class Simulator(ABC): 15 | """ 16 | Simulator for recommendation systems, which uses the users 17 | and items data passed to it, to simulate the users responses 18 | to recommended items 19 | """ 20 | 21 | ITER_COLUMN = '__iter' 22 | DEFAULT_LOG_FILENAME = 'log.parquet' 23 | 24 | # pylint: disable=too-many-arguments 25 | def __init__( 26 | self, 27 | user_gen : GeneratorBase, 28 | item_gen : GeneratorBase, 29 | data_dir : str, 30 | log_df : DataFrame = None, 31 | user_key_col : str = 'user_idx', 32 | item_key_col : str = 'item_idx', 33 | spark_session : SparkSession = None 34 | ): 35 | """ 36 | :param user_gen: Users data generator instance 37 | :param item_gen: Items data generator instance 38 | :param log_df: The history log with user-item pairs with other 39 | necessary fields. During the simulation the results will be 40 | appended to this log on update_log() call, defaults to None 41 | :param user_key_col: User identifier column name, defaults 42 | to 'user_idx' 43 | :param item_key_col: Item identifier column name, defaults 44 | to 'item_idx' 45 | :param data_dir: Directory name to save simulator data 46 | :param spark_session: Spark session to use, defaults to None 47 | """ 48 | 49 | self._spark = spark_session if spark_session is not None else State().session 50 | 51 | self._user_key_col = user_key_col 52 | self._item_key_col = item_key_col 53 | self._user_gen = user_gen 54 | self._item_gen = item_gen 55 | 56 | if data_dir is None: 57 | raise ValueError('Pass directory name as `data_dir` parameter') 58 | 59 | self._data_dir = data_dir 60 | pathlib.Path(self._data_dir).mkdir(parents=True, exist_ok=False) 61 | 62 | self._log_filename = self.DEFAULT_LOG_FILENAME 63 | 64 | self._log = None 65 | self._log_schema = None 66 | if log_df is not None: 67 | self.update_log(log_df, iteration='start') 68 | 69 | @property 70 | def log(self): 71 | """ 72 | Returns log 73 | """ 74 | return self._log 75 | 76 | @property 77 | def data_dir(self): 78 | """ 79 | Returns directory with saved simulator data 80 | """ 81 | return self._data_dir 82 | 83 | @data_dir.setter 84 | def data_dir(self, value): 85 | self._data_dir = value 86 | 87 | @property 88 | def log_filename(self): 89 | """ 90 | Returns name of log file 91 | """ 92 | return self._log_filename 93 | 94 | @log_filename.setter 95 | def log_filename(self, value): 96 | self._log_filename = value 97 | 98 | def clear_log( 99 | self 100 | ) -> None: 101 | """ 102 | Clears the log 103 | """ 104 | 105 | self._log = None 106 | self._log_schema = None 107 | 108 | @staticmethod 109 | def _check_names_and_types(df1_schema, df2_schema): 110 | """ 111 | Check if names of columns and their types are equal for two schema. 112 | `Nullable` parameter is not compared. 113 | 114 | """ 115 | df1_schema_s = sorted( 116 | [(x.name, x.dataType) for x in df1_schema], 117 | key=lambda x: (x[0], x[1]) 118 | ) 119 | df2_schema_s = sorted( 120 | [(x.name, x.dataType) for x in df2_schema], 121 | key=lambda x: (x[0], x[1]) 122 | ) 123 | names_diff = set(df1_schema_s).symmetric_difference(set(df2_schema_s)) 124 | 125 | if names_diff: 126 | raise ValueError( 127 | f'Columns of two dataframes are different.\nDifferences: \n' 128 | f'In the first dataframe:\n' 129 | f'{[name_type for name_type in df1_schema_s if name_type in names_diff]}\n' 130 | f'In the second dataframe:\n' 131 | f'{[name_type for name_type in df2_schema_s if name_type in names_diff]}' 132 | ) 133 | 134 | def update_log( 135 | self, 136 | log : DataFrame, 137 | iteration : Union[int, str] 138 | ) -> None: 139 | """ 140 | Appends the passed log to the existing one 141 | 142 | :param log: The log with user-item pairs with their respective 143 | necessary fields. If there was no log before this: remembers 144 | the log schema, to which the future logs will be compared. 145 | To reset current log and the schema see clear_log() 146 | :param iteration: Iteration label or index 147 | """ 148 | 149 | if self._log_schema is None: 150 | self._log_schema = log.schema.fields 151 | else: 152 | self._check_names_and_types(self._log_schema, log.schema) 153 | 154 | write_path = str( 155 | pathlib.Path(self._data_dir) 156 | .joinpath(f'{self.log_filename}/{self.ITER_COLUMN}={iteration}') 157 | ) 158 | log.write.parquet(write_path) 159 | 160 | read_path = str(pathlib.Path(self._data_dir).joinpath(f'{self.log_filename}')) 161 | self._log = self._spark.read.parquet(read_path) 162 | 163 | def sample_users( 164 | self, 165 | frac_users : float 166 | ) -> DataFrame: 167 | """ 168 | Samples a fraction of random users 169 | 170 | :param frac_users: Fractions of users to sample from user generator 171 | :returns: Sampled users dataframe 172 | """ 173 | 174 | return self._user_gen.sample(frac_users) 175 | 176 | def sample_items( 177 | self, 178 | frac_items : float 179 | ) -> DataFrame: 180 | """ 181 | Samples a fraction of random items 182 | 183 | :param frac_items: Fractions of items to sample from item generator 184 | :returns: Sampled users dataframe 185 | """ 186 | 187 | return self._item_gen.sample(frac_items) 188 | 189 | def get_log( 190 | self, 191 | user_df : DataFrame 192 | ) -> DataFrame: 193 | """ 194 | Returns log for users listed in passed users' dataframe 195 | 196 | :param user_df: Dataframe with user identifiers to get log for 197 | :return: Users' history log. Will return None, if there is no log data 198 | """ 199 | 200 | if self.log is not None: 201 | return self.log.join( 202 | user_df, on=self._user_key_col, how='leftsemi' 203 | ) 204 | 205 | return None 206 | 207 | def get_user_items( 208 | self, 209 | user_df : DataFrame, 210 | selector : Transformer 211 | ) -> Tuple[DataFrame, DataFrame]: 212 | """ 213 | Froms candidate pairs to pass to the recommendation algorithm based 214 | on the provided users 215 | 216 | :param user_df: Users dataframe with features and identifiers 217 | :param selector: Transformer to use for creating user-item pairs 218 | :returns: Tuple of user-item pairs and log dataframes which will 219 | be used by recommendation algorithm. Will return None as a log, 220 | if there is no log data 221 | """ 222 | 223 | log = self.get_log(user_df) 224 | pairs = selector.transform(user_df) 225 | 226 | return pairs, log 227 | 228 | def sample_responses( 229 | self, 230 | recs_df : DataFrame, 231 | action_models : PipelineModel, 232 | user_features : Optional[DataFrame] = None, 233 | item_features : Optional[DataFrame] = None 234 | ) -> DataFrame: 235 | """ 236 | Simulates the actions users took on their recommended items 237 | 238 | :param recs_df: Dataframe with recommendations. Must contain 239 | user's and item's identifier columns. Other columns will 240 | be ignored 241 | :param user_features: Users dataframe with features and identifiers, 242 | can be None 243 | :param item_features: Items dataframe with features and identifiers, 244 | can be None 245 | :param action_models: Spark pipeline to evaluate responses 246 | :returns: DataFrame with user-item pairs and the respective actions 247 | """ 248 | 249 | if user_features is not None: 250 | recs_df = recs_df.join(user_features, self._user_key_col, 'left') 251 | 252 | if item_features is not None: 253 | recs_df = recs_df.join(item_features, self._item_key_col, 'left') 254 | 255 | return action_models.transform(recs_df) 256 | -------------------------------------------------------------------------------- /sim4rec/params/__init__.py: -------------------------------------------------------------------------------- 1 | from .params import ( 2 | HasUserKeyColumn, 3 | HasItemKeyColumn, 4 | HasSeed, 5 | HasWeights, 6 | HasMean, 7 | HasStandardDeviation, 8 | HasClipNegative, 9 | HasConstantValue, 10 | HasLabel, 11 | HasDevice, 12 | HasDataSize, 13 | HasParallelizationLevel, 14 | HasSeedSequence 15 | ) 16 | 17 | __all__ = [ 18 | 'HasUserKeyColumn', 19 | 'HasItemKeyColumn', 20 | 'HasSeed', 21 | 'HasWeights', 22 | 'HasMean', 23 | 'HasStandardDeviation', 24 | 'HasClipNegative', 25 | 'HasConstantValue', 26 | 'HasLabel', 27 | 'HasDevice', 28 | 'HasDataSize', 29 | 'HasParallelizationLevel', 30 | 'HasSeedSequence' 31 | ] 32 | -------------------------------------------------------------------------------- /sim4rec/params/params.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | from pyspark.ml.param.shared import Params, Param, TypeConverters 4 | 5 | 6 | class HasUserKeyColumn(Params): 7 | """ 8 | Controls user identifier column name 9 | """ 10 | 11 | userKeyColumn = Param( 12 | Params._dummy(), 13 | "userKeyColumn", 14 | "User identifier column name", 15 | typeConverter=TypeConverters.toString 16 | ) 17 | 18 | def setUserKeyColumn(self, value): 19 | """ 20 | Sets user indentifier column name 21 | 22 | :param value: new column name 23 | """ 24 | return self._set(userKeyColumn=value) 25 | 26 | def getUserKeyColumn(self): 27 | """ 28 | Returns item indentifier column name 29 | """ 30 | return self.getOrDefault(self.userKeyColumn) 31 | 32 | 33 | class HasItemKeyColumn(Params): 34 | """ 35 | Controls item identifier column name 36 | """ 37 | 38 | itemKeyColumn = Param( 39 | Params._dummy(), 40 | "itemKeyColumn", 41 | "Item identifier column name", 42 | typeConverter=TypeConverters.toString 43 | ) 44 | 45 | def setItemKeyColumn(self, value): 46 | """ 47 | Sets item indentifier column name 48 | 49 | :param value: new column name 50 | """ 51 | return self._set(itemKeyColumn=value) 52 | 53 | def getItemKeyColumn(self): 54 | """ 55 | Returns item indentifier column name 56 | """ 57 | return self.getOrDefault(self.itemKeyColumn) 58 | 59 | 60 | class HasSeed(Params): 61 | """ 62 | Controls random state seed 63 | """ 64 | 65 | seed = Param( 66 | Params._dummy(), 67 | "seed", 68 | "Random state seed", 69 | typeConverter=TypeConverters.toInt 70 | ) 71 | 72 | def setSeed(self, value): 73 | """ 74 | Changes random state seed 75 | 76 | :param value: new random state seed 77 | """ 78 | return self._set(seed=value) 79 | 80 | def getSeed(self): 81 | """ 82 | Returns state seed 83 | """ 84 | return self.getOrDefault(self.seed) 85 | 86 | 87 | class HasSeedSequence(Params): 88 | """ 89 | Controls random state seed of sequence 90 | """ 91 | _rng : np.random.Generator 92 | 93 | current_seed = Param( 94 | Params._dummy(), 95 | "current_seed", 96 | "Random state seed sequence", 97 | typeConverter=TypeConverters.toInt 98 | ) 99 | 100 | init_seed = Param( 101 | Params._dummy(), 102 | "init_seed", 103 | "Sequence initial seed", 104 | typeConverter=TypeConverters.toInt 105 | ) 106 | 107 | def initSeedSequence(self, value): 108 | """ 109 | Sets initial random state seed of sequence 110 | 111 | :param value: new initial random state seed of sequence 112 | """ 113 | self._rng = np.random.default_rng(value) 114 | return self._set( 115 | init_seed=value if value is not None else -1, 116 | current_seed=self._rng.integers(0, sys.maxsize) 117 | ) 118 | 119 | def getInitSeed(self): 120 | """ 121 | Returns initial random state seed of sequence 122 | """ 123 | value = self.getOrDefault(self.init_seed) 124 | return None if value == -1 else value 125 | 126 | def getNextSeed(self): 127 | """ 128 | Returns current random state seed of sequence 129 | """ 130 | seed = self.getOrDefault(self.current_seed) 131 | self._set(current_seed=self._rng.integers(0, sys.maxsize)) 132 | return seed 133 | 134 | 135 | class HasWeights(Params): 136 | """ 137 | Controls weights for models ensemble 138 | """ 139 | 140 | weights = Param( 141 | Params._dummy(), 142 | "weights", 143 | "Weights for models ensemble", 144 | typeConverter=TypeConverters.toListFloat 145 | ) 146 | 147 | def setWeights(self, value): 148 | """ 149 | Changes weights for models ensemble 150 | 151 | :param value: new weights 152 | """ 153 | return self._set(weights=value) 154 | 155 | def getWeights(self): 156 | """ 157 | Returns weigths for models ensemble 158 | """ 159 | return self.getOrDefault(self.weights) 160 | 161 | 162 | class HasMean(Params): 163 | """ 164 | Controls mean parameter of normal distribution 165 | """ 166 | 167 | mean = Param( 168 | Params._dummy(), 169 | "mean", 170 | "Mean parameter of normal distribution", 171 | typeConverter=TypeConverters.toFloat 172 | ) 173 | 174 | def setMean(self, value): 175 | """ 176 | Changes mean parameter of normal distribution 177 | 178 | :param value: new value of mean parameter 179 | """ 180 | return self._set(mean=value) 181 | 182 | def getMean(self): 183 | """ 184 | Returns mean parameter 185 | """ 186 | return self.getOrDefault(self.mean) 187 | 188 | 189 | class HasStandardDeviation(Params): 190 | """ 191 | Controls Standard Deviation parameter of normal distribution 192 | """ 193 | 194 | std = Param( 195 | Params._dummy(), 196 | "std", 197 | "Standard Deviation parameter of normal distribution", 198 | typeConverter=TypeConverters.toFloat 199 | ) 200 | 201 | def setStandardDeviation(self, value): 202 | """ 203 | Changes Standard Deviation parameter of normal distribution 204 | 205 | :param value: new value of std parameter 206 | """ 207 | 208 | return self._set(std=value) 209 | 210 | def getStandardDeviation(self): 211 | """ 212 | Returns value of std parameter 213 | """ 214 | return self.getOrDefault(self.std) 215 | 216 | 217 | class HasClipNegative(Params): 218 | """ 219 | Controls flag that controls clipping of negative values 220 | """ 221 | 222 | clipNegative = Param( 223 | Params._dummy(), 224 | "clipNegative", 225 | "Boolean flag to clip negative values", 226 | typeConverter=TypeConverters.toBoolean 227 | ) 228 | 229 | def setClipNegative(self, value): 230 | """ 231 | Changes flag that controls clipping of negative values 232 | 233 | :param value: New value of flag 234 | """ 235 | return self._set(clipNegative=value) 236 | 237 | def getClipNegative(self): 238 | """ 239 | Returns flag that controls clipping of negative values 240 | """ 241 | return self.getOrDefault(self.clipNegative) 242 | 243 | 244 | class HasConstantValue(Params): 245 | """ 246 | Controls constant value parameter 247 | """ 248 | 249 | constantValue = Param( 250 | Params._dummy(), 251 | "constantValue", 252 | "Constant value parameter", 253 | typeConverter=TypeConverters.toFloat 254 | ) 255 | 256 | def setConstantValue(self, value): 257 | """ 258 | Sets constant value parameter 259 | 260 | :param value: Value 261 | """ 262 | return self._set(constantValue=value) 263 | 264 | def getConstantValue(self): 265 | """ 266 | Returns constant value 267 | """ 268 | return self.getOrDefault(self.constantValue) 269 | 270 | 271 | class HasLabel(Params): 272 | """ 273 | Controls string label 274 | """ 275 | label = Param( 276 | Params._dummy(), 277 | "label", 278 | "String label", 279 | typeConverter=TypeConverters.toString 280 | ) 281 | 282 | def setLabel(self, value): 283 | """ 284 | Sets string label 285 | 286 | :param value: Label 287 | """ 288 | return self._set(label=value) 289 | 290 | def getLabel(self): 291 | """ 292 | Returns current string label 293 | """ 294 | return self.getOrDefault(self.label) 295 | 296 | 297 | class HasDevice(Params): 298 | """ 299 | Controls device 300 | """ 301 | device = Param( 302 | Params._dummy(), 303 | "device", 304 | "Name of a device to use", 305 | typeConverter=TypeConverters.toString 306 | ) 307 | 308 | def setDevice(self, value): 309 | """ 310 | Sets device 311 | 312 | :param value: Name of device to use 313 | """ 314 | return self._set(device=value) 315 | 316 | def getDevice(self): 317 | """ 318 | Returns current device 319 | """ 320 | return self.getOrDefault(self.device) 321 | 322 | 323 | class HasDataSize(Params): 324 | """ 325 | Controls data size 326 | """ 327 | data_size = Param( 328 | Params._dummy(), 329 | "data_size", 330 | "Size of a DataFrame", 331 | typeConverter=TypeConverters.toInt 332 | ) 333 | 334 | def setDataSize(self, value): 335 | """ 336 | Sets data size to a certain value 337 | 338 | :param value: Size of a DataFrame 339 | """ 340 | return self._set(data_size=value) 341 | 342 | def getDataSize(self): 343 | """ 344 | Returns current size of a DataFrame 345 | """ 346 | return self.getOrDefault(self.data_size) 347 | 348 | 349 | class HasParallelizationLevel(Params): 350 | """ 351 | Controls parallelization level 352 | """ 353 | parallelizationLevel = Param( 354 | Params._dummy(), 355 | "parallelizationLevel", 356 | "Level of parallelization", 357 | typeConverter=TypeConverters.toInt 358 | ) 359 | 360 | def setParallelizationLevel(self, value): 361 | """ 362 | Sets level of parallelization 363 | 364 | :param value: Level of parallelization 365 | """ 366 | return self._set(parallelizationLevel=value) 367 | 368 | def getParallelizationLevel(self): 369 | """ 370 | Returns current level of parallelization 371 | """ 372 | return self.getOrDefault(self.parallelizationLevel) 373 | -------------------------------------------------------------------------------- /sim4rec/recommenders/ucb.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import math 3 | 4 | from os.path import join 5 | from typing import Any, Dict, List, Optional 6 | 7 | from abc import ABC 8 | import numpy as np 9 | import pandas as pd 10 | from numpy.random import default_rng 11 | 12 | from pyspark.sql import DataFrame, Window 13 | from pyspark.sql import functions as sf 14 | 15 | from sim4rec.recommenders.utils import REC_SCHEMA 16 | 17 | 18 | class UCB(ABC): 19 | """Simple bandit model, which caclulate item relevance as upper confidence bound 20 | (`UCB `_) 21 | for the confidence interval of true fraction of positive ratings. 22 | Should be used in iterative (online) mode to achive proper recommendation quality. 23 | ``relevance`` from log must be converted to binary 0-1 form. 24 | .. math:: 25 | pred_i = ctr_i + \\sqrt{\\frac{c\\ln{n}}{n_i}} 26 | :math:`pred_i` -- predicted relevance of item :math:`i` 27 | :math:`c` -- exploration coeficient 28 | :math:`n` -- number of interactions in log 29 | :math:`n_i` -- number of interactions with item :math:`i` 30 | """ 31 | 32 | can_predict_cold_users = True 33 | can_predict_cold_items = True 34 | item_popularity: DataFrame 35 | fill: float 36 | 37 | def __init__( 38 | self, 39 | exploration_coef: float = 2, 40 | sample: bool = False, 41 | seed: Optional[int] = None, 42 | ): 43 | """ 44 | :param exploration_coef: exploration coefficient 45 | :param sample: flag to choose recommendation strategy. 46 | If True, items are sampled with a probability proportional 47 | to the calculated predicted relevance 48 | :param seed: random seed. Provides reproducibility if fixed 49 | """ 50 | # pylint: disable=super-init-not-called 51 | self.coef = exploration_coef 52 | self.sample = sample 53 | self.seed = seed 54 | 55 | @property 56 | def _init_args(self): 57 | return { 58 | "exploration_coef": self.coef, 59 | "sample": self.sample, 60 | "seed": self.seed, 61 | } 62 | 63 | @property 64 | def _dataframes(self): 65 | return {"item_popularity": self.item_popularity} 66 | 67 | def _save_model(self, path: str): 68 | joblib.dump({"fill": self.fill}, join(path)) 69 | 70 | def _load_model(self, path: str): 71 | self.fill = joblib.load(join(path))["fill"] 72 | 73 | def fit( 74 | self, 75 | log: DataFrame, 76 | user_features: Optional[DataFrame] = None, 77 | item_features: Optional[DataFrame] = None, 78 | ) -> None: 79 | vals = log.select("relevance").where( 80 | (sf.col("relevance") != 1) & (sf.col("relevance") != 0) 81 | ) 82 | if vals.count() > 0: 83 | raise ValueError("Relevance values in log must be 0 or 1") 84 | 85 | items_counts = log.groupby("item_idx").agg( 86 | sf.sum("relevance").alias("pos"), 87 | sf.count("relevance").alias("total"), 88 | ) 89 | 90 | full_count = log.count() 91 | items_counts = items_counts.withColumn( 92 | "relevance", 93 | ( 94 | sf.col("pos") / sf.col("total") 95 | + sf.sqrt( 96 | sf.log(sf.lit(self.coef * full_count)) / sf.col("total") 97 | ) 98 | ), 99 | ) 100 | 101 | self.item_popularity = items_counts.drop("pos", "total") 102 | self.item_popularity.cache().count() 103 | 104 | self.fill = 1 + math.sqrt(math.log(self.coef * full_count)) 105 | 106 | def _clear_cache(self): 107 | if hasattr(self, "item_popularity"): 108 | self.item_popularity.unpersist() 109 | 110 | def _predict_with_sampling( 111 | self, 112 | log: DataFrame, 113 | item_popularity: DataFrame, 114 | k: int, 115 | users: DataFrame, 116 | filter_seen_items: bool = True, 117 | ): 118 | items_pd = item_popularity.withColumn( 119 | "probability", 120 | sf.col("relevance") 121 | / item_popularity.select(sf.sum("relevance")).first()[0], 122 | ).toPandas() 123 | 124 | seed = self.seed 125 | 126 | def grouped_map(pandas_df: pd.DataFrame) -> pd.DataFrame: 127 | user_idx = pandas_df["user_idx"][0] 128 | cnt = pandas_df["cnt"][0] 129 | 130 | if seed is not None: 131 | local_rng = default_rng(seed + user_idx) 132 | else: 133 | local_rng = default_rng() 134 | 135 | items_positions = local_rng.choice( 136 | np.arange(items_pd.shape[0]), 137 | size=cnt, 138 | p=items_pd["probability"].values, 139 | replace=False, 140 | ) 141 | 142 | return pd.DataFrame( 143 | { 144 | "user_idx": cnt * [user_idx], 145 | "item_idx": items_pd["item_idx"].values[items_positions], 146 | "relevance": items_pd["probability"].values[ 147 | items_positions 148 | ], 149 | } 150 | ) 151 | 152 | recs = users.withColumn("cnt", sf.lit(k)) 153 | if log is not None and filter_seen_items: 154 | recs = ( 155 | log.join(users, how="right", on="user_idx") 156 | .select("user_idx", "item_idx") 157 | .groupby("user_idx") 158 | .agg(sf.countDistinct("item_idx").alias("cnt")) 159 | .selectExpr( 160 | "user_idx", 161 | f"LEAST(cnt + {k}, {items_pd.shape[0]}) AS cnt", 162 | ) 163 | ) 164 | return recs.groupby("user_idx").applyInPandas(grouped_map, REC_SCHEMA) 165 | 166 | @staticmethod 167 | def _calc_max_hist_len(log, users): 168 | max_hist_len = ( 169 | ( 170 | log.join(users, on="user_idx") 171 | .groupBy("user_idx") 172 | .agg(sf.countDistinct("item_idx").alias("items_count")) 173 | ) 174 | .select(sf.max("items_count")) 175 | .collect()[0][0] 176 | ) 177 | # all users have empty history 178 | if max_hist_len is None: 179 | return 0 180 | return max_hist_len 181 | 182 | # pylint: disable=too-many-arguments 183 | def predict( 184 | self, 185 | log: DataFrame, 186 | k: int, 187 | users: DataFrame, 188 | items: DataFrame, 189 | user_features: Optional[DataFrame] = None, 190 | item_features: Optional[DataFrame] = None, 191 | filter_seen_items: bool = True, 192 | ) -> DataFrame: 193 | 194 | selected_item_popularity = self.item_popularity.join( 195 | items, 196 | on="item_idx", 197 | how="right", 198 | ).fillna(value=self.fill, subset=["relevance"]) 199 | 200 | if self.sample: 201 | return self._predict_with_sampling( 202 | log=log, 203 | item_popularity=selected_item_popularity, 204 | k=k, 205 | users=users, 206 | filter_seen_items=filter_seen_items, 207 | ) 208 | 209 | selected_item_popularity = selected_item_popularity.withColumn( 210 | "rank", 211 | sf.row_number().over( 212 | Window.orderBy( 213 | sf.col("relevance").desc(), sf.col("item_idx").desc() 214 | ) 215 | ), 216 | ) 217 | 218 | max_hist_len = ( 219 | self._calc_max_hist_len(log, users) 220 | if filter_seen_items and log is not None 221 | else 0 222 | ) 223 | 224 | return users.crossJoin( 225 | selected_item_popularity.filter(sf.col("rank") <= k + max_hist_len) 226 | ).drop("rank") 227 | 228 | def _predict_pairs( 229 | self, 230 | pairs: DataFrame, 231 | log: Optional[DataFrame] = None, 232 | user_features: Optional[DataFrame] = None, 233 | item_features: Optional[DataFrame] = None, 234 | ) -> DataFrame: 235 | return pairs.join( 236 | self.item_popularity, on="item_idx", how="left" 237 | ).fillna(value=self.fill, subset=["relevance"]) -------------------------------------------------------------------------------- /sim4rec/recommenders/utils.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType, TimestampType 2 | 3 | 4 | def get_schema( 5 | query_column: str = "query_id", 6 | item_column: str = "item_id", 7 | timestamp_column: str = "timestamp", 8 | rating_column: str = "rating", 9 | has_timestamp: bool = True, 10 | has_rating: bool = True, 11 | ): 12 | """ 13 | Get Spark Schema with query_id, item_id, rating, timestamp columns 14 | 15 | :param query_column: column name with query ids 16 | :param item_column: column name with item ids 17 | :param timestamp_column: column name with timestamps 18 | :param rating_column: column name with ratings 19 | :param has_rating: flag to add rating to schema 20 | :param has_timestamp: flag to add tomestamp to schema 21 | """ 22 | base = [ 23 | StructField(query_column, IntegerType()), 24 | StructField(item_column, IntegerType()), 25 | ] 26 | if has_timestamp: 27 | base += [StructField(timestamp_column, TimestampType())] 28 | if has_rating: 29 | base += [StructField(rating_column, DoubleType())] 30 | return StructType(base) 31 | 32 | REC_SCHEMA = get_schema( 33 | query_column="user_idx", 34 | item_column="item_idx", 35 | rating_column="relevance", 36 | has_timestamp=False, 37 | ) 38 | -------------------------------------------------------------------------------- /sim4rec/response/__init__.py: -------------------------------------------------------------------------------- 1 | from .response import ( 2 | ActionModelEstimator, 3 | ActionModelTransformer, 4 | ConstantResponse, 5 | NoiseResponse, 6 | CosineSimilatiry, 7 | BernoulliResponse, 8 | ParametricResponseFunction 9 | ) 10 | 11 | __all__ = [ 12 | 'ActionModelEstimator', 13 | 'ActionModelTransformer', 14 | 'ConstantResponse', 15 | 'NoiseResponse', 16 | 'CosineSimilatiry', 17 | 'BernoulliResponse', 18 | 'ParametricResponseFunction' 19 | ] 20 | -------------------------------------------------------------------------------- /sim4rec/response/response.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-member,unused-argument,too-many-ancestors,abstract-method 2 | import math 3 | from typing import List 4 | from collections.abc import Iterable 5 | 6 | import pyspark.sql.types as st 7 | import pyspark.sql.functions as sf 8 | from pyspark.ml import Transformer, Estimator 9 | from pyspark.ml.param.shared import HasInputCols, HasInputCol, HasOutputCol 10 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable 11 | 12 | from pyspark.sql import DataFrame 13 | from pyspark import keyword_only 14 | 15 | from sim4rec.params import ( 16 | HasWeights, HasSeedSequence, 17 | HasConstantValue, HasClipNegative, 18 | HasMean, HasStandardDeviation 19 | ) 20 | 21 | 22 | class ActionModelEstimator(Estimator, 23 | HasOutputCol, 24 | DefaultParamsReadable, 25 | DefaultParamsWritable): 26 | """ 27 | Base class for response estimator 28 | """ 29 | @keyword_only 30 | def __init__( 31 | self, 32 | outputCol : str = None 33 | ): 34 | """ 35 | :param outputCol: Name of the response score column, defaults 36 | to None 37 | """ 38 | 39 | super().__init__() 40 | self.setParams(**self._input_kwargs) 41 | 42 | @keyword_only 43 | def setParams( 44 | self, 45 | outputCol : str = None 46 | ): 47 | """ 48 | Sets parameters for response estimator 49 | """ 50 | return self._set(**self._input_kwargs) 51 | 52 | 53 | class ActionModelTransformer(Transformer, 54 | HasOutputCol, 55 | DefaultParamsReadable, 56 | DefaultParamsWritable): 57 | """ 58 | Base class for response transformer. transform() will be 59 | used to calculate score based on inputCols, and write it 60 | to outputCol column 61 | """ 62 | @keyword_only 63 | def __init__( 64 | self, 65 | outputCol : str = None 66 | ): 67 | """ 68 | :param outputCol: Name of the response score column, defaults 69 | to None 70 | """ 71 | 72 | super().__init__() 73 | self.setParams(**self._input_kwargs) 74 | 75 | @keyword_only 76 | def setParams( 77 | self, 78 | outputCol : str = None 79 | ): 80 | """ 81 | Sets parameters for response transformer 82 | """ 83 | return self._set(**self._input_kwargs) 84 | 85 | 86 | class BernoulliResponse(ActionModelTransformer, 87 | HasInputCol, 88 | HasSeedSequence): 89 | """ 90 | Samples responses from probability column 91 | """ 92 | def __init__( 93 | self, 94 | inputCol : str = None, 95 | outputCol : str = None, 96 | seed : int = None 97 | ): 98 | """ 99 | :param inputCol: Probability column name. Probability should 100 | be in range [0; 1] 101 | :param outputCol: Output column name 102 | :param seed: Random state seed, defaults to None 103 | """ 104 | 105 | super().__init__(outputCol=outputCol) 106 | 107 | self._set(inputCol=inputCol) 108 | self.initSeedSequence(seed) 109 | 110 | def _transform( 111 | self, 112 | dataset : DataFrame 113 | ): 114 | inputCol = self.getInputCol() 115 | outputCol = self.getOutputCol() 116 | seed = self.getNextSeed() 117 | 118 | return dataset.withColumn( 119 | outputCol, 120 | sf.when(sf.rand(seed=seed) <= sf.col(inputCol), 1).otherwise(0) 121 | ) 122 | 123 | 124 | class NoiseResponse(ActionModelTransformer, 125 | HasMean, 126 | HasStandardDeviation, 127 | HasClipNegative, 128 | HasSeedSequence): 129 | # pylint: disable=too-many-arguments 130 | """ 131 | Creates random response sampled from normal distribution 132 | """ 133 | def __init__( 134 | self, 135 | mu : float = None, 136 | sigma : float = None, 137 | outputCol : str = None, 138 | clipNegative : bool = True, 139 | seed : int = None 140 | ): 141 | """ 142 | :param mu: Mean parameter of normal distribution 143 | :param sigma: Standard deviation parameter of normal distribution 144 | :param outputCol: Output column name 145 | :param clip_negative: Whether to make response non-negative, 146 | defaults to True 147 | :param seed: Random state seed, defaults to None 148 | """ 149 | 150 | super().__init__(outputCol=outputCol) 151 | 152 | self._set(mean=mu, std=sigma, clipNegative=clipNegative) 153 | self.initSeedSequence(seed) 154 | 155 | def _transform( 156 | self, 157 | dataset : DataFrame 158 | ): 159 | mu = self.getMean() 160 | sigma = self.getStandardDeviation() 161 | clip_negative = self.getClipNegative() 162 | outputCol = self.getOutputCol() 163 | seed = self.getNextSeed() 164 | 165 | expr = sf.randn(seed=seed) * sigma + mu 166 | if clip_negative: 167 | expr = sf.greatest(expr, sf.lit(0)) 168 | 169 | return dataset.withColumn(outputCol, expr) 170 | 171 | 172 | class ConstantResponse(ActionModelTransformer, 173 | HasConstantValue): 174 | """ 175 | Always returns constant valued response 176 | """ 177 | def __init__( 178 | self, 179 | value : float = 0.0, 180 | outputCol : str = None 181 | ): 182 | """ 183 | :param value: Response value 184 | :param outputCol: Output column name 185 | """ 186 | 187 | super().__init__(outputCol=outputCol) 188 | 189 | self._set(constantValue=value) 190 | 191 | def _transform( 192 | self, 193 | dataset : DataFrame 194 | ): 195 | value = self.getConstantValue() 196 | outputColumn = self.getOutputCol() 197 | 198 | return dataset.withColumn(outputColumn, sf.lit(value)) 199 | 200 | 201 | class CosineSimilatiry(ActionModelTransformer, 202 | HasInputCols): 203 | """ 204 | Calculates the cosine similarity between two vectors. 205 | The result is in [0; 1] range 206 | """ 207 | def __init__( 208 | self, 209 | inputCols : List[str] = None, 210 | outputCol : str = None 211 | ): 212 | """ 213 | :param inputCols: Two column names with dense vectors 214 | :param outputCol: Output column name 215 | """ 216 | 217 | if inputCols is not None and len(inputCols) != 2: 218 | raise ValueError('There must be two array columns ' 219 | 'to calculate cosine similarity') 220 | 221 | super().__init__(outputCol=outputCol) 222 | self._set(inputCols=inputCols) 223 | 224 | def _transform( 225 | self, 226 | dataset : DataFrame 227 | ): 228 | inputCols = self.getInputCols() 229 | outputCol = self.getOutputCol() 230 | 231 | def cosine_similarity(first, second): 232 | num = first.dot(second) 233 | den = first.norm(2) * second.norm(2) 234 | 235 | if den == 0: 236 | return float(0) 237 | 238 | cosine = max(min(num / den, 1.0), -1.0) 239 | return float(1 - math.acos(cosine) / math.pi) 240 | 241 | cos_udf = sf.udf(cosine_similarity, st.DoubleType()) 242 | 243 | return dataset.withColumn( 244 | outputCol, 245 | # pylint: disable=unsubscriptable-object 246 | cos_udf(sf.col(inputCols[0]), sf.col(inputCols[1])) 247 | ) 248 | 249 | 250 | class ParametricResponseFunction(ActionModelTransformer, 251 | HasInputCols, 252 | HasWeights): 253 | """ 254 | Calculates response based on the weighted sum of input responses 255 | """ 256 | def __init__( 257 | self, 258 | inputCols : List[str] = None, 259 | outputCol : str = None, 260 | weights : Iterable = None 261 | ): 262 | """ 263 | :param inputCols: Input responses column names 264 | :param outputCol: Output column name 265 | :param weights: Input responses weights 266 | """ 267 | 268 | super().__init__(outputCol=outputCol) 269 | self._set(inputCols=inputCols, weights=weights) 270 | 271 | def _transform( 272 | self, 273 | dataset : DataFrame 274 | ): 275 | inputCols = self.getInputCols() 276 | outputCol = self.getOutputCol() 277 | weights = self.getWeights() 278 | 279 | return dataset.withColumn( 280 | outputCol, 281 | sum([ 282 | # pylint: disable=unsubscriptable-object 283 | sf.col(c) * weights[i] 284 | for i, c in enumerate(inputCols) 285 | ]) 286 | ) 287 | -------------------------------------------------------------------------------- /sim4rec/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .uce import ( 2 | VectorElementExtractor, 3 | NotFittedError, 4 | EmptyDataFrameError, 5 | save, 6 | load 7 | ) 8 | from .convert import pandas_to_spark 9 | 10 | from .session_handler import get_spark_session, State 11 | 12 | __all__ = [ 13 | 'VectorElementExtractor', 14 | 'NotFittedError', 15 | 'EmptyDataFrameError', 16 | 'State', 17 | 'save', 18 | 'load', 19 | 'pandas_to_spark', 20 | 'get_spark_session' 21 | ] 22 | -------------------------------------------------------------------------------- /sim4rec/utils/convert.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pyspark.sql import SparkSession, DataFrame 3 | 4 | from sim4rec.utils.session_handler import State 5 | 6 | 7 | def pandas_to_spark( 8 | df: pd.DataFrame, 9 | schema=None, 10 | spark_session : SparkSession = None) -> DataFrame: 11 | """ 12 | Converts pandas DataFrame to spark DataFrame 13 | 14 | :param df: DataFrame to convert 15 | :param schema: Schema of the dataframe, defaults to None 16 | :param spark_session: Spark session to use, defaults to None 17 | :returns: data converted to spark DataFrame 18 | """ 19 | 20 | if not isinstance(df, pd.DataFrame): 21 | raise ValueError('df must be an instance of pd.DataFrame') 22 | 23 | if len(df) == 0: 24 | raise ValueError('Dataframe is empty') 25 | 26 | if spark_session is not None: 27 | spark = spark_session 28 | else: 29 | spark = State().session 30 | 31 | return spark.createDataFrame(df, schema=schema) 32 | -------------------------------------------------------------------------------- /sim4rec/utils/session_handler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Painless creation and retrieval of Spark sessions 3 | """ 4 | 5 | import os 6 | import sys 7 | from math import floor 8 | from typing import Any, Dict, Optional 9 | 10 | import psutil 11 | import torch 12 | from pyspark.sql import SparkSession 13 | 14 | 15 | def get_spark_session( 16 | spark_memory: Optional[int] = None, 17 | shuffle_partitions: Optional[int] = None, 18 | ) -> SparkSession: 19 | """ 20 | Get default SparkSession 21 | 22 | :param spark_memory: GB of memory allocated for Spark; 23 | 70% of RAM by default. 24 | :param shuffle_partitions: number of partitions for Spark; triple CPU count by default 25 | """ 26 | if os.environ.get("SCRIPT_ENV", None) == "cluster": 27 | return SparkSession.builder.getOrCreate() 28 | 29 | os.environ["PYSPARK_PYTHON"] = sys.executable 30 | os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable 31 | 32 | if spark_memory is None: 33 | spark_memory = floor(psutil.virtual_memory().total / 1024**3 * 0.7) 34 | if shuffle_partitions is None: 35 | shuffle_partitions = os.cpu_count() * 3 36 | driver_memory = f"{spark_memory}g" 37 | user_home = os.environ["HOME"] 38 | spark = ( 39 | SparkSession.builder.config("spark.driver.memory", driver_memory) 40 | .config( 41 | "spark.driver.extraJavaOptions", 42 | "-Dio.netty.tryReflectionSetAccessible=true", 43 | ) 44 | .config("spark.sql.shuffle.partitions", str(shuffle_partitions)) 45 | .config("spark.local.dir", os.path.join(user_home, "tmp")) 46 | .config("spark.driver.maxResultSize", "4g") 47 | .config("spark.driver.bindAddress", "127.0.0.1") 48 | .config("spark.driver.host", "localhost") 49 | .config("spark.sql.execution.arrow.pyspark.enabled", "true") 50 | .config("spark.kryoserializer.buffer.max", "256m") 51 | .config("spark.files.overwrite", "true") 52 | .master("local[*]") 53 | .enableHiveSupport() 54 | .getOrCreate() 55 | ) 56 | return spark 57 | 58 | 59 | # pylint: disable=too-few-public-methods 60 | class Borg: 61 | """ 62 | This class allows to share objects between instances. 63 | """ 64 | 65 | _shared_state: Dict[str, Any] = {} 66 | 67 | def __init__(self): 68 | self.__dict__ = self._shared_state 69 | 70 | 71 | # pylint: disable=too-few-public-methods 72 | class State(Borg): 73 | """ 74 | All modules look for Spark session via this class. You can put your own session here. 75 | 76 | Other parameters are stored here too: ``default device`` for ``pytorch`` (CPU/CUDA) 77 | """ 78 | 79 | def __init__( 80 | self, 81 | session: Optional[SparkSession] = None, 82 | device: Optional[torch.device] = None, 83 | ): 84 | Borg.__init__(self) 85 | 86 | if session is None: 87 | if not hasattr(self, "session"): 88 | self.session = get_spark_session() 89 | else: 90 | self.session = session 91 | 92 | if device is None: 93 | if not hasattr(self, "device"): 94 | if torch.cuda.is_available(): 95 | self.device = torch.device( 96 | f"cuda:{torch.cuda.current_device()}" 97 | ) 98 | else: 99 | self.device = torch.device("cpu") 100 | else: 101 | self.device = device 102 | -------------------------------------------------------------------------------- /sim4rec/utils/uce.py: -------------------------------------------------------------------------------- 1 | # Utility class extensions 2 | # pylint: disable=no-member,unused-argument 3 | import pickle 4 | import pyspark.sql.functions as sf 5 | import pyspark.sql.types as st 6 | 7 | from pyspark.sql import DataFrame 8 | from pyspark.ml import Transformer 9 | from pyspark.ml.param.shared import HasInputCol, HasOutputCol 10 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable 11 | 12 | 13 | from pyspark.ml.param.shared import Params, Param, TypeConverters 14 | from pyspark import keyword_only 15 | 16 | 17 | class NotFittedError(Exception): 18 | # pylint: disable=missing-class-docstring 19 | pass 20 | 21 | 22 | class EmptyDataFrameError(Exception): 23 | # pylint: disable=missing-class-docstring 24 | pass 25 | 26 | 27 | # pylint: disable=too-many-ancestors 28 | class VectorElementExtractor(Transformer, 29 | HasInputCol, HasOutputCol, 30 | DefaultParamsReadable, DefaultParamsWritable): 31 | """ 32 | Extracts element at index from array column 33 | """ 34 | 35 | index = Param( 36 | Params._dummy(), 37 | "index", 38 | "Array index to extract", 39 | typeConverter=TypeConverters.toInt 40 | ) 41 | 42 | def setIndex(self, value): 43 | """ 44 | Sets index to a certain value 45 | :param value: Value to set index of an element 46 | """ 47 | return self._set(index=value) 48 | 49 | def getIndex(self): 50 | """ 51 | Returns index of element 52 | """ 53 | return self.getOrDefault(self.index) 54 | 55 | @keyword_only 56 | def __init__( 57 | self, 58 | inputCol : str = None, 59 | outputCol : str = None, 60 | index : int = None 61 | ): 62 | """ 63 | :param inputCol: Input column with array 64 | :param outputCol: Output column name 65 | :param index: Index of an element within array 66 | """ 67 | super().__init__() 68 | self.setParams(**self._input_kwargs) 69 | 70 | @keyword_only 71 | def setParams( 72 | self, 73 | inputCol : str = None, 74 | outputCol : str = None, 75 | index : int = None 76 | ): 77 | """ 78 | Sets parameters for extractor 79 | """ 80 | return self._set(**self._input_kwargs) 81 | 82 | def _transform( 83 | self, 84 | dataset : DataFrame 85 | ): 86 | index = self.getIndex() 87 | 88 | el_udf = sf.udf( 89 | lambda x : float(x[index]), st.DoubleType() 90 | ) 91 | 92 | inputCol = self.getInputCol() 93 | outputCol = self.getOutputCol() 94 | 95 | return dataset.withColumn(outputCol, el_udf(inputCol)) 96 | 97 | 98 | def save(obj : object, filename : str): 99 | """ 100 | Saves an object to pickle dump 101 | :param obj: Instance 102 | :param filename: File name of a dump 103 | """ 104 | 105 | with open(filename, 'wb') as f: 106 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 107 | 108 | 109 | def load(filename : str): 110 | """ 111 | Loads a pickle dump from file 112 | :param filename: File name of a dump 113 | :return: Read instance 114 | """ 115 | 116 | with open(filename, 'rb') as f: 117 | obj = pickle.load(f) 118 | 119 | return obj 120 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sb-ai-lab/Sim4Rec/267b644932b0e120cc9887f0f3a16be38d5739b7/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pyspark.sql import DataFrame, SparkSession 3 | 4 | SEED = 1234 5 | 6 | 7 | @pytest.fixture(scope="session") 8 | def spark() -> SparkSession: 9 | return SparkSession.builder\ 10 | .appName('simulator_test')\ 11 | .master('local[4]')\ 12 | .config('spark.sql.shuffle.partitions', '4')\ 13 | .config('spark.default.parallelism', '4')\ 14 | .config('spark.driver.extraJavaOptions', '-XX:+UseG1GC')\ 15 | .config('spark.executor.extraJavaOptions', '-XX:+UseG1GC')\ 16 | .config('spark.sql.autoBroadcastJoinThreshold', '-1')\ 17 | .config('spark.driver.memory', '4g')\ 18 | .getOrCreate() 19 | 20 | 21 | @pytest.fixture(scope="session") 22 | def users_df(spark: SparkSession) -> DataFrame: 23 | data = [ 24 | (0, 1.25, -0.75, 0.5), 25 | (1, 0.2, -1.0, 0.0), 26 | (2, 0.85, -0.5, -1.5), 27 | (3, -0.33, -0.33, 0.33), 28 | (4, 0.1, 0.2, 0.3) 29 | ] 30 | 31 | return spark.createDataFrame( 32 | data=data, 33 | schema=['user_id', 'user_attr_1', 'user_attr_2', 'user_attr_3'] 34 | ) 35 | 36 | 37 | @pytest.fixture(scope="session") 38 | def items_df(spark: SparkSession) -> DataFrame: 39 | data = [ 40 | (0, 0.45, -0.45, 1.2), 41 | (1, -0.3, 0.75, 0.25), 42 | (2, 1.25, -0.75, 0.5), 43 | (3, -1.0, 0.0, -0.5), 44 | (4, 0.5, -0.5, -1.0) 45 | ] 46 | 47 | return spark.createDataFrame( 48 | data=data, 49 | schema=['item_id', 'item_attr_1', 'item_attr_2', 'item_attr_3'] 50 | ) 51 | 52 | 53 | @pytest.fixture(scope="session") 54 | def log_df(spark: SparkSession) -> DataFrame: 55 | data = [ 56 | (0, 2, 1.0, 0), 57 | (1, 1, 0.0, 0), 58 | (1, 2, 0.0, 0), 59 | (2, 0, 1.0, 0), 60 | (2, 2, 1.0, 0) 61 | ] 62 | return spark.createDataFrame( 63 | data=data, 64 | schema=['user_id', 'item_id', 'relevance', 'timestamp'] 65 | ) 66 | -------------------------------------------------------------------------------- /tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pyspark.sql import DataFrame 3 | 4 | from sim4rec.modules import ( 5 | EncoderEstimator, 6 | EncoderTransformer 7 | ) 8 | 9 | SEED = 1234 10 | 11 | 12 | @pytest.fixture(scope="module") 13 | def estimator(users_df : DataFrame) -> EncoderEstimator: 14 | return EncoderEstimator( 15 | inputCols=users_df.columns, 16 | outputCols=[f'encoded_{i}' for i in range(5)], 17 | hidden_dim=10, 18 | lr=0.001, 19 | batch_size=32, 20 | num_loader_workers=2, 21 | max_iter=10, 22 | device_name='cpu', 23 | seed=SEED 24 | ) 25 | 26 | 27 | @pytest.fixture(scope="module") 28 | def transformer( 29 | users_df : DataFrame, 30 | estimator : EncoderEstimator 31 | ) -> EncoderTransformer: 32 | return estimator.fit(users_df) 33 | 34 | 35 | def test_estimator_fit( 36 | estimator : EncoderEstimator, 37 | transformer : EncoderTransformer 38 | ): 39 | assert estimator._input_dim == len(estimator.getInputCols()) 40 | assert estimator._latent_dim == len(estimator.getOutputCols()) 41 | assert estimator.getDevice() == transformer.getDevice() 42 | assert str(next(transformer._encoder.parameters()).device) == transformer.getDevice() 43 | 44 | 45 | def test_transformer_transform( 46 | users_df : DataFrame, 47 | transformer : EncoderTransformer 48 | ): 49 | result = transformer.transform(users_df) 50 | 51 | assert result.count() == users_df.count() 52 | assert len(result.columns) == 5 53 | assert set(result.columns) == set([f'encoded_{i}' for i in range(5)]) 54 | -------------------------------------------------------------------------------- /tests/test_evaluation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pyspark.sql import DataFrame, SparkSession 4 | 5 | from sim4rec.modules import ( 6 | evaluate_synthetic, 7 | EvaluateMetrics, 8 | ks_test, 9 | kl_divergence 10 | ) 11 | from sim4rec.response import ConstantResponse 12 | 13 | 14 | @pytest.fixture(scope="function") 15 | def evaluator() -> EvaluateMetrics: 16 | return EvaluateMetrics( 17 | userKeyCol='user_id', 18 | itemKeyCol='item_id', 19 | predictionCol='relevance', 20 | labelCol='response', 21 | mllib_metrics=['mse', 'f1', 'areaUnderROC'] 22 | ) 23 | 24 | 25 | @pytest.fixture(scope="module") 26 | def response_df(spark : SparkSession) -> DataFrame: 27 | data = [ 28 | (0, 0, 0.0, 0.0), 29 | (0, 1, 0.5, 1.0), 30 | (0, 2, 1.0, 1.0), 31 | (1, 0, 1.0, 0.0), 32 | (1, 1, 0.5, 0.0), 33 | (1, 2, 0.0, 1.0) 34 | ] 35 | return spark.createDataFrame(data=data, schema=['user_id', 'item_id', 'relevance', 'response']) 36 | 37 | 38 | def test_evaluate_metrics( 39 | evaluator : EvaluateMetrics, 40 | response_df : DataFrame 41 | ): 42 | result = evaluator(response_df) 43 | assert 'mse' in result 44 | assert 'f1' in result 45 | assert 'areaUnderROC' in result 46 | 47 | result = evaluator(response_df) 48 | assert 'mse' in result 49 | assert 'f1' in result 50 | assert 'areaUnderROC' in result 51 | 52 | evaluator._mllib_metrics = [] 53 | 54 | result = evaluator(response_df) 55 | 56 | 57 | def test_evaluate_synthetic( 58 | users_df : DataFrame 59 | ): 60 | import pandas as pd 61 | pd.options.mode.chained_assignment = None 62 | 63 | result = evaluate_synthetic( 64 | users_df.sample(0.5).drop('user_id'), 65 | users_df.sample(0.5).drop('user_id') 66 | ) 67 | 68 | assert result['LogisticDetection'] is not None 69 | assert result['SVCDetection'] is not None 70 | assert result['KSTest'] is not None 71 | assert result['ContinuousKLDivergence'] is not None 72 | 73 | 74 | def test_kstest( 75 | users_df : DataFrame 76 | ): 77 | result = ks_test( 78 | df=users_df.select('user_attr_1', 'user_attr_2'), 79 | predCol='user_attr_1', 80 | labelCol='user_attr_2' 81 | ) 82 | 83 | assert isinstance(result, float) 84 | 85 | 86 | def test_kldiv( 87 | users_df : DataFrame 88 | ): 89 | result = kl_divergence( 90 | df=users_df.select('user_attr_1', 'user_attr_2'), 91 | predCol='user_attr_1', 92 | labelCol='user_attr_2' 93 | ) 94 | 95 | assert isinstance(result, float) 96 | -------------------------------------------------------------------------------- /tests/test_generators.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import numpy as np 4 | import pandas as pd 5 | import pyspark.sql.functions as sf 6 | from pyspark.sql import DataFrame 7 | 8 | from sim4rec.modules import ( 9 | RealDataGenerator, 10 | SDVDataGenerator, 11 | CompositeGenerator 12 | ) 13 | 14 | SEED = 1234 15 | 16 | 17 | @pytest.fixture(scope="function") 18 | def real_gen() -> RealDataGenerator: 19 | return RealDataGenerator(label='real', seed=SEED) 20 | 21 | 22 | @pytest.fixture(scope="function") 23 | def synth_gen() -> SDVDataGenerator: 24 | return SDVDataGenerator( 25 | label='synth', 26 | id_column_name='user_id', 27 | model_name='gaussiancopula', 28 | parallelization_level=2, 29 | device_name='cpu', 30 | seed=SEED 31 | ) 32 | 33 | 34 | @pytest.fixture(scope="function") 35 | def comp_gen(real_gen : RealDataGenerator, synth_gen : SDVDataGenerator) -> CompositeGenerator: 36 | return CompositeGenerator( 37 | generators=[real_gen, synth_gen], 38 | label='composite', 39 | weights=[0.5, 0.5] 40 | ) 41 | 42 | 43 | def test_realdatagenerator_fit(real_gen : RealDataGenerator, users_df : DataFrame): 44 | real_gen.fit(users_df) 45 | 46 | assert real_gen._fit_called 47 | assert real_gen._source_df.count() == users_df.count() 48 | 49 | 50 | def test_sdvdatagenerator_fit(synth_gen : SDVDataGenerator, users_df : DataFrame): 51 | synth_gen.fit(users_df) 52 | 53 | assert synth_gen._fit_called 54 | assert isinstance(synth_gen._model.sample(100), pd.DataFrame) 55 | 56 | 57 | def test_realdatagenerator_generate(real_gen : RealDataGenerator, users_df : DataFrame): 58 | real_gen.fit(users_df) 59 | 60 | assert real_gen.generate(5).count() == 5 61 | assert real_gen._df.count() == 5 62 | assert real_gen.getDataSize() == 5 63 | 64 | 65 | def test_sdvdatagenerator_generate(synth_gen : SDVDataGenerator, users_df : DataFrame): 66 | synth_gen.fit(users_df) 67 | 68 | assert synth_gen.generate(100).count() == 100 69 | assert synth_gen._df.count() == 100 70 | assert synth_gen.getDataSize() == 100 71 | 72 | 73 | def test_compositegenerator_generate( 74 | real_gen : RealDataGenerator, 75 | synth_gen : SDVDataGenerator, 76 | comp_gen : CompositeGenerator, 77 | users_df : DataFrame 78 | ): 79 | real_gen.fit(users_df) 80 | synth_gen.fit(users_df) 81 | comp_gen.generate(10) 82 | 83 | assert real_gen.getDataSize() == 5 84 | assert synth_gen.getDataSize() == 5 85 | 86 | comp_gen.setWeights([1.0, 0.0]) 87 | comp_gen.generate(5) 88 | 89 | assert real_gen.getDataSize() == 5 90 | assert synth_gen.getDataSize() == 0 91 | 92 | comp_gen.setWeights([0.0, 1.0]) 93 | comp_gen.generate(10) 94 | 95 | assert real_gen.getDataSize() == 0 96 | assert synth_gen.getDataSize() == 10 97 | 98 | 99 | def test_realdatagenerator_sample(real_gen : RealDataGenerator, users_df : DataFrame): 100 | real_gen.fit(users_df) 101 | _ = real_gen.generate(5) 102 | 103 | assert real_gen.sample(1.0).count() == 5 104 | assert real_gen.sample(0.5).count() == 2 105 | assert real_gen.sample(0.0).count() == 0 106 | 107 | 108 | def test_sdvdatagenerator_sample(synth_gen : SDVDataGenerator, users_df : DataFrame): 109 | synth_gen.fit(users_df) 110 | _ = synth_gen.generate(100) 111 | 112 | assert synth_gen.sample(1.0).count() == 100 113 | assert synth_gen.sample(0.5).count() == 46 114 | assert synth_gen.sample(0.0).count() == 0 115 | 116 | 117 | def test_compositegenerator_sample( 118 | real_gen : RealDataGenerator, 119 | synth_gen : SDVDataGenerator, 120 | comp_gen : CompositeGenerator, 121 | users_df : DataFrame 122 | ): 123 | real_gen.fit(users_df) 124 | synth_gen.fit(users_df) 125 | comp_gen.generate(10) 126 | 127 | assert comp_gen.sample(1.0).count() == 10 128 | assert comp_gen.sample(0.5).count() == 4 129 | assert comp_gen.sample(0.0).count() == 0 130 | 131 | df = comp_gen.sample(1.0).toPandas() 132 | assert df['user_id'].str.startswith('synth').sum() == 5 133 | 134 | comp_gen.setWeights([1.0, 0.0]) 135 | comp_gen.generate(5) 136 | df = comp_gen.sample(1.0).toPandas() 137 | assert df['user_id'].str.startswith('synth').sum() == 0 138 | 139 | comp_gen.setWeights([0.0, 1.0]) 140 | comp_gen.generate(10) 141 | df = comp_gen.sample(1.0).toPandas() 142 | assert df['user_id'].str.startswith('synth').sum() == 10 143 | 144 | 145 | def test_realdatagenerator_iterdiff(real_gen : RealDataGenerator, users_df : DataFrame): 146 | real_gen.fit(users_df) 147 | generated_1 = real_gen.generate(5).toPandas() 148 | sampled_1 = real_gen.sample(0.5).toPandas() 149 | 150 | generated_2 = real_gen.generate(5).toPandas() 151 | sampled_2 = real_gen.sample(0.5).toPandas() 152 | 153 | assert not generated_1.equals(generated_2) 154 | assert not sampled_1.equals(sampled_2) 155 | 156 | 157 | def test_sdvdatagenerator_iterdiff(synth_gen : SDVDataGenerator, users_df : DataFrame): 158 | synth_gen.fit(users_df) 159 | 160 | generated_1 = synth_gen.generate(100).toPandas() 161 | sampled_1 = synth_gen.sample(0.1).toPandas() 162 | 163 | generated_2 = synth_gen.generate(100).toPandas() 164 | sampled_2 = synth_gen.sample(0.1).toPandas() 165 | 166 | assert not generated_1.equals(generated_2) 167 | assert not sampled_1.equals(sampled_2) 168 | 169 | 170 | def test_compositegenerator_iterdiff( 171 | real_gen : RealDataGenerator, 172 | synth_gen : SDVDataGenerator, 173 | comp_gen : CompositeGenerator, 174 | users_df : DataFrame 175 | ): 176 | real_gen.fit(users_df) 177 | synth_gen.fit(users_df) 178 | comp_gen.generate(10) 179 | 180 | sampled_1 = comp_gen.sample(0.5).toPandas() 181 | sampled_2 = comp_gen.sample(0.5).toPandas() 182 | 183 | assert not sampled_1.equals(sampled_2) 184 | 185 | 186 | def test_sdvdatagenerator_partdiff(synth_gen : SDVDataGenerator, users_df : DataFrame): 187 | synth_gen.fit(users_df) 188 | 189 | generated = synth_gen.generate(100)\ 190 | .drop('user_id')\ 191 | .withColumn('__partition_id', sf.spark_partition_id()) 192 | df_1 = generated.filter(sf.col('__partition_id') == 0)\ 193 | .drop('__partition_id').toPandas() 194 | df_2 = generated.filter(sf.col('__partition_id') == 1)\ 195 | .drop('__partition_id').toPandas() 196 | 197 | assert not df_1.equals(df_2) 198 | 199 | 200 | def test_sdv_save_load( 201 | synth_gen : SDVDataGenerator, 202 | users_df : DataFrame, 203 | tmp_path 204 | ): 205 | synth_gen.fit(users_df) 206 | synth_gen.save_model(f'{tmp_path}/generator.pkl') 207 | 208 | assert os.path.isfile(f'{tmp_path}/generator.pkl') 209 | 210 | g = SDVDataGenerator.load(f'{tmp_path}/generator.pkl') 211 | 212 | assert g.getLabel() == 'synth' 213 | assert g._id_col_name == 'user_id' 214 | assert g._model_name == 'gaussiancopula' 215 | assert g.getParallelizationLevel() == 2 216 | assert g.getDevice() == 'cpu' 217 | assert g.getInitSeed() == 1234 218 | assert g._fit_called 219 | assert hasattr(g, '_model') 220 | assert hasattr(g, '_schema') 221 | -------------------------------------------------------------------------------- /tests/test_responses.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from pyspark.sql import DataFrame, SparkSession 4 | from pyspark.ml.feature import VectorAssembler 5 | from pyspark.ml.linalg import DenseVector 6 | 7 | from sim4rec.response import ( 8 | BernoulliResponse, 9 | NoiseResponse, 10 | ConstantResponse, 11 | CosineSimilatiry, 12 | ParametricResponseFunction 13 | ) 14 | 15 | SEED = 1234 16 | 17 | 18 | @pytest.fixture(scope="function") 19 | def users_va_left() -> VectorAssembler: 20 | return VectorAssembler( 21 | inputCols=[f'user_attr_{i}' for i in range(0, 5)], 22 | outputCol='__v1' 23 | ) 24 | 25 | 26 | @pytest.fixture(scope="function") 27 | def users_va_right() -> VectorAssembler: 28 | return VectorAssembler( 29 | inputCols=[f'user_attr_{i}' for i in range(5, 10)], 30 | outputCol='__v2' 31 | ) 32 | 33 | 34 | @pytest.fixture(scope="function") 35 | def bernoulli_resp() -> BernoulliResponse: 36 | return BernoulliResponse( 37 | inputCol='__proba', 38 | outputCol='relevance', 39 | seed=SEED 40 | ) 41 | 42 | 43 | @pytest.fixture(scope="function") 44 | def noise_resp() -> NoiseResponse: 45 | return NoiseResponse( 46 | mu=0.5, 47 | sigma=0.2, 48 | outputCol='__noise', 49 | seed=SEED 50 | ) 51 | 52 | 53 | @pytest.fixture(scope="function") 54 | def const_resp() -> ConstantResponse: 55 | return ConstantResponse( 56 | value=0.5, 57 | outputCol='__const', 58 | ) 59 | 60 | 61 | @pytest.fixture(scope="function") 62 | def cosine_resp() -> CosineSimilatiry: 63 | return CosineSimilatiry( 64 | inputCols=['__v1', '__v2'], 65 | outputCol='__cosine' 66 | ) 67 | 68 | 69 | @pytest.fixture(scope="function") 70 | def param_resp() -> ParametricResponseFunction: 71 | return ParametricResponseFunction( 72 | inputCols=['__const', '__cosine'], 73 | outputCol='__proba', 74 | weights=[0.5, 0.5] 75 | ) 76 | 77 | 78 | @pytest.fixture(scope="module") 79 | def random_df(spark : SparkSession) -> DataFrame: 80 | data = [ 81 | (0, 0.0), 82 | (1, 0.2), 83 | (2, 0.4), 84 | (3, 0.6), 85 | (4, 1.0) 86 | ] 87 | return spark.createDataFrame(data=data, schema=['id', '__proba']) 88 | 89 | 90 | @pytest.fixture(scope="module") 91 | def vector_df(spark : SparkSession) -> DataFrame: 92 | data = [ 93 | (0, DenseVector([1.0, 0.0]), DenseVector([0.0, 1.0])), 94 | (1, DenseVector([-1.0, 0.0]), DenseVector([1.0, 0.0])), 95 | (2, DenseVector([1.0, 0.0]), DenseVector([1.0, 0.0])), 96 | (3, DenseVector([0.5, 0.5]), DenseVector([-0.5, 0.5])), 97 | (4, DenseVector([0.0, 0.0]), DenseVector([1.0, 0.0])) 98 | ] 99 | return spark.createDataFrame(data=data, schema=['id', '__v1', '__v2']) 100 | 101 | 102 | def test_bernoulli_transform( 103 | bernoulli_resp : BernoulliResponse, 104 | random_df : DataFrame 105 | ): 106 | result = bernoulli_resp.transform(random_df).toPandas().sort_values(['id']) 107 | 108 | assert 'relevance' in result.columns 109 | assert len(result) == 5 110 | assert set(result['relevance']) == set([0, 1]) 111 | assert list(result['relevance'][:5]) == [0, 0, 0, 1, 1] 112 | 113 | 114 | def test_bernoulli_iterdiff( 115 | bernoulli_resp : BernoulliResponse, 116 | random_df : DataFrame 117 | ): 118 | result1 = bernoulli_resp.transform(random_df).toPandas() 119 | result1 = result1.sort_values(['id']) 120 | result2 = bernoulli_resp.transform(random_df).toPandas() 121 | result2 = result2.sort_values(['id']) 122 | 123 | assert list(result1['relevance'][:5]) != list(result2['relevance'][:5]) 124 | 125 | 126 | def test_noise_transform( 127 | noise_resp : NoiseResponse, 128 | random_df : DataFrame 129 | ): 130 | result = noise_resp.transform(random_df).toPandas().sort_values(['id']) 131 | 132 | assert '__noise' in result.columns 133 | assert len(result) == 5 134 | assert np.allclose(result['__noise'][0], 0.6117798825975235) 135 | 136 | 137 | def test_noise_iterdiff( 138 | noise_resp : NoiseResponse, 139 | random_df : DataFrame 140 | ): 141 | result1 = noise_resp.transform(random_df).toPandas().sort_values(['id']) 142 | result2 = noise_resp.transform(random_df).toPandas().sort_values(['id']) 143 | 144 | assert result1['__noise'][0] != result2['__noise'][0] 145 | 146 | 147 | def test_const_transform( 148 | const_resp : ConstantResponse, 149 | random_df : DataFrame 150 | ): 151 | result = const_resp.transform(random_df).toPandas().sort_values(['id']) 152 | 153 | assert '__const' in result.columns 154 | assert len(result) == 5 155 | assert list(result['__const']) == [0.5] * 5 156 | 157 | 158 | def test_cosine_transform( 159 | cosine_resp : CosineSimilatiry, 160 | vector_df : DataFrame 161 | ): 162 | result = cosine_resp.transform(vector_df)\ 163 | .drop('__v1', '__v2')\ 164 | .toPandas()\ 165 | .sort_values(['id']) 166 | 167 | assert '__cosine' in result.columns 168 | assert len(result) == 5 169 | assert list(result['__cosine']) == [0.5, 0.0, 1.0, 0.5, 0.0] 170 | 171 | 172 | def test_paramresp_transform( 173 | param_resp : ParametricResponseFunction, 174 | const_resp : ConstantResponse, 175 | cosine_resp : CosineSimilatiry, 176 | vector_df : DataFrame 177 | ): 178 | df = const_resp.transform(vector_df) 179 | df = cosine_resp.transform(df) 180 | 181 | result = param_resp.transform(df)\ 182 | .drop('__v1', '__v2')\ 183 | .toPandas()\ 184 | .sort_values(['id']) 185 | 186 | assert '__proba' in result.columns 187 | assert len(result) == 5 188 | assert list(result['__proba']) == [0.5, 0.25, 0.75, 0.5, 0.25] 189 | 190 | r = result['__const'][0] / 2 +\ 191 | result['__cosine'][0] / 2 192 | assert result['__proba'][0] == r 193 | 194 | param_resp.setWeights([1.0, 0.0]) 195 | result = param_resp.transform(df)\ 196 | .drop('__v1', '__v2')\ 197 | .toPandas()\ 198 | .sort_values(['id']) 199 | assert result['__proba'][0] == result['__const'][0] 200 | 201 | param_resp.setWeights([0.0, 1.0]) 202 | result = param_resp.transform(df)\ 203 | .drop('__v1', '__v2')\ 204 | .toPandas()\ 205 | .sort_values(['id']) 206 | assert result['__proba'][0] == result['__cosine'][0] 207 | -------------------------------------------------------------------------------- /tests/test_selectors.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pyspark.sql import DataFrame 3 | 4 | from sim4rec.modules import ( 5 | CrossJoinItemEstimator, 6 | CrossJoinItemTransformer 7 | ) 8 | 9 | SEED = 1234 10 | K = 5 11 | 12 | 13 | @pytest.fixture(scope="function") 14 | def estimator() -> CrossJoinItemEstimator: 15 | return CrossJoinItemEstimator( 16 | k=K, 17 | userKeyColumn='user_id', 18 | itemKeyColumn='item_id', 19 | seed=SEED 20 | ) 21 | 22 | 23 | @pytest.fixture(scope="function") 24 | def transformer( 25 | estimator : CrossJoinItemEstimator, 26 | items_df : DataFrame 27 | ) -> CrossJoinItemTransformer: 28 | return estimator.fit(items_df) 29 | 30 | 31 | def test_crossjoinestimator_fit( 32 | transformer : CrossJoinItemTransformer, 33 | items_df : DataFrame 34 | ): 35 | assert transformer._item_df is not None 36 | assert transformer._item_df.toPandas().equals(items_df.toPandas()) 37 | 38 | 39 | def test_crossjointransformer_transform( 40 | transformer : CrossJoinItemTransformer, 41 | users_df : DataFrame 42 | ): 43 | sample = transformer.transform(users_df).toPandas() 44 | sample = sample.sort_values(['user_id', 'item_id']) 45 | 46 | assert 'user_id' in sample.columns[0] 47 | assert 'item_id' in sample.columns[1] 48 | assert len(sample) == users_df.count() * K 49 | assert sample.iloc[K, 0] == 1 50 | assert sample.iloc[K, 1] == 0 51 | 52 | 53 | def test_crossjointransformer_iterdiff( 54 | transformer : CrossJoinItemTransformer, 55 | users_df : DataFrame 56 | ): 57 | sample_1 = transformer.transform(users_df).toPandas() 58 | sample_2 = transformer.transform(users_df).toPandas() 59 | sample_1 = sample_1.sort_values(['user_id', 'item_id']) 60 | sample_2 = sample_2.sort_values(['user_id', 'item_id']) 61 | 62 | assert not sample_1.equals(sample_2) 63 | 64 | 65 | def test_crossjointransformer_fixedseed( 66 | transformer : CrossJoinItemTransformer, 67 | users_df : DataFrame, 68 | items_df : DataFrame 69 | ): 70 | e = CrossJoinItemEstimator( 71 | k=K, 72 | userKeyColumn='user_id', 73 | itemKeyColumn='item_id', 74 | seed=SEED 75 | ) 76 | t = e.fit(items_df) 77 | 78 | sample_1 = transformer.transform(users_df).toPandas() 79 | sample_2 = t.transform(users_df).toPandas() 80 | sample_1 = sample_1.sort_values(['user_id', 'item_id']) 81 | sample_2 = sample_2.sort_values(['user_id', 'item_id']) 82 | 83 | assert sample_1.equals(sample_2) 84 | -------------------------------------------------------------------------------- /tests/test_simulator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import pytest 4 | import pyspark.sql.functions as sf 5 | from pyspark.sql import DataFrame, SparkSession 6 | from pyspark.ml import PipelineModel 7 | from pyspark.ml.feature import VectorAssembler 8 | 9 | from sim4rec.modules import ( 10 | Simulator, 11 | RealDataGenerator, 12 | SDVDataGenerator, 13 | CompositeGenerator, 14 | CrossJoinItemEstimator, 15 | CrossJoinItemTransformer 16 | ) 17 | from sim4rec.response import CosineSimilatiry 18 | 19 | 20 | SEED = 1234 21 | 22 | 23 | @pytest.fixture(scope="module") 24 | def real_users_gen(users_df : DataFrame) -> RealDataGenerator: 25 | gen = RealDataGenerator(label='real', seed=SEED) 26 | gen.fit(users_df) 27 | gen.generate(5) 28 | 29 | return gen 30 | 31 | 32 | @pytest.fixture(scope="module") 33 | def synth_users_gen(users_df : DataFrame) -> SDVDataGenerator: 34 | gen = SDVDataGenerator( 35 | label='synth', 36 | id_column_name='user_id', 37 | model_name='gaussiancopula', 38 | parallelization_level=2, 39 | device_name='cpu', 40 | seed=SEED 41 | ) 42 | gen.fit(users_df) 43 | gen.generate(5) 44 | 45 | return gen 46 | 47 | 48 | @pytest.fixture(scope="module") 49 | def comp_users_gen( 50 | real_users_gen : RealDataGenerator, 51 | synth_users_gen : SDVDataGenerator 52 | ) -> CompositeGenerator: 53 | return CompositeGenerator( 54 | generators=[real_users_gen, synth_users_gen], 55 | label='composite', 56 | weights=[0.5, 0.5] 57 | ) 58 | 59 | 60 | @pytest.fixture(scope="module") 61 | def real_items_gen(items_df : DataFrame) -> RealDataGenerator: 62 | gen = RealDataGenerator(label='real', seed=SEED) 63 | gen.fit(items_df) 64 | gen.generate(5) 65 | 66 | return gen 67 | 68 | 69 | @pytest.fixture(scope="module") 70 | def selector(items_df : DataFrame) -> CrossJoinItemTransformer: 71 | estimator = CrossJoinItemEstimator( 72 | k=3, 73 | userKeyColumn='user_id', 74 | itemKeyColumn='item_id', 75 | seed=SEED 76 | ) 77 | return estimator.fit(items_df) 78 | 79 | 80 | @pytest.fixture(scope="module") 81 | def pipeline() -> PipelineModel: 82 | va_left = VectorAssembler(inputCols=['user_attr_1', 'user_attr_2'], outputCol='__v1') 83 | va_right = VectorAssembler(inputCols=['item_attr_1', 'item_attr_2'], outputCol='__v2') 84 | 85 | c = CosineSimilatiry(inputCols=['__v1', '__v2'], outputCol='response') 86 | 87 | return PipelineModel(stages=[va_left, va_right, c]) 88 | 89 | 90 | @pytest.fixture(scope="function") 91 | def simulator_empty( 92 | comp_users_gen : CompositeGenerator, 93 | real_items_gen : RealDataGenerator, 94 | spark : SparkSession, 95 | tmp_path 96 | ) -> Simulator: 97 | shutil.rmtree(str(tmp_path / 'sim_empty'), ignore_errors=True) 98 | return Simulator( 99 | user_gen=comp_users_gen, 100 | item_gen=real_items_gen, 101 | log_df=None, 102 | user_key_col='user_id', 103 | item_key_col='item_id', 104 | data_dir=str(tmp_path / 'sim_empty'), 105 | spark_session=spark 106 | ) 107 | 108 | 109 | @pytest.fixture(scope="function") 110 | def simulator_with_log( 111 | comp_users_gen : CompositeGenerator, 112 | real_items_gen : RealDataGenerator, 113 | log_df : DataFrame, 114 | spark : SparkSession, 115 | tmp_path 116 | ) -> Simulator: 117 | shutil.rmtree(str(tmp_path / 'sim_with_log'), ignore_errors=True) 118 | return Simulator( 119 | user_gen=comp_users_gen, 120 | item_gen=real_items_gen, 121 | log_df=log_df, 122 | user_key_col='user_id', 123 | item_key_col='item_id', 124 | data_dir=str(tmp_path / 'sim_with_log'), 125 | spark_session=spark 126 | ) 127 | 128 | 129 | def test_simulator_init( 130 | simulator_empty : Simulator, 131 | simulator_with_log : Simulator 132 | ): 133 | assert os.path.isdir(simulator_empty._data_dir) 134 | assert os.path.isdir(simulator_with_log._data_dir) 135 | 136 | assert simulator_empty._log is None 137 | assert Simulator.ITER_COLUMN in simulator_with_log._log.columns 138 | 139 | assert simulator_with_log._log.count() == 5 140 | assert os.path.isdir(f'{simulator_with_log._data_dir}/{simulator_with_log.log_filename}/{Simulator.ITER_COLUMN}=start') 141 | 142 | 143 | def test_simulator_clearlog( 144 | simulator_with_log : Simulator 145 | ): 146 | simulator_with_log.clear_log() 147 | 148 | assert simulator_with_log.log is None 149 | assert simulator_with_log._log_schema is None 150 | 151 | 152 | def test_simulator_updatelog( 153 | simulator_empty : Simulator, 154 | simulator_with_log : Simulator, 155 | log_df : DataFrame 156 | ): 157 | simulator_empty.update_log(log_df, iteration=0) 158 | simulator_with_log.update_log(log_df, iteration=0) 159 | 160 | assert simulator_empty.log.count() == 5 161 | assert simulator_with_log.log.count() == 10 162 | 163 | assert set(simulator_empty.log.toPandas()[Simulator.ITER_COLUMN].unique()) == set([0]) 164 | assert set(simulator_with_log.log.toPandas()[Simulator.ITER_COLUMN].unique()) == set(['0', 'start']) 165 | 166 | assert os.path.isdir(f'{simulator_empty._data_dir}/{simulator_empty.log_filename}/{Simulator.ITER_COLUMN}=0') 167 | assert os.path.isdir(f'{simulator_with_log._data_dir}/{simulator_with_log.log_filename}/{Simulator.ITER_COLUMN}=start') 168 | assert os.path.isdir(f'{simulator_with_log._data_dir}/{simulator_with_log.log_filename}/{Simulator.ITER_COLUMN}=0') 169 | 170 | 171 | def test_simulator_sampleusers( 172 | simulator_empty : Simulator 173 | ): 174 | sampled1 = simulator_empty.sample_users(0.5)\ 175 | .toPandas().sort_values(['user_id']) 176 | sampled2 = simulator_empty.sample_users(0.5)\ 177 | .toPandas().sort_values(['user_id']) 178 | 179 | assert not sampled1.equals(sampled2) 180 | 181 | assert len(sampled1) == 2 182 | assert len(sampled2) == 4 183 | 184 | 185 | def test_simulator_sampleitems( 186 | simulator_empty : Simulator 187 | ): 188 | sampled1 = simulator_empty.sample_items(0.5)\ 189 | .toPandas().sort_values(['item_id']) 190 | sampled2 = simulator_empty.sample_items(0.5)\ 191 | .toPandas().sort_values(['item_id']) 192 | 193 | assert not sampled1.equals(sampled2) 194 | 195 | assert len(sampled1) == 2 196 | assert len(sampled2) == 2 197 | 198 | 199 | def test_simulator_getuseritem( 200 | simulator_with_log : Simulator, 201 | selector : CrossJoinItemTransformer, 202 | users_df : DataFrame 203 | ): 204 | users = users_df.filter(sf.col('user_id').isin([0, 1, 2])) 205 | pairs, log = simulator_with_log.get_user_items(users, selector) 206 | 207 | assert pairs.count() == users.count() * selector._k 208 | assert log.count() == 5 209 | 210 | assert 'user_id' in pairs.columns 211 | assert 'item_id' in pairs.columns 212 | 213 | assert set(log.toPandas()['user_id']) == set([0, 1, 2]) 214 | 215 | 216 | def test_simulator_responses( 217 | simulator_empty : Simulator, 218 | pipeline : PipelineModel, 219 | users_df : DataFrame, 220 | items_df : DataFrame, 221 | log_df : DataFrame 222 | ): 223 | resp = simulator_empty.sample_responses( 224 | recs_df=log_df, 225 | user_features=users_df, 226 | item_features=items_df, 227 | action_models=pipeline 228 | ).drop('__v1', '__v2').toPandas().sort_values(['user_id']) 229 | 230 | assert 'user_id' in resp.columns 231 | assert 'item_id' in resp.columns 232 | assert 'response' in resp.columns 233 | assert len(resp) == log_df.count() 234 | assert resp['response'].values[0] == 1 235 | --------------------------------------------------------------------------------