├── __init__.py ├── logs └── .gitkeep ├── outputs └── .gitkeep ├── src ├── __init__.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── network_blocks │ │ │ ├── normalize_layer.py │ │ │ ├── mlp.py │ │ │ ├── embedding_aggregator.py │ │ │ └── aggregation_strategy.py │ │ └── interfaces.py │ └── modules │ │ ├── __init__.py │ │ ├── huggingface │ │ ├── __init__.py │ │ └── transformer_base_module.py │ │ ├── semantic_id │ │ └── __init__.py │ │ ├── clustering │ │ └── mini_batch_kmeans.py │ │ └── base_module.py ├── modules │ ├── __init__.py │ ├── clustering │ │ ├── __init__.py │ │ └── vector_quantization.py │ └── semantic_embedding_inference_module.py ├── components │ ├── __init__.py │ ├── network_blocks │ │ ├── __init__.py │ │ └── hf_language_model.py │ ├── optimizer.py │ ├── scheduler.py │ ├── training_loop_functions.py │ ├── loss_functions.py │ └── distance_functions.py ├── data │ └── loading │ │ ├── __init__.py │ │ ├── components │ │ ├── __init__.py │ │ ├── custom_dataloader.py │ │ └── dataloading.py │ │ ├── datamodules │ │ └── __init__.py │ │ └── utils.py ├── utils │ ├── __init__.py │ ├── masking_utils.py │ ├── instantiators.py │ ├── pylogger.py │ ├── logging_utils.py │ ├── rich_utils.py │ ├── custom_hydra_resolvers.py │ ├── file_utils.py │ ├── decorators.py │ ├── launcher_utils.py │ ├── tensor_utils.py │ ├── restart_job_utils.py │ └── utils.py ├── inference.py └── train.py ├── configs ├── local │ └── .gitkeep ├── __init__.py ├── logger │ └── csv.yaml ├── callbacks │ ├── local_pickle_writer.yaml │ ├── model_summary.yaml │ ├── rich_progress_bar.yaml │ ├── default.yaml │ ├── early_stopping.yaml │ └── model_checkpoint.yaml ├── extras │ └── default.yaml ├── trainer │ ├── ddp.yaml │ └── default.yaml ├── hydra │ └── default.yaml ├── inference.yaml ├── paths │ └── default.yaml ├── train.yaml └── experiment │ ├── sem_embeds_inference_flat.yaml │ ├── rkmeans_inference_flat.yaml │ ├── tiger_inference_flat.yaml │ ├── rkmeans_train_flat.yaml │ └── rvq_train_flat.yaml ├── LICENSE ├── .gitignore └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /outputs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/loading/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modules/clustering/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/components/network_blocks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/loading/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/loading/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/modules/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/modules/semantic_id/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | ## csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/callbacks/local_pickle_writer.yaml: -------------------------------------------------------------------------------- 1 | ## Inference callback to write predictions local pickle files 2 | pickle_writer: 3 | _target_: src.utils.inference_utils.LocalPickleWriter 4 | output_dir: ${paths.output_dir}/pickle 5 | flush_frequency: 100000 6 | write_interval: batch 7 | should_merge_files_on_main: True 8 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from src.utils.logging_utils import finalize_loggers, log_hyperparameters 3 | from src.utils.pylogger import RankedLogger 4 | from src.utils.rich_utils import enforce_tags, print_config_tree 5 | from src.utils.utils import extras 6 | -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # print warnings for missing configs 8 | print_config_warnings: True 9 | 10 | # pretty print config tree at the start of the run using Rich library 11 | print_config: True 12 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | ## Callback that prints a summary of the model to the console. 2 | ## Implements callback from https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 3 | 4 | model_summary: 5 | _target_: lightning.pytorch.callbacks.RichModelSummary 6 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 7 | -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | ## Callback to show a progress bar using the Rich library. 2 | ## Implements callback from https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 3 | ## Does not work well for unbounded datasets. 4 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 5 | 6 | rich_progress_bar: 7 | _target_: lightning.pytorch.callbacks.RichProgressBar 8 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | # If you want to pass parameters like the static_graph flag, 5 | # comment this one out and uncomment the one below. 6 | strategy: ddp 7 | 8 | # This step can provide a 5% speedup for DDP if your model 9 | # works well with static graph. Uncomment the code below to enable it. 10 | # strategy: 11 | # _target_: lightning.pytorch.strategies.DDPStrategy 12 | # static_graph: True 13 | 14 | accelerator: gpu 15 | devices: -1 16 | num_nodes: 1 17 | sync_batchnorm: True 18 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint 3 | - early_stopping 4 | - model_summary 5 | # - rich_progress_bar # does not work well with unbonunded datasets. 6 | - _self_ 7 | 8 | model_checkpoint: 9 | dirpath: ${paths.output_dir}/checkpoints 10 | filename: "checkpoint_{epoch:03d}_{step:06d}" 11 | monitor: "val/loss" 12 | mode: "min" 13 | auto_insert_metric_name: False 14 | every_n_train_steps: 5000 15 | 16 | early_stopping: 17 | monitor: "val/loss" 18 | patience: 100 19 | mode: "min" 20 | 21 | model_summary: 22 | max_depth: -1 23 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${id} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${id} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: ${hydra.runtime.output_dir}/${task_name}.log 20 | -------------------------------------------------------------------------------- /configs/inference.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - model: null # we need to have model before self to override the default model parameters that are not used in inference 5 | - _self_ 6 | - paths: default 7 | - logger: csv 8 | - trainer: default 9 | - hydra: default 10 | - data_loading: null 11 | - extras: default 12 | - experiment: null 13 | 14 | task_name: "inference" 15 | id: ${now:%Y-%m-%d}/${now:%H-%M-%S} 16 | tags: ["inference"] 17 | 18 | experiment: null 19 | 20 | model: 21 | loss_function: null 22 | optimizer: null 23 | scheduler: null 24 | evaluator: null 25 | 26 | callbacks: 27 | bq_writer: 28 | table_id: ??? 29 | 30 | 31 | 32 | # passing checkpoint path is necessary for inference 33 | ckpt_path: ??? 34 | -------------------------------------------------------------------------------- /src/models/components/network_blocks/normalize_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class NormalizeLayer(nn.Module): 6 | def __init__(self, dim=-1, p=2): 7 | """ 8 | Initialize the NormalizeLayer. 9 | 10 | This is a wrapper around `torch.nn.functional.normalize` that enables using it 11 | as an nn.Module. 12 | 13 | Args: 14 | dim (int): The dimension along which to normalize. Default is -1 (last dimension). 15 | p (float): The norm degree. Default is 2 (L2 normalization). 16 | """ 17 | super(NormalizeLayer, self).__init__() 18 | self.dim = dim 19 | self.p = p 20 | 21 | def forward(self, x): 22 | return F.normalize(x, dim=self.dim, p=self.p) 23 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | _target_: lightning.pytorch.trainer.Trainer 5 | 6 | default_root_dir: ${paths.output_dir} 7 | min_steps: 1 8 | # remember that if you use gradient accumulation, this number will only be updated after each accumulation 9 | # so for 40000 steps with gradient accumulation of 2, you will have 80000 batches 10 | max_steps: !!int 80000 11 | max_epochs: 10 12 | accelerator: cpu 13 | devices: 1 14 | num_nodes: 1 15 | # mixed precision for extra speed-up 16 | precision: bf16-mixed 17 | log_every_n_steps: 2500 18 | # perform a validation loop every N training epochs 19 | val_check_interval: 5000 20 | 21 | # set True to to ensure deterministic results 22 | # makes training slower but gives more reproducibility than just setting seeds 23 | deterministic: False 24 | accumulate_grad_batches: 1 25 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: . 5 | 6 | # path to data directory 7 | data_dir: ??? 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs 11 | # path to output directory, created dynamically by hydra 12 | # path generation pattern is specified in `configs/hydra/default.yaml` 13 | # use it to store all files generated during the run, like ckpts and metrics 14 | output_dir: ${hydra:runtime.output_dir} 15 | 16 | # path to working directory 17 | work_dir: ${hydra:runtime.cwd} 18 | 19 | # We define profile_dir as different than output dir as gcs paths 20 | # are not supported by the profiler. At the end of the run, we copy 21 | # the profiler output to the output_dir 22 | profile_dir: ${hydra:run.dir}/profile_output 23 | 24 | # Path for metadata used to retry jobs if they fail. 25 | metadata_dir: ${paths.output_dir}/metadata 26 | -------------------------------------------------------------------------------- /src/components/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | 3 | 4 | class PassThroughOptimizer(Optimizer): 5 | """ 6 | A dummy PyTorch optimizer that does nothing during the step() function. 7 | 8 | This can be used for testing purposes or for training Lightning modules with manual 9 | parameter updates. In Lightning, we need to call `opt.step()` so that 10 | trainer.global_step is incremented, even if we are doing manual optimization. 11 | """ 12 | 13 | def __init__(self, params, lr=0.01): 14 | defaults = dict(lr=lr) 15 | super(PassThroughOptimizer, self).__init__(params, defaults) 16 | 17 | def step(self, closure=None): 18 | return None 19 | 20 | def zero_grad(self): 21 | for group in self.param_groups: 22 | for p in group["params"]: 23 | if p.grad is not None: 24 | p.grad.detach_() 25 | p.grad.zero_() 26 | 27 | def state_dict(self): 28 | # needed for lightning checkpointing 29 | return {} 30 | 31 | def load_state_dict(self, state_dict): 32 | # needed for lightning checkpointing 33 | pass 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2025 Snap Inc. All rights reserved. 2 | 3 | This sample code is made available by Snap Inc. for non-commercial, research purposes only. 4 | 5 | Non-commercial means not primarily intended for or directed towards commercial advantage or monetary compensation. Research purposes mean solely for study, instruction, or non-commercial research, testing or validation. 6 | 7 | No commercial license, whether implied or otherwise, is granted in or to this code, unless you have entered into a separate agreement with Snap Inc. for such rights. 8 | 9 | This sample code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, title, fitness for a particular purpose, non-infringement, or that the code is free of defects, errors or viruses. In no event will Snap Inc. be liable for any damages or losses of any kind arising from this sample code or your use thereof. 10 | 11 | Any redistribution of this sample code, including in binary form, must retain or reproduce this license text including all copyright notices, conditions and disclaimers. 12 | 13 | Please see notices.txt for attribution notices for third-party software that may be included in portions of this sample code. -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | ## EarlyStopping is used to stop the training if the monitored quantity does not improve. 2 | ## implements callback from https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html 3 | 4 | early_stopping: 5 | _target_: lightning.pytorch.callbacks.EarlyStopping 6 | monitor: ??? # quantity to be monitored, must be specified !!! 7 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 8 | patience: 3 # number of checks with no improvement after which training will be stopped 9 | verbose: False # verbosity mode 10 | mode: "min" # "max" means higher metric value is better, can be also "min" 11 | strict: True # whether to crash the training if monitor is not found in the validation metrics 12 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 13 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 14 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 15 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 16 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 17 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | ## Callback to save the model with the best score. 2 | ## Implements callback from https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 3 | 4 | model_checkpoint: 5 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 6 | dirpath: null # directory to save the model file 7 | filename: null # checkpoint filename 8 | monitor: null # name of the logged metric which determines when model is improving 9 | verbose: True # verbosity mode 10 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 11 | save_top_k: 1 # save k best models (determined by above metric) 12 | mode: "min" # "max" means higher metric value is better, can be also "min" 13 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 14 | save_weights_only: False # if True, then only the model’s weights will be saved 15 | every_n_train_steps: null # number of training steps between checkpoints 16 | train_time_interval: null # checkpoints are monitored at the specified time interval 17 | every_n_epochs: null # number of epochs between checkpoints 18 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 19 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | ## Main Training configurations. 4 | 5 | # specify here default configuration 6 | # order of defaults determines the order in which configs override each other 7 | defaults: 8 | - _self_ 9 | - data_loading: null 10 | - model: null 11 | - callbacks: default 12 | - logger: csv # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 13 | - trainer: default 14 | - paths: default 15 | - extras: default 16 | - hydra: default 17 | - loss: null 18 | - optim: null 19 | - eval: null 20 | 21 | # experiment configs allow for version control of specific hyperparameters 22 | # e.g. best hyperparameters for given model and datamodule 23 | - experiment: null 24 | 25 | # task name, determines output directory path 26 | task_name: "train" 27 | 28 | id: ${now:%Y-%m-%d}/${now:%H-%M-%S} 29 | # tags to help you identify your experiments 30 | # you can overwrite this in experiment configs 31 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 32 | tags: ["dev"] 33 | 34 | # set False to skip model training 35 | train: True 36 | 37 | # evaluate on test set, using best model weights achieved during training 38 | # lightning chooses best weights based on the metric specified in checkpoint callback 39 | test: True 40 | 41 | # simply provide checkpoint path to resume training 42 | ckpt_path: null 43 | 44 | # seed for random number generators in pytorch, numpy and python.random 45 | seed: null 46 | -------------------------------------------------------------------------------- /src/utils/masking_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | def create_last_k_mask( 6 | sequence_length: int, last_item_index: torch.Tensor, last_k: Optional[int] = None 7 | ) -> torch.tensor: 8 | """ 9 | Creates a mask to select the last K items of sequences. 10 | If a sequence has less than K items, all items are considered for the row. 11 | If last_k is None, all items are considered for all rows. 12 | 13 | Args: 14 | sequence_length (int): The length of the sequences. 15 | last_item_index (torch.Tensor) of shape (batch_size,). 16 | The tensor containing the indices of the last items in the each row 17 | last_k (Optional[int]): The number of last K items to consider. 18 | If None, all items are considered. 19 | Returns: 20 | torch.Tensor: A boolean tensor of shape (batch_size, sequence_length) with 21 | True for the last K items in each row and False for the rest. 22 | """ 23 | 24 | if last_k is None: 25 | start_index = torch.zeros_like(last_item_index) 26 | else: 27 | if last_k < 1: 28 | raise ValueError("last_k must be None or greater than or equal to 1") 29 | start_index = torch.clamp( 30 | last_item_index - last_k + 1, min=0 31 | ) # Shape (batch_size,) 32 | 33 | indices = ( 34 | torch.arange(sequence_length, device=last_item_index.device) 35 | .unsqueeze(0) 36 | .expand(last_item_index.size(0), -1) 37 | ) # shape (batch_size, sequence_length) 38 | 39 | mask = (indices >= start_index.unsqueeze(1)) & ( 40 | indices <= last_item_index.unsqueeze(1) 41 | ) # Shape (batch_size, sequence_length) 42 | return mask 43 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import hydra 4 | import rootutils 5 | from omegaconf import DictConfig 6 | 7 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 8 | 9 | from src.utils import RankedLogger, extras 10 | from src.utils.custom_hydra_resolvers import * 11 | from src.utils.launcher_utils import pipeline_launcher 12 | 13 | command_line_logger = RankedLogger(__name__, rank_zero_only=True) 14 | 15 | 16 | def inference(cfg: DictConfig) -> Dict[str, Any]: 17 | """Runs inference using a pre-trained model. 18 | 19 | :param cfg: A DictConfig configuration composed by Hydra. 20 | :return: A dict with all instantiated objects. 21 | """ 22 | 23 | with pipeline_launcher(cfg) as pipeline_modules: 24 | command_line_logger.info("Starting inference!") 25 | ckpt_path = pipeline_modules.cfg.get("ckpt_path", None) 26 | if not ckpt_path: 27 | command_line_logger.warning( 28 | "No ckpt_path was provided. If using a model you trained, this is mandatory. Only leave ckpt_path=None if using a pre-trained model." 29 | ) 30 | 31 | pipeline_modules.trainer.predict( 32 | model=pipeline_modules.model, 33 | datamodule=pipeline_modules.datamodule, 34 | ckpt_path=ckpt_path, 35 | return_predictions=False, 36 | ) 37 | 38 | 39 | @hydra.main(version_base="1.3", config_path="../configs", config_name="inference.yaml") 40 | def main(cfg: DictConfig) -> None: 41 | """Main entry point for inference. 42 | 43 | :param cfg: DictConfig configuration composed by Hydra. 44 | """ 45 | # apply extra utilities 46 | extras(cfg) 47 | 48 | # run inference 49 | inference(cfg) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /src/models/components/network_blocks/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class MLP(nn.Module): 8 | """A simple fully-connected neural net for computing predictions.""" 9 | 10 | def __init__( 11 | self, 12 | input_dim: int, 13 | output_dim: int, 14 | hidden_dim_list: Optional[List[int]] = None, 15 | activation: nn.Module = nn.ReLU, 16 | bias: bool = True, 17 | dropout: float = 0.0, 18 | ) -> None: 19 | """Initialize the MLP. 20 | 21 | Args: 22 | input_dim (int): The dimensionality of the input tensor. 23 | output_dim (int): The dimensionality of the output tensor. 24 | hidden_dim_list Optional(List[int]): A list of the dimensions of each hidden 25 | layer output. The number of layers in the MLP is the length of this list 26 | plus one. 27 | activation (nn.Module): The activation function to use between layers. 28 | bias (bool): Whether to include bias terms in the linear layers. 29 | dropout (float): The dropout rate to apply after each layer. 30 | """ 31 | super().__init__() 32 | 33 | if hidden_dim_list is None: 34 | hidden_dim_list = [] 35 | hidden_dim_list.append(output_dim) 36 | layers = [nn.Linear(input_dim, hidden_dim_list[0], bias=bias)] 37 | for i in range(1, len(hidden_dim_list)): 38 | layers.append(activation()) 39 | layers.append( 40 | nn.Linear(hidden_dim_list[i - 1], hidden_dim_list[i], bias=bias) 41 | ) 42 | layers.append(nn.Dropout(dropout)) 43 | self.output_dim = output_dim 44 | self.input_dim = input_dim 45 | self.model = nn.Sequential(*layers) 46 | 47 | def forward(self, x: torch.Tensor) -> torch.Tensor: 48 | return self.model(x) 49 | -------------------------------------------------------------------------------- /src/models/components/network_blocks/embedding_aggregator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from src.models.components.network_blocks.aggregation_strategy import ( 5 | AggregationStrategy, 6 | ) 7 | 8 | 9 | class EmbeddingAggregator(nn.Module): 10 | """Embedding aggregator function. this determins how user embeddings are aggregated to form the final user embedding. 11 | 12 | Parameters 13 | ---------- 14 | aggregation_type: str 15 | aggregation function type 16 | """ 17 | 18 | def __init__( 19 | self, 20 | aggregation_strategy: AggregationStrategy, 21 | ): 22 | super(EmbeddingAggregator, self).__init__() 23 | self.aggregation_strategy = aggregation_strategy 24 | 25 | def forward( 26 | self, 27 | embeddings: torch.Tensor, 28 | attention_mask: torch.Tensor, 29 | ) -> torch.Tensor: 30 | # embeddings: (batch_size, sequence_length, embedding_dim) 31 | # attention_mask: (batch_size, sequence_length) 32 | 33 | # we -1 here because the token index starts from 0 34 | last_item_index = attention_mask.sum(dim=1) - 1 35 | 36 | # The following 3 steps are equivalent to 37 | # row_ids = torch.arange(embeddings.size(0)) 38 | # but in a way that is traceable with Fx. 39 | 40 | # 1. Create a dummy tensor with the same batch shape as attention_mask 41 | dummy_tensor_for_batch_shape = attention_mask[:, 0] # Shape (batch_size,) 42 | 43 | # 2. Use torch.ones_like to create a tensor of ones with that shape. 44 | # Note that torch.ones is not traceable in Fx, so we use torch.ones_like. 45 | ones_tensor = torch.ones_like(dummy_tensor_for_batch_shape, dtype=torch.long) 46 | 47 | # 3. Use cumsum to get the 0 to batch_size - 1 sequence 48 | row_ids = torch.cumsum(ones_tensor, dim=0) - 1 49 | 50 | return self.aggregation_strategy.aggregate(embeddings, row_ids, last_item_index) 51 | -------------------------------------------------------------------------------- /src/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from src.utils import logging_utils, pylogger 9 | 10 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for name, lg_conf in logger_cfg.items(): 52 | if name == "wandb": 53 | log.info("Authenticating to W&B!") 54 | logging_utils.login_wandb() 55 | 56 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 57 | log.info(f"Instantiating logger <{lg_conf._target_}>") 58 | logger.append(hydra.utils.instantiate(lg_conf)) 59 | 60 | return logger 61 | -------------------------------------------------------------------------------- /src/data/loading/components/custom_dataloader.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from torch.utils.data import _utils 4 | from torch.utils.data.dataloader import ( 5 | DataLoader, 6 | _BaseDataLoaderIter, 7 | _MultiProcessingDataLoaderIter, 8 | _SingleProcessDataLoaderIter, 9 | ) 10 | 11 | from src.utils.pylogger import RankedLogger 12 | 13 | command_line_logger = RankedLogger(__name__, rank_zero_only=False) 14 | 15 | 16 | class _MultiProcessingDataLoaderIterWithRetry(_MultiProcessingDataLoaderIter): 17 | def __init__(self, loader, max_retries=3): 18 | super().__init__(loader) 19 | self._max_retries = max_retries 20 | 21 | def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): 22 | # Note that the max wait time will be timeout * (max_retries + 1) due to original call + all retries. 23 | # This won't fix issues related to corrupted data, but can help with network issues. 24 | retries = 0 25 | for _ in range(self._max_retries): 26 | start_time = time.monotonic() 27 | status, data = super()._try_get_data(timeout) 28 | if status: 29 | return (True, data) 30 | end_time = time.monotonic() 31 | # If it took less than the timeout time, it means the issue was an empty queue. 32 | # If that happens or if we have reached our last retry, we return False. Else, we retry. 33 | # We need this because original _try_get_data will return False if the queue is empty or 34 | # if the timeout is reached. We need a way to differentiate and only retry if there is a timeout. 35 | if end_time - start_time < timeout or retries + 1 == self._max_retries: 36 | return (False, None) 37 | command_line_logger.warning( 38 | f"Retrying after timeout... Retry {retries + 1}/{self._max_retries}" 39 | ) 40 | retries += 1 41 | 42 | 43 | class DataloaderWithIterationRetry(DataLoader): 44 | def __init__(self, max_retries=3, *args, **kwargs): 45 | super().__init__(*args, **kwargs) 46 | self._max_retries = max_retries 47 | 48 | def _get_iterator(self) -> "_BaseDataLoaderIter": 49 | if self.num_workers == 0: 50 | return _SingleProcessDataLoaderIter(self) 51 | else: 52 | self.check_worker_number_rationality() 53 | return _MultiProcessingDataLoaderIterWithRetry( 54 | self, max_retries=self._max_retries 55 | ) 56 | -------------------------------------------------------------------------------- /src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Mapping, Optional 3 | 4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only 5 | 6 | 7 | class RankedLogger(logging.LoggerAdapter): 8 | """A multi-GPU-friendly python command line logger.""" 9 | 10 | def __init__( 11 | self, 12 | name: str = __name__, 13 | rank_zero_only: bool = False, 14 | extra: Optional[Mapping[str, object]] = None, 15 | ) -> None: 16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes 17 | with their rank prefixed in the log message. 18 | 19 | :param name: The name of the logger. Default is ``__name__``. 20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 22 | """ 23 | logger = logging.getLogger(name) 24 | super().__init__(logger=logger, extra=extra) 25 | self.rank_zero_only = rank_zero_only 26 | 27 | def log( 28 | self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs 29 | ) -> None: 30 | """Delegate a log call to the underlying logger, after prefixing its message with the rank 31 | of the process it's being logged from. If `'rank'` is provided, then the log will only 32 | occur on that rank/process. 33 | 34 | :param level: The level to log at. Look at `logging.__init__.py` for more information. 35 | :param msg: The message to log. 36 | :param rank: The rank to log at. 37 | :param args: Additional args to pass to the underlying logging function. 38 | :param kwargs: Any additional keyword args to pass to the underlying logging function. 39 | """ 40 | if self.isEnabledFor(level): 41 | msg, kwargs = self.process(msg, kwargs) 42 | current_rank = getattr(rank_zero_only, "rank", None) 43 | if current_rank is None: 44 | raise RuntimeError( 45 | "The `rank_zero_only.rank` needs to be set before use" 46 | ) 47 | msg = rank_prefixed_message(msg, current_rank) 48 | if self.rank_zero_only: 49 | if current_rank == 0: 50 | self.logger.log(level, msg, *args, **kwargs) 51 | else: 52 | if rank is None: 53 | self.logger.log(level, msg, *args, **kwargs) 54 | elif current_rank == rank: 55 | self.logger.log(level, msg, *args, **kwargs) 56 | -------------------------------------------------------------------------------- /src/components/network_blocks/hf_language_model.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers import PreTrainedModel 6 | from transformers.modeling_outputs import BaseModelOutput 7 | 8 | from src.models.components.network_blocks.embedding_aggregator import ( 9 | EmbeddingAggregator, 10 | ) 11 | 12 | 13 | class HFLanguageModel(nn.Module): 14 | def __init__( 15 | self, 16 | huggingface_model: PreTrainedModel, 17 | aggregator: EmbeddingAggregator, 18 | postprocessor: nn.Module = nn.Identity(), 19 | return_last_hidden_states: bool = False, 20 | ): 21 | """Initialize the HuggingFace language model. 22 | 23 | This is a wrapper around a HuggingFace PreTrainedModel that generates text 24 | sequence embeddings by passing the last hidden states of the PreTrainedModel 25 | through an aggregator and postprocessor. 26 | 27 | Args: 28 | huggingface_model: HuggingFace model to use for language modeling 29 | aggregator: Aggregator to use to aggregate the embeddings 30 | postprocessor: Postprocessor to use to process the aggregated embeddings 31 | return_last_hidden_states: Whether to return the last hidden states 32 | """ 33 | super(HFLanguageModel, self).__init__() 34 | self.huggingface_model = huggingface_model 35 | self.aggregator = aggregator 36 | self.postprocessor = postprocessor 37 | self.return_last_hidden_states = return_last_hidden_states 38 | 39 | def forward( 40 | self, input_ids: torch.Tensor, attention_mask: torch.Tensor 41 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 42 | """Forward pass of the HuggingFace language model. 43 | 44 | Args: 45 | text_ids: Tensor of token ids. 46 | text_attention_masks: Tensor of attention masks. 47 | 48 | Returns: 49 | postprocessed_embeddings: Postprocessed embeddings. 50 | embeddings: Last hidden states if return_last_hidden_states is True. 51 | """ 52 | # TODO(lcollins2): Generalize this to handle other types of inputs 53 | outputs: BaseModelOutput = self.huggingface_model( 54 | input_ids=input_ids, attention_mask=attention_mask 55 | ) 56 | embeddings = outputs.last_hidden_state 57 | aggregated_embeddings = self.aggregator(embeddings, attention_mask) 58 | postprocessed_embeddings = self.postprocessor(aggregated_embeddings) 59 | if self.return_last_hidden_states: 60 | return postprocessed_embeddings, embeddings 61 | return postprocessed_embeddings 62 | -------------------------------------------------------------------------------- /src/components/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | class WarmupCosineSchedulerNonzeroMin(torch.optim.lr_scheduler.LambdaLR): 7 | """Cosine schedule with warmup and decay to a possibly nonzero min learning rate.""" 8 | 9 | def __init__( 10 | self, 11 | optimizer: torch.optim.Optimizer, 12 | warmup_steps: int, 13 | scheduler_steps: int, 14 | min_ratio: float = 0.1, 15 | num_cycles: float = 0.5, 16 | last_epoch: int = -1, 17 | **kwargs, 18 | ): 19 | """ 20 | Create a schedule with a learning rate that decreases following the values of 21 | the cosine function between the initial lr set in the optimizer to 0, after a 22 | warmup period during which it increases linearly between 0 and the initial lr 23 | set in the optimizer. 24 | 25 | Parameters 26 | ---------- 27 | optimizer: torch.optim.Optimizer 28 | The optimizer for which to schedule the learning rate. 29 | warmup_steps: int 30 | The number of steps for the warmup phase. 31 | scheduler_steps: int 32 | The total number of training steps. 33 | min_lr_ratio: int 34 | Minimum learning rate divided by initial learning rate. 35 | num_cycles: float 36 | The number of waves in the cosine schedule (the default is to just decrease 37 | from the max value to 0 following a half-cosine). 38 | last_epoch (`int`, *optional*, defaults to -1): 39 | The index of the last epoch when resuming training. 40 | """ 41 | self.warmup_steps = warmup_steps 42 | self.scheduler_steps = scheduler_steps 43 | self.min_ratio = min_ratio 44 | self.num_cycles = num_cycles 45 | super(WarmupCosineSchedulerNonzeroMin, self).__init__( 46 | optimizer, self.lr_lambda, last_epoch=last_epoch 47 | ) 48 | 49 | def lr_lambda( 50 | self, 51 | step: int, 52 | ) -> float: 53 | """Return the factor to mulitply the initial learning rate with.""" 54 | if step < self.warmup_steps: # linear warm-up 55 | return float(step) / float(max(1, self.warmup_steps)) 56 | if step <= self.scheduler_steps: 57 | # cosine decay 58 | decay_ratio = float(step - self.warmup_steps) / float( 59 | max(1, self.scheduler_steps - self.warmup_steps) 60 | ) 61 | coeff = 0.5 * ( 62 | 1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * decay_ratio) 63 | ) 64 | return max(self.min_ratio, self.min_ratio + coeff * (1 - self.min_ratio)) 65 | else: # current_step > self.scheduler_steps 66 | return self.min_ratio 67 | -------------------------------------------------------------------------------- /src/models/components/interfaces.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | 5 | 6 | class ModelOutput: 7 | def __init__(self, *args, **kwargs): 8 | raise NotImplementedError 9 | 10 | @property 11 | def list_of_row_format(self): 12 | """ 13 | Function used to convert the predictions into a format that can be written to BigQuery. 14 | """ 15 | raise NotImplementedError 16 | 17 | def _convert_to_list(self, prediction: Union[torch.Tensor, List]) -> List: 18 | """ 19 | Convert the prediction to a list so it can be serialized. 20 | """ 21 | if isinstance(prediction, torch.Tensor): 22 | return prediction.detach().cpu().tolist() 23 | 24 | return prediction 25 | 26 | 27 | class SharedKeyAcrossPredictionsOutput(ModelOutput): 28 | """ 29 | A class to represent the output of a model with a single key for all predictions in a batch. 30 | 31 | Attributes: 32 | key: The single key associated with all the predictions. 33 | predictions: The predictions made by the model. 34 | key_name (str): The name of the key attribute. Default is "idx". 35 | prediction_name (str): The name of the prediction attribute. Default is "prediction". 36 | """ 37 | 38 | def __init__( 39 | self, 40 | key, 41 | predictions, 42 | key_name: str = "idx", 43 | prediction_name: str = "prediction", 44 | ): 45 | self.key = key 46 | self.predictions = predictions 47 | self.key_name = key_name 48 | self.prediction_name = prediction_name 49 | 50 | @property 51 | def list_of_row_format(self): 52 | return [ 53 | {self.key_name: self.key, self.prediction_name: pred} 54 | for pred in self._convert_to_list(self.predictions) 55 | ] 56 | 57 | 58 | class OneKeyPerPredictionOutput(ModelOutput): 59 | """ 60 | A class used to represent the output of a model where each prediction is associated with a unique key. 61 | Attributes 62 | ---------- 63 | keys : Any 64 | The keys associated with each prediction. 65 | predictions : Any 66 | The predictions made by the model. 67 | key_name : str, optional 68 | The name to be used for the key in the output dictionary (default is "idx"). 69 | prediction_name : str, optional 70 | The name to be used for the prediction in the output dictionary (default is "prediction"). 71 | """ 72 | 73 | def __init__( 74 | self, 75 | keys, 76 | predictions, 77 | key_name: str = "idx", 78 | prediction_name: str = "prediction", 79 | ): 80 | self.keys = keys 81 | self.predictions = predictions 82 | self.key_name = key_name 83 | self.prediction_name = prediction_name 84 | 85 | @property 86 | def list_of_row_format(self): 87 | return [ 88 | {self.key_name: key, self.prediction_name: pred} 89 | for key, pred in zip( 90 | self._convert_to_list(self.keys), 91 | self._convert_to_list(self.predictions), 92 | ) 93 | ] -------------------------------------------------------------------------------- /src/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from importlib.util import find_spec 4 | from typing import Any, Dict 5 | 6 | from dotenv import load_dotenv 7 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 8 | from lightning_utilities.core.rank_zero import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf 10 | 11 | from src.utils import pylogger 12 | 13 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 14 | 15 | # logging constants 16 | END_RUN = "end_run" 17 | 18 | 19 | def convert_dict_to_json_string(data: dict) -> str: 20 | return json.dumps(data, indent=4) 21 | 22 | 23 | @rank_zero_only 24 | def login_wandb(): 25 | """ 26 | If WANDB_API_KEY is set in the environment, login to wandb. 27 | """ 28 | # Load environment variables from .env file 29 | load_dotenv() 30 | 31 | # Now you can access the WANDB_API_KEY 32 | wandb_api_key = os.getenv("WANDB_API_KEY") 33 | if wandb_api_key: 34 | import wandb 35 | 36 | wandb.login(key=wandb_api_key, relogin=True) 37 | 38 | 39 | @rank_zero_only 40 | def finalize_loggers(trainer: Any, status=END_RUN) -> None: 41 | """ 42 | Finalize loggers after training is done. 43 | 44 | :param trainer: The Lightning trainer. 45 | """ 46 | [ 47 | logger.finalize(status) 48 | for logger in trainer.loggers 49 | if hasattr(logger, "finalize") 50 | ] 51 | 52 | if find_spec( 53 | "wandb" 54 | ): # check if wandb is installed. If so, close connection to wandb. 55 | import wandb 56 | 57 | if wandb.run: 58 | log.info("Closing wandb!") 59 | wandb.finish() 60 | 61 | 62 | @rank_zero_only 63 | def log_hyperparameters( 64 | cfg: DictConfig, model: LightningModule, trainer: Trainer 65 | ) -> None: 66 | """Controls which config parts are saved by Lightning loggers. 67 | 68 | Additionally saves: 69 | - Number of model parameters 70 | 71 | :param object_dict: A dictionary containing the following objects: 72 | - `"cfg"`: A DictConfig object containing the main config. 73 | - `"model"`: The Lightning model. 74 | - `"trainer"`: The Lightning trainer. 75 | """ 76 | hparams = {} 77 | # We resolve the configs to get the actual paths for logging. 78 | cfg = OmegaConf.to_container(cfg, resolve=True) 79 | 80 | if not trainer.logger: 81 | log.warning("Logger not found! Skipping hyperparameter logging...") 82 | return 83 | 84 | hparams["paths"] = cfg["paths"] 85 | hparams["model"] = cfg["model"] 86 | 87 | # save number of model parameters 88 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 89 | hparams["model/params/trainable"] = sum( 90 | p.numel() for p in model.parameters() if p.requires_grad 91 | ) 92 | hparams["model/params/non_trainable"] = sum( 93 | p.numel() for p in model.parameters() if not p.requires_grad 94 | ) 95 | 96 | hparams["data_loading"] = cfg["data_loading"] 97 | hparams["trainer"] = cfg["trainer"] 98 | 99 | hparams["callbacks"] = cfg.get("callbacks") 100 | hparams["extras"] = cfg.get("extras") 101 | 102 | hparams["task_name"] = cfg.get("task_name") 103 | hparams["tags"] = cfg.get("tags") 104 | hparams["ckpt_path"] = cfg.get("ckpt_path") 105 | hparams["seed"] = cfg.get("seed") 106 | 107 | # send hparams to all loggers 108 | for logger in trainer.loggers: 109 | logger.log_hyperparams(hparams) 110 | -------------------------------------------------------------------------------- /src/models/components/network_blocks/aggregation_strategy.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | from src.utils.masking_utils import create_last_k_mask 7 | 8 | 9 | class AggregationStrategy(ABC): 10 | @abstractmethod 11 | def aggregate( 12 | self, 13 | embeddings: torch.Tensor, 14 | row_ids: torch.Tensor, 15 | last_item_index: torch.Tensor, 16 | ) -> torch.Tensor: 17 | pass 18 | 19 | 20 | class MeanAggregation(AggregationStrategy): 21 | """ 22 | Aggregates the embeddings by computing their mean. If last_k is specified, only the last K embeddings are considered. 23 | """ 24 | 25 | def __init__(self, last_k: Optional[int] = None): 26 | """ 27 | Initializes the MeanAggregation class with the specified number of last embeddings to consider. 28 | 29 | Args: 30 | last_k Optional[int] = None 31 | The number of last K embeddings to consider for aggregation. If None, all embeddings are considered. 32 | """ 33 | self.last_k = last_k 34 | 35 | def aggregate( 36 | self, 37 | embeddings: torch.Tensor, 38 | row_ids: torch.Tensor, 39 | last_item_index: torch.Tensor, 40 | ) -> torch.Tensor: 41 | """ 42 | Aggregates the last K embeddings for each row by computing their mean. 43 | 44 | Args: 45 | embeddings (torch.Tensor): Shape (batch_size, sequence_length, embedding_dim). 46 | The tensor containing embeddings for each row. 47 | row_ids (torch.Tensor): Shape (return_size,). 48 | The tensor containing row ids for which the aggregated embedding has to be returned. 49 | last_item_index (torch.Tensor): Shape (return_size,). 50 | The tensor containing the indices of the last items in emdeddings for each row in row_ids. 51 | 52 | Returns: 53 | torch.Tensor: The aggregated embeddings of shape (return_size, embedding_dim). 54 | """ 55 | # Select the embeddings for the specified row ids 56 | embeddings = embeddings[ 57 | row_ids 58 | ] # Shape (return_size, sequence_length, embedding_dim) 59 | # Create a mask to select the last K items of sequences 60 | mask = create_last_k_mask(embeddings.size(1), last_item_index, self.last_k) 61 | mask = mask.to(dtype=embeddings.dtype, device=embeddings.device) 62 | 63 | # Apply the mask to the embeddings 64 | masked_embeddings = embeddings * mask.unsqueeze( 65 | 2 66 | ) # Shape (return_size, sequence_length, embedding_dim) 67 | 68 | # Sum the masked embeddings and divide by the count of non-zero elements in the mask 69 | sum_embeddings = torch.sum( 70 | masked_embeddings, dim=1 71 | ) # Shape (return_size, embedding_dim) 72 | count = ( 73 | torch.sum(mask, dim=1).clamp(min=1).unsqueeze(1) 74 | ) # Shape (return_size, embedding_dim) 75 | return sum_embeddings / count # Shape (return_size, embedding_dim) 76 | 77 | 78 | class LastAggregation(AggregationStrategy): 79 | def aggregate( 80 | self, 81 | embeddings: torch.Tensor, 82 | row_ids: torch.Tensor, 83 | last_item_index: torch.Tensor, 84 | ) -> torch.Tensor: 85 | return embeddings[row_ids, last_item_index] 86 | 87 | 88 | class FirstAggregation(AggregationStrategy): 89 | def aggregate( 90 | self, 91 | embeddings: torch.Tensor, 92 | row_ids: torch.Tensor, 93 | last_item_index: torch.Tensor, 94 | ) -> torch.Tensor: 95 | # Return the first item in each sequence, assuming sequences are right-padded 96 | # TODO(liam): allow all aggregation strategies to handle left padding 97 | return embeddings[row_ids, 0] 98 | -------------------------------------------------------------------------------- /src/modules/semantic_embedding_inference_module.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union 2 | 3 | import torch 4 | import transformers 5 | from lightning import LightningModule 6 | from torch import nn 7 | 8 | from src.data.loading.components.interfaces import ItemData 9 | from src.models.components.interfaces import OneKeyPerPredictionOutput 10 | 11 | 12 | class SemanticEmbeddingInferenceModule(LightningModule): 13 | def __init__( 14 | self, 15 | semantic_embedding_model: Union[nn.Module, transformers.PreTrainedModel], 16 | semantic_embedding_model_input_map: Optional[Dict[str, str]] = None, 17 | **kwargs, 18 | ) -> None: 19 | """ 20 | Initialize the SemanticEmbeddingInferenceModule. 21 | 22 | This module is used to compute semantic embeddings from input data using a 23 | pre-trained, frozen semantic embedding model. It is intended to be used only for 24 | inference. 25 | 26 | Args: 27 | semantic_embedding_model: The model to use for computing semantic embeddings. 28 | semantic_embedding_model_input_map: The mapping from feature names to input names 29 | expected by the semantic embedding model. 30 | """ 31 | super().__init__() 32 | 33 | self.semantic_embedding_model = semantic_embedding_model 34 | # We use a frozen embedding module to compute the input embeddings 35 | for param in self.semantic_embedding_model.parameters(): 36 | param.requires_grad = False 37 | self.semantic_embedding_model_input_map = semantic_embedding_model_input_map 38 | 39 | def forward(self, model_input: ItemData) -> torch.Tensor: 40 | """ 41 | Get the semantic embeddings from the input data. 42 | 43 | Args: 44 | model_input: ItemData consisting of the batch of input features. 45 | 46 | Returns: 47 | semantic_embeddings: The semantic embeddings. 48 | Shape (batch_size, n_features) 49 | """ 50 | semantic_embedding_model_input_name_to_feature = { 51 | input_embedding_model_input_name: model_input.transformed_features[ 52 | feature_name 53 | ] 54 | for input_embedding_model_input_name, feature_name in self.semantic_embedding_model_input_map.items() 55 | } 56 | with torch.no_grad(): 57 | semantic_embeddings = self.semantic_embedding_model( 58 | **semantic_embedding_model_input_name_to_feature 59 | ) 60 | return semantic_embeddings 61 | 62 | def model_step(self, model_input: ItemData) -> torch.Tensor: 63 | semantic_embeddings = self.forward(model_input) 64 | return semantic_embeddings 65 | 66 | def predict_step(self, batch: ItemData) -> OneKeyPerPredictionOutput: 67 | """ 68 | Perform a single prediction step on a batch of data. 69 | 70 | Save the semantic embeddings of the input items and the corresponding item ids 71 | in a OneKeyAcrossPredictionsOutput object. 72 | 73 | Args: 74 | batch: A batch of data of ItemData type. 75 | batch_idx: The index of the batch. 76 | 77 | Returns: 78 | model_output: A SharedKeyAcrossPredictionsOutput object containing the item 79 | ids as keys and the semantic embeddings as predictions. 80 | """ 81 | semantic_embeddings = self.model_step(batch) 82 | item_ids = [ 83 | item_id.item() if isinstance(item_id, torch.Tensor) else item_id 84 | for item_id in batch.item_ids 85 | ] 86 | 87 | model_output = OneKeyPerPredictionOutput( 88 | keys=item_ids, 89 | predictions=semantic_embeddings, 90 | key_name="item_id", 91 | prediction_name="embedding", 92 | ) 93 | return model_output 94 | -------------------------------------------------------------------------------- /src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning_utilities.core.rank_zero import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from src.utils import pylogger 13 | from src.utils.file_utils import open_local_or_remote 14 | 15 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 16 | 17 | 18 | @rank_zero_only 19 | def print_config_tree( 20 | cfg: DictConfig, 21 | print_order: Sequence[str] = ( 22 | "data_loading", 23 | "model", 24 | "callbacks", 25 | "logger", 26 | "trainer", 27 | "paths", 28 | "extras", 29 | ), 30 | resolve: bool = False, 31 | save_to_file: bool = False, 32 | ) -> None: 33 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 34 | 35 | :param cfg: A DictConfig composed by Hydra. 36 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 37 | "callbacks", "logger", "trainer", "paths", "extras")``. 38 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 39 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 40 | """ 41 | style = "dim" 42 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 43 | 44 | queue = [] 45 | 46 | # add fields from `print_order` to queue 47 | for field in print_order: 48 | queue.append(field) if field in cfg else log.warning( 49 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 50 | ) 51 | 52 | # add all the other fields to queue (not specified in `print_order`) 53 | for field in cfg: 54 | if field not in queue: 55 | queue.append(field) 56 | 57 | # generate config tree from queue 58 | for field in queue: 59 | branch = tree.add(field, style=style, guide_style=style) 60 | 61 | config_group = cfg[field] 62 | if isinstance(config_group, DictConfig): 63 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 64 | else: 65 | branch_content = str(config_group) 66 | 67 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 68 | 69 | # print config tree 70 | rich.print(tree) 71 | 72 | # save config tree to file 73 | if save_to_file: 74 | with open_local_or_remote( 75 | f"{cfg.paths.output_dir}/config_tree.log", "w" 76 | ) as file: 77 | rich.print(tree, file=file) 78 | 79 | 80 | @rank_zero_only 81 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 82 | """Prompts user to input tags from command line if no tags are provided in config. 83 | 84 | :param cfg: A DictConfig composed by Hydra. 85 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 86 | """ 87 | if not cfg.get("tags"): 88 | if "id" in HydraConfig().cfg.hydra.job: 89 | raise ValueError("Specify tags before launching a multirun!") 90 | 91 | log.warning("No tags provided in config. Prompting user to input tags...") 92 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 93 | tags = [t.strip() for t in tags.split(",") if t != ""] 94 | 95 | with open_dict(cfg): 96 | cfg.tags = tags 97 | 98 | log.info(f"Tags: {cfg.tags}") 99 | 100 | if save_to_file: 101 | with open_local_or_remote(f"{cfg.paths.output_dir}/tags.log", "w") as file: 102 | rich.print(cfg.tags, file=file) 103 | -------------------------------------------------------------------------------- /src/components/training_loop_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning import LightningModule 3 | 4 | 5 | def scale_loss_by_world_size_for_initialization_training_loop( 6 | model: LightningModule, 7 | loss: torch.Tensor, 8 | world_size: int, 9 | is_initialized: bool = True, 10 | initalization_optimizer_lr: float = 0.5, 11 | initalization_optimizer: torch.optim.Optimizer = torch.optim.SGD, 12 | ): 13 | """ 14 | Training loop that scales the loss by the number of GPUs used for training during 15 | model initialization. 16 | 17 | This is used for training models with DDP that require special initialization 18 | computed based on the data, such as KMeans. To use this function to initialize a 19 | model: 20 | 21 | 1. Ensure that the model parameters are set to zero at the very start of training. 22 | 23 | 2. On one device, compute the desired initial parameters based on the data, and 24 | compute the loss as the squared distance between the model's parameters and the 25 | desired initial parameters. We use only one device to compute the initialization 26 | because different devices may have different data, and the average initialization 27 | across devices may not be a reasonable initialization. 28 | 29 | 3. Set the loss on other devices to zero. 30 | 31 | 4. Feed the loss to this training loop with the `is_initialized=False`. This will 32 | tell this function to scale the loss by the world size, which is necessary for 33 | the following reason: 34 | 35 | Without loss rescaling, the losses are: 36 | - Device 1: || model.parameters - desired_initial_parameters||^2 37 | - Device 2-world_size: 0 38 | The gradients are (recalling that model.parameters are all zero): 39 | - Device 1: -2 * desired_initial_parameters 40 | - Device 2-world_size: 0 41 | Recall that DDP will average the gradients across all devices. At this point, the 42 | average gradient is: 43 | -2 * desired_initial_parameters / world_size 44 | which means that with SGD optimizer with learning rate 0.5, the model's parameters 45 | will be `desired_initial_parameters / world_size` after the update. To avoid this, 46 | we scale the loss by the world size so that the average gradient across all devices 47 | is -2 * desired_initial_parameters, . This means that the model's parameters will be 48 | updated to the desired initial parameters in one step on all devices via DDP. 49 | Note that this is only used for the initialization step. After the model is 50 | initialized, we can use the default training loop without loss rescaling because the 51 | gradients on all devices are non-trivial. 52 | 53 | 5. Ensure that the optimizer is SGD with learning rate 0.5 so that the model's 54 | parameters are updated to the desired initial parameters in one step on all devices 55 | via DDP. 56 | 57 | 6. After the model is initialized in this step, set the `is_initialized` flag to 58 | `True` to use the default training loop, as now we can compute losses on all devices 59 | and no longer need to scale the loss by the world size. 60 | 61 | Args: 62 | trainer: The trainer object. 63 | loss: The loss value. If the model is not yet initialized, this loss 64 | should be the squared distance between the model's parameters, which should 65 | have value zero at this point, and the desired initial parameters. 66 | world_size: The number of GPUs used for training. 67 | is_initialized: A boolean indicating whether the model is already initialized. 68 | """ 69 | if not is_initialized: 70 | # Perform special initialization 71 | # We scale the loss by the world size for proper initialization 72 | opt = initalization_optimizer(model.parameters(), lr=initalization_optimizer_lr) 73 | loss = loss * world_size 74 | else: 75 | # Use the default training loop 76 | opt = model.optimizers() 77 | opt.zero_grad() 78 | model.manual_backward(loss) 79 | opt.step() 80 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, Optional, Tuple 3 | 4 | import hydra 5 | import rootutils 6 | import torch 7 | from omegaconf import DictConfig 8 | 9 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 10 | 11 | # ------------------------------------------------------------------------------------ # 12 | # the setup_root above is equivalent to: 13 | # - adding project root dir to PYTHONPATH 14 | # (so you don't need to force user to install project as a package) 15 | # (necessary before importing any local modules e.g. `from src import utils`) 16 | # - setting up PROJECT_ROOT environment variable 17 | # (which is used as a base for paths in "configs/paths/default.yaml") 18 | # (this way all filepaths are the same no matter where you run the code) 19 | # - loading environment variables from ".env" in root dir 20 | # 21 | # you can remove it if you: 22 | # 1. either install project as a package or move entry files to project root dir 23 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 24 | # 25 | # more info: https://github.com/ashleve/rootutils 26 | # ------------------------------------------------------------------------------------ # 27 | from src.utils import RankedLogger, extras 28 | from src.utils.custom_hydra_resolvers import * 29 | from src.utils.launcher_utils import pipeline_launcher 30 | from src.utils.restart_job import LocalJobLauncher 31 | 32 | command_line_logger = RankedLogger(__name__, rank_zero_only=True) 33 | 34 | torch.set_float32_matmul_precision("medium") 35 | 36 | 37 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 38 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 39 | training. 40 | 41 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 42 | failure. Useful for multiruns, saving info about the crash, etc. 43 | 44 | :param cfg: A DictConfig configuration composed by Hydra. 45 | :return: A tuple with metrics and dict with all instantiated objects. 46 | """ 47 | # Pipeline launcher initializes the modules needed for the pipeline to run. 48 | # It also serves as a context manager, so all resources are properly closed after the pipeline is done. 49 | with pipeline_launcher(cfg) as pipeline_modules: 50 | 51 | if cfg.get("train"): 52 | command_line_logger.info("Starting training!") 53 | pipeline_modules.trainer.fit( 54 | model=pipeline_modules.model, 55 | datamodule=pipeline_modules.datamodule, 56 | ckpt_path=cfg.get("ckpt_path"), 57 | ) 58 | train_metrics = pipeline_modules.trainer.callback_metrics 59 | 60 | if cfg.get("test"): 61 | command_line_logger.info("Starting testing!") 62 | ckpt_path = None 63 | # Check if a checkpoint callback is available and if it has a best model path. 64 | # Note that if multiple checkpoint callbacks are used, only the first one will be used 65 | # to determine the best model path for testing. 66 | checkpoint_callback = getattr( 67 | pipeline_modules.trainer, "checkpoint_callback", None 68 | ) 69 | if checkpoint_callback: 70 | ckpt_path = getattr(checkpoint_callback, "best_model_path", None) 71 | if ckpt_path == "": 72 | ckpt_path = None 73 | if not ckpt_path: 74 | command_line_logger.warning( 75 | "Best checkpoint not found! Using current weights for testing..." 76 | ) 77 | pipeline_modules.trainer.test( 78 | model=pipeline_modules.model, 79 | datamodule=pipeline_modules.datamodule, 80 | ckpt_path=ckpt_path, 81 | ) 82 | command_line_logger.info(f"Best ckpt path: {ckpt_path}") 83 | 84 | test_metrics = pipeline_modules.trainer.callback_metrics 85 | 86 | # merge train and test metrics 87 | metric_dict = {**train_metrics, **test_metrics} 88 | 89 | command_line_logger.info(f"Metrics: {metric_dict}") 90 | 91 | 92 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 93 | def main(cfg: DictConfig) -> Optional[float]: 94 | """Main entry point for training. 95 | 96 | :param cfg: DictConfig configuration composed by Hydra. 97 | :return: Optional[float] with optimized metric value. 98 | """ 99 | # apply extra utilities 100 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 101 | extras(cfg) 102 | job_launcher = LocalJobLauncher(cfg=cfg) 103 | job_launcher.launch(function_to_run=train) 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /src/components/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Optional, Tuple, Union, Dict 5 | 6 | class FullBatchCrossEntropyLoss(nn.Module): 7 | """ 8 | Contrastive loss with negative samples being all candidates in the embedding table. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | normalize: bool = True, 14 | **kwargs, 15 | ): 16 | """ 17 | Initialize the FullBatchContrastiveLoss. 18 | 19 | Parameters 20 | ---------- 21 | contrastive_tau: float 22 | Temperature parameter for the contrastive loss. 23 | normalize: bool 24 | Whether to normalize the embeddings before computing the logits via dot product. 25 | """ 26 | super().__init__() 27 | self.normalize = normalize 28 | self.cross_entroy_loss = torch.nn.CrossEntropyLoss() 29 | 30 | def forward( 31 | self, 32 | query_embeddings: torch.Tensor, 33 | key_embeddings: torch.Tensor, 34 | label_locations: torch.Tensor, 35 | labels: torch.Tensor, 36 | ) -> torch.Tensor: 37 | """ 38 | Compute the contrastive loss with negative samples from the full vocabulary. 39 | 40 | Parameters 41 | ---------- 42 | query_embeddings: torch.Tensor (batch_size x sequence length x embedding_dim) 43 | The embeddings of the query items. 44 | key_embeddings: torch.Tensor (total number of items x embedding_dim) 45 | The embeddings of all items, i.e the full embedding table. 46 | label_locations: torch.Tensor (number of labels x 2) 47 | The locations of the labels in the input sequences. 48 | labels: torch.Tensor (number of labels) 49 | The labels for the input sequences. 50 | 51 | Returns 52 | ------- 53 | torch.Tensor 54 | The contrastive loss. 55 | """ 56 | # get representation of masked tokens 57 | # label_locations[:, 0] refers to the index of sequences 58 | # label_locations[:, 1] refers to the index of tokens in the sequences 59 | query_embeddings = query_embeddings[ 60 | label_locations[:, 0], label_locations[:, 1] 61 | ] 62 | 63 | if self.normalize: 64 | query_embeddings = F.normalize(query_embeddings, dim=-1) 65 | key_embeddings = F.normalize(key_embeddings, dim=-1) 66 | 67 | logits = torch.mm(query_embeddings, key_embeddings.t()) 68 | 69 | loss = self.cross_entroy_loss(logits, labels.long()) 70 | 71 | return loss 72 | 73 | class WeightedSquaredError(torch.nn.Module): 74 | def __init__(self): 75 | """Initialize the WeightedSquaredError loss function.""" 76 | super().__init__() 77 | 78 | def forward( 79 | self, x: torch.Tensor, y: torch.Tensor, weights: Optional[torch.Tensor] = None 80 | ) -> torch.Tensor: 81 | """ 82 | Compute the weighted squared error loss. 83 | 84 | Args: 85 | x: Predicted values of shape (n_points, n_features) 86 | y: Target values of shape (n_points, n_features) 87 | weights: Weights for each point of shape (n_points,) 88 | 89 | Returns: 90 | A tensor containing the weighted squared error loss of shape (1,) 91 | """ 92 | error = x - y 93 | squared_error = torch.sum(error**2, dim=-1) 94 | # If weights are not provided, use uniform weights 95 | # This is equivalent to the standard squared error loss 96 | if weights is None: 97 | return torch.sum(squared_error) 98 | return torch.sum(weights * squared_error) 99 | 100 | class BetaQuantizationLoss(torch.nn.Module): 101 | def __init__(self, beta: float = 0.25, reduction: str = "sum"): 102 | """Initialize the Beta Quantization Loss. 103 | 104 | Parameters 105 | ---------- 106 | beta: float 107 | Weighting factor for the reconstruction loss. 108 | reduction: str 109 | Reduction method to apply to the loss. Options are 'none', 'mean', and 'sum'. 110 | """ 111 | super().__init__() 112 | self.beta = beta 113 | self.criterion = torch.nn.MSELoss(reduction=reduction) 114 | 115 | def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: 116 | """ 117 | Compute the beta quantization loss. 118 | Args: 119 | x: Original tensor of shape (batch_size, n_features) 120 | x: Quantized tensor of shape (batch_size, n_features) 121 | Returns: 122 | A tensor containing the beta quantization loss of shape (1,) 123 | """ 124 | x_no_grad = x.detach() 125 | xq_no_grad = xq.detach() 126 | loss = self.criterion(x_no_grad, xq) + self.beta * self.criterion(x, xq_no_grad) 127 | return loss 128 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Abstra 171 | # Abstra is an AI-powered process automation framework. 172 | # Ignore directories containing user credentials, local state, and settings. 173 | # Learn more at https://abstra.io/docs 174 | .abstra/ 175 | 176 | # Visual Studio Code 177 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 178 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 179 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 180 | # you could uncomment the following to ignore the enitre vscode folder 181 | # .vscode/ 182 | 183 | # Ruff stuff: 184 | .ruff_cache/ 185 | 186 | # PyPI configuration file 187 | .pypirc 188 | 189 | # Cursor 190 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 191 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 192 | # refer to https://docs.cursor.com/context/ignore-files 193 | .cursorignore 194 | .cursorindexingignore -------------------------------------------------------------------------------- /src/modules/clustering/vector_quantization.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | 6 | from src.components.distance_functions import DistanceFunction 7 | from src.components.clustering_initializers import ClusteringInitializer 8 | from src.components.loss_functions import WeightedSquaredError 9 | from src.components.quantization_strategies import QuantizationStrategy 10 | from src.models.modules.clustering.base_clustering_module import BaseClusteringModule 11 | 12 | 13 | class VectorQuantization(BaseClusteringModule): 14 | def __init__( 15 | self, 16 | n_clusters: int, 17 | n_features: int, 18 | distance_function: DistanceFunction, 19 | initializer: ClusteringInitializer, 20 | quantization_strategy: QuantizationStrategy, 21 | loss_function: torch.nn.Module = WeightedSquaredError(), 22 | optimizer: torch.optim.Optimizer = functools.partial( 23 | torch.optim.SGD, 24 | lr=0.5, 25 | ), 26 | init_buffer_size: int = 1000, 27 | ): 28 | """ 29 | Initialize the VectorQuantization module. 30 | 31 | Args: 32 | n_clusters: Number of clusters. 33 | n_features: Number of features in the input data. 34 | distance_function: Distance function to use for computing distances between points. 35 | loss_function: Loss function to use for training. 36 | optimizer: Optimizer to use for training. 37 | init_method: Initialization method ("random" or "k-means++"). 38 | init_buffer_size: Number of points to buffer for initialization. 39 | """ 40 | 41 | super().__init__( 42 | n_clusters=n_clusters, 43 | n_features=n_features, 44 | distance_function=distance_function, 45 | loss_function=loss_function, 46 | optimizer=optimizer, 47 | initializer=initializer, 48 | init_buffer_size=init_buffer_size, 49 | ) 50 | 51 | self.quantization_strategy = quantization_strategy 52 | 53 | def forward( 54 | self, batch: torch.Tensor 55 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 56 | """ 57 | Perform a forward pass of the K-Means model on the input batch. 58 | 59 | This function computes the cluster assignments for each input point, the number 60 | of points in each cluster, and the sum of points in each cluster. 61 | 62 | Args: 63 | batch: Data points of shape (batch_size, n_features) 64 | 65 | Returns: 66 | assignments: Cluster assignments of shape (batch_size,) 67 | embeddings: Embeddings of shape (batch_size, n_features). 68 | These embeddings will be used for computing the quantization loss. 69 | reconstruction_loss_embeddings: Embeddings of shape (batch_size, n_features) 70 | computed in a way that enables gradient backpropagation through the input 71 | embeddings. If the quantization strategy does not support this, this will 72 | be None. 73 | """ 74 | codebook = self.get_centroids() 75 | ( 76 | ids, 77 | embeddings, 78 | reconstruction_loss_embeddings, 79 | ) = self.quantization_strategy.quantize( 80 | codebook=codebook, 81 | batch=batch, 82 | ) 83 | return ids, embeddings, reconstruction_loss_embeddings 84 | 85 | def model_step( 86 | self, 87 | batch: torch.Tensor, 88 | ) -> Tuple[torch.Tensor, torch.Tensor, bool]: 89 | """ 90 | Perform a forward pass of the K-Means model on the batch and compute the loss. 91 | 92 | This function may be called by another LightningModule, such as a residual 93 | K-means module, that is using this MiniBatchKMeans module as a submodule. 94 | 95 | Calling this function along will not update the centroids, and will not 96 | increment self.global_step. If a parent module is using this module as a 97 | submodule, the parent will be responsible for updating those parameters. 98 | Otherwise, these will be updated by Lightning after it calls 99 | training_step. 100 | 101 | Args: 102 | batch: Data points of shape (batch_size, n_features) 103 | 104 | Returns: 105 | assignments: Cluster assignments of shape (batch_size,) 106 | global_loss_embeddings: Embeddings of shape (batch_size, n_features) 107 | loss: Loss value. Tensor of shape (1,) 108 | """ 109 | if batch.device != self.device: 110 | batch = batch.to(self.device) 111 | 112 | # Initialize centroids using the chosen method 113 | # Buffer initial batches for better initialization 114 | if self.is_initial_step: 115 | self.is_initial_step = False 116 | self.is_initialized = True 117 | if not self.is_initialized: 118 | return self.initialization_step(batch) 119 | 120 | assignments, embeddings, reconstruction_loss_embeddings = self.forward(batch) 121 | loss = self.loss_function(batch, embeddings) # quantization loss 122 | return ( 123 | assignments, 124 | reconstruction_loss_embeddings 125 | if reconstruction_loss_embeddings is not None 126 | else embeddings, 127 | loss, 128 | ) 129 | -------------------------------------------------------------------------------- /src/components/distance_functions.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | import torch 4 | 5 | 6 | class DistanceFunction(ABC): 7 | @abstractmethod 8 | def compute(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 9 | """ 10 | Compute distances between the rows of x and the rows of y. 11 | 12 | Args: 13 | x: Data points of shape (n1, d) 14 | y: Centroids of shape (n2, d) 15 | 16 | Returns: 17 | Distances of shape (n1, n2) 18 | """ 19 | pass 20 | 21 | 22 | class SquaredEuclideanDistance(DistanceFunction): 23 | def compute( 24 | self, x: torch.Tensor, y: torch.Tensor, batch_size: int = 256 25 | ) -> torch.Tensor: 26 | """ 27 | Compute squared Euclidean distances between the rows of x and the rows of y, 28 | with optional batching along the x-axis to manage memory. 29 | 30 | Args: 31 | x: Data points of shape (n1, d) 32 | y: Centroids of shape (n2, d) 33 | batch_size: Optional. The number of rows from x to process at a time. 34 | If None, no batching is performed (original behavior). 35 | 36 | Returns: 37 | Squared distances of shape (n1, n2) 38 | 39 | Raises: 40 | AssertionError: If the input tensors do not have the expected shapes 41 | """ 42 | assert x.dim() == 2, f"Data must be 2D, got {x.dim()} dimensions" 43 | assert y.dim() == 2, f"Data must be 2D, got {y.dim()} dimensions" 44 | assert x.size(1) == y.size(1), f"Data must have the same number of columns" 45 | 46 | n1, d = x.shape 47 | n2, _ = y.shape 48 | 49 | if batch_size is None or batch_size >= n1: 50 | # No batching needed or batch_size is larger than n1, compute directly 51 | x_expanded = x.unsqueeze(1) # Shape (n1, 1, d) 52 | y_expanded = y.unsqueeze(0) # Shape (1, n2, d) 53 | sq_diffs = (x_expanded - y_expanded).pow(2) # Shape (n1, n2, d) 54 | sq_distances = torch.sum(sq_diffs, dim=2) # Shape (n1, n2) 55 | return sq_distances 56 | else: 57 | # Perform batching 58 | all_sq_distances = [] 59 | num_batches = (n1 + batch_size - 1) // batch_size # Ceiling division 60 | 61 | for i in range(num_batches): 62 | start_idx = i * batch_size 63 | end_idx = min((i + 1) * batch_size, n1) 64 | x_batch = x[start_idx:end_idx] # Shape (current_batch_size, d) 65 | 66 | # Expand and compute for the current batch 67 | x_batch_expanded = x_batch.unsqueeze( 68 | 1 69 | ) # Shape (current_batch_size, 1, d) 70 | y_expanded = y.unsqueeze(0) # Shape (1, n2, d) - y remains the same 71 | 72 | sq_diffs_batch = (x_batch_expanded - y_expanded).pow( 73 | 2 74 | ) # Shape (current_batch_size, n2, d) 75 | sq_distances_batch = torch.sum( 76 | sq_diffs_batch, dim=2 77 | ) # Shape (current_batch_size, n2) 78 | 79 | all_sq_distances.append(sq_distances_batch) 80 | 81 | # Concatenate the results from all batches 82 | return torch.cat(all_sq_distances, dim=0) 83 | 84 | class WeightedSquaredError(torch.nn.Module): 85 | def __init__(self): 86 | """Initialize the WeightedSquaredError loss function.""" 87 | super().__init__() 88 | 89 | def forward( 90 | self, x: torch.Tensor, y: torch.Tensor, weights: Optional[torch.Tensor] = None 91 | ) -> torch.Tensor: 92 | """ 93 | Compute the weighted squared error loss. 94 | 95 | Args: 96 | x: Predicted values of shape (n_points, n_features) 97 | y: Target values of shape (n_points, n_features) 98 | weights: Weights for each point of shape (n_points,) 99 | 100 | Returns: 101 | A tensor containing the weighted squared error loss of shape (1,) 102 | """ 103 | error = x - y 104 | squared_error = torch.sum(error**2, dim=-1) 105 | # If weights are not provided, use uniform weights 106 | # This is equivalent to the standard squared error loss 107 | if weights is None: 108 | return torch.sum(squared_error) 109 | return torch.sum(weights * squared_error) 110 | 111 | 112 | class BetaQuantizationLoss(torch.nn.Module): 113 | def __init__(self, beta: float = 0.25, reduction: str = "sum"): 114 | """Initialize the Beta Quantization Loss. 115 | 116 | Parameters 117 | ---------- 118 | beta: float 119 | Weighting factor for the reconstruction loss. 120 | reduction: str 121 | Reduction method to apply to the loss. Options are 'none', 'mean', and 'sum'. 122 | """ 123 | super().__init__() 124 | self.beta = beta 125 | self.criterion = torch.nn.MSELoss(reduction=reduction) 126 | 127 | def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: 128 | """ 129 | Compute the beta quantization loss. 130 | Args: 131 | x: Original tensor of shape (batch_size, n_features) 132 | x: Quantized tensor of shape (batch_size, n_features) 133 | Returns: 134 | A tensor containing the beta quantization loss of shape (1,) 135 | """ 136 | x_no_grad = x.detach() 137 | xq_no_grad = xq.detach() 138 | loss = self.criterion(x_no_grad, xq) + self.beta * self.criterion(x, xq_no_grad) 139 | return loss 140 | -------------------------------------------------------------------------------- /configs/experiment/sem_embeds_inference_flat.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_dir: ??? 3 | embedding_model: google/flan-t5-xl 4 | 5 | task_name: inference 6 | id: ${now:%Y-%m-%d}/${now:%H-%M-%S} 7 | tags: 8 | - amazon 9 | - semantic-embeddings-inference 10 | model: 11 | loss_function: null 12 | optimizer: null 13 | scheduler: null 14 | evaluator: null 15 | _target_: src.modules.semantic_embedding_inference_module.SemanticEmbeddingInferenceModule 16 | semantic_embedding_model_input_map: 17 | input_ids: text_tokens 18 | attention_mask: text_mask 19 | semantic_embedding_model: 20 | _target_: src.components.network_blocks.hf_language_model.HFLanguageModel 21 | huggingface_model: 22 | _target_: transformers.T5EncoderModel.from_pretrained 23 | pretrained_model_name_or_path: ${embedding_model} 24 | aggregator: 25 | _target_: src.models.components.network_blocks.embedding_aggregator.EmbeddingAggregator 26 | aggregation_strategy: 27 | _target_: src.models.components.network_blocks.aggregation_strategy.MeanAggregation 28 | callbacks: 29 | bq_writer: null 30 | pickle_writer: 31 | _target_: src.utils.inference_utils.LocalPickleWriter 32 | output_dir: ${paths.output_dir}/pickle 33 | flush_frequency: 64 34 | write_interval: batch 35 | should_merge_files_on_main: true 36 | prediction_key_name: item_id 37 | prediction_name: embedding 38 | should_merge_list_of_keyed_tensors_to_single_tensor: true 39 | ckpt_path: null 40 | paths: 41 | root_dir: . 42 | data_dir: ${data_dir} 43 | log_dir: ${paths.root_dir}/logs 44 | output_dir: ${hydra:runtime.output_dir} 45 | work_dir: ${hydra:runtime.cwd} 46 | profile_dir: ${hydra:run.dir}/profile_output 47 | metadata_dir: ${paths.output_dir}/metadata 48 | logger: {} 49 | trainer: 50 | _target_: lightning.pytorch.trainer.Trainer 51 | default_root_dir: ${paths.output_dir} 52 | min_steps: 1 53 | max_steps: 80000 54 | max_epochs: 10 55 | accelerator: gpu 56 | devices: -1 57 | num_nodes: 1 58 | precision: 32 59 | log_every_n_steps: 2500 60 | val_check_interval: 5000 61 | deterministic: false 62 | accumulate_grad_batches: 1 63 | profiler: 64 | _target_: lightning.pytorch.profilers.PassThroughProfiler 65 | data_loading: 66 | tokenizer_config: 67 | max_length: 128 68 | padding: max_length 69 | truncation: true 70 | add_special_tokens: true 71 | postprocess_eos_token: false 72 | tokenizer: 73 | _target_: transformers.AutoTokenizer.from_pretrained 74 | pretrained_model_name_or_path: ${embedding_model} 75 | features_config: 76 | features: 77 | - name: id 78 | num_placeholder_tokens: 0 79 | is_item_ids: true 80 | embeddings: ??? 81 | type: 82 | _target_: torch.__dict__.get 83 | _args_: 84 | - int32 85 | - name: text 86 | type: 87 | _target_: torch.__dict__.get 88 | _args_: 89 | - bytes 90 | is_text: true 91 | - name: embedding 92 | type: 93 | _target_: torch.__dict__.get 94 | _args_: 95 | - float32 96 | is_embedding: true 97 | dataset_config: 98 | dataset: 99 | _target_: src.data.loading.components.interfaces.ItemDatasetConfig 100 | item_id_field: id 101 | keep_item_id: true 102 | iterate_per_row: true 103 | data_iterator: 104 | _target_: src.data.loading.components.iterators.TFRecordIterator 105 | features_to_consider: ${extract_fields_from_list_of_dicts:${data_loading.features_config.features}, 106 | "name", False, "is_text", "True"} 107 | num_placeholder_tokens_map: ${create_map_from_list_of_dicts:${data_loading.features_config.features}, 108 | "name", "num_placeholder_tokens"} 109 | preprocessing_functions: 110 | - _target_: src.data.loading.components.pre_processing.filter_features_to_consider 111 | _partial_: true 112 | - _target_: src.data.loading.components.pre_processing.convert_to_dense_numpy_array 113 | _partial_: true 114 | features_to_apply: 115 | - id 116 | - text 117 | - _target_: src.data.loading.components.pre_processing.convert_fields_to_tensors 118 | _partial_: true 119 | features_to_apply: 120 | - id 121 | - _target_: src.data.loading.components.pre_processing.convert_bytes_to_string 122 | _partial_: true 123 | features_to_apply: 124 | - text 125 | - _target_: src.data.loading.components.pre_processing.tokenize_text_features 126 | _partial_: true 127 | features_to_apply: 128 | - text 129 | tokenizer_config: ${data_loading.tokenizer_config} 130 | - _target_: src.data.loading.components.pre_processing.squeeze_tensor_in_place 131 | _partial_: true 132 | features_to_apply: 133 | - text 134 | - text_mask 135 | field_type_map: ${create_map_from_list_of_dicts:${data_loading.features_config.features}, 136 | "name", "type"} 137 | datamodule: 138 | _target_: src.data.loading.datamodules.sequence_datamodule.ItemDataModule 139 | predict_dataloader_config: 140 | _target_: src.data.loading.components.interfaces.ItemDataloaderConfig 141 | dataset_class: 142 | _target_: src.data.loading.components.dataloading.UnboundedSequenceIterable 143 | _partial_: true 144 | data_folder: ${paths.data_dir}/items 145 | should_shuffle_rows: false 146 | batch_size_per_device: 8 147 | num_workers: 2 148 | assign_files_by_size: true 149 | timeout: 60 150 | drop_last: false 151 | pin_memory: false 152 | persistent_workers: true 153 | collate_fn: 154 | _target_: src.data.loading.components.collate_functions.collate_fn_items 155 | _partial_: true 156 | item_id_field: ${data_loading.dataset_config.dataset.item_id_field} 157 | feature_to_input_name: 158 | id: item_ids 159 | text: text_tokens 160 | text_mask: text_mask 161 | embedding: input_embedding 162 | dataset_config: ${data_loading.dataset_config.dataset} 163 | limit_files: null 164 | extras: 165 | ignore_warnings: false 166 | enforce_tags: true 167 | print_config_warnings: true 168 | print_config: true 169 | seed: 42 170 | -------------------------------------------------------------------------------- /configs/experiment/rkmeans_inference_flat.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_dir: ??? 3 | embedding_path: ??? 4 | codebook_width: ??? 5 | num_hierarchies: ??? 6 | embedding_dim: ??? 7 | ckpt_path: ??? 8 | seed: 42 9 | 10 | 11 | model: 12 | _target_: src.modules.clustering.residual_quantization.ResidualQuantization 13 | track_residuals: true 14 | verbose: true 15 | train_layer_wise: true 16 | normalize_residuals: true 17 | input_dim: ${embedding_dim} 18 | n_layers: ${num_hierarchies} 19 | init_buffer_size: 3072 20 | quantization_layer: 21 | _target_: src.models.modules.clustering.mini_batch_kmeans.MiniBatchKMeans 22 | n_clusters: ${codebook_width} 23 | n_features: ${model.input_dim} 24 | distance_function: 25 | _target_: src.components.distance_functions.SquaredEuclideanDistance 26 | initializer: 27 | _target_: src.components.clustering_initializers.KMeansPlusPlusInitInitializer 28 | n_clusters: ${model.quantization_layer.n_clusters} 29 | distance_function: ${model.quantization_layer.distance_function} 30 | initialize_on_cpu: false 31 | init_buffer_size: ${model.init_buffer_size} 32 | optimizer: null 33 | scheduler: null 34 | quantization_layer_list: null 35 | training_loop_function: 36 | _target_: src.components.training_loop_functions.scale_loss_by_world_size_for_initialization_training_loop 37 | _partial_: true 38 | loss_function: null 39 | evaluator: null 40 | task_name: inference 41 | id: ${now:%Y-%m-%d}/${now:%H-%M-%S} 42 | tags: 43 | - amazon-assign-ids-inference 44 | experiment: null 45 | callbacks: 46 | bq_writer: 47 | table_id: ??? 48 | pickle_writer: 49 | _target_: src.utils.inference_utils.LocalPickleWriter 50 | output_dir: ${paths.output_dir}/pickle 51 | flush_frequency: 100000 52 | write_interval: batch 53 | should_merge_files_on_main: true 54 | prediction_key_name: item_id 55 | prediction_name: cluster_ids 56 | post_processing_functions: 57 | - function: 58 | _target_: src.utils.tensor_utils.deduplicate_rows_in_tensor 59 | _partial_: true 60 | main_only: true 61 | - function: 62 | _target_: src.utils.tensor_utils.transpose_tensor_from_file 63 | _partial_: true 64 | main_only: true 65 | paths: 66 | root_dir: . 67 | data_dir: ${data_dir} 68 | log_dir: ${paths.root_dir}/logs 69 | output_dir: ${hydra:runtime.output_dir} 70 | work_dir: ${hydra:runtime.cwd} 71 | profile_dir: ${hydra:run.dir}/profile_output 72 | metadata_dir: ${paths.output_dir}/metadata 73 | logger: 74 | csv: 75 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 76 | save_dir: ${paths.output_dir} 77 | name: csv/ 78 | prefix: '' 79 | trainer: 80 | _target_: lightning.pytorch.trainer.Trainer 81 | default_root_dir: ${paths.output_dir} 82 | min_steps: 1 83 | max_steps: 80000 84 | max_epochs: 10 85 | accelerator: gpu 86 | devices: -1 87 | num_nodes: 1 88 | precision: bf16-mixed 89 | log_every_n_steps: 2500 90 | val_check_interval: 5000 91 | deterministic: false 92 | accumulate_grad_batches: 1 93 | profiler: 94 | _target_: lightning.pytorch.profilers.PassThroughProfiler 95 | data_loading: 96 | features_config: 97 | features: 98 | - name: id 99 | num_placeholder_tokens: 0 100 | is_item_ids: true 101 | embeddings: 102 | _target_: torch.load 103 | _args_: 104 | - _target_: src.utils.file_utils.open_local_or_remote 105 | file_path: ${embedding_path} 106 | mode: rb 107 | type: 108 | _target_: torch.__dict__.get 109 | _args_: 110 | - int32 111 | dataset_config: 112 | dataset: 113 | _target_: src.data.loading.components.interfaces.ItemDatasetConfig 114 | item_id_field: id 115 | keep_item_id: true 116 | iterate_per_row: true 117 | data_iterator: 118 | _target_: src.data.loading.components.iterators.TFRecordIterator 119 | features_to_consider: ${extract_fields_from_list_of_dicts:${data_loading.features_config.features}, 120 | "name", False, "is_item_ids", "True"} 121 | embedding_map: 122 | id: ${data_loading.features_config.features[0].embeddings} 123 | num_placeholder_tokens_map: ${create_map_from_list_of_dicts:${data_loading.features_config.features}, 124 | "name", "num_placeholder_tokens"} 125 | preprocessing_functions: 126 | - _target_: src.data.loading.components.pre_processing.filter_features_to_consider 127 | _partial_: true 128 | - _target_: src.data.loading.components.pre_processing.convert_to_dense_numpy_array 129 | _partial_: true 130 | - _target_: src.data.loading.components.pre_processing.convert_fields_to_tensors 131 | _partial_: true 132 | - _target_: src.data.loading.components.pre_processing.map_sparse_id_to_embedding 133 | _partial_: true 134 | sparse_id_field: id 135 | embedding_field_to_add: embedding 136 | field_type_map: ${create_map_from_list_of_dicts:${data_loading.features_config.features}, 137 | "name", "type"} 138 | datamodule: 139 | _target_: src.data.loading.datamodules.sequence_datamodule.ItemDataModule 140 | predict_dataloader_config: 141 | _target_: src.data.loading.components.interfaces.ItemDataloaderConfig 142 | dataset_class: 143 | _target_: src.data.loading.components.dataloading.UnboundedSequenceIterable 144 | _partial_: true 145 | data_folder: ${paths.data_dir}/items 146 | should_shuffle_rows: false 147 | batch_size_per_device: 128 148 | num_workers: 2 149 | assign_files_by_size: true 150 | timeout: 60 151 | drop_last: false 152 | pin_memory: false 153 | persistent_workers: true 154 | collate_fn: 155 | _target_: src.data.loading.components.collate_functions.collate_fn_items 156 | _partial_: true 157 | item_id_field: ${data_loading.dataset_config.dataset.item_id_field} 158 | feature_to_input_name: 159 | id: item_ids 160 | text: text_tokens 161 | text_mask: text_mask 162 | embedding: input_embedding 163 | dataset_config: ${data_loading.dataset_config.dataset} 164 | limit_files: null 165 | extras: 166 | ignore_warnings: false 167 | enforce_tags: true 168 | print_config_warnings: true 169 | print_config: true -------------------------------------------------------------------------------- /src/data/loading/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for data processing.""" 2 | 3 | import heapq 4 | import random 5 | from collections import defaultdict 6 | from typing import Dict, List 7 | 8 | import torch 9 | 10 | from src.utils.file_utils import get_file_size 11 | 12 | 13 | def assign_files_to_workers( 14 | list_of_files: List[str], 15 | total_workers: int, 16 | assign_by_size: bool, 17 | should_shuffle_rows: bool, 18 | assign_all_files_per_worker: bool, 19 | ) -> tuple[Dict[int, List[str]], bool]: 20 | """Assign each file path in `list_of_files` to either one or all workers. 21 | 22 | - If `total_workers == 0`, then the function returns a single-key dict 23 | mapping 0 to `list_of_files` as well as a boolean indicating that the 24 | files are shared among "workers". This is for debuggging. 25 | - Otherwise, if the list of files is shorter than `total_workers`, all files 26 | are assigned to each worker, and the returned boolean indicates that the 27 | files are shared among workers. 28 | - Otherwise, each file gets a single worker, which may be assigned according 29 | to file size, depending on the value of `assign_by_size`: 30 | - If `assign_by_size`, files are sorted by size, then assigned in a 31 | way that encourages even cumulative files size across workers. 32 | - If not `assign_by_size`, files are assigned randomly to workers. 33 | In this case, the return boolean is False, indicating that the files are 34 | not shared among workers. 35 | 36 | :param list_of_files: List of file paths to be assigned. 37 | :param total_workers: The number of workers among which to assign files. 38 | :param assign_by_size: Whether to assign files to balance size (if True), 39 | or to assign randomly. 40 | :param assign_all_files_per_worker: Whether to assign all files to each 41 | worker. 42 | 43 | :return: A dictionary mapping worker indices to file paths and a boolean 44 | indicating whether files have been assigned to all workers (i.e. each 45 | file is shared among all workers). 46 | NOTE: The second returned parameter is currently ignored by 47 | `sequence_datamodule.py` but it will be used after an upcoming PR. 48 | """ 49 | if total_workers == 0: 50 | return {0: list_of_files}, True 51 | 52 | # If more workers than files, then each worker gets all files, but reads 53 | # only a fraction of the rows 54 | if len(list_of_files) < total_workers or assign_all_files_per_worker: 55 | return {worker: list_of_files.copy() for worker in range(total_workers)}, True 56 | 57 | if not assign_by_size: 58 | # files are assigned randomly to workers 59 | list_of_files = list_of_files.copy() 60 | if should_shuffle_rows: 61 | random.shuffle(list_of_files) 62 | worker_to_files = { 63 | worker_id: list_of_files[worker_id::total_workers] 64 | for worker_id in range(total_workers) 65 | } 66 | return worker_to_files, False 67 | 68 | # Otherwise, assign files to workers balancing by file size 69 | list_of_files_and_sizes = [(file, get_file_size(file)) for file in list_of_files] 70 | list_of_files_and_sizes.sort(key=lambda x: x[1], reverse=True) 71 | 72 | worker_to_files = {i: [] for i in range(total_workers)} 73 | worker_loads = [(0, worker_id) for worker_id in range(total_workers)] 74 | 75 | for file, file_size in list_of_files_and_sizes: 76 | # assign file to the worker with smallest storage usage 77 | worker_load, min_worker_load_index = heapq.heappop(worker_loads) 78 | worker_to_files[min_worker_load_index].append(file) 79 | # update worker's total storage usage 80 | heapq.heappush(worker_loads, (worker_load + file_size, min_worker_load_index)) 81 | 82 | return worker_to_files, False 83 | 84 | 85 | def pad_or_trim_sequence( 86 | padded_sequence: torch.Tensor, sequence_length: int, padding_token: int = 0 87 | ) -> torch.Tensor: 88 | """Pad or trim the input sequence to the desired length.""" 89 | 90 | # truncation 91 | if padded_sequence.size(1) > sequence_length: 92 | # TODO (clark): if padded_sequence contains a lot of sequences sharing the same post-fix, 93 | # this current solution will create duplicate sequences. 94 | bs, seq = padded_sequence.shape 95 | arange0 = torch.arange(seq, device=padded_sequence.device).repeat((bs, 1)) 96 | mask = padded_sequence == padding_token 97 | # gets the len before padding 98 | lengths = seq - mask.sum(1) 99 | # shifts only for sequences longer than max_len 100 | shift = torch.clamp(lengths - sequence_length, min=0).unsqueeze(1) 101 | # rotate the indexes so we can trim just the last ones 102 | final_idx = (arange0 + shift) % seq 103 | rotated = torch.gather(padded_sequence, 1, final_idx) 104 | # get just the max len 105 | padded_sequence = rotated[:, :sequence_length] 106 | 107 | # additional padding 108 | if padded_sequence.size(1) < sequence_length: 109 | padding_tensor = ( 110 | padding_token 111 | * torch.ones( 112 | (padded_sequence.shape[0], sequence_length - padded_sequence.size(1)) 113 | ).long() 114 | ) 115 | padded_sequence = torch.cat([padded_sequence, padding_tensor], dim=-1) 116 | return padded_sequence 117 | 118 | 119 | def combine_list_of_tensor_dicts( 120 | list_of_dicts: List[Dict[str, torch.Tensor]] 121 | ) -> Dict[str, List[torch.Tensor]]: 122 | batch = defaultdict(list) 123 | for sequence in list_of_dicts: 124 | for field_name, field_sequence in sequence.items(): 125 | batch[field_name].append(field_sequence) 126 | return batch 127 | 128 | 129 | def convert_all_tensors_to_device(object, device): 130 | if isinstance(object, torch.Tensor): 131 | return object.to(device) 132 | elif isinstance(object, dict): 133 | return { 134 | k: convert_all_tensors_to_device(v, device) 135 | for k, v in object.items() 136 | if v is not None and v != object 137 | } 138 | elif isinstance(object, list): 139 | return [ 140 | convert_all_tensors_to_device(v, device) 141 | for v in object 142 | if v is not None and v != object 143 | ] 144 | else: 145 | return object 146 | -------------------------------------------------------------------------------- /src/utils/custom_hydra_resolvers.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import operator as op 3 | from typing import Optional 4 | 5 | from omegaconf import DictConfig, ListConfig, OmegaConf 6 | 7 | """" 8 | Hydra allows for custom resolvers, which are functions that can be used to resolve values in the config. 9 | For example, one can manipulate strings or apply simple python functions to the config values. 10 | 11 | """ 12 | 13 | 14 | def remove_chars_from_string(s: str, chars: str) -> str: 15 | """Removes all occurrences of `chars` from `s`. 16 | 17 | :param s: The input string. 18 | :param chars: The characters to remove from `s`. 19 | 20 | :return: The string `s` with all occurrences of `chars` removed. 21 | """ 22 | return s.translate(str.maketrans("", "", chars)) 23 | 24 | 25 | def conditional_expression( 26 | condition_expression, value_if_true, value_if_false, **kwargs 27 | ): 28 | """ 29 | A generic resolver that evaluates a condition expression based on config values. 30 | """ 31 | try: 32 | # Evaluate the condition expression with the config context 33 | result = eval(condition_expression, {}, kwargs) 34 | return value_if_true if result else value_if_false 35 | except Exception as e: 36 | raise ValueError( 37 | f"Error evaluating condition: {condition_expression}. Error: {e}" 38 | ) 39 | 40 | 41 | def extract_fields_from_list_of_dicts( 42 | list_of_dicts: ListConfig, 43 | key: str, 44 | default: str = None, 45 | filter_key: str = None, 46 | filter_value: str = None, 47 | ) -> ListConfig: 48 | """ 49 | Extracts a list of values from a list of dictionaries based on a key, with an optional filter condition. 50 | 51 | :param list_of_dicts: The list of dictionaries to extract values from. 52 | :param key: The key to extract values for. 53 | :param default: The default value to use if the key is not found in a dictionary. 54 | :param filter_key: The key to filter dictionaries by. 55 | :param filter_value: The value that the filter_key should have for a dictionary to be included. 56 | 57 | Example: 58 | Given a list of dictionaries: 59 | [ 60 | {"name": "feature1", "is_sparse": True}, 61 | {"name": "feature2", "is_sparse": False}, 62 | {"name": "feature3"}, 63 | ] 64 | 65 | extract_fields_from_list_of_dicts(features, "name") 66 | will return ["feature1", "feature2", "feature3"] 67 | 68 | extract_fields_from_list_of_dicts(features, "is_sparse", default=False) 69 | will return [True, False, False] 70 | 71 | extract_fields_from_list_of_dicts(features, "name", default=False, filter_key="is_sparse", filter_value="True") 72 | will return ["feature1"] 73 | 74 | 75 | :return: A ListConfig of extracted values. 76 | """ 77 | if filter_key and filter_value: 78 | filtered_dicts = [ 79 | d for d in list_of_dicts if d.get(filter_key) == eval(filter_value) 80 | ] 81 | else: 82 | filtered_dicts = list_of_dicts 83 | 84 | return ListConfig([d.get(key, default) for d in filtered_dicts]) 85 | 86 | 87 | def create_map_from_list_of_dicts( 88 | list_of_dicts: ListConfig, key: str, value: Optional[str] = None 89 | ) -> DictConfig: 90 | """ 91 | Creates a dictionary from a list of dictionaries based on the key and value. 92 | For example, if a feature has a name and an attribute name dim, this function can be used to create a mapping 93 | from the feature name to the attribute: 94 | create_map_from_list_of_dicts(features, "name", "dim") 95 | If value is not provided, the function will return a dictionary with the key as the key and the value as the dictionary. 96 | create_map_from_list_of_dicts(features, "name") 97 | will return {"feature1": {"name": "feature1", "dim": 10}, "feature2": {"name": "feature2", "dim": 20}, "feature3": {"name": "feature3"}} 98 | """ 99 | if value is None: 100 | return DictConfig({d[key]: d for d in list_of_dicts if key in d}) 101 | 102 | return DictConfig( 103 | {d[key]: d[value] for d in list_of_dicts if key in d and value in d} 104 | ) 105 | 106 | 107 | def math_eval(expression: str) -> float: 108 | """ 109 | Evaluate a mathematical expression given as a string. 110 | 111 | Examples: 112 | 113 | ${math_eval:"2^6"} returns 64 114 | 115 | ${math_eval:"1 + 2*3**4 / (5 + -6)"} returns -161.0 116 | 117 | dim_1: 32 118 | dim_2: 96 119 | ${math_eval:${dim_1}+${dim_2}} returns 128 120 | """ 121 | # Supported operators 122 | operators = { 123 | ast.Add: op.add, 124 | ast.Sub: op.sub, 125 | ast.Mult: op.mul, 126 | ast.Div: op.truediv, 127 | ast.Pow: op.pow, 128 | ast.BitXor: op.xor, 129 | ast.USub: op.neg, 130 | } 131 | 132 | def eval_(node): 133 | # Recursively evaluate the AST nodes 134 | match node: 135 | case ast.Constant(value) if isinstance(value, int): 136 | return value # integer 137 | case ast.BinOp(left, op, right): 138 | return operators[type(op)](eval_(left), eval_(right)) 139 | case ast.UnaryOp(op, operand): # e.g., -1 140 | return operators[type(op)](eval_(operand)) 141 | case _: 142 | raise TypeError(node) 143 | 144 | return eval_(ast.parse(expression, mode="eval").body) 145 | 146 | 147 | def remove_item_from_list(input_list: ListConfig, item_to_remove: str) -> ListConfig: 148 | """ 149 | Removes all occurrences of a specific item from a list. 150 | :param input_list: The input list to remove items from. 151 | :param item_to_remove: The item to remove from the list. 152 | :return: A ListConfig with the specified item removed. 153 | """ 154 | return ListConfig([item for item in input_list if item != item_to_remove]) 155 | 156 | 157 | # resolvers need to be registered to be accessible during config composition. 158 | # The resolver name is the function name without the type annotations. 159 | OmegaConf.register_new_resolver("remove_chars_from_string", remove_chars_from_string) 160 | OmegaConf.register_new_resolver("conditional_expression", conditional_expression) 161 | OmegaConf.register_new_resolver( 162 | "extract_fields_from_list_of_dicts", extract_fields_from_list_of_dicts 163 | ) 164 | OmegaConf.register_new_resolver( 165 | "create_map_from_list_of_dicts", create_map_from_list_of_dicts 166 | ) 167 | OmegaConf.register_new_resolver("math_eval", math_eval) 168 | OmegaConf.register_new_resolver("remove_item_from_list", remove_item_from_list) 169 | -------------------------------------------------------------------------------- /src/models/modules/clustering/mini_batch_kmeans.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from src.components.distance_functions import DistanceFunction 8 | from src.components.clustering_initializers import ( 9 | ClusteringInitializer, 10 | KMeansPlusPlusInitInitializer, 11 | ) 12 | from src.components.loss_functions import WeightedSquaredError 13 | from src.models.modules.clustering.base_clustering_module import BaseClusteringModule 14 | 15 | 16 | class MiniBatchKMeans(BaseClusteringModule): 17 | def __init__( 18 | self, 19 | n_clusters: int, 20 | n_features: int, 21 | distance_function: DistanceFunction, 22 | initializer: ClusteringInitializer = KMeansPlusPlusInitInitializer, 23 | loss_function: torch.nn.Module = WeightedSquaredError(), 24 | optimizer: torch.optim.Optimizer = functools.partial( 25 | torch.optim.SGD, 26 | lr=0.5, 27 | ), 28 | init_buffer_size: int = 1000, 29 | update_manually: bool = False, 30 | ): 31 | """ 32 | Initialize an implementation of the mini-batch k-Means algorithm (Sculley 2010). 33 | 34 | Paper reference: https://dl.acm.org/doi/abs/10.1145/1772690.1772862 35 | 36 | Args: 37 | n_clusters: Number of clusters. 38 | n_features: Number of features in the input data. 39 | distance_function: Distance function to use for computing distances between points. 40 | loss_function: Loss function to use for training. 41 | optimizer: Optimizer to use for training. 42 | initializer: Initialization method. 43 | init_buffer_size: Number of points to buffer for initialization. 44 | """ 45 | super().__init__( 46 | n_clusters=n_clusters, 47 | n_features=n_features, 48 | distance_function=distance_function, 49 | loss_function=loss_function, 50 | optimizer=optimizer, 51 | initializer=initializer, 52 | init_buffer_size=init_buffer_size, 53 | update_manually=update_manually, 54 | ) 55 | self.cluster_counts = torch.zeros(self.n_clusters) 56 | 57 | def forward( 58 | self, batch: torch.Tensor 59 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 60 | """ 61 | Perform a forward pass of the K-Means model on the input batch. 62 | 63 | This function computes the cluster assignments for each input point, the number 64 | of points in each cluster, and the sum of points in each cluster. 65 | 66 | Args: 67 | batch: Data points of shape (batch_size, n_features) 68 | 69 | Returns: 70 | assignments: Cluster assignments of shape (batch_size,) 71 | batch_cluster_counts: Number of points in each cluster of shape (n_clusters,) 72 | batch_cluster_sums: Sum of points in each cluster of shape (n_clusters, n_features) 73 | """ 74 | # Compute cluster assignments 75 | # Note that assignments is automatically detached from the computation graph 76 | # because it results from argmin 77 | assignments = self.predict_step(batch, return_embeddings=False) 78 | assignments_one_hot = ( 79 | nn.functional.one_hot(assignments, self.n_clusters) 80 | ).detach() 81 | # Count points in each cluster 82 | batch_cluster_counts = torch.sum(assignments_one_hot, dim=0) 83 | self.cluster_counts += batch_cluster_counts 84 | # Accumulate points for each cluster 85 | batch_cluster_sums = torch.mm(assignments_one_hot.float().t(), batch) 86 | 87 | return assignments, batch_cluster_counts, batch_cluster_sums 88 | 89 | def model_step( 90 | self, batch: torch.Tensor 91 | ) -> Tuple[torch.Tensor, torch.Tensor, bool]: 92 | """ 93 | Perform a forward pass of the K-Means model on the batch and compute the loss. 94 | 95 | This function may be called by another LightningModule, such as a residual 96 | K-means module, that is using this MiniBatchKMeans module as a submodule. 97 | 98 | Calling this function along will not update the centroids, and will not 99 | increment self.global_step. If a parent module is using this module as a 100 | submodule, the parent will be responsible for updating those parameters. 101 | Otherwise, these will be updated by Lightning after it calls 102 | training_step. 103 | 104 | Args: 105 | batch: Data points of shape (batch_size, n_features) 106 | 107 | Returns: 108 | assignments: Cluster assignments of shape (batch_size,) 109 | embeddings: Embeddings of shape (batch_size, n_features) 110 | loss: Loss value. Tensor of shape (1,), or None if update_manually is True. 111 | """ 112 | batch = batch.to(self.device) 113 | 114 | # Initialize centroids using the chosen method 115 | # Buffer initial batches for better initialization 116 | if self.is_initial_step: 117 | self.is_initial_step = False 118 | self.is_initialized = True 119 | if not self.is_initialized: 120 | return self.initialization_step(batch) 121 | 122 | assignments, batch_cluster_counts, batch_cluster_sums = self.forward(batch) 123 | 124 | centroids = self.get_centroids() 125 | # Use a mask to avoid division by zero 126 | mask = batch_cluster_counts != 0 127 | mask_target = batch_cluster_sums[mask] / batch_cluster_counts[mask].unsqueeze(1) 128 | centroid_weights = batch_cluster_counts[mask] / self.cluster_counts[mask] 129 | 130 | if self.update_manually: 131 | self.centroids[mask] = self.centroids[mask].data - ( 132 | (centroids[mask].data - mask_target) * centroid_weights.unsqueeze(1) 133 | ) 134 | return assignments, centroids[assignments], None 135 | else: 136 | # The MiniBatchKMeans algorithm update above is equivalent to an SGD step 137 | # with learning rate 0.5 on the loss function below 138 | loss = self.loss_function(centroids[mask], mask_target, centroid_weights) 139 | return assignments, centroids[assignments], loss 140 | 141 | def on_train_start(self) -> None: 142 | """Lightning callback to reset the model state at the start of training.""" 143 | self.cluster_counts = torch.zeros(self.n_clusters, device=self.device) 144 | super().on_train_start() 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Recommendation with Semantic IDs (GRID) 2 | [![PyTorch](https://img.shields.io/badge/pytorch-2.0%2B-red)](https://pytorch.org/) 3 | [![Hydra](https://img.shields.io/badge/config-hydra-89b8cd)](https://hydra.cc/) 4 | [![Lightning](https://img.shields.io/badge/pytorch-lightning-792ee5)](https://lightning.ai/) 5 | [![arXiv](https://img.shields.io/badge/arXiv-2507.22224-b31b1b.svg)](https://arxiv.org/abs/2507.22224) 6 | 7 | 8 | **GRID** (Generative Recommendation with Semantic IDs) is a state-of-the-art framework for generative recommendation systems using semantic IDs, developed by a group of scientists and engineers from [Snap Research](https://research.snap.com/team/user-modeling-and-personalization.html). This project implements novel approaches for learning semantic IDs from text embedding and generating recommendations through transformer-based generative models. 9 | 10 | ## 🚀 Overview 11 | 12 | GRID facilitates generative recommendation three overarching steps: 13 | 14 | - **Embedding Generation with LLMs**: Converting item text into embeddings using any LLMs available on Huggingface. 15 | - **Semantic ID Learning**: Converting item embedding into hierarchical semantic IDs using Residual Quantization techniques such as RQ-KMeans, RQ-VAE, RVQ. 16 | - **Generative Recommendations**: Using transformer architectures to generate recommendation sequences as semantic ID tokens. 17 | 18 | 19 | ## 📦 Installation 20 | 21 | ### Prerequisites 22 | - Python 3.10+ 23 | - CUDA-compatible GPU (recommended) 24 | 25 | ### Setup Environment 26 | 27 | ```bash 28 | # Clone the repository 29 | git clone https://github.com/snap-research/GRID.git 30 | cd GRID 31 | 32 | # Install dependencies 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## 🎯 Quick Start 37 | 38 | ### 1. Data Preparation 39 | 40 | Prepare your dataset in the expected format: 41 | ``` 42 | data/ 43 | ├── train/ # training sequence of user history 44 | ├── validation/ # validation sequence of user history 45 | ├── test/ # testing sequence of user history 46 | └── items/ # text of all items in the dataset 47 | ``` 48 | 49 | We provide pre-processed Amazon data explored in the [P5 paper](https://arxiv.org/abs/2203.13366) [4]. The data can be downloaded from this [google drive link](https://drive.google.com/file/d/1B5_q_MT3GYxmHLrMK0-lAqgpbAuikKEz/view?usp=sharing). 50 | 51 | ### 2. Embedding Generation with LLMs 52 | 53 | Generate embeddings from LLMs, which later will be transformed into semantic IDs. 54 | 55 | ```bash 56 | python -m src.inference experiment=sem_embeds_inference_flat data_dir=data/amazon_data/beauty # avaiable data includes 'beauty', 'sports', and 'toys' 57 | ``` 58 | 59 | ### 3. Train and Generate Semantic IDs 60 | 61 | Learn semantic ID centroids for embeddings generated in step 2: 62 | 63 | ```bash 64 | python -m src.train experiment=rkmeans_train_flat \ 65 | data_dir=data/amazon_data/beauty \ 66 | embedding_path=/merged_predictions_tensor.pt \ # this can be found in the log dirs in step2 67 | embedding_dim=2048 \ # the model dimension of the LLMs you use in step 2. 2048 for flan-t5-xl as used in this example. 68 | num_hierarchies=3 \ # we train 3 codebooks 69 | codebook_width=256 \ # each codebook has 256 rows of centroids 70 | ``` 71 | 72 | Generate SIDs: 73 | 74 | ```bash 75 | python -m src.inference experiment=rkmeans_inference_flat \ 76 | data_dir=data/amazon_data/beauty \ 77 | embedding_path=/merged_predictions_tensor.pt \ 78 | embedding_dim=2048 \ 79 | num_hierarchies=3 \ 80 | codebook_width=256 \ 81 | ckpt_path= # this can be found in the log dir for training SIDs 82 | ``` 83 | 84 | 85 | ### 4. Train Generative Recommendation Model with Semantic IDs 86 | 87 | Train the recommendation model using the learned semantic IDs: 88 | 89 | ```bash 90 | python -m src.train experiment=tiger_train_flat \ 91 | data_dir=data/amazon_data/beauty \ 92 | semantic_id_path=/pickle/merged_predictions_tensor.pt \ 93 | num_hierarchies=4 # Please note that we add 1 for num_hierarchies because in the previous step we appended one additional digit to de-duplicate the semantic IDs we generate. 94 | ``` 95 | 96 | ### 4. Generate Recommendations 97 | 98 | Run inference to generate recommendations: 99 | 100 | ```bash 101 | python -m src.inference experiment=tiger_inference_flat \ 102 | data_dir=data/amazon_data/beauty \ 103 | semantic_id_path=/pickle/merged_predictions_tensor.pt \ 104 | ckpt_path= \ # this can be found in the log dir for training GR models 105 | num_hierarchies=4 \ # Please note that we add 1 for num_hierarchies because in the previous step we appended one additional digit to de-duplicate the semantic IDs we generate. 106 | ``` 107 | 108 | ## Supported Models: 109 | 110 | ### Semantic ID: 111 | 112 | 1. Residual K-means proposed in One-Rec [2] 113 | 2. Residual Vector Quantization 114 | 3. Residual Quantization with Variational Autoencoder [3] 115 | 116 | ### Generative Recommendation: 117 | 118 | 1. TIGER [1] 119 | 120 | ## 📚 Citation 121 | 122 | If you use GRID in your research, please cite: 123 | 124 | ```bibtex 125 | @inproceedings{grid, 126 | title = {Generative Recommendation with Semantic IDs: A Practitioner's Handbook}, 127 | author = {Ju, Clark Mingxuan and Collins, Liam and Neves, Leonardo and Kumar, Bhuvesh and Wang, Louis Yufeng and Zhao, Tong and Shah, Neil}, 128 | booktitle = {Proceedings of the 34th ACM International Conference on Information and Knowledge Management (CIKM)}, 129 | year = {2025} 130 | } 131 | ``` 132 | 133 | ## 🤝 Acknowledgments 134 | 135 | - Built with [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/) 136 | - Configuration management by [Hydra](https://hydra.cc/) 137 | - Inspired by recent advances in generative AI and recommendation systems 138 | - Part of this repo is built on top of https://github.com/ashleve/lightning-hydra-template 139 | 140 | ## 📞 Contact 141 | 142 | For questions and support: 143 | - Create an issue on GitHub 144 | - Contact the development team: Clark Mingxuan Ju (mju@snap.com), Liam Collins (lcollins2@snap.com), Bhuvesh Kumar (bhuvesh@snap.com) and Leonardo Neves (lneves@snap.com). 145 | 146 | ## Bibliography 147 | 148 | [1] Rajput, Shashank, et al. "Recommender systems with generative retrieval." Advances in Neural Information Processing Systems 36 (2023): 10299-10315. 149 | 150 | [2] Deng, Jiaxin, et al. "Onerec: Unifying retrieve and rank with generative recommender and iterative preference alignment." arXiv preprint arXiv:2502.18965 (2025). 151 | 152 | [3] Lee, Doyup, et al. "Autoregressive image generation using residual quantization." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022. 153 | 154 | [4] Geng, Shijie, et al. "Recommendation as language processing (rlp): A unified pretrain, personalized prompt & predict paradigm (p5)." Proceedings of the 16th ACM conference on recommender systems. 2022. 155 | -------------------------------------------------------------------------------- /src/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from shutil import SameFileError 5 | from typing import BinaryIO, List, Optional 6 | 7 | from fsspec.core import url_to_fs 8 | from lightning.fabric.utilities.types import _PATH 9 | from pyarrow import fs as pyarrow_fs 10 | 11 | from src.utils.decorators import retry 12 | 13 | 14 | @retry() 15 | def get_file_size(file_path: str) -> int: 16 | fs, _ = url_to_fs(file_path) 17 | return fs.size(file_path) 18 | 19 | 20 | @retry() 21 | def copy_to_remote(local_path: str, remote_path: str, recursive: bool = True) -> None: 22 | try: 23 | logging.info(f"Copying {local_path} to {remote_path}") 24 | fs, _ = url_to_fs(remote_path) 25 | fs.put(local_path, remote_path, recursive=recursive) 26 | logging.info(f"Finished copying {local_path} to {remote_path}") 27 | except SameFileError: 28 | logging.warning(f"{local_path} and {remote_path} are the same. Skipping copy.") 29 | 30 | 31 | @retry() 32 | def file_exists_local_or_remote(file_path: str) -> bool: 33 | fs, _ = url_to_fs(file_path) 34 | return fs.exists(file_path) 35 | 36 | 37 | @retry() 38 | def open_local_or_remote(file_path: str, mode: str = "r") -> BinaryIO: 39 | fs, _ = url_to_fs(file_path) 40 | return fs.open(file_path, mode) 41 | 42 | 43 | def load_json(file_path: str) -> dict: 44 | with open_local_or_remote(file_path, "r") as f: 45 | feature_map = json.load(f) 46 | return feature_map 47 | 48 | 49 | @retry() 50 | def open_pyarrow_file(file_path: str): 51 | # Optimized function for large pyarrow files (ie. parquet) 52 | # enable us to read row groups instead of entire file. 53 | fs, path = pyarrow_fs.FileSystem.from_uri(file_path) 54 | return fs.open_input_file(path) 55 | 56 | 57 | def get_last_modified_file( 58 | folder_path: str, suffix="*", should_update_prefix=True 59 | ) -> str: 60 | """ 61 | Can get the last modified file in a folder from a local or remote filesystem. 62 | """ 63 | fs, _ = url_to_fs(folder_path) 64 | file_list = list_files(folder_path, suffix, should_update_prefix) 65 | if not file_list: 66 | return "" 67 | 68 | latest_mtime = 0 69 | for file in file_list: 70 | info = fs.info(file) 71 | mtime = info.get("mtime", 0) 72 | if latest_mtime == 0 or mtime > latest_mtime: 73 | latest_mtime = mtime 74 | latest_file = file 75 | return latest_file 76 | 77 | 78 | def remove_file_extension(path: _PATH) -> _PATH: 79 | """ 80 | Removes the file extension from a given file path. 81 | 82 | Args: 83 | path (_PATH): The file path from which to remove the extension. 84 | 85 | Returns: 86 | _PATH: The file path without the extension. 87 | 88 | Example: 89 | >>> remove_file_extension("example/file.txt") 90 | 'example/file' 91 | """ 92 | base, _ = os.path.splitext(path) 93 | return base 94 | 95 | 96 | def has_no_extension(filepath: _PATH) -> bool: 97 | # Extract just the filename from the path, handles both local and cloud paths 98 | filename = os.path.basename(filepath) 99 | # Split the filename and check if extension is empty 100 | _, extension = os.path.splitext(filename) 101 | return extension == "" 102 | 103 | 104 | def list_subfolders( 105 | directory_path, 106 | should_update_prefix: bool = True, 107 | ): 108 | """ 109 | List all folders inside a directory using fsspec. 110 | 111 | Args: 112 | directory_path (str): Path to the directory to search 113 | (can be local or remote like 's3://', 'gs://', etc.) 114 | should_update_prefix (bool): If True, adds the prefix based on the filesystem, 115 | otherwise returns the path generated by glob 116 | Returns: 117 | list: List of folder paths 118 | """ 119 | # Get the appropriate filesystem 120 | fs, _ = url_to_fs(directory_path) 121 | 122 | # List all items in the directory 123 | all_items = fs.ls(directory_path) 124 | 125 | # Filter to only include directories that are not the original one. 126 | folders = [ 127 | f"{fs.protocol[0]}://{item}" if should_update_prefix else item 128 | for item in all_items 129 | if fs.isdir(item) and item != directory_path 130 | ] 131 | 132 | return folders 133 | 134 | 135 | @retry() 136 | def list_files( 137 | folder_path: str, 138 | suffix: str = "*", 139 | # if should_update_prefix is True, adds the prefix based on the filesystem, 140 | # otherwise returns the path generated by glob 141 | should_update_prefix: bool = True, 142 | ) -> List[str]: 143 | 144 | # We remove trailing slashes to avoid double slashes in the path 145 | folder_path = folder_path.removesuffix("/") 146 | 147 | fs, _ = url_to_fs(folder_path) 148 | return ( 149 | # add the prefix for gcs. 150 | [f"{fs.protocol[0]}://{x}" for x in fs.glob(f"{folder_path}/{suffix}")] 151 | if should_update_prefix 152 | else fs.glob(f"{folder_path}/{suffix}") 153 | ) 154 | 155 | 156 | def replace_char_after_segment( 157 | path: str, 158 | char_to_replace: str, 159 | replacement_char: str, 160 | segment_to_find: Optional[str] = None, 161 | ) -> str: 162 | """ 163 | Replace a specific character with another character in a path string. 164 | If segment_to_find is provided, replacements occur only after that segment, 165 | or returns the original string if segment is not found. 166 | If segment_to_find is None, replacements occur throughout the entire string. 167 | 168 | Args: 169 | path (str): The full path string to process 170 | char_to_replace (str): The character to be replaced 171 | replacement_char (str): The character to use as replacement 172 | segment_to_find (Optional[str]): The path segment after which replacements 173 | should occur. If None, replace in the entire string. 174 | 175 | Returns: 176 | str: The modified path with character replacements. 177 | If segment_to_find is provided but not found, returns the original path unchanged. 178 | """ 179 | # If no segment is specified, replace throughout the entire string 180 | if segment_to_find is None: 181 | return path.replace(char_to_replace, replacement_char) 182 | 183 | # Find the position of the segment in the path 184 | segment_index = path.find(segment_to_find) 185 | 186 | if segment_index != -1: 187 | # Include the full segment in the "before" part 188 | segment_end = segment_index + len(segment_to_find) 189 | before_segment = path[:segment_end] 190 | after_segment = path[segment_end:] 191 | 192 | # Replace characters only in the part after the segment 193 | modified_after_segment = after_segment.replace( 194 | char_to_replace, replacement_char 195 | ) 196 | 197 | # Combine the parts 198 | return before_segment + modified_after_segment 199 | 200 | # If segment is not found, return the original path unchanged 201 | return path 202 | -------------------------------------------------------------------------------- /src/utils/decorators.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import signal 4 | import time 5 | from functools import wraps 6 | from typing import Callable, Optional, Tuple, Type, TypeVar 7 | 8 | from src.utils.pylogger import RankedLogger 9 | 10 | T = TypeVar("T") # return type 11 | 12 | logger = RankedLogger(__name__) 13 | 14 | 15 | class TimedOutException(Exception): 16 | pass 17 | 18 | 19 | class RetriesFailedException(Exception): 20 | pass 21 | 22 | 23 | class __RetriableTimeoutException(Exception): 24 | pass 25 | 26 | def timeout( 27 | seconds=10, 28 | error_message=os.strerror(errno.ETIME), 29 | timeout_action_func=None, 30 | exception_thrown_on_timeout=TimedOutException, 31 | **timeout_action_func_params, 32 | ): 33 | """ 34 | Decorator to exit a program when a function execution exceeds a specified timeout. 35 | 36 | This decorator exits the program when the decorated function timed out, and executes 37 | timeout_action_func before exiting. The timeout_action_func can be useful for cases 38 | like environment clean up (e.g., ray.shutdown()). Another way to handle timeouts, 39 | especially when there are multiple threads or child threads, is to run `func` with 40 | multiprocessing. However, multiprocessing does not work with Ray. 41 | 42 | References to the current timeout implementation: 43 | - https://docs.python.org/3/library/signal.html 44 | - https://www.saltycrane.com/blog/2010/04/using-python-timeout-decorator-uploading-s3/ 45 | 46 | Parameters: 47 | - seconds (int): Timeout duration in seconds. Defaults to 10. 48 | - error_message (str): Error message to log on timeout. Defaults to os.strerror(errno.ETIME). 49 | - timeout_action_func (callable, optional): Function to execute before exiting on timeout. 50 | - exception_thrown_on_timeout (Exception): Exception to raise on timeout. Defaults to TimedOutException. 51 | - **timeout_action_func_params: Arbitrary keyword arguments passed to timeout_action_func. 52 | """ 53 | 54 | def decorator(func): 55 | def _handler(signum, frame): 56 | logger.info(error_message) 57 | if timeout_action_func: 58 | timeout_action_func(**timeout_action_func_params) 59 | raise exception_thrown_on_timeout() 60 | 61 | def wrapper(*args, **kwargs): 62 | old = signal.signal(signal.SIGALRM, _handler) 63 | signal.alarm(seconds) 64 | try: 65 | result = func(*args, **kwargs) 66 | finally: 67 | # cancel the alarm 68 | signal.alarm(0) 69 | # reinstall the old signal handler 70 | signal.signal(signal.SIGALRM, old) 71 | return result 72 | 73 | return wraps(func)(wrapper) 74 | 75 | return decorator 76 | 77 | 78 | def retry( 79 | exception_to_check: Type = Exception, 80 | tries: int = 5, 81 | delay_s: int = 3, 82 | backoff: int = 2, 83 | max_delay_s: Optional[int] = None, 84 | fn_execution_timeout_s: Optional[int] = None, 85 | deadline_s: Optional[int] = None, 86 | should_throw_original_exception: bool = False, 87 | ) -> Callable[[Callable[..., T]], Callable[..., T]]: 88 | """ 89 | Decorator that can be added around a function to retry incase it fails i.e. throws some exceptions 90 | 91 | Args: 92 | exception_to_check (Optional[Type]): the exception to check. may be a tuple of 93 | exceptions to check. Defaults to Exception. i.e. catches everything 94 | tries (Optional[int]): [description]. number of times to try (not retry) before giving up. Defaults to 5. 95 | delay_s (Optional[int]): [description]. initial delay between retries in seconds. Defaults to 3. 96 | backoff (Optional[int]): [description]. backoff multiplier e.g. value of 2 will double the delay 97 | each retry. Defaults to 2. 98 | max_delay_s (Optional[int]): [description]. maximum delay between retries in seconds. Defaults to None. 99 | fn_execution_timeout_s (Optional[int]): Maximum time given before a single function 100 | execution should time out. Defaults to None. 101 | deadline_s (Optional[int]): [description]. Total time in seconds to spend retrying, fails if exceeds 102 | this time. Note this timeout can also stop the first execution, so ensure to provide 103 | a lot of extra room so retries can actually take place. 104 | If timeout occurs, src.common.utils.timeout.TimedOutException is raised. 105 | Defaults to None. 106 | should_throw_original_exception (Optional[bool]): Defaults to False. 107 | """ 108 | 109 | def deco_retry(f) -> Callable[..., T]: 110 | @wraps(f) 111 | def f_retry(*args, **kwargs) -> T: # type: ignore[type-var] 112 | mtries, mdelay = tries, delay_s 113 | 114 | def fn(*args, **kwargs) -> T: 115 | if fn_execution_timeout_s is not None: 116 | timeout_individual_fn_call_decorator = timeout( 117 | seconds=fn_execution_timeout_s, 118 | exception_thrown_on_timeout=__RetriableTimeoutException, 119 | ) 120 | return timeout_individual_fn_call_decorator(f)(*args, **kwargs) 121 | return f(*args, **kwargs) 122 | 123 | acceptable_exceptions: Tuple[Type[Exception], ...] = ( 124 | exception_to_check 125 | if isinstance(exception_to_check, tuple) 126 | else (exception_to_check,) 127 | ) 128 | acceptable_exceptions = acceptable_exceptions + ( 129 | __RetriableTimeoutException, 130 | ) 131 | 132 | ret_val: T 133 | while mtries >= 0: 134 | try: 135 | ret_val = fn(*args, **kwargs) 136 | break 137 | except TimedOutException: 138 | raise 139 | except acceptable_exceptions as e: 140 | if mtries == 0: 141 | # Failed for the final time 142 | logger.exception(f"Failed for the last time: {e}") 143 | if should_throw_original_exception: 144 | raise # Reraise original exception 145 | raise RetriesFailedException( 146 | f"Retry failed, permanently failing {f.__module__}:{f.__name__}, see logs for {e}" 147 | ) 148 | msg = f"{e}, Retrying {f.__module__}:{f.__name__} in {mdelay} seconds..." 149 | logger.warning(msg) 150 | time.sleep(mdelay) 151 | mtries -= 1 152 | mdelay = ( 153 | min(mdelay * backoff, max_delay_s) 154 | if max_delay_s 155 | else mdelay * backoff 156 | ) 157 | 158 | return ret_val 159 | 160 | if deadline_s is not None: 161 | global_retry_timeout_decorator = timeout(seconds=deadline_s) 162 | return global_retry_timeout_decorator(f_retry) 163 | 164 | return f_retry 165 | 166 | return deco_retry 167 | -------------------------------------------------------------------------------- /src/models/modules/huggingface/transformer_base_module.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple 2 | 3 | import torch 4 | import transformers 5 | from torchmetrics.aggregation import BaseAggregator 6 | 7 | from src.components.eval_metrics import RetrievalEvaluator 8 | from src.data.loading.components.interfaces import ( 9 | SequentialModelInputData, 10 | SequentialModuleLabelData, 11 | ) 12 | from src.models.components.interfaces import SharedKeyAcrossPredictionsOutput 13 | from src.models.components.network_blocks.embedding_aggregator import ( 14 | EmbeddingAggregator, 15 | ) 16 | from src.models.modules.base_module import BaseModule 17 | 18 | 19 | class TransformerBaseModule(BaseModule): 20 | def __init__( 21 | self, 22 | huggingface_model: transformers.PreTrainedModel, 23 | postprocessor: torch.nn.Module, 24 | aggregator: EmbeddingAggregator, 25 | optimizer: torch.optim.Optimizer, 26 | scheduler: torch.optim.lr_scheduler, 27 | loss_function: torch.nn.Module, 28 | evaluator: RetrievalEvaluator, 29 | weight_tying: bool, 30 | compile: bool, 31 | training_loop_function: callable = None, 32 | feature_to_model_input_map: Dict[str, str] = {}, 33 | decoder: torch.nn.Module = None, 34 | ) -> None: 35 | 36 | super().__init__( 37 | model=huggingface_model, 38 | optimizer=optimizer, 39 | scheduler=scheduler, 40 | loss_function=loss_function, 41 | evaluator=evaluator, 42 | training_loop_function=training_loop_function, 43 | ) 44 | 45 | # this line allows to access init params with 'self.hparams' attribute 46 | # also ensures init params will be stored in ckpt 47 | # we remove the nn.Modules as they are already checkpointed to avoid doing it twice 48 | 49 | self.save_hyperparameters( 50 | logger=False, 51 | ignore=[ 52 | "huggingface_model", 53 | "postprocessor", 54 | "aggregator", 55 | "decoder", 56 | "loss_function", 57 | ], 58 | ) 59 | 60 | self.encoder = huggingface_model 61 | self.embedding_post_processor = postprocessor 62 | self.decoder = decoder 63 | self.aggregator = aggregator 64 | self.feature_to_model_input_map = feature_to_model_input_map 65 | 66 | def get_embedding_table(self): 67 | if self.hparams.weight_tying: # type: ignore 68 | return self.encoder.get_input_embeddings().weight 69 | else: 70 | return self.decoder.weight 71 | 72 | def training_step( 73 | self, 74 | batch: Tuple[Tuple[SequentialModelInputData, SequentialModuleLabelData]], 75 | batch_idx: int, 76 | ) -> torch.Tensor: 77 | """Perform a single training step on a batch of data from the training set. 78 | 79 | :param batch: A batch of data of data (tuple). Because of lightning, the tuple is wrapped in another tuple, 80 | and the actual batch is at position 0. The batch is a tuple of data where first object is a SequentialModelInputData object 81 | and second is a SequentialModuleLabelData object. 82 | :param batch_idx: The index of the current batch. 83 | :return: A tensor of losses between model predictions and targets. 84 | """ 85 | # Lightning wraps it in a tuple for training, we get the batch from position 0. 86 | # this behavior only happens for training_step. 87 | batch = batch[0] 88 | # Batch is a tuple of model inputs and labels. 89 | model_input: SequentialModelInputData = batch[0] 90 | label_data: SequentialModuleLabelData = batch[1] 91 | # Batch will be a tuple of model inputs and labels. We use the index here to access them. 92 | model_output, loss = self.model_step( 93 | model_input=model_input, label_data=label_data 94 | ) 95 | 96 | # update and log metrics. Will only be logged at the interval specified in the logger config 97 | self.train_loss(loss) 98 | # checks logging interval and logs the loss 99 | self.log( 100 | "train/loss", 101 | self.train_loss, 102 | on_step=True, 103 | on_epoch=True, 104 | prog_bar=True, 105 | logger=True, 106 | sync_dist=True, 107 | ) 108 | 109 | # If a training loop function is passed, we call it with the module and the loss. 110 | # otherwise we use the automatic optimization provided by lightning 111 | if self.training_loop_function is not None: 112 | self.training_loop_function(self, loss) 113 | 114 | return loss 115 | 116 | def eval_step( 117 | self, 118 | batch: Tuple[SequentialModelInputData, SequentialModuleLabelData], 119 | loss_to_aggregate: BaseAggregator, 120 | ): 121 | """Perform a single evaluation step on a batch of data from the validation or test set. 122 | The method will update the metrics and the loss that is passed. 123 | """ 124 | # Batch is a tuple of model inputs and labels. 125 | model_input: SequentialModelInputData = batch[0] 126 | label_data: SequentialModuleLabelData = batch[1] 127 | 128 | model_output_before_aggregation, loss = self.model_step( 129 | model_input=model_input, label_data=label_data 130 | ) 131 | 132 | model_output_after_aggregation = self.aggregator( 133 | model_output_before_aggregation, model_input.mask 134 | ) 135 | 136 | # Updates metrics inside evaluator. 137 | self.evaluator( 138 | query_embeddings=model_output_after_aggregation, 139 | key_embeddings=self.get_embedding_table().to( 140 | model_output_after_aggregation.device 141 | ), 142 | # TODO: (lneves) hardcoded for now, will need to change for multiple features 143 | labels=list(label_data.labels.values())[0].to( 144 | model_output_after_aggregation.device 145 | ), 146 | ) 147 | loss_to_aggregate(loss) 148 | 149 | def predict_step( 150 | self, 151 | batch: Tuple[SequentialModelInputData, SequentialModuleLabelData], 152 | batch_idx: int, 153 | ): 154 | """Perform a single prediction step on a batch of data from the test set. 155 | 156 | :param batch: A batch of data of data (tuple) where first object is a SequentialModelInputData object 157 | and second is a SequentialModuleLabelData object. 158 | """ 159 | model_input: SequentialModelInputData = batch[0] 160 | model_output_before_aggregation, _ = self.model_step(model_input=model_input) 161 | 162 | model_output_after_aggregation = self.aggregator( 163 | model_output_before_aggregation, model_input.mask 164 | ) 165 | # TODO(lneves): Currently passing batch idx, change it to user_id and allow for the user to specify the key and prediction names. 166 | model_output = SharedKeyAcrossPredictionsOutput( 167 | key=batch_idx, 168 | predictions=model_output_after_aggregation, 169 | key_name=self.prediction_key_name, 170 | prediction_name=self.prediction_name, 171 | ) 172 | return model_output 173 | -------------------------------------------------------------------------------- /src/data/loading/components/dataloading.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any, List, Optional 3 | 4 | from torch.utils.data import IterableDataset, get_worker_info 5 | 6 | from src.data.loading.components.interfaces import BaseDatasetConfig 7 | from src.utils.pylogger import RankedLogger 8 | 9 | command_line_logger = RankedLogger(__name__, rank_zero_only=True) 10 | 11 | 12 | class BaseDataset: 13 | def __init__( 14 | self, 15 | dataset_config: BaseDatasetConfig, 16 | data_folder: str, 17 | should_shuffle_rows: bool = False, 18 | batch_size: int = 1, 19 | is_for_training: bool = True, 20 | assign_all_files_per_worker: bool = False, 21 | ): 22 | """ 23 | Base class for all datasets. This class is used to set up the dataset and provide the list of files to be used. 24 | Args: 25 | dataset_config (BaseDatasetConfig): Configuration for the dataset. 26 | data_folder (str): Path to the folder where the data is stored. 27 | should_shuffle_rows (bool): Whether to shuffle the rows of the dataset. 28 | batch_size (int): Batch size to be used for the dataset. 29 | is_for_training (bool): Whether the dataset is for training or not. 30 | assign_all_files_per_worker (bool): Whether to assign all files to each worker or not. 31 | This will enable each worker to access all files. Each worker will locally shuffle the files. 32 | This would be useful for small datasets. In smaller datasets, if each worker only observes a subset of the files, 33 | it may not be able to learn the distribution of the data. 34 | """ 35 | self.dataset_config = dataset_config 36 | self.should_shuffle_rows = should_shuffle_rows 37 | self.data_folder = data_folder 38 | self.list_of_file_paths = [] 39 | self.batch_size = batch_size 40 | self.is_for_training = is_for_training 41 | self.assign_all_files_per_worker = assign_all_files_per_worker 42 | 43 | def set_list_of_files(self, list_of_files: List[str]): 44 | self.list_of_file_paths = list_of_files 45 | 46 | def set_distributed_params(self, total_workers: int, global_worker_id: int): 47 | # TODO (lneves): figure out how to do this in LightningDataModule 48 | self.total_workers = total_workers 49 | self.global_worker_id = global_worker_id 50 | 51 | def get_worker_id_and_num_workers(self): 52 | worker_info = get_worker_info() 53 | 54 | if worker_info is None: 55 | # Single-worker setup (no multiprocessing) 56 | worker_id = 0 57 | num_workers = 1 58 | else: 59 | # Multi-worker setup 60 | worker_id = worker_info.id 61 | num_workers = worker_info.num_workers 62 | 63 | self.global_dataloader_worker_id = ( 64 | self.global_worker_id * num_workers + worker_id 65 | ) 66 | 67 | return worker_id, num_workers 68 | 69 | def get_list_of_worker_files(self): 70 | # Get information about worker and then separate only files that belong to this worker 71 | worker_id, num_workers = self.get_worker_id_and_num_workers() 72 | if self.assign_all_files_per_worker: 73 | worker_files = self.list_of_file_paths 74 | else: 75 | worker_files = self.list_of_file_paths[worker_id::num_workers] 76 | command_line_logger.debug( 77 | f"GPU Worker: {self.global_worker_id}/{self.total_workers} CPU Worker {worker_id} has {len(worker_files)} files" 78 | ) 79 | return worker_files 80 | 81 | def setup(self): 82 | pass 83 | 84 | 85 | class UnboundedSequenceIterable(BaseDataset, IterableDataset): 86 | """An unbounded dataset is a dataset that we don't know the size of beforehand. 87 | For training, we will iterate over the dataset infinitely. For evaluation, we will iterate over the dataset once. 88 | """ 89 | 90 | def __init__( 91 | self, 92 | dataset_config: BaseDatasetConfig, 93 | data_folder: str, 94 | should_shuffle_rows: bool = False, 95 | batch_size: int = 1, 96 | is_for_training: bool = True, 97 | assign_all_files_per_worker: bool = False, 98 | ): 99 | super().__init__( 100 | dataset_config=dataset_config, 101 | data_folder=data_folder, 102 | should_shuffle_rows=should_shuffle_rows, 103 | batch_size=batch_size, 104 | is_for_training=is_for_training, 105 | assign_all_files_per_worker=assign_all_files_per_worker, 106 | ) 107 | self.data_iterator = dataset_config.data_iterator 108 | self.dataset_to_iterate = None 109 | 110 | def setup(self): 111 | # We update each worker's data iterator with the files just for that worker. 112 | self.data_iterator.update_list_of_file_paths(self.get_list_of_worker_files()) 113 | self.data_iterator = ( 114 | # here we use global_dataloader_worker_id as the seed for shuffling 115 | # this doesn't matter for the case where workers have non-overlapping files 116 | # but it does matter for the case where workers have all files 117 | # (e.g. when using assign_all_files_per_worker) 118 | # the same seed is used for all workers would cause duplicated examples returned by different workers 119 | self.data_iterator.shuffle(seed=self.global_dataloader_worker_id) 120 | if self.should_shuffle_rows 121 | else self.data_iterator 122 | ) 123 | self.data_iterator.should_shuffle_rows = self.should_shuffle_rows 124 | # We provide the flexibility to iterate per row, if per row preprocessing is needed, or per batch. 125 | self.dataset_to_iterate = ( 126 | self.data_iterator.iterrows() 127 | if self.dataset_config.iterate_per_row 128 | else self.data_iterator.iter_batches(self.batch_size) 129 | ) 130 | 131 | command_line_logger.debug( 132 | f"GLOBAL ID {self.global_dataloader_worker_id} GPU Worker: {self.global_worker_id}/{self.total_workers} with {len(self.data_iterator.list_of_file_paths)} files\ 133 | First five files are: {self.data_iterator.list_of_file_paths[:5]}" 134 | ) 135 | 136 | def __iter__(self): 137 | if self.dataset_to_iterate is None: 138 | # If it has not been set up, it means it is a forkserver worker. We need to set it up. 139 | self.setup() 140 | # If the dataset is for training, we want to keep iterating over the dataset infinitely. 141 | # On a streaming dataset, we will always be on Epoch 0. 142 | finished_iteration = False 143 | while not finished_iteration: 144 | 145 | for row_or_batch in self.dataset_to_iterate: 146 | for ( 147 | preprocessing_function 148 | ) in self.dataset_config.preprocessing_functions: 149 | row_or_batch = preprocessing_function( 150 | row_or_batch, dataset_config=self.dataset_config 151 | ) 152 | if row_or_batch is None: 153 | break 154 | if row_or_batch: 155 | yield row_or_batch 156 | # if the dataset is not for training, we stop the loop. Otherwise, we continue. 157 | finished_iteration = not self.is_for_training 158 | if not finished_iteration: 159 | self.setup() 160 | # We reset the dataset to iterate to None, so that it is set up again in the next iteration. 161 | # This is required for validation when persitent_workers = True. 162 | self.dataset_to_iterate = None 163 | return None -------------------------------------------------------------------------------- /configs/experiment/tiger_inference_flat.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_dir: ??? 3 | semantic_id_path: ??? 4 | ckpt_path: ??? 5 | num_hierarchies: ??? 6 | seed: 42 7 | sequence_length: 120 8 | 9 | model: 10 | huggingface_model: 11 | _target_: transformers.T5EncoderModel 12 | config: 13 | _target_: transformers.T5Config 14 | vocab_size: 256 15 | d_model: 128 16 | num_heads: 6 17 | dropout_rate: 0.15 18 | d_ff: 1024 19 | d_kv: 64 20 | num_layers: 4 21 | _target_: src.models.modules.semantic_id.tiger_generation_model.SemanticIDEncoderDecoder 22 | feature_to_model_input_map: 23 | sequence_data: input_ids 24 | user_id: user_id 25 | postprocessor: null 26 | aggregator: null 27 | loss_function: null 28 | optimizer: null 29 | scheduler: null 30 | evaluator: null 31 | weight_tying: true 32 | compile: false 33 | decoder: 34 | _target_: transformers.models.t5.modeling_t5.T5Stack 35 | config: 36 | _target_: transformers.models.t5.configuration_t5.T5Config 37 | vocab_size: ${model.huggingface_model.config.vocab_size} 38 | d_model: ${model.huggingface_model.config.d_model} 39 | num_heads: ${model.huggingface_model.config.num_heads} 40 | dropout_rate: 0.15 41 | d_ff: ${model.huggingface_model.config.d_ff} 42 | d_kv: ${model.huggingface_model.config.d_kv} 43 | num_layers: 4 44 | is_decoder: true 45 | is_encoder_decoder: false 46 | embed_tokens: 47 | _target_: torch.nn.Embedding 48 | num_embeddings: ${model.huggingface_model.config.vocab_size} 49 | embedding_dim: ${model.huggingface_model.config.d_model} 50 | num_hierarchies: ${num_hierarchies} 51 | num_user_bins: null 52 | codebooks: ${data_loading.predict_dataloader_config.dataloader.dataset_config.semantic_id_map.sequence_data} 53 | mlp_layers: 2 54 | top_k_for_generation: 10 55 | task_name: inference 56 | id: ${now:%Y-%m-%d}/${now:%H-%M-%S} 57 | tags: 58 | - amazon-p5-gr-train 59 | callbacks: 60 | bq_writer: null 61 | pickle_writer: 62 | _target_: src.utils.inference_utils.LocalPickleWriter 63 | output_dir: ${paths.output_dir}/pickle 64 | flush_frequency: 100000 65 | write_interval: batch 66 | should_merge_files_on_main: true 67 | prediction_key_name: user_id 68 | prediction_name: semantic_ids 69 | paths: 70 | root_dir: . 71 | data_dir: ${data_dir} 72 | log_dir: ${paths.root_dir}/logs 73 | output_dir: ${hydra:runtime.output_dir} 74 | work_dir: ${hydra:runtime.cwd} 75 | profile_dir: ${hydra:run.dir}/profile_output 76 | metadata_dir: ${paths.output_dir}/metadata 77 | logger: 78 | csv: 79 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 80 | save_dir: ${paths.output_dir} 81 | name: csv/ 82 | prefix: '' 83 | trainer: 84 | _target_: lightning.pytorch.trainer.Trainer 85 | default_root_dir: ${paths.output_dir} 86 | min_steps: 1 87 | max_steps: 80000 88 | max_epochs: 10 89 | accelerator: gpu 90 | devices: -1 91 | num_nodes: 1 92 | precision: 32-true 93 | log_every_n_steps: 2500 94 | val_check_interval: 5000 95 | deterministic: false 96 | accumulate_grad_batches: 1 97 | profiler: 98 | _target_: lightning.pytorch.profilers.PassThroughProfiler 99 | strategy: ddp 100 | sync_batchnorm: true 101 | data_loading: 102 | features_config: 103 | features: 104 | - name: sequence_data 105 | num_placeholder_tokens: 0 106 | num_placeholder_tokens_sparse_ids: 2 107 | semantic_ids: ??? 108 | is_item_ids: true 109 | type: 110 | _target_: torch.__dict__.get 111 | _args_: 112 | - int32 113 | - name: embedding 114 | type: 115 | _target_: torch.__dict__.get 116 | _args_: 117 | - float32 118 | - name: text 119 | type: 120 | _target_: torch.__dict__.get 121 | _args_: 122 | - bytes 123 | - name: user_id 124 | is_item_ids: true 125 | type: 126 | _target_: torch.__dict__.get 127 | _args_: 128 | - int32 129 | dataset_config: 130 | dataset: 131 | _target_: src.data.loading.components.interfaces.SemanticIDDatasetConfig 132 | user_id_field: user_id 133 | min_sequence_length: 1 134 | iterate_per_row: true 135 | keep_user_id: true 136 | features_to_consider: ${extract_fields_from_list_of_dicts:${data_loading.features_config.features}, 137 | "name", False, "is_item_ids", "True"} 138 | num_placeholder_tokens_map: ${create_map_from_list_of_dicts:${data_loading.features_config.features}, 139 | "name", "num_placeholder_tokens"} 140 | semantic_id_map: ${create_map_from_list_of_dicts:${data_loading.features_config.features}, 141 | "name", "semantic_ids"} 142 | data_iterator: 143 | _target_: src.data.loading.components.iterators.TFRecordIterator 144 | preprocessing_functions: 145 | - _target_: src.data.loading.components.pre_processing.filter_features_to_consider 146 | _partial_: true 147 | - _target_: src.data.loading.components.pre_processing.convert_to_dense_numpy_array 148 | _partial_: true 149 | - _target_: src.data.loading.components.pre_processing.convert_fields_to_tensors 150 | _partial_: true 151 | - _target_: src.data.loading.components.pre_processing.map_sparse_id_to_semantic_id 152 | _partial_: true 153 | features_to_apply: 154 | - sequence_data 155 | num_hierarchies: ${model.num_hierarchies} 156 | predict_dataloader_config: 157 | dataloader: 158 | _target_: src.data.loading.components.interfaces.SequenceDataloaderConfig 159 | dataset_class: 160 | _target_: src.data.loading.components.dataloading.UnboundedSequenceIterable 161 | _partial_: true 162 | data_folder: ${paths.data_dir}/testing 163 | should_shuffle_rows: false 164 | labels: ~ 165 | batch_size_per_device: 32 166 | num_workers: 8 167 | timeout: 60 168 | assign_files_by_size: true 169 | oov_token: null 170 | masking_token: 1 171 | sequence_length: ${sequence_length} 172 | padding_token: -1 173 | drop_last: false 174 | persistent_workers: false 175 | collate_fn: 176 | _target_: src.data.loading.components.collate_functions.collate_fn_inference_for_sequence 177 | _partial_: true 178 | sequence_length: ${data_loading.predict_dataloader_config.dataloader.sequence_length} 179 | padding_token: ${data_loading.predict_dataloader_config.dataloader.padding_token} 180 | id_field_name: ${data_loading.predict_dataloader_config.dataloader.dataset_config.user_id_field} 181 | dataset_config: 182 | _target_: src.data.loading.components.interfaces.SemanticIDDatasetConfig 183 | user_id_field: user_id 184 | min_sequence_length: 1 185 | iterate_per_row: true 186 | keep_user_id: true 187 | features_to_consider: ${extract_fields_from_list_of_dicts:${data_loading.features_config.features}, 188 | "name", False, "is_item_ids", "True"} 189 | num_placeholder_tokens_map: ${create_map_from_list_of_dicts:${data_loading.features_config.features}, 190 | "name", "num_placeholder_tokens"} 191 | semantic_id_map: 192 | sequence_data: 193 | _target_: torch.load 194 | _args_: 195 | - _target_: src.utils.file_utils.open_local_or_remote 196 | file_path: ${semantic_id_path} 197 | mode: rb 198 | data_iterator: 199 | _target_: src.data.loading.components.iterators.TFRecordIterator 200 | preprocessing_functions: 201 | - _target_: src.data.loading.components.pre_processing.filter_features_to_consider 202 | _partial_: true 203 | - _target_: src.data.loading.components.pre_processing.convert_to_dense_numpy_array 204 | _partial_: true 205 | - _target_: src.data.loading.components.pre_processing.convert_fields_to_tensors 206 | _partial_: true 207 | - _target_: src.data.loading.components.pre_processing.map_sparse_id_to_semantic_id 208 | _partial_: true 209 | features_to_apply: 210 | - sequence_data 211 | num_hierarchies: ${model.num_hierarchies} 212 | pin_memory: false 213 | datamodule: 214 | _target_: src.data.loading.datamodules.sequence_datamodule.SequenceDataModule 215 | predict_dataloader_config: ${..predict_dataloader_config.dataloader} 216 | extras: 217 | ignore_warnings: false 218 | enforce_tags: true 219 | print_config_warnings: true 220 | print_config: true 221 | 222 | -------------------------------------------------------------------------------- /src/utils/launcher_utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from dataclasses import dataclass 3 | from typing import List 4 | 5 | import hydra 6 | import lightning as L 7 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 8 | from lightning.pytorch.callbacks import ModelCheckpoint, ModelSummary 9 | from lightning.pytorch.loggers import Logger 10 | from omegaconf import DictConfig 11 | 12 | from src.utils import ( 13 | RankedLogger, 14 | instantiate_callbacks, 15 | instantiate_loggers, 16 | log_hyperparameters, 17 | ) 18 | from src.utils.file_utils import ( 19 | get_last_modified_file, 20 | has_no_extension, 21 | list_subfolders, 22 | ) 23 | from src.utils.logging_utils import finalize_loggers 24 | from src.utils.restart_job_utils import get_attribute_from_metadata_file 25 | from src.utils.utils import has_class_object_inside_list 26 | 27 | command_line_logger = RankedLogger(__name__, rank_zero_only=True) 28 | 29 | 30 | @dataclass 31 | class PipelineModules: 32 | cfg: DictConfig 33 | datamodule: LightningDataModule 34 | model: LightningModule 35 | # We use the plural form to match the names used by lightning 36 | callbacks: List[Callback] 37 | loggers: List[Logger] 38 | trainer: Trainer 39 | 40 | 41 | def update_cfg_with_most_recent_checkpoint_path(cfg: DictConfig) -> str: 42 | """ 43 | Updates the configuration with the most recent checkpoint path if the job is a retry, a checkpoint callback exists, 44 | and a checkpoint file is found. 45 | 46 | This function is useful for resuming training from the most recent checkpoint when a job is restarted. 47 | It checks if the current run is part of a retry (using restart metadata), and if so, it looks for the 48 | most recently modified checkpoint file in the checkpoint directory to use instead of the initial checkpoint. 49 | 50 | Args: 51 | cfg (DictConfig): The configuration dictionary containing training parameters. 52 | 53 | Returns: 54 | DictConfig: The updated configuration dictionary with the most recent checkpoint path. 55 | """ 56 | 57 | ckpt_path = cfg.get("ckpt_path", None) 58 | 59 | if ( 60 | ckpt_path is not None 61 | and has_no_extension(ckpt_path) 62 | and cfg.get("should_retrieve_latest_ckpt_path", False) 63 | ): 64 | # If a path to a folder is passed, we assume it contains folders with versions of checkpoints. 65 | # We expect those folders to be named using a timestamp. 66 | checkpoint_folders = list_subfolders(ckpt_path) 67 | if len(checkpoint_folders) > 0: 68 | # We sort them in reverse order to get the most recent one. 69 | checkpoint_folders.sort(reverse=True) 70 | # We take the first one, which is the most recent one. 71 | latest_ckpt_folder = checkpoint_folders[0] 72 | last_modified = get_last_modified_file( 73 | folder_path=latest_ckpt_folder, suffix="*.ckpt" 74 | ) 75 | if len(last_modified) > 0: 76 | ckpt_path = last_modified 77 | command_line_logger.info( 78 | f"Found most recent checkpoint path: {ckpt_path}. Starting job from this checkpoint." 79 | ) 80 | 81 | # If there is a checkpoint callback running, and the restart_metadata file shows we are not on the first run, 82 | # we check if there are checkpoints in the checkpoint folder and restart from there instead of the initial checkpoint. 83 | if ( 84 | cfg.get("callbacks") # has callbacks 85 | and cfg.callbacks.get("model_checkpoint") # has checkpoint callback 86 | and cfg.callbacks.get("restart_job") # has retry callback 87 | and get_attribute_from_metadata_file( 88 | f"{cfg.callbacks.restart_job.metadata_dir}/restart_metadata.json", 89 | "current_run", 90 | ) 91 | > 0 # current run is part of a retry 92 | ): 93 | checkpoint_folder = cfg.callbacks.model_checkpoint.dirpath 94 | # We check if there are files with the extension .ckpt in the checkpoint folder. If so, we get the latest one. 95 | last_modified = get_last_modified_file( 96 | folder_path=checkpoint_folder, suffix="*.ckpt" 97 | ) 98 | if len(last_modified) > 0: 99 | ckpt_path = last_modified 100 | command_line_logger.info( 101 | f"Found most recent checkpoint path: {ckpt_path}. Starting job from this checkpoint." 102 | ) 103 | 104 | cfg.ckpt_path = ckpt_path 105 | return cfg 106 | 107 | 108 | def initialize_pipeline_modules( 109 | cfg: DictConfig, 110 | ) -> PipelineModules: 111 | """ 112 | Initialize and instantiate various objects required for running pipelines. 113 | 114 | Args: 115 | cfg (DictConfig): Configuration object containing parameters for data, model, callbacks, logger, and trainer. 116 | 117 | Returns: 118 | PipelineModules: A dataclass containing the instantiated objects. 119 | """ 120 | # set seed for random number generators in pytorch, numpy and python.random 121 | if cfg.get("seed"): 122 | L.seed_everything(cfg.seed, workers=True) 123 | 124 | command_line_logger.info( 125 | f"Instantiating datamodule <{cfg.data_loading.datamodule._target_}>" 126 | ) 127 | datamodule: LightningDataModule = hydra.utils.instantiate( 128 | cfg.data_loading.datamodule 129 | ) 130 | 131 | command_line_logger.info(f"Instantiating model <{cfg.model._target_}>") 132 | model: LightningModule = hydra.utils.instantiate(cfg.model) 133 | 134 | command_line_logger.info("Instantiating callbacks...") 135 | callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) 136 | 137 | command_line_logger.info("Instantiating loggers...") 138 | loggers: List[Logger] = instantiate_loggers(cfg.get("logger")) 139 | 140 | command_line_logger.info(f"Instantiating trainer <{cfg.trainer._target_}>") 141 | 142 | cfg = update_cfg_with_most_recent_checkpoint_path(cfg) 143 | 144 | trainer: Trainer = hydra.utils.instantiate( 145 | cfg.trainer, 146 | callbacks=callbacks, 147 | # The default behavior for lightning it to set `enable_checkpointing` and 148 | # `enable_model_summary` to True, which might be misleading when we are trying to 149 | # debug. We change the default to False, but this can be overriden by either 150 | # setting the parameters in the config file or passing the callbacks as part 151 | # of the callbacks yaml. 152 | enable_checkpointing=cfg.trainer.get( 153 | "enable_checkpointing", 154 | has_class_object_inside_list(callbacks, ModelCheckpoint), 155 | ), 156 | enable_model_summary=cfg.trainer.get( 157 | "enable_model_summary", 158 | has_class_object_inside_list(callbacks, ModelSummary), 159 | ), 160 | logger=loggers, 161 | ) 162 | 163 | pipeline_modules = PipelineModules( 164 | cfg=cfg, 165 | datamodule=datamodule, 166 | model=model, 167 | callbacks=callbacks, 168 | loggers=loggers, 169 | trainer=trainer, 170 | ) 171 | 172 | return pipeline_modules 173 | 174 | 175 | @contextmanager 176 | def pipeline_launcher(cfg: DictConfig): 177 | """ 178 | Launches the pipeline with the given configuration and logger. 179 | Args: 180 | cfg (DictConfig): Configuration object containing pipeline settings. 181 | log (RankedLogger): Logger object for logging information. 182 | Yields: 183 | PipelineModules: A dataclass containing the instantiated objects. 184 | Raises: 185 | Exception: Propagates any exception that occurs during pipeline initialization. 186 | Notes: 187 | - If the configuration contains a logger, hyperparameters will be logged. 188 | - Ensures that loggers are finalized and profiler output is saved even if the task fails. 189 | """ 190 | 191 | try: 192 | pipeline_modules: PipelineModules = initialize_pipeline_modules(cfg) 193 | # Log hyperparameters if loggers are present 194 | if len(pipeline_modules.loggers) > 0: 195 | command_line_logger.info("Logging hyperparameters!") 196 | log_hyperparameters(cfg, pipeline_modules.model, pipeline_modules.trainer) 197 | yield pipeline_modules 198 | except Exception as ex: 199 | raise ex 200 | finally: 201 | # We add the try catch to make sure the loggers are finalized even if the task fails. 202 | finalize_loggers(pipeline_modules.trainer) 203 | -------------------------------------------------------------------------------- /src/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | 5 | from src.utils.file_utils import open_local_or_remote 6 | 7 | 8 | def locations_to_index_tuple(locations: torch.Tensor, num_dims: int = 2) -> Tuple: 9 | """ 10 | Convert a tensor of locations to a tuple of index tensors for advanced indexing. 11 | 12 | Args: 13 | locations (torch.Tensor): A tensor of shape `[L, D]` where `L` is the number of 14 | locations and `D >= num_dims`. 15 | num_dims (int): The number of dimensions to extract. The first num_dims columns of 16 | the locations tensor are used. We explicitly specify this to make the 17 | function call traceable. 18 | 19 | Returns: 20 | Tuple: A tuple of `num_dims` tensors, each of shape `[L]` representing the 21 | indices for one dimension. 22 | 23 | Example: 24 | >>> locations = torch.tensor([[0, 10], [1, 20], [2, 5]]) 25 | >>> locations_to_index_tuple(locations, num_dims=2) 26 | (tensor([0, 1, 2]), tensor([10, 20, 5])) 27 | 28 | >>> locations = torch.tensor([[0, 10], [1, 20], [2, 5]]) 29 | >>> locations_to_index_tuple(locations, num_dims=1) 30 | (tensor([0, 1, 2])) 31 | """ 32 | return tuple(locations[:, i] for i in range(num_dims)) 33 | 34 | 35 | def extract_locations( 36 | data: torch.tensor, locations: torch.tensor, num_dims: int = 2 37 | ) -> torch.tensor: 38 | """ 39 | Extracts the elements from a tensor at the specified indices. 40 | 41 | Args: 42 | data (torch.tensor): The input tensor of N dimensions from which to extract elements. 43 | locations (torch.tensor): Tensor of shape [L, D] where L is the number of 44 | elements where each D dimensional row reprecents the first D dimensions 45 | of the data tensor to extract. 46 | num_dims (int): The number of dimensions to extract. The first num_dims columns of 47 | the locations tensors are used. We need to specify to make this function call traceable. 48 | 49 | Returns: 50 | torch.tensor: A tensor of shape [L,...] with total N-num_dims+1 dimensions 51 | containing the extracted elements. 52 | 53 | Example: 54 | >>> data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 55 | >>> locations = torch.tensor([[0, 1], [1, 2]]) 56 | >>> extract_locations(data, locations, num_dims=2) 57 | tensor([2, 6]) # (Index 0,1 gives 2 (First row, second column); Index 1,2 gives 6 (Second row, third column)) 58 | 59 | >>> locations = torch.tensor([[0, 1], [2, 0]]) 60 | >>> extract_locations(data, locations, num_dims=1) 61 | tensor([[1, 2, 3], [7, 8, 9]]) 62 | # (num_dims = 1 implies we are extracting based on the first dimension only. 63 | # Thus, we get the first row (from [0,1] as 1 is ignored) and the third row 64 | # (from [2,0] as 0 is ignored) of the data tensor. 65 | """ 66 | 67 | # Separate the locations for each of the first D dimensions 68 | index_tuple = locations_to_index_tuple(locations=locations, num_dims=num_dims) 69 | 70 | # Use indexing with a tuple of index tensors 71 | extracted_values = data[index_tuple] 72 | 73 | return extracted_values 74 | 75 | 76 | def merge_list_of_keyed_tensors_to_single_tensor( 77 | data: list[dict[str, torch.Tensor]], 78 | index_key: str, 79 | value_key: str, 80 | ) -> torch.Tensor: 81 | """ 82 | Converts a list of dictionaries of id to tensors into a single tensor by squeezing 83 | the tensors along the specified index key. 84 | e.g., 85 | data = [ 86 | [ 87 | {'user_id': 123, 88 | {'semantic_id': torch.tensor([21, 32, 124]), 89 | other features.....} 90 | ], 91 | [ 92 | {'user_id': 456, 93 | {'semantic_id': torch.tensor([11, 22, 33]), 94 | other features.....} 95 | ] 96 | output: 97 | tensor([..., 98 | [21, 32, 124], # row 123 99 | ..., 100 | [11, 22, 33], # row 456 101 | ... 102 | ]) 103 | 104 | Args: 105 | data (list[dict[str, torch.Tensor]]): A list of dictionaries where each dictionary 106 | contains an index key and a value key. 107 | index_key (str): The key in the dictionary that contains the index for each row. 108 | value_key (str): The key in the dictionary that contains the tensor to be merged. 109 | """ 110 | batch_size = len(data) 111 | dimensions = torch.tensor(data[0][value_key]).size() 112 | output_tensor = torch.zeros((batch_size, *dimensions)) 113 | for row in data: 114 | index = row[index_key] 115 | value = row[value_key] 116 | if index < batch_size: 117 | output_tensor[index] = torch.tensor(value) 118 | else: 119 | raise IndexError( 120 | f"Index {index} out of bounds for batch size {batch_size}." 121 | ) 122 | return output_tensor 123 | 124 | 125 | def deduplicate_rows_in_tensor( 126 | file_path: Optional[str] = None, return_tensor: bool = False 127 | ) -> Union[None, torch.Tensor]: 128 | """ 129 | Identifies and de-duplicate repeated rows in a PyTorch tensor. 130 | Rows that are not duplicated will have a new column with value 0, 131 | while rows that are duplicated will have a new column indicating the number of duplicates from 1 to N-1 132 | where N is the number of duplicates for that row. 133 | 134 | Args: 135 | file_path: Optional; Path to a file containing the tensor data. 136 | return_tensor: If True, returns the modified tensor; otherwise, saves it to the file. 137 | Returns: 138 | If return_tensor is True, returns the modified tensor with a new column indicating 139 | the number of duplicates for each row. If False, saves the modified tensor to the file. 140 | """ 141 | if not file_path.endswith(".pt"): 142 | return None 143 | data = torch.load(open_local_or_remote(file_path, mode="rb")) 144 | assert len(data.size()) == 2, "Input data must be a 2D PyTorch tensor." 145 | 146 | # Use torch.unique to get unique rows and their inverse indices 147 | unique_rows, inverse_indices, counts = torch.unique( 148 | data, dim=0, return_inverse=True, return_counts=True 149 | ) 150 | 151 | output_indices = torch.zeros_like(inverse_indices) 152 | 153 | # Find indices where counts > 1 (meaning duplicates exist) 154 | duplicate_indices = torch.where(counts > 1)[0] 155 | 156 | for i in range(len(duplicate_indices)): 157 | # Calculate number of collisions 158 | num_of_collisions = counts[duplicate_indices[i]] 159 | 160 | # Gather the indices where the collisions occur 161 | indices_to_change = torch.where(inverse_indices == duplicate_indices[i])[0] 162 | 163 | # Create a range based on the number of collision, starting from 1 164 | range_to_add = torch.arange(1, num_of_collisions + 1) 165 | 166 | # Scatter to those specific indices 167 | output_indices = output_indices.scatter(0, indices_to_change, range_to_add) 168 | 169 | # Concatenate the duplicate indicator column to the original data 170 | result = torch.cat((data, output_indices.unsqueeze(1)), dim=1).long() 171 | if return_tensor: 172 | return result 173 | else: 174 | # Save the result to a file 175 | torch.save(result, file_path) 176 | return None 177 | 178 | 179 | def transpose_tensor_from_file( 180 | file_path: Optional[str] = None, 181 | return_tensor: bool = False, 182 | dim1: int = -2, 183 | dim2: int = -1, 184 | ) -> Union[None, torch.Tensor]: 185 | """ 186 | Transposes a PyTorch tensor from a file accoridng to designated dimensions. 187 | 188 | Args: 189 | file_path: Optional; Path to a file containing the tensor data. 190 | return_tensor: If True, returns the modified tensor; otherwise, saves it to the file. 191 | dim1: The first dimension to transpose (default: -2). 192 | dim2: The second dimension to transpose (default: -1). 193 | Returns: 194 | If return_tensor is True, returns the modified tensor. If False, saves the modified tensor to the file. 195 | """ 196 | if not file_path.endswith(".pt"): 197 | return None 198 | data = torch.load(open_local_or_remote(file_path, mode="rb")) 199 | 200 | # Transpose the tensor 201 | result = data.transpose(dim1, dim2) 202 | if return_tensor: 203 | return result 204 | else: 205 | # Save the result to a file 206 | torch.save(result, file_path) 207 | return None 208 | -------------------------------------------------------------------------------- /src/models/modules/base_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union 2 | 3 | import torch 4 | import transformers 5 | from lightning import LightningModule 6 | from torchmetrics import MeanMetric 7 | from torchmetrics.aggregation import BaseAggregator 8 | 9 | from src.components.eval_metrics import Evaluator 10 | from src.utils.pylogger import RankedLogger 11 | 12 | command_line_logger = RankedLogger(__name__, rank_zero_only=True) 13 | 14 | 15 | class BaseModule(LightningModule): 16 | def __init__( 17 | self, 18 | model: Union[torch.nn.Module, transformers.PreTrainedModel], 19 | optimizer: torch.optim.Optimizer, 20 | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], 21 | loss_function: torch.nn.Module, 22 | evaluator: Evaluator, 23 | training_loop_function: callable = None, 24 | ) -> None: 25 | """ 26 | Args: 27 | model: The model to train. 28 | optimizer: The optimizer to use for the model. 29 | scheduler: The scheduler to use for the model. 30 | loss_function: The loss function to use for the model. 31 | evaluator: The evaluator to use for the model. 32 | training_loop_function: The training loop function to use for the model, in case it is different than the default one. 33 | """ 34 | super().__init__() 35 | 36 | self.model = model 37 | self.optimizer = optimizer 38 | self.scheduler = scheduler 39 | self.loss_function = loss_function 40 | self.evaluator = evaluator 41 | self.training_loop_function = training_loop_function 42 | # We use setters to set the prediction key and name. 43 | self._prediction_key_name = None 44 | self._prediction_name = None 45 | 46 | if self.training_loop_function is not None: 47 | self.automatic_optimization = False 48 | 49 | if self.evaluator: # For inference, evaluator is not set. 50 | for metric_name, metric_object in self.evaluator.metrics.items(): 51 | setattr(self, metric_name, metric_object) 52 | 53 | # for averaging loss across batches 54 | self.train_loss = MeanMetric() 55 | self.val_loss = MeanMetric() 56 | self.test_loss = MeanMetric() 57 | 58 | @property 59 | def prediction_key_name(self) -> Optional[str]: 60 | return self._prediction_key_name 61 | 62 | @prediction_key_name.setter 63 | def prediction_key_name(self, value: str) -> None: 64 | command_line_logger.debug(f"Setting prediction_key_name to {value}") 65 | self._prediction_key_name = value 66 | 67 | @property 68 | def prediction_name(self) -> Optional[str]: 69 | return self._prediction_name 70 | 71 | @prediction_name.setter 72 | def prediction_name(self, value: str) -> None: 73 | command_line_logger.debug(f"Setting prediction_name to {value}") 74 | self._prediction_name = value 75 | 76 | def forward( 77 | self, 78 | **kwargs: Dict[str, torch.Tensor], 79 | ) -> torch.Tensor: 80 | raise NotImplementedError( 81 | "Inherit from this class and implement the forward method." 82 | ) 83 | 84 | def model_step( 85 | self, 86 | model_input: Any, 87 | label_data: Optional[Any] = None, 88 | ): 89 | raise NotImplementedError( 90 | "Inherit from this class and implement the model_step method." 91 | ) 92 | 93 | def on_train_start(self) -> None: 94 | """Lightning hook that is called when training begins.""" 95 | # by default lightning executes validation step sanity checks before training starts, 96 | # so it's worth to make sure validation metrics don't store results from these checks 97 | self.val_loss.reset() 98 | self.evaluator.reset() 99 | self.train_loss.reset() 100 | self.test_loss.reset() 101 | 102 | def on_validation_epoch_start(self) -> None: 103 | """Lightning hook that is called when a validation epoch starts.""" 104 | self.val_loss.reset() 105 | self.evaluator.reset() 106 | 107 | def on_test_epoch_start(self): 108 | self.test_loss.reset() 109 | self.evaluator.reset() 110 | 111 | def on_validation_epoch_end(self) -> None: 112 | "Lightning hook that is called when a validation epoch ends." 113 | self.log("val/loss", self.val_loss, sync_dist=False, prog_bar=True, logger=True) 114 | self.log_metrics("val") 115 | 116 | def on_test_epoch_end(self) -> None: 117 | self.log( 118 | "test/loss", self.test_loss, sync_dist=False, prog_bar=True, logger=True 119 | ) 120 | self.log_metrics("test") 121 | 122 | def on_exception(self, exception): 123 | self.trainer.should_stop = True # stop all workers 124 | self.trainer.logger.finalize(status="failure") 125 | 126 | def log_metrics( 127 | self, 128 | prefix: str, 129 | on_step=False, 130 | on_epoch=True, 131 | # We use sync_dist=False by default because, if using retrieval metrics, those are already synchronized. Change if using 132 | # different metrics than the default ones. 133 | sync_dist=False, 134 | logger=True, 135 | prog_bar=False, 136 | call_compute=False, 137 | ) -> Dict[str, Any]: 138 | 139 | metrics_dict = { 140 | f"{prefix}/{metric_name}": metric_object.compute() 141 | if call_compute 142 | else metric_object 143 | for metric_name, metric_object in self.evaluator.metrics.items() 144 | } 145 | 146 | self.log_dict( 147 | metrics_dict, 148 | on_step=on_step, 149 | on_epoch=on_epoch, 150 | sync_dist=sync_dist, 151 | logger=logger, 152 | prog_bar=prog_bar, 153 | ) 154 | 155 | def setup(self, stage: str) -> None: 156 | """Lightning hook that is called at the beginning of fit (train + validate), validate, 157 | test, or predict. 158 | 159 | This is a good hook when you need to build models dynamically or adjust something about 160 | them. This hook is called on every process when using DDP. 161 | 162 | :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 163 | """ 164 | # if self.hparams.compile and stage == "fit": 165 | # self.net = torch.compile(self.net) 166 | pass 167 | 168 | def configure_optimizers(self) -> Dict[str, Any]: 169 | """Choose what optimizers and learning-rate schedulers to use in your optimization. 170 | Normally you'd need one. But in the case of GANs or similar you might have multiple. 171 | 172 | Examples: 173 | https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers 174 | 175 | :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. 176 | """ 177 | optimizer = self.optimizer(params=self.trainer.model.parameters()) 178 | if self.scheduler is not None: 179 | scheduler = self.scheduler(optimizer=optimizer) 180 | return { 181 | "optimizer": optimizer, 182 | "lr_scheduler": { 183 | "scheduler": scheduler, 184 | "monitor": "val/loss", 185 | "interval": "step", 186 | "frequency": 1, 187 | }, 188 | } 189 | return {"optimizer": optimizer} 190 | 191 | def eval_step(self, batch: Any, loss_to_aggregate: BaseAggregator): 192 | raise NotImplementedError("eval_step method must be implemented.") 193 | 194 | def validation_step( 195 | self, 196 | batch: Any, 197 | batch_idx: int, 198 | ) -> None: 199 | """Perform a single validation step on a batch of data from the validation set. 200 | 201 | :param batch: A batch of data of data (tuple) where first object is a SequentialModelInputData object 202 | and second is a SequentialModuleLabelData object. 203 | """ 204 | self.eval_step(batch, self.val_loss) 205 | 206 | def test_step( 207 | self, 208 | batch: Any, 209 | batch_idx: int, 210 | ) -> None: 211 | """Perform a single test step on a batch of data from the test set. 212 | 213 | :param batch: A batch of data of data (tuple) where first object is a SequentialModelInputData object 214 | and second is a SequentialModuleLabelData object. 215 | """ 216 | self.eval_step(batch, self.test_loss) 217 | -------------------------------------------------------------------------------- /configs/experiment/rkmeans_train_flat.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_dir: ??? 3 | embedding_path: ??? 4 | embedding_dim: ??? 5 | num_hierarchies: ??? 6 | codebook_width: ??? 7 | 8 | task_name: train 9 | id: ${now:%Y-%m-%d}/${now:%H-%M-%S} 10 | tags: 11 | - amazon-assign-ids-train 12 | train: true 13 | test: false 14 | ckpt_path: null 15 | seed: 42 16 | data_loading: 17 | features_config: 18 | features: 19 | - name: id 20 | num_placeholder_tokens: 0 21 | is_item_ids: true 22 | embeddings: 23 | _target_: torch.load 24 | _args_: 25 | - _target_: src.utils.file_utils.open_local_or_remote 26 | file_path: ${embedding_path} 27 | mode: rb 28 | type: 29 | _target_: torch.__dict__.get 30 | _args_: 31 | - int32 32 | dataset_config: 33 | dataset: 34 | _target_: src.data.loading.components.interfaces.ItemDatasetConfig 35 | item_id_field: id 36 | keep_item_id: true 37 | iterate_per_row: true 38 | data_iterator: 39 | _target_: src.data.loading.components.iterators.TFRecordIterator 40 | features_to_consider: ${extract_fields_from_list_of_dicts:${data_loading.features_config.features}, 41 | "name", False, "is_item_ids", "True"} 42 | embedding_map: 43 | id: ${data_loading.features_config.features[0].embeddings} 44 | num_placeholder_tokens_map: ${create_map_from_list_of_dicts:${data_loading.features_config.features}, 45 | "name", "num_placeholder_tokens"} 46 | preprocessing_functions: 47 | - _target_: src.data.loading.components.pre_processing.filter_features_to_consider 48 | _partial_: true 49 | - _target_: src.data.loading.components.pre_processing.convert_to_dense_numpy_array 50 | _partial_: true 51 | - _target_: src.data.loading.components.pre_processing.convert_fields_to_tensors 52 | _partial_: true 53 | - _target_: src.data.loading.components.pre_processing.map_sparse_id_to_embedding 54 | _partial_: true 55 | sparse_id_field: id 56 | embedding_field_to_add: embedding 57 | field_type_map: ${create_map_from_list_of_dicts:${data_loading.features_config.features}, 58 | "name", "type"} 59 | datamodule: 60 | _target_: src.data.loading.datamodules.sequence_datamodule.ItemDataModule 61 | train_dataloader_config: 62 | _target_: src.data.loading.components.interfaces.ItemDataloaderConfig 63 | dataset_class: 64 | _target_: src.data.loading.components.dataloading.UnboundedSequenceIterable 65 | _partial_: true 66 | data_folder: ${paths.data_dir}/items 67 | should_shuffle_rows: true 68 | batch_size_per_device: 2048 69 | num_workers: 12 70 | assign_files_by_size: false 71 | timeout: 60 72 | drop_last: false 73 | pin_memory: true 74 | persistent_workers: true 75 | collate_fn: 76 | _target_: src.data.loading.components.collate_functions.collate_fn_items 77 | _partial_: true 78 | item_id_field: ${data_loading.dataset_config.dataset.item_id_field} 79 | feature_to_input_name: 80 | id: item_ids 81 | text: text_tokens 82 | text_mask: text_mask 83 | embedding: input_embedding 84 | dataset_config: ${data_loading.dataset_config.dataset} 85 | limit_files: null 86 | assign_all_files_per_worker: true 87 | val_dataloader_config: 88 | _target_: src.data.loading.components.interfaces.ItemDataloaderConfig 89 | dataset_class: 90 | _target_: src.data.loading.components.dataloading.UnboundedSequenceIterable 91 | _partial_: true 92 | data_folder: ${paths.data_dir}/items 93 | should_shuffle_rows: false 94 | batch_size_per_device: 256 95 | num_workers: 2 96 | assign_files_by_size: true 97 | timeout: 60 98 | drop_last: false 99 | pin_memory: false 100 | persistent_workers: true 101 | collate_fn: 102 | _target_: src.data.loading.components.collate_functions.collate_fn_items 103 | _partial_: true 104 | item_id_field: ${data_loading.dataset_config.dataset.item_id_field} 105 | feature_to_input_name: 106 | id: item_ids 107 | text: text_tokens 108 | text_mask: text_mask 109 | embedding: input_embedding 110 | dataset_config: ${data_loading.dataset_config.dataset} 111 | limit_files: null 112 | test_dataloader_config: 113 | _target_: src.data.loading.components.interfaces.ItemDataloaderConfig 114 | dataset_class: 115 | _target_: src.data.loading.components.dataloading.UnboundedSequenceIterable 116 | _partial_: true 117 | data_folder: ${paths.data_dir}/items 118 | should_shuffle_rows: false 119 | batch_size_per_device: 256 120 | num_workers: 2 121 | assign_files_by_size: true 122 | timeout: 60 123 | drop_last: false 124 | pin_memory: false 125 | persistent_workers: true 126 | collate_fn: 127 | _target_: src.data.loading.components.collate_functions.collate_fn_items 128 | _partial_: true 129 | item_id_field: ${data_loading.dataset_config.dataset.item_id_field} 130 | feature_to_input_name: 131 | id: item_ids 132 | text: text_tokens 133 | text_mask: text_mask 134 | embedding: input_embedding 135 | dataset_config: ${data_loading.dataset_config.dataset} 136 | limit_files: null 137 | model: 138 | _target_: src.modules.clustering.residual_quantization.ResidualQuantization 139 | track_residuals: true 140 | verbose: true 141 | train_layer_wise: true 142 | normalize_residuals: true 143 | input_dim: ${embedding_dim} 144 | n_layers: ${num_hierarchies} 145 | init_buffer_size: 3072 146 | quantization_layer: 147 | _target_: src.models.modules.clustering.mini_batch_kmeans.MiniBatchKMeans 148 | n_clusters: ${codebook_width} 149 | n_features: ${model.input_dim} 150 | distance_function: 151 | _target_: src.components.distance_functions.SquaredEuclideanDistance 152 | initializer: 153 | _target_: src.components.clustering_initializers.KMeansPlusPlusInitInitializer 154 | n_clusters: ${model.quantization_layer.n_clusters} 155 | distance_function: ${model.quantization_layer.distance_function} 156 | initialize_on_cpu: false 157 | init_buffer_size: ${model.init_buffer_size} 158 | optimizer: null 159 | optimizer: ${optim.optimizer} 160 | scheduler: null 161 | quantization_layer_list: null 162 | training_loop_function: 163 | _target_: src.components.training_loop_functions.scale_loss_by_world_size_for_initialization_training_loop 164 | _partial_: true 165 | callbacks: 166 | model_checkpoint: 167 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 168 | dirpath: ${paths.output_dir}/checkpoints 169 | filename: checkpoint_{epoch:03d}_{step:06d} 170 | monitor: train/loss 171 | verbose: true 172 | save_last: null 173 | save_top_k: 1 174 | mode: min 175 | auto_insert_metric_name: false 176 | save_weights_only: false 177 | every_n_train_steps: ${trainer.max_steps} 178 | train_time_interval: null 179 | every_n_epochs: null 180 | save_on_train_epoch_end: null 181 | early_stopping: null 182 | model_summary: 183 | _target_: lightning.pytorch.callbacks.RichModelSummary 184 | max_depth: -1 185 | logger: 186 | csv: 187 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 188 | save_dir: ${paths.output_dir} 189 | name: csv/ 190 | prefix: '' 191 | trainer: 192 | _target_: lightning.pytorch.trainer.Trainer 193 | default_root_dir: ${paths.output_dir} 194 | min_steps: 1 195 | max_steps: 30 196 | max_epochs: 10 197 | accelerator: gpu 198 | devices: -1 199 | num_nodes: 1 200 | precision: bf16-mixed 201 | log_every_n_steps: 10 202 | val_check_interval: 100000000 203 | deterministic: false 204 | accumulate_grad_batches: 1 205 | profiler: 206 | _target_: lightning.pytorch.profilers.PassThroughProfiler 207 | strategy: ddp_find_unused_parameters_true 208 | sync_batchnorm: true 209 | num_sanity_val_steps: 0 210 | paths: 211 | root_dir: . 212 | data_dir: ${data_dir} 213 | log_dir: ${paths.root_dir}/logs 214 | output_dir: ${hydra:runtime.output_dir} 215 | work_dir: ${hydra:runtime.cwd} 216 | profile_dir: ${hydra:run.dir}/profile_output 217 | metadata_dir: ${paths.output_dir}/metadata 218 | extras: 219 | ignore_warnings: false 220 | enforce_tags: true 221 | print_config_warnings: true 222 | print_config: true 223 | optim: 224 | optimizer: 225 | _target_: torch.optim.SGD 226 | _partial_: true 227 | lr: 0.5 228 | scheduler: null 229 | eval: 230 | evaluator: 231 | placeholder_token_buffer: 0 232 | -------------------------------------------------------------------------------- /configs/experiment/rvq_train_flat.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | data_dir: ??? 3 | embedding_path: ??? 4 | embedding_dim: ??? 5 | num_hierarchies: ??? 6 | codebook_width: ??? 7 | 8 | task_name: train 9 | id: ${now:%Y-%m-%d}/${now:%H-%M-%S} 10 | tags: 11 | - amazon-assign-ids-train 12 | train: true 13 | test: false 14 | ckpt_path: null 15 | seed: 42 16 | data_loading: 17 | features_config: 18 | features: 19 | - name: id 20 | num_placeholder_tokens: 0 21 | is_item_ids: true 22 | embeddings: 23 | _target_: torch.load 24 | _args_: 25 | - _target_: src.utils.file_utils.open_local_or_remote 26 | file_path: ${embedding_path} 27 | mode: rb 28 | type: 29 | _target_: torch.__dict__.get 30 | _args_: 31 | - int32 32 | dataset_config: 33 | dataset: 34 | _target_: src.data.loading.components.interfaces.ItemDatasetConfig 35 | item_id_field: id 36 | keep_item_id: true 37 | iterate_per_row: true 38 | data_iterator: 39 | _target_: src.data.loading.components.iterators.TFRecordIterator 40 | features_to_consider: ${extract_fields_from_list_of_dicts:${data_loading.features_config.features}, 41 | "name", False, "is_item_ids", "True"} 42 | embedding_map: 43 | id: ${data_loading.features_config.features[0].embeddings} 44 | num_placeholder_tokens_map: ${create_map_from_list_of_dicts:${data_loading.features_config.features}, 45 | "name", "num_placeholder_tokens"} 46 | preprocessing_functions: 47 | - _target_: src.data.loading.components.pre_processing.filter_features_to_consider 48 | _partial_: true 49 | - _target_: src.data.loading.components.pre_processing.convert_to_dense_numpy_array 50 | _partial_: true 51 | - _target_: src.data.loading.components.pre_processing.convert_fields_to_tensors 52 | _partial_: true 53 | - _target_: src.data.loading.components.pre_processing.map_sparse_id_to_embedding 54 | _partial_: true 55 | sparse_id_field: id 56 | embedding_field_to_add: embedding 57 | field_type_map: ${create_map_from_list_of_dicts:${data_loading.features_config.features}, 58 | "name", "type"} 59 | datamodule: 60 | _target_: src.data.loading.datamodules.sequence_datamodule.ItemDataModule 61 | train_dataloader_config: 62 | _target_: src.data.loading.components.interfaces.ItemDataloaderConfig 63 | dataset_class: 64 | _target_: src.data.loading.components.dataloading.UnboundedSequenceIterable 65 | _partial_: true 66 | data_folder: ${paths.data_dir}/items 67 | should_shuffle_rows: true 68 | batch_size_per_device: 2048 69 | num_workers: 12 70 | assign_files_by_size: false 71 | timeout: 60 72 | drop_last: false 73 | pin_memory: true 74 | persistent_workers: true 75 | collate_fn: 76 | _target_: src.data.loading.components.collate_functions.collate_fn_items 77 | _partial_: true 78 | item_id_field: ${data_loading.dataset_config.dataset.item_id_field} 79 | feature_to_input_name: 80 | id: item_ids 81 | text: text_tokens 82 | text_mask: text_mask 83 | embedding: input_embedding 84 | dataset_config: ${data_loading.dataset_config.dataset} 85 | limit_files: null 86 | assign_all_files_per_worker: true 87 | val_dataloader_config: 88 | _target_: src.data.loading.components.interfaces.ItemDataloaderConfig 89 | dataset_class: 90 | _target_: src.data.loading.components.dataloading.UnboundedSequenceIterable 91 | _partial_: true 92 | data_folder: ${paths.data_dir}/items 93 | should_shuffle_rows: false 94 | batch_size_per_device: 256 95 | num_workers: 2 96 | assign_files_by_size: true 97 | timeout: 60 98 | drop_last: false 99 | pin_memory: false 100 | persistent_workers: true 101 | collate_fn: 102 | _target_: src.data.loading.components.collate_functions.collate_fn_items 103 | _partial_: true 104 | item_id_field: ${data_loading.dataset_config.dataset.item_id_field} 105 | feature_to_input_name: 106 | id: item_ids 107 | text: text_tokens 108 | text_mask: text_mask 109 | embedding: input_embedding 110 | dataset_config: ${data_loading.dataset_config.dataset} 111 | limit_files: null 112 | test_dataloader_config: 113 | _target_: src.data.loading.components.interfaces.ItemDataloaderConfig 114 | dataset_class: 115 | _target_: src.data.loading.components.dataloading.UnboundedSequenceIterable 116 | _partial_: true 117 | data_folder: ${paths.data_dir}/items 118 | should_shuffle_rows: false 119 | batch_size_per_device: 256 120 | num_workers: 2 121 | assign_files_by_size: true 122 | timeout: 60 123 | drop_last: false 124 | pin_memory: false 125 | persistent_workers: true 126 | collate_fn: 127 | _target_: src.data.loading.components.collate_functions.collate_fn_items 128 | _partial_: true 129 | item_id_field: ${data_loading.dataset_config.dataset.item_id_field} 130 | feature_to_input_name: 131 | id: item_ids 132 | text: text_tokens 133 | text_mask: text_mask 134 | embedding: input_embedding 135 | dataset_config: ${data_loading.dataset_config.dataset} 136 | limit_files: null 137 | model: 138 | _target_: src.modules.clustering.residual_quantization.ResidualQuantization 139 | track_residuals: true 140 | verbose: true 141 | train_layer_wise: true 142 | normalize_residuals: true 143 | input_dim: ${embedding_dim} 144 | n_layers: ${num_hierarchies} 145 | init_buffer_size: 3072 146 | quantization_layer: 147 | _target_: src.modules.clustering.vector_quantization.VectorQuantization 148 | n_clusters: ${codebook_width} 149 | n_features: ${model.input_dim} 150 | distance_function: 151 | _target_: src.components.distance_functions.SquaredEuclideanDistance 152 | quantization_strategy: 153 | _target_: src.components.quantization_strategies.STEQuantization 154 | distance_function: ${model.quantization_layer.distance_function} 155 | compute_reconstruction_loss_embeddings: false 156 | initializer: 157 | _target_: src.components.clustering_initializers.KMeansPlusPlusInitInitializer 158 | n_clusters: ${model.quantization_layer.n_clusters} 159 | initialize_on_cpu: false 160 | distance_function: ${model.quantization_layer.distance_function} 161 | init_buffer_size: ${model.init_buffer_size} 162 | loss_function: 163 | _target_: src.components.loss_functions.BetaQuantizationLoss 164 | beta: 0.25 165 | optimizer: null 166 | optimizer: ${optim.optimizer} 167 | scheduler: null 168 | quantization_layer_list: null 169 | training_loop_function: 170 | _target_: src.components.training_loop_functions.scale_loss_by_world_size_for_initialization_training_loop 171 | _partial_: true 172 | callbacks: 173 | model_checkpoint: 174 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 175 | dirpath: ${paths.output_dir}/checkpoints 176 | filename: checkpoint_{epoch:03d}_{step:06d} 177 | monitor: train/loss 178 | verbose: true 179 | save_last: null 180 | save_top_k: 1 181 | mode: min 182 | auto_insert_metric_name: false 183 | save_weights_only: false 184 | every_n_train_steps: ${trainer.max_steps} 185 | train_time_interval: null 186 | every_n_epochs: null 187 | save_on_train_epoch_end: null 188 | early_stopping: null 189 | model_summary: 190 | _target_: lightning.pytorch.callbacks.RichModelSummary 191 | max_depth: -1 192 | logger: 193 | csv: 194 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 195 | save_dir: ${paths.output_dir} 196 | name: csv/ 197 | prefix: '' 198 | trainer: 199 | _target_: lightning.pytorch.trainer.Trainer 200 | default_root_dir: ${paths.output_dir} 201 | min_steps: 1 202 | max_steps: 3000 203 | max_epochs: 10 204 | accelerator: gpu 205 | devices: -1 206 | num_nodes: 1 207 | precision: bf16-mixed 208 | log_every_n_steps: 10 209 | val_check_interval: 100000000 210 | deterministic: false 211 | accumulate_grad_batches: 1 212 | profiler: 213 | _target_: lightning.pytorch.profilers.PassThroughProfiler 214 | strategy: ddp_find_unused_parameters_true 215 | sync_batchnorm: true 216 | num_sanity_val_steps: 0 217 | paths: 218 | root_dir: . 219 | data_dir: ${data_dir} 220 | log_dir: ${paths.root_dir}/logs 221 | output_dir: ${hydra:runtime.output_dir} 222 | work_dir: ${hydra:runtime.cwd} 223 | profile_dir: ${hydra:run.dir}/profile_output 224 | metadata_dir: ${paths.output_dir}/metadata 225 | extras: 226 | ignore_warnings: false 227 | enforce_tags: true 228 | print_config_warnings: true 229 | print_config: true 230 | optim: 231 | optimizer: 232 | lr: 0.001 233 | eval: 234 | evaluator: 235 | placeholder_token_buffer: 0 236 | -------------------------------------------------------------------------------- /src/utils/restart_job_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from dataclasses import dataclass, field 4 | from datetime import datetime 5 | from typing import Any, Callable, Dict, List, Optional, Set, TypeVar 6 | 7 | import psutil 8 | from lightning import Trainer 9 | from omegaconf import DictConfig 10 | 11 | from src.utils.file_utils import ( 12 | file_exists_local_or_remote, 13 | load_json, 14 | open_local_or_remote, 15 | ) 16 | from src.utils.pylogger import RankedLogger 17 | 18 | command_line_logger = RankedLogger(__name__, rank_zero_only=True) 19 | F = TypeVar("F", bound=Callable[..., Any]) 20 | 21 | import os 22 | import sys 23 | from datetime import datetime 24 | from typing import Any, Callable, TypeVar 25 | 26 | import torch 27 | import torch.distributed as dist 28 | from lightning import Trainer 29 | from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher 30 | from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal 31 | from omegaconf import DictConfig 32 | 33 | from src.utils.pylogger import RankedLogger 34 | 35 | 36 | @dataclass 37 | class JobCheckpointMetadata: 38 | """ 39 | A class that stores metadata for job checkpointing and restarts. 40 | 41 | Attributes: 42 | start_time (str): ISO formatted timestamp of when the job started. Defaults to the current time. 43 | restarts (List[Dict[str, Any]]): List of dictionaries containing information about previous restarts. 44 | current_run (int): Counter for the current run number. Starts at 0 and increments with each restart. 45 | used_ports (List[str]): List of ports that have been used by previous runs. 46 | world_size (int): The total number of processes participating in the distributed job. 47 | node_rank (int): The rank of this node in the distributed job. 48 | master_addr (str): The address of the master node for distributed training. 49 | original_args (List[str]): The original command-line arguments used to start the job. 50 | 51 | Methods: 52 | to_dict(): Converts the metadata object to a dictionary representation. 53 | """ 54 | 55 | start_time: str = field(default_factory=lambda: datetime.now().isoformat()) 56 | restarts: List[Dict[str, Any]] = field(default_factory=list) 57 | current_run: int = 0 58 | used_ports: List[str] = field(default_factory=list) 59 | world_size: int = 0 60 | node_rank: int = 0 61 | master_addr: str = "" 62 | original_args: List[str] = field(default_factory=lambda: sys.argv) 63 | 64 | def to_dict(self) -> Dict[str, Any]: 65 | return { 66 | "start_time": self.start_time, 67 | "restarts": self.restarts, 68 | "current_run": self.current_run, 69 | "used_ports": self.used_ports, 70 | "world_size": self.world_size, 71 | "node_rank": self.node_rank, 72 | "master_addr": self.master_addr, 73 | "original_args": self.original_args, 74 | } 75 | 76 | 77 | @dataclass 78 | class RestartMetadata: 79 | """ 80 | A class to store metadata about a job restart. 81 | 82 | Attributes: 83 | time (str): The time when the job was restarted. 84 | exception (str): The exception that caused the job to be restarted. 85 | run_number (int): The number of times the job has been run. 86 | 87 | Methods: 88 | to_dict(): Converts the metadata to a dictionary. 89 | """ 90 | 91 | time: str 92 | exception: str 93 | run_number: int 94 | 95 | def to_dict(self) -> Dict[str, Any]: 96 | return { 97 | "time": self.time, 98 | "exception": self.exception, 99 | "run_number": self.run_number, 100 | } 101 | 102 | 103 | def load_metadata_from_local_or_remote(metadata_path: str) -> JobCheckpointMetadata: 104 | """ 105 | Loads a JobCheckpointMetadata from a local or remote filepath, if available. 106 | 107 | This function attempts to load and deserialize a JobCheckpointMetadata from the given path. 108 | If the file doesn't exist, it returns an empty JobCheckpointMetadata object and logs a warning. 109 | 110 | Args: 111 | metadata_path (str): Path to the metadata file, either local or remote. 112 | 113 | Returns: 114 | JobCheckpointMetadata: The loaded metadata if file exists, otherwise an empty metadata object. 115 | """ 116 | command_line_logger.info(f"Trying to load metadata from {metadata_path}") 117 | if file_exists_local_or_remote(metadata_path): 118 | metadata_dict = load_json(metadata_path) 119 | command_line_logger.info(f"Metadata loaded successfully from {metadata_path}") 120 | return JobCheckpointMetadata(**metadata_dict) 121 | else: 122 | command_line_logger.warning( 123 | f"Metadata file not found at {metadata_path}. Creating empty metadata." 124 | ) 125 | return JobCheckpointMetadata() 126 | 127 | 128 | def save_metadata_to_local_or_remote( 129 | metadata: JobCheckpointMetadata, metadata_path: str 130 | ) -> None: 131 | """ 132 | Save job checkpoint metadata to a local or remote file. 133 | This function serializes the metadata object to JSON and writes it to the specified path. 134 | It handles both local filesystem paths and remote paths (e.g., Google Cloud Storage 'gs://' paths) 135 | using the open_local_or_remote utility function. 136 | Args: 137 | metadata (JobCheckpointMetadata): The metadata object to save 138 | metadata_path (str): Path where to save the metadata, can be a local path or a remote path (e.g., gs://...) 139 | Returns: 140 | None 141 | Logs: 142 | - Info message indicating where metadata is being saved 143 | """ 144 | 145 | command_line_logger.info( 146 | f"Saving metadata to {metadata_path}. {metadata.to_dict()}" 147 | ) 148 | 149 | # Convert metadata to JSON string 150 | json_content = json.dumps(metadata.to_dict(), indent=2) 151 | 152 | # Use the open_local_or_remote function which should handle gs:// paths 153 | with open_local_or_remote(metadata_path, "w") as f: 154 | f.write(json_content) 155 | 156 | 157 | def get_attribute_from_metadata_file(metadata_path: str, attribute: str) -> Any: 158 | """ 159 | Extracts a specified attribute from a metadata file. 160 | 161 | This function loads the metadata from either a local or remote file and returns 162 | the value of the specified attribute from the loaded metadata object. 163 | 164 | Args: 165 | metadata_path (str): The path to the metadata file. This can be a local file path 166 | or a remote URL. 167 | attribute (str): The name of the attribute to retrieve from the metadata object. 168 | 169 | Returns: 170 | Any: The value of the specified attribute from the metadata object. 171 | 172 | Raises: 173 | AttributeError: If the specified attribute doesn't exist in the metadata object. 174 | 175 | Note: 176 | This function relies on the `load_metadata_from_local_or_remote` function to handle 177 | the loading of the metadata file. 178 | """ 179 | metadata = load_metadata_from_local_or_remote(metadata_path) 180 | attribute_value = getattr(metadata, attribute, None) 181 | command_line_logger.info( 182 | f"Retrieved {attribute}: {attribute_value} from metadata {metadata_path}" 183 | ) 184 | return attribute_value 185 | 186 | 187 | def _is_process_running(proc: psutil.Process) -> bool: 188 | """ 189 | Check if a process is still running. 190 | 191 | This function polls the process to check its status. It returns True if the process is still running (i.e., its return code is None), 192 | and False if the process has terminated (i.e., its return code is not None). 193 | 194 | Args: 195 | proc (psutil.Process): The process to check. 196 | 197 | Returns: 198 | bool: True if the process is still running, False otherwise. 199 | """ 200 | # Check if the process is running by checking the return code. 201 | # If the process is still running, poll() will return None. 202 | # If the process has finished, poll() will return the return code. 203 | proc.poll() 204 | return proc.returncode is None 205 | 206 | 207 | def clean_up_resources( 208 | trainer: Optional[Trainer] = None, exception: Optional[Exception] = None 209 | ) -> None: 210 | """Clean up distributed processes and CUDA resources.""" 211 | if dist.is_initialized(): 212 | command_line_logger.info("Cleaning up distributed process group") 213 | dist.destroy_process_group() 214 | 215 | if torch.cuda.is_available(): 216 | command_line_logger.info("Clearing CUDA cache") 217 | torch.cuda.empty_cache() 218 | 219 | if trainer is not None: 220 | command_line_logger.info("Tearing down trainer") 221 | trainer.strategy.on_exception(exception) 222 | launcher = trainer.strategy.launcher if trainer.strategy is not None else None 223 | trainer._teardown() 224 | if isinstance(launcher, _SubprocessScriptLauncher): 225 | launcher.kill(_get_sigkill_signal()) 226 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from time import sleep 3 | from typing import Any, Dict, List, Optional, Tuple, Union 4 | 5 | import torch 6 | from omegaconf import DictConfig 7 | from transformers.cache_utils import DynamicCache 8 | 9 | from src.utils import pylogger, rich_utils 10 | from functools import partial 11 | from tokenizers.processors import TemplateProcessing 12 | from src.data.loading.components.interfaces import TokenizerConfig 13 | 14 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 15 | 16 | 17 | def print_warnings_for_missing_configs(cfg: DictConfig) -> None: 18 | _DEFAULT_CONFIGS = [ 19 | "data_loading", 20 | "model", 21 | "loss", 22 | "optim", 23 | "eval", 24 | ] 25 | has_warnings = False 26 | for config in _DEFAULT_CONFIGS: 27 | if not cfg.get(config): 28 | log.warning( 29 | f"Config {config} was not found in the config tree. Make sure this is expected." 30 | ) 31 | has_warnings = True 32 | if has_warnings: 33 | sleep(3) # wait for 3 seconds to let the user read the warning 34 | 35 | 36 | def extras(cfg: DictConfig) -> None: 37 | """Applies optional utilities before the task is started. 38 | 39 | Utilities: 40 | - Ignoring python warnings 41 | - Setting tags from command line 42 | - Rich config printing 43 | 44 | :param cfg: A DictConfig object containing the config tree. 45 | """ 46 | # return if no `extras` config 47 | if not cfg.get("extras"): 48 | log.warning("Extras config not found! ") 49 | return 50 | 51 | # disable python warnings 52 | if cfg.extras.get("ignore_warnings"): 53 | log.info("Disabling python warnings! ") 54 | warnings.filterwarnings("ignore") 55 | 56 | # prompt user to input tags from command line if none are provided in the config 57 | if cfg.extras.get("enforce_tags"): 58 | log.info("Enforcing tags! ") 59 | rich_utils.enforce_tags(cfg, save_to_file=True) 60 | 61 | if cfg.extras.get("print_config_warnings"): 62 | print_warnings_for_missing_configs(cfg) 63 | 64 | # pretty print config tree using Rich library 65 | if cfg.extras.get("print_config"): 66 | log.info("Printing config tree with Rich! ") 67 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 68 | 69 | 70 | def delete_module(module: torch.nn.Module, module_name: str) -> None: 71 | """Recursively delete a submodule from a module. 72 | 73 | :param module: the parent module that we want the submodule to be removed from. 74 | :param module_name: the name of the submodule to be removed. 75 | :return: None. 76 | """ 77 | if hasattr(module, module_name): 78 | delattr(module, module_name) 79 | 80 | for name, submodule in module.named_children(): 81 | delete_module(submodule, module_name) 82 | 83 | 84 | def find_module_shape( 85 | module: torch.nn.Module, module_name: str 86 | ) -> Optional[torch.Size]: 87 | """Recursively find a submodule in a module and return its shape. 88 | 89 | :param module: the parent module that we want the submodule to be removed from. 90 | :param module_name: the name of the submodule to be removed. 91 | :return: the shape of the module if it exists. 92 | """ 93 | if hasattr(module, module_name): 94 | return getattr(module, module_name).weight.shape 95 | 96 | for name, submodule in module.named_children(): 97 | shape = find_module_shape(submodule, module_name) 98 | if shape: 99 | return shape 100 | return None 101 | 102 | 103 | def reset_parameters(module: torch.nn.Module) -> None: 104 | """Reset the parameters of a given module. 105 | 106 | :param module: the module whose parameters will be reset. 107 | :return: None. 108 | """ 109 | 110 | if hasattr(module, "reset_parameters"): 111 | module.reset_parameters() 112 | else: 113 | for layer in module.children(): 114 | reset_parameters(layer) 115 | 116 | 117 | def get_var_if_not_none(value: Optional[Any], default_value: Any) -> Any: 118 | """ 119 | :return: value if value is not None, else default_value 120 | Note that when value is: 121 | Boolean: False is not None 122 | Int: 0 is not None 123 | List: An empty list is not None 124 | Tensor: A tensor with all zeros is not None 125 | """ 126 | return value if value is not None else default_value 127 | 128 | 129 | def get_class_name_str(class_definition: Any) -> str: 130 | """ 131 | Args: 132 | class_definition: The class definition. 133 | 134 | Returns: 135 | The fully qualified name of the given class. 136 | """ 137 | return ".".join([class_definition.__module__, class_definition.__name__]) 138 | 139 | 140 | def has_class_object_inside_list(obj_list: list, class_type: Any) -> bool: 141 | """ 142 | Args: 143 | obj_list: List of objects. 144 | class_type: The class type to check. 145 | 146 | Returns: 147 | True if the list contains an object of the given class type. 148 | """ 149 | return any(isinstance(obj, class_type) for obj in obj_list) 150 | 151 | 152 | def convert_legacy_kv_cache_to_dynamic( 153 | kv_cache: Union[DynamicCache, Tuple[torch.Tensor]] 154 | ) -> DynamicCache: 155 | """ 156 | Converts a legacy key-value cache (Tuple of tensors) to a dynamic cache. 157 | 158 | Args: 159 | kv_cache: The key-value cache. 160 | 161 | Returns: 162 | The dynamic cache. 163 | """ 164 | if isinstance(kv_cache, DynamicCache): 165 | return kv_cache 166 | 167 | return DynamicCache.from_legacy_cache(kv_cache) 168 | 169 | 170 | def get_parent_module_and_attr( 171 | model: torch.nn.Module, module_name: str 172 | ) -> Tuple[torch.nn.Module, str]: 173 | """ 174 | Get the parent module and attribute name for a given module name. 175 | 176 | Args: 177 | model (torch.nn.Module): The model containing the module. 178 | module_name (str): The full name of the module. 179 | 180 | Returns: 181 | Tuple[torch.nn.Module, str]: The parent module and the attribute name. 182 | """ 183 | parts = module_name.split(".") 184 | parent = model 185 | for part in parts[:-1]: 186 | parent = getattr(parent, part) 187 | return parent, parts[-1] 188 | 189 | 190 | def lightning_precision_to_dtype(precision: str) -> torch.dtype: 191 | """ 192 | Convert Lightning precision identifier to PyTorch dtype. 193 | 194 | Args: 195 | precision (str): The precision identifier used in Lightning. 196 | Expected values include '32', '32-true', '16', '16-mixed', 'bf16', '64', 'half'. 197 | 198 | Returns: 199 | torch.dtype: The corresponding PyTorch dtype. 200 | 201 | Raises: 202 | ValueError: If an unsupported precision type is provided. 203 | """ 204 | # Mapping from Lightning precision identifiers to PyTorch dtypes 205 | precision_map = { 206 | "32": torch.float32, # Single precision (float32) 207 | "32-true": torch.float32, # Also maps to float32, useful for clarity when specifying defaults 208 | "64": torch.float64, # Double precision 209 | "16": torch.float16, # Half precision 210 | "16-mixed": torch.float16, # Mixed precision typically uses torch.float16 211 | "bf16": torch.bfloat16, # BFloat16 precision 212 | "half": torch.float16, # Alias for half precision 213 | } 214 | 215 | if precision in precision_map: 216 | return precision_map[precision] 217 | else: 218 | raise ValueError( 219 | f"Unsupported precision type: '{precision}'. " 220 | "Supported precision types are: '32', '32-true', '64', '16', '16-mixed', 'bf16', 'half'." 221 | ) 222 | 223 | def load_tokenize(config: TokenizerConfig) -> Any: 224 | """Load tokenizer and return a partial function for tokenization.""" 225 | tokenizer = config.tokenizer 226 | if hasattr(config, "special_tokens"): 227 | tokenizer.add_special_tokens(config.special_tokens) 228 | if config.postprocess_eos_token: 229 | tokenizer._tokenizer.post_processor = TemplateProcessing( 230 | single="$A " + tokenizer.eos_token, 231 | special_tokens=[(tokenizer.eos_token, tokenizer.eos_token_id)], 232 | ) 233 | tokenize = partial( 234 | tokenizer.encode_plus, 235 | max_length=config.max_length, 236 | padding=config.padding, 237 | truncation=config.truncation, 238 | add_special_tokens=config.add_special_tokens, 239 | return_tensors="pt", 240 | ) 241 | return tokenize 242 | 243 | def sample_gumbel(shape: Tuple, device: torch.device, eps=1e-20) -> torch.Tensor: 244 | """Sample from Gumbel(0, 1)""" 245 | U = torch.rand(shape, device=device) 246 | return -torch.log(-torch.log(U + eps) + eps) 247 | 248 | 249 | def gumbel_softmax_sample(logits: torch.Tensor, temperature: float) -> torch.Tensor: 250 | """Draw a sample from the Gumbel-Softmax distribution""" 251 | y = logits + sample_gumbel(logits.shape, logits.device) 252 | sample = F.softmax(y / temperature, dim=-1) 253 | return sample 254 | --------------------------------------------------------------------------------