├── generative_recommenders ├── indexing │ ├── __init__.py │ ├── utils.py │ ├── mips_top_k.py │ └── candidate_index.py ├── modeling │ ├── __init__.py │ ├── sequential │ │ ├── __init__.py │ │ ├── output_postprocessors.py │ │ ├── features.py │ │ ├── embedding_modules.py │ │ ├── utils.py │ │ ├── encoder_utils.py │ │ ├── input_features_preprocessors.py │ │ ├── sasrec.py │ │ └── autoregressive_losses.py │ ├── initialization.py │ ├── ndp_module.py │ ├── similarity │ │ ├── dot_product.py │ │ └── mol.py │ ├── similarity_module.py │ └── similarity_utils.py ├── data │ ├── item_features.py │ ├── reco_dataset.py │ ├── dataset.py │ ├── eval.py │ └── preprocessor.py └── trainer │ ├── data_loader.py │ └── train.py ├── assets └── 1742544226938.png ├── preprocess_public_data.py ├── configs ├── ml-1m │ ├── fuxi-sampled-softmax-n128-final.gin │ ├── fuxi-sampled-softmax-n128-large-final.gin │ ├── sasrec-sampled-softmax-n128-final.gin │ ├── hstu-sampled-softmax-n128-large-final.gin │ └── hstu-sampled-softmax-n128-final.gin └── ml-20m │ ├── fuxi-sampled-softmax-n128-final.gin │ ├── fuxi-sampled-softmax-n128-large-final.gin │ ├── sasrec-sampled-softmax-n128-final.gin │ ├── hstu-sampled-softmax-n128-final.gin │ └── hstu-sampled-softmax-n128-large-final.gin ├── requirements.txt ├── main.py ├── README.md └── .gitignore /generative_recommenders/indexing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/sequential/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/1742544226938.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/USTC-StarTeam/FuXi-alpha/HEAD/assets/1742544226938.png -------------------------------------------------------------------------------- /generative_recommenders/data/item_features.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | from dataclasses import dataclass 18 | from typing import List 19 | 20 | import torch 21 | 22 | 23 | @dataclass 24 | class ItemFeatures: 25 | num_items: int 26 | max_jagged_dimension: int 27 | max_ind_range: List[int] # [(,)] x num_features 28 | lengths: List[torch.Tensor] # [(num_items,)] x num_features 29 | values: List[torch.Tensor] # [(num_items, max_jagged_dimension)] x num_features 30 | -------------------------------------------------------------------------------- /preprocess_public_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | """ 18 | Usage: mkdir -p tmp/ && python3 preprocess_public_data.py 19 | """ 20 | from generative_recommenders.data.preprocessor import get_common_preprocessors 21 | 22 | 23 | def main() -> None: 24 | get_common_preprocessors()["ml-1m"].preprocess_rating() 25 | get_common_preprocessors()["ml-20m"].preprocess_rating() 26 | # get_common_preprocessors()["ml-1b"].preprocess_rating() 27 | # get_common_preprocessors()["amzn-books"].preprocess_rating() 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /configs/ml-1m/fuxi-sampled-softmax-n128-final.gin: -------------------------------------------------------------------------------- 1 | # Run this as: 2 | # mkdir -p logs/ml-1m-l200/ 3 | # CUDA_VISIBLE_DEVICES=0 python3 main.py --gin_config_file=configs/ml-1m/fuxi-sampled-softmax-n128-final.gin --master_port=12345 2>&1 | tee logs/ml-1m-l200/fuxi-sampled-softmax-n128-final.log 4 | 5 | train_fn.dataset_name = "ml-1m" 6 | train_fn.max_sequence_length = 200 7 | train_fn.local_batch_size = 128 8 | 9 | train_fn.main_module = "FuXi" 10 | train_fn.dropout_rate = 0.2 11 | train_fn.user_embedding_norm = "l2_norm" 12 | train_fn.num_epochs = 101 13 | train_fn.item_embedding_dim = 50 14 | 15 | fuxi_encoder.num_blocks = 2 16 | fuxi_encoder.num_heads = 1 17 | fuxi_encoder.dqk = 50 18 | fuxi_encoder.dv = 50 19 | fuxi_encoder.linear_dropout_rate = 0.2 20 | 21 | train_fn.learning_rate = 1e-3 22 | train_fn.weight_decay = 0 23 | train_fn.num_warmup_steps = 0 24 | 25 | train_fn.interaction_module_type = "DotProduct" 26 | train_fn.top_k_method = "MIPSBruteForceTopK" 27 | 28 | train_fn.loss_module = "SampledSoftmaxLoss" 29 | train_fn.num_negatives = 128 30 | 31 | train_fn.sampling_strategy = "local" 32 | train_fn.temperature = 0.05 33 | train_fn.item_l2_norm = True 34 | train_fn.l2_norm_eps = 1e-6 35 | 36 | train_fn.enable_tf32 = True 37 | 38 | create_data_loader.prefetch_factor = 128 39 | create_data_loader.num_workers = 8 40 | -------------------------------------------------------------------------------- /configs/ml-1m/fuxi-sampled-softmax-n128-large-final.gin: -------------------------------------------------------------------------------- 1 | # Run this as: 2 | # mkdir -p logs/ml-1m-l200/ 3 | # CUDA_VISIBLE_DEVICES=0 python3 main.py --gin_config_file=configs/ml-1m/fuxi-sampled-softmax-n128-large-final.gin --master_port=12345 2>&1 | tee logs/ml-1m-l200/fuxi-sampled-softmax-n128-large-final.log 4 | 5 | train_fn.dataset_name = "ml-1m" 6 | train_fn.max_sequence_length = 200 7 | train_fn.local_batch_size = 128 8 | 9 | train_fn.main_module = "FuXi" 10 | train_fn.dropout_rate = 0.2 11 | train_fn.user_embedding_norm = "l2_norm" 12 | train_fn.num_epochs = 101 13 | train_fn.item_embedding_dim = 50 14 | 15 | fuxi_encoder.num_blocks = 8 16 | fuxi_encoder.num_heads = 2 17 | fuxi_encoder.dqk = 25 18 | fuxi_encoder.dv = 25 19 | fuxi_encoder.linear_dropout_rate = 0.2 20 | 21 | train_fn.learning_rate = 1e-3 22 | train_fn.weight_decay = 0 23 | train_fn.num_warmup_steps = 0 24 | 25 | train_fn.interaction_module_type = "DotProduct" 26 | train_fn.top_k_method = "MIPSBruteForceTopK" 27 | 28 | train_fn.loss_module = "SampledSoftmaxLoss" 29 | train_fn.num_negatives = 128 30 | 31 | train_fn.sampling_strategy = "local" 32 | train_fn.temperature = 0.05 33 | train_fn.item_l2_norm = True 34 | train_fn.l2_norm_eps = 1e-6 35 | 36 | train_fn.enable_tf32 = True 37 | 38 | create_data_loader.prefetch_factor = 128 39 | create_data_loader.num_workers = 8 40 | -------------------------------------------------------------------------------- /configs/ml-20m/fuxi-sampled-softmax-n128-final.gin: -------------------------------------------------------------------------------- 1 | # Run this as: 2 | # mkdir -p logs/ml-20m-l200/ 3 | # CUDA_VISIBLE_DEVICES=0 python3 main.py --gin_config_file=configs/ml-20m/fuxi-sampled-softmax-n128-final.gin --master_port=12345 2>&1 | tee logs/ml-20m-l200/fuxi-sampled-softmax-n128-final.log 4 | 5 | train_fn.dataset_name = "ml-20m" 6 | train_fn.max_sequence_length = 200 7 | train_fn.local_batch_size = 128 8 | 9 | train_fn.main_module = "FuXi" 10 | train_fn.dropout_rate = 0.2 11 | train_fn.user_embedding_norm = "l2_norm" 12 | train_fn.num_epochs = 101 13 | train_fn.item_embedding_dim = 256 14 | 15 | fuxi_encoder.num_blocks = 2 16 | fuxi_encoder.num_heads = 4 17 | fuxi_encoder.dv = 64 18 | fuxi_encoder.dqk = 64 19 | fuxi_encoder.linear_dropout_rate = 0.2 20 | fuxi_encoder.ffn_multiply = 4 21 | 22 | train_fn.learning_rate = 1e-3 23 | train_fn.weight_decay = 0 24 | train_fn.num_warmup_steps = 0 25 | 26 | train_fn.interaction_module_type = "DotProduct" 27 | train_fn.top_k_method = "MIPSBruteForceTopK" 28 | 29 | train_fn.loss_module = "SampledSoftmaxLoss" 30 | train_fn.num_negatives = 128 31 | 32 | train_fn.sampling_strategy = "local" 33 | train_fn.temperature = 0.05 34 | train_fn.item_l2_norm = True 35 | train_fn.l2_norm_eps = 1e-6 36 | 37 | train_fn.enable_tf32 = True 38 | 39 | create_data_loader.prefetch_factor = 128 40 | create_data_loader.num_workers = 8 41 | -------------------------------------------------------------------------------- /generative_recommenders/indexing/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import torch 18 | 19 | from generative_recommenders.indexing.candidate_index import CandidateIndex, TopKModule 20 | from generative_recommenders.indexing.mips_top_k import MIPSBruteForceTopK 21 | 22 | 23 | def get_top_k_module( 24 | top_k_method: str, 25 | model: torch.nn.Module, 26 | item_embeddings: torch.Tensor, 27 | item_ids: torch.Tensor, 28 | ) -> TopKModule: 29 | if top_k_method == "MIPSBruteForceTopK": 30 | top_k_module = MIPSBruteForceTopK( 31 | item_embeddings=item_embeddings, 32 | item_ids=item_ids, 33 | ) 34 | else: 35 | raise ValueError(f"Invalid top-k method {top_k_method}") 36 | return top_k_module 37 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/initialization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import torch 18 | 19 | 20 | def truncated_normal(x: torch.Tensor, mean: float, std: float) -> torch.Tensor: 21 | with torch.no_grad(): 22 | size = x.shape 23 | tmp = x.new_empty(size + (4,)).normal_() 24 | valid = (tmp < 2) & (tmp > -2) 25 | ind = valid.max(-1, keepdim=True)[1] 26 | x.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 27 | x.data.mul_(std).add_(mean) 28 | return x 29 | 30 | 31 | def init_mlp_xavier_weights_zero_bias(m: torch.nn.Module) -> None: 32 | if isinstance(m, torch.nn.Linear): 33 | torch.nn.init.xavier_uniform(m.weight) 34 | if getattr(m, "bias", None) is not None: 35 | m.bias.data.fill_(0.0) 36 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/ndp_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import abc 18 | from typing import Dict, Optional, Tuple 19 | 20 | import torch 21 | 22 | 23 | class NDPModule(torch.nn.Module): 24 | 25 | def forward( # pyre-ignore[3] 26 | self, 27 | input_embeddings: torch.Tensor, 28 | item_embeddings: torch.Tensor, 29 | item_sideinfo: Optional[torch.Tensor], 30 | item_ids: torch.Tensor, 31 | precomputed_logits: Optional[torch.Tensor] = None, 32 | ): 33 | """ 34 | Args: 35 | input_embeddings: (B, input_embedding_dim) x float 36 | item_embeddings: (1/B, X, item_embedding_dim) x float 37 | item_sideinfo: (1/B, X, item_sideinfo_dim) x float 38 | 39 | Returns: 40 | Tuple of (B, X,) similarity values, keyed outputs 41 | """ 42 | pass 43 | -------------------------------------------------------------------------------- /configs/ml-20m/fuxi-sampled-softmax-n128-large-final.gin: -------------------------------------------------------------------------------- 1 | # Run this as: 2 | # mkdir -p logs/ml-20m-l200/ 3 | # CUDA_VISIBLE_DEVICES=0 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python3 main.py --gin_config_file=configs/ml-20m/fuxi-sampled-softmax-n128-large-final.gin --master_port=12345 2>&1 | tee logs/ml-20m-l200/fuxi-sampled-softmax-n128-large-final.log 4 | # CUDA_VISIBLE_DEVICES=0 python3 main.py --gin_config_file=configs/ml-20m/fuxi-sampled-softmax-n128-large-final.gin --master_port=12345 5 | 6 | train_fn.dataset_name = "ml-20m" 7 | train_fn.max_sequence_length = 200 8 | train_fn.local_batch_size = 128 9 | 10 | train_fn.main_module = "FuXi" 11 | train_fn.dropout_rate = 0.2 12 | train_fn.user_embedding_norm = "l2_norm" 13 | train_fn.num_epochs = 101 14 | train_fn.item_embedding_dim = 256 15 | 16 | fuxi_encoder.num_blocks = 8 17 | fuxi_encoder.num_heads = 8 18 | fuxi_encoder.dv = 32 19 | fuxi_encoder.dqk = 32 20 | fuxi_encoder.linear_dropout_rate = 0.2 21 | fuxi_encoder.ffn_multiply = 4 22 | 23 | train_fn.learning_rate = 1e-3 24 | train_fn.weight_decay = 0 25 | train_fn.num_warmup_steps = 0 26 | 27 | train_fn.interaction_module_type = "DotProduct" 28 | train_fn.top_k_method = "MIPSBruteForceTopK" 29 | 30 | train_fn.loss_module = "SampledSoftmaxLoss" 31 | train_fn.num_negatives = 128 32 | 33 | train_fn.sampling_strategy = "local" 34 | train_fn.temperature = 0.05 35 | train_fn.item_l2_norm = True 36 | train_fn.l2_norm_eps = 1e-6 37 | 38 | train_fn.enable_tf32 = True 39 | 40 | create_data_loader.prefetch_factor = 128 41 | create_data_loader.num_workers = 8 42 | -------------------------------------------------------------------------------- /configs/ml-1m/sasrec-sampled-softmax-n128-final.gin: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Frozen config, validated on 04/11/2024. 4 | # Based on baseline settings in Revisiting Neural Retrieval on Accelerators (https://arxiv.org/abs/2306.04039, KDD'23). 5 | # 6 | # Run this as: 7 | # mkdir -p logs/ml-1m-l200/ 8 | # CUDA_VISIBLE_DEVICES=0 python3 main.py --gin_config_file=configs/ml-1m/sasrec-sampled-softmax-n128-final.gin --master_port=12345 2>&1 | tee logs/ml-1m-l200/sasrec-sampled-softmax-n128-final.log 9 | 10 | train_fn.dataset_name = "ml-1m" 11 | train_fn.max_sequence_length = 200 12 | train_fn.local_batch_size = 128 13 | 14 | train_fn.main_module = "SASRec" 15 | train_fn.dropout_rate = 0.2 16 | train_fn.user_embedding_norm = "l2_norm" 17 | train_fn.num_epochs = 101 18 | train_fn.item_embedding_dim = 50 19 | 20 | sasrec_encoder.num_blocks = 2 21 | sasrec_encoder.num_heads = 1 22 | sasrec_encoder.ffn_dropout_rate = 0.2 23 | sasrec_encoder.ffn_hidden_dim = 50 24 | sasrec_encoder.ffn_activation_fn = "relu" 25 | 26 | train_fn.learning_rate = 1e-3 27 | train_fn.weight_decay = 0 28 | train_fn.num_warmup_steps = 0 29 | 30 | train_fn.top_k_method = "MIPSBruteForceTopK" 31 | train_fn.interaction_module_type = "DotProduct" 32 | 33 | train_fn.loss_module = "SampledSoftmaxLoss" 34 | train_fn.num_negatives = 128 35 | 36 | train_fn.sampling_strategy = "local" 37 | train_fn.temperature = 0.05 38 | train_fn.item_l2_norm = True 39 | train_fn.l2_norm_eps = 1e-6 40 | 41 | train_fn.enable_tf32 = True 42 | 43 | create_data_loader.prefetch_factor = 128 44 | create_data_loader.num_workers = 8 45 | -------------------------------------------------------------------------------- /configs/ml-20m/sasrec-sampled-softmax-n128-final.gin: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Frozen config, validated on 04/12/2024. 4 | # Based on baseline settings in Revisiting Neural Retrieval on Accelerators (https://arxiv.org/abs/2306.04039, KDD'23). 5 | # 6 | # Run this as: 7 | # mkdir -p logs/ml-20m-l200/ 8 | # CUDA_VISIBLE_DEVICES=0 python3 main.py --gin_config_file=configs/ml-20m/sasrec-sampled-softmax-n128-final.gin --master_port=12345 2>&1 | tee logs/ml-20m-l200/sasrec-sampled-softmax-n128-final.log 9 | 10 | train_fn.dataset_name = "ml-20m" 11 | train_fn.max_sequence_length = 200 12 | train_fn.local_batch_size = 128 13 | 14 | train_fn.main_module = "SASRec" 15 | train_fn.dropout_rate = 0.2 16 | train_fn.user_embedding_norm = "l2_norm" 17 | train_fn.num_epochs = 101 18 | train_fn.item_embedding_dim = 256 19 | 20 | sasrec_encoder.num_blocks = 4 21 | sasrec_encoder.num_heads = 4 22 | sasrec_encoder.ffn_dropout_rate = 0.2 23 | sasrec_encoder.ffn_hidden_dim = 256 24 | sasrec_encoder.ffn_activation_fn = "relu" 25 | 26 | train_fn.learning_rate = 1e-3 27 | train_fn.weight_decay = 0 28 | train_fn.num_warmup_steps = 0 29 | 30 | train_fn.top_k_method = "MIPSBruteForceTopK" 31 | train_fn.interaction_module_type = "DotProduct" 32 | 33 | train_fn.loss_module = "SampledSoftmaxLoss" 34 | train_fn.num_negatives = 128 35 | 36 | train_fn.sampling_strategy = "local" 37 | train_fn.temperature = 0.05 38 | train_fn.item_l2_norm = True 39 | train_fn.l2_norm_eps = 1e-6 40 | 41 | train_fn.enable_tf32 = True 42 | 43 | create_data_loader.prefetch_factor = 128 44 | create_data_loader.num_workers = 8 45 | -------------------------------------------------------------------------------- /configs/ml-1m/hstu-sampled-softmax-n128-large-final.gin: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Frozen config, validated on 04/11/2024. 4 | # Based on HSTU-large results in 5 | # Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for Generative Recommendations (https://arxiv.org/abs/2402.17152). 6 | # 7 | # Run this as: 8 | # mkdir -p logs/ml-1m-l200/ 9 | # CUDA_VISIBLE_DEVICES=0 python3 main.py --gin_config_file=configs/ml-1m/hstu-sampled-softmax-n128-large-final.gin --master_port=12345 2>&1 | tee logs/ml-1m-l200/hstu-sampled-softmax-n128-large-final.log 10 | 11 | train_fn.dataset_name = "ml-1m" 12 | train_fn.max_sequence_length = 200 13 | train_fn.local_batch_size = 128 14 | 15 | train_fn.main_module = "HSTU" 16 | train_fn.dropout_rate = 0.2 17 | train_fn.user_embedding_norm = "l2_norm" 18 | train_fn.num_epochs = 101 19 | train_fn.item_embedding_dim = 50 20 | 21 | hstu_encoder.num_blocks = 8 22 | hstu_encoder.num_heads = 2 23 | hstu_encoder.dqk = 25 24 | hstu_encoder.dv = 25 25 | hstu_encoder.linear_dropout_rate = 0.2 26 | 27 | train_fn.learning_rate = 1e-3 28 | train_fn.weight_decay = 0 29 | train_fn.num_warmup_steps = 0 30 | 31 | train_fn.interaction_module_type = "DotProduct" 32 | train_fn.top_k_method = "MIPSBruteForceTopK" 33 | 34 | train_fn.loss_module = "SampledSoftmaxLoss" 35 | train_fn.num_negatives = 128 36 | 37 | train_fn.sampling_strategy = "local" 38 | train_fn.temperature = 0.05 39 | train_fn.item_l2_norm = True 40 | train_fn.l2_norm_eps = 1e-6 41 | 42 | train_fn.enable_tf32 = True 43 | 44 | create_data_loader.prefetch_factor = 128 45 | create_data_loader.num_workers = 8 46 | -------------------------------------------------------------------------------- /configs/ml-1m/hstu-sampled-softmax-n128-final.gin: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Frozen config, validated on 04/11/2024. 4 | # Based on HSTU results (w/ identical configurations as a SotA Transformer baseline) in 5 | # Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for Generative Recommendations (https://arxiv.org/abs/2402.17152). 6 | # 7 | # Run this as: 8 | # mkdir -p logs/ml-1m-l200/ 9 | # CUDA_VISIBLE_DEVICES=0 python3 main.py --gin_config_file=configs/ml-1m/hstu-sampled-softmax-n128-final.gin --master_port=12345 2>&1 | tee logs/ml-1m-l200/hstu-sampled-softmax-n128-final.log 10 | 11 | train_fn.dataset_name = "ml-1m" 12 | train_fn.max_sequence_length = 200 13 | train_fn.local_batch_size = 128 14 | 15 | train_fn.main_module = "HSTU" 16 | train_fn.dropout_rate = 0.2 17 | train_fn.user_embedding_norm = "l2_norm" 18 | train_fn.num_epochs = 101 19 | train_fn.item_embedding_dim = 50 20 | 21 | hstu_encoder.num_blocks = 2 22 | hstu_encoder.num_heads = 1 23 | hstu_encoder.dqk = 50 24 | hstu_encoder.dv = 50 25 | hstu_encoder.linear_dropout_rate = 0.2 26 | 27 | train_fn.learning_rate = 1e-3 28 | train_fn.weight_decay = 0 29 | train_fn.num_warmup_steps = 0 30 | 31 | train_fn.interaction_module_type = "DotProduct" 32 | train_fn.top_k_method = "MIPSBruteForceTopK" 33 | 34 | train_fn.loss_module = "SampledSoftmaxLoss" 35 | train_fn.num_negatives = 128 36 | 37 | train_fn.sampling_strategy = "local" 38 | train_fn.temperature = 0.05 39 | train_fn.item_l2_norm = True 40 | train_fn.l2_norm_eps = 1e-6 41 | 42 | train_fn.enable_tf32 = True 43 | 44 | create_data_loader.prefetch_factor = 128 45 | create_data_loader.num_workers = 8 46 | -------------------------------------------------------------------------------- /configs/ml-20m/hstu-sampled-softmax-n128-final.gin: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Frozen config, validated on 04/12/2024. 4 | # Based on HSTU results (w/ identical configurations as a SotA Transformer baseline) in 5 | # Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for Generative Recommendations (https://arxiv.org/abs/2402.17152). 6 | # 7 | # Run this as: 8 | # mkdir -p logs/ml-20m-l200/ 9 | # CUDA_VISIBLE_DEVICES=0 python3 main.py --gin_config_file=configs/ml-20m/hstu-sampled-softmax-n128-final.gin --master_port=12345 2>&1 | tee logs/ml-20m-l200/hstu-sampled-softmax-n128-final.log 10 | 11 | train_fn.dataset_name = "ml-20m" 12 | train_fn.max_sequence_length = 200 13 | train_fn.local_batch_size = 128 14 | 15 | train_fn.main_module = "HSTU" 16 | train_fn.dropout_rate = 0.2 17 | train_fn.user_embedding_norm = "l2_norm" 18 | train_fn.num_epochs = 101 19 | train_fn.item_embedding_dim = 256 20 | 21 | hstu_encoder.num_blocks = 2 22 | hstu_encoder.num_heads = 4 23 | hstu_encoder.dv = 64 24 | hstu_encoder.dqk = 64 25 | hstu_encoder.linear_dropout_rate = 0.2 26 | 27 | train_fn.learning_rate = 1e-3 28 | train_fn.weight_decay = 0 29 | train_fn.num_warmup_steps = 0 30 | 31 | train_fn.interaction_module_type = "DotProduct" 32 | train_fn.top_k_method = "MIPSBruteForceTopK" 33 | 34 | train_fn.loss_module = "SampledSoftmaxLoss" 35 | train_fn.num_negatives = 128 36 | 37 | train_fn.sampling_strategy = "local" 38 | train_fn.temperature = 0.05 39 | train_fn.item_l2_norm = True 40 | train_fn.l2_norm_eps = 1e-6 41 | 42 | train_fn.enable_tf32 = True 43 | 44 | create_data_loader.prefetch_factor = 128 45 | create_data_loader.num_workers = 8 46 | -------------------------------------------------------------------------------- /configs/ml-20m/hstu-sampled-softmax-n128-large-final.gin: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Frozen config, validated on 04/12/2024. 4 | # Based on HSTU-large results in 5 | # Actions Speak Louder than Words: Trillion-Parameter Sequential Transducers for Generative Recommendations (https://arxiv.org/abs/2402.17152). 6 | # 7 | # Run this as: 8 | # mkdir -p logs/ml-20m-l200/ 9 | # CUDA_VISIBLE_DEVICES=0 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python3 main.py --gin_config_file=configs/ml-20m/hstu-sampled-softmax-n128-large-final.gin --master_port=12345 2>&1 | tee logs/ml-20m-l200/hstu-sampled-softmax-n128-large-final.log 10 | # CUDA_VISIBLE_DEVICES=0 python3 main.py --gin_config_file=configs/ml-20m/hstu-sampled-softmax-n128-large-final.gin --master_port=12345 11 | 12 | train_fn.dataset_name = "ml-20m" 13 | train_fn.max_sequence_length = 200 14 | train_fn.local_batch_size = 128 15 | 16 | train_fn.main_module = "HSTU" 17 | train_fn.dropout_rate = 0.2 18 | train_fn.user_embedding_norm = "l2_norm" 19 | train_fn.num_epochs = 101 20 | train_fn.item_embedding_dim = 256 21 | 22 | hstu_encoder.num_blocks = 8 23 | hstu_encoder.num_heads = 8 24 | hstu_encoder.dv = 32 25 | hstu_encoder.dqk = 32 26 | hstu_encoder.linear_dropout_rate = 0.2 27 | 28 | train_fn.learning_rate = 1e-3 29 | train_fn.weight_decay = 0 30 | train_fn.num_warmup_steps = 0 31 | 32 | train_fn.interaction_module_type = "DotProduct" 33 | train_fn.top_k_method = "MIPSBruteForceTopK" 34 | 35 | train_fn.loss_module = "SampledSoftmaxLoss" 36 | train_fn.num_negatives = 128 37 | 38 | train_fn.sampling_strategy = "local" 39 | train_fn.temperature = 0.05 40 | train_fn.item_l2_norm = True 41 | train_fn.l2_norm_eps = 1e-6 42 | 43 | train_fn.enable_tf32 = True 44 | 45 | create_data_loader.prefetch_factor = 128 46 | create_data_loader.num_workers = 8 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | astunparse==1.6.3 3 | certifi==2024.2.2 4 | charset-normalizer==3.3.2 5 | contourpy==1.2.1 6 | cycler==0.12.1 7 | fbgemm-gpu==0.6.0 8 | filelock==3.13.4 9 | flatbuffers==24.3.25 10 | fonttools==4.51.0 11 | fsspec==2024.3.1 12 | fvcore==0.1.5.post20221221 13 | gast==0.5.4 14 | gin-config==0.5.0 15 | google-pasta==0.2.0 16 | grpcio==1.62.1 17 | h5py==3.11.0 18 | iopath==0.1.10 19 | joblib==1.4.0 20 | keras==3.2.0 21 | kiwisolver==1.4.5 22 | libclang==18.1.1 23 | Markdown==3.6 24 | markdown-it-py==3.0.0 25 | MarkupSafe==2.1.5 26 | matplotlib==3.8.4 27 | mdurl==0.1.2 28 | ml-dtypes==0.3.2 29 | mpmath==1.3.0 30 | namex==0.0.7 31 | networkx==3.3 32 | numpy==1.26.4 33 | nvidia-cublas-cu12==12.1.3.1 34 | nvidia-cuda-cupti-cu12==12.1.105 35 | nvidia-cuda-nvrtc-cu12==12.1.105 36 | nvidia-cuda-runtime-cu12==12.1.105 37 | nvidia-cudnn-cu12==8.9.2.26 38 | nvidia-cufft-cu12==11.0.2.54 39 | nvidia-curand-cu12==10.3.2.106 40 | nvidia-cusolver-cu12==11.4.5.107 41 | nvidia-cusparse-cu12==12.1.0.106 42 | nvidia-nccl-cu12==2.19.3 43 | nvidia-nvjitlink-cu12==12.4.127 44 | nvidia-nvtx-cu12==12.1.105 45 | opt-einsum==3.3.0 46 | optree==0.11.0 47 | packaging==24.0 48 | pandas==2.2.1 49 | pillow==10.3.0 50 | portalocker==2.8.2 51 | protobuf==4.25.3 52 | Pygments==2.17.2 53 | pyparsing==3.1.2 54 | python-dateutil==2.9.0.post0 55 | pytz==2024.1 56 | PyYAML==6.0.1 57 | rich==13.7.1 58 | scipy==1.13.0 59 | six==1.16.0 60 | sympy==1.12 61 | tabulate==0.9.0 62 | tensorboard==2.16.2 63 | tensorboard-data-server==0.7.2 64 | tensorflow==2.16.1 65 | tensorflow-io-gcs-filesystem==0.36.0 66 | tensorrt==8.6.1.post1 67 | tensorrt-bindings==8.6.1 68 | tensorrt-libs==8.6.1 69 | termcolor==2.4.0 70 | threadpoolctl==3.4.0 71 | torch==2.2.2 72 | torch_tensorrt==2.2.0 73 | torchaudio==2.2.2 74 | torchvision==0.17.2 75 | triton==2.2.0 76 | typing_extensions==4.11.0 77 | tzdata==2024.1 78 | wrapt==1.16.0 79 | yacs==0.1.8 80 | -------------------------------------------------------------------------------- /generative_recommenders/trainer/data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import os 18 | from typing import Optional, Tuple 19 | 20 | import gin 21 | import torch 22 | 23 | 24 | @gin.configurable 25 | def create_data_loader( 26 | dataset: torch.utils.data.Dataset, 27 | batch_size: int, 28 | world_size: int, 29 | rank: int, 30 | shuffle: bool, 31 | prefetch_factor: int = 128, 32 | num_workers: Optional[int] = os.cpu_count(), 33 | drop_last: bool = False, 34 | ) -> Tuple[ 35 | Optional[torch.utils.data.distributed.DistributedSampler[torch.utils.data.Dataset]], 36 | torch.utils.data.DataLoader, 37 | ]: 38 | # print(f"num_workers={num_workers}") 39 | if shuffle: 40 | sampler = torch.utils.data.distributed.DistributedSampler( 41 | dataset, 42 | num_replicas=world_size, 43 | rank=rank, 44 | shuffle=True, 45 | seed=0, 46 | drop_last=drop_last, 47 | ) 48 | else: 49 | sampler = None 50 | data_loader = torch.utils.data.DataLoader( 51 | dataset, 52 | batch_size=batch_size, 53 | # shuffle=True, cannot use with sampler 54 | num_workers=0, 55 | sampler=sampler, 56 | # prefetch_factor=prefetch_factor, 57 | ) 58 | return sampler, data_loader 59 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/similarity/dot_product.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | from typing import Callable, Dict, Optional, Tuple 18 | 19 | import torch 20 | 21 | from generative_recommenders.modeling.ndp_module import NDPModule 22 | 23 | 24 | class DotProductSimilarity(NDPModule): 25 | def __init__( 26 | self, 27 | ) -> None: 28 | super().__init__() 29 | 30 | def debug_str(self) -> str: 31 | return "dp" 32 | 33 | def forward( # pyre-ignore [3] 34 | self, 35 | input_embeddings: torch.Tensor, 36 | item_embeddings: torch.Tensor, 37 | item_sideinfo: Optional[torch.Tensor], 38 | item_ids: torch.Tensor, 39 | precomputed_logits: Optional[torch.Tensor] = None, 40 | ): 41 | """ 42 | Args: 43 | input_embeddings: (B, D,) or (B * r, D) x float. 44 | item_embeddings: (1, X, D) or (B, X, D) x float. 45 | 46 | Returns: 47 | (B, X) x float (or (B * r, X) x float). 48 | """ 49 | del item_ids 50 | 51 | if item_embeddings.size(0) == 1: 52 | # [B, D] x ([1, X, D] -> [D, X]) => [B, X] 53 | return ( 54 | torch.mm(input_embeddings, item_embeddings.squeeze(0).t()), 55 | {}, 56 | ) # [B, X] 57 | elif input_embeddings.size(0) != item_embeddings.size(0): 58 | # (B * r, D) x (B, X, D). 59 | B, X, D = item_embeddings.size() 60 | return torch.bmm( 61 | input_embeddings.view(B, -1, D), item_embeddings.permute(0, 2, 1) 62 | ).view(-1, X) 63 | else: 64 | # assert input_embeddings.size(0) == item_embeddings.size(0) 65 | # [B, X, D] x ([B, D] -> [B, D, 1]) => [B, X, 1] -> [B, X] 66 | return torch.bmm(item_embeddings, input_embeddings.unsqueeze(2)).squeeze(2) 67 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/sequential/output_postprocessors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import abc 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | 22 | 23 | class OutputPostprocessorModule(torch.nn.Module): 24 | 25 | @abc.abstractmethod 26 | def debug_str(self) -> str: 27 | pass 28 | 29 | @abc.abstractmethod 30 | def forward( 31 | self, 32 | output_embeddings: torch.Tensor, 33 | ) -> torch.Tensor: 34 | pass 35 | 36 | 37 | class L2NormEmbeddingPostprocessor(OutputPostprocessorModule): 38 | 39 | def __init__( 40 | self, 41 | embedding_dim: int, 42 | eps: float = 1e-6, 43 | ) -> None: 44 | super().__init__() 45 | self._embedding_dim: int = embedding_dim 46 | self._eps: float = eps 47 | 48 | def debug_str(self) -> str: 49 | return "l2" 50 | 51 | def forward( 52 | self, 53 | output_embeddings: torch.Tensor, 54 | ) -> torch.Tensor: 55 | output_embeddings = output_embeddings[..., : self._embedding_dim] 56 | return output_embeddings / torch.clamp( 57 | torch.linalg.norm(output_embeddings, ord=None, dim=-1, keepdim=True), 58 | min=self._eps, 59 | ) 60 | 61 | 62 | class LayerNormEmbeddingPostprocessor(OutputPostprocessorModule): 63 | 64 | def __init__( 65 | self, 66 | embedding_dim: int, 67 | eps: float = 1e-6, 68 | ) -> None: 69 | super().__init__() 70 | self._embedding_dim: int = embedding_dim 71 | self._eps: float = eps 72 | 73 | def debug_str(self) -> str: 74 | return "ln" 75 | 76 | def forward( 77 | self, 78 | output_embeddings: torch.Tensor, 79 | ) -> torch.Tensor: 80 | output_embeddings = output_embeddings[..., : self._embedding_dim] 81 | return F.layer_norm( 82 | output_embeddings, 83 | normalized_shape=(self._embedding_dim,), 84 | eps=self._eps, 85 | ) 86 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | """ 18 | Main entry point for model training. Please refer to README.md for usage instructions. 19 | """ 20 | 21 | import logging 22 | import os 23 | 24 | from typing import List, Optional 25 | 26 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" # Hide excessive tensorflow debug messages 27 | import sys 28 | 29 | import fbgemm_gpu # noqa: F401, E402 30 | import gin 31 | 32 | import torch 33 | import torch.multiprocessing as mp 34 | 35 | from absl import app, flags 36 | from generative_recommenders.trainer.train import train_fn 37 | 38 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 39 | 40 | 41 | def delete_flags(FLAGS, keys_to_delete: List[str]) -> None: # pyre-ignore [2] 42 | keys = [key for key in FLAGS._flags()] 43 | for key in keys: 44 | if key in keys_to_delete: 45 | delattr(FLAGS, key) 46 | 47 | 48 | delete_flags(flags.FLAGS, ["gin_config_file", "master_port"]) 49 | flags.DEFINE_string("gin_config_file", None, "Path to the config file.") 50 | flags.DEFINE_integer("master_port", 12355, "Master port.") 51 | FLAGS = flags.FLAGS # pyre-ignore [5] 52 | 53 | 54 | def mp_train_fn( 55 | rank: int, 56 | world_size: int, 57 | master_port: int, 58 | gin_config_file: Optional[str], 59 | ) -> None: 60 | if gin_config_file is not None: 61 | # Hack as absl doesn't support flag parsing inside multiprocessing. 62 | logging.info(f"Rank {rank}: loading gin config from {gin_config_file}") 63 | gin.parse_config_file(gin_config_file) 64 | 65 | train_fn(rank, world_size, master_port) 66 | 67 | 68 | def _main(argv) -> None: # pyre-ignore [2] 69 | world_size = torch.cuda.device_count() 70 | 71 | mp.set_start_method("forkserver") 72 | mp.spawn( 73 | mp_train_fn, 74 | args=(world_size, FLAGS.master_port, FLAGS.gin_config_file), 75 | nprocs=world_size, 76 | join=True, 77 | ) 78 | 79 | 80 | def main() -> None: 81 | app.run(_main) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /generative_recommenders/indexing/mips_top_k.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | from typing import Tuple 18 | 19 | import torch 20 | 21 | from generative_recommenders.indexing.candidate_index import TopKModule 22 | 23 | 24 | class MIPSTopKModule(TopKModule): 25 | 26 | def __init__( 27 | self, 28 | item_embeddings: torch.Tensor, 29 | item_ids: torch.Tensor, 30 | ) -> None: 31 | """ 32 | Args: 33 | item_embeddings: (1, X, D) 34 | item_ids: (1, X,) 35 | """ 36 | super().__init__() 37 | 38 | self._item_embeddings: torch.Tensor = item_embeddings 39 | self._item_ids: torch.Tensor = item_ids 40 | 41 | 42 | class MIPSBruteForceTopK(MIPSTopKModule): 43 | 44 | def __init__( 45 | self, 46 | item_embeddings: torch.Tensor, 47 | item_ids: torch.Tensor, 48 | ) -> None: 49 | super().__init__( 50 | item_embeddings=item_embeddings, 51 | item_ids=item_ids, 52 | ) 53 | del self._item_embeddings 54 | self._item_embeddings_t: torch.Tensor = item_embeddings.permute( 55 | 2, 1, 0 56 | ).squeeze(2) 57 | 58 | def forward( 59 | self, 60 | query_embeddings: torch.Tensor, 61 | k: int, 62 | sorted: bool = True, 63 | ) -> Tuple[torch.Tensor, torch.Tensor]: 64 | """ 65 | Args: 66 | query_embeddings: (B, ...). Implementation-specific. 67 | k: int. final top-k to return. 68 | sorted: bool. whether to sort final top-k results or not. 69 | 70 | Returns: 71 | Tuple of (top_k_scores x float, top_k_ids x int), both of shape (B, K,) 72 | """ 73 | # (B, X,) 74 | all_logits = torch.mm(query_embeddings, self._item_embeddings_t) 75 | top_k_logits, top_k_indices = torch.topk( 76 | all_logits, 77 | dim=1, 78 | k=k, 79 | sorted=sorted, 80 | largest=True, 81 | ) # (B, k,) 82 | return top_k_logits, self._item_ids.squeeze(0)[top_k_indices] 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FuXi-$\alpha$ 2 | 3 | This is the Pytorch implementation for our paper `FuXi-𝛼: Scaling Recommendation Model with Feature Interaction Enhanced Transformer` 4 | 5 | ## Paper Overview 6 | 7 | FuXi-𝛼 is a novel recommendation model that leverages an Adaptive Multi-channel Self-attention mechanism to distinctly model temporal, positional, and semantic features, along with a Multi-stage Feed-Forward Network to enhance implicit feature interactions. This model addresses the limitations of previous sequential recommendation models that inadequately integrate temporal and positional information. Our experiments demonstrate that FuXi-𝛼 outperforms existing models and its performance continues to improve as the model size increases. 8 | 9 | 10 | 11 | ## Getting started 12 | 13 | ### Public experiments 14 | 15 | To replicate the public experiments conducted in the traditional sequential recommender setting on MovieLens as described in the paper, please follow these steps: 16 | 17 | #### Install dependencies. 18 | 19 | Install PyTorch based on official instructions. Then, 20 | 21 | ``` 22 | pip3 install gin-config absl-py scikit-learn scipy matplotlib numpy apex hypothesis pandas fbgemm_gpu iopath 23 | ``` 24 | 25 | #### Download and preprocess data. 26 | 27 | ``` 28 | mkdir -p tmp/ && python3 preprocess_public_data.py 29 | ``` 30 | 31 | #### Run model training. 32 | 33 | A GPU with 24GB or more HBM should work for most datasets. 34 | 35 | ``` 36 | CUDA_VISIBLE_DEVICES=0 python3 main.py --gin_config_file=configs/ml-1m/fuxi-sampled-softmax-n128-final.gin --master_port=12345 37 | ``` 38 | 39 | Other configurations are included in configs/ml-1m, configs/ml-20m to make reproducing these experiments easier. 40 | 41 | #### Verify results. 42 | 43 | By default we write experimental logs to exps/. We can launch tensorboard with something like the following: 44 | 45 | ``` 46 | tensorboard --logdir ~/generative-recommenders/exps/ml-1m-l200/ --port 24001 --bind_all 47 | tensorboard --logdir ~/generative-recommenders/exps/ml-20m-l200/ --port 24001 --bind_all 48 | ``` 49 | 50 | ## Citation 51 | 52 | If you find FuXi-$\alpha$ useful, please cite it as: 53 | 54 | ``` 55 | @article{ye2025fuxi, 56 | title={FuXi-$$\backslash$alpha $: Scaling Recommendation Model with Feature Interaction Enhanced Transformer}, 57 | author={Ye, Yufei and Guo, Wei and Chin, Jin Yao and Wang, Hao and Zhu, Hong and Lin, Xi and Ye, Yuyang and Liu, Yong and Tang, Ruiming and Lian, Defu and others}, 58 | journal={arXiv preprint arXiv:2502.03036}, 59 | year={2025} 60 | } 61 | ``` 62 | 63 | > Thanks to the excellent code repository [HSTU](https://github.com/facebookresearch/generative-recommenders),which has saved us a lot of work in implementing the code. 64 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/similarity_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import abc 18 | from typing import Optional, Tuple 19 | 20 | import torch 21 | 22 | from generative_recommenders.modeling.ndp_module import NDPModule 23 | 24 | 25 | class InteractionModule(torch.nn.Module): 26 | 27 | @abc.abstractmethod 28 | def get_item_embeddings( 29 | self, 30 | item_ids: torch.Tensor, 31 | ) -> torch.Tensor: 32 | pass 33 | 34 | @abc.abstractmethod 35 | def get_item_sideinfo( 36 | self, 37 | item_ids: torch.Tensor, 38 | ) -> Optional[torch.Tensor]: 39 | pass 40 | 41 | @abc.abstractmethod 42 | def interaction( 43 | self, 44 | input_embeddings: torch.Tensor, # [B, D] 45 | target_ids: torch.Tensor, # [1, X] or [B, X] 46 | target_embeddings: Optional[torch.Tensor] = None, # [1, X, D'] or [B, X, D'] 47 | ) -> torch.Tensor: 48 | pass 49 | 50 | 51 | class GeneralizedInteractionModule(InteractionModule): 52 | def __init__( 53 | self, 54 | ndp_module: NDPModule, 55 | ) -> None: 56 | super().__init__() 57 | 58 | self._ndp_module: NDPModule = ndp_module 59 | 60 | @abc.abstractmethod 61 | def debug_str( 62 | self, 63 | ) -> str: 64 | pass 65 | 66 | def interaction( 67 | self, 68 | input_embeddings: torch.Tensor, 69 | target_ids: torch.Tensor, 70 | target_embeddings: Optional[torch.Tensor] = None, 71 | ) -> torch.Tensor: 72 | torch._assert( 73 | len(input_embeddings.size()) == 2, "len(input_embeddings.size()) must be 2" 74 | ) 75 | torch._assert(len(target_ids.size()) == 2, "len(target_ids.size()) must be 2") 76 | if target_embeddings is None: 77 | target_embeddings = self.get_item_embeddings(target_ids) 78 | torch._assert( 79 | len(target_embeddings.size()) == 3, 80 | "len(target_embeddings.size()) must be 3", 81 | ) 82 | 83 | with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): 84 | return self._ndp_module( 85 | input_embeddings=input_embeddings, # [B, self._input_embedding_dim] 86 | item_embeddings=target_embeddings, # [1/B, X, self._item_embedding_dim] 87 | item_sideinfo=self.get_item_sideinfo( 88 | item_ids=target_ids 89 | ), # [1/B, X, self._item_sideinfo_dim] 90 | item_ids=target_ids, 91 | precomputed_logits=None, 92 | ) 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Don't check in parsed data files 2 | tmp/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # Experiments 159 | exps/ 160 | ckpts/ -------------------------------------------------------------------------------- /generative_recommenders/modeling/sequential/features.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | from typing import Dict, NamedTuple, Optional, Tuple 18 | 19 | import torch 20 | 21 | 22 | class SequentialFeatures(NamedTuple): 23 | # (B,) x int64. Requires past_lengths[i] > 0 \forall i. 24 | past_lengths: torch.Tensor 25 | # (B, N,) x int64. 0 denotes valid ids. 26 | past_ids: torch.Tensor 27 | # (B, N, D) x float. 28 | past_embeddings: Optional[torch.Tensor] 29 | # Implementation-specific payloads. 30 | # e.g., past timestamps, past event_types (e.g., clicks, likes), etc. 31 | past_payloads: Dict[str, torch.Tensor] 32 | 33 | 34 | def movielens_seq_features_from_row( 35 | row: Dict[str, torch.Tensor], 36 | device: int, 37 | max_output_length: int, 38 | ) -> Tuple[SequentialFeatures, torch.Tensor, torch.Tensor]: 39 | historical_lengths = row["history_lengths"].to(device) # [B] 40 | historical_ids = row["historical_ids"].to(device) # [B, N] 41 | historical_ratings = row["historical_ratings"].to(device) 42 | historical_timestamps = row["historical_timestamps"].to(device) 43 | target_ids = row["target_ids"].to(device).unsqueeze(1) # [B, 1] 44 | target_ratings = row["target_ratings"].to(device).unsqueeze(1) 45 | target_timestamps = row["target_timestamps"].to(device).unsqueeze(1) 46 | if max_output_length > 0: 47 | B = historical_lengths.size(0) 48 | historical_ids = torch.cat( 49 | [ 50 | historical_ids, 51 | torch.zeros( 52 | (B, max_output_length), dtype=historical_ids.dtype, device=device 53 | ), 54 | ], 55 | dim=1, 56 | ) 57 | historical_ratings = torch.cat( 58 | [ 59 | historical_ratings, 60 | torch.zeros( 61 | (B, max_output_length), 62 | dtype=historical_ratings.dtype, 63 | device=device, 64 | ), 65 | ], 66 | dim=1, 67 | ) 68 | historical_timestamps = torch.cat( 69 | [ 70 | historical_timestamps, 71 | torch.zeros( 72 | (B, max_output_length), 73 | dtype=historical_timestamps.dtype, 74 | device=device, 75 | ), 76 | ], 77 | dim=1, 78 | ) 79 | historical_timestamps.scatter_( 80 | dim=1, 81 | index=historical_lengths.view(-1, 1), 82 | src=target_timestamps.view(-1, 1), 83 | ) 84 | # print(f"historical_ids.size()={historical_ids.size()}, historical_timestamps.size()={historical_timestamps.size()}") 85 | features = SequentialFeatures( 86 | past_lengths=historical_lengths, 87 | past_ids=historical_ids, 88 | past_embeddings=None, 89 | past_payloads={ 90 | "timestamps": historical_timestamps, 91 | "ratings": historical_ratings, 92 | }, 93 | ) 94 | return features, target_ids, target_ratings 95 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/sequential/embedding_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import abc 18 | 19 | import torch 20 | 21 | from generative_recommenders.modeling.initialization import truncated_normal 22 | 23 | 24 | class EmbeddingModule(torch.nn.Module): 25 | 26 | @abc.abstractmethod 27 | def debug_str(self) -> str: 28 | pass 29 | 30 | @abc.abstractmethod 31 | def get_item_embeddings(self, item_ids: torch.Tensor) -> torch.Tensor: 32 | pass 33 | 34 | @property 35 | @abc.abstractmethod 36 | def item_embedding_dim(self) -> int: 37 | pass 38 | 39 | 40 | class LocalEmbeddingModule(EmbeddingModule): 41 | 42 | def __init__( 43 | self, 44 | num_items: int, 45 | item_embedding_dim: int, 46 | ) -> None: 47 | super().__init__() 48 | 49 | self._item_embedding_dim: int = item_embedding_dim 50 | self._item_emb = torch.nn.Embedding( 51 | num_items + 1, item_embedding_dim, padding_idx=0 52 | ) 53 | self.reset_params() 54 | 55 | def debug_str(self) -> str: 56 | return f"local_emb_d{self._item_embedding_dim}" 57 | 58 | def reset_params(self) -> None: 59 | for name, params in self.named_parameters(): 60 | if "_item_emb" in name: 61 | print( 62 | f"Initialize {name} as truncated normal: {params.data.size()} params" 63 | ) 64 | truncated_normal(params, mean=0.0, std=0.02) 65 | else: 66 | print(f"Skipping initializing params {name} - not configured") 67 | 68 | def get_item_embeddings(self, item_ids: torch.Tensor) -> torch.Tensor: 69 | return self._item_emb(item_ids) 70 | 71 | @property 72 | def item_embedding_dim(self) -> int: 73 | return self._item_embedding_dim 74 | 75 | 76 | class CategoricalEmbeddingModule(EmbeddingModule): 77 | 78 | def __init__( 79 | self, 80 | num_items: int, 81 | item_embedding_dim: int, 82 | item_id_to_category_id: torch.Tensor, 83 | ) -> None: 84 | super().__init__() 85 | 86 | self._item_embedding_dim: int = item_embedding_dim 87 | self._item_emb: torch.nn.Embedding = torch.nn.Embedding( 88 | num_items + 1, item_embedding_dim, padding_idx=0 89 | ) 90 | self.register_buffer("_item_id_to_category_id", item_id_to_category_id) 91 | self.reset_params() 92 | 93 | def debug_str(self) -> str: 94 | return f"cat_emb_d{self._item_embedding_dim}" 95 | 96 | def reset_params(self) -> None: 97 | for name, params in self.named_parameters(): 98 | if "_item_emb" in name: 99 | print( 100 | f"Initialize {name} as truncated normal: {params.data.size()} params" 101 | ) 102 | truncated_normal(params, mean=0.0, std=0.02) 103 | else: 104 | print(f"Skipping initializing params {name} - not configured") 105 | 106 | def get_item_embeddings(self, item_ids: torch.Tensor) -> torch.Tensor: 107 | item_ids = self._item_id_to_category_id[(item_ids - 1).clamp(min=0)] + 1 108 | return self._item_emb(item_ids) 109 | 110 | @property 111 | def item_embedding_dim(self) -> int: 112 | return self._item_embedding_dim 113 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/sequential/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import torch 18 | 19 | 20 | def batch_gather_embeddings( 21 | rowwise_indices: torch.Tensor, 22 | embeddings: torch.Tensor, 23 | ) -> torch.Tensor: 24 | """ 25 | Args: 26 | rowwise_indices: (B, N) x int, where each entry is in [0, X). 27 | embeddings: (B, X, D,) x float. 28 | 29 | Returns: 30 | (B, N, D,) x float, embeddings corresponding to rowwise_indices. 31 | """ 32 | _, N = rowwise_indices.size() 33 | B, X, D = embeddings.size() 34 | flattened_indices = ( 35 | rowwise_indices 36 | + torch.arange( 37 | start=0, 38 | end=B, 39 | step=1, 40 | dtype=rowwise_indices.dtype, 41 | device=rowwise_indices.device, 42 | ) 43 | .unsqueeze(1) 44 | .expand(-1, N) 45 | * X 46 | ) 47 | return embeddings.view(-1, D)[flattened_indices, :].reshape( 48 | rowwise_indices.size() + (D,) 49 | ) 50 | 51 | 52 | def batch_scatter_embeddings( 53 | dst_embeddings: torch.Tensor, 54 | rowwise_indices: torch.Tensor, 55 | src_embeddings: torch.Tensor, 56 | ) -> None: 57 | """ 58 | Args: 59 | dst_embeddings: (B, N, D,) x float. 60 | rowwise_indices: (B,) x int, where each entry is in [0, N - 1). 61 | source_embeddings: (B, D,) x float. 62 | """ 63 | B, N, D = dst_embeddings.size() 64 | flattened_indices = rowwise_indices + torch.arange( 65 | start=0, 66 | end=B * N, 67 | step=N, 68 | dtype=rowwise_indices.dtype, 69 | device=rowwise_indices.device, 70 | ) 71 | dst_embeddings.view(B * N, D)[flattened_indices, :] = src_embeddings 72 | 73 | 74 | def get_current_embeddings( 75 | lengths: torch.Tensor, 76 | encoded_embeddings: torch.Tensor, 77 | ) -> torch.Tensor: 78 | """ 79 | Args: 80 | lengths: (B,) x int 81 | seq_embeddings: (B, N, D,) x float 82 | 83 | Returns: 84 | (B, D,) x float, where [i, :] == encoded_embeddings[i, lengths[i] - 1, :] 85 | """ 86 | B, N, D = encoded_embeddings.size() 87 | flattened_offsets = (lengths - 1) + torch.arange( 88 | start=0, end=B, step=1, dtype=lengths.dtype, device=lengths.device 89 | ) * N 90 | return encoded_embeddings.reshape(-1, D)[flattened_offsets, :].reshape(B, D) 91 | 92 | 93 | def jagged_or_dense_repeat_interleave_dim0( 94 | x: torch.Tensor, lengths: torch.Tensor, repeats: int 95 | ) -> torch.Tensor: 96 | if len(x.size()) == 3: 97 | return x.repeat_interleave(repeats, dim=0) 98 | else: 99 | assert len(x.size()) == 2, f"x.size() = {x.size()}" 100 | padded_x = torch.ops.fbgemm.jagged_to_padded_dense( 101 | values=x, 102 | offsets=[torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)], 103 | max_lengths=[lengths.max()], 104 | padding_value=0.0, 105 | ) 106 | lengths = lengths.repeat_interleave(repeats, dim=0) 107 | return torch.ops.fbgemm.dense_to_jagged( 108 | padded_x.repeat_interleave(repeats, dim=0), 109 | [torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)], 110 | )[0] 111 | 112 | 113 | def jagged_or_dense_index_select_dim0( 114 | x: torch.Tensor, lengths: torch.Tensor, indices: torch.Tensor 115 | ) -> torch.Tensor: 116 | if len(x.size()) == 3: 117 | return x[indices, :, :] 118 | else: 119 | assert len(x.size()) == 2, f"x.size() = {x.size()}" 120 | padded_x = torch.ops.fbgemm.jagged_to_padded_dense( 121 | values=x, 122 | offsets=[torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)], 123 | max_lengths=[lengths.max()], 124 | padding_value=0.0, 125 | ) 126 | return torch.ops.fbgemm.dense_to_jagged( 127 | padded_x[indices, :], 128 | [torch.ops.fbgemm.asynchronous_complete_cumsum(lengths[indices])], 129 | )[0] 130 | -------------------------------------------------------------------------------- /generative_recommenders/data/reco_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | from dataclasses import dataclass 18 | from typing import List 19 | 20 | import pandas as pd 21 | 22 | import torch 23 | 24 | from generative_recommenders.data.dataset import DatasetV2, MultiFileDatasetV2 25 | from generative_recommenders.data.item_features import ItemFeatures 26 | from generative_recommenders.data.preprocessor import get_common_preprocessors 27 | 28 | 29 | @dataclass 30 | class RecoDataset: 31 | max_sequence_length: int 32 | num_unique_items: int 33 | max_item_id: int 34 | all_item_ids: List[int] 35 | train_dataset: torch.utils.data.Dataset 36 | eval_dataset: torch.utils.data.Dataset 37 | 38 | 39 | def get_reco_dataset( 40 | dataset_name: str, 41 | max_sequence_length: int, 42 | chronological: bool, 43 | positional_sampling_ratio: float = 1.0, 44 | ) -> RecoDataset: 45 | if dataset_name == "ml-1m": 46 | dp = get_common_preprocessors()[dataset_name] 47 | train_dataset = DatasetV2( 48 | ratings_file=dp.output_format_csv(), 49 | padding_length=max_sequence_length + 1, # target 50 | ignore_last_n=1, 51 | chronological=chronological, 52 | sample_ratio=positional_sampling_ratio, 53 | ) 54 | eval_dataset = DatasetV2( 55 | ratings_file=dp.output_format_csv(), 56 | padding_length=max_sequence_length + 1, # target 57 | ignore_last_n=0, 58 | chronological=chronological, 59 | sample_ratio=1.0, # do not sample 60 | ) 61 | elif dataset_name == "ml-20m": 62 | dp = get_common_preprocessors()[dataset_name] 63 | train_dataset = DatasetV2( 64 | ratings_file=dp.output_format_csv(), 65 | padding_length=max_sequence_length + 1, # target 66 | ignore_last_n=1, 67 | chronological=chronological, 68 | ) 69 | eval_dataset = DatasetV2( 70 | ratings_file=dp.output_format_csv(), 71 | padding_length=max_sequence_length + 1, # target 72 | ignore_last_n=0, 73 | chronological=chronological, 74 | ) 75 | elif dataset_name == "ml-3b": 76 | dp = get_common_preprocessors()[dataset_name] 77 | train_dataset = MultiFileDatasetV2( 78 | file_prefix="tmp/ml-3b/16x32", 79 | num_files=16, 80 | padding_length=max_sequence_length + 1, # target 81 | ignore_last_n=1, 82 | chronological=chronological, 83 | ) 84 | eval_dataset = MultiFileDatasetV2( 85 | file_prefix="tmp/ml-3b/16x32", 86 | num_files=16, 87 | padding_length=max_sequence_length + 1, # target 88 | ignore_last_n=0, 89 | chronological=chronological, 90 | ) 91 | elif dataset_name == "amzn-books": 92 | dp = get_common_preprocessors()[dataset_name] 93 | train_dataset = DatasetV2( 94 | ratings_file=dp.output_format_csv(), 95 | padding_length=max_sequence_length + 1, # target 96 | ignore_last_n=1, 97 | shift_id_by=1, # [0..n-1] -> [1..n] 98 | chronological=chronological, 99 | ) 100 | eval_dataset = DatasetV2( 101 | ratings_file=dp.output_format_csv(), 102 | padding_length=max_sequence_length + 1, # target 103 | ignore_last_n=0, 104 | shift_id_by=1, # [0..n-1] -> [1..n] 105 | chronological=chronological, 106 | ) 107 | else: 108 | raise ValueError(f"Unknown dataset {dataset_name}") 109 | 110 | if dataset_name == "ml-1m" or dataset_name == "ml-20m": 111 | items = pd.read_csv(dp.processed_item_csv(), delimiter=",") 112 | max_jagged_dimension = 16 113 | expected_max_item_id = dp.expected_max_item_id() 114 | assert expected_max_item_id is not None 115 | item_features: ItemFeatures = ItemFeatures( 116 | max_ind_range=[63, 16383, 511], 117 | num_items=expected_max_item_id + 1, 118 | max_jagged_dimension=max_jagged_dimension, 119 | lengths=[ 120 | torch.zeros((expected_max_item_id + 1,), dtype=torch.int64), 121 | torch.zeros((expected_max_item_id + 1,), dtype=torch.int64), 122 | torch.zeros((expected_max_item_id + 1,), dtype=torch.int64), 123 | ], 124 | values=[ 125 | torch.zeros( 126 | (expected_max_item_id + 1, max_jagged_dimension), 127 | dtype=torch.int64, 128 | ), 129 | torch.zeros( 130 | (expected_max_item_id + 1, max_jagged_dimension), 131 | dtype=torch.int64, 132 | ), 133 | torch.zeros( 134 | (expected_max_item_id + 1, max_jagged_dimension), 135 | dtype=torch.int64, 136 | ), 137 | ], 138 | ) 139 | all_item_ids = [] 140 | for df_index, row in items.iterrows(): 141 | # print(f"index {df_index}: {row}") 142 | movie_id = int(row["movie_id"]) 143 | genres = row["genres"].split("|") 144 | titles = row["cleaned_title"].split(" ") 145 | # print(f"{index}: genres{genres}, title{titles}") 146 | genres_vector = [hash(x) % item_features.max_ind_range[0] for x in genres] 147 | titles_vector = [hash(x) % item_features.max_ind_range[1] for x in titles] 148 | years_vector = [hash(row["year"]) % item_features.max_ind_range[2]] 149 | item_features.lengths[0][movie_id] = min( 150 | len(genres_vector), max_jagged_dimension 151 | ) 152 | item_features.lengths[1][movie_id] = min( 153 | len(titles_vector), max_jagged_dimension 154 | ) 155 | item_features.lengths[2][movie_id] = min( 156 | len(years_vector), max_jagged_dimension 157 | ) 158 | for f, f_values in enumerate([genres_vector, titles_vector, years_vector]): 159 | for j in range(min(len(f_values), max_jagged_dimension)): 160 | item_features.values[f][movie_id][j] = f_values[j] 161 | all_item_ids.append(movie_id) 162 | max_item_id = dp.expected_max_item_id() 163 | for x in all_item_ids: 164 | assert x > 0, "x in all_item_ids should be positive" 165 | else: 166 | # expected_max_item_id and item_features are not set for Amazon datasets. 167 | item_features = None 168 | max_item_id = dp.expected_num_unique_items() 169 | all_item_ids = [x + 1 for x in range(max_item_id)] # pyre-ignore [6] 170 | 171 | return RecoDataset( 172 | max_sequence_length=max_sequence_length, 173 | num_unique_items=dp.expected_num_unique_items(), # pyre-ignore [6] 174 | max_item_id=max_item_id, # pyre-ignore [6] 175 | all_item_ids=all_item_ids, 176 | train_dataset=train_dataset, 177 | eval_dataset=eval_dataset, 178 | ) 179 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/sequential/encoder_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import gin 18 | from generative_recommenders.modeling.sequential.embedding_modules import ( 19 | EmbeddingModule, 20 | ) 21 | from generative_recommenders.modeling.sequential.hstu import HSTU 22 | from generative_recommenders.modeling.sequential.fuxi import FuXi 23 | from generative_recommenders.modeling.sequential.input_features_preprocessors import ( 24 | InputFeaturesPreprocessorModule, 25 | ) 26 | from generative_recommenders.modeling.sequential.output_postprocessors import ( 27 | OutputPostprocessorModule, 28 | ) 29 | from generative_recommenders.modeling.sequential.sasrec import SASRec 30 | 31 | from generative_recommenders.modeling.similarity_module import ( 32 | GeneralizedInteractionModule, 33 | InteractionModule, 34 | ) 35 | 36 | 37 | @gin.configurable 38 | def sasrec_encoder( 39 | max_sequence_length: int, 40 | max_output_length: int, 41 | embedding_module: EmbeddingModule, 42 | interaction_module: InteractionModule, 43 | input_preproc_module: InputFeaturesPreprocessorModule, 44 | output_postproc_module: OutputPostprocessorModule, 45 | activation_checkpoint: bool, 46 | verbose: bool, 47 | ffn_hidden_dim: int = 64, 48 | ffn_activation_fn: str = "relu", 49 | ffn_dropout_rate: float = 0.2, 50 | num_blocks: int = 2, 51 | num_heads: int = 1, 52 | ) -> GeneralizedInteractionModule: 53 | return SASRec( 54 | embedding_module=embedding_module, 55 | max_sequence_len=max_sequence_length, 56 | max_output_len=max_output_length, 57 | embedding_dim=embedding_module.item_embedding_dim, 58 | ffn_hidden_dim=ffn_hidden_dim, 59 | ffn_activation_fn=ffn_activation_fn, 60 | ffn_dropout_rate=ffn_dropout_rate, 61 | num_blocks=num_blocks, 62 | num_heads=num_heads, 63 | similarity_module=interaction_module, # pyre-ignore [6] 64 | input_features_preproc_module=input_preproc_module, 65 | output_postproc_module=output_postproc_module, 66 | activation_checkpoint=activation_checkpoint, 67 | verbose=verbose, 68 | ) 69 | 70 | 71 | @gin.configurable 72 | def hstu_encoder( 73 | max_sequence_length: int, 74 | max_output_length: int, 75 | embedding_module: EmbeddingModule, 76 | interaction_module: InteractionModule, 77 | input_preproc_module: InputFeaturesPreprocessorModule, 78 | output_postproc_module: OutputPostprocessorModule, 79 | activation_checkpoint: bool, 80 | verbose: bool, 81 | num_blocks: int = 2, 82 | num_heads: int = 1, 83 | dqk: int = 64, 84 | dv: int = 64, 85 | linear_dropout_rate: float = 0.0, 86 | attn_dropout_rate: float = 0.0, 87 | normalization: str = "rel_bias", 88 | linear_config: str = "uvqk", 89 | linear_activation: str = "silu", 90 | concat_ua: bool = False, 91 | enable_relative_attention_bias: bool = True, 92 | ) -> GeneralizedInteractionModule: 93 | return HSTU( 94 | embedding_module=embedding_module, 95 | similarity_module=interaction_module, # pyre-ignore [6] 96 | input_features_preproc_module=input_preproc_module, 97 | output_postproc_module=output_postproc_module, 98 | max_sequence_len=max_sequence_length, 99 | max_output_len=max_output_length, 100 | embedding_dim=embedding_module.item_embedding_dim, 101 | num_blocks=num_blocks, 102 | num_heads=num_heads, 103 | attention_dim=dqk, 104 | linear_dim=dv, 105 | linear_dropout_rate=linear_dropout_rate, 106 | attn_dropout_rate=attn_dropout_rate, 107 | linear_config=linear_config, 108 | linear_activation=linear_activation, 109 | normalization=normalization, 110 | concat_ua=concat_ua, 111 | enable_relative_attention_bias=enable_relative_attention_bias, 112 | verbose=verbose, 113 | ) 114 | 115 | # TODO 116 | 117 | @gin.configurable 118 | def fuxi_encoder( 119 | max_sequence_length: int, 120 | max_output_length: int, 121 | embedding_module: EmbeddingModule, 122 | interaction_module: InteractionModule, 123 | input_preproc_module: InputFeaturesPreprocessorModule, 124 | output_postproc_module: OutputPostprocessorModule, 125 | activation_checkpoint: bool, 126 | verbose: bool, 127 | num_blocks: int = 2, 128 | num_heads: int = 1, 129 | dqk: int = 64, 130 | dv: int = 64, 131 | linear_dropout_rate: float = 0.0, 132 | attn_dropout_rate: float = 0.0, 133 | ffn_multiply: int = 1, 134 | ffn_single_stage: bool = False, 135 | normalization: str = "rel_bias", 136 | linear_config: str = "uvqk", 137 | linear_activation: str = "silu", 138 | enable_relative_attention_bias: bool = True, 139 | ) -> GeneralizedInteractionModule: 140 | return FuXi( 141 | embedding_module=embedding_module, 142 | similarity_module=interaction_module, # pyre-ignore [6] 143 | input_features_preproc_module=input_preproc_module, 144 | output_postproc_module=output_postproc_module, 145 | max_sequence_len=max_sequence_length, 146 | max_output_len=max_output_length, 147 | embedding_dim=embedding_module.item_embedding_dim, 148 | num_blocks=num_blocks, 149 | num_heads=num_heads, 150 | attention_dim=dqk, 151 | linear_dim=dv, 152 | linear_dropout_rate=linear_dropout_rate, 153 | attn_dropout_rate=attn_dropout_rate, 154 | ffn_multiply=ffn_multiply, 155 | ffn_single_stage=ffn_single_stage, 156 | linear_config=linear_config, 157 | linear_activation=linear_activation, 158 | normalization=normalization, 159 | enable_relative_attention_bias=enable_relative_attention_bias, 160 | verbose=verbose, 161 | ) 162 | 163 | @gin.configurable 164 | def get_sequential_encoder( 165 | module_type: str, 166 | max_sequence_length: int, 167 | max_output_length: int, 168 | embedding_module: EmbeddingModule, 169 | interaction_module: InteractionModule, 170 | input_preproc_module: InputFeaturesPreprocessorModule, 171 | output_postproc_module: OutputPostprocessorModule, 172 | verbose: bool, 173 | activation_checkpoint: bool = False, 174 | ) -> GeneralizedInteractionModule: 175 | if module_type == "SASRec": 176 | model = sasrec_encoder( 177 | max_sequence_length=max_sequence_length, 178 | max_output_length=max_output_length, 179 | embedding_module=embedding_module, 180 | interaction_module=interaction_module, 181 | input_preproc_module=input_preproc_module, 182 | output_postproc_module=output_postproc_module, 183 | activation_checkpoint=activation_checkpoint, 184 | verbose=verbose, 185 | ) 186 | elif module_type == "HSTU": 187 | model = hstu_encoder( 188 | max_sequence_length=max_sequence_length, 189 | max_output_length=max_output_length, 190 | embedding_module=embedding_module, 191 | interaction_module=interaction_module, 192 | input_preproc_module=input_preproc_module, 193 | output_postproc_module=output_postproc_module, 194 | activation_checkpoint=activation_checkpoint, 195 | verbose=verbose, 196 | ) 197 | elif module_type == 'FuXi' : 198 | model = fuxi_encoder( 199 | max_sequence_length=max_sequence_length, 200 | max_output_length=max_output_length, 201 | embedding_module=embedding_module, 202 | interaction_module=interaction_module, 203 | input_preproc_module=input_preproc_module, 204 | output_postproc_module=output_postproc_module, 205 | activation_checkpoint=activation_checkpoint, 206 | verbose=verbose, 207 | ) 208 | else: 209 | raise ValueError(f"Unsupported module_type {module_type}") 210 | return model 211 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/sequential/input_features_preprocessors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import abc 18 | import math 19 | from typing import Dict, Tuple 20 | 21 | import torch 22 | 23 | from generative_recommenders.modeling.initialization import truncated_normal 24 | 25 | 26 | class InputFeaturesPreprocessorModule(torch.nn.Module): 27 | 28 | @abc.abstractmethod 29 | def debug_str(self) -> str: 30 | pass 31 | 32 | @abc.abstractmethod 33 | def forward( 34 | self, 35 | past_lengths: torch.Tensor, 36 | past_ids: torch.Tensor, 37 | past_embeddings: torch.Tensor, 38 | past_payloads: Dict[str, torch.Tensor], 39 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 40 | pass 41 | 42 | 43 | class LearnablePositionalEmbeddingInputFeaturesPreprocessor( 44 | InputFeaturesPreprocessorModule 45 | ): 46 | 47 | def __init__( 48 | self, 49 | max_sequence_len: int, 50 | embedding_dim: int, 51 | dropout_rate: float, 52 | ) -> None: 53 | super().__init__() 54 | 55 | self._embedding_dim: int = embedding_dim 56 | self._pos_emb: torch.nn.Embedding = torch.nn.Embedding( 57 | max_sequence_len, 58 | self._embedding_dim, 59 | ) 60 | self._dropout_rate: float = dropout_rate 61 | self._emb_dropout = torch.nn.Dropout(p=dropout_rate) 62 | self.reset_state() 63 | 64 | def debug_str(self) -> str: 65 | return f"posi_d{self._dropout_rate}" 66 | 67 | def reset_state(self) -> None: 68 | truncated_normal( 69 | self._pos_emb.weight.data, 70 | mean=0.0, 71 | std=math.sqrt(1.0 / self._embedding_dim), 72 | ) 73 | 74 | def forward( 75 | self, 76 | past_lengths: torch.Tensor, 77 | past_ids: torch.Tensor, 78 | past_embeddings: torch.Tensor, 79 | past_payloads: Dict[str, torch.Tensor], 80 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 81 | B, N = past_ids.size() 82 | D = past_embeddings.size(-1) 83 | 84 | user_embeddings = past_embeddings * (self._embedding_dim**0.5) + self._pos_emb( 85 | torch.arange(N, device=past_ids.device).unsqueeze(0).repeat(B, 1) 86 | ) 87 | user_embeddings = self._emb_dropout(user_embeddings) 88 | 89 | valid_mask = (past_ids != 0).unsqueeze(-1).float() # [B, N, 1] 90 | user_embeddings *= valid_mask 91 | return past_lengths, user_embeddings, valid_mask 92 | 93 | 94 | class LearnablePositionalEmbeddingRatedInputFeaturesPreprocessor( 95 | InputFeaturesPreprocessorModule 96 | ): 97 | 98 | def __init__( 99 | self, 100 | max_sequence_len: int, 101 | item_embedding_dim: int, 102 | dropout_rate: float, 103 | rating_embedding_dim: int, 104 | num_ratings: int, 105 | ) -> None: 106 | super().__init__() 107 | 108 | self._embedding_dim: int = item_embedding_dim + rating_embedding_dim 109 | self._pos_emb: torch.nn.Embedding = torch.nn.Embedding( 110 | max_sequence_len, 111 | self._embedding_dim, 112 | ) 113 | self._dropout_rate: float = dropout_rate 114 | self._emb_dropout = torch.nn.Dropout(p=dropout_rate) 115 | self._rating_emb: torch.nn.Embedding = torch.nn.Embedding( 116 | num_ratings, 117 | rating_embedding_dim, 118 | ) 119 | self.reset_state() 120 | 121 | def debug_str(self) -> str: 122 | return f"posir_d{self._dropout_rate}" 123 | 124 | def reset_state(self) -> None: 125 | truncated_normal( 126 | self._pos_emb.weight.data, 127 | mean=0.0, 128 | std=math.sqrt(1.0 / self._embedding_dim), 129 | ) 130 | truncated_normal( 131 | self._rating_emb.weight.data, 132 | mean=0.0, 133 | std=math.sqrt(1.0 / self._embedding_dim), 134 | ) 135 | 136 | def forward( 137 | self, 138 | past_lengths: torch.Tensor, 139 | past_ids: torch.Tensor, 140 | past_embeddings: torch.Tensor, 141 | past_payloads: Dict[str, torch.Tensor], 142 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 143 | B, N = past_ids.size() 144 | 145 | user_embeddings = torch.cat( 146 | [past_embeddings, self._rating_emb(past_payloads["ratings"].int())], 147 | dim=-1, 148 | ) * (self._embedding_dim**0.5) + self._pos_emb( 149 | torch.arange(N, device=past_ids.device).unsqueeze(0).repeat(B, 1) 150 | ) 151 | user_embeddings = self._emb_dropout(user_embeddings) 152 | 153 | valid_mask = (past_ids != 0).unsqueeze(-1).float() # [B, N, 1] 154 | user_embeddings *= valid_mask 155 | return past_lengths, user_embeddings, valid_mask 156 | 157 | 158 | class CombinedItemAndRatingInputFeaturesPreprocessor(InputFeaturesPreprocessorModule): 159 | 160 | def __init__( 161 | self, 162 | max_sequence_len: int, 163 | item_embedding_dim: int, 164 | dropout_rate: float, 165 | rating_embedding_dim: int, 166 | num_ratings: int, 167 | ) -> None: 168 | super().__init__() 169 | 170 | self._embedding_dim: int = item_embedding_dim 171 | self._rating_embedding_dim: int = rating_embedding_dim 172 | # Due to [item_0, rating_0, item_1, rating_1, ...] 173 | self._pos_emb: torch.nn.Embedding = torch.nn.Embedding( 174 | max_sequence_len * 2, 175 | self._embedding_dim, 176 | ) 177 | self._dropout_rate: float = dropout_rate 178 | self._emb_dropout = torch.nn.Dropout(p=dropout_rate) 179 | self._rating_emb: torch.nn.Embedding = torch.nn.Embedding( 180 | num_ratings, 181 | rating_embedding_dim, 182 | ) 183 | self.reset_state() 184 | 185 | def debug_str(self) -> str: 186 | return f"combir_d{self._dropout_rate}" 187 | 188 | def reset_state(self) -> None: 189 | truncated_normal( 190 | self._pos_emb.weight.data, 191 | mean=0.0, 192 | std=math.sqrt(1.0 / self._embedding_dim), 193 | ) 194 | truncated_normal( 195 | self._rating_emb.weight.data, 196 | mean=0.0, 197 | std=math.sqrt(1.0 / self._embedding_dim), 198 | ) 199 | 200 | def get_preprocessed_ids( 201 | self, 202 | past_lengths: torch.Tensor, 203 | past_ids: torch.Tensor, 204 | past_embeddings: torch.Tensor, 205 | past_payloads: Dict[str, torch.Tensor], 206 | ) -> torch.Tensor: 207 | """ 208 | Returns (B, N * 2,) x int64. 209 | """ 210 | B, N = past_ids.size() 211 | return torch.cat( 212 | [ 213 | past_ids.unsqueeze(2), # (B, N, 1) 214 | past_payloads["ratings"].to(past_ids.dtype).unsqueeze(2), 215 | ], 216 | dim=2, 217 | ).reshape(B, N * 2) 218 | 219 | def get_preprocessed_masks( 220 | self, 221 | past_lengths: torch.Tensor, 222 | past_ids: torch.Tensor, 223 | past_embeddings: torch.Tensor, 224 | past_payloads: Dict[str, torch.Tensor], 225 | ) -> torch.Tensor: 226 | """ 227 | Returns (B, N * 2,) x bool. 228 | """ 229 | B, N = past_ids.size() 230 | return (past_ids != 0).unsqueeze(2).expand(-1, -1, 2).reshape(B, N * 2) 231 | 232 | def forward( 233 | self, 234 | past_lengths: torch.Tensor, 235 | past_ids: torch.Tensor, 236 | past_embeddings: torch.Tensor, 237 | past_payloads: Dict[str, torch.Tensor], 238 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 239 | B, N = past_ids.size() 240 | D = past_embeddings.size(-1) 241 | 242 | user_embeddings = torch.cat( 243 | [ 244 | past_embeddings, # (B, N, D) 245 | self._rating_emb(past_payloads["ratings"].int()), 246 | ], 247 | dim=2, 248 | ) * (self._embedding_dim**0.5) 249 | user_embeddings = user_embeddings.view(B, N * 2, D) 250 | user_embeddings = user_embeddings + self._pos_emb( 251 | torch.arange(N * 2, device=past_ids.device).unsqueeze(0).repeat(B, 1) 252 | ) 253 | user_embeddings = self._emb_dropout(user_embeddings) 254 | 255 | valid_mask = ( 256 | self.get_preprocessed_masks( 257 | past_lengths, 258 | past_ids, 259 | past_embeddings, 260 | past_payloads, 261 | ) 262 | .unsqueeze(2) 263 | .float() 264 | ) # (B, N * 2, 1,) 265 | user_embeddings *= valid_mask 266 | return past_lengths * 2, user_embeddings, valid_mask 267 | -------------------------------------------------------------------------------- /generative_recommenders/indexing/candidate_index.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import abc 18 | from typing import Optional, Tuple 19 | 20 | import torch 21 | 22 | from generative_recommenders.modeling.sequential.utils import batch_gather_embeddings 23 | 24 | 25 | class TopKModule(torch.nn.Module): 26 | 27 | @abc.abstractmethod 28 | def forward( 29 | self, 30 | query_embeddings: torch.Tensor, 31 | k: int, 32 | sorted: bool = True, 33 | ) -> Tuple[torch.Tensor, torch.Tensor]: 34 | """ 35 | Args: 36 | query_embeddings: (B, X, ...). Implementation-specific. 37 | k: int. top k to return. 38 | sorted: bool. 39 | 40 | Returns: 41 | Tuple of (top_k_scores, top_k_ids), both of shape (B, K,) 42 | """ 43 | pass 44 | 45 | 46 | class CandidateIndex(object): 47 | 48 | def __init__( 49 | self, 50 | ids: torch.Tensor, 51 | embeddings: torch.Tensor, 52 | invalid_ids: Optional[torch.Tensor] = None, 53 | debug_path: Optional[str] = None, 54 | ) -> None: 55 | super().__init__() 56 | 57 | self._ids: torch.Tensor = ids 58 | self._embeddings: torch.Tensor = embeddings 59 | self._invalid_ids: Optional[torch.Tensor] = invalid_ids 60 | self._debug_path: Optional[str] = debug_path 61 | 62 | @property 63 | def ids(self) -> torch.Tensor: 64 | """ 65 | Returns: 66 | (1, X) or (B, X), where valid ids are positive integers. 67 | """ 68 | return self._ids 69 | 70 | @property 71 | def num_objects(self) -> int: 72 | return self._ids.size(1) 73 | 74 | @property 75 | def embeddings(self) -> torch.Tensor: 76 | """ 77 | Returns: 78 | (1, X, D) or (B, X, D) with the same shape as `ids'. 79 | """ 80 | return self._embeddings 81 | 82 | def filter_invalid_ids( 83 | self, 84 | invalid_ids: torch.Tensor, 85 | ) -> "CandidateIndex": 86 | """ 87 | Filters invalid_ids (batch dimension dependent) from the current index. 88 | 89 | Args: 90 | invalid_ids: (B, N) x int64. 91 | 92 | Returns: 93 | CandidateIndex with invalid_ids filtered. 94 | """ 95 | X = self._ids.size(1) 96 | # if self._ids.size(0) == 1 and X <= 100000: 97 | if self._ids.size(0) == 1: 98 | # ((1, X, 1) == (B, 1, N)) -> (B, X) 99 | invalid_mask, _ = (self._ids.unsqueeze(2) == invalid_ids.unsqueeze(1)).max( 100 | dim=2 101 | ) 102 | lengths = (~invalid_mask).int().sum(-1) # (B,) 103 | valid_1d_mask = (~invalid_mask).view(-1) 104 | B: int = lengths.size(0) 105 | D: int = self._embeddings.size(-1) 106 | jagged_ids = self._ids.expand(B, -1).reshape(-1)[valid_1d_mask] 107 | jagged_embeddings = self._embeddings.expand(B, -1, -1).reshape(-1, D)[ 108 | valid_1d_mask 109 | ] 110 | X_prime: int = lengths.max(-1)[0].item() 111 | jagged_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) 112 | return CandidateIndex( 113 | ids=torch.ops.fbgemm.jagged_to_padded_dense( 114 | values=jagged_ids.unsqueeze(-1), 115 | offsets=[jagged_offsets], 116 | max_lengths=[X_prime], 117 | padding_value=0, 118 | ).squeeze(-1), 119 | embeddings=torch.ops.fbgemm.jagged_to_padded_dense( 120 | values=jagged_embeddings, 121 | offsets=[jagged_offsets], 122 | max_lengths=[X_prime], 123 | padding_value=0.0, 124 | ), 125 | debug_path=self._debug_path, 126 | ) 127 | else: 128 | assert self._invalid_ids == None 129 | return CandidateIndex( 130 | ids=self.ids, 131 | embeddings=self.embeddings, 132 | invalid_ids=invalid_ids, 133 | debug_path=self._debug_path, 134 | ) 135 | 136 | def get_top_k_outputs( 137 | self, 138 | query_embeddings: torch.Tensor, 139 | k: int, 140 | top_k_module: TopKModule, 141 | invalid_ids: Optional[torch.Tensor], 142 | r: int = 1, 143 | return_embeddings: bool = False, 144 | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 145 | """ 146 | Gets top-k outputs specified by `policy_fn', while filtering out 147 | invalid ids per row as specified by `invalid_ids'. 148 | 149 | Args: 150 | k: int. top k to return. 151 | policy_fn: lambda that takes in item-side embeddings (B, X, D,) and user-side 152 | embeddings (B * r, ...), and returns predictions (unnormalized logits) 153 | of shape (B * r, X,). 154 | invalid_ids: (B * r, N_0) x int64. The list of ids (if > 0) to filter from 155 | results if present. Expect N_0 to be a small constant. 156 | return_embeddings: bool if we should additionally return embeddings for the 157 | top k results. 158 | 159 | Returns: 160 | A tuple of (top_k_ids, top_k_prs, top_k_embeddings) of shape (B * r, k, ...). 161 | """ 162 | B: int = query_embeddings.size(0) 163 | max_num_invalid_ids = 0 164 | if invalid_ids is not None: 165 | max_num_invalid_ids = invalid_ids.size(1) 166 | 167 | k_prime = min(k + max_num_invalid_ids, self.num_objects) 168 | top_k_prime_scores, top_k_prime_ids = top_k_module( 169 | query_embeddings=query_embeddings, k=k_prime 170 | ) 171 | """ 172 | B: int = candidate_logits.size(0) 173 | candidate_logits, debug_info = policy_fn(self.ids, self.embeddings) # (B, X,) 174 | if self._debug_path is not None: 175 | # print(f"Saving debug dict to {self._debug_path}") 176 | for debug_k, debug_v in debug_info.items(): 177 | with open(self._debug_path + "." + debug_k, 'wb') as f: 178 | # np doesn't work with bf16 179 | np.save(f, debug_v.to(torch.float32).cpu().numpy()) 180 | 181 | # assume that the top X ids can approximately capture the probability mass. 182 | candidate_prs = F.softmax(candidate_logits, dim=1) 183 | 184 | top_k_prime_prs, top_k_prime_indices = torch.topk( 185 | candidate_prs, dim=1, k=min(k + max_num_invalid_ids, self.num_objects), 186 | ) # [B, K + N_0] 187 | expanded_ids = self.ids.repeat_interleave(r, dim=0) if r > 1 else self.ids 188 | # TODO revisit. For amzn-books only 189 | if expanded_ids.size(0) == 1: 190 | expanded_ids = expanded_ids.expand(B, -1) 191 | 192 | top_k_prime_ids = torch.gather(expanded_ids, dim=1, index=top_k_prime_indices) # [B * r, K + N_0] 193 | """ 194 | # Masks out invalid items rowwise. 195 | if invalid_ids is not None: 196 | id_is_valid = ~( 197 | (top_k_prime_ids.unsqueeze(2) == invalid_ids.unsqueeze(1)).max(2)[0] 198 | ) # [B, K + N_0] 199 | id_is_valid = torch.logical_and( 200 | id_is_valid, torch.cumsum(id_is_valid.int(), dim=1) <= k 201 | ) 202 | # [[1, 0, 1, 0], [0, 1, 1, 1]], k=2 -> [[0, 2], [1, 2]] 203 | top_k_rowwise_offsets = torch.nonzero(id_is_valid, as_tuple=True)[1].view( 204 | -1, k 205 | ) 206 | top_k_scores = torch.gather( 207 | top_k_prime_scores, dim=1, index=top_k_rowwise_offsets 208 | ) 209 | top_k_ids = torch.gather( 210 | top_k_prime_ids, dim=1, index=top_k_rowwise_offsets 211 | ) 212 | else: 213 | # id_is_valid = torch.ones_like(top_k_prime_indices, dtype=torch.bool, device=expanded_ids.device) 214 | top_k_scores = top_k_prime_scores 215 | top_k_ids = top_k_prime_ids 216 | 217 | # top_k_indices = torch.gather(top_k_prime_indices, dim=1, index=top_k_rowwise_offsets) 218 | # top_k_prs = torch.gather(top_k_prime_prs, dim=1, index=top_k_rowwise_offsets) 219 | # top_k_ids = torch.gather(expanded_ids, dim=1, index=top_k_indices) # [B * r, k] 220 | # TODO: this should be decoupled from candidate_index. 221 | if return_embeddings: 222 | # TODO: get rid of repeat_interleave in the final version. 223 | expanded_embeddings = ( 224 | self.embeddings.repeat_interleave(r, dim=0) 225 | if r > 1 226 | else self.embeddings 227 | ) 228 | top_k_embeddings = batch_gather_embeddings( 229 | rowwise_indices=top_k_indices, embeddings=expanded_embeddings # pyre-ignore[10] 230 | ) 231 | else: 232 | top_k_embeddings = None 233 | return top_k_ids, top_k_scores, top_k_embeddings 234 | 235 | def apply_object_filter(self) -> "CandidateIndex": 236 | """ 237 | Applies general per batch filters. 238 | """ 239 | raise NotImplementedError("not implemented.") 240 | -------------------------------------------------------------------------------- /generative_recommenders/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import csv 18 | import linecache 19 | 20 | from typing import Dict, List, Optional, Tuple 21 | 22 | import numpy as np 23 | import pandas as pd 24 | import torch 25 | 26 | import logging 27 | from tqdm import tqdm 28 | 29 | class DatasetV2(torch.utils.data.Dataset): 30 | """In reverse chronological order.""" 31 | 32 | def __init__( 33 | self, 34 | ratings_file: str, 35 | padding_length: int, 36 | ignore_last_n: int, # used for creating train/valid/test sets 37 | shift_id_by: int = 0, 38 | chronological: bool = False, 39 | sample_ratio: float = 1.0, 40 | ) -> None: 41 | """ 42 | Args: 43 | csv_file (string): Path to the csv file. 44 | """ 45 | super().__init__() 46 | 47 | self.ratings_frame: pd.DataFrame = pd.read_csv( 48 | ratings_file, 49 | delimiter=",", 50 | # iterator=True, 51 | ) 52 | self._padding_length: int = padding_length 53 | self._ignore_last_n: int = ignore_last_n 54 | self._cache: Dict[int, Dict[str, torch.Tensor]] = dict() 55 | self._shift_id_by: int = shift_id_by 56 | self._chronological: bool = chronological 57 | self._sample_ratio: float = sample_ratio 58 | for i in tqdm(range(len(self))) : 59 | self.__getitem__(i) 60 | 61 | def __len__(self) -> int: 62 | return len(self.ratings_frame) 63 | 64 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 65 | if idx in self._cache.keys(): 66 | return self._cache[idx] 67 | data = self.ratings_frame.iloc[idx] 68 | sample = self.load_item(data) 69 | self._cache[idx] = sample 70 | return sample 71 | 72 | def load_item(self, data) -> Dict[str, torch.Tensor]: 73 | user_id = data.user_id 74 | 75 | def eval_as_list(x: str, ignore_last_n: int) -> List[int]: 76 | y = eval(x) 77 | y_list = [y] if type(y) == int else list(y) 78 | if ignore_last_n > 0: 79 | # for training data creation 80 | y_list = y_list[:-ignore_last_n] 81 | return y_list 82 | 83 | def eval_int_list( 84 | x: str, 85 | target_len: int, 86 | ignore_last_n: int, 87 | shift_id_by: int, 88 | sampling_kept_mask: Optional[List[bool]], 89 | ) -> Tuple[List[int], int]: 90 | y = eval_as_list(x, ignore_last_n=ignore_last_n) 91 | if sampling_kept_mask is not None: 92 | y = [x for x, kept in zip(y, sampling_kept_mask) if kept] 93 | y_len = len(y) 94 | y.reverse() 95 | if shift_id_by > 0: 96 | y = [x + shift_id_by for x in y] 97 | return y, y_len 98 | 99 | if self._sample_ratio < 1.0: 100 | raw_length = len(eval_as_list(data.sequence_item_ids, self._ignore_last_n)) 101 | sampling_kept_mask = ( 102 | torch.rand((raw_length,), dtype=torch.float32) < self._sample_ratio 103 | ).tolist() 104 | else: 105 | sampling_kept_mask = None 106 | 107 | movie_history, movie_history_len = eval_int_list( 108 | data.sequence_item_ids, 109 | self._padding_length, 110 | self._ignore_last_n, 111 | shift_id_by=self._shift_id_by, 112 | sampling_kept_mask=sampling_kept_mask, 113 | ) 114 | movie_history_ratings, ratings_len = eval_int_list( 115 | data.sequence_ratings, 116 | self._padding_length, 117 | self._ignore_last_n, 118 | 0, 119 | sampling_kept_mask=sampling_kept_mask, 120 | ) 121 | movie_timestamps, timestamps_len = eval_int_list( 122 | data.sequence_timestamps, 123 | self._padding_length, 124 | self._ignore_last_n, 125 | 0, 126 | sampling_kept_mask=sampling_kept_mask, 127 | ) 128 | assert ( 129 | movie_history_len == timestamps_len 130 | ), f"history len {movie_history_len} differs from timestamp len {timestamps_len}." 131 | assert ( 132 | movie_history_len == ratings_len 133 | ), f"history len {movie_history_len} differs from ratings len {ratings_len}." 134 | 135 | def _truncate_or_pad_seq( 136 | y: List[int], target_len: int, chronological: bool 137 | ) -> List[int]: 138 | y_len = len(y) 139 | if y_len < target_len: 140 | y = y + [0] * (target_len - y_len) 141 | else: 142 | if not chronological: 143 | y = y[:target_len] 144 | else: 145 | y = y[-target_len:] 146 | assert len(y) == target_len 147 | return y 148 | 149 | historical_ids = movie_history[1:] 150 | historical_ratings = movie_history_ratings[1:] 151 | historical_timestamps = movie_timestamps[1:] 152 | target_ids = movie_history[0] 153 | target_ratings = movie_history_ratings[0] 154 | target_timestamps = movie_timestamps[0] 155 | if self._chronological: 156 | historical_ids.reverse() 157 | historical_ratings.reverse() 158 | historical_timestamps.reverse() 159 | 160 | max_seq_len = self._padding_length - 1 161 | history_length = min(len(historical_ids), max_seq_len) 162 | historical_ids = _truncate_or_pad_seq( 163 | historical_ids, 164 | max_seq_len, 165 | self._chronological, 166 | ) 167 | historical_ratings = _truncate_or_pad_seq( 168 | historical_ratings, 169 | max_seq_len, 170 | self._chronological, 171 | ) 172 | historical_timestamps = _truncate_or_pad_seq( 173 | historical_timestamps, 174 | max_seq_len, 175 | self._chronological, 176 | ) 177 | # moved to features.py 178 | # if self._chronological: 179 | # historical_ids.append(0) 180 | # historical_ratings.append(0) 181 | # historical_timestamps.append(0) 182 | # print(historical_ids, historical_ratings, historical_timestamps, target_ids, target_ratings, target_timestamps) 183 | ret = { 184 | "user_id": user_id, 185 | "historical_ids": torch.tensor(historical_ids, dtype=torch.int64), 186 | "historical_ratings": torch.tensor(historical_ratings, dtype=torch.int64), 187 | "historical_timestamps": torch.tensor( 188 | historical_timestamps, dtype=torch.int64 189 | ), 190 | "history_lengths": history_length, 191 | "target_ids": target_ids, 192 | "target_ratings": target_ratings, 193 | "target_timestamps": target_timestamps, 194 | } 195 | return ret 196 | 197 | 198 | class MultiFileDatasetV2(DatasetV2, torch.utils.data.Dataset): 199 | 200 | def __init__( 201 | self, 202 | file_prefix: str, 203 | num_files: int, 204 | padding_length: int, 205 | ignore_last_n: int, # used for creating train/valid/test sets 206 | shift_id_by: int = 0, 207 | chronological: bool = False, 208 | sample_ratio: float = 1.0, 209 | ) -> None: 210 | torch.utils.data.Dataset().__init__() 211 | self._file_prefix: str = file_prefix 212 | self._num_files: int = num_files 213 | with open(f"{file_prefix}_users.csv", "r") as file: 214 | reader = csv.reader(file) 215 | self.users_cumsum: List[int] = np.cumsum( 216 | [int(row[1]) for row in reader] 217 | ).tolist() 218 | self._padding_length: int = padding_length 219 | self._ignore_last_n: int = ignore_last_n 220 | self._shift_id_by: int = shift_id_by 221 | self._chronological: bool = chronological 222 | self._sample_ratio: float = sample_ratio 223 | 224 | def __len__(self) -> int: 225 | return self.users_cumsum[-1] 226 | 227 | def _process_line(self, line: str) -> pd.Series: 228 | reader = csv.reader([line]) 229 | parsed_line = next(reader) 230 | user_id = int(parsed_line[0]) 231 | sequence_item_ids = parsed_line[1] 232 | sequence_ratings = parsed_line[2] 233 | return pd.Series( 234 | data={ 235 | "user_id": user_id, 236 | "sequence_item_ids": sequence_item_ids, 237 | "sequence_ratings": sequence_ratings, 238 | "sequence_timestamps": sequence_item_ids, # placeholder 239 | } 240 | ) 241 | 242 | def __getitem__(self, idx) -> Dict[str, torch.Tensor]: 243 | assert idx < self.users_cumsum[-1] 244 | file_idx: int = 0 245 | while self.users_cumsum[file_idx] <= idx: 246 | file_idx += 1 247 | if file_idx == 0: 248 | local_idx = idx 249 | else: 250 | local_idx = idx - self.users_cumsum[file_idx - 1] 251 | line = linecache.getline(f"{self._file_prefix}_{file_idx}.csv", local_idx + 1) 252 | data = self._process_line(line) 253 | sample = self.load_item(data) 254 | return sample 255 | -------------------------------------------------------------------------------- /generative_recommenders/data/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import logging 18 | import sys 19 | from dataclasses import dataclass 20 | from typing import Callable, Dict, List, Optional, Set, Union 21 | 22 | import torch 23 | import torch.distributed as dist 24 | 25 | from generative_recommenders.indexing.candidate_index import CandidateIndex, TopKModule 26 | from generative_recommenders.modeling.ndp_module import NDPModule 27 | from generative_recommenders.modeling.sequential.features import SequentialFeatures 28 | from torch.utils.tensorboard import SummaryWriter 29 | 30 | 31 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 32 | 33 | 34 | @dataclass 35 | class EvalState: 36 | all_item_ids: Set[int] 37 | candidate_index: CandidateIndex 38 | top_k_module: TopKModule 39 | 40 | 41 | def get_eval_state( 42 | model: NDPModule, 43 | all_item_ids: List[int], # [X] 44 | negatives_sampler: torch.nn.Module, 45 | top_k_module_fn: Callable[[torch.Tensor, torch.Tensor], TopKModule], 46 | device: int, 47 | float_dtype: Optional[torch.dtype] = None, 48 | ) -> EvalState: 49 | # Exhaustively eval all items (incl. seen ids). 50 | eval_negatives_ids = torch.as_tensor(all_item_ids).to(device).unsqueeze(0) # [1, X] 51 | eval_negative_embeddings = negatives_sampler.normalize_embeddings( 52 | model.get_item_embeddings(eval_negatives_ids) 53 | ) 54 | if float_dtype is not None: 55 | eval_negative_embeddings = eval_negative_embeddings.to(float_dtype) 56 | candidates = CandidateIndex( 57 | ids=eval_negatives_ids, 58 | embeddings=eval_negative_embeddings, 59 | ) 60 | return EvalState( 61 | all_item_ids=set(all_item_ids), 62 | candidate_index=candidates, 63 | top_k_module=top_k_module_fn(eval_negative_embeddings, eval_negatives_ids), 64 | ) 65 | 66 | 67 | @torch.inference_mode # pyre-ignore [56] 68 | def eval_metrics_v2_from_tensors( 69 | eval_state: EvalState, 70 | model: NDPModule, 71 | seq_features: SequentialFeatures, 72 | target_ids: torch.Tensor, # [B, 1] 73 | min_positive_rating: int = 4, 74 | target_ratings: Optional[torch.Tensor] = None, # [B, 1] 75 | epoch: Optional[str] = None, 76 | filter_invalid_ids: bool = True, 77 | user_max_batch_size: Optional[int] = None, 78 | dtype: Optional[torch.dtype] = None, 79 | ) -> Dict[str, Union[float, torch.Tensor]]: 80 | """ 81 | Args: 82 | eval_negatives_ids: Optional[Tensor]. If not present, defaults to eval over 83 | the entire corpus (`num_items`) excluding all the items that users have 84 | seen in the past (historical_ids, target_ids). This is consistent with 85 | papers like SASRec and TDM but may not be fair in practice as retrieval 86 | modules don't have access to read state during the initial fetch stage. 87 | filter_invalid_ids: bool. If true, filters seen ids by default. 88 | Returns: 89 | keyed metric -> list of values for each example. 90 | """ 91 | B, _ = target_ids.shape 92 | device = target_ids.device 93 | 94 | for target_id in target_ids: 95 | target_id = int(target_id) 96 | if target_id not in eval_state.all_item_ids: 97 | print(f"missing target_id {target_id}") 98 | 99 | # computes ro- part exactly once. 100 | shared_input_embeddings = model.encode( 101 | past_lengths=seq_features.past_lengths, 102 | past_ids=seq_features.past_ids, 103 | past_embeddings=model.get_item_embeddings(seq_features.past_ids), 104 | past_payloads=seq_features.past_payloads, 105 | ) 106 | if dtype is not None: 107 | shared_input_embeddings = shared_input_embeddings.to(dtype) 108 | 109 | MAX_K = 2500 110 | k = min(MAX_K, eval_state.candidate_index.ids.size(1)) 111 | user_max_batch_size = user_max_batch_size or shared_input_embeddings.size(0) 112 | num_batches = ( 113 | shared_input_embeddings.size(0) + user_max_batch_size - 1 114 | ) // user_max_batch_size 115 | eval_top_k_ids_all = [] 116 | eval_top_k_prs_all = [] 117 | for mb in range(num_batches): 118 | eval_top_k_ids, eval_top_k_prs, _ = ( 119 | eval_state.candidate_index.get_top_k_outputs( 120 | query_embeddings=shared_input_embeddings[ 121 | mb * user_max_batch_size : (mb + 1) * user_max_batch_size, ... 122 | ], 123 | top_k_module=eval_state.top_k_module, 124 | k=k, 125 | invalid_ids=( 126 | seq_features.past_ids[ 127 | mb * user_max_batch_size : (mb + 1) * user_max_batch_size, : 128 | ] 129 | if filter_invalid_ids 130 | else None 131 | ), 132 | return_embeddings=False, 133 | ) 134 | ) 135 | eval_top_k_ids_all.append(eval_top_k_ids) 136 | eval_top_k_prs_all.append(eval_top_k_prs) 137 | 138 | if num_batches == 1: 139 | eval_top_k_ids = eval_top_k_ids_all[0] 140 | eval_top_k_prs = eval_top_k_prs_all[0] 141 | else: 142 | eval_top_k_ids = torch.cat(eval_top_k_ids_all, dim=0) 143 | eval_top_k_prs = torch.cat(eval_top_k_prs_all, dim=0) 144 | 145 | assert eval_top_k_ids.size(1) == k 146 | _, eval_rank_indices = torch.max( 147 | torch.cat( 148 | [eval_top_k_ids, target_ids], 149 | dim=1, 150 | ) 151 | == target_ids, 152 | dim=1, 153 | ) 154 | eval_ranks = torch.where(eval_rank_indices == k, MAX_K + 1, eval_rank_indices + 1) 155 | 156 | output = { 157 | "ndcg@1": torch.where( 158 | eval_ranks <= 1, 159 | torch.div(1.0, torch.log2(eval_ranks + 1)), 160 | torch.zeros(1, dtype=torch.float32, device=device), 161 | ), 162 | "ndcg@10": torch.where( 163 | eval_ranks <= 10, 164 | torch.div(1.0, torch.log2(eval_ranks + 1)), 165 | torch.zeros(1, dtype=torch.float32, device=device), 166 | ), 167 | "ndcg@50": torch.where( 168 | eval_ranks <= 50, 169 | torch.div(1.0, torch.log2(eval_ranks + 1)), 170 | torch.zeros(1, dtype=torch.float32, device=device), 171 | ), 172 | "ndcg@100": torch.where( 173 | eval_ranks <= 100, 174 | torch.div(1.0, torch.log2(eval_ranks + 1)), 175 | torch.zeros(1, dtype=torch.float32, device=device), 176 | ), 177 | "ndcg@200": torch.where( 178 | eval_ranks <= 200, 179 | torch.div(1.0, torch.log2(eval_ranks + 1)), 180 | torch.zeros(1, dtype=torch.float32, device=device), 181 | ), 182 | "hr@1": (eval_ranks <= 1), 183 | "hr@10": (eval_ranks <= 10), 184 | "hr@50": (eval_ranks <= 50), 185 | "hr@100": (eval_ranks <= 100), 186 | "hr@200": (eval_ranks <= 200), 187 | "hr@500": (eval_ranks <= 500), 188 | "hr@1000": (eval_ranks <= 1000), 189 | "mrr": torch.div(1.0, eval_ranks), 190 | } 191 | if target_ratings is not None: 192 | target_ratings = target_ratings.squeeze(1) # [B] 193 | output["ndcg@10_>=4"] = torch.where( 194 | eval_ranks[target_ratings >= 4] <= 10, 195 | torch.div(1.0, torch.log2(eval_ranks[target_ratings >= 4] + 1)), 196 | torch.zeros(1, dtype=torch.float32, device=device), 197 | ) 198 | output[f"hr@10_>={min_positive_rating}"] = ( 199 | eval_ranks[target_ratings >= min_positive_rating] <= 10 200 | ) 201 | output[f"hr@50_>={min_positive_rating}"] = ( 202 | eval_ranks[target_ratings >= min_positive_rating] <= 50 203 | ) 204 | output[f"mrr_>={min_positive_rating}"] = torch.div( 205 | 1.0, eval_ranks[target_ratings >= min_positive_rating] 206 | ) 207 | 208 | return output # pyre-ignore [7] 209 | 210 | 211 | def eval_recall_metrics_from_tensors( 212 | eval_state: EvalState, 213 | model: NDPModule, 214 | seq_features: SequentialFeatures, 215 | user_max_batch_size: Optional[int] = None, 216 | dtype: Optional[torch.dtype] = None, 217 | ) -> Dict[str, torch.Tensor]: 218 | target_ids = seq_features.past_ids[:, -1].unsqueeze(1) 219 | filtered_past_ids = seq_features.past_ids.detach().clone() 220 | filtered_past_ids[:, -1] = torch.zeros_like(target_ids.squeeze(1)) 221 | return eval_metrics_v2_from_tensors( 222 | eval_state=eval_state, 223 | model=model, 224 | seq_features=SequentialFeatures( 225 | past_lengths=seq_features.past_lengths - 1, 226 | past_ids=filtered_past_ids, 227 | past_embeddings=seq_features.past_embeddings, 228 | past_payloads=seq_features.past_payloads, 229 | ), 230 | target_ids=target_ids, 231 | user_max_batch_size=user_max_batch_size, 232 | dtype=dtype, 233 | ) 234 | 235 | 236 | def _avg(x: torch.Tensor, world_size: int) -> torch.Tensor: 237 | _sum_and_numel = torch.tensor( 238 | [x.sum(), x.numel()], dtype=torch.float32, device=x.device 239 | ) 240 | if world_size > 1: 241 | dist.all_reduce(_sum_and_numel, op=dist.ReduceOp.SUM) 242 | return _sum_and_numel[0] / _sum_and_numel[1] 243 | 244 | 245 | def add_to_summary_writer( 246 | writer: Optional[SummaryWriter], 247 | batch_id: int, 248 | metrics: Dict[str, torch.Tensor], 249 | prefix: str, 250 | world_size: int, 251 | ) -> None: 252 | for key, values in metrics.items(): 253 | avg_value = _avg(values, world_size) 254 | if writer is not None: 255 | writer.add_scalar(f"{prefix}/{key}", avg_value, batch_id) 256 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/similarity_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | from typing import Tuple 18 | 19 | import gin 20 | import torch 21 | 22 | from generative_recommenders.modeling.initialization import ( 23 | init_mlp_xavier_weights_zero_bias, 24 | ) 25 | from generative_recommenders.modeling.similarity.dot_product import DotProductSimilarity 26 | from generative_recommenders.modeling.similarity.mol import ( 27 | GeGLU, 28 | IdentityMLPProjectionFn, 29 | MoLSimilarity, 30 | SoftmaxDropoutCombiner, 31 | ) 32 | 33 | 34 | @gin.configurable 35 | def create_mol_interaction_module( 36 | query_embedding_dim: int, 37 | item_embedding_dim: int, 38 | dot_product_dimension: int, 39 | query_dot_product_groups: int, 40 | item_dot_product_groups: int, 41 | temperature: float, 42 | dot_product_l2_norm: bool, 43 | query_use_identity_fn: bool, 44 | query_dropout_rate: float, 45 | query_hidden_dim: int, 46 | item_use_identity_fn: bool, 47 | item_dropout_rate: float, 48 | item_hidden_dim: int, 49 | gating_combination_type: str, 50 | gating_qi_hidden_dim: int, 51 | gating_query_hidden_dim: int, 52 | gating_item_hidden_dim: int, 53 | gating_softmax_dropout_rate: float, 54 | bf16_training: bool, 55 | gating_query_fn: bool = True, 56 | gating_item_fn: bool = True, 57 | gating_qi_dropout_rate: float = 0.0, 58 | gating_item_dropout_rate: float = 0.0, 59 | gating_use_custom_tau: bool = False, 60 | gating_tau_alpha: float = 0.01, 61 | eps: float = 1e-6, 62 | ) -> Tuple[MoLSimilarity, str]: 63 | mol_module = MoLSimilarity( 64 | input_embedding_dim=query_embedding_dim, 65 | item_embedding_dim=item_embedding_dim, 66 | dot_product_dimension=dot_product_dimension, 67 | input_dot_product_groups=query_dot_product_groups, 68 | item_dot_product_groups=item_dot_product_groups, 69 | temperature=temperature, 70 | dot_product_l2_norm=dot_product_l2_norm, 71 | num_precomputed_logits=0, 72 | # item_feature_embedding_dim * 3 if not ablate_item_features else 73 | item_sideinfo_dim=0, # not configured 74 | context_proj_fn=lambda input_dim, output_dim: ( 75 | IdentityMLPProjectionFn( 76 | input_dim=input_dim, 77 | output_num_features=query_dot_product_groups, 78 | output_dim=output_dim // query_dot_product_groups, 79 | input_dropout_rate=query_dropout_rate, 80 | ) 81 | if query_use_identity_fn 82 | else ( 83 | torch.nn.Sequential( 84 | torch.nn.Dropout(p=query_dropout_rate), 85 | GeGLU( 86 | in_features=input_dim, 87 | out_features=query_hidden_dim, 88 | ), 89 | torch.nn.Linear( 90 | in_features=query_hidden_dim, 91 | out_features=output_dim, 92 | ), 93 | ).apply(init_mlp_xavier_weights_zero_bias) 94 | if query_hidden_dim > 0 95 | else torch.nn.Sequential( 96 | torch.nn.Dropout(p=query_dropout_rate), 97 | torch.nn.Linear( 98 | in_features=input_dim, 99 | out_features=output_dim, 100 | ), 101 | ).apply(init_mlp_xavier_weights_zero_bias) 102 | ) 103 | ), 104 | item_proj_fn=lambda input_dim, output_dim: ( 105 | IdentityMLPProjectionFn( 106 | input_dim=input_dim, 107 | output_num_features=item_dot_product_groups, 108 | output_dim=output_dim // item_dot_product_groups, 109 | input_dropout_rate=item_dropout_rate, 110 | ) 111 | if item_use_identity_fn 112 | else ( 113 | torch.nn.Sequential( 114 | torch.nn.Dropout(p=item_dropout_rate), 115 | GeGLU( 116 | in_features=input_dim, 117 | out_features=item_hidden_dim, 118 | ), 119 | torch.nn.Linear( 120 | in_features=item_hidden_dim, 121 | out_features=output_dim, 122 | ), 123 | ).apply(init_mlp_xavier_weights_zero_bias) 124 | if item_hidden_dim > 0 125 | else torch.nn.Sequential( 126 | torch.nn.Dropout(p=item_dropout_rate), 127 | torch.nn.Linear( 128 | in_features=input_dim, 129 | out_features=output_dim, 130 | ), 131 | ).apply(init_mlp_xavier_weights_zero_bias) 132 | ) 133 | ), 134 | gating_context_only_partial_fn=lambda input_dim, output_dim: ( # pyre-ignore [6] 135 | torch.nn.Sequential( 136 | torch.nn.Linear( 137 | in_features=input_dim, 138 | out_features=gating_query_hidden_dim, 139 | ), 140 | torch.nn.SiLU(), 141 | torch.nn.Linear( 142 | in_features=gating_query_hidden_dim, 143 | out_features=output_dim, 144 | bias=False, 145 | ), 146 | ).apply(init_mlp_xavier_weights_zero_bias) 147 | if gating_query_fn 148 | else None 149 | ), 150 | gating_item_only_partial_fn=lambda input_dim, output_dim: ( # pyre-ignore [6] 151 | torch.nn.Sequential( 152 | torch.nn.Dropout(p=gating_item_dropout_rate), 153 | torch.nn.Linear( 154 | in_features=input_dim, 155 | out_features=gating_item_hidden_dim, 156 | ), 157 | torch.nn.SiLU(), 158 | torch.nn.Linear( 159 | in_features=gating_item_hidden_dim, 160 | out_features=output_dim, 161 | bias=False, 162 | ), 163 | ).apply(init_mlp_xavier_weights_zero_bias) 164 | if gating_item_fn 165 | else None 166 | ), 167 | gating_ci_partial_fn=lambda input_dim, output_dim: ( # pyre-ignore [6] 168 | torch.nn.Sequential( 169 | torch.nn.Dropout(p=gating_qi_dropout_rate), 170 | torch.nn.Linear( 171 | in_features=input_dim, 172 | out_features=gating_qi_hidden_dim, 173 | ), 174 | torch.nn.SiLU(), 175 | torch.nn.Linear( 176 | in_features=gating_qi_hidden_dim, 177 | out_features=output_dim, 178 | ), 179 | ).apply(init_mlp_xavier_weights_zero_bias) 180 | if gating_qi_hidden_dim > 0 181 | else torch.nn.Sequential( 182 | torch.nn.Dropout(p=gating_qi_dropout_rate), 183 | torch.nn.Linear( 184 | in_features=input_dim, 185 | out_features=output_dim, 186 | ), 187 | ).apply(init_mlp_xavier_weights_zero_bias) 188 | ), 189 | gating_combination_type=gating_combination_type, 190 | gating_normalization_fn=lambda _: SoftmaxDropoutCombiner( 191 | dropout_rate=gating_softmax_dropout_rate, eps=1e-6 192 | ), 193 | eps=eps, 194 | gating_combine_item_sideinfo_into_ci=False, 195 | gating_use_custom_tau=gating_use_custom_tau, 196 | gating_tau_alpha=gating_tau_alpha, 197 | bf16_training=bf16_training, 198 | ) 199 | interaction_module_debug_str = ( 200 | f"MoL-{query_dot_product_groups}x{item_dot_product_groups}x{dot_product_dimension}" 201 | + f"-t{temperature}-d{gating_softmax_dropout_rate}" 202 | + f"{'-l2' if dot_product_l2_norm else ''}" 203 | + ( 204 | f"-q{query_hidden_dim}d{query_dropout_rate}geglu" 205 | if query_hidden_dim > 0 206 | else f"-qd{query_dropout_rate}" 207 | ) 208 | + ( 209 | "-i_id" 210 | if item_use_identity_fn 211 | else ( 212 | f"-{item_hidden_dim}d{item_dropout_rate}-geglu" 213 | if item_hidden_dim > 0 214 | else f"-id{item_dropout_rate}" 215 | ) 216 | ) 217 | + (f"-gq{gating_query_hidden_dim}" if gating_query_fn else "") 218 | + ( 219 | f"-gi{gating_item_hidden_dim}d{gating_item_dropout_rate}" 220 | if gating_item_fn 221 | else "" 222 | ) 223 | + f"-gqi{gating_qi_hidden_dim}d{gating_qi_dropout_rate}-x-{gating_combination_type}" 224 | ) 225 | if gating_use_custom_tau: 226 | interaction_module_debug_str += f"-tau{gating_tau_alpha}" 227 | return mol_module, interaction_module_debug_str 228 | 229 | 230 | @gin.configurable 231 | def get_similarity_function( 232 | module_type: str, 233 | query_embedding_dim: int, 234 | item_embedding_dim: int, 235 | bf16_training: bool = False, 236 | activation_checkpoint: bool = False, 237 | ) -> Tuple[torch.nn.Module, str]: 238 | if module_type == "DotProduct": 239 | interaction_module = DotProductSimilarity() 240 | interaction_module_debug_str = "DotProduct" 241 | elif module_type == "MoL": 242 | interaction_module, interaction_module_debug_str = ( 243 | create_mol_interaction_module( 244 | query_embedding_dim=query_embedding_dim, 245 | item_embedding_dim=item_embedding_dim, 246 | bf16_training=bf16_training, 247 | ) 248 | ) 249 | else: 250 | raise ValueError(f"Unknown interaction_module_type {module_type}") 251 | return interaction_module, interaction_module_debug_str 252 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/sequential/sasrec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | """ 18 | Implements SASRec (Self-Attentive Sequential Recommendation, https://arxiv.org/abs/1808.09781, ICDM'18). 19 | 20 | Compared with the original paper which used BCE loss, this implementation is modified so that 21 | we can utilize a Sampled Softmax loss proposed in Revisiting Neural Retrieval on Accelerators 22 | (https://arxiv.org/abs/2306.04039, KDD'23) and Turning Dross Into Gold Loss: is BERT4Rec really 23 | better than SASRec? (https://arxiv.org/abs/2309.07602, RecSys'23), where the authors showed 24 | sampled softmax loss to significantly improved SASRec model quality. 25 | """ 26 | 27 | from typing import Dict, Optional, Tuple 28 | 29 | import torch 30 | import torch.nn.functional as F 31 | 32 | from generative_recommenders.modeling.ndp_module import NDPModule 33 | from generative_recommenders.modeling.sequential.embedding_modules import ( 34 | EmbeddingModule, 35 | ) 36 | from generative_recommenders.modeling.sequential.input_features_preprocessors import ( 37 | InputFeaturesPreprocessorModule, 38 | ) 39 | from generative_recommenders.modeling.sequential.output_postprocessors import ( 40 | OutputPostprocessorModule, 41 | ) 42 | from generative_recommenders.modeling.sequential.utils import get_current_embeddings 43 | from generative_recommenders.modeling.similarity_module import ( 44 | GeneralizedInteractionModule, 45 | ) 46 | 47 | 48 | class StandardAttentionFF(torch.nn.Module): 49 | def __init__( 50 | self, 51 | embedding_dim: int, 52 | hidden_dim: int, 53 | activation_fn: str, 54 | dropout_rate: float, 55 | ) -> None: 56 | super().__init__() 57 | 58 | assert ( 59 | activation_fn == "relu" or activation_fn == "gelu" 60 | ), f"Invalid activation_fn {activation_fn}" 61 | 62 | self._conv1d = torch.nn.Sequential( 63 | torch.nn.Conv1d( 64 | in_channels=embedding_dim, 65 | out_channels=hidden_dim, 66 | kernel_size=1, 67 | ), 68 | torch.nn.GELU() if activation_fn == "gelu" else torch.nn.ReLU(), 69 | torch.nn.Dropout(p=dropout_rate), 70 | torch.nn.Conv1d( 71 | in_channels=hidden_dim, 72 | out_channels=embedding_dim, 73 | kernel_size=1, 74 | ), 75 | torch.nn.Dropout(p=dropout_rate), 76 | ) 77 | 78 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 79 | # Conv1D requires (B, D, N) 80 | return self._conv1d(inputs.transpose(-1, -2)).transpose(-1, -2) + inputs 81 | 82 | 83 | class SASRec(GeneralizedInteractionModule): 84 | """ 85 | Implements SASRec (Self-Attentive Sequential Recommendation, https://arxiv.org/abs/1808.09781, ICDM'18). 86 | 87 | Compared with the original paper which used BCE loss, this implementation is modified so that 88 | we can utilize a Sampled Softmax loss proposed in Revisiting Neural Retrieval on Accelerators 89 | (https://arxiv.org/abs/2306.04039, KDD'23) and Turning Dross Into Gold Loss: is BERT4Rec really 90 | better than SASRec? (https://arxiv.org/abs/2309.07602, RecSys'23), where the authors showed 91 | sampled softmax loss to significantly improved SASRec model quality. 92 | """ 93 | 94 | def __init__( 95 | self, 96 | max_sequence_len: int, 97 | max_output_len: int, 98 | embedding_dim: int, 99 | num_blocks: int, 100 | num_heads: int, 101 | ffn_hidden_dim: int, 102 | ffn_activation_fn: str, 103 | ffn_dropout_rate: float, 104 | embedding_module: EmbeddingModule, 105 | similarity_module: NDPModule, 106 | input_features_preproc_module: InputFeaturesPreprocessorModule, 107 | output_postproc_module: OutputPostprocessorModule, 108 | activation_checkpoint: bool = False, 109 | verbose: bool = False, 110 | ) -> None: 111 | super().__init__(ndp_module=similarity_module) 112 | 113 | self._embedding_module: EmbeddingModule = embedding_module 114 | self._embedding_dim: int = embedding_dim 115 | self._item_embedding_dim: int = embedding_module.item_embedding_dim 116 | self._max_sequence_length: int = max_sequence_len + max_output_len 117 | self._input_features_preproc: InputFeaturesPreprocessorModule = ( 118 | input_features_preproc_module 119 | ) 120 | self._output_postproc: OutputPostprocessorModule = output_postproc_module 121 | self._activation_checkpoint: bool = activation_checkpoint 122 | self._verbose: bool = verbose 123 | 124 | self.attention_layers = torch.nn.ModuleList() 125 | self.forward_layers = torch.nn.ModuleList() 126 | self._num_blocks: int = num_blocks 127 | self._num_heads: int = num_heads 128 | self._ffn_hidden_dim: int = ffn_hidden_dim 129 | self._ffn_activation_fn: str = ffn_activation_fn 130 | self._ffn_dropout_rate: float = ffn_dropout_rate 131 | 132 | for _ in range(num_blocks): 133 | self.attention_layers.append( 134 | torch.nn.MultiheadAttention( 135 | embed_dim=self._embedding_dim, 136 | num_heads=num_heads, 137 | dropout=ffn_dropout_rate, 138 | batch_first=True, 139 | ) 140 | ) 141 | self.forward_layers.append( 142 | StandardAttentionFF( 143 | embedding_dim=self._embedding_dim, 144 | hidden_dim=ffn_hidden_dim, 145 | activation_fn=ffn_activation_fn, 146 | dropout_rate=self._ffn_dropout_rate, 147 | ) 148 | ) 149 | 150 | self.register_buffer( 151 | "_attn_mask", 152 | torch.triu( 153 | torch.ones( 154 | (self._max_sequence_length, self._max_sequence_length), 155 | dtype=torch.bool, 156 | ), 157 | diagonal=1, 158 | ), 159 | ) 160 | self.reset_state() 161 | 162 | def reset_state(self) -> None: 163 | for name, params in self.named_parameters(): 164 | if ( 165 | "_input_features_preproc" in name 166 | or "_embedding_module" in name 167 | or "_output_postproc" in name 168 | ): 169 | if self._verbose: 170 | print(f"Skipping initialization for {name}") 171 | continue 172 | try: 173 | torch.nn.init.xavier_normal_(params.data) 174 | if self._verbose: 175 | print( 176 | f"Initialize {name} as xavier normal: {params.data.size()} params" 177 | ) 178 | except: 179 | if self._verbose: 180 | print(f"Failed to initialize {name}: {params.data.size()} params") 181 | 182 | def get_item_embeddings(self, item_ids: torch.Tensor) -> torch.Tensor: 183 | return self._embedding_module.get_item_embeddings(item_ids) 184 | 185 | def debug_str(self) -> str: 186 | return ( 187 | f"SASRec-d{self._item_embedding_dim}-b{self._num_blocks}-h{self._num_heads}" 188 | + "-" 189 | + self._input_features_preproc.debug_str() 190 | + "-" 191 | + self._output_postproc.debug_str() 192 | + f"-ffn{self._ffn_hidden_dim}-{self._ffn_activation_fn}-d{self._ffn_dropout_rate}" 193 | + f"{'-ac' if self._activation_checkpoint else ''}" 194 | ) 195 | 196 | def _run_one_layer( 197 | self, 198 | i: int, 199 | user_embeddings: torch.Tensor, 200 | valid_mask: torch.Tensor, 201 | ) -> torch.Tensor: 202 | Q = F.layer_norm( 203 | user_embeddings, 204 | normalized_shape=(self._embedding_dim,), 205 | eps=1e-8, 206 | ) 207 | mha_outputs, _ = self.attention_layers[i]( 208 | query=Q, 209 | key=user_embeddings, 210 | value=user_embeddings, 211 | attn_mask=self._attn_mask, 212 | ) 213 | user_embeddings = self.forward_layers[i]( 214 | F.layer_norm( 215 | Q + mha_outputs, 216 | normalized_shape=(self._embedding_dim,), 217 | eps=1e-8, 218 | ) 219 | ) 220 | user_embeddings *= valid_mask 221 | return user_embeddings 222 | 223 | def generate_user_embeddings( 224 | self, 225 | past_lengths: torch.Tensor, 226 | past_ids: torch.Tensor, 227 | past_embeddings: torch.Tensor, 228 | past_payloads: Dict[str, torch.Tensor], 229 | ) -> torch.Tensor: 230 | """ 231 | Args: 232 | past_ids: (B, N,) x int 233 | 234 | Returns: 235 | (B, N, D,) x float 236 | """ 237 | past_lengths, user_embeddings, valid_mask = self._input_features_preproc( 238 | past_lengths=past_lengths, 239 | past_ids=past_ids, 240 | past_embeddings=past_embeddings, 241 | past_payloads=past_payloads, 242 | ) 243 | 244 | for i in range(len(self.attention_layers)): 245 | if self._activation_checkpoint: 246 | user_embeddings = torch.utils.checkpoint.checkpoint( 247 | self._run_one_layer, 248 | i, 249 | user_embeddings, 250 | valid_mask, 251 | use_reentrant=False, 252 | ) 253 | else: 254 | user_embeddings = self._run_one_layer(i, user_embeddings, valid_mask) 255 | 256 | return self._output_postproc(user_embeddings) 257 | 258 | def forward( 259 | self, 260 | past_lengths: torch.Tensor, 261 | past_ids: torch.Tensor, 262 | past_embeddings: torch.Tensor, 263 | past_payloads: Dict[str, torch.Tensor], 264 | batch_id: Optional[int] = None, 265 | ) -> torch.Tensor: 266 | """ 267 | Args: 268 | past_ids: [B, N] x int64 where the latest engaged ids come first. In 269 | particular, [:, 0] should correspond to the last engaged values. 270 | past_ratings: [B, N] x int64. 271 | past_timestamps: [B, N] x int64. 272 | 273 | Returns: 274 | encoded_embeddings of [B, N, D]. 275 | """ 276 | encoded_embeddings = self.generate_user_embeddings( 277 | past_lengths, 278 | past_ids, 279 | past_embeddings, 280 | past_payloads, 281 | ) 282 | return encoded_embeddings 283 | 284 | def encode( 285 | self, 286 | past_lengths: torch.Tensor, 287 | past_ids: torch.Tensor, # [B, N] x int64 288 | past_embeddings: torch.Tensor, 289 | past_payloads: Dict[str, torch.Tensor], 290 | ) -> torch.Tensor: 291 | encoded_seq_embeddings = self.generate_user_embeddings( 292 | past_lengths, past_ids, past_embeddings, past_payloads 293 | ) # [B, N, D] 294 | return get_current_embeddings( 295 | lengths=past_lengths, encoded_embeddings=encoded_seq_embeddings 296 | ) 297 | 298 | def predict( 299 | self, 300 | past_ids: torch.Tensor, 301 | past_ratings: torch.Tensor, 302 | past_timestamps: torch.Tensor, 303 | next_timestamps: torch.Tensor, 304 | target_ids: torch.Tensor, 305 | batch_id: Optional[int] = None, 306 | ) -> torch.Tensor: 307 | return self.interaction( 308 | self.encode(past_ids, past_ratings, past_timestamps, next_timestamps), # pyre-ignore [6] 309 | target_ids, 310 | ) # [B, X] 311 | -------------------------------------------------------------------------------- /generative_recommenders/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import abc 18 | import logging 19 | import os 20 | import sys 21 | import tarfile 22 | from typing import Dict, Optional, Union 23 | 24 | from urllib.request import urlretrieve 25 | from zipfile import ZipFile 26 | 27 | import numpy as np 28 | 29 | import pandas as pd 30 | 31 | 32 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 33 | 34 | 35 | class DataProcessor: 36 | """ 37 | This preprocessor does not remap item_ids. This is intended so that we can easily join other 38 | side-information based on item_ids later. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | prefix: str, 44 | expected_num_unique_items: Optional[int], 45 | expected_max_item_id: Optional[int], 46 | ) -> None: 47 | self._prefix: str = prefix 48 | self._expected_num_unique_items = expected_num_unique_items 49 | self._expected_max_item_id = expected_max_item_id 50 | 51 | @abc.abstractmethod 52 | def expected_num_unique_items(self) -> Optional[int]: 53 | return self._expected_num_unique_items 54 | 55 | @abc.abstractmethod 56 | def expected_max_item_id(self) -> Optional[int]: 57 | return self._expected_max_item_id 58 | 59 | @abc.abstractmethod 60 | def processed_item_csv(self) -> str: 61 | pass 62 | 63 | def output_format_csv(self) -> str: 64 | return f"tmp/{self._prefix}/sasrec_format.csv" 65 | 66 | def to_seq_data( 67 | self, 68 | ratings_data: pd.DataFrame, 69 | user_data: Optional[pd.DataFrame] = None, 70 | ) -> pd.DataFrame: 71 | if user_data is not None: 72 | ratings_data_transformed = ratings_data.join( 73 | user_data.set_index("user_id"), on="user_id" 74 | ) 75 | else: 76 | ratings_data_transformed = ratings_data 77 | ratings_data_transformed.item_ids = ratings_data_transformed.item_ids.apply( 78 | lambda x: ",".join([str(v) for v in x]) 79 | ) 80 | ratings_data_transformed.ratings = ratings_data_transformed.ratings.apply( 81 | lambda x: ",".join([str(v) for v in x]) 82 | ) 83 | ratings_data_transformed.timestamps = ratings_data_transformed.timestamps.apply( 84 | lambda x: ",".join([str(v) for v in x]) 85 | ) 86 | ratings_data_transformed.rename( 87 | columns={ 88 | "item_ids": "sequence_item_ids", 89 | "ratings": "sequence_ratings", 90 | "timestamps": "sequence_timestamps", 91 | }, 92 | inplace=True, 93 | ) 94 | return ratings_data_transformed 95 | 96 | def file_exists(self, name: str) -> bool: 97 | return os.path.isfile("%s/%s" % (os.getcwd(), name)) 98 | 99 | 100 | class MovielensSyntheticDataProcessor(DataProcessor): 101 | def __init__( 102 | self, 103 | prefix: str, 104 | expected_num_unique_items: Optional[int] = None, 105 | expected_max_item_id: Optional[int] = None, 106 | ) -> None: 107 | super().__init__(prefix, expected_num_unique_items, expected_max_item_id) 108 | 109 | def preprocess_rating(self) -> None: 110 | return 111 | 112 | 113 | class MovielensDataProcessor(DataProcessor): 114 | def __init__( 115 | self, 116 | download_path: str, 117 | saved_name: str, 118 | prefix: str, 119 | convert_timestamp: bool, 120 | expected_num_unique_items: Optional[int] = None, 121 | expected_max_item_id: Optional[int] = None, 122 | ) -> None: 123 | super().__init__(prefix, expected_num_unique_items, expected_max_item_id) 124 | self._download_path = download_path 125 | self._saved_name = saved_name 126 | self._convert_timestamp: bool = convert_timestamp 127 | 128 | def download(self) -> None: 129 | if not self.file_exists(self._saved_name): 130 | urlretrieve(self._download_path, self._saved_name) 131 | if self._saved_name[-4:] == ".zip": 132 | ZipFile(self._saved_name, "r").extractall(path="tmp/") 133 | else: 134 | with tarfile.open(self._saved_name, "r:*") as tar_ref: 135 | tar_ref.extractall("tmp/") 136 | 137 | def processed_item_csv(self) -> str: 138 | return f"tmp/processed/{self._prefix}/movies.csv" 139 | 140 | def sasrec_format_csv_by_user_train(self) -> str: 141 | return f"tmp/{self._prefix}/sasrec_format_by_user_train.csv" 142 | 143 | def sasrec_format_csv_by_user_test(self) -> str: 144 | return f"tmp/{self._prefix}/sasrec_format_by_user_test.csv" 145 | 146 | def preprocess_rating(self) -> int: 147 | self.download() 148 | 149 | if self._prefix == "ml-1m": 150 | users = pd.read_csv( 151 | f"tmp/{self._prefix}/users.dat", 152 | sep="::", 153 | names=["user_id", "sex", "age_group", "occupation", "zip_code"], 154 | ) 155 | ratings = pd.read_csv( 156 | f"tmp/{self._prefix}/ratings.dat", 157 | sep="::", 158 | names=["user_id", "movie_id", "rating", "unix_timestamp"], 159 | ) 160 | movies = pd.read_csv( 161 | f"tmp/{self._prefix}/movies.dat", 162 | sep="::", 163 | names=["movie_id", "title", "genres"], 164 | encoding="iso-8859-1", 165 | ) 166 | elif self._prefix == "ml-20m": 167 | # ml-20m 168 | # ml-20m doesn't have user data. 169 | users = None 170 | # ratings: userId,movieId,rating,timestamp 171 | ratings = pd.read_csv( 172 | f"tmp/{self._prefix}/ratings.csv", 173 | sep=",", 174 | ) 175 | ratings.rename( 176 | columns={ 177 | "userId": "user_id", 178 | "movieId": "movie_id", 179 | "timestamp": "unix_timestamp", 180 | }, 181 | inplace=True, 182 | ) 183 | # movieId,title,genres 184 | # 1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy 185 | # 2,Jumanji (1995),Adventure|Children|Fantasy 186 | movies = pd.read_csv( 187 | f"tmp/{self._prefix}/movies.csv", 188 | sep=",", 189 | encoding="iso-8859-1", 190 | ) 191 | movies.rename(columns={"movieId": "movie_id"}, inplace=True) 192 | else: 193 | assert self._prefix == "ml-20mx16x32" 194 | # ml-1b 195 | user_ids = [] 196 | movie_ids = [] 197 | for i in range(16): 198 | train_file = f"tmp/{self._prefix}/trainx16x32_{i}.npz" 199 | with np.load(train_file) as data: 200 | user_ids.extend([x[0] for x in data["arr_0"]]) 201 | movie_ids.extend([x[1] for x in data["arr_0"]]) 202 | ratings = pd.DataFrame( 203 | data={ 204 | "user_id": user_ids, 205 | "movie_id": movie_ids, 206 | "rating": user_ids, # placeholder 207 | "unix_timestamp": movie_ids, # placeholder 208 | } 209 | ) 210 | users = None 211 | movies = None 212 | 213 | if movies is not None: 214 | # ML-1M and ML-20M only 215 | movies["year"] = movies["title"].apply(lambda x: x[-5:-1]) 216 | movies["cleaned_title"] = movies["title"].apply(lambda x: x[:-7]) 217 | # movies.year = pd.Categorical(movies.year) 218 | # movies["year"] = movies.year.cat.codes 219 | 220 | if users is not None: 221 | ## Users (ml-1m only) 222 | users.sex = pd.Categorical(users.sex) 223 | users["sex"] = users.sex.cat.codes 224 | 225 | users.age_group = pd.Categorical(users.age_group) 226 | users["age_group"] = users.age_group.cat.codes 227 | 228 | users.occupation = pd.Categorical(users.occupation) 229 | users["occupation"] = users.occupation.cat.codes 230 | 231 | users.zip_code = pd.Categorical(users.zip_code) 232 | users["zip_code"] = users.zip_code.cat.codes 233 | 234 | # Normalize movie ids to speed up training 235 | print( 236 | f"{self._prefix} #item before normalize: {len(set(ratings['movie_id'].values))}" 237 | ) 238 | print( 239 | f"{self._prefix} max item id before normalize: {max(set(ratings['movie_id'].values))}" 240 | ) 241 | # print(f"ratings.movie_id.cat.categories={ratings.movie_id.cat.categories}; {type(ratings.movie_id.cat.categories)}") 242 | # print(f"ratings.movie_id.cat.codes={ratings.movie_id.cat.codes}; {type(ratings.movie_id.cat.codes)}") 243 | # print(movie_id_to_cat) 244 | # ratings["movie_id"] = ratings.movie_id.cat.codes 245 | # print(f"{self._prefix} #item after normalize: {len(set(ratings['movie_id'].values))}") 246 | # print(f"{self._prefix} max item id after normalize: {max(set(ratings['movie_id'].values))}") 247 | # movies["remapped_id"] = movies["movie_id"].apply(lambda x: movie_id_to_cat[x]) 248 | 249 | if self._convert_timestamp: 250 | ratings["unix_timestamp"] = pd.to_datetime( 251 | ratings["unix_timestamp"], unit="s" 252 | ) 253 | 254 | # Save primary csv's 255 | if not os.path.exists(f"tmp/processed/{self._prefix}"): 256 | os.makedirs(f"tmp/processed/{self._prefix}") 257 | if users is not None: 258 | users.to_csv(f"tmp/processed/{self._prefix}/users.csv", index=False) 259 | if movies is not None: 260 | movies.to_csv(f"tmp/processed/{self._prefix}/movies.csv", index=False) 261 | ratings.to_csv(f"tmp/processed/{self._prefix}/ratings.csv", index=False) 262 | 263 | num_unique_users = len(set(ratings["user_id"].values)) 264 | num_unique_items = len(set(ratings["movie_id"].values)) 265 | 266 | # SASRec version 267 | ratings_group = ratings.sort_values(by=["unix_timestamp"]).groupby("user_id") 268 | seq_ratings_data = pd.DataFrame( 269 | data={ 270 | "user_id": list(ratings_group.groups.keys()), 271 | "item_ids": list(ratings_group.movie_id.apply(list)), 272 | "ratings": list(ratings_group.rating.apply(list)), 273 | "timestamps": list(ratings_group.unix_timestamp.apply(list)), 274 | } 275 | ) 276 | 277 | result = pd.DataFrame([[]]) 278 | for col in ["item_ids"]: 279 | result[col + "_mean"] = seq_ratings_data[col].apply(len).mean() 280 | result[col + "_min"] = seq_ratings_data[col].apply(len).min() 281 | result[col + "_max"] = seq_ratings_data[col].apply(len).max() 282 | print(self._prefix) 283 | print(result) 284 | 285 | seq_ratings_data = self.to_seq_data(seq_ratings_data, users) 286 | seq_ratings_data.sample(frac=1).reset_index().to_csv( 287 | self.output_format_csv(), index=False, sep="," 288 | ) 289 | 290 | # Split by user ids (not tested yet) 291 | user_id_split = int(num_unique_users * 0.9) 292 | seq_ratings_data_train = seq_ratings_data[ 293 | seq_ratings_data["user_id"] <= user_id_split 294 | ] 295 | seq_ratings_data_train.sample(frac=1).reset_index().to_csv( 296 | self.sasrec_format_csv_by_user_train(), 297 | index=False, 298 | sep=",", 299 | ) 300 | seq_ratings_data_test = seq_ratings_data[ 301 | seq_ratings_data["user_id"] > user_id_split 302 | ] 303 | seq_ratings_data_test.sample(frac=1).reset_index().to_csv( 304 | self.sasrec_format_csv_by_user_test(), index=False, sep="," 305 | ) 306 | print( 307 | f"{self._prefix}: train num user: {len(set(seq_ratings_data_train['user_id'].values))}" 308 | ) 309 | print( 310 | f"{self._prefix}: test num user: {len(set(seq_ratings_data_test['user_id'].values))}" 311 | ) 312 | 313 | # print(seq_ratings_data) 314 | if self.expected_num_unique_items() is not None: 315 | assert ( 316 | self.expected_num_unique_items() == num_unique_items 317 | ), f"Expected items: {self.expected_num_unique_items()}, got: {num_unique_items}" 318 | 319 | return num_unique_items 320 | 321 | 322 | class AmazonDataProcessor(DataProcessor): 323 | def __init__( 324 | self, 325 | download_path: str, 326 | saved_name: str, 327 | prefix: str, 328 | expected_num_unique_items: Optional[int], 329 | ) -> None: 330 | super().__init__( 331 | prefix, 332 | expected_num_unique_items=expected_num_unique_items, 333 | expected_max_item_id=None, 334 | ) 335 | self._download_path = download_path 336 | self._saved_name = saved_name 337 | self._prefix = prefix 338 | 339 | def download(self) -> None: 340 | if not self.file_exists(self._saved_name): 341 | urlretrieve(self._download_path, self._saved_name) 342 | 343 | def preprocess_rating(self) -> int: 344 | self.download() 345 | 346 | ratings = pd.read_csv( 347 | self._saved_name, 348 | sep=",", 349 | names=["user_id", "item_id", "rating", "timestamp"], 350 | ) 351 | print(f"{self._prefix} #data points before filter: {ratings.shape[0]}") 352 | print( 353 | f"{self._prefix} #user before filter: {len(set(ratings['user_id'].values))}" 354 | ) 355 | print( 356 | f"{self._prefix} #item before filter: {len(set(ratings['item_id'].values))}" 357 | ) 358 | 359 | # filter users and items with presence < 5 360 | item_id_count = ( 361 | ratings["item_id"] 362 | .value_counts() 363 | .rename_axis("unique_values") 364 | .reset_index(name="item_count") 365 | ) 366 | user_id_count = ( 367 | ratings["user_id"] 368 | .value_counts() 369 | .rename_axis("unique_values") 370 | .reset_index(name="user_count") 371 | ) 372 | ratings = ratings.join(item_id_count.set_index("unique_values"), on="item_id") 373 | ratings = ratings.join(user_id_count.set_index("unique_values"), on="user_id") 374 | ratings = ratings[ratings["item_count"] >= 5] 375 | ratings = ratings[ratings["user_count"] >= 5] 376 | print(f"{self._prefix} #data points after filter: {ratings.shape[0]}") 377 | 378 | # categorize user id and item id 379 | ratings["item_id"] = pd.Categorical(ratings["item_id"]) 380 | ratings["item_id"] = ratings["item_id"].cat.codes 381 | ratings["user_id"] = pd.Categorical(ratings["user_id"]) 382 | ratings["user_id"] = ratings["user_id"].cat.codes 383 | print( 384 | f"{self._prefix} #user after filter: {len(set(ratings['user_id'].values))}" 385 | ) 386 | print( 387 | f"{self._prefix} #item ater filter: {len(set(ratings['item_id'].values))}" 388 | ) 389 | 390 | num_unique_items = len(set(ratings["item_id"].values)) 391 | 392 | # SASRec version 393 | ratings_group = ratings.sort_values(by=["timestamp"]).groupby("user_id") 394 | 395 | seq_ratings_data = pd.DataFrame( 396 | data={ 397 | "user_id": list(ratings_group.groups.keys()), 398 | "item_ids": list(ratings_group.item_id.apply(list)), 399 | "ratings": list(ratings_group.rating.apply(list)), 400 | "timestamps": list(ratings_group.timestamp.apply(list)), 401 | } 402 | ) 403 | 404 | seq_ratings_data = seq_ratings_data[ 405 | seq_ratings_data["item_ids"].apply(len) >= 5 406 | ] 407 | 408 | result = pd.DataFrame([[]]) 409 | for col in ["item_ids"]: 410 | result[col + "_mean"] = seq_ratings_data[col].apply(len).mean() 411 | result[col + "_min"] = seq_ratings_data[col].apply(len).min() 412 | result[col + "_max"] = seq_ratings_data[col].apply(len).max() 413 | print(self._prefix) 414 | print(result) 415 | 416 | if not os.path.exists(f"tmp/{self._prefix}"): 417 | os.makedirs(f"tmp/{self._prefix}") 418 | 419 | seq_ratings_data = self.to_seq_data(seq_ratings_data) 420 | seq_ratings_data.sample(frac=1).reset_index().to_csv( 421 | self.output_format_csv(), index=False, sep="," 422 | ) 423 | 424 | if self.expected_num_unique_items() is not None: 425 | assert ( 426 | self.expected_num_unique_items() == num_unique_items 427 | ), f"expected: {self.expected_num_unique_items()}, actual: {num_unique_items}" 428 | logging.info(f"{self.expected_num_unique_items()} unique items.") 429 | 430 | return num_unique_items 431 | 432 | 433 | def get_common_preprocessors() -> Dict[ 434 | str, 435 | Union[AmazonDataProcessor, MovielensDataProcessor, MovielensSyntheticDataProcessor], 436 | ]: 437 | ml_1m_dp = MovielensDataProcessor( # pyre-ignore [45] 438 | "http://files.grouplens.org/datasets/movielens/ml-1m.zip", 439 | "tmp/movielens1m.zip", 440 | prefix="ml-1m", 441 | convert_timestamp=False, 442 | expected_num_unique_items=3706, 443 | expected_max_item_id=3952, 444 | ) 445 | ml_20m_dp = MovielensDataProcessor( # pyre-ignore [45] 446 | "http://files.grouplens.org/datasets/movielens/ml-20m.zip", 447 | "tmp/movielens20m.zip", 448 | prefix="ml-20m", 449 | convert_timestamp=False, 450 | expected_num_unique_items=26744, 451 | expected_max_item_id=131262, 452 | ) 453 | ml_1b_dp = MovielensDataProcessor( # pyre-ignore [45] 454 | "https://files.grouplens.org/datasets/movielens/ml-20mx16x32.tar", 455 | "tmp/movielens1b.tar", 456 | prefix="ml-20mx16x32", 457 | convert_timestamp=False, 458 | ) 459 | ml_3b_dp = MovielensSyntheticDataProcessor( # pyre-ignore [45] 460 | prefix="ml-3b", 461 | expected_num_unique_items=26743 * 32, 462 | expected_max_item_id=26743 * 32, 463 | ) 464 | amzn_books_dp = AmazonDataProcessor( # pyre-ignore [45] 465 | "http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/ratings_Books.csv", 466 | "tmp/ratings_Books.csv", 467 | prefix="amzn_books", 468 | expected_num_unique_items=695762, 469 | ) 470 | return { 471 | "ml-1m": ml_1m_dp, 472 | "ml-20m": ml_20m_dp, 473 | "ml-1b": ml_1b_dp, 474 | "ml-3b": ml_3b_dp, 475 | "amzn-books": amzn_books_dp, 476 | } 477 | -------------------------------------------------------------------------------- /generative_recommenders/trainer/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import logging 18 | import os 19 | import random 20 | 21 | import time 22 | 23 | from datetime import date 24 | from typing import Optional 25 | 26 | import gin 27 | 28 | import torch 29 | import torch.distributed as dist 30 | 31 | from tqdm import tqdm 32 | 33 | from generative_recommenders.data.eval import ( 34 | _avg, 35 | add_to_summary_writer, 36 | eval_metrics_v2_from_tensors, 37 | get_eval_state, 38 | ) 39 | 40 | from generative_recommenders.data.reco_dataset import get_reco_dataset 41 | from generative_recommenders.indexing.utils import get_top_k_module 42 | from generative_recommenders.modeling.sequential.autoregressive_losses import ( 43 | BCELoss, 44 | InBatchNegativesSampler, 45 | LocalNegativesSampler, 46 | SampledSoftmaxLoss, 47 | ) 48 | from generative_recommenders.modeling.sequential.embedding_modules import ( 49 | EmbeddingModule, 50 | LocalEmbeddingModule, 51 | ) 52 | from generative_recommenders.modeling.sequential.encoder_utils import ( 53 | get_sequential_encoder, 54 | ) 55 | from generative_recommenders.modeling.sequential.features import ( 56 | movielens_seq_features_from_row, 57 | ) 58 | from generative_recommenders.modeling.sequential.input_features_preprocessors import ( 59 | LearnablePositionalEmbeddingInputFeaturesPreprocessor, 60 | ) 61 | from generative_recommenders.modeling.sequential.output_postprocessors import ( 62 | L2NormEmbeddingPostprocessor, 63 | LayerNormEmbeddingPostprocessor, 64 | ) 65 | from generative_recommenders.modeling.similarity_utils import get_similarity_function 66 | from generative_recommenders.trainer.data_loader import create_data_loader 67 | from torch.nn.parallel import DistributedDataParallel as DDP 68 | from torch.utils.tensorboard import SummaryWriter 69 | 70 | 71 | def setup(rank: int, world_size: int, master_port: int) -> None: 72 | os.environ["MASTER_ADDR"] = "localhost" 73 | os.environ["MASTER_PORT"] = str(master_port) 74 | 75 | # initialize the process group 76 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 77 | 78 | 79 | def cleanup() -> None: 80 | dist.destroy_process_group() 81 | 82 | 83 | @gin.configurable 84 | def train_fn( 85 | rank: int, 86 | world_size: int, 87 | master_port: int, 88 | dataset_name: str = "ml-20m", 89 | max_sequence_length: int = 200, 90 | positional_sampling_ratio: float = 1.0, 91 | local_batch_size: int = 128, 92 | eval_batch_size: int = 128, 93 | eval_user_max_batch_size: Optional[int] = None, 94 | main_module: str = "SASRec", 95 | main_module_bf16: bool = False, 96 | dropout_rate: float = 0.2, 97 | user_embedding_norm: str = "l2_norm", 98 | sampling_strategy: str = "in-batch", 99 | loss_module: str = "SampledSoftmaxLoss", 100 | num_negatives: int = 1, 101 | loss_activation_checkpoint: bool = False, 102 | item_l2_norm: bool = False, 103 | temperature: float = 0.05, 104 | num_epochs: int = 101, 105 | learning_rate: float = 1e-3, 106 | num_warmup_steps: int = 0, 107 | weight_decay: float = 1e-3, 108 | top_k_method: str = "MIPSBruteForceTopK", 109 | eval_interval: int = 100, 110 | full_eval_every_n: int = 1, 111 | save_ckpt_every_n: int = 1000, 112 | partial_eval_num_iters: int = 32, 113 | embedding_module_type: str = "local", 114 | item_embedding_dim: int = 240, 115 | interaction_module_type: str = "", 116 | gr_output_length: int = 10, 117 | l2_norm_eps: float = 1e-6, 118 | enable_tf32: bool = False, 119 | random_seed: int = 42, 120 | ) -> None: 121 | # to enable more deterministic results. 122 | random.seed(random_seed) 123 | torch.backends.cuda.matmul.allow_tf32 = enable_tf32 124 | torch.backends.cudnn.allow_tf32 = enable_tf32 125 | logging.info(f"cuda.matmul.allow_tf32: {enable_tf32}") 126 | logging.info(f"cudnn.allow_tf32: {enable_tf32}") 127 | logging.info(f"Training model on rank {rank}.") 128 | setup(rank, world_size, master_port) 129 | 130 | dataset = get_reco_dataset( 131 | dataset_name=dataset_name, 132 | max_sequence_length=max_sequence_length, 133 | chronological=True, 134 | positional_sampling_ratio=positional_sampling_ratio, 135 | ) 136 | 137 | train_data_sampler, train_data_loader = create_data_loader( 138 | dataset.train_dataset, 139 | batch_size=local_batch_size, 140 | world_size=world_size, 141 | rank=rank, 142 | shuffle=True, 143 | drop_last=world_size > 1, 144 | ) 145 | eval_data_sampler, eval_data_loader = create_data_loader( 146 | dataset.eval_dataset, 147 | batch_size=eval_batch_size, 148 | world_size=world_size, 149 | rank=rank, 150 | shuffle=True, # needed for partial eval 151 | drop_last=world_size > 1, 152 | ) 153 | 154 | model_debug_str = main_module 155 | if embedding_module_type == "local": 156 | embedding_module: EmbeddingModule = LocalEmbeddingModule( 157 | num_items=dataset.max_item_id, 158 | item_embedding_dim=item_embedding_dim, 159 | ) 160 | else: 161 | raise ValueError(f"Unknown embedding_module_type {embedding_module_type}") 162 | model_debug_str += f"-{embedding_module.debug_str()}" 163 | 164 | interaction_module, interaction_module_debug_str = get_similarity_function( 165 | module_type=interaction_module_type, 166 | query_embedding_dim=item_embedding_dim, 167 | item_embedding_dim=item_embedding_dim, 168 | ) 169 | 170 | assert ( 171 | user_embedding_norm == "l2_norm" or user_embedding_norm == "layer_norm" 172 | ), f"Not implemented for {user_embedding_norm}" 173 | output_postproc_module = ( 174 | L2NormEmbeddingPostprocessor( 175 | embedding_dim=item_embedding_dim, 176 | eps=1e-6, 177 | ) 178 | if user_embedding_norm == "l2_norm" 179 | else LayerNormEmbeddingPostprocessor( 180 | embedding_dim=item_embedding_dim, 181 | eps=1e-6, 182 | ) 183 | ) 184 | input_preproc_module = LearnablePositionalEmbeddingInputFeaturesPreprocessor( 185 | max_sequence_len=dataset.max_sequence_length + gr_output_length + 1, 186 | embedding_dim=item_embedding_dim, 187 | dropout_rate=dropout_rate, 188 | ) 189 | 190 | model = get_sequential_encoder( 191 | module_type=main_module, 192 | max_sequence_length=dataset.max_sequence_length, 193 | max_output_length=gr_output_length + 1, 194 | embedding_module=embedding_module, 195 | interaction_module=interaction_module, 196 | input_preproc_module=input_preproc_module, 197 | output_postproc_module=output_postproc_module, 198 | verbose=True, 199 | ) 200 | model_debug_str = model.debug_str() 201 | 202 | # loss 203 | loss_debug_str = loss_module 204 | if loss_module == "BCELoss": 205 | loss_debug_str = loss_debug_str[:-4] 206 | assert temperature == 1.0 207 | ar_loss = BCELoss(temperature=temperature, model=model) 208 | elif loss_module == "SampledSoftmaxLoss": 209 | loss_debug_str = "ssl" 210 | if temperature != 1.0: 211 | loss_debug_str += f"-t{temperature}" 212 | ar_loss = SampledSoftmaxLoss( 213 | num_to_sample=num_negatives, 214 | softmax_temperature=temperature, 215 | model=model, 216 | activation_checkpoint=loss_activation_checkpoint, 217 | ) 218 | loss_debug_str += ( 219 | f"-n{num_negatives}{'-ac' if loss_activation_checkpoint else ''}" 220 | ) 221 | else: 222 | raise ValueError(f"Unrecognized loss module {loss_module}.") 223 | 224 | # sampling 225 | if sampling_strategy == "in-batch": 226 | negatives_sampler = InBatchNegativesSampler( 227 | l2_norm=item_l2_norm, 228 | l2_norm_eps=l2_norm_eps, 229 | dedup_embeddings=True, 230 | ) 231 | sampling_debug_str = ( 232 | f"in-batch{f'-l2-eps{l2_norm_eps}' if item_l2_norm else ''}-dedup" 233 | ) 234 | elif sampling_strategy == "local": 235 | negatives_sampler = LocalNegativesSampler( 236 | num_items=dataset.max_item_id, 237 | item_emb=model._embedding_module._item_emb, 238 | all_item_ids=dataset.all_item_ids, 239 | l2_norm=item_l2_norm, 240 | l2_norm_eps=l2_norm_eps, 241 | ) 242 | else: 243 | raise ValueError(f"Unrecognized sampling strategy {sampling_strategy}.") 244 | sampling_debug_str = negatives_sampler.debug_str() 245 | 246 | # Creates model and moves it to GPU with id rank 247 | device = rank 248 | if main_module_bf16: 249 | model = model.to(torch.bfloat16) 250 | model = model.to(device) 251 | ar_loss = ar_loss.to(device) 252 | negatives_sampler = negatives_sampler.to(device) 253 | model = DDP(model, device_ids=[rank], broadcast_buffers=False) 254 | 255 | # TODO: wrap in create_optimizer. 256 | opt = torch.optim.AdamW( 257 | model.parameters(), 258 | lr=learning_rate, 259 | betas=(0.9, 0.98), 260 | weight_decay=weight_decay, 261 | ) 262 | 263 | date_str = date.today().strftime("%Y-%m-%d") 264 | model_subfolder = f"{dataset_name}-l{max_sequence_length}" 265 | model_desc = ( 266 | f"{model_subfolder}" 267 | + f"/{model_debug_str}_{interaction_module_debug_str}_{sampling_debug_str}_{loss_debug_str}" 268 | + f"{f'-ddp{world_size}' if world_size > 1 else ''}-b{local_batch_size}-lr{learning_rate}-wu{num_warmup_steps}-wd{weight_decay}{'' if enable_tf32 else '-notf32'}-{date_str}" 269 | ) 270 | if full_eval_every_n > 1: 271 | model_desc += f"-fe{full_eval_every_n}" 272 | if positional_sampling_ratio is not None and positional_sampling_ratio < 1: 273 | model_desc += f"-d{positional_sampling_ratio}" 274 | # creates subfolders. 275 | os.makedirs(f"./exps/{model_subfolder}", exist_ok=True) 276 | os.makedirs(f"./ckpts/{model_subfolder}", exist_ok=True) 277 | log_dir = f"./exps/{model_desc}" 278 | if rank == 0: 279 | writer = SummaryWriter(log_dir=log_dir) 280 | logging.info(f"Rank {rank}: writing logs to {log_dir}") 281 | else: 282 | writer = None 283 | logging.info(f"Rank {rank}: disabling summary writer") 284 | 285 | last_training_time = time.time() 286 | torch.autograd.set_detect_anomaly(True) 287 | 288 | batch_id = 0 289 | epoch = 0 290 | 291 | eval_elapse = 0 292 | train_elapse = 0 293 | 294 | for epoch in tqdm(range(num_epochs)): 295 | if train_data_sampler is not None: 296 | train_data_sampler.set_epoch(epoch) 297 | if eval_data_sampler is not None: 298 | eval_data_sampler.set_epoch(epoch) 299 | model.train() 300 | train_elapse -= time.time() 301 | for row in iter(train_data_loader): 302 | seq_features, target_ids, target_ratings = movielens_seq_features_from_row( 303 | row, 304 | device=device, 305 | max_output_length=gr_output_length + 1, 306 | ) 307 | 308 | if (batch_id % eval_interval) == 0: 309 | model.eval() 310 | 311 | seq_features, target_ids, target_ratings = ( 312 | movielens_seq_features_from_row( 313 | row, 314 | device=device, 315 | max_output_length=gr_output_length + 1, 316 | ) 317 | ) 318 | eval_state = get_eval_state( 319 | model=model.module, 320 | all_item_ids=dataset.all_item_ids, 321 | negatives_sampler=negatives_sampler, 322 | top_k_module_fn=lambda item_embeddings, item_ids: get_top_k_module( 323 | top_k_method=top_k_method, 324 | model=model.module, 325 | item_embeddings=item_embeddings, 326 | item_ids=item_ids, 327 | ), 328 | device=device, 329 | float_dtype=torch.bfloat16 if main_module_bf16 else None, 330 | ) 331 | eval_dict = eval_metrics_v2_from_tensors( 332 | eval_state, 333 | model.module, 334 | seq_features, 335 | target_ids=target_ids, 336 | target_ratings=target_ratings, 337 | user_max_batch_size=eval_user_max_batch_size, 338 | dtype=torch.bfloat16 if main_module_bf16 else None, 339 | ) 340 | add_to_summary_writer( 341 | writer, batch_id, eval_dict, prefix="eval", world_size=world_size 342 | ) 343 | logging.info( 344 | f"rank {rank}: batch-stat (eval): iter {batch_id} (epoch {epoch}): " 345 | + f"NDCG@10 {_avg(eval_dict['ndcg@10'], world_size):.4f}, " 346 | f"HR@10 {_avg(eval_dict['hr@10'], world_size):.4f}, " 347 | f"HR@50 {_avg(eval_dict['hr@50'], world_size):.4f}, " 348 | + f"MRR {_avg(eval_dict['mrr'], world_size):.4f} " 349 | ) 350 | model.train() 351 | 352 | # TODO: consider separating this out? 353 | B, N = seq_features.past_ids.shape 354 | seq_features.past_ids.scatter_( 355 | dim=1, 356 | index=seq_features.past_lengths.view(-1, 1), 357 | src=target_ids.view(-1, 1), 358 | ) 359 | 360 | opt.zero_grad() 361 | input_embeddings = model.module.get_item_embeddings(seq_features.past_ids) 362 | seq_embeddings = model( 363 | past_lengths=seq_features.past_lengths, 364 | past_ids=seq_features.past_ids, 365 | past_embeddings=input_embeddings, 366 | past_payloads=seq_features.past_payloads, 367 | ) # [B, X] 368 | 369 | supervision_ids = seq_features.past_ids 370 | 371 | if sampling_strategy == "in-batch": 372 | # get_item_embeddings currently assume 1-d tensor. 373 | in_batch_ids = supervision_ids.view(-1) 374 | negatives_sampler.process_batch( 375 | ids=in_batch_ids, 376 | presences=(in_batch_ids != 0), 377 | embeddings=model.module.get_item_embeddings(in_batch_ids), 378 | ) 379 | else: 380 | negatives_sampler._item_emb = model.module._embedding_module._item_emb 381 | 382 | ar_mask = supervision_ids[:, 1:] != 0 383 | loss = ar_loss( 384 | lengths=seq_features.past_lengths, # [B], 385 | output_embeddings=seq_embeddings[:, :-1, :], # [B, N-1, D] 386 | supervision_ids=supervision_ids[:, 1:], # [B, N-1] 387 | supervision_embeddings=input_embeddings[:, 1:, :], # [B, N - 1, D] 388 | supervision_weights=ar_mask.float(), 389 | negatives_sampler=negatives_sampler, 390 | ) # [B, N] 391 | if rank == 0: 392 | assert writer is not None 393 | writer.add_scalar("losses/ar_loss", loss, batch_id) 394 | 395 | loss.backward() 396 | 397 | # Optional linear warmup. 398 | if batch_id < num_warmup_steps: 399 | lr_scalar = min(1.0, float(batch_id + 1) / num_warmup_steps) 400 | for pg in opt.param_groups: 401 | pg["lr"] = lr_scalar * learning_rate 402 | lr = lr_scalar * learning_rate 403 | else: 404 | lr = learning_rate 405 | 406 | if (batch_id % eval_interval) == 0: 407 | logging.info( 408 | f" rank: {rank}, batch-stat (train): step {batch_id} " 409 | f"(epoch {epoch} in {time.time() - last_training_time:.2f}s): {loss:.6f}" 410 | ) 411 | last_training_time = time.time() 412 | if rank == 0: 413 | assert writer is not None 414 | writer.add_scalar("loss/train", loss, batch_id) 415 | writer.add_scalar("lr", lr, batch_id) 416 | 417 | opt.step() 418 | 419 | batch_id += 1 420 | train_elapse += time.time() 421 | 422 | def is_full_eval(epoch: int) -> bool: 423 | return (epoch % full_eval_every_n) == 0 424 | 425 | # eval per epoch 426 | eval_dict_all = None 427 | eval_start_time = time.time() 428 | model.eval() 429 | eval_state = get_eval_state( 430 | model=model.module, 431 | all_item_ids=dataset.all_item_ids, 432 | negatives_sampler=negatives_sampler, 433 | top_k_module_fn=lambda item_embeddings, item_ids: get_top_k_module( 434 | top_k_method=top_k_method, 435 | model=model.module, 436 | item_embeddings=item_embeddings, 437 | item_ids=item_ids, 438 | ), 439 | device=device, 440 | float_dtype=torch.bfloat16 if main_module_bf16 else None, 441 | ) 442 | eval_elapse -= time.time() 443 | for eval_iter, row in enumerate(iter(eval_data_loader)): 444 | seq_features, target_ids, target_ratings = movielens_seq_features_from_row( 445 | row, device=device, max_output_length=gr_output_length + 1 446 | ) 447 | eval_dict = eval_metrics_v2_from_tensors( 448 | eval_state, 449 | model.module, 450 | seq_features, 451 | target_ids=target_ids, 452 | target_ratings=target_ratings, 453 | user_max_batch_size=eval_user_max_batch_size, 454 | dtype=torch.bfloat16 if main_module_bf16 else None, 455 | ) 456 | 457 | if eval_dict_all is None: 458 | eval_dict_all = {} 459 | for k, v in eval_dict.items(): 460 | eval_dict_all[k] = [] 461 | 462 | for k, v in eval_dict.items(): 463 | eval_dict_all[k] = eval_dict_all[k] + [v] 464 | del eval_dict 465 | 466 | if (eval_iter + 1 >= partial_eval_num_iters) and (not is_full_eval(epoch)): 467 | logging.info( 468 | f"Truncating epoch {epoch} eval to {eval_iter + 1} iters to save cost.." 469 | ) 470 | break 471 | eval_elapse += time.time() 472 | 473 | assert eval_dict_all is not None 474 | for k, v in eval_dict_all.items(): 475 | eval_dict_all[k] = torch.cat(v, dim=-1) 476 | 477 | ndcg_10 = _avg(eval_dict_all["ndcg@10"], world_size=world_size) 478 | ndcg_50 = _avg(eval_dict_all["ndcg@50"], world_size=world_size) 479 | hr_10 = _avg(eval_dict_all["hr@10"], world_size=world_size) 480 | hr_50 = _avg(eval_dict_all["hr@50"], world_size=world_size) 481 | mrr = _avg(eval_dict_all["mrr"], world_size=world_size) 482 | 483 | add_to_summary_writer( 484 | writer, 485 | batch_id=epoch, 486 | metrics=eval_dict_all, 487 | prefix="eval_epoch", 488 | world_size=world_size, 489 | ) 490 | if full_eval_every_n > 1 and is_full_eval(epoch): 491 | add_to_summary_writer( 492 | writer, 493 | batch_id=epoch, 494 | metrics=eval_dict_all, 495 | prefix="eval_epoch_full", 496 | world_size=world_size, 497 | ) 498 | if rank == 0 and epoch > 0 and (epoch % save_ckpt_every_n) == 0: 499 | torch.save( 500 | { 501 | "epoch": epoch, 502 | "model_state_dict": model.state_dict(), 503 | "optimizer_state_dict": opt.state_dict(), 504 | }, 505 | f"./ckpts/{model_desc}_ep{epoch}", 506 | ) 507 | 508 | logging.info( 509 | f"rank {rank}: eval @ epoch {epoch} in {time.time() - eval_start_time:.2f}s: " 510 | f"NDCG@10 {ndcg_10:.4f}, NDCG@50 {ndcg_50:.4f}, HR@10 {hr_10:.4f}, HR@50 {hr_50:.4f}, MRR {mrr:.4f}" 511 | ) 512 | average_train_elapse = train_elapse / (epoch + 1) 513 | average_eval_elapse = eval_elapse / (epoch + 1) 514 | logging.info(f'train speed: {average_train_elapse: .4f}s/epoch, eval speed: {average_eval_elapse: .4f}s/epoch') 515 | last_training_time = time.time() 516 | 517 | average_eval_elapse = eval_elapse / num_epochs 518 | average_train_elapse = train_elapse / num_epochs 519 | logging.info(f'average eval time: {average_eval_elapse: .4f}s/epoch') 520 | logging.info(f'average train time: {average_train_elapse: .4f}s/epoch') 521 | 522 | logging.info(f'model desc: {model_desc}') 523 | 524 | if rank == 0: 525 | if writer is not None: 526 | writer.flush() 527 | writer.close() 528 | 529 | torch.save( 530 | { 531 | "epoch": epoch, 532 | "model_state_dict": model.state_dict(), 533 | "optimizer_state_dict": opt.state_dict(), 534 | }, 535 | f"./ckpts/{model_desc}_ep{epoch}", 536 | ) 537 | 538 | cleanup() 539 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/sequential/autoregressive_losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | import abc 18 | from collections import OrderedDict 19 | from typing import List, Tuple 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | from generative_recommenders.modeling.ndp_module import NDPModule 24 | from torch.utils.checkpoint import checkpoint 25 | 26 | 27 | class NegativesSampler(torch.nn.Module): 28 | 29 | def __init__(self, l2_norm: bool, l2_norm_eps: float) -> None: 30 | super().__init__() 31 | 32 | self._l2_norm: bool = l2_norm 33 | self._l2_norm_eps: float = l2_norm_eps 34 | 35 | def normalize_embeddings(self, x: torch.Tensor) -> torch.Tensor: 36 | return self._maybe_l2_norm(x) 37 | 38 | def _maybe_l2_norm(self, x: torch.Tensor) -> torch.Tensor: 39 | if self._l2_norm: 40 | x = x / torch.clamp( 41 | torch.linalg.norm(x, ord=2, dim=-1, keepdim=True), 42 | min=self._l2_norm_eps, 43 | ) 44 | return x 45 | 46 | @abc.abstractmethod 47 | def debug_str(self) -> str: 48 | pass 49 | 50 | @abc.abstractmethod 51 | def process_batch( 52 | self, 53 | ids: torch.Tensor, 54 | presences: torch.Tensor, 55 | embeddings: torch.Tensor, 56 | ) -> None: 57 | pass 58 | 59 | @abc.abstractmethod 60 | def forward( 61 | self, 62 | positive_ids: torch.Tensor, 63 | num_to_sample: int, 64 | ) -> Tuple[torch.Tensor, torch.Tensor]: 65 | """ 66 | Returns: 67 | A tuple of (sampled_ids, sampled_negative_embeddings). 68 | """ 69 | pass 70 | 71 | 72 | class LocalNegativesSampler(NegativesSampler): 73 | 74 | def __init__( 75 | self, 76 | num_items: int, 77 | item_emb: torch.nn.Embedding, 78 | all_item_ids: List[int], 79 | l2_norm: bool, 80 | l2_norm_eps: float, 81 | ) -> None: 82 | super().__init__(l2_norm=l2_norm, l2_norm_eps=l2_norm_eps) 83 | 84 | self._num_items: int = len(all_item_ids) 85 | self._item_emb: torch.nn.Embedding = item_emb 86 | self.register_buffer("_all_item_ids", torch.tensor(all_item_ids)) 87 | 88 | def debug_str(self) -> str: 89 | sampling_debug_str = ( 90 | f"local{f'-l2-eps{self._l2_norm_eps}' if self._l2_norm else ''}" 91 | ) 92 | return sampling_debug_str 93 | 94 | def process_batch( 95 | self, 96 | ids: torch.Tensor, 97 | presences: torch.Tensor, 98 | embeddings: torch.Tensor, 99 | ) -> None: 100 | pass 101 | 102 | def forward( 103 | self, 104 | positive_ids: torch.Tensor, 105 | num_to_sample: int, 106 | ) -> Tuple[torch.Tensor, torch.Tensor]: 107 | """ 108 | Returns: 109 | A tuple of (sampled_ids, sampled_negative_embeddings). 110 | """ 111 | # assert torch.max(torch.abs(self._item_emb(positive_ids) - positive_embeddings)) < 1e-4 112 | output_shape = positive_ids.size() + (num_to_sample,) 113 | sampled_offsets = torch.randint( 114 | low=0, 115 | high=self._num_items, 116 | size=output_shape, 117 | dtype=positive_ids.dtype, 118 | device=positive_ids.device, 119 | ) 120 | sampled_ids = self._all_item_ids[sampled_offsets.view(-1)].reshape(output_shape) 121 | return sampled_ids, self.normalize_embeddings(self._item_emb(sampled_ids)) 122 | 123 | 124 | class InBatchNegativesSampler(NegativesSampler): 125 | 126 | def __init__( 127 | self, 128 | l2_norm: bool, 129 | l2_norm_eps: float, 130 | dedup_embeddings: bool, 131 | ) -> None: 132 | super().__init__(l2_norm=l2_norm, l2_norm_eps=l2_norm_eps) 133 | 134 | self._dedup_embeddings: bool = dedup_embeddings 135 | 136 | def debug_str(self) -> str: 137 | sampling_debug_str = ( 138 | f"in-batch{f'-l2-eps{self._l2_norm_eps}' if self._l2_norm else ''}" 139 | ) 140 | if self._dedup_embeddings: 141 | sampling_debug_str += "-dedup" 142 | return sampling_debug_str 143 | 144 | def process_batch( 145 | self, 146 | ids: torch.Tensor, 147 | presences: torch.Tensor, 148 | embeddings: torch.Tensor, 149 | ) -> None: 150 | """ 151 | Args: 152 | ids: (N') or (B, N) x int64 153 | presences: (N') or (B, N) x bool 154 | embeddings: (N', D) or (B, N, D) x float 155 | """ 156 | assert ids.size() == presences.size() 157 | assert ids.size() == embeddings.size()[:-1] 158 | if self._dedup_embeddings: 159 | valid_ids = ids[presences] 160 | unique_ids, unique_ids_inverse_indices = torch.unique( 161 | input=valid_ids, sorted=False, return_inverse=True 162 | ) 163 | device = unique_ids.device 164 | unique_embedding_offsets = torch.empty( 165 | (unique_ids.numel(),), 166 | dtype=torch.int64, 167 | device=device, 168 | ) 169 | unique_embedding_offsets[unique_ids_inverse_indices] = torch.arange( 170 | valid_ids.numel(), dtype=torch.int64, device=device 171 | ) 172 | unique_embeddings = embeddings[presences][unique_embedding_offsets, :] 173 | self._cached_embeddings = self._maybe_l2_norm(unique_embeddings) 174 | self._cached_ids = unique_ids 175 | else: 176 | self._cached_embeddings = self._maybe_l2_norm(embeddings[presences]) 177 | self._cached_ids = ids[presences] 178 | 179 | def get_all_ids_and_embeddings(self) -> Tuple[torch.Tensor, torch.Tensor]: 180 | return self._cached_ids, self._cached_embeddings 181 | 182 | def forward( 183 | self, 184 | positive_ids: torch.Tensor, 185 | num_to_sample: int, 186 | ) -> Tuple[torch.Tensor, torch.Tensor]: 187 | """ 188 | Returns: 189 | A tuple of (sampled_ids, sampled_negative_embeddings,). 190 | """ 191 | X = self._cached_ids.size(0) 192 | sampled_offsets = torch.randint( 193 | low=0, 194 | high=X, 195 | size=positive_ids.size() + (num_to_sample,), 196 | dtype=positive_ids.dtype, 197 | device=positive_ids.device, 198 | ) 199 | return ( 200 | self._cached_ids[sampled_offsets], 201 | self._cached_embeddings[sampled_offsets], 202 | ) 203 | 204 | 205 | class AutoregressiveLoss(torch.nn.Module): 206 | 207 | @abc.abstractmethod 208 | def jagged_forward( 209 | self, 210 | output_embeddings: torch.Tensor, 211 | supervision_ids: torch.Tensor, 212 | supervision_embeddings: torch.Tensor, 213 | supervision_weights: torch.Tensor, 214 | negatives_sampler: NegativesSampler, 215 | ) -> torch.Tensor: 216 | """ 217 | Variant of forward() when the tensors are already in jagged format. 218 | 219 | Args: 220 | output_embeddings: [N', D] x float, embeddings for the current 221 | input sequence. 222 | supervision_ids: [N'] x int64, (positive) supervision ids. 223 | supervision_embeddings: [N', D] x float. 224 | supervision_weights: Optional [N'] x float. Optional weights for 225 | masking out invalid positions, or reweighting supervision labels. 226 | negatives_sampler: sampler used to obtain negative examples paired with 227 | positives. 228 | 229 | Returns: 230 | (1), loss for the current engaged sequence. 231 | """ 232 | pass 233 | 234 | @abc.abstractmethod 235 | def forward( 236 | self, 237 | lengths: torch.Tensor, 238 | output_embeddings: torch.Tensor, 239 | supervision_ids: torch.Tensor, 240 | supervision_embeddings: torch.Tensor, 241 | supervision_weights: torch.Tensor, 242 | negatives_sampler: NegativesSampler, 243 | ) -> torch.Tensor: 244 | """ 245 | Args: 246 | lengths: [B] x int32 representing number of non-zero elements per row. 247 | output_embeddings: [B, N, D] x float, embeddings for the current 248 | input sequence. 249 | supervision_ids: [B, N] x int64, (positive) supervision ids. 250 | supervision_embeddings: [B, N, D] x float. 251 | supervision_weights: Optional [B, N] x float. Optional weights for 252 | masking out invalid positions, or reweighting supervision labels. 253 | negatives_sampler: sampler used to obtain negative examples paired with 254 | positives. 255 | 256 | Returns: 257 | (1), loss for the current engaged sequence. 258 | """ 259 | pass 260 | 261 | 262 | class BCELoss(AutoregressiveLoss): 263 | def __init__( 264 | self, 265 | temperature: float, 266 | model: NDPModule, 267 | ) -> None: 268 | super().__init__() 269 | self._temperature: float = temperature 270 | self._model = model 271 | 272 | def jagged_forward( 273 | self, 274 | output_embeddings: torch.Tensor, 275 | supervision_ids: torch.Tensor, 276 | supervision_embeddings: torch.Tensor, 277 | supervision_weights: torch.Tensor, 278 | negatives_sampler: NegativesSampler, 279 | ) -> torch.Tensor: 280 | assert output_embeddings.size() == supervision_embeddings.size() 281 | assert supervision_ids.size() == supervision_embeddings.size()[:-1] 282 | assert supervision_ids.size() == supervision_weights.size() 283 | 284 | sampled_ids, sampled_negative_embeddings = negatives_sampler( 285 | positive_ids=supervision_ids, 286 | num_to_sample=1, 287 | ) 288 | 289 | positive_logits = ( 290 | self._model.interaction( 291 | input_embeddings=output_embeddings, # [B, D] = [N', D] 292 | target_ids=supervision_ids.unsqueeze(1), # [N', 1] 293 | target_embeddings=supervision_embeddings.unsqueeze( 294 | 1 295 | ), # [N', D] -> [N', 1, D] 296 | )[0].squeeze(1) 297 | / self._temperature 298 | ) # [N'] 299 | 300 | sampled_negatives_logits = ( 301 | self._model.interaction( 302 | input_embeddings=output_embeddings, # [N', D] 303 | target_ids=sampled_ids, # [N', 1] 304 | target_embeddings=sampled_negative_embeddings, # [N', 1, D] 305 | )[0].squeeze(1) 306 | / self._temperature 307 | ) # [N'] 308 | sampled_negatives_valid_mask = ( 309 | supervision_ids != sampled_ids.squeeze(1) 310 | ).float() # [N'] 311 | loss_weights = supervision_weights * sampled_negatives_valid_mask 312 | weighted_losses = ( 313 | ( 314 | F.binary_cross_entropy_with_logits( 315 | input=positive_logits, 316 | target=torch.ones_like(positive_logits), 317 | reduction="none", 318 | ) 319 | + F.binary_cross_entropy_with_logits( 320 | input=sampled_negatives_logits, 321 | target=torch.zeros_like(sampled_negatives_logits), 322 | reduction="none", 323 | ) 324 | ) 325 | * loss_weights 326 | * 0.5 327 | ) 328 | return weighted_losses.sum() / loss_weights.sum() 329 | 330 | def forward( 331 | self, 332 | lengths: torch.Tensor, 333 | output_embeddings: torch.Tensor, 334 | supervision_ids: torch.Tensor, 335 | supervision_embeddings: torch.Tensor, 336 | supervision_weights: torch.Tensor, 337 | negatives_sampler: NegativesSampler, 338 | ) -> torch.Tensor: 339 | """ 340 | Args: 341 | lengths: [B] x int32 representing number of non-zero elements per row. 342 | output_embeddings: [B, N, D] x float, embeddings for the current 343 | input sequence. 344 | supervision_ids: [B, N] x int64, (positive) supervision ids. 345 | supervision_embeddings: [B, N, D] x float. 346 | supervision_weights: Optional [B, N] x float. Optional weights for 347 | masking out invalid positions, or reweighting supervision labels. 348 | negatives_sampler: sampler used to obtain negative examples paired with 349 | positives. 350 | Returns: 351 | (1), loss for the current engaged sequence. 352 | """ 353 | assert output_embeddings.size() == supervision_embeddings.size() 354 | assert supervision_ids.size() == supervision_embeddings.size()[:-1] 355 | jagged_id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) 356 | jagged_supervision_ids = ( 357 | torch.ops.fbgemm.dense_to_jagged( 358 | supervision_ids.unsqueeze(-1).float(), [jagged_id_offsets] 359 | )[0] 360 | .squeeze(1) 361 | .long() 362 | ) 363 | jagged_supervision_weights = torch.ops.fbgemm.dense_to_jagged( 364 | supervision_weights.unsqueeze(-1), 365 | [jagged_id_offsets], 366 | )[0].squeeze(1) 367 | return self.jagged_forward( 368 | output_embeddings=torch.ops.fbgemm.dense_to_jagged( 369 | output_embeddings, 370 | [jagged_id_offsets], 371 | )[0], 372 | supervision_ids=jagged_supervision_ids, 373 | supervision_embeddings=torch.ops.fbgemm.dense_to_jagged( 374 | supervision_embeddings, 375 | [jagged_id_offsets], 376 | )[0], 377 | supervision_weights=jagged_supervision_weights, 378 | negatives_sampler=negatives_sampler, 379 | ) 380 | 381 | 382 | class BCELossWithRatings(AutoregressiveLoss): 383 | def __init__( 384 | self, 385 | temperature: float, 386 | model: NDPModule, 387 | ) -> None: 388 | super().__init__() 389 | self._temperature: float = temperature 390 | self._model = model 391 | 392 | def jagged_forward( 393 | self, 394 | output_embeddings: torch.Tensor, 395 | supervision_ids: torch.Tensor, 396 | supervision_embeddings: torch.Tensor, 397 | supervision_weights: torch.Tensor, 398 | supervision_ratings: torch.Tensor, 399 | negatives_sampler: NegativesSampler, 400 | ) -> torch.Tensor: 401 | assert output_embeddings.size() == supervision_embeddings.size() 402 | assert supervision_ids.size() == supervision_embeddings.size()[:-1] 403 | assert supervision_ids.size() == supervision_weights.size() 404 | 405 | target_logits = ( 406 | self._model.interaction( 407 | input_embeddings=output_embeddings, # [B, D] = [N', D] 408 | target_ids=supervision_ids.unsqueeze(1), # [N', 1] 409 | target_embeddings=supervision_embeddings.unsqueeze( 410 | 1 411 | ), # [N', D] -> [N', 1, D] 412 | )[0].squeeze(1) 413 | / self._temperature 414 | ) # [N', 1] 415 | 416 | # loss_weights = (supervision_ids > 0).to(torch.float32) 417 | 418 | weighted_losses = ( 419 | F.binary_cross_entropy_with_logits( 420 | input=target_logits, 421 | target=supervision_ratings.to(dtype=target_logits.dtype), 422 | reduction="none", 423 | ) 424 | ) * supervision_weights 425 | return weighted_losses.sum() / supervision_weights.sum() 426 | 427 | def forward( 428 | self, 429 | lengths: torch.Tensor, 430 | output_embeddings: torch.Tensor, 431 | supervision_ids: torch.Tensor, 432 | supervision_embeddings: torch.Tensor, 433 | supervision_weights: torch.Tensor, 434 | supervision_ratings: torch.Tensor, 435 | negatives_sampler: NegativesSampler, 436 | ) -> torch.Tensor: 437 | """ 438 | Args: 439 | lengths: [B] x int32 representing number of non-zero elements per row. 440 | output_embeddings: [B, N, D] x float, embeddings for the current 441 | input sequence. 442 | supervision_ids: [B, N] x int64, (positive) supervision ids. 443 | supervision_embeddings: [B, N, D] x float. 444 | supervision_weights: Optional [B, N] x float. Optional weights for 445 | masking out invalid positions, or reweighting supervision labels. 446 | negatives_sampler: sampler used to obtain negative examples paired with 447 | positives. 448 | Returns: 449 | (1), loss for the current engaged sequence. 450 | """ 451 | assert output_embeddings.size() == supervision_embeddings.size() 452 | assert supervision_ids.size() == supervision_embeddings.size()[:-1] 453 | jagged_id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) 454 | jagged_supervision_ids = ( 455 | torch.ops.fbgemm.dense_to_jagged( 456 | supervision_ids.unsqueeze(-1).float(), [jagged_id_offsets] 457 | )[0] 458 | .squeeze(1) 459 | .long() 460 | ) 461 | jagged_supervision_weights = torch.ops.fbgemm.dense_to_jagged( 462 | supervision_weights.unsqueeze(-1), 463 | [jagged_id_offsets], 464 | )[0].squeeze(1) 465 | return self.jagged_forward( 466 | output_embeddings=torch.ops.fbgemm.dense_to_jagged( 467 | output_embeddings, 468 | [jagged_id_offsets], 469 | )[0], 470 | supervision_ids=jagged_supervision_ids, 471 | supervision_embeddings=torch.ops.fbgemm.dense_to_jagged( 472 | supervision_embeddings, 473 | [jagged_id_offsets], 474 | )[0], 475 | supervision_weights=jagged_supervision_weights, 476 | supervision_ratings=torch.ops.fbgemm.dense_to_jagged( 477 | supervision_ratings.unsqueeze(-1), 478 | [jagged_id_offsets], 479 | )[0].squeeze(1), 480 | negatives_sampler=negatives_sampler, 481 | ) 482 | 483 | 484 | class SampledSoftmaxLoss(AutoregressiveLoss): 485 | 486 | def __init__( 487 | self, 488 | num_to_sample: int, 489 | softmax_temperature: float, 490 | model: torch.nn.Module, 491 | activation_checkpoint: bool = False, 492 | ) -> None: 493 | super().__init__() 494 | self._num_to_sample: int = num_to_sample 495 | self._softmax_temperature: float = softmax_temperature 496 | self._model = model 497 | self._activation_checkpoint: bool = activation_checkpoint 498 | 499 | def jagged_forward( 500 | self, 501 | output_embeddings: torch.Tensor, 502 | supervision_ids: torch.Tensor, 503 | supervision_embeddings: torch.Tensor, 504 | supervision_weights: torch.Tensor, 505 | negatives_sampler: NegativesSampler, 506 | ) -> torch.Tensor: 507 | assert output_embeddings.size() == supervision_embeddings.size() 508 | assert supervision_ids.size() == supervision_embeddings.size()[:-1] 509 | assert supervision_ids.size() == supervision_weights.size() 510 | 511 | sampled_ids, sampled_negative_embeddings = negatives_sampler( 512 | positive_ids=supervision_ids, 513 | num_to_sample=self._num_to_sample, 514 | ) 515 | positive_embeddings = negatives_sampler.normalize_embeddings( 516 | supervision_embeddings 517 | ) 518 | positive_logits = ( 519 | self._model.interaction( 520 | input_embeddings=output_embeddings, # [B, D] = [N', D] 521 | target_ids=supervision_ids.unsqueeze(1), # [N', 1] 522 | target_embeddings=positive_embeddings.unsqueeze( 523 | 1 524 | ), # [N', D] -> [N', 1, D] 525 | ) 526 | / self._softmax_temperature 527 | ) # [0] 528 | sampled_negatives_logits = self._model.interaction( 529 | input_embeddings=output_embeddings, # [N', D] 530 | target_ids=sampled_ids, # [N', R] 531 | target_embeddings=sampled_negative_embeddings, # [N', R, D] 532 | ) # [N', R] # [0] 533 | sampled_negatives_logits = torch.where( 534 | supervision_ids.unsqueeze(1) == sampled_ids, # [N', R] 535 | -5e4, 536 | sampled_negatives_logits / self._softmax_temperature, 537 | ) 538 | jagged_loss = -F.log_softmax( 539 | torch.cat([positive_logits, sampled_negatives_logits], dim=1), dim=1 540 | )[:, 0] 541 | return (jagged_loss * supervision_weights).sum() / supervision_weights.sum() 542 | 543 | def forward( 544 | self, 545 | lengths: torch.Tensor, 546 | output_embeddings: torch.Tensor, 547 | supervision_ids: torch.Tensor, 548 | supervision_embeddings: torch.Tensor, 549 | supervision_weights: torch.Tensor, 550 | negatives_sampler: NegativesSampler, 551 | ) -> torch.Tensor: 552 | """ 553 | Args: 554 | lengths: [B] x int32 representing number of non-zero elements per row. 555 | output_embeddings: [B, N, D] x float, embeddings for the current 556 | input sequence. 557 | supervision_ids: [B, N] x int64, (positive) supervision ids. 558 | supervision_embeddings: [B, N, D] x float. 559 | supervision_weights: Optional [B, N] x float. Optional weights for 560 | masking out invalid positions, or reweighting supervision labels. 561 | negatives_sampler: sampler used to obtain negative examples paired with 562 | positives. 563 | 564 | Returns: 565 | (1), loss for the current engaged sequence. 566 | """ 567 | assert output_embeddings.size() == supervision_embeddings.size() 568 | assert supervision_ids.size() == supervision_embeddings.size()[:-1] 569 | jagged_id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) 570 | jagged_supervision_ids = ( 571 | torch.ops.fbgemm.dense_to_jagged( 572 | supervision_ids.unsqueeze(-1).float(), [jagged_id_offsets] 573 | )[0] 574 | .squeeze(1) 575 | .long() 576 | ) 577 | 578 | args = OrderedDict( 579 | [ 580 | ( 581 | "output_embeddings", 582 | torch.ops.fbgemm.dense_to_jagged( 583 | output_embeddings, 584 | [jagged_id_offsets], 585 | )[0], 586 | ), 587 | ("supervision_ids", jagged_supervision_ids), 588 | ( 589 | "supervision_embeddings", 590 | torch.ops.fbgemm.dense_to_jagged( 591 | supervision_embeddings, 592 | [jagged_id_offsets], 593 | )[0], 594 | ), 595 | ( 596 | "supervision_weights", 597 | torch.ops.fbgemm.dense_to_jagged( 598 | supervision_weights.unsqueeze(-1), 599 | [jagged_id_offsets], 600 | )[0].squeeze(1), 601 | ), 602 | ("negatives_sampler", negatives_sampler), 603 | ] 604 | ) 605 | if self._activation_checkpoint: 606 | return checkpoint( 607 | self.jagged_forward, 608 | *args.values(), 609 | use_reentrant=False, 610 | ) 611 | else: 612 | return self.jagged_forward( 613 | output_embeddings=torch.ops.fbgemm.dense_to_jagged( 614 | output_embeddings, 615 | [jagged_id_offsets], 616 | )[0], 617 | supervision_ids=jagged_supervision_ids, 618 | supervision_embeddings=torch.ops.fbgemm.dense_to_jagged( 619 | supervision_embeddings, 620 | [jagged_id_offsets], 621 | )[0], 622 | supervision_weights=torch.ops.fbgemm.dense_to_jagged( 623 | supervision_weights.unsqueeze(-1), 624 | [jagged_id_offsets], 625 | )[0].squeeze(1), 626 | negatives_sampler=negatives_sampler, 627 | ) 628 | -------------------------------------------------------------------------------- /generative_recommenders/modeling/similarity/mol.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pyre-unsafe 16 | 17 | """ 18 | Implements MoL (Mixture-of-Logits) in 19 | Revisiting Neural Retrieval on Accelerators (https://arxiv.org/abs/2306.04039, KDD'23). 20 | """ 21 | from typing import Callable, Dict, Optional, Tuple 22 | 23 | import torch 24 | import torch.nn.functional as F 25 | 26 | from generative_recommenders.modeling.initialization import ( 27 | init_mlp_xavier_weights_zero_bias, 28 | ) 29 | 30 | 31 | class SoftmaxDropout(torch.nn.Module): 32 | 33 | def __init__( 34 | self, 35 | dropout_rate: float, 36 | eps: float = 1e-6, 37 | ) -> None: 38 | super().__init__() 39 | 40 | self._softmax: torch.nn.Module = torch.nn.Softmax(dim=-1) 41 | self._dropout: torch.nn.Module = torch.nn.Dropout(p=dropout_rate) 42 | self._eps = eps 43 | 44 | def forward( 45 | self, x: torch.Tensor, tau: Optional[torch.Tensor] = None 46 | ) -> torch.Tensor: 47 | if tau is not None: 48 | x = x / tau 49 | x = self._dropout(self._softmax(x)) 50 | return x / torch.clamp(x.sum(-1, keepdims=True), min=self._eps) 51 | 52 | 53 | class SoftmaxDropoutCombiner(torch.nn.Module): 54 | 55 | def __init__( 56 | self, 57 | dropout_rate: float, 58 | eps: float, 59 | keep_debug_info: bool = False, 60 | ) -> None: 61 | super().__init__() 62 | 63 | self._softmax_dropout: torch.nn.Module = SoftmaxDropout( 64 | dropout_rate=dropout_rate, eps=eps 65 | ) 66 | self._keep_debug_info: bool = keep_debug_info 67 | 68 | def forward( 69 | self, 70 | gating_weights: torch.Tensor, 71 | x: torch.Tensor, 72 | tau: Optional[torch.Tensor] = None, 73 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 74 | combined_logits = (self._softmax_dropout(gating_weights, tau) * x).sum(-1) 75 | if self._keep_debug_info: 76 | return combined_logits, { 77 | "gating_weights": gating_weights.detach().clone(), 78 | "x": x.detach().clone(), 79 | } 80 | else: 81 | return combined_logits 82 | 83 | 84 | class IdentityMLPProjectionFn(torch.nn.Module): 85 | def __init__( 86 | self, 87 | input_dim: int, 88 | output_num_features: int, 89 | output_dim: int, 90 | input_dropout_rate: float, 91 | ) -> None: 92 | super().__init__() 93 | 94 | self._output_num_features = output_num_features 95 | self._output_dim = output_dim 96 | if output_num_features > 1: 97 | self._proj_mlp: torch.nn.modules.container.Sequential = torch.nn.Sequential( 98 | torch.nn.Dropout(p=input_dropout_rate), 99 | torch.nn.Linear( 100 | in_features=input_dim, 101 | out_features=(output_num_features - 1) * output_dim, 102 | ), 103 | ).apply(init_mlp_xavier_weights_zero_bias) 104 | 105 | def forward(self, x: torch.Tensor) -> torch.Tensor: 106 | output_emb_0 = x[..., : self._output_dim] # [.., D] -> [.., 1, D'] 107 | if self._output_num_features > 1: 108 | return torch.cat([output_emb_0, self._proj_mlp(x)], dim=-1) 109 | return output_emb_0 110 | 111 | 112 | class TauFn(torch.nn.Module): 113 | 114 | def __init__( 115 | self, 116 | alpha: float, 117 | item_sideinfo_dim: int, 118 | ) -> None: 119 | super().__init__() 120 | 121 | self._tau_fn: torch.nn.Module = torch.nn.Sequential( 122 | torch.nn.Linear(in_features=item_sideinfo_dim, out_features=1), 123 | torch.nn.Sigmoid(), 124 | ) 125 | self._alpha: float = alpha 126 | 127 | def forward( 128 | self, 129 | item_sideinfo: torch.Tensor, 130 | ) -> torch.Tensor: 131 | return (self._tau_fn(item_sideinfo) + self._alpha) / self._alpha 132 | 133 | 134 | class GeGLU(torch.nn.Module): 135 | 136 | def __init__( 137 | self, 138 | in_features: int, 139 | out_features: int, 140 | ) -> None: 141 | super().__init__() 142 | 143 | self._in_features = in_features 144 | self._out_features = out_features 145 | self._w = torch.nn.Parameter( 146 | torch.empty((in_features, out_features * 2)).normal_(mean=0, std=0.02), 147 | ) 148 | self._b = torch.nn.Parameter( 149 | torch.zeros( 150 | ( 151 | 1, 152 | out_features * 2, 153 | ) 154 | ), 155 | ) 156 | 157 | def forward(self, x: torch.Tensor) -> torch.Tensor: 158 | bs = x.size()[:-1] 159 | lhs, rhs = torch.split( 160 | torch.mm(x.reshape(-1, self._in_features), self._w) + self._b, 161 | [self._out_features, self._out_features], 162 | dim=-1, 163 | ) 164 | return (F.gelu(lhs) * rhs).reshape(bs + (self._out_features,)) 165 | 166 | 167 | class SwiGLU(torch.nn.Module): 168 | """ 169 | SwiGLU as proposed in ``GLU Variants Improve Transformer'' (https://arxiv.org/abs/2002.05202). 170 | """ 171 | 172 | def __init__( 173 | self, 174 | in_features: int, 175 | out_features: int, 176 | ) -> None: 177 | super().__init__() 178 | 179 | self._in_features = in_features 180 | self._out_features = out_features 181 | self._w = torch.nn.Parameter( 182 | torch.empty((in_features, out_features * 2)).normal_(mean=0, std=0.02), 183 | ) 184 | self._b = torch.nn.Parameter( 185 | torch.zeros( 186 | ( 187 | 1, 188 | out_features * 2, 189 | ) 190 | ), 191 | ) 192 | 193 | def forward(self, x: torch.Tensor) -> torch.Tensor: 194 | bs = x.size()[:-1] 195 | lhs, rhs = torch.split( 196 | torch.mm(x.reshape(-1, self._in_features), self._w) + self._b, 197 | [self._out_features, self._out_features], 198 | dim=-1, 199 | ) 200 | return (F.silu(lhs) * rhs).reshape(bs + (self._out_features,)) 201 | 202 | 203 | class MoLGatingFn(torch.nn.Module): 204 | 205 | def __init__( 206 | self, 207 | num_logits: int, 208 | context_embedding_dim: int, 209 | item_embedding_dim: int, 210 | item_sideinfo_dim: int, 211 | context_only_partial_fn: Optional[Callable[[int, int], torch.nn.Module]], 212 | item_only_partial_fn: Optional[Callable[[int, int], torch.nn.Module]], 213 | ci_partial_fn: Optional[Callable[[int, int], torch.nn.Module]], 214 | combination_type: str, 215 | normalization_fn: Callable[[int], torch.nn.Module], 216 | combine_item_sideinfo_into_ci: bool = False, 217 | gating_use_custom_tau: bool = False, 218 | gating_tau_alpha: float = 0.01, 219 | ) -> None: 220 | super().__init__() 221 | 222 | self._context_only_partial_module: Optional[torch.nn.Module] = ( 223 | context_only_partial_fn(context_embedding_dim, num_logits) 224 | if context_only_partial_fn 225 | else None 226 | ) 227 | self._item_only_partial_module: Optional[torch.nn.Module] = ( 228 | item_only_partial_fn(item_embedding_dim + item_sideinfo_dim, num_logits) 229 | if item_only_partial_fn 230 | else None 231 | ) 232 | self._ci_partial_module: Optional[torch.nn.Module] = ( 233 | ci_partial_fn( 234 | num_logits 235 | + (item_sideinfo_dim if combine_item_sideinfo_into_ci else 0), 236 | num_logits, 237 | ) 238 | if ci_partial_fn is not None 239 | else None 240 | ) 241 | if ( 242 | self._context_only_partial_module is None 243 | and self._item_only_partial_module is None 244 | and self._ci_partial_module is None 245 | ): 246 | raise ValueError( 247 | "At least one of context_only_partial_fn, item_only_partial_fn, " 248 | "and ci_partial_fn must not be None." 249 | ) 250 | self._num_logits: int = num_logits 251 | self._combination_type: str = combination_type 252 | self._combine_item_sideinfo_into_ci: bool = combine_item_sideinfo_into_ci 253 | self._normalization_fn: torch.nn.Module = normalization_fn(num_logits) 254 | if gating_use_custom_tau: 255 | self._tau_fn: Optional[TauFn] = TauFn( 256 | item_sideinfo_dim=item_sideinfo_dim, alpha=gating_tau_alpha 257 | ) 258 | else: 259 | self._tau_fn: Optional[TauFn] = None 260 | 261 | def forward( 262 | self, 263 | logits: torch.Tensor, 264 | context_embeddings: torch.Tensor, 265 | item_embeddings: torch.Tensor, 266 | item_sideinfo: Optional[torch.Tensor] = None, 267 | batch_id: Optional[torch.Tensor] = None, 268 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 269 | """ 270 | Args: 271 | logits: (B, X, L) x float 272 | context_embeddings: (B, D) x float 273 | item_embeddings: (1/B, X, D') x float 274 | item_sideinfo: (1/B, X, F) x float or None 275 | batch_id: if present, (,) x int 276 | 277 | Returns: 278 | (B, X) x float 279 | """ 280 | B, X, _ = logits.size() 281 | # [B, 1, F], [1/B, X, F], [B, X, F] 282 | context_partial_inputs, item_partial_inputs, ci_partial_inputs = ( 283 | None, 284 | None, 285 | None, 286 | ) 287 | if self._context_only_partial_module is not None: 288 | context_partial_inputs = self._context_only_partial_module( 289 | context_embeddings 290 | ).unsqueeze(1) 291 | if self._item_only_partial_module is not None: 292 | if item_sideinfo is not None: 293 | item_embeddings = torch.cat([item_embeddings, item_sideinfo], dim=-1) 294 | item_partial_inputs = self._item_only_partial_module(item_embeddings) # pyre-ignore [29] 295 | if self._ci_partial_module is not None: 296 | if self._combine_item_sideinfo_into_ci: 297 | assert item_sideinfo is not None 298 | B_prime = item_sideinfo.size(0) 299 | if B_prime == 1: 300 | item_sideinfo = item_sideinfo.expand(B, -1, -1) 301 | ci_partial_inputs = self._ci_partial_module( # pyre-ignore [29] 302 | torch.cat([logits, item_sideinfo], dim=2) 303 | ) 304 | else: 305 | ci_partial_inputs = self._ci_partial_module(logits) 306 | 307 | if self._combination_type == "glu_silu": 308 | gating_inputs = ( 309 | context_partial_inputs * item_partial_inputs + ci_partial_inputs 310 | ) 311 | gating_weights = gating_inputs * F.sigmoid(gating_inputs) 312 | elif self._combination_type == "glu_silu_ln": 313 | gating_inputs = ( 314 | context_partial_inputs * item_partial_inputs + ci_partial_inputs 315 | ) 316 | gating_weights = gating_inputs * F.sigmoid( 317 | F.layer_norm(gating_inputs, normalized_shape=[self._num_logits]) 318 | ) 319 | elif self._combination_type == "silu": 320 | if context_partial_inputs is not None: 321 | gating_inputs = context_partial_inputs.expand(-1, X, -1) 322 | else: 323 | gating_inputs = None 324 | 325 | if gating_inputs is None: 326 | gating_inputs = item_partial_inputs 327 | elif item_partial_inputs is not None: 328 | gating_inputs = gating_inputs + item_partial_inputs 329 | 330 | if gating_inputs is None: 331 | gating_inputs = ci_partial_inputs 332 | elif ci_partial_inputs is not None: 333 | gating_inputs = gating_inputs + ci_partial_inputs 334 | 335 | gating_weights = gating_inputs * F.sigmoid(gating_inputs) 336 | elif self._combination_type == "none": 337 | gating_inputs = context_partial_inputs 338 | if gating_inputs is None: 339 | gating_inputs = item_partial_inputs 340 | elif item_partial_inputs is not None: 341 | gating_inputs += item_partial_inputs 342 | if gating_inputs is None: 343 | gating_inputs = ci_partial_inputs 344 | elif ci_partial_inputs is not None: 345 | gating_inputs += ci_partial_inputs 346 | gating_weights = gating_inputs 347 | else: 348 | raise ValueError(f"Unknown combination_type {self._combination_type}") 349 | 350 | tau = None 351 | if self._tau_fn is not None: 352 | tau = self._tau_fn(item_sideinfo) 353 | return self._normalization_fn(gating_weights, logits, tau) # , {} 354 | 355 | 356 | class MoLSimilarity(torch.nn.Module): 357 | """ 358 | Implements MoL (Mixture-of-Logits) learned similarity in 359 | Revisiting Neural Retrieval on Accelerators (https://arxiv.org/abs/2306.04039, KDD'23). 360 | """ 361 | 362 | def __init__( 363 | self, 364 | input_embedding_dim: int, 365 | item_embedding_dim: int, 366 | dot_product_dimension: int, 367 | input_dot_product_groups: int, 368 | item_dot_product_groups: int, 369 | temperature: float, 370 | dot_product_l2_norm: bool, 371 | num_precomputed_logits: int, 372 | item_sideinfo_dim: int, 373 | context_proj_fn: Callable[[int, int], torch.nn.Module], 374 | item_proj_fn: Callable[[int, int], torch.nn.Module], 375 | gating_context_only_partial_fn: Optional[Callable[[int, int], torch.nn.Module]], 376 | gating_item_only_partial_fn: Optional[Callable[[int, int], torch.nn.Module]], 377 | gating_ci_partial_fn: Optional[Callable[[int], torch.nn.Module]], 378 | gating_combination_type: str, 379 | gating_normalization_fn: Callable[[int], torch.nn.Module], 380 | eps: float, 381 | gating_combine_item_sideinfo_into_ci: bool = False, 382 | gating_use_custom_tau: bool = False, 383 | gating_tau_alpha: float = 0.01, 384 | bf16_training: bool = False, 385 | ) -> None: 386 | super().__init__() 387 | 388 | self._gating_fn: MoLGatingFn = MoLGatingFn( 389 | num_logits=input_dot_product_groups * item_dot_product_groups 390 | + num_precomputed_logits, 391 | context_embedding_dim=input_embedding_dim, 392 | item_embedding_dim=item_embedding_dim, 393 | item_sideinfo_dim=item_sideinfo_dim, 394 | context_only_partial_fn=gating_context_only_partial_fn, 395 | item_only_partial_fn=gating_item_only_partial_fn, 396 | ci_partial_fn=gating_ci_partial_fn, # pyre-ignore [6] 397 | combine_item_sideinfo_into_ci=gating_combine_item_sideinfo_into_ci, 398 | combination_type=gating_combination_type, 399 | normalization_fn=gating_normalization_fn, 400 | gating_use_custom_tau=gating_use_custom_tau, 401 | gating_tau_alpha=gating_tau_alpha, 402 | ) 403 | self._context_proj_module: torch.nn.Module = context_proj_fn( 404 | input_embedding_dim, 405 | dot_product_dimension * input_dot_product_groups, 406 | ) 407 | self._item_proj_module: torch.nn.Module = item_proj_fn( 408 | item_embedding_dim, # + item_sideinfo_dim, 409 | dot_product_dimension * item_dot_product_groups, 410 | ) 411 | self._item_sideinfo_dim: int = item_sideinfo_dim 412 | self._dot_product_l2_norm: bool = dot_product_l2_norm 413 | self._input_dot_product_groups: int = input_dot_product_groups 414 | self._item_dot_product_groups: int = item_dot_product_groups 415 | self._dot_product_dimension: int = dot_product_dimension 416 | self._temperature: float = temperature 417 | self._eps: float = eps 418 | self._bf16_training: bool = bf16_training 419 | 420 | def _frequency_estimator_old(self, ids: torch.Tensor) -> torch.Tensor: 421 | ids_shape = ids.size() 422 | ids = ids.reshape(-1) 423 | temp = (1 - self._lnx_estimator_alpha) * self._B[ 424 | ids 425 | ] + self._lnx_estimator_alpha * (self._lnx_num_batches + 1 - self._A[ids]) 426 | temp = torch.clamp(temp, max=self._lnx_estimator_b_cap) # pyre-ignore [6] 427 | if self.train: 428 | self._lnx_num_batches = self._lnx_num_batches + 1 429 | self._B[ids] = temp 430 | self._A[ids] = self._lnx_num_batches 431 | return torch.div(1.0, temp.reshape(ids_shape)) 432 | 433 | def _frequency_estimator(self, ids: torch.Tensor, update: bool) -> torch.Tensor: 434 | ids_shape = ids.size() 435 | ids = ids.reshape(-1) 436 | sorted_id_values, sorted_id_indices = ids.sort() 437 | ( 438 | sorted_unique_ids, 439 | sorted_unique_inverses, 440 | sorted_unique_cnts, 441 | ) = sorted_id_values.unique_consecutive( 442 | return_counts=True, 443 | return_inverse=True, 444 | ) 445 | most_recent_batches = torch.zeros_like(sorted_unique_ids, dtype=torch.int64) 446 | most_recent_batches[sorted_unique_inverses] = ( 447 | sorted_id_indices + self._lnx_estimator_num_elements 448 | ) 449 | delta_batches = torch.zeros_like(ids, dtype=torch.float32) 450 | delta_batches[sorted_id_indices] = torch.gather( 451 | input=(most_recent_batches - self._A[sorted_unique_ids]).float() 452 | / sorted_unique_cnts.float(), 453 | dim=0, 454 | index=sorted_unique_inverses, 455 | ) 456 | 457 | temp = (1 - self._lnx_estimator_alpha) * self._B[ 458 | ids 459 | ] + self._lnx_estimator_alpha * delta_batches 460 | temp = torch.clamp(temp, max=self._lnx_estimator_b_cap) # pyre-ignore [6] 461 | 462 | if update: 463 | self._B[ids] = temp 464 | self._A[sorted_unique_ids] = most_recent_batches 465 | self._lnx_estimator_num_elements = ( 466 | self._lnx_estimator_num_elements + ids.numel() 467 | ) 468 | return torch.div(1.0, temp.reshape(ids_shape)) 469 | 470 | def get_query_component_embeddings( 471 | self, 472 | input_embeddings: torch.Tensor, 473 | ) -> torch.Tensor: 474 | """ 475 | Args: 476 | input_embeddings: (B, self._input_embedding_dim,) x float. 477 | 478 | Returns: 479 | (B, query_dot_product_groups, dot_product_embedding_dim) x float. 480 | """ 481 | with torch.autocast( 482 | enabled=self._bf16_training, dtype=torch.bfloat16, device_type="cuda" 483 | ): 484 | split_user_embeddings = self._context_proj_module(input_embeddings).reshape( 485 | ( 486 | input_embeddings.size(0), 487 | self._input_dot_product_groups, 488 | self._dot_product_dimension, 489 | ) 490 | ) 491 | if self._dot_product_l2_norm: 492 | split_user_embeddings = split_user_embeddings / torch.clamp( 493 | torch.linalg.norm( 494 | split_user_embeddings, 495 | ord=None, 496 | dim=-1, 497 | keepdim=True, 498 | ), 499 | min=self._eps, 500 | ) 501 | return split_user_embeddings 502 | 503 | def get_item_component_embeddings( 504 | self, 505 | input_embeddings: torch.Tensor, 506 | ) -> torch.Tensor: 507 | """ 508 | Args: 509 | input_embeddings: (B, self._input_embedding_dim,) x float. 510 | 511 | Returns: 512 | (B, item_dot_product_groups, dot_product_embedding_dim) x float. 513 | """ 514 | with torch.autocast( 515 | enabled=self._bf16_training, dtype=torch.bfloat16, device_type="cuda" 516 | ): 517 | split_item_embeddings = self._item_proj_module(input_embeddings).reshape( 518 | input_embeddings.size()[:-1] 519 | + ( 520 | self._item_dot_product_groups, 521 | self._dot_product_dimension, 522 | ) 523 | ) 524 | if self._dot_product_l2_norm: 525 | split_item_embeddings = split_item_embeddings / torch.clamp( 526 | torch.linalg.norm( 527 | split_item_embeddings, 528 | ord=None, 529 | dim=-1, 530 | keepdim=True, 531 | ), 532 | min=self._eps, 533 | ) 534 | return split_item_embeddings 535 | 536 | def forward( 537 | self, 538 | input_embeddings: torch.Tensor, 539 | item_embeddings: torch.Tensor, 540 | item_sideinfo: Optional[torch.Tensor], 541 | item_ids: torch.Tensor, 542 | precomputed_logits: Optional[torch.Tensor] = None, 543 | batch_id: Optional[int] = None, 544 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 545 | """ 546 | Args: 547 | input_embeddings: (B, self._input_embedding_dim) 548 | item_embeddings: (1/B, X, self._item_embedding_dim) 549 | item_sideinfo: (1/B, X, self._item_sideinfo_dim) 550 | item_ids: (1/B, X,) 551 | precomputed_logits: (B, X, self._num_precomputed_logits,) 552 | """ 553 | with torch.autocast( 554 | enabled=self._bf16_training, dtype=torch.bfloat16, device_type="cuda" 555 | ): 556 | B = input_embeddings.size(0) 557 | B_prime, X, D = item_embeddings.shape 558 | 559 | # if self._item_sideinfo_dim > 0: 560 | # item_proj_input = torch.cat([item_embeddings, item_sideinfo], dim=-1) 561 | # else: 562 | item_proj_input = item_embeddings 563 | 564 | split_user_embeddings = self._context_proj_module(input_embeddings).reshape( 565 | (B, self._input_dot_product_groups, self._dot_product_dimension) 566 | ) 567 | split_item_embeddings = self._item_proj_module(item_proj_input).reshape( 568 | (B_prime, X, self._item_dot_product_groups, self._dot_product_dimension) 569 | ) 570 | if self._dot_product_l2_norm: 571 | split_user_embeddings = split_user_embeddings / torch.clamp( 572 | torch.linalg.norm( 573 | split_user_embeddings, 574 | ord=None, 575 | dim=-1, 576 | keepdim=True, 577 | ), 578 | min=self._eps, 579 | ) 580 | split_item_embeddings = split_item_embeddings / torch.clamp( 581 | torch.linalg.norm( 582 | split_item_embeddings, 583 | ord=None, 584 | dim=-1, 585 | keepdim=True, 586 | ), 587 | min=self._eps, 588 | ) 589 | if B_prime == 1: 590 | # logits = torch.mm(split_user_embeddings, split_item_embeddings.t()).reshape( 591 | # B, self._input_dot_product_groups, X, self._item_dot_product_groups 592 | # ).permute(0, 2, 1, 3) # (bn, xm) -> (b, n, x, m) -> (b, x, n, m) 593 | logits = torch.einsum( 594 | "bnd,xmd->bxnm", 595 | split_user_embeddings, 596 | split_item_embeddings.squeeze(0), 597 | ).reshape( 598 | B, X, self._input_dot_product_groups * self._item_dot_product_groups 599 | ) 600 | else: 601 | # logits = torch.bmm( 602 | # split_user_embeddings, 603 | # split_item_embeddings.permute(0, 2, 1) # [b, n, d], [b, xm, d] -> [b, n, xm] 604 | # ).reshape(B, self._input_dot_product_groups, X, self._item_dot_product_groups).permute(0, 2, 1, 3) 605 | logits = torch.einsum( 606 | "bnd,bxmd->bxnm", split_user_embeddings, split_item_embeddings 607 | ).reshape( 608 | B, X, self._input_dot_product_groups * self._item_dot_product_groups 609 | ) 610 | # [b, x, n, m] -> [b, x, n * m] 611 | # logits = logits.reshape(B, X, self._input_dot_product_groups * self._item_dot_product_groups) 612 | 613 | return self._gating_fn( 614 | logits=logits / self._temperature, # [B, X, L] 615 | context_embeddings=input_embeddings, # [B, D] 616 | item_embeddings=item_embeddings, # [1/B, X, D'] 617 | item_sideinfo=item_sideinfo, # [1/B, X, D''] 618 | batch_id=batch_id, 619 | ) 620 | --------------------------------------------------------------------------------