├── .github
└── workflows
│ └── main.yml
├── .pylintrc
├── LICENSE
├── README.md
├── demo
├── README.md
├── cycle_time_50M_users_5seeds.csv
├── demo_examples.ipynb
├── requirements.txt
├── rs_performance_evaluation.ipynb
├── simulation_cycle_time.ipynb
├── synthetic_data_generation.ipynb
├── synthetic_data_generation_time.ipynb
└── synthetic_gen_time.csv
├── docs
├── Makefile
├── make.bat
└── source
│ ├── conf.py
│ ├── index.rst
│ └── pages
│ ├── modules.rst
│ ├── response.rst
│ └── utils.rst
├── experiments
├── Amazon
│ ├── RU_amazon_processing.ipynb
│ ├── RU_load_dataset.ipynb
│ ├── RU_train_test_split.ipynb
│ └── amazon.py
├── Movielens
│ ├── ml20.py
│ ├── movie_preprocessing.ipynb
│ └── traint_test_split.ipynb
├── Netflix
│ ├── RU_load_datasets.ipynb
│ ├── RU_netflix_processing.ipynb
│ ├── RU_train_test_split.ipynb
│ └── netflix.py
├── RU_amazon_embeddings.ipynb
├── RU_amazon_generators.ipynb
├── RU_amazon_response.ipynb
├── RU_amazon_surface.ipynb
├── RU_netflix_embeddings.ipynb
├── RU_netflix_generators.ipynb
├── RU_netflix_response.ipynb
├── RU_netflix_surface.ipynb
├── RU_sim_validation.ipynb
├── RU_validation_quality_control.ipynb
├── custom_response_function.ipynb
├── datautils.py
├── generators_fit.ipynb
├── generators_generate.ipynb
├── generators_generate.py
├── items_generation.ipynb
├── movielens_embeddings.ipynb
├── movielens_generators.ipynb
├── movielens_quality_control.ipynb
├── movielens_response.ipynb
├── movielens_scenario.ipynb
├── movielens_surface.ipynb
├── response_models
│ ├── data
│ │ └── popular_items_popularity.parquet
│ │ │ ├── .part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet.crc
│ │ │ ├── _SUCCESS
│ │ │ └── part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet
│ ├── task_1_popular_items.py
│ └── utils.py
├── simulator_time.ipynb
├── simulator_time.py
└── transformers.py
├── notebooks
├── embeddings.ipynb
└── pipeline.ipynb
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── sim4rec
├── __init__.py
├── modules
│ ├── __init__.py
│ ├── embeddings.py
│ ├── evaluation.py
│ ├── generator.py
│ ├── selectors.py
│ └── simulator.py
├── params
│ ├── __init__.py
│ └── params.py
├── recommenders
│ ├── ucb.py
│ └── utils.py
├── response
│ ├── __init__.py
│ └── response.py
└── utils
│ ├── __init__.py
│ ├── convert.py
│ ├── session_handler.py
│ └── uce.py
└── tests
├── __init__.py
├── conftest.py
├── test_embeddings.py
├── test_evaluation.py
├── test_generators.py
├── test_responses.py
├── test_selectors.py
└── test_simulator.py
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | # Controls when the workflow will run
4 | on:
5 | # Triggers the workflow on pull request events but only for the main branch
6 | pull_request:
7 | branches: [main]
8 |
9 | # Allows you to run this workflow manually from the Actions tab
10 | workflow_dispatch:
11 |
12 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel
13 | jobs:
14 | run_tests:
15 | runs-on: ubuntu-20.04
16 | strategy:
17 | matrix:
18 | python-version: ["3.8", "3.9"]
19 | steps:
20 | - uses: actions/checkout@v2
21 | - uses: actions/setup-python@v2
22 | with:
23 | python-version: ${{ matrix.python-version }}
24 | - name: Install package
25 | run: |
26 | python -m venv venv
27 | . ./venv/bin/activate
28 | pip install --upgrade pip wheel poetry pycodestyle pylint pytest-cov
29 |
30 | poetry cache clear pypi --all
31 | poetry lock
32 | poetry install
33 | - name: Build docs
34 | run: |
35 | . ./venv/bin/activate
36 | cd docs
37 | make clean html
38 | - name: pycodestyle
39 | run: |
40 | . ./venv/bin/activate
41 | pycodestyle --ignore=E203,E231,E501,W503,W605,E122,E125 --max-doc-length=160 sim4rec tests
42 | - name: pylint
43 | run: |
44 | . ./venv/bin/activate
45 | pylint --rcfile=.pylintrc sim4rec
46 | - name: pytest
47 | run: |
48 | . ./venv/bin/activate
49 | pytest --cov=sim4rec --cov-report=term-missing --doctest-modules sim4rec --cov-fail-under=89 tests
50 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [GENERAL]
2 | fail-under=10
3 |
4 | [TYPECHECK]
5 | ignored-modules=pyspark.sql.functions,pyspark,torch,ignite,ignite.engine,pyspark.sql,statsmodels.stats.proportion,numba
6 | disable=bad-option-value,no-else-return
7 |
8 | [MESSAGES CONTROL]
9 | disable=anomalous-backslash-in-string,bad-continuation,missing-module-docstring,wrong-import-order,protected-access,consider-using-generator,consider-using-enumerate,invalid-name
10 |
11 | [FORMAT]
12 | ignore-long-lines=^.*([А-Яа-я]|>{3}|\.{3}|\\{2}|https://).*$
13 | max-line-length=100
14 | good-names=df,i,j,k,n,_,x,y
15 |
16 | [SIMILARITIES]
17 | min-similarity-lines=33
18 | ignore-comments=yes
19 | ignore-docstrings=yes
20 | ignore-imports=yes
21 |
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 Sber AI Lab
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Simulator
2 |
3 | Simulator is a framework for training and evaluating recommendation algorithms on real or synthetic data. The framework is based on the pyspark library for working with big data.
4 | As part of the simulation process, the framework includes data generators, response functions and other tools that allow flexible use of the simulator.
5 |
6 | # Table of contents
7 |
8 | * [Installation](#installation)
9 | * [Quickstart](#quickstart)
10 | * [Examples](#examples)
11 | * [Build from sources](#build-from-sources)
12 | * [Building documentation](#compile-documentation)
13 | * [Running tests](#tests)
14 |
15 | ## Installation
16 |
17 | ```bash
18 | pip install sim4rec
19 | ```
20 |
21 | If the installation takes too long, try
22 | ```bash
23 | pip install sim4rec --use-deprecated=legacy-resolver
24 | ```
25 |
26 | To install dependencies with poetry run
27 |
28 | ```bash
29 | pip install --upgrade pip wheel poetry
30 | poetry install
31 | ```
32 |
33 | ## Quickstart
34 |
35 | The following example shows how to use simulator to train model iteratively by refitting recommendation algorithm on the new upcoming history log
36 |
37 | ```python
38 | import numpy as np
39 | import pandas as pd
40 |
41 | import pyspark.sql.types as st
42 | from pyspark.ml import PipelineModel
43 |
44 | from sim4rec.modules import RealDataGenerator, Simulator, EvaluateMetrics
45 | from sim4rec.response import NoiseResponse, BernoulliResponse
46 | from sim4rec.recommenders.ucb import UCB
47 | from sim4rec.utils import pandas_to_spark
48 |
49 | LOG_SCHEMA = st.StructType([
50 | st.StructField('user_idx', st.LongType(), True),
51 | st.StructField('item_idx', st.LongType(), True),
52 | st.StructField('relevance', st.DoubleType(), False),
53 | st.StructField('response', st.IntegerType(), False)
54 | ])
55 |
56 | users_df = pd.DataFrame(
57 | data=np.random.normal(0, 1, size=(100, 15)),
58 | columns=[f'user_attr_{i}' for i in range(15)]
59 | )
60 | items_df = pd.DataFrame(
61 | data=np.random.normal(1, 1, size=(30, 10)),
62 | columns=[f'item_attr_{i}' for i in range(10)]
63 | )
64 | history_df = pandas_to_spark(pd.DataFrame({
65 | 'user_idx' : [1, 10, 10, 50],
66 | 'item_idx' : [4, 25, 26, 25],
67 | 'relevance' : [1.0, 0.0, 1.0, 1.0],
68 | 'response' : [1, 0, 1, 1]
69 | }), schema=LOG_SCHEMA)
70 |
71 | users_df['user_idx'] = np.arange(len(users_df))
72 | items_df['item_idx'] = np.arange(len(items_df))
73 |
74 | users_df = pandas_to_spark(users_df)
75 | items_df = pandas_to_spark(items_df)
76 |
77 | user_gen = RealDataGenerator(label='users_real')
78 | item_gen = RealDataGenerator(label='items_real')
79 | user_gen.fit(users_df)
80 | item_gen.fit(items_df)
81 | _ = user_gen.generate(100)
82 | _ = item_gen.generate(30)
83 |
84 | sim = Simulator(
85 | user_gen=user_gen,
86 | item_gen=item_gen,
87 | data_dir='test_simulator',
88 | user_key_col='user_idx',
89 | item_key_col='item_idx',
90 | log_df=history_df
91 | )
92 |
93 | noise_resp = NoiseResponse(mu=0.5, sigma=0.2, outputCol='__noise')
94 | br = BernoulliResponse(inputCol='__noise', outputCol='response')
95 | pipeline = PipelineModel(stages=[noise_resp, br])
96 |
97 | model = UCB()
98 | model.fit(log=history_df)
99 |
100 | evaluator = EvaluateMetrics(
101 | userKeyCol='user_idx',
102 | itemKeyCol='item_idx',
103 | predictionCol='relevance',
104 | labelCol='response',
105 | mllib_metrics=['areaUnderROC']
106 | )
107 |
108 | metrics = []
109 | for i in range(10):
110 | users = sim.sample_users(0.1).cache()
111 |
112 | recs = model.predict(
113 | log=sim.log, k=5, users=users, items=items_df, filter_seen_items=True
114 | ).cache()
115 |
116 | true_resp = (
117 | sim.sample_responses(
118 | recs_df=recs,
119 | user_features=users,
120 | item_features=items_df,
121 | action_models=pipeline,
122 | )
123 | .select("user_idx", "item_idx", "relevance", "response")
124 | .cache()
125 | )
126 |
127 | sim.update_log(true_resp, iteration=i)
128 |
129 | metrics.append(evaluator(true_resp))
130 |
131 | model.fit(sim.log.drop("relevance").withColumnRenamed("response", "relevance"))
132 |
133 | users.unpersist()
134 | recs.unpersist()
135 | true_resp.unpersist()
136 | ```
137 |
138 | ## Examples
139 |
140 | You can find useful examples in the 'notebooks' folder, which demonstrate how to use synthetic data generators, composite generators, evaluate the results of the generators, iteratively refit the recommendation algorithm, use response functions and more.
141 |
142 | Experiments with different datasets and a tutorial on writing custom response functions can be found in the 'experiments' folder.
143 |
144 | ## Case studies
145 |
146 | Case studies prepared for the demo track are available in the 'demo' directory.
147 |
148 | 1. Synthetic data generation
149 | 2. Long-term RS performance evaluation
150 |
151 | ## Build from sources
152 |
153 | ```bash
154 | poetry build
155 | pip install ./dist/sim4rec-0.0.1-py3-none-any.whl
156 | ```
157 |
158 | ## Compile documentation
159 |
160 | ```bash
161 | cd docs
162 | make clean && make html
163 | ```
164 |
165 | ## Tests
166 |
167 | The pytest Python library is used for testing, and to run tests for all modules you can run the following command from the repository root directory
168 |
169 | ```bash
170 | pytest
171 | ```
172 |
173 | ## Licence
174 | Sim4Rec is distributed under the [Apache Licence Version 2.0] (https://github.com/sb-ai-lab/Sim4Rec/blob/main/LICENSE),
175 | however the SDV package imported by Sim4Rec for synthetic data generation
176 | is distributed under the [Business Source Licence (BSL) 1.1](https://github.com/sdv-dev/SDV/blob/master/LICENSE).
177 |
178 | Synthetic tabular data generation is not a purpose of the Sit4Rec framework.
179 | Sim4Rec provides an API and wrappers to run simulations with synthetic data, but the method of synthetic data generation is determined by the user.
180 | The SDV package is imported for illustration purposes and can be replaced by another synthetic data generation solution.
181 |
182 | Thus, synthetic data generation functionality and quality evaluation is provided by the SDV library,
183 | namely `SDVDataGenerator` from [generator.py](sim4rec/modules/generator.py) and `evaluate_synthetic` from [evaluation.py](sim4rec/modules/evaluation.py)
184 | should only be used for non-production purposes according to the SDV licence.
185 |
--------------------------------------------------------------------------------
/demo/README.md:
--------------------------------------------------------------------------------
1 | # Case studies for demo
2 | Cases for the demo:
3 | 1. [Synthetic data generation](https://github.com/sb-ai-lab/Sim4Rec/blob/main/demo/synthetic_data_generation.ipynb)
4 |
5 | Generation of synthetic users based on real data.
6 |
7 | 2. [Long-term RS performance evaluation](https://github.com/sb-ai-lab/Sim4Rec/blob/main/demo/rs_performance_evaluation.ipynb)
8 |
9 | Simple simulation pipeline for long-term performance evaluation of recommender system.
10 |
11 | # Table of contents
12 |
13 | * [Installation](#installation)
14 | * [Synthetic data generation pipeline](#synthetic-data-generation-pipeline)
15 | * [Long-term RS performance evaluation pipeline](#long-term-RS-performance-evaluation-pipeline)
16 |
17 | ## Installation
18 |
19 | Install dependencies with poetry run
20 |
21 | ```bash
22 | pip install --upgrade pip wheel poetry
23 | poetry install
24 | ```
25 |
26 | Please install `rs_datasets` to import MovieLens dataset for the synthetic_data_generation.ipynb
27 | ```bash
28 | pip install rs_datasets
29 | ```
30 |
31 | ## Synthetic data generation pipeline
32 | 1. Fit non-negative ALS to real data containing user interactions
33 | 2. Obtain vector representations of real users
34 | 3. Fit CopulaGAN to non-negative ALS embeddings of real users
35 | 4. Generate synthetic user feature vectors with CopulaGAN
36 | 5. Evaluate the quality of the synthetic user profiles
37 |
38 | ## Long-term RS performance evaluation pipeline
39 | 1. Before running the simulation cycle:
40 | - Initialize and fit recommender model to historical data
41 | - Construct response function pipeline and fit it to items
42 | - Initialize simulator
43 | 2. Run the simulation cycle:
44 | - Sample users using a simulator
45 | - Get recommendations for sampled users from the recommender system
46 | - Get responses to recommended items from the response function
47 | - Update user interaction history
48 | - Measure the quality of the recommender system
49 | - Refit the recommender model
50 | 3. After running the simulation cycle:
51 | - Get recommendations for all users from the recommender system
52 | - Get responses to recommended items from the response function
53 | - Measure the quality of the recommender system trained in the simulation cycle
54 |
55 |
--------------------------------------------------------------------------------
/demo/cycle_time_50M_users_5seeds.csv:
--------------------------------------------------------------------------------
1 | index,sample_users_time,sample_responses_time,update_log_time,metrics_time,iter_time,users_num
2 | 0,11.15279746055603,1.0836639404296875,4.48405385017395,10.712761878967283,39.2559826374054,1001660
3 | 1,6.915727138519287,1.144331455230713,2.2269480228424072,7.514208316802978,29.782680749893192,3450869
4 | 2,9.624614000320436,1.9051272869110107,2.5454702377319336,8.540910720825195,38.522289514541626,5897983
5 | 3,10.06862211227417,2.072542428970337,2.2254538536071777,10.562509536743164,41.597875118255615,8347862
6 | 4,10.775094747543335,3.0510990619659424,3.777689456939697,10.119225978851318,44.3191511631012,10798519
7 | 5,10.705054521560667,2.914153575897217,3.795412063598633,8.681732177734375,40.83335065841675,13252619
8 | 6,8.72693920135498,3.5513224601745605,3.9841458797454834,11.071977853775024,42.67783713340759,15702401
9 | 7,11.697437286376953,4.566967010498047,3.8343541622161865,9.24762773513794,44.81257128715515,18153435
10 | 8,9.278602838516237,4.395270109176637,4.1952335834503165,10.463558673858644,45.67495346069336,20603933
11 | 9,9.882215976715088,4.600301027297974,4.617482662200928,7.5286338329315186,39.09629273414612,23053131
12 | 10,8.283291816711426,4.990021228790283,4.571868658065796,6.711239576339723,39.4278380870819,25502789
13 | 11,7.111920833587647,5.580391883850098,5.736581087112428,7.497943639755249,39.17783188819885,27948898
14 | 12,7.849193811416626,6.416064977645874,5.17004656791687,6.878835916519165,41.43364119529724,30399766
15 | 13,7.180546760559082,6.605871915817263,6.033704519271852,7.945231914520264,43.03808045387268,32848393
16 | 14,8.542397260665894,6.9637413024902335,6.764723062515259,6.6079912185668945,43.448081731796265,35299228
17 | 15,6.880044460296631,7.377933502197266,7.637020826339723,7.429396629333496,42.84182572364807,37749079
18 | 16,6.994188547134399,7.831934452056885,7.783137798309326,7.6338841915130615,44.403237104415894,40197676
19 | 17,11.496174812316895,9.800821781158447,7.58228325843811,7.516530275344849,53.62795400619507,42648904
20 | 18,7.79708194732666,9.787442684173584,7.138163089752197,8.8120014667511,50.9351236820221,45099701
21 | 19,8.43359112739563,9.523060560226439,13.3839693069458,7.699332237243652,55.89987015724182,47551011
22 | 20,7.566620349884032,9.683169603347777,8.543299436569214,7.9015352725982675,48.8129768371582,50000000
23 | 0,7.190418481826782,0.707650899887085,2.28092622756958,8.264355659484862,29.746066331863407,1000052
24 | 1,7.42564058303833,1.5093376636505127,1.7751114368438718,7.884857416152954,31.61107087135315,3449012
25 | 2,9.563541889190674,1.6901743412017822,2.4463682174682617,10.842523097991943,37.391902208328254,5900034
26 | 3,6.786839962005615,2.04491662979126,3.031137704849243,8.219735622406006,31.60140872001648,8349078
27 | 4,6.9660797119140625,2.544528245925904,3.1702497005462646,6.7007696628570566,30.183650732040405,10797805
28 | 5,6.701206207275392,3.25626277923584,3.102719783782959,8.533268451690674,32.68921160697937,13247233
29 | 6,6.420436382293701,3.4489002227783203,3.6940829753875732,8.104418039321901,33.3744637966156,15697274
30 | 7,7.689843416213987,3.54607367515564,4.6508872509002686,6.960585594177246,35.21638894081116,18147315
31 | 8,7.622220993041992,4.401278734207153,4.334021091461182,7.482393980026245,37.205557823181145,20600284
32 | 9,7.009465456008911,4.764593839645387,4.193192958831787,6.897244215011598,36.11890935897827,23048885
33 | 10,7.099967002868652,5.215404510498047,4.7635498046875,6.3894689083099365,36.10999631881714,25499101
34 | 11,6.69516921043396,5.744151592254639,6.602449655532838,6.736882925033568,38.676912784576416,27948933
35 | 12,7.4403440952301025,6.7572083473205575,6.003504753112793,7.172186851501465,40.154664516448975,30401023
36 | 13,6.998932123184204,6.20236349105835,5.672366619110107,7.1271512508392325,39.302847623825066,32851486
37 | 14,6.915180683135986,7.822373628616332,6.337413787841798,7.708374500274657,42.24116182327271,35302731
38 | 15,7.066570281982423,8.524396419525146,6.580228328704834,6.166517734527588,41.87582015991211,37752524
39 | 16,6.589472532272339,7.94525933265686,7.2967987060546875,6.8301634788513175,42.06322407722472,40202517
40 | 17,7.117594242095947,8.471821308135986,7.566959619522095,7.366473436355591,43.776485204696655,42652356
41 | 18,7.089627504348755,8.872769355773926,9.707370281219482,7.241292476654053,46.907626152038574,45099245
42 | 19,8.168758392333983,8.996371507644652,7.879936456680298,7.036686420440674,46.9917676448822,47551931
43 | 20,7.7413179874420175,9.81250810623169,8.90994095802307,8.575214385986326,50.75287199020386,50000000
44 | 0,14.046154499053955,1.1173858642578125,3.2084243297576904,13.64061188697815,53.393492698669434,999235
45 | 1,11.949691534042358,1.360381841659546,3.777711153030397,13.68442177772522,49.955246925354004,3446144
46 | 2,12.614464521408081,2.5385425090789795,3.680132150650024,9.67790412902832,48.45165801048279,5896817
47 | 3,7.58543848991394,2.63399624824524,3.206125497817993,7.958815097808838,35.34187650680542,8347080
48 | 4,8.984705686569212,2.444612979888916,3.4341411590576167,8.198475122451782,35.859378814697266,10796748
49 | 5,8.170023202896116,3.02512526512146,3.568270444869995,9.304476022720335,37.61494064331055,13245268
50 | 6,8.56131911277771,3.1345784664154053,3.905285120010376,7.664520978927612,37.156663179397576,15695938
51 | 7,7.557111024856567,3.866744756698608,4.302062511444092,7.748353481292725,37.79162979125977,18148543
52 | 8,7.450805902481079,4.136418581008911,4.425943851470947,8.722140789031982,38.14530491828919,20598543
53 | 9,8.626715898513794,4.416030645370483,6.006159543991089,8.419605731964111,42.09596753120423,23050550
54 | 10,8.417001008987429,5.234817981719972,4.37753963470459,8.036226987838745,41.50021719932556,25502173
55 | 11,7.920273303985598,5.881896018981934,5.305891990661621,8.453716993331911,46.445900678634644,27950737
56 | 12,8.933857202529909,6.833164215087892,7.10662579536438,7.8169307708740225,46.64496946334839,30400600
57 | 13,8.500087738037111,7.582361221313477,7.917228937149048,8.558687925338745,49.47708082199097,32849515
58 | 14,8.397745609283447,6.668765306472777,6.310772657394409,8.957489013671875,46.73683547973633,35301288
59 | 15,8.917690038681028,7.176633358001709,7.5546300411224365,8.910341739654541,49.477648973464966,37750324
60 | 16,9.176238298416136,7.644198417663574,7.773294448852539,8.591369390487671,50.09724831581116,40197372
61 | 17,9.06108808517456,9.202136516571043,9.213278055191038,9.451592922210693,54.75792002677918,42649165
62 | 18,9.47390842437744,9.417816400527952,7.360036134719849,8.879366874694824,53.35044598579407,45096715
63 | 19,9.15760898590088,9.921532869338991,10.288001775741575,9.495866537094118,56.78333520889282,47548777
64 | 20,16.184381008148193,9.112372875213623,8.641405582427979,9.695455789566038,61.253098726272576,50000000
65 | 0,6.146031141281128,0.6092917919158936,1.4211950302124023,5.874315500259399,23.75724244117737,1000900
66 | 1,6.513831615447997,1.0445129871368408,1.6927015781402588,5.854241371154785,25.26718783378601,3450816
67 | 2,6.894704580307008,1.874687910079956,2.5109894275665283,6.259218692779541,27.11310958862305,5902225
68 | 3,6.681285381317139,1.8760371208190918,2.541154146194458,6.62699556350708,28.388039350509644,8350918
69 | 4,6.579586744308473,2.4730172157287598,3.8537116050720215,6.489072561264037,30.003851413726807,10802310
70 | 5,7.6521430015563965,3.9225046634674072,5.195086479187012,9.061530828475954,40.26849174499512,13251873
71 | 6,8.345972299575807,3.4538462162017822,3.8924813270568848,9.229451656341553,40.28138494491577,15700939
72 | 7,8.912295579910277,3.58895206451416,5.256728410720825,9.334285974502565,42.195481300354,18152575
73 | 8,8.239080429077147,4.429729700088501,4.059387922286986,8.210181474685669,40.50559496879577,20603259
74 | 9,9.318784713745117,4.529031991958618,4.944356679916382,8.913965463638307,43.021955728530884,23055321
75 | 10,7.842668533325195,5.421630144119263,5.258700609207153,8.723150491714478,42.82481050491333,25502107
76 | 11,9.253255605697634,5.528379440307617,5.403992176055908,9.261732339859007,46.435377836227424,27951857
77 | 12,8.715843677520754,6.042918205261231,6.0950984954833975,10.580655574798584,48.58100295066834,30401979
78 | 13,10.208436012268066,6.864079475402832,6.1206886768341064,9.028311014175417,49.34014439582825,32852565
79 | 14,9.33082103729248,7.135979652404785,6.008606195449829,9.459828853607178,49.91624927520752,35301288
80 | 15,9.565973281860352,8.052979946136475,6.41303563117981,9.887627601623537,51.34547996520996,37751471
81 | 16,10.069446802139284,7.863379001617432,10.273854017257689,10.316685199737547,56.69612789154053,40200631
82 | 17,9.850059986114502,8.273949146270754,6.978187322616577,9.329088449478151,53.352537870407104,42651732
83 | 18,10.059386253356934,8.744035959243773,8.563298940658571,15.396541357040405,69.58081459999084,45102261
84 | 19,14.43708348274231,9.591934204101562,8.743757486343384,15.680865287780762,75.4746916294098,47551779
85 | 20,14.997735500335693,9.706062555313109,9.138372182846071,14.083873987197876,74.37188625335693,50000000
86 | 0,8.0369713306427,0.6770808696746826,1.845033884048462,8.336313486099241,30.60119819641113,998952
87 | 1,7.773632764816284,1.0916354656219482,2.2429637908935547,8.532743692398071,32.707759380340576,3451097
88 | 2,7.750965356826782,1.7475388050079346,3.01460862159729,7.785831451416016,32.291831016540534,5901157
89 | 3,9.205931425094603,2.2260079383850098,3.6819036006927486,8.288788080215454,39.349198341369636,8352896
90 | 4,8.438814640045166,2.592901468276977,3.4753494262695312,7.9989845752716064,36.04001092910767,10802364
91 | 5,7.781528472900391,3.16336727142334,4.267394781112672,8.632305383682251,39.529292345047004,13254049
92 | 6,7.7201831340789795,3.8348591327667236,4.173544406890868,8.322942733764647,37.802941083908074,15702498
93 | 7,7.653884172439575,4.125317096710205,4.42400050163269,8.175558090209961,39.1270318031311,18156231
94 | 8,9.750343084335329,4.440631151199341,4.769840717315674,8.377973794937134,42.31133937835693,20606681
95 | 9,8.555236101150513,5.129318952560425,5.573054552078247,8.976478576660156,43.04151272773743,23057180
96 | 10,9.08205509185791,5.223078727722168,5.408283948898315,8.495746850967407,43.826560258865356,25507343
97 | 11,8.577333211898804,6.8158326148986825,6.167778730392456,9.532485723495485,47.65059852600098,27957800
98 | 12,10.51621079444885,6.293447494506836,5.392074823379517,11.002318143844604,49.65189242362976,30407546
99 | 13,9.719135522842409,6.897059202194214,8.705862045288086,8.648547649383545,49.54546356201172,32858175
100 | 14,8.366340398788452,7.308839082717895,7.413874626159668,8.331423759460451,48.2678542137146,35307198
101 | 15,8.730891704559326,7.683967351913452,7.198543071746826,9.149252891540527,49.236247062683105,37756384
102 | 16,10.17529797554016,7.613230228424072,8.572442531585692,8.933055639266966,51.33104181289673,40205103
103 | 17,9.738165616989138,8.498561859130861,7.544682502746582,8.592431545257567,51.26230978965759,42652737
104 | 18,8.937159538269043,9.955829858779909,8.697373151779175,9.49044942855835,56.309436082839966,45101848
105 | 19,9.159666299819945,8.806410551071169,8.19353199005127,9.085792779922484,53.98657464981079,47550366
106 | 20,8.804518222808838,9.738995313644411,10.185393571853636,9.620959997177122,56.602333784103394,50000000
107 |
--------------------------------------------------------------------------------
/demo/requirements.txt:
--------------------------------------------------------------------------------
1 | rs_datasets
--------------------------------------------------------------------------------
/demo/synthetic_gen_time.csv:
--------------------------------------------------------------------------------
1 | n_samples,"time, s"
2 | 100000.0,38.741114616394036
3 | 595000.0,192.2492127418518
4 | 1090000.0,311.1856393814087
5 | 1585000.0,397.0943524837494
6 | 2080000.0,553.0572288036346
7 | 2575000.0,670.9549901485442
8 | 3070000.0,723.5816826820374
9 | 3565000.0,856.1593241691589
10 | 4060000.0,977.9992072582244
11 | 4555000.0,1063.0632655620577
12 | 5050000.0,1089.8609247207644
13 | 5545000.0,1322.955791473389
14 | 6040000.0,1417.3582293987274
15 | 6535000.0,1430.8968422412872
16 | 7030000.0,1486.5244340896606
17 | 7525000.0,1647.060962677002
18 | 8020000.0,1754.9148001670835
19 | 8515000.0,1840.006804227829
20 | 9010000.0,1771.7790541648865
21 | 9505000.0,1953.510559797287
22 | 10000000.0,2095.720278024673
23 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
22 | gh-deploy:
23 | @make html
24 | @ghp-import _build/html -p -o -n
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 | sys.path.insert(0, os.path.abspath('../..'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'Sim4Rec'
21 | copyright = '2022, Sberbank AI Laboratory'
22 | author = 'Alexey Vasilev, Anna Volodkevich, Andrey Gurov, Elizaveta Stavinova, Anton Lysenko'
23 |
24 | # The full version, including alpha/beta/rc tags
25 | release = '0.0.1'
26 |
27 |
28 | # -- General configuration ---------------------------------------------------
29 |
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = [
34 | 'sphinx.ext.autodoc'
35 | ]
36 |
37 | autoclass_content = 'both'
38 |
39 | # Add any paths that contain templates here, relative to this directory.
40 | templates_path = ['_templates']
41 |
42 | # List of patterns, relative to source directory, that match files and
43 | # directories to ignore when looking for source files.
44 | # This pattern also affects html_static_path and html_extra_path.
45 | exclude_patterns = []
46 |
47 |
48 | # -- Options for HTML output -------------------------------------------------
49 |
50 | # The theme to use for HTML and HTML Help pages. See the documentation for
51 | # a list of builtin themes.
52 | #
53 | html_theme = 'sphinx_rtd_theme'
54 |
55 | # Add any paths that contain custom static files (such as style sheets) here,
56 | # relative to this directory. They are copied after the builtin static files,
57 | # so a file named "default.css" will overwrite the builtin "default.css".
58 | # html_static_path = ['_static']
59 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. Simulator documentation master file, created by
2 | sphinx-quickstart on Mon Sep 19 11:23:53 2022.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to Sim4Rec's documentation!
7 | =====================================
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 | :caption: Contents:
12 |
13 | pages/modules
14 | pages/response
15 | pages/utils
16 |
17 |
18 | Indices and tables
19 | ==================
20 |
21 | * :ref:`genindex`
22 | * :ref:`modindex`
23 | * :ref:`search`
24 |
--------------------------------------------------------------------------------
/docs/source/pages/modules.rst:
--------------------------------------------------------------------------------
1 | Modules
2 | =======
3 |
4 | .. automodule:: sim4rec.modules
5 |
6 |
7 | Generators
8 | __________
9 |
10 | Generators serves for generating synthetic either real data for simulation process.
11 | All of the generators are derived from ``GeneratorBase`` base class and to implement
12 | your own generator you must inherit from it. Basicly, the generator fits from a provided
13 | dataset (in case of real generator it just remembers it), than it creates a population to
14 | sample from with a number of rows by calling the ``generate()`` method and samples from
15 | this those population with ``sample()`` method. Note, that ``sample()`` takes the fraction
16 | of population size.
17 |
18 | If a user is interested in using multiple generators at ones (e.g. modelling multiple groups
19 | of users or mixing results from different generating models) that it will be useful to look
20 | at ``CompositeGenerator`` which can handle a list of generators has a proportion mixing parameter
21 | which controls the weights of particular generators at sampling and generating time.
22 |
23 | .. autoclass:: sim4rec.modules.GeneratorBase
24 | :members:
25 |
26 | .. autoclass:: sim4rec.modules.RealDataGenerator
27 | :members:
28 |
29 | .. autoclass:: sim4rec.modules.SDVDataGenerator
30 | :members:
31 |
32 | .. autoclass:: sim4rec.modules.CompositeGenerator
33 | :members:
34 |
35 |
36 | Embeddings
37 | __________
38 |
39 | Embeddings can be utilized in case of high dimensional data or high data sparsity and it should be
40 | applied before performing the main simulator pipeline. Here the autoencoder estimator and transformer
41 | are implemented in the chance of the existing spark methods are not enough for your propose. The usage
42 | example can be found in `notebooks` directory.
43 |
44 | .. autoclass:: sim4rec.modules.EncoderEstimator
45 | :members:
46 |
47 | .. autoclass:: sim4rec.modules.EncoderTransformer
48 | :members:
49 |
50 |
51 | Items selectors
52 | _______________
53 |
54 | Those spark transformers are used to assign items to given users while making the candidate
55 | pairs for prediction by recommendation system. It is optional to use selector in your pipelines
56 | if a recommendation algorithm can recommend any item with no restrictions. The one should
57 | implement own selector in case of custom logic of items selection for a certain users. Selector
58 | could implement some business rules (e.g. some items are not available for a user), be a simple
59 | recommendation model, generating candidates, or create user-specific items (e.g. price offers)
60 | online. To implement your custom selector, you can derive from a ``ItemSelectionEstimator`` base
61 | class, to implement some pre-calculation logic, and ``ItemSelectionTransformer`` to perform pairs
62 | creation. Both classes are inherited from spark's Estimator and Transformer classes and to define
63 | fit() and transform() methods the one can just overwrite ``_fit()`` and ``_transform()``.
64 |
65 | .. autoclass:: sim4rec.modules.ItemSelectionEstimator
66 | :members:
67 |
68 | .. autoclass:: sim4rec.modules.ItemSelectionTransformer
69 | :members:
70 |
71 | .. autoclass:: sim4rec.modules.CrossJoinItemEstimator
72 | :members:
73 | :private-members: _fit
74 |
75 | .. autoclass:: sim4rec.modules.CrossJoinItemTransformer
76 | :members:
77 | :private-members: _transform
78 |
79 |
80 | Simulator
81 | _________
82 |
83 | The simulator class provides a way to handle the simulation process by connecting different
84 | parts of the module such as generatos and response pipelines and saving the results to a
85 | given directory. The simulation process consists of the following steps:
86 |
87 | - Sampling random real or synthetic users
88 | - Creation of candidate user-item pairs for a recommendation algorithm
89 | - Prediction by a recommendation system
90 | - Evaluating respones on a given recommendations
91 | - Updating the history log
92 | - Metrics evaluation
93 | - Refitting the recommendation model with a new data
94 |
95 | Some of the steps can be skipped depending on the task your perform. For example you don't need
96 | a second step if your algorithm dont use user-item pairs as an input or you don't need to refit
97 | the model if you want just to evaluate it on some data. For more usage please refer to examples
98 |
99 | .. autoclass:: sim4rec.modules.Simulator
100 | :members:
101 |
102 |
103 | Evaluation
104 | __________
105 |
106 | .. autoclass:: sim4rec.modules.EvaluateMetrics
107 | :members:
108 | :special-members: __call__
109 |
110 | .. autofunction:: sim4rec.modules.evaluate_synthetic
111 |
112 | .. autofunction:: sim4rec.modules.ks_test
113 |
114 | .. autofunction:: sim4rec.modules.kl_divergence
115 |
--------------------------------------------------------------------------------
/docs/source/pages/response.rst:
--------------------------------------------------------------------------------
1 | Response functions
2 | ==================
3 |
4 | .. automodule:: sim4rec.response
5 |
6 |
7 | Response functions are used to model users behaviour on items to simulate any
8 | kind of reaction that user perform. For example it can be binary classification
9 | model that determines whether user clicked on a item, or rating, that user gave
10 | to an item.
11 |
12 |
13 | Base classes
14 | ____________
15 |
16 | All of the existing response functions are made on a top of underlying base classes
17 | ``ActionModelEstimator`` and ``ActionModelTransformer`` which follows the logic of
18 | spark's Estimator and Transformer classes. To implement custom response function
19 | the one can inherit from base classes:
20 |
21 | * ``ActionModelEstimator`` if any learning logic is necessary
22 | * ``ActionModelTransformer`` for performing model infering
23 |
24 | Base classes are inherited from spark's Estimator and Transformer and to define fit()
25 | or transform() logic the one should overwrite ``_fit()`` and ``_transform()`` respectively.
26 | Note, that those base classes are useful to implement your own response function, but are not
27 | necessary, and to create a response pipeline any proper spark's estimators/transformers can be used
28 |
29 | .. autoclass:: sim4rec.response.ActionModelEstimator
30 | :members:
31 |
32 | .. autoclass:: sim4rec.response.ActionModelTransformer
33 | :members:
34 |
35 |
36 | Response functions
37 | __________________
38 |
39 | .. autoclass:: sim4rec.response.ConstantResponse
40 | :members:
41 |
42 | .. autoclass:: sim4rec.response.NoiseResponse
43 | :members:
44 |
45 | .. autoclass:: sim4rec.response.CosineSimilatiry
46 | :members:
47 |
48 | .. autoclass:: sim4rec.response.BernoulliResponse
49 | :members:
50 |
51 | .. autoclass:: sim4rec.response.ParametricResponseFunction
52 | :members:
53 |
--------------------------------------------------------------------------------
/docs/source/pages/utils.rst:
--------------------------------------------------------------------------------
1 | Utils
2 | =====
3 |
4 | .. automodule:: sim4rec.utils
5 |
6 |
7 | Dataframe convertation
8 | ______________________
9 |
10 | .. autofunction:: sim4rec.utils.pandas_to_spark
11 |
12 |
13 | Session
14 | ______________________
15 |
16 | .. autofunction:: sim4rec.utils.session_handler.get_spark_session
17 |
18 | .. autoclass:: sim4rec.utils.session_handler.State
19 | :members:
20 |
21 | Exceptions
22 | __________
23 |
24 | .. autoclass:: sim4rec.utils.NotFittedError
25 | :members:
26 |
27 | .. autoclass:: sim4rec.utils.EmptyDataFrameError
28 | :members:
29 |
30 |
31 | Transformers
32 | ____________
33 |
34 | .. autoclass:: sim4rec.utils.VectorElementExtractor
35 | :members:
36 |
37 |
38 | File management
39 | _______________
40 |
41 | .. autofunction:: sim4rec.utils.save
42 | .. autofunction:: sim4rec.utils.load
43 |
--------------------------------------------------------------------------------
/experiments/Amazon/RU_train_test_split.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "6f3723f0-15da-4e41-baff-d1fac30b6da1",
6 | "metadata": {},
7 | "source": [
8 | "# Иллюстрация разбиения датасета Amazon на train/test/split\n",
9 | "
\n",
10 | " Предобработки признаков для датасета Amazon, используемая здесь вынесена в \n",
11 | " файл,\n",
12 | " иллюстрация её работы продемонстрирована в ноутбуке.
\n",
14 | "
"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "id": "8854aa20-25e6-4468-935b-46fbb1dfb0a3",
20 | "metadata": {},
21 | "source": [
22 | "### $\\textbf{Содержание}$:\n",
23 | "\n",
24 | "### $\\textbf{I. Загрузка данных }$\n",
25 | "#### - Чтение данных с диска;\n",
26 | "---"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "id": "d9d4aafa-b8c5-4e7d-b8c0-3a79ea3fae26",
32 | "metadata": {},
33 | "source": [
34 | "### $\\textbf{II. Разбиение данных для эксперимента}$\n",
35 | "### Для разбиения данных на $\\it{train/test/split}$ производится деление исходного датасета *df_rating* по квантилям атрибута $\\it{timestamp}$, $\\mathbb{q}$ для генерации признаков:\n",
36 | "#### $\\it{rating}_{t}$ = *df_rating*$[0, \\mathbb{q}_{t}]$, где $\\mathbb{q}_{train}=0.5$, $\\mathbb{q}_{val}=0.75$, $\\mathbb{q}_{test}=1$:\n",
37 | "#### - $\\it{rating}_{train}$ = *df_rating*$[0, 0.5]$;\n",
38 | "#### - $\\it{rating}_{val}$ = *df_rating*$[0, 0.75]$;\n",
39 | "#### - $\\it{rating}_{test}$ = *df_rating*$[0, 1]$;\n",
40 | "### Далее для каждого из промежутков {$\\it{rating}_{train}$, $\\it{rating}_{val}$, $\\it{rating}_{test}$} генерируются соответствующие им признаки пользователей и предложений:\n",
41 | "#### - $\\it{items}_{t}$, $\\it{users}_{t}$, $\\it{rating}_{t}$ = data_processing(movies, $\\it{rating}_{t}$, tags), $t \\in \\{\\it{train}, \\it{val}, \\it{test}\\}$;\n",
42 | "### После чего формируются окончательные рейтинги:\n",
43 | "#### - $\\it{rating}_{train}$ = $\\it{rating}_{train}$ = *df_rating*$[0, 0.5]$;\n",
44 | "#### - $\\it{rating}_{val}$ = $\\it{rating}_{val}$[$\\mathbb{q}>\\mathbb{q}_{train}$] = *df_rating*$(0.5, 0.75]$;\n",
45 | "#### - $\\it{rating}_{test}$ = $\\it{rating}_{test}$[$\\mathbb{q}>\\mathbb{q}_{val}$] = *df_rating*$(0.75, 1]$;\n",
46 | "\n",
47 | "\n",
48 | " То есть, если для генерации признаков для валидационного набора данных мы используем временные метки с 0 по 0.75 квантиль, то в качестве рейтингов мы возьмем оценки\n",
49 | " только с 0.5 по 0.75 квантили. Аналогично для тестового набора: все временные метки для генерации признаков, но в качестве рейтингов только оценки с 0.75 по 1\n",
50 | " квантили.
\n",
51 | "
\n",
52 | "
"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "id": "0a23f90f-056c-44c3-b18c-13e8064a4bac",
59 | "metadata": {},
60 | "outputs": [
61 | {
62 | "name": "stderr",
63 | "output_type": "stream",
64 | "text": [
65 | "[nltk_data] Downloading package punkt to\n",
66 | "[nltk_data] /data/home/agurov/nltk_data...\n",
67 | "[nltk_data] Package punkt is already up-to-date!\n",
68 | "[nltk_data] Downloading package stopwords to\n",
69 | "[nltk_data] /data/home/agurov/nltk_data...\n",
70 | "[nltk_data] Package stopwords is already up-to-date!\n",
71 | "[nltk_data] Downloading package words to\n",
72 | "[nltk_data] /data/home/agurov/nltk_data...\n",
73 | "[nltk_data] Package words is already up-to-date!\n",
74 | "/data/home/agurov/.conda/envs/sber3.8/lib/python3.8/site-packages/pyspark/sql/pandas/functions.py:394: UserWarning: In Python 3.6+ and Spark 3.0+, it is preferred to specify type hints for pandas UDF instead of specifying pandas UDF type which will be deprecated in the future releases. See SPARK-28264 for more details.\n",
75 | " warnings.warn(\n"
76 | ]
77 | }
78 | ],
79 | "source": [
80 | "import pandas as pd\n",
81 | "import numpy as np\n",
82 | "import re\n",
83 | "import itertools\n",
84 | "import tqdm\n",
85 | "\n",
86 | "from amazon import data_processing, get_spark_session"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "id": "91a4cb15-e49f-4ae8-a569-14703cf9f7a1",
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "file_name_parquet = \"all.parquet\"\n",
97 | "save_path = \"hdfs://namenode:9000/Sber_data/Amazon/final_data\"\n",
98 | "\n",
99 | "spark = get_spark_session(1)"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "id": "4625524c-b878-4dc8-84bc-b41d7ffddd1c",
105 | "metadata": {},
106 | "source": [
107 | "I. Загрузка данных"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "id": "6aec812a-a519-409e-9258-5d331cc9633c",
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "df = pd.read_parquet(file_name_parquet)\n",
118 | "df_sp = spark.createDataFrame(df)"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "id": "d09161ae-21f5-4a78-8c42-376c20009a4a",
124 | "metadata": {},
125 | "source": [
126 | "### II. Разбиение данных для эксперимента"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "id": "648d7c9e-b0b4-4bdc-9878-a0e46b5832bd",
132 | "metadata": {},
133 | "source": [
134 | "### Разбиение df_rating на train/test/validation части по квантилям timestamp:\n",
135 | "#### - train [0, 0.5]\n",
136 | "#### - validation [0, 0.75]\n",
137 | "#### - test [0, 1.]"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 3,
143 | "id": "5dbd2348-4188-4a68-8fe6-b0d463703e0b",
144 | "metadata": {},
145 | "outputs": [
146 | {
147 | "name": "stdout",
148 | "output_type": "stream",
149 | "text": [
150 | "Quantile: 0.5 - 1189555200.0, 0.75 - 1301443200.0\n"
151 | ]
152 | }
153 | ],
154 | "source": [
155 | "q_50, q_75 = df_sp.approxQuantile(\"timestamp\", [0.5, 0.75], 0)\n",
156 | "print(f\"Quantile: 0.5 - {q_50}, 0.75 - {q_75}\")"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "id": "fc1aa2c3-29fe-4278-9b33-08014ede15da",
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "df_train = df_procc.filter(sf.col(\"timestamp\") <= q_50)\n",
167 | "df_train.cache()"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "id": "41174c08-956e-47a6-881c-1164a6f4f5e0",
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "df_val = df_procc.filter(sf.col(\"timestamp\") <= q_75)\n",
178 | "df_val.cache()"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": null,
184 | "id": "0fd08ad6-4d4f-489d-b530-a2a4ce26fc26",
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "df_test = df_procc\n",
189 | "df_test.cache()"
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "id": "03ef300d-d957-443c-b8f9-413354e574ea",
195 | "metadata": {},
196 | "source": [
197 | "#### Генерация признаков по временным промежуткам"
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "id": "427898bd-811b-4c96-b85a-0003929bfc69",
203 | "metadata": {},
204 | "source": [
205 | "#### Train data"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": null,
211 | "id": "5581f3cd-2ee3-4dae-bcba-bc227f319463",
212 | "metadata": {},
213 | "outputs": [],
214 | "source": [
215 | "df_users_train, df_items_train, df_rating_train = data_processing(df_train)\n",
216 | "\n",
217 | "df_items_train.write.parquet(os.path.join(save_path, r'train/items.parquet'))\n",
218 | "df_users_train.write.parquet(os.path.join(save_path, r'train/users.parquet'))\n",
219 | "df_rating_train.write.parquet(os.path.join(save_path, r'train/rating.parquet'))\n",
220 | "\n",
221 | "df_train.unpersist()"
222 | ]
223 | },
224 | {
225 | "cell_type": "markdown",
226 | "id": "72abd8d2-eec1-4aab-9de5-51da78fe1530",
227 | "metadata": {},
228 | "source": [
229 | "#### Validation data"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "id": "0df22394-6d7f-4499-bedd-2d35399de706",
236 | "metadata": {},
237 | "outputs": [],
238 | "source": [
239 | "df_users_val, df_items_val, df_rating_val = data_processing(df_val)\n",
240 | "\n",
241 | "df_rating_val = df_rating_val.filter(sf.col(\"timestamp\") > q_50)\n",
242 | "df_items_val.write.parquet(os.path.join(save_path, r'val/items.parquet'))\n",
243 | "df_users_val.write.parquet(os.path.join(save_path, r'val/users.parquet'))\n",
244 | "df_rating_val.write.parquet(os.path.join(save_path, r'val/rating.parquet'))\n",
245 | "\n",
246 | "df_val.unpersist()"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "id": "31ebc6f5-4f65-4b99-a261-dbf02ea7d132",
252 | "metadata": {},
253 | "source": [
254 | "#### Test data"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "id": "4fa40c1f-c8ce-48c2-9d79-e58abc78752c",
261 | "metadata": {},
262 | "outputs": [],
263 | "source": [
264 | "df_users_test, df_items_test, df_rating_test = data_processing(df_test)\n",
265 | "\n",
266 | "df_rating_test = df_rating_test.filter(sf.col(\"timestamp\") > q_75)\n",
267 | "df_items_test.write.parquet(os.path.join(save_path, r'val/items.parquet'))\n",
268 | "df_users_test.write.parquet(os.path.join(save_path, r'val/users.parquet'))\n",
269 | "df_rating_test.write.parquet(os.path.join(save_path, r'val/rating.parquet'))\n",
270 | "\n",
271 | "df_test.unpersist()"
272 | ]
273 | }
274 | ],
275 | "metadata": {
276 | "kernelspec": {
277 | "display_name": "Python [conda env:.conda-sber3.8]",
278 | "language": "python",
279 | "name": "conda-env-.conda-sber3.8-py"
280 | },
281 | "language_info": {
282 | "codemirror_mode": {
283 | "name": "ipython",
284 | "version": 3
285 | },
286 | "file_extension": ".py",
287 | "mimetype": "text/x-python",
288 | "name": "python",
289 | "nbconvert_exporter": "python",
290 | "pygments_lexer": "ipython3",
291 | "version": "3.8.13"
292 | }
293 | },
294 | "nbformat": 4,
295 | "nbformat_minor": 5
296 | }
297 |
--------------------------------------------------------------------------------
/experiments/Amazon/amazon.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"]=""
3 |
4 | import pandas as pd
5 | import numpy as np
6 | import requests
7 | import gzip
8 | import json
9 | import re
10 |
11 | import matplotlib.pyplot as plt
12 | from tqdm.notebook import tqdm
13 |
14 | from typing import List, Optional
15 |
16 | from bs4 import BeautifulSoup
17 | from nltk.tokenize import TreebankWordTokenizer, WhitespaceTokenizer
18 |
19 | import nltk
20 | nltk.download('punkt')
21 | nltk.download('stopwords')
22 | nltk.download('words')
23 | words = set(nltk.corpus.words.words())
24 | words = set([w.lower() for w in words])
25 | from nltk.stem import PorterStemmer, WordNetLemmatizer
26 | from nltk.corpus import stopwords
27 | stop_words = set(stopwords.words("english"))
28 | from nltk.tokenize import sent_tokenize
29 |
30 | import gensim
31 | from gensim.downloader import load
32 | from gensim.models import Word2Vec
33 | w2v_model = gensim.downloader.load('word2vec-google-news-300')
34 |
35 | import pyspark
36 | from pyspark.sql.types import *
37 | from pyspark import SparkConf
38 | from pyspark.sql import SparkSession
39 | from pyspark.sql import functions as sf
40 | from pyspark.ml.feature import Tokenizer, RegexTokenizer, StopWordsRemover
41 | from pyspark.ml import Pipeline
42 | from pyspark.sql.functions import expr
43 |
44 |
45 | def get_spark_session(
46 | mode
47 | ) -> SparkSession:
48 | """
49 | The function creates spark session
50 | :param mode: session mode
51 | :type mode: int
52 | :return: SparkSession
53 | :rtype: SparkSession
54 | """
55 | if mode == 1:
56 |
57 | SPARK_MASTER_URL = 'spark://spark:7077'
58 | SPARK_DRIVER_HOST = 'jupyterhub'
59 |
60 | conf = SparkConf().setAll([
61 | ('spark.master', SPARK_MASTER_URL),
62 | ('spark.driver.bindAddress', '0.0.0.0'),
63 | ('spark.driver.host', SPARK_DRIVER_HOST),
64 | ('spark.driver.blockManager.port', '12346'),
65 | ('spark.driver.port', '12345'),
66 | ('spark.driver.memory', '8g'), #4
67 | ('spark.driver.memoryOverhead', '2g'),
68 | ('spark.executor.memory', '14g'), #14
69 | ('spark.executor.memoryOverhead', '2g'),
70 | ('spark.app.name', 'simulator'),
71 | ('spark.submit.deployMode', 'client'),
72 | ('spark.ui.showConsoleProgress', 'true'),
73 | ('spark.eventLog.enabled', 'false'),
74 | ('spark.logConf', 'false'),
75 | ('spark.network.timeout', '10000000'),
76 | ('spark.executor.heartbeatInterval', '10000000'),
77 | ('spark.sql.shuffle.partitions', '4'),
78 | ('spark.default.parallelism', '4'),
79 | ("spark.kryoserializer.buffer","1024"),
80 | ('spark.sql.execution.arrow.pyspark.enabled', 'true'),
81 | ('spark.rpc.message.maxSize', '1000'),
82 | ("spark.driver.maxResultSize", "2g")
83 | ])
84 | spark = SparkSession.builder\
85 | .config(conf=conf)\
86 | .getOrCreate()
87 |
88 | elif mode == 0:
89 | spark = SparkSession.builder\
90 | .appName('simulator')\
91 | .master('local[4]')\
92 | .config('spark.sql.shuffle.partitions', '4')\
93 | .config('spark.default.parallelism', '4')\
94 | .config('spark.driver.extraJavaOptions', '-XX:+UseG1GC')\
95 | .config('spark.executor.extraJavaOptions', '-XX:+UseG1GC')\
96 | .config('spark.sql.autoBroadcastJoinThreshold', '-1')\
97 | .config('spark.sql.execution.arrow.pyspark.enabled', 'true')\
98 | .getOrCreate()
99 |
100 | return spark
101 |
102 | def clean_text(text: str) -> str:
103 |
104 | """
105 | Cleaning and preprocessing of the tags text with help of regular expressions.
106 | :param text: initial text
107 | :type text: str
108 | :return: cleaned text
109 | :rtype: str
110 | """
111 |
112 | text = re.sub("[^a-zA-Z]", " ",text)
113 | text = re.sub(r"\s+", " ", text)
114 | text = re.sub(r"\s+$", "", text)
115 | text = re.sub(r"^\s+", "", text)
116 | text = text.lower()
117 |
118 | return text
119 |
120 |
121 | def string_embedding(arr: list) -> np.ndarray:
122 | """
123 | Processing each word in the string with word2vec and return their aggregation (mean).
124 |
125 | :param arr: words
126 | :type text: List[str]
127 | :return: average vector of word2vec words representations
128 | :rtype: np.ndarray
129 | """
130 |
131 | vec = 0
132 | cnt = 0
133 | for i in arr:
134 | try:
135 | vec += w2v_model[i]
136 | cnt += 1
137 | except:
138 | pass
139 | if cnt == 0:
140 | vec = np.zeros((300,))
141 | else:
142 | vec /= cnt
143 | return vec
144 |
145 | @sf.pandas_udf(StringType(), sf.PandasUDFType.SCALAR)
146 | def clean_udf(str_series):
147 | """
148 | pandas udf of the clean_text function
149 | """
150 | result = []
151 | for x in str_series:
152 | x_procc = clean_text(x)
153 | result.append(x_procc)
154 | return pd.Series(result)
155 |
156 | @sf.pandas_udf(ArrayType(DoubleType()), sf.PandasUDFType.SCALAR)
157 | def embedding_udf(str_series):
158 | """
159 | pandas udf of the string_embedding function
160 | """
161 | result = []
162 | for x in str_series:
163 | x_procc = string_embedding(x)
164 | result.append(x_procc)
165 | return pd.Series(result)
166 |
167 |
168 | def data_processing(df_sp: pyspark.sql.DataFrame):
169 | df_procc = df.withColumnRenamed("item_id", "item_idx")\
170 | .withColumnRenamed("user_id", "user_idx")\
171 | .withColumn("review_clean", clean_udf(sf.col("review"))).drop("review", "__index_level_0__")
172 | df_procc = tokenizer.transform(df_procc).drop("review_clean")
173 | df_procc = remover.transform(df_procc).drop("tokens")
174 | df_procc = df_procc.withColumn("embedding", embedding_udf(sf.col("tokens_clean"))).drop("tokens_clean")
175 |
176 | df_items = df_sp.groupby("item_idx").agg(sf.array(*[sf.mean(sf.col("embedding")[i]) for i in range(300)]).alias("embedding"),
177 | sf.mean("helpfulness").alias("helpfulness"),
178 | sf.mean("score").alias("rating_avg"),
179 | sf.count("score").alias("rating_cnt"))
180 |
181 | df_users = df_sp.groupby("user_idx").agg(sf.array(*[sf.mean(sf.col("embedding")[i]) for i in range(300)]).alias("embedding"),
182 | sf.mean("helpfulness").alias("helpfulness"),
183 | sf.mean("score").alias("rating_avg"),
184 | sf.count("score").alias("rating_cnt"))
185 |
186 | df_rating = df_sp.groupby("user_idx", "item_idx", "timestamp").agg(sf.mean("score").alias("relevance"), sf.count("score").alias("rating_cnt"))
187 |
188 | df_items = df_items.select(['item_idx', 'helpfulness', 'rating_avg', 'rating_cnt']+[expr('embedding[' + str(x) + ']') for x in range(0, 300)])
189 | new_colnames = ['item_idx', 'helpfulness', 'rating_avg', 'rating_cnt'] + ['w2v_' + str(i) for i in range(0, 300)]
190 | df_items = df_items.toDF(*new_colnames)
191 |
192 | df_users = df_users.select(['user_idx', 'helpfulness', 'rating_avg', 'rating_cnt']+[expr('embedding[' + str(x) + ']') for x in range(0, 300)])
193 | new_colnames = ['user_idx', 'helpfulness', 'rating_avg', 'rating_cnt'] + ['w2v_' + str(i) for i in range(0, 300)]
194 | df_users = df_users.toDF(*new_colnames)
195 |
196 | return [df_users, df_items, df_rating]
197 |
--------------------------------------------------------------------------------
/experiments/Movielens/ml20.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import re
4 | import itertools
5 | import tqdm
6 |
7 | import seaborn as sns
8 | import tqdm
9 | import matplotlib.pyplot as plt
10 |
11 | from bs4 import BeautifulSoup
12 | from nltk.tokenize import TreebankWordTokenizer, WhitespaceTokenizer
13 |
14 | import nltk
15 | nltk.download('stopwords')
16 | nltk.download('words')
17 | words = set(nltk.corpus.words.words())
18 | words = set([w.lower() for w in words])
19 |
20 | from nltk.stem import PorterStemmer, WordNetLemmatizer
21 | nltk.download("wordnet")
22 |
23 | from nltk.corpus import stopwords
24 | stop_words = set(stopwords.words("english"))
25 |
26 |
27 | from nltk.tokenize import sent_tokenize
28 |
29 | import gensim
30 | from gensim.downloader import load
31 | from gensim.models import Word2Vec
32 | w2v_model = gensim.downloader.load('word2vec-google-news-300')
33 |
34 | from typing import Dict, List, Optional, Tuple
35 |
36 |
37 | def title_prep(title: str) -> str:
38 |
39 | """
40 | The function of cleaning the title of the movie from extra spaces, reduction to lowercase.ch of methods to create
41 | vector embeddings from original data
42 |
43 | :param title: the title of the movie
44 | :type title: str
45 | :return: cleaned title of the movie
46 | :rtype: str
47 | """
48 |
49 | title = re.sub(r'\s+', r' ', title)
50 | title = re.sub(r'($\s+|^\s+)', '', title)
51 | title = title.lower()
52 |
53 | return title
54 |
55 | def extract_year(title: str) -> Optional[str]:
56 |
57 | """
58 | Extracting year from the movie title
59 |
60 | :param title: the cleaned title of the movie
61 | :type title: str
62 | :return: movie year
63 | :rtype: float, optional
64 | """
65 |
66 | one_year = re.findall(r'\(\d{4}\)', title)
67 | two_years = re.findall(r'\(\d{4}-\d{4}\)', title)
68 | one_year_till_today = re.findall(r'\(\d{4}[-–]\s?\)', title)
69 | if len(one_year) == 1:
70 | return int(one_year[0][1:-1])
71 |
72 | elif len(two_years) == 1:
73 | return round((int(two_years[0][1:5]) + int(two_years[0][6:-1]))/2)
74 |
75 | elif len(one_year_till_today) == 1:
76 | return int(one_year_till_today[0][1:5])
77 | else:
78 | return np.nan
79 |
80 | def genres_processing(movies: pd.DataFrame) -> pd.DataFrame:
81 |
82 | """
83 | Processing movie genres by constructing a binary vector of length n, where n is the number of all possible genres.
84 | For example a string like 'genre1|genre3|...' will be transformed into a vector [0,1,0,1,...].
85 |
86 | :param movies: DataFrame with column 'genres'
87 | :type title: pd.DataFrame
88 | :return: DataFrame with processed genres
89 | :rtype: pd.DataFrame
90 | """
91 |
92 | genre_lists = [set(item.split('|')).difference(set(['(no genres listed)'])) for item in movies['genres']]
93 | genre_lists = pd.DataFrame(genre_lists)
94 |
95 | genre_dict = {token: idx for idx, token in enumerate(set(itertools.chain.from_iterable([item.split('|')
96 | for item in movies['genres']])).difference(set(['(no genres listed)'])))}
97 | genre_dict = pd.DataFrame(genre_dict.items())
98 | genre_dict.columns = ['genre', 'index']
99 |
100 | dummy = np.zeros([len(movies), len(genre_dict)])
101 |
102 | for i in range(dummy.shape[0]):
103 | for j in range(dummy.shape[1]):
104 | if genre_dict['genre'][j] in list(genre_lists.iloc[i, :]):
105 | dummy[i, j] = 1
106 |
107 | df_dummy = pd.DataFrame(dummy, columns = ['genre' + str(i) for i in range(dummy.shape[1])])
108 |
109 | movies_return = pd.concat([movies, df_dummy], 1)
110 | return movies_return
111 |
112 | def fill_null_years(movies: pd.DataFrame) -> pd.DataFrame:
113 |
114 | """
115 | Processing null years
116 |
117 | :param movies: DataFrame with processed years
118 | :type title: pd.DataFrame
119 | :return: DataFrame with processed not null years
120 | :rtype: pd.DataFrame
121 | """
122 |
123 | df_movies = movies.copy()
124 | genres_columns = [item for item in movies.columns.tolist() if item[:5]=='genre' and item !='genres']
125 | df_no_year = movies[movies.year.isna()][['movieId', *genres_columns]]
126 |
127 | years_mean = {}
128 | for i in df_no_year.index:
129 |
130 | row = np.asarray(df_no_year.loc[i, :][genres_columns])
131 | years = []
132 | for j in np.asarray(movies[['year', *genres_columns]]):
133 | if np.sum(row == j[1:]) == len(genres_columns):
134 | try:
135 | years.append(int(j[0]))
136 | except:
137 | pass
138 |
139 | years_mean[i] = round(np.mean(years))
140 |
141 | for i in years_mean:
142 | df_movies.loc[i, 'year'] = years_mean[i]
143 |
144 | df_movies.year=df_movies.year.astype('int')
145 |
146 | return df_movies
147 |
148 | def clean_text(text: str) -> str:
149 | """
150 | Cleaning text: remove extra spaces and non-text characters
151 |
152 | :param text: tag row text
153 | :type title: str
154 | :return: tag cleaned text
155 | :rtype: str
156 | """
157 |
158 | text = re.sub("[^a-zA-Z]", " ",text)
159 | text = re.sub(r"\s+", " ", text)
160 | text = re.sub(r"\s+$", "", text)
161 | text = re.sub(r"^\s+", "", text)
162 | text = text.lower()
163 |
164 | return text
165 |
166 |
167 | def procces_text(text):
168 | """
169 | Processing text: lemmatization, tokenization, removing stop-words
170 |
171 | :param text: tag cleaned text
172 | :type title: str
173 | :return: tag processed text
174 | :rtype: str
175 | """
176 | lemmatizer = WordNetLemmatizer()
177 |
178 | text = [word for word in nltk.word_tokenize(text) if not word in stop_words]
179 | text = [lemmatizer.lemmatize(token) for token in text]
180 | text = [word for word in text if word in words]
181 |
182 | text = " ".join(text)
183 |
184 | return text
185 |
186 | def string_embedding(string: str) -> np.ndarray:
187 | """
188 | Processing text: lemmatization, tokenization, removing stop-words
189 |
190 | :param string: cleaned and processed tags
191 | :type title: str
192 | :return: average vector of the string words embeddings
193 | :rtype: np.ndarray, optional
194 | """
195 |
196 | arr = string.split(' ')
197 | vec = 0
198 | cnt = 0
199 | for i in arr:
200 | try:
201 | vec += w2v_model[i]
202 | cnt += 1
203 | except:
204 | pass
205 | if cnt == 0:
206 | vec = np.zeros((300, 1))
207 | else:
208 | vec /= cnt
209 | return vec
210 |
211 |
212 | def data_processing(df_movie: pd.DataFrame,
213 | df_rating: pd.DataFrame,
214 | df_tags: pd.DataFrame
215 | ) -> List[pd.DataFrame]:
216 |
217 | print("------------------------ Movie processing ------------------------")
218 | #Extraction of the movies' years and transform genres lists to genres vector
219 | df_movies_procc = df_movie.copy()
220 | df_movies_procc.title = df_movies_procc.title.apply(title_prep) #title processing
221 | df_movies_procc['year'] = df_movies_procc.title.apply(extract_year) #year processing
222 | df_movies_procc = genres_processing(df_movies_procc) #genres processing
223 | df_movies_procc = fill_null_years(df_movies_procc) #fillimg null year values
224 |
225 | #Creating rating_avg column
226 | print("------------------------ Rating processing ------------------------")
227 | df_movies_procc = pd.merge(df_movies_procc, df_rating.groupby('movieId', as_index=False).rating.mean(), on='movieId', how='left')
228 | df_movies_procc.rating = df_movies_procc.rating.fillna(0.0)
229 | df_movies_procc = df_movies_procc.rename(columns={'rating' : 'rating_avg'})
230 | df_movies_clean = df_movies_procc.drop(['title', 'genres'], axis=1)[['movieId', 'year', 'rating_avg', *['genre' + str(i) for i in range(19)]]]
231 |
232 | print("------------------------ Tags processing ------------------------")
233 | df_tags_ = df_tags.drop(df_tags[df_tags.tag.isna()].index)
234 | df_movie_tags = df_tags_.sort_values(by=['movieId', 'timestamp'])[['movieId', 'tag', 'timestamp']]
235 | df_movie_tags['clean_tag'] = df_movie_tags.tag.apply(lambda x : procces_text(clean_text(x)))
236 | df_movie_tags = df_movie_tags[df_movie_tags.clean_tag.str.len()!=0]
237 |
238 | print("------------------------ Tags embedding ------------------------")
239 | #tags text gathering
240 | docs_movie_tags = df_movie_tags.sort_values(["movieId", "timestamp"]).groupby("movieId", as_index=False).agg({"clean_tag":lambda x: " ".join(x)})
241 | df_movies_tags = pd.concat([docs_movie_tags.movieId, pd.DataFrame(docs_movie_tags.clean_tag.apply(string_embedding).to_list(), columns = ['w2v_' + str(i) for i in range(300)])], axis = 1)
242 | df_movies_clean = pd.merge(df_movies_clean, df_movies_tags, on = "movieId", how = "left").fillna(0.0)
243 |
244 | print("------------------------ Users processing ------------------------")
245 | #users procc
246 | df_users = df_rating.copy()
247 | df_users = df_users.groupby(by=['userId'], as_index=False).rating.mean().rename(columns = {'rating' : 'rating_avg'})
248 | df_users_genres = pd.merge(df_movies_clean[['movieId', *df_movies_clean.columns[3:22]]], pd.merge(df_rating, df_users, on = 'userId')[['userId', 'movieId']],
249 | on = 'movieId')
250 |
251 | df_users_genres = df_users_genres.groupby(by = ['userId'], as_index = False)[df_movies_clean.columns[3:22]].mean()
252 | df_users_genres = pd.merge(df_users_genres, df_users, on = 'userId')
253 | df_pairs = pd.merge(df_rating, df_users, on = 'userId')[['userId', 'movieId']]
254 |
255 | print("------------------------ Users embedding ------------------------")
256 | users_id = []
257 | vect_space = []
258 | for Id in tqdm.tqdm(df_pairs.userId.unique()):
259 | movie_list = df_pairs[df_pairs.userId == Id].movieId.tolist()
260 | vect = np.asarray(df_movies_clean[df_movies_clean.movieId.isin(movie_list)][[*df_movies_clean.columns[22:]]].mean().tolist())
261 | users_id.append(Id)
262 | vect_space.append(vect)
263 |
264 | df_users_w2v = pd.DataFrame(vect_space, columns = ['w2v_' + str(i) for i in range(len(df_movies_clean.columns[22:]))])
265 | df_users_w2v['userId'] = users_id
266 | df_users_clean = pd.merge(df_users_genres, df_users_w2v, on = 'userId')
267 | df_rating_clean = df_rating[['userId', 'movieId', 'rating', 'timestamp']]
268 |
269 | """
270 | cat_dict_movies = pd.Series(df_movies_clean.movieId.astype("category").cat.codes.values, index=df_movies_clean.movieId).to_dict()
271 | cat_dict_users = pd.Series(df_users_clean.userId.astype("category").cat.codes.values, index=df_users_clean.userId).to_dict()
272 |
273 | df_movies_clean.movieId = df_movies_clean.movieId.apply(lambda x: cat_dict_movies[x])
274 | df_users_clean.userId = df_users_clean.userId.apply(lambda x: cat_dict_users[x])
275 | df_rating_clean.movieId = df_rating.movieId.apply(lambda x: cat_dict_movies[x])
276 | df_rating_clean.userId = df_rating.userId.apply(lambda x: cat_dict_users[x])
277 | """
278 |
279 | df_movies_clean = df_movies_clean.rename(columns={'movieId': 'item_idx'})
280 | df_users_clean = df_users_clean.rename(columns={'userId': 'user_idx'})
281 | df_rating_clean = df_rating_clean.rename(columns={'movieId': 'item_idx', 'userId': 'user_idx', 'rating': 'relevance'})
282 |
283 | return [df_movies_clean, df_users_clean, df_rating_clean]
--------------------------------------------------------------------------------
/experiments/Netflix/RU_load_datasets.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "e481238a-e6ea-431d-b2a7-957600884500",
6 | "metadata": {},
7 | "source": [
8 | "# Иллюстрация загрузки и подготовки данных Netflix"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "d3607d22-edeb-445a-b281-8bc237e86a98",
14 | "metadata": {},
15 | "source": [
16 | "### $\\textbf{Содержание}$:"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "id": "3c60e1bb-1f79-47b0-9cb4-1240e29f2d7a",
22 | "metadata": {},
23 | "source": [
24 | "### $\\textbf{I. Создание табличных данных на основании txt файлов}$\n",
25 | "### Из записей txt файлов формируются в табличные данные csv формата с атрибутами:\n",
26 | "#### - $\\it{movie\\_Id} \\in \\mathbb{N}$: идентификатор предложения; \n",
27 | "#### - $\\it{user_\\_i_d} \\in \\mathbb{N}$: идентификатор пользователя;\n",
28 | "#### - $\\it{rating} \\in [1, 5]$: полезность предложения для пользователя;\n",
29 | "#### - $\\it{date} \\in \\mathbb{N}$: временная метка;\n",
30 | "------\n"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "id": "344a920d-f5a1-4770-80b5-80d8959acc21",
36 | "metadata": {},
37 | "source": [
38 | "### $\\textbf{II. Обработка фильмов}$\n",
39 | "### Из названия фильмов генерируются следующие признаки:\n",
40 | "#### - год выпуска фильма $\\it{year} \\in \\mathbb{N}$;\n",
41 | "### Из оценок, выставленных пользователями, генерируются следующие признаки:\n",
42 | "#### - средняя оценка фильма $\\it{rating\\_avg} \\in [0, 5]$;\n",
43 | "#### - количество оценок фильма $\\it{rating\\_cnt} \\in \\mathbb{N} \\cup \\{0\\}$;\n",
44 | "---"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "id": "272be4e1-9592-4d71-8211-18c94b840c3a",
50 | "metadata": {},
51 | "source": [
52 | "### $\\textbf{III. Обработка рейтингов}$\n",
53 | "#### Никакие признаки на это этапе не генерируются;\n",
54 | "#### Атрибут даты приводятся к формату timestamp;\n"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "id": "b51300d1-d5c1-425a-aef5-a2f5c00483ad",
61 | "metadata": {
62 | "tags": []
63 | },
64 | "outputs": [],
65 | "source": [
66 | "import os\n",
67 | "import sys\n",
68 | "import pandas as pd\n",
69 | "import numpy as np\n",
70 | "\n",
71 | "import re\n",
72 | "from datetime import datetime\n",
73 | "\n",
74 | "import tqdm\n",
75 | "\n",
76 | "import pyspark\n",
77 | "from pyspark.sql import SparkSession\n",
78 | "from pyspark.sql import functions as F\n",
79 | "\n",
80 | "spark = SparkSession.builder\\\n",
81 | " .appName(\"processingApp\")\\\n",
82 | " .config(\"spark.driver.memory\", \"8G\")\\\n",
83 | " .config(\"spark.executor.cores\", \"8G\")\\\n",
84 | " .config(\"spark.executor.memory\", \"2G\")\\\n",
85 | " .getOrCreate()\n",
86 | "\n",
87 | "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "id": "e8354c1e-a4d2-4cc6-810a-5e76e7b8c8fe",
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "data_folder = r'./'\n",
98 | "row_data_files = ['combined_data_' + str(i) + '.txt' for i in range(1,5)]\n",
99 | "\n",
100 | "movie_titles_path = r'./movie_titles.csv'\n",
101 | "\n",
102 | "save_file_name = r\"./data_clean/netflix_full.csv\"\n",
103 | "\n",
104 | "save_movies = r'./data_clean/movies.csv'\n",
105 | "save_rating = r'./data_clean/rating.csv'"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "id": "c08cea48-8e20-433b-ad67-5dd6b19ac0bd",
112 | "metadata": {},
113 | "outputs": [],
114 | "source": []
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "id": "222b7bcc-208a-459b-a7db-65245f7a4e77",
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "row_data_files = [os.path.join(data_folder, i) for i in row_data_files]"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "id": "5e51560c-2254-49be-b33d-13a43075d4b2",
129 | "metadata": {},
130 | "source": [
131 | "### I. Создание табличных данных на основании txt файлов"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "id": "e88db14a-d515-4ee8-be4d-a4ee6de986b7",
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "def row_data_ops(\n",
142 | " files: List[str],\n",
143 | " save_file_name: str\n",
144 | "):\n",
145 | " \"\"\"\n",
146 | " Creating table data from txt files\n",
147 | " :param files: txt files names\n",
148 | " :type files: list\n",
149 | " :param save_file_name: file name for saving\n",
150 | " :type save_file_name: str\n",
151 | "\n",
152 | " \n",
153 | " \"\"\"\n",
154 | " for _, file_name in enumerate(files):\n",
155 | " df = spark.read.text(os.path.join(file_name))\n",
156 | " df = df.coalesce(1).withColumn(\"row_num\", F.monotonically_increasing_id())\n",
157 | "\n",
158 | " df_partitions = df.select( F.col(\"row_num\").alias(\"Id\"), \n",
159 | " F.regexp_extract(F.col(\"value\"), r'\\d+', 0).alias(\"Id_start\") ).where( F.substring(F.col(\"value\"), -1, 1)==\":\" )\n",
160 | " df_partitions = df_partitions.select( F.col(\"Id\").cast('int'),\n",
161 | " F.col(\"Id_start\").cast('int'))\n",
162 | "\n",
163 | " df_rows = df.select( F.col(\"row_num\").alias(\"Id\"),\n",
164 | " F.col(\"value\") ).where( F.substring(F.col(\"value\"), -1, 1)!=\":\" )\n",
165 | " df_rows = df_rows.select( F.col(\"Id\"),\n",
166 | " F.regexp_extract(F.col(\"value\"), r'(\\d+),(\\d+),(\\d+-\\d+-\\d+)', 1).cast('int').alias(\"user_Id\"),\n",
167 | " F.regexp_extract(F.col(\"value\"), r'(\\d+),(\\d+),(\\d+-\\d+-\\d+)', 2).cast('int').alias(\"rating\"),\n",
168 | " F.to_date(F.regexp_extract(F.col(\"value\"), r'(\\d+),(\\d+),(\\d+-\\d+-\\d+)', 3), \"yyyy-mm-dd\").alias(\"date\"))\n",
169 | " df_partitions2 = df_partitions.select( F.col(\"Id\").alias(\"Id2\"),\n",
170 | " F.col(\"Id_start\").alias(\"Id_end\"))\n",
171 | " df_indexes = df_partitions.join(df_partitions2, df_partitions2.Id_end - df_partitions.Id_start == 1, \"left\").select( \n",
172 | " F.col('Id').alias('Idx_start'), \n",
173 | " F.col('Id2').alias('Idx_stop'),\n",
174 | " F.col('Id_start').alias('Index'))\n",
175 | "\n",
176 | " df_result = df_rows.join(F.broadcast(df_indexes), (df_rows.Id > df_indexes.Idx_start) & ((df_rows.Id < df_indexes.Idx_stop) | (df_indexes.Idx_stop.isNull())), \"inner\").select(\n",
177 | " F.col('Index').alias('movie_Id'),\n",
178 | " F.col('user_Id'),\n",
179 | " F.col('rating'),\n",
180 | " F.col('date')\n",
181 | " ).distinct()\n",
182 | " \n",
183 | " if _ == 0:\n",
184 | " df_all_users = df_result\n",
185 | " else:\n",
186 | " df_all_users = df_all_users.union(df_result)\n",
187 | " \n",
188 | " df_all_users.write.csv(save_file_name)"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "id": "7d25ad6e-c18e-4e7e-a7ee-2eaf80de0621",
195 | "metadata": {},
196 | "outputs": [],
197 | "source": [
198 | "row_data_ops(row_data_files, save_file_name)"
199 | ]
200 | },
201 | {
202 | "cell_type": "markdown",
203 | "id": "8a164e5c-e723-4a45-96a3-1d84fd04221a",
204 | "metadata": {},
205 | "source": [
206 | "### II. Обработка фильмов"
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": 1,
212 | "id": "033358f6-fdca-4872-b638-79db19f61725",
213 | "metadata": {},
214 | "outputs": [],
215 | "source": [
216 | "def movies_ops(\n",
217 | " data_path, \n",
218 | " movies_path, \n",
219 | " movie_save_path\n",
220 | "):\n",
221 | " \n",
222 | " \"\"\"\n",
223 | " operate movies name\n",
224 | " :param data_path: path to netflix full file\n",
225 | " :type data_path: str\n",
226 | " :param movies_path: path to netflix movies file\n",
227 | " :type movies_path: str\n",
228 | " :param movie_save_path: file path to save clear netflix movies\n",
229 | " :type movie_save_path: str\n",
230 | "\n",
231 | " \"\"\"\n",
232 | " \n",
233 | " df_all = pd.read_csv(data_path) \n",
234 | " df_movies = pd.merge(df_all.groupby(by = ['movie_Id'], as_index=False).rating.count().rename(columns={'rating':'rating_cnt'}),\n",
235 | " df_all.groupby(by = ['movie_Id'], as_index=False).rating.mean().rename(columns={'rating':'rating_avg'}), \n",
236 | " on = 'movie_Id', how = 'inner')\n",
237 | " \n",
238 | " with open(movies_path, 'r', encoding=\"ISO-8859-1\") as f:\n",
239 | " file = f.read()\n",
240 | " file_arr = file.split('\\n')\n",
241 | " \n",
242 | " file_arr_id = []\n",
243 | " file_arr_year = []\n",
244 | " file_arr_name = []\n",
245 | "\n",
246 | " file_arr_problem = []\n",
247 | "\n",
248 | " for i in file_arr:\n",
249 | " row = re.sub(r'^\\s+', '', i)\n",
250 | " row = re.sub(r'\\s+$', '', i)\n",
251 | " row_group = re.match(r'(\\d+),(\\d+),(.+)', row)\n",
252 | " if row_group != None:\n",
253 | " assert row == row_group.group(0)\n",
254 | "\n",
255 | " file_arr_id.append(int(row_group.group(1)))\n",
256 | " file_arr_year.append(int(row_group.group(2)))\n",
257 | " file_arr_name.append(row_group.group(3))\n",
258 | "\n",
259 | " else:\n",
260 | " file_arr_problem.append(row)\n",
261 | "\n",
262 | " \n",
263 | " df_names = pd.DataFrame({ 'movie_Id':file_arr_id, 'year':file_arr_year, 'title':file_arr_name })\n",
264 | " fill_na_year = ['2002', '2002', '2002', '1974', '1999', '1994', '1999']\n",
265 | " fill_na_name = []\n",
266 | " fill_na_id = []\n",
267 | "\n",
268 | " for i in range(len(file_arr_problem)-1):\n",
269 | " row_group = re.match(r'(\\d+),(NULL),(.+)', file_arr_problem[i])\n",
270 | "\n",
271 | " fill_na_id.append(int(row_group.group(1)))\n",
272 | " fill_na_name.append(row_group.group(3))\n",
273 | "\n",
274 | " df_names = pd.concat([df_names, pd.DataFrame({ 'movie_Id':fill_na_id, 'year':fill_na_year, 'title':fill_na_name })])\n",
275 | " df_names.movie_Id = df_names.movie_Id.astype('int')\n",
276 | " df_names.year = df_names.year.astype('int')\n",
277 | " \n",
278 | " df_movies = pd.merge(df_movies, df_names, on = 'movie_Id', how = 'left')\n",
279 | " df_movies.reset_index(drop=True).to_csv(movie_save_path, index=False)"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": null,
285 | "id": "344fe1df-ae89-47cb-b777-4d4214e86bbf",
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "movies_ops(save_file_name, movie_titles_path, save_movies)"
290 | ]
291 | },
292 | {
293 | "cell_type": "markdown",
294 | "id": "4fa05a96-ddd8-4d92-82ec-d16145ecd774",
295 | "metadata": {},
296 | "source": [
297 | "### III. Обработка рейтингов"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": 2,
303 | "id": "10d08d3b-f540-4b1b-8c81-7a2ac2d024fc",
304 | "metadata": {},
305 | "outputs": [],
306 | "source": [
307 | "def rating_op(\n",
308 | " data_path, \n",
309 | " rating_save_path\n",
310 | "):\n",
311 | " \"\"\"\n",
312 | " operate ratings \n",
313 | " :param data_path: path to netflix full file\n",
314 | " :type data_path: str\n",
315 | " :param rating_save_path: file path to save operated netflix ratings\n",
316 | " :type rating_save_path: str\n",
317 | "\n",
318 | " \"\"\"\n",
319 | " \n",
320 | " df_rating = pd.read_csv(data_path) \n",
321 | " df_rating['timestamp'] = df_rating.date.apply(lambda x: pd.to_datetime(x))\n",
322 | " df_rating['timestamp'] = df_rating.timestamp.apply(lambda x: x.timestamp())\n",
323 | " df_rating = df_rating[['movie_Id', 'user_Id', 'rating', 'timestamp']]\n",
324 | " \n",
325 | " df_rating.reset_index(drop=True).to_csv(rating_save_path, index=False)"
326 | ]
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": null,
331 | "id": "5744d2fd-2507-48f7-9081-69d9daa8a1e2",
332 | "metadata": {},
333 | "outputs": [],
334 | "source": [
335 | "rating_op(save_file_name, save_rating)"
336 | ]
337 | }
338 | ],
339 | "metadata": {
340 | "kernelspec": {
341 | "display_name": "Python [conda env:.conda-sber3.8]",
342 | "language": "python",
343 | "name": "conda-env-.conda-sber3.8-py"
344 | },
345 | "language_info": {
346 | "codemirror_mode": {
347 | "name": "ipython",
348 | "version": 3
349 | },
350 | "file_extension": ".py",
351 | "mimetype": "text/x-python",
352 | "name": "python",
353 | "nbconvert_exporter": "python",
354 | "pygments_lexer": "ipython3",
355 | "version": "3.8.13"
356 | }
357 | },
358 | "nbformat": 4,
359 | "nbformat_minor": 5
360 | }
361 |
--------------------------------------------------------------------------------
/experiments/Netflix/netflix.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import re
4 | import itertools
5 | import tqdm
6 |
7 | import seaborn as sns
8 | import tqdm
9 | import matplotlib.pyplot as plt
10 |
11 | from bs4 import BeautifulSoup
12 | from nltk.tokenize import TreebankWordTokenizer, WhitespaceTokenizer
13 |
14 | import nltk
15 | nltk.download('stopwords')
16 | nltk.download('words')
17 | words = set(nltk.corpus.words.words())
18 | words = set([w.lower() for w in words])
19 |
20 | from nltk.stem import PorterStemmer, WordNetLemmatizer
21 | nltk.download("wordnet")
22 |
23 | from nltk.corpus import stopwords
24 | stop_words = set(stopwords.words("english"))
25 |
26 |
27 | from nltk.tokenize import sent_tokenize
28 |
29 | import gensim
30 | from gensim.downloader import load
31 | from gensim.models import Word2Vec
32 | w2v_model = gensim.downloader.load('word2vec-google-news-300')
33 |
34 | from typing import Dict, List, Optional, Tuple
35 |
36 | def clean_text(text: str) -> str:
37 | """
38 | Cleaning text: remove extra spaces and non-text characters
39 | :param text: tag row text
40 | :type title: str
41 | :return: tag cleaned text
42 | :rtype: str
43 | """
44 |
45 | text = re.sub("[^a-zA-Z]", " ",text)
46 | text = re.sub(r"\s+", " ", text)
47 | text = re.sub(r"\s+$", "", text)
48 | text = re.sub(r"^\s+", "", text)
49 | text = text.lower()
50 |
51 | return text
52 |
53 |
54 | def procces_text(text):
55 | """
56 | Processing text: lemmatization, tokenization, removing stop-words
57 | :param text: tag cleaned text
58 | :type text: str
59 | :return: tag processed text
60 | :rtype: str
61 | """
62 | lemmatizer = WordNetLemmatizer()
63 |
64 | text = [word for word in nltk.word_tokenize(text) if not word in stop_words]
65 | text = [lemmatizer.lemmatize(token) for token in text]
66 | text = [word for word in text if word in words]
67 |
68 | text = " ".join(text)
69 |
70 | return text
71 |
72 | def string_embedding(string: str) -> np.ndarray:
73 | """"
74 | Processing text: lemmatization, tokenization, removing stop-words
75 | :param string: cleaned and processed tags
76 | :type string: str
77 | :return: average vector of the string words embeddings
78 | :rtype: np.ndarray, optional
79 | """
80 |
81 | arr = string.split(' ')
82 | vec = 0
83 | cnt = 0
84 | for i in arr:
85 | try:
86 | vec += w2v_model[i]
87 | cnt += 1
88 | except:
89 | pass
90 | if cnt == 0:
91 | vec = np.zeros((300, 1))
92 | else:
93 | vec /= cnt
94 | return vec
95 |
96 | def group_w2v(history: pd.DataFrame, movies: pd.DataFrame) -> pd.DataFrame:
97 | """"
98 | Aggregate embedded data for users partitions.
99 | :param history: users movies history data .
100 | :type history: pd.DataFrame
101 | :param movies: movies data.
102 | :type movies: pd.DataFrame
103 | :return: average vector of the string words embeddings
104 | :rtype: np.ndarray, optional
105 | """
106 |
107 | """
108 | Aggregation (mean) embedded data for users watch history partitions.
109 |
110 | Arguments:
111 | --history: data frame of users movies history.
112 | --movies: data frame of movies.
113 |
114 | Return:
115 | --df: data frame of users with aggregation embedded movies data.
116 | """
117 | users_id_arr = history.user_Id.unique()
118 |
119 | id_arr = []
120 | vec_arr = np.zeros((len(users_id_arr), 300))
121 |
122 | for user_id in tqdm.tqdm_notebook(range(len(users_id_arr))):
123 | vec = np.asarray(movies[movies.movie_Id.isin(history[history.user_Id == users_id_arr[user_id]].movie_Id)].iloc[:, 4:]).mean(axis=0)
124 |
125 | id_arr.append(users_id_arr[user_id])
126 | vec_arr[user_id] = vec
127 |
128 | df = pd.DataFrame(vec_arr)
129 | df['user_Id'] = id_arr
130 |
131 | return df
132 |
133 | def data_processing(df_movies: pd.DataFrame,
134 | df_rating: pd.DataFrame,
135 | rename: bool = True
136 | ) -> List[pd.DataFrame]:
137 |
138 | df_movies = df_movies.drop(["rating_cnt", "rating_avg"], axis=1)
139 | df_movies['clean_title'] = df_movies.title.apply(lambda x : procces_text(clean_text(x)))
140 | df_movies.drop("title", axis = 1, inplace = True)
141 | df_movies_clean = pd.concat([df_movies.drop("clean_title", axis=1),
142 | pd.DataFrame(df_movies.clean_title.apply(string_embedding).to_list(), columns = ['w2v_' + str(i) for i in range(300)])], axis = 1)
143 |
144 | movies_vector = df_movies_clean.drop(['year'], axis=1)
145 | for col in movies_vector.drop("movie_Id", axis=1).columns:
146 | movies_vector[col] = movies_vector[col].astype('float')
147 |
148 | agg_columns = []
149 | df_result = pd.DataFrame()
150 |
151 | chunksize=10000
152 | chunk_count = (df_rating.shape[0] // chunksize) + 1 if df_rating.shape[0]%chunksize!=0 else df_rating.shape[0] // chunksize
153 | for idx in tqdm.tqdm_notebook(range(chunk_count)):
154 | chunk = df_rating.iloc[idx*chunksize:(idx+1)*chunksize, :]
155 | df_history = pd.merge(chunk[['user_Id', 'movie_Id', 'rating']], movies_vector.movie_Id, on = 'movie_Id', how = 'left')
156 | df_history = pd.merge(df_history, movies_vector, how='left', on='movie_Id').drop('movie_Id', axis=1)
157 | df_history['cnt'] = 1
158 |
159 | if idx == 0:
160 | agg_columns = df_history.drop(['user_Id'], axis=1).columns
161 | df_history_aggregated = df_history.groupby("user_Id", as_index=False)[agg_columns].sum()
162 | df_result = df_result.append(df_history_aggregated, ignore_index=True)
163 |
164 | if idx % 20 == 0:
165 | df_result = df_result.groupby("user_Id", as_index=False)[agg_columns].sum()
166 |
167 | df_result = df_result.groupby("user_Id", as_index=False)[agg_columns].sum()
168 | for col in agg_columns:
169 | if col != "cnt":
170 | df_result[col] = df_result[col] / df_result["cnt"]
171 | df_result = df_result.rename(columns={"rating": "rating_avg", "cnt": "rating_cnt"})
172 | df_users_clean = df_result
173 |
174 | df_movies_clean = pd.merge(df_rating.groupby("movie_Id", as_index=False)["rating"]\
175 | .agg(['mean', 'count'])\
176 | .rename(columns={"mean": "rating_avg", "count": "rating_cnt"}), df_movies_clean,how='left', on='movie_Id').fillna(0.0)
177 |
178 | df_rating_clean = df_rating
179 | if rename:
180 | cat_dict_movies = pd.Series(df_movies_clean.movie_Id.astype("category").cat.codes.values, index=df_movies_clean.movie_Id).to_dict()
181 | cat_dict_users = pd.Series(df_users_clean.user_Id.astype("category").cat.codes.values, index=df_users_clean.user_Id).to_dict()
182 | df_movies_clean.movie_Id = df_movies_clean.movie_Id.apply(lambda x: cat_dict_movies[x])
183 | df_users_clean.user_Id = df_users_clean.user_Id.apply(lambda x: cat_dict_users[x])
184 | df_rating_clean.movie_Id = df_rating_clean.movie_Id.apply(lambda x: cat_dict_movies[x])
185 | df_rating_clean.user_Id = df_rating_clean.user_Id.apply(lambda x: cat_dict_users[x])
186 |
187 | df_movies_clean = df_movies_clean.rename(columns={'movie_Id': 'item_idx'})
188 | df_users_clean = df_users_clean.rename(columns={'user_Id': 'user_idx'})
189 | df_rating_clean = df_rating_clean.rename(columns={'movie_Id': 'item_idx', 'user_Id': 'user_idx', 'rating': 'relevance'})
190 |
191 | return [df_movies_clean, df_users_clean, df_rating_clean]
--------------------------------------------------------------------------------
/experiments/datautils.py:
--------------------------------------------------------------------------------
1 | import pyspark.sql.types as st
2 | import pyspark.sql.functions as sf
3 | from pyspark.sql import DataFrame
4 | from pyspark.ml.feature import Bucketizer, VectorAssembler
5 | from pyspark.ml.clustering import KMeans
6 |
7 |
8 | USER_PREFIX = 'user_'
9 | ITEM_PREFIX = 'item_'
10 |
11 | MOVIELENS_USER_SCHEMA = st.StructType(
12 | [st.StructField('user_idx', st.IntegerType())] +\
13 | [st.StructField(f'genre{i}', st.DoubleType()) for i in range(19)] +\
14 | [st.StructField('rating_avg', st.DoubleType())] +\
15 | [st.StructField(f'w2v_{i}', st.DoubleType()) for i in range(300)]
16 | )
17 | MOVIELENS_ITEM_SCHEMA = st.StructType(
18 | [st.StructField('item_idx', st.IntegerType())] +\
19 | [st.StructField('year', st.IntegerType())] +\
20 | [st.StructField('rating_avg', st.DoubleType())] +\
21 | [st.StructField(f'genre{i}', st.DoubleType()) for i in range(19)] +\
22 | [st.StructField(f'w2v_{i}', st.DoubleType()) for i in range(300)]
23 | )
24 | MOVIELENS_LOG_SCHEMA = st.StructType([
25 | st.StructField('user_idx', st.IntegerType()),
26 | st.StructField('item_idx', st.IntegerType()),
27 | st.StructField('relevance', st.DoubleType()),
28 | st.StructField('timestamp', st.IntegerType())
29 | ])
30 |
31 | NETFLIX_USER_SCHEMA = st.StructType(
32 | [st.StructField('user_idx', st.IntegerType())] +\
33 | [st.StructField('rating_avg', st.DoubleType())] +\
34 | [st.StructField(f'w2v_{i}', st.DoubleType()) for i in range(300)] +\
35 | [st.StructField('rating_cnt', st.IntegerType())]
36 | )
37 | NETFLIX_ITEM_SCHEMA = st.StructType(
38 | [st.StructField('item_idx', st.IntegerType())] +\
39 | [st.StructField('rating_avg', st.DoubleType())] +\
40 | [st.StructField('rating_cnt', st.IntegerType())] +\
41 | [st.StructField('year', st.IntegerType())] +\
42 | [st.StructField(f'w2v_{i}', st.DoubleType()) for i in range(300)]
43 | )
44 | NETFLIX_LOG_SCHEMA = st.StructType([
45 | st.StructField('item_idx', st.IntegerType()),
46 | st.StructField('user_idx', st.IntegerType()),
47 | st.StructField('relevance', st.DoubleType()),
48 | st.StructField('timestamp', st.DoubleType())
49 | ])
50 |
51 | MOVIELENS_CLUSTER_COLS = [
52 | 'genre0', 'genre1', 'genre2', 'genre3', 'genre4',
53 | 'genre5', 'genre6', 'genre7', 'genre8', 'genre9',
54 | 'genre10', 'genre11', 'genre12', 'genre13', 'genre14',
55 | 'genre15', 'genre16', 'genre17', 'genre18'
56 | ]
57 | NETFLIX_STRAT_COL = 'rating_cnt'
58 | NETFLIX_CLUSTER_COLS = [f'w2v_{i}' for i in range(300)]
59 |
60 |
61 | def read_movielens(base_path, type, spark_session):
62 | if type not in ['train', 'val', 'test']:
63 | raise ValueError('Wrong dataset type')
64 |
65 | users = spark_session\
66 | .read.csv(f'{base_path}/{type}/users.csv', header=True, schema=MOVIELENS_USER_SCHEMA)\
67 | .withColumnRenamed('user_idx', 'user_id')
68 | items = spark_session\
69 | .read.csv(f'{base_path}/{type}/items.csv', header=True, schema=MOVIELENS_ITEM_SCHEMA)\
70 | .withColumnRenamed('item_idx', 'item_id')
71 | log = spark_session\
72 | .read.csv(f'{base_path}/{type}/rating.csv', header=True, schema=MOVIELENS_LOG_SCHEMA)\
73 | .withColumnRenamed('user_idx', 'user_id')\
74 | .withColumnRenamed('item_idx', 'item_id')
75 |
76 | log = log\
77 | .join(users, on='user_id', how='leftsemi')\
78 | .join(items, on='item_id', how='leftsemi')
79 |
80 | for c in users.columns:
81 | if not c.startswith('user_'):
82 | users = users.withColumnRenamed(c, 'user_' + c)
83 |
84 | for c in items.columns:
85 | if not c.startswith('item_'):
86 | items = items.withColumnRenamed(c, 'item_' + c)
87 |
88 | log = log.withColumn('relevance', sf.when(sf.col('relevance') >= 3, 1).otherwise(0))
89 |
90 | users = users.na.drop()
91 | items = items.na.drop()
92 | log = log.na.drop()
93 |
94 | return users, items, log
95 |
96 | def read_netflix(base_path, type, spark_session):
97 | if type not in ['train', 'val', 'test']:
98 | raise ValueError('Wrong dataset type')
99 |
100 | users = spark_session\
101 | .read.csv(f'{base_path}/{type}/users.csv', header=True, schema=NETFLIX_USER_SCHEMA)\
102 | .withColumnRenamed('user_idx', 'user_id')
103 | items = spark_session\
104 | .read.csv(f'{base_path}/{type}/items.csv', header=True, schema=NETFLIX_ITEM_SCHEMA)\
105 | .withColumnRenamed('item_idx', 'item_id')
106 | log = spark_session\
107 | .read.csv(f'{base_path}/{type}/rating.csv', header=True, schema=NETFLIX_LOG_SCHEMA)\
108 | .withColumnRenamed('user_idx', 'user_id')\
109 | .withColumnRenamed('item_idx', 'item_id')
110 |
111 | log = log\
112 | .join(users, on='user_id', how='leftsemi')\
113 | .join(items, on='item_id', how='leftsemi')
114 |
115 | for c in users.columns:
116 | if not c.startswith('user_'):
117 | users = users.withColumnRenamed(c, 'user_' + c)
118 |
119 | for c in items.columns:
120 | if not c.startswith('item_'):
121 | items = items.withColumnRenamed(c, 'item_' + c)
122 |
123 | log = log.withColumn('relevance', sf.when(sf.col('relevance') >= 3, 1).otherwise(0))
124 |
125 | users = users.na.drop()
126 | items = items.na.drop()
127 | log = log.na.drop()
128 |
129 | return users, items, log
130 |
131 | def read_amazon(base_path, type, spark_session):
132 | if type not in ['train', 'val', 'test']:
133 | raise ValueError('Wrong dataset type')
134 |
135 | users = spark_session\
136 | .read.parquet(f'{base_path}/{type}/users.parquet').withColumnRenamed('user_idx', 'user_id')
137 | items = spark_session\
138 | .read.parquet(f'{base_path}/{type}/items.parquet').withColumnRenamed('item_idx', 'item_id')
139 | log = spark_session\
140 | .read.parquet(f'{base_path}/{type}/rating.parquet')\
141 | .withColumnRenamed('user_idx', 'user_id')\
142 | .withColumnRenamed('item_idx', 'item_id')
143 |
144 | log = log\
145 | .join(users, on='user_id', how='leftsemi')\
146 | .join(items, on='item_id', how='leftsemi')
147 |
148 | for c in users.columns:
149 | if not c.startswith('user_'):
150 | users = users.withColumnRenamed(c, 'user_' + c)
151 |
152 | for c in items.columns:
153 | if not c.startswith('item_'):
154 | items = items.withColumnRenamed(c, 'item_' + c)
155 |
156 | log = log.withColumn('relevance', sf.when(sf.col('relevance') >= 3, 1).otherwise(0))
157 |
158 | users = users.na.drop()
159 | items = items.na.drop()
160 | log = log.na.drop()
161 |
162 | return users, items, log
163 |
164 |
165 | def netflix_cluster_users(
166 | df : DataFrame,
167 | outputCol : str = 'cluster',
168 | column_prefix : str = '',
169 | seed : int = None
170 | ):
171 | cluster_cols = [f'{column_prefix}{c}' for c in NETFLIX_CLUSTER_COLS]
172 | assembler = VectorAssembler(
173 | inputCols=cluster_cols, outputCol='__features'
174 | )
175 | kmeans = KMeans(
176 | k=10, featuresCol='__features',
177 | predictionCol=outputCol, maxIter=300, seed=seed
178 | )
179 |
180 | df = assembler.transform(df)
181 | kmeans_model = kmeans.fit(df)
182 | df = kmeans_model.transform(df)
183 |
184 | return df.drop('__features')
185 |
186 | def movielens_cluster_users(
187 | df : DataFrame,
188 | outputCol : str = 'cluster',
189 | column_prefix : str = '',
190 | seed : int = None
191 | ):
192 | cluster_cols = [f'{column_prefix}{c}' for c in MOVIELENS_CLUSTER_COLS]
193 | assembler = VectorAssembler(
194 | inputCols=cluster_cols, outputCol='__features'
195 | )
196 | kmeans = KMeans(
197 | k=10, featuresCol='__features',
198 | predictionCol=outputCol, maxIter=300, seed=seed
199 | )
200 |
201 | df = assembler.transform(df)
202 | kmeans_model = kmeans.fit(df)
203 | df = kmeans_model.transform(df)
204 |
205 | return df.drop('__features')
206 |
--------------------------------------------------------------------------------
/experiments/generators_generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import pandas as pd
5 | import torch
6 | from pyspark.sql import SparkSession
7 | from replay.session_handler import State
8 | from sim4rec.modules import SDVDataGenerator
9 |
10 | SPARK_LOCAL_DIR = '/data/home/anthony/tmp'
11 | RESULT_DIR = '../bin'
12 |
13 |
14 | NUM_JOBS = int(sys.argv[1])
15 | torch.set_num_threads(8)
16 |
17 | spark = SparkSession.builder\
18 | .appName('simulator')\
19 | .master(f'local[{NUM_JOBS}]')\
20 | .config('spark.sql.shuffle.partitions', f'{NUM_JOBS}')\
21 | .config('spark.default.parallelism', f'{NUM_JOBS}')\
22 | .config('spark.driver.extraJavaOptions', '-XX:+UseG1GC')\
23 | .config('spark.executor.extraJavaOptions', '-XX:+UseG1GC')\
24 | .config('spark.sql.autoBroadcastJoinThreshold', '-1')\
25 | .config('spark.driver.memory', '256g')\
26 | .config('spark.local.dir', SPARK_LOCAL_DIR)\
27 | .getOrCreate()
28 |
29 | State(spark)
30 |
31 | def generate_time(generator, num_samples):
32 | start = time.time()
33 | df = generator.generate(num_samples).cache()
34 | df.count()
35 | result_time = time.time() - start
36 | df.unpersist()
37 |
38 | return (generator.getLabel(), result_time, num_samples, NUM_JOBS)
39 |
40 | generators = [SDVDataGenerator.load(f'{RESULT_DIR}/genscale_{g}_{10000}.pkl') for g in ['copulagan', 'ctgan', 'gaussiancopula', 'tvae']]
41 | for g in generators:
42 | g.setParallelizationLevel(NUM_JOBS)
43 |
44 | NUM_TEST_SAMPLES = [10, 100, 1000, 10000, 100000, 1000000, 10000000]
45 |
46 | result_df = pd.DataFrame(columns=['model_label', 'generate_time', 'num_samples', 'num_threads'])
47 |
48 | for g in generators:
49 | _ = g.generate(100).cache().count()
50 | for n in NUM_TEST_SAMPLES:
51 | print(f'Generating with {g.getLabel()} {n} samples')
52 | result_df.loc[len(result_df)] = generate_time(g, n)
53 |
54 | old_df = None
55 | if os.path.isfile(f'{RESULT_DIR}/gens_sample_time.csv'):
56 | old_df = pd.read_csv(f'{RESULT_DIR}/gens_sample_time.csv')
57 | else:
58 | old_df = pd.DataFrame(columns=['model_label', 'generate_time', 'num_samples', 'num_threads'])
59 |
60 | pd.concat([old_df, result_df], ignore_index=True).to_csv(f'{RESULT_DIR}/gens_sample_time.csv', index=False)
61 |
--------------------------------------------------------------------------------
/experiments/response_models/data/popular_items_popularity.parquet/.part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet.crc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sb-ai-lab/Sim4Rec/267b644932b0e120cc9887f0f3a16be38d5739b7/experiments/response_models/data/popular_items_popularity.parquet/.part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet.crc
--------------------------------------------------------------------------------
/experiments/response_models/data/popular_items_popularity.parquet/_SUCCESS:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sb-ai-lab/Sim4Rec/267b644932b0e120cc9887f0f3a16be38d5739b7/experiments/response_models/data/popular_items_popularity.parquet/_SUCCESS
--------------------------------------------------------------------------------
/experiments/response_models/data/popular_items_popularity.parquet/part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sb-ai-lab/Sim4Rec/267b644932b0e120cc9887f0f3a16be38d5739b7/experiments/response_models/data/popular_items_popularity.parquet/part-00000-8ac17189-ecfe-473d-a9c5-d72f7f86d119-c000.snappy.parquet
--------------------------------------------------------------------------------
/experiments/response_models/task_1_popular_items.py:
--------------------------------------------------------------------------------
1 | from pyspark.ml import PipelineModel
2 | from pyspark.sql import SparkSession
3 |
4 | import pyspark.sql.functions as sf
5 |
6 | from sim4rec.response import BernoulliResponse, ActionModelTransformer
7 |
8 | from response_models.utils import get_session
9 |
10 |
11 | class PopBasedTransformer(ActionModelTransformer):
12 | def __init__(
13 | self,
14 | spark: SparkSession,
15 | outputCol: str = None,
16 | pop_df_path: str = None,
17 | ):
18 | """
19 | :param outputCol: Name of the response probability column
20 | :param pop_df_path: path to a spark dataframe with items' popularity
21 | """
22 | self.pop_df = sf.broadcast(spark.read.parquet(pop_df_path))
23 | self.outputCol = outputCol
24 |
25 | def _transform(self, dataframe):
26 | return (dataframe
27 | .join(self.pop_df, on='item_idx')
28 | .drop(*set(self.pop_df.columns).difference(["item_idx", self.outputCol]))
29 | )
30 |
31 |
32 | class TaskOneResponse:
33 | def __init__(self, spark, pop_df_path="./response_models/data/popular_items_popularity.parquet", seed=123):
34 | pop_resp = PopBasedTransformer(spark=spark, outputCol="popularity", pop_df_path=pop_df_path)
35 | br = BernoulliResponse(seed=seed, inputCol='popularity', outputCol='response')
36 | self.model = PipelineModel(
37 | stages=[pop_resp, br])
38 |
39 | def transform(self, df):
40 | return self.model.transform(df).drop("popularity")
41 |
42 |
43 | if __name__ == '__main__':
44 | import pandas as pd
45 | spark = get_session()
46 | task = TaskOneResponse(spark, pop_df_path="./data/popular_items_popularity.parquet", seed=123)
47 | task.model.stages[0].pop_df.show()
48 | test_df = spark.createDataFrame(pd.DataFrame({"item_idx": [3, 5, 1], "user_idx": [5, 2, 1]}))
49 | task.transform(test_df).show()
50 | task.transform(test_df).show()
51 | task.transform(test_df).show()
52 |
--------------------------------------------------------------------------------
/experiments/response_models/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import pyspark.sql.functions as sf
3 | from pyspark.sql import SparkSession, DataFrame
4 | from pyspark.ml import PipelineModel
5 | from sim4rec.response import BernoulliResponse, ActionModelTransformer
6 | from IPython.display import clear_output
7 |
8 | def get_session(num_threads=4) -> SparkSession:
9 | return SparkSession.builder \
10 | .appName('simulator') \
11 | .master(f'local[{num_threads}]') \
12 | .config('spark.sql.shuffle.partitions', f'{num_threads * 3}') \
13 | .config('spark.default.parallelism', f'{num_threads * 3}') \
14 | .config('spark.driver.extraJavaOptions', '-XX:+UseG1GC') \
15 | .config('spark.executor.extraJavaOptions', '-XX:+UseG1GC') \
16 | .getOrCreate()
17 |
18 |
19 | def plot_metric(metrics):
20 | clear_output(wait=True)
21 | # plt.ylim(0, max(metrics) + 1)
22 | plt.plot(metrics)
23 | plt.grid()
24 | plt.xlabel('iteration')
25 | plt.ylabel('# of clicks')
26 | plt.show()
27 |
28 | def calc_metric(response_df):
29 | return (response_df
30 | .groupBy("user_idx").agg(sf.sum("response").alias("num_positive"))
31 | .select(sf.mean("num_positive")).collect()[0][0]
32 | )
33 |
34 | class ResponseTransformer(ActionModelTransformer):
35 |
36 | def __init__(
37 | self,
38 | spark: SparkSession,
39 | outputCol: str = None,
40 | proba_df_path: str = None,
41 | ):
42 | """
43 | Calculates users' response based on precomputed probability of item interaction
44 |
45 | :param outputCol: Name of the response probability column
46 | :param boost_df_path: path to a spark dataframe with precomputed user-item probability of interaction
47 | """
48 | self.proba_df = spark.read.parquet(proba_df_path)
49 | self.outputCol = outputCol
50 |
51 | def _transform(self, dataframe):
52 | return (dataframe
53 | .join(self.proba_df, on=['item_idx', 'user_idx'])
54 | .drop(*set(self.proba_df.columns).difference(["user_idx", "item_idx", self.outputCol]))
55 | )
56 |
--------------------------------------------------------------------------------
/experiments/simulator_time.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings("ignore")
3 |
4 | import sys
5 | import time
6 | import random
7 |
8 | import pandas as pd
9 | import numpy as np
10 |
11 | from pyspark.sql import SparkSession
12 | from replay.session_handler import State
13 | from sim4rec.utils import pandas_to_spark
14 | from sim4rec.modules import SDVDataGenerator
15 | from pyspark.ml.classification import LogisticRegression
16 | from pyspark.ml.feature import VectorAssembler
17 | from pyspark.ml import PipelineModel
18 | from sim4rec.response import CosineSimilatiry, BernoulliResponse, NoiseResponse, ParametricResponseFunction
19 | from sim4rec.utils import VectorElementExtractor
20 |
21 | from replay.data_preparator import Indexer
22 | from replay.models import UCB
23 |
24 | NUM_JOBS = int(sys.argv[1])
25 |
26 | SPARK_LOCAL_DIR = '/data/home/anthony/tmp'
27 | CHECKPOINT_DIR = '/data/home/anthony/tmp/checkpoints'
28 | MODELS_PATH = '../bin'
29 |
30 | spark = SparkSession.builder\
31 | .appName('simulator_validation')\
32 | .master(f'local[{NUM_JOBS}]')\
33 | .config('spark.sql.shuffle.partitions', f'{NUM_JOBS}')\
34 | .config('spark.default.parallelism', f'{NUM_JOBS}')\
35 | .config('spark.driver.extraJavaOptions', '-XX:+UseG1GC')\
36 | .config('spark.executor.extraJavaOptions', '-XX:+UseG1GC')\
37 | .config('spark.sql.autoBroadcastJoinThreshold', '-1')\
38 | .config('spark.driver.memory', '64g')\
39 | .config('spark.local.dir', SPARK_LOCAL_DIR)\
40 | .getOrCreate()
41 |
42 | State(spark)
43 |
44 | users_df = pd.DataFrame(data=np.random.normal(1, 1, size=(10000, 100)), columns=[f'user_attr_{i}' for i in range(100)])
45 | items_df = pd.DataFrame(data=np.random.normal(-1, 1, size=(2000, 100)), columns=[f'item_attr_{i}' for i in range(100)])
46 | items_df.loc[random.sample(range(2000), 1000)] = np.random.normal(1, 1, size=(1000, 100))
47 | users_df['user_id'] = np.arange(len(users_df))
48 | items_df['item_id'] = np.arange(len(items_df))
49 | history_df_all = pd.DataFrame()
50 | history_df_all['user_id'] = np.random.randint(0, 10000, size=33000)
51 | history_df_all['item_id'] = np.random.randint(0, 2000, size=33000)
52 | history_df_all['relevance'] = 0
53 |
54 | users_matrix = users_df.values[history_df_all.values[:, 0], :-1]
55 | items_matrix = items_df.values[history_df_all.values[:, 1], :-1]
56 | dot = np.sum(users_matrix * items_matrix, axis=1)
57 | history_df_all['relevance'] = np.where(dot >= 0.5, 1, 0)
58 | history_df_all = history_df_all.drop_duplicates(subset=['user_id', 'item_id'], ignore_index=True)
59 |
60 | history_df_train = history_df_all.iloc[:30000]
61 | history_df_val = history_df_all.iloc[30000:]
62 |
63 | users_df = pandas_to_spark(users_df)
64 | items_df = pandas_to_spark(items_df)
65 | history_df_train = pandas_to_spark(history_df_train)
66 | history_df_val = pandas_to_spark(history_df_val)
67 |
68 | user_generator = SDVDataGenerator.load(f'{MODELS_PATH}/cycle_scale_users_gen.pkl')
69 | item_generator = SDVDataGenerator.load(f'{MODELS_PATH}/cycle_scale_items_gen.pkl')
70 |
71 | user_generator.setDevice('cpu')
72 | item_generator.setDevice('cpu')
73 | user_generator.setParallelizationLevel(NUM_JOBS)
74 | item_generator.setParallelizationLevel(NUM_JOBS)
75 |
76 | syn_users = user_generator.generate(1000000).cache()
77 | syn_items = item_generator.generate(10000).cache()
78 |
79 | va_users_items = VectorAssembler(
80 | inputCols=users_df.columns[:-1] + items_df.columns[:-1],
81 | outputCol='features'
82 | )
83 |
84 | lr = LogisticRegression(
85 | featuresCol='features',
86 | labelCol='relevance',
87 | probabilityCol='__lr_prob'
88 | )
89 |
90 | vee = VectorElementExtractor(inputCol='__lr_prob', outputCol='__lr_prob', index=1)
91 |
92 | lr_train_df = history_df_train\
93 | .join(users_df, 'user_id', 'left')\
94 | .join(items_df, 'item_id', 'left')
95 |
96 | lr_model = lr.fit(va_users_items.transform(lr_train_df))
97 |
98 |
99 | va_users = VectorAssembler(
100 | inputCols=users_df.columns[:-1],
101 | outputCol='features_usr'
102 | )
103 |
104 | va_items = VectorAssembler(
105 | inputCols=items_df.columns[:-1],
106 | outputCol='features_itm'
107 | )
108 |
109 | cos_sim = CosineSimilatiry(
110 | inputCols=["features_usr", "features_itm"],
111 | outputCol="__cos_prob"
112 | )
113 |
114 | noise_resp = NoiseResponse(mu=0.5, sigma=0.2, outputCol='__noise_prob', seed=1234)
115 |
116 | param_resp = ParametricResponseFunction(
117 | inputCols=['__lr_prob', '__cos_prob', '__noise_prob'],
118 | outputCol='__proba',
119 | weights=[1/3, 1/3, 1/3]
120 | )
121 |
122 | br = BernoulliResponse(inputCol='__proba', outputCol='response')
123 |
124 | pipeline = PipelineModel(
125 | stages=[
126 | va_users_items,
127 | lr_model,
128 | vee,
129 | va_users,
130 | va_items,
131 | cos_sim,
132 | noise_resp,
133 | param_resp,
134 | br
135 | ]
136 | )
137 |
138 | from sim4rec.modules import Simulator, EvaluateMetrics
139 | from replay.metrics import NDCG
140 |
141 | sim = Simulator(
142 | user_gen=user_generator,
143 | item_gen=item_generator,
144 | user_key_col='user_id',
145 | item_key_col='item_id',
146 | spark_session=spark,
147 | data_dir=f'{CHECKPOINT_DIR}/cycle_load_test_{NUM_JOBS}',
148 | )
149 |
150 | evaluator = EvaluateMetrics(
151 | userKeyCol='user_id',
152 | itemKeyCol='item_id',
153 | predictionCol='relevance',
154 | labelCol='response',
155 | replay_label_filter=1.0,
156 | replay_metrics={NDCG() : 100}
157 | )
158 |
159 | indexer = Indexer(user_col='user_id', item_col='item_id')
160 | indexer.fit(users=syn_users, items=syn_items)
161 |
162 | ucb = UCB(sample=True)
163 | ucb.fit(log=indexer.transform(history_df_train.limit(1)))
164 |
165 | items_replay = indexer.transform(syn_items).cache()
166 |
167 | ucb_metrics = []
168 |
169 | time_list = []
170 | for i in range(30):
171 | cycle_time = {}
172 | iter_start = time.time()
173 |
174 | start = time.time()
175 | users = sim.sample_users(0.02).cache()
176 | users.count()
177 | cycle_time['sample_users_time'] = time.time() - start
178 |
179 | start = time.time()
180 | log = sim.get_log(users)
181 | if log is not None:
182 | log = indexer.transform(log).cache()
183 | else:
184 | log = indexer.transform(history_df_train.limit(1)).cache()
185 | log.count()
186 | cycle_time['get_log_time'] = time.time() - start
187 |
188 | start = time.time()
189 | recs_ucb = ucb.predict(
190 | log=log,
191 | k=100,
192 | users=indexer.transform(users),
193 | items=items_replay
194 | )
195 | recs_ucb = indexer.inverse_transform(recs_ucb).cache()
196 | recs_ucb.count()
197 | cycle_time['model_predict_time'] = time.time() - start
198 |
199 | start = time.time()
200 | resp_ucb = sim.sample_responses(
201 | recs_df=recs_ucb,
202 | user_features=users,
203 | item_features=syn_items,
204 | action_models=pipeline
205 | ).select('user_id', 'item_id', 'relevance', 'response').cache()
206 | resp_ucb.count()
207 | cycle_time['sample_responses_time'] = time.time() - start
208 |
209 | start = time.time()
210 | sim.update_log(resp_ucb, iteration=i)
211 | cycle_time['update_log_time'] = time.time() - start
212 |
213 | start = time.time()
214 | ucb_metrics.append(evaluator(resp_ucb))
215 | cycle_time['metrics_time'] = time.time() - start
216 |
217 | start = time.time()
218 | ucb._clear_cache()
219 | ucb_train_log = sim.log.cache()
220 | cycle_time['log_size'] = ucb_train_log.count()
221 | ucb.fit(
222 | log=indexer.transform(
223 | ucb_train_log\
224 | .select('user_id', 'item_id', 'response')\
225 | .withColumnRenamed('response', 'relevance')
226 | )
227 | )
228 | cycle_time['model_train'] = time.time() - start
229 |
230 | users.unpersist()
231 | if log is not None:
232 | log.unpersist()
233 | recs_ucb.unpersist()
234 | resp_ucb.unpersist()
235 | ucb_train_log.unpersist()
236 |
237 | cycle_time['iter_time'] = time.time() - iter_start
238 | cycle_time['iteration'] = i
239 | cycle_time['num_threads'] = NUM_JOBS
240 |
241 | time_list.append(cycle_time)
242 |
243 | print(f'Iteration {i} ended in {cycle_time["iter_time"]} seconds')
244 |
245 | items_replay.unpersist()
246 |
247 | import os
248 |
249 | if os.path.isfile(f'{MODELS_PATH}/cycle_time.csv'):
250 | pd.concat([pd.read_csv(f'{MODELS_PATH}/cycle_time.csv'), pd.DataFrame(time_list)]).to_csv(f'{MODELS_PATH}/cycle_time.csv', index=False)
251 | else:
252 | pd.DataFrame(time_list).to_csv(f'{MODELS_PATH}/cycle_time.csv', index=False)
253 |
--------------------------------------------------------------------------------
/experiments/transformers.py:
--------------------------------------------------------------------------------
1 | from sim4rec.response import ActionModelTransformer
2 | from pyspark.ml.param.shared import HasInputCol
3 | from pyspark.sql import DataFrame
4 | import pyspark.sql.functions as sf
5 | from pyspark.ml.param.shared import Params, Param, TypeConverters
6 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
7 |
8 | class HasMultiplierValue(Params):
9 | multiplierValue = Param(
10 | Params._dummy(),
11 | "multiplierValue",
12 | "Multiplier value parameter",
13 | typeConverter=TypeConverters.toFloat
14 | )
15 |
16 | def __init__(self):
17 | super().__init__()
18 |
19 | def setMultiplierValue(self, value):
20 | return self._set(multiplierValue=value)
21 |
22 | def getMultiplierValue(self):
23 | return self.getOrDefault(self.multiplierValue)
24 |
25 | class ModelCalibration(ActionModelTransformer,
26 | HasInputCol,
27 | HasMultiplierValue,
28 | DefaultParamsReadable,
29 | DefaultParamsWritable):
30 | def __init__(
31 | self,
32 | value : float = 0.0,
33 | inputCol : str = None,
34 | outputCol : str = None
35 | ):
36 | """
37 | Multiplies response function output by the chosen value.
38 | :param value: Multiplier value
39 | :param outputCol: Output column name
40 | """
41 |
42 | super().__init__(outputCol=outputCol)
43 |
44 | self._set(inputCol=inputCol)
45 | self._set(multiplierValue=value)
46 |
47 | def _transform(
48 | self,
49 | df : DataFrame
50 | ):
51 | value = self.getMultiplierValue()
52 | inputCol = self.getInputCol()
53 | outputColumn = self.getOutputCol()
54 |
55 |
56 | return df.withColumn(outputColumn, sf.lit(value)*sf.col(inputCol))
57 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "sim4rec"
3 | version = "0.0.2"
4 | description = "Simulator for recommendation algorithms"
5 | authors = ["Alexey Vasilev",
6 | "Anna Volodkevich",
7 | "Andrey Gurov",
8 | "Elizaveta Stavinova",
9 | "Anton Lysenko"]
10 | packages = [
11 | { include = "sim4rec" }
12 | ]
13 | readme = "README.md"
14 | repository = "https://github.com/sb-ai-lab/Sim4Rec"
15 |
16 | [tool.poetry.dependencies]
17 | python = ">=3.8, <3.10"
18 | pyarrow = "*"
19 | sdv = "0.15.0"
20 | torch = "*"
21 | pandas = "*"
22 | pyspark = ">=3.0"
23 | numpy = ">=1.20.0"
24 | scipy = "*"
25 |
26 | [tool.poetry.dev-dependencies]
27 | # visualization
28 | jupyter = "*"
29 | jupyterlab = "*"
30 | matplotlib = "*"
31 | # testing
32 | pytest-cov = "*"
33 | pycodestyle = "*"
34 | pylint = "*"
35 | # docs
36 | Sphinx = "*"
37 | sphinx-rtd-theme = "*"
38 | sphinx-autodoc-typehints = "*"
39 | ghp-import = "*"
40 |
41 | [build-system]
42 | requires = ["poetry-core>=1.0.0"]
43 | build-backend = "poetry.core.masonry.api"
44 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 |
2 | [pytest]
3 | minversion = 6.0
4 | log_cli = 1
5 | log_cli_level = ERROR
6 | log_cli_format = %(asctime)s [%(levelname)8s] [%(name)s] %(message)s (%(filename)s:%(lineno)s)
7 | log_cli_date_format = %H:%M:%S
8 |
9 | addopts =
10 | --capture=tee-sys -q
11 | -m 'not mycandidate'
12 |
13 | testpaths =
14 | tests
15 | filterwarnings =
16 | ignore::DeprecationWarning
17 | markers =
18 | mycandidate: test mycandidate api integration (disabled by default)
--------------------------------------------------------------------------------
/sim4rec/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sb-ai-lab/Sim4Rec/267b644932b0e120cc9887f0f3a16be38d5739b7/sim4rec/__init__.py
--------------------------------------------------------------------------------
/sim4rec/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .generator import (
2 | GeneratorBase,
3 | RealDataGenerator,
4 | SDVDataGenerator,
5 | CompositeGenerator
6 | )
7 | from .selectors import (
8 | ItemSelectionEstimator,
9 | ItemSelectionTransformer,
10 | CrossJoinItemEstimator,
11 | CrossJoinItemTransformer
12 | )
13 | from .simulator import Simulator
14 | from .embeddings import (
15 | EncoderEstimator,
16 | EncoderTransformer
17 | )
18 | from .evaluation import (
19 | evaluate_synthetic,
20 | EvaluateMetrics,
21 | ks_test,
22 | kl_divergence
23 | )
24 |
25 | __all__ = [
26 | 'GeneratorBase',
27 | 'RealDataGenerator',
28 | 'SDVDataGenerator',
29 | 'CompositeGenerator',
30 | 'ItemSelectionEstimator',
31 | 'ItemSelectionTransformer',
32 | 'CrossJoinItemEstimator',
33 | 'CrossJoinItemTransformer',
34 | 'Simulator',
35 | 'EncoderEstimator',
36 | 'EncoderTransformer',
37 | 'evaluate_synthetic',
38 | 'EvaluateMetrics',
39 | 'ks_test',
40 | 'kl_divergence'
41 | ]
42 |
--------------------------------------------------------------------------------
/sim4rec/modules/embeddings.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import pandas as pd
4 | import torch
5 | import torch.optim as opt
6 | import torch.nn.functional as F
7 | from torch.utils.data import DataLoader
8 |
9 | import pyspark.sql.types as st
10 | from pyspark.sql import DataFrame
11 | from pyspark.ml import Transformer, Estimator
12 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
13 | from pyspark.ml.param.shared import HasInputCols, HasOutputCols
14 |
15 | from sim4rec.params import HasDevice, HasSeed
16 |
17 |
18 | class Encoder(torch.nn.Module):
19 | """
20 | Encoder layer
21 | """
22 | def __init__(
23 | self,
24 | input_dim : int,
25 | hidden_dim : int,
26 | latent_dim : int
27 | ):
28 | super().__init__()
29 |
30 | input_dims = [input_dim, hidden_dim, latent_dim]
31 | self._layers = torch.nn.ModuleList([
32 | torch.nn.Linear(_in, _out)
33 | for _in, _out in zip(input_dims[:-1], input_dims[1:])
34 | ])
35 |
36 | def forward(self, X):
37 | """
38 | Performs forward pass through layer
39 | """
40 | X = F.normalize(X, p=2)
41 | for layer in self._layers[:-1]:
42 | X = layer(X)
43 | X = F.leaky_relu(X)
44 |
45 | X = self._layers[-1](X)
46 |
47 | return X
48 |
49 |
50 | class Decoder(torch.nn.Module):
51 | """
52 | Decoder layer
53 | """
54 | def __init__(
55 | self,
56 | input_dim : int,
57 | hidden_dim : int,
58 | latent_dim : int
59 | ):
60 | super().__init__()
61 |
62 | input_dims = [latent_dim, hidden_dim, input_dim]
63 | self._layers = torch.nn.ModuleList([
64 | torch.nn.Linear(_in, _out)
65 | for _in, _out in zip(input_dims[:-1], input_dims[1:])
66 | ])
67 |
68 | def forward(self, X):
69 | """
70 | Performs forward pass through layer
71 | """
72 | for layer in self._layers[:-1]:
73 | X = layer(X)
74 | X = F.leaky_relu(X)
75 |
76 | X = self._layers[-1](X)
77 |
78 | return X
79 |
80 |
81 | # pylint: disable=too-many-ancestors
82 | class EncoderEstimator(Estimator,
83 | HasInputCols,
84 | HasOutputCols,
85 | HasDevice,
86 | HasSeed,
87 | DefaultParamsReadable,
88 | DefaultParamsWritable):
89 | """
90 | Estimator for encoder part of the autoencoder pipeline. Trains
91 | the encoder to process incoming data into latent representation
92 | """
93 | # pylint: disable=too-many-arguments
94 | def __init__(
95 | self,
96 | inputCols : List[str],
97 | outputCols : List[str],
98 | hidden_dim : int,
99 | lr : float,
100 | batch_size : int,
101 | num_loader_workers : int,
102 | max_iter : int = 100,
103 | device_name : str = 'cpu',
104 | seed : int = None
105 | ):
106 | """
107 | :param inputCols: Column names to process
108 | :param outputCols: List of output column names per latent coordinate.
109 | The length of outputCols will determine the embedding dimension size
110 | :param hidden_dim: Size of hidden layers
111 | :param lr: Learning rate
112 | :param batch_size: Batch size during training process
113 | :param num_loader_workers: Number of cpus to use for data loader
114 | :param max_iter: Maximum number of iterations, defaults to 100
115 | :param device_name: PyTorch device name, defaults to 'cpu'
116 | """
117 |
118 | super().__init__()
119 |
120 | self._set(inputCols=inputCols, outputCols=outputCols)
121 | self.setDevice(device_name)
122 | self.setSeed(seed)
123 |
124 | self._input_dim = len(inputCols)
125 | self._hidden_dim = hidden_dim
126 | self._latent_dim = len(outputCols)
127 |
128 | self._lr = lr
129 | self._batch_size = batch_size
130 | self._num_loader_workers = num_loader_workers
131 | self._max_iter = max_iter
132 |
133 | # pylint: disable=too-many-locals, not-callable
134 | def _fit(
135 | self,
136 | dataset : DataFrame
137 | ):
138 | inputCols = self.getInputCols()
139 | outputCols = self.getOutputCols()
140 | device_name = self.getDevice()
141 | seed = self.getSeed()
142 | device = torch.device(self.getDevice())
143 | # pylint: disable=not-an-iterable
144 | X = dataset.select(*inputCols).toPandas().values
145 |
146 | torch.manual_seed(torch.seed() if seed is None else seed)
147 |
148 | train_loader = DataLoader(X, batch_size=self._batch_size,
149 | shuffle=True, num_workers=self._num_loader_workers)
150 |
151 | encoder = Encoder(
152 | input_dim=self._input_dim,
153 | hidden_dim=self._hidden_dim,
154 | latent_dim=self._latent_dim
155 | )
156 | decoder = Decoder(
157 | input_dim=self._input_dim,
158 | hidden_dim=self._hidden_dim,
159 | latent_dim=self._latent_dim
160 | )
161 |
162 | model = torch.nn.Sequential(encoder, decoder).to(torch.device(self.getDevice()))
163 |
164 | optimizer = opt.Adam(model.parameters(), lr=self._lr)
165 | crit = torch.nn.MSELoss()
166 |
167 | for _ in range(self._max_iter):
168 | loss = 0
169 | for X_batch in train_loader:
170 | X_batch = X_batch.float().to(device)
171 |
172 | optimizer.zero_grad()
173 |
174 | pred = model(X_batch)
175 | train_loss = crit(pred, X_batch)
176 |
177 | train_loss.backward()
178 | optimizer.step()
179 |
180 | loss += train_loss.item()
181 |
182 | torch.manual_seed(torch.seed())
183 |
184 | return EncoderTransformer(
185 | inputCols=inputCols,
186 | outputCols=outputCols,
187 | encoder=encoder,
188 | device_name=device_name
189 | )
190 |
191 |
192 | class EncoderTransformer(Transformer,
193 | HasInputCols,
194 | HasOutputCols,
195 | HasDevice,
196 | DefaultParamsReadable,
197 | DefaultParamsWritable):
198 | """
199 | Encoder transformer that transforms incoming columns into latent
200 | representation. Output data will be appended to dataframe and
201 | named according to outputCols parameter
202 | """
203 | def __init__(
204 | self,
205 | inputCols : List[str],
206 | outputCols : List[str],
207 | encoder : Encoder,
208 | device_name : str = 'cpu'
209 | ):
210 | """
211 | :param inputCols: Column names to process
212 | :param outputCols: List of output column names per latent coordinate.
213 | The length of outputCols must be equal to embedding dimension of
214 | a trained encoder
215 | :param encoder: Trained encoder
216 | :param device_name: PyTorch device name, defaults to 'cpu'
217 | """
218 |
219 | super().__init__()
220 |
221 | self._set(inputCols=inputCols, outputCols=outputCols)
222 | self._encoder = encoder
223 | self.setDevice(device_name)
224 |
225 | def setDevice(self, value):
226 | super().setDevice(value)
227 |
228 | self._encoder.to(torch.device(value))
229 |
230 | # pylint: disable=not-callable
231 | def _transform(
232 | self,
233 | dataset : DataFrame
234 | ):
235 | inputCols = self.getInputCols()
236 | outputCols = self.getOutputCols()
237 | device = torch.device(self.getDevice())
238 |
239 | encoder = self._encoder
240 |
241 | @torch.no_grad()
242 | def encode(iterator):
243 | for pdf in iterator:
244 | X = torch.tensor(pdf.loc[:, inputCols].values).float().to(device)
245 | yield pd.DataFrame(
246 | data=encoder(X).cpu().numpy(),
247 | columns=outputCols
248 | )
249 |
250 | schema = st.StructType(
251 | # pylint: disable=not-an-iterable
252 | [st.StructField(c, st.FloatType()) for c in outputCols]
253 | )
254 |
255 | return dataset.mapInPandas(encode, schema)
256 |
--------------------------------------------------------------------------------
/sim4rec/modules/evaluation.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import List, Union, Dict, Optional
3 |
4 | import numpy as np
5 | from scipy.stats import kstest
6 | # TL;DR scipy.special is a C library, pylint needs python source code
7 | # https://github.com/pylint-dev/pylint/issues/3703
8 | # pylint: disable=no-name-in-module
9 | from scipy.special import kl_div
10 |
11 | from pyspark.sql import DataFrame
12 | from pyspark.ml.evaluation import (
13 | BinaryClassificationEvaluator,
14 | RegressionEvaluator,
15 | MulticlassClassificationEvaluator
16 | )
17 |
18 | from sdv.evaluation import evaluate
19 |
20 |
21 | def evaluate_synthetic(
22 | synth_df : DataFrame,
23 | real_df : DataFrame
24 | ) -> dict:
25 | """
26 | Evaluates the quality of synthetic data against real. The following
27 | metrics will be calculated:
28 |
29 | - LogisticDetection: The metric evaluates how hard it is to distinguish the synthetic
30 | data from the real data by using a Logistic regression model
31 | - SVCDetection: The metric evaluates how hard it is to distinguish the synthetic data
32 | from the real data by using a C-Support Vector Classification model
33 | - KSTest: This metric uses the two-sample Kolmogorov-Smirnov test to compare
34 | the distributions of continuous columns using the empirical CDF
35 | - ContinuousKLDivergence: This approximates the KL divergence by binning the continuous values
36 | to turn them into categorical values and then computing the relative entropy
37 |
38 | :param synth_df: Synthetic data without any identifiers
39 | :param real_df: Real data without any identifiers
40 | :return: Dictionary with metrics on synthetic data quality
41 | """
42 |
43 | result = evaluate(
44 | synthetic_data=synth_df.toPandas(),
45 | real_data=real_df.toPandas(),
46 | metrics=[
47 | 'LogisticDetection',
48 | 'SVCDetection',
49 | 'KSTest',
50 | 'ContinuousKLDivergence'
51 | ],
52 | aggregate=False
53 | )
54 |
55 | return {
56 | row['metric'] : row['normalized_score']
57 | for _, row in result.iterrows()
58 | }
59 |
60 |
61 | def ks_test(
62 | df : DataFrame,
63 | predCol : str,
64 | labelCol : str
65 | ) -> float:
66 | """
67 | Kolmogorov-Smirnov test on two dataframe columns
68 |
69 | :param df: Dataframe with two target columns
70 | :param predCol: Column name with values to test
71 | :param labelCol: Column name with values to test against
72 | :return: Result of KS test
73 | """
74 |
75 | pdf = df.select(predCol, labelCol).toPandas()
76 | rvs, cdf = pdf[predCol].values, pdf[labelCol].values
77 |
78 | return kstest(rvs, cdf).statistic
79 |
80 |
81 | def kl_divergence(
82 | df : DataFrame,
83 | predCol : str,
84 | labelCol : str
85 | ) -> float:
86 | """
87 | Normalized Kullback–Leibler divergence on two dataframe columns. The normalization is
88 | as follows:
89 |
90 | .. math::
91 | \\frac{1}{1 + KL\_div}
92 |
93 | :param df: Dataframe with two target columns
94 | :param predCol: First column name
95 | :param labelCol: Second column name
96 | :return: Result of KL divergence
97 | """
98 |
99 | pdf = df.select(predCol, labelCol).toPandas()
100 | predicted, ground_truth = pdf[predCol].values, pdf[labelCol].values
101 |
102 | f_obs, edges = np.histogram(ground_truth)
103 | f_exp, _ = np.histogram(predicted, bins=edges)
104 |
105 | f_obs = f_obs.flatten() + 1e-5
106 | f_exp = f_exp.flatten() + 1e-5
107 |
108 | return 1 / (1 + np.sum(kl_div(f_obs, f_exp)))
109 |
110 |
111 | # pylint: disable=too-few-public-methods
112 | class EvaluateMetrics(ABC):
113 | """
114 | Recommendation systems and response function metric evaluator class.
115 | The class allows you to evaluate the quality of a response function on
116 | historical data or a recommender system on historical data or based on
117 | the results of an experiment in a simulator. Provides simultaneous
118 | calculation of several metrics using metrics from the Spark MLlib library.
119 | A created instance is callable on a dataframe with ``user_id, item_id,
120 | predicted relevance/response, true relevance/response`` format, which
121 | you can usually retrieve from simulators sample_responses() or log data
122 | with recommendation algorithm scores.
123 | """
124 |
125 | REGRESSION_METRICS = set(['rmse', 'mse', 'r2', 'mae', 'var'])
126 | MULTICLASS_METRICS = set([
127 | 'f1', 'accuracy', 'weightedPrecision', 'weightedRecall',
128 | 'weightedTruePositiveRate', 'weightedFalsePositiveRate',
129 | 'weightedFMeasure', 'truePositiveRateByLabel', 'falsePositiveRateByLabel',
130 | 'precisionByLabel', 'recallByLabel', 'fMeasureByLabel',
131 | 'logLoss', 'hammingLoss'
132 | ])
133 | BINARY_METRICS = set(['areaUnderROC', 'areaUnderPR'])
134 |
135 | # pylint: disable=too-many-arguments
136 | def __init__(
137 | self,
138 | userKeyCol : str,
139 | itemKeyCol : str,
140 | predictionCol : str,
141 | labelCol : str,
142 | mllib_metrics : Optional[Union[str, List[str]]] = None
143 | ):
144 | """
145 | :param userKeyCol: User identifier column name
146 | :param itemKeyCol: Item identifier column name
147 | :param predictionCol: Predicted scores column name
148 | :param labelCol: True label column name
149 | :param mllib_metrics: Metrics to calculate from spark's mllib. See
150 | REGRESSION_METRICS, MULTICLASS_METRICS, BINARY_METRICS for available
151 | values, defaults to None
152 | """
153 |
154 | super().__init__()
155 |
156 | self._userKeyCol = userKeyCol
157 | self._itemKeyCol = itemKeyCol
158 | self._predictionCol = predictionCol
159 | self._labelCol = labelCol
160 |
161 | if isinstance(mllib_metrics, str):
162 | mllib_metrics = [mllib_metrics]
163 |
164 | if mllib_metrics is None:
165 | mllib_metrics = []
166 |
167 | self._mllib_metrics = mllib_metrics
168 |
169 | def __call__(
170 | self,
171 | df : DataFrame
172 | ) -> Dict[str, float]:
173 | """
174 | Performs metrics calculations on passed dataframe
175 |
176 | :param df: Spark dataframe with userKeyCol, itemKeyCol, predictionCol
177 | and labelCol columns
178 | :return: Dictionary with metrics
179 | """
180 |
181 | df = df.withColumnRenamed(self._userKeyCol, 'user_idx')\
182 | .withColumnRenamed(self._itemKeyCol, 'item_idx')
183 |
184 | result = {}
185 |
186 | for m in self._mllib_metrics:
187 | evaluator = self._get_evaluator(m)
188 | result[m] = evaluator.evaluate(df)
189 |
190 | return result
191 |
192 | def _reg_or_multiclass_params(self):
193 | return {'predictionCol' : self._predictionCol, 'labelCol' : self._labelCol}
194 |
195 | def _binary_params(self):
196 | return {'rawPredictionCol' : self._predictionCol, 'labelCol' : self._labelCol}
197 |
198 | def _get_evaluator(self, metric):
199 | if metric in self.REGRESSION_METRICS:
200 | return RegressionEvaluator(
201 | metricName=metric, **self._reg_or_multiclass_params())
202 | if metric in self.BINARY_METRICS:
203 | return BinaryClassificationEvaluator(
204 | metricName=metric, **self._binary_params())
205 | if metric in self.MULTICLASS_METRICS:
206 | return MulticlassClassificationEvaluator(
207 | metricName=metric, **self._reg_or_multiclass_params())
208 |
209 | raise ValueError(f'Non existing metric was passed: {metric}')
210 |
--------------------------------------------------------------------------------
/sim4rec/modules/selectors.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=no-member,unused-argument,too-many-ancestors,abstract-method
2 | from pyspark.sql import functions as sf
3 | from pyspark.sql import DataFrame
4 | from pyspark.ml import Transformer, Estimator
5 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
6 | from pyspark import keyword_only
7 | from sim4rec.params import HasUserKeyColumn, HasItemKeyColumn, HasSeed, HasSeedSequence
8 |
9 |
10 | class ItemSelectionEstimator(Estimator,
11 | HasUserKeyColumn,
12 | HasItemKeyColumn,
13 | DefaultParamsReadable,
14 | DefaultParamsWritable):
15 | """
16 | Base class for item selection estimator
17 | """
18 | @keyword_only
19 | def __init__(
20 | self,
21 | userKeyColumn : str = None,
22 | itemKeyColumn : str = None
23 | ):
24 | super().__init__()
25 | self.setParams(**self._input_kwargs)
26 |
27 | @keyword_only
28 | def setParams(
29 | self,
30 | userKeyColumn : str = None,
31 | itemKeyColumn : str = None
32 | ):
33 | """
34 | Sets Estimator parameters
35 | """
36 | return self._set(**self._input_kwargs)
37 |
38 |
39 | class ItemSelectionTransformer(Transformer,
40 | HasUserKeyColumn,
41 | HasItemKeyColumn,
42 | DefaultParamsReadable,
43 | DefaultParamsWritable):
44 | """
45 | Base class for item selection transformer. transform()
46 | will be used to create user-item pairs
47 | """
48 | @keyword_only
49 | def __init__(
50 | self,
51 | userKeyColumn : str = None,
52 | itemKeyColumn : str = None
53 | ):
54 | super().__init__()
55 | self.setParams(**self._input_kwargs)
56 |
57 | @keyword_only
58 | def setParams(
59 | self,
60 | userKeyColumn : str = None,
61 | itemKeyColumn : str = None
62 | ):
63 | """
64 | Sets Transformer parameters
65 | """
66 | self._set(**self._input_kwargs)
67 |
68 |
69 | class CrossJoinItemEstimator(ItemSelectionEstimator, HasSeed):
70 | """
71 | Assigns k items for every user from random items subsample
72 | """
73 | def __init__(
74 | self,
75 | k : int,
76 | userKeyColumn : str = None,
77 | itemKeyColumn : str = None,
78 | seed : int = None
79 | ):
80 | """
81 | :param k: Number of items for every user
82 | :param userKeyColumn: Users identifier column, defaults to None
83 | :param itemKeyColumn: Items identifier column, defaults to None
84 | :param seed: Random state seed, defaults to None
85 | """
86 |
87 | super().__init__(userKeyColumn=userKeyColumn,
88 | itemKeyColumn=itemKeyColumn)
89 |
90 | self.setSeed(seed)
91 |
92 | self._k = k
93 |
94 | def _fit(
95 | self,
96 | dataset : DataFrame
97 | ):
98 | """
99 | Fits estimator with items dataframe
100 |
101 | :param df: Items dataframe
102 | :returns: CrossJoinItemTransformer instance
103 | """
104 |
105 | userKeyColumn = self.getUserKeyColumn()
106 | itemKeyColumn = self.getItemKeyColumn()
107 | seed = self.getSeed()
108 |
109 | if itemKeyColumn not in dataset.columns:
110 | raise ValueError(f'Dataframe has no {itemKeyColumn} column')
111 |
112 | return CrossJoinItemTransformer(
113 | item_df=dataset,
114 | k=self._k,
115 | userKeyColumn=userKeyColumn,
116 | itemKeyColumn=itemKeyColumn,
117 | seed=seed
118 | )
119 |
120 |
121 | class CrossJoinItemTransformer(ItemSelectionTransformer, HasSeedSequence):
122 | """
123 | Assigns k items for every user from random items subsample
124 | """
125 | # pylint: disable=too-many-arguments
126 | def __init__(
127 | self,
128 | item_df : DataFrame,
129 | k : int,
130 | userKeyColumn : str = None,
131 | itemKeyColumn : str = None,
132 | seed : int = None
133 | ):
134 | super().__init__(userKeyColumn=userKeyColumn,
135 | itemKeyColumn=itemKeyColumn)
136 |
137 | self.initSeedSequence(seed)
138 |
139 | self._item_df = item_df
140 | self._k = k
141 |
142 | def _transform(
143 | self,
144 | dataset : DataFrame
145 | ):
146 | """
147 | Takes a users dataframe and assings defined number of items
148 |
149 | :param df: Users dataframe
150 | :returns: Users cross join on random items subsample
151 | """
152 |
153 | userKeyColumn = self.getUserKeyColumn()
154 | itemKeyColumn = self.getItemKeyColumn()
155 | seed = self.getNextSeed()
156 |
157 | if userKeyColumn not in dataset.columns:
158 | raise ValueError(f'Dataframe has no {userKeyColumn} column')
159 |
160 | random_items = self._item_df.orderBy(sf.rand(seed=seed))\
161 | .limit(self._k)
162 |
163 | return dataset.select(userKeyColumn)\
164 | .crossJoin(random_items.select(itemKeyColumn))
165 |
--------------------------------------------------------------------------------
/sim4rec/modules/simulator.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from abc import ABC
3 | from typing import Tuple, Union, Optional
4 |
5 | from pyspark.sql import SparkSession
6 | from pyspark.sql import DataFrame
7 | from pyspark.ml import Transformer, PipelineModel
8 |
9 | from sim4rec.utils.session_handler import State
10 | from sim4rec.modules.generator import GeneratorBase
11 |
12 |
13 | # pylint: disable=too-many-instance-attributes
14 | class Simulator(ABC):
15 | """
16 | Simulator for recommendation systems, which uses the users
17 | and items data passed to it, to simulate the users responses
18 | to recommended items
19 | """
20 |
21 | ITER_COLUMN = '__iter'
22 | DEFAULT_LOG_FILENAME = 'log.parquet'
23 |
24 | # pylint: disable=too-many-arguments
25 | def __init__(
26 | self,
27 | user_gen : GeneratorBase,
28 | item_gen : GeneratorBase,
29 | data_dir : str,
30 | log_df : DataFrame = None,
31 | user_key_col : str = 'user_idx',
32 | item_key_col : str = 'item_idx',
33 | spark_session : SparkSession = None
34 | ):
35 | """
36 | :param user_gen: Users data generator instance
37 | :param item_gen: Items data generator instance
38 | :param log_df: The history log with user-item pairs with other
39 | necessary fields. During the simulation the results will be
40 | appended to this log on update_log() call, defaults to None
41 | :param user_key_col: User identifier column name, defaults
42 | to 'user_idx'
43 | :param item_key_col: Item identifier column name, defaults
44 | to 'item_idx'
45 | :param data_dir: Directory name to save simulator data
46 | :param spark_session: Spark session to use, defaults to None
47 | """
48 |
49 | self._spark = spark_session if spark_session is not None else State().session
50 |
51 | self._user_key_col = user_key_col
52 | self._item_key_col = item_key_col
53 | self._user_gen = user_gen
54 | self._item_gen = item_gen
55 |
56 | if data_dir is None:
57 | raise ValueError('Pass directory name as `data_dir` parameter')
58 |
59 | self._data_dir = data_dir
60 | pathlib.Path(self._data_dir).mkdir(parents=True, exist_ok=False)
61 |
62 | self._log_filename = self.DEFAULT_LOG_FILENAME
63 |
64 | self._log = None
65 | self._log_schema = None
66 | if log_df is not None:
67 | self.update_log(log_df, iteration='start')
68 |
69 | @property
70 | def log(self):
71 | """
72 | Returns log
73 | """
74 | return self._log
75 |
76 | @property
77 | def data_dir(self):
78 | """
79 | Returns directory with saved simulator data
80 | """
81 | return self._data_dir
82 |
83 | @data_dir.setter
84 | def data_dir(self, value):
85 | self._data_dir = value
86 |
87 | @property
88 | def log_filename(self):
89 | """
90 | Returns name of log file
91 | """
92 | return self._log_filename
93 |
94 | @log_filename.setter
95 | def log_filename(self, value):
96 | self._log_filename = value
97 |
98 | def clear_log(
99 | self
100 | ) -> None:
101 | """
102 | Clears the log
103 | """
104 |
105 | self._log = None
106 | self._log_schema = None
107 |
108 | @staticmethod
109 | def _check_names_and_types(df1_schema, df2_schema):
110 | """
111 | Check if names of columns and their types are equal for two schema.
112 | `Nullable` parameter is not compared.
113 |
114 | """
115 | df1_schema_s = sorted(
116 | [(x.name, x.dataType) for x in df1_schema],
117 | key=lambda x: (x[0], x[1])
118 | )
119 | df2_schema_s = sorted(
120 | [(x.name, x.dataType) for x in df2_schema],
121 | key=lambda x: (x[0], x[1])
122 | )
123 | names_diff = set(df1_schema_s).symmetric_difference(set(df2_schema_s))
124 |
125 | if names_diff:
126 | raise ValueError(
127 | f'Columns of two dataframes are different.\nDifferences: \n'
128 | f'In the first dataframe:\n'
129 | f'{[name_type for name_type in df1_schema_s if name_type in names_diff]}\n'
130 | f'In the second dataframe:\n'
131 | f'{[name_type for name_type in df2_schema_s if name_type in names_diff]}'
132 | )
133 |
134 | def update_log(
135 | self,
136 | log : DataFrame,
137 | iteration : Union[int, str]
138 | ) -> None:
139 | """
140 | Appends the passed log to the existing one
141 |
142 | :param log: The log with user-item pairs with their respective
143 | necessary fields. If there was no log before this: remembers
144 | the log schema, to which the future logs will be compared.
145 | To reset current log and the schema see clear_log()
146 | :param iteration: Iteration label or index
147 | """
148 |
149 | if self._log_schema is None:
150 | self._log_schema = log.schema.fields
151 | else:
152 | self._check_names_and_types(self._log_schema, log.schema)
153 |
154 | write_path = str(
155 | pathlib.Path(self._data_dir)
156 | .joinpath(f'{self.log_filename}/{self.ITER_COLUMN}={iteration}')
157 | )
158 | log.write.parquet(write_path)
159 |
160 | read_path = str(pathlib.Path(self._data_dir).joinpath(f'{self.log_filename}'))
161 | self._log = self._spark.read.parquet(read_path)
162 |
163 | def sample_users(
164 | self,
165 | frac_users : float
166 | ) -> DataFrame:
167 | """
168 | Samples a fraction of random users
169 |
170 | :param frac_users: Fractions of users to sample from user generator
171 | :returns: Sampled users dataframe
172 | """
173 |
174 | return self._user_gen.sample(frac_users)
175 |
176 | def sample_items(
177 | self,
178 | frac_items : float
179 | ) -> DataFrame:
180 | """
181 | Samples a fraction of random items
182 |
183 | :param frac_items: Fractions of items to sample from item generator
184 | :returns: Sampled users dataframe
185 | """
186 |
187 | return self._item_gen.sample(frac_items)
188 |
189 | def get_log(
190 | self,
191 | user_df : DataFrame
192 | ) -> DataFrame:
193 | """
194 | Returns log for users listed in passed users' dataframe
195 |
196 | :param user_df: Dataframe with user identifiers to get log for
197 | :return: Users' history log. Will return None, if there is no log data
198 | """
199 |
200 | if self.log is not None:
201 | return self.log.join(
202 | user_df, on=self._user_key_col, how='leftsemi'
203 | )
204 |
205 | return None
206 |
207 | def get_user_items(
208 | self,
209 | user_df : DataFrame,
210 | selector : Transformer
211 | ) -> Tuple[DataFrame, DataFrame]:
212 | """
213 | Froms candidate pairs to pass to the recommendation algorithm based
214 | on the provided users
215 |
216 | :param user_df: Users dataframe with features and identifiers
217 | :param selector: Transformer to use for creating user-item pairs
218 | :returns: Tuple of user-item pairs and log dataframes which will
219 | be used by recommendation algorithm. Will return None as a log,
220 | if there is no log data
221 | """
222 |
223 | log = self.get_log(user_df)
224 | pairs = selector.transform(user_df)
225 |
226 | return pairs, log
227 |
228 | def sample_responses(
229 | self,
230 | recs_df : DataFrame,
231 | action_models : PipelineModel,
232 | user_features : Optional[DataFrame] = None,
233 | item_features : Optional[DataFrame] = None
234 | ) -> DataFrame:
235 | """
236 | Simulates the actions users took on their recommended items
237 |
238 | :param recs_df: Dataframe with recommendations. Must contain
239 | user's and item's identifier columns. Other columns will
240 | be ignored
241 | :param user_features: Users dataframe with features and identifiers,
242 | can be None
243 | :param item_features: Items dataframe with features and identifiers,
244 | can be None
245 | :param action_models: Spark pipeline to evaluate responses
246 | :returns: DataFrame with user-item pairs and the respective actions
247 | """
248 |
249 | if user_features is not None:
250 | recs_df = recs_df.join(user_features, self._user_key_col, 'left')
251 |
252 | if item_features is not None:
253 | recs_df = recs_df.join(item_features, self._item_key_col, 'left')
254 |
255 | return action_models.transform(recs_df)
256 |
--------------------------------------------------------------------------------
/sim4rec/params/__init__.py:
--------------------------------------------------------------------------------
1 | from .params import (
2 | HasUserKeyColumn,
3 | HasItemKeyColumn,
4 | HasSeed,
5 | HasWeights,
6 | HasMean,
7 | HasStandardDeviation,
8 | HasClipNegative,
9 | HasConstantValue,
10 | HasLabel,
11 | HasDevice,
12 | HasDataSize,
13 | HasParallelizationLevel,
14 | HasSeedSequence
15 | )
16 |
17 | __all__ = [
18 | 'HasUserKeyColumn',
19 | 'HasItemKeyColumn',
20 | 'HasSeed',
21 | 'HasWeights',
22 | 'HasMean',
23 | 'HasStandardDeviation',
24 | 'HasClipNegative',
25 | 'HasConstantValue',
26 | 'HasLabel',
27 | 'HasDevice',
28 | 'HasDataSize',
29 | 'HasParallelizationLevel',
30 | 'HasSeedSequence'
31 | ]
32 |
--------------------------------------------------------------------------------
/sim4rec/params/params.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | from pyspark.ml.param.shared import Params, Param, TypeConverters
4 |
5 |
6 | class HasUserKeyColumn(Params):
7 | """
8 | Controls user identifier column name
9 | """
10 |
11 | userKeyColumn = Param(
12 | Params._dummy(),
13 | "userKeyColumn",
14 | "User identifier column name",
15 | typeConverter=TypeConverters.toString
16 | )
17 |
18 | def setUserKeyColumn(self, value):
19 | """
20 | Sets user indentifier column name
21 |
22 | :param value: new column name
23 | """
24 | return self._set(userKeyColumn=value)
25 |
26 | def getUserKeyColumn(self):
27 | """
28 | Returns item indentifier column name
29 | """
30 | return self.getOrDefault(self.userKeyColumn)
31 |
32 |
33 | class HasItemKeyColumn(Params):
34 | """
35 | Controls item identifier column name
36 | """
37 |
38 | itemKeyColumn = Param(
39 | Params._dummy(),
40 | "itemKeyColumn",
41 | "Item identifier column name",
42 | typeConverter=TypeConverters.toString
43 | )
44 |
45 | def setItemKeyColumn(self, value):
46 | """
47 | Sets item indentifier column name
48 |
49 | :param value: new column name
50 | """
51 | return self._set(itemKeyColumn=value)
52 |
53 | def getItemKeyColumn(self):
54 | """
55 | Returns item indentifier column name
56 | """
57 | return self.getOrDefault(self.itemKeyColumn)
58 |
59 |
60 | class HasSeed(Params):
61 | """
62 | Controls random state seed
63 | """
64 |
65 | seed = Param(
66 | Params._dummy(),
67 | "seed",
68 | "Random state seed",
69 | typeConverter=TypeConverters.toInt
70 | )
71 |
72 | def setSeed(self, value):
73 | """
74 | Changes random state seed
75 |
76 | :param value: new random state seed
77 | """
78 | return self._set(seed=value)
79 |
80 | def getSeed(self):
81 | """
82 | Returns state seed
83 | """
84 | return self.getOrDefault(self.seed)
85 |
86 |
87 | class HasSeedSequence(Params):
88 | """
89 | Controls random state seed of sequence
90 | """
91 | _rng : np.random.Generator
92 |
93 | current_seed = Param(
94 | Params._dummy(),
95 | "current_seed",
96 | "Random state seed sequence",
97 | typeConverter=TypeConverters.toInt
98 | )
99 |
100 | init_seed = Param(
101 | Params._dummy(),
102 | "init_seed",
103 | "Sequence initial seed",
104 | typeConverter=TypeConverters.toInt
105 | )
106 |
107 | def initSeedSequence(self, value):
108 | """
109 | Sets initial random state seed of sequence
110 |
111 | :param value: new initial random state seed of sequence
112 | """
113 | self._rng = np.random.default_rng(value)
114 | return self._set(
115 | init_seed=value if value is not None else -1,
116 | current_seed=self._rng.integers(0, sys.maxsize)
117 | )
118 |
119 | def getInitSeed(self):
120 | """
121 | Returns initial random state seed of sequence
122 | """
123 | value = self.getOrDefault(self.init_seed)
124 | return None if value == -1 else value
125 |
126 | def getNextSeed(self):
127 | """
128 | Returns current random state seed of sequence
129 | """
130 | seed = self.getOrDefault(self.current_seed)
131 | self._set(current_seed=self._rng.integers(0, sys.maxsize))
132 | return seed
133 |
134 |
135 | class HasWeights(Params):
136 | """
137 | Controls weights for models ensemble
138 | """
139 |
140 | weights = Param(
141 | Params._dummy(),
142 | "weights",
143 | "Weights for models ensemble",
144 | typeConverter=TypeConverters.toListFloat
145 | )
146 |
147 | def setWeights(self, value):
148 | """
149 | Changes weights for models ensemble
150 |
151 | :param value: new weights
152 | """
153 | return self._set(weights=value)
154 |
155 | def getWeights(self):
156 | """
157 | Returns weigths for models ensemble
158 | """
159 | return self.getOrDefault(self.weights)
160 |
161 |
162 | class HasMean(Params):
163 | """
164 | Controls mean parameter of normal distribution
165 | """
166 |
167 | mean = Param(
168 | Params._dummy(),
169 | "mean",
170 | "Mean parameter of normal distribution",
171 | typeConverter=TypeConverters.toFloat
172 | )
173 |
174 | def setMean(self, value):
175 | """
176 | Changes mean parameter of normal distribution
177 |
178 | :param value: new value of mean parameter
179 | """
180 | return self._set(mean=value)
181 |
182 | def getMean(self):
183 | """
184 | Returns mean parameter
185 | """
186 | return self.getOrDefault(self.mean)
187 |
188 |
189 | class HasStandardDeviation(Params):
190 | """
191 | Controls Standard Deviation parameter of normal distribution
192 | """
193 |
194 | std = Param(
195 | Params._dummy(),
196 | "std",
197 | "Standard Deviation parameter of normal distribution",
198 | typeConverter=TypeConverters.toFloat
199 | )
200 |
201 | def setStandardDeviation(self, value):
202 | """
203 | Changes Standard Deviation parameter of normal distribution
204 |
205 | :param value: new value of std parameter
206 | """
207 |
208 | return self._set(std=value)
209 |
210 | def getStandardDeviation(self):
211 | """
212 | Returns value of std parameter
213 | """
214 | return self.getOrDefault(self.std)
215 |
216 |
217 | class HasClipNegative(Params):
218 | """
219 | Controls flag that controls clipping of negative values
220 | """
221 |
222 | clipNegative = Param(
223 | Params._dummy(),
224 | "clipNegative",
225 | "Boolean flag to clip negative values",
226 | typeConverter=TypeConverters.toBoolean
227 | )
228 |
229 | def setClipNegative(self, value):
230 | """
231 | Changes flag that controls clipping of negative values
232 |
233 | :param value: New value of flag
234 | """
235 | return self._set(clipNegative=value)
236 |
237 | def getClipNegative(self):
238 | """
239 | Returns flag that controls clipping of negative values
240 | """
241 | return self.getOrDefault(self.clipNegative)
242 |
243 |
244 | class HasConstantValue(Params):
245 | """
246 | Controls constant value parameter
247 | """
248 |
249 | constantValue = Param(
250 | Params._dummy(),
251 | "constantValue",
252 | "Constant value parameter",
253 | typeConverter=TypeConverters.toFloat
254 | )
255 |
256 | def setConstantValue(self, value):
257 | """
258 | Sets constant value parameter
259 |
260 | :param value: Value
261 | """
262 | return self._set(constantValue=value)
263 |
264 | def getConstantValue(self):
265 | """
266 | Returns constant value
267 | """
268 | return self.getOrDefault(self.constantValue)
269 |
270 |
271 | class HasLabel(Params):
272 | """
273 | Controls string label
274 | """
275 | label = Param(
276 | Params._dummy(),
277 | "label",
278 | "String label",
279 | typeConverter=TypeConverters.toString
280 | )
281 |
282 | def setLabel(self, value):
283 | """
284 | Sets string label
285 |
286 | :param value: Label
287 | """
288 | return self._set(label=value)
289 |
290 | def getLabel(self):
291 | """
292 | Returns current string label
293 | """
294 | return self.getOrDefault(self.label)
295 |
296 |
297 | class HasDevice(Params):
298 | """
299 | Controls device
300 | """
301 | device = Param(
302 | Params._dummy(),
303 | "device",
304 | "Name of a device to use",
305 | typeConverter=TypeConverters.toString
306 | )
307 |
308 | def setDevice(self, value):
309 | """
310 | Sets device
311 |
312 | :param value: Name of device to use
313 | """
314 | return self._set(device=value)
315 |
316 | def getDevice(self):
317 | """
318 | Returns current device
319 | """
320 | return self.getOrDefault(self.device)
321 |
322 |
323 | class HasDataSize(Params):
324 | """
325 | Controls data size
326 | """
327 | data_size = Param(
328 | Params._dummy(),
329 | "data_size",
330 | "Size of a DataFrame",
331 | typeConverter=TypeConverters.toInt
332 | )
333 |
334 | def setDataSize(self, value):
335 | """
336 | Sets data size to a certain value
337 |
338 | :param value: Size of a DataFrame
339 | """
340 | return self._set(data_size=value)
341 |
342 | def getDataSize(self):
343 | """
344 | Returns current size of a DataFrame
345 | """
346 | return self.getOrDefault(self.data_size)
347 |
348 |
349 | class HasParallelizationLevel(Params):
350 | """
351 | Controls parallelization level
352 | """
353 | parallelizationLevel = Param(
354 | Params._dummy(),
355 | "parallelizationLevel",
356 | "Level of parallelization",
357 | typeConverter=TypeConverters.toInt
358 | )
359 |
360 | def setParallelizationLevel(self, value):
361 | """
362 | Sets level of parallelization
363 |
364 | :param value: Level of parallelization
365 | """
366 | return self._set(parallelizationLevel=value)
367 |
368 | def getParallelizationLevel(self):
369 | """
370 | Returns current level of parallelization
371 | """
372 | return self.getOrDefault(self.parallelizationLevel)
373 |
--------------------------------------------------------------------------------
/sim4rec/recommenders/ucb.py:
--------------------------------------------------------------------------------
1 | import joblib
2 | import math
3 |
4 | from os.path import join
5 | from typing import Any, Dict, List, Optional
6 |
7 | from abc import ABC
8 | import numpy as np
9 | import pandas as pd
10 | from numpy.random import default_rng
11 |
12 | from pyspark.sql import DataFrame, Window
13 | from pyspark.sql import functions as sf
14 |
15 | from sim4rec.recommenders.utils import REC_SCHEMA
16 |
17 |
18 | class UCB(ABC):
19 | """Simple bandit model, which caclulate item relevance as upper confidence bound
20 | (`UCB `_)
21 | for the confidence interval of true fraction of positive ratings.
22 | Should be used in iterative (online) mode to achive proper recommendation quality.
23 | ``relevance`` from log must be converted to binary 0-1 form.
24 | .. math::
25 | pred_i = ctr_i + \\sqrt{\\frac{c\\ln{n}}{n_i}}
26 | :math:`pred_i` -- predicted relevance of item :math:`i`
27 | :math:`c` -- exploration coeficient
28 | :math:`n` -- number of interactions in log
29 | :math:`n_i` -- number of interactions with item :math:`i`
30 | """
31 |
32 | can_predict_cold_users = True
33 | can_predict_cold_items = True
34 | item_popularity: DataFrame
35 | fill: float
36 |
37 | def __init__(
38 | self,
39 | exploration_coef: float = 2,
40 | sample: bool = False,
41 | seed: Optional[int] = None,
42 | ):
43 | """
44 | :param exploration_coef: exploration coefficient
45 | :param sample: flag to choose recommendation strategy.
46 | If True, items are sampled with a probability proportional
47 | to the calculated predicted relevance
48 | :param seed: random seed. Provides reproducibility if fixed
49 | """
50 | # pylint: disable=super-init-not-called
51 | self.coef = exploration_coef
52 | self.sample = sample
53 | self.seed = seed
54 |
55 | @property
56 | def _init_args(self):
57 | return {
58 | "exploration_coef": self.coef,
59 | "sample": self.sample,
60 | "seed": self.seed,
61 | }
62 |
63 | @property
64 | def _dataframes(self):
65 | return {"item_popularity": self.item_popularity}
66 |
67 | def _save_model(self, path: str):
68 | joblib.dump({"fill": self.fill}, join(path))
69 |
70 | def _load_model(self, path: str):
71 | self.fill = joblib.load(join(path))["fill"]
72 |
73 | def fit(
74 | self,
75 | log: DataFrame,
76 | user_features: Optional[DataFrame] = None,
77 | item_features: Optional[DataFrame] = None,
78 | ) -> None:
79 | vals = log.select("relevance").where(
80 | (sf.col("relevance") != 1) & (sf.col("relevance") != 0)
81 | )
82 | if vals.count() > 0:
83 | raise ValueError("Relevance values in log must be 0 or 1")
84 |
85 | items_counts = log.groupby("item_idx").agg(
86 | sf.sum("relevance").alias("pos"),
87 | sf.count("relevance").alias("total"),
88 | )
89 |
90 | full_count = log.count()
91 | items_counts = items_counts.withColumn(
92 | "relevance",
93 | (
94 | sf.col("pos") / sf.col("total")
95 | + sf.sqrt(
96 | sf.log(sf.lit(self.coef * full_count)) / sf.col("total")
97 | )
98 | ),
99 | )
100 |
101 | self.item_popularity = items_counts.drop("pos", "total")
102 | self.item_popularity.cache().count()
103 |
104 | self.fill = 1 + math.sqrt(math.log(self.coef * full_count))
105 |
106 | def _clear_cache(self):
107 | if hasattr(self, "item_popularity"):
108 | self.item_popularity.unpersist()
109 |
110 | def _predict_with_sampling(
111 | self,
112 | log: DataFrame,
113 | item_popularity: DataFrame,
114 | k: int,
115 | users: DataFrame,
116 | filter_seen_items: bool = True,
117 | ):
118 | items_pd = item_popularity.withColumn(
119 | "probability",
120 | sf.col("relevance")
121 | / item_popularity.select(sf.sum("relevance")).first()[0],
122 | ).toPandas()
123 |
124 | seed = self.seed
125 |
126 | def grouped_map(pandas_df: pd.DataFrame) -> pd.DataFrame:
127 | user_idx = pandas_df["user_idx"][0]
128 | cnt = pandas_df["cnt"][0]
129 |
130 | if seed is not None:
131 | local_rng = default_rng(seed + user_idx)
132 | else:
133 | local_rng = default_rng()
134 |
135 | items_positions = local_rng.choice(
136 | np.arange(items_pd.shape[0]),
137 | size=cnt,
138 | p=items_pd["probability"].values,
139 | replace=False,
140 | )
141 |
142 | return pd.DataFrame(
143 | {
144 | "user_idx": cnt * [user_idx],
145 | "item_idx": items_pd["item_idx"].values[items_positions],
146 | "relevance": items_pd["probability"].values[
147 | items_positions
148 | ],
149 | }
150 | )
151 |
152 | recs = users.withColumn("cnt", sf.lit(k))
153 | if log is not None and filter_seen_items:
154 | recs = (
155 | log.join(users, how="right", on="user_idx")
156 | .select("user_idx", "item_idx")
157 | .groupby("user_idx")
158 | .agg(sf.countDistinct("item_idx").alias("cnt"))
159 | .selectExpr(
160 | "user_idx",
161 | f"LEAST(cnt + {k}, {items_pd.shape[0]}) AS cnt",
162 | )
163 | )
164 | return recs.groupby("user_idx").applyInPandas(grouped_map, REC_SCHEMA)
165 |
166 | @staticmethod
167 | def _calc_max_hist_len(log, users):
168 | max_hist_len = (
169 | (
170 | log.join(users, on="user_idx")
171 | .groupBy("user_idx")
172 | .agg(sf.countDistinct("item_idx").alias("items_count"))
173 | )
174 | .select(sf.max("items_count"))
175 | .collect()[0][0]
176 | )
177 | # all users have empty history
178 | if max_hist_len is None:
179 | return 0
180 | return max_hist_len
181 |
182 | # pylint: disable=too-many-arguments
183 | def predict(
184 | self,
185 | log: DataFrame,
186 | k: int,
187 | users: DataFrame,
188 | items: DataFrame,
189 | user_features: Optional[DataFrame] = None,
190 | item_features: Optional[DataFrame] = None,
191 | filter_seen_items: bool = True,
192 | ) -> DataFrame:
193 |
194 | selected_item_popularity = self.item_popularity.join(
195 | items,
196 | on="item_idx",
197 | how="right",
198 | ).fillna(value=self.fill, subset=["relevance"])
199 |
200 | if self.sample:
201 | return self._predict_with_sampling(
202 | log=log,
203 | item_popularity=selected_item_popularity,
204 | k=k,
205 | users=users,
206 | filter_seen_items=filter_seen_items,
207 | )
208 |
209 | selected_item_popularity = selected_item_popularity.withColumn(
210 | "rank",
211 | sf.row_number().over(
212 | Window.orderBy(
213 | sf.col("relevance").desc(), sf.col("item_idx").desc()
214 | )
215 | ),
216 | )
217 |
218 | max_hist_len = (
219 | self._calc_max_hist_len(log, users)
220 | if filter_seen_items and log is not None
221 | else 0
222 | )
223 |
224 | return users.crossJoin(
225 | selected_item_popularity.filter(sf.col("rank") <= k + max_hist_len)
226 | ).drop("rank")
227 |
228 | def _predict_pairs(
229 | self,
230 | pairs: DataFrame,
231 | log: Optional[DataFrame] = None,
232 | user_features: Optional[DataFrame] = None,
233 | item_features: Optional[DataFrame] = None,
234 | ) -> DataFrame:
235 | return pairs.join(
236 | self.item_popularity, on="item_idx", how="left"
237 | ).fillna(value=self.fill, subset=["relevance"])
--------------------------------------------------------------------------------
/sim4rec/recommenders/utils.py:
--------------------------------------------------------------------------------
1 | from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType, TimestampType
2 |
3 |
4 | def get_schema(
5 | query_column: str = "query_id",
6 | item_column: str = "item_id",
7 | timestamp_column: str = "timestamp",
8 | rating_column: str = "rating",
9 | has_timestamp: bool = True,
10 | has_rating: bool = True,
11 | ):
12 | """
13 | Get Spark Schema with query_id, item_id, rating, timestamp columns
14 |
15 | :param query_column: column name with query ids
16 | :param item_column: column name with item ids
17 | :param timestamp_column: column name with timestamps
18 | :param rating_column: column name with ratings
19 | :param has_rating: flag to add rating to schema
20 | :param has_timestamp: flag to add tomestamp to schema
21 | """
22 | base = [
23 | StructField(query_column, IntegerType()),
24 | StructField(item_column, IntegerType()),
25 | ]
26 | if has_timestamp:
27 | base += [StructField(timestamp_column, TimestampType())]
28 | if has_rating:
29 | base += [StructField(rating_column, DoubleType())]
30 | return StructType(base)
31 |
32 | REC_SCHEMA = get_schema(
33 | query_column="user_idx",
34 | item_column="item_idx",
35 | rating_column="relevance",
36 | has_timestamp=False,
37 | )
38 |
--------------------------------------------------------------------------------
/sim4rec/response/__init__.py:
--------------------------------------------------------------------------------
1 | from .response import (
2 | ActionModelEstimator,
3 | ActionModelTransformer,
4 | ConstantResponse,
5 | NoiseResponse,
6 | CosineSimilatiry,
7 | BernoulliResponse,
8 | ParametricResponseFunction
9 | )
10 |
11 | __all__ = [
12 | 'ActionModelEstimator',
13 | 'ActionModelTransformer',
14 | 'ConstantResponse',
15 | 'NoiseResponse',
16 | 'CosineSimilatiry',
17 | 'BernoulliResponse',
18 | 'ParametricResponseFunction'
19 | ]
20 |
--------------------------------------------------------------------------------
/sim4rec/response/response.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=no-member,unused-argument,too-many-ancestors,abstract-method
2 | import math
3 | from typing import List
4 | from collections.abc import Iterable
5 |
6 | import pyspark.sql.types as st
7 | import pyspark.sql.functions as sf
8 | from pyspark.ml import Transformer, Estimator
9 | from pyspark.ml.param.shared import HasInputCols, HasInputCol, HasOutputCol
10 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
11 |
12 | from pyspark.sql import DataFrame
13 | from pyspark import keyword_only
14 |
15 | from sim4rec.params import (
16 | HasWeights, HasSeedSequence,
17 | HasConstantValue, HasClipNegative,
18 | HasMean, HasStandardDeviation
19 | )
20 |
21 |
22 | class ActionModelEstimator(Estimator,
23 | HasOutputCol,
24 | DefaultParamsReadable,
25 | DefaultParamsWritable):
26 | """
27 | Base class for response estimator
28 | """
29 | @keyword_only
30 | def __init__(
31 | self,
32 | outputCol : str = None
33 | ):
34 | """
35 | :param outputCol: Name of the response score column, defaults
36 | to None
37 | """
38 |
39 | super().__init__()
40 | self.setParams(**self._input_kwargs)
41 |
42 | @keyword_only
43 | def setParams(
44 | self,
45 | outputCol : str = None
46 | ):
47 | """
48 | Sets parameters for response estimator
49 | """
50 | return self._set(**self._input_kwargs)
51 |
52 |
53 | class ActionModelTransformer(Transformer,
54 | HasOutputCol,
55 | DefaultParamsReadable,
56 | DefaultParamsWritable):
57 | """
58 | Base class for response transformer. transform() will be
59 | used to calculate score based on inputCols, and write it
60 | to outputCol column
61 | """
62 | @keyword_only
63 | def __init__(
64 | self,
65 | outputCol : str = None
66 | ):
67 | """
68 | :param outputCol: Name of the response score column, defaults
69 | to None
70 | """
71 |
72 | super().__init__()
73 | self.setParams(**self._input_kwargs)
74 |
75 | @keyword_only
76 | def setParams(
77 | self,
78 | outputCol : str = None
79 | ):
80 | """
81 | Sets parameters for response transformer
82 | """
83 | return self._set(**self._input_kwargs)
84 |
85 |
86 | class BernoulliResponse(ActionModelTransformer,
87 | HasInputCol,
88 | HasSeedSequence):
89 | """
90 | Samples responses from probability column
91 | """
92 | def __init__(
93 | self,
94 | inputCol : str = None,
95 | outputCol : str = None,
96 | seed : int = None
97 | ):
98 | """
99 | :param inputCol: Probability column name. Probability should
100 | be in range [0; 1]
101 | :param outputCol: Output column name
102 | :param seed: Random state seed, defaults to None
103 | """
104 |
105 | super().__init__(outputCol=outputCol)
106 |
107 | self._set(inputCol=inputCol)
108 | self.initSeedSequence(seed)
109 |
110 | def _transform(
111 | self,
112 | dataset : DataFrame
113 | ):
114 | inputCol = self.getInputCol()
115 | outputCol = self.getOutputCol()
116 | seed = self.getNextSeed()
117 |
118 | return dataset.withColumn(
119 | outputCol,
120 | sf.when(sf.rand(seed=seed) <= sf.col(inputCol), 1).otherwise(0)
121 | )
122 |
123 |
124 | class NoiseResponse(ActionModelTransformer,
125 | HasMean,
126 | HasStandardDeviation,
127 | HasClipNegative,
128 | HasSeedSequence):
129 | # pylint: disable=too-many-arguments
130 | """
131 | Creates random response sampled from normal distribution
132 | """
133 | def __init__(
134 | self,
135 | mu : float = None,
136 | sigma : float = None,
137 | outputCol : str = None,
138 | clipNegative : bool = True,
139 | seed : int = None
140 | ):
141 | """
142 | :param mu: Mean parameter of normal distribution
143 | :param sigma: Standard deviation parameter of normal distribution
144 | :param outputCol: Output column name
145 | :param clip_negative: Whether to make response non-negative,
146 | defaults to True
147 | :param seed: Random state seed, defaults to None
148 | """
149 |
150 | super().__init__(outputCol=outputCol)
151 |
152 | self._set(mean=mu, std=sigma, clipNegative=clipNegative)
153 | self.initSeedSequence(seed)
154 |
155 | def _transform(
156 | self,
157 | dataset : DataFrame
158 | ):
159 | mu = self.getMean()
160 | sigma = self.getStandardDeviation()
161 | clip_negative = self.getClipNegative()
162 | outputCol = self.getOutputCol()
163 | seed = self.getNextSeed()
164 |
165 | expr = sf.randn(seed=seed) * sigma + mu
166 | if clip_negative:
167 | expr = sf.greatest(expr, sf.lit(0))
168 |
169 | return dataset.withColumn(outputCol, expr)
170 |
171 |
172 | class ConstantResponse(ActionModelTransformer,
173 | HasConstantValue):
174 | """
175 | Always returns constant valued response
176 | """
177 | def __init__(
178 | self,
179 | value : float = 0.0,
180 | outputCol : str = None
181 | ):
182 | """
183 | :param value: Response value
184 | :param outputCol: Output column name
185 | """
186 |
187 | super().__init__(outputCol=outputCol)
188 |
189 | self._set(constantValue=value)
190 |
191 | def _transform(
192 | self,
193 | dataset : DataFrame
194 | ):
195 | value = self.getConstantValue()
196 | outputColumn = self.getOutputCol()
197 |
198 | return dataset.withColumn(outputColumn, sf.lit(value))
199 |
200 |
201 | class CosineSimilatiry(ActionModelTransformer,
202 | HasInputCols):
203 | """
204 | Calculates the cosine similarity between two vectors.
205 | The result is in [0; 1] range
206 | """
207 | def __init__(
208 | self,
209 | inputCols : List[str] = None,
210 | outputCol : str = None
211 | ):
212 | """
213 | :param inputCols: Two column names with dense vectors
214 | :param outputCol: Output column name
215 | """
216 |
217 | if inputCols is not None and len(inputCols) != 2:
218 | raise ValueError('There must be two array columns '
219 | 'to calculate cosine similarity')
220 |
221 | super().__init__(outputCol=outputCol)
222 | self._set(inputCols=inputCols)
223 |
224 | def _transform(
225 | self,
226 | dataset : DataFrame
227 | ):
228 | inputCols = self.getInputCols()
229 | outputCol = self.getOutputCol()
230 |
231 | def cosine_similarity(first, second):
232 | num = first.dot(second)
233 | den = first.norm(2) * second.norm(2)
234 |
235 | if den == 0:
236 | return float(0)
237 |
238 | cosine = max(min(num / den, 1.0), -1.0)
239 | return float(1 - math.acos(cosine) / math.pi)
240 |
241 | cos_udf = sf.udf(cosine_similarity, st.DoubleType())
242 |
243 | return dataset.withColumn(
244 | outputCol,
245 | # pylint: disable=unsubscriptable-object
246 | cos_udf(sf.col(inputCols[0]), sf.col(inputCols[1]))
247 | )
248 |
249 |
250 | class ParametricResponseFunction(ActionModelTransformer,
251 | HasInputCols,
252 | HasWeights):
253 | """
254 | Calculates response based on the weighted sum of input responses
255 | """
256 | def __init__(
257 | self,
258 | inputCols : List[str] = None,
259 | outputCol : str = None,
260 | weights : Iterable = None
261 | ):
262 | """
263 | :param inputCols: Input responses column names
264 | :param outputCol: Output column name
265 | :param weights: Input responses weights
266 | """
267 |
268 | super().__init__(outputCol=outputCol)
269 | self._set(inputCols=inputCols, weights=weights)
270 |
271 | def _transform(
272 | self,
273 | dataset : DataFrame
274 | ):
275 | inputCols = self.getInputCols()
276 | outputCol = self.getOutputCol()
277 | weights = self.getWeights()
278 |
279 | return dataset.withColumn(
280 | outputCol,
281 | sum([
282 | # pylint: disable=unsubscriptable-object
283 | sf.col(c) * weights[i]
284 | for i, c in enumerate(inputCols)
285 | ])
286 | )
287 |
--------------------------------------------------------------------------------
/sim4rec/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .uce import (
2 | VectorElementExtractor,
3 | NotFittedError,
4 | EmptyDataFrameError,
5 | save,
6 | load
7 | )
8 | from .convert import pandas_to_spark
9 |
10 | from .session_handler import get_spark_session, State
11 |
12 | __all__ = [
13 | 'VectorElementExtractor',
14 | 'NotFittedError',
15 | 'EmptyDataFrameError',
16 | 'State',
17 | 'save',
18 | 'load',
19 | 'pandas_to_spark',
20 | 'get_spark_session'
21 | ]
22 |
--------------------------------------------------------------------------------
/sim4rec/utils/convert.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from pyspark.sql import SparkSession, DataFrame
3 |
4 | from sim4rec.utils.session_handler import State
5 |
6 |
7 | def pandas_to_spark(
8 | df: pd.DataFrame,
9 | schema=None,
10 | spark_session : SparkSession = None) -> DataFrame:
11 | """
12 | Converts pandas DataFrame to spark DataFrame
13 |
14 | :param df: DataFrame to convert
15 | :param schema: Schema of the dataframe, defaults to None
16 | :param spark_session: Spark session to use, defaults to None
17 | :returns: data converted to spark DataFrame
18 | """
19 |
20 | if not isinstance(df, pd.DataFrame):
21 | raise ValueError('df must be an instance of pd.DataFrame')
22 |
23 | if len(df) == 0:
24 | raise ValueError('Dataframe is empty')
25 |
26 | if spark_session is not None:
27 | spark = spark_session
28 | else:
29 | spark = State().session
30 |
31 | return spark.createDataFrame(df, schema=schema)
32 |
--------------------------------------------------------------------------------
/sim4rec/utils/session_handler.py:
--------------------------------------------------------------------------------
1 | """
2 | Painless creation and retrieval of Spark sessions
3 | """
4 |
5 | import os
6 | import sys
7 | from math import floor
8 | from typing import Any, Dict, Optional
9 |
10 | import psutil
11 | import torch
12 | from pyspark.sql import SparkSession
13 |
14 |
15 | def get_spark_session(
16 | spark_memory: Optional[int] = None,
17 | shuffle_partitions: Optional[int] = None,
18 | ) -> SparkSession:
19 | """
20 | Get default SparkSession
21 |
22 | :param spark_memory: GB of memory allocated for Spark;
23 | 70% of RAM by default.
24 | :param shuffle_partitions: number of partitions for Spark; triple CPU count by default
25 | """
26 | if os.environ.get("SCRIPT_ENV", None) == "cluster":
27 | return SparkSession.builder.getOrCreate()
28 |
29 | os.environ["PYSPARK_PYTHON"] = sys.executable
30 | os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
31 |
32 | if spark_memory is None:
33 | spark_memory = floor(psutil.virtual_memory().total / 1024**3 * 0.7)
34 | if shuffle_partitions is None:
35 | shuffle_partitions = os.cpu_count() * 3
36 | driver_memory = f"{spark_memory}g"
37 | user_home = os.environ["HOME"]
38 | spark = (
39 | SparkSession.builder.config("spark.driver.memory", driver_memory)
40 | .config(
41 | "spark.driver.extraJavaOptions",
42 | "-Dio.netty.tryReflectionSetAccessible=true",
43 | )
44 | .config("spark.sql.shuffle.partitions", str(shuffle_partitions))
45 | .config("spark.local.dir", os.path.join(user_home, "tmp"))
46 | .config("spark.driver.maxResultSize", "4g")
47 | .config("spark.driver.bindAddress", "127.0.0.1")
48 | .config("spark.driver.host", "localhost")
49 | .config("spark.sql.execution.arrow.pyspark.enabled", "true")
50 | .config("spark.kryoserializer.buffer.max", "256m")
51 | .config("spark.files.overwrite", "true")
52 | .master("local[*]")
53 | .enableHiveSupport()
54 | .getOrCreate()
55 | )
56 | return spark
57 |
58 |
59 | # pylint: disable=too-few-public-methods
60 | class Borg:
61 | """
62 | This class allows to share objects between instances.
63 | """
64 |
65 | _shared_state: Dict[str, Any] = {}
66 |
67 | def __init__(self):
68 | self.__dict__ = self._shared_state
69 |
70 |
71 | # pylint: disable=too-few-public-methods
72 | class State(Borg):
73 | """
74 | All modules look for Spark session via this class. You can put your own session here.
75 |
76 | Other parameters are stored here too: ``default device`` for ``pytorch`` (CPU/CUDA)
77 | """
78 |
79 | def __init__(
80 | self,
81 | session: Optional[SparkSession] = None,
82 | device: Optional[torch.device] = None,
83 | ):
84 | Borg.__init__(self)
85 |
86 | if session is None:
87 | if not hasattr(self, "session"):
88 | self.session = get_spark_session()
89 | else:
90 | self.session = session
91 |
92 | if device is None:
93 | if not hasattr(self, "device"):
94 | if torch.cuda.is_available():
95 | self.device = torch.device(
96 | f"cuda:{torch.cuda.current_device()}"
97 | )
98 | else:
99 | self.device = torch.device("cpu")
100 | else:
101 | self.device = device
102 |
--------------------------------------------------------------------------------
/sim4rec/utils/uce.py:
--------------------------------------------------------------------------------
1 | # Utility class extensions
2 | # pylint: disable=no-member,unused-argument
3 | import pickle
4 | import pyspark.sql.functions as sf
5 | import pyspark.sql.types as st
6 |
7 | from pyspark.sql import DataFrame
8 | from pyspark.ml import Transformer
9 | from pyspark.ml.param.shared import HasInputCol, HasOutputCol
10 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
11 |
12 |
13 | from pyspark.ml.param.shared import Params, Param, TypeConverters
14 | from pyspark import keyword_only
15 |
16 |
17 | class NotFittedError(Exception):
18 | # pylint: disable=missing-class-docstring
19 | pass
20 |
21 |
22 | class EmptyDataFrameError(Exception):
23 | # pylint: disable=missing-class-docstring
24 | pass
25 |
26 |
27 | # pylint: disable=too-many-ancestors
28 | class VectorElementExtractor(Transformer,
29 | HasInputCol, HasOutputCol,
30 | DefaultParamsReadable, DefaultParamsWritable):
31 | """
32 | Extracts element at index from array column
33 | """
34 |
35 | index = Param(
36 | Params._dummy(),
37 | "index",
38 | "Array index to extract",
39 | typeConverter=TypeConverters.toInt
40 | )
41 |
42 | def setIndex(self, value):
43 | """
44 | Sets index to a certain value
45 | :param value: Value to set index of an element
46 | """
47 | return self._set(index=value)
48 |
49 | def getIndex(self):
50 | """
51 | Returns index of element
52 | """
53 | return self.getOrDefault(self.index)
54 |
55 | @keyword_only
56 | def __init__(
57 | self,
58 | inputCol : str = None,
59 | outputCol : str = None,
60 | index : int = None
61 | ):
62 | """
63 | :param inputCol: Input column with array
64 | :param outputCol: Output column name
65 | :param index: Index of an element within array
66 | """
67 | super().__init__()
68 | self.setParams(**self._input_kwargs)
69 |
70 | @keyword_only
71 | def setParams(
72 | self,
73 | inputCol : str = None,
74 | outputCol : str = None,
75 | index : int = None
76 | ):
77 | """
78 | Sets parameters for extractor
79 | """
80 | return self._set(**self._input_kwargs)
81 |
82 | def _transform(
83 | self,
84 | dataset : DataFrame
85 | ):
86 | index = self.getIndex()
87 |
88 | el_udf = sf.udf(
89 | lambda x : float(x[index]), st.DoubleType()
90 | )
91 |
92 | inputCol = self.getInputCol()
93 | outputCol = self.getOutputCol()
94 |
95 | return dataset.withColumn(outputCol, el_udf(inputCol))
96 |
97 |
98 | def save(obj : object, filename : str):
99 | """
100 | Saves an object to pickle dump
101 | :param obj: Instance
102 | :param filename: File name of a dump
103 | """
104 |
105 | with open(filename, 'wb') as f:
106 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
107 |
108 |
109 | def load(filename : str):
110 | """
111 | Loads a pickle dump from file
112 | :param filename: File name of a dump
113 | :return: Read instance
114 | """
115 |
116 | with open(filename, 'rb') as f:
117 | obj = pickle.load(f)
118 |
119 | return obj
120 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sb-ai-lab/Sim4Rec/267b644932b0e120cc9887f0f3a16be38d5739b7/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pyspark.sql import DataFrame, SparkSession
3 |
4 | SEED = 1234
5 |
6 |
7 | @pytest.fixture(scope="session")
8 | def spark() -> SparkSession:
9 | return SparkSession.builder\
10 | .appName('simulator_test')\
11 | .master('local[4]')\
12 | .config('spark.sql.shuffle.partitions', '4')\
13 | .config('spark.default.parallelism', '4')\
14 | .config('spark.driver.extraJavaOptions', '-XX:+UseG1GC')\
15 | .config('spark.executor.extraJavaOptions', '-XX:+UseG1GC')\
16 | .config('spark.sql.autoBroadcastJoinThreshold', '-1')\
17 | .config('spark.driver.memory', '4g')\
18 | .getOrCreate()
19 |
20 |
21 | @pytest.fixture(scope="session")
22 | def users_df(spark: SparkSession) -> DataFrame:
23 | data = [
24 | (0, 1.25, -0.75, 0.5),
25 | (1, 0.2, -1.0, 0.0),
26 | (2, 0.85, -0.5, -1.5),
27 | (3, -0.33, -0.33, 0.33),
28 | (4, 0.1, 0.2, 0.3)
29 | ]
30 |
31 | return spark.createDataFrame(
32 | data=data,
33 | schema=['user_id', 'user_attr_1', 'user_attr_2', 'user_attr_3']
34 | )
35 |
36 |
37 | @pytest.fixture(scope="session")
38 | def items_df(spark: SparkSession) -> DataFrame:
39 | data = [
40 | (0, 0.45, -0.45, 1.2),
41 | (1, -0.3, 0.75, 0.25),
42 | (2, 1.25, -0.75, 0.5),
43 | (3, -1.0, 0.0, -0.5),
44 | (4, 0.5, -0.5, -1.0)
45 | ]
46 |
47 | return spark.createDataFrame(
48 | data=data,
49 | schema=['item_id', 'item_attr_1', 'item_attr_2', 'item_attr_3']
50 | )
51 |
52 |
53 | @pytest.fixture(scope="session")
54 | def log_df(spark: SparkSession) -> DataFrame:
55 | data = [
56 | (0, 2, 1.0, 0),
57 | (1, 1, 0.0, 0),
58 | (1, 2, 0.0, 0),
59 | (2, 0, 1.0, 0),
60 | (2, 2, 1.0, 0)
61 | ]
62 | return spark.createDataFrame(
63 | data=data,
64 | schema=['user_id', 'item_id', 'relevance', 'timestamp']
65 | )
66 |
--------------------------------------------------------------------------------
/tests/test_embeddings.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pyspark.sql import DataFrame
3 |
4 | from sim4rec.modules import (
5 | EncoderEstimator,
6 | EncoderTransformer
7 | )
8 |
9 | SEED = 1234
10 |
11 |
12 | @pytest.fixture(scope="module")
13 | def estimator(users_df : DataFrame) -> EncoderEstimator:
14 | return EncoderEstimator(
15 | inputCols=users_df.columns,
16 | outputCols=[f'encoded_{i}' for i in range(5)],
17 | hidden_dim=10,
18 | lr=0.001,
19 | batch_size=32,
20 | num_loader_workers=2,
21 | max_iter=10,
22 | device_name='cpu',
23 | seed=SEED
24 | )
25 |
26 |
27 | @pytest.fixture(scope="module")
28 | def transformer(
29 | users_df : DataFrame,
30 | estimator : EncoderEstimator
31 | ) -> EncoderTransformer:
32 | return estimator.fit(users_df)
33 |
34 |
35 | def test_estimator_fit(
36 | estimator : EncoderEstimator,
37 | transformer : EncoderTransformer
38 | ):
39 | assert estimator._input_dim == len(estimator.getInputCols())
40 | assert estimator._latent_dim == len(estimator.getOutputCols())
41 | assert estimator.getDevice() == transformer.getDevice()
42 | assert str(next(transformer._encoder.parameters()).device) == transformer.getDevice()
43 |
44 |
45 | def test_transformer_transform(
46 | users_df : DataFrame,
47 | transformer : EncoderTransformer
48 | ):
49 | result = transformer.transform(users_df)
50 |
51 | assert result.count() == users_df.count()
52 | assert len(result.columns) == 5
53 | assert set(result.columns) == set([f'encoded_{i}' for i in range(5)])
54 |
--------------------------------------------------------------------------------
/tests/test_evaluation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pyspark.sql import DataFrame, SparkSession
4 |
5 | from sim4rec.modules import (
6 | evaluate_synthetic,
7 | EvaluateMetrics,
8 | ks_test,
9 | kl_divergence
10 | )
11 | from sim4rec.response import ConstantResponse
12 |
13 |
14 | @pytest.fixture(scope="function")
15 | def evaluator() -> EvaluateMetrics:
16 | return EvaluateMetrics(
17 | userKeyCol='user_id',
18 | itemKeyCol='item_id',
19 | predictionCol='relevance',
20 | labelCol='response',
21 | mllib_metrics=['mse', 'f1', 'areaUnderROC']
22 | )
23 |
24 |
25 | @pytest.fixture(scope="module")
26 | def response_df(spark : SparkSession) -> DataFrame:
27 | data = [
28 | (0, 0, 0.0, 0.0),
29 | (0, 1, 0.5, 1.0),
30 | (0, 2, 1.0, 1.0),
31 | (1, 0, 1.0, 0.0),
32 | (1, 1, 0.5, 0.0),
33 | (1, 2, 0.0, 1.0)
34 | ]
35 | return spark.createDataFrame(data=data, schema=['user_id', 'item_id', 'relevance', 'response'])
36 |
37 |
38 | def test_evaluate_metrics(
39 | evaluator : EvaluateMetrics,
40 | response_df : DataFrame
41 | ):
42 | result = evaluator(response_df)
43 | assert 'mse' in result
44 | assert 'f1' in result
45 | assert 'areaUnderROC' in result
46 |
47 | result = evaluator(response_df)
48 | assert 'mse' in result
49 | assert 'f1' in result
50 | assert 'areaUnderROC' in result
51 |
52 | evaluator._mllib_metrics = []
53 |
54 | result = evaluator(response_df)
55 |
56 |
57 | def test_evaluate_synthetic(
58 | users_df : DataFrame
59 | ):
60 | import pandas as pd
61 | pd.options.mode.chained_assignment = None
62 |
63 | result = evaluate_synthetic(
64 | users_df.sample(0.5).drop('user_id'),
65 | users_df.sample(0.5).drop('user_id')
66 | )
67 |
68 | assert result['LogisticDetection'] is not None
69 | assert result['SVCDetection'] is not None
70 | assert result['KSTest'] is not None
71 | assert result['ContinuousKLDivergence'] is not None
72 |
73 |
74 | def test_kstest(
75 | users_df : DataFrame
76 | ):
77 | result = ks_test(
78 | df=users_df.select('user_attr_1', 'user_attr_2'),
79 | predCol='user_attr_1',
80 | labelCol='user_attr_2'
81 | )
82 |
83 | assert isinstance(result, float)
84 |
85 |
86 | def test_kldiv(
87 | users_df : DataFrame
88 | ):
89 | result = kl_divergence(
90 | df=users_df.select('user_attr_1', 'user_attr_2'),
91 | predCol='user_attr_1',
92 | labelCol='user_attr_2'
93 | )
94 |
95 | assert isinstance(result, float)
96 |
--------------------------------------------------------------------------------
/tests/test_generators.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | import numpy as np
4 | import pandas as pd
5 | import pyspark.sql.functions as sf
6 | from pyspark.sql import DataFrame
7 |
8 | from sim4rec.modules import (
9 | RealDataGenerator,
10 | SDVDataGenerator,
11 | CompositeGenerator
12 | )
13 |
14 | SEED = 1234
15 |
16 |
17 | @pytest.fixture(scope="function")
18 | def real_gen() -> RealDataGenerator:
19 | return RealDataGenerator(label='real', seed=SEED)
20 |
21 |
22 | @pytest.fixture(scope="function")
23 | def synth_gen() -> SDVDataGenerator:
24 | return SDVDataGenerator(
25 | label='synth',
26 | id_column_name='user_id',
27 | model_name='gaussiancopula',
28 | parallelization_level=2,
29 | device_name='cpu',
30 | seed=SEED
31 | )
32 |
33 |
34 | @pytest.fixture(scope="function")
35 | def comp_gen(real_gen : RealDataGenerator, synth_gen : SDVDataGenerator) -> CompositeGenerator:
36 | return CompositeGenerator(
37 | generators=[real_gen, synth_gen],
38 | label='composite',
39 | weights=[0.5, 0.5]
40 | )
41 |
42 |
43 | def test_realdatagenerator_fit(real_gen : RealDataGenerator, users_df : DataFrame):
44 | real_gen.fit(users_df)
45 |
46 | assert real_gen._fit_called
47 | assert real_gen._source_df.count() == users_df.count()
48 |
49 |
50 | def test_sdvdatagenerator_fit(synth_gen : SDVDataGenerator, users_df : DataFrame):
51 | synth_gen.fit(users_df)
52 |
53 | assert synth_gen._fit_called
54 | assert isinstance(synth_gen._model.sample(100), pd.DataFrame)
55 |
56 |
57 | def test_realdatagenerator_generate(real_gen : RealDataGenerator, users_df : DataFrame):
58 | real_gen.fit(users_df)
59 |
60 | assert real_gen.generate(5).count() == 5
61 | assert real_gen._df.count() == 5
62 | assert real_gen.getDataSize() == 5
63 |
64 |
65 | def test_sdvdatagenerator_generate(synth_gen : SDVDataGenerator, users_df : DataFrame):
66 | synth_gen.fit(users_df)
67 |
68 | assert synth_gen.generate(100).count() == 100
69 | assert synth_gen._df.count() == 100
70 | assert synth_gen.getDataSize() == 100
71 |
72 |
73 | def test_compositegenerator_generate(
74 | real_gen : RealDataGenerator,
75 | synth_gen : SDVDataGenerator,
76 | comp_gen : CompositeGenerator,
77 | users_df : DataFrame
78 | ):
79 | real_gen.fit(users_df)
80 | synth_gen.fit(users_df)
81 | comp_gen.generate(10)
82 |
83 | assert real_gen.getDataSize() == 5
84 | assert synth_gen.getDataSize() == 5
85 |
86 | comp_gen.setWeights([1.0, 0.0])
87 | comp_gen.generate(5)
88 |
89 | assert real_gen.getDataSize() == 5
90 | assert synth_gen.getDataSize() == 0
91 |
92 | comp_gen.setWeights([0.0, 1.0])
93 | comp_gen.generate(10)
94 |
95 | assert real_gen.getDataSize() == 0
96 | assert synth_gen.getDataSize() == 10
97 |
98 |
99 | def test_realdatagenerator_sample(real_gen : RealDataGenerator, users_df : DataFrame):
100 | real_gen.fit(users_df)
101 | _ = real_gen.generate(5)
102 |
103 | assert real_gen.sample(1.0).count() == 5
104 | assert real_gen.sample(0.5).count() == 2
105 | assert real_gen.sample(0.0).count() == 0
106 |
107 |
108 | def test_sdvdatagenerator_sample(synth_gen : SDVDataGenerator, users_df : DataFrame):
109 | synth_gen.fit(users_df)
110 | _ = synth_gen.generate(100)
111 |
112 | assert synth_gen.sample(1.0).count() == 100
113 | assert synth_gen.sample(0.5).count() == 46
114 | assert synth_gen.sample(0.0).count() == 0
115 |
116 |
117 | def test_compositegenerator_sample(
118 | real_gen : RealDataGenerator,
119 | synth_gen : SDVDataGenerator,
120 | comp_gen : CompositeGenerator,
121 | users_df : DataFrame
122 | ):
123 | real_gen.fit(users_df)
124 | synth_gen.fit(users_df)
125 | comp_gen.generate(10)
126 |
127 | assert comp_gen.sample(1.0).count() == 10
128 | assert comp_gen.sample(0.5).count() == 4
129 | assert comp_gen.sample(0.0).count() == 0
130 |
131 | df = comp_gen.sample(1.0).toPandas()
132 | assert df['user_id'].str.startswith('synth').sum() == 5
133 |
134 | comp_gen.setWeights([1.0, 0.0])
135 | comp_gen.generate(5)
136 | df = comp_gen.sample(1.0).toPandas()
137 | assert df['user_id'].str.startswith('synth').sum() == 0
138 |
139 | comp_gen.setWeights([0.0, 1.0])
140 | comp_gen.generate(10)
141 | df = comp_gen.sample(1.0).toPandas()
142 | assert df['user_id'].str.startswith('synth').sum() == 10
143 |
144 |
145 | def test_realdatagenerator_iterdiff(real_gen : RealDataGenerator, users_df : DataFrame):
146 | real_gen.fit(users_df)
147 | generated_1 = real_gen.generate(5).toPandas()
148 | sampled_1 = real_gen.sample(0.5).toPandas()
149 |
150 | generated_2 = real_gen.generate(5).toPandas()
151 | sampled_2 = real_gen.sample(0.5).toPandas()
152 |
153 | assert not generated_1.equals(generated_2)
154 | assert not sampled_1.equals(sampled_2)
155 |
156 |
157 | def test_sdvdatagenerator_iterdiff(synth_gen : SDVDataGenerator, users_df : DataFrame):
158 | synth_gen.fit(users_df)
159 |
160 | generated_1 = synth_gen.generate(100).toPandas()
161 | sampled_1 = synth_gen.sample(0.1).toPandas()
162 |
163 | generated_2 = synth_gen.generate(100).toPandas()
164 | sampled_2 = synth_gen.sample(0.1).toPandas()
165 |
166 | assert not generated_1.equals(generated_2)
167 | assert not sampled_1.equals(sampled_2)
168 |
169 |
170 | def test_compositegenerator_iterdiff(
171 | real_gen : RealDataGenerator,
172 | synth_gen : SDVDataGenerator,
173 | comp_gen : CompositeGenerator,
174 | users_df : DataFrame
175 | ):
176 | real_gen.fit(users_df)
177 | synth_gen.fit(users_df)
178 | comp_gen.generate(10)
179 |
180 | sampled_1 = comp_gen.sample(0.5).toPandas()
181 | sampled_2 = comp_gen.sample(0.5).toPandas()
182 |
183 | assert not sampled_1.equals(sampled_2)
184 |
185 |
186 | def test_sdvdatagenerator_partdiff(synth_gen : SDVDataGenerator, users_df : DataFrame):
187 | synth_gen.fit(users_df)
188 |
189 | generated = synth_gen.generate(100)\
190 | .drop('user_id')\
191 | .withColumn('__partition_id', sf.spark_partition_id())
192 | df_1 = generated.filter(sf.col('__partition_id') == 0)\
193 | .drop('__partition_id').toPandas()
194 | df_2 = generated.filter(sf.col('__partition_id') == 1)\
195 | .drop('__partition_id').toPandas()
196 |
197 | assert not df_1.equals(df_2)
198 |
199 |
200 | def test_sdv_save_load(
201 | synth_gen : SDVDataGenerator,
202 | users_df : DataFrame,
203 | tmp_path
204 | ):
205 | synth_gen.fit(users_df)
206 | synth_gen.save_model(f'{tmp_path}/generator.pkl')
207 |
208 | assert os.path.isfile(f'{tmp_path}/generator.pkl')
209 |
210 | g = SDVDataGenerator.load(f'{tmp_path}/generator.pkl')
211 |
212 | assert g.getLabel() == 'synth'
213 | assert g._id_col_name == 'user_id'
214 | assert g._model_name == 'gaussiancopula'
215 | assert g.getParallelizationLevel() == 2
216 | assert g.getDevice() == 'cpu'
217 | assert g.getInitSeed() == 1234
218 | assert g._fit_called
219 | assert hasattr(g, '_model')
220 | assert hasattr(g, '_schema')
221 |
--------------------------------------------------------------------------------
/tests/test_responses.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | from pyspark.sql import DataFrame, SparkSession
4 | from pyspark.ml.feature import VectorAssembler
5 | from pyspark.ml.linalg import DenseVector
6 |
7 | from sim4rec.response import (
8 | BernoulliResponse,
9 | NoiseResponse,
10 | ConstantResponse,
11 | CosineSimilatiry,
12 | ParametricResponseFunction
13 | )
14 |
15 | SEED = 1234
16 |
17 |
18 | @pytest.fixture(scope="function")
19 | def users_va_left() -> VectorAssembler:
20 | return VectorAssembler(
21 | inputCols=[f'user_attr_{i}' for i in range(0, 5)],
22 | outputCol='__v1'
23 | )
24 |
25 |
26 | @pytest.fixture(scope="function")
27 | def users_va_right() -> VectorAssembler:
28 | return VectorAssembler(
29 | inputCols=[f'user_attr_{i}' for i in range(5, 10)],
30 | outputCol='__v2'
31 | )
32 |
33 |
34 | @pytest.fixture(scope="function")
35 | def bernoulli_resp() -> BernoulliResponse:
36 | return BernoulliResponse(
37 | inputCol='__proba',
38 | outputCol='relevance',
39 | seed=SEED
40 | )
41 |
42 |
43 | @pytest.fixture(scope="function")
44 | def noise_resp() -> NoiseResponse:
45 | return NoiseResponse(
46 | mu=0.5,
47 | sigma=0.2,
48 | outputCol='__noise',
49 | seed=SEED
50 | )
51 |
52 |
53 | @pytest.fixture(scope="function")
54 | def const_resp() -> ConstantResponse:
55 | return ConstantResponse(
56 | value=0.5,
57 | outputCol='__const',
58 | )
59 |
60 |
61 | @pytest.fixture(scope="function")
62 | def cosine_resp() -> CosineSimilatiry:
63 | return CosineSimilatiry(
64 | inputCols=['__v1', '__v2'],
65 | outputCol='__cosine'
66 | )
67 |
68 |
69 | @pytest.fixture(scope="function")
70 | def param_resp() -> ParametricResponseFunction:
71 | return ParametricResponseFunction(
72 | inputCols=['__const', '__cosine'],
73 | outputCol='__proba',
74 | weights=[0.5, 0.5]
75 | )
76 |
77 |
78 | @pytest.fixture(scope="module")
79 | def random_df(spark : SparkSession) -> DataFrame:
80 | data = [
81 | (0, 0.0),
82 | (1, 0.2),
83 | (2, 0.4),
84 | (3, 0.6),
85 | (4, 1.0)
86 | ]
87 | return spark.createDataFrame(data=data, schema=['id', '__proba'])
88 |
89 |
90 | @pytest.fixture(scope="module")
91 | def vector_df(spark : SparkSession) -> DataFrame:
92 | data = [
93 | (0, DenseVector([1.0, 0.0]), DenseVector([0.0, 1.0])),
94 | (1, DenseVector([-1.0, 0.0]), DenseVector([1.0, 0.0])),
95 | (2, DenseVector([1.0, 0.0]), DenseVector([1.0, 0.0])),
96 | (3, DenseVector([0.5, 0.5]), DenseVector([-0.5, 0.5])),
97 | (4, DenseVector([0.0, 0.0]), DenseVector([1.0, 0.0]))
98 | ]
99 | return spark.createDataFrame(data=data, schema=['id', '__v1', '__v2'])
100 |
101 |
102 | def test_bernoulli_transform(
103 | bernoulli_resp : BernoulliResponse,
104 | random_df : DataFrame
105 | ):
106 | result = bernoulli_resp.transform(random_df).toPandas().sort_values(['id'])
107 |
108 | assert 'relevance' in result.columns
109 | assert len(result) == 5
110 | assert set(result['relevance']) == set([0, 1])
111 | assert list(result['relevance'][:5]) == [0, 0, 0, 1, 1]
112 |
113 |
114 | def test_bernoulli_iterdiff(
115 | bernoulli_resp : BernoulliResponse,
116 | random_df : DataFrame
117 | ):
118 | result1 = bernoulli_resp.transform(random_df).toPandas()
119 | result1 = result1.sort_values(['id'])
120 | result2 = bernoulli_resp.transform(random_df).toPandas()
121 | result2 = result2.sort_values(['id'])
122 |
123 | assert list(result1['relevance'][:5]) != list(result2['relevance'][:5])
124 |
125 |
126 | def test_noise_transform(
127 | noise_resp : NoiseResponse,
128 | random_df : DataFrame
129 | ):
130 | result = noise_resp.transform(random_df).toPandas().sort_values(['id'])
131 |
132 | assert '__noise' in result.columns
133 | assert len(result) == 5
134 | assert np.allclose(result['__noise'][0], 0.6117798825975235)
135 |
136 |
137 | def test_noise_iterdiff(
138 | noise_resp : NoiseResponse,
139 | random_df : DataFrame
140 | ):
141 | result1 = noise_resp.transform(random_df).toPandas().sort_values(['id'])
142 | result2 = noise_resp.transform(random_df).toPandas().sort_values(['id'])
143 |
144 | assert result1['__noise'][0] != result2['__noise'][0]
145 |
146 |
147 | def test_const_transform(
148 | const_resp : ConstantResponse,
149 | random_df : DataFrame
150 | ):
151 | result = const_resp.transform(random_df).toPandas().sort_values(['id'])
152 |
153 | assert '__const' in result.columns
154 | assert len(result) == 5
155 | assert list(result['__const']) == [0.5] * 5
156 |
157 |
158 | def test_cosine_transform(
159 | cosine_resp : CosineSimilatiry,
160 | vector_df : DataFrame
161 | ):
162 | result = cosine_resp.transform(vector_df)\
163 | .drop('__v1', '__v2')\
164 | .toPandas()\
165 | .sort_values(['id'])
166 |
167 | assert '__cosine' in result.columns
168 | assert len(result) == 5
169 | assert list(result['__cosine']) == [0.5, 0.0, 1.0, 0.5, 0.0]
170 |
171 |
172 | def test_paramresp_transform(
173 | param_resp : ParametricResponseFunction,
174 | const_resp : ConstantResponse,
175 | cosine_resp : CosineSimilatiry,
176 | vector_df : DataFrame
177 | ):
178 | df = const_resp.transform(vector_df)
179 | df = cosine_resp.transform(df)
180 |
181 | result = param_resp.transform(df)\
182 | .drop('__v1', '__v2')\
183 | .toPandas()\
184 | .sort_values(['id'])
185 |
186 | assert '__proba' in result.columns
187 | assert len(result) == 5
188 | assert list(result['__proba']) == [0.5, 0.25, 0.75, 0.5, 0.25]
189 |
190 | r = result['__const'][0] / 2 +\
191 | result['__cosine'][0] / 2
192 | assert result['__proba'][0] == r
193 |
194 | param_resp.setWeights([1.0, 0.0])
195 | result = param_resp.transform(df)\
196 | .drop('__v1', '__v2')\
197 | .toPandas()\
198 | .sort_values(['id'])
199 | assert result['__proba'][0] == result['__const'][0]
200 |
201 | param_resp.setWeights([0.0, 1.0])
202 | result = param_resp.transform(df)\
203 | .drop('__v1', '__v2')\
204 | .toPandas()\
205 | .sort_values(['id'])
206 | assert result['__proba'][0] == result['__cosine'][0]
207 |
--------------------------------------------------------------------------------
/tests/test_selectors.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pyspark.sql import DataFrame
3 |
4 | from sim4rec.modules import (
5 | CrossJoinItemEstimator,
6 | CrossJoinItemTransformer
7 | )
8 |
9 | SEED = 1234
10 | K = 5
11 |
12 |
13 | @pytest.fixture(scope="function")
14 | def estimator() -> CrossJoinItemEstimator:
15 | return CrossJoinItemEstimator(
16 | k=K,
17 | userKeyColumn='user_id',
18 | itemKeyColumn='item_id',
19 | seed=SEED
20 | )
21 |
22 |
23 | @pytest.fixture(scope="function")
24 | def transformer(
25 | estimator : CrossJoinItemEstimator,
26 | items_df : DataFrame
27 | ) -> CrossJoinItemTransformer:
28 | return estimator.fit(items_df)
29 |
30 |
31 | def test_crossjoinestimator_fit(
32 | transformer : CrossJoinItemTransformer,
33 | items_df : DataFrame
34 | ):
35 | assert transformer._item_df is not None
36 | assert transformer._item_df.toPandas().equals(items_df.toPandas())
37 |
38 |
39 | def test_crossjointransformer_transform(
40 | transformer : CrossJoinItemTransformer,
41 | users_df : DataFrame
42 | ):
43 | sample = transformer.transform(users_df).toPandas()
44 | sample = sample.sort_values(['user_id', 'item_id'])
45 |
46 | assert 'user_id' in sample.columns[0]
47 | assert 'item_id' in sample.columns[1]
48 | assert len(sample) == users_df.count() * K
49 | assert sample.iloc[K, 0] == 1
50 | assert sample.iloc[K, 1] == 0
51 |
52 |
53 | def test_crossjointransformer_iterdiff(
54 | transformer : CrossJoinItemTransformer,
55 | users_df : DataFrame
56 | ):
57 | sample_1 = transformer.transform(users_df).toPandas()
58 | sample_2 = transformer.transform(users_df).toPandas()
59 | sample_1 = sample_1.sort_values(['user_id', 'item_id'])
60 | sample_2 = sample_2.sort_values(['user_id', 'item_id'])
61 |
62 | assert not sample_1.equals(sample_2)
63 |
64 |
65 | def test_crossjointransformer_fixedseed(
66 | transformer : CrossJoinItemTransformer,
67 | users_df : DataFrame,
68 | items_df : DataFrame
69 | ):
70 | e = CrossJoinItemEstimator(
71 | k=K,
72 | userKeyColumn='user_id',
73 | itemKeyColumn='item_id',
74 | seed=SEED
75 | )
76 | t = e.fit(items_df)
77 |
78 | sample_1 = transformer.transform(users_df).toPandas()
79 | sample_2 = t.transform(users_df).toPandas()
80 | sample_1 = sample_1.sort_values(['user_id', 'item_id'])
81 | sample_2 = sample_2.sort_values(['user_id', 'item_id'])
82 |
83 | assert sample_1.equals(sample_2)
84 |
--------------------------------------------------------------------------------
/tests/test_simulator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import pytest
4 | import pyspark.sql.functions as sf
5 | from pyspark.sql import DataFrame, SparkSession
6 | from pyspark.ml import PipelineModel
7 | from pyspark.ml.feature import VectorAssembler
8 |
9 | from sim4rec.modules import (
10 | Simulator,
11 | RealDataGenerator,
12 | SDVDataGenerator,
13 | CompositeGenerator,
14 | CrossJoinItemEstimator,
15 | CrossJoinItemTransformer
16 | )
17 | from sim4rec.response import CosineSimilatiry
18 |
19 |
20 | SEED = 1234
21 |
22 |
23 | @pytest.fixture(scope="module")
24 | def real_users_gen(users_df : DataFrame) -> RealDataGenerator:
25 | gen = RealDataGenerator(label='real', seed=SEED)
26 | gen.fit(users_df)
27 | gen.generate(5)
28 |
29 | return gen
30 |
31 |
32 | @pytest.fixture(scope="module")
33 | def synth_users_gen(users_df : DataFrame) -> SDVDataGenerator:
34 | gen = SDVDataGenerator(
35 | label='synth',
36 | id_column_name='user_id',
37 | model_name='gaussiancopula',
38 | parallelization_level=2,
39 | device_name='cpu',
40 | seed=SEED
41 | )
42 | gen.fit(users_df)
43 | gen.generate(5)
44 |
45 | return gen
46 |
47 |
48 | @pytest.fixture(scope="module")
49 | def comp_users_gen(
50 | real_users_gen : RealDataGenerator,
51 | synth_users_gen : SDVDataGenerator
52 | ) -> CompositeGenerator:
53 | return CompositeGenerator(
54 | generators=[real_users_gen, synth_users_gen],
55 | label='composite',
56 | weights=[0.5, 0.5]
57 | )
58 |
59 |
60 | @pytest.fixture(scope="module")
61 | def real_items_gen(items_df : DataFrame) -> RealDataGenerator:
62 | gen = RealDataGenerator(label='real', seed=SEED)
63 | gen.fit(items_df)
64 | gen.generate(5)
65 |
66 | return gen
67 |
68 |
69 | @pytest.fixture(scope="module")
70 | def selector(items_df : DataFrame) -> CrossJoinItemTransformer:
71 | estimator = CrossJoinItemEstimator(
72 | k=3,
73 | userKeyColumn='user_id',
74 | itemKeyColumn='item_id',
75 | seed=SEED
76 | )
77 | return estimator.fit(items_df)
78 |
79 |
80 | @pytest.fixture(scope="module")
81 | def pipeline() -> PipelineModel:
82 | va_left = VectorAssembler(inputCols=['user_attr_1', 'user_attr_2'], outputCol='__v1')
83 | va_right = VectorAssembler(inputCols=['item_attr_1', 'item_attr_2'], outputCol='__v2')
84 |
85 | c = CosineSimilatiry(inputCols=['__v1', '__v2'], outputCol='response')
86 |
87 | return PipelineModel(stages=[va_left, va_right, c])
88 |
89 |
90 | @pytest.fixture(scope="function")
91 | def simulator_empty(
92 | comp_users_gen : CompositeGenerator,
93 | real_items_gen : RealDataGenerator,
94 | spark : SparkSession,
95 | tmp_path
96 | ) -> Simulator:
97 | shutil.rmtree(str(tmp_path / 'sim_empty'), ignore_errors=True)
98 | return Simulator(
99 | user_gen=comp_users_gen,
100 | item_gen=real_items_gen,
101 | log_df=None,
102 | user_key_col='user_id',
103 | item_key_col='item_id',
104 | data_dir=str(tmp_path / 'sim_empty'),
105 | spark_session=spark
106 | )
107 |
108 |
109 | @pytest.fixture(scope="function")
110 | def simulator_with_log(
111 | comp_users_gen : CompositeGenerator,
112 | real_items_gen : RealDataGenerator,
113 | log_df : DataFrame,
114 | spark : SparkSession,
115 | tmp_path
116 | ) -> Simulator:
117 | shutil.rmtree(str(tmp_path / 'sim_with_log'), ignore_errors=True)
118 | return Simulator(
119 | user_gen=comp_users_gen,
120 | item_gen=real_items_gen,
121 | log_df=log_df,
122 | user_key_col='user_id',
123 | item_key_col='item_id',
124 | data_dir=str(tmp_path / 'sim_with_log'),
125 | spark_session=spark
126 | )
127 |
128 |
129 | def test_simulator_init(
130 | simulator_empty : Simulator,
131 | simulator_with_log : Simulator
132 | ):
133 | assert os.path.isdir(simulator_empty._data_dir)
134 | assert os.path.isdir(simulator_with_log._data_dir)
135 |
136 | assert simulator_empty._log is None
137 | assert Simulator.ITER_COLUMN in simulator_with_log._log.columns
138 |
139 | assert simulator_with_log._log.count() == 5
140 | assert os.path.isdir(f'{simulator_with_log._data_dir}/{simulator_with_log.log_filename}/{Simulator.ITER_COLUMN}=start')
141 |
142 |
143 | def test_simulator_clearlog(
144 | simulator_with_log : Simulator
145 | ):
146 | simulator_with_log.clear_log()
147 |
148 | assert simulator_with_log.log is None
149 | assert simulator_with_log._log_schema is None
150 |
151 |
152 | def test_simulator_updatelog(
153 | simulator_empty : Simulator,
154 | simulator_with_log : Simulator,
155 | log_df : DataFrame
156 | ):
157 | simulator_empty.update_log(log_df, iteration=0)
158 | simulator_with_log.update_log(log_df, iteration=0)
159 |
160 | assert simulator_empty.log.count() == 5
161 | assert simulator_with_log.log.count() == 10
162 |
163 | assert set(simulator_empty.log.toPandas()[Simulator.ITER_COLUMN].unique()) == set([0])
164 | assert set(simulator_with_log.log.toPandas()[Simulator.ITER_COLUMN].unique()) == set(['0', 'start'])
165 |
166 | assert os.path.isdir(f'{simulator_empty._data_dir}/{simulator_empty.log_filename}/{Simulator.ITER_COLUMN}=0')
167 | assert os.path.isdir(f'{simulator_with_log._data_dir}/{simulator_with_log.log_filename}/{Simulator.ITER_COLUMN}=start')
168 | assert os.path.isdir(f'{simulator_with_log._data_dir}/{simulator_with_log.log_filename}/{Simulator.ITER_COLUMN}=0')
169 |
170 |
171 | def test_simulator_sampleusers(
172 | simulator_empty : Simulator
173 | ):
174 | sampled1 = simulator_empty.sample_users(0.5)\
175 | .toPandas().sort_values(['user_id'])
176 | sampled2 = simulator_empty.sample_users(0.5)\
177 | .toPandas().sort_values(['user_id'])
178 |
179 | assert not sampled1.equals(sampled2)
180 |
181 | assert len(sampled1) == 2
182 | assert len(sampled2) == 4
183 |
184 |
185 | def test_simulator_sampleitems(
186 | simulator_empty : Simulator
187 | ):
188 | sampled1 = simulator_empty.sample_items(0.5)\
189 | .toPandas().sort_values(['item_id'])
190 | sampled2 = simulator_empty.sample_items(0.5)\
191 | .toPandas().sort_values(['item_id'])
192 |
193 | assert not sampled1.equals(sampled2)
194 |
195 | assert len(sampled1) == 2
196 | assert len(sampled2) == 2
197 |
198 |
199 | def test_simulator_getuseritem(
200 | simulator_with_log : Simulator,
201 | selector : CrossJoinItemTransformer,
202 | users_df : DataFrame
203 | ):
204 | users = users_df.filter(sf.col('user_id').isin([0, 1, 2]))
205 | pairs, log = simulator_with_log.get_user_items(users, selector)
206 |
207 | assert pairs.count() == users.count() * selector._k
208 | assert log.count() == 5
209 |
210 | assert 'user_id' in pairs.columns
211 | assert 'item_id' in pairs.columns
212 |
213 | assert set(log.toPandas()['user_id']) == set([0, 1, 2])
214 |
215 |
216 | def test_simulator_responses(
217 | simulator_empty : Simulator,
218 | pipeline : PipelineModel,
219 | users_df : DataFrame,
220 | items_df : DataFrame,
221 | log_df : DataFrame
222 | ):
223 | resp = simulator_empty.sample_responses(
224 | recs_df=log_df,
225 | user_features=users_df,
226 | item_features=items_df,
227 | action_models=pipeline
228 | ).drop('__v1', '__v2').toPandas().sort_values(['user_id'])
229 |
230 | assert 'user_id' in resp.columns
231 | assert 'item_id' in resp.columns
232 | assert 'response' in resp.columns
233 | assert len(resp) == log_df.count()
234 | assert resp['response'].values[0] == 1
235 |
--------------------------------------------------------------------------------