├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── images ├── deep_hash_embeddings.png ├── memorization_vs_generalization.png ├── moe_cohort_plus_id_lookup.png ├── schematic_multi_task_estimator.png ├── user_feature_based_lookup.png ├── user_feature_based_lookup_2.png └── user_id_embedding_lookup.png ├── src ├── __init__.py ├── conftest.py ├── deep_hash_embeddings.py ├── multi_task_estimator.py ├── user_cohort_embeddings.py ├── user_mo_representations.py └── userid_lookup_representaion.py └── tests ├── conftest.py ├── test_dhe_rep.py ├── test_user_cohort_rep.py ├── test_user_mor.py └── test_userid_lookup_rep.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .DS_Store 6 | .vscode/ 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Gaurav Chakravorty 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multiple ways to model user preference in recommender systems 2 | 3 | Modeling the preference of the user as an input to the retrieval or ranking model has been a successful strategy in recommender systems. In this repo, we will show how to do it effectively. We will use the example of a ranking model but the approach equally applies to a retrieval model. The overall schematic of a ranking model is in ![Fig 0: schematic_multi_task_estimator](./images/schematic_multi_task_estimator.png) 4 | 5 | The conventional approach of representing a user is using an embedding table lookup as shown in the image below and implemented in [multi_task_estimator.py](./src/multi_task_estimator.py). 6 | ![Fig 1: user_id_embedding_lookup](./images/user_id_embedding_lookup.png) 7 | 8 | 9 | We will also look at the schematic of an implementation using Deep Hash Embeddings. 10 | ![Fig 2: deep_hash_embeddings](./images/deep_hash_embeddings.png) 11 | 12 | Then we will look at an approach where we reuse the machinery of Deep Hash Embeddings but seed it with an embedding that is looked up in a relatively small table as a function of the user's features (not including user id) 13 | ![Fig 3: user_feature_based_lookup](./images/user_feature_based_lookup_2.png) 14 | 15 | Finally we will put id lookup and cohort lookup together using and idea from [this paper from Google](https://arxiv.org/abs/2210.14309). This image from the paper captures the idea: 16 | ![Fig 4: Memorization vs Generalization](./images/memorization_vs_generalization.png) 17 | 18 | The implementation in our repository is: 19 | ![Fig 5: Mixture of Representations](./images/moe_cohort_plus_id_lookup.png) 20 | 21 | ### Customization 22 | If you want to allocate more of your memorization capacity to a certain cohort, for instance you could care more about US users, you could do that by encoding the weight in the loss function and perhaps adding the country / feature in the input to the Mixture of Representations tower. 23 | 24 | ## Contributing 25 | 26 | Run `pytest tests/*` from main directory before submitting a PR. 27 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gauravchak/user_preference_modeling/fbe8b42275800be54b10681d7bfe944bd8b462c0/__init__.py -------------------------------------------------------------------------------- /images/deep_hash_embeddings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gauravchak/user_preference_modeling/fbe8b42275800be54b10681d7bfe944bd8b462c0/images/deep_hash_embeddings.png -------------------------------------------------------------------------------- /images/memorization_vs_generalization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gauravchak/user_preference_modeling/fbe8b42275800be54b10681d7bfe944bd8b462c0/images/memorization_vs_generalization.png -------------------------------------------------------------------------------- /images/moe_cohort_plus_id_lookup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gauravchak/user_preference_modeling/fbe8b42275800be54b10681d7bfe944bd8b462c0/images/moe_cohort_plus_id_lookup.png -------------------------------------------------------------------------------- /images/schematic_multi_task_estimator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gauravchak/user_preference_modeling/fbe8b42275800be54b10681d7bfe944bd8b462c0/images/schematic_multi_task_estimator.png -------------------------------------------------------------------------------- /images/user_feature_based_lookup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gauravchak/user_preference_modeling/fbe8b42275800be54b10681d7bfe944bd8b462c0/images/user_feature_based_lookup.png -------------------------------------------------------------------------------- /images/user_feature_based_lookup_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gauravchak/user_preference_modeling/fbe8b42275800be54b10681d7bfe944bd8b462c0/images/user_feature_based_lookup_2.png -------------------------------------------------------------------------------- /images/user_id_embedding_lookup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gauravchak/user_preference_modeling/fbe8b42275800be54b10681d7bfe944bd8b462c0/images/user_id_embedding_lookup.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gauravchak/user_preference_modeling/fbe8b42275800be54b10681d7bfe944bd8b462c0/src/__init__.py -------------------------------------------------------------------------------- /src/conftest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pytest 4 | 5 | # Add project root to Python path 6 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | -------------------------------------------------------------------------------- /src/deep_hash_embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from src.multi_task_estimator import MultiTaskEstimator 7 | 8 | 9 | class DHERepresentation(MultiTaskEstimator): 10 | """Same as MultiTaskEstimator except using Deep hash Embeddings idea""" 11 | 12 | def __init__( 13 | self, 14 | num_tasks: int, 15 | user_id_embedding_dim: int, 16 | user_features_size: int, 17 | item_id_hash_size: int, 18 | item_id_embedding_dim: int, 19 | item_features_size: int, 20 | cross_features_size: int, 21 | user_value_weights: List[float], 22 | dhe_stack_in_embedding_dim: int, 23 | ) -> None: 24 | """ 25 | params: 26 | num_tasks (T): The tasks to compute estimates of 27 | user_id_embedding_dim (DU): internal dimension 28 | user_features_size (IU): input feature size for users 29 | item_id_hash_size: the size of the embedding table for items 30 | item_id_embedding_dim (DI): internal dimension 31 | item_features_size: (II) input feature size for items 32 | cross_features_size: (IC) size of cross features 33 | user_value_weights: T dimensional weights, such that a linear 34 | combination of point-wise immediate rewards is the best predictor 35 | of long term user satisfaction. 36 | dhe_stack_in_embedding_dim (D_dhe_in): input emb dim for DHE 37 | """ 38 | super(DHERepresentation, self).__init__( 39 | num_tasks=num_tasks, 40 | user_id_embedding_dim=user_id_embedding_dim, 41 | user_features_size=user_features_size, 42 | item_id_hash_size=item_id_hash_size, 43 | item_id_embedding_dim=item_id_embedding_dim, 44 | item_features_size=item_features_size, 45 | cross_features_size=cross_features_size, 46 | user_value_weights=user_value_weights, 47 | ) 48 | # In DHE paper this was more than DU 49 | # In Twitter, GNN they found something similar to work where 50 | # the final layer was about one third of the start input dim. 51 | self.dhe_stack_in: int = dhe_stack_in_embedding_dim 52 | self.dhe_stack = nn.Sequential( 53 | nn.Linear(self.dhe_stack_in, user_id_embedding_dim), 54 | nn.ReLU(), 55 | nn.Linear(user_id_embedding_dim, user_id_embedding_dim), 56 | nn.ReLU(), 57 | nn.Linear(user_id_embedding_dim, user_id_embedding_dim), 58 | nn.ReLU(), 59 | nn.Linear(user_id_embedding_dim, user_id_embedding_dim), 60 | ) 61 | 62 | def hash_fn( 63 | self, 64 | user_id: torch.Tensor, # [B] 65 | ) -> torch.Tensor: 66 | """ 67 | Returns [B, self.dhe_stack_in] 68 | WIP, Need to replace with a proper hash function 69 | """ 70 | return torch.randn(user_id.shape[0], self.dhe_stack_in) # [B, D_dhe_in] 71 | 72 | def get_user_embedding( 73 | self, 74 | user_id: torch.Tensor, # [B] 75 | user_features: torch.Tensor, # [B, IU] 76 | ) -> torch.Tensor: 77 | """ 78 | Returns: [B, user_id_embedding_dim] 79 | """ 80 | user_hash = self.hash_fn(user_id) # [B, D_dhe_in] 81 | user_id_embeddings = self.dhe_stack(user_hash) # [B, DU] 82 | return user_id_embeddings 83 | -------------------------------------------------------------------------------- /src/multi_task_estimator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a specific instance of a final ranker in a recommender system. 3 | """ 4 | 5 | from typing import List 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class MultiTaskEstimator(nn.Module): 12 | """A core component of multi-task ranking systems where 13 | we compute estimates of the getting those binary feedback 14 | labels from the user.""" 15 | 16 | def __init__( 17 | self, 18 | num_tasks: int, 19 | user_id_embedding_dim: int, 20 | user_features_size: int, 21 | item_id_hash_size: int, 22 | item_id_embedding_dim: int, 23 | item_features_size: int, 24 | cross_features_size: int, 25 | user_value_weights: List[float], 26 | ) -> None: 27 | """ 28 | params: 29 | num_tasks (T): The tasks to compute estimates of 30 | user_id_embedding_dim (DU): internal dimension 31 | user_features_size (IU): input feature size for users 32 | item_id_hash_size: the size of the embedding table for items 33 | item_id_embedding_dim (DI): internal dimension 34 | item_features_size: (II) input feature size for items 35 | cross_features_size: (IC) size of cross features 36 | user_value_weights: T dimensional weights, such that a linear 37 | combination of point-wise immediate rewards is the best predictor 38 | of long term user satisfaction. 39 | """ 40 | super(MultiTaskEstimator, self).__init__() 41 | self.user_value_weights = torch.tensor( 42 | user_value_weights 43 | ) # noqa TODO add device input. 44 | self.user_id_embedding_dim = user_id_embedding_dim 45 | self.item_id_embedding_dim = item_id_embedding_dim 46 | 47 | # Embedding layers for item ids 48 | self.item_id_embedding_arch = nn.Embedding( 49 | item_id_hash_size, item_id_embedding_dim 50 | ) 51 | 52 | # Linear projection layer for user features 53 | self.user_features_layer = nn.Linear( 54 | in_features=user_features_size, out_features=user_id_embedding_dim 55 | ) # noqa 56 | 57 | # Linear projection layer for user features 58 | self.item_features_layer = nn.Linear( 59 | in_features=item_features_size, out_features=item_id_embedding_dim 60 | ) # noqa 61 | 62 | self.cross_feature_proc_dim = 128 63 | # Linear projection layer for cross features 64 | self.cross_features_layer = nn.Linear( 65 | in_features=cross_features_size, out_features=self.cross_feature_proc_dim 66 | ) 67 | 68 | # Linear layer for final prediction 69 | self.task_arch = nn.Linear( 70 | in_features=( 71 | 2 * user_id_embedding_dim 72 | + 2 * item_id_embedding_dim 73 | + self.cross_feature_proc_dim 74 | ), 75 | out_features=num_tasks, 76 | ) 77 | 78 | def get_user_embedding( 79 | self, 80 | user_id: torch.Tensor, # [B] 81 | user_features: torch.Tensor, # [B, IU] 82 | ) -> torch.Tensor: 83 | """Please implement in subclass""" 84 | raise NotImplementedError("Subclasses must implement get_user_embedding method") 85 | 86 | def process_features( 87 | self, 88 | user_id: torch.Tensor, # [B] 89 | user_features: torch.Tensor, # [B, IU] 90 | item_id: torch.Tensor, # [B] 91 | item_features: torch.Tensor, # [B, II] 92 | cross_features: torch.Tensor, # [B, IC] 93 | position: torch.Tensor, # [B] 94 | ) -> torch.Tensor: 95 | """ 96 | Process features. Separated from forward function so that we can change 97 | handling of forward and train_forward in future without duplicating 98 | feature processing. 99 | """ 100 | 101 | # Get user embedding 102 | user_id_embedding = self.get_user_embedding( 103 | user_id=user_id, 104 | user_features=user_features, 105 | ) 106 | # Embedding lookup for item ids 107 | item_id_embedding = self.item_id_embedding_arch(item_id) 108 | 109 | # Linear transformation for user features 110 | user_features_transformed = self.user_features_layer(user_features) 111 | 112 | # Linear transformation for item features 113 | item_features_transformed = self.item_features_layer(item_features) 114 | 115 | # Linear transformation for user features 116 | cross_features_transformed = self.cross_features_layer(cross_features) 117 | 118 | # Concatenate user embedding, user features, and item embedding 119 | combined_features = torch.cat( 120 | [ 121 | user_id_embedding, 122 | user_features_transformed, 123 | item_id_embedding, 124 | item_features_transformed, 125 | cross_features_transformed, 126 | ], 127 | dim=1, 128 | ) 129 | 130 | return combined_features 131 | 132 | def forward( 133 | self, 134 | user_id: torch.Tensor, # [B] 135 | user_features: torch.Tensor, # [B, IU] 136 | item_id: torch.Tensor, # [B] 137 | item_features: torch.Tensor, # [B, II] 138 | cross_features: torch.Tensor, # [B, IC] 139 | position: torch.Tensor, # [B] 140 | ) -> torch.Tensor: 141 | combined_features = self.process_features( 142 | user_id=user_id, 143 | user_features=user_features, 144 | item_id=item_id, 145 | item_features=item_features, 146 | cross_features=cross_features, 147 | position=position, 148 | ) 149 | # Compute per-task scores/logits 150 | ui_logits = self.task_arch(combined_features) # [B, T] 151 | 152 | return ui_logits 153 | 154 | def train_forward( 155 | self, 156 | user_id, 157 | user_features, 158 | item_id, 159 | item_features, # [B, II] 160 | cross_features, # [B, IC] 161 | position, # [B] 162 | labels, 163 | ) -> torch.Tensor: 164 | """Compute the loss during training""" 165 | # Get task logits using forward method 166 | ui_logits = self.forward( 167 | user_id=user_id, 168 | user_features=user_features, 169 | item_id=item_id, 170 | item_features=item_features, 171 | cross_features=cross_features, 172 | position=position, 173 | ) 174 | 175 | # Compute binary cross-entropy loss 176 | cross_entropy_loss = F.binary_cross_entropy_with_logits( 177 | input=ui_logits, target=labels.float(), reduction="sum" 178 | ) 179 | 180 | return cross_entropy_loss 181 | -------------------------------------------------------------------------------- /src/user_cohort_embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from src.multi_task_estimator import MultiTaskEstimator 8 | 9 | 10 | class UserCohortRepresentation(MultiTaskEstimator): 11 | """ 12 | Here to capture user preference, we infer the cluster/cohort this 13 | user is closest to based on user features and use the embeddings 14 | of that cluster. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | num_tasks: int, 20 | user_id_embedding_dim: int, 21 | user_features_size: int, 22 | item_id_hash_size: int, 23 | item_id_embedding_dim: int, 24 | item_features_size: int, 25 | cross_features_size: int, 26 | user_value_weights: List[float], 27 | cohort_table_size: int, 28 | cohort_lookup_dropout_rate: float, 29 | cohort_enable_topk_regularization: bool, 30 | ) -> None: 31 | super(UserCohortRepresentation, self).__init__( 32 | num_tasks=num_tasks, 33 | user_id_embedding_dim=user_id_embedding_dim, 34 | user_features_size=user_features_size, 35 | item_id_hash_size=item_id_hash_size, 36 | item_id_embedding_dim=item_id_embedding_dim, 37 | item_features_size=item_features_size, 38 | cross_features_size=cross_features_size, 39 | user_value_weights=user_value_weights, 40 | ) 41 | self.dhe_stack_in: int = user_id_embedding_dim 42 | self.dhe_stack = nn.Sequential( 43 | nn.Linear(self.dhe_stack_in, user_id_embedding_dim), 44 | nn.ReLU(), 45 | nn.Linear(user_id_embedding_dim, user_id_embedding_dim), 46 | nn.ReLU(), 47 | nn.Linear(user_id_embedding_dim, user_id_embedding_dim), 48 | nn.ReLU(), 49 | nn.Linear(user_id_embedding_dim, user_id_embedding_dim), 50 | ) 51 | # Initialize the cohort embedding matrix with random values 52 | self.cohort_embedding_matrix = nn.Parameter( 53 | torch.randn(cohort_table_size, self.dhe_stack_in) 54 | ) # [cohort_table_size, self.dhe_stack_in] 55 | self.cohort_enable_topk_regularization: bool = cohort_enable_topk_regularization 56 | self.topk: int = 1 57 | # Get cohort addressing from user features. 58 | # Input: [B, user features] 59 | # Output: [B, cohort_table_size] 60 | self.cohort_addressing_layer = nn.Sequential( 61 | nn.Linear(in_features=user_features_size, out_features=cohort_table_size), 62 | # Adding dropout could increase generalization 63 | # by allowing multiple clusters in the table to learn 64 | # from the behavior of a user. 65 | nn.Dropout(p=cohort_lookup_dropout_rate), 66 | ) 67 | 68 | def compute_cohort_affinity( 69 | self, 70 | user_features: torch.Tensor, # [B, IU] 71 | ) -> torch.Tensor: 72 | # Pass user features through the cohort addressing layer 73 | cohort_affinity = self.cohort_addressing_layer(user_features) # [B, H] 74 | if self.cohort_enable_topk_regularization: 75 | # Apply top-k=1 to get the indices of the top k values 76 | # This ensures that the sum along dim 1 will finally be 1 77 | _, topk_indices = torch.topk(cohort_affinity, k=self.topk, dim=1) 78 | 79 | # Change cohort_affinity to a binary tensor initialized to 0 80 | cohort_affinity.fill_(0) 81 | 82 | # Set the selected indices to 1/self.topk 83 | # Note that if you have topk > 1 then this ensures that 84 | # the sum of values in dim=1 is still 1 85 | cohort_affinity.scatter_(dim=1, index=topk_indices, value=(1 / self.topk)) 86 | else: 87 | cohort_affinity = F.softmax(input=cohort_affinity, dim=-1) 88 | return cohort_affinity 89 | 90 | def get_user_embedding( 91 | self, 92 | user_id: torch.Tensor, # [B] 93 | user_features: torch.Tensor, # [B, IU] 94 | ) -> torch.Tensor: 95 | """ 96 | Returns: [B, user_id_embedding_dim] 97 | """ 98 | cohort_affinity = self.compute_cohort_affinity(user_features) 99 | # Perform matrix multiplication with the embedding matrix 100 | cohort_embedding_in = torch.matmul( 101 | cohort_affinity, self.cohort_embedding_matrix 102 | ) # [B, H] * [H, dhe_in] -> [B, dhe_in] 103 | user_representation = self.dhe_stack(cohort_embedding_in) # [B, DU] 104 | return user_representation 105 | -------------------------------------------------------------------------------- /src/user_mo_representations.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from src.user_cohort_embeddings import UserCohortRepresentation 8 | 9 | 10 | class UserMORepresentations(UserCohortRepresentation): 11 | """ 12 | This is the kicthen soup idea that has proliferated in ML recently. 13 | Here we use both 14 | - the table lookup approach of how user id embeddings 15 | are implemented in UseridLookupRepresentation. 16 | - the cohort/cluster embeddings of UserCohortRepresentation 17 | Then we mix it up using a [(k * emb_lookup) + ((1 - k) * emb_cohort)] 18 | approach that is popularly known as Mixture of Experts, where k is 19 | a function of user features. 20 | If you are implementing this, I suggest you to verify that k is higher 21 | for power users and lower for cold start users/marginal users. You can 22 | see an example of this in paper https://arxiv.org/abs/2210.14309. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | num_tasks: int, 28 | user_id_embedding_dim: int, 29 | user_features_size: int, 30 | item_id_hash_size: int, 31 | item_id_embedding_dim: int, 32 | item_features_size: int, 33 | cross_features_size: int, 34 | user_value_weights: List[float], 35 | cohort_table_size: int, 36 | cohort_lookup_dropout_rate: float, 37 | cohort_enable_topk_regularization: bool, 38 | user_id_hash_size: int, 39 | ) -> None: 40 | super(UserMORepresentations, self).__init__( 41 | num_tasks=num_tasks, 42 | user_id_embedding_dim=user_id_embedding_dim, 43 | user_features_size=user_features_size, 44 | item_id_hash_size=item_id_hash_size, 45 | item_id_embedding_dim=item_id_embedding_dim, 46 | item_features_size=item_features_size, 47 | cross_features_size=cross_features_size, 48 | user_value_weights=user_value_weights, 49 | cohort_table_size=cohort_table_size, 50 | cohort_lookup_dropout_rate=cohort_lookup_dropout_rate, 51 | cohort_enable_topk_regularization=cohort_enable_topk_regularization, 52 | ) 53 | # Setup id lookup 54 | self.user_id_embedding_arch = nn.Embedding( 55 | user_id_hash_size, user_id_embedding_dim 56 | ) 57 | # Setup Mixture network 58 | self.mixture_layer = nn.Sequential( 59 | nn.Linear(user_features_size, 2), 60 | # Add a dropout ensuring that sometimes the 61 | # network passes gradient to the other branch and hence 62 | # learns to train both representations for a user. 63 | nn.Dropout(p=0.8), 64 | nn.ReLU(), 65 | ) 66 | 67 | def get_user_embedding( 68 | self, 69 | user_id: torch.Tensor, # [B] 70 | user_features: torch.Tensor, # [B, IU] 71 | ) -> torch.Tensor: 72 | user_id_lookup_embedding = self.user_id_embedding_arch(user_id) # [B, DU] 73 | user_id_cohort_embedding = super().get_user_embedding( 74 | user_id=user_id, user_features=user_features 75 | ) # [B, DU] 76 | rep_wts = self.mixture_layer(user_features) # [B, 2] 77 | rep_probs = F.softmax(rep_wts, dim=1) # [B, 2] 78 | stacked_user_embeddings = torch.stack( 79 | [user_id_lookup_embedding, user_id_cohort_embedding], dim=2 80 | ) # [B, DU, 2] 81 | user_id_embedding = torch.bmm( 82 | stacked_user_embeddings, rep_probs.unsqueeze(2) 83 | ).squeeze( 84 | -1 85 | ) # [B, DU] 86 | return user_id_embedding 87 | -------------------------------------------------------------------------------- /src/userid_lookup_representaion.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from src.multi_task_estimator import MultiTaskEstimator 7 | 8 | 9 | class UseridLookupRepresentation(MultiTaskEstimator): 10 | """ 11 | This uses embedding table lookup 12 | """ 13 | 14 | def __init__( 15 | self, 16 | num_tasks: int, 17 | user_id_hash_size: int, 18 | user_id_embedding_dim: int, 19 | user_features_size: int, 20 | item_id_hash_size: int, 21 | item_id_embedding_dim: int, 22 | item_features_size: int, 23 | cross_features_size: int, 24 | user_value_weights: List[float], 25 | ) -> None: 26 | """ 27 | params: 28 | num_tasks (T): The tasks to compute estimates of 29 | user_id_hash_size: the size of the embedding table for users 30 | user_id_embedding_dim (DU): internal dimension 31 | user_features_size (IU): input feature size for users 32 | item_id_hash_size: the size of the embedding table for items 33 | item_id_embedding_dim (DI): internal dimension 34 | item_features_size: (II) input feature size for items 35 | cross_features_size: (IC) size of cross features 36 | user_value_weights: T dimensional weights, such that a linear 37 | combination of point-wise immediate rewards is the best predictor 38 | of long term user satisfaction. 39 | """ 40 | super(UseridLookupRepresentation, self).__init__( 41 | num_tasks=num_tasks, 42 | user_id_embedding_dim=user_id_embedding_dim, 43 | user_features_size=user_features_size, 44 | item_id_hash_size=item_id_hash_size, 45 | item_id_embedding_dim=item_id_embedding_dim, 46 | item_features_size=item_features_size, 47 | cross_features_size=cross_features_size, 48 | user_value_weights=user_value_weights, 49 | ) 50 | self.user_id_embedding_arch = nn.Embedding( 51 | user_id_hash_size, user_id_embedding_dim 52 | ) 53 | 54 | def get_user_embedding( 55 | self, 56 | user_id: torch.Tensor, # [B] 57 | user_features: torch.Tensor, # [B, IU] 58 | ) -> torch.Tensor: 59 | user_id_embedding = self.user_id_embedding_arch(user_id) 60 | return user_id_embedding 61 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pytest 4 | 5 | # Add project root to Python path 6 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | -------------------------------------------------------------------------------- /tests/test_dhe_rep.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from src.deep_hash_embeddings import DHERepresentation 5 | 6 | 7 | class TestDHERepresentation(unittest.TestCase): 8 | def test_dhe_rep(self): 9 | num_tasks = 3 10 | user_id_embedding_dim = 50 11 | user_features_size = 10 12 | item_id_hash_size = 200 13 | item_id_embedding_dim = 30 14 | item_features_size = 10 15 | cross_features_size = 10 16 | dhe_stack_in_embedding_dim = 60 17 | batch_size = 3 18 | 19 | # unused in the baseline MultiTaskEstimator implementation 20 | user_value_weights = [0.5, 0.3, 0.2] 21 | assert len(user_value_weights) == num_tasks 22 | 23 | # Instantiate DHERepresentation based estimator 24 | model: DHERepresentation = DHERepresentation( 25 | num_tasks=num_tasks, 26 | user_id_embedding_dim=user_id_embedding_dim, 27 | user_features_size=user_features_size, 28 | item_id_hash_size=item_id_hash_size, 29 | item_id_embedding_dim=item_id_embedding_dim, 30 | item_features_size=item_features_size, 31 | cross_features_size=cross_features_size, 32 | user_value_weights=user_value_weights, 33 | dhe_stack_in_embedding_dim=dhe_stack_in_embedding_dim 34 | ) 35 | 36 | # Example input data 37 | user_id = torch.tensor([1, 2, 3]) 38 | user_features = torch.randn(batch_size, user_features_size) 39 | item_id = torch.tensor([4, 5, 6]) 40 | item_features = torch.randn(batch_size, item_features_size) 41 | cross_features = torch.randn(batch_size, cross_features_size) 42 | position = torch.tensor([1, 2, 3], dtype=torch.int32) 43 | labels = torch.randint(2, size=(batch_size, num_tasks)) 44 | 45 | # Example train_forward pass 46 | loss = model.train_forward( 47 | user_id, user_features, 48 | item_id, item_features, 49 | cross_features, position, 50 | labels 51 | ) 52 | self.assertIsInstance(loss, torch.Tensor) 53 | self.assertGreaterEqual(loss.item(), 0) 54 | 55 | # Example forward pass 56 | inference_position = torch.zeros(batch_size, dtype=torch.int32) 57 | output = model( 58 | user_id, user_features, 59 | item_id, item_features, 60 | cross_features, inference_position 61 | ) 62 | self.assertIsInstance(output, torch.Tensor) 63 | self.assertEqual(output.shape, (batch_size, num_tasks)) 64 | 65 | 66 | if __name__ == '__main__': 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /tests/test_user_cohort_rep.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from src.user_cohort_embeddings import UserCohortRepresentation 5 | 6 | 7 | class TestUserCohortRepresentation(unittest.TestCase): 8 | def test_user_cohort_rep_reg(self): 9 | num_tasks = 3 10 | user_id_embedding_dim = 50 11 | user_features_size = 10 12 | item_id_hash_size = 200 13 | item_id_embedding_dim = 30 14 | item_features_size = 10 15 | cross_features_size = 10 16 | cohort_table_size = 256 17 | cohort_lookup_dropout_rate = 0.5 18 | cohort_enable_topk_regularization: bool = False 19 | batch_size = 3 20 | 21 | # unused in the baseline MultiTaskEstimator implementation 22 | user_value_weights = [0.5, 0.3, 0.2] 23 | assert len(user_value_weights) == num_tasks 24 | 25 | # Instantiate UserCohortRepresentation based estimator 26 | model: UserCohortRepresentation = UserCohortRepresentation( 27 | num_tasks=num_tasks, 28 | user_id_embedding_dim=user_id_embedding_dim, 29 | user_features_size=user_features_size, 30 | item_id_hash_size=item_id_hash_size, 31 | item_id_embedding_dim=item_id_embedding_dim, 32 | item_features_size=item_features_size, 33 | cross_features_size=cross_features_size, 34 | user_value_weights=user_value_weights, 35 | cohort_table_size=cohort_table_size, 36 | cohort_lookup_dropout_rate=cohort_lookup_dropout_rate, 37 | cohort_enable_topk_regularization=cohort_enable_topk_regularization 38 | ) 39 | 40 | # Example input data 41 | user_id = torch.tensor([1, 2, 3]) 42 | user_features = torch.randn(batch_size, user_features_size) 43 | item_id = torch.tensor([4, 5, 6]) 44 | item_features = torch.randn(batch_size, item_features_size) 45 | cross_features = torch.randn(batch_size, cross_features_size) 46 | position = torch.tensor([1, 2, 3], dtype=torch.int32) 47 | labels = torch.randint(2, size=(batch_size, num_tasks)) 48 | 49 | # Example train_forward pass 50 | loss = model.train_forward( 51 | user_id, user_features, 52 | item_id, item_features, 53 | cross_features, position, 54 | labels 55 | ) 56 | self.assertIsInstance(loss, torch.Tensor) 57 | self.assertGreaterEqual(loss.item(), 0) 58 | 59 | # Example forward pass 60 | inference_position = torch.zeros(batch_size, dtype=torch.int32) 61 | output = model( 62 | user_id, user_features, 63 | item_id, item_features, 64 | cross_features, inference_position 65 | ) 66 | self.assertIsInstance(output, torch.Tensor) 67 | self.assertEqual(output.shape, (batch_size, num_tasks)) 68 | 69 | def test_user_cohort_rep_reg(self): 70 | num_tasks = 3 71 | user_id_embedding_dim = 50 72 | user_features_size = 10 73 | item_id_hash_size = 200 74 | item_id_embedding_dim = 30 75 | item_features_size = 10 76 | cross_features_size = 10 77 | cohort_table_size = 256 78 | cohort_lookup_dropout_rate = 0.5 79 | cohort_enable_topk_regularization: bool = True 80 | batch_size = 3 81 | 82 | # unused in the baseline MultiTaskEstimator implementation 83 | user_value_weights = [0.5, 0.3, 0.2] 84 | assert len(user_value_weights) == num_tasks 85 | 86 | # Instantiate UserCohortRepresentation based estimator 87 | model: UserCohortRepresentation = UserCohortRepresentation( 88 | num_tasks=num_tasks, 89 | user_id_embedding_dim=user_id_embedding_dim, 90 | user_features_size=user_features_size, 91 | item_id_hash_size=item_id_hash_size, 92 | item_id_embedding_dim=item_id_embedding_dim, 93 | item_features_size=item_features_size, 94 | cross_features_size=cross_features_size, 95 | user_value_weights=user_value_weights, 96 | cohort_table_size=cohort_table_size, 97 | cohort_lookup_dropout_rate=cohort_lookup_dropout_rate, 98 | cohort_enable_topk_regularization=cohort_enable_topk_regularization 99 | ) 100 | 101 | # Example input data 102 | user_id = torch.tensor([1, 2, 3]) 103 | user_features = torch.randn(batch_size, user_features_size) 104 | item_id = torch.tensor([4, 5, 6]) 105 | item_features = torch.randn(batch_size, item_features_size) 106 | cross_features = torch.randn(batch_size, cross_features_size) 107 | position = torch.tensor([1, 2, 3], dtype=torch.int32) 108 | labels = torch.randint(2, size=(batch_size, num_tasks)) 109 | 110 | # Example train_forward pass 111 | loss = model.train_forward( 112 | user_id, user_features, 113 | item_id, item_features, 114 | cross_features, position, 115 | labels 116 | ) 117 | self.assertIsInstance(loss, torch.Tensor) 118 | self.assertGreaterEqual(loss.item(), 0) 119 | 120 | # Example forward pass 121 | inference_position = torch.zeros(batch_size, dtype=torch.int32) 122 | output = model( 123 | user_id, user_features, 124 | item_id, item_features, 125 | cross_features, inference_position 126 | ) 127 | self.assertIsInstance(output, torch.Tensor) 128 | self.assertEqual(output.shape, (batch_size, num_tasks)) 129 | 130 | 131 | if __name__ == '__main__': 132 | unittest.main() 133 | -------------------------------------------------------------------------------- /tests/test_user_mor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from src.user_mo_representations import UserMORepresentations 5 | 6 | 7 | class TestUserMORepresentations(unittest.TestCase): 8 | def test_user_mor_rep_reg(self): 9 | num_tasks = 3 10 | user_id_embedding_dim = 50 11 | user_features_size = 10 12 | item_id_hash_size = 200 13 | item_id_embedding_dim = 30 14 | item_features_size = 10 15 | cross_features_size = 10 16 | cohort_table_size = 256 17 | cohort_lookup_dropout_rate = 0.5 18 | cohort_enable_topk_regularization: bool = False 19 | user_id_hash_size: int = 1024 20 | batch_size = 3 21 | 22 | # unused in the baseline MultiTaskEstimator implementation 23 | user_value_weights = [0.5, 0.3, 0.2] 24 | assert len(user_value_weights) == num_tasks 25 | 26 | # Instantiate UserMORepresentations based estimator 27 | model: UserMORepresentations = UserMORepresentations( 28 | num_tasks=num_tasks, 29 | user_id_embedding_dim=user_id_embedding_dim, 30 | user_features_size=user_features_size, 31 | item_id_hash_size=item_id_hash_size, 32 | item_id_embedding_dim=item_id_embedding_dim, 33 | item_features_size=item_features_size, 34 | cross_features_size=cross_features_size, 35 | user_value_weights=user_value_weights, 36 | cohort_table_size=cohort_table_size, 37 | cohort_lookup_dropout_rate=cohort_lookup_dropout_rate, 38 | cohort_enable_topk_regularization=cohort_enable_topk_regularization, 39 | user_id_hash_size=user_id_hash_size, 40 | ) 41 | 42 | # Example input data 43 | user_id = torch.tensor([1, 2, 3]) 44 | user_features = torch.randn(batch_size, user_features_size) 45 | item_id = torch.tensor([4, 5, 6]) 46 | item_features = torch.randn(batch_size, item_features_size) 47 | cross_features = torch.randn(batch_size, cross_features_size) 48 | position = torch.tensor([1, 2, 3], dtype=torch.int32) 49 | labels = torch.randint(2, size=(batch_size, num_tasks)) 50 | 51 | # Example train_forward pass 52 | loss = model.train_forward( 53 | user_id, user_features, 54 | item_id, item_features, 55 | cross_features, position, 56 | labels 57 | ) 58 | self.assertIsInstance(loss, torch.Tensor) 59 | self.assertGreaterEqual(loss.item(), 0) 60 | 61 | # Example forward pass 62 | inference_position = torch.zeros(batch_size, dtype=torch.int32) 63 | output = model( 64 | user_id, user_features, 65 | item_id, item_features, 66 | cross_features, inference_position 67 | ) 68 | self.assertIsInstance(output, torch.Tensor) 69 | self.assertEqual(output.shape, (batch_size, num_tasks)) 70 | 71 | def test_user_cohort_rep_reg(self): 72 | num_tasks = 3 73 | user_id_embedding_dim = 50 74 | user_features_size = 10 75 | item_id_hash_size = 200 76 | item_id_embedding_dim = 30 77 | item_features_size = 10 78 | cross_features_size = 10 79 | cohort_table_size = 256 80 | cohort_lookup_dropout_rate = 0.5 81 | cohort_enable_topk_regularization: bool = True 82 | user_id_hash_size: int = 1024 83 | batch_size = 3 84 | 85 | # unused in the baseline MultiTaskEstimator implementation 86 | user_value_weights = [0.5, 0.3, 0.2] 87 | assert len(user_value_weights) == num_tasks 88 | 89 | # Instantiate UserMORepresentations based estimator 90 | model: UserMORepresentations = UserMORepresentations( 91 | num_tasks=num_tasks, 92 | user_id_embedding_dim=user_id_embedding_dim, 93 | user_features_size=user_features_size, 94 | item_id_hash_size=item_id_hash_size, 95 | item_id_embedding_dim=item_id_embedding_dim, 96 | item_features_size=item_features_size, 97 | cross_features_size=cross_features_size, 98 | user_value_weights=user_value_weights, 99 | cohort_table_size=cohort_table_size, 100 | cohort_lookup_dropout_rate=cohort_lookup_dropout_rate, 101 | cohort_enable_topk_regularization=cohort_enable_topk_regularization, 102 | user_id_hash_size=user_id_hash_size 103 | ) 104 | 105 | # Example input data 106 | user_id = torch.tensor([1, 2, 3]) 107 | user_features = torch.randn(batch_size, user_features_size) 108 | item_id = torch.tensor([4, 5, 6]) 109 | item_features = torch.randn(batch_size, item_features_size) 110 | cross_features = torch.randn(batch_size, cross_features_size) 111 | position = torch.tensor([1, 2, 3], dtype=torch.int32) 112 | labels = torch.randint(2, size=(batch_size, num_tasks)) 113 | 114 | # Example train_forward pass 115 | loss = model.train_forward( 116 | user_id, user_features, 117 | item_id, item_features, 118 | cross_features, position, 119 | labels 120 | ) 121 | self.assertIsInstance(loss, torch.Tensor) 122 | self.assertGreaterEqual(loss.item(), 0) 123 | 124 | # Example forward pass 125 | inference_position = torch.zeros(batch_size, dtype=torch.int32) 126 | output = model( 127 | user_id, user_features, 128 | item_id, item_features, 129 | cross_features, inference_position 130 | ) 131 | self.assertIsInstance(output, torch.Tensor) 132 | self.assertEqual(output.shape, (batch_size, num_tasks)) 133 | 134 | 135 | if __name__ == '__main__': 136 | unittest.main() 137 | -------------------------------------------------------------------------------- /tests/test_userid_lookup_rep.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from src.userid_lookup_representaion import UseridLookupRepresentation 5 | 6 | 7 | class TestUseridLookupRepresentation(unittest.TestCase): 8 | def test_user_id_lookup_rep(self): 9 | num_tasks = 3 10 | user_id_hash_size = 100 11 | user_id_embedding_dim = 50 12 | user_features_size = 10 13 | item_id_hash_size = 200 14 | item_id_embedding_dim = 30 15 | item_features_size = 10 16 | cross_features_size = 10 17 | batch_size = 3 18 | 19 | # unused in the baseline MultiTaskEstimator implementation 20 | user_value_weights = [0.5, 0.3, 0.2] 21 | assert len(user_value_weights) == num_tasks 22 | 23 | # Instantiate UseridLookupRepresentation based estimator 24 | model: UseridLookupRepresentation = UseridLookupRepresentation( 25 | num_tasks, user_id_hash_size, user_id_embedding_dim, 26 | user_features_size, item_id_hash_size, item_id_embedding_dim, 27 | item_features_size, cross_features_size, 28 | user_value_weights 29 | ) 30 | 31 | # Example input data 32 | user_id = torch.tensor([1, 2, 3]) 33 | user_features = torch.randn(batch_size, user_features_size) 34 | item_id = torch.tensor([4, 5, 6]) 35 | item_features = torch.randn(batch_size, item_features_size) 36 | cross_features = torch.randn(batch_size, cross_features_size) 37 | position = torch.tensor([1, 2, 3], dtype=torch.int32) 38 | labels = torch.randint(2, size=(batch_size, num_tasks)) 39 | 40 | # Example train_forward pass 41 | loss = model.train_forward( 42 | user_id, user_features, 43 | item_id, item_features, 44 | cross_features, position, 45 | labels 46 | ) 47 | self.assertIsInstance(loss, torch.Tensor) 48 | self.assertGreaterEqual(loss.item(), 0) 49 | 50 | # Example forward pass 51 | inference_position = torch.zeros(batch_size, dtype=torch.int32) 52 | output = model( 53 | user_id, user_features, 54 | item_id, item_features, 55 | cross_features, inference_position 56 | ) 57 | self.assertIsInstance(output, torch.Tensor) 58 | self.assertEqual(output.shape, (batch_size, num_tasks)) 59 | 60 | 61 | if __name__ == '__main__': 62 | unittest.main() 63 | --------------------------------------------------------------------------------